Coverage for src/gncpy/filters/ 68%

234 statements  

« prev     ^ index     » next v7.2.7, created at 2023-07-19 05:48 +0000

1import numpy as np 

2import numpy.random as rnd 

3import matplotlib.pyplot as plt 

4from copy import deepcopy 

5from warnings import warn 


7import gncpy.distributions as gdistrib 

8import gncpy.errors as gerr 

9import gncpy.plotting as pltUtil 

10from gncpy.filters.bayes_filter import BayesFilter 



13class ParticleFilter(BayesFilter): 

14 """Implements a basic Particle Filter. 


16 Notes 

17 ----- 

18 The implementation is based on 

19 :cite:`Simon2006_OptimalStateEstimationKalmanHInfinityandNonlinearApproaches` 

20 and uses Sampling-Importance Resampling (SIR) sampling. Other resampling 

21 methods can be added in derived classes. 


23 Attributes 

24 ---------- 

25 require_copy_prop_parts : bool 

26 Flag indicating if the propagated particles need to be copied if this 

27 filter is being manipulated externally. This is a constant value that 

28 should not be modified outside of the class, but can be overridden by 

29 inherited classes. 

30 require_copy_can_dist : bool 

31 Flag indicating if a candidate distribution needs to be copied if this 

32 filter is being manipulated externally. This is a constant value that 

33 should not be modified outside of the class, but can be overridden by 

34 inherited classes. 

35 """ 


37 require_copy_prop_parts = True 

38 require_copy_can_dist = False 


40 def __init__( 

41 self, 

42 dyn_obj=None, 

43 dyn_fun=None, 

44 part_dist=None, 

45 transition_prob_fnc=None, 

46 rng=None, 

47 **kwargs 

48 ): 


50 self.__meas_likelihood_fnc = None 

51 self.__proposal_sampling_fnc = None 

52 self.__proposal_fnc = None 

53 self.__transition_prob_fnc = None 


55 if rng is None: 

56 rng = rnd.default_rng(1) 

57 self.rng = rng 


59 self._dyn_fnc = None 

60 self._dyn_obj = None 


62 self._meas_mat = None 

63 self._meas_fnc = None 


65 if dyn_obj is not None or dyn_fun is not None: 

66 self.set_state_model(dyn_obj=dyn_obj, dyn_fun=dyn_fun) 

67 self._particleDist = gdistrib.ParticleDistribution() 

68 if part_dist is not None: 

69 self.init_from_dist(part_dist) 

70 self.prop_parts = [] 


72 super().__init__(**kwargs) 


74 def save_filter_state(self): 

75 """Saves filter variables so they can be restored later.""" 

76 filt_state = super().save_filter_state() 


78 filt_state["__meas_likelihood_fnc"] = self.__meas_likelihood_fnc 

79 filt_state["__proposal_sampling_fnc"] = self.__proposal_sampling_fnc 

80 filt_state["__proposal_fnc"] = self.__proposal_fnc 

81 filt_state["__transition_prob_fnc"] = self.__transition_prob_fnc 


83 filt_state["rng"] = self.rng 


85 filt_state["_dyn_fnc"] = self._dyn_fnc 

86 filt_state["_dyn_obj"] = self._dyn_obj 


88 filt_state["_meas_mat"] = self._meas_mat 

89 filt_state["_meas_fnc"] = self._meas_fnc 


91 filt_state["_particleDist"] = deepcopy(self._particleDist) 

92 filt_state["prop_parts"] = deepcopy(self.prop_parts) 


94 return filt_state 


96 def load_filter_state(self, filt_state): 

97 """Initializes filter using saved filter state. 


99 Attributes 

100 ---------- 

101 filt_state : dict 

102 Dictionary generated by :meth:`save_filter_state`. 

103 """ 

104 super().load_filter_state(filt_state) 


106 self.__meas_likelihood_fnc = filt_state["__meas_likelihood_fnc"] 

107 self.__proposal_sampling_fnc = filt_state["__proposal_sampling_fnc"] 

