Coverage for src/gncpy/control/elqr.py: 66%

269 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-09-28 15:43 +0000

1import io 

2import matplotlib 

3import matplotlib.pyplot as plt 

4import numpy as np 

5import scipy.linalg as la 

6from PIL import Image 

7 

8import gncpy.dynamics.basic as gdyn 

9import gncpy.math as gmath 

10import gncpy.plotting as gplot 

11from gncpy.control.lqr import LQR 

12 

13 

14class ELQR(LQR): 

15 """Implements an Extended Linear Quadratic Regulator (ELQR) controller. 

16 

17 This is based on 

18 :cite:`Berg2016_ExtendedLQRLocallyOptimalFeedbackControlforSystemswithNonLinearDynamicsandNonQuadraticCost`. 

19 

20 Attributes 

21 ---------- 

22 max_iters : int 

23 Maximum number of iterations to try for convergence. 

24 tol : float 

25 Tolerance for convergence. 

26 ct_come_mats : Nh+1 x N x N numpy array 

27 Cost-to-come matrices, 1 per step in the time horizon. 

28 ct_come_vecs : Nh+1 x N x 1 numpy array 

29 Cost-to-come vectors, 1 per step in the time horizon. 

30 use_custom_cost : bool 

31 Flag indicating if a custom cost function should be used. 

32 gif_frame_skip : int 

33 Number of frames to skip when saving the gif. Set to 1 to take every 

34 frame. 

35 """ 

36 

37 def __init__(self, max_iters=1e3, tol=1e-4, **kwargs): 

38 """Initialize an object. 

39 

40 Parameters 

41 ---------- 

42 max_iters : int, optional 

43 Maximum number of iterations to try for convergence. The default is 

44 1e3. 

45 tol : float, optional 

46 Tolerance on convergence. The default is 1e-4. 

47 **kwargs : dict 

48 Additional arguments. 

49 """ 

50 key = "time_horizon" 

51 if key not in kwargs: 

52 kwargs[key] = 10 

53 

54 super().__init__(**kwargs) 

55 self.max_iters = int(max_iters) 

56 self.tol = tol 

57 

58 self.ct_come_mats = np.array([]) 

59 self.ct_come_vecs = np.array([]) 

60 

61 self.use_custom_cost = False 

62 self.gif_frame_skip = 1 

63 self._non_quad_fun = None 

64 self._quad_modifier = None 

65 self._cost_fun = None 

66 

67 self._ax = 0 

68 

69 def set_cost_model( 

70 self, 

71 Q=None, 

72 R=None, 

73 non_quadratic_fun=None, 

74 quad_modifier=None, 

75 cost_fun=None, 

76 skip_validity_check=False, 

77 ): 

78 r"""Sets the cost model. 

79 

80 Either `Q`, and `R` must be supplied (and optionally `quad_modifier`) 

81 or `cost_fun`. If `Q and `R` are specified a `non_quadratic_fun` is 

82 also needed by the code but this function does not force the requirement. 

83 

84 Notes 

85 ----- 

86 This assumes the following form for the cost function 

87 

88 .. math:: 

89 

90 J = x_f^T Q x_f + \int_{t_0}^{t_f} x^T Q x + u^T R u + u^T P x + f(\tau, x, u) d\tau 

91 

92 where :math:`f(t, x, u)` contains the non-quadratic terms and gets 

93 automatically quadratized by the algorithm. 

94 

95 A custom cost function can also be provided for :math:`J` in which case 

96 the entire function :math:`J(t, x, u)` will be quadratized at every 

97 timestep. 

98 

99 Parameters 

100 ---------- 

101 Q : N x N numpy array, optional 

102 State cost matrix. The default is None. 

103 R : Nu x Nu numpy array, optional 

104 Control cost matrix. The default is None. 

105 non_quadratic_fun : callable, optional 

106 Non-quadratic portion of the standard cost function. This should 

107 have the form 

108 :code:`f(t, state, ctrl_input, end_state, is_initial, is_final, *args)` 

109 and return a scalar additional cost. The default is None. 

110 quad_modifier : callable, optional 

111 Function to modifiy the :math:`P, Q, R, q, r` matrices before 

112 adding the non-quadratic terms. Must have the form 

113 :code:`f(itr, t, P, Q, R, q, r)` and return a tuple of 

114 :code:`(P, Q, R, q, r)`. The default is None. 

115 cost_fun : callable, optional 

116 Custom cost function, must handle all cases and will be numericaly 

117 quadratized at every timestep. Must have the form 

118 :code:`f(t, state, ctrl_input, end_state, is_initial, is_final, *args)` 

119 and return a scalar total cost. The default is None. 

120 skip_validity_check : bool 

121 Flag indicating if an error should be raised if the wrong input 

122 combination is given. 

123 

124 Raises 

125 ------ 

126 RuntimeError 

127 Invlaid combination of input arguments. 

128 """ 

