Coverage for src/gncpy/planning/rrt_star.py: 67%

429 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-13 06:15 +0000

1"""Implements variations of the RRT* algorithm.""" 

2import io 

3import numpy as np 

4import matplotlib.pyplot as plt 

5from matplotlib.patches import Circle, Rectangle 

6from copy import deepcopy 

7from PIL import Image 

8from sys import exit 

9 

10import gncpy.control as gctrl 

11import gncpy.plotting as gplot 

12 

13 

14class Node: # Each node on the tree has these properties 

15 """Helper class for nodes in the tree. 

16 

17 Attributes 

18 ---------- 

19 sv : N numpy array 

20 State vector 

21 u : Nu numpy array 

22 Control vector 

23 path : N x Np numpy array 

24 All state vectors in the path 

25 parent : :class:`.Node` 

26 Parent node. 

27 cost : float 

28 Node cost value. 

29 """ 

30 

31 def __init__(self, state): 

32 """Initialize an object. 

33 

34 Parameters 

35 ---------- 

36 state : numpy array 

37 state vector for the node. 

38 """ 

39 self.sv = state.ravel() 

40 self.u = [] 

41 self.path = [] 

42 self.parent = [] 

43 self.cost = 0 

44 

45 

46class LQRRRTStar: 

47 """Implements the LQR-RRT* algorithm. 

48 

49 Attributes 

50 ---------- 

51 rng : numpy random number generator 

52 Instance of a random number generator to use. 

53 start : N x 1 numpy array 

54 Starting state 

55 end : N x 1 numpy array 

56 Ending state 

57 node_list : list 

58 List of nodes in the tree. 

59 min_rand : Np numpy array 

60 Minimum position values when generating random nodes 

61 max_rand : Np numpy array 

62 Maximum position values when generating radnom nodes 

63 pos_inds : list 

64 List of position indices in the state vector 

65 sampling_fun : callable 

66 Function that returns a random sample from the state space. Must 

67 take rng, pos_inds, min_rand, max_rand as inputs and return a numpy array 

68 of the same size as the state vector. 

69 goal_sample_rate : float 

70 max_iter : int 

71 Maximum iterations to search for. 

72 connect_circle_dist : float 

73 step_size : float 

74 expand_dis : float 

75 ell_con : float 

76 Ellipsoidal constraint value 

77 Nobs : int 

78 Number of obstacles. 

79 obstacle_list : Nobs x (3 or 4) numpy array 

80 Each row describes an obstacle; x pos, y pos, (z pos), radius. 

81 P : (2 or 3) x (2 or 3) x Nobs numpy array 

82 States of each obstacle for collision checking. 

83 planner : :class:`gncpy.control.LQR` 

84 Planner class instance for predicting paths. 

85 S : N x N numpy array 

86 Solution to the discrete time algebraic riccatti equation. 

87 """ 

88 

89 def __init__( 

90 self, 

91 planner=None, 

92 goal_sample_rate=10, 

93 max_iter=300, 

94 connect_circle_dist=2, 

95 step_size=1, 

96 expand_dis=1, 

97 rng=None, 

98 sampling_fun=None, 

99 ): 

100 

101 if rng is None: 

102 rng = np.random.default_rng() 

103 self.rng = rng 

104 

105 self.start = None 

106 self.end = None 

107 self.node_list = [] 

108 self.min_rand = np.array([]) 

109 self.max_rand = np.array([]) 

110 self.pos_inds = [] 

111 

112 self.sampling_fun = sampling_fun 

113 

114 self.goal_sample_rate = goal_sample_rate 

115 self.max_iter = max_iter 

116 self.connect_circle_dist = connect_circle_dist 

117 self.step_size = step_size 

118 self.expand_dis = expand_dis 

119 self.gif_frame_skip = 1 

120 

121 # Obstacles 

122 self.ell_con = None 

123 self.Nobs = 0 

124 self.obstacle_list = np.array([]) 

125 self.P = np.zeros((self.numPos, self.numPos, self.Nobs)) 

