Coverage for src/gncpy/control/elqr.py: 66%

266 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-13 06:15 +0000

1import io 

2import matplotlib 

3import matplotlib.pyplot as plt 

4import numpy as np 

5import scipy.linalg as la 

6from PIL import Image 

7 

8import gncpy.dynamics.basic as gdyn 

9import gncpy.math as gmath 

10import gncpy.plotting as gplot 

11from gncpy.control.lqr import LQR 

12 

13 

14class ELQR(LQR): 

15 """Implements an Extended Linear Quadratic Regulator (ELQR) controller. 

16 

17 This is based on 

18 :cite:`Berg2016_ExtendedLQRLocallyOptimalFeedbackControlforSystemswithNonLinearDynamicsandNonQuadraticCost`. 

19 

20 Attributes 

21 ---------- 

22 max_iters : int 

23 Maximum number of iterations to try for convergence. 

24 tol : float 

25 Tolerance for convergence. 

26 ct_come_mats : Nh+1 x N x N numpy array 

27 Cost-to-come matrices, 1 per step in the time horizon. 

28 ct_come_vecs : Nh+1 x N x 1 numpy array 

29 Cost-to-come vectors, 1 per step in the time horizon. 

30 use_custom_cost : bool 

31 Flag indicating if a custom cost function should be used. 

32 gif_frame_skip : int 

33 Number of frames to skip when saving the gif. Set to 1 to take every 

34 frame. 

35 """ 

36 

37 def __init__(self, max_iters=1e3, tol=1e-4, **kwargs): 

38 """Initialize an object. 

39 

40 Parameters 

41 ---------- 

42 max_iters : int, optional 

43 Maximum number of iterations to try for convergence. The default is 

44 1e3. 

45 tol : float, optional 

46 Tolerance on convergence. The default is 1e-4. 

47 **kwargs : dict 

48 Additional arguments. 

49 """ 

50 key = "time_horizon" 

51 if key not in kwargs: 

52 kwargs[key] = 10 

53 

54 super().__init__(**kwargs) 

55 self.max_iters = int(max_iters) 

56 self.tol = tol 

57 

58 self.ct_come_mats = np.array([]) 

59 self.ct_come_vecs = np.array([]) 

60 

61 self.use_custom_cost = False 

62 self.gif_frame_skip = 1 

63 self._non_quad_fun = None 

64 self._quad_modifier = None 

65 self._cost_fun = None 

66 

67 self._ax = 0 

68 

69 def set_cost_model( 

70 self, 

71 Q=None, 

72 R=None, 

73 non_quadratic_fun=None, 

74 quad_modifier=None, 

75 cost_fun=None, 

76 skip_validity_check=False, 

77 ): 

78 r"""Sets the cost model. 

79 

80 Either `Q`, and `R` must be supplied (and optionally `quad_modifier`) 

81 or `cost_fun`. If `Q and `R` are specified a `non_quadratic_fun` is 

82 also needed by the code but this function does not force the requirement. 

83 

84 Notes 

85 ----- 

86 This assumes the following form for the cost function 

87 

88 .. math:: 

89 

90 J = x_f^T Q x_f + \int_{t_0}^{t_f} x^T Q x + u^T R u + u^T P x + f(\tau, x, u) d\tau 

91 

92 where :math:`f(t, x, u)` contains the non-quadratic terms and gets 

93 automatically quadratized by the algorithm. 

94 

95 A custom cost function can also be provided for :math:`J` in which case 

96 the entire function :math:`J(t, x, u)` will be quadratized at every 

97 timestep. 

98 

99 Parameters 

100 ---------- 

101 Q : N x N numpy array, optional 

102 State cost matrix. The default is None. 

103 R : Nu x Nu numpy array, optional 

104 Control cost matrix. The default is None. 

105 non_quadratic_fun : callable, optional 

106 Non-quadratic portion of the standard cost function. This should 

107 have the form 

108 :code:`f(t, state, ctrl_input, end_state, is_initial, is_final, *args)` 

109 and return a scalar additional cost. The default is None. 

110 quad_modifier : callable, optional 

111 Function to modifiy the :math:`P, Q, R, q, r` matrices before 

112 adding the non-quadratic terms. Must have the form 

113 :code:`f(itr, t, P, Q, R, q, r)` and return a tuple of 

114 :code:`(P, Q, R, q, r)`. The default is None. 

115 cost_fun : callable, optional 

116 Custom cost function, must handle all cases and will be numericaly 

117 quadratized at every timestep. Must have the form 

118 :code:`f(t, state, ctrl_input, end_state, is_initial, is_final, *args)` 

119 and return a scalar total cost. The default is None. 

120 skip_validity_check : bool 

121 Flag indicating if an error should be raised if the wrong input 

122 combination is given. 

123 

124 Raises 

125 ------ 

126 RuntimeError 

127 Invlaid combination of input arguments. 

128 """ 

