Coverage for src/carbs/swarm_estimator/tracker.py: 70%

3695 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-15 06:48 +0000

1"""Implements RFS tracking algorithms. 

2 

3This module contains the classes and data structures 

4for RFS tracking related algorithms. 

5""" 

6 

7import gncpy.filters 

8import numpy as np 

9import numpy.linalg as la 

10import numpy.random as rnd 

11import matplotlib.pyplot as plt 

12from typing import Iterable 

13from matplotlib.patches import Ellipse 

14import matplotlib.animation as animation 

15import abc 

16from copy import deepcopy 

17import warnings 

18import itertools 

19 

20from carbs.utilities.graphs import ( 

21 k_shortest, 

22 murty_m_best, 

23 murty_m_best_all_meas_assigned, 

24) 

25from carbs.utilities.sampling import gibbs, mm_gibbs 

26 

27from gncpy.math import log_sum_exp, get_elem_sym_fnc 

28import gncpy.plotting as pltUtil 

29import gncpy.filters as gfilts 

30import gncpy.errors as gerr 

31 

32import serums.models as smodels 

33from serums.enums import SingleObjectDistance 

34from serums.distances import calculate_ospa, calculate_ospa2, calculate_gospa 

35 

36 

37class RandomFiniteSetBase(metaclass=abc.ABCMeta): 

38 """Generic base class for RFS based filters. 

39 

40 Attributes 

41 ---------- 

42 filter : gncpy.filters.BayesFilter 

43 Filter handling dynamics 

44 prob_detection : float 

45 Modeled probability an object is detected 

46 prob_survive : float 

47 Modeled probability of object survival 

48 birth_terms : list 

49 List of terms in the birth model 

50 clutter_rate : float 

51 Rate of clutter 

52 clutter_density : float 

53 Density of clutter distribution 

54 inv_chi2_gate : float 

55 Chi squared threshold for gating the measurements 

56 save_covs : bool 

57 Save covariance matrix for each state during state extraction 

58 debug_plots : bool 

59 Saves data needed for extra debugging plots 

60 ospa : numpy array 

61 Calculated OSPA value for the given truth data. Must be manually updated 

62 by a function call. 

63 ospa_localization : numpy array 

64 Calculated OSPA value for the given truth data. Must be manually updated 

65 by a function call. 

66 ospa_cardinality : numpy array 

67 Calculated OSPA value for the given truth data. Must be manually updated 

68 by a function call. 

69 enable_spawning : bool 

70 Flag for enabling spawning. 

71 spawn_cov : N x N numpy array 

72 Covariance for spawned targets. 

73 spawn_weight : float 

74 Weight for spawned targets. 

75 """ 

76 

77 def __init__( 

78 self, 

79 in_filter: gncpy.filters.BayesFilter = None, 

80 prob_detection: float = 1, 

81 prob_survive: float = 1, 

82 birth_terms: list = None, 

83 clutter_rate: float = 0, 

84 clutter_den: float = 0, 

85 inv_chi2_gate: float = 0, 

86 save_covs: bool = False, 

87 debug_plots: bool = False, 

88 enable_spawning: bool = False, 

89 spawn_cov: np.ndarray = None, 

90 spawn_weight: float = None, 

91 ): 

92 """Initialize an object. 

93 

94 Parameters 

95 ---------- 

96 in_filter 

97 Inner filter object. 

98 prob_detection 

99 Probability of detection. 

100 prob_survive 

101 Probability of survival. 

102 birth_terms 

103 Birth model. 

104 clutter_rate 

105 Clutter rate per scan. 

106 clutter_den 

107 Clutter density. 

108 inv_chi2_gate 

109 Inverse Chi^2 gating threshold. 

110 save_covs 

111 Flag for saving covariances. 

112 debug_plots 

113 Flag for enabling debug plots. 

114 enable_spawning 

115 Flag for enabling spawning. 

116 spawn_cov 

117 Covariance for spawned targets. 

118 spawn_weight 

119 Weight for spawned targets. 

120 """ 

121 if birth_terms is None: 

122 birth_terms = [] 

123 self.filter = deepcopy(in_filter) 

124 self.prob_detection = prob_detection 

125 self.prob_survive = prob_survive 

126 self.birth_terms = deepcopy(birth_terms) 

127 self.clutter_rate = clutter_rate 

128 if isinstance(clutter_den, np.ndarray): 

129 clutter_den = clutter_den.item() 

130 self.clutter_den = clutter_den 

131 

132 self.inv_chi2_gate = inv_chi2_gate 

133 

134 self.save_covs = save_covs 

135 self.debug_plots = debug_plots 

136 

137 self.ospa = None 

138 self.ospa_localization = None 

139 self.ospa_cardinality = None 

140 self._ospa_params = {} 

141 self._gospa_params = {} 

142 

143 self._states = [] # local copy for internal modification 

144 self._meas_tab = ( 

145 [] 

146 ) # list of lists, one per timestep, inner is all meas at time 

147 self._covs = [] # local copy for internal modification 

148 self.enable_spawning = enable_spawning 

149 self.spawn_cov = spawn_cov 

150 self.spawn_weight = spawn_weight 

151 

152 super().__init__() 

153 

154 @property 

155 def ospa_method(self): 

156 """The distance metric used in the OSPA calculation (read only).""" 

157 if "core" in self._ospa_params: 

158 return self._ospa_params["core"] 

159 else: 

160 return None 

161 

162 @ospa_method.setter 

163 def ospa_method(self, val): 

164 warnings.warn("OSPA method is read only. SKIPPING") 

165 

166 @abc.abstractmethod 

167 def save_filter_state(self): 

168 """Generic method for saving key filter variables. 

169 

170 This must be overridden in the inherited class. It is recommended to keep 

171 the signature the same to allow for standardized implemenation of 

172 wrapper classes. This should return a single variable that can be passed 

173 to the loading function to setup a filter to the same internal state 

174 as the current instance when this function was called. 

175 """ 

176 filt_state = {} 

177 if self.filter is not None: 

178 filt_state["filter"] = (type(self.filter), self.filter.save_filter_state()) 

179 else: 

180 filt_state["filter"] = (None, self.filter) 

181 filt_state["prob_detection"] = self.prob_detection 

182 filt_state["prob_survive"] = self.prob_survive 

183 filt_state["birth_terms"] = self.birth_terms 

184 filt_state["clutter_rate"] = self.clutter_rate 

185 filt_state["clutter_den"] = self.clutter_den 

186 filt_state["inv_chi2_gate"] = self.inv_chi2_gate 

187 filt_state["save_covs"] = self.save_covs 

188 filt_state["debug_plots"] = self.debug_plots 

189 filt_state["ospa"] = self.ospa 

190 filt_state["ospa_localization"] = self.ospa_localization 

191 filt_state["ospa_cardinality"] = self.ospa_cardinality 

192 

193 filt_state["_states"] = self._states 

194 filt_state["_meas_tab"] = self._meas_tab 

195 filt_state["_covs"] = self._covs 

196 filt_state["_ospa_params"] = self._ospa_params 

197 

198 return filt_state 

199 

200 @abc.abstractmethod 

201 def load_filter_state(self, filt_state): 

202 """Generic method for saving key filter variables. 

203 

204 This must be overridden in the inherited class. It is recommended to keep 

205 the signature the same to allow for standardized implemenation of 

206 wrapper classes. This initialize all internal variables saved by the 

207 filter save function such that a new instance would generate the same 

208 output as the original instance that called the save function. 

209 """ 

210 cls_type = filt_state["filter"][0] 

211 if cls_type is not None: 

212 self.filter = cls_type() 

213 self.filter.load_filter_state(filt_state["filter"][1]) 

214 else: 

215 self.filter = filt_state["filter"] 

216 self.prob_detection = filt_state["prob_detection"] 

217 self.prob_survive = filt_state["prob_survive"] 

218 self.birth_terms = filt_state["birth_terms"] 

219 self.clutter_rate = filt_state["clutter_rate"] 

220 self.clutter_den = filt_state["clutter_den"] 

221 self.inv_chi2_gate = filt_state["inv_chi2_gate"] 

222 self.save_covs = filt_state["save_covs"] 

223 self.debug_plots = filt_state["debug_plots"] 

224 self.ospa = filt_state["ospa"] 

225 self.ospa_localization = filt_state["ospa_localization"] 

226 self.ospa_cardinality = filt_state["ospa_cardinality"] 

227 

228 self._states = filt_state["_states"] 

229 self._meas_tab = filt_state["_meas_tab"] 

230 self._covs = filt_state["_covs"] 

231 self._ospa_params = filt_state["_ospa_params"] 

232 

233 @property 

234 def prob_miss_detection(self): 

235 """Compliment of :py:attr:`.swarm_estimator.RandomFiniteSetBase.prob_detection`.""" 

236 return 1 - self.prob_detection 

237 

238 @property 

239 def prob_death(self): 

240 """Compliment of :attr:`carbs.swarm_estimator.RandomFinitSetBase.prob_survive`.""" 

241 return 1 - self.prob_survive 

242 

243 @property 

244 def num_birth_terms(self): 

245 """Number of terms in the birth model.""" 

246 return len(self.birth_terms) 

247 

248 @abc.abstractmethod 

249 def predict(self, t, **kwargs): 

250 """Abstract method for the prediction step. 

251 

252 This must be overridden in the inherited class. It is recommended to 

253 keep the same structure/order for the arguments for consistency 

254 between the inherited classes. 

255 """ 

256 pass 

257 

258 @abc.abstractmethod 

259 def correct(self, t, m, **kwargs): 

260 """Abstract method for the correction step. 

261 

262 This must be overridden in the inherited class. It is recommended to 

263 keep the same structure/order for the arguments for consistency 

264 between the inherited classes. 

265 """ 

266 pass 

267 

268 @abc.abstractmethod 

269 def extract_states(self, **kwargs): 

270 """Abstract method for extracting states.""" 

271 pass 

272 

273 @abc.abstractmethod 

274 def cleanup(self, **kwargs): 

275 """Abstract method that performs the cleanup step of the filter. 

276 

277 This must be overridden in the inherited class. It is recommended to 

278 keep the same structure/order for the arguments for consistency 

279 between the inherited classes. 

280 """ 

281 pass 

282 

283 def _gate_meas(self, meas, means, covs, meas_mat_args={}, est_meas_args={}): 

284 """Gates measurements based on current estimates. 

285 

286 Notes 

287 ----- 

288 Gating is performed based on a Gaussian noise model. 

289 See :cite:`Cox1993_AReviewofStatisticalDataAssociationTechniquesforMotionCorrespondence` 

290 for details on the chi squared test used. 

291 

292 Parameters 

293 ---------- 

294 meas : list 

295 2d numpy arrrays of each measurement. 

296 means : list 

297 2d numpy arrays of each mean. 

298 covs : list 

299 2d numpy array of each covariance. 

300 meas_mat_args : dict, optional 

301 keyword arguments to pass to the inner filters get measurement 

302 matrix function. The default is {}. 

303 est_meas_args : TYPE, optional 

304 keyword arguments to pass to the inner filters get estimate 

305 matrix function. The default is {}. 

306 

307 Returns 

308 ------- 

309 list 

310 2d numpy arrays of valid measurements. 

311 

312 """ 

313 if len(meas) == 0: 

314 return [] 

315 valid = [] 

316 for m, p in zip(means, covs): 

317 meas_mat = self.filter.get_meas_mat(m, **meas_mat_args) 

318 est = self.filter.get_est_meas(m, **est_meas_args) 

319 meas_pred_cov = meas_mat @ p @ meas_mat.T + self.filter.meas_noise 

320 meas_pred_cov = (meas_pred_cov + meas_pred_cov.T) / 2 

321 v_s = la.cholesky(meas_pred_cov.T) 

322 inv_sqrt_m_cov = la.inv(v_s) 

323 

324 for ii, z in enumerate(meas): 

325 if ii in valid: 

326 continue 

327 inov = z - est 

328 dist = np.sum((inv_sqrt_m_cov.T @ inov) ** 2) 

329 if dist < self.inv_chi2_gate: 

330 valid.append(ii) 

331 valid.sort() 

332 return [meas[ii] for ii in valid] 

333 

334 def _ospa_setup_tmat(self, truth, state_dim, true_covs, state_inds): 

335 # get sizes 

336 num_timesteps = len(truth) 

337 num_objs = 0 

338 

339 for lst in truth: 

340 num_objs = np.max( 

341 [ 

342 num_objs, 

343 np.sum([_x is not None and _x.size > 0 for _x in lst]).astype(int), 

344 ] 

345 ) 

346 # create matrices 

347 true_mat = np.nan * np.ones((state_dim, num_timesteps, num_objs)) 

348 true_cov_mat = np.nan * np.ones((state_dim, state_dim, num_timesteps, num_objs)) 

349 

350 for tt, lst in enumerate(truth): 

351 obj_num = 0 

352 for s in lst: 

353 if s is not None and s.size > 0: 

354 true_mat[:, tt, obj_num] = s.ravel()[state_inds] 

355 obj_num += 1 

356 if true_covs is not None: 

357 for tt, lst in enumerate(true_covs): 

358 obj_num = 0 

359 for c in lst: 

360 if c is not None and truth[tt][obj_num].size > 0: 

361 true_cov_mat[:, :, tt, obj_num] = c[state_inds][:, state_inds] 

362 obj_num += 1 

363 return true_mat, true_cov_mat 

364 

365 def _ospa_setup_emat(self, state_dim, state_inds): 

366 # get sizes 

367 num_timesteps = len(self._states) 

368 num_objs = 0 

369 

370 for lst in self._states: 

371 num_objs = np.max( 

372 [num_objs, np.sum([_x is not None for _x in lst]).astype(int)] 

373 ) 

374 # create matrices 

375 est_mat = np.nan * np.ones((state_dim, num_timesteps, num_objs)) 

376 est_cov_mat = np.nan * np.ones((state_dim, state_dim, num_timesteps, num_objs)) 

377 

378 for tt, lst in enumerate(self._states): 

379 for obj_num, s in enumerate(lst): 

380 if s is not None and s.size > 0: 

381 est_mat[:, tt, obj_num] = s.ravel()[state_inds] 

382 if self.save_covs: 

383 for tt, lst in enumerate(self._covs): 

384 for obj_num, c in enumerate(lst): 

385 if c is not None and self._states[tt][obj_num].size > 0: 

386 est_cov_mat[:, :, tt, obj_num] = c[state_inds][:, state_inds] 

387 return est_mat, est_cov_mat 

388 

389 def _ospa_input_check(self, core_method, truth, true_covs): 

390 if core_method is None: 

391 core_method = SingleObjectDistance.EUCLIDEAN 

392 elif core_method is SingleObjectDistance.MAHALANOBIS and not self.save_covs: 

393 msg = "Must save covariances to calculate {:s} OSPA. Using {:s} instead" 

394 warnings.warn(msg.format(core_method, SingleObjectDistance.EUCLIDEAN)) 

395 core_method = SingleObjectDistance.EUCLIDEAN 

396 elif core_method is SingleObjectDistance.HELLINGER and true_covs is None: 

397 msg = "Must save covariances to calculate {:s} OSPA. Using {:s} instead" 

398 warnings.warn(msg.format(core_method, SingleObjectDistance.EUCLIDEAN)) 

399 core_method = SingleObjectDistance.EUCLIDEAN 

400 return core_method 

401 

402 def _ospa_find_s_dim(self, truth): 

403 state_dim = None 

404 for lst in truth: 

405 for _x in lst: 

406 if _x is not None: 

407 state_dim = _x.size 

408 break 

409 if state_dim is not None: 

410 break 

411 if state_dim is None: 

412 for lst in self._states: 

413 for _x in lst: 

414 if _x is not None: 

415 state_dim = _x.size 

416 break 

417 if state_dim is not None: 

418 break 

419 return state_dim 

420 

421 def calculate_ospa( 

422 self, 

423 truth: Iterable[Iterable[np.ndarray]], 

424 c: float, 

425 p: float, 

426 core_method: SingleObjectDistance = None, 

427 true_covs: Iterable[Iterable[np.ndarray]] = None, 

428 state_inds: Iterable[int] = None, 

429 ): 

430 """Calculates the OSPA distance between the truth at all timesteps. 

431 

432 Wrapper for :func:`serums.distances.calculate_ospa`. 

433 

434 Parameters 

435 ---------- 

436 truth : list 

437 Each element represents a timestep and is a list of N x 1 numpy array, 

438 one per true agent in the swarm. 

439 c : float 

440 Distance cutoff for considering a point properly assigned. This 

441 influences how cardinality errors are penalized. For :math:`p = 1` 

442 it is the penalty given false point estimate. 

443 p : int 

444 The power of the distance term. Higher values penalize outliers 

445 more. 

446 core_method : :class:`serums.enums.SingleObjectDistance`, Optional 

447 The main distance measure to use for the localization component. 

448 The default value of None implies :attr:`.SingleObjectDistance.EUCLIDEAN`. 

449 true_covs : list, Optional 

450 Each element represents a timestep and is a list of N x N numpy arrays 

451 corresonponding to the uncertainty about the true states. Note the 

452 order must be consistent with the truth data given. This is only 

453 needed for core methods :attr:`SingleObjectDistance.HELLINGER`. The defautl 

454 value is None. 

455 state_inds : list, optional 

456 Indices in the state vector to use, will be applied to the truth 

457 data as well. The default is None which means the full state is 

458 used. 

459 """ 

460 # error checking on optional input arguments 

461 core_method = self._ospa_input_check(core_method, truth, true_covs) 

462 

463 # setup data structures 

464 if state_inds is None: 

465 state_dim = self._ospa_find_s_dim(truth) 

466 state_inds = range(state_dim) 

467 else: 

468 state_dim = len(state_inds) 

469 if state_dim is None: 

470 warnings.warn("Failed to get state dimension. SKIPPING OSPA calculation") 

471 

472 nt = len(self._states) 

473 self.ospa = np.zeros(nt) 

474 self.ospa_localization = np.zeros(nt) 

475 self.ospa_cardinality = np.zeros(nt) 

476 self._ospa_params["core"] = core_method 

477 self._ospa_params["cutoff"] = c 

478 self._ospa_params["power"] = p 

479 return 

480 true_mat, true_cov_mat = self._ospa_setup_tmat( 

481 truth, state_dim, true_covs, state_inds 

482 ) 

483 est_mat, est_cov_mat = self._ospa_setup_emat(state_dim, state_inds) 

484 

485 # find OSPA 

486 ( 

487 self.ospa, 

488 self.ospa_localization, 

489 self.ospa_cardinality, 

490 self._ospa_params["core"], 

491 self._ospa_params["cutoff"], 

492 self._ospa_params["power"], 

493 ) = calculate_ospa( 

494 est_mat, 

495 true_mat, 

496 c, 

497 p, 

498 use_empty=True, 

499 core_method=core_method, 

500 true_cov_mat=true_cov_mat, 

501 est_cov_mat=est_cov_mat, 

502 )[ 

503 0:6 

504 ] 

505 

506 def calculate_gospa( 

507 self, 

508 truth: Iterable[Iterable[np.ndarray]], 

509 c: float, 

510 p: float, 

511 a: int, 

512 core_method: SingleObjectDistance = None, 

513 true_covs: Iterable[Iterable[np.ndarray]] = None, 

514 state_inds: Iterable[int] = None, 

515 ): 

516 """Calculates the OSPA distance between the truth at all timesteps. 

517 

518 Wrapper for :func:`serums.distances.calculate_ospa`. 

519 

520 Parameters 

521 ---------- 

522 truth : list 

523 Each element represents a timestep and is a list of N x 1 numpy array, 

524 one per true agent in the swarm. 

525 c : float 

526 Distance cutoff for considering a point properly assigned. This 

527 influences how cardinality errors are penalized. For :math:`p = 1` 

528 it is the penalty given false point estimate. 

529 p : int 

530 The power of the distance term. Higher values penalize outliers 

531 more. 

532 a : int 

533 The normalization factor of the distance term. Appropriately penalizes missed 

534 or false detection of tracks rather than normalizing by the total maximum 

535 cardinality. 

536 core_method : :class:`serums.enums.SingleObjectDistance`, Optional 

537 The main distance measure to use for the localization component. 

538 The default value of None implies :attr:`.SingleObjectDistance.EUCLIDEAN`. 

539 true_covs : list, Optional 

540 Each element represents a timestep and is a list of N x N numpy arrays 

541 corresonponding to the uncertainty about the true states. Note the 

542 order must be consistent with the truth data given. This is only 

543 needed for core methods :attr:`SingleObjectDistance.HELLINGER`. The defautl 

544 value is None. 

545 state_inds : list, optional 

546 Indices in the state vector to use, will be applied to the truth 

547 data as well. The default is None which means the full state is 

548 used. 

549 """ 

550 # error checking on optional input arguments 

551 core_method = self._ospa_input_check(core_method, truth, true_covs) 

552 

553 # setup data structures 

554 if state_inds is None: 

555 state_dim = self._ospa_find_s_dim(truth) 

556 state_inds = range(state_dim) 

557 else: 

558 state_dim = len(state_inds) 

559 if state_dim is None: 

560 warnings.warn("Failed to get state dimension. SKIPPING OSPA calculation") 

561 

562 nt = len(self._states) 

563 self.gospa = np.zeros(nt) 

564 self.gospa_localization = np.zeros(nt) 

565 self.gospa_cardinality = np.zeros(nt) 

566 self._gospa_params["core"] = core_method 

567 self._gospa_params["cutoff"] = c 

568 self._gospa_params["power"] = p 

569 self._gospa_params["normalization"] = a 

570 return 

571 true_mat, true_cov_mat = self._ospa_setup_tmat( 

572 truth, state_dim, true_covs, state_inds 

573 ) 

574 est_mat, est_cov_mat = self._ospa_setup_emat(state_dim, state_inds) 

575 

576 # find OSPA 

577 ( 

578 self.gospa, 

579 self.gospa_localization, 

580 self.gospa_cardinality, 

581 self._gospa_params["core"], 

582 self._gospa_params["cutoff"], 

583 self._gospa_params["power"], 

584 self._gospa_params["normalization"], 

585 ) = calculate_gospa( 

586 est_mat, 

587 true_mat, 

588 c, 

589 p, 

590 a, 

591 use_empty=True, 

592 core_method=core_method, 

593 true_cov_mat=true_cov_mat, 

594 est_cov_mat=est_cov_mat, 

595 )[ 

596 0:7 

597 ] 

598 

599 def _plt_ospa_hist(self, y_val, time_units, time, ttl, y_lbl, opts): 

600 fig = opts["f_hndl"] 

601 

602 if fig is None: 

603 fig = plt.figure() 

604 fig.add_subplot(1, 1, 1) 

605 if time is None: 

606 time = np.arange(y_val.size, dtype=int) 

607 fig.axes[0].grid(True) 

608 fig.axes[0].ticklabel_format(useOffset=False) 

609 fig.axes[0].plot(time, y_val) 

610 

611 pltUtil.set_title_label( 

612 fig, 0, opts, ttl=ttl, x_lbl="Time ({})".format(time_units), y_lbl=y_lbl 

613 ) 

614 fig.tight_layout() 

615 

616 return fig 

617 

618 def _plt_ospa_hist_subs(self, y_vals, time_units, time, ttl, y_lbls, opts): 

619 fig = opts["f_hndl"] 

620 new_plot = fig is None 

621 num_subs = len(y_vals) 

622 

623 if new_plot: 

624 fig = plt.figure() 

625 pltUtil.set_title_label(fig, 0, opts, ttl=ttl) 

626 for ax, (y_val, y_lbl) in enumerate(zip(y_vals, y_lbls)): 

627 if new_plot: 

628 if ax > 0: 

629 kwargs = {"sharex": fig.axes[0]} 

630 else: 

631 kwargs = {} 

632 fig.add_subplot(num_subs, 1, ax + 1, **kwargs) 

633 fig.axes[ax].grid(True) 

634 fig.axes[ax].ticklabel_format(useOffset=False) 

635 kwargs = {"y_lbl": y_lbl} 

636 if ax == len(y_vals) - 1: 

637 kwargs["x_lbl"] = "Time ({})".format(time_units) 

638 pltUtil.set_title_label(fig, ax, opts, **kwargs) 

639 if time is None: 

640 time = np.arange(y_val.size, dtype=int) 

641 fig.axes[ax].plot(time, y_val) 

642 if new_plot: 

643 fig.tight_layout() 

644 return fig 

645 

646 def plot_ospa_history( 

647 self, 

648 time_units="index", 

649 time=None, 

650 main_opts=None, 

651 sub_opts=None, 

652 plot_subs=True, 

653 ): 

654 """Plots the OSPA history. 

655 

656 This requires that the OSPA has been calcualted by the approriate 

657 function first. 

658 

659 Parameters 

660 ---------- 

661 time_units : string, optional 

662 Text representing the units of time in the plot. The default is 

663 'index'. 

664 time : numpy array, optional 

665 Vector to use for the x-axis of the plot. If none is given then 

666 vector indices are used. The default is None. 

667 main_opts : dict, optional 

668 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

669 function. Values implemented here are `f_hndl`, and any values 

670 relating to title/axis text formatting. The default of None implies 

671 the default options are used for the main plot. 

672 sub_opts : dict, optional 

673 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

674 function. Values implemented here are `f_hndl`, and any values 

675 relating to title/axis text formatting. The default of None implies 

676 the default options are used for the sub plot. 

677 plot_subs : bool, optional 

678 Flag indicating if the component statistics (cardinality and 

679 localization) should also be plotted. 

680 

681 Returns 

682 ------- 

683 figs : dict 

684 Dictionary of matplotlib figure objects the data was plotted on. 

685 """ 

686 if self.ospa is None: 

687 warnings.warn("OSPA must be calculated before plotting") 

688 return 

689 if main_opts is None: 

690 main_opts = pltUtil.init_plotting_opts() 

691 if sub_opts is None and plot_subs: 

692 sub_opts = pltUtil.init_plotting_opts() 

693 fmt = "{:s} OSPA (c = {:.1f}, p = {:d})" 

694 ttl = fmt.format( 

695 self._ospa_params["core"], 

696 self._ospa_params["cutoff"], 

697 self._ospa_params["power"], 

698 ) 

699 y_lbl = "OSPA" 

700 

701 figs = {} 

702 figs["OSPA"] = self._plt_ospa_hist( 

703 self.ospa, time_units, time, ttl, y_lbl, main_opts 

704 ) 

705 

706 if plot_subs: 

707 fmt = "{:s} OSPA Components (c = {:.1f}, p = {:d})" 

708 ttl = fmt.format( 

709 self._ospa_params["core"], 

710 self._ospa_params["cutoff"], 

711 self._ospa_params["power"], 

712 ) 

713 y_lbls = ["Localiztion", "Cardinality"] 

714 figs["OSPA_subs"] = self._plt_ospa_hist_subs( 

715 [self.ospa_localization, self.ospa_cardinality], 

716 time_units, 

717 time, 

718 ttl, 

719 y_lbls, 

720 main_opts, 

721 ) 

722 return figs 

723 

724 def plot_gospa_history( 

725 self, 

726 time_units="index", 

727 time=None, 

728 main_opts=None, 

729 sub_opts=None, 

730 plot_subs=True, 

731 ): 

732 """Plots the GOSPA history. 

733 

734 This requires that the GOSPA has been calcualted by the approriate 

735 function first. 

736 

737 Parameters 

738 ---------- 

739 time_units : string, optional 

740 Text representing the units of time in the plot. The default is 

741 'index'. 

742 time : numpy array, optional 

743 Vector to use for the x-axis of the plot. If none is given then 

744 vector indices are used. The default is None. 

745 main_opts : dict, optional 

746 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

747 function. Values implemented here are `f_hndl`, and any values 

748 relating to title/axis text formatting. The default of None implies 

749 the default options are used for the main plot. 

750 sub_opts : dict, optional 

751 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

752 function. Values implemented here are `f_hndl`, and any values 

753 relating to title/axis text formatting. The default of None implies 

754 the default options are used for the sub plot. 

755 plot_subs : bool, optional 

756 Flag indicating if the component statistics (cardinality and 

757 localization) should also be plotted. 

758 

759 Returns 

760 ------- 

761 figs : dict 

762 Dictionary of matplotlib figure objects the data was plotted on. 

763 """ 

764 if self.gospa is None: 

765 warnings.warn("GOSPA must be calculated before plotting") 

766 return 

767 if main_opts is None: 

768 main_opts = pltUtil.init_plotting_opts() 

769 if sub_opts is None and plot_subs: 

770 sub_opts = pltUtil.init_plotting_opts() 

771 fmt = "{:s} GOSPA (c = {:.1f}, p = {:d}, a = {:d})" 

772 ttl = fmt.format( 

773 self._gospa_params["core"], 

774 self._gospa_params["cutoff"], 

775 self._gospa_params["power"], 

776 self._gospa_params["normalization"], 

777 ) 

778 y_lbl = "GOSPA" 

779 

780 figs = {} 

781 figs["GOSPA"] = self._plt_ospa_hist( 

782 self.gospa, time_units, time, ttl, y_lbl, main_opts 

783 ) 

784 

785 if plot_subs: 

786 fmt = "{:s} GOSPA Components (c = {:.1f}, p = {:d}, a = {:d})" 

787 ttl = fmt.format( 

788 self._gospa_params["core"], 

789 self._gospa_params["cutoff"], 

790 self._gospa_params["power"], 

791 self._gospa_params["normalization"], 

792 ) 

793 y_lbls = ["Localiztion", "Cardinality"] 

794 figs["GOSPA_subs"] = self._plt_ospa_hist_subs( 

795 [self.gospa_localization, self.gospa_cardinality], 

796 time_units, 

797 time, 

798 ttl, 

799 y_lbls, 

800 main_opts, 

801 ) 

802 return figs 

803 

804 

805class ProbabilityHypothesisDensity(RandomFiniteSetBase): 

806 """Implements the Probability Hypothesis Density filter. 

807 

808 The kwargs in the constructor are passed through to the parent constructor. 

809 

810 Notes 

811 ----- 

812 The filter implementation is based on :cite:`Vo2006_TheGaussianMixtureProbabilityHypothesisDensityFilter` 

813 

814 Attributes 

815 ---------- 

816 gating_on : bool 

817 flag indicating if measurement gating should be performed. The 

818 default is False. 

819 inv_chi2_gate : float 

820 threshold for the chi squared test in the measurement gating. The 

821 default is 0. 

822 extract_threshold : float 

823 threshold for extracting the state. The default is 0.5. 

824 prune_threshold : float 

825 threshold for removing hypotheses. The default is 10**-5. 

826 merge_threshold : float 

827 threshold for merging hypotheses. The default is 4. 

828 max_gauss : int 

829 max number of gaussians to use. The default is 100. 

830 

831 """ 

832 

833 def __init__( 

834 self, 

835 gating_on=False, 

836 inv_chi2_gate=0, 

837 extract_threshold=0.5, 

838 prune_threshold=1e-5, 

839 merge_threshold=4, 

840 max_gauss=100, 

841 **kwargs, 

842 ): 

843 self.gating_on = gating_on 

844 self.inv_chi2_gate = inv_chi2_gate 

845 self.extract_threshold = extract_threshold 

846 self.prune_threshold = prune_threshold 

847 self.merge_threshold = merge_threshold 

848 self.max_gauss = max_gauss 

849 

850 self._gaussMix = smodels.GaussianMixture() 

851 

852 super().__init__(**kwargs) 

853 

854 def save_filter_state(self): 

855 """Saves filter variables so they can be restored later.""" 

856 filt_state = super().save_filter_state() 

857 

858 raise RuntimeError("Not implmented yet") 

859 return filt_state 

860 

861 def load_filter_state(self, filt_state): 

862 """Initializes filter using saved filter state. 

863 

864 Attributes 

865 ---------- 

866 filt_state : dict 

867 Dictionary generated by :meth:`save_filter_state`. 

868 """ 

869 super().load_filter_state(filt_state) 

870 

871 raise RuntimeError("Not implmented yet") 

872 

873 @property 

874 def states(self): 

875 """Read only list of extracted states. 

876 

877 This is a list with 1 element per timestep, and each element is a list 

878 of the best states extracted at that timestep. The order of each 

879 element corresponds to the label order. 

880 """ 

881 if len(self._states) > 0: 

882 return self._states[-1] 

883 else: 

884 return [] 

885 

886 @property 

887 def covariances(self): 

888 """Read only list of extracted covariances. 

889 

890 This is a list with 1 element per timestep, and each element is a list 

891 of the best covariances extracted at that timestep. The order of each 

892 element corresponds to the state order. 

893 

894 Warns 

895 ----- 

896 RuntimeWarning 

897 If the class is not saving the covariances, and returns an 

898 empty list 

899 """ 

900 if not self.save_covs: 

901 warnings.warn("Not saving covariances") 

902 return [] 

903 if len(self._covs) > 0: 

904 return self._covs[-1] 

905 else: 

906 return [] 

907 

908 @property 

909 def cardinality(self): 

910 """Read only cardinality of the RFS.""" 

911 if len(self._states) == 0: 

912 return 0 

913 else: 

914 return len(self._states[-1]) 

915 

916 def _gen_spawned_targets(self, gaussMix): 

917 if self.spawn_cov is not None and self.spawn_weight is not None: 

918 gauss_list = [ 

919 smodels.Gaussian(mean=m.copy(), covariance=self.spawn_cov.copy()) 

920 for m in gaussMix.means 

921 ] 

922 return smodels.GaussianMixture( 

923 distributions=gauss_list, 

924 weights=[self.spawn_weight for ii in range(len(gauss_list))], 

925 ) 

926 else: 

927 raise RuntimeError( 

928 "self.spawn_cov and self.spawn_weight must be specified." 

929 ) 

930 

931 def predict(self, timestep, filt_args={}): 

932 """Prediction step of the PHD filter. 

933 

934 This predicts new hypothesis, and propogates them to the next time 

935 step. It also updates the cardinality distribution. Because this calls 

936 the inner filter's predict function, the keyword arguments must contain 

937 any information needed by that function. 

938 

939 

940 Parameters 

941 ---------- 

942 timestep: float 

943 current timestep 

944 filt_args : dict, optional 

945 Passed to the inner filter. The default is {}. 

946 

947 Returns 

948 ------- 

949 None. 

950 

951 """ 

952 if self.enable_spawning: 

953 spawn_mix = self._gen_spawned_targets(self._gaussMix) 

954 

955 self._gaussMix = self._predict_prob_density(timestep, self._gaussMix, filt_args) 

956 

957 if self.enable_spawning: 

958 self._gaussMix.add_components( 

959 spawn_mix.means, spawn_mix.covariances, spawn_mix.weights 

960 ) 

961 

962 for gm in self.birth_terms: 

963 self._gaussMix.add_components(gm.means, gm.covariances, gm.weights) 

964 

965 def _predict_prob_density(self, timestep, probDensity, filt_args): 

966 """Predicts the probability density. 

967 

968 Loops over all elements in a probability distribution and performs 

969 the filter prediction. 

970 

971 Parameters 

972 ---------- 

973 timestep: float 

974 current timestep 

975 probDensity : :class:`serums.models.GaussianMixture` 

976 Probability density to perform prediction on. 

977 filt_args : dict 

978 Passed directly to the inner filter. 

979 

980 Returns 

981 ------- 

982 gm : :class:`serums.models.GaussianMixture` 

983 predicted Gaussian mixture. 

984 

985 """ 

986 weights = [self.prob_survive * x for x in probDensity.weights.copy()] 

987 n_terms = len(probDensity.means) 

988 covariances = [None] * n_terms 

989 means = [None] * n_terms 

990 for ii, (m, P) in enumerate(zip(probDensity.means, probDensity.covariances)): 

991 self.filter.cov = P 

992 n_mean = self.filter.predict(timestep, m, **filt_args) 

993 covariances[ii] = self.filter.cov.copy() 

994 means[ii] = n_mean 

995 return smodels.GaussianMixture( 

996 means=means, covariances=covariances, weights=weights 

997 ) 

998 

999 def correct( 

1000 self, timestep, meas_in, meas_mat_args={}, est_meas_args={}, filt_args={} 

1001 ): 

1002 """Correction step of the PHD filter. 

1003 

1004 This corrects the hypotheses based on the measurements and gates the 

1005 measurements according to the class settings. It also updates the 

1006 cardinality distribution. 

1007 

1008 

1009 Parameters 

1010 ---------- 

1011 timestep: float 

1012 current timestep 

1013 meas_in : list 

1014 2d numpy arrays representing a measurement. 

1015 meas_mat_args : dict, optional 

1016 keyword arguments to pass to the inner filters get measurement 

1017 matrix function. Only used if gating is on. The default is {}. 

1018 est_meas_args : TYPE, optional 

1019 keyword arguments to pass to the inner filters estimate 

1020 measurements function. Only used if gating is on. The default is {}. 

1021 filt_args : dict, optional 

1022 keyword arguments to pass to the inner filters correct function. 

1023 The default is {}. 

1024 

1025 Todo 

1026 ---- 

1027 Fix the measurement gating 

1028 

1029 Returns 

1030 ------- 

1031 None. 

1032 

1033 """ 

1034 meas = deepcopy(meas_in) 

1035 

1036 if self.gating_on: 

1037 meas = self._gate_meas( 

1038 meas, 

1039 self._gaussMix.means, 

1040 self._gaussMix.covariances, 

1041 meas_mat_args, 

1042 est_meas_args, 

1043 ) 

1044 self._meas_tab.append(meas) 

1045 

1046 gmix = deepcopy(self._gaussMix) 

1047 gmix.weights = [self.prob_miss_detection * x for x in gmix.weights] 

1048 gm = self._correct_prob_density(timestep, meas, self._gaussMix, filt_args) 

1049 gm.add_components(gmix.means, gmix.covariances, gmix.weights) 

1050 

1051 self._gaussMix = gm 

1052 

1053 def _correct_prob_density(self, timestep, meas, probDensity, filt_args): 

1054 """Corrects the probability densities. 

1055 

1056 Loops over all elements in a probability distribution and preforms 

1057 the filter correction. 

1058 

1059 Parameters 

1060 ---------- 

1061 meas : list 

1062 2d numpy arrays of each measurement. 

1063 probDensity : :py:class:`serums.models.GaussianMixture` 

1064 probability density to run correction on. 

1065 filt_args : dict 

1066 arguements to pass to the inner filter correct function. 

1067 

1068 Returns 

1069 ------- 

1070 gm : :py:class:`serums.models.GaussianMixture` 

1071 corrected probability density. 

1072 

1073 """ 

1074 means = [] 

1075 covariances = [] 

1076 weights = [] 

1077 det_weights = [self.prob_detection * x for x in probDensity.weights] 

1078 for z in meas: 

1079 w_lst = [] 

1080 for jj in range(0, len(probDensity.means)): 

1081 self.filter.cov = probDensity.covariances[jj] 

1082 state = probDensity.means[jj] 

1083 (mean, qz) = self.filter.correct(timestep, z, state, **filt_args) 

1084 cov = self.filter.cov 

1085 w = qz * det_weights[jj] 

1086 means.append(mean) 

1087 covariances.append(cov) 

1088 w_lst.append(w) 

1089 weights.extend( 

1090 [x / (self.clutter_rate * self.clutter_den + sum(w_lst)) for x in w_lst] 

1091 ) 

1092 return smodels.GaussianMixture( 

1093 means=means, covariances=covariances, weights=weights 

1094 ) 

1095 

1096 def _prune(self): 

1097 """Removes hypotheses below a threshold. 

1098 

1099 This should be called once per time step after the correction and 

1100 before the state extraction. 

1101 """ 

1102 inds = np.where(np.asarray(self._gaussMix.weights) < self.prune_threshold)[0] 

1103 self._gaussMix.remove_components(inds.flatten().tolist()) 

1104 return inds 

1105 

1106 def _merge(self): 

1107 """Merges nearby hypotheses.""" 

1108 loop_inds = set(range(0, len(self._gaussMix.means))) 

1109 

1110 w_lst = [] 

1111 m_lst = [] 

1112 p_lst = [] 

1113 while len(loop_inds) > 0: 

1114 jj = int(np.argmax(self._gaussMix.weights)) 

1115 comp_inds = [] 

1116 inv_cov = la.inv(self._gaussMix.covariances[jj]) 

1117 for ii in loop_inds: 

1118 diff = self._gaussMix.means[ii] - self._gaussMix.means[jj] 

1119 val = diff.T @ inv_cov @ diff 

1120 if val <= self.merge_threshold: 

1121 comp_inds.append(ii) 

1122 w_new = sum([self._gaussMix.weights[ii] for ii in comp_inds]) 

1123 m_new = ( 

1124 sum( 

1125 [ 

1126 self._gaussMix.weights[ii] * self._gaussMix.means[ii] 

1127 for ii in comp_inds 

1128 ] 

1129 ) 

1130 / w_new 

1131 ) 

1132 p_new = ( 

1133 sum( 

1134 [ 

1135 self._gaussMix.weights[ii] * self._gaussMix.covariances[ii] 

1136 for ii in comp_inds 

1137 ] 

1138 ) 

1139 / w_new 

1140 ) 

1141 

1142 w_lst.append(w_new) 

1143 m_lst.append(m_new) 

1144 p_lst.append(p_new) 

1145 

1146 loop_inds = loop_inds.symmetric_difference(comp_inds) 

1147 for ii in comp_inds: 

1148 self._gaussMix.weights[ii] = -1 

1149 self._gaussMix = smodels.GaussianMixture( 

1150 means=m_lst, covariances=p_lst, weights=w_lst 

1151 ) 

1152 

1153 def _cap(self): 

1154 """Removes least likely hypotheses until a maximum number is reached. 

1155 

1156 This should be called once per time step after pruning and 

1157 before the state extraction. 

1158 """ 

1159 if len(self._gaussMix.weights) > self.max_gauss: 

1160 idx = np.argsort(self._gaussMix.weights) 

1161 w = sum(self._gaussMix.weights) 

1162 self._gaussMix.remove_components(idx[0 : -self.max_gauss]) 

1163 self._gaussMix.weights = [ 

1164 x * (w / sum(self._gaussMix.weights)) for x in self._gaussMix.weights 

1165 ] 

1166 return idx[0 : -self.max_gauss].tolist() 

1167 return [] 

1168 

1169 def extract_states(self): 

1170 """Extracts the best state estimates. 

1171 

1172 This extracts the best states from the distribution. It should be 

1173 called once per time step after the correction function. 

1174 """ 

1175 inds = np.where(np.asarray(self._gaussMix.weights) >= self.extract_threshold) 

1176 inds = np.ndarray.flatten(inds[0]) 

1177 s_lst = [] 

1178 c_lst = [] 

1179 for jj in inds: 

1180 jj = int(jj) 

1181 num_reps = round(self._gaussMix.weights[jj]) 

1182 s_lst.extend([self._gaussMix.means[jj]] * num_reps) 

1183 if self.save_covs: 

1184 c_lst.extend([self._gaussMix.covariances[jj]] * num_reps) 

1185 self._states.append(s_lst) 

1186 if self.save_covs: 

1187 self._covs.append(c_lst) 

1188 

1189 def cleanup( 

1190 self, 

1191 enable_prune=True, 

1192 enable_cap=True, 

1193 enable_merge=True, 

1194 enable_extract=True, 

1195 extract_kwargs=None, 

1196 ): 

1197 """Performs the cleanup step of the filter. 

1198 

1199 This can prune, cap, and extract states. It must be called once per 

1200 timestep. If this is called with `enable_extract` set to true then 

1201 the extract states method does not need to be called separately. It is 

1202 recommended to call this function instead of 

1203 :meth:`carbs.swarm_estimator.tracker.GeneralizedLabeledMultiBernoulli.extract_states` 

1204 directly. 

1205 

1206 Parameters 

1207 ---------- 

1208 enable_prune : bool, optional 

1209 Flag indicating if prunning should be performed. The default is True. 

1210 enable_cap : bool, optional 

1211 Flag indicating if capping should be performed. The default is True. 

1212 enable_merge : bool, optional 

1213 Flag indicating if merging should be performed. The default is True. 

1214 enable_extract : bool, optional 

1215 Flag indicating if state extraction should be performed. The default is True. 

1216 extract_kwargs : dict, optional 

1217 Extra arguments to pass to the extract function. 

1218 """ 

1219 if enable_prune: 

1220 self._prune() 

1221 if enable_merge: 

1222 self._merge() 

1223 if enable_cap: 

1224 self._cap() 

1225 if enable_extract: 

1226 if extract_kwargs is None: 

1227 extract_kwargs = {} 

1228 self.extract_states(**extract_kwargs) 

1229 

1230 def __ani_state_plotting( 

1231 self, 

1232 f_hndl, 

1233 tt, 

1234 states, 

1235 show_sig, 

1236 plt_inds, 

1237 sig_bnd, 

1238 color, 

1239 marker, 

1240 state_lbl, 

1241 added_sig_lbl, 

1242 added_state_lbl, 

1243 scat=None, 

1244 ): 

1245 if scat is None: 

1246 if not added_state_lbl: 

1247 scat = f_hndl.axes[0].scatter( 

1248 [], [], color=color, edgecolors=(0, 0, 0), marker=marker 

1249 ) 

1250 else: 

1251 scat = f_hndl.axes[0].scatter( 

1252 [], 

1253 [], 

1254 color=color, 

1255 edgecolors=(0, 0, 0), 

1256 marker=marker, 

1257 label=state_lbl, 

1258 ) 

1259 if len(states) == 0: 

1260 return scat 

1261 x = np.concatenate(states, axis=1) 

1262 if show_sig: 

1263 sigs = [None] * len(states) 

1264 for ii, cov in enumerate(self._covs[tt]): 

1265 sig = np.zeros((2, 2)) 

1266 sig[0, 0] = cov[plt_inds[0], plt_inds[0]] 

1267 sig[0, 1] = cov[plt_inds[0], plt_inds[1]] 

1268 sig[1, 0] = cov[plt_inds[1], plt_inds[0]] 

1269 sig[1, 1] = cov[plt_inds[1], plt_inds[1]] 

1270 sigs[ii] = sig 

1271 # plot 

1272 for ii, sig in enumerate(sigs): 

1273 if sig is None: 

1274 continue 

1275 w, h, a = pltUtil.calc_error_ellipse(sig, sig_bnd) 

1276 if not added_sig_lbl: 

1277 s = r"${}\sigma$ Error Ellipses".format(sig_bnd) 

1278 e = Ellipse( 

1279 xy=x[plt_inds, ii], 

1280 width=w, 

1281 height=h, 

1282 angle=a, 

1283 zorder=-10000, 

1284 animated=True, 

1285 label=s, 

1286 ) 

1287 else: 

1288 e = Ellipse( 

1289 xy=x[plt_inds, ii], 

1290 width=w, 

1291 height=h, 

1292 angle=a, 

1293 zorder=-10000, 

1294 animated=True, 

1295 ) 

1296 e.set_clip_box(f_hndl.axes[0].bbox) 

1297 e.set_alpha(0.15) 

1298 e.set_facecolor(color) 

1299 f_hndl.axes[0].add_patch(e) 

1300 scat.set_offsets(x[plt_inds[0:2], :].T) 

1301 return scat 

1302 

1303 def plot_states( 

1304 self, 

1305 plt_inds, 

1306 state_lbl="States", 

1307 ttl=None, 

1308 state_color=None, 

1309 x_lbl=None, 

1310 y_lbl=None, 

1311 **kwargs, 

1312 ): 

1313 """Plots the best estimate for the states. 

1314 

1315 This assumes that the states have been extracted. It's designed to plot 

1316 two of the state variables (typically x/y position). The error ellipses 

1317 are calculated according to :cite:`Hoover1984_AlgorithmsforConfidenceCirclesandEllipses` 

1318 

1319 Keyword arguments are processed with 

1320 :meth:`gncpy.plotting.init_plotting_opts`. This function 

1321 implements 

1322 

1323 - f_hndl 

1324 - true_states 

1325 - sig_bnd 

1326 - rng 

1327 - meas_inds 

1328 - lgnd_loc 

1329 - marker 

1330 

1331 Parameters 

1332 ---------- 

1333 plt_inds : list 

1334 List of indices in the state vector to plot 

1335 state_lbl : string 

1336 Value to appear in legend for the states. Only appears if the 

1337 legend is shown 

1338 ttl : string, optional 

1339 Title for the plot, if None a default title is generated. The default 

1340 is None. 

1341 x_lbl : string 

1342 Label for the x-axis. 

1343 y_lbl : string 

1344 Label for the y-axis. 

1345 

1346 Returns 

1347 ------- 

1348 Matplotlib figure 

1349 Instance of the matplotlib figure used 

1350 """ 

1351 opts = pltUtil.init_plotting_opts(**kwargs) 

1352 f_hndl = opts["f_hndl"] 

1353 true_states = opts["true_states"] 

1354 sig_bnd = opts["sig_bnd"] 

1355 rng = opts["rng"] 

1356 meas_inds = opts["meas_inds"] 

1357 lgnd_loc = opts["lgnd_loc"] 

1358 marker = opts["marker"] 

1359 if ttl is None: 

1360 ttl = "State Estimates" 

1361 if rng is None: 

1362 rng = rnd.default_rng(1) 

1363 if x_lbl is None: 

1364 x_lbl = "x-position" 

1365 if y_lbl is None: 

1366 y_lbl = "y-position" 

1367 plt_meas = meas_inds is not None 

1368 show_sig = sig_bnd is not None and self.save_covs 

1369 

1370 s_lst = deepcopy(self._states) 

1371 x_dim = None 

1372 

1373 if f_hndl is None: 

1374 f_hndl = plt.figure() 

1375 f_hndl.add_subplot(1, 1, 1) 

1376 # get state dimension 

1377 for states in s_lst: 

1378 if len(states) > 0: 

1379 x_dim = states[0].size 

1380 break 

1381 # get array of all state values for each label 

1382 added_sig_lbl = False 

1383 added_true_lbl = False 

1384 added_state_lbl = False 

1385 added_meas_lbl = False 

1386 r = rng.random() 

1387 b = rng.random() 

1388 g = rng.random() 

1389 if state_color is None: 

1390 color = (r, g, b) 

1391 else: 

1392 color = state_color 

1393 for tt, states in enumerate(s_lst): 

1394 if len(states) == 0: 

1395 continue 

1396 x = np.concatenate(states, axis=1) 

1397 if show_sig: 

1398 sigs = [None] * len(states) 

1399 for ii, cov in enumerate(self._covs[tt]): 

1400 sig = np.zeros((2, 2)) 

1401 sig[0, 0] = cov[plt_inds[0], plt_inds[0]] 

1402 sig[0, 1] = cov[plt_inds[0], plt_inds[1]] 

1403 sig[1, 0] = cov[plt_inds[1], plt_inds[0]] 

1404 sig[1, 1] = cov[plt_inds[1], plt_inds[1]] 

1405 sigs[ii] = sig 

1406 # plot 

1407 for ii, sig in enumerate(sigs): 

1408 if sig is None: 

1409 continue 

1410 w, h, a = pltUtil.calc_error_ellipse(sig, sig_bnd) 

1411 if not added_sig_lbl: 

1412 s = r"${}\sigma$ Error Ellipses".format(sig_bnd) 

1413 e = Ellipse( 

1414 xy=x[plt_inds, ii], 

1415 width=w, 

1416 height=h, 

1417 angle=a, 

1418 zorder=-10000, 

1419 label=s, 

1420 ) 

1421 added_sig_lbl = True 

1422 else: 

1423 e = Ellipse( 

1424 xy=x[plt_inds, ii], 

1425 width=w, 

1426 height=h, 

1427 angle=a, 

1428 zorder=-10000, 

1429 ) 

1430 e.set_clip_box(f_hndl.axes[0].bbox) 

1431 e.set_alpha(0.15) 

1432 e.set_facecolor(color) 

1433 f_hndl.axes[0].add_patch(e) 

1434 if not added_state_lbl: 

1435 f_hndl.axes[0].scatter( 

1436 x[plt_inds[0], :], 

1437 x[plt_inds[1], :], 

1438 color=color, 

1439 edgecolors=(0, 0, 0), 

1440 marker=marker, 

1441 label=state_lbl, 

1442 ) 

1443 added_state_lbl = True 

1444 else: 

1445 f_hndl.axes[0].scatter( 

1446 x[plt_inds[0], :], 

1447 x[plt_inds[1], :], 

1448 color=color, 

1449 edgecolors=(0, 0, 0), 

1450 marker=marker, 

1451 ) 

1452 # if true states are available then plot them 

1453 if true_states is not None: 

1454 if x_dim is None: 

1455 for states in true_states: 

1456 if len(states) > 0: 

1457 x_dim = states[0].size 

1458 break 

1459 max_true = max([len(x) for x in true_states]) 

1460 x = np.nan * np.ones((x_dim, len(true_states), max_true)) 

1461 for tt, states in enumerate(true_states): 

1462 for ii, state in enumerate(states): 

1463 x[:, [tt], ii] = state.copy() 

1464 for ii in range(0, max_true): 

1465 if not added_true_lbl: 

1466 f_hndl.axes[0].plot( 

1467 x[plt_inds[0], :, ii], 

1468 x[plt_inds[1], :, ii], 

1469 color="k", 

1470 marker=".", 

1471 label="True Trajectories", 

1472 ) 

1473 added_true_lbl = True 

1474 else: 

1475 f_hndl.axes[0].plot( 

1476 x[plt_inds[0], :, ii], 

1477 x[plt_inds[1], :, ii], 

1478 color="k", 

1479 marker=".", 

1480 ) 

1481 if plt_meas: 

1482 meas_x = [] 

1483 meas_y = [] 

1484 for meas_tt in self._meas_tab: 

1485 mx_ii = [m[meas_inds[0]].item() for m in meas_tt] 

1486 my_ii = [m[meas_inds[1]].item() for m in meas_tt] 

1487 meas_x.extend(mx_ii) 

1488 meas_y.extend(my_ii) 

1489 color = (128 / 255, 128 / 255, 128 / 255) 

1490 meas_x = np.asarray(meas_x) 

1491 meas_y = np.asarray(meas_y) 

1492 if not added_meas_lbl: 

1493 f_hndl.axes[0].scatter( 

1494 meas_x, 

1495 meas_y, 

1496 zorder=-1, 

1497 alpha=0.35, 

1498 color=color, 

1499 marker="^", 

1500 edgecolors=(0, 0, 0), 

1501 label="Measurements", 

1502 ) 

1503 else: 

1504 f_hndl.axes[0].scatter( 

1505 meas_x, 

1506 meas_y, 

1507 zorder=-1, 

1508 alpha=0.35, 

1509 color=color, 

1510 marker="^", 

1511 edgecolors=(0, 0, 0), 

1512 ) 

1513 f_hndl.axes[0].grid(True) 

1514 pltUtil.set_title_label(f_hndl, 0, opts, ttl=ttl, x_lbl=x_lbl, y_lbl=y_lbl) 

1515 

1516 if lgnd_loc is not None: 

1517 plt.legend(loc=lgnd_loc) 

1518 plt.tight_layout() 

1519 

1520 return f_hndl 

1521 

1522 def animate_state_plot( 

1523 self, 

1524 plt_inds, 

1525 state_lbl="States", 

1526 state_color=None, 

1527 interval=250, 

1528 repeat=True, 

1529 repeat_delay=1000, 

1530 save_path=None, 

1531 **kwargs, 

1532 ): 

1533 """Creates an animated plot of the states. 

1534 

1535 Parameters 

1536 ---------- 

1537 plt_inds : list 

1538 indices of the state vector to plot. 

1539 state_lbl : string, optional 

1540 label for the states. The default is 'States'. 

1541 state_color : tuple, optional 

1542 3-tuple for rgb value. The default is None. 

1543 interval : int, optional 

1544 interval of the animation in ms. The default is 250. 

1545 repeat : bool, optional 

1546 flag indicating if the animation loops. The default is True. 

1547 repeat_delay : int, optional 

1548 delay between loops in ms. The default is 1000. 

1549 save_path : string, optional 

1550 file path and name to save the gif, does not save if not given. 

1551 The default is None. 

1552 **kwargs : dict, optional 

1553 Standard plotting options for 

1554 :meth:`gncpy.plotting.init_plotting_opts`. This function 

1555 implements 

1556 

1557 - f_hndl 

1558 - sig_bnd 

1559 - rng 

1560 - meas_inds 

1561 - lgnd_loc 

1562 - marker 

1563 

1564 Returns 

1565 ------- 

1566 anim : 

1567 handle to the animation. 

1568 

1569 """ 

1570 opts = pltUtil.init_plotting_opts(**kwargs) 

1571 f_hndl = opts["f_hndl"] 

1572 sig_bnd = opts["sig_bnd"] 

1573 rng = opts["rng"] 

1574 meas_inds = opts["meas_inds"] 

1575 lgnd_loc = opts["lgnd_loc"] 

1576 marker = opts["marker"] 

1577 

1578 plt_meas = meas_inds is not None 

1579 show_sig = sig_bnd is not None and self.save_covs 

1580 

1581 f_hndl.axes[0].grid(True) 

1582 pltUtil.set_title_label( 

1583 f_hndl, 

1584 0, 

1585 opts, 

1586 ttl="State Estimates", 

1587 x_lbl="x-position", 

1588 y_lbl="y-position", 

1589 ) 

1590 

1591 fr_number = f_hndl.axes[0].annotate( 

1592 "0", 

1593 (0, 1), 

1594 xycoords="axes fraction", 

1595 xytext=(10, -10), 

1596 textcoords="offset points", 

1597 ha="left", 

1598 va="top", 

1599 animated=False, 

1600 ) 

1601 

1602 added_sig_lbl = False 

1603 added_state_lbl = False 

1604 added_meas_lbl = False 

1605 r = rng.random() 

1606 b = rng.random() 

1607 g = rng.random() 

1608 if state_color is None: 

1609 s_color = (r, g, b) 

1610 else: 

1611 s_color = state_color 

1612 state_scat = f_hndl.axes[0].scatter( 

1613 [], [], color=s_color, edgecolors=(0, 0, 0), marker=marker, label=state_lbl 

1614 ) 

1615 meas_scat = None 

1616 if plt_meas: 

1617 m_color = (128 / 255, 128 / 255, 128 / 255) 

1618 

1619 if meas_scat is None: 

1620 if not added_meas_lbl: 

1621 lbl = "Measurements" 

1622 meas_scat = f_hndl.axes[0].scatter( 

1623 [], 

1624 [], 

1625 zorder=-1, 

1626 alpha=0.35, 

1627 color=m_color, 

1628 marker="^", 

1629 edgecolors="k", 

1630 label=lbl, 

1631 ) 

1632 added_meas_lbl = True 

1633 else: 

1634 meas_scat = f_hndl.axes[0].scatter( 

1635 [], 

1636 [], 

1637 zorder=-1, 

1638 alpha=0.35, 

1639 color=m_color, 

1640 marker="^", 

1641 edgecolors="k", 

1642 ) 

1643 

1644 def update(tt, *fargs): 

1645 nonlocal added_sig_lbl 

1646 nonlocal added_state_lbl 

1647 nonlocal added_meas_lbl 

1648 nonlocal state_scat 

1649 nonlocal meas_scat 

1650 nonlocal fr_number 

1651 

1652 fr_number.set_text("Timestep: {j}".format(j=tt)) 

1653 

1654 states = self._states[tt] 

1655 state_scat = self.__ani_state_plotting( 

1656 f_hndl, 

1657 tt, 

1658 states, 

1659 show_sig, 

1660 plt_inds, 

1661 sig_bnd, 

1662 s_color, 

1663 marker, 

1664 state_lbl, 

1665 added_sig_lbl, 

1666 added_state_lbl, 

1667 scat=state_scat, 

1668 ) 

1669 added_sig_lbl = True 

1670 added_state_lbl = True 

1671 

1672 if plt_meas: 

1673 meas_tt = self._meas_tab[tt] 

1674 

1675 meas_x = [m[meas_inds[0]].item() for m in meas_tt] 

1676 meas_y = [m[meas_inds[1]].item() for m in meas_tt] 

1677 

1678 meas_x = np.asarray(meas_x) 

1679 meas_y = np.asarray(meas_y) 

1680 meas_scat.set_offsets(np.array([meas_x, meas_y]).T) 

1681 

1682 # plt.figure(f_hndl.number) 

1683 anim = animation.FuncAnimation( 

1684 f_hndl, 

1685 update, 

1686 frames=len(self._states), 

1687 interval=interval, 

1688 repeat_delay=repeat_delay, 

1689 repeat=repeat, 

1690 ) 

1691 

1692 if lgnd_loc is not None: 

1693 plt.legend(loc=lgnd_loc) 

1694 if save_path is not None: 

1695 writer = animation.PillowWriter(fps=30) 

1696 anim.save(save_path, writer=writer) 

1697 return anim 

1698 

1699 

1700class CardinalizedPHD(ProbabilityHypothesisDensity): 

1701 """Implements the Cardinalized Probability Hypothesis Density filter. 

1702 

1703 The kwargs in the constructor are passed through to the parent constructor. 

1704 

1705 Notes 

1706 ----- 

1707 The filter implementation is based on 

1708 :cite:`Vo2006_TheCardinalizedProbabilityHypothesisDensityFilterforLinearGaussianMultiTargetModels` 

1709 and :cite:`Vo2007_AnalyticImplementationsoftheCardinalizedProbabilityHypothesisDensityFilter`. 

1710 

1711 Attributes 

1712 ---------- 

1713 agents_per_state : list, optional 

1714 number of agents per state. The default is []. 

1715 """ 

1716 

1717 def __init__(self, max_expected_card=10, **kwargs): 

1718 self.agents_per_state = [] 

1719 self._max_expected_card = max_expected_card 

1720 

1721 self._card_dist = np.zeros( 

1722 self.max_expected_card + 1 

1723 ) # local copy for internal modification 

1724 self._card_dist[0] = 1 

1725 self._card_time_hist = [] # local copy for internal modification 

1726 self._n_states_per_time = [] 

1727 

1728 super().__init__(**kwargs) 

1729 

1730 @property 

1731 def max_expected_card(self): 

1732 """Maximum expected cardinality. The default is 10.""" 

1733 return self._max_expected_card 

1734 

1735 @max_expected_card.setter 

1736 def max_expected_card(self, x): 

1737 self._card_dist = np.zeros(x + 1) 

1738 self._card_dist[0] = 1 

1739 self._max_expected_card = x 

1740 

1741 @property 

1742 def cardinality(self): 

1743 """Cardinality of the RFS.""" 

1744 return np.argmax(self._card_dist) 

1745 

1746 def predict(self, timestep, **kwargs): 

1747 """Prediction step of the CPHD filter. 

1748 

1749 This predicts new hypothesis, and propogates them to the next time 

1750 step. It also updates the cardinality distribution. 

1751 

1752 

1753 Parameters 

1754 ---------- 

1755 timestep: float 

1756 current timestep 

1757 **kwargs : dict, optional 

1758 See :meth:carbs.swarm_estimator.tracker.ProbabilityHypothesisDensity.predict` 

1759 for the available arguments. 

1760 

1761 Returns 

1762 ------- 

1763 None. 

1764 

1765 """ 

1766 super().predict(timestep, **kwargs) 

1767 

1768 survive_cdn_predict = np.zeros(self.max_expected_card + 1) 

1769 for j in range(0, self.max_expected_card + 1): 

1770 terms = np.zeros(self.max_expected_card + 1) 

1771 for i in range(j, self.max_expected_card + 1): 

1772 temp = np.array( 

1773 [ 

1774 np.sum(np.log(range(1, i + 1))), 

1775 -np.sum(np.log(range(1, j + 1))), 

1776 np.sum(np.log(range(1, i - j + 1))), 

1777 j * np.log(self.prob_survive), 

1778 (i - j) * np.log(self.prob_death), 

1779 ] 

1780 ) 

1781 terms[i] = np.exp(np.sum(temp)) * self._card_dist[i] 

1782 survive_cdn_predict[j] = np.sum(terms) 

1783 cdn_predict = np.zeros(self.max_expected_card + 1) 

1784 if len(self.birth_terms) != 1: 

1785 warnings.warn("Only using the first birth term in cardinality update") 

1786 birth = np.sum( 

1787 np.array([w for w in self.birth_terms[0].weights]) 

1788 ) # NOTE: assumes 1 GM for the birth model 

1789 log_birth = np.log(birth) 

1790 for n in range(0, self.max_expected_card + 1): 

1791 terms = np.zeros(self.max_expected_card + 1) 

1792 for j in range(0, n + 1): 

1793 temp = np.array( 

1794 [birth, (n - j) * log_birth, -np.sum(np.log(range(1, n - j + 1)))] 

1795 ) 

1796 terms[j] = np.exp(np.sum(temp)) * survive_cdn_predict[j] 

1797 cdn_predict[n] = np.sum(terms) 

1798 self._card_dist = (cdn_predict / np.sum(cdn_predict)).copy() 

1799 

1800 self._card_time_hist.append( 

1801 (np.argmax(self._card_dist).item(), np.std(self._card_dist)) 

1802 ) 

1803 

1804 def correct( 

1805 self, timestep, meas_in, meas_mat_args={}, est_meas_args={}, filt_args={} 

1806 ): 

1807 """Correction step of the CPHD filter. 

1808 

1809 This corrects the hypotheses based on the measurements and gates the 

1810 measurements according to the class settings. It also updates the 

1811 cardinality distribution. 

1812 

1813 

1814 Parameters 

1815 ---------- 

1816 timestep: float 

1817 current timestep 

1818 meas_in : list 

1819 2d numpy arrays representing a measurement. 

1820 meas_mat_args : dict, optional 

1821 keyword arguments to pass to the inner filters get measurement 

1822 matrix function. Only used if gating is on. The default is {}. 

1823 est_meas_args : TYPE, optional 

1824 keyword arguments to pass to the inner filters estimate 

1825 measurements function. Only used if gating is on. The default is {}. 

1826 filt_args : TYPE, optional 

1827 keyword arguments to pass to the inner filters correct function. 

1828 The default is {}. 

1829 

1830 Returns 

1831 ------- 

1832 None. 

1833 

1834 """ 

1835 meas = deepcopy(meas_in) 

1836 

1837 if self.gating_on: 

1838 meas = self._gate_meas( 

1839 meas, 

1840 self._gaussMix.means, 

1841 self._gaussMix.covariances, 

1842 meas_mat_args, 

1843 est_meas_args, 

1844 ) 

1845 self._meas_tab.append(meas) 

1846 

1847 gmix = deepcopy(self._gaussMix) # predicted gm 

1848 

1849 self._gaussMix = self._correct_prob_density(timestep, meas, gmix, filt_args) 

1850 

1851 def _correct_prob_density(self, timestep, meas, probDensity, filt_args): 

1852 """Helper function for correction step. 

1853 

1854 Loops over all elements in a probability distribution and preforms 

1855 the filter correction. 

1856 """ 

1857 w_pred = np.zeros((len(probDensity.weights), 1)) 

1858 for i in range(0, len(probDensity.weights)): 

1859 w_pred[i] = probDensity.weights[i] 

1860 xdim = len(probDensity.means[0]) 

1861 

1862 plen = len(probDensity.means) 

1863 zlen = len(meas) 

1864 

1865 qz_temp = np.zeros((plen, zlen)) 

1866 mean_temp = np.zeros((zlen, xdim, plen)) 

1867 cov_temp = np.zeros((zlen, plen, xdim, xdim)) 

1868 

1869 for z_ind in range(0, zlen): 

1870 for p_ind in range(0, plen): 

1871 self.filter.cov = probDensity.covariances[p_ind] 

1872 state = probDensity.means[p_ind] 

1873 

1874 (mean, qz) = self.filter.correct( 

1875 timestep, meas[z_ind], state, **filt_args 

1876 ) 

1877 qz_temp[p_ind, z_ind] = qz 

1878 mean_temp[z_ind, :, p_ind] = np.ndarray.flatten(mean) 

1879 cov_temp[z_ind, p_ind, :, :] = self.filter.cov.copy() 

1880 xivals = np.zeros(zlen) 

1881 pdc = self.prob_detection / self.clutter_den 

1882 for e in range(0, zlen): 

1883 xivals[e] = pdc * np.dot(w_pred.T, qz_temp[:, [e]]) 

1884 esfvals_E = get_elem_sym_fnc(xivals) 

1885 esfvals_D = np.zeros((zlen, zlen)) 

1886 

1887 for j in range(0, zlen): 

1888 xi_temp = xivals.copy() 

1889 xi_temp = np.delete(xi_temp, j) 

1890 esfvals_D[:, [j]] = get_elem_sym_fnc(xi_temp) 

1891 ups0_E = np.zeros((self.max_expected_card + 1, 1)) 

1892 ups1_E = np.zeros((self.max_expected_card + 1, 1)) 

1893 ups1_D = np.zeros((self.max_expected_card + 1, zlen)) 

1894 

1895 tot_w_pred = sum(w_pred) 

1896 for nn in range(0, self.max_expected_card + 1): 

1897 terms0_E = np.zeros((min(zlen, nn) + 1)) 

1898 for jj in range(0, min(zlen, nn) + 1): 

1899 t1 = -self.clutter_rate + (zlen - jj) * np.log(self.clutter_rate) 

1900 t2 = sum([np.log(x) for x in range(1, nn + 1)]) 

1901 t3 = -1 * sum([np.log(x) for x in range(1, nn - jj + 1)]) 

1902 t4 = (nn - jj) * np.log(self.prob_death) 

1903 t5 = -jj * np.log(tot_w_pred) 

1904 terms0_E[jj] = np.exp(t1 + t2 + t3 + t4 + t5) * esfvals_E[jj] 

1905 ups0_E[nn] = np.sum(terms0_E) 

1906 

1907 terms1_E = np.zeros((min(zlen, nn) + 1)) 

1908 for jj in range(0, min(zlen, nn) + 1): 

1909 if nn >= jj + 1: 

1910 t1 = -self.clutter_rate + (zlen - jj) * np.log(self.clutter_rate) 

1911 t2 = sum([np.log(x) for x in range(1, nn + 1)]) 

1912 t3 = -1 * sum([np.log(x) for x in range(1, nn - (jj + 1) + 1)]) 

1913 t4 = (nn - (jj + 1)) * np.log(self.prob_death) 

1914 t5 = -(jj + 1) * np.log(tot_w_pred) 

1915 terms1_E[jj] = np.exp(t1 + t2 + t3 + t4 + t5) * esfvals_E[jj] 

1916 ups1_E[nn] = np.sum(terms1_E) 

1917 

1918 if zlen != 0: 

1919 terms1_D = np.zeros((min(zlen - 1, nn) + 1, zlen)) 

1920 for ell in range(1, zlen + 1): 

1921 for jj in range(0, min((zlen - 1), nn) + 1): 

1922 if nn >= jj + 1: 

1923 t1 = -self.clutter_rate + ((zlen - 1) - jj) * np.log( 

1924 self.clutter_rate 

1925 ) 

1926 t2 = sum([np.log(x) for x in range(1, nn + 1)]) 

1927 t3 = -1 * sum( 

1928 [np.log(x) for x in range(1, nn - (jj + 1) + 1)] 

1929 ) 

1930 t4 = (nn - (jj + 1)) * np.log(self.prob_death) 

1931 t5 = -(jj + 1) * np.log(tot_w_pred) 

1932 terms1_D[jj, ell - 1] = ( 

1933 np.exp(t1 + t2 + t3 + t4 + t5) * esfvals_D[jj, ell - 1] 

1934 ) 

1935 ups1_D[nn, :] = np.sum(terms1_D, axis=0) 

1936 gmix = deepcopy(probDensity) 

1937 w_update = ( 

1938 ((ups1_E.T @ self._card_dist) / (ups0_E.T @ self._card_dist)) 

1939 * self.prob_miss_detection 

1940 * w_pred 

1941 ) 

1942 

1943 gmix.weights = [x.item() for x in w_update] 

1944 

1945 for ee in range(0, zlen): 

1946 wt_1 = ( 

1947 (ups1_D[:, [ee]].T @ self._card_dist) / (ups0_E.T @ self._card_dist) 

1948 ).reshape((1, 1)) 

1949 wt_2 = self.prob_detection * qz_temp[:, [ee]] / self.clutter_den * w_pred 

1950 w_temp = wt_1 * wt_2 

1951 for ww in range(0, w_temp.shape[0]): 

1952 gmix.add_components( 

1953 mean_temp[ee, :, ww].reshape((xdim, 1)), 

1954 cov_temp[ee, ww, :, :], 

1955 w_temp[ww].item(), 

1956 ) 

1957 cdn_update = self._card_dist.copy() 

1958 for ii in range(0, len(cdn_update)): 

1959 cdn_update[ii] = ups0_E[ii] * self._card_dist[ii] 

1960 self._card_dist = cdn_update / np.sum(cdn_update) 

1961 # assumes predict is called before correct 

1962 self._card_time_hist[-1] = ( 

1963 np.argmax(self._card_dist).item(), 

1964 np.std(self._card_dist), 

1965 ) 

1966 

1967 return gmix 

1968 

1969 def extract_states(self, allow_multiple=True): 

1970 """Extracts the best state estimates. 

1971 

1972 This extracts the best states from the distribution. It should be 

1973 called once per time step after the correction function. 

1974 

1975 Parameters 

1976 ---------- 

1977 allow_multiple : bool 

1978 Flag inicating if extraction is allowed to map a single Gaussian 

1979 to multiple states. The default is True. 

1980 """ 

1981 s_weights = np.argsort(self._gaussMix.weights)[::-1] 

1982 s_lst = [] 

1983 c_lst = [] 

1984 self.agents_per_state = [] 

1985 ii = 0 

1986 tot_agents = 0 

1987 while ii < s_weights.size and tot_agents < self.cardinality: 

1988 idx = int(s_weights[ii]) 

1989 

1990 if allow_multiple: 

1991 n_agents = np.ceil(self._gaussMix.weights[idx]) 

1992 if n_agents <= 0: 

1993 msg = "Gaussian weights are 0 before reaching cardinality" 

1994 warnings.warn(msg, RuntimeWarning) 

1995 break 

1996 if tot_agents + n_agents > self.cardinality: 

1997 n_agents = self.cardinality - tot_agents 

1998 else: 

1999 n_agents = 1 

2000 tot_agents += n_agents 

2001 self.agents_per_state.append(n_agents) 

2002 

2003 s_lst.append(self._gaussMix.means[idx]) 

2004 if self.save_covs: 

2005 c_lst.append(self._gaussMix.covariances[idx]) 

2006 ii += 1 

2007 if tot_agents != self.cardinality: 

2008 warnings.warn("Failed to meet estimated cardinality when extracting!") 

2009 self._states.append(s_lst) 

2010 if self.save_covs: 

2011 self._covs.append(c_lst) 

2012 if self.debug_plots: 

2013 self._n_states_per_time.append(ii) 

2014 

2015 def plot_card_dist(self, **kwargs): 

2016 """Plots the current cardinality distribution. 

2017 

2018 This assumes that the cardinality distribution has been calculated by 

2019 the class. 

2020 

2021 Parameters 

2022 ---------- 

2023 **kwargs : dict, optional 

2024 Keyword arguments are processed with 

2025 :meth:`gncpy.plotting.init_plotting_opts`. This function 

2026 implements 

2027 

2028 - f_hndl 

2029 

2030 Returns 

2031 ------- 

2032 Matplotlib figure 

2033 Instance of the matplotlib figure used 

2034 

2035 Raises 

2036 ------ 

2037 RuntimeWarning 

2038 If the cardinality distribution is empty. 

2039 """ 

2040 opts = pltUtil.init_plotting_opts(**kwargs) 

2041 f_hndl = opts["f_hndl"] 

2042 

2043 if len(self._card_dist) == 0: 

2044 raise RuntimeWarning("Empty Cardinality") 

2045 return f_hndl 

2046 if f_hndl is None: 

2047 f_hndl = plt.figure() 

2048 f_hndl.add_subplot(1, 1, 1) 

2049 x_vals = np.arange(0, len(self._card_dist)) 

2050 f_hndl.axes[0].bar(x_vals, self._card_dist) 

2051 

2052 pltUtil.set_title_label( 

2053 f_hndl, 

2054 0, 

2055 opts, 

2056 ttl="Cardinality Distribution", 

2057 x_lbl="Cardinality", 

2058 y_lbl="Probability", 

2059 ) 

2060 plt.tight_layout() 

2061 

2062 return f_hndl 

2063 

2064 def plot_card_history( 

2065 self, ttl=None, true_card=None, time_units="index", time=None, **kwargs 

2066 ): 

2067 """Plots the current cardinality time history. 

2068 

2069 This assumes that the cardinality distribution has been calculated by 

2070 the class. 

2071 

2072 Parameters 

2073 ---------- 

2074 ttl : string 

2075 String for the title, if None a default is created. The default is 

2076 None. 

2077 true_card : array like 

2078 List of the true cardinality at each time 

2079 time_units : string, optional 

2080 Text representing the units of time in the plot. The default is 

2081 'index'. 

2082 time : numpy array, optional 

2083 Vector to use for the x-axis of the plot. If none is given then 

2084 vector indices are used. The default is None. 

2085 **kwargs : dict, optional 

2086 Keyword arguments are processed with 

2087 :meth:`gncpy.plotting.init_plotting_opts`. This function 

2088 implements 

2089 

2090 - f_hndl 

2091 - sig_bnd 

2092 - time_vec 

2093 - lgnd_loc 

2094 

2095 Returns 

2096 ------- 

2097 Matplotlib figure 

2098 Instance of the matplotlib figure used 

2099 """ 

2100 opts = pltUtil.init_plotting_opts(**kwargs) 

2101 f_hndl = opts["f_hndl"] 

2102 sig_bnd = opts["sig_bnd"] 

2103 # time_vec = opts["time_vec"] 

2104 lgnd_loc = opts["lgnd_loc"] 

2105 if ttl is None: 

2106 ttl = "Cardinality History" 

2107 if len(self._card_time_hist) == 0: 

2108 raise RuntimeWarning("Empty Cardinality") 

2109 return f_hndl 

2110 if sig_bnd is not None: 

2111 stds = [sig_bnd * x[1] for x in self._card_time_hist] 

2112 card = [x[0] for x in self._card_time_hist] 

2113 

2114 if f_hndl is None: 

2115 f_hndl = plt.figure() 

2116 f_hndl.add_subplot(1, 1, 1) 

2117 if time is None: 

2118 x_vals = [ii for ii in range(0, len(card))] 

2119 else: 

2120 x_vals = time 

2121 f_hndl.axes[0].step( 

2122 x_vals, 

2123 card, 

2124 label="Cardinality", 

2125 color="k", 

2126 linestyle="-", 

2127 where="post", 

2128 ) 

2129 

2130 if true_card is not None: 

2131 if len(true_card) != len(x_vals): 

2132 c_len = len(true_card) 

2133 t_len = len(x_vals) 

2134 msg = "True Cardinality vector length ({})".format( 

2135 c_len 

2136 ) + " does not match time vector length ({})".format(t_len) 

2137 warnings.warn(msg) 

2138 else: 

2139 f_hndl.axes[0].step( 

2140 x_vals, 

2141 true_card, 

2142 color="g", 

2143 label="True Cardinality", 

2144 linestyle="--", 

2145 where="post", 

2146 ) 

2147 if sig_bnd is not None: 

2148 lbl = r"${}\sigma$ Bound".format(sig_bnd) 

2149 f_hndl.axes[0].plot( 

2150 x_vals, 

2151 [x + s for (x, s) in zip(card, stds)], 

2152 linestyle="--", 

2153 color="r", 

2154 label=lbl, 

2155 ) 

2156 f_hndl.axes[0].plot( 

2157 x_vals, [x - s for (x, s) in zip(card, stds)], linestyle="--", color="r" 

2158 ) 

2159 f_hndl.axes[0].ticklabel_format(useOffset=False) 

2160 

2161 if lgnd_loc is not None: 

2162 plt.legend(loc=lgnd_loc) 

2163 plt.grid(True) 

2164 pltUtil.set_title_label( 

2165 f_hndl, 

2166 0, 

2167 opts, 

2168 ttl=ttl, 

2169 x_lbl="Time ({})".format(time_units), 

2170 y_lbl="Cardinality", 

2171 ) 

2172 

2173 plt.tight_layout() 

2174 

2175 return f_hndl 

2176 

2177 def plot_number_states_per_time(self, **kwargs): 

2178 """Plots the number of states per timestep. 

2179 

2180 This is a debug plot for if there are 0 weights in the GM but the 

2181 cardinality is not reached. Debug plots must be turned on prior to 

2182 running the filter. 

2183 

2184 

2185 Parameters 

2186 ---------- 

2187 **kwargs : dict, optional 

2188 Keyword arguments are processed with 

2189 :meth:`gncpy.plotting.init_plotting_opts`. This function 

2190 implements 

2191 

2192 - f_hndl 

2193 - lgnd_loc 

2194 

2195 Returns 

2196 ------- 

2197 f_hndl : matplotlib figure 

2198 handle to the current figure. 

2199 

2200 """ 

2201 opts = pltUtil.init_plotting_opts(**kwargs) 

2202 f_hndl = opts["f_hndl"] 

2203 lgnd_loc = opts["lgnd_loc"] 

2204 

2205 if not self.debug_plots: 

2206 msg = "Debug plots turned off" 

2207 warnings.warn(msg) 

2208 return f_hndl 

2209 if f_hndl is None: 

2210 f_hndl = plt.figure() 

2211 f_hndl.add_subplot(1, 1, 1) 

2212 if lgnd_loc is not None: 

2213 plt.legend(loc=lgnd_loc) 

2214 x_vals = [ii for ii in range(0, len(self._n_states_per_time))] 

2215 

2216 f_hndl.axes[0].plot(x_vals, self._n_states_per_time) 

2217 plt.grid(True) 

2218 pltUtil.set_title_label( 

2219 f_hndl, 

2220 0, 

2221 opts, 

2222 ttl="Gaussians per Timestep", 

2223 x_lbl="Time", 

2224 y_lbl="Number of Gaussians", 

2225 ) 

2226 

2227 return f_hndl 

2228 

2229 

2230class _IMMPHDBase: 

2231 def __init__( 

2232 self, 

2233 filter_lst=None, 

2234 model_trans=None, 

2235 init_weights=None, 

2236 init_means=None, 

2237 init_covs=None, 

2238 **kwargs, 

2239 ): 

2240 super().__init__(**kwargs) 

2241 if not isinstance(self.filter, gfilts.InteractingMultipleModel): 

2242 raise TypeError("Filter must be InteractingMultipleModel") 

2243 if filter_lst is not None and model_trans is not None: 

2244 self.filter.initialize_filter(filter_lst, model_trans) 

2245 if init_means is not None and init_covs is not None: 

2246 self.filter.initialize_states( 

2247 init_means, init_covs, init_weights=init_weights 

2248 ) 

2249 self._filter_states = [] 

2250 

2251 def _predict_prob_density(self, timestep, probDensity, filt_args): 

2252 """Predicts the probability density. 

2253 

2254 Loops over all elements in a probability distribution and performs 

2255 the filter prediction. 

2256 

2257 Parameters 

2258 ---------- 

2259 timestep: float 

2260 current timestep 

2261 probDensity : :class:`serums.models.GaussianMixture` 

2262 Probability density to perform prediction on. 

2263 filt_args : dict 

2264 Passed directly to the inner filter. 

2265 

2266 Returns 

2267 ------- 

2268 gm : :class:`serums.models.GaussianMixture` 

2269 predicted Gaussian mixture. 

2270 

2271 """ 

2272 weights = [self.prob_survive * x for x in probDensity.weights.copy()] 

2273 n_terms = len(probDensity.means) 

2274 covariances = [None] * n_terms 

2275 means = [None] * n_terms 

2276 for ii, (m, P) in enumerate(zip(probDensity.means, probDensity.covariances)): 

2277 self.filter.load_filter_state(self._filter_states[ii]) 

2278 n_mean = self.filter.predict(timestep, **filt_args).reshape((m.shape[0], 1)) 

2279 covariances[ii] = self.filter.cov.copy() 

2280 means[ii] = n_mean 

2281 self._filter_states[ii] = self.filter.save_filter_state() 

2282 return smodels.GaussianMixture( 

2283 means=means, covariances=covariances, weights=weights 

2284 ) 

2285 

2286 def predict(self, timestep, filt_args={}): 

2287 super().predict(timestep, filt_args=filt_args) 

2288 for gm in self.birth_terms: 

2289 for m, c in zip(gm.means, gm.covariances): 

2290 # if len(m) != 1 or len(c) != 1: 

2291 # raise ValueError("only one mean and covariance per filter is supported") 

2292 init_means = [] 

2293 init_covs = [] 

2294 for ii in range(0, len(self.filter.in_filt_list)): 

2295 init_means.append(m) 

2296 init_covs.append(c) 

2297 self.filter.initialize_states(init_means, init_covs) 

2298 self._filter_states.append(self.filter.save_filter_state()) 

2299 # new imm filter state to represent new means 

2300 

2301 def _prune(self): 

2302 inds = super()._prune() 

2303 inds = sorted(inds, reverse=True) 

2304 for ind in inds: 

2305 if ind < len(self._filter_states): 

2306 self._filter_states.pop(ind) 

2307 else: 

2308 raise RuntimeError("Pruned index is greater than filter state length") 

2309 

2310 # remove pruned indices from filter state indicies 

2311 

2312 def _merge(self): 

2313 """Merges nearby hypotheses.""" 

2314 loop_inds = set(range(0, len(self._gaussMix.means))) 

2315 

2316 w_lst = [] 

2317 m_lst = [] 

2318 p_lst = [] 

2319 fs_lst = [] 

2320 while len(loop_inds) > 0: 

2321 jj = int(np.argmax(self._gaussMix.weights)) 

2322 comp_inds = [] 

2323 inv_cov = la.inv(self._gaussMix.covariances[jj]) 

2324 for ii in loop_inds: 

2325 diff = self._gaussMix.means[ii] - self._gaussMix.means[jj] 

2326 val = diff.T @ inv_cov @ diff 

2327 if val <= self.merge_threshold: 

2328 comp_inds.append(ii) 

2329 w_new = sum([self._gaussMix.weights[ii] for ii in comp_inds]) 

2330 m_new = ( 

2331 sum( 

2332 [ 

2333 self._gaussMix.weights[ii] * self._gaussMix.means[ii] 

2334 for ii in comp_inds 

2335 ] 

2336 ) 

2337 / w_new 

2338 ) 

2339 p_new = ( 

2340 sum( 

2341 [ 

2342 self._gaussMix.weights[ii] * self._gaussMix.covariances[ii] 

2343 for ii in comp_inds 

2344 ] 

2345 ) 

2346 / w_new 

2347 ) 

2348 

2349 new_mean_list = [] 

2350 new_cov_list = [] 

2351 new_filt_weights = [] 

2352 for kk in range(0, len(self.filter.in_filt_list)): 

2353 # ml_new = ( sum([self._gaussMix.weights[ii] * self._filter_states[ii]["mean_list"][kk]])) 

2354 ml_new = 0 

2355 cl_new = 0 

2356 fw_new = 0 

2357 

2358 for ii in comp_inds: 

2359 ml_new = ( 

2360 ml_new 

2361 + self._gaussMix.weights[ii] 

2362 * self._filter_states[ii]["mean_list"][kk] 

2363 ) 

2364 cl_new = ( 

2365 cl_new 

2366 + self._gaussMix.weights[ii] 

2367 * self._filter_states[ii]["cov_list"][kk] 

2368 ) 

2369 fw_new = ( 

2370 fw_new 

2371 + self._gaussMix.weights[ii] 

2372 * self._filter_states[ii]["filt_weights"][kk] 

2373 ) 

2374 new_mean_list.append(ml_new / w_new) 

2375 new_cov_list.append(cl_new / w_new) 

2376 new_filt_weights.append(fw_new / w_new) 

2377 self.filter.initialize_states( 

2378 new_mean_list, new_cov_list, init_weights=new_filt_weights 

2379 ) 

2380 fs_lst.append(self.filter.save_filter_state()) 

2381 w_lst.append(w_new) 

2382 m_lst.append(m_new) 

2383 p_lst.append(p_new) 

2384 

2385 loop_inds = loop_inds.symmetric_difference(comp_inds) 

2386 for ii in comp_inds: 

2387 self._gaussMix.weights[ii] = -1 

2388 self._filter_states = fs_lst 

2389 self._gaussMix = smodels.GaussianMixture( 

2390 means=m_lst, covariances=p_lst, weights=w_lst 

2391 ) 

2392 # probably need to overwrite, do this later 

2393 

2394 def _cap(self): 

2395 inds = super()._cap() 

2396 inds = sorted(inds, reverse=True) 

2397 for ind in inds: 

2398 if ind < len(self._filter_states): 

2399 self._filter_states.pop(ind) 

2400 else: 

2401 raise RuntimeError("Capped index is greater than filter state length") 

2402 # remove capped indices from filter state indicies 

2403 

2404 

2405class IMMProbabilityHypothesisDensity(_IMMPHDBase, ProbabilityHypothesisDensity): 

2406 def __init__(self, **kwargs): 

2407 super().__init__(**kwargs) 

2408 # TODO: init_filter_states_for_imm 

2409 

2410 def _correct_prob_density(self, timestep, meas, probDensity, filt_args): 

2411 """Corrects the probability densities. 

2412 

2413 Loops over all elements in a probability distribution and preforms 

2414 the filter correction. 

2415 

2416 Parameters 

2417 ---------- 

2418 meas : list 

2419 2d numpy arrays of each measurement. 

2420 probDensity : :py:class:`serums.models.GaussianMixture` 

2421 probability density to run correction on. 

2422 filt_args : dict 

2423 arguements to pass to the inner filter correct function. 

2424 

2425 Returns 

2426 ------- 

2427 gm : :py:class:`serums.models.GaussianMixture` 

2428 corrected probability density. 

2429 

2430 """ 

2431 means = [] 

2432 covariances = [] 

2433 weights = [] 

2434 # corr_filt_weights = np.zeros(np.shape(self.filter.filt_weights)) 

2435 new_filter_states = [] 

2436 det_weights = [self.prob_detection * x for x in probDensity.weights] 

2437 for z in meas: 

2438 w_lst = [] 

2439 for jj in range(0, len(probDensity.means)): 

2440 self.filter.load_filter_state(self._filter_states[jj]) 

2441 (mean, qz) = self.filter.correct(timestep, z, **filt_args) 

2442 cov = self.filter.cov 

2443 w = qz * det_weights[jj] 

2444 means.append(mean.reshape((5, 1))) 

2445 covariances.append(cov) 

2446 w_lst.append(w) 

2447 new_filter_states.append(self.filter.save_filter_state()) 

2448 weights.extend( 

2449 [x / (self.clutter_rate * self.clutter_den + sum(w_lst)) for x in w_lst] 

2450 ) 

2451 self._filter_states = new_filter_states 

2452 return smodels.GaussianMixture( 

2453 means=means, covariances=covariances, weights=weights 

2454 ) 

2455 

2456 def correct( 

2457 self, timestep, meas_in, meas_mat_args={}, est_meas_args={}, filt_args={} 

2458 ): 

2459 meas = deepcopy(meas_in) 

2460 

2461 if self.gating_on: 

2462 meas = self._gate_meas( 

2463 meas, 

2464 self._gaussMix.means, 

2465 self._gaussMix.covariances, 

2466 meas_mat_args, 

2467 est_meas_args, 

2468 ) 

2469 self._meas_tab.append(meas) 

2470 

2471 gmix = deepcopy(self._gaussMix) 

2472 gmix.weights = [self.prob_miss_detection * x for x in gmix.weights] 

2473 saved_filt_weights = [] 

2474 for filt_state in self._filter_states: 

2475 saved_filt_weights.append(filt_state["filt_weights"].copy()) 

2476 gm = self._correct_prob_density(timestep, meas, self._gaussMix, filt_args) 

2477 gm.add_components(gmix.means, gmix.covariances, gmix.weights) 

2478 

2479 for jj, (m, c) in enumerate(zip(gmix.means, gmix.covariances)): 

2480 # for m, c in zip(m_list, c_list): 

2481 m_list = [] 

2482 c_list = [] 

2483 for ii in range(0, len(self.filter.in_filt_list)): 

2484 m_list.append(m) 

2485 c_list.append(c) 

2486 self.filter.initialize_states( 

2487 m_list, c_list, init_weights=saved_filt_weights[jj] 

2488 ) 

2489 self._filter_states.append(self.filter.save_filter_state()) 

2490 

2491 self._gaussMix = gm 

2492 

2493 

2494class IMMCardinalizedPHD(_IMMPHDBase, CardinalizedPHD): 

2495 def __init__(self, **kwargs): 

2496 super().__init__(**kwargs) 

2497 

2498 def _correct_prob_density(self, timestep, meas, probDensity, filt_args): 

2499 """Helper function for correction step. 

2500 

2501 Loops over all elements in a probability distribution and preforms 

2502 the filter correction. 

2503 """ 

2504 w_pred = np.zeros((len(probDensity.weights), 1)) 

2505 for i in range(0, len(probDensity.weights)): 

2506 w_pred[i] = probDensity.weights[i] 

2507 xdim = len(probDensity.means[0]) 

2508 

2509 plen = len(probDensity.means) 

2510 zlen = len(meas) 

2511 

2512 qz_temp = np.zeros((plen, zlen)) 

2513 mean_temp = np.zeros((zlen, xdim, plen)) 

2514 cov_temp = np.zeros((zlen, plen, xdim, xdim)) 

2515 saved_filt_weights = [] 

2516 for filt_state in self._filter_states: 

2517 saved_filt_weights.append(filt_state["filt_weights"].copy()) 

2518 new_filter_states = [] 

2519 

2520 for z_ind in range(0, zlen): 

2521 for p_ind in range(0, plen): 

2522 state = probDensity.means[p_ind] 

2523 self.filter.load_filter_state(self._filter_states[p_ind]) 

2524 # self.filter.initialize_states(probDensity.means[p_ind], probDensity.covariances[p_ind], 

2525 # init_weights=self.weight_list[p_ind]) 

2526 

2527 (mean, qz) = self.filter.correct(timestep, meas[z_ind], **filt_args) 

2528 qz_temp[p_ind, z_ind] = qz 

2529 mean_temp[z_ind, :, p_ind] = np.ndarray.flatten(mean) 

2530 cov_temp[z_ind, p_ind, :, :] = self.filter.cov.copy() 

2531 new_filter_states.append(self.filter.save_filter_state()) 

2532 # self._filter_states[p_ind] = self.filter.save_filter_state() 

2533 xivals = np.zeros(zlen) 

2534 pdc = self.prob_detection / self.clutter_den 

2535 for e in range(0, zlen): 

2536 xivals[e] = pdc * np.dot(w_pred.T, qz_temp[:, [e]]) 

2537 esfvals_E = get_elem_sym_fnc(xivals) 

2538 esfvals_D = np.zeros((zlen, zlen)) 

2539 

2540 for j in range(0, zlen): 

2541 xi_temp = xivals.copy() 

2542 xi_temp = np.delete(xi_temp, j) 

2543 esfvals_D[:, [j]] = get_elem_sym_fnc(xi_temp) 

2544 ups0_E = np.zeros((self.max_expected_card + 1, 1)) 

2545 ups1_E = np.zeros((self.max_expected_card + 1, 1)) 

2546 ups1_D = np.zeros((self.max_expected_card + 1, zlen)) 

2547 

2548 tot_w_pred = sum(w_pred) 

2549 for nn in range(0, self.max_expected_card + 1): 

2550 terms0_E = np.zeros((min(zlen, nn) + 1)) 

2551 for jj in range(0, min(zlen, nn) + 1): 

2552 t1 = -self.clutter_rate + (zlen - jj) * np.log(self.clutter_rate) 

2553 t2 = sum([np.log(x) for x in range(1, nn + 1)]) 

2554 t3 = -1 * sum([np.log(x) for x in range(1, nn - jj + 1)]) 

2555 t4 = (nn - jj) * np.log(self.prob_death) 

2556 t5 = -jj * np.log(tot_w_pred) 

2557 terms0_E[jj] = np.exp(t1 + t2 + t3 + t4 + t5) * esfvals_E[jj] 

2558 ups0_E[nn] = np.sum(terms0_E) 

2559 

2560 terms1_E = np.zeros((min(zlen, nn) + 1)) 

2561 for jj in range(0, min(zlen, nn) + 1): 

2562 if nn >= jj + 1: 

2563 t1 = -self.clutter_rate + (zlen - jj) * np.log(self.clutter_rate) 

2564 t2 = sum([np.log(x) for x in range(1, nn + 1)]) 

2565 t3 = -1 * sum([np.log(x) for x in range(1, nn - (jj + 1) + 1)]) 

2566 t4 = (nn - (jj + 1)) * np.log(self.prob_death) 

2567 t5 = -(jj + 1) * np.log(tot_w_pred) 

2568 terms1_E[jj] = np.exp(t1 + t2 + t3 + t4 + t5) * esfvals_E[jj] 

2569 ups1_E[nn] = np.sum(terms1_E) 

2570 

2571 if zlen != 0: 

2572 terms1_D = np.zeros((min(zlen - 1, nn) + 1, zlen)) 

2573 for ell in range(1, zlen + 1): 

2574 for jj in range(0, min((zlen - 1), nn) + 1): 

2575 if nn >= jj + 1: 

2576 t1 = -self.clutter_rate + ((zlen - 1) - jj) * np.log( 

2577 self.clutter_rate 

2578 ) 

2579 t2 = sum([np.log(x) for x in range(1, nn + 1)]) 

2580 t3 = -1 * sum( 

2581 [np.log(x) for x in range(1, nn - (jj + 1) + 1)] 

2582 ) 

2583 t4 = (nn - (jj + 1)) * np.log(self.prob_death) 

2584 t5 = -(jj + 1) * np.log(tot_w_pred) 

2585 terms1_D[jj, ell - 1] = ( 

2586 np.exp(t1 + t2 + t3 + t4 + t5) * esfvals_D[jj, ell - 1] 

2587 ) 

2588 ups1_D[nn, :] = np.sum(terms1_D, axis=0) 

2589 gmix = deepcopy(probDensity) 

2590 w_update = ( 

2591 ((ups1_E.T @ self._card_dist) / (ups0_E.T @ self._card_dist)) 

2592 * self.prob_miss_detection 

2593 * w_pred 

2594 ) 

2595 

2596 old_filt_states = [] 

2597 gmix.weights = [x.item() for x in w_update] 

2598 for jj, (m, c) in enumerate(zip(gmix.means, gmix.covariances)): 

2599 m_list = [] 

2600 c_list = [] 

2601 for ff in range(0, len(self.filter.in_filt_list)): 

2602 m_list.append(m) 

2603 c_list.append(c) 

2604 self.filter.initialize_states( 

2605 m_list, c_list, init_weights=saved_filt_weights[jj] 

2606 ) 

2607 old_filt_states.append(self.filter.save_filter_state()) 

2608 

2609 for ee in range(0, zlen): 

2610 wt_1 = ( 

2611 (ups1_D[:, [ee]].T @ self._card_dist) / (ups0_E.T @ self._card_dist) 

2612 ).reshape((1, 1)) 

2613 wt_2 = self.prob_detection * qz_temp[:, [ee]] / self.clutter_den * w_pred 

2614 w_temp = wt_1 * wt_2 

2615 for ww in range(0, w_temp.shape[0]): 

2616 gmix.add_components( 

2617 mean_temp[ee, :, ww].reshape((xdim, 1)), 

2618 cov_temp[ee, ww, :, :], 

2619 w_temp[ww].item(), 

2620 ) 

2621 cdn_update = self._card_dist.copy() 

2622 for ii in range(0, len(cdn_update)): 

2623 cdn_update[ii] = ups0_E[ii] * self._card_dist[ii] 

2624 self._card_dist = cdn_update / np.sum(cdn_update) 

2625 # assumes predict is called before correct 

2626 self._card_time_hist[-1] = ( 

2627 np.argmax(self._card_dist).item(), 

2628 np.std(self._card_dist), 

2629 ) 

2630 for filt_state in new_filter_states: 

2631 old_filt_states.append(filt_state) 

2632 self._filter_states = old_filt_states 

2633 return gmix 

2634 

2635 

2636class GeneralizedLabeledMultiBernoulli(RandomFiniteSetBase): 

2637 """Delta-Generalized Labeled Multi-Bernoulli filter. 

2638 

2639 Notes 

2640 ----- 

2641 This is based on :cite:`Vo2013_LabeledRandomFiniteSetsandMultiObjectConjugatePriors` 

2642 and :cite:`Vo2014_LabeledRandomFiniteSetsandtheBayesMultiTargetTrackingFilter` 

2643 It does not account for agents spawned from existing tracks, only agents 

2644 birthed from the given birth model. 

2645 

2646 Attributes 

2647 ---------- 

2648 req_births : int 

2649 Number of requested birth hypotheses 

2650 req_surv : int 

2651 Number of requested surviving hypotheses 

2652 req_upd : int 

2653 Number of requested updated hypotheses 

2654 gating_on : bool 

2655 Determines if measurements are gated 

2656 birth_terms :list 

2657 List of tuples where the first element is a 

2658 :py:class:`gncpy.distributions.GaussianMixture` and 

2659 the second is the birth probability for that term 

2660 prune_threshold : float 

2661 Minimum association probability to keep when pruning 

2662 max_hyps : int 

2663 Maximum number of hypotheses to keep when capping 

2664 decimal_places : int 

2665 Number of decimal places to keep in label. The default is 2. 

2666 save_measurements : bool 

2667 Flag indicating if measurments should be saved. Useful for some extra 

2668 plots. 

2669 """ 

2670 

2671 class _TabEntry: 

2672 def __init__(self): 

2673 self.label = () # time step born, index of birth model born from 

2674 self.distrib_weights_hist = [] # list of weights of the probDensity 

2675 self.filt_states = [] # list of dictionaries from filters save function 

2676 self.meas_assoc_hist = ( 

2677 [] 

2678 ) # list indices into measurement list per time step 

2679 

2680 self.state_hist = [] # list of lists of numpy arrays for each timestep 

2681 self.cov_hist = ( 

2682 [] 

2683 ) # list of lists of numpy arrays for each timestep (or None) 

2684 

2685 """ linear index corresponding to timestep, manually updated. Used 

2686 to index things since timestep in label can have decimals.""" 

2687 self.time_index = None 

2688 

2689 def setup(self, tab): 

2690 """Use to avoid expensive deepcopy.""" 

2691 self.label = tab.label 

2692 self.distrib_weights_hist = tab.distrib_weights_hist.copy() 

2693 self.filt_states = deepcopy(tab.filt_states) 

2694 self.meas_assoc_hist = tab.meas_assoc_hist.copy() 

2695 

2696 self.state_hist = [None] * len(tab.state_hist) 

2697 self.state_hist = [s.copy() for s in [s_lst for s_lst in tab.state_hist]] 

2698 self.cov_hist = [ 

2699 c.copy() if c else [] for c in [c_lst for c_lst in tab.cov_hist] 

2700 ] 

2701 

2702 self.time_index = tab.time_index 

2703 

2704 return self 

2705 

2706 class _HypothesisHelper: 

2707 def __init__(self): 

2708 self.assoc_prob = 0 

2709 self.track_set = [] # indices in lookup table 

2710 

2711 @property 

2712 def num_tracks(self): 

2713 return len(self.track_set) 

2714 

2715 class _ExtractHistHelper: 

2716 def __init__(self): 

2717 self.label = () 

2718 self.meas_ind_hist = [] 

2719 self.b_time_index = None 

2720 self.states = [] 

2721 self.covs = [] 

2722 

2723 def __init__( 

2724 self, 

2725 req_births=None, 

2726 req_surv=None, 

2727 req_upd=None, 

2728 gating_on=False, 

2729 prune_threshold=10**-15, 

2730 max_hyps=3000, 

2731 decimal_places=2, 

2732 save_measurements=False, 

2733 **kwargs, 

2734 ): 

2735 self.req_births = req_births 

2736 self.req_surv = req_surv 

2737 self.req_upd = req_upd 

2738 self.gating_on = gating_on 

2739 self.prune_threshold = prune_threshold 

2740 self.max_hyps = max_hyps 

2741 self.decimal_places = decimal_places 

2742 self.save_measurements = save_measurements 

2743 

2744 self._track_tab = [] # list of all possible tracks 

2745 self._labels = [] # local copy for internal modification 

2746 self._extractable_hists = [] 

2747 

2748 self._filter = None 

2749 self._baseFilter = None 

2750 

2751 hyp0 = self._HypothesisHelper() 

2752 hyp0.assoc_prob = 1 

2753 hyp0.track_set = [] 

2754 self._hypotheses = [hyp0] # list of _HypothesisHelper objects 

2755 

2756 self._card_dist = [] # probability of having index # as cardinality 

2757 

2758 """ linear index corresponding to timestep, manually updated. Used 

2759 to index things since timestep in label can have decimals. Must 

2760 be updated once per time step.""" 

2761 self._time_index_cntr = 0 

2762 

2763 self.ospa2 = None 

2764 self.ospa2_localization = None 

2765 self.ospa2_cardinality = None 

2766 self._ospa2_params = {} 

2767 

2768 super().__init__(**kwargs) 

2769 self._states = [[]] 

2770 

2771 def save_filter_state(self): 

2772 """Saves filter variables so they can be restored later. 

2773 

2774 Note that to pickle the resulting dictionary the :code:`dill` package 

2775 may need to be used due to potential pickling of functions. 

2776 """ 

2777 filt_state = super().save_filter_state() 

2778 

2779 filt_state["req_births"] = self.req_births 

2780 filt_state["req_surv"] = self.req_surv 

2781 filt_state["req_upd"] = self.req_upd 

2782 filt_state["gating_on"] = self.gating_on 

2783 filt_state["prune_threshold"] = self.prune_threshold 

2784 filt_state["max_hyps"] = self.max_hyps 

2785 filt_state["decimal_places"] = self.decimal_places 

2786 filt_state["save_measurements"] = self.save_measurements 

2787 

2788 filt_state["_track_tab"] = self._track_tab 

2789 filt_state["_labels"] = self._labels 

2790 filt_state["_extractable_hists"] = self._extractable_hists 

2791 

2792 if self._baseFilter is not None: 

2793 filt_state["_baseFilter"] = ( 

2794 type(self._baseFilter), 

2795 self._baseFilter.save_filter_state(), 

2796 ) 

2797 else: 

2798 filt_state["_baseFilter"] = (None, self._baseFilter) 

2799 filt_state["_hypotheses"] = self._hypotheses 

2800 filt_state["_card_dist"] = self._card_dist 

2801 filt_state["_time_index_cntr"] = self._time_index_cntr 

2802 

2803 filt_state["ospa2"] = self.ospa2 

2804 filt_state["ospa2_localization"] = self.ospa2_localization 

2805 filt_state["ospa2_cardinality"] = self.ospa2_cardinality 

2806 filt_state["_ospa2_params"] = self._ospa_params 

2807 

2808 return filt_state 

2809 

2810 def load_filter_state(self, filt_state): 

2811 """Initializes filter using saved filter state. 

2812 

2813 Attributes 

2814 ---------- 

2815 filt_state : dict 

2816 Dictionary generated by :meth:`save_filter_state`. 

2817 """ 

2818 super().load_filter_state(filt_state) 

2819 

2820 self.req_births = filt_state["req_births"] 

2821 self.req_surv = filt_state["req_surv"] 

2822 self.req_upd = filt_state["req_upd"] 

2823 self.gating_on = filt_state["gating_on"] 

2824 self.prune_threshold = filt_state["prune_threshold"] 

2825 self.max_hyps = filt_state["max_hyps"] 

2826 self.decimal_places = filt_state["decimal_places"] 

2827 self.save_measurements = filt_state["save_measurements"] 

2828 

2829 self._track_tab = filt_state["_track_tab"] 

2830 self._labels = filt_state["_labels"] 

2831 self._extractable_hists = filt_state["_extractable_hists"] 

2832 

2833 cls_type = filt_state["_baseFilter"][0] 

2834 if cls_type is not None: 

2835 self._baseFilter = cls_type() 

2836 self._baseFilter.load_filter_state(filt_state["_baseFilter"][1]) 

2837 else: 

2838 self._baseFilter = None 

2839 self._hypotheses = filt_state["_hypotheses"] 

2840 self._card_dist = filt_state["_card_dist"] 

2841 self._time_index_cntr = filt_state["_time_index_cntr"] 

2842 

2843 self.ospa2 = filt_state["ospa2"] 

2844 self.ospa2_localization = filt_state["ospa2_localization"] 

2845 self.ospa2_cardinality = filt_state["ospa2_cardinality"] 

2846 self._ospa2_params = filt_state["_ospa2_params"] 

2847 

2848 @property 

2849 def states(self): 

2850 """Read only list of extracted states. 

2851 

2852 This is a list with 1 element per timestep, and each element is a list 

2853 of the best states extracted at that timestep. The order of each 

2854 element corresponds to the label order. 

2855 """ 

2856 return self._states 

2857 

2858 @property 

2859 def labels(self): 

2860 """Read only list of extracted labels. 

2861 

2862 This is a list with 1 element per timestep, and each element is a list 

2863 of the best labels extracted at that timestep. The order of each 

2864 element corresponds to the state order. 

2865 """ 

2866 return self._labels 

2867 

2868 @property 

2869 def covariances(self): 

2870 """Read only list of extracted covariances. 

2871 

2872 This is a list with 1 element per timestep, and each element is a list 

2873 of the best covariances extracted at that timestep. The order of each 

2874 element corresponds to the state order. 

2875 

2876 Raises 

2877 ------ 

2878 RuntimeWarning 

2879 If the class is not saving the covariances, and returns an empty list. 

2880 """ 

2881 if not self.save_covs: 

2882 raise RuntimeWarning("Not saving covariances") 

2883 return [] 

2884 return self._covs 

2885 

2886 @property 

2887 def filter(self): 

2888 """Inner filter handling dynamics, must be a gncpy.filters.BayesFilter.""" 

2889 return self._filter 

2890 

2891 @filter.setter 

2892 def filter(self, val): 

2893 self._baseFilter = deepcopy(val) 

2894 self._filter = val 

2895 

2896 @property 

2897 def cardinality(self): 

2898 """Cardinality estimate.""" 

2899 return np.argmax(self._card_dist) 

2900 

2901 def _init_filt_states(self, distrib): 

2902 filt_states = [None] * len(distrib.means) 

2903 states = [m.copy() for m in distrib.means] 

2904 if self.save_covs: 

2905 covs = [c.copy() for c in distrib.covariances] 

2906 else: 

2907 covs = [] 

2908 weights = distrib.weights.copy() 

2909 for ii, (m, cov) in enumerate(zip(distrib.means, distrib.covariances)): 

2910 self._baseFilter.cov = cov.copy() 

2911 if isinstance(self._baseFilter, gfilts.UnscentedKalmanFilter) or isinstance( 

2912 self._baseFilter, gfilts.UKFGaussianScaleMixtureFilter 

2913 ): 

2914 self._baseFilter.init_sigma_points(m) 

2915 filt_states[ii] = self._baseFilter.save_filter_state() 

2916 return filt_states, weights, states, covs 

2917 

2918 def _gen_birth_tab(self, timestep): 

2919 log_cost = [] 

2920 birth_tab = [] 

2921 for ii, (distrib, p) in enumerate(self.birth_terms): 

2922 cost = p / (1 - p) 

2923 log_cost.append(-np.log(cost)) 

2924 entry = self._TabEntry() 

2925 entry.state_hist = [None] 

2926 entry.cov_hist = [None] 

2927 entry.distrib_weights_hist = [None] 

2928 ( 

2929 entry.filt_states, 

2930 entry.distrib_weights_hist[0], 

2931 entry.state_hist[0], 

2932 entry.cov_hist[0], 

2933 ) = self._init_filt_states(distrib) 

2934 entry.label = (round(timestep, self.decimal_places), ii) 

2935 entry.time_index = self._time_index_cntr 

2936 birth_tab.append(entry) 

2937 return birth_tab, log_cost 

2938 

2939 def _gen_birth_hyps(self, paths, hyp_costs): 

2940 birth_hyps = [] 

2941 tot_b_prob = sum([np.log(1 - x[1]) for x in self.birth_terms]) 

2942 for p, c in zip(paths, hyp_costs): 

2943 hyp = self._HypothesisHelper() 

2944 # NOTE: this may suffer from underflow and can be improved 

2945 hyp.assoc_prob = tot_b_prob - c.item() 

2946 hyp.track_set = p 

2947 birth_hyps.append(hyp) 

2948 lse = log_sum_exp([x.assoc_prob for x in birth_hyps]) 

2949 for ii in range(0, len(birth_hyps)): 

2950 birth_hyps[ii].assoc_prob = np.exp(birth_hyps[ii].assoc_prob - lse) 

2951 return birth_hyps 

2952 

2953 def _inner_predict(self, timestep, filt_state, state, filt_args): 

2954 self.filter.load_filter_state(filt_state) 

2955 new_s = self.filter.predict(timestep, state, **filt_args) 

2956 new_f_state = self.filter.save_filter_state() 

2957 if self.save_covs: 

2958 new_cov = self.filter.cov.copy() 

2959 else: 

2960 new_cov = None 

2961 return new_f_state, new_s, new_cov 

2962 

2963 def _predict_track_tab_entry(self, tab, timestep, filt_args): 

2964 """Updates table entries probability density.""" 

2965 newTab = self._TabEntry().setup(tab) 

2966 new_f_states = [None] * len(newTab.filt_states) 

2967 new_s_hist = [None] * len(newTab.filt_states) 

2968 new_c_hist = [None] * len(newTab.filt_states) 

2969 for ii, (f_state, state) in enumerate( 

2970 zip(newTab.filt_states, newTab.state_hist[-1]) 

2971 ): 

2972 (new_f_states[ii], new_s_hist[ii], new_c_hist[ii]) = self._inner_predict( 

2973 timestep, f_state, state, filt_args 

2974 ) 

2975 newTab.filt_states = new_f_states 

2976 newTab.state_hist.append(new_s_hist) 

2977 newTab.cov_hist.append(new_c_hist) 

2978 newTab.distrib_weights_hist.append(newTab.distrib_weights_hist[-1].copy()) 

2979 return newTab 

2980 

2981 def _gen_surv_tab(self, timestep, filt_args): 

2982 surv_tab = [] 

2983 for ii, track in enumerate(self._track_tab): 

2984 entry = self._predict_track_tab_entry(track, timestep, filt_args) 

2985 

2986 surv_tab.append(entry) 

2987 return surv_tab 

2988 

2989 def _gen_surv_hyps(self, avg_prob_survive, avg_prob_death): 

2990 surv_hyps = [] 

2991 sum_sqrt_w = 0 

2992 # avg_prob_mm = 

2993 for hyp in self._hypotheses: 

2994 sum_sqrt_w = sum_sqrt_w + np.sqrt(hyp.assoc_prob) 

2995 for hyp in self._hypotheses: 

2996 if hyp.num_tracks == 0: 

2997 new_hyp = self._HypothesisHelper() 

2998 new_hyp.assoc_prob = np.log(hyp.assoc_prob) 

2999 new_hyp.track_set = hyp.track_set 

3000 surv_hyps.append(new_hyp) 

3001 else: 

3002 cost = avg_prob_survive[hyp.track_set] / avg_prob_death[hyp.track_set] 

3003 log_cost = -np.log(cost) # this is length hyp.num_tracks 

3004 k = np.round(self.req_surv * np.sqrt(hyp.assoc_prob) / sum_sqrt_w) 

3005 (paths, hyp_cost) = k_shortest(np.array(log_cost), k) 

3006 

3007 pdeath_log = np.sum( 

3008 [np.log(avg_prob_death[ii]) for ii in hyp.track_set] 

3009 ) 

3010 

3011 for p, c in zip(paths, hyp_cost): 

3012 new_hyp = self._HypothesisHelper() 

3013 new_hyp.assoc_prob = pdeath_log + np.log(hyp.assoc_prob) - c.item() 

3014 if len(p) > 0: 

3015 new_hyp.track_set = [hyp.track_set[ii] for ii in p] 

3016 else: 

3017 new_hyp.track_set = [] 

3018 surv_hyps.append(new_hyp) 

3019 lse = log_sum_exp([x.assoc_prob for x in surv_hyps]) 

3020 for ii in range(0, len(surv_hyps)): 

3021 surv_hyps[ii].assoc_prob = np.exp(surv_hyps[ii].assoc_prob - lse) 

3022 return surv_hyps 

3023 

3024 def _calc_avg_prob_surv_death(self): 

3025 avg_prob_survive = self.prob_survive * np.ones(len(self._track_tab)) 

3026 avg_prob_death = 1 - avg_prob_survive 

3027 

3028 return avg_prob_survive, avg_prob_death 

3029 

3030 def _set_pred_hyps(self, birth_tab, birth_hyps, surv_hyps): 

3031 self._hypotheses = [] 

3032 tot_w = 0 

3033 for b_hyp in birth_hyps: 

3034 for s_hyp in surv_hyps: 

3035 new_hyp = self._HypothesisHelper() 

3036 new_hyp.assoc_prob = b_hyp.assoc_prob * s_hyp.assoc_prob 

3037 tot_w = tot_w + new_hyp.assoc_prob 

3038 surv_lst = [] 

3039 for x in s_hyp.track_set: 

3040 surv_lst.append(x + len(birth_tab)) 

3041 new_hyp.track_set = b_hyp.track_set + surv_lst 

3042 self._hypotheses.append(new_hyp) 

3043 for ii in range(0, len(self._hypotheses)): 

3044 n_val = self._hypotheses[ii].assoc_prob / tot_w 

3045 self._hypotheses[ii].assoc_prob = n_val 

3046 

3047 def _calc_card_dist(self, hyp_lst): 

3048 """Calucaltes the cardinality distribution.""" 

3049 if len(hyp_lst) == 0: 

3050 return [ 

3051 1, 

3052 ] 

3053 card_dist = [] 

3054 for ii in range(0, max(map(lambda x: x.num_tracks, hyp_lst)) + 1): 

3055 card = 0 

3056 for hyp in hyp_lst: 

3057 if hyp.num_tracks == ii: 

3058 card = card + hyp.assoc_prob 

3059 card_dist.append(card) 

3060 return card_dist 

3061 

3062 def _clean_predictions(self): 

3063 hash_lst = [] 

3064 for hyp in self._hypotheses: 

3065 if len(hyp.track_set) == 0: 

3066 lst = [] 

3067 else: 

3068 sorted_inds = hyp.track_set.copy() 

3069 sorted_inds.sort() 

3070 lst = [int(x) for x in sorted_inds] 

3071 h = hash("*".join(map(str, lst))) 

3072 hash_lst.append(h) 

3073 new_hyps = [] 

3074 used_hash = [] 

3075 for ii, h in enumerate(hash_lst): 

3076 if h not in used_hash: 

3077 used_hash.append(h) 

3078 new_hyps.append(self._hypotheses[ii]) 

3079 else: 

3080 new_ii = used_hash.index(h) 

3081 new_hyps[new_ii].assoc_prob += self._hypotheses[ii].assoc_prob 

3082 self._hypotheses = new_hyps 

3083 

3084 def predict(self, timestep, filt_args={}): 

3085 """Prediction step of the GLMB filter. 

3086 

3087 This predicts new hypothesis, and propogates them to the next time 

3088 step. It also updates the cardinality distribution. 

3089 

3090 Parameters 

3091 ---------- 

3092 timestep: float 

3093 Current timestep. 

3094 filt_args : dict, optional 

3095 Passed to the inner filter. The default is {}. 

3096 

3097 Returns 

3098 ------- 

3099 None. 

3100 """ 

3101 # Find cost for each birth track, and setup lookup table 

3102 birth_tab, log_cost = self._gen_birth_tab(timestep) 

3103 

3104 # get K best hypothesis, and their index in the lookup table 

3105 (paths, hyp_costs) = k_shortest(np.array(log_cost), self.req_births) 

3106 

3107 # calculate association probabilities for birth hypothesis 

3108 birth_hyps = self._gen_birth_hyps(paths, hyp_costs) 

3109 

3110 # Init and propagate surviving track table 

3111 surv_tab = self._gen_surv_tab(timestep, filt_args) 

3112 

3113 # Calculation for average survival/death probabilities 

3114 (avg_prob_survive, avg_prob_death) = self._calc_avg_prob_surv_death() 

3115 

3116 # loop over postierior components 

3117 surv_hyps = self._gen_surv_hyps(avg_prob_survive, avg_prob_death) 

3118 

3119 self._card_dist = self._calc_card_dist(surv_hyps) 

3120 

3121 # Get predicted hypothesis by convolution 

3122 self._track_tab = birth_tab + surv_tab 

3123 self._set_pred_hyps(birth_tab, birth_hyps, surv_hyps) 

3124 

3125 self._clean_predictions() 

3126 

3127 def _inner_correct( 

3128 self, timestep, meas, filt_state, distrib_weight, state, filt_args 

3129 ): 

3130 self.filter.load_filter_state(filt_state) 

3131 cor_state, likely = self.filter.correct(timestep, meas, state, **filt_args) 

3132 new_f_state = self.filter.save_filter_state() 

3133 new_s = cor_state 

3134 if self.save_covs: 

3135 new_c = self.filter.cov.copy() 

3136 else: 

3137 new_c = None 

3138 new_w = distrib_weight * likely 

3139 

3140 return new_f_state, new_s, new_c, new_w 

3141 

3142 def _correct_track_tab_entry(self, meas, tab, timestep, filt_args): 

3143 newTab = self._TabEntry().setup(tab) 

3144 new_f_states = [None] * len(newTab.filt_states) 

3145 new_s_hist = [None] * len(newTab.filt_states) 

3146 new_c_hist = [None] * len(newTab.filt_states) 

3147 new_w = [None] * len(newTab.filt_states) 

3148 depleted = False 

3149 for ii, (f_state, state, w) in enumerate( 

3150 zip( 

3151 newTab.filt_states, 

3152 newTab.state_hist[-1], 

3153 newTab.distrib_weights_hist[-1], 

3154 ) 

3155 ): 

3156 try: 

3157 ( 

3158 new_f_states[ii], 

3159 new_s_hist[ii], 

3160 new_c_hist[ii], 

3161 new_w[ii], 

3162 ) = self._inner_correct(timestep, meas, f_state, w, state, filt_args) 

3163 except ( 

3164 gerr.ParticleDepletionError, 

3165 gerr.ParticleEstimationDomainError, 

3166 gerr.ExtremeMeasurementNoiseError, 

3167 ): 

3168 return None, 0 

3169 newTab.filt_states = new_f_states 

3170 newTab.state_hist[-1] = new_s_hist 

3171 newTab.cov_hist[-1] = new_c_hist 

3172 new_w = [w + np.finfo(float).eps for w in new_w] 

3173 if not depleted: 

3174 cost = np.sum(new_w).item() 

3175 newTab.distrib_weights_hist[-1] = [w / cost for w in new_w] 

3176 else: 

3177 cost = 0 

3178 return newTab, cost 

3179 

3180 def _gen_cor_tab(self, num_meas, meas, timestep, filt_args): 

3181 num_pred = len(self._track_tab) 

3182 up_tab = [None] * (num_meas + 1) * num_pred 

3183 

3184 for ii, track in enumerate(self._track_tab): 

3185 up_tab[ii] = self._TabEntry().setup(track) 

3186 up_tab[ii].meas_assoc_hist.append(None) 

3187 # measurement updated tracks 

3188 all_cost_m = np.zeros((num_pred, num_meas)) 

3189 for emm, z in enumerate(meas): 

3190 for ii, ent in enumerate(self._track_tab): 

3191 s_to_ii = num_pred * emm + ii + num_pred 

3192 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry( 

3193 z, ent, timestep, filt_args 

3194 ) 

3195 

3196 # update association history with current measurement index 

3197 if up_tab[s_to_ii] is not None: 

3198 up_tab[s_to_ii].meas_assoc_hist.append(emm) 

3199 all_cost_m[ii, emm] = cost 

3200 return up_tab, all_cost_m 

3201 

3202 def _gen_cor_hyps( 

3203 self, num_meas, avg_prob_detect, avg_prob_miss_detect, all_cost_m 

3204 ): 

3205 num_pred = len(self._track_tab) 

3206 up_hyps = [] 

3207 if num_meas == 0: 

3208 for hyp in self._hypotheses: 

3209 pmd_log = np.sum( 

3210 [np.log(avg_prob_miss_detect[ii]) for ii in hyp.track_set] 

3211 ) 

3212 hyp.assoc_prob = -self.clutter_rate + pmd_log + np.log(hyp.assoc_prob) 

3213 up_hyps.append(hyp) 

3214 else: 

3215 clutter = self.clutter_rate * self.clutter_den 

3216 ss_w = 0 

3217 for p_hyp in self._hypotheses: 

3218 ss_w += np.sqrt(p_hyp.assoc_prob) 

3219 for p_hyp in self._hypotheses: 

3220 if p_hyp.num_tracks == 0: # all clutter 

3221 new_hyp = self._HypothesisHelper() 

3222 new_hyp.assoc_prob = ( 

3223 -self.clutter_rate 

3224 + num_meas * np.log(clutter) 

3225 + np.log(p_hyp.assoc_prob) 

3226 ) 

3227 new_hyp.track_set = p_hyp.track_set.copy() 

3228 up_hyps.append(new_hyp) 

3229 else: 

3230 pd = np.array([avg_prob_detect[ii] for ii in p_hyp.track_set]) 

3231 pmd = np.array([avg_prob_miss_detect[ii] for ii in p_hyp.track_set]) 

3232 ratio = pd / pmd 

3233 

3234 ratio = ratio.reshape((ratio.size, 1)) 

3235 ratio = np.tile(ratio, (1, num_meas)) 

3236 

3237 cost_m = np.zeros(all_cost_m[p_hyp.track_set, :].shape) 

3238 for ii, ts in enumerate(p_hyp.track_set): 

3239 cost_m[ii, :] = ratio[ii] * all_cost_m[ts, :] / clutter 

3240 max_row_inds, max_col_inds = np.where(cost_m >= np.inf) 

3241 if max_row_inds.size > 0: 

3242 cost_m[max_row_inds, max_col_inds] = np.finfo(float).max 

3243 min_row_inds, min_col_inds = np.where(cost_m <= 0.0) 

3244 if min_row_inds.size > 0: 

3245 cost_m[min_row_inds, min_col_inds] = np.finfo(float).eps # 1 

3246 neg_log = -np.log(cost_m) 

3247 # if max_row_inds.size > 0: 

3248 # neg_log[max_row_inds, max_col_inds] = -np.inf 

3249 # if min_row_inds.size > 0: 

3250 # neg_log[min_row_inds, min_col_inds] = np.inf 

3251 

3252 m = np.round(self.req_upd * np.sqrt(p_hyp.assoc_prob) / ss_w) 

3253 m = int(m.item()) 

3254 [assigns, costs] = murty_m_best(neg_log, m) 

3255 

3256 pmd_log = np.sum( 

3257 [np.log(avg_prob_miss_detect[ii]) for ii in p_hyp.track_set] 

3258 ) 

3259 for a, c in zip(assigns, costs): 

3260 new_hyp = self._HypothesisHelper() 

3261 new_hyp.assoc_prob = ( 

3262 -self.clutter_rate 

3263 + num_meas * np.log(clutter) 

3264 + pmd_log 

3265 + np.log(p_hyp.assoc_prob) 

3266 - c 

3267 ) 

3268 new_hyp.track_set = list( 

3269 np.array(p_hyp.track_set) + num_pred * a 

3270 ) 

3271 up_hyps.append(new_hyp) 

3272 lse = log_sum_exp([x.assoc_prob for x in up_hyps]) 

3273 for ii in range(0, len(up_hyps)): 

3274 up_hyps[ii].assoc_prob = np.exp(up_hyps[ii].assoc_prob - lse) 

3275 return up_hyps 

3276 

3277 def _calc_avg_prob_det_mdet(self): 

3278 avg_prob_detect = self.prob_detection * np.ones(len(self._track_tab)) 

3279 avg_prob_miss_detect = 1 - avg_prob_detect 

3280 

3281 return avg_prob_detect, avg_prob_miss_detect 

3282 

3283 def _clean_updates(self): 

3284 used = [0] * len(self._track_tab) 

3285 for hyp in self._hypotheses: 

3286 for ii in hyp.track_set: 

3287 if self._track_tab[ii] is not None: 

3288 used[ii] += 1 

3289 nnz_inds = [idx for idx, val in enumerate(used) if val != 0] 

3290 track_cnt = len(nnz_inds) 

3291 

3292 new_inds = [None] * len(self._track_tab) 

3293 for ii, v in zip(nnz_inds, [ii for ii in range(0, track_cnt)]): 

3294 new_inds[ii] = v 

3295 # new_tab = [self._TabEntry().setup(self._track_tab[ii]) for ii in nnz_inds] 

3296 new_tab = [self._track_tab[ii] for ii in nnz_inds] 

3297 new_hyps = [] 

3298 for ii, hyp in enumerate(self._hypotheses): 

3299 if len(hyp.track_set) > 0: 

3300 track_set = [new_inds[ii] for ii in hyp.track_set] 

3301 if None in track_set: 

3302 continue 

3303 hyp.track_set = track_set 

3304 new_hyps.append(hyp) 

3305 self._track_tab = new_tab 

3306 self._hypotheses = new_hyps 

3307 

3308 def correct(self, timestep, meas, filt_args={}): 

3309 """Correction step of the GLMB filter. 

3310 

3311 Notes 

3312 ----- 

3313 This corrects the hypotheses based on the measurements and gates the 

3314 measurements according to the class settings. It also updates the 

3315 cardinality distribution. 

3316 

3317 Parameters 

3318 ---------- 

3319 timestep: float 

3320 Current timestep. 

3321 meas_in : list 

3322 List of Nm x 1 numpy arrays each representing a measuremnt. 

3323 filt_args : dict, optional 

3324 keyword arguments to pass to the inner filters correct function. 

3325 The default is {}. 

3326 

3327 .. todo:: 

3328 Fix the measurement gating 

3329 

3330 Returns 

3331 ------- 

3332 None 

3333 """ 

3334 # gate measurements by tracks 

3335 if self.gating_on: 

3336 warnings.warn("Gating not implemented yet. SKIPPING", RuntimeWarning) 

3337 # means = [] 

3338 # covs = [] 

3339 # for ent in self._track_tab: 

3340 # means.extend(ent.probDensity.means) 

3341 # covs.extend(ent.probDensity.covariances) 

3342 # meas = self._gate_meas(meas, means, covs) 

3343 if self.save_measurements: 

3344 self._meas_tab.append(deepcopy(meas)) 

3345 num_meas = len(meas) 

3346 

3347 # missed detection tracks 

3348 cor_tab, all_cost_m = self._gen_cor_tab(num_meas, meas, timestep, filt_args) 

3349 

3350 # Calculation for average detection/missed probabilities 

3351 avg_prob_det, avg_prob_mdet = self._calc_avg_prob_det_mdet() 

3352 

3353 # component updates 

3354 cor_hyps = self._gen_cor_hyps(num_meas, avg_prob_det, avg_prob_mdet, all_cost_m) 

3355 

3356 # save values and cleanup 

3357 self._track_tab = cor_tab 

3358 self._hypotheses = cor_hyps 

3359 self._card_dist = self._calc_card_dist(self._hypotheses) 

3360 self._clean_updates() 

3361 

3362 def _extract_helper(self, track): 

3363 states = [None] * len(track.state_hist) 

3364 covs = [None] * len(track.state_hist) 

3365 for ii, (w_lst, s_lst, c_lst) in enumerate( 

3366 zip(track.distrib_weights_hist, track.state_hist, track.cov_hist) 

3367 ): 

3368 idx = np.argmax(w_lst) 

3369 states[ii] = s_lst[idx] 

3370 if self.save_covs: 

3371 covs[ii] = c_lst[idx] 

3372 return states, covs 

3373 

3374 def _update_extract_hist(self, idx_cmp): 

3375 used_meas_inds = [[] for ii in range(self._time_index_cntr)] 

3376 used_labels = [] 

3377 new_extract_hists = [None] * len(self._hypotheses[idx_cmp].track_set) 

3378 for ii, track in enumerate( 

3379 [ 

3380 self._track_tab[trk_ind] 

3381 for trk_ind in self._hypotheses[idx_cmp].track_set 

3382 ] 

3383 ): 

3384 new_extract_hists[ii] = self._ExtractHistHelper() 

3385 new_extract_hists[ii].label = track.label 

3386 new_extract_hists[ii].meas_ind_hist = track.meas_assoc_hist.copy() 

3387 new_extract_hists[ii].b_time_index = track.time_index 

3388 ( 

3389 new_extract_hists[ii].states, 

3390 new_extract_hists[ii].covs, 

3391 ) = self._extract_helper(track) 

3392 

3393 used_labels.append(track.label) 

3394 

3395 for t_inds_after_b, meas_ind in enumerate( 

3396 new_extract_hists[ii].meas_ind_hist 

3397 ): 

3398 tt = new_extract_hists[ii].b_time_index + t_inds_after_b 

3399 if meas_ind is not None and meas_ind not in used_meas_inds[tt]: 

3400 used_meas_inds[tt].append(meas_ind) 

3401 good_inds = [] 

3402 for ii, existing in enumerate(self._extractable_hists): 

3403 used = existing.label in used_labels 

3404 if used: 

3405 continue 

3406 for t_inds_after_b, meas_ind in enumerate(existing.meas_ind_hist): 

3407 tt = existing.b_time_index + t_inds_after_b 

3408 used = meas_ind is not None and meas_ind in used_meas_inds[tt] 

3409 if used: 

3410 break 

3411 if not used: 

3412 good_inds.append(ii) 

3413 self._extractable_hists = [self._extractable_hists[ii] for ii in good_inds] 

3414 self._extractable_hists.extend(new_extract_hists) 

3415 

3416 def extract_states(self, update=True, calc_states=True): 

3417 """Extracts the best state estimates. 

3418 

3419 This extracts the best states from the distribution. It should be 

3420 called once per time step after the correction function. This calls 

3421 both the inner filters predict and correct functions so the keyword 

3422 arguments must contain any additional variables needed by those 

3423 functions. 

3424 

3425 Parameters 

3426 ---------- 

3427 update : bool, optional 

3428 Flag indicating if the label history should be updated. This should 

3429 be done once per timestep and can be disabled if calculating states 

3430 after the final timestep. The default is True. 

3431 calc_states : bool, optional 

3432 Flag indicating if the states should be calculated based on the 

3433 label history. This only needs to be done before the states are used. 

3434 It can simply be called once after the end of the simulation. The 

3435 default is true. 

3436 

3437 Returns 

3438 ------- 

3439 idx_cmp : int 

3440 Index of the hypothesis table used when extracting states. 

3441 """ 

3442 card = np.argmax(self._card_dist) 

3443 tracks_per_hyp = np.array([x.num_tracks for x in self._hypotheses]) 

3444 weight_per_hyp = np.array([x.assoc_prob for x in self._hypotheses]) 

3445 

3446 self._states = [[] for ii in range(self._time_index_cntr)] 

3447 self._labels = [[] for ii in range(self._time_index_cntr)] 

3448 self._covs = [[] for ii in range(self._time_index_cntr)] 

3449 

3450 if len(tracks_per_hyp) == 0: 

3451 return None 

3452 idx_cmp = np.argmax(weight_per_hyp * (tracks_per_hyp == card)) 

3453 if update: 

3454 self._update_extract_hist(idx_cmp) 

3455 if calc_states: 

3456 for existing in self._extractable_hists: 

3457 for t_inds_after_b, (s, c) in enumerate( 

3458 zip(existing.states, existing.covs) 

3459 ): 

3460 tt = existing.b_time_index + t_inds_after_b 

3461 # if len(self._labels[tt]) == 0: 

3462 # self._states[tt] = [s] 

3463 # self._labels[tt] = [existing.label] 

3464 # self._covs[tt] = [c] 

3465 # else: 

3466 self._states[tt].append(s) 

3467 self._labels[tt].append(existing.label) 

3468 self._covs[tt].append(c) 

3469 if not update and not calc_states: 

3470 warnings.warn("Extracting states performed no actions") 

3471 return idx_cmp 

3472 

3473 def extract_most_prob_states(self, thresh): 

3474 """Extracts the most probable hypotheses up to a threshold. 

3475 

3476 Parameters 

3477 ---------- 

3478 thresh : float 

3479 Minimum association probability to extract. 

3480 

3481 Returns 

3482 ------- 

3483 state_sets : list 

3484 Each element is the state list from the normal 

3485 :meth:`carbs.swarm_estimator.tracker.GeneralizedLabeledMultiBernoulli.extract_states`. 

3486 label_sets : list 

3487 Each element is the label list from the normal 

3488 :meth:`carbs.swarm_estimator.tracker.GeneralizedLabeledMultiBernoulli.extract_states` 

3489 cov_sets : list 

3490 Each element is the covariance list from the normal 

3491 :meth:`carbs.swarm_estimator.tracker.GeneralizedLabeledMultiBernoulli.extract_states` 

3492 if the covariances are saved. 

3493 probs : list 

3494 Each element is the association probability for the extracted states. 

3495 """ 

3496 loc_self = deepcopy(self) 

3497 state_sets = [] 

3498 cov_sets = [] 

3499 label_sets = [] 

3500 probs = [] 

3501 

3502 idx = loc_self.extract_states() 

3503 if idx is None: 

3504 return (state_sets, label_sets, cov_sets, probs) 

3505 state_sets.append(loc_self.states.copy()) 

3506 label_sets.append(loc_self.labels.copy()) 

3507 if loc_self.save_covs: 

3508 cov_sets.append(loc_self.covariances.copy()) 

3509 probs.append(loc_self._hypotheses[idx].assoc_prob) 

3510 loc_self._hypotheses[idx].assoc_prob = 0 

3511 while True: 

3512 idx = loc_self.extract_states() 

3513 if idx is None: 

3514 break 

3515 if loc_self._hypotheses[idx].assoc_prob >= thresh: 

3516 state_sets.append(loc_self.states.copy()) 

3517 label_sets.append(loc_self.labels.copy()) 

3518 if loc_self.save_covs: 

3519 cov_sets.append(loc_self.covariances.copy()) 

3520 probs.append(loc_self._hypotheses[idx].assoc_prob) 

3521 loc_self._hypotheses[idx].assoc_prob = 0 

3522 else: 

3523 break 

3524 return (state_sets, label_sets, cov_sets, probs) 

3525 

3526 def _prune(self): 

3527 """Removes hypotheses below a threshold. 

3528 

3529 This should be called once per time step after the correction and 

3530 before the state extraction. 

3531 """ 

3532 # Find hypotheses with low association probabilities 

3533 temp_assoc_probs = np.array([]) 

3534 for ii in range(0, len(self._hypotheses)): 

3535 temp_assoc_probs = np.append( 

3536 temp_assoc_probs, self._hypotheses[ii].assoc_prob 

3537 ) 

3538 keep_indices = np.argwhere(temp_assoc_probs > self.prune_threshold).T 

3539 keep_indices = keep_indices.flatten() 

3540 

3541 # For re-weighing association probabilities 

3542 new_sum = np.sum(temp_assoc_probs[keep_indices]) 

3543 self._hypotheses = [self._hypotheses[ii] for ii in keep_indices] 

3544 for ii in range(0, len(keep_indices)): 

3545 self._hypotheses[ii].assoc_prob = self._hypotheses[ii].assoc_prob / new_sum 

3546 # Re-calculate cardinality 

3547 self._card_dist = self._calc_card_dist(self._hypotheses) 

3548 

3549 def _cap(self): 

3550 """Removes least likely hypotheses until a maximum number is reached. 

3551 

3552 This should be called once per time step after pruning and 

3553 before the state extraction. 

3554 """ 

3555 # Determine if there are too many hypotheses 

3556 if len(self._hypotheses) > self.max_hyps: 

3557 temp_assoc_probs = np.array([]) 

3558 for ii in range(0, len(self._hypotheses)): 

3559 temp_assoc_probs = np.append( 

3560 temp_assoc_probs, self._hypotheses[ii].assoc_prob 

3561 ) 

3562 sorted_indices = np.argsort(temp_assoc_probs) 

3563 

3564 # Reverse order to get descending array 

3565 sorted_indices = sorted_indices[::-1] 

3566 

3567 # Take the top n assoc_probs, where n = max_hyps 

3568 keep_indices = np.array([], dtype=np.int64) 

3569 for ii in range(0, self.max_hyps): 

3570 keep_indices = np.append(keep_indices, int(sorted_indices[ii])) 

3571 # Assign to class 

3572 self._hypotheses = [self._hypotheses[ii] for ii in keep_indices] 

3573 

3574 # Normalize association probabilities 

3575 new_sum = 0 

3576 for ii in range(0, len(self._hypotheses)): 

3577 new_sum = new_sum + self._hypotheses[ii].assoc_prob 

3578 for ii in range(0, len(self._hypotheses)): 

3579 self._hypotheses[ii].assoc_prob = ( 

3580 self._hypotheses[ii].assoc_prob / new_sum 

3581 ) 

3582 # Re-calculate cardinality 

3583 self._card_dist = self._calc_card_dist(self._hypotheses) 

3584 

3585 def cleanup( 

3586 self, 

3587 enable_prune=True, 

3588 enable_cap=True, 

3589 enable_extract=True, 

3590 extract_kwargs=None, 

3591 ): 

3592 """Performs the cleanup step of the filter. 

3593 

3594 This can prune, cap, and extract states. It must be called once per 

3595 timestep, even if all three functions are disabled. This is to ensure 

3596 that internal counters for tracking linear timestep indices are properly 

3597 incremented. If this is called with `enable_extract` set to true then 

3598 the extract states method does not need to be called separately. It is 

3599 recommended to call this function instead of 

3600 :meth:`carbs.swarm_estimator.tracker.GeneralizedLabeledMultiBernoulli.extract_states` 

3601 directly. 

3602 

3603 Parameters 

3604 ---------- 

3605 enable_prune : bool, optional 

3606 Flag indicating if prunning should be performed. The default is True. 

3607 enable_cap : bool, optional 

3608 Flag indicating if capping should be performed. The default is True. 

3609 enable_extract : bool, optional 

3610 Flag indicating if state extraction should be performed. The default is True. 

3611 extract_kwargs : dict, optional 

3612 Additional arguments to pass to :meth:`.extract_states`. The 

3613 default is None. Only used if extracting states. 

3614 

3615 Returns 

3616 ------- 

3617 None. 

3618 

3619 """ 

3620 self._time_index_cntr += 1 

3621 

3622 if enable_prune: 

3623 self._prune() 

3624 if enable_cap: 

3625 self._cap() 

3626 if enable_extract: 

3627 if extract_kwargs is None: 

3628 extract_kwargs = {} 

3629 self.extract_states(**extract_kwargs) 

3630 

3631 def _ospa_setup_emat(self, state_dim, state_inds): 

3632 # get sizes 

3633 num_timesteps = len(self.states) 

3634 num_objs = 0 

3635 lbl_to_ind = {} 

3636 

3637 for lst in self.labels: 

3638 for lbl in lst: 

3639 if lbl is None: 

3640 continue 

3641 key = str(lbl) 

3642 if key not in lbl_to_ind: 

3643 lbl_to_ind[key] = num_objs 

3644 num_objs += 1 

3645 # create matrices 

3646 est_mat = np.nan * np.ones((state_dim, num_timesteps, num_objs)) 

3647 est_cov_mat = np.nan * np.ones((state_dim, state_dim, num_timesteps, num_objs)) 

3648 

3649 for tt, (lbl_lst, s_lst) in enumerate(zip(self.labels, self.states)): 

3650 for lbl, s in zip(lbl_lst, s_lst): 

3651 if lbl is None: 

3652 continue 

3653 obj_num = lbl_to_ind[str(lbl)] 

3654 est_mat[:, tt, obj_num] = s.ravel()[state_inds] 

3655 if self.save_covs: 

3656 for tt, (lbl_lst, c_lst) in enumerate(zip(self.labels, self.covariances)): 

3657 for lbl, c in zip(lbl_lst, c_lst): 

3658 if lbl is None: 

3659 continue 

3660 est_cov_mat[:, :, tt, lbl_to_ind[str(lbl)]] = c[state_inds][ 

3661 :, state_inds 

3662 ] 

3663 return est_mat, est_cov_mat 

3664 

3665 def calculate_ospa2( 

3666 self, 

3667 truth, 

3668 c, 

3669 p, 

3670 win_len, 

3671 true_covs=None, 

3672 core_method=SingleObjectDistance.MANHATTAN, 

3673 state_inds=None, 

3674 ): 

3675 """Calculates the OSPA(2) distance between the truth at all timesteps. 

3676 

3677 Wrapper for :func:`serums.distances.calculate_ospa2`. 

3678 

3679 Parameters 

3680 ---------- 

3681 truth : list 

3682 Each element represents a timestep and is a list of N x 1 numpy array, 

3683 one per true agent in the swarm. 

3684 c : float 

3685 Distance cutoff for considering a point properly assigned. This 

3686 influences how cardinality errors are penalized. For :math:`p = 1` 

3687 it is the penalty given false point estimate. 

3688 p : int 

3689 The power of the distance term. Higher values penalize outliers 

3690 more. 

3691 win_len : int 

3692 Number of samples to include in window. 

3693 core_method : :class:`serums.enums.SingleObjectDistance`, Optional 

3694 The main distance measure to use for the localization component. 

3695 The default value is :attr:`.SingleObjectDistance.MANHATTAN`. 

3696 true_covs : list, Optional 

3697 Each element represents a timestep and is a list of N x N numpy arrays 

3698 corresonponding to the uncertainty about the true states. Note the 

3699 order must be consistent with the truth data given. This is only 

3700 needed for core methods :attr:`SingleObjectDistance.HELLINGER`. The defautl 

3701 value is None. 

3702 state_inds : list, optional 

3703 Indices in the state vector to use, will be applied to the truth 

3704 data as well. The default is None which means the full state is 

3705 used. 

3706 """ 

3707 # error checking on optional input arguments 

3708 core_method = self._ospa_input_check(core_method, truth, true_covs) 

3709 

3710 # setup data structures 

3711 if state_inds is None: 

3712 state_dim = self._ospa_find_s_dim(truth) 

3713 state_inds = range(state_dim) 

3714 else: 

3715 state_dim = len(state_inds) 

3716 if state_dim is None: 

3717 warnings.warn("Failed to get state dimension. SKIPPING OSPA(2) calculation") 

3718 

3719 nt = len(self._states) 

3720 self.ospa2 = np.zeros(nt) 

3721 self.ospa2_localization = np.zeros(nt) 

3722 self.ospa2_cardinality = np.zeros(nt) 

3723 self._ospa2_params["core"] = core_method 

3724 self._ospa2_params["cutoff"] = c 

3725 self._ospa2_params["power"] = p 

3726 self._ospa2_params["win_len"] = win_len 

3727 return 

3728 true_mat, true_cov_mat = self._ospa_setup_tmat( 

3729 truth, state_dim, true_covs, state_inds 

3730 ) 

3731 est_mat, est_cov_mat = self._ospa_setup_emat(state_dim, state_inds) 

3732 

3733 # find OSPA 

3734 ( 

3735 self.ospa2, 

3736 self.ospa2_localization, 

3737 self.ospa2_cardinality, 

3738 self._ospa2_params["core"], 

3739 self._ospa2_params["cutoff"], 

3740 self._ospa2_params["power"], 

3741 self._ospa2_params["win_len"], 

3742 ) = calculate_ospa2( 

3743 est_mat, 

3744 true_mat, 

3745 c, 

3746 p, 

3747 win_len, 

3748 core_method=core_method, 

3749 true_cov_mat=true_cov_mat, 

3750 est_cov_mat=est_cov_mat, 

3751 ) 

3752 

3753 def plot_states_labels( 

3754 self, 

3755 plt_inds, 

3756 ttl="Labeled State Trajectories", 

3757 x_lbl=None, 

3758 y_lbl=None, 

3759 meas_tx_fnc=None, 

3760 **kwargs, 

3761 ): 

3762 """Plots the best estimate for the states and labels. 

3763 

3764 This assumes that the states have been extracted. It's designed to plot 

3765 two of the state variables (typically x/y position). The error ellipses 

3766 are calculated according to :cite:`Hoover1984_AlgorithmsforConfidenceCirclesandEllipses` 

3767 

3768 Keywrod arguments are processed with 

3769 :meth:`gncpy.plotting.init_plotting_opts`. This function 

3770 implements 

3771 

3772 - f_hndl 

3773 - true_states 

3774 - sig_bnd 

3775 - rng 

3776 - meas_inds 

3777 - lgnd_loc 

3778 

3779 Parameters 

3780 ---------- 

3781 plt_inds : list 

3782 List of indices in the state vector to plot 

3783 ttl : string, optional 

3784 Title of the plot. 

3785 x_lbl : string, optional 

3786 X-axis label for the plot. 

3787 y_lbl : string, optional 

3788 Y-axis label for the plot. 

3789 meas_tx_fnc : callable, optional 

3790 Takes in the measurement vector as an Nm x 1 numpy array and 

3791 returns a numpy array representing the states to plot (size 2). The 

3792 default is None. 

3793 

3794 Returns 

3795 ------- 

3796 Matplotlib figure 

3797 Instance of the matplotlib figure used 

3798 """ 

3799 opts = pltUtil.init_plotting_opts(**kwargs) 

3800 f_hndl = opts["f_hndl"] 

3801 true_states = opts["true_states"] 

3802 sig_bnd = opts["sig_bnd"] 

3803 rng = opts["rng"] 

3804 meas_inds = opts["meas_inds"] 

3805 lgnd_loc = opts["lgnd_loc"] 

3806 mrkr = opts["marker"] 

3807 

3808 if rng is None: 

3809 rng = rnd.default_rng(1) 

3810 if x_lbl is None: 

3811 x_lbl = "x-position" 

3812 if y_lbl is None: 

3813 y_lbl = "y-position" 

3814 meas_specs_given = ( 

3815 meas_inds is not None and len(meas_inds) == 2 

3816 ) or meas_tx_fnc is not None 

3817 plt_meas = meas_specs_given and self.save_measurements 

3818 show_sig = sig_bnd is not None and self.save_covs 

3819 

3820 s_lst = deepcopy(self.states) 

3821 l_lst = deepcopy(self.labels) 

3822 x_dim = None 

3823 

3824 if f_hndl is None: 

3825 f_hndl = plt.figure() 

3826 f_hndl.add_subplot(1, 1, 1) 

3827 # get state dimension 

3828 for states in s_lst: 

3829 if states is not None and len(states) > 0: 

3830 x_dim = states[0].size 

3831 break 

3832 # get unique labels 

3833 u_lbls = [] 

3834 for lbls in l_lst: 

3835 if lbls is None: 

3836 continue 

3837 for lbl in lbls: 

3838 if lbl not in u_lbls: 

3839 u_lbls.append(lbl) 

3840 cmap = pltUtil.get_cmap(len(u_lbls)) 

3841 

3842 # get array of all state values for each label 

3843 added_sig_lbl = False 

3844 added_true_lbl = False 

3845 added_state_lbl = False 

3846 added_meas_lbl = False 

3847 for c_idx, lbl in enumerate(u_lbls): 

3848 x = np.nan * np.ones((x_dim, len(s_lst))) 

3849 if show_sig: 

3850 sigs = [None] * len(s_lst) 

3851 for tt, lbls in enumerate(l_lst): 

3852 if lbls is None: 

3853 continue 

3854 if lbl in lbls: 

3855 ii = lbls.index(lbl) 

3856 if s_lst[tt][ii] is not None: 

3857 x[:, [tt]] = s_lst[tt][ii].copy() 

3858 if show_sig: 

3859 sig = np.zeros((2, 2)) 

3860 if self._covs[tt][ii] is not None: 

3861 sig[0, 0] = self._covs[tt][ii][plt_inds[0], plt_inds[0]] 

3862 sig[0, 1] = self._covs[tt][ii][plt_inds[0], plt_inds[1]] 

3863 sig[1, 0] = self._covs[tt][ii][plt_inds[1], plt_inds[0]] 

3864 sig[1, 1] = self._covs[tt][ii][plt_inds[1], plt_inds[1]] 

3865 else: 

3866 sig = None 

3867 sigs[tt] = sig 

3868 # plot 

3869 color = cmap(c_idx) 

3870 

3871 if show_sig: 

3872 for tt, sig in enumerate(sigs): 

3873 if sig is None: 

3874 continue 

3875 w, h, a = pltUtil.calc_error_ellipse(sig, sig_bnd) 

3876 if not added_sig_lbl: 

3877 s = r"${}\sigma$ Error Ellipses".format(sig_bnd) 

3878 e = Ellipse( 

3879 xy=x[plt_inds, tt], 

3880 width=w, 

3881 height=h, 

3882 angle=a, 

3883 zorder=-10000, 

3884 label=s, 

3885 ) 

3886 added_sig_lbl = True 

3887 else: 

3888 e = Ellipse( 

3889 xy=x[plt_inds, tt], 

3890 width=w, 

3891 height=h, 

3892 angle=a, 

3893 zorder=-10000, 

3894 ) 

3895 e.set_clip_box(f_hndl.axes[0].bbox) 

3896 e.set_alpha(0.2) 

3897 e.set_facecolor(color) 

3898 f_hndl.axes[0].add_patch(e) 

3899 settings = { 

3900 "color": color, 

3901 "markeredgecolor": "k", 

3902 "marker": mrkr, 

3903 "ls": "--", 

3904 } 

3905 if not added_state_lbl: 

3906 settings["label"] = "States" 

3907 # f_hndl.axes[0].scatter(x[plt_inds[0], :], x[plt_inds[1], :], 

3908 # color=color, edgecolors='k', 

3909 # label='States') 

3910 added_state_lbl = True 

3911 # else: 

3912 f_hndl.axes[0].plot(x[plt_inds[0], :], x[plt_inds[1], :], **settings) 

3913 

3914 s = "({}, {})".format(lbl[0], lbl[1]) 

3915 tmp = x.copy() 

3916 tmp = tmp[:, ~np.any(np.isnan(tmp), axis=0)] 

3917 f_hndl.axes[0].text( 

3918 tmp[plt_inds[0], 0], tmp[plt_inds[1], 0], s, color=color 

3919 ) 

3920 # if true states are available then plot them 

3921 if true_states is not None and any([len(x) > 0 for x in true_states]): 

3922 if x_dim is None: 

3923 for states in true_states: 

3924 if len(states) > 0: 

3925 x_dim = states[0].size 

3926 break 

3927 max_true = max([len(x) for x in true_states]) 

3928 x = np.nan * np.ones((x_dim, len(true_states), max_true)) 

3929 for tt, states in enumerate(true_states): 

3930 for ii, state in enumerate(states): 

3931 if state is not None and state.size > 0: 

3932 x[:, [tt], ii] = state.copy() 

3933 for ii in range(0, max_true): 

3934 if not added_true_lbl: 

3935 f_hndl.axes[0].plot( 

3936 x[plt_inds[0], :, ii], 

3937 x[plt_inds[1], :, ii], 

3938 color="k", 

3939 marker=".", 

3940 label="True Trajectories", 

3941 ) 

3942 added_true_lbl = True 

3943 else: 

3944 f_hndl.axes[0].plot( 

3945 x[plt_inds[0], :, ii], 

3946 x[plt_inds[1], :, ii], 

3947 color="k", 

3948 marker=".", 

3949 ) 

3950 if plt_meas: 

3951 meas_x = [] 

3952 meas_y = [] 

3953 for meas_tt in self._meas_tab: 

3954 if meas_tx_fnc is not None: 

3955 tx_meas = [meas_tx_fnc(m) for m in meas_tt] 

3956 mx_ii = [tm[0].item() for tm in tx_meas] 

3957 my_ii = [tm[1].item() for tm in tx_meas] 

3958 else: 

3959 mx_ii = [m[meas_inds[0]].item() for m in meas_tt] 

3960 my_ii = [m[meas_inds[1]].item() for m in meas_tt] 

3961 meas_x.extend(mx_ii) 

3962 meas_y.extend(my_ii) 

3963 color = (128 / 255, 128 / 255, 128 / 255) 

3964 meas_x = np.asarray(meas_x) 

3965 meas_y = np.asarray(meas_y) 

3966 if meas_x.size > 0: 

3967 if not added_meas_lbl: 

3968 f_hndl.axes[0].scatter( 

3969 meas_x, 

3970 meas_y, 

3971 zorder=-1, 

3972 alpha=0.35, 

3973 color=color, 

3974 marker="^", 

3975 label="Measurements", 

3976 ) 

3977 else: 

3978 f_hndl.axes[0].scatter( 

3979 meas_x, meas_y, zorder=-1, alpha=0.35, color=color, marker="^" 

3980 ) 

3981 f_hndl.axes[0].grid(True) 

3982 pltUtil.set_title_label( 

3983 f_hndl, 0, opts, ttl=ttl, x_lbl="x-position", y_lbl="y-position" 

3984 ) 

3985 if lgnd_loc is not None: 

3986 plt.legend(loc=lgnd_loc) 

3987 plt.tight_layout() 

3988 

3989 return f_hndl 

3990 

3991 def plot_card_dist(self, ttl=None, **kwargs): 

3992 """Plots the current cardinality distribution. 

3993 

3994 This assumes that the cardinality distribution has been calculated by 

3995 the class. 

3996 

3997 Keywrod arguments are processed with 

3998 :meth:`gncpy.plotting.init_plotting_opts`. This function 

3999 implements 

4000 

4001 - f_hndl 

4002 

4003 Parameters 

4004 ---------- 

4005 ttl : string 

4006 Title of the plot, if None a default title is generated. The default 

4007 is None. 

4008 

4009 Returns 

4010 ------- 

4011 Matplotlib figure 

4012 Instance of the matplotlib figure used 

4013 """ 

4014 opts = pltUtil.init_plotting_opts(**kwargs) 

4015 f_hndl = opts["f_hndl"] 

4016 if ttl is None: 

4017 ttl = "Cardinality Distribution" 

4018 if len(self._card_dist) == 0: 

4019 raise RuntimeWarning("Empty Cardinality") 

4020 return f_hndl 

4021 if f_hndl is None: 

4022 f_hndl = plt.figure() 

4023 f_hndl.add_subplot(1, 1, 1) 

4024 x_vals = np.arange(0, len(self._card_dist)) 

4025 f_hndl.axes[0].bar(x_vals, self._card_dist) 

4026 

4027 pltUtil.set_title_label( 

4028 f_hndl, 0, opts, ttl=ttl, x_lbl="Cardinality", y_lbl="Probability" 

4029 ) 

4030 plt.tight_layout() 

4031 

4032 return f_hndl 

4033 

4034 def plot_card_history( 

4035 self, time_units="index", time=None, ttl="Cardinality History", **kwargs 

4036 ): 

4037 """Plots the cardinality history. 

4038 

4039 Parameters 

4040 ---------- 

4041 time_units : string, optional 

4042 Text representing the units of time in the plot. The default is 

4043 'index'. 

4044 time : numpy array, optional 

4045 Vector to use for the x-axis of the plot. If none is given then 

4046 vector indices are used. The default is None. 

4047 ttl : string, optional 

4048 Title of the plot. 

4049 **kwargs : dict 

4050 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

4051 function. Values implemented here are `f_hndl`, and any values 

4052 relating to title/axis text formatting. 

4053 

4054 Returns 

4055 ------- 

4056 fig : matplotlib figure 

4057 Figure object the data was plotted on. 

4058 """ 

4059 card_history = np.array([len(state_set) for state_set in self.states]) 

4060 

4061 opts = pltUtil.init_plotting_opts(**kwargs) 

4062 fig = opts["f_hndl"] 

4063 

4064 if fig is None: 

4065 fig = plt.figure() 

4066 fig.add_subplot(1, 1, 1) 

4067 if time is None: 

4068 time = np.arange(card_history.size, dtype=int) 

4069 fig.axes[0].grid(True) 

4070 fig.axes[0].step(time, card_history, where="post", label="estimated", color="k") 

4071 fig.axes[0].ticklabel_format(useOffset=False) 

4072 

4073 pltUtil.set_title_label( 

4074 fig, 

4075 0, 

4076 opts, 

4077 ttl=ttl, 

4078 x_lbl="Time ({})".format(time_units), 

4079 y_lbl="Cardinality", 

4080 ) 

4081 fig.tight_layout() 

4082 

4083 return fig 

4084 

4085 def plot_ospa2_history( 

4086 self, 

4087 time_units="index", 

4088 time=None, 

4089 main_opts=None, 

4090 sub_opts=None, 

4091 plot_subs=True, 

4092 ): 

4093 """Plots the OSPA2 history. 

4094 

4095 This requires that the OSPA2 has been calcualted by the approriate 

4096 function first. 

4097 

4098 Parameters 

4099 ---------- 

4100 time_units : string, optional 

4101 Text representing the units of time in the plot. The default is 

4102 'index'. 

4103 time : numpy array, optional 

4104 Vector to use for the x-axis of the plot. If none is given then 

4105 vector indices are used. The default is None. 

4106 main_opts : dict, optional 

4107 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

4108 function. Values implemented here are `f_hndl`, and any values 

4109 relating to title/axis text formatting. The default of None implies 

4110 the default options are used for the main plot. 

4111 sub_opts : dict, optional 

4112 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

4113 function. Values implemented here are `f_hndl`, and any values 

4114 relating to title/axis text formatting. The default of None implies 

4115 the default options are used for the sub plot. 

4116 plot_subs : bool, optional 

4117 Flag indicating if the component statistics (cardinality and 

4118 localization) should also be plotted. 

4119 

4120 Returns 

4121 ------- 

4122 figs : dict 

4123 Dictionary of matplotlib figure objects the data was plotted on. 

4124 """ 

4125 if self.ospa2 is None: 

4126 warnings.warn("OSPA must be calculated before plotting") 

4127 return 

4128 if main_opts is None: 

4129 main_opts = pltUtil.init_plotting_opts() 

4130 if sub_opts is None and plot_subs: 

4131 sub_opts = pltUtil.init_plotting_opts() 

4132 fmt = "{:s} OSPA2 (c = {:.1f}, p = {:d}, w={:d})" 

4133 ttl = fmt.format( 

4134 self._ospa2_params["core"], 

4135 self._ospa2_params["cutoff"], 

4136 self._ospa2_params["power"], 

4137 self._ospa2_params["win_len"], 

4138 ) 

4139 y_lbl = "OSPA2" 

4140 

4141 figs = {} 

4142 figs["OSPA2"] = self._plt_ospa_hist( 

4143 self.ospa2, time_units, time, ttl, y_lbl, main_opts 

4144 ) 

4145 

4146 if plot_subs: 

4147 fmt = "{:s} OSPA2 Components (c = {:.1f}, p = {:d}, w={:d})" 

4148 ttl = fmt.format( 

4149 self._ospa2_params["core"], 

4150 self._ospa2_params["cutoff"], 

4151 self._ospa2_params["power"], 

4152 self._ospa2_params["win_len"], 

4153 ) 

4154 y_lbls = ["Localiztion", "Cardinality"] 

4155 figs["OSPA2_subs"] = self._plt_ospa_hist_subs( 

4156 [self.ospa2_localization, self.ospa2_cardinality], 

4157 time_units, 

4158 time, 

4159 ttl, 

4160 y_lbls, 

4161 main_opts, 

4162 ) 

4163 return figs 

4164 

4165 

4166class _STMGLMBBase: 

4167 def __init__(self, **kwargs): 

4168 super().__init__(**kwargs) 

4169 

4170 def _init_filt_states(self, distrib): 

4171 filt_states = [None] * len(distrib.means) 

4172 states = [m.copy() for m in distrib.means] 

4173 covs = [None] * len(distrib.means) 

4174 

4175 weights = distrib.weights.copy() 

4176 self._baseFilter.dof = distrib.dof 

4177 for ii, scale in enumerate(distrib.scalings): 

4178 self._baseFilter.scale = scale.copy() 

4179 filt_states[ii] = self._baseFilter.save_filter_state() 

4180 if self.save_covs: 

4181 # no need to copy because cov is already a new object for the student's t-fitler 

4182 covs[ii] = self.filter.cov 

4183 return filt_states, weights, states, covs 

4184 

4185 def _gate_meas(self, meas, means, covs, **kwargs): 

4186 # TODO: check this implementation 

4187 if len(meas) == 0: 

4188 return [] 

4189 scalings = [] 

4190 for ent in self._track_tab: 

4191 scalings.extend(ent.probDensity.scalings) 

4192 valid = [] 

4193 for m, p in zip(means, scalings): 

4194 meas_mat = self.filter.get_meas_mat(m, **kwargs) 

4195 est = self.filter.get_est_meas(m, **kwargs) 

4196 factor = ( 

4197 self.filter.meas_noise_dof 

4198 * (self.filter.dof - 2) 

4199 / (self.filter.dof * (self.filter.meas_noise_dof - 2)) 

4200 ) 

4201 P_zz = meas_mat @ p @ meas_mat.T + factor * self.filter.meas_noise 

4202 inv_P = la.inv(P_zz) 

4203 

4204 for ii, z in enumerate(meas): 

4205 if ii in valid: 

4206 continue 

4207 innov = z - est 

4208 dist = innov.T @ inv_P @ innov 

4209 if dist < self.inv_chi2_gate: 

4210 valid.append(ii) 

4211 valid.sort() 

4212 return [meas[ii] for ii in valid] 

4213 

4214 

4215# Note: need inherited classes in this order for proper MRO 

4216class STMGeneralizedLabeledMultiBernoulli( 

4217 _STMGLMBBase, GeneralizedLabeledMultiBernoulli 

4218): 

4219 """Implementation of a STM-GLMB filter.""" 

4220 

4221 def __init__(self, **kwargs): 

4222 super().__init__(**kwargs) 

4223 

4224 

4225class _SMCGLMBBase: 

4226 def __init__( 

4227 self, compute_prob_detection=None, compute_prob_survive=None, **kwargs 

4228 ): 

4229 self.compute_prob_detection = compute_prob_detection 

4230 self.compute_prob_survive = compute_prob_survive 

4231 

4232 # for wrappers for predict/correct function to handle extra args for private functions 

4233 self._prob_surv_args = () 

4234 self._prob_det_args = () 

4235 

4236 super().__init__(**kwargs) 

4237 

4238 def _init_filt_states(self, distrib): 

4239 self._baseFilter.init_from_dist(distrib, make_copy=True) 

4240 filt_states = [ 

4241 self._baseFilter.save_filter_state(), 

4242 ] 

4243 states = [distrib.mean] 

4244 if self.save_covs: 

4245 covs = [ 

4246 distrib.covariance, 

4247 ] 

4248 else: 

4249 covs = [] 

4250 weights = [ 

4251 1, 

4252 ] # not needed so set to 1 

4253 

4254 return filt_states, weights, states, covs 

4255 

4256 def _calc_avg_prob_surv_death(self): 

4257 avg_prob_survive = np.zeros(len(self._track_tab)) 

4258 for tabidx, ent in enumerate(self._track_tab): 

4259 # TODO: fix hack so not using "private" variable outside class 

4260 p_surv = self.compute_prob_survive( 

4261 ent.filt_states[0]["_particleDist"].particles, *self._prob_surv_args 

4262 ) 

4263 avg_prob_survive[tabidx] = np.sum( 

4264 np.array(ent.filt_states[0]["_particleDist"].weights) * p_surv 

4265 ) 

4266 avg_prob_death = 1 - avg_prob_survive 

4267 

4268 return avg_prob_survive, avg_prob_death 

4269 

4270 def _inner_predict(self, timestep, filt_state, state, filt_args): 

4271 self.filter.load_filter_state(filt_state) 

4272 if self.filter._particleDist.num_particles > 0: 

4273 new_s = self.filter.predict(timestep, **filt_args) 

4274 

4275 # manually update weights to account for prob survive 

4276 # TODO: fix hack so not using "private" variable outside class 

4277 ps = self.compute_prob_survive( 

4278 self.filter._particleDist.particles, *self._prob_surv_args 

4279 ) 

4280 new_weights = [ 

4281 w * ps[ii] for ii, (p, w) in enumerate(self.filter._particleDist) 

4282 ] 

4283 tot = sum(new_weights) 

4284 if np.abs(tot) == np.inf: 

4285 w_lst = [np.inf] * len(new_weights) 

4286 else: 

4287 w_lst = [w / tot for w in new_weights] 

4288 self.filter._particleDist.update_weights(w_lst) 

4289 

4290 new_f_state = self.filter.save_filter_state() 

4291 if self.save_covs: 

4292 new_cov = self.filter.cov.copy() 

4293 else: 

4294 new_cov = None 

4295 else: 

4296 new_f_state = self.filter.save_filter_state() 

4297 new_s = state 

4298 new_cov = self.filter.cov 

4299 return new_f_state, new_s, new_cov 

4300 

4301 def predict(self, timestep, prob_surv_args=(), **kwargs): 

4302 """Prediction step of the SMC-GLMB filter. 

4303 

4304 This is a wrapper for the parent class to allow for extra parameters. 

4305 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.predict` for 

4306 additional details. 

4307 

4308 Parameters 

4309 ---------- 

4310 timestep : float 

4311 Current timestep. 

4312 prob_surv_args : tuple, optional 

4313 Additional arguments for the `compute_prob_survive` function. 

4314 The default is (). 

4315 **kwargs : dict, optional 

4316 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.predict` 

4317 """ 

4318 self._prob_surv_args = prob_surv_args 

4319 return super().predict(timestep, **kwargs) 

4320 

4321 def _calc_avg_prob_det_mdet(self): 

4322 avg_prob_detect = np.zeros(len(self._track_tab)) 

4323 for tabidx, ent in enumerate(self._track_tab): 

4324 # TODO: fix hack so not using "private" variable outside class 

4325 p_detect = self.compute_prob_detection( 

4326 ent.filt_states[0]["_particleDist"].particles, *self._prob_det_args 

4327 ) 

4328 avg_prob_detect[tabidx] = np.sum( 

4329 np.array(ent.filt_states[0]["_particleDist"].weights) * p_detect 

4330 ) 

4331 avg_prob_miss_detect = 1 - avg_prob_detect 

4332 

4333 return avg_prob_detect, avg_prob_miss_detect 

4334 

4335 def _inner_correct( 

4336 self, timestep, meas, filt_state, distrib_weight, state, filt_args 

4337 ): 

4338 self.filter.load_filter_state(filt_state) 

4339 if self.filter._particleDist.num_particles > 0: 

4340 cor_state, likely = self.filter.correct(timestep, meas, **filt_args)[0:2] 

4341 

4342 # manually update the particle weights to account for probability of detection 

4343 # TODO: fix hack so not using "private" variable outside class 

4344 pd = self.compute_prob_detection( 

4345 self.filter._particleDist.particles, *self._prob_det_args 

4346 ) 

4347 pd_weight = ( 

4348 pd * np.array(self.filter._particleDist.weights) + np.finfo(float).eps 

4349 ) 

4350 self.filter._particleDist.update_weights( 

4351 (pd_weight / np.sum(pd_weight)).tolist() 

4352 ) 

4353 

4354 # determine the partial cost, the remainder is calculated later from 

4355 # the hypothesis 

4356 new_w = np.sum(likely * pd_weight) # same as cost in this case 

4357 

4358 new_f_state = self.filter.save_filter_state() 

4359 new_s = cor_state 

4360 if self.save_covs: 

4361 new_c = self.filter.cov 

4362 else: 

4363 new_c = None 

4364 else: 

4365 new_f_state = self.filter.save_filter_state() 

4366 new_s = state 

4367 new_c = self.filter.cov 

4368 new_w = 0 

4369 return new_f_state, new_s, new_c, new_w 

4370 

4371 def correct(self, timestep, meas, prob_det_args=(), **kwargs): 

4372 """Correction step of the SMC-GLMB filter. 

4373 

4374 This is a wrapper for the parent class to allow for extra parameters. 

4375 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.correct` for 

4376 additional details. 

4377 

4378 Parameters 

4379 ---------- 

4380 timestep : float 

4381 Current timestep. 

4382 prob_det_args : tuple, optional 

4383 Additional arguments for the `compute_prob_detection` function. 

4384 The default is (). 

4385 **kwargs : dict, optional 

4386 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.correct` 

4387 """ 

4388 self._prob_det_args = prob_det_args 

4389 return super().correct(timestep, meas, **kwargs) 

4390 

4391 def extract_most_prob_states(self, thresh, **kwargs): 

4392 """Extracts themost probable states. 

4393 

4394 .. todo:: 

4395 Implement this function for the SMC-GLMB filter 

4396 

4397 Raises 

4398 ------ 

4399 RuntimeWarning 

4400 Function must be implemented. 

4401 """ 

4402 warnings.warn("Not implemented for this class") 

4403 

4404 

4405# Note: need inherited classes in this order for proper MRO 

4406class SMCGeneralizedLabeledMultiBernoulli( 

4407 _SMCGLMBBase, GeneralizedLabeledMultiBernoulli 

4408): 

4409 """Implementation of a Sequential Monte Carlo GLMB filter. 

4410 

4411 This is based on :cite:`Vo2014_LabeledRandomFiniteSetsandtheBayesMultiTargetTrackingFilter` 

4412 It does not account for agents spawned from existing tracks, only agents 

4413 birthed from the given birth model. 

4414 

4415 Attributes 

4416 ---------- 

4417 compute_prob_detection : callable 

4418 Function that takes a list of particles as the first argument and `*args` 

4419 as the next. Returns the probability of detection for each particle as a list. 

4420 compute_prob_survive : callable 

4421 Function that takes a list of particles as the first argument and `*args` as 

4422 the next. Returns the average probability of survival for each particle as a list. 

4423 """ 

4424 

4425 def __init__(self, **kwargs): 

4426 super().__init__(**kwargs) 

4427 

4428 

4429class GSMGeneralizedLabeledMultiBernoulli(GeneralizedLabeledMultiBernoulli): 

4430 """Implementation of a GSM-GLMB filter. 

4431 

4432 The implementation of the GSM-GLMB fitler does not change for different core 

4433 filters (i.e. QKF GSM, SQKF GSM, UKF GSM, etc.) so this class can use any 

4434 of the GSM inner filters from gncpy.filters 

4435 """ 

4436 

4437 def __init__(self, **kwargs): 

4438 super().__init__(**kwargs) 

4439 

4440 

4441class _IMMGLMBBase: 

4442 def __init__(self, **kwargs): 

4443 super().__init__(**kwargs) 

4444 

4445 def _init_filt_states(self, distrib): 

4446 filt_states = [None] * len(distrib.means) 

4447 states = [m.copy() for m in distrib.means] 

4448 if self.save_covs: 

4449 covs = [c.copy() for c in distrib.covariances] 

4450 else: 

4451 covs = [] 

4452 weights = distrib.weights.copy() 

4453 for ii, (m, cov) in enumerate(zip(distrib.means, distrib.covariances)): 

4454 # if len(m) != 1 or len(cov) != 1: 

4455 # raise ValueError("Only one mean can be passed to IMM filters for initialization") 

4456 m_list = [] 

4457 c_list = [] 

4458 for jj in range(0, len(self._baseFilter.in_filt_list)): 

4459 m_list.append(m) 

4460 c_list.append(cov) 

4461 self._baseFilter.initialize_states(m_list, c_list) 

4462 filt_states[ii] = self._baseFilter.save_filter_state() 

4463 return filt_states, weights, states, covs 

4464 

4465 def _inner_predict(self, timestep, filt_state, state, filt_args): 

4466 self.filter.load_filter_state(filt_state) 

4467 new_s = self.filter.predict(timestep, **filt_args) 

4468 new_f_state = self.filter.save_filter_state() 

4469 if self.save_covs: 

4470 new_cov = self.filter.cov.copy() 

4471 else: 

4472 new_cov = None 

4473 return new_f_state, new_s, new_cov 

4474 

4475 def _inner_correct( 

4476 self, timestep, meas, filt_state, distrib_weight, state, filt_args 

4477 ): 

4478 self.filter.load_filter_state(filt_state) 

4479 cor_state, likely = self.filter.correct(timestep, meas, **filt_args) 

4480 new_f_state = self.filter.save_filter_state() 

4481 new_s = cor_state 

4482 if self.save_covs: 

4483 new_c = self.filter.cov.copy() 

4484 else: 

4485 new_c = None 

4486 new_w = distrib_weight * likely 

4487 

4488 return new_f_state, new_s, new_c, new_w 

4489 

4490 

4491class IMMGeneralizedLabeledMultiBernoulli( 

4492 _IMMGLMBBase, GeneralizedLabeledMultiBernoulli 

4493): 

4494 """An implementation of the IMM-GLMB algorithm.""" 

4495 

4496 def __init__(self, **kwargs): 

4497 super().__init__(**kwargs) 

4498 

4499 

4500class JointGeneralizedLabeledMultiBernoulli(GeneralizedLabeledMultiBernoulli): 

4501 """Implements a Joint Generalized Labeled Multi-Bernoulli Filter. 

4502 

4503 The Joint GLMB is designed to call predict and correct simultaneously, 

4504 as a single joint prediction-correction step. 

4505 Calling them asynchronously may cause poor performance. 

4506 

4507 Notes 

4508 ----- 

4509 This is based on :cite:`Vo2017_AnEfficientImplementationoftheGeneralizedLabeledMultiBernoulliFilter`. 

4510 It does not account for agents spawned from existing tracks, only agents 

4511 birthed from the given birth model. 

4512 """ 

4513 

4514 def __init__(self, rng=None, **kwargs): 

4515 super().__init__(**kwargs) 

4516 self._old_track_tab_len = len(self._track_tab) 

4517 self._update_has_been_called = ( 

4518 True # used to denote if the update function should be called or not. 

4519 ) 

4520 if rng is None: 

4521 self._rng = np.random.default_rng() 

4522 else: 

4523 self._rng = rng 

4524 

4525 def save_filter_state(self): 

4526 """Saves filter variables so they can be restored later. 

4527 

4528 Note that to pickle the resulting dictionary the :code:`dill` package 

4529 may need to be used due to potential pickling of functions. 

4530 """ 

4531 filt_state = super().save_filter_state() 

4532 

4533 filt_state["_old_track_tab_len"] = self._old_track_tab_len 

4534 

4535 return filt_state 

4536 

4537 def load_filter_state(self, filt_state): 

4538 """Initializes filter using saved filter state. 

4539 

4540 Attributes 

4541 ---------- 

4542 filt_state : dict 

4543 Dictionary generated by :meth:`save_filter_state`. 

4544 """ 

4545 super().load_filter_state(filt_state) 

4546 

4547 self._old_track_tab_len = filt_state["_old_track_tab_len"] 

4548 

4549 def predict(self, timestep, filt_args={}): 

4550 """Prediction step of the JGLMB filter. 

4551 

4552 This predicts new hypothesis, and propogates them to the next time 

4553 step. Because this calls 

4554 the inner filter's predict function, the keyword arguments must contain 

4555 any information needed by that function. 

4556 

4557 Parameters 

4558 ---------- 

4559 timestep: float 

4560 Current timestep. 

4561 filt_args : dict, optional 

4562 Passed to the inner filter. The default is {}. 

4563 

4564 Returns 

4565 ------- 

4566 None. 

4567 """ 

4568 if self._update_has_been_called: 

4569 # Birth Track Table 

4570 birth_tab = self._gen_birth_tab(timestep)[0] 

4571 else: 

4572 birth_tab = [] 

4573 warnings.warn("Joint GLMB should call predict and correct simultaneously") 

4574 self._update_has_been_called = False 

4575 

4576 # Survival Track Table 

4577 surv_tab = self._gen_surv_tab(timestep, filt_args) 

4578 

4579 # Prediction Track Table 

4580 

4581 self._track_tab = birth_tab + surv_tab 

4582 

4583 def _unique_faster(self, keys): 

4584 difference = np.diff(np.append(keys, np.nan), n=1, axis=0) 

4585 keyind = np.not_equal(difference, 0) 

4586 mindices = (keys[0][np.where(keyind)]).astype(int) 

4587 return mindices 

4588 

4589 def _calc_avg_prob_surv_death(self): 

4590 avg_surv = np.zeros(len(self.birth_terms) + self._old_track_tab_len) 

4591 for ii in range(0, avg_surv.shape[0]): 

4592 if ii <= len(self.birth_terms) - 1: 

4593 avg_surv[ii] = self.birth_terms[ii][1] 

4594 else: 

4595 avg_surv[ii] = self.prob_survive 

4596 # avg_surv = np.array([avg_surv]).T 

4597 avg_death = 1 - avg_surv 

4598 return avg_surv, avg_death 

4599 

4600 def _calc_avg_prob_det_mdet(self): 

4601 avg_detect = self.prob_detection * np.ones(len(self._track_tab)) 

4602 # avg_detect = np.array([avg_detect]).T 

4603 avg_miss = 1 - avg_detect 

4604 return avg_detect, avg_miss 

4605 

4606 def _gen_cor_tab(self, num_meas, meas, timestep, filt_args): 

4607 num_pred = len(self._track_tab) 

4608 up_tab = [None] * (num_meas + 1) * num_pred 

4609 

4610 for ii, track in enumerate(self._track_tab): 

4611 up_tab[ii] = self._TabEntry().setup(track) 

4612 up_tab[ii].meas_assoc_hist.append(None) 

4613 # measurement updated tracks 

4614 all_cost_m = np.zeros((num_pred, num_meas)) 

4615 # for emm, z in enumerate(meas): 

4616 for ii, ent in enumerate(self._track_tab): 

4617 for emm, z in enumerate(meas): 

4618 s_to_ii = num_pred * emm + ii + num_pred 

4619 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry( 

4620 z, ent, timestep, filt_args 

4621 ) 

4622 

4623 # update association history with current measurement index 

4624 if up_tab[s_to_ii] is not None: 

4625 up_tab[s_to_ii].meas_assoc_hist.append(emm) 

4626 all_cost_m[ii, emm] = cost 

4627 return up_tab, all_cost_m 

4628 

4629 def _gen_cor_hyps( 

4630 self, 

4631 num_meas, 

4632 avg_prob_detect, 

4633 avg_prob_miss_detect, 

4634 avg_prob_surv, 

4635 avg_prob_death, 

4636 all_cost_m, 

4637 ): 

4638 # Define clutter 

4639 clutter = self.clutter_rate * self.clutter_den 

4640 # clutter = self.clutter_den 

4641 

4642 # Joint Cost Matrix 

4643 joint_cost = np.concatenate( 

4644 [ 

4645 np.diag(avg_prob_death.ravel()), 

4646 np.diag(avg_prob_surv.ravel() * avg_prob_miss_detect.ravel()), 

4647 ], 

4648 axis=1, 

4649 ) 

4650 

4651 other_jc_terms = ( 

4652 np.tile((avg_prob_surv * avg_prob_detect).reshape((-1, 1)), (1, num_meas)) 

4653 * all_cost_m 

4654 / (clutter) 

4655 ) 

4656 

4657 # Full joint cost matrix 

4658 joint_cost = np.append(joint_cost, other_jc_terms, axis=1) 

4659 

4660 # Gated Measurement index matrix 

4661 gate_meas_indices = np.zeros((len(self._track_tab), num_meas)) 

4662 for ii in range(0, len(self._track_tab)): 

4663 for jj in range(0, len(self._track_tab[ii].gatemeas)): 

4664 gate_meas_indices[ii][jj] = self._track_tab[ii].gatemeas[jj] 

4665 gate_meas_indc = gate_meas_indices >= 0 

4666 

4667 # Component updates 

4668 ss_w = 0 

4669 up_hyp = [] 

4670 for p_hyp in self._hypotheses: 

4671 ss_w += np.sqrt(p_hyp.assoc_prob) 

4672 for p_hyp in self._hypotheses: 

4673 cpreds = len(self._track_tab) 

4674 num_births = len(self.birth_terms) 

4675 num_exists = len(p_hyp.track_set) 

4676 num_tracks = num_births + num_exists 

4677 

4678 # Hypothesis index masking 

4679 tindices = np.concatenate( 

4680 (np.arange(0, num_births), num_births + np.array(p_hyp.track_set)) 

4681 ).astype(int) 

4682 lselmask = np.zeros((len(self._track_tab), num_meas), dtype="bool") 

4683 lselmask[tindices,] = gate_meas_indc[tindices,] 

4684 

4685 keys = np.array([np.sort(gate_meas_indices[lselmask])]) 

4686 mindices = self._unique_faster(keys) 

4687 

4688 comb_tind_cpred = np.append( 

4689 np.append(tindices, cpreds + tindices), [2 * cpreds + mindices] 

4690 ) 

4691 # print(joint_cost.shape) 

4692 # print(tindices) 

4693 # print(comb_tind_cpred) 

4694 cost_m = joint_cost[tindices][:, comb_tind_cpred] 

4695 # print(cost_m.shape) 

4696 # cost_m = np.zeros((len(tindices), len(comb_tind_cpred))) 

4697 # cmi = 0 

4698 # for ind in tindices: 

4699 # cost_m[cmi, :] = joint_cost[ind, comb_tind_cpred] 

4700 # cmi = cmi + 1 

4701 with warnings.catch_warnings(): 

4702 warnings.simplefilter("ignore", RuntimeWarning) 

4703 neg_log = -np.log(cost_m) 

4704 

4705 m = np.round(self.req_upd * np.sqrt(p_hyp.assoc_prob) / ss_w) 

4706 m = int(m.item()) + 1 

4707 

4708 # Gibbs Sampler 

4709 [assigns, costs] = gibbs(neg_log, m, rng=self._rng) 

4710 

4711 # Process unique assighnments from gibbs sampler 

4712 assigns[assigns < num_tracks] = -np.inf 

4713 for ii in range(np.shape(assigns)[0]): 

4714 if len(np.shape(assigns)) < 2: 

4715 if assigns[ii] >= num_tracks and assigns[ii] < 2 * num_tracks: 

4716 assigns[ii] = -1 

4717 else: 

4718 for jj in range(np.shape(assigns)[1]): 

4719 if ( 

4720 assigns[ii][jj] >= num_tracks 

4721 and assigns[ii][jj] < 2 * num_tracks 

4722 ): 

4723 assigns[ii][jj] = -1 

4724 assigns[assigns >= 2 * num_tracks] -= 2 * num_tracks 

4725 if assigns[assigns >= 0].size != 0: 

4726 assigns[assigns >= 0] = mindices[ 

4727 assigns[assigns >= 0].astype(int)[ 

4728 assigns[assigns >= 0].astype(int) >= 0 

4729 ] 

4730 ] 

4731 # Assign updated hypotheses from gibbs sampler 

4732 for c, cst in enumerate(costs.flatten()): 

4733 update_hyp_cmp_temp = assigns[c,] 

4734 update_hyp_cmp_idx = cpreds * (update_hyp_cmp_temp + 1) + np.append( 

4735 np.array([np.arange(0, num_births)]), 

4736 num_births + np.array([p_hyp.track_set]), 

4737 ) 

4738 new_hyp = self._HypothesisHelper() 

4739 new_hyp.assoc_prob = ( 

4740 -self.clutter_rate 

4741 + num_meas * np.log(clutter) 

4742 + np.log(p_hyp.assoc_prob) 

4743 - cst 

4744 ) 

4745 new_hyp.track_set = update_hyp_cmp_idx[update_hyp_cmp_idx >= 0].astype( 

4746 int 

4747 ) 

4748 up_hyp.append(new_hyp) 

4749 lse = log_sum_exp([x.assoc_prob for x in up_hyp]) 

4750 

4751 for ii in range(0, len(up_hyp)): 

4752 up_hyp[ii].assoc_prob = np.exp(up_hyp[ii].assoc_prob - lse) 

4753 return up_hyp 

4754 

4755 def correct(self, timestep, meas, filt_args={}): 

4756 """Correction step of the JGLMB filter. 

4757 

4758 This corrects the hypotheses based on the measurements and gates the 

4759 measurements according to the class settings. It also updates the 

4760 cardinality distribution. Because this calls the inner filter's correct 

4761 function, the keyword arguments must contain any information needed by 

4762 that function. 

4763 

4764 Parameters 

4765 ---------- 

4766 timestep: float 

4767 Current timestep. 

4768 meas_in : list 

4769 List of Nm x 1 numpy arrays each representing a measuremnt. 

4770 filt_args : dict, optional 

4771 keyword arguments to pass to the inner filters correct function. 

4772 The default is {}. 

4773 

4774 Todo 

4775 ---- 

4776 Fix the measurement gating 

4777 

4778 Returns 

4779 ------- 

4780 None 

4781 """ 

4782 # gating by tracks 

4783 if self.gating_on: 

4784 RuntimeError("Gating not implemented yet. PLEASE TURN OFF GATING") 

4785 # for ent in self._track_tab: 

4786 # ent.gatemeas = self._gate_meas(meas, ent.probDensity.means, 

4787 # ent.probDensity.covariances) 

4788 else: 

4789 for ent in self._track_tab: 

4790 ent.gatemeas = np.arange(0, len(meas)) 

4791 # Pre-calculation of average survival/death probabilities 

4792 avg_prob_surv, avg_prob_death = self._calc_avg_prob_surv_death() 

4793 

4794 # Pre-calculation of average detection/missed probabilities 

4795 avg_prob_detect, avg_prob_miss_detect = self._calc_avg_prob_det_mdet() 

4796 

4797 if self.save_measurements: 

4798 self._meas_tab.append(deepcopy(meas)) 

4799 num_meas = len(meas) 

4800 

4801 # missed detection tracks 

4802 [up_tab, all_cost_m] = self._gen_cor_tab(num_meas, meas, timestep, filt_args) 

4803 

4804 up_hyp = self._gen_cor_hyps( 

4805 num_meas, 

4806 avg_prob_detect, 

4807 avg_prob_miss_detect, 

4808 avg_prob_surv, 

4809 avg_prob_death, 

4810 all_cost_m, 

4811 ) 

4812 

4813 self._track_tab = up_tab 

4814 self._hypotheses = up_hyp 

4815 self._card_dist = self._calc_card_dist(self._hypotheses) 

4816 self._clean_predictions() 

4817 self._clean_updates() 

4818 self._update_has_been_called = True 

4819 self._old_track_tab_len = len(self._track_tab) 

4820 

4821 

4822class STMJointGeneralizedLabeledMultiBernoulli( 

4823 _STMGLMBBase, JointGeneralizedLabeledMultiBernoulli 

4824): 

4825 """Implementation of a STM-JGLMB class.""" 

4826 

4827 def __init__(self, **kwargs): 

4828 super().__init__(**kwargs) 

4829 

4830 

4831class SMCJointGeneralizedLabeledMultiBernoulli( 

4832 _SMCGLMBBase, JointGeneralizedLabeledMultiBernoulli 

4833): 

4834 """Implementation of a SMC-JGLMB filter.""" 

4835 

4836 def __init__(self, **kwargs): 

4837 super().__init__(**kwargs) 

4838 

4839 

4840class GSMJointGeneralizedLabeledMultiBernoulli(JointGeneralizedLabeledMultiBernoulli): 

4841 """Implementation of a GSM-JGLMB filter. 

4842 

4843 The implementation of the GSM-JGLMB fitler does not change for different 

4844 core filters (i.e. QKF GSM, SQKF GSM, UKF GSM, etc.) so this class can use 

4845 any of the GSM inner filters from gncpy.filters 

4846 """ 

4847 

4848 def __init__(self, **kwargs): 

4849 super().__init__(**kwargs) 

4850 

4851 

4852class IMMJointGeneralizedLabeledMultiBernoulli( 

4853 _IMMGLMBBase, JointGeneralizedLabeledMultiBernoulli 

4854): 

4855 """Implementation of an IMM-JGLMB filter.""" 

4856 

4857 def __init__(self, **kwargs): 

4858 super().__init__(**kwargs) 

4859 

4860 

4861class MSJointGeneralizedLabeledMultiBernoulli(JointGeneralizedLabeledMultiBernoulli): 

4862 """Implementation of the Multiple Sensor JGLMB Filter""" 

4863 

4864 def __init__(self, **kwargs): 

4865 super().__init__(**kwargs) 

4866 

4867 def _gen_cor_tab(self, num_meas, meas, timestep, comb_inds, filt_args): 

4868 num_pred = len(self._track_tab) 

4869 num_sens = len(meas) 

4870 # if len(meas) != len(self.filter.meas_model_list): 

4871 # raise ValueError("measurement lists must match number of measurement models") 

4872 up_tab = [None] * (num_meas + 1) * num_pred 

4873 

4874 for ii, track in enumerate(self._track_tab): 

4875 up_tab[ii] = self._TabEntry().setup(track) 

4876 up_tab[ii].meas_assoc_hist.append(None) 

4877 # measurement updated tracks 

4878 all_cost_m = np.zeros((num_pred, num_meas)) 

4879 for ii, ent in enumerate(self._track_tab): 

4880 for emm, z in enumerate(meas): 

4881 s_to_ii = num_pred * emm + ii + num_pred 

4882 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry( 

4883 z, ent, timestep, filt_args 

4884 ) 

4885 if up_tab[s_to_ii] is not None: 

4886 up_tab[s_to_ii].meas_assoc_hist.append(comb_inds[emm]) 

4887 all_cost_m[ii, emm] = cost 

4888 

4889 return up_tab, all_cost_m 

4890 

4891 def _gen_cor_hyps( 

4892 self, 

4893 num_meas, 

4894 avg_prob_detect, 

4895 avg_prob_miss_detect, 

4896 avg_prob_surv, 

4897 avg_prob_death, 

4898 all_cost_m, 

4899 meas_combs, 

4900 cor_tab=None, 

4901 ): 

4902 # Define clutter 

4903 clutter = self.clutter_rate * self.clutter_den 

4904 

4905 # Joint Cost Matrix 

4906 joint_cost = np.concatenate( 

4907 [ 

4908 np.diag(avg_prob_death.ravel()), 

4909 np.diag(avg_prob_surv.ravel() * avg_prob_miss_detect.ravel()), 

4910 ], 

4911 axis=1, 

4912 ) 

4913 

4914 other_jc_terms = ( 

4915 np.tile((avg_prob_surv * avg_prob_detect).reshape((-1, 1)), (1, num_meas)) 

4916 * all_cost_m 

4917 / (clutter) 

4918 ) 

4919 

4920 # Full joint cost matrix for sensor s 

4921 joint_cost = np.append(joint_cost, other_jc_terms, axis=1) 

4922 

4923 gate_meas_indices = np.zeros((len(self._track_tab), num_meas)) 

4924 for ii in range(0, len(self._track_tab)): 

4925 for jj in range(0, len(self._track_tab[ii].gatemeas)): 

4926 gate_meas_indices[ii][jj] = self._track_tab[ii].gatemeas[jj] 

4927 gate_meas_indc = gate_meas_indices >= 0 

4928 

4929 # Component updates 

4930 ss_w = 0 

4931 up_hyp = [] 

4932 for p_hyp in self._hypotheses: 

4933 ss_w += np.sqrt(p_hyp.assoc_prob) 

4934 for p_hyp in self._hypotheses: 

4935 for ind_lst in meas_combs: 

4936 cpreds = len(self._track_tab) # num_pred 

4937 num_births = len(self.birth_terms) # num_birth_terms 

4938 num_exists = len(p_hyp.track_set) # num_existing_tracks 

4939 num_tracks = num_births + num_exists # num_possible_tracks 

4940 

4941 # Hypothesis index masking 

4942 # all birth terms and tracks included in p_hyp.track_set 

4943 tindices = np.concatenate( 

4944 (np.arange(0, num_births), num_births + np.array(p_hyp.track_set)) 

4945 ).astype(int) 

4946 # lselmask = np.zeros((len(self._track_tab), len(ind_lst)), dtype="bool") 

4947 lselmask = np.zeros((len(self._track_tab), num_meas), dtype="bool") 

4948 # lselmask = np. 

4949 for ii, index in enumerate(ind_lst): 

4950 lselmask[tindices, index] = gate_meas_indc[tindices, index] 

4951 

4952 # verify sort works for 3d arrays similar to 2d arrays, may have to do this list-wise 

4953 keys = np.array([np.sort(gate_meas_indices[lselmask])]) 

4954 # keys = np.array([np.sort(gate_meas_indices[:, ind_lst][lselmask])]) 

4955 # meas_indices 

4956 mindices = self._unique_faster(keys) 

4957 

4958 comb_tind_cpred = np.append( 

4959 np.append(tindices, cpreds + tindices), [2 * cpreds + mindices] 

4960 ) 

4961 # comb_tind_cpred = np.append( 

4962 # np.append(tindices, cpreds + tindices), relevant_meas_inds 

4963 # ).astype(int) 

4964 

4965 cost_m = joint_cost[tindices][:, comb_tind_cpred] 

4966 

4967 with warnings.catch_warnings(): 

4968 warnings.simplefilter("ignore", RuntimeWarning) 

4969 neg_log = -np.log(cost_m) 

4970 

4971 m = np.round(self.req_upd * np.sqrt(p_hyp.assoc_prob) / ss_w) 

4972 m = int(m.item()) + 1 

4973 

4974 # Gibbs Sampler 

4975 [assigns, costs] = gibbs(neg_log, m, rng=self._rng) 

4976 

4977 # Process unique assignments from gibbs sampler 

4978 assigns[assigns < num_tracks] = -np.inf 

4979 for ii in range(np.shape(assigns)[0]): 

4980 if len(np.shape(assigns)) < 2: 

4981 if assigns[ii] >= num_tracks and assigns[ii] < 2 * num_tracks: 

4982 assigns[ii] = -1 

4983 else: 

4984 for jj in range(np.shape(assigns)[1]): 

4985 if ( 

4986 assigns[ii][jj] >= num_tracks 

4987 and assigns[ii][jj] < 2 * num_tracks 

4988 ): 

4989 assigns[ii][jj] = -1 

4990 assigns[assigns >= 2 * num_tracks] -= 2 * num_tracks 

4991 if assigns[assigns >= 0].size != 0: 

4992 assigns[assigns >= 0] = mindices[ 

4993 assigns[assigns >= 0].astype(int)[ 

4994 assigns[assigns >= 0].astype(int) >= 0 

4995 ] 

4996 ] 

4997 # Assign updated hypotheses from gibbs sampler 

4998 for c, cst in enumerate(costs.flatten()): 

4999 update_hyp_cmp_temp = assigns[c,] 

5000 update_hyp_cmp_idx = cpreds * (update_hyp_cmp_temp + 1) + np.append( 

5001 np.array([np.arange(0, num_births)]), 

5002 num_births + np.array([p_hyp.track_set]), 

5003 ) 

5004 new_hyp = self._HypothesisHelper() 

5005 new_hyp.assoc_prob = ( 

5006 -self.clutter_rate 

5007 + num_meas * np.log(clutter) 

5008 + np.log(p_hyp.assoc_prob) 

5009 - cst 

5010 ) 

5011 new_hyp.track_set = update_hyp_cmp_idx[ 

5012 update_hyp_cmp_idx >= 0 

5013 ].astype(int) 

5014 up_hyp.append(new_hyp) 

5015 lse = log_sum_exp([x.assoc_prob for x in up_hyp]) 

5016 

5017 for ii in range(0, len(up_hyp)): 

5018 up_hyp[ii].assoc_prob = np.exp(up_hyp[ii].assoc_prob - lse) 

5019 return up_hyp 

5020 

5021 def correct(self, timestep, meas, filt_args={}): 

5022 """Correction step of the MS-JGLMB filter. 

5023 

5024 This corrects the hypotheses based on the measurements and gates the 

5025 measurements according to the class settings. It also updates the 

5026 cardinality distribution. Because this calls the inner filter's correct 

5027 function, the keyword arguments must contain any information needed by 

5028 that function. 

5029 

5030 Parameters 

5031 ---------- 

5032 timestep: float 

5033 Current timestep. 

5034 meas : list 

5035 List of lists representing sensor measurements containing Nm x 1 numpy arrays each representing a single measurement. 

5036 filt_args : dict, optional 

5037 keyword arguments to pass to the inner filters correct function. 

5038 The default is {}. 

5039 

5040 Todo 

5041 ---- 

5042 Fix the measurement gating 

5043 

5044 Returns 

5045 ------- 

5046 None 

5047 """ 

5048 all_combs = list(itertools.product(*meas)) 

5049 num_meas_per_sens = [len(x) for x in meas] 

5050 num_meas = len(all_combs) 

5051 num_sens = len(meas) 

5052 mnmps = min(num_meas_per_sens) 

5053 

5054 comb_inds = list(itertools.product(*list(np.arange(0, len(x)) for x in meas))) 

5055 comb_inds = [list(ele) for ele in comb_inds] 

5056 

5057 all_meas_combs = list(itertools.combinations(comb_inds, mnmps)) 

5058 all_meas_combs = [list(ele) for ele in all_meas_combs] 

5059 

5060 poss_meas_combs = [] 

5061 

5062 for ii in range(0, len(all_meas_combs)): 

5063 break_flag = False 

5064 cur_comb = [] 

5065 for jj, lst1 in enumerate(all_meas_combs[ii]): 

5066 for kk, lst2 in enumerate(all_meas_combs[ii]): 

5067 if jj == kk: 

5068 continue 

5069 else: 

5070 out = (np.array(lst1) == np.array(lst2)).tolist() 

5071 if any(out): 

5072 break_flag = True 

5073 break 

5074 if break_flag: 

5075 break 

5076 if break_flag: 

5077 pass 

5078 else: 

5079 for lst1 in all_meas_combs[ii]: 

5080 for ii, lst2 in enumerate(comb_inds): 

5081 if lst1 == lst2: 

5082 cur_comb.append(ii) 

5083 poss_meas_combs.append(cur_comb) 

5084 

5085 # gating by tracks 

5086 if self.gating_on: 

5087 RuntimeError("Gating not implemented yet. PLEASE TURN OFF GATING") 

5088 # for ent in self._track_tab: 

5089 # ent.gatemeas = self._gate_meas(meas, ent.probDensity.means, 

5090 # ent.probDensity.covariances) 

5091 else: 

5092 for ent in self._track_tab: 

5093 ent.gatemeas = np.arange(0, len(all_combs)) 

5094 # ent.gatemeas = np.arange(0, len(poss_meas_combs)) 

5095 # Pre-calculation of average survival/death probabilities 

5096 avg_prob_surv, avg_prob_death = self._calc_avg_prob_surv_death() 

5097 

5098 # Pre-calculation of average detection/missed probabilities 

5099 avg_prob_detect, avg_prob_miss_detect = self._calc_avg_prob_det_mdet() 

5100 

5101 if self.save_measurements: 

5102 self._meas_tab.append(deepcopy(meas)) 

5103 # all_combs = list(itertools.product(*meas)) 

5104 

5105 # missed detection tracks 

5106 [up_tab, all_cost_m] = self._gen_cor_tab( 

5107 num_meas, all_combs, timestep, comb_inds, filt_args 

5108 ) 

5109 

5110 up_hyp = self._gen_cor_hyps( 

5111 num_meas, 

5112 avg_prob_detect, 

5113 avg_prob_miss_detect, 

5114 avg_prob_surv, 

5115 avg_prob_death, 

5116 all_cost_m, 

5117 poss_meas_combs, 

5118 cor_tab=up_tab, 

5119 ) 

5120 

5121 self._track_tab = up_tab 

5122 self._hypotheses = up_hyp 

5123 self._card_dist = self._calc_card_dist(self._hypotheses) 

5124 self._clean_predictions() 

5125 self._clean_updates() 

5126 self._update_has_been_called = True 

5127 self._old_track_tab_len = len(self._track_tab) 

5128 

5129 

5130class MSIMMJointGeneralizedLabeledMultiBernoulli( 

5131 _IMMGLMBBase, MSJointGeneralizedLabeledMultiBernoulli 

5132): 

5133 """An implementation of the Multi-Sensor IMM-JGLMB algorithm.""" 

5134 

5135 def __init__(self, **kwargs): 

5136 super().__init__(**kwargs) 

5137 

5138 def _correct_track_tab_entry(self, meas, tab, timestep, filt_args): 

5139 newTab = self._TabEntry().setup(tab) 

5140 new_f_states = [None] * len(newTab.filt_states) 

5141 new_s_hist = [None] * len(newTab.filt_states) 

5142 new_c_hist = [None] * len(newTab.filt_states) 

5143 new_w = [None] * len(newTab.filt_states) 

5144 depleted = False 

5145 for ii, (f_state, state, w) in enumerate( 

5146 zip( 

5147 newTab.filt_states, 

5148 newTab.state_hist[-1], 

5149 newTab.distrib_weights_hist[-1], 

5150 ) 

5151 ): 

5152 try: 

5153 ( 

5154 new_f_states[ii], 

5155 new_s_hist[ii], 

5156 new_c_hist[ii], 

5157 new_w[ii], 

5158 ) = self._inner_correct(timestep, meas, f_state, w, state, filt_args) 

5159 except ( 

5160 gerr.ParticleDepletionError, 

5161 gerr.ParticleEstimationDomainError, 

5162 gerr.ExtremeMeasurementNoiseError, 

5163 ): 

5164 return None, 0 

5165 newTab.filt_states = new_f_states 

5166 newTab.state_hist[-1] = new_s_hist 

5167 newTab.cov_hist[-1] = new_c_hist 

5168 new_w = [w + np.finfo(float).eps for w in new_w] 

5169 if not depleted: 

5170 cost = np.sum(new_w).item() 

5171 newTab.distrib_weights_hist[-1] = [w / cost for w in new_w] 

5172 else: 

5173 cost = 0 

5174 return newTab, cost 

5175 

5176 

5177class PoissonMultiBernoulliMixture(RandomFiniteSetBase): 

5178 class _TabEntry: 

5179 def __init__(self): 

5180 self.label = () # time step born, index of birth model born from 

5181 self.distrib_weights_hist = [] # list of weights of the probDensity 

5182 self.exist_prob = None # existence probability of the probDensity 

5183 self.filt_states = [] # list of dictionaries from filters save function 

5184 self.meas_assoc_hist = ( 

5185 [] 

5186 ) # list indices into measurement list per time step 

5187 

5188 self.state_hist = [] # list of lists of numpy arrays for each timestep 

5189 self.cov_hist = ( 

5190 [] 

5191 ) # list of lists of numpy arrays for each timestep (or None) 

5192 

5193 """ linear index corresponding to timestep, manually updated. Used 

5194 to index things since timestep in label can have decimals.""" 

5195 self.time_index = None 

5196 

5197 def setup(self, tab): 

5198 """Use to avoid expensive deepcopy.""" 

5199 self.label = tab.label 

5200 self.distrib_weights_hist = tab.distrib_weights_hist.copy() 

5201 self.exist_prob = tab.exist_prob 

5202 self.filt_states = deepcopy(tab.filt_states) 

5203 self.meas_assoc_hist = tab.meas_assoc_hist.copy() 

5204 

5205 self.state_hist = [None] * len(tab.state_hist) 

5206 self.state_hist = [s.copy() for s in [s_lst for s_lst in tab.state_hist]] 

5207 self.cov_hist = [ 

5208 c.copy() if c else [] for c in [c_lst for c_lst in tab.cov_hist] 

5209 ] 

5210 

5211 self.time_index = tab.time_index 

5212 

5213 return self 

5214 

5215 class _HypothesisHelper: 

5216 def __init__(self): 

5217 self.assoc_prob = 0 

5218 self.track_set = [] # indices in lookup table 

5219 

5220 @property 

5221 def num_tracks(self): 

5222 return len(self.track_set) 

5223 

5224 class _ExtractHistHelper: 

5225 def __init__(self): 

5226 self.label = () 

5227 self.meas_ind_hist = [] 

5228 self.b_time_index = None 

5229 self.states = [] 

5230 self.covs = [] 

5231 

5232 def __init__( 

5233 self, 

5234 req_upd=None, 

5235 gating_on=False, 

5236 prune_threshold=10**-15, 

5237 exist_threshold=10**-15, 

5238 max_hyps=3000, 

5239 decimal_places=2, 

5240 save_measurements=False, 

5241 **kwargs, 

5242 ): 

5243 self.req_upd = req_upd 

5244 self.gating_on = gating_on 

5245 self.prune_threshold = prune_threshold 

5246 self.exist_threshold = exist_threshold 

5247 self.max_hyps = max_hyps 

5248 self.decimal_places = decimal_places 

5249 self.save_measurements = save_measurements 

5250 

5251 self._track_tab = [] # list of all possible tracks 

5252 self._extractable_hists = [] 

5253 

5254 self._filter = None 

5255 self._baseFilter = None 

5256 

5257 hyp0 = self._HypothesisHelper() 

5258 hyp0.assoc_prob = 1 

5259 hyp0.track_set = [] 

5260 self._hypotheses = [hyp0] # list of _HypothesisHelper objects 

5261 

5262 self._card_dist = [] # probability of having index # as cardinality 

5263 

5264 """ linear index corresponding to timestep, manually updated. Used 

5265 to index things since timestep in label can have decimals. Must 

5266 be updated once per time step.""" 

5267 self._time_index_cntr = 0 

5268 

5269 self.ospa2 = None 

5270 self.ospa2_localization = None 

5271 self.ospa2_cardinality = None 

5272 self._ospa2_params = {} 

5273 

5274 super().__init__(**kwargs) 

5275 self._states = [[]] 

5276 

5277 def save_filter_state(self): 

5278 """Saves filter variables so they can be restored later. 

5279 

5280 Note that to pickle the resulting dictionary the :code:`dill` package 

5281 may need to be used due to potential pickling of functions. 

5282 """ 

5283 filt_state = super().save_filter_state() 

5284 

5285 filt_state["req_upd"] = self.req_upd 

5286 filt_state["gating_on"] = self.gating_on 

5287 filt_state["prune_threshold"] = self.prune_threshold 

5288 filt_state["exist_threshold"] = self.exist_threshold 

5289 filt_state["max_hyps"] = self.max_hyps 

5290 filt_state["decimal_places"] = self.decimal_places 

5291 filt_state["save_measurements"] = self.save_measurements 

5292 

5293 filt_state["_track_tab"] = self._track_tab 

5294 filt_state["_extractable_hists"] = self._extractable_hists 

5295 

5296 if self._baseFilter is not None: 

5297 filt_state["_baseFilter"] = ( 

5298 type(self._baseFilter), 

5299 self._baseFilter.save_filter_state(), 

5300 ) 

5301 else: 

5302 filt_state["_baseFilter"] = (None, self._baseFilter) 

5303 filt_state["_hypotheses"] = self._hypotheses 

5304 filt_state["_card_dist"] = self._card_dist 

5305 filt_state["_time_index_cntr"] = self._time_index_cntr 

5306 

5307 filt_state["ospa2"] = self.ospa2 

5308 filt_state["ospa2_localization"] = self.ospa2_localization 

5309 filt_state["ospa2_cardinality"] = self.ospa2_cardinality 

5310 filt_state["_ospa2_params"] = self._ospa_params 

5311 

5312 return filt_state 

5313 

5314 def load_filter_state(self, filt_state): 

5315 """Initializes filter using saved filter state. 

5316 

5317 Attributes 

5318 ---------- 

5319 filt_state : dict 

5320 Dictionary generated by :meth:`save_filter_state`. 

5321 """ 

5322 super().load_filter_state(filt_state) 

5323 

5324 self.req_upd = filt_state["req_upd"] 

5325 self.gating_on = filt_state["gating_on"] 

5326 self.prune_threshold = filt_state["prune_threshold"] 

5327 self.exist_threshold = filt_state["exist_threshold"] 

5328 self.max_hyps = filt_state["max_hyps"] 

5329 self.decimal_places = filt_state["decimal_places"] 

5330 self.save_measurements = filt_state["save_measurements"] 

5331 

5332 self._track_tab = filt_state["_track_tab"] 

5333 self._extractable_hists = filt_state["_extractable_hists"] 

5334 

5335 cls_type = filt_state["_baseFilter"][0] 

5336 if cls_type is not None: 

5337 self._baseFilter = cls_type() 

5338 self._baseFilter.load_filter_state(filt_state["_baseFilter"][1]) 

5339 else: 

5340 self._baseFilter = None 

5341 self._hypotheses = filt_state["_hypotheses"] 

5342 self._card_dist = filt_state["_card_dist"] 

5343 self._time_index_cntr = filt_state["_time_index_cntr"] 

5344 

5345 self.ospa2 = filt_state["ospa2"] 

5346 self.ospa2_localization = filt_state["ospa2_localization"] 

5347 self.ospa2_cardinality = filt_state["ospa2_cardinality"] 

5348 self._ospa2_params = filt_state["_ospa2_params"] 

5349 

5350 @property 

5351 def states(self): 

5352 """Read only list of extracted states. 

5353 

5354 This is a list with 1 element per timestep, and each element is a list 

5355 of the best states extracted at that timestep. The order of each 

5356 element corresponds to the label order. 

5357 """ 

5358 return self._states 

5359 

5360 @property 

5361 def covariances(self): 

5362 """Read only list of extracted covariances. 

5363 

5364 This is a list with 1 element per timestep, and each element is a list 

5365 of the best covariances extracted at that timestep. The order of each 

5366 element corresponds to the state order. 

5367 

5368 Raises 

5369 ------ 

5370 RuntimeWarning 

5371 If the class is not saving the covariances, and returns an empty list. 

5372 """ 

5373 if not self.save_covs: 

5374 raise RuntimeWarning("Not saving covariances") 

5375 return [] 

5376 return self._covs 

5377 

5378 @property 

5379 def filter(self): 

5380 """Inner filter handling dynamics, must be a gncpy.filters.BayesFilter.""" 

5381 return self._filter 

5382 

5383 @filter.setter 

5384 def filter(self, val): 

5385 self._baseFilter = deepcopy(val) 

5386 self._filter = val 

5387 

5388 @property 

5389 def cardinality(self): 

5390 """Cardinality estimate.""" 

5391 return np.argmax(self._card_dist) 

5392 

5393 def _init_filt_states(self, distrib): 

5394 filt_states = [None] * len(distrib.means) 

5395 states = [m.copy() for m in distrib.means] 

5396 if self.save_covs: 

5397 covs = [c.copy() for c in distrib.covariances] 

5398 else: 

5399 covs = [] 

5400 weights = distrib.weights.copy() 

5401 for ii, (m, cov) in enumerate(zip(distrib.means, distrib.covariances)): 

5402 self._baseFilter.cov = cov.copy() 

5403 if isinstance(self._baseFilter, gfilts.UnscentedKalmanFilter) or isinstance( 

5404 self._baseFilter, gfilts.UKFGaussianScaleMixtureFilter 

5405 ): 

5406 self._baseFilter.init_sigma_points(m) 

5407 filt_states[ii] = self._baseFilter.save_filter_state() 

5408 return filt_states, weights, states, covs 

5409 

5410 def _inner_predict(self, timestep, filt_state, state, filt_args): 

5411 self.filter.load_filter_state(filt_state) 

5412 new_s = self.filter.predict(timestep, state, **filt_args) 

5413 new_f_state = self.filter.save_filter_state() 

5414 if self.save_covs: 

5415 new_cov = self.filter.cov.copy() 

5416 else: 

5417 new_cov = None 

5418 return new_f_state, new_s, new_cov 

5419 

5420 def _predict_det_tab_entry(self, tab, timestep, filt_args): 

5421 new_tab = self._TabEntry().setup(tab) 

5422 new_f_states = [None] * len(new_tab.filt_states) 

5423 new_s_hist = [None] * len(new_tab.filt_states) 

5424 new_c_hist = [None] * len(new_tab.filt_states) 

5425 for ii, (f_state, state) in enumerate( 

5426 zip(new_tab.filt_states, new_tab.state_hist[-1]) 

5427 ): 

5428 (new_f_states[ii], new_s_hist[ii], new_c_hist[ii]) = self._inner_predict( 

5429 timestep, f_state, state, filt_args 

5430 ) 

5431 new_tab.filt_states = new_f_states 

5432 new_tab.state_hist.append(new_s_hist) 

5433 new_tab.cov_hist.append(new_c_hist) 

5434 new_tab.distrib_weights_hist.append(new_tab.distrib_weights_hist[-1].copy()) 

5435 new_tab.exist_prob = new_tab.exist_prob * self.prob_survive 

5436 return new_tab 

5437 

5438 def _gen_pred_tab(self, timestep, filt_args): 

5439 pred_tab = [] 

5440 

5441 for ii, ent in enumerate(self._track_tab): 

5442 entry = self._predict_det_tab_entry(ent, timestep, filt_args) 

5443 pred_tab.append(entry) 

5444 

5445 return pred_tab 

5446 

5447 def predict(self, timestep, filt_args={}): 

5448 # all objects are propagated forward regardless of previous associations. 

5449 self._track_tab = self._gen_pred_tab(timestep, filt_args) 

5450 

5451 def _calc_avg_prob_det_mdet(self, cor_tab): 

5452 avg_prob_detect = self.prob_detection * np.ones(len(cor_tab)) 

5453 avg_prob_miss_detect = 1 - avg_prob_detect 

5454 

5455 return avg_prob_detect, avg_prob_miss_detect 

5456 

5457 def _inner_correct( 

5458 self, timestep, meas, filt_state, distrib_weight, state, filt_args 

5459 ): 

5460 self.filter.load_filter_state(filt_state) 

5461 cor_state, likely = self.filter.correct(timestep, meas, state, **filt_args) 

5462 new_f_state = self.filter.save_filter_state() 

5463 new_s = cor_state 

5464 if self.save_covs: 

5465 new_c = self.filter.cov.copy() 

5466 else: 

5467 new_c = None 

5468 new_w = distrib_weight * likely 

5469 

5470 return new_f_state, new_s, new_c, new_w 

5471 

5472 def _correct_track_tab_entry(self, meas, tab, timestep, filt_args): 

5473 new_tab = self._TabEntry().setup(tab) 

5474 new_f_states = [None] * len(new_tab.filt_states) 

5475 new_s_hist = [None] * len(new_tab.filt_states) 

5476 new_c_hist = [None] * len(new_tab.filt_states) 

5477 new_w = [None] * len(new_tab.filt_states) 

5478 depleted = False 

5479 for ii, (f_state, state, w) in enumerate( 

5480 zip( 

5481 new_tab.filt_states, 

5482 new_tab.state_hist[-1], 

5483 new_tab.distrib_weights_hist[-1], 

5484 ) 

5485 ): 

5486 try: 

5487 ( 

5488 new_f_states[ii], 

5489 new_s_hist[ii], 

5490 new_c_hist[ii], 

5491 new_w[ii], 

5492 ) = self._inner_correct(timestep, meas, f_state, w, state, filt_args) 

5493 except ( 

5494 gerr.ParticleDepletionError, 

5495 gerr.ParticleEstimationDomainError, 

5496 gerr.ExtremeMeasurementNoiseError, 

5497 ): 

5498 return None, 0 

5499 new_tab.filt_states = new_f_states 

5500 new_tab.state_hist[-1] = new_s_hist 

5501 new_tab.cov_hist[-1] = new_c_hist 

5502 new_w = [w + np.finfo(float).eps for w in new_w] 

5503 if not depleted: 

5504 cost = (new_tab.exist_prob * self.prob_detection * np.sum(new_w).item()) / ( 

5505 (1 - new_tab.exist_prob + new_tab.exist_prob * self.prob_miss_detection) 

5506 * np.sum(tab.distrib_weights_hist[-1]).item() 

5507 ) 

5508 # new_tab.distrib_weights_hist[-1] = [w / cost for w in new_w] 

5509 nw_list = [w * new_tab.exist_prob * self.prob_detection for w in new_w] 

5510 new_tab.distrib_weights_hist[-1] = [ 

5511 w / np.sum(nw_list).item() for w in nw_list 

5512 ] 

5513 new_tab.exist_prob = 1 

5514 else: 

5515 cost = 0 

5516 return new_tab, cost 

5517 

5518 def _correct_birth_tab_entry(self, meas, distrib, timestep, filt_args): 

5519 new_tab = self._TabEntry() 

5520 (filt_states, weights, states, covs) = self._init_filt_states(distrib) 

5521 

5522 new_f_states = [None] * len(filt_states) 

5523 new_s_hist = [None] * len(filt_states) 

5524 new_c_hist = [None] * len(filt_states) 

5525 new_w = [None] * len(filt_states) 

5526 depleted = False 

5527 for ii, (f_state, state, w) in enumerate(zip(filt_states, states, weights)): 

5528 try: 

5529 ( 

5530 new_f_states[ii], 

5531 new_s_hist[ii], 

5532 new_c_hist[ii], 

5533 new_w[ii], 

5534 ) = self._inner_correct(timestep, meas, f_state, w, state, filt_args) 

5535 except ( 

5536 gerr.ParticleDepletionError, 

5537 gerr.ParticleEstimationDomainError, 

5538 gerr.ExtremeMeasurementNoiseError, 

5539 ): 

5540 return None, 0 

5541 new_tab.filt_states = new_f_states 

5542 new_tab.state_hist = [new_s_hist] 

5543 new_tab.cov_hist = [new_c_hist] 

5544 new_tab.distrib_weights_hist = [] 

5545 new_w = [w + np.finfo(float).eps for w in new_w] 

5546 if not depleted: 

5547 cost = ( 

5548 np.sum(new_w).item() * self.prob_detection 

5549 + self.clutter_rate * self.clutter_den 

5550 ) 

5551 new_tab.distrib_weights_hist.append( 

5552 [w / np.sum(new_w).item() for w in new_w] 

5553 ) 

5554 new_tab.exist_prob = ( 

5555 self.prob_detection 

5556 * cost 

5557 / (self.clutter_rate * self.clutter_den + self.prob_detection * cost) 

5558 ) 

5559 else: 

5560 cost = 0 

5561 new_tab.time_index = self._time_index_cntr 

5562 return new_tab, cost 

5563 

5564 def _gen_cor_tab(self, num_meas, meas, timestep, filt_args): 

5565 num_pred = len(self._track_tab) 

5566 num_birth = len(self.birth_terms) 

5567 up_tab = [None] * ((num_meas + 1) * num_pred + num_meas * num_birth) 

5568 

5569 # Missed Detection Updates 

5570 for ii, track in enumerate(self._track_tab): 

5571 up_tab[ii] = self._TabEntry().setup(track) 

5572 sum_non_exist_prob = ( 

5573 1 

5574 - up_tab[ii].exist_prob 

5575 + up_tab[ii].exist_prob * self.prob_miss_detection 

5576 ) 

5577 up_tab[ii].distrib_weights_hist.append( 

5578 [w * sum_non_exist_prob for w in up_tab[ii].distrib_weights_hist[-1]] 

5579 ) 

5580 up_tab[ii].exist_prob = ( 

5581 up_tab[ii].exist_prob * self.prob_miss_detection 

5582 ) / (sum_non_exist_prob) 

5583 up_tab[ii].meas_assoc_hist.append(None) 

5584 # left_cost_m = np.zeros() 

5585 # all_cost_m = np.zeros((num_pred + num_birth * num_meas, num_meas)) 

5586 all_cost_m = np.zeros((num_meas, num_pred + num_birth * num_meas)) 

5587 

5588 # Update for all existing tracks 

5589 for emm, z in enumerate(meas): 

5590 for ii, ent in enumerate(self._track_tab): 

5591 s_to_ii = num_pred * emm + ii + num_pred 

5592 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry( 

5593 z, ent, timestep, filt_args 

5594 ) 

5595 if up_tab[s_to_ii] is not None: 

5596 up_tab[s_to_ii].meas_assoc_hist.append(emm) 

5597 all_cost_m[emm, ii] = cost 

5598 

5599 # Update for all potential new births 

5600 for emm, z in enumerate(meas): 

5601 for ii, b_model in enumerate(self.birth_terms): 

5602 s_to_ii = ((num_meas + 1) * num_pred) + emm * num_birth + ii 

5603 (up_tab[s_to_ii], cost) = self._correct_birth_tab_entry( 

5604 z, b_model, timestep, filt_args 

5605 ) 

5606 if up_tab[s_to_ii] is not None: 

5607 up_tab[s_to_ii].meas_assoc_hist.append(emm) 

5608 all_cost_m[emm, emm + num_pred] = cost 

5609 return up_tab, all_cost_m 

5610 

5611 # TODO: find some way to cherry pick the appropriate all_combs to ensure that 

5612 # measurements are not duplicated before being passed to assignment 

5613 def _gen_cor_hyps( 

5614 self, num_meas, avg_prob_detect, avg_prob_miss_detect, all_cost_m, cor_tab 

5615 ): 

5616 num_pred = len(self._track_tab) 

5617 up_hyps = [] 

5618 if num_meas == 0: 

5619 for hyp in self._hypotheses: 

5620 pmd_log = np.sum( 

5621 [np.log(avg_prob_miss_detect[ii]) for ii in hyp.track_set] 

5622 ) 

5623 hyp.assoc_prob = -self.clutter_rate + pmd_log + np.log(hyp.assoc_prob) 

5624 up_hyps.append(hyp) 

5625 else: 

5626 clutter = self.clutter_rate * self.clutter_den 

5627 ss_w = 0 

5628 for p_hyp in self._hypotheses: 

5629 ss_w += np.sqrt(p_hyp.assoc_prob) 

5630 for p_hyp in self._hypotheses: 

5631 if p_hyp.num_tracks == 0: # all clutter 

5632 inds = np.arange(num_pred, num_pred + num_meas).tolist() 

5633 else: 

5634 inds = ( 

5635 p_hyp.track_set 

5636 + np.arange(num_pred, num_pred + num_meas).tolist() 

5637 ) 

5638 

5639 cost_m = all_cost_m[:, inds] 

5640 max_row_inds, max_col_inds = np.where(cost_m >= np.inf) 

5641 if max_row_inds.size > 0: 

5642 cost_m[max_row_inds, max_col_inds] = np.finfo(float).max 

5643 min_row_inds, min_col_inds = np.where(cost_m <= 0.0) 

5644 if min_row_inds.size > 0: 

5645 cost_m[min_row_inds, min_col_inds] = np.finfo(float).eps # 1 

5646 neg_log = -np.log(cost_m) 

5647 # if max_row_inds.size > 0: 

5648 # neg_log[max_row_inds, max_col_inds] = -np.inf 

5649 # if min_row_inds.size > 0: 

5650 # neg_log[min_row_inds, min_col_inds] = np.inf 

5651 

5652 m = np.round(self.req_upd * np.sqrt(p_hyp.assoc_prob) / ss_w) 

5653 m = int(m.item()) 

5654 # if m <1: 

5655 # m=1 

5656 [assigns, costs] = murty_m_best_all_meas_assigned(neg_log, m) 

5657 """assignment matrix consisting of 0 or 1 entries such that each column sums 

5658 to one and each row sums to zero or one""" # (transposed from the paper) 

5659 # assigns = assigns.T 

5660 # assigns = np.delete(assigns, 1, axis=0) 

5661 # costs = np.delete(costs, 1, axis=0) 

5662 

5663 pmd_log = np.sum( 

5664 [np.log(avg_prob_miss_detect[ii]) for ii in p_hyp.track_set] 

5665 ) 

5666 for a, c in zip(assigns, costs): 

5667 new_hyp = self._HypothesisHelper() 

5668 new_hyp.assoc_prob = ( 

5669 -self.clutter_rate 

5670 + num_meas * np.log(clutter) 

5671 + pmd_log 

5672 + np.log(p_hyp.assoc_prob) 

5673 - c 

5674 ) 

5675 if p_hyp.num_tracks == 0: 

5676 new_track_list = list(num_pred * a + num_pred * num_meas) 

5677 else: 

5678 # track_inds = np.argwhere(a==1) 

5679 new_track_list = [] 

5680 

5681 for ii, ms in enumerate(a): 

5682 if len(p_hyp.track_set) >= ms: 

5683 new_track_list.append( 

5684 (ii + 1) * num_pred + p_hyp.track_set[(ms - 1)] 

5685 ) 

5686 elif len(p_hyp.track_set) == len(a): 

5687 new_track_list.append( 

5688 num_pred * ms - ii * (num_pred - 1) 

5689 ) 

5690 elif len(p_hyp.track_set) < len(a): 

5691 new_track_list.append( 

5692 num_pred * (num_meas + 1) + (ms - num_meas) 

5693 ) 

5694 else: 

5695 new_track_list.append( 

5696 (num_meas + 1) * num_pred 

5697 - (ms - len(p_hyp.track_set) - 1) 

5698 ) 

5699 # if len(p_hyp.track_set) >= ms: 

5700 # new_track_list.append( 

5701 # (ii + 1) * num_pred + p_hyp.track_set[(ms - 1)] 

5702 # ) 

5703 # else: 

5704 # new_track_list.append( 

5705 # (num_meas + 1) * num_pred + ms - num_meas 

5706 # ) 

5707 # if len(a) == len(p_hyp.track_set): 

5708 # for ii, (ms, t) in enumerate(zip(a, p_hyp.track_set)): 

5709 # if len(p_hyp.track_set) >= ms: 

5710 # # new_track_list.append(((np.array(t)) * ms + num_pred)) 

5711 # # new_track_list.append((num_pred * ms + np.array(t))) 

5712 # new_track_list.append( 

5713 # (ii + 1) * num_pred + p_hyp.track_set[(ms - 1)] 

5714 # ) 

5715 # else: 

5716 # new_track_list.append( 

5717 # num_pred * ms - ii * (num_pred - 1) 

5718 # ) 

5719 # elif len(p_hyp.track_set) < len(a): 

5720 # for ii, ms in enumerate(a): 

5721 # if len(p_hyp.track_set) >= ms: 

5722 # # coiuld be this one, trying -1 first 

5723 # # new_track_list.append(((np.array(p_hyp.track_set[(ms-ii)]) + num_pred) * ms)) 

5724 # # new_track_list.append(((np.array(p_hyp.track_set[(ms-1)]) + num_pred) * ms + num_meas * ii)) 

5725 # # new_track_list.append( 

5726 # # (ms + ii) * num_pred + p_hyp.track_set[(ms - 1)] 

5727 # # ) 

5728 # new_track_list.append( 

5729 # (ii + 1) * num_pred + p_hyp.track_set[(ms - 1)] 

5730 # ) 

5731 # elif len(p_hyp.track_set) < ms: 

5732 # new_track_list.append( 

5733 # num_pred * (num_meas + 1) + (ms - num_meas) 

5734 # ) 

5735 # # new_track_list.append(num_meas * num_pred + ms) 

5736 # elif len(p_hyp.track_set) > len(a): 

5737 # # May need to modify this 

5738 # for ii, ms in enumerate(a): 

5739 # if len(p_hyp.track_set) >= ms: 

5740 # new_track_list.append( 

5741 # (ii + 1) * num_pred + p_hyp.track_set[(ms - 1)] 

5742 # ) 

5743 # # new_track_list.append( 

5744 # # (ms - 1) * num_pred + p_hyp.track_set[(ms - 1)] 

5745 # # ) 

5746 # # new_track_list.append( 

5747 # # ms * num_pred + p_hyp.track_set[(ms - 1)] 

5748 # # ) 

5749 # elif len(p_hyp.track_set) < ms: 

5750 # new_track_list.append( 

5751 # num_pred * (num_meas + 1) 

5752 # + (ms - 1 - len(p_hyp.track_set)) 

5753 # ) 

5754 

5755 # new_track_list = list(np.array(p_hyp.track_set) + num_pred + num_pred * a)# new_track_list = list(num_pred * a + np.array(p_hyp.track_set)) 

5756 

5757 new_hyp.track_set = new_track_list 

5758 up_hyps.append(new_hyp) 

5759 

5760 lse = log_sum_exp([x.assoc_prob for x in up_hyps]) 

5761 for ii in range(0, len(up_hyps)): 

5762 up_hyps[ii].assoc_prob = np.exp(up_hyps[ii].assoc_prob - lse) 

5763 return up_hyps 

5764 

5765 def _clean_updates(self): 

5766 used = [0] * len(self._track_tab) 

5767 for hyp in self._hypotheses: 

5768 for ii in hyp.track_set: 

5769 if self._track_tab[ii] is not None: 

5770 used[ii] += 1 

5771 nnz_inds = [idx for idx, val in enumerate(used) if val != 0] 

5772 track_cnt = len(nnz_inds) 

5773 

5774 new_inds = [None] * len(self._track_tab) 

5775 for ii, v in zip(nnz_inds, [ii for ii in range(0, track_cnt)]): 

5776 new_inds[ii] = v 

5777 # new_tab = [self._TabEntry().setup(self._track_tab[ii]) for ii in nnz_inds] 

5778 new_tab = [self._track_tab[ii] for ii in nnz_inds] 

5779 new_hyps = [] 

5780 for ii, hyp in enumerate(self._hypotheses): 

5781 if len(hyp.track_set) > 0: 

5782 track_set = [new_inds[ii] for ii in hyp.track_set] 

5783 if None in track_set: 

5784 continue 

5785 hyp.track_set = track_set 

5786 new_hyps.append(hyp) 

5787 self._track_tab = new_tab 

5788 self._hypotheses = new_hyps 

5789 

5790 def _calc_card_dist(self, hyp_lst): 

5791 """Calucaltes the cardinality distribution.""" 

5792 if len(hyp_lst) == 0: 

5793 return [ 

5794 1, 

5795 ] 

5796 card_dist = [] 

5797 for ii in range(0, max(map(lambda x: x.num_tracks, hyp_lst)) + 1): 

5798 card = 0 

5799 for hyp in hyp_lst: 

5800 if hyp.num_tracks == ii: 

5801 card = card + hyp.assoc_prob 

5802 card_dist.append(card) 

5803 return card_dist 

5804 

5805 def correct(self, timestep, meas, filt_args={}): 

5806 """Correction step of the PMBM filter. 

5807 

5808 Notes 

5809 ----- 

5810 This corrects the hypotheses based on the measurements and gates the 

5811 measurements according to the class settings. It also updates the 

5812 cardinality distribution. 

5813 

5814 Parameters 

5815 ---------- 

5816 timestep: float 

5817 Current timestep. 

5818 meas : list 

5819 List of Nm x 1 numpy arrays each representing a measuremnt. 

5820 filt_args : dict, optional 

5821 keyword arguments to pass to the inner filters correct function. 

5822 The default is {}. 

5823 

5824 Returns 

5825 ------- 

5826 None 

5827 """ 

5828 if self.gating_on: 

5829 warnings.warn("Gating not implemented yet. SKIPPING", RuntimeWarning) 

5830 # means = [] 

5831 # covs = [] 

5832 # for ent in self._track_tab: 

5833 # means.extend(ent.probDensity.means) 

5834 # covs.extend(ent.probDensity.covariances) 

5835 # meas = self._gate_meas(meas, means, covs) 

5836 if self.save_measurements: 

5837 self._meas_tab.append(deepcopy(meas)) 

5838 num_meas = len(meas) 

5839 cor_tab, all_cost_m = self._gen_cor_tab(num_meas, meas, timestep, filt_args) 

5840 

5841 # self._add_birth_hyps(num_meas) 

5842 

5843 avg_prob_det, avg_prob_mdet = self._calc_avg_prob_det_mdet(cor_tab) 

5844 

5845 cor_hyps = self._gen_cor_hyps( 

5846 num_meas, avg_prob_det, avg_prob_mdet, all_cost_m, cor_tab 

5847 ) 

5848 

5849 self._track_tab = cor_tab 

5850 self._hypotheses = cor_hyps 

5851 self._card_dist = self._calc_card_dist(self._hypotheses) 

5852 self._clean_updates() 

5853 

5854 def _extract_helper(self, track): 

5855 states = [None] * len(track.state_hist) 

5856 covs = [None] * len(track.state_hist) 

5857 for ii, (w_lst, s_lst, c_lst) in enumerate( 

5858 zip(track.distrib_weights_hist, track.state_hist, track.cov_hist) 

5859 ): 

5860 idx = np.argmax(w_lst) 

5861 states[ii] = s_lst[idx] 

5862 if self.save_covs: 

5863 covs[ii] = c_lst[idx] 

5864 return states, covs 

5865 

5866 def _update_extract_hist(self, idx_cmp): 

5867 used_meas_inds = [[] for ii in range(self._time_index_cntr)] 

5868 new_extract_hists = [None] * len(self._hypotheses[idx_cmp].track_set) 

5869 for ii, track in enumerate( 

5870 [ 

5871 self._track_tab[trk_ind] 

5872 for trk_ind in self._hypotheses[idx_cmp].track_set 

5873 ] 

5874 ): 

5875 new_extract_hists[ii] = self._ExtractHistHelper() 

5876 new_extract_hists[ii].meas_ind_hist = track.meas_assoc_hist.copy() 

5877 new_extract_hists[ii].b_time_index = track.time_index 

5878 ( 

5879 new_extract_hists[ii].states, 

5880 new_extract_hists[ii].covs, 

5881 ) = self._extract_helper(track) 

5882 

5883 for t_inds_after_b, meas_ind in enumerate( 

5884 new_extract_hists[ii].meas_ind_hist 

5885 ): 

5886 tt = new_extract_hists[ii].b_time_index + t_inds_after_b 

5887 if meas_ind is not None and meas_ind not in used_meas_inds[tt]: 

5888 used_meas_inds[tt].append(meas_ind) 

5889 good_inds = [] 

5890 for ii, existing in enumerate(self._extractable_hists): 

5891 for t_inds_after_b, meas_ind in enumerate(existing.meas_ind_hist): 

5892 tt = existing.b_time_index + t_inds_after_b 

5893 used = meas_ind is not None and meas_ind in used_meas_inds[tt] 

5894 if used: 

5895 break 

5896 if not used: 

5897 good_inds.append(ii) 

5898 self._extractable_hists = [self._extractable_hists[ii] for ii in good_inds] 

5899 self._extractable_hists.extend(new_extract_hists) 

5900 

5901 def extract_states(self, update=True, calc_states=True): 

5902 """Extracts the best state estimates. 

5903 

5904 This extracts the best states from the distribution. It should be 

5905 called once per time step after the correction function. This calls 

5906 both the inner filters predict and correct functions so the keyword 

5907 arguments must contain any additional variables needed by those 

5908 functions. 

5909 

5910 Parameters 

5911 ---------- 

5912 update : bool, optional 

5913 Flag indicating if the label history should be updated. This should 

5914 be done once per timestep and can be disabled if calculating states 

5915 after the final timestep. The default is True. 

5916 calc_states : bool, optional 

5917 Flag indicating if the states should be calculated based on the 

5918 label history. This only needs to be done before the states are used. 

5919 It can simply be called once after the end of the simulation. The 

5920 default is true. 

5921 

5922 Returns 

5923 ------- 

5924 idx_cmp : int 

5925 Index of the hypothesis table used when extracting states. 

5926 """ 

5927 card = np.argmax(self._card_dist) 

5928 tracks_per_hyp = np.array([x.num_tracks for x in self._hypotheses]) 

5929 weight_per_hyp = np.array([x.assoc_prob for x in self._hypotheses]) 

5930 

5931 self._states = [[] for ii in range(self._time_index_cntr)] 

5932 self._covs = [[] for ii in range(self._time_index_cntr)] 

5933 

5934 if len(tracks_per_hyp) == 0: 

5935 return None 

5936 idx_cmp = np.argmax(weight_per_hyp * (tracks_per_hyp == card)) 

5937 if update: 

5938 self._update_extract_hist(idx_cmp) 

5939 if calc_states: 

5940 for existing in self._extractable_hists: 

5941 for t_inds_after_b, (s, c) in enumerate( 

5942 zip(existing.states, existing.covs) 

5943 ): 

5944 tt = existing.b_time_index + t_inds_after_b 

5945 self._states[tt].append(s) 

5946 self._covs[tt].append(c) 

5947 if not update and not calc_states: 

5948 warnings.warn("Extracting states performed no actions") 

5949 return idx_cmp 

5950 

5951 def _prune(self): 

5952 """Removes hypotheses below a threshold. 

5953 

5954 This should be called once per time step after the correction and 

5955 before the state extraction. 

5956 """ 

5957 # Find hypotheses with low association probabilities 

5958 temp_assoc_probs = np.array([]) 

5959 for ii in range(0, len(self._hypotheses)): 

5960 temp_assoc_probs = np.append( 

5961 temp_assoc_probs, self._hypotheses[ii].assoc_prob 

5962 ) 

5963 keep_indices = np.argwhere(temp_assoc_probs > self.prune_threshold).T 

5964 keep_indices = keep_indices.flatten() 

5965 

5966 # For re-weighing association probabilities 

5967 new_sum = np.sum(temp_assoc_probs[keep_indices]) 

5968 self._hypotheses = [self._hypotheses[ii] for ii in keep_indices] 

5969 for ii in range(0, len(keep_indices)): 

5970 self._hypotheses[ii].assoc_prob = self._hypotheses[ii].assoc_prob / new_sum 

5971 # Re-calculate cardinality 

5972 self._card_dist = self._calc_card_dist(self._hypotheses) 

5973 

5974 def _cap(self): 

5975 """Removes least likely hypotheses until a maximum number is reached. 

5976 

5977 This should be called once per time step after pruning and 

5978 before the state extraction. 

5979 """ 

5980 # Determine if there are too many hypotheses 

5981 if len(self._hypotheses) > self.max_hyps: 

5982 temp_assoc_probs = np.array([]) 

5983 for ii in range(0, len(self._hypotheses)): 

5984 temp_assoc_probs = np.append( 

5985 temp_assoc_probs, self._hypotheses[ii].assoc_prob 

5986 ) 

5987 sorted_indices = np.argsort(temp_assoc_probs) 

5988 

5989 # Reverse order to get descending array 

5990 sorted_indices = sorted_indices[::-1] 

5991 

5992 # Take the top n assoc_probs, where n = max_hyps 

5993 keep_indices = np.array([], dtype=np.int64) 

5994 for ii in range(0, self.max_hyps): 

5995 keep_indices = np.append(keep_indices, int(sorted_indices[ii])) 

5996 # Assign to class 

5997 self._hypotheses = [self._hypotheses[ii] for ii in keep_indices] 

5998 

5999 # Normalize association probabilities 

6000 new_sum = 0 

6001 for ii in range(0, len(self._hypotheses)): 

6002 new_sum = new_sum + self._hypotheses[ii].assoc_prob 

6003 for ii in range(0, len(self._hypotheses)): 

6004 self._hypotheses[ii].assoc_prob = ( 

6005 self._hypotheses[ii].assoc_prob / new_sum 

6006 ) 

6007 # Re-calculate cardinality 

6008 self._card_dist = self._calc_card_dist(self._hypotheses) 

6009 

6010 def _bern_prune(self): 

6011 """Removes track table entries below a threshold. 

6012 

6013 This should be called once per time step after the correction and 

6014 before the state extraction. 

6015 """ 

6016 used = [0] * len(self._track_tab) 

6017 for ii in range(0, len(self._track_tab)): 

6018 if self._track_tab[ii].exist_prob > self.exist_threshold: 

6019 used[ii] += 1 

6020 

6021 keep_inds = [idx for idx, val in enumerate(used) if val != 0] 

6022 track_cnt = len(keep_inds) 

6023 

6024 new_inds = [None] * len(self._track_tab) 

6025 for ii, v in zip(keep_inds, [ii for ii in range(0, track_cnt)]): 

6026 new_inds[ii] = v 

6027 

6028 # loop over track table and remove pruned entries 

6029 new_tab = [self._track_tab[ii] for ii in keep_inds] 

6030 new_hyps = [] 

6031 for ii, hyp in enumerate(self._hypotheses): 

6032 if len(hyp.track_set) > 0: 

6033 track_set = [new_inds[track_ind] for track_ind in hyp.track_set] 

6034 if None in track_set: 

6035 track_set = [item for item in track_set if item != None] 

6036 hyp.track_set = track_set 

6037 new_hyps.append(hyp) 

6038 

6039 del_inds = [] 

6040 # TODO: ADD CASE FOR NO MORE TRACKS SO THAT WE DON'T REMOVE ALL HYPOTHESES 

6041 # AND OR THE HYPOTHESES HAVE AN EMPTY TRACK SET RATHER THAN A TRACK SET OF NONES 

6042 

6043 for ii in range(0, len(new_hyps)): 

6044 same_inds = [] 

6045 for jj in range(ii, len(new_hyps)): 

6046 if ii == jj or any(jj == x for x in del_inds): 

6047 continue 

6048 if new_hyps[ii].track_set == new_hyps[jj].track_set: 

6049 same_inds.append(jj) 

6050 for jj in same_inds: 

6051 new_hyps[ii].assoc_prob += new_hyps[jj].assoc_prob 

6052 del_inds.append(jj) 

6053 del_inds.sort(reverse=True) 

6054 for ind in del_inds: 

6055 new_hyps.pop(ind) 

6056 

6057 self._track_tab = new_tab 

6058 self._hypotheses = new_hyps 

6059 

6060 def cleanup( 

6061 self, 

6062 enable_prune=True, 

6063 enable_cap=True, 

6064 enable_bern_prune=True, 

6065 enable_extract=True, 

6066 extract_kwargs=None, 

6067 ): 

6068 """Performs the cleanup step of the filter. 

6069 

6070 This can prune, cap, and extract states. It must be called once per 

6071 timestep, even if all three functions are disabled. This is to ensure 

6072 that internal counters for tracking linear timestep indices are properly 

6073 incremented. If this is called with `enable_extract` set to true then 

6074 the extract states method does not need to be called separately. It is 

6075 recommended to call this function instead of 

6076 :meth:`carbs.swarm_estimator.tracker.PoissonMultiBernoulliMixture.extract_states` 

6077 directly. 

6078 

6079 Parameters 

6080 ---------- 

6081 enable_prune : bool, optional 

6082 Flag indicating if prunning should be performed. The default is True. 

6083 enable_cap : bool, optional 

6084 Flag indicating if capping should be performed. The default is True. 

6085 enable_bern_prune: bool, optional 

6086 Flag indicating if bernoulli pruning should be performed. The default is True. 

6087 enable_extract : bool, optional 

6088 Flag indicating if state extraction should be performed. The default is True. 

6089 extract_kwargs : dict, optional 

6090 Additional arguments to pass to :meth:`.extract_states`. The 

6091 default is None. Only used if extracting states. 

6092 

6093 Returns 

6094 ------- 

6095 None. 

6096 

6097 """ 

6098 self._time_index_cntr += 1 

6099 

6100 if enable_prune: 

6101 self._prune() 

6102 if enable_cap: 

6103 self._cap() 

6104 if enable_bern_prune: 

6105 self._bern_prune() 

6106 if enable_extract: 

6107 if extract_kwargs is None: 

6108 extract_kwargs = {} 

6109 self.extract_states(**extract_kwargs) 

6110 

6111 def calculate_ospa2( 

6112 self, 

6113 truth, 

6114 c, 

6115 p, 

6116 win_len, 

6117 true_covs=None, 

6118 core_method=SingleObjectDistance.MANHATTAN, 

6119 state_inds=None, 

6120 ): 

6121 """Calculates the OSPA(2) distance between the truth at all timesteps. 

6122 

6123 Wrapper for :func:`serums.distances.calculate_ospa2`. 

6124 

6125 Parameters 

6126 ---------- 

6127 truth : list 

6128 Each element represents a timestep and is a list of N x 1 numpy array, 

6129 one per true agent in the swarm. 

6130 c : float 

6131 Distance cutoff for considering a point properly assigned. This 

6132 influences how cardinality errors are penalized. For :math:`p = 1` 

6133 it is the penalty given false point estimate. 

6134 p : int 

6135 The power of the distance term. Higher values penalize outliers 

6136 more. 

6137 win_len : int 

6138 Number of samples to include in window. 

6139 core_method : :class:`serums.enums.SingleObjectDistance`, Optional 

6140 The main distance measure to use for the localization component. 

6141 The default value is :attr:`.SingleObjectDistance.MANHATTAN`. 

6142 true_covs : list, Optional 

6143 Each element represents a timestep and is a list of N x N numpy arrays 

6144 corresonponding to the uncertainty about the true states. Note the 

6145 order must be consistent with the truth data given. This is only 

6146 needed for core methods :attr:`SingleObjectDistance.HELLINGER`. The defautl 

6147 value is None. 

6148 state_inds : list, optional 

6149 Indices in the state vector to use, will be applied to the truth 

6150 data as well. The default is None which means the full state is 

6151 used. 

6152 """ 

6153 # error checking on optional input arguments 

6154 core_method = self._ospa_input_check(core_method, truth, true_covs) 

6155 

6156 # setup data structures 

6157 if state_inds is None: 

6158 state_dim = self._ospa_find_s_dim(truth) 

6159 state_inds = range(state_dim) 

6160 else: 

6161 state_dim = len(state_inds) 

6162 if state_dim is None: 

6163 warnings.warn("Failed to get state dimension. SKIPPING OSPA(2) calculation") 

6164 

6165 nt = len(self._states) 

6166 self.ospa2 = np.zeros(nt) 

6167 self.ospa2_localization = np.zeros(nt) 

6168 self.ospa2_cardinality = np.zeros(nt) 

6169 self._ospa2_params["core"] = core_method 

6170 self._ospa2_params["cutoff"] = c 

6171 self._ospa2_params["power"] = p 

6172 self._ospa2_params["win_len"] = win_len 

6173 return 

6174 true_mat, true_cov_mat = self._ospa_setup_tmat( 

6175 truth, state_dim, true_covs, state_inds 

6176 ) 

6177 est_mat, est_cov_mat = self._ospa_setup_emat(state_dim, state_inds) 

6178 

6179 # find OSPA 

6180 ( 

6181 self.ospa2, 

6182 self.ospa2_localization, 

6183 self.ospa2_cardinality, 

6184 self._ospa2_params["core"], 

6185 self._ospa2_params["cutoff"], 

6186 self._ospa2_params["power"], 

6187 self._ospa2_params["win_len"], 

6188 ) = calculate_ospa2( 

6189 est_mat, 

6190 true_mat, 

6191 c, 

6192 p, 

6193 win_len, 

6194 core_method=core_method, 

6195 true_cov_mat=true_cov_mat, 

6196 est_cov_mat=est_cov_mat, 

6197 ) 

6198 

6199 def plot_states( 

6200 self, 

6201 plt_inds, 

6202 state_lbl="States", 

6203 ttl=None, 

6204 state_color=None, 

6205 x_lbl=None, 

6206 y_lbl=None, 

6207 **kwargs, 

6208 ): 

6209 """Plots the best estimate for the states. 

6210 

6211 This assumes that the states have been extracted. It's designed to plot 

6212 two of the state variables (typically x/y position). The error ellipses 

6213 are calculated according to :cite:`Hoover1984_AlgorithmsforConfidenceCirclesandEllipses` 

6214 

6215 Keyword arguments are processed with 

6216 :meth:`gncpy.plotting.init_plotting_opts`. This function 

6217 implements 

6218 

6219 - f_hndl 

6220 - true_states 

6221 - sig_bnd 

6222 - rng 

6223 - meas_inds 

6224 - lgnd_loc 

6225 - marker 

6226 

6227 Parameters 

6228 ---------- 

6229 plt_inds : list 

6230 List of indices in the state vector to plot 

6231 state_lbl : string 

6232 Value to appear in legend for the states. Only appears if the 

6233 legend is shown 

6234 ttl : string, optional 

6235 Title for the plot, if None a default title is generated. The default 

6236 is None. 

6237 x_lbl : string 

6238 Label for the x-axis. 

6239 y_lbl : string 

6240 Label for the y-axis. 

6241 

6242 Returns 

6243 ------- 

6244 Matplotlib figure 

6245 Instance of the matplotlib figure used 

6246 """ 

6247 opts = pltUtil.init_plotting_opts(**kwargs) 

6248 f_hndl = opts["f_hndl"] 

6249 true_states = opts["true_states"] 

6250 sig_bnd = opts["sig_bnd"] 

6251 rng = opts["rng"] 

6252 meas_inds = opts["meas_inds"] 

6253 lgnd_loc = opts["lgnd_loc"] 

6254 marker = opts["marker"] 

6255 if ttl is None: 

6256 ttl = "State Estimates" 

6257 if rng is None: 

6258 rng = rnd.default_rng(1) 

6259 if x_lbl is None: 

6260 x_lbl = "x-position" 

6261 if y_lbl is None: 

6262 y_lbl = "y-position" 

6263 plt_meas = meas_inds is not None 

6264 show_sig = sig_bnd is not None and self.save_covs 

6265 

6266 s_lst = deepcopy(self._states) 

6267 x_dim = None 

6268 

6269 if f_hndl is None: 

6270 f_hndl = plt.figure() 

6271 f_hndl.add_subplot(1, 1, 1) 

6272 # get state dimension 

6273 for states in s_lst: 

6274 if len(states) > 0: 

6275 x_dim = states[0].size 

6276 break 

6277 # get array of all state values for each label 

6278 added_sig_lbl = False 

6279 added_true_lbl = False 

6280 added_state_lbl = False 

6281 added_meas_lbl = False 

6282 r = rng.random() 

6283 b = rng.random() 

6284 g = rng.random() 

6285 if state_color is None: 

6286 color = (r, g, b) 

6287 else: 

6288 color = state_color 

6289 for tt, states in enumerate(s_lst): 

6290 if len(states) == 0: 

6291 continue 

6292 x = np.concatenate(states, axis=1) 

6293 if show_sig: 

6294 sigs = [None] * len(states) 

6295 for ii, cov in enumerate(self._covs[tt]): 

6296 sig = np.zeros((2, 2)) 

6297 sig[0, 0] = cov[plt_inds[0], plt_inds[0]] 

6298 sig[0, 1] = cov[plt_inds[0], plt_inds[1]] 

6299 sig[1, 0] = cov[plt_inds[1], plt_inds[0]] 

6300 sig[1, 1] = cov[plt_inds[1], plt_inds[1]] 

6301 sigs[ii] = sig 

6302 # plot 

6303 for ii, sig in enumerate(sigs): 

6304 if sig is None: 

6305 continue 

6306 w, h, a = pltUtil.calc_error_ellipse(sig, sig_bnd) 

6307 if not added_sig_lbl: 

6308 s = r"${}\sigma$ Error Ellipses".format(sig_bnd) 

6309 e = Ellipse( 

6310 xy=x[plt_inds, ii], 

6311 width=w, 

6312 height=h, 

6313 angle=a, 

6314 zorder=-10000, 

6315 label=s, 

6316 ) 

6317 added_sig_lbl = True 

6318 else: 

6319 e = Ellipse( 

6320 xy=x[plt_inds, ii], 

6321 width=w, 

6322 height=h, 

6323 angle=a, 

6324 zorder=-10000, 

6325 ) 

6326 e.set_clip_box(f_hndl.axes[0].bbox) 

6327 e.set_alpha(0.15) 

6328 e.set_facecolor(color) 

6329 f_hndl.axes[0].add_patch(e) 

6330 if not added_state_lbl: 

6331 f_hndl.axes[0].scatter( 

6332 x[plt_inds[0], :], 

6333 x[plt_inds[1], :], 

6334 color=color, 

6335 edgecolors=(0, 0, 0), 

6336 marker=marker, 

6337 label=state_lbl, 

6338 ) 

6339 added_state_lbl = True 

6340 else: 

6341 f_hndl.axes[0].scatter( 

6342 x[plt_inds[0], :], 

6343 x[plt_inds[1], :], 

6344 color=color, 

6345 edgecolors=(0, 0, 0), 

6346 marker=marker, 

6347 ) 

6348 # if true states are available then plot them 

6349 if true_states is not None: 

6350 if x_dim is None: 

6351 for states in true_states: 

6352 if len(states) > 0: 

6353 x_dim = states[0].size 

6354 break 

6355 max_true = max([len(x) for x in true_states]) 

6356 x = np.nan * np.ones((x_dim, len(true_states), max_true)) 

6357 for tt, states in enumerate(true_states): 

6358 for ii, state in enumerate(states): 

6359 x[:, [tt], ii] = state.copy() 

6360 for ii in range(0, max_true): 

6361 if not added_true_lbl: 

6362 f_hndl.axes[0].plot( 

6363 x[plt_inds[0], :, ii], 

6364 x[plt_inds[1], :, ii], 

6365 color="k", 

6366 marker=".", 

6367 label="True Trajectories", 

6368 ) 

6369 added_true_lbl = True 

6370 else: 

6371 f_hndl.axes[0].plot( 

6372 x[plt_inds[0], :, ii], 

6373 x[plt_inds[1], :, ii], 

6374 color="k", 

6375 marker=".", 

6376 ) 

6377 if plt_meas: 

6378 meas_x = [] 

6379 meas_y = [] 

6380 for meas_tt in self._meas_tab: 

6381 mx_ii = [m[meas_inds[0]].item() for m in meas_tt] 

6382 my_ii = [m[meas_inds[1]].item() for m in meas_tt] 

6383 meas_x.extend(mx_ii) 

6384 meas_y.extend(my_ii) 

6385 color = (128 / 255, 128 / 255, 128 / 255) 

6386 meas_x = np.asarray(meas_x) 

6387 meas_y = np.asarray(meas_y) 

6388 if not added_meas_lbl: 

6389 f_hndl.axes[0].scatter( 

6390 meas_x, 

6391 meas_y, 

6392 zorder=-1, 

6393 alpha=0.35, 

6394 color=color, 

6395 marker="^", 

6396 edgecolors=(0, 0, 0), 

6397 label="Measurements", 

6398 ) 

6399 else: 

6400 f_hndl.axes[0].scatter( 

6401 meas_x, 

6402 meas_y, 

6403 zorder=-1, 

6404 alpha=0.35, 

6405 color=color, 

6406 marker="^", 

6407 edgecolors=(0, 0, 0), 

6408 ) 

6409 f_hndl.axes[0].grid(True) 

6410 pltUtil.set_title_label(f_hndl, 0, opts, ttl=ttl, x_lbl=x_lbl, y_lbl=y_lbl) 

6411 

6412 if lgnd_loc is not None: 

6413 plt.legend(loc=lgnd_loc) 

6414 plt.tight_layout() 

6415 

6416 return f_hndl 

6417 

6418 def plot_card_dist(self, ttl=None, **kwargs): 

6419 """Plots the current cardinality distribution. 

6420 

6421 This assumes that the cardinality distribution has been calculated by 

6422 the class. 

6423 

6424 Keywrod arguments are processed with 

6425 :meth:`gncpy.plotting.init_plotting_opts`. This function 

6426 implements 

6427 

6428 - f_hndl 

6429 

6430 Parameters 

6431 ---------- 

6432 ttl : string 

6433 Title of the plot, if None a default title is generated. The default 

6434 is None. 

6435 

6436 Returns 

6437 ------- 

6438 Matplotlib figure 

6439 Instance of the matplotlib figure used 

6440 """ 

6441 opts = pltUtil.init_plotting_opts(**kwargs) 

6442 f_hndl = opts["f_hndl"] 

6443 if ttl is None: 

6444 ttl = "Cardinality Distribution" 

6445 if len(self._card_dist) == 0: 

6446 raise RuntimeWarning("Empty Cardinality") 

6447 return f_hndl 

6448 if f_hndl is None: 

6449 f_hndl = plt.figure() 

6450 f_hndl.add_subplot(1, 1, 1) 

6451 x_vals = np.arange(0, len(self._card_dist)) 

6452 f_hndl.axes[0].bar(x_vals, self._card_dist) 

6453 

6454 pltUtil.set_title_label( 

6455 f_hndl, 0, opts, ttl=ttl, x_lbl="Cardinality", y_lbl="Probability" 

6456 ) 

6457 plt.tight_layout() 

6458 

6459 return f_hndl 

6460 

6461 def plot_card_history( 

6462 self, time_units="index", time=None, ttl="Cardinality History", **kwargs 

6463 ): 

6464 """Plots the cardinality history. 

6465 

6466 Parameters 

6467 ---------- 

6468 time_units : string, optional 

6469 Text representing the units of time in the plot. The default is 

6470 'index'. 

6471 time : numpy array, optional 

6472 Vector to use for the x-axis of the plot. If none is given then 

6473 vector indices are used. The default is None. 

6474 ttl : string, optional 

6475 Title of the plot. 

6476 **kwargs : dict 

6477 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

6478 function. Values implemented here are `f_hndl`, and any values 

6479 relating to title/axis text formatting. 

6480 

6481 Returns 

6482 ------- 

6483 fig : matplotlib figure 

6484 Figure object the data was plotted on. 

6485 """ 

6486 card_history = np.array([len(state_set) for state_set in self.states]) 

6487 

6488 opts = pltUtil.init_plotting_opts(**kwargs) 

6489 fig = opts["f_hndl"] 

6490 

6491 if fig is None: 

6492 fig = plt.figure() 

6493 fig.add_subplot(1, 1, 1) 

6494 if time is None: 

6495 time = np.arange(card_history.size, dtype=int) 

6496 fig.axes[0].grid(True) 

6497 fig.axes[0].step(time, card_history, where="post", label="estimated", color="k") 

6498 fig.axes[0].ticklabel_format(useOffset=False) 

6499 

6500 pltUtil.set_title_label( 

6501 fig, 

6502 0, 

6503 opts, 

6504 ttl=ttl, 

6505 x_lbl="Time ({})".format(time_units), 

6506 y_lbl="Cardinality", 

6507 ) 

6508 fig.tight_layout() 

6509 

6510 return fig 

6511 

6512 def plot_ospa2_history( 

6513 self, 

6514 time_units="index", 

6515 time=None, 

6516 main_opts=None, 

6517 sub_opts=None, 

6518 plot_subs=True, 

6519 ): 

6520 """Plots the OSPA2 history. 

6521 

6522 This requires that the OSPA2 has been calcualted by the approriate 

6523 function first. 

6524 

6525 Parameters 

6526 ---------- 

6527 time_units : string, optional 

6528 Text representing the units of time in the plot. The default is 

6529 'index'. 

6530 time : numpy array, optional 

6531 Vector to use for the x-axis of the plot. If none is given then 

6532 vector indices are used. The default is None. 

6533 main_opts : dict, optional 

6534 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

6535 function. Values implemented here are `f_hndl`, and any values 

6536 relating to title/axis text formatting. The default of None implies 

6537 the default options are used for the main plot. 

6538 sub_opts : dict, optional 

6539 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

6540 function. Values implemented here are `f_hndl`, and any values 

6541 relating to title/axis text formatting. The default of None implies 

6542 the default options are used for the sub plot. 

6543 plot_subs : bool, optional 

6544 Flag indicating if the component statistics (cardinality and 

6545 localization) should also be plotted. 

6546 

6547 Returns 

6548 ------- 

6549 figs : dict 

6550 Dictionary of matplotlib figure objects the data was plotted on. 

6551 """ 

6552 if self.ospa2 is None: 

6553 warnings.warn("OSPA must be calculated before plotting") 

6554 return 

6555 if main_opts is None: 

6556 main_opts = pltUtil.init_plotting_opts() 

6557 if sub_opts is None and plot_subs: 

6558 sub_opts = pltUtil.init_plotting_opts() 

6559 fmt = "{:s} OSPA2 (c = {:.1f}, p = {:d}, w={:d})" 

6560 ttl = fmt.format( 

6561 self._ospa2_params["core"], 

6562 self._ospa2_params["cutoff"], 

6563 self._ospa2_params["power"], 

6564 self._ospa2_params["win_len"], 

6565 ) 

6566 y_lbl = "OSPA2" 

6567 

6568 figs = {} 

6569 figs["OSPA2"] = self._plt_ospa_hist( 

6570 self.ospa2, time_units, time, ttl, y_lbl, main_opts 

6571 ) 

6572 

6573 if plot_subs: 

6574 fmt = "{:s} OSPA2 Components (c = {:.1f}, p = {:d}, w={:d})" 

6575 ttl = fmt.format( 

6576 self._ospa2_params["core"], 

6577 self._ospa2_params["cutoff"], 

6578 self._ospa2_params["power"], 

6579 self._ospa2_params["win_len"], 

6580 ) 

6581 y_lbls = ["Localiztion", "Cardinality"] 

6582 figs["OSPA2_subs"] = self._plt_ospa_hist_subs( 

6583 [self.ospa2_localization, self.ospa2_cardinality], 

6584 time_units, 

6585 time, 

6586 ttl, 

6587 y_lbls, 

6588 main_opts, 

6589 ) 

6590 return figs 

6591 

6592 

6593class LabeledPoissonMultiBernoulliMixture(PoissonMultiBernoulliMixture): 

6594 def __init__(self, **kwargs): 

6595 super().__init__(**kwargs) 

6596 

6597 @property 

6598 def labels(self): 

6599 """Read only list of extracted labels. 

6600 

6601 This is a list with 1 element per timestep, and each element is a list 

6602 of the best labels extracted at that timestep. The order of each 

6603 element corresponds to the state order. 

6604 """ 

6605 return self._labels 

6606 

6607 def _correct_birth_tab_entry(self, meas, distrib, timestep, filt_args): 

6608 new_tab = self._TabEntry() 

6609 filt_states, weights, states, covs = self._init_filt_states(distrib) 

6610 

6611 new_f_states = [None] * len(filt_states) 

6612 new_s_hist = [None] * len(filt_states) 

6613 new_c_hist = [None] * len(filt_states) 

6614 new_w = [None] * len(filt_states) 

6615 depleted = False 

6616 for ii, (f_state, state, w) in enumerate(zip(filt_states, states, weights)): 

6617 try: 

6618 ( 

6619 new_f_states[ii], 

6620 new_s_hist[ii], 

6621 new_c_hist[ii], 

6622 new_w[ii], 

6623 ) = self._inner_correct(timestep, meas, f_state, w, state, filt_args) 

6624 except ( 

6625 gerr.ParticleDepletionError, 

6626 gerr.ParticleEstimationDomainError, 

6627 gerr.ExtremeMeasurementNoiseError, 

6628 ): 

6629 return None, 0 

6630 new_tab.filt_states = new_f_states 

6631 new_tab.state_hist = [new_s_hist] 

6632 new_tab.cov_hist = [new_c_hist] 

6633 new_tab.distrib_weights_hist = [] 

6634 new_w = [w + np.finfo(float).eps for w in new_w] 

6635 if not depleted: 

6636 cost = ( 

6637 np.sum(new_w).item() * self.prob_detection 

6638 + self.clutter_rate * self.clutter_den 

6639 ) 

6640 new_tab.distrib_weights_hist.append( 

6641 [w / np.sum(new_w).item() for w in new_w] 

6642 ) 

6643 new_tab.exist_prob = ( 

6644 self.prob_detection 

6645 * cost 

6646 / (self.clutter_rate * self.clutter_den + self.prob_detection * cost) 

6647 ) 

6648 else: 

6649 cost = 0 

6650 new_tab.time_index = self._time_index_cntr 

6651 return new_tab, cost 

6652 

6653 def _gen_cor_tab(self, num_meas, meas, timestep, filt_args): 

6654 num_pred = len(self._track_tab) 

6655 num_birth = len(self.birth_terms) 

6656 up_tab = [None] * ((num_meas + 1) * num_pred + num_meas * num_birth) 

6657 

6658 # Missed Detection Updates 

6659 for ii, track in enumerate(self._track_tab): 

6660 up_tab[ii] = self._TabEntry().setup(track) 

6661 sum_non_exist_prob = ( 

6662 1 

6663 - up_tab[ii].exist_prob 

6664 + up_tab[ii].exist_prob * self.prob_miss_detection 

6665 ) 

6666 up_tab[ii].distrib_weights_hist.append( 

6667 [w * sum_non_exist_prob for w in up_tab[ii].distrib_weights_hist[-1]] 

6668 ) 

6669 up_tab[ii].exist_prob = ( 

6670 up_tab[ii].exist_prob * self.prob_miss_detection 

6671 ) / (sum_non_exist_prob) 

6672 up_tab[ii].meas_assoc_hist.append(None) 

6673 # left_cost_m = np.zeros() 

6674 # all_cost_m = np.zeros((num_pred + num_birth * num_meas, num_meas)) 

6675 all_cost_m = np.zeros((num_meas, num_pred + num_birth * num_meas)) 

6676 

6677 # Update for all existing tracks 

6678 for emm, z in enumerate(meas): 

6679 for ii, ent in enumerate(self._track_tab): 

6680 s_to_ii = num_pred * emm + ii + num_pred 

6681 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry( 

6682 z, ent, timestep, filt_args 

6683 ) 

6684 if up_tab[s_to_ii] is not None: 

6685 up_tab[s_to_ii].meas_assoc_hist.append(emm) 

6686 all_cost_m[emm, ii] = cost 

6687 

6688 # Update for all potential new births 

6689 for emm, z in enumerate(meas): 

6690 for ii, b_model in enumerate(self.birth_terms): 

6691 s_to_ii = ((num_meas + 1) * num_pred) + emm * num_birth + ii 

6692 (up_tab[s_to_ii], cost) = self._correct_birth_tab_entry( 

6693 z, b_model, timestep, filt_args 

6694 ) 

6695 if up_tab[s_to_ii] is not None: 

6696 up_tab[s_to_ii].meas_assoc_hist.append(emm) 

6697 up_tab[s_to_ii].label = (round(timestep, self.decimal_places), ii) 

6698 all_cost_m[emm, emm + num_pred] = cost 

6699 return up_tab, all_cost_m 

6700 

6701 def _update_extract_hist(self, idx_cmp): 

6702 used_meas_inds = [[] for ii in range(self._time_index_cntr)] 

6703 used_labels = [] 

6704 new_extract_hists = [None] * len(self._hypotheses[idx_cmp].track_set) 

6705 for ii, track in enumerate( 

6706 [ 

6707 self._track_tab[trk_ind] 

6708 for trk_ind in self._hypotheses[idx_cmp].track_set 

6709 ] 

6710 ): 

6711 new_extract_hists[ii] = self._ExtractHistHelper() 

6712 new_extract_hists[ii].label = track.label 

6713 new_extract_hists[ii].meas_ind_hist = track.meas_assoc_hist.copy() 

6714 new_extract_hists[ii].b_time_index = track.time_index 

6715 ( 

6716 new_extract_hists[ii].states, 

6717 new_extract_hists[ii].covs, 

6718 ) = self._extract_helper(track) 

6719 

6720 used_labels.append(track.label) 

6721 

6722 for t_inds_after_b, meas_ind in enumerate( 

6723 new_extract_hists[ii].meas_ind_hist 

6724 ): 

6725 tt = new_extract_hists[ii].b_time_index + t_inds_after_b 

6726 if meas_ind is not None and meas_ind not in used_meas_inds[tt]: 

6727 used_meas_inds[tt].append(meas_ind) 

6728 good_inds = [] 

6729 for ii, existing in enumerate(self._extractable_hists): 

6730 used = existing.label in used_labels 

6731 if used: 

6732 continue 

6733 for t_inds_after_b, meas_ind in enumerate(existing.meas_ind_hist): 

6734 tt = existing.b_time_index + t_inds_after_b 

6735 used = meas_ind is not None and meas_ind in used_meas_inds[tt] 

6736 if used: 

6737 break 

6738 if not used: 

6739 good_inds.append(ii) 

6740 self._extractable_hists = [self._extractable_hists[ii] for ii in good_inds] 

6741 self._extractable_hists.extend(new_extract_hists) 

6742 

6743 def extract_states(self, update=True, calc_states=True): 

6744 """Extracts the best state estimates. 

6745 

6746 This extracts the best states from the distribution. It should be 

6747 called once per time step after the correction function. This calls 

6748 both the inner filters predict and correct functions so the keyword 

6749 arguments must contain any additional variables needed by those 

6750 functions. 

6751 

6752 Parameters 

6753 ---------- 

6754 update : bool, optional 

6755 Flag indicating if the label history should be updated. This should 

6756 be done once per timestep and can be disabled if calculating states 

6757 after the final timestep. The default is True. 

6758 calc_states : bool, optional 

6759 Flag indicating if the states should be calculated based on the 

6760 label history. This only needs to be done before the states are used. 

6761 It can simply be called once after the end of the simulation. The 

6762 default is true. 

6763 

6764 Returns 

6765 ------- 

6766 idx_cmp : int 

6767 Index of the hypothesis table used when extracting states. 

6768 """ 

6769 card = np.argmax(self._card_dist) 

6770 tracks_per_hyp = np.array([x.num_tracks for x in self._hypotheses]) 

6771 weight_per_hyp = np.array([x.assoc_prob for x in self._hypotheses]) 

6772 

6773 self._states = [[] for ii in range(self._time_index_cntr)] 

6774 self._labels = [[] for ii in range(self._time_index_cntr)] 

6775 self._covs = [[] for ii in range(self._time_index_cntr)] 

6776 

6777 if len(tracks_per_hyp) == 0: 

6778 return None 

6779 idx_cmp = np.argmax(weight_per_hyp * (tracks_per_hyp == card)) 

6780 if update: 

6781 self._update_extract_hist(idx_cmp) 

6782 if calc_states: 

6783 for existing in self._extractable_hists: 

6784 for t_inds_after_b, (s, c) in enumerate( 

6785 zip(existing.states, existing.covs) 

6786 ): 

6787 tt = existing.b_time_index + t_inds_after_b 

6788 # if len(self._labels[tt]) == 0: 

6789 # self._states[tt] = [s] 

6790 # self._labels[tt] = [existing.label] 

6791 # self._covs[tt] = [c] 

6792 # else: 

6793 self._states[tt].append(s) 

6794 self._labels[tt].append(existing.label) 

6795 self._covs[tt].append(c) 

6796 if not update and not calc_states: 

6797 warnings.warn("Extracting states performed no actions") 

6798 return idx_cmp 

6799 

6800 def _ospa_setup_emat(self, state_dim, state_inds): 

6801 # get sizes 

6802 num_timesteps = len(self.states) 

6803 num_objs = 0 

6804 lbl_to_ind = {} 

6805 

6806 for lst in self.labels: 

6807 for lbl in lst: 

6808 if lbl is None: 

6809 continue 

6810 key = str(lbl) 

6811 if key not in lbl_to_ind: 

6812 lbl_to_ind[key] = num_objs 

6813 num_objs += 1 

6814 # create matrices 

6815 est_mat = np.nan * np.ones((state_dim, num_timesteps, num_objs)) 

6816 est_cov_mat = np.nan * np.ones((state_dim, state_dim, num_timesteps, num_objs)) 

6817 

6818 for tt, (lbl_lst, s_lst) in enumerate(zip(self.labels, self.states)): 

6819 for lbl, s in zip(lbl_lst, s_lst): 

6820 if lbl is None: 

6821 continue 

6822 obj_num = lbl_to_ind[str(lbl)] 

6823 est_mat[:, tt, obj_num] = s.ravel()[state_inds] 

6824 if self.save_covs: 

6825 for tt, (lbl_lst, c_lst) in enumerate(zip(self.labels, self.covariances)): 

6826 for lbl, c in zip(lbl_lst, c_lst): 

6827 if lbl is None: 

6828 continue 

6829 est_cov_mat[:, :, tt, lbl_to_ind[str(lbl)]] = c[state_inds][ 

6830 :, state_inds 

6831 ] 

6832 return est_mat, est_cov_mat 

6833 

6834 def calculate_ospa2( 

6835 self, 

6836 truth, 

6837 c, 

6838 p, 

6839 win_len, 

6840 true_covs=None, 

6841 core_method=SingleObjectDistance.MANHATTAN, 

6842 state_inds=None, 

6843 ): 

6844 """Calculates the OSPA(2) distance between the truth at all timesteps. 

6845 

6846 Wrapper for :func:`serums.distances.calculate_ospa2`. 

6847 

6848 Parameters 

6849 ---------- 

6850 truth : list 

6851 Each element represents a timestep and is a list of N x 1 numpy array, 

6852 one per true agent in the swarm. 

6853 c : float 

6854 Distance cutoff for considering a point properly assigned. This 

6855 influences how cardinality errors are penalized. For :math:`p = 1` 

6856 it is the penalty given false point estimate. 

6857 p : int 

6858 The power of the distance term. Higher values penalize outliers 

6859 more. 

6860 win_len : int 

6861 Number of samples to include in window. 

6862 core_method : :class:`serums.enums.SingleObjectDistance`, Optional 

6863 The main distance measure to use for the localization component. 

6864 The default value is :attr:`.SingleObjectDistance.MANHATTAN`. 

6865 true_covs : list, Optional 

6866 Each element represents a timestep and is a list of N x N numpy arrays 

6867 corresonponding to the uncertainty about the true states. Note the 

6868 order must be consistent with the truth data given. This is only 

6869 needed for core methods :attr:`SingleObjectDistance.HELLINGER`. The defautl 

6870 value is None. 

6871 state_inds : list, optional 

6872 Indices in the state vector to use, will be applied to the truth 

6873 data as well. The default is None which means the full state is 

6874 used. 

6875 """ 

6876 # error checking on optional input arguments 

6877 core_method = self._ospa_input_check(core_method, truth, true_covs) 

6878 

6879 # setup data structures 

6880 if state_inds is None: 

6881 state_dim = self._ospa_find_s_dim(truth) 

6882 state_inds = range(state_dim) 

6883 else: 

6884 state_dim = len(state_inds) 

6885 if state_dim is None: 

6886 warnings.warn("Failed to get state dimension. SKIPPING OSPA(2) calculation") 

6887 

6888 nt = len(self._states) 

6889 self.ospa2 = np.zeros(nt) 

6890 self.ospa2_localization = np.zeros(nt) 

6891 self.ospa2_cardinality = np.zeros(nt) 

6892 self._ospa2_params["core"] = core_method 

6893 self._ospa2_params["cutoff"] = c 

6894 self._ospa2_params["power"] = p 

6895 self._ospa2_params["win_len"] = win_len 

6896 return 

6897 true_mat, true_cov_mat = self._ospa_setup_tmat( 

6898 truth, state_dim, true_covs, state_inds 

6899 ) 

6900 est_mat, est_cov_mat = self._ospa_setup_emat(state_dim, state_inds) 

6901 

6902 # find OSPA 

6903 ( 

6904 self.ospa2, 

6905 self.ospa2_localization, 

6906 self.ospa2_cardinality, 

6907 self._ospa2_params["core"], 

6908 self._ospa2_params["cutoff"], 

6909 self._ospa2_params["power"], 

6910 self._ospa2_params["win_len"], 

6911 ) = calculate_ospa2( 

6912 est_mat, 

6913 true_mat, 

6914 c, 

6915 p, 

6916 win_len, 

6917 core_method=core_method, 

6918 true_cov_mat=true_cov_mat, 

6919 est_cov_mat=est_cov_mat, 

6920 ) 

6921 

6922 def plot_states_labels( 

6923 self, 

6924 plt_inds, 

6925 ttl="Labeled State Trajectories", 

6926 x_lbl=None, 

6927 y_lbl=None, 

6928 meas_tx_fnc=None, 

6929 **kwargs, 

6930 ): 

6931 """Plots the best estimate for the states and labels. 

6932 

6933 This assumes that the states have been extracted. It's designed to plot 

6934 two of the state variables (typically x/y position). The error ellipses 

6935 are calculated according to :cite:`Hoover1984_AlgorithmsforConfidenceCirclesandEllipses` 

6936 

6937 Keywrod arguments are processed with 

6938 :meth:`gncpy.plotting.init_plotting_opts`. This function 

6939 implements 

6940 

6941 - f_hndl 

6942 - true_states 

6943 - sig_bnd 

6944 - rng 

6945 - meas_inds 

6946 - lgnd_loc 

6947 

6948 Parameters 

6949 ---------- 

6950 plt_inds : list 

6951 List of indices in the state vector to plot 

6952 ttl : string, optional 

6953 Title of the plot. 

6954 x_lbl : string, optional 

6955 X-axis label for the plot. 

6956 y_lbl : string, optional 

6957 Y-axis label for the plot. 

6958 meas_tx_fnc : callable, optional 

6959 Takes in the measurement vector as an Nm x 1 numpy array and 

6960 returns a numpy array representing the states to plot (size 2). The 

6961 default is None. 

6962 

6963 Returns 

6964 ------- 

6965 Matplotlib figure 

6966 Instance of the matplotlib figure used 

6967 """ 

6968 opts = pltUtil.init_plotting_opts(**kwargs) 

6969 f_hndl = opts["f_hndl"] 

6970 true_states = opts["true_states"] 

6971 sig_bnd = opts["sig_bnd"] 

6972 rng = opts["rng"] 

6973 meas_inds = opts["meas_inds"] 

6974 lgnd_loc = opts["lgnd_loc"] 

6975 mrkr = opts["marker"] 

6976 

6977 if rng is None: 

6978 rng = rnd.default_rng(1) 

6979 if x_lbl is None: 

6980 x_lbl = "x-position" 

6981 if y_lbl is None: 

6982 y_lbl = "y-position" 

6983 meas_specs_given = ( 

6984 meas_inds is not None and len(meas_inds) == 2 

6985 ) or meas_tx_fnc is not None 

6986 plt_meas = meas_specs_given and self.save_measurements 

6987 show_sig = sig_bnd is not None and self.save_covs 

6988 

6989 s_lst = deepcopy(self.states) 

6990 l_lst = deepcopy(self.labels) 

6991 x_dim = None 

6992 

6993 if f_hndl is None: 

6994 f_hndl = plt.figure() 

6995 f_hndl.add_subplot(1, 1, 1) 

6996 # get state dimension 

6997 for states in s_lst: 

6998 if states is not None and len(states) > 0: 

6999 x_dim = states[0].size 

7000 break 

7001 # get unique labels 

7002 u_lbls = [] 

7003 for lbls in l_lst: 

7004 if lbls is None: 

7005 continue 

7006 for lbl in lbls: 

7007 if lbl not in u_lbls: 

7008 u_lbls.append(lbl) 

7009 cmap = pltUtil.get_cmap(len(u_lbls)) 

7010 

7011 # get array of all state values for each label 

7012 added_sig_lbl = False 

7013 added_true_lbl = False 

7014 added_state_lbl = False 

7015 added_meas_lbl = False 

7016 for c_idx, lbl in enumerate(u_lbls): 

7017 x = np.nan * np.ones((x_dim, len(s_lst))) 

7018 if show_sig: 

7019 sigs = [None] * len(s_lst) 

7020 for tt, lbls in enumerate(l_lst): 

7021 if lbls is None: 

7022 continue 

7023 if lbl in lbls: 

7024 ii = lbls.index(lbl) 

7025 if s_lst[tt][ii] is not None: 

7026 x[:, [tt]] = s_lst[tt][ii].copy() 

7027 if show_sig: 

7028 sig = np.zeros((2, 2)) 

7029 if self._covs[tt][ii] is not None: 

7030 sig[0, 0] = self._covs[tt][ii][plt_inds[0], plt_inds[0]] 

7031 sig[0, 1] = self._covs[tt][ii][plt_inds[0], plt_inds[1]] 

7032 sig[1, 0] = self._covs[tt][ii][plt_inds[1], plt_inds[0]] 

7033 sig[1, 1] = self._covs[tt][ii][plt_inds[1], plt_inds[1]] 

7034 else: 

7035 sig = None 

7036 sigs[tt] = sig 

7037 # plot 

7038 color = cmap(c_idx) 

7039 

7040 if show_sig: 

7041 for tt, sig in enumerate(sigs): 

7042 if sig is None: 

7043 continue 

7044 w, h, a = pltUtil.calc_error_ellipse(sig, sig_bnd) 

7045 if not added_sig_lbl: 

7046 s = r"${}\sigma$ Error Ellipses".format(sig_bnd) 

7047 e = Ellipse( 

7048 xy=x[plt_inds, tt], 

7049 width=w, 

7050 height=h, 

7051 angle=a, 

7052 zorder=-10000, 

7053 label=s, 

7054 ) 

7055 added_sig_lbl = True 

7056 else: 

7057 e = Ellipse( 

7058 xy=x[plt_inds, tt], 

7059 width=w, 

7060 height=h, 

7061 angle=a, 

7062 zorder=-10000, 

7063 ) 

7064 e.set_clip_box(f_hndl.axes[0].bbox) 

7065 e.set_alpha(0.2) 

7066 e.set_facecolor(color) 

7067 f_hndl.axes[0].add_patch(e) 

7068 settings = { 

7069 "color": color, 

7070 "markeredgecolor": "k", 

7071 "marker": mrkr, 

7072 "ls": "--", 

7073 } 

7074 if not added_state_lbl: 

7075 settings["label"] = "States" 

7076 # f_hndl.axes[0].scatter(x[plt_inds[0], :], x[plt_inds[1], :], 

7077 # color=color, edgecolors='k', 

7078 # label='States') 

7079 added_state_lbl = True 

7080 # else: 

7081 f_hndl.axes[0].plot(x[plt_inds[0], :], x[plt_inds[1], :], **settings) 

7082 

7083 s = "({}, {})".format(lbl[0], lbl[1]) 

7084 tmp = x.copy() 

7085 tmp = tmp[:, ~np.any(np.isnan(tmp), axis=0)] 

7086 f_hndl.axes[0].text( 

7087 tmp[plt_inds[0], 0], tmp[plt_inds[1], 0], s, color=color 

7088 ) 

7089 # if true states are available then plot them 

7090 if true_states is not None and any([len(x) > 0 for x in true_states]): 

7091 if x_dim is None: 

7092 for states in true_states: 

7093 if len(states) > 0: 

7094 x_dim = states[0].size 

7095 break 

7096 max_true = max([len(x) for x in true_states]) 

7097 x = np.nan * np.ones((x_dim, len(true_states), max_true)) 

7098 for tt, states in enumerate(true_states): 

7099 for ii, state in enumerate(states): 

7100 if state is not None and state.size > 0: 

7101 x[:, [tt], ii] = state.copy() 

7102 for ii in range(0, max_true): 

7103 if not added_true_lbl: 

7104 f_hndl.axes[0].plot( 

7105 x[plt_inds[0], :, ii], 

7106 x[plt_inds[1], :, ii], 

7107 color="k", 

7108 marker=".", 

7109 label="True Trajectories", 

7110 ) 

7111 added_true_lbl = True 

7112 else: 

7113 f_hndl.axes[0].plot( 

7114 x[plt_inds[0], :, ii], 

7115 x[plt_inds[1], :, ii], 

7116 color="k", 

7117 marker=".", 

7118 ) 

7119 if plt_meas: 

7120 meas_x = [] 

7121 meas_y = [] 

7122 for meas_tt in self._meas_tab: 

7123 if meas_tx_fnc is not None: 

7124 tx_meas = [meas_tx_fnc(m) for m in meas_tt] 

7125 mx_ii = [tm[0].item() for tm in tx_meas] 

7126 my_ii = [tm[1].item() for tm in tx_meas] 

7127 else: 

7128 mx_ii = [m[meas_inds[0]].item() for m in meas_tt] 

7129 my_ii = [m[meas_inds[1]].item() for m in meas_tt] 

7130 meas_x.extend(mx_ii) 

7131 meas_y.extend(my_ii) 

7132 color = (128 / 255, 128 / 255, 128 / 255) 

7133 meas_x = np.asarray(meas_x) 

7134 meas_y = np.asarray(meas_y) 

7135 if meas_x.size > 0: 

7136 if not added_meas_lbl: 

7137 f_hndl.axes[0].scatter( 

7138 meas_x, 

7139 meas_y, 

7140 zorder=-1, 

7141 alpha=0.35, 

7142 color=color, 

7143 marker="^", 

7144 label="Measurements", 

7145 ) 

7146 else: 

7147 f_hndl.axes[0].scatter( 

7148 meas_x, meas_y, zorder=-1, alpha=0.35, color=color, marker="^" 

7149 ) 

7150 f_hndl.axes[0].grid(True) 

7151 pltUtil.set_title_label( 

7152 f_hndl, 0, opts, ttl=ttl, x_lbl="x-position", y_lbl="y-position" 

7153 ) 

7154 if lgnd_loc is not None: 

7155 plt.legend(loc=lgnd_loc) 

7156 plt.tight_layout() 

7157 

7158 return f_hndl 

7159 

7160 

7161class _STMPMBMBase: 

7162 def __init__(self, **kwargs): 

7163 super().__init__(**kwargs) 

7164 

7165 def _init_filt_states(self, distrib): 

7166 filt_states = [None] * len(distrib.means) 

7167 states = [m.copy() for m in distrib.means] 

7168 covs = [None] * len(distrib.means) 

7169 

7170 weights = distrib.weights.copy() 

7171 self._baseFilter.dof = distrib.dof 

7172 for ii, scale in enumerate(distrib.scalings): 

7173 self._baseFilter.scale = scale.copy() 

7174 filt_states[ii] = self._baseFilter.save_filter_state() 

7175 if self.save_covs: 

7176 # no need to copy because cov is already a new object for the student's t-fitler 

7177 covs[ii] = self.filter.cov 

7178 return filt_states, weights, states, covs 

7179 

7180 def _gate_meas(self, meas, means, covs, **kwargs): 

7181 # TODO: check this implementation 

7182 if len(meas) == 0: 

7183 return [] 

7184 scalings = [] 

7185 for ent in self._track_tab: 

7186 scalings.extend(ent.probDensity.scalings) 

7187 valid = [] 

7188 for m, p in zip(means, scalings): 

7189 meas_mat = self.filter.get_meas_mat(m, **kwargs) 

7190 est = self.filter.get_est_meas(m, **kwargs) 

7191 factor = ( 

7192 self.filter.meas_noise_dof 

7193 * (self.filter.dof - 2) 

7194 / (self.filter.dof * (self.filter.meas_noise_dof - 2)) 

7195 ) 

7196 P_zz = meas_mat @ p @ meas_mat.T + factor * self.filter.meas_noise 

7197 inv_P = la.inv(P_zz) 

7198 

7199 for ii, z in enumerate(meas): 

7200 if ii in valid: 

7201 continue 

7202 innov = z - est 

7203 dist = innov.T @ inv_P @ innov 

7204 if dist < self.inv_chi2_gate: 

7205 valid.append(ii) 

7206 valid.sort() 

7207 return [meas[ii] for ii in valid] 

7208 

7209 

7210class STMPoissonMultiBernoulliMixture(_STMPMBMBase, PoissonMultiBernoulliMixture): 

7211 """Implementation of a STM-PMBM filter.""" 

7212 

7213 def __init__(self, **kwargs): 

7214 super().__init__(**kwargs) 

7215 

7216 

7217class STMLabeledPoissonMultiBernoulliMixture( 

7218 _STMPMBMBase, LabeledPoissonMultiBernoulliMixture 

7219): 

7220 """Implementation of a STM-LPMBM filter.""" 

7221 

7222 def __init__(self, **kwargs): 

7223 super().__init__(**kwargs) 

7224 

7225 

7226class _SMCPMBMBase: 

7227 def __init__( 

7228 self, compute_prob_detection=None, compute_prob_survive=None, **kwargs 

7229 ): 

7230 self.compute_prob_detection = compute_prob_detection 

7231 self.compute_prob_survive = compute_prob_survive 

7232 

7233 # for wrappers for predict/correct function to handle extra args for private functions 

7234 self._prob_surv_args = () 

7235 self._prob_det_args = () 

7236 

7237 super().__init__(**kwargs) 

7238 

7239 def _init_filt_states(self, distrib): 

7240 self._baseFilter.init_from_dist(distrib, make_copy=True) 

7241 filt_states = [ 

7242 self._baseFilter.save_filter_state(), 

7243 ] 

7244 states = [distrib.mean] 

7245 if self.save_covs: 

7246 covs = [ 

7247 distrib.covariance, 

7248 ] 

7249 else: 

7250 covs = [] 

7251 weights = [ 

7252 1, 

7253 ] # not needed so set to 1 

7254 

7255 return filt_states, weights, states, covs 

7256 

7257 def _calc_avg_prob_surv_death(self): 

7258 avg_prob_survive = np.zeros(len(self._track_tab)) 

7259 for tabidx, ent in enumerate(self._track_tab): 

7260 # TODO: fix hack so not using "private" variable outside class 

7261 p_surv = self.compute_prob_survive( 

7262 ent.filt_states[0]["_particleDist"].particles, *self._prob_surv_args 

7263 ) 

7264 avg_prob_survive[tabidx] = np.sum( 

7265 np.array(ent.filt_states[0]["_particleDist"].weights) * p_surv 

7266 ) 

7267 avg_prob_death = 1 - avg_prob_survive 

7268 

7269 return avg_prob_survive, avg_prob_death 

7270 

7271 def _inner_predict(self, timestep, filt_state, state, filt_args): 

7272 self.filter.load_filter_state(filt_state) 

7273 if self.filter._particleDist.num_particles > 0: 

7274 new_s = self.filter.predict(timestep, **filt_args) 

7275 

7276 # manually update weights to account for prob survive 

7277 # TODO: fix hack so not using "private" variable outside class 

7278 ps = self.compute_prob_survive( 

7279 self.filter._particleDist.particles, *self._prob_surv_args 

7280 ) 

7281 new_weights = [ 

7282 w * ps[ii] for ii, (p, w) in enumerate(self.filter._particleDist) 

7283 ] 

7284 tot = sum(new_weights) 

7285 if np.abs(tot) == np.inf: 

7286 w_lst = [np.inf] * len(new_weights) 

7287 else: 

7288 w_lst = [w / tot for w in new_weights] 

7289 self.filter._particleDist.update_weights(w_lst) 

7290 

7291 new_f_state = self.filter.save_filter_state() 

7292 if self.save_covs: 

7293 new_cov = self.filter.cov.copy() 

7294 else: 

7295 new_cov = None 

7296 else: 

7297 new_f_state = self.filter.save_filter_state() 

7298 new_s = state 

7299 new_cov = self.filter.cov 

7300 return new_f_state, new_s, new_cov 

7301 

7302 def predict(self, timestep, prob_surv_args=(), **kwargs): 

7303 """Prediction step of the SMC-GLMB filter. 

7304 

7305 This is a wrapper for the parent class to allow for extra parameters. 

7306 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.predict` for 

7307 additional details. 

7308 

7309 Parameters 

7310 ---------- 

7311 timestep : float 

7312 Current timestep. 

7313 prob_surv_args : tuple, optional 

7314 Additional arguments for the `compute_prob_survive` function. 

7315 The default is (). 

7316 **kwargs : dict, optional 

7317 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.predict` 

7318 """ 

7319 self._prob_surv_args = prob_surv_args 

7320 return super().predict(timestep, **kwargs) 

7321 

7322 def _calc_avg_prob_det_mdet(self, cor_tab): 

7323 avg_prob_detect = np.zeros(len(cor_tab)) 

7324 for tabidx, ent in enumerate(cor_tab): 

7325 # TODO: fix hack so not using "private" variable outside class 

7326 p_detect = self.compute_prob_detection( 

7327 ent.filt_states[0]["_particleDist"].particles, *self._prob_det_args 

7328 ) 

7329 avg_prob_detect[tabidx] = np.sum( 

7330 np.array(ent.filt_states[0]["_particleDist"].weights) * p_detect 

7331 ) 

7332 avg_prob_miss_detect = 1 - avg_prob_detect 

7333 

7334 return avg_prob_detect, avg_prob_miss_detect 

7335 

7336 def _inner_correct( 

7337 self, timestep, meas, filt_state, distrib_weight, state, filt_args 

7338 ): 

7339 self.filter.load_filter_state(filt_state) 

7340 if self.filter._particleDist.num_particles > 0: 

7341 cor_state, likely = self.filter.correct(timestep, meas, **filt_args)[0:2] 

7342 

7343 # manually update the particle weights to account for probability of detection 

7344 # TODO: fix hack so not using "private" variable outside class 

7345 pd = self.compute_prob_detection( 

7346 self.filter._particleDist.particles, *self._prob_det_args 

7347 ) 

7348 pd_weight = ( 

7349 pd * np.array(self.filter._particleDist.weights) + np.finfo(float).eps 

7350 ) 

7351 self.filter._particleDist.update_weights( 

7352 (pd_weight / np.sum(pd_weight)).tolist() 

7353 ) 

7354 

7355 # determine the partial cost, the remainder is calculated later from 

7356 # the hypothesis 

7357 new_w = np.sum(likely * pd_weight) # same as cost in this case 

7358 

7359 new_f_state = self.filter.save_filter_state() 

7360 new_s = cor_state 

7361 if self.save_covs: 

7362 new_c = self.filter.cov 

7363 else: 

7364 new_c = None 

7365 else: 

7366 new_f_state = self.filter.save_filter_state() 

7367 new_s = state 

7368 new_c = self.filter.cov 

7369 new_w = 0 

7370 return new_f_state, new_s, new_c, new_w 

7371 

7372 def correct(self, timestep, meas, prob_det_args=(), **kwargs): 

7373 """Correction step of the SMC-GLMB filter. 

7374 

7375 This is a wrapper for the parent class to allow for extra parameters. 

7376 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.correct` for 

7377 additional details. 

7378 

7379 Parameters 

7380 ---------- 

7381 timestep : float 

7382 Current timestep. 

7383 prob_det_args : tuple, optional 

7384 Additional arguments for the `compute_prob_detection` function. 

7385 The default is (). 

7386 **kwargs : dict, optional 

7387 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.correct` 

7388 """ 

7389 self._prob_det_args = prob_det_args 

7390 return super().correct(timestep, meas, **kwargs) 

7391 

7392 def extract_most_prob_states(self, thresh, **kwargs): 

7393 """Extracts themost probable states. 

7394 

7395 .. todo:: 

7396 Implement this function for the SMC-GLMB filter 

7397 

7398 Raises 

7399 ------ 

7400 RuntimeWarning 

7401 Function must be implemented. 

7402 """ 

7403 warnings.warn("Not implemented for this class") 

7404 

7405 

7406class SMCPoissonMultiBernoulliMixture(_SMCPMBMBase, PoissonMultiBernoulliMixture): 

7407 """Implementation of a Sequential Monte Carlo PMBM filter. 

7408 

7409 This filter does not account for agents spawned from existing tracks, only agents 

7410 birthed from the given birth model. 

7411 

7412 Attributes 

7413 ---------- 

7414 compute_prob_detection : callable 

7415 Function that takes a list of particles as the first argument and `*args` 

7416 as the next. Returns the probability of detection for each particle as a list. 

7417 compute_prob_survive : callable 

7418 Function that takes a list of particles as the first argument and `*args` as 

7419 the next. Returns the average probability of survival for each particle as a list. 

7420 """ 

7421 

7422 def __init__(self, **kwargs): 

7423 super().__init__(**kwargs) 

7424 

7425 

7426class SMCLabeledPoissonMultiBernoulliMixture( 

7427 _SMCPMBMBase, LabeledPoissonMultiBernoulliMixture 

7428): 

7429 """Implementation of a Sequential Monte Carlo LPMBM filter. 

7430 

7431 This filter does not account for agents spawned from existing tracks, only agents 

7432 birthed from the given birth model. 

7433 

7434 Attributes 

7435 ---------- 

7436 compute_prob_detection : callable 

7437 Function that takes a list of particles as the first argument and `*args` 

7438 as the next. Returns the probability of detection for each particle as a list. 

7439 compute_prob_survive : callable 

7440 Function that takes a list of particles as the first argument and `*args` as 

7441 the next. Returns the average probability of survival for each particle as a list. 

7442 """ 

7443 

7444 def __init__(self, **kwargs): 

7445 super().__init__(**kwargs) 

7446 

7447 

7448class _IMMPMBMBase: 

7449 def _init_filt_states(self, distrib): 

7450 filt_states = [None] * len(distrib.means) 

7451 states = [m.copy() for m in distrib.means] 

7452 if self.save_covs: 

7453 covs = [c.copy() for c in distrib.covariances] 

7454 else: 

7455 covs = [] 

7456 weights = distrib.weights.copy() 

7457 for ii, (m, cov) in enumerate(zip(distrib.means, distrib.covariances)): 

7458 # if len(m) != 1 or len(cov) != 1: 

7459 # raise ValueError("Only one mean can be passed to IMM filters for initialization") 

7460 m_list = [] 

7461 c_list = [] 

7462 for jj in range(0, len(self._baseFilter.in_filt_list)): 

7463 m_list.append(m) 

7464 c_list.append(cov) 

7465 self._baseFilter.initialize_states(m_list, c_list) 

7466 filt_states[ii] = self._baseFilter.save_filter_state() 

7467 return filt_states, weights, states, covs 

7468 

7469 def _inner_predict(self, timestep, filt_state, state, filt_args): 

7470 self.filter.load_filter_state(filt_state) 

7471 new_s = self.filter.predict(timestep, **filt_args) 

7472 new_f_state = self.filter.save_filter_state() 

7473 if self.save_covs: 

7474 new_cov = self.filter.cov.copy() 

7475 else: 

7476 new_cov = None 

7477 return new_f_state, new_s, new_cov 

7478 

7479 def _inner_correct( 

7480 self, timestep, meas, filt_state, distrib_weight, state, filt_args 

7481 ): 

7482 self.filter.load_filter_state(filt_state) 

7483 cor_state, likely = self.filter.correct(timestep, meas, **filt_args) 

7484 new_f_state = self.filter.save_filter_state() 

7485 new_s = cor_state 

7486 if self.save_covs: 

7487 new_c = self.filter.cov.copy() 

7488 else: 

7489 new_c = None 

7490 new_w = distrib_weight * likely 

7491 

7492 return new_f_state, new_s, new_c, new_w 

7493 

7494 

7495class IMMPoissonMultiBernoulliMixture(_IMMPMBMBase, PoissonMultiBernoulliMixture): 

7496 """An implementation of the IMM-PMBM algorithm.""" 

7497 

7498 def __init__(self, **kwargs): 

7499 super().__init__(**kwargs) 

7500 

7501 

7502class IMMLabeledPoissonMultiBernoulliMixture( 

7503 _IMMPMBMBase, LabeledPoissonMultiBernoulliMixture 

7504): 

7505 """An implementation of the IMM-LPMBM algorithm.""" 

7506 

7507 def __init__(self, **kwargs): 

7508 super().__init__(**kwargs) 

7509 

7510 

7511class MSPoissonMultiBernoulliMixture(PoissonMultiBernoulliMixture): 

7512 """An Implementation of the Multiple Sensor PMBM Filter.""" 

7513 

7514 # Need measurement association history to incorporate meas inds from each sensor 

7515 def __init__(selfself, **kwargs): 

7516 super().__init__(**kwargs) 

7517 

7518 def _gen_cor_tab(self, num_meas, meas, timestep, comb_inds, filt_args): 

7519 num_pred = len(self._track_tab) 

7520 num_birth = len(self.birth_terms) 

7521 up_tab = [None] * ((num_meas + 1) * num_pred + num_meas * num_birth) 

7522 

7523 # Missed Detection Updates 

7524 for ii, track in enumerate(self._track_tab): 

7525 up_tab[ii] = self._TabEntry().setup(track) 

7526 sum_non_exist_prob = ( 

7527 1 

7528 - up_tab[ii].exist_prob 

7529 + up_tab[ii].exist_prob * self.prob_miss_detection 

7530 ) 

7531 up_tab[ii].distrib_weights_hist.append( 

7532 [w * sum_non_exist_prob for w in up_tab[ii].distrib_weights_hist[-1]] 

7533 ) 

7534 up_tab[ii].exist_prob = ( 

7535 up_tab[ii].exist_prob * self.prob_miss_detection 

7536 ) / (sum_non_exist_prob) 

7537 up_tab[ii].meas_assoc_hist.append(None) 

7538 # left_cost_m = np.zeros() 

7539 # all_cost_m = np.zeros((num_pred + num_birth * num_meas, num_meas)) 

7540 all_cost_m = np.zeros((num_meas, num_pred + num_birth * num_meas)) 

7541 

7542 # Update for all existing tracks 

7543 for emm, z in enumerate(meas): 

7544 for ii, ent in enumerate(self._track_tab): 

7545 s_to_ii = num_pred * emm + ii + num_pred 

7546 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry( 

7547 z, ent, timestep, filt_args 

7548 ) 

7549 if up_tab[s_to_ii] is not None: 

7550 up_tab[s_to_ii].meas_assoc_hist.append(comb_inds[emm]) 

7551 all_cost_m[emm, ii] = cost 

7552 

7553 # Update for all potential new births 

7554 for emm, z in enumerate(meas): 

7555 for ii, b_model in enumerate(self.birth_terms): 

7556 s_to_ii = ((num_meas + 1) * num_pred) + emm * num_birth + ii 

7557 (up_tab[s_to_ii], cost) = self._correct_birth_tab_entry( 

7558 z, b_model, timestep, filt_args 

7559 ) 

7560 if up_tab[s_to_ii] is not None: 

7561 up_tab[s_to_ii].meas_assoc_hist.append(comb_inds[emm]) 

7562 all_cost_m[emm, emm + num_pred] = cost 

7563 return up_tab, all_cost_m 

7564 

7565 def _gen_cor_hyps( 

7566 self, 

7567 num_meas, 

7568 avg_prob_detect, 

7569 avg_prob_miss_detect, 

7570 all_cost_m, 

7571 meas_combs, 

7572 cor_tab, 

7573 ): 

7574 num_pred = len(self._track_tab) 

7575 up_hyps = [] 

7576 if not meas_combs: 

7577 meas_combs = np.arange(0, np.shape(all_cost_m)[0]).tolist() 

7578 # n_obj_for_tracks = 

7579 if num_meas == 0: 

7580 for hyp in self._hypotheses: 

7581 pmd_log = np.sum( 

7582 [np.log(avg_prob_miss_detect[ii]) for ii in hyp.track_set] 

7583 ) 

7584 hyp.assoc_prob = -self.clutter_rate + pmd_log + np.log(hyp.assoc_prob) 

7585 up_hyps.append(hyp) 

7586 else: 

7587 clutter = self.clutter_rate * self.clutter_den 

7588 ss_w = 0 

7589 for p_hyp in self._hypotheses: 

7590 ss_w += np.sqrt(p_hyp.assoc_prob) 

7591 for p_hyp in self._hypotheses: 

7592 for ind_lst in meas_combs: 

7593 if len(meas_combs) == 1: 

7594 if p_hyp.num_tracks == 0: # all clutter 

7595 inds = np.arange(num_pred, num_pred + num_meas).tolist() 

7596 else: 

7597 inds = ( 

7598 p_hyp.track_set 

7599 + np.arange(num_pred, num_pred + num_meas).tolist() 

7600 ) 

7601 cost_m = all_cost_m[:, inds] 

7602 else: 

7603 if p_hyp.num_tracks == 0: # all clutter 

7604 inds = np.arange(num_pred, num_pred + len(ind_lst)).tolist() 

7605 else: 

7606 inds = p_hyp.track_set + [x + num_pred for x in ind_lst] 

7607 tcm = all_cost_m[ 

7608 :, inds 

7609 ] # error is certainly caused here. I'm going to bed now because it's past 11. 

7610 cost_m = tcm[ind_lst, :] 

7611 max_row_inds, max_col_inds = np.where(cost_m >= np.inf) 

7612 if max_row_inds.size > 0: 

7613 cost_m[max_row_inds, max_col_inds] = np.finfo(float).max 

7614 min_row_inds, min_col_inds = np.where(cost_m <= 0.0) 

7615 if min_row_inds.size > 0: 

7616 cost_m[min_row_inds, min_col_inds] = np.finfo(float).eps # 1 

7617 neg_log = -np.log(cost_m) 

7618 

7619 m = np.round(self.req_upd * np.sqrt(p_hyp.assoc_prob) / ss_w) 

7620 m = int(m.item()) 

7621 

7622 [assigns, costs] = murty_m_best_all_meas_assigned(neg_log, m) 

7623 

7624 pmd_log = np.sum( 

7625 [np.log(avg_prob_miss_detect[ii]) for ii in p_hyp.track_set] 

7626 ) 

7627 for a, c in zip(assigns, costs): 

7628 new_hyp = self._HypothesisHelper() 

7629 new_hyp.assoc_prob = ( 

7630 -self.clutter_rate 

7631 + num_meas * np.log(clutter) 

7632 + pmd_log 

7633 + np.log(p_hyp.assoc_prob) 

7634 - c 

7635 ) 

7636 if p_hyp.num_tracks == 0: 

7637 new_track_list = list(num_pred * a + num_pred * num_meas) 

7638 else: 

7639 # track_inds = np.argwhere(a==1) 

7640 new_track_list = [] 

7641 

7642 for ii, ms in enumerate(a): 

7643 if len(p_hyp.track_set) >= ms: 

7644 new_track_list.append( 

7645 (ii + 1) * num_pred + p_hyp.track_set[(ms - 1)] 

7646 ) 

7647 elif len(p_hyp.track_set) == len(a): 

7648 new_track_list.append( 

7649 num_pred * ms - ii * (num_pred - 1) 

7650 ) 

7651 elif len(p_hyp.track_set) < len(a): 

7652 new_track_list.append( 

7653 num_pred * (num_meas + 1) + (ms - num_meas) 

7654 ) 

7655 else: 

7656 new_track_list.append( 

7657 (num_meas + 1) * num_pred 

7658 - (ms - len(p_hyp.track_set) - 1) 

7659 ) 

7660 # if len(a) == len(p_hyp.track_set): 

7661 # for ii, (ms, t) in enumerate(zip(a, p_hyp.track_set)): 

7662 # if len(p_hyp.track_set) >= ms: 

7663 # # new_track_list.append(((np.array(t)) * ms + num_pred)) 

7664 # new_track_list.append( 

7665 # (num_pred * (ind_lst[ii] + 1) + np.array(t)) 

7666 # ) 

7667 # # new_track_list.append((num_pred * ms + np.array(t))) 

7668 # else: 

7669 # # new_track_list.append(num_pred * ms - ind_lst[ii] * (num_pred - 1)) 

7670 # new_track_list.append( 

7671 # num_pred * (ind_lst[ii] + 1) 

7672 # - ii * (num_pred - 1) 

7673 # ) 

7674 # elif len(p_hyp.track_set) < len(a): 

7675 # for ii, ms in enumerate(a): 

7676 # if len(p_hyp.track_set) >= ms: 

7677 # new_track_list.append( 

7678 # (1 + ind_lst[ii]) * num_pred 

7679 # + p_hyp.track_set[(ms - 1)] 

7680 # ) 

7681 # elif len(p_hyp.track_set) < ms: 

7682 # new_track_list.append( 

7683 # num_pred * (num_meas + 1) + (ind_lst[ii]) 

7684 # ) 

7685 # elif len(p_hyp.track_set) > len(a): 

7686 # # May need to modify this 

7687 # for ii, ms in enumerate(a): 

7688 # if len(p_hyp.track_set) >= ms: 

7689 # new_track_list.append( 

7690 # ms * num_pred + p_hyp.track_set[(ms - 1)] 

7691 # ) 

7692 # elif len(p_hyp.track_set) < ms: 

7693 # new_track_list.append( 

7694 # num_pred * (num_meas + 1) + (ms - num_meas) 

7695 # ) 

7696 

7697 # new_track_list = list(np.array(p_hyp.track_set) + num_pred + num_pred * a)# new_track_list = list(num_pred * a + np.array(p_hyp.track_set)) 

7698 

7699 new_hyp.track_set = new_track_list 

7700 up_hyps.append(new_hyp) 

7701 

7702 lse = log_sum_exp([x.assoc_prob for x in up_hyps]) 

7703 for ii in range(0, len(up_hyps)): 

7704 up_hyps[ii].assoc_prob = np.exp(up_hyps[ii].assoc_prob - lse) 

7705 return up_hyps 

7706 

7707 def correct(self, timestep, meas, filt_args={}): 

7708 """Correction step of the MS-PMBM filter. 

7709 

7710 Notes 

7711 ----- 

7712 This corrects the hypotheses based on the measurements and gates the 

7713 measurements according to the class settings. It also updates the 

7714 cardinality distribution. 

7715 

7716 Parameters 

7717 ---------- 

7718 timestep: float 

7719 Current timestep. 

7720 meas : list 

7721 List of Nm x 1 numpy arrays each representing a measuremnt. 

7722 filt_args : dict, optional 

7723 keyword arguments to pass to the inner filters correct function. 

7724 The default is {}. 

7725 

7726 Returns 

7727 ------- 

7728 None 

7729 """ 

7730 all_combs = list(itertools.product(*meas)) 

7731 # TODO: Add method for only single measurements to be assoc'd i.e. all_combs needs to include single measurement options 

7732 if self.gating_on: 

7733 warnings.warn("Gating not implemented yet. SKIPPING", RuntimeWarning) 

7734 # means = [] 

7735 # covs = [] 

7736 # for ent in self._track_tab: 

7737 # means.extend(ent.probDensity.means) 

7738 # covs.extend(ent.probDensity.covariances) 

7739 # meas = self._gate_meas(meas, means, covs) 

7740 if self.save_measurements: 

7741 self._meas_tab.append(deepcopy(meas)) 

7742 

7743 # get matrix of indices in all_combs 

7744 num_meas_per_sens = [len(x) for x in meas] 

7745 num_meas = len(all_combs) 

7746 num_sens = len(meas) 

7747 mnmps = min(num_meas_per_sens) 

7748 comb_inds = list(itertools.product(*list(np.arange(0, len(x)) for x in meas))) 

7749 comb_inds = [list(ele) for ele in comb_inds] 

7750 min_meas_in_sens = np.min([len(x) for x in meas]) 

7751 

7752 all_meas_combs = list(itertools.combinations(comb_inds, mnmps)) 

7753 all_meas_combs = [list(ele) for ele in all_meas_combs] 

7754 

7755 poss_meas_combs = [] 

7756 

7757 for ii in range(0, len(all_meas_combs)): 

7758 break_flag = False 

7759 cur_comb = [] 

7760 for jj, lst1 in enumerate(all_meas_combs[ii]): 

7761 for kk, lst2 in enumerate(all_meas_combs[ii]): 

7762 if jj == kk: 

7763 continue 

7764 else: 

7765 out = (np.array(lst1) == np.array(lst2)).tolist() 

7766 if any(out): 

7767 break_flag = True 

7768 break 

7769 if break_flag: 

7770 break 

7771 if break_flag: 

7772 pass 

7773 else: 

7774 for lst1 in all_meas_combs[ii]: 

7775 for ii, lst2 in enumerate(comb_inds): 

7776 if lst1 == lst2: 

7777 cur_comb.append(ii) 

7778 poss_meas_combs.append(cur_comb) 

7779 

7780 cor_tab, all_cost_m = self._gen_cor_tab( 

7781 num_meas, all_combs, timestep, comb_inds, filt_args 

7782 ) 

7783 

7784 # self._add_birth_hyps(num_meas) 

7785 

7786 avg_prob_det, avg_prob_mdet = self._calc_avg_prob_det_mdet(cor_tab) 

7787 

7788 cor_hyps = self._gen_cor_hyps( 

7789 num_meas, avg_prob_det, avg_prob_mdet, all_cost_m, poss_meas_combs, cor_tab 

7790 ) 

7791 

7792 self._track_tab = cor_tab 

7793 self._hypotheses = cor_hyps 

7794 self._card_dist = self._calc_card_dist(self._hypotheses) 

7795 self._clean_updates() 

7796 

7797 

7798class MSLabeledPoissonMultiBernoulliMixture(LabeledPoissonMultiBernoulliMixture): 

7799 def __init__(self, **kwargs): 

7800 super().__init__(**kwargs) 

7801 

7802 def _gen_cor_tab(self, num_meas, meas, timestep, comb_inds, filt_args): 

7803 num_pred = len(self._track_tab) 

7804 num_birth = len(self.birth_terms) 

7805 up_tab = [None] * ((num_meas + 1) * num_pred + num_meas * num_birth) 

7806 

7807 # Missed Detection Updates 

7808 for ii, track in enumerate(self._track_tab): 

7809 up_tab[ii] = self._TabEntry().setup(track) 

7810 sum_non_exist_prob = ( 

7811 1 

7812 - up_tab[ii].exist_prob 

7813 + up_tab[ii].exist_prob * self.prob_miss_detection 

7814 ) 

7815 up_tab[ii].distrib_weights_hist.append( 

7816 [w * sum_non_exist_prob for w in up_tab[ii].distrib_weights_hist[-1]] 

7817 ) 

7818 up_tab[ii].exist_prob = ( 

7819 up_tab[ii].exist_prob * self.prob_miss_detection 

7820 ) / (sum_non_exist_prob) 

7821 up_tab[ii].meas_assoc_hist.append(None) 

7822 all_cost_m = np.zeros((num_meas, num_pred + num_birth * num_meas)) 

7823 

7824 # Update for all existing tracks 

7825 for emm, z in enumerate(meas): 

7826 for ii, ent in enumerate(self._track_tab): 

7827 s_to_ii = num_pred * emm + ii + num_pred 

7828 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry( 

7829 z, ent, timestep, filt_args 

7830 ) 

7831 if up_tab[s_to_ii] is not None: 

7832 up_tab[s_to_ii].meas_assoc_hist.append(comb_inds[emm]) 

7833 all_cost_m[emm, ii] = cost 

7834 

7835 # Update for all potential new births 

7836 for emm, z in enumerate(meas): 

7837 for ii, b_model in enumerate(self.birth_terms): 

7838 s_to_ii = ((num_meas + 1) * num_pred) + emm * num_birth + ii 

7839 (up_tab[s_to_ii], cost) = self._correct_birth_tab_entry( 

7840 z, b_model, timestep, filt_args 

7841 ) 

7842 if up_tab[s_to_ii] is not None: 

7843 up_tab[s_to_ii].meas_assoc_hist.append(comb_inds[emm]) 

7844 up_tab[s_to_ii].label = (round(timestep, self.decimal_places), ii) 

7845 all_cost_m[emm, emm + num_pred] = cost 

7846 return up_tab, all_cost_m 

7847 

7848 def _gen_cor_hyps( 

7849 self, 

7850 num_meas, 

7851 avg_prob_detect, 

7852 avg_prob_miss_detect, 

7853 all_cost_m, 

7854 meas_combs, 

7855 cor_tab, 

7856 ): 

7857 num_pred = len(self._track_tab) 

7858 up_hyps = [] 

7859 if not meas_combs: 

7860 meas_combs = np.arange(0, np.shape(all_cost_m)[0]).tolist() 

7861 # n_obj_for_tracks = 

7862 if num_meas == 0: 

7863 for hyp in self._hypotheses: 

7864 pmd_log = np.sum( 

7865 [np.log(avg_prob_miss_detect[ii]) for ii in hyp.track_set] 

7866 ) 

7867 hyp.assoc_prob = -self.clutter_rate + pmd_log + np.log(hyp.assoc_prob) 

7868 up_hyps.append(hyp) 

7869 else: 

7870 clutter = self.clutter_rate * self.clutter_den 

7871 ss_w = 0 

7872 for p_hyp in self._hypotheses: 

7873 ss_w += np.sqrt(p_hyp.assoc_prob) 

7874 for p_hyp in self._hypotheses: 

7875 for ind_lst in meas_combs: 

7876 if len(meas_combs) == 1: 

7877 if p_hyp.num_tracks == 0: # all clutter 

7878 inds = np.arange(num_pred, num_pred + num_meas).tolist() 

7879 else: 

7880 inds = ( 

7881 p_hyp.track_set 

7882 + np.arange(num_pred, num_pred + num_meas).tolist() 

7883 ) 

7884 cost_m = all_cost_m[:, inds] 

7885 else: 

7886 # TODO Change this process so it works when we can't just arange through to the end. I think we add num_pred to the indices in ind_lst 

7887 if p_hyp.num_tracks == 0: # all clutter 

7888 # inds = np.arange(num_pred, num_pred + len(ind_lst)).tolist() 

7889 inds = list(num_pred + np.array(ind_lst)) 

7890 else: 

7891 inds = p_hyp.track_set + [x + num_pred for x in ind_lst] 

7892 tcm = all_cost_m[ 

7893 :, inds 

7894 ] # error is certainly caused here. I'm going to bed now because it's past 11. 

7895 cost_m = tcm[ind_lst, :] 

7896 max_row_inds, max_col_inds = np.where(cost_m >= np.inf) 

7897 if max_row_inds.size > 0: 

7898 cost_m[max_row_inds, max_col_inds] = np.finfo(float).max 

7899 min_row_inds, min_col_inds = np.where(cost_m <= 0.0) 

7900 if min_row_inds.size > 0: 

7901 cost_m[min_row_inds, min_col_inds] = np.finfo(float).eps # 1 

7902 neg_log = -np.log(cost_m) 

7903 

7904 m = np.round(self.req_upd * np.sqrt(p_hyp.assoc_prob) / ss_w) 

7905 m = int(m.item()) 

7906 

7907 [assigns, costs] = murty_m_best_all_meas_assigned(neg_log, m) 

7908 

7909 pmd_log = np.sum( 

7910 [np.log(avg_prob_miss_detect[ii]) for ii in p_hyp.track_set] 

7911 ) 

7912 for a, c in zip(assigns, costs): 

7913 new_hyp = self._HypothesisHelper() 

7914 new_hyp.assoc_prob = ( 

7915 -self.clutter_rate 

7916 + num_meas * np.log(clutter) 

7917 + pmd_log 

7918 + np.log(p_hyp.assoc_prob) 

7919 - c 

7920 ) 

7921 if p_hyp.num_tracks == 0: 

7922 # new_track_list = list(num_pred * a + num_pred * num_meas) 

7923 # new_track_list = list(num_pred * a + num_pred * num_meas) 

7924 new_track_list = [] 

7925 for ii, ms in enumerate(a): 

7926 new_track_list.append( 

7927 (ind_lst[ii] * (num_pred + 1) + num_pred * num_meas) 

7928 ) 

7929 else: 

7930 # track_inds = np.argwhere(a==1) 

7931 new_track_list = [] 

7932 for ii, ms in enumerate(a): 

7933 if len(p_hyp.track_set) >= ms: 

7934 new_track_list.append( 

7935 (ind_lst[ii] + 1) * num_pred 

7936 + p_hyp.track_set[(ms - 1)] 

7937 ) 

7938 elif len(p_hyp.track_set) == len(a): 

7939 new_track_list.append( 

7940 num_pred * ms - ii * (num_pred - 1) 

7941 ) 

7942 elif len(p_hyp.track_set) < len(a): 

7943 new_track_list.append( 

7944 num_pred * (num_meas + 1) + (ind_lst[ii]) 

7945 ) 

7946 else: 

7947 new_track_list.append( 

7948 num_pred * (num_meas + 1) 

7949 - (ms - len(p_hyp.track_set) - 1) 

7950 ) 

7951 

7952 # if len(a) == len(p_hyp.track_set): 

7953 # for ii, (ms, t) in enumerate(zip(a, p_hyp.track_set)): 

7954 # if len(p_hyp.track_set) >= ms: 

7955 # # new_track_list.append(((np.array(t)) * ms + num_pred)) 

7956 # new_track_list.append( 

7957 # (num_pred * (ind_lst[ii] + 1) + np.array(t)) 

7958 # ) 

7959 # # new_track_list.append((num_pred * ms + np.array(t))) 

7960 # else: 

7961 # # new_track_list.append(num_pred * ms - ind_lst[ii] * (num_pred - 1)) 

7962 # new_track_list.append( 

7963 # num_pred * (ind_lst[ii] + 1) + t + ms 

7964 # # num_pred * (ms + 1) + ind_lst[ii] 

7965 # # num_pred * (ind_lst[ii] + 1) 

7966 # # - ii * (num_pred - 1) 

7967 # ) 

7968 # elif len(p_hyp.track_set) < len(a): 

7969 # for ii, ms in enumerate(a): 

7970 # if len(p_hyp.track_set) >= ms: 

7971 # new_track_list.append( 

7972 # (1 + ind_lst[ii]) * num_pred 

7973 # + p_hyp.track_set[(ms - 1)] 

7974 # ) 

7975 # elif len(p_hyp.track_set) < ms: 

7976 # new_track_list.append( 

7977 # num_pred * (num_meas + 1) + (ind_lst[ii]) 

7978 # ) 

7979 # elif len(p_hyp.track_set) > len(a): 

7980 # # May need to modify this 

7981 # for ii, ms in enumerate(a): 

7982 # if len(p_hyp.track_set) >= ms: 

7983 # # new_track_list.append( 

7984 # # ms * num_pred + p_hyp.track_set[(ms - 1)] 

7985 # # ) 

7986 # new_track_list.append( 

7987 # (1 + ind_lst[ii]) * num_pred 

7988 # + p_hyp.track_set[(ms - 1)] 

7989 # ) 

7990 # elif len(p_hyp.track_set) < ms: 

7991 # new_track_list.append( 

7992 # num_pred * (num_meas + 1) - (ms - len(p_hyp.track_set) - 1) 

7993 # ) 

7994 

7995 # new_track_list = list(np.array(p_hyp.track_set) + num_pred + num_pred * a)# new_track_list = list(num_pred * a + np.array(p_hyp.track_set)) 

7996 

7997 new_hyp.track_set = new_track_list 

7998 up_hyps.append(new_hyp) 

7999 

8000 lse = log_sum_exp([x.assoc_prob for x in up_hyps]) 

8001 for ii in range(0, len(up_hyps)): 

8002 up_hyps[ii].assoc_prob = np.exp(up_hyps[ii].assoc_prob - lse) 

8003 return up_hyps 

8004 

8005 def correct(self, timestep, meas, filt_args={}): 

8006 """Correction step of the MS-PMBM filter. 

8007 

8008 Notes 

8009 ----- 

8010 This corrects the hypotheses based on the measurements and gates the 

8011 measurements according to the class settings. It also updates the 

8012 cardinality distribution. 

8013 

8014 Parameters 

8015 ---------- 

8016 timestep: float 

8017 Current timestep. 

8018 meas : list 

8019 List of Nm x 1 numpy arrays each representing a measuremnt. 

8020 filt_args : dict, optional 

8021 keyword arguments to pass to the inner filters correct function. 

8022 The default is {}. 

8023 

8024 Returns 

8025 ------- 

8026 None 

8027 """ 

8028 # sens_len_lst = [] 

8029 # for sens in meas: 

8030 # temp_lst = np.array([len(x) for x in sens]) 

8031 # if np.all(temp_lst==0): 

8032 # sens_len_lst.append(0) 

8033 # else: 

8034 # sens_len_lst.append(len(sens)) 

8035 # if len(sens[-1]) != 0: 

8036 # sens.append(np.array([])) 

8037 all_combs = list(itertools.product(*meas)) 

8038 # all_combs.pop(-1) 

8039 

8040 # for ii, c in enumerate(all_combs): 

8041 # if np.all([len(tmplst) == 0 for tmplst in c]): 

8042 # all_combs.pop(ii) 

8043 # else: 

8044 # continue 

8045 if self.gating_on: 

8046 warnings.warn("Gating not implemented yet. SKIPPING", RuntimeWarning) 

8047 # means = [] 

8048 # covs = [] 

8049 # for ent in self._track_tab: 

8050 # means.extend(ent.probDensity.means) 

8051 # covs.extend(ent.probDensity.covariances) 

8052 # meas = self._gate_meas(meas, means, covs) 

8053 if self.save_measurements: 

8054 self._meas_tab.append(deepcopy(meas)) 

8055 

8056 # get matrix of indices in all_combs somehow 

8057 num_meas_per_sens = [len(x) for x in meas] 

8058 num_meas = len(all_combs) 

8059 num_sens = len(meas) 

8060 # NEED TO ITERATE THROUGH THIS 

8061 # SHOULD GO FROM CURR MNMPS AND STOP WHEN IT HITS 0. 

8062 # THEN WE'LL GET COMBINATIONS OF ALL POSSIBLE MEASUREMENTS FOR ALL POSSIBLE PERMUTATIONS 

8063 mnmps = min(num_meas_per_sens) 

8064 

8065 # find a way to make this list of lists not list of tuples 

8066 comb_inds = list(itertools.product(*list(np.arange(0, len(x)) for x in meas))) 

8067 comb_inds = [list(ele) for ele in comb_inds] 

8068 # pop_lst = [] 

8069 # for ii in range(len(comb_inds)): 

8070 # for jj in range(len(comb_inds[ii])): 

8071 # if comb_inds[ii][jj] >= sens_len_lst[jj]: 

8072 # comb_inds[ii][jj] = np.nan 

8073 # if np.all(np.isnan(comb_inds[ii])): 

8074 # pop_lst.append(ii) 

8075 # for ind in pop_lst: 

8076 # comb_inds.pop(ind) 

8077 min_meas_in_sens = np.min([len(x) for x in meas]) 

8078 

8079 all_meas_combs = list(itertools.combinations(comb_inds, mnmps)) 

8080 all_meas_combs = [list(ele) for ele in all_meas_combs] 

8081 

8082 poss_meas_combs = [] 

8083 

8084 # for ii in range(len(all_combs)): 

8085 # poss_meas_combs.append([ii]) 

8086 

8087 for ii in range(0, len(all_meas_combs)): 

8088 break_flag = False 

8089 cur_comb = [] 

8090 for jj, lst1 in enumerate(all_meas_combs[ii]): 

8091 for kk, lst2 in enumerate(all_meas_combs[ii]): 

8092 if jj == kk: 

8093 continue 

8094 else: 

8095 out = (np.array(lst1) == np.array(lst2)).tolist() 

8096 if any(out): 

8097 break_flag = True 

8098 break 

8099 if break_flag: 

8100 break 

8101 if break_flag: 

8102 pass 

8103 else: 

8104 for lst1 in all_meas_combs[ii]: 

8105 for ii, lst2 in enumerate(comb_inds): 

8106 if lst1 == lst2: 

8107 cur_comb.append(ii) 

8108 poss_meas_combs.append(cur_comb) 

8109 

8110 cor_tab, all_cost_m = self._gen_cor_tab( 

8111 num_meas, all_combs, timestep, comb_inds, filt_args 

8112 ) 

8113 

8114 # self._add_birth_hyps(num_meas) 

8115 

8116 avg_prob_det, avg_prob_mdet = self._calc_avg_prob_det_mdet(cor_tab) 

8117 

8118 cor_hyps = self._gen_cor_hyps( 

8119 num_meas, avg_prob_det, avg_prob_mdet, all_cost_m, poss_meas_combs, cor_tab 

8120 ) 

8121 

8122 self._track_tab = cor_tab 

8123 self._hypotheses = cor_hyps 

8124 self._card_dist = self._calc_card_dist(self._hypotheses) 

8125 self._clean_updates() 

8126 

8127 

8128class MSIMMPoissonMultiBernoulliMixture(_IMMPMBMBase, MSPoissonMultiBernoulliMixture): 

8129 def __init__(self, **kwargs): 

8130 super().__init__(**kwargs) 

8131 

8132 

8133class MSIMMLabeledPoissonMultiBernoulliMixture( 

8134 _IMMPMBMBase, MSLabeledPoissonMultiBernoulliMixture 

8135): 

8136 def __init__(self, **kwargs): 

8137 super().__init__(**kwargs)