126 

127 # setup LQR planner 

128 self.planner = planner 

129 self.S = np.array([[]]) 

130 self.planner_args = {} 

131 

132 self._disp = False 

133 

134 # plotting helpers 

135 self._fig = None 

136 self._plt_inds = None 

137 self._show_planner = False 

138 self._frame_list = [] 

139 self._fig_h = None 

140 self._fig_w = None 

141 self._plt_opts = gplot.init_plotting_opts() 

142 

143 @property 

144 def numPos(self): 

145 return len(self.pos_inds) 

146 

147 def set_control_model(self, planner, pos_inds, controller_args=None): 

148 if not isinstance(planner, gctrl.LQR): 

149 raise TypeError("Must specify an LQR instance") 

150 self.planner = planner 

151 self.pos_inds = pos_inds 

152 if controller_args is not None: 

153 self.planner_args = controller_args 

154 

155 def set_environment(self, search_area=None, obstacles=None): 

156 if search_area is not None: 

157 self.min_rand = search_area[0, :] 

158 self.max_rand = search_area[1, :] 

159 

160 if obstacles is not None: 

161 self.ell_con = 1 

162 self.Nobs = obstacles.shape[0] 

163 self.obstacle_list = obstacles 

164 

165 dim = int(max(obstacles.shape[1] - 1, 0)) 

166 

167 self.P = np.zeros((dim, dim, self.Nobs)) 

168 for k, r in enumerate(self.obstacle_list[:, -1]): 

169 self.P[:, :, k] = r ** (-2) * np.eye(dim) 

170 

171 @property 

172 def use_box(self): 

173 return self.numPos == 2 and self.obstacle_list.shape[1] == 4 

174 

175 def dx( 

176 self, x1, x2 

177 ): # x1 goes to x2. Difference between current state to Reference State 

178 return (x1 - x2).reshape((-1, 1)) 

179 

180 def _make_sphere(self, xc, yc, zc, r): 

181 u = np.linspace(0, 2 * np.pi, 39) 

182 v = np.linspace(0, np.pi, 21) 

183 

184 x = r * np.outer(np.cos(u), np.sin(v)) + xc 

185 y = r * np.outer(np.sin(u), np.sin(v)) + yc 

186 z = r * np.outer(np.ones(np.size(u)), np.cos(v)) + zc 

187 

188 return x, y, z 

189 

190 def draw_obstacles(self, ax): 

191 for obs in self.obstacle_list: 

192 if self.numPos == 2: 

193 if self.use_box: 

194 xy = obs[:2] - obs[2:4] / 2 

195 p = Rectangle(xy, obs[2], obs[3], color="k", zorder=1000) 

196 else: 

197 p = Circle(obs[:2], radius=obs[-1], color="k", zorder=1000) 

198 self._fig.axes[ax].add_patch(p) 

199 

200 elif self.numPos == 3: 

201 self._fig.axes[ax].plot_surface( 

202 *self._make_sphere(*obs), rstride=3, cstride=3, color="k" 

203 ) 

204 

205 def draw_start(self, start, ax): 

206 self._fig.axes[ax].scatter( 

207 *start.ravel()[self._plt_inds], color="g", marker="o", zorder=1000, 

208 ) 

209 

210 def draw_end(self, es): 

211 self._fig.axes[0].scatter( 

212 *es[self._plt_inds], color="r", marker="x", zorder=1000, 

213 ) 

214 

215 def save_frame(self): 

216 plt.pause(0.005) 

217 with io.BytesIO() as buff: 

218 self._fig.savefig(buff, format="raw") 

219 buff.seek(0) 

220 img = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape( 

221 (self._fig_h, self._fig_w, -1) 

222 ) 

223 self._frame_list.append(Image.fromarray(img)) 

224 

225 def reset_controller_plot(self): 

226 self._fig.axes[1].clear() 

227 

228 self._fig.axes[1].set_aspect("equal", adjustable="box") 

