Coverage for src/carbs/guidance.py: 0%

403 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-15 06:48 +0000

1"""Implements RFS guidance algorithms. 

2 

3This module contains the classes and data structures 

4for RFS guidance related algorithms. 

5""" 

6import io 

7import numpy as np 

8import scipy.linalg as la 

9import matplotlib.pyplot as plt 

10from PIL import Image 

11from copy import deepcopy 

12from warnings import warn 

13from scipy.optimize import linear_sum_assignment 

14 

15import gncpy.control as gcontrol 

16import gncpy.plotting as gplot 

17import serums.models as smodels 

18from serums.distances import calculate_ospa 

19from serums.enums import SingleObjectDistance 

20 

21 

22def gaussian_density_cost(state_dist, goal_dist, safety_factor, y_ref): 

23 r"""Implements a GM density based cost function. 

24 

25 Notes 

26 ----- 

27 Implements the following cost function based on the difference between 

28 Gaussian mixtures, with additional terms to improve convergence when 

29 far from the targets. 

30 

31 .. math:: 

32 J &= \sum_{k=1}^{T} 10 N_g\sigma_{g,max}^2 \left( \sum_{j=1}^{N_g} 

33 \sum_{i=1}^{N_g} w_{g,k}^{(j)} w_{g,k}^{(i)} 

34 \mathcal{N}( \mathbf{m}^{(j)}_{g,k}; \mathbf{m}^{(i)}_{g,k}, 

35 P^{(j)}_{g, k} + P^{(i)}_{g, k} ) \right. \\ 

36 &- \left. 20 \sigma_{d, max}^2 N_d \sum_{j=1}^{N_d} \sum_{i=1}^{N_g} 

37 w_{d,k}^{(j)} w_{g,k}^{(i)} \mathcal{N}( 

38 \mathbf{m}^{(j)}_{d, k}; \mathbf{m}^{(i)}_{g, k}, 

39 P^{(j)}_{d, k} + P^{(i)}_{g, k} ) \right) \\ 

40 &+ \sum_{j=1}^{N_d} \sum_{i=1}^{N_d} w_{d,k}^{(j)} 

41 w_{d,k}^{(i)} \mathcal{N}( \mathbf{m}^{(j)}_{d,k}; 

42 \mathbf{m}^{(i)}_{d,k}, P^{(j)}_{d, k} + P^{(i)}_{d, k} ) \\ 

43 &+ \alpha \sum_{j=1}^{N_d} \sum_{i=1}^{N_g} w_{d,k}^{(j)} 

44 w_{g,k}^{(i)} \ln{\mathcal{N}( \mathbf{m}^{(j)}_{d,k}; 

45 \mathbf{m}^{(i)}_{g,k}, P^{(j)}_{d, k} + P^{(i)}_{g, k} )} 

46 

47 Parameters 

48 ---------- 

49 state_dist : :class:`serums.models.GaussianMixture` 

50 Initial state distribution. 

51 goal_dist : :class:`serums.models.GaussianMixture` 

52 Desired state distribution. 

53 safety_factor : float 

54 Overbounding tuning factor for extra convergence term. 

55 y_ref : float 

56 Reference point to use on the sigmoid function, must be in the range 

57 (0, 1). 

58 

59 Returns 

60 ------- 

61 float 

62 cost. 

63 """ 

64 all_goals = np.array([m.ravel().tolist() for m in goal_dist.means]) 

65 all_states = np.array([m.ravel().tolist() for m in state_dist.means]) 

66 target_center = np.mean(all_goals, axis=0).reshape((-1, 1)) 

67 num_targets = all_goals.shape[0] 

68 num_objects = all_states.shape[0] 

69 state_dim = all_states.shape[1] 

70 

71 # find radius of influence and shift 

72 diff = all_goals - target_center.T 

73 max_dist = np.sqrt(np.max(np.sum(diff * diff, axis=1))) 

74 radius_of_influence = safety_factor * max_dist 

75 shift = radius_of_influence + np.log(1 / y_ref - 1) 

76 

77 # get actiavation term 

78 diff = all_states - target_center.T 

79 max_dist = np.sqrt(np.max(np.sum(diff * diff, axis=1))) 

80 activator = 1 / (1 + np.exp(-(max_dist - shift))) 

81 

82 # get maximum variance 

83 max_var_obj = max(map(lambda x: float(np.max(np.diag(x))), state_dist.covariances)) 

84 max_var_target = max( 

85 map(lambda x: float(np.max(np.diag(x))), goal_dist.covariances) 

86 ) 

87 

88 # Loop for all double summation terms 

89 sum_obj_obj = 0 

90 sum_obj_target = 0 

91 quad = 0 

92 for out_w, out_dist in state_dist: 

93 # create temporary gaussian object for calculations 

94 temp_gauss = smodels.Gaussian(mean=out_dist.mean) 

95 

96 # object to object cost 

97 for in_w, in_dist in state_dist: 

98 temp_gauss.covariance = out_dist.covariance + in_dist.covariance 

99 sum_obj_obj += in_w * out_w * temp_gauss.pdf(in_dist.mean) 

100 

101 # object to target and quadratic 

102 for tar_w, tar_dist in goal_dist: 

103 # object to target 

104 temp_gauss.covariance = out_dist.covariance + tar_dist.covariance 

105 sum_obj_target += tar_w * out_w * temp_gauss.pdf(tar_dist.mean) 

106 

107 # quadratic 

108 diff = out_dist.mean - tar_dist.mean 

109 log_term = ( 

110 np.log( 

111 (2 * np.pi) ** (-0.5 * state_dim) 

112 / np.sqrt(la.det(temp_gauss.covariance)) 

113 ) 

114 - 0.5 * diff.T @ la.inv(temp_gauss.covariance) @ diff 

115 ) 

116 quad += out_w * tar_w * log_term.item() 

117 

118 sum_target_target = 0 

119 for out_w, out_dist in goal_dist: 

120 temp_gauss = smodels.Gaussian(mean=out_dist.mean) 

121 for in_w, in_dist in goal_dist: 