129 if non_quadratic_fun is not None: 

130 self._non_quad_fun = non_quadratic_fun 

131 

132 self._quad_modifier = quad_modifier 

133 

134 if Q is not None and R is not None: 

135 super().set_cost_model(Q, R) 

136 self.use_custom_cost = False 

137 

138 elif cost_fun is not None: 

139 self._cost_fun = cost_fun 

140 self.use_custom_cost = True 

141 

142 else: 

143 if not skip_validity_check: 

144 raise RuntimeError("Invalid combination of inputs.") 

145 

146 def cost_function( 

147 self, tt, state, ctrl_input, cost_args, is_initial=False, is_final=False 

148 ): 

149 """Calculates the cost for the state and control input. 

150 

151 Parameters 

152 ---------- 

153 tt : float 

154 Current timestep. 

155 state : N x 1 numpy array 

156 Current state. 

157 ctrl_input : Nu x 1 numpy array 

158 Current control input. 

159 cost_args : tuple 

160 Additional arguments for the non-quadratic part or the custom cost 

161 function. 

162 is_initial : bool, optional 

163 Flag indicating if it's the initial time. The default is False. 

164 is_final : bool, optional 

165 Flag indicating if it's the final time. The default is False. 

166 

167 Returns 

168 ------- 

169 float 

170 Cost. 

171 """ 

172 if not self.use_custom_cost: 

173 cost = super().cost_function( 

174 tt, state, ctrl_input, is_initial=is_initial, is_final=is_final 

175 ) 

176 return ( 

177 cost 

178 if is_final 

179 else cost 

180 + self._non_quad_fun( 

181 tt, 

182 state, 

183 ctrl_input, 

184 self.end_state, 

185 is_initial, 

186 is_final, 

187 *cost_args, 

188 ) 

189 ) 

190 

191 return self._cost_fun( 

192 tt, state, ctrl_input, self.end_state, is_initial, is_final, *cost_args 

193 ) 

194 

195 def prop_state_forward( 

196 self, tt, x_hat, u_hat, state_args, ctrl_args, inv_state_args, inv_ctrl_args 

197 ): 

198 """Propagate the state forward and get the backward state space. 

199 

200 Parameters 

201 ---------- 

202 tt : float 

203 Current timestep. 

204 x_hat : N x 1 numpy array 

205 Current state. 

206 u_hat : Nu x 1 numpy array 

207 Current control input. 

208 state_args : tuple 

209 Extra arguments for the state matrix calculation. 

210 ctrl_args : tuple 

211 Extra arguments for the control matrix calculation. 

212 inv_state_args : tuple 

213 Extra arguments for the inverse state matrix calculation. 

214 inv_ctrl_args : tuple 

215 Extra arguments for the inverse control matrix calculation. 

216 

217 Returns 

218 ------- 

219 x_hat_p : N x 1 numpy array 

220 Next state. 

221 Abar : N x N numpy array 

222 Backward state transition matrix. 

223 Bbar : N x Nu 

224 Backward input matrix. 

225 cbar : N x 1 numpy array 

226 Backward c vector. 

227 """ 

228 x_hat_p = self.prop_state( 

229 tt, x_hat, u_hat, state_args, ctrl_args, True, inv_state_args, inv_ctrl_args 

230 ) 

231 

232 if self.dynObj is not None: 

233 if isinstance(self.dynObj, gdyn.NonlinearDynamicsBase): 

234 if self.dynObj.dt > 0: 

235 self.dynObj.dt *= -1 # set to inverse dynamics 

236 

237 ABar = self.dynObj.get_state_mat( 

238 tt, x_hat_p, *inv_state_args, u=u_hat, ctrl_args=ctrl_args 

239 ) 

240 BBar = self.dynObj.get_input_mat(tt, x_hat_p, u_hat, *inv_ctrl_args) 

241 

242 else: 

243 A, B = self.get_state_space(tt, x_hat_p, u_hat, state_args, ctrl_args) 

244 ABar = la.inv(A) 

245 BBar = -ABar @ B 

246 

247 else: 

248 raise NotImplementedError("Need to implement this") 