229 self._fig.axes[1].set_xlim((self.min_rand[0], self.max_rand[0])) 

230 self._fig.axes[1].set_ylim((self.min_rand[1], self.max_rand[1])) 

231 

232 gplot.set_title_label( 

233 self._fig, 1, self._plt_opts, x_lbl="x pos", use_local=True, 

234 ) 

235 

236 self.draw_obstacles(1) 

237 

238 plt.pause(0.01) 

239 

240 def plan_helper( 

241 self, 

242 cur_time, 

243 state_args, 

244 ctrl_args, 

245 use_convergence, 

246 use_first_traj, 

247 rtol, 

248 show_animation, 

249 save_animation, 

250 ): 

251 last_cost = float("inf") 

252 for i in range(self.max_iter): 

253 rnd = self.get_random_node() 

254 if self._disp: 

255 print( 

256 "\tIter: {:4d}, number of nodes: {:5d}, last cost: {:10.4f}".format( 

257 i, len(self.node_list), last_cost 

258 ) 

259 ) 

260 

261 nearest_ind = self.get_nearest_node_index( 

262 cur_time, rnd, state_args, ctrl_args 

263 ) 

264 new_node = self.steer( 

265 cur_time, 

266 self.node_list[nearest_ind], 

267 rnd, 

268 state_args, 

269 ctrl_args, 

270 show_controller=True, 

271 planner_ttl="Searching", 

272 ) 

273 if new_node is None: 

274 continue 

275 if self.not_colliding(new_node): 

276 near_indices = self.find_near_nodes(new_node, nearest_ind) 

277 new_node = self.choose_parent( 

278 cur_time, new_node, near_indices, state_args, ctrl_args 

279 ) 

280 if new_node: 

281 self.node_list.append(new_node) 

282 self.rewire(cur_time, new_node, near_indices, state_args, ctrl_args) 

283 

284 if show_animation and new_node is not None: 

285 if self.numPos == 2: 

286 self._fig.axes[0].plot( 

287 new_node.path[self._plt_inds[0], :], 

288 new_node.path[self._plt_inds[1], :], 

289 color=(0.5, 0.5, 0.5), 

290 alpha=0.15, 

291 zorder=-10, 

292 ) 

293 elif self.numPos == 3: 

294 self._fig.axes[0].plot( 

295 new_node.path[self._plt_inds[0], :], 

296 new_node.path[self._plt_inds[1], :], 

297 new_node.path[self._plt_inds[2], :], 

298 color=(0.5, 0.5, 0.5), 

299 alpha=0.1, 

300 zorder=-10, 

301 ) 

302 

303 plt.pause(0.005) 

304 if save_animation and i % self.gif_frame_skip == 0: 

305 self.save_frame() 

306 

307 if new_node: 

308 last_index = self.search_best_goal_node() 

309 if last_index: 

310 traj, u_traj = self.generate_final_course(last_index) 

311 cost = self.node_list[last_index].cost 

312 converged = ( 

313 use_convergence 

314 and np.abs(np.abs(last_cost - cost) / cost) < rtol 

315 ) 

316 last_cost = cost 

317 if use_first_traj or converged: 

318 if show_animation: 

319 if self.numPos == 2: 

320 self._fig.axes[0].plot( 

321 traj[self._plt_inds[0], :], 

322 traj[self._plt_inds[1], :], 

323 color="g", 

324 ) 

325 elif self.numPos == 3: 

326 self._fig.axes[0].plot( 

327 traj[self._plt_inds[0], :], 

328 traj[self._plt_inds[1], :], 

329 traj[self._plt_inds[2], :], 

330 color="g", 

331 ) 

332 plt.pause(0.005) 

333 

334 # save final frame 

335 if save_animation: 

336 self.save_frame() 

337 

338 details = (cost, u_traj.T, self._fig, self._frame_list) 

339 return traj.T, cost, u_traj.T 

340 

341 if self._disp: 

342 print("\tReached Max Iteration!!") 

343 

344 last_index = self.search_best_goal_node() 

