Coverage for src/gncpy/planning/ 67%

429 statements  

« prev     ^ index     » next v7.2.7, created at 2023-07-19 05:48 +0000

1"""Implements variations of the RRT* algorithm.""" 

2import io 

3import numpy as np 

4import matplotlib.pyplot as plt 

5from matplotlib.patches import Circle, Rectangle 

6from copy import deepcopy 

7from PIL import Image 

8from sys import exit 


10import gncpy.control as gctrl 

11import gncpy.plotting as gplot 



14class Node: # Each node on the tree has these properties 

15 """Helper class for nodes in the tree. 


17 Attributes 

18 ---------- 

19 sv : N numpy array 

20 State vector 

21 u : Nu numpy array 

22 Control vector 

23 path : N x Np numpy array 

24 All state vectors in the path 

25 parent : :class:`.Node` 

26 Parent node. 

27 cost : float 

28 Node cost value. 

29 """ 


31 def __init__(self, state): 

32 """Initialize an object. 


34 Parameters 

35 ---------- 

36 state : numpy array 

37 state vector for the node. 

38 """ 

39 = state.ravel() 

40 self.u = [] 

41 self.path = [] 

42 self.parent = [] 

43 self.cost = 0 



46class LQRRRTStar: 

47 """Implements the LQR-RRT* algorithm. 


49 Attributes 

50 ---------- 

51 rng : numpy random number generator 

52 Instance of a random number generator to use. 

53 start : N x 1 numpy array 

54 Starting state 

55 end : N x 1 numpy array 

56 Ending state 

57 node_list : list 

58 List of nodes in the tree. 

59 min_rand : Np numpy array 

60 Minimum position values when generating random nodes 

61 max_rand : Np numpy array 

62 Maximum position values when generating radnom nodes 

63 pos_inds : list 

64 List of position indices in the state vector 

65 sampling_fun : callable 

66 Function that returns a random sample from the state space. Must 

67 take rng, pos_inds, min_rand, max_rand as inputs and return a numpy array 

68 of the same size as the state vector. 

69 goal_sample_rate : float 

70 max_iter : int 

71 Maximum iterations to search for. 

72 connect_circle_dist : float 

73 step_size : float 

74 expand_dis : float 

75 ell_con : float 

76 Ellipsoidal constraint value 

77 Nobs : int 

78 Number of obstacles. 

79 obstacle_list : Nobs x (3 or 4) numpy array 

80 Each row describes an obstacle; x pos, y pos, (z pos), radius. 

81 P : (2 or 3) x (2 or 3) x Nobs numpy array 

82 States of each obstacle for collision checking. 

83 planner : :class:`gncpy.control.LQR` 

84 Planner class instance for predicting paths. 

85 S : N x N numpy array 

86 Solution to the discrete time algebraic riccatti equation. 

87 """ 


89 def __init__( 

90 self, 

91 planner=None, 

92 goal_sample_rate=10, 

93 max_iter=300, 

94 connect_circle_dist=2, 

95 step_size=1, 

96 expand_dis=1, 

97 rng=None, 

98 sampling_fun=None, 

99 ): 


101 if rng is None: 

102 rng = np.random.default_rng() 

103 self.rng = rng 


105 self.start = None 

106 self.end = None 

107 self.node_list = [] 

108 self.min_rand = np.array([]) 

109 self.max_rand = np.array([]) 

110 self.pos_inds = [] 


112 self.sampling_fun = sampling_fun 


114 self.goal_sample_rate = goal_sample_rate 

115 self.max_iter = max_iter 

116 self.connect_circle_dist = connect_circle_dist 

117 self.step_size = step_size 

118 self.expand_dis = expand_dis 

119 self.gif_frame_skip = 1 


121 # Obstacles 

122 self.ell_con = None 

123 self.Nobs = 0 

124 self.obstacle_list = np.array([]) 

125 self.P = np.zeros((self.numPos, self.numPos, self.Nobs)) 


127 # setup LQR planner 

128 self.planner = planner 

129 self.S = np.array([[]]) 

130 self.planner_args = {} 


132 self._disp = False 


134 # plotting helpers 

135 self._fig = None 

136 self._plt_inds = None 

137 self._show_planner = False 

138 self._frame_list = [] 

139 self._fig_h = None 

140 self._fig_w = None 

141 self._plt_opts = gplot.init_plotting_opts() 


143 @property 