122 temp_gauss.covariance = out_dist.covariance + in_dist.covariance 

123 sum_target_target += out_w * in_w * temp_gauss.pdf(in_dist.mean) 

124 

125 return ( 

126 10 

127 * num_objects 

128 * max_var_obj 

129 * (sum_obj_obj - 2 * max_var_target * num_targets * sum_obj_target) 

130 + sum_target_target 

131 + activator * quad 

132 ) 

133 

134 

135class ELQR: 

136 """Implements the ELQR algorithm for swarms. 

137 

138 Notes 

139 ----- 

140 This follows :cite:`Thomas2021_RecedingHorizonExtendedLinearQuadraticRegulatorforRFSBasedSwarms`. 

141 

142 Attributes 

143 ---------- 

144 max_iters : int 

145 Maximum number of iterations to optimize. 

146 tol : float 

147 Relative tolerance for convergence. 

148 """ 

149 

150 def __init__(self, max_iters=1e3, tol=1e-4): 

151 """Initialize an object. 

152 

153 Parameters 

154 ---------- 

155 max_iters : int, optional 

156 Maximum number of iterations to optimize. The default is 1e3. 

157 tol : float, optional 

158 Relative tolerance for convergence. The default is 1e-4. 

159 """ 

160 super().__init__() 

161 

162 self.max_iters = int(max_iters) 

163 self.tol = tol 

164 self.end_dist = smodels.GaussianMixture() 

165 self.non_quad_weight = 1 

166 

167 self._singleELQR = gcontrol.ELQR() 

168 self._elqr_lst = [] 

169 self._start_covs = [] 

170 self._start_weights = [] 

171 self._time_vec = np.array([]) 

172 self._cur_ind = None 

173 self._non_quad_fun = None 

174 

175 self.__safety_factor = 1 

176 self.__y_ref = 0.5 

177 

178 def get_state_dist(self, tt): 

179 """Calculates the current state distribution. 

180 

181 Parameters 

182 ---------- 

183 tt : float 

184 Current timestep. 

185 

186 Returns 

187 ------- 

188 :class:`serums.models.GaussianMixture` 

189 State distribution. 

190 """ 

191 kk = int(np.argmin(np.abs(tt - self._time_vec))) 

192 means = [ 

193 params["traj"][kk, :].copy().reshape((-1, 1)) for params in self._elqr_lst 

194 ] 

195 return smodels.GaussianMixture( 

196 means=means, 

197 covariances=[c.copy() for c in self._start_covs], 

198 weights=self._start_weights.copy(), 

199 ) 

200 

201 def non_quad_fun_factory(self): 

202 r"""Factory for creating the non-quadratic cost function. 

203 

204 This should generate a function of the same form needed by the single 

205 agent controller, that is `time, state, control input, end state, 

206 is initial flag, is final flag, *args`. For this class, it implements 

207 a GM density based cost function plus an optional additional cost function 

208 of the same form. It returns a float for the non-quadratic part of 

209 the cost. 

210 

211 Returns 

212 ------- 

213 callable 

214 function to calculate cost. 

215 """ 

216 

217 def non_quadratic_fun( 

218 tt, state, ctrl_input, end_state, is_initial, is_final, *args 

219 ): 

220 state_dist = self.get_state_dist(tt) 

221 state_dist.remove_components( 

222 [self._cur_ind,] # noqa 

223 ) 

224 state_dist.add_components( 

225 state.reshape((-1, 1)), 

226 self._start_covs[self._cur_ind], 

227 self._start_weights[self._cur_ind], 

228 ) 

229 cost = gaussian_density_cost( 

230 state_dist, self.end_dist, self.__safety_factor, self.__y_ref 

231 ) 

232 if self._non_quad_fun is not None: 

233 cost += self.non_quad_weight * self._non_quad_fun( 

234 tt, state, ctrl_input, end_state, is_initial, is_final, *args 

235 ) 

236 return cost 

237 

238 return non_quadratic_fun 

239 

240 def set_cost_model( 

241 self, 

242 safety_factor=None, 

243 y_ref=None, 

244 non_quad_fun=None, 

245 quad_modifier=None, 

246 non_quad_weight=None, 

247 ): 

248 """Set the parameters for the cost non-quadratic function. 

249 

250 This sets the parameters for the built-in non-quadratic cost function 

251 and allows specifying an additional non-quadratic function (see 

252 :meth:`gncpy.control.ELQR`) for the agents. Any additional non-quadratic 

253 terms must be set here instead of within the single control model directly 

254 due to the built-in multi-agent non-quadratic cost. 

255 

256 Parameters 

257 ---------- 

258 safety_factor : float, optional 

259 Additional scaling factor on the radius of influence. The default 

260 is None. 

261 y_ref : float, optional 

262 Reference point on the sigmoid function to use when determining 

263 th radius of influence. Should be between 0 and 1 (non-inclusive). 

264 The default is None. 

265 non_quad_fun : callable, optional 

266 Optional additional non-quadratic cost. Must have the same form as 

267 that of the single agent controller. See :meth:`gncpy.control.ELQR`. 

268 The default is None. 

269 quad_modifier : callable, optional 

270 Function to modify the quadratization process, see 

271 :meth:`gncpy.control.ELQR`. The default is None. 

272 non_quad_weight : float 

273 Additional scaling to be applied to the additional non_quadratic 

274 cost (the non built-in term). The default is None 

275 """ 

276 if safety_factor is not None: 

277 self.__safety_factor = safety_factor 

278 

279 if y_ref is not None: 

280 self.__y_ref = y_ref 

281 

282 if non_quad_fun is not None: 

283 self._non_quad_fun = non_quad_fun 

284 

285 if quad_modifier is not None: 

286 self._singleELQR.set_cost_model( 

287 quad_modifier=quad_modifier, skip_validity_check=True 

288 ) 

289 

290 if non_quad_weight is not None: 

291 self.non_quad_weight = non_quad_weight 

292 

293 def set_control_model(self, singleELQR): 