345 if last_index: 

346 traj, u_traj = self.generate_final_course(last_index) 

347 cost = self.node_list[last_index].cost 

348 if show_animation: 

349 if self.numPos == 2: 

350 self._fig.axes[0].plot( 

351 traj[self._plt_inds[0], :], 

352 traj[self._plt_inds[1], :], 

353 color="g", 

354 ) 

355 elif self.numPos == 3: 

356 self._fig.axes[0].plot( 

357 traj[self._plt_inds[0], :], 

358 traj[self._plt_inds[1], :], 

359 traj[self._plt_inds[2], :], 

360 color="g", 

361 ) 

362 plt.pause(0.005) 

363 if save_animation: 

364 self.save_frame() 

365 else: 

366 traj = np.array([[]]) 

367 u_traj = np.array([[]]) 

368 cost = float("inf") 

369 if self._disp: 

370 print("\tCannot find path!!") 

371 

372 return traj.T, cost, u_traj.T 

373 

374 def plan( 

375 self, 

376 cur_time, 

377 cur_state, 

378 end_state, 

379 use_first_traj=True, 

380 use_convergence=False, 

381 rtol=1e-3, 

382 state_args=None, 

383 ctrl_args=None, 

384 provide_details=False, 

385 disp=True, 

386 show_animation=False, 

387 save_animation=False, 

388 plt_opts=None, 

389 ttl=None, 

390 fig=None, 

391 plt_inds=None, 

392 show_planner=True, 

393 ): 

394 if state_args is None: 

395 state_args = () 

396 if ctrl_args is None: 

397 ctrl_args = () 

398 if plt_inds is None: 

399 plt_inds = [0, 1] 

400 

401 if use_convergence: 

402 use_first_traj = False 

403 

404 self._disp = disp 

405 

406 self._plt_inds = plt_inds 

407 self._fig = fig 

408 

409 if len(end_state.shape) == 1: 

410 _end_arr = end_state.copy().reshape((1, -1)) 

411 else: 

412 _end_arr = end_state.copy().T # flip so each row is an ending state 

413 

414 self.start = Node(cur_state.reshape((-1, 1))) 

415 # self.end = Node(end_state.reshape((-1, 1))) 

416 # self.node_list = [ 

417 # deepcopy(self.start), 

418 # ] 

419 

420 self._frame_list = [] 

421 self._show_planner = False 

422 if show_animation: 

423 self._show_planner = show_planner and isinstance( 

424 self.planner, gctrl.ELQR 

425 ) 

426 if self._show_planner: 

427 n_plts = 2 

428 else: 

429 n_plts = 1 

430 

431 if self._fig is None: 

432 self._fig = plt.figure() 

433 if self.numPos == 2: 

434 for ii in range(n_plts): 

435 self._fig.add_subplot(1, n_plts, ii + 1) 

436 self._fig.axes[0].set_aspect("equal", adjustable="box") 

437 self._fig.axes[0].set_xlim((self.min_rand[0], self.max_rand[0])) 

438 self._fig.axes[0].set_ylim((self.min_rand[1], self.max_rand[1])) 

439 elif self.numPos == 3: 

440 self._fig.add_subplot(1, 1, 1, projection="3d") 

441 self._fig.axes[0].set_xlim((self.min_rand[0], self.max_rand[0])) 

442 self._fig.axes[0].set_ylim((self.min_rand[1], self.max_rand[1])) 

443 self._fig.axes[0].set_zlim((self.min_rand[2], self.max_rand[2])) 

444 

445 if plt_opts is None: 

446 self._plt_opts = gplot.init_plotting_opts(f_hndl=self._fig) 

447 if ttl is None: 

448 ttl = "LQR-RRT*" 

449 gplot.set_title_label( 

450 self._fig, 0, self._plt_opts, ttl=ttl, x_lbl="x pos", y_lbl="y pos" 

451 ) 

452 

453 self.draw_start(self.start.sv, 0) 

454 self.draw_obstacles(0) 

