Coverage for src/gncpy/planning/a_star.py: 0%

191 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-13 06:15 +0000

1"""Implements the A* algorithm and several variations.""" 

2import io 

3import numpy as np 

4import matplotlib.pyplot as plt 

5from matplotlib.patches import Rectangle 

6from PIL import Image 

7from sys import exit 

8 

9import gncpy.plotting as gplot 

10 

11 

12class Node: 

13 """Helper class for grid nodes in A* planning. 

14 

15 Attributes 

16 ---------- 

17 indices : 2 numpy array 

18 row/column index into grid. 

19 cost : float 

20 cost of the path up to this node. 

21 parent_idx : int 

22 index of the parent node in the closed set 

23 """ 

24 

25 def __init__(self, indices, cost, parent_idx): 

26 """Initialize an object. 

27 

28 Parameters 

29 ---------- 

30 indices : 2 numpy array 

31 row/column index into grid. 

32 cost : float 

33 cost of the path up to this node. 

34 parent_idx : int 

35 index of the parent node in the closed set 

36 """ 

37 self.indices = indices 

38 self.cost = cost 

39 self.parent_idx = parent_idx 

40 

41 

42class AStar: 

43 """Implements various forms of the A* gird search algorithm. 

44 

45 This is based on the example implementation found 

46 `here <https://github.com/AtsushiSakai/PythonRobotics>`_. It currently 

47 implements 

48 

49 - Normal A* 

50 - Beam search 

51 - Weighted 

52 

53 Attributes 

54 ---------- 

55 resolution : 2 numpy array 

56 Real distance per gird square for x and y dimensions. Should not be set 

57 directly, use the :meth:`.set_map` function. 

58 min : 2 numpy array 

59 Minimum x and y positions in real units. Should not be set directly, use 

60 the :meth:`.set_map` function. 

61 max : 2 numpy array 

62 Maximum x/y positions in real units. Should not be set directly, use 

63 the :meth:`.set_map` function. 

64 weight : float 

65 Constant weighting factor applied to the heuristic. The default is 1 and 

66 is only used if the weighting funtion is not overwritten. 

67 use_beam_search : bool 

68 Flag indicating if the beam search variation is used. 

69 beam_search_max_nodes : int 

70 Maximum nuber of grid nodes to keep when using the beam search variation. 

71 motion : 8 x 3 numpy array 

72 Each row represents a potential action, with the first column being row 

73 motion in gird squares, second column being column motion, and third 

74 being the cost of moving. 

75 """ 

76 

77 def __init__(self, use_beam_search=False, beam_search_max_nodes=30): 

78 """Initialize an object. 

79 

80 Parameters 

81 ---------- 

82 use_beam_search : bool, optional 

83 Flag indicating if the beam search variation is used. The default 

84 is False. 

85 beam_search_max_nodes : int, optional 

86 Maximum nuber of grid nodes to keep when using the beam search 

87 variation. The default is 30. 

88 """ 

89 self.resolution = np.nan * np.ones(2) 

90 self.min = np.nan * np.ones(2) 

91 self.max = np.nan * np.ones(2) 

92 

93 self.weight = 1 

94 self.use_beam_search = use_beam_search 

95 self.beam_search_max_nodes = beam_search_max_nodes 

96 

97 self.motion = np.array( 

98 [ 

99 [1, 0, 1], 

100 [0, 1, 1], 

101 [-1, 0, 1], 

102 [0, -1, 1], 

103 [-1, -1, np.sqrt(2)], 

104 [-1, 1, np.sqrt(2)], 

105 [1, -1, np.sqrt(2)], 

106 [1, 1, np.sqrt(2)], 

107 ] 

108 ) 

109 

110 self._map = np.array([[]]) 

111 self._obstacles = None 

112 self._hazards = None 

113 

114 def ind_to_pos(self, indices): 

115 """Convert a set of row/colum indices to real positions. 

116 

117 Note that rows are y positions and columns are x positions, this function 

118 handles the conversion. 

119 

120 Parameters 

121 ---------- 

122 indices : 2 numpy array 

123 row/column index into grid. 

124 

125 Returns 

126 ------- 

127 2 numpy array 

128 real x/y position. 

129 """ 