129 if non_quadratic_fun is not None: 

130 self._non_quad_fun = non_quadratic_fun 

131 

132 self._quad_modifier = quad_modifier 

133 

134 if Q is not None and R is not None: 

135 super().set_cost_model(Q, R) 

136 self.use_custom_cost = False 

137 

138 elif cost_fun is not None: 

139 self._cost_fun = cost_fun 

140 self.use_custom_cost = True 

141 

142 else: 

143 if not skip_validity_check: 

144 raise RuntimeError("Invalid combination of inputs.") 

145 

146 def cost_function( 

147 self, tt, state, ctrl_input, cost_args, is_initial=False, is_final=False 

148 ): 

149 """Calculates the cost for the state and control input. 

150 

151 Parameters 

152 ---------- 

153 tt : float 

154 Current timestep. 

155 state : N x 1 numpy array 

156 Current state. 

157 ctrl_input : Nu x 1 numpy array 

158 Current control input. 

159 cost_args : tuple 

160 Additional arguments for the non-quadratic part or the custom cost 

161 function. 

162 is_initial : bool, optional 

163 Flag indicating if it's the initial time. The default is False. 

164 is_final : bool, optional 

165 Flag indicating if it's the final time. The default is False. 

166 

167 Returns 

168 ------- 

169 float 

170 Cost. 

171 """ 

172 if not self.use_custom_cost: 

173 cost = super().cost_function( 

174 tt, state, ctrl_input, is_initial=is_initial, is_final=is_final 

175 ) 

176 return ( 

177 cost 

178 if is_final 

179 else cost 

180 + self._non_quad_fun( 

181 tt, 

182 state, 

183 ctrl_input, 

184 self.end_state, 

185 is_initial, 

186 is_final, 

187 *cost_args, 

188 ) 

189 ) 

190 

191 return self._cost_fun( 

192 tt, state, ctrl_input, self.end_state, is_initial, is_final, *cost_args 

193 ) 

194 

195 def prop_state_forward( 

196 self, tt, x_hat, u_hat, state_args, ctrl_args, inv_state_args, inv_ctrl_args 

197 ): 

198 """Propagate the state forward and get the backward state space. 

199 

200 Parameters 

201 ---------- 

202 tt : float 

203 Current timestep. 

204 x_hat : N x 1 numpy array 

205 Current state. 

206 u_hat : Nu x 1 numpy array 

207 Current control input. 

208 state_args : tuple 

209 Extra arguments for the state matrix calculation. 

210 ctrl_args : tuple 

211 Extra arguments for the control matrix calculation. 

212 inv_state_args : tuple 

213 Extra arguments for the inverse state matrix calculation. 

214 inv_ctrl_args : tuple 

215 Extra arguments for the inverse control matrix calculation. 

216 

217 Returns 

218 ------- 

219 x_hat_p : N x 1 numpy array 

220 Next state. 

221 Abar : N x N numpy array 

222 Backward state transition matrix. 

223 Bbar : N x Nu 

224 Backward input matrix. 

225 cbar : N x 1 numpy array 

226 Backward c vector. 

227 """ 

228 x_hat_p = self.prop_state( 

229 tt, x_hat, u_hat, state_args, ctrl_args, True, inv_state_args, inv_ctrl_args 

230 ) 

231 

232 if self.dynObj is not None: 

233 if isinstance(self.dynObj, gdyn.NonlinearDynamicsBase): 

234 if self.dynObj.dt > 0: 

235 self.dynObj.dt *= -1 # set to inverse dynamics 

236 

237 ABar = self.dynObj.get_state_mat( 

238 tt, x_hat_p, *inv_state_args, u=u_hat, ctrl_args=ctrl_args 

239 ) 

240 BBar = self.dynObj.get_input_mat(tt, x_hat_p, u_hat, *inv_ctrl_args) 

241 

242 else: 

243 A, B = self.get_state_space(tt, x_hat_p, u_hat, state_args, ctrl_args) 

244 ABar = la.inv(A) 

245 BBar = -ABar @ B 

246 

247 else: 