455 

456 if self._show_planner: 

457 self.reset_controller_plot() 

458 gplot.set_title_label( 

459 self._fig, 0, self._plt_opts, ttl="Tree", use_local=True 

460 ) 

461 

462 self.draw_end(_end_arr[0, :]) 

463 

464 if self._show_planner: 

465 self.planner_args.update( 

466 { 

467 "show_animation": True, 

468 "save_animation": save_animation, 

469 "ax_num": 1, 

470 "fig": self._fig, 

471 "plt_inds": self._plt_inds, 

472 "plt_opts": self._plt_opts, 

473 } 

474 ) 

475 

476 self._fig.tight_layout() 

477 plt.pause(0.1) 

478 

479 # for stopping simulation with the esc key. 

480 self._fig.canvas.mpl_connect( 

481 "key_release_event", 

482 lambda event: [exit(0) if event.key == "escape" else None], 

483 ) 

484 self._fig_w, self._fig_h = self._fig.canvas.get_width_height() 

485 

486 # save first frame of animation 

487 if save_animation: 

488 self.save_frame() 

489 

490 if self._disp: 

491 print("Starting LQR-RRT* Planning...") 

492 

493 cost = 0 

494 u_traj = np.array([]) 

495 traj = np.array([]) 

496 for kk, es in enumerate(_end_arr): 

497 if kk == 0: 

498 self.start = Node(cur_state.reshape((-1, 1))) 

499 self.end = Node(es.reshape((-1, 1))) 

500 self.node_list = [ 

501 deepcopy(self.start), 

502 ] 

503 else: 

504 self.start = Node(_end_arr[kk-1].reshape((-1, 1))) 

505 self.end = Node(es.reshape((-1, 1))) 

506 # NOTE: in order to resuse the tree, need to recacluate costs and flip paths, for now just create new tree 

507 self.node_list = [ 

508 deepcopy(self.start), 

509 ] 

510 

511 if show_animation: 

512 self.draw_end(es) 

513 

514 t, c, u = self.plan_helper( 

515 cur_time, 

516 state_args, 

517 ctrl_args, 

518 use_convergence, 

519 use_first_traj, 

520 rtol, 

521 show_animation, 

522 save_animation, 

523 ) 

524 if t.size == 0: 

525 break 

526 if kk == 0: 

527 traj = t 

528 u_traj = u 

529 else: 

530 traj = np.vstack((traj, t)) 

531 u_traj = np.vstack((u_traj, u)) 

532 cost += c 

533 

534 details = (cost, u_traj, self._fig, self._frame_list) 

535 return (traj, *details) if provide_details else traj 

536 

537 def generate_final_course(self, goal_index): # Generate Final Course 

538 path = self.end.sv.reshape(-1, 1) 

539 u_path = np.array([]) 

540 node = self.node_list[goal_index] 

541 while node.parent: 

542 i = np.flip(node.path, 1) 

543 path = np.append(path, i, axis=1) 

544 

545 j = np.flip(node.u, 1) 

546 if u_path.size == 0: 

547 u_path = j 

548 else: 

549 u_path = np.append(u_path, j, axis=1) 

550 

551 node = node.parent 

552 path = np.flip(path, 1) 

553 u_path = np.flip(u_path, 1) 

554 return path, u_path 

555 

556 def search_best_goal_node(self): # Finds Node closest to Goal Node 

557 dist_to_goal_list = [self.calc_dist_to_goal(node.sv) for node in self.node_list] 

558 goal_inds = [ 

559 dist_to_goal_list.index(i) 

560 for i in dist_to_goal_list 

561 if i <= self.expand_dis 

562 ] 

563 if not goal_inds: 

564 return None 

565 min_cost = min([self.node_list[i].cost for i in goal_inds]) 

566 for i in goal_inds: 

567 if self.node_list[i].cost == min_cost: 

568 return i 

569 return None 

570 

571 def calc_dist_to_goal(self, sv): # Calculate distance between Node and the Goal 