249 

250 cBar = x_hat - ABar @ x_hat_p - BBar @ u_hat 

251 

252 return x_hat_p, ABar, BBar, cBar 

253 

254 def quadratize_cost(self, tt, itr, x_hat, u_hat, is_initial, is_final, cost_args): 

255 """Quadratizes the cost function. 

256 

257 If the non-quadratic portion of the standard cost function is not 

258 positive semi-definite then it is regularized by setting the negative 

259 eigen values to zero and reconstructing the matrix. 

260 

261 Parameters 

262 ---------- 

263 tt : float 

264 Current timestep. 

265 itr : int 

266 Iteration number. 

267 x_hat : N x 1 numpy array 

268 Current state. 

269 u_hat : Nu x 1 numpy array 

270 Current control input. 

271 is_initial : bool 

272 Flag indicating if this is the first timestep. 

273 is_final : bool 

274 Flag indicating if this is the last timestep. 

275 cost_args : tuple 

276 Additional arguments for the cost calculation. 

277 

278 Raises 

279 ------ 

280 NotImplementedError 

281 Unsupported configuration. 

282 

283 Returns 

284 ------- 

285 P : Nu x N numpy array 

286 State and control correlation cost matrix. 

287 Q : N x N numpy array 

288 State cost matrix. 

289 R : Nu x Nu numpy array 

290 Control input cost matrix. 

291 q : N x 1 

292 State cost vector. 

293 r : Nu x 1 

294 Control input cost vector. 

295 """ 

296 if not self.use_custom_cost: 

297 P, Q, R, q, r = super()._determine_cost_matrices( 

298 tt, itr, x_hat, u_hat, is_initial, is_final, cost_args 

299 ) 

300 

301 if not is_initial and not is_final: 

302 if self._quad_modifier is not None: 

303 P, Q, R, q, r = self._quad_modifier(itr, tt, P, Q, R, q, r) 

304 

305 xdim = x_hat.size 

306 udim = u_hat.size 

307 comb_state = np.vstack((x_hat, u_hat)).ravel() 

308 big_mat = gmath.get_hessian( 

309 comb_state, 

310 lambda _x, *_args: self._non_quad_fun( 

311 tt, 

312 _x[:xdim], 

313 _x[xdim:], 

314 self.end_state, 

315 is_initial, 

316 is_final, 

317 *cost_args, 

318 ), 

319 ) 

320 

321 # regularize hessian to keep it pos semi def 

322 vals, vecs = np.linalg.eig(big_mat) 

323 vals[vals < 0] = 0 

324 big_mat = vecs @ np.diag(vals) @ vecs.T 

325 

326 # extract non-quadratic terms 

327 non_Q = big_mat[:xdim, :xdim] 

328 

329 non_P = big_mat[xdim:, :xdim] 

330 non_R = big_mat[xdim:, xdim:] 

331 

332 big_vec = gmath.get_jacobian( 

333 comb_state, 

334 lambda _x, *args: self._non_quad_fun( 

335 tt, 

336 _x[:xdim], 

337 _x[xdim:], 

338 self.end_state, 

339 is_initial, 

340 is_final, 

341 *cost_args, 

342 ), 

343 ) 

344 

345 non_q = big_vec[:xdim].reshape((xdim, 1)) - ( 

346 non_Q @ x_hat + non_P.T @ u_hat 

347 ) 

348 non_r = big_vec[xdim:].reshape((udim, 1)) - ( 

349 non_P @ x_hat + non_R @ u_hat 

350 ) 

351 

352 Q += non_Q 

353 q += non_q 

354 R += non_R 

355 r += non_r 

356 P += non_P 

357 

358 else: 

359 # TODO: get hessians 

360 raise NotImplementedError("Need to implement this") 

361 

362 return P, Q, R, q, r 

363 

364 def quadratize_final_cost(self, itr, num_timesteps, traj, time_vec, cost_args): 

365 u_hat = ( 

366 self.feedback_gain[num_timesteps] 

367 @ traj[num_timesteps - 1, :].reshape((-1, 1)) 

368 + self.feedthrough_gain[num_timesteps] 

369 ) 

370 ( 

371 _, 

372 self.ct_go_mats[num_timesteps], 

373 _, 

374 self.ct_go_vecs[num_timesteps], 

375 _, 

376 ) = self.quadratize_cost( 

377 time_vec[-1], 

378 itr, 

379 traj[num_timesteps - 1, :].reshape((-1, 1)), 

380 u_hat, 

381 False, 

382 True, 

383 cost_args, 

384 ) 