144 def numPos(self): 

145 return len(self.pos_inds) 


147 def set_control_model(self, planner, pos_inds, controller_args=None): 

148 if not isinstance(planner, gctrl.LQR): 

149 raise TypeError("Must specify an LQR instance") 

150 self.planner = planner 

151 self.pos_inds = pos_inds 

152 if controller_args is not None: 

153 self.planner_args = controller_args 


155 def set_environment(self, search_area=None, obstacles=None): 

156 if search_area is not None: 

157 self.min_rand = search_area[0, :] 

158 self.max_rand = search_area[1, :] 


160 if obstacles is not None: 

161 self.ell_con = 1 

162 self.Nobs = obstacles.shape[0] 

163 self.obstacle_list = obstacles 


165 dim = int(max(obstacles.shape[1] - 1, 0)) 


167 self.P = np.zeros((dim, dim, self.Nobs)) 

168 for k, r in enumerate(self.obstacle_list[:, -1]): 

169 self.P[:, :, k] = r ** (-2) * np.eye(dim) 


171 @property 

172 def use_box(self): 

173 return self.numPos == 2 and self.obstacle_list.shape[1] == 4 


175 def dx( 

176 self, x1, x2 

177 ): # x1 goes to x2. Difference between current state to Reference State 

178 return (x1 - x2).reshape((-1, 1)) 


180 def _make_sphere(self, xc, yc, zc, r): 

181 u = np.linspace(0, 2 * np.pi, 39) 

182 v = np.linspace(0, np.pi, 21) 


184 x = r * np.outer(np.cos(u), np.sin(v)) + xc 

185 y = r * np.outer(np.sin(u), np.sin(v)) + yc 

186 z = r * np.outer(np.ones(np.size(u)), np.cos(v)) + zc 


188 return x, y, z 


190 def draw_obstacles(self, ax): 

191 for obs in self.obstacle_list: 

192 if self.numPos == 2: 

193 if self.use_box: 

194 xy = obs[:2] - obs[2:4] / 2 

195 p = Rectangle(xy, obs[2], obs[3], color="k", zorder=1000) 

196 else: 

197 p = Circle(obs[:2], radius=obs[-1], color="k", zorder=1000) 

198 self._fig.axes[ax].add_patch(p) 


200 elif self.numPos == 3: 

201 self._fig.axes[ax].plot_surface( 

202 *self._make_sphere(*obs), rstride=3, cstride=3, color="k" 

203 ) 


205 def draw_start(self, start, ax): 

206 self._fig.axes[ax].scatter( 

207 *start.ravel()[self._plt_inds], color="g", marker="o", zorder=1000, 

208 ) 


210 def draw_end(self, es): 

211 self._fig.axes[0].scatter( 

212 *es[self._plt_inds], color="r", marker="x", zorder=1000, 

213 ) 


215 def save_frame(self): 

216 plt.pause(0.005) 

217 with io.BytesIO() as buff: 

218 self._fig.savefig(buff, format="raw") 


220 img = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape( 

221 (self._fig_h, self._fig_w, -1) 

222 ) 

223 self._frame_list.append(Image.fromarray(img)) 


225 def reset_controller_plot(self): 

226 self._fig.axes[1].clear() 


228 self._fig.axes[1].set_aspect("equal", adjustable="box") 

229 self._fig.axes[1].set_xlim((self.min_rand[0], self.max_rand[0])) 

230 self._fig.axes[1].set_ylim((self.min_rand[1], self.max_rand[1])) 


232 gplot.set_title_label( 

233 self._fig, 1, self._plt_opts, x_lbl="x pos", use_local=True, 

234 ) 


236 self.draw_obstacles(1) 


238 plt.pause(0.01) 


240 def plan_helper( 

241 self, 

242 cur_time, 

243 state_args, 

244 ctrl_args, 

245 use_convergence, 

246 use_first_traj, 

247 rtol, 

248 show_animation, 

249 save_animation, 

250 ): 

251 last_cost = float("inf") 

252 for i in range(self.max_iter): 

253 rnd = self.get_random_node() 

254 if self._disp: 

255 print( 

256 "\tIter: {:4d}, number of nodes: {:5d}, last cost: {:10.4f}".format( 

257 i, len(self.node_list), last_cost 

258 ) 

259 ) 


261 nearest_ind = self.get_nearest_node_index( 

262 cur_time, rnd, state_args, ctrl_args 

263 ) 

