Coverage for src/gncpy/control/elqr.py: 66%
266 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
1import io
2import matplotlib
3import matplotlib.pyplot as plt
4import numpy as np
5import scipy.linalg as la
6from PIL import Image
8import gncpy.dynamics.basic as gdyn
9import gncpy.math as gmath
10import gncpy.plotting as gplot
11from gncpy.control.lqr import LQR
14class ELQR(LQR):
15 """Implements an Extended Linear Quadratic Regulator (ELQR) controller.
17 This is based on
18 :cite:`Berg2016_ExtendedLQRLocallyOptimalFeedbackControlforSystemswithNonLinearDynamicsandNonQuadraticCost`.
20 Attributes
21 ----------
22 max_iters : int
23 Maximum number of iterations to try for convergence.
24 tol : float
25 Tolerance for convergence.
26 ct_come_mats : Nh+1 x N x N numpy array
27 Cost-to-come matrices, 1 per step in the time horizon.
28 ct_come_vecs : Nh+1 x N x 1 numpy array
29 Cost-to-come vectors, 1 per step in the time horizon.
30 use_custom_cost : bool
31 Flag indicating if a custom cost function should be used.
32 gif_frame_skip : int
33 Number of frames to skip when saving the gif. Set to 1 to take every
34 frame.
35 """
37 def __init__(self, max_iters=1e3, tol=1e-4, **kwargs):
38 """Initialize an object.
40 Parameters
41 ----------
42 max_iters : int, optional
43 Maximum number of iterations to try for convergence. The default is
44 1e3.
45 tol : float, optional
46 Tolerance on convergence. The default is 1e-4.
47 **kwargs : dict
48 Additional arguments.
49 """
50 key = "time_horizon"
51 if key not in kwargs:
52 kwargs[key] = 10
54 super().__init__(**kwargs)
55 self.max_iters = int(max_iters)
56 self.tol = tol
58 self.ct_come_mats = np.array([])
59 self.ct_come_vecs = np.array([])
61 self.use_custom_cost = False
62 self.gif_frame_skip = 1
63 self._non_quad_fun = None
64 self._quad_modifier = None
65 self._cost_fun = None
67 self._ax = 0
69 def set_cost_model(
70 self,
71 Q=None,
72 R=None,
73 non_quadratic_fun=None,
74 quad_modifier=None,
75 cost_fun=None,
76 skip_validity_check=False,
77 ):
78 r"""Sets the cost model.
80 Either `Q`, and `R` must be supplied (and optionally `quad_modifier`)
81 or `cost_fun`. If `Q and `R` are specified a `non_quadratic_fun` is
82 also needed by the code but this function does not force the requirement.
84 Notes
85 -----
86 This assumes the following form for the cost function
88 .. math::
90 J = x_f^T Q x_f + \int_{t_0}^{t_f} x^T Q x + u^T R u + u^T P x + f(\tau, x, u) d\tau
92 where :math:`f(t, x, u)` contains the non-quadratic terms and gets
93 automatically quadratized by the algorithm.
95 A custom cost function can also be provided for :math:`J` in which case
96 the entire function :math:`J(t, x, u)` will be quadratized at every
97 timestep.
99 Parameters
100 ----------
101 Q : N x N numpy array, optional
102 State cost matrix. The default is None.
103 R : Nu x Nu numpy array, optional
104 Control cost matrix. The default is None.
105 non_quadratic_fun : callable, optional
106 Non-quadratic portion of the standard cost function. This should
107 have the form
108 :code:`f(t, state, ctrl_input, end_state, is_initial, is_final, *args)`
109 and return a scalar additional cost. The default is None.
110 quad_modifier : callable, optional
111 Function to modifiy the :math:`P, Q, R, q, r` matrices before
112 adding the non-quadratic terms. Must have the form
113 :code:`f(itr, t, P, Q, R, q, r)` and return a tuple of
114 :code:`(P, Q, R, q, r)`. The default is None.
115 cost_fun : callable, optional
116 Custom cost function, must handle all cases and will be numericaly
117 quadratized at every timestep. Must have the form
118 :code:`f(t, state, ctrl_input, end_state, is_initial, is_final, *args)`
119 and return a scalar total cost. The default is None.
120 skip_validity_check : bool
121 Flag indicating if an error should be raised if the wrong input
122 combination is given.
124 Raises
125 ------
126 RuntimeError
127 Invlaid combination of input arguments.
128 """
129 if non_quadratic_fun is not None:
130 self._non_quad_fun = non_quadratic_fun
132 self._quad_modifier = quad_modifier
134 if Q is not None and R is not None:
135 super().set_cost_model(Q, R)
136 self.use_custom_cost = False
138 elif cost_fun is not None:
139 self._cost_fun = cost_fun
140 self.use_custom_cost = True
142 else:
143 if not skip_validity_check:
144 raise RuntimeError("Invalid combination of inputs.")
146 def cost_function(
147 self, tt, state, ctrl_input, cost_args, is_initial=False, is_final=False
148 ):
149 """Calculates the cost for the state and control input.
151 Parameters
152 ----------
153 tt : float
154 Current timestep.
155 state : N x 1 numpy array
156 Current state.
157 ctrl_input : Nu x 1 numpy array
158 Current control input.
159 cost_args : tuple
160 Additional arguments for the non-quadratic part or the custom cost
161 function.
162 is_initial : bool, optional
163 Flag indicating if it's the initial time. The default is False.
164 is_final : bool, optional
165 Flag indicating if it's the final time. The default is False.
167 Returns
168 -------
169 float
170 Cost.
171 """
172 if not self.use_custom_cost:
173 cost = super().cost_function(
174 tt, state, ctrl_input, is_initial=is_initial, is_final=is_final
175 )
176 return (
177 cost
178 if is_final
179 else cost
180 + self._non_quad_fun(
181 tt,
182 state,
183 ctrl_input,
184 self.end_state,
185 is_initial,
186 is_final,
187 *cost_args,
188 )
189 )
191 return self._cost_fun(
192 tt, state, ctrl_input, self.end_state, is_initial, is_final, *cost_args
193 )
195 def prop_state_forward(
196 self, tt, x_hat, u_hat, state_args, ctrl_args, inv_state_args, inv_ctrl_args
197 ):
198 """Propagate the state forward and get the backward state space.
200 Parameters
201 ----------
202 tt : float
203 Current timestep.
204 x_hat : N x 1 numpy array
205 Current state.
206 u_hat : Nu x 1 numpy array
207 Current control input.
208 state_args : tuple
209 Extra arguments for the state matrix calculation.
210 ctrl_args : tuple
211 Extra arguments for the control matrix calculation.
212 inv_state_args : tuple
213 Extra arguments for the inverse state matrix calculation.
214 inv_ctrl_args : tuple
215 Extra arguments for the inverse control matrix calculation.
217 Returns
218 -------
219 x_hat_p : N x 1 numpy array
220 Next state.
221 Abar : N x N numpy array
222 Backward state transition matrix.
223 Bbar : N x Nu
224 Backward input matrix.
225 cbar : N x 1 numpy array
226 Backward c vector.
227 """
228 x_hat_p = self.prop_state(
229 tt, x_hat, u_hat, state_args, ctrl_args, True, inv_state_args, inv_ctrl_args
230 )
232 if self.dynObj is not None:
233 if isinstance(self.dynObj, gdyn.NonlinearDynamicsBase):
234 if self.dynObj.dt > 0:
235 self.dynObj.dt *= -1 # set to inverse dynamics
237 ABar = self.dynObj.get_state_mat(
238 tt, x_hat_p, *inv_state_args, u=u_hat, ctrl_args=ctrl_args
239 )
240 BBar = self.dynObj.get_input_mat(tt, x_hat_p, u_hat, *inv_ctrl_args)
242 else:
243 A, B = self.get_state_space(tt, x_hat_p, u_hat, state_args, ctrl_args)
244 ABar = la.inv(A)
245 BBar = -ABar @ B
247 else:
248 raise NotImplementedError("Need to implement this")
250 cBar = x_hat - ABar @ x_hat_p - BBar @ u_hat
252 return x_hat_p, ABar, BBar, cBar
254 def quadratize_cost(self, tt, itr, x_hat, u_hat, is_initial, is_final, cost_args):
255 """Quadratizes the cost function.
257 If the non-quadratic portion of the standard cost function is not
258 positive semi-definite then it is regularized by setting the negative
259 eigen values to zero and reconstructing the matrix.
261 Parameters
262 ----------
263 tt : float
264 Current timestep.
265 itr : int
266 Iteration number.
267 x_hat : N x 1 numpy array
268 Current state.
269 u_hat : Nu x 1 numpy array
270 Current control input.
271 is_initial : bool
272 Flag indicating if this is the first timestep.
273 is_final : bool
274 Flag indicating if this is the last timestep.
275 cost_args : tuple
276 Additional arguments for the cost calculation.
278 Raises
279 ------
280 NotImplementedError
281 Unsupported configuration.
283 Returns
284 -------
285 P : Nu x N numpy array
286 State and control correlation cost matrix.
287 Q : N x N numpy array
288 State cost matrix.
289 R : Nu x Nu numpy array
290 Control input cost matrix.
291 q : N x 1
292 State cost vector.
293 r : Nu x 1
294 Control input cost vector.
295 """
296 if not self.use_custom_cost:
297 P, Q, R, q, r = super()._determine_cost_matrices(
298 tt, itr, x_hat, u_hat, is_initial, is_final, cost_args
299 )
301 if not is_initial and not is_final:
302 if self._quad_modifier is not None:
303 P, Q, R, q, r = self._quad_modifier(itr, tt, P, Q, R, q, r)
305 xdim = x_hat.size
306 udim = u_hat.size
307 comb_state = np.vstack((x_hat, u_hat)).ravel()
308 big_mat = gmath.get_hessian(
309 comb_state,
310 lambda _x, *_args: self._non_quad_fun(
311 tt,
312 _x[:xdim],
313 _x[xdim:],
314 self.end_state,
315 is_initial,
316 is_final,
317 *cost_args,
318 ),
319 )
321 # regularize hessian to keep it pos semi def
322 vals, vecs = np.linalg.eig(big_mat)
323 vals[vals < 0] = 0
324 big_mat = vecs @ np.diag(vals) @ vecs.T
326 # extract non-quadratic terms
327 non_Q = big_mat[:xdim, :xdim]
329 non_P = big_mat[xdim:, :xdim]
330 non_R = big_mat[xdim:, xdim:]
332 big_vec = gmath.get_jacobian(
333 comb_state,
334 lambda _x, *args: self._non_quad_fun(
335 tt,
336 _x[:xdim],
337 _x[xdim:],
338 self.end_state,
339 is_initial,
340 is_final,
341 *cost_args,
342 ),
343 )
345 non_q = big_vec[:xdim].reshape((xdim, 1)) - (
346 non_Q @ x_hat + non_P.T @ u_hat
347 )
348 non_r = big_vec[xdim:].reshape((udim, 1)) - (
349 non_P @ x_hat + non_R @ u_hat
350 )
352 Q += non_Q
353 q += non_q
354 R += non_R
355 r += non_r
356 P += non_P
358 else:
359 # TODO: get hessians
360 raise NotImplementedError("Need to implement this")
362 return P, Q, R, q, r
364 def quadratize_final_cost(self, itr, num_timesteps, traj, time_vec, cost_args):
365 u_hat = (
366 self.feedback_gain[num_timesteps]
367 @ traj[num_timesteps - 1, :].reshape((-1, 1))
368 + self.feedthrough_gain[num_timesteps]
369 )
370 (
371 _,
372 self.ct_go_mats[num_timesteps],
373 _,
374 self.ct_go_vecs[num_timesteps],
375 _,
376 ) = self.quadratize_cost(
377 time_vec[-1],
378 itr,
379 traj[num_timesteps - 1, :].reshape((-1, 1)),
380 u_hat,
381 False,
382 True,
383 cost_args,
384 )
385 traj[num_timesteps, :] = -(
386 np.linalg.inv(
387 self.ct_go_mats[num_timesteps] + self.ct_come_mats[num_timesteps]
388 )
389 @ (self.ct_go_vecs[num_timesteps] + self.ct_come_vecs[num_timesteps])
390 ).ravel()
392 if (
393 self.hard_constraints
394 and self.dynObj is not None
395 and self.dynObj.state_constraint is not None
396 ):
397 traj[num_timesteps, :] = self.dynObj.state_constraint(
398 time_vec[num_timesteps], traj[num_timesteps, :].reshape((-1, 1))
399 ).ravel()
401 return traj
403 def _determine_cost_matrices(
404 self, tt, itr, x_hat, u_hat, is_initial, is_final, cost_args
405 ):
406 return self.quadratize_cost(
407 tt, itr, x_hat, u_hat, is_initial, is_final, cost_args
408 )
410 def _back_pass_update_traj(self, x_hat_p, kk):
411 return -(
412 np.linalg.inv(self.ct_go_mats[kk] + self.ct_come_mats[kk])
413 @ (self.ct_go_vecs[kk] + self.ct_come_vecs[kk])
414 ).ravel()
416 def forward_pass_step(
417 self,
418 itr,
419 kk,
420 time_vec,
421 traj,
422 state_args,
423 ctrl_args,
424 cost_args,
425 inv_state_args,
426 inv_ctrl_args,
427 ):
428 tt = time_vec[kk]
430 u_hat = (
431 self.feedback_gain[kk] @ traj[kk, :].reshape((-1, 1))
432 + self.feedthrough_gain[kk]
433 )
434 if self.hard_constraints and self.control_constraints is not None:
435 u_hat = self.control_constraints(tt, u_hat)
437 x_hat_p, ABar, BBar, cBar = self.prop_state_forward(
438 tt,
439 traj[kk, :].reshape((-1, 1)),
440 u_hat,
441 state_args,
442 ctrl_args,
443 inv_state_args,
444 inv_ctrl_args,
445 )
447 # final cost is handled after the forward pass
448 P, Q, R, q, r = self._determine_cost_matrices(
449 tt, itr, traj[kk, :].reshape((-1, 1)), u_hat, kk == 0, False, cost_args
450 )
452 ctm_Q = self.ct_come_mats[kk] + Q
453 ctm_Q_A = ctm_Q @ ABar
454 ctv_q_ctm_Q_c = self.ct_come_vecs[kk] + q + ctm_Q @ cBar
455 CBar = BBar.T @ ctm_Q_A + P @ ABar
456 DBar = ABar.T @ ctm_Q_A
457 EBar = BBar.T @ ctm_Q @ BBar + R + P @ BBar + BBar.T @ P.T
458 dBar = ABar.T @ ctv_q_ctm_Q_c
459 eBar = BBar.T @ ctv_q_ctm_Q_c + r + P @ cBar
461 neg_inv_EBar = -np.linalg.inv(EBar)
462 self.feedback_gain[kk] = neg_inv_EBar @ CBar
463 self.feedthrough_gain[kk] = neg_inv_EBar @ eBar
465 self.ct_come_mats[kk + 1] = DBar + CBar.T @ self.feedback_gain[kk]
466 self.ct_come_vecs[kk + 1] = dBar + CBar.T @ self.feedthrough_gain[kk]
468 out = -(
469 np.linalg.inv(self.ct_go_mats[kk + 1] + self.ct_come_mats[kk + 1])
470 @ (self.ct_go_vecs[kk + 1] + self.ct_come_vecs[kk + 1])
471 )
472 if (
473 self.hard_constraints
474 and self.dynObj is not None
475 and self.dynObj.state_constraint is not None
476 ):
477 out = self.dynObj.state_constraint(time_vec[kk + 1], out)
478 return out.ravel()
480 def forward_pass(
481 self,
482 itr,
483 num_timesteps,
484 traj,
485 state_args,
486 ctrl_args,
487 cost_args,
488 time_vec,
489 inv_state_args,
490 inv_ctrl_args,
491 ):
492 """Forward pass for the smoothing.
494 Parameters
495 ----------
496 itr : int
497 iteration number.
498 num_timesteps : int
499 Total number of timesteps.
500 traj : Nh+1 x N numpy array
501 State trajectory.
502 state_args : tuple
503 Extra arguments for the state matrix calculation.
504 ctrl_args : tuple
505 Extra arguments for the control matrix calculation.
506 cost_args : tuple
507 Extra arguments for the cost function.
508 time_vec : Nh+1 numpy array
509 Time vector for control horizon.
510 inv_state_args : tuple
511 Extra arguments for the inverse state matrix calculation.
512 inv_ctrl_args : tuple
513 Extra arguments for the inverse control matrix calculation.
515 Returns
516 -------
517 traj : Nh+1 x N numpy array
518 State trajectory.
519 """
520 for kk in range(num_timesteps):
521 traj[kk + 1, :] = self.forward_pass_step(
522 itr,
523 kk,
524 time_vec,
525 traj,
526 state_args,
527 ctrl_args,
528 cost_args,
529 inv_state_args,
530 inv_ctrl_args,
531 )
533 return traj
535 def draw_traj(
536 self,
537 fig,
538 plt_inds,
539 fig_h,
540 fig_w,
541 save_animation,
542 num_timesteps,
543 time_vec,
544 state_args,
545 ctrl_args,
546 inv_state_args,
547 inv_ctrl_args,
548 color=None,
549 alpha=0.2,
550 zorder=-10,
551 ):
552 if color is None:
553 color = (0.5, 0.5, 0.5)
554 state_traj = np.nan * np.ones((num_timesteps + 1, self._init_state.size))
555 state_traj[0, :] = self._init_state.flatten()
556 for kk, tt in enumerate(time_vec[:-1]):
557 u = (
558 self.feedback_gain[kk] @ state_traj[kk, :].reshape((-1, 1))
559 + self.feedthrough_gain[kk]
560 )
561 state_traj[kk + 1, :] = self.prop_state(
562 tt,
563 state_traj[kk, :].reshape((-1, 1)),
564 u,
565 state_args,
566 ctrl_args,
567 True,
568 inv_state_args,
569 inv_ctrl_args,
570 ).ravel()
572 # plot backward pass trajectory
573 extra_args = dict(
574 color=color,
575 alpha=alpha,
576 zorder=zorder,
577 )
578 if len(plt_inds) == 3:
579 fig.axes[self._ax].plot(
580 state_traj[:, plt_inds[0]],
581 state_traj[:, plt_inds[1]],
582 state_traj[:, plt_inds[2]],
583 **extra_args,
584 )
585 else:
586 fig.axes[self._ax].plot(
587 state_traj[:, plt_inds[0]],
588 state_traj[:, plt_inds[1]],
589 **extra_args,
590 )
592 if matplotlib.get_backend() != "Agg":
593 plt.pause(0.005)
594 if save_animation:
595 with io.BytesIO() as buff:
596 fig.savefig(buff, format="raw")
597 buff.seek(0)
598 arr = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape(
599 (*(fig.canvas.get_width_height()[1::-1]), -1)
600 )
601 if arr.shape[0] != fig_w or arr.shape[1] != fig_h:
602 img = Image.fromarray(arr).resize((fig_w, fig_h), Image.BICUBIC)
603 else:
604 img = Image.fromarray(arr)
605 return img
607 return []
609 def reset(self, tt, cur_state, end_state):
610 """Reset values for calculating the control parameters.
612 This generally does not need to be called outside of the class. It is
613 automatically called when calculating the control parameters.
615 Parameters
616 ----------
617 tt : float
618 Current time when starting the control calculation.
619 cur_state : N x 1 numpy array
620 Current state.
621 end_state : N x 1 numpy array
622 Desired/reference ending state.
624 Returns
625 -------
626 old_cost : float
627 prior cost.
628 num_timesteps : int
629 number of timesteps in the horizon.
630 traj : Nh x N numpy array
631 state trajectory, each row is a timestep.
632 time_vec : Nh+1 numpy array
633 time vector.
634 """
635 old_cost = float("inf")
636 num_timesteps = int(self.time_horizon / self.dt)
637 self._init_state = cur_state.reshape((-1, 1)).copy()
638 self.end_state = end_state.reshape((-1, 1)).copy()
639 traj = np.nan * np.ones((num_timesteps + 1, cur_state.size))
640 traj[0, :] = cur_state.flatten()
642 self.ct_come_mats = np.zeros(
643 (num_timesteps + 1, cur_state.size, cur_state.size)
644 )
645 self.ct_come_vecs = np.zeros((num_timesteps + 1, cur_state.size, 1))
646 self.ct_go_mats = np.zeros(self.ct_come_mats.shape)
647 self.ct_go_vecs = np.zeros(self.ct_come_vecs.shape)
649 self.feedback_gain = np.zeros(
650 (num_timesteps + 1, self.u_nom.size, cur_state.size)
651 )
652 self.feedthrough_gain = self.u_nom * np.ones(
653 (num_timesteps + 1, self.u_nom.size, 1)
654 )
656 abs_dt = np.abs(self.dt)
658 time_vec = tt + abs_dt * np.linspace(0, 1, num_timesteps + 1, endpoint=True)
660 return old_cost, num_timesteps, traj, time_vec
662 def calculate_control(
663 self,
664 tt,
665 cur_state,
666 end_state,
667 state_args=None,
668 ctrl_args=None,
669 cost_args=None,
670 inv_state_args=None,
671 inv_ctrl_args=None,
672 provide_details=False,
673 disp=True,
674 show_animation=False,
675 save_animation=False,
676 plt_opts=None,
677 ttl=None,
678 fig=None,
679 plt_inds=None,
680 ax_num=0,
681 ):
682 """Calculate the control parameters and state trajectory.
684 Parameters
685 ----------
686 tt : float
687 Current time when starting the control calculation.
688 cur_state : N x 1 numpy array
689 Current state.
690 end_state : N x 1 numpy array
691 Desired/reference ending state.
692 state_args : tuple, optional
693 Extra arguments for calculating the state matrix. The default is None.
694 ctrl_args : tuple, optional
695 Extra arguments for calculating the input matrix. The default is None.
696 cost_args : tuple, optional
697 Extra arguments for the cost function, either the non-quadratic
698 part or the custom function. The default is None.
699 inv_state_args : tuple, optional
700 Extra arguments for calculating the inverse state matrix. The
701 default is None.
702 inv_ctrl_args : tuple, optional
703 Extra arguments for calculating the inverse input matrix. The
704 default is None.
705 provide_details : bool, optional
706 Flag indicating if additional outputs should be provided. The
707 default is False.
708 disp : bool, optional
709 Flag indicating if extra text should be displayed during the
710 optimization. The default is True.
711 show_animation : bool, optional
712 Flag indicating if an animation should be shown during the
713 optimization. The default is False.
714 save_animation : bool, optional
715 Flag indicating if the frames of the animation should be saved.
716 Requires that the animation is shown. The default is False.
717 plt_opts : dict, optional
718 Options for the plot, see :func:`gncpy.plotting.init_plotting_opts`.
719 The default is None.
720 ttl : string, optional
721 Title for the plot. The default is None.
722 fig : matplotlib figure, optional
723 Figure to draw on. If supplied only the end point and paths are
724 drawn, and the title is updated if supplied; the rest of the figure
725 remains unchanged. The default is None which creates a new figure
726 and adds a title. If no title is supplied a default is used.
727 plt_inds : list, optional
728 2 element list of state indices for plotting. The default is None
729 which assumes [0, 1].
731 Returns
732 -------
733 u : Nu x 1 numpy array
734 Control input for current timestep.
735 cost : float, optional
736 Cost of the trajectory
737 state_traj : Nh+1 x N numpy array, optional
738 State trajectory over the horizon, or until the reference is reached.
739 ctrl_signal : Nh x Nu numy array, optional
740 Control inputs over the horizon, or until the reference is reached.
741 fig : matplotlib figure, optional
742 Figure handle for the generated plot, None if animation is not shown.
743 frame_list : list, optional
744 Each element is a PIL.Image corresponding to an animation frame.
745 """
746 if state_args is None:
747 state_args = ()
748 if ctrl_args is None:
749 ctrl_args = ()
750 if cost_args is None:
751 cost_args = ()
752 if inv_state_args is None:
753 inv_state_args = ()
754 if inv_ctrl_args is None:
755 inv_ctrl_args = ()
757 if plt_inds is None:
758 plt_inds = [0, 1]
759 self._ax = ax_num
761 old_cost, num_timesteps, traj, time_vec = self.reset(tt, cur_state, end_state)
763 frame_list = []
764 if show_animation:
765 if fig is None:
766 fig = plt.figure()
767 if len(plt_inds) == 3:
768 extra_args = {"projection": "3d"}
769 else:
770 extra_args = {}
771 fig.add_subplot(1, 1, 1, **extra_args)
772 try:
773 fig.axes[self._ax].set_aspect("equal", adjustable="box")
774 except Exception:
775 pass
777 if plt_opts is None:
778 plt_opts = gplot.init_plotting_opts(f_hndl=fig)
780 if ttl is None:
781 ttl = "ELQR"
783 gplot.set_title_label(
784 fig, self._ax, plt_opts, ttl=ttl, use_local=(self._ax != 0)
785 )
787 # draw start
788 fig.axes[self._ax].scatter(
789 *self._init_state.ravel()[plt_inds],
790 marker="o",
791 color="g",
792 zorder=1000,
793 )
794 fig.tight_layout()
795 else:
796 if ttl is not None:
797 gplot.set_title_label(
798 fig, self._ax, plt_opts, ttl=ttl, use_local=(self._ax != 0)
799 )
801 fig.axes[self._ax].scatter(
802 *self.end_state.ravel()[plt_inds],
803 marker="x",
804 color="r",
805 zorder=1000,
806 )
807 if matplotlib.get_backend() != "Agg":
808 plt.pause(0.01)
810 # for stopping simulation with the esc key.
811 fig.canvas.mpl_connect(
812 "key_release_event",
813 lambda event: [exit(0) if event.key == "escape" else None],
814 )
815 fig_w, fig_h = fig.canvas.get_width_height()
817 # save first frame of animation
818 if save_animation:
819 with io.BytesIO() as buff:
820 fig.savefig(buff, format="raw")
821 buff.seek(0)
822 arr = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape(
823 (*(fig.canvas.get_width_height()[1::-1]), -1)
824 )
825 if arr.shape[0] != fig_w or arr.shape[1] != fig_h:
826 img = Image.fromarray(arr).resize((fig_w, fig_h), Image.BICUBIC)
827 else:
828 img = Image.fromarray(arr)
829 frame_list.append(img)
831 if disp:
832 print("Starting ELQR optimization loop...")
834 for ii in range(self.max_iters):
835 # forward pass
836 traj = self.forward_pass(
837 ii,
838 num_timesteps,
839 traj,
840 state_args,
841 ctrl_args,
842 cost_args,
843 time_vec,
844 inv_state_args,
845 inv_ctrl_args,
846 )
848 # quadratize final cost
849 traj = self.quadratize_final_cost(
850 ii, num_timesteps, traj, time_vec, cost_args
851 )
853 # backward pass
854 traj = self.backward_pass(
855 ii,
856 num_timesteps,
857 traj,
858 state_args,
859 ctrl_args,
860 cost_args,
861 time_vec,
862 inv_state_args,
863 inv_ctrl_args,
864 )
866 # get cost
867 cost = 0
868 x = traj[0, :].copy().reshape((-1, 1))
869 for kk, tt in enumerate(time_vec[:-1]):
870 u = self.feedback_gain[kk] @ x + self.feedthrough_gain[kk]
871 if self.control_constraints is not None:
872 u = self.control_constraints(tt, u)
873 cost += self.cost_function(
874 tt,
875 x,
876 u,
877 cost_args,
878 is_initial=(kk == 0),
879 is_final=False,
880 )
881 x = self.prop_state(
882 tt, x, u, state_args, ctrl_args, True, inv_state_args, inv_ctrl_args
883 )
884 cost += self.cost_function(
885 time_vec[-1],
886 x,
887 u,
888 cost_args,
889 is_initial=False,
890 is_final=True,
891 )
893 if show_animation:
894 img = self.draw_traj(
895 fig,
896 plt_inds,
897 fig_h,
898 fig_w,
899 (save_animation and ii % self.gif_frame_skip == 0),
900 num_timesteps,
901 time_vec,
902 state_args,
903 ctrl_args,
904 inv_state_args,
905 inv_ctrl_args,
906 )
907 if save_animation and ii % self.gif_frame_skip == 0:
908 frame_list.append(img)
910 if disp:
911 print("\tIteration: {:3d} Cost: {:10.4f}".format(ii, cost))
913 # check for convergence
914 if np.abs((old_cost - cost) / cost) < self.tol:
915 break
916 old_cost = cost
918 # create outputs and return
919 ctrl_signal = np.nan * np.ones((num_timesteps, self.u_nom.size))
920 state_traj = np.nan * np.ones((num_timesteps + 1, self._init_state.size))
921 cost = 0
922 state_traj[0, :] = self._init_state.flatten()
923 for kk, tt in enumerate(time_vec[:-1]):
924 ctrl_signal[kk, :] = (
925 self.feedback_gain[kk] @ state_traj[kk, :].reshape((-1, 1))
926 + self.feedthrough_gain[kk]
927 ).ravel()
928 if self.control_constraints is not None:
929 ctrl_signal[kk, :] = self.control_constraints(
930 tt, ctrl_signal[kk, :].reshape((-1, 1))
931 ).ravel()
933 cost += self.cost_function(
934 tt,
935 state_traj[kk, :].reshape((-1, 1)),
936 ctrl_signal[kk, :].reshape((-1, 1)),
937 cost_args,
938 is_initial=(kk == 0),
939 is_final=False,
940 )
941 state_traj[kk + 1, :] = self.prop_state(
942 tt,
943 state_traj[kk, :].reshape((-1, 1)),
944 ctrl_signal[kk, :].reshape((-1, 1)),
945 state_args,
946 ctrl_args,
947 True,
948 inv_state_args,
949 inv_ctrl_args,
950 ).ravel()
952 cost += self.cost_function(
953 time_vec[-1],
954 state_traj[num_timesteps, :].reshape((-1, 1)),
955 ctrl_signal[num_timesteps - 1, :].reshape((-1, 1)),
956 cost_args,
957 is_initial=False,
958 is_final=True,
959 )
961 if show_animation:
962 extra_args = dict(
963 linestyle="-",
964 color="g",
965 )
966 if len(plt_inds) == 3:
967 fig.axes[self._ax].plot(
968 state_traj[:, plt_inds[0]],
969 state_traj[:, plt_inds[1]],
970 state_traj[:, plt_inds[2]],
971 **extra_args,
972 )
973 else:
974 fig.axes[self._ax].plot(
975 state_traj[:, plt_inds[0]], state_traj[:, plt_inds[1]], **extra_args
976 )
977 if matplotlib.get_backend() != "Agg":
978 plt.pause(0.001)
979 if save_animation:
980 with io.BytesIO() as buff:
981 fig.savefig(buff, format="raw")
982 buff.seek(0)
983 arr = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape(
984 (*(fig.canvas.get_width_height()[1::-1]), -1)
985 )
986 if arr.shape[0] != fig_w or arr.shape[1] != fig_h:
987 img = Image.fromarray(arr).resize((fig_w, fig_h), Image.BICUBIC)
988 else:
989 img = Image.fromarray(arr)
990 frame_list.append(img)
992 u = ctrl_signal[0, :].reshape((-1, 1))
993 details = (cost, state_traj, ctrl_signal, fig, frame_list)
994 return (u, *details) if provide_details else u