294 """Sets the single agent control model used. 

295 

296 Parameters 

297 ---------- 

298 singleELQR : :class:`gncpy.control.ELQR` 

299 Single agent controller for generating trajectories. 

300 quad_modifier : callable, optional 

301 Modifing function for the quadratization. See 

302 :meth:`gncpy.control.ELQR.set_cost_model`. The default is None. 

303 """ 

304 self._singleELQR = deepcopy(singleELQR) 

305 

306 def find_end_state(self, tt, cur_ind, cur_state): 

307 """Finds the ending state for the given current state. 

308 

309 Parameters 

310 ---------- 

311 tt : float 

312 Current timestep 

313 cur_ind : int 

314 Current index into the state distribution 

315 cur_state : N x 1 numpy array 

316 Current state. 

317 

318 Returns 

319 ------- 

320 N x 1 numpy array 

321 Best ending state given the current state. 

322 """ 

323 all_ends = np.vstack([m.ravel() for m in self.end_dist.means]) 

324 diff = all_ends - cur_state.T 

325 min_ind = int(np.argmin(np.sum(diff * diff, axis=1))) 

326 

327 return self.end_dist.means[min_ind] 

328 

329 def init_elqr_lst(self, tt, start_dist): 

330 """Initialize the list of single agent ELQR controllers. 

331 

332 Parameters 

333 ---------- 

334 tt : float 

335 current time. 

336 start_dist : :class:`serums.models.GaussianMixture` 

337 Starting gaussian mixture. 

338 

339 Returns 

340 ------- 

341 num_timesteps : int 

342 total number of timesteps. 

343 """ 

344 self._elqr_lst = [] 

345 for ind, (w, dist) in enumerate(start_dist): 

346 p = {} 

347 p["elqr"] = deepcopy(self._singleELQR) 

348 end_state = np.zeros((dist.location.size, 1)) 

349 p["old_cost"], num_timesteps, p["traj"], self._time_vec = p["elqr"].reset( 

350 tt, dist.location, end_state 

351 ) 

352 self._elqr_lst.append(p) 

353 

354 # reset with proper end state 

355 for ind, ((w, dist), p) in enumerate(zip(start_dist, self._elqr_lst)): 

356 end_state = self.find_end_state(tt, ind, dist.location) 

357 p["elqr"].reset(tt, dist.location, end_state) 

358 

359 return num_timesteps 

360 

361 def targets_to_wayareas(self, end_states): 

362 """Converts target locations to wayareas with automatic scaling. 

363 

364 Performs a Principal Component Analysis (PCA) on the ending state 

365 locations to create a Gaussian Mixture. 

366 

367 Parameters 

368 ---------- 

369 end_states : Nt x N numpy array 

370 All possible ending states, one per row. 

371 

372 Returns 

373 ------- 

374 :class:`serums.models.GaussianMixture` 

375 Ending state distribution. 

376 """ 

377 

378 def find_principal_components(data): 

379 num_samps = data.shape[0] 

380 num_feats = data.shape[1] 

381 

382 mean = np.sum(data, 0) / num_samps 

383 covars = np.zeros((num_feats, num_feats)) 

384 for ii in range(0, num_feats): 

385 for jj in range(0, num_feats): 

386 acc = 0 

387 for samp in range(0, num_samps): 

388 acc += (data[samp, ii] - mean[ii]) * (data[samp, jj] - mean[jj]) 

389 covars[ii, jj] = acc / num_samps 

390 (w, comps) = la.eig(covars) 

391 inds = np.argsort(w)[::-1] 

392 return comps[:, inds].T 

393 

394 def find_largest_proj_dist(new_dirs, old_dirs): 

395 vals = np.zeros(new_dirs.shape[0]) 

396 for ii in range(0, new_dirs.shape[1]): 

397 for jj in range(0, old_dirs.shape[1]): 

398 proj = np.abs(new_dirs[:, [ii]].T @ old_dirs[:, [jj]]) 

399 if proj > vals[ii]: 

400 vals[ii] = proj 

401 return vals 

402 

403 thresh = 1e-2 

404 

405 wayareas = smodels.GaussianMixture() 

406 all_ends = np.vstack([s.ravel() for s in end_states]) 

407 aug_end_states = np.vstack((all_ends, np.mean(all_ends, axis=0))) 

408 

409 directions = np.zeros( 

410 (aug_end_states.shape[1], aug_end_states.shape[0], aug_end_states.shape[0]) 

411 ) 

412 for ii, s_pt in enumerate(aug_end_states): 

413 for jj, e_pt in enumerate(aug_end_states): 

414 directions[:, ii, jj] = e_pt - s_pt 

415 

416 weight = 1 / len(end_states) 

417 for wp_ind, center in enumerate(all_ends): 

418 sample_data = np.delete(aug_end_states, wp_ind, axis=0) 

419 sample_dirs = np.delete(directions[:, wp_ind, :].squeeze(), wp_ind, axis=1) 

420 comps = find_principal_components(sample_data) 

421 vals = find_largest_proj_dist(comps, sample_dirs) 

422 vals[vals <= thresh] = thresh 

423 

424 cov = comps @ np.diag(vals) @ la.inv(comps) 

425 wayareas.add_components(center.reshape((-1, 1)), cov, weight) 

426 

427 return wayareas 

428 

429 def gen_final_traj( 

430 self, 

431 num_timesteps, 

432 start, 

433 elqr, 

434 state_args, 

435 ctrl_args, 

436 cost_args, 

437 inv_state_args, 

438 inv_ctrl_args, 

439 ): 

440 """Generates the final trajectory state and control trajectories. 

441 

442 Parameters 

443 ---------- 

444 num_timesteps : int 

445 total number of timesteps. 

446 start : N x 1 numpy array 

447 Initial state. 

448 elqr : :class:`gncpy.control.ELQR` 

449 Single agent controller for the given starting state. 

450 state_args : tuple 

451 Additional arguments for the state matrix. 

452 ctrl_args : tuple 

453 Additional arguments for the input matrix. 

454 cost_args : tuple 

455 Additional arguments for the cost function. 

456 inv_state_args : tuple 

457 Additional arguments for the inverse state transition matrix. 

458 inv_ctrl_args : tuple 

459 Additional arguments for the inverse input matrix. 

460 

461 Returns 

462 ------- 

463 state_traj : Nh+1 x N numpy array 

464 state trajectory. 

465 ctrl_signal : Nh x Nu numpy array 

466 control signal. 

467 cost : float 

468 cost of the trajectory. 

469 """ 