108 self.__proposal_fnc = filt_state["__proposal_fnc"] 

109 self.__transition_prob_fnc = filt_state["__transition_prob_fnc"] 


111 self.rng = filt_state["rng"] 


113 self._dyn_fnc = filt_state["_dyn_fnc"] 

114 self._dyn_obj = filt_state["_dyn_obj"] 


116 self._meas_mat = filt_state["_meas_mat"] 

117 self._meas_fnc = filt_state["_meas_fnc"] 


119 self._particleDist = filt_state["_particleDist"] 

120 self.prop_parts = filt_state["prop_parts"] 


122 @property 

123 def meas_likelihood_fnc(self): 

124 r"""A function that returns the likelihood of the measurement. 


126 This must have the signature :code:`f(y, y_hat, *args)` where `y` is 

127 the measurement as an Nm x 1 numpy array, and `y_hat` is the estimated 

128 measurement. 


130 Notes 

131 ----- 

132 This represents :math:`p(y_t \vert x_t)` in the importance 

133 weight 


135 .. math:: 


137 w_t = w_{t-1} \frac{p(y_t \vert x_t) p(x_t \vert x_{t-1})}{q(x_t \vert x_{t-1}, y_t)} 


139 Returns 

140 ------- 

141 callable 

142 function to return the measurement likelihood. 

143 """ 

144 return self.__meas_likelihood_fnc 


146 @meas_likelihood_fnc.setter 

147 def meas_likelihood_fnc(self, val): 

148 self.__meas_likelihood_fnc = val 


150 @property 

151 def proposal_fnc(self): 

152 r"""A function that returns the probability for the proposal distribution. 


154 This must have the signature :code:`f(x_hat, x, y, *args)` where 

155 `x_hat` is a :class:`gncpy.distributions.Particle` of the estimated 

156 state, `x` is the particle it is conditioned on, and `y` is the 

157 measurement. 


159 Notes 

160 ----- 

161 This represents :math:`q(x_t \vert x_{t-1}, y_t)` in the importance 

162 weight 


164 .. math:: 


166 w_t = w_{t-1} \frac{p(y_t \vert x_t) p(x_t \vert x_{t-1})}{q(x_t \vert x_{t-1}, y_t)} 


168 Returns 

169 ------- 

170 callable 

171 function to return the proposal probability. 

172 """ 

173 return self.__proposal_fnc 


175 @proposal_fnc.setter 

176 def proposal_fnc(self, val): 

177 self.__proposal_fnc = val 


179 @property 

180 def proposal_sampling_fnc(self): 

181 """A function that returns a random sample from the proposal distribtion. 


183 This should be consistent with the PDF specified in the 

184 :meth:`gncpy.filters.ParticleFilter.proposal_fnc`. 


186 Returns 

187 ------- 

188 callable 

189 function to return a random sample. 

190 """ 

191 return self.__proposal_sampling_fnc 


193 @proposal_sampling_fnc.setter 

194 def proposal_sampling_fnc(self, val): 

195 self.__proposal_sampling_fnc = val 


197 @property 

198 def transition_prob_fnc(self): 

199 r"""A function that returns the transition probability for the state. 


201 This must have the signature :code:`f(x_hat, x, *args)` where 

202 `x_hat` is an N x 1 numpy array representing the propagated state, and 

203 `x` is the state it is conditioned on. 


205 Notes 

206 ----- 

207 This represents :math:`p(x_t \vert x_{t-1})` in the importance 

208 weight 


210 .. math:: 


212 w_t = w_{t-1} \frac{p(y_t \vert x_t) p(x_t \vert x_{t-1})}{q(x_t \vert x_{t-1}, y_t)} 


214 Returns 

215 ------- 

216 callable 

217 function to return the transition probability. 

218 """ 

219 return self.__transition_prob_fnc 


221 @transition_prob_fnc.setter 

222 def transition_prob_fnc(self, val): 

223 self.__transition_prob_fnc = val 


225 def set_state_model(self, dyn_obj=None, dyn_fun=None): 