130 # NOTE: row is y, column is x 

131 return indices[[1, 0]] * self.resolution + self.min + self.resolution / 2 

132 

133 def pos_to_ind(self, pos): 

134 """Convert a set of x/y positions to grid indices. 

135 

136 Note that rows are y positions and columns are x positions, this function 

137 handles the conversion. This does not bounds check the indices. 

138 

139 Parameters 

140 ---------- 

141 pos : 2 numpy array 

142 Real x/y position. 

143 

144 Returns 

145 ------- 

146 2 numpy array 

147 grid row/column index. 

148 """ 

149 # NOTE: x is column, y is row, bound to given area 

150 p = np.max(np.vstack((pos.ravel(), self.min.ravel())), axis=0) 

151 p = np.min(np.vstack((p, self.max.ravel())), axis=0) 

152 inds = np.floor( 

153 (p - self.min.ravel() + self.resolution / 2) / self.resolution.ravel() 

154 )[[1, 0]] 

155 

156 return inds.astype(int) 

157 

158 def ravel_ind(self, multi_index): 

159 """Convert row/column indices into a single index into flattened array. 

160 

161 Parameters 

162 ---------- 

163 multi_index : 2 numpy array 

164 Row/col indices into the gird. 

165 

166 Returns 

167 ------- 

168 float 

169 Flattened index 

170 """ 

171 return np.ravel_multi_index(multi_index, self._map.shape).item() 

172 

173 def final_path(self, endNode, closed_set): 

174 """Calculate the final path and total cost. 

175 

176 Parameters 

177 ---------- 

178 endNode : :class`.Node` 

179 Ending position. 

180 closed_set : dict 

181 Dictionary of path nodes. The keys are flattened indices and the 

182 values are :class:`.Node`. 

183 

184 Returns 

185 ------- 

186 N x 2 numpy array 

187 Real x/y positions of the path nodes. 

188 float 

189 total cost of the path. 

190 """ 

191 path = np.nan * np.ones((len(closed_set) + 1, endNode.indices.size)) 

192 path[0, :] = self.ind_to_pos(endNode.indices) 

193 parent_idx = endNode.parent_idx 

194 p_ind = 1 

195 while parent_idx > 0: 

196 curNode = closed_set[parent_idx] 

197 path[p_ind, :] = self.ind_to_pos(curNode.indices) 

198 p_ind += 1 

199 parent_idx = curNode.parent_idx 

200 

201 # remove all the extras 

202 path = path[:p_ind, :] 

203 

204 # flip so first row is starting point 

205 return path[::-1, :], endNode.cost 

206 

207 def calc_weight(self, node): 

208 """Calculates the weight of applied to the heuristic. 

209 

210 This can be overriddent to calculate custom weights. 

211 

212 Parameters 

213 ---------- 

214 node : :class:`.Node` 

215 Current node. 

216 

217 Returns 

218 ------- 

219 float 

220 weight value. 

221 """ 

222 return self.weight 

223 

224 def calc_heuristic(self, endNode, curNode): 

225 """Calculates the heuristic cost. 

226 

227 Parameters 

228 ---------- 

229 endNode : :class:`.Node` 

230 Goal node. 

231 curNode : :class:`.Node` 

232 Current node. 

233 

234 Returns 

235 ------- 

236 float 

237 heuristic cost. 

238 """ 

239 diff = self.ind_to_pos(endNode.indices) - self.ind_to_pos(curNode.indices) 

240 return np.sqrt(np.sum(diff * diff)) 

241 

242 def is_valid(self, node): 

243 """Checks if the node is valid. 

244 

245 Bounds checks the indices and determines if node corresponds to a wall. 

246 

247 Parameters 

248 ---------- 

249 node : :class:`.Node` 

250 Node to check. 

251 

252 Returns 

253 ------- 

254 bool 

255 True if the node is valid. 

256 """ 

257 pos = self.ind_to_pos(node.indices) 

258 

259 # bounds check position 

260 if np.any(pos < self.min) or np.any(pos > self.max): 

261 return False 

262 

263 # obstacles are inf 