470 ctrl_signal = np.nan * np.ones((num_timesteps, elqr.u_nom.size)) 

471 state_traj = np.nan * np.ones((num_timesteps + 1, start.size)) 

472 cost = 0 

473 state_traj[0, :] = start.flatten() 

474 for kk, tt in enumerate(self._time_vec[:-1]): 

475 ctrl_signal[kk, :] = ( 

476 elqr.feedback_gain[kk] @ state_traj[kk, :].reshape((-1, 1)) 

477 + elqr.feedthrough_gain[kk] 

478 ).ravel() 

479 if elqr.control_constraints is not None: 

480 ctrl_signal[kk, :] = elqr.control_constraints( 

481 tt, ctrl_signal[kk, :].reshape((-1, 1)) 

482 ).ravel() 

483 cost += elqr.cost_function( 

484 tt, 

485 state_traj[kk, :].reshape((-1, 1)), 

486 ctrl_signal[kk, :].reshape((-1, 1)), 

487 cost_args, 

488 is_initial=(kk == 0), 

489 is_final=False, 

490 ) 

491 state_traj[kk + 1, :] = elqr.prop_state( 

492 tt, 

493 state_traj[kk, :].reshape((-1, 1)), 

494 ctrl_signal[kk, :].reshape((-1, 1)), 

495 state_args, 

496 ctrl_args, 

497 True, 

498 inv_state_args, 

499 inv_ctrl_args, 

500 ).ravel() 

501 

502 cost += elqr.cost_function( 

503 self._time_vec[-1], 

504 state_traj[num_timesteps, :].reshape((-1, 1)), 

505 ctrl_signal[num_timesteps - 1, :].reshape((-1, 1)), 

506 cost_args, 

507 is_initial=False, 

508 is_final=True, 

509 ) 

510 

511 return state_traj, ctrl_signal, cost 

512 

513 def draw_init_states( 

514 self, fig, states, plt_inds, marker, zorder, cmap=None, color=None 

515 ): 

516 kwargs = dict(marker=marker, zorder=zorder,) 

517 if color is not None: 

518 kwargs["color"] = color 

519 

520 for c_ind, (w, dist) in enumerate(states): 

521 s = dist.location 

522 if cmap is not None: 

523 kwargs["color"] = cmap(c_ind) 

524 fig.axes[0].scatter(s[plt_inds[0], 0], s[plt_inds[1], 0], **kwargs) 

525 

526 def save_animation(self, fig, fig_h, fig_w, frame_list): 

527 with io.BytesIO() as buff: 

528 fig.savefig(buff, format="raw") 

529 buff.seek(0) 

530 img = np.frombuffer(buff.getvalue(), dtype=np.uint8).reshape( 

531 (fig_h, fig_w, -1) 

532 ) 

533 frame_list.append(Image.fromarray(img)) 

534 

535 def init_plot( 

536 self, 

537 show_animation, 

538 save_animation, 

539 cmap, 

540 start_dist, 

541 fig, 

542 plt_opts, 

543 ttl, 

544 plt_inds, 

545 ): 

546 frame_list = [] 

547 fig_h = None 

548 fig_w = None 

549 if show_animation: 

550 if cmap is None: 

551 cmap = gplot.get_cmap(len(start_dist)) 

552 

553 if fig is None: 

554 fig = plt.figure() 

555 fig.add_subplot(1, 1, 1) 

556 fig.axes[0].set_aspect("equal", adjustable="box") 

557 

558 if plt_opts is None: 

559 plt_opts = gplot.init_plotting_opts(f_hndl=fig) 

560 

561 if ttl is None: 

562 ttl = "Multi-Agent ELQR" 

563 

564 gplot.set_title_label(fig, 0, plt_opts, ttl=ttl) 

565 

566 # draw start 

567 self.draw_init_states(fig, start_dist, plt_inds, "o", 1000, cmap=cmap) 

568 

569 self.draw_init_states(fig, self.end_dist, plt_inds, "x", 1000, color="r") 

570 

571 fig.tight_layout() 

572 plt.pause(0.1) 

573 

574 # for stopping simulation with the esc key. 

575 fig.canvas.mpl_connect( 

576 "key_release_event", 

577 lambda event: [exit(0) if event.key == "escape" else None], 

578 ) 

579 fig_w, fig_h = fig.canvas.get_width_height() 

580 

581 # save first frame of animation 

582 if save_animation: 

583 self.save_animation(fig, fig_h, fig_w, frame_list) 

584 

585 return fig, fig_h, fig_w, frame_list, cmap 

586 

587 def reset(self, start_dist): 

588 self._start_covs = [c.copy() for c in start_dist.covariances] 

589 self._start_weights = [w for w in start_dist.weights] 

590 

591 def output_helper( 

592 self, 

593 c_ind, 

594 num_timesteps, 

595 start, 

596 params, 

597 state_args, 

598 ctrl_args, 

599 cost_args, 

600 inv_state_args, 

601 inv_ctrl_args, 

602 costs, 

603 state_trajs, 

604 ctrl_signals, 

605 show_animation, 

606 fig, 

607 plt_inds, 

608 cmap, 

609 ): 

610 self._cur_ind = c_ind 

611 params["elqr"].set_cost_model( 

612 non_quadratic_fun=self.non_quad_fun_factory(), skip_validity_check=True 

613 ) 

614 st, cs, c = self.gen_final_traj( 

615 num_timesteps, 

616 start, 

617 params["elqr"], 

618 state_args, 

619 ctrl_args, 

620 cost_args, 

621 inv_state_args, 

622 inv_ctrl_args, 

623 ) 

624 costs.append(c) 

625 state_trajs.append(st) 