264 new_node = self.steer( 

265 cur_time, 

266 self.node_list[nearest_ind], 

267 rnd, 

268 state_args, 

269 ctrl_args, 

270 show_controller=True, 

271 planner_ttl="Searching", 

272 ) 

273 if new_node is None: 

274 continue 

275 if self.not_colliding(new_node): 

276 near_indices = self.find_near_nodes(new_node, nearest_ind) 

277 new_node = self.choose_parent( 

278 cur_time, new_node, near_indices, state_args, ctrl_args 

279 ) 

280 if new_node: 

281 self.node_list.append(new_node) 

282 self.rewire(cur_time, new_node, near_indices, state_args, ctrl_args) 


284 if show_animation and new_node is not None: 

285 if self.numPos == 2: 

286 self._fig.axes[0].plot( 

287 new_node.path[self._plt_inds[0], :], 

288 new_node.path[self._plt_inds[1], :], 

289 color=(0.5, 0.5, 0.5), 

290 alpha=0.15, 

291 zorder=-10, 

292 ) 

293 elif self.numPos == 3: 

294 self._fig.axes[0].plot( 

295 new_node.path[self._plt_inds[0], :], 

296 new_node.path[self._plt_inds[1], :], 

297 new_node.path[self._plt_inds[2], :], 

298 color=(0.5, 0.5, 0.5), 

299 alpha=0.1, 

300 zorder=-10, 

301 ) 


303 plt.pause(0.005) 

304 if save_animation and i % self.gif_frame_skip == 0: 

305 self.save_frame() 


307 if new_node: 

308 last_index = self.search_best_goal_node() 

309 if last_index: 

310 traj, u_traj = self.generate_final_course(last_index) 

311 cost = self.node_list[last_index].cost 

312 converged = ( 

313 use_convergence 

314 and np.abs(np.abs(last_cost - cost) / cost) < rtol 

315 ) 

316 last_cost = cost 

317 if use_first_traj or converged: 

318 if show_animation: 

319 if self.numPos == 2: 

320 self._fig.axes[0].plot( 

321 traj[self._plt_inds[0], :], 

322 traj[self._plt_inds[1], :], 

323 color="g", 

324 ) 

325 elif self.numPos == 3: 

326 self._fig.axes[0].plot( 

327 traj[self._plt_inds[0], :], 

328 traj[self._plt_inds[1], :], 

329 traj[self._plt_inds[2], :], 

330 color="g", 

331 ) 

332 plt.pause(0.005) 


334 # save final frame 

335 if save_animation: 

336 self.save_frame() 


338 details = (cost, u_traj.T, self._fig, self._frame_list) 

339 return traj.T, cost, u_traj.T 


341 if self._disp: 

342 print("\tReached Max Iteration!!") 


344 last_index = self.search_best_goal_node() 

345 if last_index: 

346 traj, u_traj = self.generate_final_course(last_index) 

347 cost = self.node_list[last_index].cost 

348 if show_animation: 

349 if self.numPos == 2: 

350 self._fig.axes[0].plot( 

351 traj[self._plt_inds[0], :], 

352 traj[self._plt_inds[1], :], 

353 color="g", 

354 ) 

355 elif self.numPos == 3: 

356 self._fig.axes[0].plot( 

357 traj[self._plt_inds[0], :], 

358 traj[self._plt_inds[1], :], 

359 traj[self._plt_inds[2], :], 

360 color="g", 

361 ) 

362 plt.pause(0.005) 

363 if save_animation: 

364 self.save_frame() 

365 else: 

366 traj = np.array([[]]) 

367 u_traj = np.array([[]]) 

368 cost = float("inf") 

369 if self._disp: 

370 print("\tCannot find path!!") 


372 return traj.T, cost, u_traj.T 


374 def plan( 

375 self, 

376 cur_time, 

377 cur_state, 

378 end_state, 

379 use_first_traj=True, 

380 use_convergence=False, 

381 rtol=1e-3, 

382 state_args=None, 

383 ctrl_args=None, 

384 provide_details=False, 

385 disp=True, 

386 show_animation=False, 

387 save_animation=False, 

388 plt_opts=None, 

389 ttl=None, 

390 fig=None, 

391 plt_inds=None, 

392 show_planner=True, 

393 ): 

394 if state_args is None: 

395 state_args = () 

396 if ctrl_args is None: 