264 if np.isinf(self._map[node.indices[0], node.indices[1]]): 

265 return False 

266 

267 return True 

268 

269 def draw_start(self, fig, startNode): 

270 """Draw starting node on the figure. 

271 

272 Parameters 

273 ---------- 

274 fig : matplotlib figure 

275 Figure to draw on. 

276 startNode : :class:`.Node` 

277 Starting node. 

278 """ 

279 pos = self.ind_to_pos(startNode.indices) - self.resolution / 2 

280 fig.axes[0].add_patch( 

281 Rectangle( 

282 pos, self.resolution[0], self.resolution[1], color="g", zorder=1000 

283 ), 

284 ) 

285 

286 def draw_end(self, fig, endNode): 

287 """Draw the ending node on the figure. 

288 

289 Parameters 

290 ---------- 

291 fig : matplotlib figure 

292 Figure to draw on. 

293 endNode : :class:`.Node` 

294 Ending node. 

295 """ 

296 pos = self.ind_to_pos(endNode.indices) - self.resolution / 2 

297 fig.axes[0].add_patch( 

298 Rectangle( 

299 pos, self.resolution[0], self.resolution[1], color="r", zorder=1000 

300 ), 

301 ) 

302 

303 def draw_map(self, fig): 

304 """Draws the map to the figure. 

305 

306 Parameters 

307 ---------- 

308 fig : matplotlib figure 

309 Figure to draw on. 

310 """ 

311 rows, cols = np.where(np.isinf(self._map)) 

312 inds = np.vstack((rows, cols)).T 

313 for ii in inds: 

314 pos = self.ind_to_pos(ii) 

315 fig.axes[0].add_patch( 

316 Rectangle( 

317 pos - self.resolution / 2, 

318 self.resolution[0], 

319 self.resolution[1], 

320 facecolor="k", 

321 ) 

322 ) 

323 

324 rows, cols = np.where(self._map > 0) 

325 inds = np.vstack((rows, cols)).T 

326 for ii in inds: 

327 pos = self.ind_to_pos(ii) 

328 fig.axes[0].add_patch( 

329 Rectangle( 

330 pos - self.resolution / 2, 

331 self.resolution[0], 

332 self.resolution[1], 

333 facecolor=(255 / 255, 140 / 255, 0 / 255), 

334 zorder=-1000, 

335 ) 

336 ) 

337 

338 def set_map(self, min_pos, max_pos, grid_res, obstacles=None, hazards=None): 

339 """Sets up the map with obstacles and hazards. 

340 

341 Parameters 

342 ---------- 

343 min_pos : 2 numpy array 

344 Min x/y position in real units. 

345 max_pos : 2 numpy array 

346 Max x/y position in real units. 

347 grid_res : 2 numpy array 

348 Real distance per gird square for x/y positions. 

349 obstacles : N x 4 numpy array, optional 

350 Locations of walls, each row is one wall. First column is x position 

351 second column is y position, third is width, and fourth is height. 

352 The position is the location of the center. All distances are in 

353 real units. The default is None. 

354 hazards : N x 5 numpy array, optional 

355 Locations of walls, each row is one wall. First column is x position 

356 second column is y position, third is width, fourth is height, and 

357 the last is the cost of being on that node. The position is the 

358 location of the center. All distances are in real units. The 

359 default is None. 

360 """ 

361 self.resolution = grid_res.ravel() 

362 self.min = min_pos.ravel() - self.resolution / 2 

363 self.max = max_pos.ravel() + self.resolution / 2 

364 

365 self._obstacles = obstacles 

366 self._hazards = hazards 

367 

368 max_inds = self.pos_to_ind(self.max) 

369 self._map = np.zeros([ii + 1 for ii in max_inds.tolist()]) 

370 

371 if self._obstacles is not None: 

372 for obs in self._obstacles: 

373 width2 = np.array([max(obs[2], self.resolution[0]) / 2, 0]) 

374 height2 = np.array([0, max(obs[3], self.resolution[1]) / 2]) 

375 left = max(self.pos_to_ind(obs[:2] - width2)[1], 0) 

376 right = min(self.pos_to_ind(obs[:2] + width2)[1], max_inds[1]) 