626 ctrl_signals.append(cs) 

627 

628 if show_animation: 

629 fig.axes[0].plot( 

630 st[:, plt_inds[0]], 

631 st[:, plt_inds[1]], 

632 linestyle="-", 

633 color=cmap(c_ind), 

634 ) 

635 plt.pause(0.01) 

636 

637 def create_outputs( 

638 self, 

639 start_dist, 

640 num_timesteps, 

641 state_args, 

642 ctrl_args, 

643 cost_args, 

644 inv_state_args, 

645 inv_ctrl_args, 

646 show_animation, 

647 fig, 

648 plt_inds, 

649 cmap, 

650 ): 

651 costs = [] 

652 state_trajs = [] 

653 ctrl_signals = [] 

654 for c_ind, ((w, dist), params) in enumerate(zip(start_dist, self._elqr_lst)): 

655 self.output_helper( 

656 c_ind, 

657 num_timesteps, 

658 dist.location, 

659 params, 

660 state_args, 

661 ctrl_args, 

662 cost_args, 

663 inv_state_args, 

664 inv_ctrl_args, 

665 costs, 

666 state_trajs, 

667 ctrl_signals, 

668 show_animation, 

669 fig, 

670 plt_inds, 

671 cmap, 

672 ) 

673 

674 return state_trajs, costs, ctrl_signals 

675 

676 def plan( 

677 self, 

678 tt, 

679 start_dist, 

680 end_dist, 

681 state_args=None, 

682 ctrl_args=None, 

683 cost_args=None, 

684 inv_state_args=None, 

685 inv_ctrl_args=None, 

686 provide_details=False, 

687 disp=True, 

688 show_animation=False, 

689 save_animation=False, 

690 plt_opts=None, 

691 ttl=None, 

692 fig=None, 

693 cmap=None, 

694 plt_inds=None, 

695 update_end=True, 

696 update_end_iters=1, 

697 ): 

698 """Main planning function. 

699 

700 Parameters 

701 ---------- 

702 tt : float 

703 Starting timestep for the plan. 

704 start_dist : :class:`serums.models.GaussianMixture` 

705 Starting state distribution. 

706 end_dist : :class:`serums.models.GaussianMixture` 

707 Ending state distribution. 

708 state_args : tuple, optional 

709 Additional arguments for getting the state transition matrix. The 

710 default is None. 

711 ctrl_args : tuple, optional 

712 Additional arguements for getting the input matrix. The default is 

713 None. 

714 cost_args : tuple, optional 

715 Additional arguments for the cost function. The default is None. 

716 inv_state_args : tuple, optional 

717 Additional arguments to get the inverse state matrix. The default 

718 is None. 

719 inv_ctrl_args : tuple, optional 

720 Additional arguments to get the inverse input matrix. The default 

721 is None. 

722 provide_details : bool, optional 

723 Falg for if optional outputs should be output. The default is False. 

724 disp : bool, optional 

725 Falg for if additional text should be printed. The default is True. 

726 show_animation : bool, optional 

727 Flag for if an animation is generated. The default is False. 

728 save_animation : bool, optional 

729 Flag for saving the animation. Only applies if the animation is 

730 shown. The default is False. 

731 plt_opts : dict, optional 

732 Additional plotting options. See 

733 :func:`gncpy.plotting.init_plotting_opts`. The default is None. 

734 ttl : string, optional 

735 Title of the generated plot. The default is None. 

736 fig : matplotlib figure, optional 

737 Handle to the figure. If supplied only the end states are added. 

738 The default is None. 

739 cmap : matplotlib colormap, optional 

740 Color map for the different agents. See :func:`gncpy.plotting.get_cmap`. 

741 The default is None. 

742 plt_inds : list, optional 

743 Indices in the state vector to plot. The default is None. 

744 

745 Returns 

746 ------- 

747 state_trajs : list 

748 Each element is an Nh+1xN numpy array. 

749 costs : list, optional 

750 Each element is a float for the cost of that trajectory 

751 ctrl_signals : list, optional 

752 Each element is an NhxNu numpy array 

753 fig : matplotlib figure, optional 

754 Handle to the generated figure 

755 frame_list : list, optional 

756 Each element is a PIL image if the animation is being saved. 

757 """ 

758 if state_args is None: 

759 state_args = () 

760 if ctrl_args is None: 

761 ctrl_args = () 

762 if cost_args is None: 

763 cost_args = () 

764 if inv_state_args is None: 

765 inv_state_args = () 

766 if inv_ctrl_args is None: 

767 inv_ctrl_args = () 

768 if plt_inds is None: 

769 plt_inds = [0, 1] 

770 

771 self.end_dist = end_dist 

772 

773 num_timesteps = self.init_elqr_lst(tt, start_dist) 

774 old_cost = float("inf") 

775 self.reset(start_dist) 

776 

777 fig, fig_h, fig_w, frame_list, cmap = self.init_plot( 

778 show_animation, 

779 save_animation, 

780 cmap, 

781 start_dist, 

782 fig, 

783 plt_opts, 

784 ttl, 

785 plt_inds, 

786 ) 

787 

788 if disp: 

789 print("Starting ELQR optimization loop...") 

790 

791 for itr in range(self.max_iters): 

792 # forward pass for each gaussian, step by step 

793 for kk in range(num_timesteps): 

794 for ind, params in enumerate(self._elqr_lst): 

795 self._cur_ind = ind 

796 params["elqr"].set_cost_model( 

797 non_quadratic_fun=self.non_quad_fun_factory(), 

798 skip_validity_check=True, 

799 ) 

800 params["traj"][kk + 1, :] = params["elqr"].forward_pass_step( 

801 itr, 

802 kk, 

803 self._time_vec, 

804 params["traj"], 

805 state_args, 

806 ctrl_args, 

807 cost_args, 

808 inv_state_args, 

809 inv_ctrl_args, 

810 ) 

811 

812 # quadratize final cost for each gaussian 

813 for c_ind, params in enumerate(self._elqr_lst): 

814 self._cur_ind = c_ind 