248 raise NotImplementedError("Need to implement this") 

249 

250 cBar = x_hat - ABar @ x_hat_p - BBar @ u_hat 

251 

252 return x_hat_p, ABar, BBar, cBar 

253 

254 def quadratize_cost(self, tt, itr, x_hat, u_hat, is_initial, is_final, cost_args): 

255 """Quadratizes the cost function. 

256 

257 If the non-quadratic portion of the standard cost function is not 

258 positive semi-definite then it is regularized by setting the negative 

259 eigen values to zero and reconstructing the matrix. 

260 

261 Parameters 

262 ---------- 

263 tt : float 

264 Current timestep. 

265 itr : int 

266 Iteration number. 

267 x_hat : N x 1 numpy array 

268 Current state. 

269 u_hat : Nu x 1 numpy array 

270 Current control input. 

271 is_initial : bool 

272 Flag indicating if this is the first timestep. 

273 is_final : bool 

274 Flag indicating if this is the last timestep. 

275 cost_args : tuple 

276 Additional arguments for the cost calculation. 

277 

278 Raises 

279 ------ 

280 NotImplementedError 

281 Unsupported configuration. 

282 

283 Returns 

284 ------- 

285 P : Nu x N numpy array 

286 State and control correlation cost matrix. 

287 Q : N x N numpy array 

288 State cost matrix. 

289 R : Nu x Nu numpy array 

290 Control input cost matrix. 

291 q : N x 1 

292 State cost vector. 

293 r : Nu x 1 

294 Control input cost vector. 

295 """ 

296 if not self.use_custom_cost: 

297 P, Q, R, q, r = super()._determine_cost_matrices( 

298 tt, itr, x_hat, u_hat, is_initial, is_final, cost_args 

299 ) 

300 

301 if not is_initial and not is_final: 

302 if self._quad_modifier is not None: 

303 P, Q, R, q, r = self._quad_modifier(itr, tt, P, Q, R, q, r) 

304 

305 xdim = x_hat.size 

306 udim = u_hat.size 

307 comb_state = np.vstack((x_hat, u_hat)).ravel() 

308 big_mat = gmath.get_hessian( 

309 comb_state, 

310 lambda _x, *_args: self._non_quad_fun( 

311 tt, 

312 _x[:xdim], 

313 _x[xdim:], 

314 self.end_state, 

315 is_initial, 

316 is_final, 

317 *cost_args, 

318 ), 

319 ) 

320 

321 # regularize hessian to keep it pos semi def 

322 vals, vecs = np.linalg.eig(big_mat) 

323 vals[vals < 0] = 0 

324 big_mat = vecs @ np.diag(vals) @ vecs.T 

325 

326 # extract non-quadratic terms 

327 non_Q = big_mat[:xdim, :xdim] 

328 

329 non_P = big_mat[xdim:, :xdim] 

330 non_R = big_mat[xdim:, xdim:] 

331 

332 big_vec = gmath.get_jacobian( 

333 comb_state, 

334 lambda _x, *args: self._non_quad_fun( 

335 tt, 

336 _x[:xdim], 

337 _x[xdim:], 

338 self.end_state, 

339 is_initial, 

340 is_final, 

341 *cost_args, 

342 ), 

343 ) 

344 

345 non_q = big_vec[:xdim].reshape((xdim, 1)) - ( 

346 non_Q @ x_hat + non_P.T @ u_hat 

347 ) 

348 non_r = big_vec[xdim:].reshape((udim, 1)) - ( 

349 non_P @ x_hat + non_R @ u_hat 

350 ) 

351 

352 Q += non_Q 

353 q += non_q 

354 R += non_R 

355 r += non_r 

356 P += non_P 

357 

358 else: 

359 # TODO: get hessians 

360 raise NotImplementedError("Need to implement this") 

361 

362 return P, Q, R, q, r 

363 

364 def quadratize_final_cost(self, itr, num_timesteps, traj, time_vec, cost_args): 

365 u_hat = ( 

366 self.feedback_gain[num_timesteps] 

367 @ traj[num_timesteps - 1, :].reshape((-1, 1)) 

368 + self.feedthrough_gain[num_timesteps] 

369 ) 

370 ( 

371 _, 

372 self.ct_go_mats[num_timesteps], 

373 _, 

374 self.ct_go_vecs[num_timesteps], 

375 _, 

376 ) = self.quadratize_cost( 

377 time_vec[-1], 

378 itr, 

379 traj[num_timesteps - 1, :].reshape((-1, 1)), 

380 u_hat, 

381 False, 

382 True, 

383 cost_args, 

384 ) 

