Coverage for src/gncpy/planning/a_star.py: 0%
191 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
1"""Implements the A* algorithm and several variations."""
2import io
3import numpy as np
4import matplotlib.pyplot as plt
5from matplotlib.patches import Rectangle
6from PIL import Image
7from sys import exit
9import gncpy.plotting as gplot
12class Node:
13 """Helper class for grid nodes in A* planning.
15 Attributes
16 ----------
17 indices : 2 numpy array
18 row/column index into grid.
19 cost : float
20 cost of the path up to this node.
21 parent_idx : int
22 index of the parent node in the closed set
23 """
25 def __init__(self, indices, cost, parent_idx):
26 """Initialize an object.
28 Parameters
29 ----------
30 indices : 2 numpy array
31 row/column index into grid.
32 cost : float
33 cost of the path up to this node.
34 parent_idx : int
35 index of the parent node in the closed set
36 """
37 self.indices = indices
38 self.cost = cost
39 self.parent_idx = parent_idx
42class AStar:
43 """Implements various forms of the A* gird search algorithm.
45 This is based on the example implementation found
46 `here <https://github.com/AtsushiSakai/PythonRobotics>`_. It currently
47 implements
49 - Normal A*
50 - Beam search
51 - Weighted
53 Attributes
54 ----------
55 resolution : 2 numpy array
56 Real distance per gird square for x and y dimensions. Should not be set
57 directly, use the :meth:`.set_map` function.
58 min : 2 numpy array
59 Minimum x and y positions in real units. Should not be set directly, use
60 the :meth:`.set_map` function.
61 max : 2 numpy array
62 Maximum x/y positions in real units. Should not be set directly, use
63 the :meth:`.set_map` function.
64 weight : float
65 Constant weighting factor applied to the heuristic. The default is 1 and
66 is only used if the weighting funtion is not overwritten.
67 use_beam_search : bool
68 Flag indicating if the beam search variation is used.
69 beam_search_max_nodes : int
70 Maximum nuber of grid nodes to keep when using the beam search variation.
71 motion : 8 x 3 numpy array
72 Each row represents a potential action, with the first column being row
73 motion in gird squares, second column being column motion, and third
74 being the cost of moving.
75 """
77 def __init__(self, use_beam_search=False, beam_search_max_nodes=30):
78 """Initialize an object.
80 Parameters
81 ----------
82 use_beam_search : bool, optional
83 Flag indicating if the beam search variation is used. The default
84 is False.
85 beam_search_max_nodes : int, optional
86 Maximum nuber of grid nodes to keep when using the beam search
87 variation. The default is 30.
88 """
89 self.resolution = np.nan * np.ones(2)
90 self.min = np.nan * np.ones(2)
91 self.max = np.nan * np.ones(2)
93 self.weight = 1
94 self.use_beam_search = use_beam_search
95 self.beam_search_max_nodes = beam_search_max_nodes
97 self.motion = np.array(
98 [
99 [1, 0, 1],
100 [0, 1, 1],
101 [-1, 0, 1],
102 [0, -1, 1],
103 [-1, -1, np.sqrt(2)],
104 [-1, 1, np.sqrt(2)],
105 [1, -1, np.sqrt(2)],
106 [1, 1, np.sqrt(2)],
107 ]
108 )
110 self._map = np.array([[]])
111 self._obstacles = None
112 self._hazards = None
114 def ind_to_pos(self, indices):
115 """Convert a set of row/colum indices to real positions.
117 Note that rows are y positions and columns are x positions, this function
118 handles the conversion.
120 Parameters
121 ----------
122 indices : 2 numpy array
123 row/column index into grid.
125 Returns
126 -------
127 2 numpy array
128 real x/y position.
129 """
130 # NOTE: row is y, column is x
131 return indices[[1, 0]] * self.resolution + self.min + self.resolution / 2
133 def pos_to_ind(self, pos):
134 """Convert a set of x/y positions to grid indices.
136 Note that rows are y positions and columns are x positions, this function
137 handles the conversion. This does not bounds check the indices.
139 Parameters
140 ----------
141 pos : 2 numpy array
142 Real x/y position.
144 Returns
145 -------
146 2 numpy array
147 grid row/column index.
148 """
149 # NOTE: x is column, y is row, bound to given area
150 p = np.max(np.vstack((pos.ravel(), self.min.ravel())), axis=0)
151 p = np.min(np.vstack((p, self.max.ravel())), axis=0)
152 inds = np.floor(
153 (p - self.min.ravel() + self.resolution / 2) / self.resolution.ravel()
154 )[[1, 0]]
156 return inds.astype(int)
158 def ravel_ind(self, multi_index):
159 """Convert row/column indices into a single index into flattened array.
161 Parameters
162 ----------
163 multi_index : 2 numpy array
164 Row/col indices into the gird.
166 Returns
167 -------
168 float
169 Flattened index
170 """
171 return np.ravel_multi_index(multi_index, self._map.shape).item()
173 def final_path(self, endNode, closed_set):
174 """Calculate the final path and total cost.
176 Parameters
177 ----------
178 endNode : :class`.Node`
179 Ending position.
180 closed_set : dict
181 Dictionary of path nodes. The keys are flattened indices and the
182 values are :class:`.Node`.
184 Returns
185 -------
186 N x 2 numpy array
187 Real x/y positions of the path nodes.
188 float
189 total cost of the path.
190 """
191 path = np.nan * np.ones((len(closed_set) + 1, endNode.indices.size))
192 path[0, :] = self.ind_to_pos(endNode.indices)
193 parent_idx = endNode.parent_idx
194 p_ind = 1
195 while parent_idx > 0:
196 curNode = closed_set[parent_idx]
197 path[p_ind, :] = self.ind_to_pos(curNode.indices)
198 p_ind += 1
199 parent_idx = curNode.parent_idx
201 # remove all the extras
202 path = path[:p_ind, :]
204 # flip so first row is starting point
205 return path[::-1, :], endNode.cost
207 def calc_weight(self, node):
208 """Calculates the weight of applied to the heuristic.
210 This can be overriddent to calculate custom weights.
212 Parameters
213 ----------
214 node : :class:`.Node`
215 Current node.
217 Returns
218 -------
219 float
220 weight value.
221 """
222 return self.weight
224 def calc_heuristic(self, endNode, curNode):
225 """Calculates the heuristic cost.
227 Parameters
228 ----------
229 endNode : :class:`.Node`
230 Goal node.
231 curNode : :class:`.Node`
232 Current node.
234 Returns
235 -------
236 float
237 heuristic cost.
238 """
239 diff = self.ind_to_pos(endNode.indices) - self.ind_to_pos(curNode.indices)
240 return np.sqrt(np.sum(diff * diff))
242 def is_valid(self, node):
243 """Checks if the node is valid.
245 Bounds checks the indices and determines if node corresponds to a wall.
247 Parameters
248 ----------
249 node : :class:`.Node`
250 Node to check.
252 Returns
253 -------
254 bool
255 True if the node is valid.
256 """
257 pos = self.ind_to_pos(node.indices)
259 # bounds check position
260 if np.any(pos < self.min) or np.any(pos > self.max):
261 return False
263 # obstacles are inf
264 if np.isinf(self._map[node.indices[0], node.indices[1]]):
265 return False
267 return True
269 def draw_start(self, fig, startNode):
270 """Draw starting node on the figure.
272 Parameters
273 ----------
274 fig : matplotlib figure
275 Figure to draw on.
276 startNode : :class:`.Node`
277 Starting node.
278 """
279 pos = self.ind_to_pos(startNode.indices) - self.resolution / 2
280 fig.axes[0].add_patch(
281 Rectangle(
282 pos, self.resolution[0], self.resolution[1], color="g", zorder=1000
283 ),
284 )
286 def draw_end(self, fig, endNode):
287 """Draw the ending node on the figure.
289 Parameters
290 ----------
291 fig : matplotlib figure
292 Figure to draw on.
293 endNode : :class:`.Node`
294 Ending node.
295 """
296 pos = self.ind_to_pos(endNode.indices) - self.resolution / 2
297 fig.axes[0].add_patch(
298 Rectangle(
299 pos, self.resolution[0], self.resolution[1], color="r", zorder=1000
300 ),
301 )
303 def draw_map(self, fig):
304 """Draws the map to the figure.
306 Parameters
307 ----------
308 fig : matplotlib figure
309 Figure to draw on.
310 """
311 rows, cols = np.where(np.isinf(self._map))
312 inds = np.vstack((rows, cols)).T
313 for ii in inds:
314 pos = self.ind_to_pos(ii)
315 fig.axes[0].add_patch(
316 Rectangle(
317 pos - self.resolution / 2,
318 self.resolution[0],
319 self.resolution[1],
320 facecolor="k",
321 )
322 )
324 rows, cols = np.where(self._map > 0)
325 inds = np.vstack((rows, cols)).T
326 for ii in inds:
327 pos = self.ind_to_pos(ii)
328 fig.axes[0].add_patch(
329 Rectangle(
330 pos - self.resolution / 2,
331 self.resolution[0],
332 self.resolution[1],
333 facecolor=(255 / 255, 140 / 255, 0 / 255),
334 zorder=-1000,
335 )
336 )
338 def set_map(self, min_pos, max_pos, grid_res, obstacles=None, hazards=None):
339 """Sets up the map with obstacles and hazards.
341 Parameters
342 ----------
343 min_pos : 2 numpy array
344 Min x/y position in real units.
345 max_pos : 2 numpy array
346 Max x/y position in real units.
347 grid_res : 2 numpy array
348 Real distance per gird square for x/y positions.
349 obstacles : N x 4 numpy array, optional
350 Locations of walls, each row is one wall. First column is x position
351 second column is y position, third is width, and fourth is height.
352 The position is the location of the center. All distances are in
353 real units. The default is None.
354 hazards : N x 5 numpy array, optional
355 Locations of walls, each row is one wall. First column is x position
356 second column is y position, third is width, fourth is height, and
357 the last is the cost of being on that node. The position is the
358 location of the center. All distances are in real units. The
359 default is None.
360 """
361 self.resolution = grid_res.ravel()
362 self.min = min_pos.ravel() - self.resolution / 2
363 self.max = max_pos.ravel() + self.resolution / 2
365 self._obstacles = obstacles
366 self._hazards = hazards
368 max_inds = self.pos_to_ind(self.max)
369 self._map = np.zeros([ii + 1 for ii in max_inds.tolist()])
371 if self._obstacles is not None:
372 for obs in self._obstacles:
373 width2 = np.array([max(obs[2], self.resolution[0]) / 2, 0])
374 height2 = np.array([0, max(obs[3], self.resolution[1]) / 2])
375 left = max(self.pos_to_ind(obs[:2] - width2)[1], 0)
376 right = min(self.pos_to_ind(obs[:2] + width2)[1], max_inds[1])
377 top = min(self.pos_to_ind(obs[:2] + height2)[0], max_inds[0])
378 bot = max(self.pos_to_ind(obs[:2] - height2)[0], 0)
380 # TODO: more efficent indexing?
381 for row in range(bot, top):
382 for col in range(left, right):
383 self._map[row, col] = np.inf
385 if self._hazards is not None:
386 for haz in self._hazards:
387 width2 = np.array([max(haz[2], self.resolution[0]) / 2, 0])
388 height2 = np.array([0, max(haz[3], self.resolution[1]) / 2])
389 left = max(self.pos_to_ind(haz[:2] - width2)[1], 0)
390 right = min(self.pos_to_ind(haz[:2] + width2)[1], max_inds[1])
391 top = min(self.pos_to_ind(haz[:2] + height2)[0], max_inds[0])
392 bot = max(self.pos_to_ind(haz[:2] - height2)[0], 0)
394 # TODO: more efficent indexing?
395 for row in range(bot, top):
396 for col in range(left, right):
397 self._map[row, col] += haz[4]
399 def get_map_cost(self, indices):
400 """Returns the cost of being at the given map indices.
402 Parameters
403 ----------
404 indices : 2 numpy array
405 Row/col indices.
407 Returns
408 -------
409 float
410 Cost of the gird node.
411 """
412 return self._map[indices[0], indices[1]]
414 def plan(
415 self,
416 start_pos,
417 end_pos,
418 show_animation=False,
419 save_animation=False,
420 plt_opts=None,
421 ttl=None,
422 fig=None,
423 ):
424 """Runs the search algorithm.
426 The setup functoins should be called prior to this.
428 Parameters
429 ----------
430 start_pos : 2 numpy array
431 Starting x/y pos in real units.
432 end_pos : 2 numpy array
433 ending x/y pos in real units.
434 show_animation : bool, optional
435 Flag indicating if an animated plot should be shown during planning.
436 The default is False. If shown, escape key can be used to quit.
437 save_animation : bool, optional
438 Flag for saving each frame of the animation as a PIL image. This
439 can later be saved to a gif. The default is False. The animation must
440 be shown for this to have an effect.
441 plt_opts : dict, optional
442 Additional options for the plot from
443 :meth:`gncpy.plotting.init_plotting_opts`. The default is None.
444 ttl : string, optional
445 title string of the plot. The default is None which gives a generic
446 name of the algorithm.
447 fig : matplotlib figure, optional
448 Figure object to plot on. The default is None which makes a new
449 figure. If a figure is provided, then only the ending state,
450 searched nodes, and best path are added to the figure.
452 Returns
453 -------
454 path : Nx2 numpy array
455 Real positions of the path.
456 cost : float
457 total cost of the path.
458 fig : matplotlib figure
459 Figure that was drawn, None if the animation is not shown.
460 frame_list : list
461 Each element is a PIL image corresponding to an animation frame.
462 """
463 startNode = Node(
464 self.pos_to_ind(start_pos),
465 self.get_map_cost(self.pos_to_ind(start_pos)),
466 -1,
467 )
468 endNode = Node(
469 self.pos_to_ind(end_pos), self.get_map_cost(self.pos_to_ind(end_pos)), -1
470 )
472 if not self.is_valid(startNode) or not self.is_valid(endNode):
473 return np.array((0, start_pos.size)), float('inf'), fig, []
475 frame_list = []
477 if show_animation:
478 if fig is None:
479 fig = plt.figure()
480 fig.add_subplot(1, 1, 1)
482 # fig.axes[0].grid(True)
483 fig.axes[0].set_aspect("equal", adjustable="box")
485 fig.axes[0].set_xlim((self.min[0], self.max[0]))
486 fig.axes[0].set_ylim((self.min[1], self.max[1]))
488 if plt_opts is None:
489 plt_opts = gplot.init_plotting_opts(f_hndl=fig)
491 if ttl is None:
492 ttl = "A* Pathfinding"
493 if self.use_beam_search:
494 ttl += " with Beam Search"
496 gplot.set_title_label(fig, 0, plt_opts, ttl=ttl)
498 # draw map
499 self.draw_start(fig, startNode)
500 self.draw_map(fig)
501 fig.tight_layout()
503 self.draw_end(fig, endNode)
504 plt.pause(0.01)
506 # for stopping simulation with the esc key.
507 fig.canvas.mpl_connect(
508 "key_release_event",
509 lambda event: [exit(0) if event.key == "escape" else None],
510 )
511 fig_w, fig_h = fig.canvas.get_width_height()
513 # save first frame of animation
514 if save_animation:
515 with io.BytesIO() as buff:
516 fig.savefig(buff, format="raw")
517 buff.seek(0)
518 img = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape(
519 (fig_h, fig_w, -1)
520 )
521 frame_list.append(Image.fromarray(img))
523 else:
524 fig = None
526 open_set = {}
527 closed_set = {}
528 ind = self.ravel_ind(startNode.indices)
529 open_set[ind] = startNode
531 while len(open_set) > 0:
532 s_inds = np.argsort(
533 [
534 n.cost + self.calc_weight(n) * self.calc_heuristic(endNode, n)
535 for n in open_set.values()
536 ]
537 )
538 keys = list(open_set.keys())
539 c_ind = keys[s_inds[0]]
540 curNode = open_set[c_ind]
541 if self.use_beam_search and s_inds.size > self.beam_search_max_nodes:
542 for ii in s_inds[: self.beam_search_max_nodes - 1 : -1]:
543 del open_set[keys[ii]]
545 del open_set[c_ind]
547 if show_animation:
548 pos = self.ind_to_pos(curNode.indices)
549 fig.axes[0].scatter(pos[0], pos[1], marker="x", color="c")
550 if len(closed_set) % 10 == 0:
551 plt.pause(0.001)
552 # save frame after pause to make sure frame is drawn
553 if save_animation:
554 with io.BytesIO() as buff:
555 fig.savefig(buff, format="raw")
556 buff.seek(0)
557 img = np.frombuffer(
558 buff.getvalue(), dtype=np.uint8
559 ).reshape((fig_h, fig_w, -1))
560 frame_list.append(Image.fromarray(img))
562 if np.all(curNode.indices == endNode.indices):
563 endNode.parent_idx = curNode.parent_idx
564 endNode.cost = curNode.cost
565 break
567 closed_set[c_ind] = curNode
569 for m in self.motion:
570 new_indices = curNode.indices + m[:2].astype(int)
571 node = Node(new_indices, 0, c_ind)
572 if not self.is_valid(node):
573 continue
574 node.cost = curNode.cost + m[2] + self.get_map_cost(new_indices)
576 n_ind = self.ravel_ind(node.indices)
578 if n_ind in closed_set:
579 continue
581 if n_ind not in open_set or open_set[n_ind].cost > node.cost:
582 open_set[n_ind] = node
584 path, cost = self.final_path(endNode, closed_set)
586 if show_animation:
587 fig.axes[0].plot(path[:, 0], path[:, 1], linestyle="-", color="g")
588 plt.pause(0.001)
589 if save_animation:
590 with io.BytesIO() as buff:
591 fig.savefig(buff, format="raw")
592 buff.seek(0)
593 img = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape(
594 (fig_h, fig_w, -1)
595 )
596 frame_list.append(Image.fromarray(img))
598 return path, cost, fig, frame_list