815 params["elqr"].set_cost_model( 

816 non_quadratic_fun=self.non_quad_fun_factory(), 

817 skip_validity_check=True, 

818 ) 

819 # if update_end and itr != 0 and itr % update_end_iters == 0: 

820 # params["elqr"].end_state = self.find_end_state( 

821 # self._time_vec[-1], c_ind, params["traj"][-1, :].reshape((-1, 1)), 

822 # ) 

823 params["traj"] = params["elqr"].quadratize_final_cost( 

824 itr, num_timesteps, params["traj"], self._time_vec, cost_args 

825 ) 

826 

827 # backward pass for each gaussian 

828 for kk in range(num_timesteps - 1, -1, -1): 

829 for c_ind, params in enumerate(self._elqr_lst): 

830 self._cur_ind = c_ind 

831 params["elqr"].set_cost_model( 

832 non_quadratic_fun=self.non_quad_fun_factory(), 

833 skip_validity_check=True, 

834 ) 

835 params["traj"][kk, :] = params["elqr"].backward_pass_step( 

836 itr, 

837 kk, 

838 self._time_vec, 

839 params["traj"], 

840 state_args, 

841 ctrl_args, 

842 cost_args, 

843 inv_state_args, 

844 inv_ctrl_args, 

845 ) 

846 

847 # get true cost 

848 cost = 0 

849 x = np.nan * np.ones( 

850 (len(self._elqr_lst), self._elqr_lst[0]["traj"].shape[1]) 

851 ) 

852 for kk, tt in enumerate(self._time_vec[:-1]): 

853 for ind, params in enumerate(self._elqr_lst): 

854 x[ind] = params["traj"][kk].copy() 

855 u = ( 

856 params["elqr"].feedback_gain[kk] @ x[ind].reshape((-1, 1)) 

857 + params["elqr"].feedthrough_gain[kk] 

858 ) 

859 if params["elqr"].control_constraints is not None: 

860 u = params["elqr"].control_constraints(tt, u) 

861 self._cur_ind = ind 

862 params["elqr"].set_cost_model( 

863 non_quadratic_fun=self.non_quad_fun_factory(), 

864 skip_validity_check=True, 

865 ) 

866 cost += params["elqr"].cost_function( 

867 tt, 

868 x[ind].reshape((-1, 1)), 

869 u, 

870 cost_args, 

871 is_initial=(kk == 0), 

872 is_final=False, 

873 ) 

874 x[ind] = ( 

875 params["elqr"] 

876 .prop_state( 

877 tt, 

878 x[ind].reshape((-1, 1)), 

879 u, 

880 state_args, 

881 ctrl_args, 

882 True, 

883 inv_state_args, 

884 inv_ctrl_args, 

885 ) 

886 .ravel() 

887 ) 

888 

889 for ind, params in enumerate(self._elqr_lst): 

890 if update_end and itr != 0 and itr % update_end_iters == 0: 

891 params["elqr"].end_state = self.find_end_state( 

892 self._time_vec[-1], ind, x[ind].reshape((-1, 1)) 

893 ) 

894 u = ( 

895 params["elqr"].feedback_gain[-1] @ x[ind].reshape((-1, 1)) 

896 + params["elqr"].feedthrough_gain[-1] 

897 ) 

898 if params["elqr"].control_constraints is not None: 

899 u = params["elqr"].control_constraints(tt, u) 

900 

901 self._cur_ind = ind 

902 params["elqr"].set_cost_model( 

903 non_quadratic_fun=self.non_quad_fun_factory(), 

904 skip_validity_check=True, 

905 ) 

906 cost += params["elqr"].cost_function( 

907 self._time_vec[-1], 

908 x[ind].reshape((-1, 1)), 

909 u, 

910 cost_args, 

911 is_initial=False, 

912 is_final=True, 

913 ) 

914 

915 # for ind, params in enumerate(self._elqr_lst): 

916 # self._cur_ind = ind 

917 # params["elqr"].set_cost_model( 

918 # non_quadratic_fun=self.non_quad_fun_factory(), 

919 # skip_validity_check=True, 

920 # ) 

921 # x = params["traj"][0, :].copy().reshape((-1, 1)) 

922 # for kk, tt in enumerate(self._time_vec[:-1]): 

923 # u = ( 

924 # params["elqr"].feedback_gain[kk] @ x 

925 # + params["elqr"].feedthrough_gain[kk] 

926 # ) 

927 # if params["elqr"].control_constraints is not None: 

928 # u = params["elqr"].control_constraints(tt, u) 

929 # cost += params["elqr"].cost_function( 

930 # tt, x, u, cost_args, is_initial=(kk == 0), is_final=False, 

931 # ) 

932 # x = params["elqr"].prop_state( 

933 # tt, 

934 # x, 

935 # u, 

936 # state_args, 

937 # ctrl_args, 

938 # True, 

939 # inv_state_args, 

940 # inv_ctrl_args, 

941 # ) 

942 # params["elqr"].end_state = self.find_end_state( 

943 # self._time_vec[-1], ind, params["traj"][-1, :].reshape((-1, 1)) 

944 # ) 

945 # cost += params["elqr"].cost_function( 

946 # self._time_vec[-1], 

947 # x, 

948 # u, 

949 # cost_args, 

950 # is_initial=False, 

951 # is_final=True, 

952 # ) 

953 

954 if disp: 

955 print("\tIteration: {:3d} Cost: {:10.4f}".format(itr, cost)) 

956 

957 if show_animation: 

958 for c_ind, params in enumerate(self._elqr_lst): 

959 img = params["elqr"].draw_traj( 

960 fig, 

961 plt_inds, 

962 fig_h, 

963 fig_w, 

964 c_ind == (len(self._elqr_lst) - 1), 

965 num_timesteps, 

966 self._time_vec, 

967 state_args, 

968 ctrl_args, 

969 inv_state_args, 

970 inv_ctrl_args, 

971 color=cmap(c_ind), 

972 alpha=0.2, 

973 zorder=-10, 

974 ) 

975 

976 if save_animation: 