226 """Sets the state model. 


228 Parameters 

229 ---------- 

230 dyn_obj : :class:gncpy.dynamics.DynamicsBase`, optional 

231 Dynamic object to use. The default is None. 

232 dyn_fun : callable, optional 

233 function that returns the next state. It must have the signature 

234 `f(t, x, *args)` and return a N x 1 numpy array. The default is None. 


236 Raises 

237 ------ 

238 RuntimeError 

239 If no model is specified. 


241 Returns 

242 ------- 

243 None. 

244 """ 

245 if dyn_obj is not None: 

246 self._dyn_obj = deepcopy(dyn_obj) 

247 elif dyn_fun is not None: 

248 self._dyn_fnc = dyn_fun 

249 else: 

250 msg = "Invalid state model specified. Check arguments" 

251 raise RuntimeError(msg) 


253 def set_measurement_model(self, meas_mat=None, meas_fun=None): 

254 r"""Sets the measurement model for the filter. 


256 This can either set the constant measurement matrix, or a set of 

257 non-linear functions (potentially time varying) to map states to 

258 measurements. 


260 Notes 

261 ----- 

262 The constant matrix assumes a measurement model of the form 


264 .. math:: 

265 \tilde{y}_{k+1} = H x_{k+1}^- 


267 and the non-linear case assumes 


269 .. math:: 

270 \tilde{y}_{k+1} = h(t, x_{k+1}^-) 


272 Parameters 

273 ---------- 

274 meas_mat : Nm x N numpy array, optional 

275 Measurement matrix that transforms the state to estimated 

276 measurements. The default is None. 

277 meas_fun_lst : list, optional 

278 Non-linear functions that return the expected measurement for the 

279 given state. Each function must have the signature `h(t, x, *args)`. 

280 The default is None. 


282 Raises 

283 ------ 

284 RuntimeError 

285 Rasied if no arguments are specified. 


287 Returns 

288 ------- 

289 None. 

290 """ 

291 if meas_mat is not None: 

292 self._meas_mat = meas_mat 

293 elif meas_fun is not None: 

294 self._meas_fnc = meas_fun 

295 else: 

296 raise RuntimeError("Invalid combination of inputs") 


298 @property 

299 def cov(self): 

300 """Read only covariance of the particles. 


302 Returns 

303 ------- 

304 N x N numpy array 

305 covariance matrix. 


307 """ 

308 return self._particleDist.covariance 


310 @cov.setter 

311 def cov(self, x): 

312 raise RuntimeError("Covariance is read only") 


314 @property 

315 def num_particles(self): 

316 """Read only number of particles used by the filter. 


318 Returns 

319 ------- 

320 int 

321 Number of particles. 

322 """ 

323 return self._particleDist.num_particles 


325 def init_from_dist(self, dist, make_copy=True): 

326 """Initialize the distribution from a distribution object. 


328 Parameters 

329 ---------- 

330 dist : :class:`gncpy.distributions.ParticleDistribution` 

331 Distribution object to use. 

332 make_copy : bool, optional 

333 Flag indicating if a deepcopy of the input distribution should be 

334 performed. The default is True. 


336 Returns 

337 ------- 

338 None. 

339 """ 

340 if make_copy: 

341 self._particleDist = deepcopy(dist) 

342 else: 

343 self._particleDist = dist 


345 def extract_dist(self, make_copy=True): 

346 """Extracts the particle distribution used by the filter. 


348 Parameters 

349 ---------- 

350 make_copy : bool, optional 

351 Flag indicating if a deepcopy of the distribution should be 

352 performed. The default is True. 


354 Returns 

355 ------- 

356 :class:`gncpy.distributions.ParticleDistribution` 

357 Particle distribution object used by the filter 

358 """ 

359 if make_copy: 

360 return deepcopy(self._particleDist) 

361 else: 

362 return self._particleDist 


364 def init_particles(self, particle_lst): 