385 traj[num_timesteps, :] = -( 

386 np.linalg.inv( 

387 self.ct_go_mats[num_timesteps] + self.ct_come_mats[num_timesteps] 

388 ) 

389 @ (self.ct_go_vecs[num_timesteps] + self.ct_come_vecs[num_timesteps]) 

390 ).ravel() 

391 

392 if ( 

393 self.hard_constraints 

394 and self.dynObj is not None 

395 and self.dynObj.state_constraint is not None 

396 ): 

397 traj[num_timesteps, :] = self.dynObj.state_constraint( 

398 time_vec[num_timesteps], traj[num_timesteps, :].reshape((-1, 1)) 

399 ).ravel() 

400 

401 return traj 

402 

403 def _determine_cost_matrices( 

404 self, tt, itr, x_hat, u_hat, is_initial, is_final, cost_args 

405 ): 

406 return self.quadratize_cost( 

407 tt, itr, x_hat, u_hat, is_initial, is_final, cost_args 

408 ) 

409 

410 def _back_pass_update_traj(self, x_hat_p, kk): 

411 return -( 

412 np.linalg.inv(self.ct_go_mats[kk] + self.ct_come_mats[kk]) 

413 @ (self.ct_go_vecs[kk] + self.ct_come_vecs[kk]) 

414 ).ravel() 

415 

416 def forward_pass_step( 

417 self, 

418 itr, 

419 kk, 

420 time_vec, 

421 traj, 

422 state_args, 

423 ctrl_args, 

424 cost_args, 

425 inv_state_args, 

426 inv_ctrl_args, 

427 ): 

428 tt = time_vec[kk] 

429 

430 u_hat = ( 

431 self.feedback_gain[kk] @ traj[kk, :].reshape((-1, 1)) 

432 + self.feedthrough_gain[kk] 

433 ) 

434 if self.hard_constraints and self.control_constraints is not None: 

435 u_hat = self.control_constraints(tt, u_hat) 

436 

437 x_hat_p, ABar, BBar, cBar = self.prop_state_forward( 

438 tt, 

439 traj[kk, :].reshape((-1, 1)), 

440 u_hat, 

441 state_args, 

442 ctrl_args, 

443 inv_state_args, 

444 inv_ctrl_args, 

445 ) 

446 

447 # final cost is handled after the forward pass 

448 P, Q, R, q, r = self._determine_cost_matrices( 

449 tt, itr, traj[kk, :].reshape((-1, 1)), u_hat, kk == 0, False, cost_args 

450 ) 

451 

452 ctm_Q = self.ct_come_mats[kk] + Q 

453 ctm_Q_A = ctm_Q @ ABar 

454 ctv_q_ctm_Q_c = self.ct_come_vecs[kk] + q + ctm_Q @ cBar 

455 CBar = BBar.T @ ctm_Q_A + P @ ABar 

456 DBar = ABar.T @ ctm_Q_A 

457 EBar = BBar.T @ ctm_Q @ BBar + R + P @ BBar + BBar.T @ P.T 

458 dBar = ABar.T @ ctv_q_ctm_Q_c 

459 eBar = BBar.T @ ctv_q_ctm_Q_c + r + P @ cBar 

460 

461 neg_inv_EBar = -np.linalg.inv(EBar) 

462 self.feedback_gain[kk] = neg_inv_EBar @ CBar 

463 self.feedthrough_gain[kk] = neg_inv_EBar @ eBar 

464 

465 self.ct_come_mats[kk + 1] = DBar + CBar.T @ self.feedback_gain[kk] 

466 self.ct_come_vecs[kk + 1] = dBar + CBar.T @ self.feedthrough_gain[kk] 

467 

468 out = -( 

469 np.linalg.inv(self.ct_go_mats[kk + 1] + self.ct_come_mats[kk + 1]) 

470 @ (self.ct_go_vecs[kk + 1] + self.ct_come_vecs[kk + 1]) 

471 ) 

472 if ( 

473 self.hard_constraints 

474 and self.dynObj is not None 

475 and self.dynObj.state_constraint is not None 

476 ): 

477 out = self.dynObj.state_constraint(time_vec[kk + 1], out) 

478 return out.ravel() 

479 

480 def forward_pass( 

481 self, 

482 itr, 

483 num_timesteps, 

484 traj, 

485 state_args, 

486 ctrl_args, 

487 cost_args, 

488 time_vec, 

489 inv_state_args, 

490 inv_ctrl_args, 

491 ): 