385 traj[num_timesteps, :] = -( 

386 np.linalg.inv( 

387 self.ct_go_mats[num_timesteps] + self.ct_come_mats[num_timesteps] 

388 ) 

389 @ (self.ct_go_vecs[num_timesteps] + self.ct_come_vecs[num_timesteps]) 

390 ).ravel() 

391 

392 if ( 

393 self.hard_constraints 

394 and self.dynObj is not None 

395 and self.dynObj.state_constraint is not None 

396 ): 

397 traj[num_timesteps, :] = self.dynObj.state_constraint( 

398 time_vec[num_timesteps], traj[num_timesteps, :].reshape((-1, 1)) 

399 ).ravel() 

400 

401 return traj 

402 

403 def _determine_cost_matrices( 

404 self, tt, itr, x_hat, u_hat, is_initial, is_final, cost_args 

405 ): 

406 return self.quadratize_cost( 

407 tt, itr, x_hat, u_hat, is_initial, is_final, cost_args 

408 ) 

409 

410 def _back_pass_update_traj(self, x_hat_p, kk): 

411 return -( 

412 np.linalg.inv(self.ct_go_mats[kk] + self.ct_come_mats[kk]) 

413 @ (self.ct_go_vecs[kk] + self.ct_come_vecs[kk]) 

414 ).ravel() 

415 

416 def forward_pass_step( 

417 self, 

418 itr, 

419 kk, 

420 time_vec, 

421 traj, 

422 state_args, 

423 ctrl_args, 

424 cost_args, 

425 inv_state_args, 

426 inv_ctrl_args, 

427 ): 

428 tt = time_vec[kk] 

429 

430 u_hat = ( 

431 self.feedback_gain[kk] @ traj[kk, :].reshape((-1, 1)) 

432 + self.feedthrough_gain[kk] 

433 ) 

434 if self.hard_constraints and self.control_constraints is not None: 

435 u_hat = self.control_constraints(tt, u_hat) 

436 

437 x_hat_p, ABar, BBar, cBar = self.prop_state_forward( 

438 tt, 

439 traj[kk, :].reshape((-1, 1)), 

440 u_hat, 

441 state_args, 

442 ctrl_args, 

443 inv_state_args, 

444 inv_ctrl_args, 

445 ) 

446 

447 # final cost is handled after the forward pass 

448 P, Q, R, q, r = self._determine_cost_matrices( 

449 tt, itr, traj[kk, :].reshape((-1, 1)), u_hat, kk == 0, False, cost_args 

450 ) 

451 

452 ctm_Q = self.ct_come_mats[kk] + Q 

453 ctm_Q_A = ctm_Q @ ABar 

454 ctv_q_ctm_Q_c = self.ct_come_vecs[kk] + q + ctm_Q @ cBar 

455 CBar = BBar.T @ ctm_Q_A + P @ ABar 

456 DBar = ABar.T @ ctm_Q_A 

457 EBar = BBar.T @ ctm_Q @ BBar + R + P @ BBar + BBar.T @ P.T 

458 dBar = ABar.T @ ctv_q_ctm_Q_c 

459 eBar = BBar.T @ ctv_q_ctm_Q_c + r + P @ cBar 

460 

461 neg_inv_EBar = -np.linalg.inv(EBar) 

462 self.feedback_gain[kk] = neg_inv_EBar @ CBar 

463 self.feedthrough_gain[kk] = neg_inv_EBar @ eBar 

464 

465 self.ct_come_mats[kk + 1] = DBar + CBar.T @ self.feedback_gain[kk] 

466 self.ct_come_vecs[kk + 1] = dBar + CBar.T @ self.feedthrough_gain[kk] 

467 

468 out = -( 

469 np.linalg.inv(self.ct_go_mats[kk + 1] + self.ct_come_mats[kk + 1]) 

470 @ (self.ct_go_vecs[kk + 1] + self.ct_come_vecs[kk + 1]) 

471 ) 

472 if ( 

473 self.hard_constraints 

474 and self.dynObj is not None 

475 and self.dynObj.state_constraint is not None 

476 ): 

477 out = self.dynObj.state_constraint(time_vec[kk + 1], out) 

478 return out.ravel() 

479 

480 def forward_pass( 

481 self, 

482 itr, 

483 num_timesteps, 

484 traj, 

485 state_args, 

486 ctrl_args, 

487 cost_args, 

488 time_vec, 

489 inv_state_args, 

490 inv_ctrl_args, 

491 ): 