397 ctrl_args = () 

398 if plt_inds is None: 

399 plt_inds = [0, 1] 


401 if use_convergence: 

402 use_first_traj = False 


404 self._disp = disp 


406 self._plt_inds = plt_inds 

407 self._fig = fig 


409 if len(end_state.shape) == 1: 

410 _end_arr = end_state.copy().reshape((1, -1)) 

411 else: 

412 _end_arr = end_state.copy().T # flip so each row is an ending state 


414 self.start = Node(cur_state.reshape((-1, 1))) 

415 # self.end = Node(end_state.reshape((-1, 1))) 

416 # self.node_list = [ 

417 # deepcopy(self.start), 

418 # ] 


420 self._frame_list = [] 

421 self._show_planner = False 

422 if show_animation: 

423 self._show_planner = show_planner and isinstance( 

424 self.planner, gctrl.ELQR 

425 ) 

426 if self._show_planner: 

427 n_plts = 2 

428 else: 

429 n_plts = 1 


431 if self._fig is None: 

432 self._fig = plt.figure() 

433 if self.numPos == 2: 

434 for ii in range(n_plts): 

435 self._fig.add_subplot(1, n_plts, ii + 1) 

436 self._fig.axes[0].set_aspect("equal", adjustable="box") 

437 self._fig.axes[0].set_xlim((self.min_rand[0], self.max_rand[0])) 

438 self._fig.axes[0].set_ylim((self.min_rand[1], self.max_rand[1])) 

439 elif self.numPos == 3: 

440 self._fig.add_subplot(1, 1, 1, projection="3d") 

441 self._fig.axes[0].set_xlim((self.min_rand[0], self.max_rand[0])) 

442 self._fig.axes[0].set_ylim((self.min_rand[1], self.max_rand[1])) 

443 self._fig.axes[0].set_zlim((self.min_rand[2], self.max_rand[2])) 


445 if plt_opts is None: 

446 self._plt_opts = gplot.init_plotting_opts(f_hndl=self._fig) 

447 if ttl is None: 

448 ttl = "LQR-RRT*" 

449 gplot.set_title_label( 

450 self._fig, 0, self._plt_opts, ttl=ttl, x_lbl="x pos", y_lbl="y pos" 

451 ) 


453 self.draw_start(, 0) 

454 self.draw_obstacles(0) 


456 if self._show_planner: 

457 self.reset_controller_plot() 

458 gplot.set_title_label( 

459 self._fig, 0, self._plt_opts, ttl="Tree", use_local=True 

460 ) 


462 self.draw_end(_end_arr[0, :]) 


464 if self._show_planner: 

465 self.planner_args.update( 

466 { 

467 "show_animation": True, 

468 "save_animation": save_animation, 

469 "ax_num": 1, 

470 "fig": self._fig, 

471 "plt_inds": self._plt_inds, 

472 "plt_opts": self._plt_opts, 

473 } 

474 ) 


476 self._fig.tight_layout() 

477 plt.pause(0.1) 


479 # for stopping simulation with the esc key. 

480 self._fig.canvas.mpl_connect( 

481 "key_release_event", 

482 lambda event: [exit(0) if event.key == "escape" else None], 

483 ) 

484 self._fig_w, self._fig_h = self._fig.canvas.get_width_height() 


486 # save first frame of animation 

487 if save_animation: 

488 self.save_frame() 


490 if self._disp: 

491 print("Starting LQR-RRT* Planning...") 


493 cost = 0 

494 u_traj = np.array([]) 

495 traj = np.array([]) 

496 for kk, es in enumerate(_end_arr): 

497 if kk == 0: 

498 self.start = Node(cur_state.reshape((-1, 1))) 

499 self.end = Node(es.reshape((-1, 1))) 

500 self.node_list = [ 

501 deepcopy(self.start), 

502 ] 

503 else: 

504 self.start = Node(_end_arr[kk-1].reshape((-1, 1))) 

505 self.end = Node(es.reshape((-1, 1))) 

506 # NOTE: in order to resuse the tree, need to recacluate costs and flip paths, for now just create new tree 

507 self.node_list = [ 

508 deepcopy(self.start), 

509 ] 


511 if show_animation: 

512 self.draw_end(es) 


514 t, c, u = self.plan_helper( 

515 cur_time, 

516 state_args, 

517 ctrl_args, 

518 use_convergence, 

519 use_first_traj, 

520 rtol, 

521 show_animation, 

522 save_animation, 

523 ) 