377 top = min(self.pos_to_ind(obs[:2] + height2)[0], max_inds[0]) 

378 bot = max(self.pos_to_ind(obs[:2] - height2)[0], 0) 

379 

380 # TODO: more efficent indexing? 

381 for row in range(bot, top): 

382 for col in range(left, right): 

383 self._map[row, col] = np.inf 

384 

385 if self._hazards is not None: 

386 for haz in self._hazards: 

387 width2 = np.array([max(haz[2], self.resolution[0]) / 2, 0]) 

388 height2 = np.array([0, max(haz[3], self.resolution[1]) / 2]) 

389 left = max(self.pos_to_ind(haz[:2] - width2)[1], 0) 

390 right = min(self.pos_to_ind(haz[:2] + width2)[1], max_inds[1]) 

391 top = min(self.pos_to_ind(haz[:2] + height2)[0], max_inds[0]) 

392 bot = max(self.pos_to_ind(haz[:2] - height2)[0], 0) 

393 

394 # TODO: more efficent indexing? 

395 for row in range(bot, top): 

396 for col in range(left, right): 

397 self._map[row, col] += haz[4] 

398 

399 def get_map_cost(self, indices): 

400 """Returns the cost of being at the given map indices. 

401 

402 Parameters 

403 ---------- 

404 indices : 2 numpy array 

405 Row/col indices. 

406 

407 Returns 

408 ------- 

409 float 

410 Cost of the gird node. 

411 """ 

412 return self._map[indices[0], indices[1]] 

413 

414 def plan( 

415 self, 

416 start_pos, 

417 end_pos, 

418 show_animation=False, 

419 save_animation=False, 

420 plt_opts=None, 

421 ttl=None, 

422 fig=None, 

423 ): 

424 """Runs the search algorithm. 

425 

426 The setup functoins should be called prior to this. 

427 

428 Parameters 

429 ---------- 

430 start_pos : 2 numpy array 

431 Starting x/y pos in real units. 

432 end_pos : 2 numpy array 

433 ending x/y pos in real units. 

434 show_animation : bool, optional 

435 Flag indicating if an animated plot should be shown during planning. 

436 The default is False. If shown, escape key can be used to quit. 

437 save_animation : bool, optional 

438 Flag for saving each frame of the animation as a PIL image. This 

439 can later be saved to a gif. The default is False. The animation must 

440 be shown for this to have an effect. 

441 plt_opts : dict, optional 

442 Additional options for the plot from 

443 :meth:`gncpy.plotting.init_plotting_opts`. The default is None. 

444 ttl : string, optional 

445 title string of the plot. The default is None which gives a generic 

446 name of the algorithm. 

447 fig : matplotlib figure, optional 

448 Figure object to plot on. The default is None which makes a new 

449 figure. If a figure is provided, then only the ending state, 

450 searched nodes, and best path are added to the figure. 

451 

452 Returns 

453 ------- 

454 path : Nx2 numpy array 

455 Real positions of the path. 

456 cost : float 

457 total cost of the path. 

458 fig : matplotlib figure 

459 Figure that was drawn, None if the animation is not shown. 

460 frame_list : list 

461 Each element is a PIL image corresponding to an animation frame. 

462 """ 

463 startNode = Node( 

464 self.pos_to_ind(start_pos), 

465 self.get_map_cost(self.pos_to_ind(start_pos)), 

466 -1, 

467 ) 

468 endNode = Node( 

469 self.pos_to_ind(end_pos), self.get_map_cost(self.pos_to_ind(end_pos)), -1 

470 ) 

471 

472 if not self.is_valid(startNode) or not self.is_valid(endNode): 

473 return np.array((0, start_pos.size)), float('inf'), fig, [] 

474 

475 frame_list = [] 

476 

477 if show_animation: 

478 if fig is None: 

479 fig = plt.figure() 

480 fig.add_subplot(1, 1, 1) 

481 

482 # fig.axes[0].grid(True) 

483 fig.axes[0].set_aspect("equal", adjustable="box") 

484 

485 fig.axes[0].set_xlim((self.min[0], self.max[0])) 

