Coverage for src/gncpy/planning/rrt_star.py: 67%
429 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
1"""Implements variations of the RRT* algorithm."""
2import io
3import numpy as np
4import matplotlib.pyplot as plt
5from matplotlib.patches import Circle, Rectangle
6from copy import deepcopy
7from PIL import Image
8from sys import exit
10import gncpy.control as gctrl
11import gncpy.plotting as gplot
14class Node: # Each node on the tree has these properties
15 """Helper class for nodes in the tree.
17 Attributes
18 ----------
19 sv : N numpy array
20 State vector
21 u : Nu numpy array
22 Control vector
23 path : N x Np numpy array
24 All state vectors in the path
25 parent : :class:`.Node`
26 Parent node.
27 cost : float
28 Node cost value.
29 """
31 def __init__(self, state):
32 """Initialize an object.
34 Parameters
35 ----------
36 state : numpy array
37 state vector for the node.
38 """
39 self.sv = state.ravel()
40 self.u = []
41 self.path = []
42 self.parent = []
43 self.cost = 0
46class LQRRRTStar:
47 """Implements the LQR-RRT* algorithm.
49 Attributes
50 ----------
51 rng : numpy random number generator
52 Instance of a random number generator to use.
53 start : N x 1 numpy array
54 Starting state
55 end : N x 1 numpy array
56 Ending state
57 node_list : list
58 List of nodes in the tree.
59 min_rand : Np numpy array
60 Minimum position values when generating random nodes
61 max_rand : Np numpy array
62 Maximum position values when generating radnom nodes
63 pos_inds : list
64 List of position indices in the state vector
65 sampling_fun : callable
66 Function that returns a random sample from the state space. Must
67 take rng, pos_inds, min_rand, max_rand as inputs and return a numpy array
68 of the same size as the state vector.
69 goal_sample_rate : float
70 max_iter : int
71 Maximum iterations to search for.
72 connect_circle_dist : float
73 step_size : float
74 expand_dis : float
75 ell_con : float
76 Ellipsoidal constraint value
77 Nobs : int
78 Number of obstacles.
79 obstacle_list : Nobs x (3 or 4) numpy array
80 Each row describes an obstacle; x pos, y pos, (z pos), radius.
81 P : (2 or 3) x (2 or 3) x Nobs numpy array
82 States of each obstacle for collision checking.
83 planner : :class:`gncpy.control.LQR`
84 Planner class instance for predicting paths.
85 S : N x N numpy array
86 Solution to the discrete time algebraic riccatti equation.
87 """
89 def __init__(
90 self,
91 planner=None,
92 goal_sample_rate=10,
93 max_iter=300,
94 connect_circle_dist=2,
95 step_size=1,
96 expand_dis=1,
97 rng=None,
98 sampling_fun=None,
99 ):
101 if rng is None:
102 rng = np.random.default_rng()
103 self.rng = rng
105 self.start = None
106 self.end = None
107 self.node_list = []
108 self.min_rand = np.array([])
109 self.max_rand = np.array([])
110 self.pos_inds = []
112 self.sampling_fun = sampling_fun
114 self.goal_sample_rate = goal_sample_rate
115 self.max_iter = max_iter
116 self.connect_circle_dist = connect_circle_dist
117 self.step_size = step_size
118 self.expand_dis = expand_dis
119 self.gif_frame_skip = 1
121 # Obstacles
122 self.ell_con = None
123 self.Nobs = 0
124 self.obstacle_list = np.array([])
125 self.P = np.zeros((self.numPos, self.numPos, self.Nobs))
127 # setup LQR planner
128 self.planner = planner
129 self.S = np.array([[]])
130 self.planner_args = {}
132 self._disp = False
134 # plotting helpers
135 self._fig = None
136 self._plt_inds = None
137 self._show_planner = False
138 self._frame_list = []
139 self._fig_h = None
140 self._fig_w = None
141 self._plt_opts = gplot.init_plotting_opts()
143 @property
144 def numPos(self):
145 return len(self.pos_inds)
147 def set_control_model(self, planner, pos_inds, controller_args=None):
148 if not isinstance(planner, gctrl.LQR):
149 raise TypeError("Must specify an LQR instance")
150 self.planner = planner
151 self.pos_inds = pos_inds
152 if controller_args is not None:
153 self.planner_args = controller_args
155 def set_environment(self, search_area=None, obstacles=None):
156 if search_area is not None:
157 self.min_rand = search_area[0, :]
158 self.max_rand = search_area[1, :]
160 if obstacles is not None:
161 self.ell_con = 1
162 self.Nobs = obstacles.shape[0]
163 self.obstacle_list = obstacles
165 dim = int(max(obstacles.shape[1] - 1, 0))
167 self.P = np.zeros((dim, dim, self.Nobs))
168 for k, r in enumerate(self.obstacle_list[:, -1]):
169 self.P[:, :, k] = r ** (-2) * np.eye(dim)
171 @property
172 def use_box(self):
173 return self.numPos == 2 and self.obstacle_list.shape[1] == 4
175 def dx(
176 self, x1, x2
177 ): # x1 goes to x2. Difference between current state to Reference State
178 return (x1 - x2).reshape((-1, 1))
180 def _make_sphere(self, xc, yc, zc, r):
181 u = np.linspace(0, 2 * np.pi, 39)
182 v = np.linspace(0, np.pi, 21)
184 x = r * np.outer(np.cos(u), np.sin(v)) + xc
185 y = r * np.outer(np.sin(u), np.sin(v)) + yc
186 z = r * np.outer(np.ones(np.size(u)), np.cos(v)) + zc
188 return x, y, z
190 def draw_obstacles(self, ax):
191 for obs in self.obstacle_list:
192 if self.numPos == 2:
193 if self.use_box:
194 xy = obs[:2] - obs[2:4] / 2
195 p = Rectangle(xy, obs[2], obs[3], color="k", zorder=1000)
196 else:
197 p = Circle(obs[:2], radius=obs[-1], color="k", zorder=1000)
198 self._fig.axes[ax].add_patch(p)
200 elif self.numPos == 3:
201 self._fig.axes[ax].plot_surface(
202 *self._make_sphere(*obs), rstride=3, cstride=3, color="k"
203 )
205 def draw_start(self, start, ax):
206 self._fig.axes[ax].scatter(
207 *start.ravel()[self._plt_inds], color="g", marker="o", zorder=1000,
208 )
210 def draw_end(self, es):
211 self._fig.axes[0].scatter(
212 *es[self._plt_inds], color="r", marker="x", zorder=1000,
213 )
215 def save_frame(self):
216 plt.pause(0.005)
217 with io.BytesIO() as buff:
218 self._fig.savefig(buff, format="raw")
219 buff.seek(0)
220 img = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape(
221 (self._fig_h, self._fig_w, -1)
222 )
223 self._frame_list.append(Image.fromarray(img))
225 def reset_controller_plot(self):
226 self._fig.axes[1].clear()
228 self._fig.axes[1].set_aspect("equal", adjustable="box")
229 self._fig.axes[1].set_xlim((self.min_rand[0], self.max_rand[0]))
230 self._fig.axes[1].set_ylim((self.min_rand[1], self.max_rand[1]))
232 gplot.set_title_label(
233 self._fig, 1, self._plt_opts, x_lbl="x pos", use_local=True,
234 )
236 self.draw_obstacles(1)
238 plt.pause(0.01)
240 def plan_helper(
241 self,
242 cur_time,
243 state_args,
244 ctrl_args,
245 use_convergence,
246 use_first_traj,
247 rtol,
248 show_animation,
249 save_animation,
250 ):
251 last_cost = float("inf")
252 for i in range(self.max_iter):
253 rnd = self.get_random_node()
254 if self._disp:
255 print(
256 "\tIter: {:4d}, number of nodes: {:5d}, last cost: {:10.4f}".format(
257 i, len(self.node_list), last_cost
258 )
259 )
261 nearest_ind = self.get_nearest_node_index(
262 cur_time, rnd, state_args, ctrl_args
263 )
264 new_node = self.steer(
265 cur_time,
266 self.node_list[nearest_ind],
267 rnd,
268 state_args,
269 ctrl_args,
270 show_controller=True,
271 planner_ttl="Searching",
272 )
273 if new_node is None:
274 continue
275 if self.not_colliding(new_node):
276 near_indices = self.find_near_nodes(new_node, nearest_ind)
277 new_node = self.choose_parent(
278 cur_time, new_node, near_indices, state_args, ctrl_args
279 )
280 if new_node:
281 self.node_list.append(new_node)
282 self.rewire(cur_time, new_node, near_indices, state_args, ctrl_args)
284 if show_animation and new_node is not None:
285 if self.numPos == 2:
286 self._fig.axes[0].plot(
287 new_node.path[self._plt_inds[0], :],
288 new_node.path[self._plt_inds[1], :],
289 color=(0.5, 0.5, 0.5),
290 alpha=0.15,
291 zorder=-10,
292 )
293 elif self.numPos == 3:
294 self._fig.axes[0].plot(
295 new_node.path[self._plt_inds[0], :],
296 new_node.path[self._plt_inds[1], :],
297 new_node.path[self._plt_inds[2], :],
298 color=(0.5, 0.5, 0.5),
299 alpha=0.1,
300 zorder=-10,
301 )
303 plt.pause(0.005)
304 if save_animation and i % self.gif_frame_skip == 0:
305 self.save_frame()
307 if new_node:
308 last_index = self.search_best_goal_node()
309 if last_index:
310 traj, u_traj = self.generate_final_course(last_index)
311 cost = self.node_list[last_index].cost
312 converged = (
313 use_convergence
314 and np.abs(np.abs(last_cost - cost) / cost) < rtol
315 )
316 last_cost = cost
317 if use_first_traj or converged:
318 if show_animation:
319 if self.numPos == 2:
320 self._fig.axes[0].plot(
321 traj[self._plt_inds[0], :],
322 traj[self._plt_inds[1], :],
323 color="g",
324 )
325 elif self.numPos == 3:
326 self._fig.axes[0].plot(
327 traj[self._plt_inds[0], :],
328 traj[self._plt_inds[1], :],
329 traj[self._plt_inds[2], :],
330 color="g",
331 )
332 plt.pause(0.005)
334 # save final frame
335 if save_animation:
336 self.save_frame()
338 details = (cost, u_traj.T, self._fig, self._frame_list)
339 return traj.T, cost, u_traj.T
341 if self._disp:
342 print("\tReached Max Iteration!!")
344 last_index = self.search_best_goal_node()
345 if last_index:
346 traj, u_traj = self.generate_final_course(last_index)
347 cost = self.node_list[last_index].cost
348 if show_animation:
349 if self.numPos == 2:
350 self._fig.axes[0].plot(
351 traj[self._plt_inds[0], :],
352 traj[self._plt_inds[1], :],
353 color="g",
354 )
355 elif self.numPos == 3:
356 self._fig.axes[0].plot(
357 traj[self._plt_inds[0], :],
358 traj[self._plt_inds[1], :],
359 traj[self._plt_inds[2], :],
360 color="g",
361 )
362 plt.pause(0.005)
363 if save_animation:
364 self.save_frame()
365 else:
366 traj = np.array([[]])
367 u_traj = np.array([[]])
368 cost = float("inf")
369 if self._disp:
370 print("\tCannot find path!!")
372 return traj.T, cost, u_traj.T
374 def plan(
375 self,
376 cur_time,
377 cur_state,
378 end_state,
379 use_first_traj=True,
380 use_convergence=False,
381 rtol=1e-3,
382 state_args=None,
383 ctrl_args=None,
384 provide_details=False,
385 disp=True,
386 show_animation=False,
387 save_animation=False,
388 plt_opts=None,
389 ttl=None,
390 fig=None,
391 plt_inds=None,
392 show_planner=True,
393 ):
394 if state_args is None:
395 state_args = ()
396 if ctrl_args is None:
397 ctrl_args = ()
398 if plt_inds is None:
399 plt_inds = [0, 1]
401 if use_convergence:
402 use_first_traj = False
404 self._disp = disp
406 self._plt_inds = plt_inds
407 self._fig = fig
409 if len(end_state.shape) == 1:
410 _end_arr = end_state.copy().reshape((1, -1))
411 else:
412 _end_arr = end_state.copy().T # flip so each row is an ending state
414 self.start = Node(cur_state.reshape((-1, 1)))
415 # self.end = Node(end_state.reshape((-1, 1)))
416 # self.node_list = [
417 # deepcopy(self.start),
418 # ]
420 self._frame_list = []
421 self._show_planner = False
422 if show_animation:
423 self._show_planner = show_planner and isinstance(
424 self.planner, gctrl.ELQR
425 )
426 if self._show_planner:
427 n_plts = 2
428 else:
429 n_plts = 1
431 if self._fig is None:
432 self._fig = plt.figure()
433 if self.numPos == 2:
434 for ii in range(n_plts):
435 self._fig.add_subplot(1, n_plts, ii + 1)
436 self._fig.axes[0].set_aspect("equal", adjustable="box")
437 self._fig.axes[0].set_xlim((self.min_rand[0], self.max_rand[0]))
438 self._fig.axes[0].set_ylim((self.min_rand[1], self.max_rand[1]))
439 elif self.numPos == 3:
440 self._fig.add_subplot(1, 1, 1, projection="3d")
441 self._fig.axes[0].set_xlim((self.min_rand[0], self.max_rand[0]))
442 self._fig.axes[0].set_ylim((self.min_rand[1], self.max_rand[1]))
443 self._fig.axes[0].set_zlim((self.min_rand[2], self.max_rand[2]))
445 if plt_opts is None:
446 self._plt_opts = gplot.init_plotting_opts(f_hndl=self._fig)
447 if ttl is None:
448 ttl = "LQR-RRT*"
449 gplot.set_title_label(
450 self._fig, 0, self._plt_opts, ttl=ttl, x_lbl="x pos", y_lbl="y pos"
451 )
453 self.draw_start(self.start.sv, 0)
454 self.draw_obstacles(0)
456 if self._show_planner:
457 self.reset_controller_plot()
458 gplot.set_title_label(
459 self._fig, 0, self._plt_opts, ttl="Tree", use_local=True
460 )
462 self.draw_end(_end_arr[0, :])
464 if self._show_planner:
465 self.planner_args.update(
466 {
467 "show_animation": True,
468 "save_animation": save_animation,
469 "ax_num": 1,
470 "fig": self._fig,
471 "plt_inds": self._plt_inds,
472 "plt_opts": self._plt_opts,
473 }
474 )
476 self._fig.tight_layout()
477 plt.pause(0.1)
479 # for stopping simulation with the esc key.
480 self._fig.canvas.mpl_connect(
481 "key_release_event",
482 lambda event: [exit(0) if event.key == "escape" else None],
483 )
484 self._fig_w, self._fig_h = self._fig.canvas.get_width_height()
486 # save first frame of animation
487 if save_animation:
488 self.save_frame()
490 if self._disp:
491 print("Starting LQR-RRT* Planning...")
493 cost = 0
494 u_traj = np.array([])
495 traj = np.array([])
496 for kk, es in enumerate(_end_arr):
497 if kk == 0:
498 self.start = Node(cur_state.reshape((-1, 1)))
499 self.end = Node(es.reshape((-1, 1)))
500 self.node_list = [
501 deepcopy(self.start),
502 ]
503 else:
504 self.start = Node(_end_arr[kk-1].reshape((-1, 1)))
505 self.end = Node(es.reshape((-1, 1)))
506 # NOTE: in order to resuse the tree, need to recacluate costs and flip paths, for now just create new tree
507 self.node_list = [
508 deepcopy(self.start),
509 ]
511 if show_animation:
512 self.draw_end(es)
514 t, c, u = self.plan_helper(
515 cur_time,
516 state_args,
517 ctrl_args,
518 use_convergence,
519 use_first_traj,
520 rtol,
521 show_animation,
522 save_animation,
523 )
524 if t.size == 0:
525 break
526 if kk == 0:
527 traj = t
528 u_traj = u
529 else:
530 traj = np.vstack((traj, t))
531 u_traj = np.vstack((u_traj, u))
532 cost += c
534 details = (cost, u_traj, self._fig, self._frame_list)
535 return (traj, *details) if provide_details else traj
537 def generate_final_course(self, goal_index): # Generate Final Course
538 path = self.end.sv.reshape(-1, 1)
539 u_path = np.array([])
540 node = self.node_list[goal_index]
541 while node.parent:
542 i = np.flip(node.path, 1)
543 path = np.append(path, i, axis=1)
545 j = np.flip(node.u, 1)
546 if u_path.size == 0:
547 u_path = j
548 else:
549 u_path = np.append(u_path, j, axis=1)
551 node = node.parent
552 path = np.flip(path, 1)
553 u_path = np.flip(u_path, 1)
554 return path, u_path
556 def search_best_goal_node(self): # Finds Node closest to Goal Node
557 dist_to_goal_list = [self.calc_dist_to_goal(node.sv) for node in self.node_list]
558 goal_inds = [
559 dist_to_goal_list.index(i)
560 for i in dist_to_goal_list
561 if i <= self.expand_dis
562 ]
563 if not goal_inds:
564 return None
565 min_cost = min([self.node_list[i].cost for i in goal_inds])
566 for i in goal_inds:
567 if self.node_list[i].cost == min_cost:
568 return i
569 return None
571 def calc_dist_to_goal(self, sv): # Calculate distance between Node and the Goal
572 return np.linalg.norm(self.dx(sv, self.end.sv))
574 def rewire(
575 self, cur_time, new_node, near_inds, state_args, ctrl_args
576 ): # Rewires the Nodes
577 for i in near_inds:
578 near_node = self.node_list[i]
579 edge_node = self.steer(
580 cur_time,
581 new_node,
582 near_node,
583 state_args,
584 ctrl_args,
585 show_controller=True,
586 planner_ttl="Rewiring",
587 )
588 if edge_node is None:
589 continue
590 edge_node.cost = self.calc_new_cost(
591 cur_time, new_node, near_node, state_args, ctrl_args
592 )
593 no_collision = self.not_colliding(edge_node)
594 improved_cost = near_node.cost > edge_node.cost
596 if no_collision and improved_cost:
597 near_node = edge_node
598 near_node.parent = new_node
599 self.propagate_cost_to_leaves(cur_time, new_node, state_args, ctrl_args)
601 def propagate_cost_to_leaves(
602 self, cur_time, parent_node, state_args, ctrl_args
603 ): # Re-computes cost from rewired Nodes
604 for node in self.node_list:
605 if node.parent == parent_node:
606 node.cost = self.calc_new_cost(
607 cur_time, parent_node, node, state_args, ctrl_args
608 )
609 self.propagate_cost_to_leaves(cur_time, node, state_args, ctrl_args)
611 def choose_parent(
612 self, cur_time, new_node, near_inds, state_args, ctrl_args
613 ): # Chooses a parent node with lowest cost
615 if not near_inds:
616 return None
617 costs = np.inf * np.ones(len(near_inds))
618 for kk, i in enumerate(near_inds):
619 near_node = self.node_list[i]
620 t_node = self.steer(cur_time, near_node, new_node, state_args, ctrl_args)
621 if t_node and self.not_colliding(t_node):
622 costs[kk] = self.calc_new_cost(
623 cur_time, near_node, new_node, state_args, ctrl_args
624 )
626 min_ind = np.argmin(costs)
627 min_cost = costs[min_ind]
629 if np.isinf(min_cost):
630 if self._disp:
631 print("\t\tNo Path - Infinite Cost")
632 return None
634 new_node = self.steer(
635 cur_time,
636 self.node_list[near_inds[min_ind]],
637 new_node,
638 state_args,
639 ctrl_args,
640 show_controller=True,
641 planner_ttl="Generating Node",
642 )
643 new_node.parent = self.node_list[near_inds[min_ind]]
644 new_node.cost = min_cost
645 return new_node
647 def calc_new_cost(
648 self, cur_time, from_node, to_node, state_args, ctrl_args
649 ): # Calculates cost of node
650 x_sim, u_sim = self.call_planner(
651 cur_time, from_node.sv, to_node.sv, state_args, ctrl_args
652 )
653 x_sim_sample, u_sim_sample, course_lens = self.sample_path(x_sim.T, u_sim.T)
654 if len(x_sim_sample) == 0:
655 return float("inf")
656 return from_node.cost + sum(course_lens)
658 def find_near_nodes(
659 self, new_node, nearest_ind
660 ): # Finds near nodes close to new_node
661 nnode = len(self.node_list) + 1
662 dist_list = [
663 self.dx(node.sv, new_node.sv).T @ self.S @ self.dx(node.sv, new_node.sv)
664 for node in self.node_list
665 ]
666 r = (
667 self.connect_circle_dist
668 * np.amin(dist_list)
669 * (np.log(nnode) / nnode) ** (1 / (new_node.sv.size - 1))
670 )
671 ind = [dist_list.index(i) for i in dist_list if i <= r]
672 if not ind:
673 ind = [nearest_ind]
674 return ind
676 def not_colliding(self, node): # Check for collisions with Ellipsoids/circle/box
677 if len(self.obstacle_list) == 0:
678 return True
680 if self.use_box:
681 x = node.path[self.pos_inds[0], :]
682 y = node.path[self.pos_inds[1], :]
683 for k, xobs in enumerate(self.obstacle_list):
684 if self.use_box:
685 left = xobs[0] - xobs[2] / 2
686 right = xobs[0] + xobs[2] / 2
687 top = xobs[1] + xobs[3] / 2
688 bot = xobs[1] - xobs[3] / 2
689 over_x = np.logical_and(x >= left, x <= right)
690 over_y = np.logical_and(y >= bot, y <= top)
691 if np.any(np.logical_and(over_x, over_y)):
692 return False
693 else:
694 distxyz = (node.path[self.pos_inds, :].T - xobs[:-1].reshape((1, -1))).T
695 d_list = np.einsum("ij,ij->i", (distxyz.T @ self.P[:, :, k]), distxyz.T)
696 if -min(d_list) + self.ell_con >= 0.0:
697 return False
698 return True
700 def get_random_node(self): # Find a random node from the state space
701 if self.rng.integers(0, high=100) > self.goal_sample_rate:
702 rand_x = self.sampling_fun(
703 self.rng, self.pos_inds, self.min_rand, self.max_rand
704 )
705 rnd = Node(rand_x)
706 else: # goal point sampling
707 rnd = Node(self.end.sv.reshape((-1, 1)))
708 return rnd
710 def get_nearest_node_index(
711 self, cur_time, rnd_node, state_args, ctrl_args
712 ): # Get nearest node index in tree
713 self.S = self.planner.solve_dare(
714 cur_time,
715 rnd_node.sv.reshape((-1, 1)),
716 state_args=state_args,
717 ctrl_args=ctrl_args,
718 )[0]
719 dlist = np.array(
720 [
721 (
722 self.dx(node.sv, rnd_node.sv).T
723 @ self.S
724 @ self.dx(node.sv, rnd_node.sv)
725 ).item()
726 for node in self.node_list
727 ]
728 )
729 return np.argmin(dlist)
731 def call_planner(
732 self, cur_time, from_state, to_state, state_args, ctrl_args, **kwargs
733 ):
734 extra_args = dict()
735 extra_args.update(**self.planner_args)
736 extra_args.update(**kwargs)
738 if self._show_planner and extra_args["show_animation"]:
739 self.reset_controller_plot()
740 self.draw_start(from_state, 1)
741 plt.pause(0.01)
743 _, _, x_traj, u_traj, _, f_lst = self.planner.calculate_control(
744 cur_time,
745 from_state.reshape((-1, 1)),
746 end_state=to_state.reshape((-1, 1)),
747 provide_details=True,
748 state_args=state_args,
749 ctrl_args=ctrl_args,
750 **extra_args,
751 )
753 self._frame_list.extend(f_lst)
754 return x_traj, u_traj
756 return self.planner.calculate_control(
757 cur_time,
758 from_state.reshape((-1, 1)),
759 end_state=to_state.reshape((-1, 1)),
760 provide_details=True,
761 state_args=state_args,
762 ctrl_args=ctrl_args,
763 **extra_args,
764 )[2:4]
766 def steer(
767 self,
768 cur_time,
769 from_node,
770 to_node,
771 state_args,
772 ctrl_args,
773 show_controller=False,
774 planner_ttl=None,
775 ): # Obtain trajectory between from_node to to_node using LQR and save trajectory
776 if self._show_planner:
777 self.planner_args["show_animation"] = show_controller
778 if self.planner_args.get("show_animation", False):
779 self.planner_args["ttl"] = (
780 planner_ttl if planner_ttl is not None else "Controller"
781 )
783 x_sim, u_sim = self.call_planner(
784 cur_time, from_node.sv, to_node.sv, state_args, ctrl_args
785 )
787 x_sim_sample, u_sim_sample, course_lens = self.sample_path(x_sim.T, u_sim.T)
788 if len(x_sim_sample) == 0:
789 return None
790 newNode = Node(x_sim_sample[:, -1].reshape((-1, 1)))
791 newNode.u = u_sim_sample
792 newNode.path = x_sim_sample
793 newNode.cost = from_node.cost + np.sum(np.abs(course_lens))
794 newNode.parent = from_node
795 return newNode
797 def sample_path(self, x_sim, u_sim): # Interpolate path obtained by LQR
798 x_sim_sample = []
799 u_sim_sample = []
800 if x_sim.size == 0:
801 clen = []
802 return x_sim_sample, u_sim_sample, clen
803 for i in range(x_sim.shape[1] - 1):
804 for t in np.arange(0.0, 1.0, self.step_size):
805 u_sim_sample.append(u_sim[:, i].tolist())
806 x_sample = (t * x_sim[:, i + 1] + (1.0 - t) * x_sim[:, i]).reshape(
807 (-1, 1)
808 )[:, 0]
809 x_sim_sample.append(x_sample.tolist())
811 # enforce shape if x_sim_sample is empty list
812 x_sim_sample = np.array(x_sim_sample).T.reshape((x_sim.shape[0], -1))
813 u_sim_sample = np.array(u_sim_sample).T.reshape((u_sim.shape[0], -1))
815 # diff_x_sim=np.diff(x_sim_sample);
816 diff_x_sim2 = [
817 self.dx(x_sim_sample[:, k + 1], x_sim_sample[:, k])[:, 0]
818 for k in range(x_sim_sample.shape[1] - 1)
819 ]
820 diff_x_sim = np.array(diff_x_sim2).T
821 if diff_x_sim.size == 0:
822 return [], [], []
823 clen = np.einsum("ij,ij->i", (diff_x_sim.T @ self.S), diff_x_sim.T)
824 return x_sim_sample, u_sim_sample, clen
827class ExtendedLQRRRTStar(LQRRRTStar):
828 def __init__(self, **kwargs):
829 super().__init__(**kwargs)
830 self._inv_state_args = None
831 self._inv_ctrl_args = None
832 self._cost_args = None
834 def call_planner(
835 self, cur_time, from_state, to_state, state_args, ctrl_args, **kwargs
836 ):
837 extra_args = dict(disp=False)
838 extra_args.update(**self.planner_args)
839 extra_args.update(**kwargs)
840 extra_args.update(
841 dict(
842 cost_args=self._cost_args,
843 inv_state_args=self._inv_state_args,
844 inv_ctrl_args=self._inv_ctrl_args,
845 )
846 )
847 return super().call_planner(
848 cur_time, from_state, to_state, state_args, ctrl_args, **extra_args
849 )
851 def plan(
852 self,
853 cur_time,
854 cur_state,
855 end_state,
856 inv_state_args=None,
857 inv_ctrl_args=None,
858 cost_args=None,
859 **kwargs
860 ):
861 self._inv_state_args = inv_state_args
862 self._inv_ctrl_args = inv_ctrl_args
863 self._cost_args = cost_args
864 if "ttl" not in kwargs:
865 kwargs["ttl"] = "ELQR-RRT*"
866 return super().plan(cur_time, cur_state, end_state, **kwargs)