524 if t.size == 0: 

525 break 

526 if kk == 0: 

527 traj = t 

528 u_traj = u 

529 else: 

530 traj = np.vstack((traj, t)) 

531 u_traj = np.vstack((u_traj, u)) 

532 cost += c 


534 details = (cost, u_traj, self._fig, self._frame_list) 

535 return (traj, *details) if provide_details else traj 


537 def generate_final_course(self, goal_index): # Generate Final Course 

538 path =, 1) 

539 u_path = np.array([]) 

540 node = self.node_list[goal_index] 

541 while node.parent: 

542 i = np.flip(node.path, 1) 

543 path = np.append(path, i, axis=1) 


545 j = np.flip(node.u, 1) 

546 if u_path.size == 0: 

547 u_path = j 

548 else: 

549 u_path = np.append(u_path, j, axis=1) 


551 node = node.parent 

552 path = np.flip(path, 1) 

553 u_path = np.flip(u_path, 1) 

554 return path, u_path 


556 def search_best_goal_node(self): # Finds Node closest to Goal Node 

557 dist_to_goal_list = [self.calc_dist_to_goal( for node in self.node_list] 

558 goal_inds = [ 

559 dist_to_goal_list.index(i) 

560 for i in dist_to_goal_list 

561 if i <= self.expand_dis 

562 ] 

563 if not goal_inds: 

564 return None 

565 min_cost = min([self.node_list[i].cost for i in goal_inds]) 

566 for i in goal_inds: 

567 if self.node_list[i].cost == min_cost: 

568 return i 

569 return None 


571 def calc_dist_to_goal(self, sv): # Calculate distance between Node and the Goal 

572 return np.linalg.norm(self.dx(sv, 


574 def rewire( 

575 self, cur_time, new_node, near_inds, state_args, ctrl_args 

576 ): # Rewires the Nodes 

577 for i in near_inds: 

578 near_node = self.node_list[i] 

579 edge_node = self.steer( 

580 cur_time, 

581 new_node, 

582 near_node, 

583 state_args, 

584 ctrl_args, 

585 show_controller=True, 

586 planner_ttl="Rewiring", 

587 ) 

588 if edge_node is None: 

589 continue 

590 edge_node.cost = self.calc_new_cost( 

591 cur_time, new_node, near_node, state_args, ctrl_args 

592 ) 

593 no_collision = self.not_colliding(edge_node) 

594 improved_cost = near_node.cost > edge_node.cost 


596 if no_collision and improved_cost: 

597 near_node = edge_node 

598 near_node.parent = new_node 

599 self.propagate_cost_to_leaves(cur_time, new_node, state_args, ctrl_args) 


601 def propagate_cost_to_leaves( 

602 self, cur_time, parent_node, state_args, ctrl_args 

603 ): # Re-computes cost from rewired Nodes 

604 for node in self.node_list: 

605 if node.parent == parent_node: 

606 node.cost = self.calc_new_cost( 

607 cur_time, parent_node, node, state_args, ctrl_args 

608 ) 

609 self.propagate_cost_to_leaves(cur_time, node, state_args, ctrl_args) 


611 def choose_parent( 

612 self, cur_time, new_node, near_inds, state_args, ctrl_args 

613 ): # Chooses a parent node with lowest cost 


615 if not near_inds: 

616 return None 

617 costs = np.inf * np.ones(len(near_inds)) 

618 for kk, i in enumerate(near_inds): 

619 near_node = self.node_list[i] 

620 t_node = self.steer(cur_time, near_node, new_node, state_args, ctrl_args) 

621 if t_node and self.not_colliding(t_node): 

622 costs[kk] = self.calc_new_cost( 

623 cur_time, near_node, new_node, state_args, ctrl_args 

624 ) 


626 min_ind = np.argmin(costs) 

627 min_cost = costs[min_ind] 


629 if np.isinf(min_cost): 

630 if self._disp: 

631 print("\t\tNo Path - Infinite Cost") 

632 return None 


634 new_node = self.steer( 

635 cur_time, 

636 self.node_list[near_inds[min_ind]], 

637 new_node, 

638 state_args, 

639 ctrl_args, 

640 show_controller=True, 

641 planner_ttl="Generating Node", 

642 ) 

643 new_node.parent = self.node_list[near_inds[min_ind]] 

644 new_node.cost = min_cost 

645 return new_node 


647 def calc_new_cost( 

648 self, cur_time, from_node, to_node, state_args, ctrl_args 

649 ): # Calculates cost of node 

650 x_sim, u_sim = self.call_planner( 

651 cur_time,,, state_args, ctrl_args 

652 ) 

653 x_sim_sample, u_sim_sample, course_lens = self.sample_path(x_sim.T, u_sim.T) 

654 if len(x_sim_sample) == 0: 

655 return float("inf") 

656 return from_node.cost + sum(course_lens) 


658 def find_near_nodes( 

659 self, new_node, nearest_ind 

660 ): # Finds near nodes close to new_node 

661 nnode = len(self.node_list) + 1 

662 dist_list = [ 

663 self.dx(, @ self.S @ self.dx(, 

664 for node in self.node_list 

665 ] 

666 r = ( 

667 self.connect_circle_dist 

668 * np.amin(dist_list) 

669 * (np.log(nnode) / nnode) ** (1 / ( - 1)) 

670 ) 

671 ind = [dist_list.index(i) for i in dist_list if i <= r] 

672 if not ind: 

673 ind = [nearest_ind] 

674 return ind 


676 def not_colliding(self, node): # Check for collisions with Ellipsoids/circle/box 

677 if len(self.obstacle_list) == 0: 

678 return True 


680 if self.use_box: 

681 x = node.path[self.pos_inds[0], :] 

682 y = node.path[self.pos_inds[1], :] 

683 for k, xobs in enumerate(self.obstacle_list): 

684 if self.use_box: 

685 left = xobs[0] - xobs[2] / 2 

686 right = xobs[0] + xobs[2] / 2 

687 top = xobs[1] + xobs[3] / 2 

688 bot = xobs[1] - xobs[3] / 2 

689 over_x = np.logical_and(x >= left, x <= right) 

690 over_y = np.logical_and(y >= bot, y <= top) 

691 if np.any(np.logical_and(over_x, over_y)): 

692 return False 

693 else: 

694 distxyz = (node.path[self.pos_inds, :].T - xobs[:-1].reshape((1, -1))).T 

695 d_list = np.einsum("ij,ij->i", (distxyz.T @ self.P[:, :, k]), distxyz.T) 

696 if -min(d_list) + self.ell_con >= 0.0: 

697 return False 

698 return True 


700 def get_random_node(self): # Find a random node from the state space 

701 if self.rng.integers(0, high=100) > self.goal_sample_rate: 

702 rand_x = self.sampling_fun( 

703 self.rng, self.pos_inds, self.min_rand, self.max_rand 

704 ) 

705 rnd = Node(rand_x) 

706 else: # goal point sampling 

707 rnd = Node(, 1))) 

708 return rnd 


710 def get_nearest_node_index( 

711 self, cur_time, rnd_node, state_args, ctrl_args 

712 ): # Get nearest node index in tree 

713 self.S = self.planner.solve_dare( 

714 cur_time, 

715, 1)), 

716 state_args=state_args, 

717 ctrl_args=ctrl_args, 

718 )[0] 

719 dlist = np.array( 

720 [ 

721 ( 

722 self.dx(, 

723 @ self.S 

724 @ self.dx(, 

725 ).item() 

726 for node in self.node_list 

727 ] 

728 ) 

729 return np.argmin(dlist) 


731 def call_planner( 

732 self, cur_time, from_state, to_state, state_args, ctrl_args, **kwargs 

733 ): 

734 extra_args = dict() 

735 extra_args.update(**self.planner_args) 

736 extra_args.update(**kwargs) 


738 if self._show_planner and extra_args["show_animation"]: 

739 self.reset_controller_plot() 

740 self.draw_start(from_state, 1) 

741 plt.pause(0.01) 


743 _, _, x_traj, u_traj, _, f_lst = self.planner.calculate_control( 

744 cur_time, 

745 from_state.reshape((-1, 1)), 

746 end_state=to_state.reshape((-1, 1)), 

747 provide_details=True, 

748 state_args=state_args, 

749 ctrl_args=ctrl_args, 

750 **extra_args, 

751 ) 


753 self._frame_list.extend(f_lst) 

754 return x_traj, u_traj 


756 return self.planner.calculate_control( 

757 cur_time, 

758 from_state.reshape((-1, 1)), 

759 end_state=to_state.reshape((-1, 1)), 

760 provide_details=True, 

761 state_args=state_args, 

762 ctrl_args=ctrl_args, 

763 **extra_args, 

764 )[2:4] 


766 def steer( 

767 self, 

768 cur_time, 

769 from_node, 

770 to_node, 

771 state_args, 

772 ctrl_args, 

773 show_controller=False, 

774 planner_ttl=None, 

775 ): # Obtain trajectory between from_node to to_node using LQR and save trajectory 

776 if self._show_planner: 

777 self.planner_args["show_animation"] = show_controller 

778 if self.planner_args.get("show_animation", False): 

779 self.planner_args["ttl"] = ( 

780 planner_ttl if planner_ttl is not None else "Controller" 

781 ) 


783 x_sim, u_sim = self.call_planner( 

784 cur_time,,, state_args, ctrl_args 

785 ) 


787 x_sim_sample, u_sim_sample, course_lens = self.sample_path(x_sim.T, u_sim.T) 

788 if len(x_sim_sample) == 0: 

789 return None 

790 newNode = Node(x_sim_sample[:, -1].reshape((-1, 1))) 

791 newNode.u = u_sim_sample 

792 newNode.path = x_sim_sample 

793 newNode.cost = from_node.cost + np.sum(np.abs(course_lens)) 

794 newNode.parent = from_node 

795 return newNode 


797 def sample_path(self, x_sim, u_sim): # Interpolate path obtained by LQR 

798 x_sim_sample = [] 

799 u_sim_sample = [] 

800 if x_sim.size == 0: 

801 clen = [] 

802 return x_sim_sample, u_sim_sample, clen 

803 for i in range(x_sim.shape[1] - 1): 

804 for t in np.arange(0.0, 1.0, self.step_size): 

805 u_sim_sample.append(u_sim[:, i].tolist()) 

806 x_sample = (t * x_sim[:, i + 1] + (1.0 - t) * x_sim[:, i]).reshape( 

807 (-1, 1) 

808 )[:, 0] 

809 x_sim_sample.append(x_sample.tolist()) 


811 # enforce shape if x_sim_sample is empty list 

812 x_sim_sample = np.array(x_sim_sample).T.reshape((x_sim.shape[0], -1)) 

813 u_sim_sample = np.array(u_sim_sample).T.reshape((u_sim.shape[0], -1)) 


815 # diff_x_sim=np.diff(x_sim_sample); 

816 diff_x_sim2 = [ 

817 self.dx(x_sim_sample[:, k + 1], x_sim_sample[:, k])[:, 0] 

818 for k in range(x_sim_sample.shape[1] - 1) 

819 ] 

820 diff_x_sim = np.array(diff_x_sim2).T 

821 if diff_x_sim.size == 0: 

822 return [], [], [] 

823 clen = np.einsum("ij,ij->i", (diff_x_sim.T @ self.S), diff_x_sim.T) 

824 return x_sim_sample, u_sim_sample, clen 



827class ExtendedLQRRRTStar(LQRRRTStar): 

828 def __init__(self, **kwargs): 

829 super().__init__(**kwargs) 

830 self._inv_state_args = None 

831 self._inv_ctrl_args = None 

832 self._cost_args = None 


834 def call_planner( 

835 self, cur_time, from_state, to_state, state_args, ctrl_args, **kwargs 

836 ): 

837 extra_args = dict(disp=False) 

838 extra_args.update(**self.planner_args) 

839 extra_args.update(**kwargs) 

840 extra_args.update( 

841 dict( 

842 cost_args=self._cost_args, 

843 inv_state_args=self._inv_state_args, 

844 inv_ctrl_args=self._inv_ctrl_args, 

845 ) 

846 ) 

847 return super().call_planner( 

848 cur_time, from_state, to_state, state_args, ctrl_args, **extra_args 

849 ) 


851 def plan( 

852 self, 

853 cur_time, 

854 cur_state, 

855 end_state, 

856 inv_state_args=None, 

857 inv_ctrl_args=None, 

858 cost_args=None, 

859 **kwargs 

860 ): 

861 self._inv_state_args = inv_state_args 

862 self._inv_ctrl_args = inv_ctrl_args 

863 self._cost_args = cost_args 

864 if "ttl" not in kwargs: 

865 kwargs["ttl"] = "ELQR-RRT*" 

866 return super().plan(cur_time, cur_state, end_state, **kwargs)