572 return np.linalg.norm(self.dx(sv, self.end.sv)) 

573 

574 def rewire( 

575 self, cur_time, new_node, near_inds, state_args, ctrl_args 

576 ): # Rewires the Nodes 

577 for i in near_inds: 

578 near_node = self.node_list[i] 

579 edge_node = self.steer( 

580 cur_time, 

581 new_node, 

582 near_node, 

583 state_args, 

584 ctrl_args, 

585 show_controller=True, 

586 planner_ttl="Rewiring", 

587 ) 

588 if edge_node is None: 

589 continue 

590 edge_node.cost = self.calc_new_cost( 

591 cur_time, new_node, near_node, state_args, ctrl_args 

592 ) 

593 no_collision = self.not_colliding(edge_node) 

594 improved_cost = near_node.cost > edge_node.cost 

595 

596 if no_collision and improved_cost: 

597 near_node = edge_node 

598 near_node.parent = new_node 

599 self.propagate_cost_to_leaves(cur_time, new_node, state_args, ctrl_args) 

600 

601 def propagate_cost_to_leaves( 

602 self, cur_time, parent_node, state_args, ctrl_args 

603 ): # Re-computes cost from rewired Nodes 

604 for node in self.node_list: 

605 if node.parent == parent_node: 

606 node.cost = self.calc_new_cost( 

607 cur_time, parent_node, node, state_args, ctrl_args 

608 ) 

609 self.propagate_cost_to_leaves(cur_time, node, state_args, ctrl_args) 

610 

611 def choose_parent( 

612 self, cur_time, new_node, near_inds, state_args, ctrl_args 

613 ): # Chooses a parent node with lowest cost 

614 

615 if not near_inds: 

616 return None 

617 costs = np.inf * np.ones(len(near_inds)) 

618 for kk, i in enumerate(near_inds): 

619 near_node = self.node_list[i] 

620 t_node = self.steer(cur_time, near_node, new_node, state_args, ctrl_args) 

621 if t_node and self.not_colliding(t_node): 

622 costs[kk] = self.calc_new_cost( 

623 cur_time, near_node, new_node, state_args, ctrl_args 

624 ) 

625 

626 min_ind = np.argmin(costs) 

627 min_cost = costs[min_ind] 

628 

629 if np.isinf(min_cost): 

630 if self._disp: 

631 print("\t\tNo Path - Infinite Cost") 

632 return None 

633 

634 new_node = self.steer( 

635 cur_time, 

636 self.node_list[near_inds[min_ind]], 

637 new_node, 

638 state_args, 

639 ctrl_args, 

640 show_controller=True, 

641 planner_ttl="Generating Node", 

642 ) 

643 new_node.parent = self.node_list[near_inds[min_ind]] 

644 new_node.cost = min_cost 

645 return new_node 

646 

647 def calc_new_cost( 

648 self, cur_time, from_node, to_node, state_args, ctrl_args 

649 ): # Calculates cost of node 

650 x_sim, u_sim = self.call_planner( 

651 cur_time, from_node.sv, to_node.sv, state_args, ctrl_args 

652 ) 

653 x_sim_sample, u_sim_sample, course_lens = self.sample_path(x_sim.T, u_sim.T) 

654 if len(x_sim_sample) == 0: 

655 return float("inf") 

656 return from_node.cost + sum(course_lens) 

657 

658 def find_near_nodes( 

659 self, new_node, nearest_ind 

660 ): # Finds near nodes close to new_node 

661 nnode = len(self.node_list) + 1 

662 dist_list = [ 

663 self.dx(node.sv, new_node.sv).T @ self.S @ self.dx(node.sv, new_node.sv) 

664 for node in self.node_list 

665 ] 

666 r = ( 

667 self.connect_circle_dist 

668 * np.amin(dist_list) 

669 * (np.log(nnode) / nnode) ** (1 / (new_node.sv.size - 1)) 

670 ) 

671 ind = [dist_list.index(i) for i in dist_list if i <= r] 