492 """Forward pass for the smoothing. 

493 

494 Parameters 

495 ---------- 

496 itr : int 

497 iteration number. 

498 num_timesteps : int 

499 Total number of timesteps. 

500 traj : Nh+1 x N numpy array 

501 State trajectory. 

502 state_args : tuple 

503 Extra arguments for the state matrix calculation. 

504 ctrl_args : tuple 

505 Extra arguments for the control matrix calculation. 

506 cost_args : tuple 

507 Extra arguments for the cost function. 

508 time_vec : Nh+1 numpy array 

509 Time vector for control horizon. 

510 inv_state_args : tuple 

511 Extra arguments for the inverse state matrix calculation. 

512 inv_ctrl_args : tuple 

513 Extra arguments for the inverse control matrix calculation. 

514 

515 Returns 

516 ------- 

517 traj : Nh+1 x N numpy array 

518 State trajectory. 

519 """ 

520 for kk in range(num_timesteps): 

521 traj[kk + 1, :] = self.forward_pass_step( 

522 itr, 

523 kk, 

524 time_vec, 

525 traj, 

526 state_args, 

527 ctrl_args, 

528 cost_args, 

529 inv_state_args, 

530 inv_ctrl_args, 

531 ) 

532 

533 return traj 

534 

535 def draw_traj( 

536 self, 

537 fig, 

538 plt_inds, 

539 fig_h, 

540 fig_w, 

541 save_animation, 

542 num_timesteps, 

543 time_vec, 

544 state_args, 

545 ctrl_args, 

546 inv_state_args, 

547 inv_ctrl_args, 

548 color=None, 

549 alpha=0.2, 

550 zorder=-10, 

551 ): 

552 if color is None: 

553 color = (0.5, 0.5, 0.5) 

554 state_traj = np.nan * np.ones((num_timesteps + 1, self._init_state.size)) 

555 state_traj[0, :] = self._init_state.flatten() 

556 for kk, tt in enumerate(time_vec[:-1]): 

557 u = ( 

558 self.feedback_gain[kk] @ state_traj[kk, :].reshape((-1, 1)) 

559 + self.feedthrough_gain[kk] 

560 ) 

561 state_traj[kk + 1, :] = self.prop_state( 

562 tt, 

563 state_traj[kk, :].reshape((-1, 1)), 

564 u, 

565 state_args, 

566 ctrl_args, 

567 True, 

568 inv_state_args, 

569 inv_ctrl_args, 

570 ).ravel() 

571 

572 # plot backward pass trajectory 

573 extra_args = dict( 

574 color=color, 

575 alpha=alpha, 

576 zorder=zorder, 

577 ) 

578 if len(plt_inds) == 3: 

579 fig.axes[self._ax].plot( 

580 state_traj[:, plt_inds[0]], 

581 state_traj[:, plt_inds[1]], 

582 state_traj[:, plt_inds[2]], 

583 **extra_args, 

584 ) 

585 else: 

586 fig.axes[self._ax].plot( 

587 state_traj[:, plt_inds[0]], 

588 state_traj[:, plt_inds[1]], 

589 **extra_args, 

590 ) 

591 

592 if matplotlib.get_backend() != "Agg": 

593 plt.pause(0.005) 

594 if save_animation: 

595 with io.BytesIO() as buff: 

596 fig.savefig(buff, format="raw") 

597 buff.seek(0) 

598 arr = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape( 

599 (*(fig.canvas.get_width_height()[1::-1]), -1) 

600 ) 

601 if arr.shape[0] != fig_w or arr.shape[1] != fig_h: 

602 img = Image.fromarray(arr).resize((fig_w, fig_h), Image.BICUBIC) 

603 else: 

604 img = Image.fromarray(arr) 

605 return img 

606 

607 return [] 

608 

609 def reset(self, tt, cur_state, end_state): 

610 """Reset values for calculating the control parameters. 

611 

612 This generally does not need to be called outside of the class. It is 

613 automatically called when calculating the control parameters. 

614 

615 Parameters 

616 ---------- 

617 tt : float 

618 Current time when starting the control calculation. 

619 cur_state : N x 1 numpy array 

620 Current state. 

621 end_state : N x 1 numpy array 

622 Desired/reference ending state. 

623 

624 Returns 

625 ------- 

626 old_cost : float 

627 prior cost. 

628 num_timesteps : int 

629 number of timesteps in the horizon. 

630 traj : Nh x N numpy array 

631 state trajectory, each row is a timestep. 

632 time_vec : Nh+1 numpy array 

633 time vector. 

634 """ 