486 fig.axes[0].set_ylim((self.min[1], self.max[1])) 

487 

488 if plt_opts is None: 

489 plt_opts = gplot.init_plotting_opts(f_hndl=fig) 

490 

491 if ttl is None: 

492 ttl = "A* Pathfinding" 

493 if self.use_beam_search: 

494 ttl += " with Beam Search" 

495 

496 gplot.set_title_label(fig, 0, plt_opts, ttl=ttl) 

497 

498 # draw map 

499 self.draw_start(fig, startNode) 

500 self.draw_map(fig) 

501 fig.tight_layout() 

502 

503 self.draw_end(fig, endNode) 

504 plt.pause(0.01) 

505 

506 # for stopping simulation with the esc key. 

507 fig.canvas.mpl_connect( 

508 "key_release_event", 

509 lambda event: [exit(0) if event.key == "escape" else None], 

510 ) 

511 fig_w, fig_h = fig.canvas.get_width_height() 

512 

513 # save first frame of animation 

514 if save_animation: 

515 with io.BytesIO() as buff: 

516 fig.savefig(buff, format="raw") 

517 buff.seek(0) 

518 img = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape( 

519 (fig_h, fig_w, -1) 

520 ) 

521 frame_list.append(Image.fromarray(img)) 

522 

523 else: 

524 fig = None 

525 

526 open_set = {} 

527 closed_set = {} 

528 ind = self.ravel_ind(startNode.indices) 

529 open_set[ind] = startNode 

530 

531 while len(open_set) > 0: 

532 s_inds = np.argsort( 

533 [ 

534 n.cost + self.calc_weight(n) * self.calc_heuristic(endNode, n) 

535 for n in open_set.values() 

536 ] 

537 ) 

538 keys = list(open_set.keys()) 

539 c_ind = keys[s_inds[0]] 

540 curNode = open_set[c_ind] 

541 if self.use_beam_search and s_inds.size > self.beam_search_max_nodes: 

542 for ii in s_inds[: self.beam_search_max_nodes - 1 : -1]: 

543 del open_set[keys[ii]] 

544 

545 del open_set[c_ind] 

546 

547 if show_animation: 

548 pos = self.ind_to_pos(curNode.indices) 

549 fig.axes[0].scatter(pos[0], pos[1], marker="x", color="c") 

550 if len(closed_set) % 10 == 0: 

551 plt.pause(0.001) 

552 # save frame after pause to make sure frame is drawn 

553 if save_animation: 

554 with io.BytesIO() as buff: 

555 fig.savefig(buff, format="raw") 

556 buff.seek(0) 

557 img = np.frombuffer( 

558 buff.getvalue(), dtype=np.uint8 

559 ).reshape((fig_h, fig_w, -1)) 

560 frame_list.append(Image.fromarray(img)) 

561 

562 if np.all(curNode.indices == endNode.indices): 

563 endNode.parent_idx = curNode.parent_idx 

564 endNode.cost = curNode.cost 

565 break 

566 

567 closed_set[c_ind] = curNode 

568 

569 for m in self.motion: 

570 new_indices = curNode.indices + m[:2].astype(int) 

571 node = Node(new_indices, 0, c_ind) 

572 if not self.is_valid(node): 

573 continue 

574 node.cost = curNode.cost + m[2] + self.get_map_cost(new_indices) 

575 

576 n_ind = self.ravel_ind(node.indices) 

577 

578 if n_ind in closed_set: 

579 continue 

580 

581 if n_ind not in open_set or open_set[n_ind].cost > node.cost: 

582 open_set[n_ind] = node 

583 

584 path, cost = self.final_path(endNode, closed_set) 

585 

586 if show_animation: 

587 fig.axes[0].plot(path[:, 0], path[:, 1], linestyle="-", color="g") 

588 plt.pause(0.001) 

589 if save_animation: 

590 with io.BytesIO() as buff: 

591 fig.savefig(buff, format="raw") 

592 buff.seek(0) 

593 img = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape( 

594 (fig_h, fig_w, -1) 

595 ) 

596 frame_list.append(Image.fromarray(img)) 

597 

598 return path, cost, fig, frame_list