672 if not ind: 

673 ind = [nearest_ind] 

674 return ind 

675 

676 def not_colliding(self, node): # Check for collisions with Ellipsoids/circle/box 

677 if len(self.obstacle_list) == 0: 

678 return True 

679 

680 if self.use_box: 

681 x = node.path[self.pos_inds[0], :] 

682 y = node.path[self.pos_inds[1], :] 

683 for k, xobs in enumerate(self.obstacle_list): 

684 if self.use_box: 

685 left = xobs[0] - xobs[2] / 2 

686 right = xobs[0] + xobs[2] / 2 

687 top = xobs[1] + xobs[3] / 2 

688 bot = xobs[1] - xobs[3] / 2 

689 over_x = np.logical_and(x >= left, x <= right) 

690 over_y = np.logical_and(y >= bot, y <= top) 

691 if np.any(np.logical_and(over_x, over_y)): 

692 return False 

693 else: 

694 distxyz = (node.path[self.pos_inds, :].T - xobs[:-1].reshape((1, -1))).T 

695 d_list = np.einsum("ij,ij->i", (distxyz.T @ self.P[:, :, k]), distxyz.T) 

696 if -min(d_list) + self.ell_con >= 0.0: 

697 return False 

698 return True 

699 

700 def get_random_node(self): # Find a random node from the state space 

701 if self.rng.integers(0, high=100) > self.goal_sample_rate: 

702 rand_x = self.sampling_fun( 

703 self.rng, self.pos_inds, self.min_rand, self.max_rand 

704 ) 

705 rnd = Node(rand_x) 

706 else: # goal point sampling 

707 rnd = Node(self.end.sv.reshape((-1, 1))) 

708 return rnd 

709 

710 def get_nearest_node_index( 

711 self, cur_time, rnd_node, state_args, ctrl_args 

712 ): # Get nearest node index in tree 

713 self.S = self.planner.solve_dare( 

714 cur_time, 

715 rnd_node.sv.reshape((-1, 1)), 

716 state_args=state_args, 

717 ctrl_args=ctrl_args, 

718 )[0] 

719 dlist = np.array( 

720 [ 

721 ( 

722 self.dx(node.sv, rnd_node.sv).T 

723 @ self.S 

724 @ self.dx(node.sv, rnd_node.sv) 

725 ).item() 

726 for node in self.node_list 

727 ] 

728 ) 

729 return np.argmin(dlist) 

730 

731 def call_planner( 

732 self, cur_time, from_state, to_state, state_args, ctrl_args, **kwargs 

733 ): 

734 extra_args = dict() 

735 extra_args.update(**self.planner_args) 

736 extra_args.update(**kwargs) 

737 

738 if self._show_planner and extra_args["show_animation"]: 

739 self.reset_controller_plot() 

740 self.draw_start(from_state, 1) 

741 plt.pause(0.01) 

742 

743 _, _, x_traj, u_traj, _, f_lst = self.planner.calculate_control( 

744 cur_time, 

745 from_state.reshape((-1, 1)), 

746 end_state=to_state.reshape((-1, 1)), 

747 provide_details=True, 

748 state_args=state_args, 

749 ctrl_args=ctrl_args, 

750 **extra_args, 

751 ) 

752 

753 self._frame_list.extend(f_lst) 

754 return x_traj, u_traj 

755 

756 return self.planner.calculate_control( 

757 cur_time, 

758 from_state.reshape((-1, 1)), 

759 end_state=to_state.reshape((-1, 1)), 

760 provide_details=True, 

761 state_args=state_args, 

762 ctrl_args=ctrl_args, 

763 **extra_args, 

764 )[2:4] 

765 

766 def steer( 

767 self, 

768 cur_time, 

769 from_node, 

770 to_node, 

771 state_args, 

772 ctrl_args, 

773 show_controller=False, 

774 planner_ttl=None, 

775 ): # Obtain trajectory between from_node to to_node using LQR and save trajectory 

776 if self._show_planner: 