977 frame_list.append(img) 

978 

979 # check for convergence 

980 if np.abs((old_cost - cost) / cost) < self.tol: 

981 break 

982 old_cost = cost 

983 

984 # generate control and state trajectories for all agents 

985 state_trajs, costs, ctrl_signals = self.create_outputs( 

986 start_dist, 

987 num_timesteps, 

988 state_args, 

989 ctrl_args, 

990 cost_args, 

991 inv_state_args, 

992 inv_ctrl_args, 

993 show_animation, 

994 fig, 

995 plt_inds, 

996 cmap, 

997 ) 

998 

999 if show_animation and save_animation: 

1000 plt.pause(0.01) 

1001 self.save_animation(fig, fig_h, fig_w, frame_list) 

1002 

1003 details = (costs, ctrl_signals, fig, frame_list) 

1004 return (state_trajs, *details) if provide_details else state_trajs 

1005 

1006 

1007class ELQROSPA(ELQR): 

1008 def __init__(self, **kwargs): 

1009 super().__init__(**kwargs) 

1010 

1011 self.end_dist = np.array([]) 

1012 

1013 self.loiter_dist = 1 

1014 

1015 self.__ospa_inds = [] 

1016 self.__ospa_cutoff = 1e3 

1017 self.__ospa_core = SingleObjectDistance.EUCLIDEAN 

1018 

1019 def get_state_dist(self, tt): 

1020 """Calculates the current state distribution. 

1021 

1022 Parameters 

1023 ---------- 

1024 tt : float 

1025 Current timestep. 

1026 

1027 Returns 

1028 ------- 

1029 Na x N numpy array 

1030 State distribution, one row per agent 

1031 """ 

1032 kk = int(np.argmin(np.abs(tt - self._time_vec))) 

1033 return np.vstack([params["traj"][kk, :].flatten() for params in self._elqr_lst]) 

1034 

1035 def non_quad_fun_factory(self): 

1036 """Factory for creating the non-quadratic cost function. 

1037 

1038 This should generate a function of the same form needed by the single 

1039 agent controller, that is `time, state, control input, end state, 

1040 is initial flag, is final flag, *args`. For this class, it implements 

1041 an OSPA based cost function plus an optional additional cost function 

1042 of the same form. It returns a float for the non-quadratic part of 

1043 the cost. 

1044 

1045 Returns 

1046 ------- 

1047 callable 

1048 function to calculate cost. 

1049 """ 

1050 

1051 def non_quadratic_fun( 

1052 tt, state, ctrl_input, end_state, is_initial, is_final, *args 

1053 ): 

1054 if self.__ospa_inds is None: 

1055 inds = [i for i in range(state.size)] 

1056 else: 

1057 inds = self.__ospa_inds 

1058 start_dist = self.get_state_dist(tt) 

1059 start_dist[self._cur_ind, :] = state.flatten() 

1060 start_dist = start_dist[:, inds].T.reshape( 

1061 (len(inds), 1, start_dist.shape[0]) 

1062 ) 

1063 

1064 end_dist = self.end_dist[:, inds].T.reshape( 

1065 (len(inds), 1, self.end_dist.shape[0]) 

1066 ) 

1067 

1068 cost = calculate_ospa( 

1069 start_dist, 

1070 end_dist, 

1071 self.__ospa_cutoff, 

1072 1, 

1073 core_method=self.__ospa_core, 

1074 )[0].item() 

1075 if self._non_quad_fun is not None: 

1076 cost += self.non_quad_weight * self._non_quad_fun( 

1077 tt, state, ctrl_input, end_state, is_initial, is_final, *args 

1078 ) 

1079 return cost 

1080 

1081 return non_quadratic_fun 

1082 

1083 def set_cost_model( 

1084 self, 

1085 ospa_inds=None, 

1086 ospa_cutoff=None, 

1087 ospa_method=None, 

1088 non_quad_fun=None, 

1089 quad_modifier=None, 

1090 non_quad_weight=None, 

1091 ): 

1092 """Set the parameters for the cost non-quadratic function. 

1093 

1094 This sets the parameters for the built-in non-quadratic cost function 

1095 and allows specifying an additional non-quadratic function (see 

1096 :meth:`gncpy.control.ELQR`) for the agents. Any additional non-quadratic 

1097 terms must be set here instead of within the single control model directly 

1098 due to the built-in multi-agent non-quadratic cost. 

1099 

1100 Parameters 

1101 ---------- 

1102 non_quad_fun : callable, optional 

1103 Optional additional non-quadratic cost. Must have the same form as 

1104 that of the single agent controller. See :meth:`gncpy.control.ELQR`. 

1105 The default is None. 

1106 quad_modifier : callable, optional 

1107 Function to modify the quadratization process, see 

1108 :meth:`gncpy.control.ELQR`. The default is None. 

1109 non_quad_weight : float 

1110 Additional scaling to be applied to the additional non_quadratic 

1111 cost (the non built-in term). The default is None 

1112 """ 

1113 if ospa_inds is not None: 

1114 self.__ospa_inds = ospa_inds 

1115 

1116 if ospa_cutoff is not None: 

1117 self.__ospa_cutoff = ospa_cutoff 

1118 

1119 if ospa_method is not None: 

1120 self.__ospa_core = ospa_method 

1121 

1122 if non_quad_fun is not None: 

1123 self._non_quad_fun = non_quad_fun 

1124 

1125 if quad_modifier is not None: 

1126 self._singleELQR.set_cost_model( 

1127 quad_modifier=quad_modifier, skip_validity_check=True 

1128 ) 

1129 

1130 if non_quad_weight is not None: 

1131 self.non_quad_weight = non_quad_weight 

1132 

1133 def find_end_state(self, tt, cur_ind, cur_state): 

1134 """Finds the ending state for the given current state. 

1135 

1136 Parameters 

1137 ---------- 

1138 tt : float 

1139 Current timestep 

1140 cur_ind : int 

1141 Current index into the state distribution 

1142 cur_state : N x 1 numpy array 

1143 Current state. 

1144 

1145 Returns 

1146 ------- 

1147 N x 1 numpy array 

1148 Best ending state given the current state. 

1149 """ 