635 old_cost = float("inf") 

636 num_timesteps = int(self.time_horizon / self.dt) 

637 self._init_state = cur_state.reshape((-1, 1)).copy() 

638 self.end_state = end_state.reshape((-1, 1)).copy() 

639 traj = np.nan * np.ones((num_timesteps + 1, cur_state.size)) 

640 traj[0, :] = cur_state.flatten() 

641 

642 self.ct_come_mats = np.zeros( 

643 (num_timesteps + 1, cur_state.size, cur_state.size) 

644 ) 

645 self.ct_come_vecs = np.zeros((num_timesteps + 1, cur_state.size, 1)) 

646 self.ct_go_mats = np.zeros(self.ct_come_mats.shape) 

647 self.ct_go_vecs = np.zeros(self.ct_come_vecs.shape) 

648 

649 self.feedback_gain = np.zeros( 

650 (num_timesteps + 1, self.u_nom.size, cur_state.size) 

651 ) 

652 self.feedthrough_gain = self.u_nom * np.ones( 

653 (num_timesteps + 1, self.u_nom.size, 1) 

654 ) 

655 

656 abs_dt = np.abs(self.dt) 

657 

658 time_vec = tt + abs_dt * np.linspace(0, 1, num_timesteps + 1, endpoint=True) 

659 

660 return old_cost, num_timesteps, traj, time_vec 

661 

662 def calculate_control( 

663 self, 

664 tt, 

665 cur_state, 

666 end_state, 

667 state_args=None, 

668 ctrl_args=None, 

669 cost_args=None, 

670 inv_state_args=None, 

671 inv_ctrl_args=None, 

672 provide_details=False, 

673 disp=True, 

674 show_animation=False, 

675 save_animation=False, 

676 plt_opts=None, 

677 ttl=None, 

678 fig=None, 

679 plt_inds=None, 

680 ax_num=0, 

681 ): 

682 """Calculate the control parameters and state trajectory. 

683 

684 Parameters 

685 ---------- 

686 tt : float 

687 Current time when starting the control calculation. 

688 cur_state : N x 1 numpy array 

689 Current state. 

690 end_state : N x 1 numpy array 

691 Desired/reference ending state. 

692 state_args : tuple, optional 

693 Extra arguments for calculating the state matrix. The default is None. 

694 ctrl_args : tuple, optional 

695 Extra arguments for calculating the input matrix. The default is None. 

696 cost_args : tuple, optional 

697 Extra arguments for the cost function, either the non-quadratic 

698 part or the custom function. The default is None. 

699 inv_state_args : tuple, optional 

700 Extra arguments for calculating the inverse state matrix. The 

701 default is None. 

702 inv_ctrl_args : tuple, optional 

703 Extra arguments for calculating the inverse input matrix. The 

704 default is None. 

705 provide_details : bool, optional 

706 Flag indicating if additional outputs should be provided. The 

707 default is False. 

708 disp : bool, optional 

709 Flag indicating if extra text should be displayed during the 

710 optimization. The default is True. 

711 show_animation : bool, optional 

712 Flag indicating if an animation should be shown during the 

713 optimization. The default is False. 

714 save_animation : bool, optional 

715 Flag indicating if the frames of the animation should be saved. 

716 Requires that the animation is shown. The default is False. 

717 plt_opts : dict, optional 

718 Options for the plot, see :func:`gncpy.plotting.init_plotting_opts`. 

719 The default is None. 

720 ttl : string, optional 

721 Title for the plot. The default is None. 

722 fig : matplotlib figure, optional 

723 Figure to draw on. If supplied only the end point and paths are 

724 drawn, and the title is updated if supplied; the rest of the figure 

725 remains unchanged. The default is None which creates a new figure 

726 and adds a title. If no title is supplied a default is used. 

727 plt_inds : list, optional 

728 2 element list of state indices for plotting. The default is None 

729 which assumes [0, 1]. 

730 

731 Returns 

732 ------- 

733 u : Nu x 1 numpy array 

734 Control input for current timestep. 

735 cost : float, optional 

736 Cost of the trajectory 

737 state_traj : Nh+1 x N numpy array, optional 

738 State trajectory over the horizon, or until the reference is reached. 

739 ctrl_signal : Nh x Nu numy array, optional 

740 Control inputs over the horizon, or until the reference is reached. 

741 fig : matplotlib figure, optional 

742 Figure handle for the generated plot, None if animation is not shown. 

743 frame_list : list, optional 

744 Each element is a PIL.Image corresponding to an animation frame. 

745 """ 