365 """Initializes the particle distribution with the given list of points. 


367 Parameters 

368 ---------- 

369 particle_lst : list 

370 List of numpy arrays, one for each particle. 

371 """ 

372 num_parts = len(particle_lst) 

373 if num_parts <= 0: 

374 warn("No particles to initialize. SKIPPING") 

375 return 

376 self._particleDist.clear_particles() 

377 self._particleDist.add_particle(particle_lst, [1.0 / num_parts] * num_parts) 


379 def _calc_state(self): 

380 return self._particleDist.mean 


382 def predict( 

383 self, timestep, dyn_fun_params=(), sampling_args=(), transition_args=() 

384 ): 

385 """Predicts the next state. 


387 Parameters 

388 ---------- 

389 timestep : float 

390 Current timestep. 

391 dyn_fun_params : tuple, optional 

392 Extra arguments to be passed to the dynamics function. The default 

393 is (). 

394 sampling_args : tuple, optional 

395 Extra arguments to be passed to the proposal sampling function. 

396 The default is (). 


398 Raises 

399 ------ 

400 RuntimeError 

401 If no state model is set. 


403 Returns 

404 ------- 

405 N x 1 numpy array 

406 predicted state. 


408 """ 

409 if self._dyn_obj is not None: 

410 self.prop_parts = [ 

411 self._dyn_obj.propagate_state(timestep, x, state_args=dyn_fun_params) 

412 for x in self._particleDist.particles 

413 ] 

414 mean = self._dyn_obj.propagate_state( 

415 timestep, self._particleDist.mean, state_args=dyn_fun_params 

416 ) 

417 elif self._dyn_fnc is not None: 

418 self.prop_parts = [ 

419 self._dyn_fnc(timestep, x, *dyn_fun_params) 

420 for x in self._particleDist.particles 

421 ] 

422 mean = self._dyn_fnc(timestep, self._particleDist.mean, *dyn_fun_params) 

423 else: 

424 raise RuntimeError("No state model set") 

425 new_weights = [ 

426 w * self.transition_prob_fnc(x, mean, *transition_args) 

427 if self.transition_prob_fnc is not None 

428 else w 

429 for x, w in zip(self.prop_parts, self._particleDist.weights) 

430 ] 


432 new_parts = [ 

433 self.proposal_sampling_fnc(p, self.rng, *sampling_args) 

434 for p in self.prop_parts 

435 ] 


437 self._particleDist.clear_particles() 

438 for p, w in zip(new_parts, new_weights): 

439 part = gdistrib.Particle(point=p) 

440 self._particleDist.add_particle(part, w) 

441 return self._calc_state() 


443 def _est_meas(self, timestep, cur_state, n_meas, meas_fun_args): 

444 if self._meas_fnc is not None: 

445 est_meas = self._meas_fnc(timestep, cur_state, *meas_fun_args) 

446 elif self._meas_mat is not None: 

447 est_meas = self._meas_mat @ cur_state 

448 else: 

449 raise RuntimeError("No measurement model set") 

450 return est_meas 


452 def _selection(self, unnorm_weights, rel_likeli_in=None): 

453 new_parts = [None] * self.num_particles 

454 old_weights = [None] * self.num_particles 

455 rel_likeli_out = [None] * self.num_particles 

456 inds_kept = [] 

457 probs = self.rng.random(self.num_particles) 

458 cumulative_weight = np.cumsum(self._particleDist.weights) 

459 failed = False 

460 for ii, r in enumerate(probs): 

461 inds = np.where(cumulative_weight >= r)[0] 

462 if inds.size > 0: 

463 new_parts[ii] = deepcopy(self._particleDist._particles[inds[0]]) 

464 old_weights[ii] = unnorm_weights[inds[0]] 

465 if rel_likeli_in is not None: 

466 rel_likeli_out[ii] = rel_likeli_in[inds[0]] 

467 if inds[0] not in inds_kept: 

468 inds_kept.append(inds[0]) 

469 else: 

470 failed = True 

471 if failed: 

472 tot = np.sum(self._particleDist.weights) 