492 """Forward pass for the smoothing. 

493 

494 Parameters 

495 ---------- 

496 itr : int 

497 iteration number. 

498 num_timesteps : int 

499 Total number of timesteps. 

500 traj : Nh+1 x N numpy array 

501 State trajectory. 

502 state_args : tuple 

503 Extra arguments for the state matrix calculation. 

504 ctrl_args : tuple 

505 Extra arguments for the control matrix calculation. 

506 cost_args : tuple 

507 Extra arguments for the cost function. 

508 time_vec : Nh+1 numpy array 

509 Time vector for control horizon. 

510 inv_state_args : tuple 

511 Extra arguments for the inverse state matrix calculation. 

512 inv_ctrl_args : tuple 

513 Extra arguments for the inverse control matrix calculation. 

514 

515 Returns 

516 ------- 

517 traj : Nh+1 x N numpy array 

518 State trajectory. 

519 """ 

520 for kk in range(num_timesteps): 

521 traj[kk + 1, :] = self.forward_pass_step( 

522 itr, 

523 kk, 

524 time_vec, 

525 traj, 

526 state_args, 

527 ctrl_args, 

528 cost_args, 

529 inv_state_args, 

530 inv_ctrl_args, 

531 ) 

532 

533 return traj 

534 

535 def draw_traj( 

536 self, 

537 fig, 

538 plt_inds, 

539 fig_h, 

540 fig_w, 

541 save_animation, 

542 num_timesteps, 

543 time_vec, 

544 state_args, 

545 ctrl_args, 

546 inv_state_args, 

547 inv_ctrl_args, 

548 color=None, 

549 alpha=0.2, 

550 zorder=-10, 

551 ): 

552 if color is None: 

553 color = (0.5, 0.5, 0.5) 

554 state_traj = np.nan * np.ones((num_timesteps + 1, self._init_state.size)) 

555 state_traj[0, :] = self._init_state.flatten() 

556 for kk, tt in enumerate(time_vec[:-1]): 

557 u = ( 

558 self.feedback_gain[kk] @ state_traj[kk, :].reshape((-1, 1)) 

559 + self.feedthrough_gain[kk] 

560 ) 

561 state_traj[kk + 1, :] = self.prop_state( 

562 tt, 

563 state_traj[kk, :].reshape((-1, 1)), 

564 u, 

565 state_args, 

566 ctrl_args, 

567 True, 

568 inv_state_args, 

569 inv_ctrl_args, 

570 ).ravel() 

571 

572 # plot backward pass trajectory 

573 extra_args = dict( 

574 color=color, 

575 alpha=alpha, 

576 zorder=zorder, 

577 ) 

578 if len(plt_inds) == 3: 

579 fig.axes[self._ax].plot( 

580 state_traj[:, plt_inds[0]], 

581 state_traj[:, plt_inds[1]], 

582 state_traj[:, plt_inds[2]], 

583 **extra_args, 

584 ) 

585 else: 

586 fig.axes[self._ax].plot( 

587 state_traj[:, plt_inds[0]], 

588 state_traj[:, plt_inds[1]], 

589 **extra_args, 

590 ) 

591 

592 if matplotlib.get_backend() != "Agg": 

593 plt.pause(0.005) 

594 if save_animation: 

595 with io.BytesIO() as buff: 

596 fig.savefig(buff, format="raw") 

597 buff.seek(0) 

598 arr = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape( 

599 (*(fig.canvas.get_width_height()[1::-1]), -1) 

600 ) 

601 if arr.shape[0] != fig_w or arr.shape[1] != fig_h: 

602 img = Image.fromarray(arr).resize((fig_w, fig_h), Image.BICUBIC) 

603 else: 

604 img = Image.fromarray(arr) 

605 return img 

606 

607 return [] 

608 

609 def reset(self, tt, cur_state, end_state): 

610 """Reset values for calculating the control parameters. 

611 

612 This generally does not need to be called outside of the class. It is 

613 automatically called when calculating the control parameters. 

614 

615 Parameters 

616 ---------- 

617 tt : float 

618 Current time when starting the control calculation. 

619 cur_state : N x 1 numpy array 

620 Current state. 

621 end_state : N x 1 numpy array 

622 Desired/reference ending state. 

623 

624 Returns 

625 ------- 

626 old_cost : float 

627 prior cost. 

628 num_timesteps : int 

629 number of timesteps in the horizon. 

630 traj : Nh x N numpy array 

631 state trajectory, each row is a timestep. 

632 time_vec : Nh+1 numpy array 

633 time vector. 

634 """ 

