Coverage for src/gncpy/control/lqr.py: 92%
226 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
1import numpy as np
2import scipy.linalg as la
4import gncpy.dynamics.basic as gdyn
7class LQR:
8 r"""Implements a Linear Quadratic Regulator (LQR) controller.
10 Notes
11 -----
12 It can be either an infinite horizon where a single controller gain :math:`K`
13 is found or a finite horizon where value iteration is preformed during the
14 backward pass to calculate a controller gain matrix at every step of the
15 time horizon. The finite horizon case is consistent with a Receding Horizon
16 formulation of the controller. If non-linear dyanmics are supplied then they
17 are linearized at every step of the finite horizon, or once about the initial
18 state for infinite horizon. This can also track references by supplying an
19 end state. If infinite horizon the the control input is given by
21 .. math::
23 u = K (r - x) + k
25 if the problem is finite horizon the control input (with or without reference
26 tracking) is given by
28 .. math::
30 u = K x + k
32 Attributes
33 ----------
34 dynObj : :class:`gncpy.dynamics.basic.DynamicsBase`
35 Dynamics model to generate control parameters for.
36 time_horizon : float
37 Length of the time horizon for the controller.
38 u_nom : Nu x 1 numpy array
39 Nominal control input
40 ct_go_mats : Nh+1 x N x N numpy array
41 Cost-to-go matrices, 1 per step in the time horizon.
42 ct_go_vecs : Nh+1 x N x 1 numpy array
43 Cost-to-go vectors, 1 per step in the time horizon.
44 feedback_gain : (Nh) x Nu x N numpy array
45 Feedback gain matrix. If finite horizon there is 1 per timestep.
46 feedthrough_gain : (Nh) x Nu x 1
47 Feedthrough gain vector. If finite horizon there is 1 per timestep.
48 end_state : N x 1
49 Ending state. This generally does not need to be set directly.
50 hard_constraints : bool
51 Flag indicating that state constraints should be enforced during value propagation.
52 """
54 def __init__(self, time_horizon=float("inf"), hard_constraints: bool = False):
55 """Initialize an object.
57 Parameters
58 ----------
59 time_horizon : float, optional
60 Time horizon for the controller. The default is float("inf").
61 hard_constraints : bool, optional
62 Flag indicating that state constraints should be enforced during value propagation.
63 """
64 super().__init__()
66 self._Q = None
67 self._R = None
68 self._P = None
70 self.dynObj = None
71 self._dt = None
73 self.time_horizon = time_horizon
75 self.u_nom = np.array([])
76 self.ct_go_mats = np.array([])
77 self.ct_go_vecs = np.array([])
78 self.feedback_gain = np.array([])
79 self.feedthrough_gain = np.array([])
81 self._init_state = np.array([])
82 self.end_state = np.array([])
83 self.hard_constraints = hard_constraints
84 self.control_constraints = None
86 @property
87 def dt(self):
88 """Timestep."""
89 if self.dynObj is not None and isinstance(
90 self.dynObj, gdyn.NonlinearDynamicsBase
91 ):
92 return self.dynObj.dt
93 else:
94 return self._dt
96 @dt.setter
97 def dt(self, val):
98 if self.dynObj is not None and isinstance(
99 self.dynObj, gdyn.NonlinearDynamicsBase
100 ):
101 self.dynObj.dt = val
102 else:
103 self._dt = val
105 @property
106 def Q(self):
107 """Read only state penalty matrix."""
108 return self._Q
110 @property
111 def R(self):
112 """Read only control penalty matrix."""
113 return self._R
115 def set_state_model(self, u_nom, control_constraints=None, dynObj=None, dt=None):
116 """Set the state/dynamics model.
118 Parameters
119 ----------
120 u_nom : Nu x 1 numpy array
121 Nominal control input.
122 control_constraints : callable
123 Function that takes in timestep and control signal and returns the constrained control signal as a Nu x 1
124 numpy array
125 dynObj : :class:`gncpy.dynamics.basic.DynamicsBase`, optional
126 System dynamics to control. The default is None.
127 dt : float, optional
128 Timestep to use. Will update the dynamic object if applicable.
129 The default is None.
130 """
131 self.u_nom = u_nom.reshape((-1, 1))
132 self.dynObj = dynObj
133 if dt is not None:
134 self.dt = dt
135 if control_constraints is not None:
136 self.control_constraints = control_constraints
138 def set_cost_model(self, Q, R, P=None):
139 r"""Sets the cost model used.
141 Notes
142 -----
143 This implements an LQR controller for the cost function
145 .. math::
147 J = \int^{t_f}_0 x^T Q x + u^T R u + u^T P x d\tau
149 Parameters
150 ----------
151 Q : N x N numpy array
152 State cost matrix.
153 R : Nu x Nu numpy array
154 Control input cost matrix.
155 P : Nu x N, optional
156 State and control correlation cost matrix. The default is None which
157 gives a zero matrix.
158 """
159 if Q.shape[0] != Q.shape[1]:
160 raise RuntimeError("Q must b a square matrix!")
161 if R.shape[0] != R.shape[1]:
162 raise RuntimeError("R must b a square matrix!")
163 self._Q = Q
164 self._R = R
166 if P is None:
167 self._P = np.zeros((self._R.shape[0], self._Q.shape[0]))
168 else:
169 self._P = P
171 def prop_state(
172 self,
173 tt,
174 x_hat,
175 u_hat,
176 state_args,
177 ctrl_args,
178 forward,
179 inv_state_args,
180 inv_ctrl_args,
181 ):
182 """Propagate the state in time."""
183 if self.dynObj is not None:
184 if isinstance(self.dynObj, gdyn.NonlinearDynamicsBase):
185 if forward and self.dynObj.dt < 0:
186 self.dynObj.dt *= -1 # set to go forward
187 elif not forward and self.dynObj.dt > 0:
188 self.dynObj.dt *= -1
190 if forward:
191 return self.dynObj.propagate_state(
192 tt, x_hat, u=u_hat, state_args=state_args, ctrl_args=ctrl_args
193 )
194 else:
195 return self.dynObj.propagate_state(
196 tt,
197 x_hat,
198 u=u_hat,
199 state_args=inv_state_args,
200 ctrl_args=inv_ctrl_args,
201 )
202 else:
203 if forward:
204 return self.dynObj.propagate_state(
205 tt, x_hat, u=u_hat, state_args=state_args, ctrl_args=ctrl_args
206 )
207 else:
208 A, B = self.get_state_space(tt, x_hat, u_hat, state_args, ctrl_args)
209 prev_state = la.inv(A) @ (x_hat - B @ u_hat)
210 if self.dynObj.state_constraint is not None:
211 prev_state = self.dynObj.state_constraint(tt, prev_state)
212 return prev_state
214 else:
215 raise NotImplementedError()
217 def get_state_space(self, tt, x_hat, u_hat, state_args, ctrl_args):
218 """Return the :math:`A, B` matrices.
220 Parameters
221 ----------
222 tt : float
223 Current timestep.
224 x_hat : N x 1 numpy array
225 Current state.
226 u_hat : Nu x 1 numpy array
227 Current control input.
228 state_args : tuple
229 Extra arguments for the state calculation.
230 ctrl_args : tuple
231 Extra arguments for the control calculation.
233 Raises
234 ------
235 NotImplementedError
236 If unsupported dynamics are used.
238 Returns
239 -------
240 A : N x N numpy array
241 State matrix.
242 B : N x Nu numpy array
243 Input matrix.
244 """
245 if self.dynObj is not None:
246 if isinstance(self.dynObj, gdyn.NonlinearDynamicsBase):
247 if self.dynObj.dt < 0:
248 self.dynObj.dt *= -1 # flip back to forward to get forward matrices
250 A = self.dynObj.get_state_mat(
251 tt, x_hat, *state_args, u=u_hat, ctrl_args=ctrl_args
252 )
253 B = self.dynObj.get_input_mat(
254 tt, x_hat, u_hat, state_args=state_args, ctrl_args=ctrl_args
255 )
257 else:
258 A = self.dynObj.get_state_mat(tt, *state_args)
259 B = self.dynObj.get_input_mat(tt, *ctrl_args)
261 else:
262 raise NotImplementedError("Need to implement this case")
264 return A, B
266 def prop_state_backward(
267 self, tt, x_hat, u_hat, state_args, ctrl_args, inv_state_args, inv_ctrl_args
268 ):
269 """Propagate the state backward and get the forward state space.
271 Parameters
272 ----------
273 tt : float
274 Future timestep.
275 x_hat : N x 1 numpy array
276 Future state.
277 u_hat : Nu x 1 numpy array
278 Future control input.
279 state_args : tuple
280 Extra arguments for the state matrix calculation.
281 ctrl_args : tuple
282 Extra arguments for the control matrix calculation.
283 inv_state_args : tuple
284 Extra arguments for the inverse state matrix calculation.
285 inv_ctrl_args : tuple
286 Extra arguments for the inverse control matrix calculation.
288 Returns
289 -------
290 x_hat_p : N x 1 numpy array
291 Previous state.
292 A : N x N numpy array
293 Forward state transition matrix.
294 B : N x Nu
295 Forward input matrix.
296 c : N x 1 numpy array
297 Forward c vector.
298 """
299 x_hat_p = self.prop_state(
300 tt,
301 x_hat,
302 u_hat,
303 state_args,
304 ctrl_args,
305 False,
306 inv_state_args,
307 inv_ctrl_args,
308 )
310 A, B = self.get_state_space(tt, x_hat_p, u_hat, state_args, ctrl_args)
311 c = x_hat - A @ x_hat_p - B @ u_hat
313 return x_hat_p, A, B, c
315 def _determine_cost_matrices(
316 self, tt, itr, x_hat, u_hat, is_initial, is_final, cost_args
317 ):
318 """Calculate the cost matrices."""
319 P = self._P
320 if is_final:
321 Q = self._Q
322 q = -(Q @ self.end_state)
323 R = np.zeros(self._R.shape)
324 r = np.zeros(self.u_nom.shape)
325 else:
326 if is_initial:
327 Q = self._Q
328 q = -(Q @ self._init_state)
329 R = self._R
330 r = -(R @ self.u_nom)
331 else:
332 Q = np.zeros(self._Q.shape)
333 q = np.zeros(self._init_state.shape)
334 R = self._R
335 r = -(R @ self.u_nom)
336 return P, Q, R, q, r
338 def _back_pass_update_traj(self, x_hat_p, kk):
339 """Get the next state for backward pass, helpful for inherited classes."""
340 return x_hat_p.ravel()
342 def backward_pass_step(
343 self,
344 itr,
345 kk,
346 time_vec,
347 traj,
348 state_args,
349 ctrl_args,
350 cost_args,
351 inv_state_args,
352 inv_ctrl_args,
353 ):
354 tt = time_vec[kk]
355 u_hat = (
356 self.feedback_gain[kk] @ traj[kk + 1, :].reshape((-1, 1))
357 + self.feedthrough_gain[kk]
358 )
359 if self.hard_constraints and self.control_constraints is not None:
360 u_hat = self.control_constraints(tt, u_hat)
361 x_hat_p, A, B, c = self.prop_state_backward(
362 tt,
363 traj[kk + 1, :].reshape((-1, 1)),
364 u_hat,
365 state_args,
366 ctrl_args,
367 inv_state_args,
368 inv_ctrl_args,
369 )
371 P, Q, R, q, r = self._determine_cost_matrices(
372 tt, itr, x_hat_p, u_hat, kk == 0, False, cost_args
373 )
375 ctm_A = self.ct_go_mats[kk + 1] @ A
376 ctv_ctm_c = self.ct_go_vecs[kk + 1] + self.ct_go_mats[kk + 1] @ c
377 C = P + B.T @ ctm_A
378 D = Q + A.T @ ctm_A
379 E = R + B.T @ self.ct_go_mats[kk + 1] @ B
380 d = q + A.T @ ctv_ctm_c
381 e = r + B.T @ ctv_ctm_c
383 neg_inv_E = -np.linalg.inv(E)
384 self.feedback_gain[kk] = neg_inv_E @ C
385 self.feedthrough_gain[kk] = neg_inv_E @ e
387 self.ct_go_mats[kk] = D + C.T @ self.feedback_gain[kk]
388 self.ct_go_vecs[kk] = d + C.T @ self.feedthrough_gain[kk]
390 out = self._back_pass_update_traj(x_hat_p, kk)
391 if (
392 self.hard_constraints
393 and self.dynObj is not None
394 and self.dynObj.state_constraint is not None
395 ):
396 out = self.dynObj.state_constraint(time_vec[kk], out.reshape((-1, 1)))
397 return out.ravel()
399 def backward_pass(
400 self,
401 itr,
402 num_timesteps,
403 traj,
404 state_args,
405 ctrl_args,
406 cost_args,
407 time_vec,
408 inv_state_args,
409 inv_ctrl_args,
410 ):
411 """Backward pass for the finite horizon case.
413 Parameters
414 ----------
415 itr : int
416 iteration number.
417 num_timesteps : int
418 Total number of timesteps.
419 traj : Nh+1 x N numpy array
420 State trajectory.
421 state_args : tuple
422 Extra arguments for the state matrix calculation.
423 ctrl_args : tuple
424 Extra arguments for the control matrix calculation.
425 cost_args : tuple
426 Extra arguments for the cost function. Not used for LQR.
427 time_vec : Nh+1 numpy array
428 Time vector for control horizon.
429 inv_state_args : tuple
430 Extra arguments for the inverse state matrix calculation.
431 inv_ctrl_args : tuple
432 Extra arguments for the inverse control matrix calculation.
434 Returns
435 -------
436 traj : Nh+1 x N numpy array
437 State trajectory.
438 """
439 for kk in range(num_timesteps - 1, -1, -1):
440 traj[kk, :] = self.backward_pass_step(
441 itr,
442 kk,
443 time_vec,
444 traj,
445 state_args,
446 ctrl_args,
447 cost_args,
448 inv_state_args,
449 inv_ctrl_args,
450 )
452 return traj
454 def cost_function(self, tt, state, ctrl_input, is_initial=False, is_final=False):
455 """Calculates the cost for the state and control input.
457 Parameters
458 ----------
459 tt : float
460 Current timestep.
461 state : N x 1 numpy array
462 Current state.
463 ctrl_input : Nu x 1 numpy array
464 Current control input.
465 is_initial : bool, optional
466 Flag indicating if it's the initial time. The default is False.
467 is_final : bool, optional
468 Flag indicating if it's the final time. The default is False.
470 Returns
471 -------
472 float
473 Cost.
474 """
475 if is_final:
476 sdiff = state - self.end_state
477 return (sdiff.T @ self._Q @ sdiff).item()
479 else:
480 cost = 0
481 if is_initial:
482 sdiff = state - self._init_state
483 cost += (sdiff.T @ self._Q @ sdiff).item()
485 cdiff = ctrl_input - self.u_nom
486 cost += (cdiff.T @ self._R @ cdiff).item()
487 return cost
489 def solve_dare(self, cur_time, cur_state, state_args=None, ctrl_args=None):
490 """Solve the discrete algebraic ricatti equation.
492 Parameters
493 ----------
494 cur_time : float
495 Current time.
496 cur_state : N x 1 numpy array
497 Current state.
498 state_args : tuple, optional
499 Additional arguments for calculating the state. The default is None.
500 ctrl_args : tuple, optional
501 Additional agruments for calculating the input matrix. The default
502 is None.
504 Returns
505 -------
506 S : N x N numpy array
507 DARE solution.
508 F : N x N numpy array
509 Discrete state transition matrix
510 G : N x Nu numpy array
511 Discrete input matrix.
512 """
513 if state_args is None:
514 state_args = ()
515 if ctrl_args is None:
516 ctrl_args = ()
517 F, G = self.get_state_space(
518 cur_time, cur_state, self.u_nom, state_args, ctrl_args
519 )
520 S = la.solve_discrete_are(F, G, self._Q, self._R)
522 return S, F, G
524 def calculate_control(
525 self,
526 cur_time,
527 cur_state,
528 end_state=None,
529 end_state_tol=1e-2,
530 max_inf_iters=int(1e3),
531 check_inds=None,
532 state_args=None,
533 ctrl_args=None,
534 inv_state_args=None,
535 inv_ctrl_args=None,
536 provide_details=False,
537 ):
538 """Calculate the control parameters and state trajectory.
540 This can track a reference or regulate to zero. It can also be inifinte
541 or finite horizon.
543 Parameters
544 ----------
545 cur_time : float
546 Current time when starting the control calculation.
547 cur_state : N x 1 numpy array
548 Current state.
549 end_state : N x 1 numpy array, optional
550 Desired/reference ending state. The default is None.
551 end_state_tol : float, optional
552 Tolerance on the reference state when calculating the state trajectory
553 for the inifinte horizon case. The default is 1e-2.
554 max_inf_iters : int
555 Maximum number of steps to use in the trajectory when finding
556 the state trajectory for the infinite horizon case to avoid infinite
557 loops.
558 check_inds : list, optional
559 List of indices of the state vector to check when determining
560 the end condition for the state trajectory calculation for the
561 infinite horizon case. The default is None which checks the full
562 state vector.
563 state_args : tuple, optional
564 Extra arguments for calculating the state matrix. The default is None.
565 ctrl_args : tuple, optional
566 Extra arguments for calculating the input matrix. The default is None.
567 inv_state_args : tuple, optional
568 Extra arguments for calculating the inverse state matrix. The
569 default is None.
570 inv_ctrl_args : tuple, optional
571 Extra arguments for calculating the inverse input matrix. The
572 default is None.
573 provide_details : bool, optional
574 Flag indicating if additional outputs should be provided. The
575 default is False.
577 Returns
578 -------
579 u : Nu x 1 numpy array
580 Control input for current timestep.
581 cost : float, optional
582 Cost of the trajectory
583 state_traj : Nh+1 x N numpy array, optional
584 State trajectory over the horizon, or until the reference is reached.
585 ctrl_signal : Nh x Nu numy array, optional
586 Control inputs over the horizon, or until the reference is reached.
587 """
588 if state_args is None:
589 state_args = ()
590 if ctrl_args is None:
591 ctrl_args = ()
592 if inv_state_args is None:
593 inv_state_args = ()
594 if inv_ctrl_args is None:
595 inv_ctrl_args = ()
596 if end_state is None:
597 end_state = np.zeros((cur_state.size, 1))
599 self._init_state = cur_state.reshape((-1, 1)).copy()
600 self.end_state = end_state.reshape((-1, 1)).copy()
602 if np.isinf(self.time_horizon) or self.time_horizon <= 0:
603 S, F, G = self.solve_dare(
604 cur_time, cur_state, state_args=state_args, ctrl_args=ctrl_args
605 )
606 self.feedback_gain = la.inv(G.T @ S @ G + self._R) @ (G.T @ S @ F + self._P)
607 self.feedthrough_gain = self.u_nom.copy()
608 state_traj = cur_state.reshape((1, -1)).copy()
610 dx = end_state - cur_state
612 if self.dt is None:
613 ctrl_signal = (self.feedback_gain @ dx + self.feedthrough_gain).ravel()
614 if self.control_constraints is not None:
615 ctrl_signal = self.control_constraints(0, ctrl_signal).ravel()
616 cost = np.nan
618 else:
619 ctrl_signal = None
620 cost = self.cost_function(
621 cur_time, cur_state, self.u_nom, is_initial=True, is_final=False
622 )
624 if check_inds is None:
625 check_inds = range(cur_state.size)
627 timestep = cur_time
628 done = (
629 np.linalg.norm(
630 state_traj[-1, check_inds] - end_state[check_inds, 0]
631 )
632 <= end_state_tol
633 )
635 itr = 0
636 while not done:
637 timestep += self.dt
638 dx = end_state - state_traj[-1, :].reshape((-1, 1))
639 u = self.feedback_gain @ dx + self.feedthrough_gain
640 if self.control_constraints is not None:
641 u = self.control_constraints(timestep, u)
642 if ctrl_signal is None:
643 ctrl_signal = u.reshape((1, -1))
644 else:
645 ctrl_signal = np.vstack((ctrl_signal, u.ravel()))
647 x = self.prop_state(
648 timestep,
649 state_traj[-1, :].reshape((-1, 1)),
650 u,
651 state_args,
652 ctrl_args,
653 True,
654 (),
655 (),
656 )
657 state_traj = np.vstack((state_traj, x.ravel()))
659 itr += 1
660 done = (
661 np.linalg.norm(
662 state_traj[-1, check_inds] - end_state[check_inds, 0]
663 )
664 <= end_state_tol
665 or itr >= max_inf_iters
666 )
668 cost += self.cost_function(
669 timestep, x, u, is_initial=False, is_final=done
670 )
672 # if we start too close to the end state ctrl_signal can be None
673 if ctrl_signal is None:
674 ctrl_signal = self.u_nom.copy().reshape((1, -1))
676 else:
677 num_timesteps = int(self.time_horizon / self.dt)
678 time_vec = cur_time + self.dt * np.linspace(
679 0, 1, num_timesteps + 1, endpoint=True
680 )
682 self.ct_go_mats = np.zeros(
683 (num_timesteps + 1, cur_state.size, cur_state.size)
684 )
685 self.ct_go_vecs = np.zeros((num_timesteps + 1, cur_state.size, 1))
686 self.ct_go_mats[-1] = self._Q.copy()
687 self.ct_go_vecs[-1] = -(self._Q @ self.end_state)
688 self.feedback_gain = np.zeros(
689 (num_timesteps + 1, self.u_nom.size, cur_state.size)
690 )
691 self.feedthrough_gain = self.u_nom * np.ones(
692 (num_timesteps + 1, self.u_nom.size, 1)
693 )
695 traj = np.nan * np.ones((num_timesteps + 1, cur_state.size))
696 traj[-1, :] = end_state.flatten()
697 self.backward_pass(
698 0,
699 num_timesteps,
700 traj,
701 state_args,
702 ctrl_args,
703 (),
704 time_vec,
705 inv_state_args,
706 inv_ctrl_args,
707 )
709 ctrl_signal = np.nan * np.ones((num_timesteps, self.u_nom.size))
710 state_traj = np.nan * np.ones((num_timesteps + 1, cur_state.size))
711 state_traj[0, :] = cur_state.flatten()
712 cost = 0
713 for kk, tt in enumerate(time_vec[:-1]):
714 u = (
715 self.feedback_gain[kk] @ state_traj[kk, :].reshape((-1, 1))
716 + self.feedthrough_gain[kk]
717 )
718 if self.control_constraints is not None:
719 u = self.control_constraints(tt, u)
720 ctrl_signal[kk, :] = u.ravel()
722 x = self.prop_state(
723 tt,
724 state_traj[kk, :].reshape((-1, 1)),
725 u,
726 state_args,
727 ctrl_args,
728 True,
729 (),
730 (),
731 )
732 state_traj[kk + 1, :] = x.ravel()
734 cost += self.cost_function(
735 tt, x, u, is_initial=False, is_final=(kk == time_vec.size)
736 )
738 u = ctrl_signal[0, :].reshape((-1, 1))
739 details = (
740 cost,
741 state_traj,
742 ctrl_signal,
743 )
744 return (u, *details) if provide_details else u