473 self._particleDist.clear_particles() 

474 msg = ( 

475 "Failed to select enough particles, " 

476 + "check weights (sum = {})".format(tot) 

477 ) 

478 raise gerr.ParticleDepletionError(msg) 

479 inds_removed = [ 

480 ii for ii in range(0, self.num_particles) if ii not in inds_kept 

481 ] 


483 self._particleDist.clear_particles() 

484 w = 1 / len(new_parts) 

485 self._particleDist.add_particle(new_parts, [w] * len(new_parts)) 


487 return inds_removed, old_weights, rel_likeli_out 


489 def correct( 

490 self, 

491 timestep, 

492 meas, 

493 meas_fun_args=(), 

494 meas_likely_args=(), 

495 proposal_args=(), 

496 selection=True, 

497 ): 

498 """Corrects the state estimate. 


500 Parameters 

501 ---------- 

502 timestep : float 

503 Current timestep. 

504 meas : Nm x 1 numpy array 

505 Current measurement. 

506 meas_fun_args : tuple, optional 

507 Arguments for the measurement matrix function if one has 

508 been specified. The default is (). 

509 meas_likely_args : tuple, optional 

510 additional agruments for the measurement likelihood function. 

511 The default is (). 

512 proposal_args : tuple, optional 

513 Additional arguments for the proposal distribution function. The 

514 default is (). 

515 selection : bool, optional 

516 Flag indicating if the selection step should be performed. The 

517 default is True. 


519 Raises 

520 ------ 

521 RuntimeError 

522 If no measurement model is set 


524 Returns 

525 ------- 

526 state : N x 1 numpy array 

527 corrected state. 

528 rel_likeli : numpy array 

529 The unnormalized measurement likelihood of each particle. 

530 inds_removed : list 

531 each element is an int representing the index of any particles 

532 that were removed during the selection process. 


534 """ 

535 # calculate weights 

536 est_meas = [ 

537 self._est_meas(timestep, p, meas.size, meas_fun_args) 

538 for p in self._particleDist.particles 

539 ] 

540 if self.meas_likelihood_fnc is None: 

541 rel_likeli = np.ones(len(est_meas)) 

542 else: 

543 rel_likeli = np.array( 

544 [self.meas_likelihood_fnc(meas, y, *meas_likely_args) for y in est_meas] 

545 ).ravel() 

546 if self.proposal_fnc is None or len(self.prop_parts) == 0: 

547 prop_fit = np.ones(len(self._particleDist.particles)) 

548 else: 

549 prop_fit = np.array( 

550 [ 

551 self.proposal_fnc(x_hat, cond, meas, *proposal_args) 

552 for x_hat, cond in zip( 

553 self._particleDist.particles, self.prop_parts 

554 ) 

555 ] 

556 ).ravel() 

557 inds = np.where(prop_fit < np.finfo(float).eps)[0] 

558 if inds.size > 0: 

559 prop_fit[inds] = np.finfo(float).eps 

560 unnorm_weights = rel_likeli / prop_fit * np.array(self._particleDist.weights) 


562 tot = np.sum(unnorm_weights) 

563 if tot > 0 and tot != np.inf: 

564 weights = unnorm_weights / tot 

565 else: 

566 weights = np.inf * np.ones(unnorm_weights.size) 

567 self._particleDist.update_weights(weights) 


569 # resample 

570 if selection: 

571 inds_removed, rel_likeli = self._selection( 

572 unnorm_weights, rel_likeli_in=rel_likeli.tolist() 

573 )[0:3:2] 

574 else: 

575 inds_removed = [] 

576 return (self._calc_state(), rel_likeli, inds_removed) 


578 def plot_particles( 

579 self, 

580 inds, 

581 title="Particle Distribution", 

582 x_lbl="State", 

583 y_lbl="Probability", 

584 **kwargs 

585 ): 