635 old_cost = float("inf") 

636 num_timesteps = int(self.time_horizon / self.dt) 

637 self._init_state = cur_state.reshape((-1, 1)).copy() 

638 self.end_state = end_state.reshape((-1, 1)).copy() 

639 traj = np.nan * np.ones((num_timesteps + 1, cur_state.size)) 

640 traj[0, :] = cur_state.flatten() 

641 

642 self.ct_come_mats = np.zeros( 

643 (num_timesteps + 1, cur_state.size, cur_state.size) 

644 ) 

645 self.ct_come_vecs = np.zeros((num_timesteps + 1, cur_state.size, 1)) 

646 self.ct_go_mats = np.zeros(self.ct_come_mats.shape) 

647 self.ct_go_vecs = np.zeros(self.ct_come_vecs.shape) 

648 

649 self.feedback_gain = np.zeros( 

650 (num_timesteps + 1, self.u_nom.size, cur_state.size) 

651 ) 

652 self.feedthrough_gain = self.u_nom * np.ones( 

653 (num_timesteps + 1, self.u_nom.size, 1) 

654 ) 

655 

656 abs_dt = np.abs(self.dt) 

657 

658 time_vec = tt + abs_dt * np.linspace(0, 1, num_timesteps + 1, endpoint=True) 

659 

660 return old_cost, num_timesteps, traj, time_vec 

661 

662 def calculate_control( 

663 self, 

664 tt, 

665 cur_state, 

666 end_state, 

667 state_args=None, 

668 ctrl_args=None, 

669 cost_args=None, 

670 inv_state_args=None, 

671 inv_ctrl_args=None, 

672 provide_details=False, 

673 disp=True, 

674 show_animation=False, 

675 save_animation=False, 

676 plt_opts=None, 

677 ttl=None, 

678 fig=None, 

679 plt_inds=None, 

680 ax_num=0, 

681 ): 

682 """Calculate the control parameters and state trajectory. 

683 

684 Parameters 

685 ---------- 

686 tt : float 

687 Current time when starting the control calculation. 

688 cur_state : N x 1 numpy array 

689 Current state. 

690 end_state : N x 1 numpy array 

691 Desired/reference ending state. 

692 state_args : tuple, optional 

693 Extra arguments for calculating the state matrix. The default is None. 

694 ctrl_args : tuple, optional 

695 Extra arguments for calculating the input matrix. The default is None. 

696 cost_args : tuple, optional 

697 Extra arguments for the cost function, either the non-quadratic 

698 part or the custom function. The default is None. 

699 inv_state_args : tuple, optional 

700 Extra arguments for calculating the inverse state matrix. The 

701 default is None. 

702 inv_ctrl_args : tuple, optional 

703 Extra arguments for calculating the inverse input matrix. The 

704 default is None. 

705 provide_details : bool, optional 

706 Flag indicating if additional outputs should be provided. The 

707 default is False. 

708 disp : bool, optional 

709 Flag indicating if extra text should be displayed during the 

710 optimization. The default is True. 

711 show_animation : bool, optional 

712 Flag indicating if an animation should be shown during the 

713 optimization. The default is False. 

714 save_animation : bool, optional 

715 Flag indicating if the frames of the animation should be saved. 

716 Requires that the animation is shown. The default is False. 

717 plt_opts : dict, optional 

718 Options for the plot, see :func:`gncpy.plotting.init_plotting_opts`. 

719 The default is None. 

720 ttl : string, optional 

721 Title for the plot. The default is None. 

722 fig : matplotlib figure, optional 

723 Figure to draw on. If supplied only the end point and paths are 

724 drawn, and the title is updated if supplied; the rest of the figure 

725 remains unchanged. The default is None which creates a new figure 

726 and adds a title. If no title is supplied a default is used. 

727 plt_inds : list, optional 

728 2 element list of state indices for plotting. The default is None 

729 which assumes [0, 1]. 

730 

731 Returns 

732 ------- 

733 u : Nu x 1 numpy array 

734 Control input for current timestep. 

735 cost : float, optional 

736 Cost of the trajectory 

737 state_traj : Nh+1 x N numpy array, optional 

738 State trajectory over the horizon, or until the reference is reached. 

739 ctrl_signal : Nh x Nu numy array, optional 

740 Control inputs over the horizon, or until the reference is reached. 

741 fig : matplotlib figure, optional 

742 Figure handle for the generated plot, None if animation is not shown. 

743 frame_list : list, optional 

744 Each element is a PIL.Image corresponding to an animation frame. 

745 """ 