746 if state_args is None: 

747 state_args = () 

748 if ctrl_args is None: 

749 ctrl_args = () 

750 if cost_args is None: 

751 cost_args = () 

752 if inv_state_args is None: 

753 inv_state_args = () 

754 if inv_ctrl_args is None: 

755 inv_ctrl_args = () 

756 

757 if plt_inds is None: 

758 plt_inds = [0, 1] 

759 self._ax = ax_num 

760 

761 old_cost, num_timesteps, traj, time_vec = self.reset(tt, cur_state, end_state) 

762 

763 frame_list = [] 

764 if show_animation: 

765 if fig is None: 

766 fig = plt.figure() 

767 if len(plt_inds) == 3: 

768 extra_args = {"projection": "3d"} 

769 else: 

770 extra_args = {} 

771 fig.add_subplot(1, 1, 1, **extra_args) 

772 try: 

773 fig.axes[self._ax].set_aspect("equal", adjustable="box") 

774 except Exception: 

775 pass 

776 

777 if plt_opts is None: 

778 plt_opts = gplot.init_plotting_opts(f_hndl=fig) 

779 

780 if ttl is None: 

781 ttl = "ELQR" 

782 

783 gplot.set_title_label( 

784 fig, self._ax, plt_opts, ttl=ttl, use_local=(self._ax != 0) 

785 ) 

786 

787 # draw start 

788 fig.axes[self._ax].scatter( 

789 *self._init_state.ravel()[plt_inds], 

790 marker="o", 

791 color="g", 

792 zorder=1000, 

793 ) 

794 fig.tight_layout() 

795 else: 

796 if ttl is not None: 

797 gplot.set_title_label( 

798 fig, self._ax, plt_opts, ttl=ttl, use_local=(self._ax != 0) 

799 ) 

800 

801 fig.axes[self._ax].scatter( 

802 *self.end_state.ravel()[plt_inds], 

803 marker="x", 

804 color="r", 

805 zorder=1000, 

806 ) 

807 if matplotlib.get_backend() != "Agg": 

808 plt.pause(0.01) 

809 

810 # for stopping simulation with the esc key. 

811 fig.canvas.mpl_connect( 

812 "key_release_event", 

813 lambda event: [exit(0) if event.key == "escape" else None], 

814 ) 

815 fig_w, fig_h = fig.canvas.get_width_height() 

816 

817 # save first frame of animation 

818 if save_animation: 

819 with io.BytesIO() as buff: 

820 fig.savefig(buff, format="raw") 

821 buff.seek(0) 

822 arr = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape( 

823 (*(fig.canvas.get_width_height()[1::-1]), -1) 

824 ) 

825 if arr.shape[0] != fig_w or arr.shape[1] != fig_h: 

826 img = Image.fromarray(arr).resize((fig_w, fig_h), Image.BICUBIC) 

827 else: 

828 img = Image.fromarray(arr) 

829 frame_list.append(img) 

830 

831 if disp: 

832 print("Starting ELQR optimization loop...") 

833 

834 for ii in range(self.max_iters): 

835 # forward pass 

836 traj = self.forward_pass( 

837 ii, 

838 num_timesteps, 

839 traj, 

840 state_args, 

841 ctrl_args, 

842 cost_args, 

843 time_vec, 

844 inv_state_args, 

845 inv_ctrl_args, 

846 ) 

847 

848 # quadratize final cost 

849 traj = self.quadratize_final_cost( 

850 ii, num_timesteps, traj, time_vec, cost_args 

851 ) 

852 

853 # backward pass 

854 traj = self.backward_pass( 

855 ii, 

856 num_timesteps, 

857 traj, 

858 state_args, 

859 ctrl_args, 

860 cost_args, 

861 time_vec, 

862 inv_state_args, 

863 inv_ctrl_args, 

864 ) 

865 

866 # get cost 

867 cost = 0 

868 x = traj[0, :].copy().reshape((-1, 1)) 

869 for kk, tt in enumerate(time_vec[:-1]): 

870 u = self.feedback_gain[kk] @ x + self.feedthrough_gain[kk] 

871 if self.control_constraints is not None: 

872 u = self.control_constraints(tt, u) 