777 self.planner_args["show_animation"] = show_controller 

778 if self.planner_args.get("show_animation", False): 

779 self.planner_args["ttl"] = ( 

780 planner_ttl if planner_ttl is not None else "Controller" 

781 ) 

782 

783 x_sim, u_sim = self.call_planner( 

784 cur_time, from_node.sv, to_node.sv, state_args, ctrl_args 

785 ) 

786 

787 x_sim_sample, u_sim_sample, course_lens = self.sample_path(x_sim.T, u_sim.T) 

788 if len(x_sim_sample) == 0: 

789 return None 

790 newNode = Node(x_sim_sample[:, -1].reshape((-1, 1))) 

791 newNode.u = u_sim_sample 

792 newNode.path = x_sim_sample 

793 newNode.cost = from_node.cost + np.sum(np.abs(course_lens)) 

794 newNode.parent = from_node 

795 return newNode 

796 

797 def sample_path(self, x_sim, u_sim): # Interpolate path obtained by LQR 

798 x_sim_sample = [] 

799 u_sim_sample = [] 

800 if x_sim.size == 0: 

801 clen = [] 

802 return x_sim_sample, u_sim_sample, clen 

803 for i in range(x_sim.shape[1] - 1): 

804 for t in np.arange(0.0, 1.0, self.step_size): 

805 u_sim_sample.append(u_sim[:, i].tolist()) 

806 x_sample = (t * x_sim[:, i + 1] + (1.0 - t) * x_sim[:, i]).reshape( 

807 (-1, 1) 

808 )[:, 0] 

809 x_sim_sample.append(x_sample.tolist()) 

810 

811 # enforce shape if x_sim_sample is empty list 

812 x_sim_sample = np.array(x_sim_sample).T.reshape((x_sim.shape[0], -1)) 

813 u_sim_sample = np.array(u_sim_sample).T.reshape((u_sim.shape[0], -1)) 

814 

815 # diff_x_sim=np.diff(x_sim_sample); 

816 diff_x_sim2 = [ 

817 self.dx(x_sim_sample[:, k + 1], x_sim_sample[:, k])[:, 0] 

818 for k in range(x_sim_sample.shape[1] - 1) 

819 ] 

820 diff_x_sim = np.array(diff_x_sim2).T 

821 if diff_x_sim.size == 0: 

822 return [], [], [] 

823 clen = np.einsum("ij,ij->i", (diff_x_sim.T @ self.S), diff_x_sim.T) 

824 return x_sim_sample, u_sim_sample, clen 

825 

826 

827class ExtendedLQRRRTStar(LQRRRTStar): 

828 def __init__(self, **kwargs): 

829 super().__init__(**kwargs) 

830 self._inv_state_args = None 

831 self._inv_ctrl_args = None 

832 self._cost_args = None 

833 

834 def call_planner( 

835 self, cur_time, from_state, to_state, state_args, ctrl_args, **kwargs 

836 ): 

837 extra_args = dict(disp=False) 

838 extra_args.update(**self.planner_args) 

839 extra_args.update(**kwargs) 

840 extra_args.update( 

841 dict( 

842 cost_args=self._cost_args, 

843 inv_state_args=self._inv_state_args, 

844 inv_ctrl_args=self._inv_ctrl_args, 

845 ) 

846 ) 

847 return super().call_planner( 

848 cur_time, from_state, to_state, state_args, ctrl_args, **extra_args 

849 ) 

850 

851 def plan( 

852 self, 

853 cur_time, 

854 cur_state, 

855 end_state, 

856 inv_state_args=None, 

857 inv_ctrl_args=None, 

858 cost_args=None, 

859 **kwargs 

860 ): 

861 self._inv_state_args = inv_state_args 

862 self._inv_ctrl_args = inv_ctrl_args 

863 self._cost_args = cost_args 

864 if "ttl" not in kwargs: 

865 kwargs["ttl"] = "ELQR-RRT*" 

866 return super().plan(cur_time, cur_state, end_state, **kwargs)