746 if state_args is None: 

747 state_args = () 

748 if ctrl_args is None: 

749 ctrl_args = () 

750 if cost_args is None: 

751 cost_args = () 

752 if inv_state_args is None: 

753 inv_state_args = () 

754 if inv_ctrl_args is None: 

755 inv_ctrl_args = () 

756 

757 if plt_inds is None: 

758 plt_inds = [0, 1] 

759 self._ax = ax_num 

760 

761 old_cost, num_timesteps, traj, time_vec = self.reset(tt, cur_state, end_state) 

762 

763 frame_list = [] 

764 if show_animation: 

765 if fig is None: 

766 fig = plt.figure() 

767 if len(plt_inds) == 3: 

768 extra_args = {"projection": "3d"} 

769 else: 

770 extra_args = {} 

771 fig.add_subplot(1, 1, 1, **extra_args) 

772 try: 

773 fig.axes[self._ax].set_aspect("equal", adjustable="box") 

774 except Exception: 

775 pass 

776 

777 if plt_opts is None: 

778 plt_opts = gplot.init_plotting_opts(f_hndl=fig) 

779 

780 if ttl is None: 

781 ttl = "ELQR" 

782 

783 gplot.set_title_label( 

784 fig, self._ax, plt_opts, ttl=ttl, use_local=(self._ax != 0) 

785 ) 

786 

787 # draw start 

788 fig.axes[self._ax].scatter( 

789 *self._init_state.ravel()[plt_inds], 

790 marker="o", 

791 color="g", 

792 zorder=1000, 

793 ) 

794 fig.tight_layout() 

795 else: 

796 if ttl is not None: 

797 gplot.set_title_label( 

798 fig, self._ax, plt_opts, ttl=ttl, use_local=(self._ax != 0) 

799 ) 

800 

801 fig.axes[self._ax].scatter( 

802 *self.end_state.ravel()[plt_inds], 

803 marker="x", 

804 color="r", 

805 zorder=1000, 

806 ) 

807 if matplotlib.get_backend() != "Agg": 

808 plt.pause(0.01) 

809 

810 # for stopping simulation with the esc key. 

811 fig.canvas.mpl_connect( 

812 "key_release_event", 

813 lambda event: [exit(0) if event.key == "escape" else None], 

814 ) 

815 fig_w, fig_h = fig.canvas.get_width_height() 

816 

817 # save first frame of animation 

818 if save_animation: 

819 with io.BytesIO() as buff: 

820 fig.savefig(buff, format="raw") 

821 buff.seek(0) 

822 arr = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape( 

823 (*(fig.canvas.get_width_height()[1::-1]), -1) 

824 ) 

825 if arr.shape[0] != fig_w or arr.shape[1] != fig_h: 

826 img = Image.fromarray(arr).resize((fig_w, fig_h), Image.BICUBIC) 

827 else: 

828 img = Image.fromarray(arr) 

829 frame_list.append(img) 

830 

831 if disp: 

832 print("Starting ELQR optimization loop...") 

833 

834 for ii in range(self.max_iters): 

835 print(f"Loop Iteration : {ii}") 

836 # forward pass 

837 traj = self.forward_pass( 

838 ii, 

839 num_timesteps, 

840 traj, 

841 state_args, 

842 ctrl_args, 

843 cost_args, 

844 time_vec, 

845 inv_state_args, 

846 inv_ctrl_args, 

847 ) 

848 print("Quadratizing final cost") 

849 # quadratize final cost 

850 traj = self.quadratize_final_cost( 

851 ii, num_timesteps, traj, time_vec, cost_args 

852 ) 

853 

854 print("Starting Backward Pass") 

855 # backward pass 

856 traj = self.backward_pass( 

857 ii, 

858 num_timesteps, 

859 traj, 

860 state_args, 

861 ctrl_args, 

862 cost_args, 

863 time_vec, 

864 inv_state_args, 

865 inv_ctrl_args, 

866 ) 

867 

868 # get cost 

869 cost = 0 

870 x = traj[0, :].copy().reshape((-1, 1)) 

871 for kk, tt in enumerate(time_vec[:-1]): 

872 u = self.feedback_gain[kk] @ x + self.feedthrough_gain[kk] 

873 if self.control_constraints is not None: 

874 u = self.control_constraints(tt, u) 