1150 if self.__ospa_inds is None: 

1151 inds = [i for i in range(cur_state.size)] 

1152 else: 

1153 inds = self.__ospa_inds 

1154 dist = self.get_state_dist(tt) 

1155 dist[cur_ind] = cur_state.ravel() 

1156 dist = dist[:, inds] 

1157 

1158 end_dist = self.end_dist[:, inds] 

1159 if end_dist.shape[0] < dist.shape[0]: 

1160 if end_dist.shape[0] > 1: 

1161 center = np.mean(end_dist, axis=0).reshape((1, -1)) 

1162 else: 

1163 # direction from end to current state 

1164 direc = cur_state.ravel()[inds] - end_dist[0] 

1165 mag = np.linalg.norm(direc) 

1166 if mag > 1e-16: 

1167 direc /= mag 

1168 else: 

1169 # current agent is at the only target, use all agents center instead 

1170 a_cent = np.mean(dist, axis=0) 

1171 direc = a_cent - end_dist[0] 

1172 mag = np.linalg.norm(direc) 

1173 if mag > 1e-16: 

1174 direc /= mag 

1175 else: 

1176 # all agents center is at the only target...just pick x direction 

1177 direc = np.zeros(len(inds)) 

1178 direc[0] = 1 

1179 direc *= self.loiter_dist 

1180 center = (end_dist[0] + direc).reshape((1, -1)) 

1181 

1182 # add enough "fake" targets so there is a 1-to-1 matching 

1183 n_miss = dist.shape[0] - end_dist.shape[0] 

1184 end_dist = np.vstack((end_dist, center * np.ones((n_miss, len(inds))))) 

1185 

1186 distances, a_exists, t_exists = calculate_ospa( 

1187 dist.T.reshape((len(inds), 1, -1)), 

1188 end_dist.T.reshape((len(inds), 1, -1)), 

1189 self.__ospa_cutoff, 

1190 1, 

1191 core_method=self.__ospa_core, 

1192 use_empty=False, 

1193 )[6:] 

1194 cont_sub = distances[ 

1195 0 : np.sum(a_exists).astype(int), 0 : np.sum(t_exists).astype(int), 0 

1196 ] 

1197 tar_inds = linear_sum_assignment(cont_sub)[1] 

1198 

1199 min_ind = tar_inds[cur_ind] 

1200 if min_ind >= self.end_dist.shape[0]: 

1201 out = np.zeros((self.end_dist.shape[1], 1)) 

1202 out[inds] = end_dist[min_ind].reshape(out[inds].shape) 

1203 else: 

1204 out = self.end_dist[min_ind, :].reshape((-1, 1)) 

1205 return out 

1206 

1207 def init_elqr_lst(self, tt, start_dist): 

1208 """Initialize the list of single agent ELQR controllers. 

1209 

1210 Parameters 

1211 ---------- 

1212 tt : float 

1213 current time. 

1214 start_dist : Na x N numpy array 

1215 Starting states, one per row. 

1216 

1217 Returns 

1218 ------- 

1219 num_timesteps : int 

1220 total number of timesteps. 

1221 """ 

1222 self._elqr_lst = [] 

1223 # initialize with fake end state 

1224 for ind, s in enumerate(start_dist): 

1225 p = {} 

1226 p["elqr"] = deepcopy(self._singleELQR) 

1227 end_state = np.zeros((s.size, 1)) 

1228 p["old_cost"], num_timesteps, p["traj"], self._time_vec = p["elqr"].reset( 

1229 tt, s.reshape((-1, 1)), end_state 

1230 ) 

1231 self._elqr_lst.append(p) 

1232 

1233 # reset with proper end state 

1234 for ind, (s, p) in enumerate(zip(start_dist, self._elqr_lst)): 

1235 end_state = self.find_end_state(tt, ind, s) 

1236 p["elqr"].reset(tt, s.reshape((-1, 1)), end_state) 

1237 

1238 return num_timesteps 

1239 

1240 def targets_to_wayareas(self, end_states): 

1241 warn("targets_to_wayareas not used by ELQROSPA") 

1242 return None 

1243 

1244 def draw_init_states( 

1245 self, fig, states, plt_inds, marker, zorder, color=None, cmap=None 

1246 ): 

1247 kwargs = dict(marker=marker, zorder=zorder,) 

1248 if color is not None: 

1249 kwargs["color"] = color 

1250 

1251 for c_ind, s in enumerate(states): 

1252 if cmap is not None: 

1253 kwargs["color"] = cmap(c_ind) 

1254 fig.axes[0].scatter(s[plt_inds[0]], s[plt_inds[1]], **kwargs) 

1255 

1256 def reset(self, start_dist): 

1257 pass 

1258 

1259 def create_outputs( 

1260 self, 

1261 start_dist, 

1262 num_timesteps, 

1263 state_args, 

1264 ctrl_args, 

1265 cost_args, 

1266 inv_state_args, 

1267 inv_ctrl_args, 

1268 show_animation, 

1269 fig, 

1270 plt_inds, 

1271 cmap, 

1272 ): 

1273 costs = [] 

1274 state_trajs = [] 

1275 ctrl_signals = [] 

1276 for c_ind, (s, params) in enumerate(zip(start_dist, self._elqr_lst)): 

1277 self.output_helper( 

1278 c_ind, 

1279 num_timesteps, 

1280 s.reshape((-1, 1)), 

1281 params, 

1282 state_args, 

1283 ctrl_args, 

1284 cost_args, 

1285 inv_state_args, 

1286 inv_ctrl_args, 

1287 costs, 

1288 state_trajs, 

1289 ctrl_signals, 

1290 show_animation, 

1291 fig, 

1292 plt_inds, 

1293 cmap, 

1294 ) 

1295 

1296 return state_trajs, costs, ctrl_signals 

1297 

1298 def plan(self, tt, start_dist, end_dist, **kwargs): 

1299 return super().plan(tt, start_dist, end_dist, **kwargs)