873 cost += self.cost_function( 

874 tt, 

875 x, 

876 u, 

877 cost_args, 

878 is_initial=(kk == 0), 

879 is_final=False, 

880 ) 

881 x = self.prop_state( 

882 tt, x, u, state_args, ctrl_args, True, inv_state_args, inv_ctrl_args 

883 ) 

884 cost += self.cost_function( 

885 time_vec[-1], 

886 x, 

887 u, 

888 cost_args, 

889 is_initial=False, 

890 is_final=True, 

891 ) 

892 

893 if show_animation: 

894 img = self.draw_traj( 

895 fig, 

896 plt_inds, 

897 fig_h, 

898 fig_w, 

899 (save_animation and ii % self.gif_frame_skip == 0), 

900 num_timesteps, 

901 time_vec, 

902 state_args, 

903 ctrl_args, 

904 inv_state_args, 

905 inv_ctrl_args, 

906 ) 

907 if save_animation and ii % self.gif_frame_skip == 0: 

908 frame_list.append(img) 

909 

910 if disp: 

911 print("\tIteration: {:3d} Cost: {:10.4f}".format(ii, cost)) 

912 

913 # check for convergence 

914 if np.abs((old_cost - cost) / cost) < self.tol: 

915 break 

916 old_cost = cost 

917 

918 # create outputs and return 

919 ctrl_signal = np.nan * np.ones((num_timesteps, self.u_nom.size)) 

920 state_traj = np.nan * np.ones((num_timesteps + 1, self._init_state.size)) 

921 cost = 0 

922 state_traj[0, :] = self._init_state.flatten() 

923 for kk, tt in enumerate(time_vec[:-1]): 

924 ctrl_signal[kk, :] = ( 

925 self.feedback_gain[kk] @ state_traj[kk, :].reshape((-1, 1)) 

926 + self.feedthrough_gain[kk] 

927 ).ravel() 

928 if self.control_constraints is not None: 

929 ctrl_signal[kk, :] = self.control_constraints( 

930 tt, ctrl_signal[kk, :].reshape((-1, 1)) 

931 ).ravel() 

932 

933 cost += self.cost_function( 

934 tt, 

935 state_traj[kk, :].reshape((-1, 1)), 

936 ctrl_signal[kk, :].reshape((-1, 1)), 

937 cost_args, 

938 is_initial=(kk == 0), 

939 is_final=False, 

940 ) 

941 state_traj[kk + 1, :] = self.prop_state( 

942 tt, 

943 state_traj[kk, :].reshape((-1, 1)), 

944 ctrl_signal[kk, :].reshape((-1, 1)), 

945 state_args, 

946 ctrl_args, 

947 True, 

948 inv_state_args, 

949 inv_ctrl_args, 

950 ).ravel() 

951 

952 cost += self.cost_function( 

953 time_vec[-1], 

954 state_traj[num_timesteps, :].reshape((-1, 1)), 

955 ctrl_signal[num_timesteps - 1, :].reshape((-1, 1)), 

956 cost_args, 

957 is_initial=False, 

958 is_final=True, 

959 ) 

960 

961 if show_animation: 

962 extra_args = dict( 

963 linestyle="-", 

964 color="g", 

965 ) 

966 if len(plt_inds) == 3: 

967 fig.axes[self._ax].plot( 

968 state_traj[:, plt_inds[0]], 

969 state_traj[:, plt_inds[1]], 

970 state_traj[:, plt_inds[2]], 

971 **extra_args, 

972 ) 

973 else: 

974 fig.axes[self._ax].plot( 

975 state_traj[:, plt_inds[0]], state_traj[:, plt_inds[1]], **extra_args 

976 ) 

977 if matplotlib.get_backend() != "Agg": 

978 plt.pause(0.001) 

979 if save_animation: 

980 with io.BytesIO() as buff: 

981 fig.savefig(buff, format="raw") 

982 buff.seek(0) 

983 arr = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape( 

984 (*(fig.canvas.get_width_height()[1::-1]), -1) 

985 ) 

986 if arr.shape[0] != fig_w or arr.shape[1] != fig_h: 

987 img = Image.fromarray(arr).resize((fig_w, fig_h), Image.BICUBIC) 

988 else: 

989 img = Image.fromarray(arr) 

990 frame_list.append(img) 

991 

992 u = ctrl_signal[0, :].reshape((-1, 1)) 

993 details = (cost, state_traj, ctrl_signal, fig, frame_list) 

994 return (u, *details) if provide_details else u