875 cost += self.cost_function( 

876 tt, 

877 x, 

878 u, 

879 cost_args, 

880 is_initial=(kk == 0), 

881 is_final=False, 

882 ) 

883 x = self.prop_state( 

884 tt, x, u, state_args, ctrl_args, True, inv_state_args, inv_ctrl_args 

885 ) 

886 cost += self.cost_function( 

887 time_vec[-1], 

888 x, 

889 u, 

890 cost_args, 

891 is_initial=False, 

892 is_final=True, 

893 ) 

894 

895 if show_animation: 

896 img = self.draw_traj( 

897 fig, 

898 plt_inds, 

899 fig_h, 

900 fig_w, 

901 (save_animation and ii % self.gif_frame_skip == 0), 

902 num_timesteps, 

903 time_vec, 

904 state_args, 

905 ctrl_args, 

906 inv_state_args, 

907 inv_ctrl_args, 

908 ) 

909 if save_animation and ii % self.gif_frame_skip == 0: 

910 frame_list.append(img) 

911 

912 if disp: 

913 print("\tIteration: {:3d} Cost: {:10.4f}".format(ii, cost)) 

914 

915 # check for convergence 

916 if np.abs((old_cost - cost) / cost) < self.tol: 

917 break 

918 old_cost = cost 

919 

920 # create outputs and return 

921 ctrl_signal = np.nan * np.ones((num_timesteps, self.u_nom.size)) 

922 state_traj = np.nan * np.ones((num_timesteps + 1, self._init_state.size)) 

923 cost = 0 

924 state_traj[0, :] = self._init_state.flatten() 

925 for kk, tt in enumerate(time_vec[:-1]): 

926 ctrl_signal[kk, :] = ( 

927 self.feedback_gain[kk] @ state_traj[kk, :].reshape((-1, 1)) 

928 + self.feedthrough_gain[kk] 

929 ).ravel() 

930 if self.control_constraints is not None: 

931 ctrl_signal[kk, :] = self.control_constraints( 

932 tt, ctrl_signal[kk, :].reshape((-1, 1)) 

933 ).ravel() 

934 

935 cost += self.cost_function( 

936 tt, 

937 state_traj[kk, :].reshape((-1, 1)), 

938 ctrl_signal[kk, :].reshape((-1, 1)), 

939 cost_args, 

940 is_initial=(kk == 0), 

941 is_final=False, 

942 ) 

943 state_traj[kk + 1, :] = self.prop_state( 

944 tt, 

945 state_traj[kk, :].reshape((-1, 1)), 

946 ctrl_signal[kk, :].reshape((-1, 1)), 

947 state_args, 

948 ctrl_args, 

949 True, 

950 inv_state_args, 

951 inv_ctrl_args, 

952 ).ravel() 

953 

954 cost += self.cost_function( 

955 time_vec[-1], 

956 state_traj[num_timesteps, :].reshape((-1, 1)), 

957 ctrl_signal[num_timesteps - 1, :].reshape((-1, 1)), 

958 cost_args, 

959 is_initial=False, 

960 is_final=True, 

961 ) 

962 

963 if show_animation: 

964 extra_args = dict( 

965 linestyle="-", 

966 color="g", 

967 ) 

968 if len(plt_inds) == 3: 

969 fig.axes[self._ax].plot( 

970 state_traj[:, plt_inds[0]], 

971 state_traj[:, plt_inds[1]], 

972 state_traj[:, plt_inds[2]], 

973 **extra_args, 

974 ) 

975 else: 

976 fig.axes[self._ax].plot( 

977 state_traj[:, plt_inds[0]], state_traj[:, plt_inds[1]], **extra_args 

978 ) 

979 if matplotlib.get_backend() != "Agg": 

980 plt.pause(0.001) 

981 if save_animation: 

982 with io.BytesIO() as buff: 

983 fig.savefig(buff, format="raw") 

984 buff.seek(0) 

985 arr = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape( 

986 (*(fig.canvas.get_width_height()[1::-1]), -1) 

987 ) 

988 if arr.shape[0] != fig_w or arr.shape[1] != fig_h: 

989 img = Image.fromarray(arr).resize((fig_w, fig_h), Image.BICUBIC) 

990 else: 

991 img = Image.fromarray(arr) 

992 frame_list.append(img) 

993 

994 u = ctrl_signal[0, :].reshape((-1, 1)) 

995 details = (cost, state_traj, ctrl_signal, fig, frame_list) 

996 return (u, *details) if provide_details else u