586 """Plots the particle distribution. 


588 This will either plot a histogram for a single index, or plot a 2-d 

589 heatmap/histogram if a list of 2 indices are given. The 1-d case will 

590 have the counts normalized to represent the probability. 


592 Parameters 

593 ---------- 

594 inds : int or list 

595 Index of the particle vector to plot. 

596 title : string, optional 

597 Title of the plot. The default is 'Particle Distribution'. 

598 x_lbl : string, optional 

599 X-axis label. The default is 'State'. 

600 y_lbl : string, optional 

601 Y-axis label. The default is 'Probability'. 

602 **kwargs : dict 

603 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

604 function. Values implemented here are `f_hndl`, `lgnd_loc`, and 

605 any values relating to title/axis text formatting. 


607 Returns 

608 ------- 

609 f_hndl : matplotlib figure 

610 Figure object the data was plotted on. 

611 """ 

612 opts = pltUtil.init_plotting_opts(**kwargs) 

613 f_hndl = opts["f_hndl"] 

614 lgnd_loc = opts["lgnd_loc"] 


616 if f_hndl is None: 

617 f_hndl = plt.figure() 

618 f_hndl.add_subplot(1, 1, 1) 

619 h_opts = {"histtype": "stepfilled", "bins": "auto", "density": True} 

620 if (not isinstance(inds, list)) or len(inds) == 1: 

621 if isinstance(inds, list): 

622 ii = inds[0] 

623 else: 

624 ii = inds 

625 x = [p[ii, 0] for p in self._particleDist.particles] 

626 f_hndl.axes[0].hist(x, **h_opts) 

627 else: 

628 x = [p[inds[0], 0] for p in self._particleDist.particles] 

629 y = [p[inds[1], 0] for p in self._particleDist.particles] 

630 f_hndl.axes[0].hist2d(x, y) 

631 pltUtil.set_title_label(f_hndl, 0, opts, ttl=title, x_lbl=x_lbl, y_lbl=y_lbl) 

632 if lgnd_loc is not None: 

633 plt.legend(loc=lgnd_loc) 

634 plt.tight_layout() 


636 return f_hndl 


638 def plot_weighted_particles( 

639 self, 

640 inds, 

641 x_lbl="State", 

642 y_lbl="Weight", 

643 title="Weighted Particle Distribution", 

644 **kwargs 

645 ): 

646 """Plots the weight vs state distribution of the particles. 


648 This generates a bar chart and only works for single indices. 


650 Parameters 

651 ---------- 

652 inds : int 

653 Index of the particle vector to plot. 

654 x_lbl : string, optional 

655 X-axis label. The default is 'State'. 

656 y_lbl : string, optional 

657 Y-axis label. The default is 'Weight'. 

658 title : string, optional 

659 Title of the plot. The default is 'Weighted Particle Distribution'. 

660 **kwargs : dict 

661 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

662 function. Values implemented here are `f_hndl`, `lgnd_loc`, and 

663 any values relating to title/axis text formatting. 


665 Returns 

666 ------- 

667 f_hndl : matplotlib figure 

668 Figure object the data was plotted on. 

669 """ 

670 opts = pltUtil.init_plotting_opts(**kwargs) 

671 f_hndl = opts["f_hndl"] 

672 lgnd_loc = opts["lgnd_loc"] 


674 if f_hndl is None: 

675 f_hndl = plt.figure() 

676 f_hndl.add_subplot(1, 1, 1) 

677 if (not isinstance(inds, list)) or len(inds) == 1: 

678 if isinstance(inds, list): 

679 ii = inds[0] 

680 else: 

681 ii = inds 

682 x = [p[ii, 0] for p in self._particleDist.particles] 

683 y = [w for p, w in self._particleDist] 

684 f_hndl.axes[0].bar(x, y) 

685 else: 

686 warn("Only 1 element supported for weighted particle distribution") 

687 pltUtil.set_title_label(f_hndl, 0, opts, ttl=title, x_lbl=x_lbl, y_lbl=y_lbl) 

688 if lgnd_loc is not None: 

689 plt.legend(loc=lgnd_loc) 

690 plt.tight_layout() 


692 return f_hndl