Coverage for src/gncpy/filters/particle_filter.py: 68%

234 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-19 05:48 +0000

1import numpy as np 

2import numpy.random as rnd 

3import matplotlib.pyplot as plt 

4from copy import deepcopy 

5from warnings import warn 

6 

7import gncpy.distributions as gdistrib 

8import gncpy.errors as gerr 

9import gncpy.plotting as pltUtil 

10from gncpy.filters.bayes_filter import BayesFilter 

11 

12 

13class ParticleFilter(BayesFilter): 

14 """Implements a basic Particle Filter. 

15 

16 Notes 

17 ----- 

18 The implementation is based on 

19 :cite:`Simon2006_OptimalStateEstimationKalmanHInfinityandNonlinearApproaches` 

20 and uses Sampling-Importance Resampling (SIR) sampling. Other resampling 

21 methods can be added in derived classes. 

22 

23 Attributes 

24 ---------- 

25 require_copy_prop_parts : bool 

26 Flag indicating if the propagated particles need to be copied if this 

27 filter is being manipulated externally. This is a constant value that 

28 should not be modified outside of the class, but can be overridden by 

29 inherited classes. 

30 require_copy_can_dist : bool 

31 Flag indicating if a candidate distribution needs to be copied if this 

32 filter is being manipulated externally. This is a constant value that 

33 should not be modified outside of the class, but can be overridden by 

34 inherited classes. 

35 """ 

36 

37 require_copy_prop_parts = True 

38 require_copy_can_dist = False 

39 

40 def __init__( 

41 self, 

42 dyn_obj=None, 

43 dyn_fun=None, 

44 part_dist=None, 

45 transition_prob_fnc=None, 

46 rng=None, 

47 **kwargs 

48 ): 

49 

50 self.__meas_likelihood_fnc = None 

51 self.__proposal_sampling_fnc = None 

52 self.__proposal_fnc = None 

53 self.__transition_prob_fnc = None 

54 

55 if rng is None: 

56 rng = rnd.default_rng(1) 

57 self.rng = rng 

58 

59 self._dyn_fnc = None 

60 self._dyn_obj = None 

61 

62 self._meas_mat = None 

63 self._meas_fnc = None 

64 

65 if dyn_obj is not None or dyn_fun is not None: 

66 self.set_state_model(dyn_obj=dyn_obj, dyn_fun=dyn_fun) 

67 self._particleDist = gdistrib.ParticleDistribution() 

68 if part_dist is not None: 

69 self.init_from_dist(part_dist) 

70 self.prop_parts = [] 

71 

72 super().__init__(**kwargs) 

73 

74 def save_filter_state(self): 

75 """Saves filter variables so they can be restored later.""" 

76 filt_state = super().save_filter_state() 

77 

78 filt_state["__meas_likelihood_fnc"] = self.__meas_likelihood_fnc 

79 filt_state["__proposal_sampling_fnc"] = self.__proposal_sampling_fnc 

80 filt_state["__proposal_fnc"] = self.__proposal_fnc 

81 filt_state["__transition_prob_fnc"] = self.__transition_prob_fnc 

82 

83 filt_state["rng"] = self.rng 

84 

85 filt_state["_dyn_fnc"] = self._dyn_fnc 

86 filt_state["_dyn_obj"] = self._dyn_obj 

87 

88 filt_state["_meas_mat"] = self._meas_mat 

89 filt_state["_meas_fnc"] = self._meas_fnc 

90 

91 filt_state["_particleDist"] = deepcopy(self._particleDist) 

92 filt_state["prop_parts"] = deepcopy(self.prop_parts) 

93 

94 return filt_state 

95 

96 def load_filter_state(self, filt_state): 

97 """Initializes filter using saved filter state. 

98 

99 Attributes 

100 ---------- 

101 filt_state : dict 

102 Dictionary generated by :meth:`save_filter_state`. 

103 """ 

104 super().load_filter_state(filt_state) 

105 

106 self.__meas_likelihood_fnc = filt_state["__meas_likelihood_fnc"] 

107 self.__proposal_sampling_fnc = filt_state["__proposal_sampling_fnc"] 

108 self.__proposal_fnc = filt_state["__proposal_fnc"] 

109 self.__transition_prob_fnc = filt_state["__transition_prob_fnc"] 

110 

111 self.rng = filt_state["rng"] 

112 

113 self._dyn_fnc = filt_state["_dyn_fnc"] 

114 self._dyn_obj = filt_state["_dyn_obj"] 

115 

116 self._meas_mat = filt_state["_meas_mat"] 

117 self._meas_fnc = filt_state["_meas_fnc"] 

118 

119 self._particleDist = filt_state["_particleDist"] 

120 self.prop_parts = filt_state["prop_parts"] 

121 

122 @property 

123 def meas_likelihood_fnc(self): 

124 r"""A function that returns the likelihood of the measurement. 

125 

126 This must have the signature :code:`f(y, y_hat, *args)` where `y` is 

127 the measurement as an Nm x 1 numpy array, and `y_hat` is the estimated 

128 measurement. 

129 

130 Notes 

131 ----- 

132 This represents :math:`p(y_t \vert x_t)` in the importance 

133 weight 

134 

135 .. math:: 

136 

137 w_t = w_{t-1} \frac{p(y_t \vert x_t) p(x_t \vert x_{t-1})}{q(x_t \vert x_{t-1}, y_t)} 

138 

139 Returns 

140 ------- 

141 callable 

142 function to return the measurement likelihood. 

143 """ 

144 return self.__meas_likelihood_fnc 

145 

146 @meas_likelihood_fnc.setter 

147 def meas_likelihood_fnc(self, val): 

148 self.__meas_likelihood_fnc = val 

149 

150 @property 

151 def proposal_fnc(self): 

152 r"""A function that returns the probability for the proposal distribution. 

153 

154 This must have the signature :code:`f(x_hat, x, y, *args)` where 

155 `x_hat` is a :class:`gncpy.distributions.Particle` of the estimated 

156 state, `x` is the particle it is conditioned on, and `y` is the 

157 measurement. 

158 

159 Notes 

160 ----- 

161 This represents :math:`q(x_t \vert x_{t-1}, y_t)` in the importance 

162 weight 

163 

164 .. math:: 

165 

166 w_t = w_{t-1} \frac{p(y_t \vert x_t) p(x_t \vert x_{t-1})}{q(x_t \vert x_{t-1}, y_t)} 

167 

168 Returns 

169 ------- 

170 callable 

171 function to return the proposal probability. 

172 """ 

173 return self.__proposal_fnc 

174 

175 @proposal_fnc.setter 

176 def proposal_fnc(self, val): 

177 self.__proposal_fnc = val 

178 

179 @property 

180 def proposal_sampling_fnc(self): 

181 """A function that returns a random sample from the proposal distribtion. 

182 

183 This should be consistent with the PDF specified in the 

184 :meth:`gncpy.filters.ParticleFilter.proposal_fnc`. 

185 

186 Returns 

187 ------- 

188 callable 

189 function to return a random sample. 

190 """ 

191 return self.__proposal_sampling_fnc 

192 

193 @proposal_sampling_fnc.setter 

194 def proposal_sampling_fnc(self, val): 

195 self.__proposal_sampling_fnc = val 

196 

197 @property 

198 def transition_prob_fnc(self): 

199 r"""A function that returns the transition probability for the state. 

200 

201 This must have the signature :code:`f(x_hat, x, *args)` where 

202 `x_hat` is an N x 1 numpy array representing the propagated state, and 

203 `x` is the state it is conditioned on. 

204 

205 Notes 

206 ----- 

207 This represents :math:`p(x_t \vert x_{t-1})` in the importance 

208 weight 

209 

210 .. math:: 

211 

212 w_t = w_{t-1} \frac{p(y_t \vert x_t) p(x_t \vert x_{t-1})}{q(x_t \vert x_{t-1}, y_t)} 

213 

214 Returns 

215 ------- 

216 callable 

217 function to return the transition probability. 

218 """ 

219 return self.__transition_prob_fnc 

220 

221 @transition_prob_fnc.setter 

222 def transition_prob_fnc(self, val): 

223 self.__transition_prob_fnc = val 

224 

225 def set_state_model(self, dyn_obj=None, dyn_fun=None): 

226 """Sets the state model. 

227 

228 Parameters 

229 ---------- 

230 dyn_obj : :class:gncpy.dynamics.DynamicsBase`, optional 

231 Dynamic object to use. The default is None. 

232 dyn_fun : callable, optional 

233 function that returns the next state. It must have the signature 

234 `f(t, x, *args)` and return a N x 1 numpy array. The default is None. 

235 

236 Raises 

237 ------ 

238 RuntimeError 

239 If no model is specified. 

240 

241 Returns 

242 ------- 

243 None. 

244 """ 

245 if dyn_obj is not None: 

246 self._dyn_obj = deepcopy(dyn_obj) 

247 elif dyn_fun is not None: 

248 self._dyn_fnc = dyn_fun 

249 else: 

250 msg = "Invalid state model specified. Check arguments" 

251 raise RuntimeError(msg) 

252 

253 def set_measurement_model(self, meas_mat=None, meas_fun=None): 

254 r"""Sets the measurement model for the filter. 

255 

256 This can either set the constant measurement matrix, or a set of 

257 non-linear functions (potentially time varying) to map states to 

258 measurements. 

259 

260 Notes 

261 ----- 

262 The constant matrix assumes a measurement model of the form 

263 

264 .. math:: 

265 \tilde{y}_{k+1} = H x_{k+1}^- 

266 

267 and the non-linear case assumes 

268 

269 .. math:: 

270 \tilde{y}_{k+1} = h(t, x_{k+1}^-) 

271 

272 Parameters 

273 ---------- 

274 meas_mat : Nm x N numpy array, optional 

275 Measurement matrix that transforms the state to estimated 

276 measurements. The default is None. 

277 meas_fun_lst : list, optional 

278 Non-linear functions that return the expected measurement for the 

279 given state. Each function must have the signature `h(t, x, *args)`. 

280 The default is None. 

281 

282 Raises 

283 ------ 

284 RuntimeError 

285 Rasied if no arguments are specified. 

286 

287 Returns 

288 ------- 

289 None. 

290 """ 

291 if meas_mat is not None: 

292 self._meas_mat = meas_mat 

293 elif meas_fun is not None: 

294 self._meas_fnc = meas_fun 

295 else: 

296 raise RuntimeError("Invalid combination of inputs") 

297 

298 @property 

299 def cov(self): 

300 """Read only covariance of the particles. 

301 

302 Returns 

303 ------- 

304 N x N numpy array 

305 covariance matrix. 

306 

307 """ 

308 return self._particleDist.covariance 

309 

310 @cov.setter 

311 def cov(self, x): 

312 raise RuntimeError("Covariance is read only") 

313 

314 @property 

315 def num_particles(self): 

316 """Read only number of particles used by the filter. 

317 

318 Returns 

319 ------- 

320 int 

321 Number of particles. 

322 """ 

323 return self._particleDist.num_particles 

324 

325 def init_from_dist(self, dist, make_copy=True): 

326 """Initialize the distribution from a distribution object. 

327 

328 Parameters 

329 ---------- 

330 dist : :class:`gncpy.distributions.ParticleDistribution` 

331 Distribution object to use. 

332 make_copy : bool, optional 

333 Flag indicating if a deepcopy of the input distribution should be 

334 performed. The default is True. 

335 

336 Returns 

337 ------- 

338 None. 

339 """ 

340 if make_copy: 

341 self._particleDist = deepcopy(dist) 

342 else: 

343 self._particleDist = dist 

344 

345 def extract_dist(self, make_copy=True): 

346 """Extracts the particle distribution used by the filter. 

347 

348 Parameters 

349 ---------- 

350 make_copy : bool, optional 

351 Flag indicating if a deepcopy of the distribution should be 

352 performed. The default is True. 

353 

354 Returns 

355 ------- 

356 :class:`gncpy.distributions.ParticleDistribution` 

357 Particle distribution object used by the filter 

358 """ 

359 if make_copy: 

360 return deepcopy(self._particleDist) 

361 else: 

362 return self._particleDist 

363 

364 def init_particles(self, particle_lst): 

365 """Initializes the particle distribution with the given list of points. 

366 

367 Parameters 

368 ---------- 

369 particle_lst : list 

370 List of numpy arrays, one for each particle. 

371 """ 

372 num_parts = len(particle_lst) 

373 if num_parts <= 0: 

374 warn("No particles to initialize. SKIPPING") 

375 return 

376 self._particleDist.clear_particles() 

377 self._particleDist.add_particle(particle_lst, [1.0 / num_parts] * num_parts) 

378 

379 def _calc_state(self): 

380 return self._particleDist.mean 

381 

382 def predict( 

383 self, timestep, dyn_fun_params=(), sampling_args=(), transition_args=() 

384 ): 

385 """Predicts the next state. 

386 

387 Parameters 

388 ---------- 

389 timestep : float 

390 Current timestep. 

391 dyn_fun_params : tuple, optional 

392 Extra arguments to be passed to the dynamics function. The default 

393 is (). 

394 sampling_args : tuple, optional 

395 Extra arguments to be passed to the proposal sampling function. 

396 The default is (). 

397 

398 Raises 

399 ------ 

400 RuntimeError 

401 If no state model is set. 

402 

403 Returns 

404 ------- 

405 N x 1 numpy array 

406 predicted state. 

407 

408 """ 

409 if self._dyn_obj is not None: 

410 self.prop_parts = [ 

411 self._dyn_obj.propagate_state(timestep, x, state_args=dyn_fun_params) 

412 for x in self._particleDist.particles 

413 ] 

414 mean = self._dyn_obj.propagate_state( 

415 timestep, self._particleDist.mean, state_args=dyn_fun_params 

416 ) 

417 elif self._dyn_fnc is not None: 

418 self.prop_parts = [ 

419 self._dyn_fnc(timestep, x, *dyn_fun_params) 

420 for x in self._particleDist.particles 

421 ] 

422 mean = self._dyn_fnc(timestep, self._particleDist.mean, *dyn_fun_params) 

423 else: 

424 raise RuntimeError("No state model set") 

425 new_weights = [ 

426 w * self.transition_prob_fnc(x, mean, *transition_args) 

427 if self.transition_prob_fnc is not None 

428 else w 

429 for x, w in zip(self.prop_parts, self._particleDist.weights) 

430 ] 

431 

432 new_parts = [ 

433 self.proposal_sampling_fnc(p, self.rng, *sampling_args) 

434 for p in self.prop_parts 

435 ] 

436 

437 self._particleDist.clear_particles() 

438 for p, w in zip(new_parts, new_weights): 

439 part = gdistrib.Particle(point=p) 

440 self._particleDist.add_particle(part, w) 

441 return self._calc_state() 

442 

443 def _est_meas(self, timestep, cur_state, n_meas, meas_fun_args): 

444 if self._meas_fnc is not None: 

445 est_meas = self._meas_fnc(timestep, cur_state, *meas_fun_args) 

446 elif self._meas_mat is not None: 

447 est_meas = self._meas_mat @ cur_state 

448 else: 

449 raise RuntimeError("No measurement model set") 

450 return est_meas 

451 

452 def _selection(self, unnorm_weights, rel_likeli_in=None): 

453 new_parts = [None] * self.num_particles 

454 old_weights = [None] * self.num_particles 

455 rel_likeli_out = [None] * self.num_particles 

456 inds_kept = [] 

457 probs = self.rng.random(self.num_particles) 

458 cumulative_weight = np.cumsum(self._particleDist.weights) 

459 failed = False 

460 for ii, r in enumerate(probs): 

461 inds = np.where(cumulative_weight >= r)[0] 

462 if inds.size > 0: 

463 new_parts[ii] = deepcopy(self._particleDist._particles[inds[0]]) 

464 old_weights[ii] = unnorm_weights[inds[0]] 

465 if rel_likeli_in is not None: 

466 rel_likeli_out[ii] = rel_likeli_in[inds[0]] 

467 if inds[0] not in inds_kept: 

468 inds_kept.append(inds[0]) 

469 else: 

470 failed = True 

471 if failed: 

472 tot = np.sum(self._particleDist.weights) 

473 self._particleDist.clear_particles() 

474 msg = ( 

475 "Failed to select enough particles, " 

476 + "check weights (sum = {})".format(tot) 

477 ) 

478 raise gerr.ParticleDepletionError(msg) 

479 inds_removed = [ 

480 ii for ii in range(0, self.num_particles) if ii not in inds_kept 

481 ] 

482 

483 self._particleDist.clear_particles() 

484 w = 1 / len(new_parts) 

485 self._particleDist.add_particle(new_parts, [w] * len(new_parts)) 

486 

487 return inds_removed, old_weights, rel_likeli_out 

488 

489 def correct( 

490 self, 

491 timestep, 

492 meas, 

493 meas_fun_args=(), 

494 meas_likely_args=(), 

495 proposal_args=(), 

496 selection=True, 

497 ): 

498 """Corrects the state estimate. 

499 

500 Parameters 

501 ---------- 

502 timestep : float 

503 Current timestep. 

504 meas : Nm x 1 numpy array 

505 Current measurement. 

506 meas_fun_args : tuple, optional 

507 Arguments for the measurement matrix function if one has 

508 been specified. The default is (). 

509 meas_likely_args : tuple, optional 

510 additional agruments for the measurement likelihood function. 

511 The default is (). 

512 proposal_args : tuple, optional 

513 Additional arguments for the proposal distribution function. The 

514 default is (). 

515 selection : bool, optional 

516 Flag indicating if the selection step should be performed. The 

517 default is True. 

518 

519 Raises 

520 ------ 

521 RuntimeError 

522 If no measurement model is set 

523 

524 Returns 

525 ------- 

526 state : N x 1 numpy array 

527 corrected state. 

528 rel_likeli : numpy array 

529 The unnormalized measurement likelihood of each particle. 

530 inds_removed : list 

531 each element is an int representing the index of any particles 

532 that were removed during the selection process. 

533 

534 """ 

535 # calculate weights 

536 est_meas = [ 

537 self._est_meas(timestep, p, meas.size, meas_fun_args) 

538 for p in self._particleDist.particles 

539 ] 

540 if self.meas_likelihood_fnc is None: 

541 rel_likeli = np.ones(len(est_meas)) 

542 else: 

543 rel_likeli = np.array( 

544 [self.meas_likelihood_fnc(meas, y, *meas_likely_args) for y in est_meas] 

545 ).ravel() 

546 if self.proposal_fnc is None or len(self.prop_parts) == 0: 

547 prop_fit = np.ones(len(self._particleDist.particles)) 

548 else: 

549 prop_fit = np.array( 

550 [ 

551 self.proposal_fnc(x_hat, cond, meas, *proposal_args) 

552 for x_hat, cond in zip( 

553 self._particleDist.particles, self.prop_parts 

554 ) 

555 ] 

556 ).ravel() 

557 inds = np.where(prop_fit < np.finfo(float).eps)[0] 

558 if inds.size > 0: 

559 prop_fit[inds] = np.finfo(float).eps 

560 unnorm_weights = rel_likeli / prop_fit * np.array(self._particleDist.weights) 

561 

562 tot = np.sum(unnorm_weights) 

563 if tot > 0 and tot != np.inf: 

564 weights = unnorm_weights / tot 

565 else: 

566 weights = np.inf * np.ones(unnorm_weights.size) 

567 self._particleDist.update_weights(weights) 

568 

569 # resample 

570 if selection: 

571 inds_removed, rel_likeli = self._selection( 

572 unnorm_weights, rel_likeli_in=rel_likeli.tolist() 

573 )[0:3:2] 

574 else: 

575 inds_removed = [] 

576 return (self._calc_state(), rel_likeli, inds_removed) 

577 

578 def plot_particles( 

579 self, 

580 inds, 

581 title="Particle Distribution", 

582 x_lbl="State", 

583 y_lbl="Probability", 

584 **kwargs 

585 ): 

586 """Plots the particle distribution. 

587 

588 This will either plot a histogram for a single index, or plot a 2-d 

589 heatmap/histogram if a list of 2 indices are given. The 1-d case will 

590 have the counts normalized to represent the probability. 

591 

592 Parameters 

593 ---------- 

594 inds : int or list 

595 Index of the particle vector to plot. 

596 title : string, optional 

597 Title of the plot. The default is 'Particle Distribution'. 

598 x_lbl : string, optional 

599 X-axis label. The default is 'State'. 

600 y_lbl : string, optional 

601 Y-axis label. The default is 'Probability'. 

602 **kwargs : dict 

603 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

604 function. Values implemented here are `f_hndl`, `lgnd_loc`, and 

605 any values relating to title/axis text formatting. 

606 

607 Returns 

608 ------- 

609 f_hndl : matplotlib figure 

610 Figure object the data was plotted on. 

611 """ 

612 opts = pltUtil.init_plotting_opts(**kwargs) 

613 f_hndl = opts["f_hndl"] 

614 lgnd_loc = opts["lgnd_loc"] 

615 

616 if f_hndl is None: 

617 f_hndl = plt.figure() 

618 f_hndl.add_subplot(1, 1, 1) 

619 h_opts = {"histtype": "stepfilled", "bins": "auto", "density": True} 

620 if (not isinstance(inds, list)) or len(inds) == 1: 

621 if isinstance(inds, list): 

622 ii = inds[0] 

623 else: 

624 ii = inds 

625 x = [p[ii, 0] for p in self._particleDist.particles] 

626 f_hndl.axes[0].hist(x, **h_opts) 

627 else: 

628 x = [p[inds[0], 0] for p in self._particleDist.particles] 

629 y = [p[inds[1], 0] for p in self._particleDist.particles] 

630 f_hndl.axes[0].hist2d(x, y) 

631 pltUtil.set_title_label(f_hndl, 0, opts, ttl=title, x_lbl=x_lbl, y_lbl=y_lbl) 

632 if lgnd_loc is not None: 

633 plt.legend(loc=lgnd_loc) 

634 plt.tight_layout() 

635 

636 return f_hndl 

637 

638 def plot_weighted_particles( 

639 self, 

640 inds, 

641 x_lbl="State", 

642 y_lbl="Weight", 

643 title="Weighted Particle Distribution", 

644 **kwargs 

645 ): 

646 """Plots the weight vs state distribution of the particles. 

647 

648 This generates a bar chart and only works for single indices. 

649 

650 Parameters 

651 ---------- 

652 inds : int 

653 Index of the particle vector to plot. 

654 x_lbl : string, optional 

655 X-axis label. The default is 'State'. 

656 y_lbl : string, optional 

657 Y-axis label. The default is 'Weight'. 

658 title : string, optional 

659 Title of the plot. The default is 'Weighted Particle Distribution'. 

660 **kwargs : dict 

661 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

662 function. Values implemented here are `f_hndl`, `lgnd_loc`, and 

663 any values relating to title/axis text formatting. 

664 

665 Returns 

666 ------- 

667 f_hndl : matplotlib figure 

668 Figure object the data was plotted on. 

669 """ 

670 opts = pltUtil.init_plotting_opts(**kwargs) 

671 f_hndl = opts["f_hndl"] 

672 lgnd_loc = opts["lgnd_loc"] 

673 

674 if f_hndl is None: 

675 f_hndl = plt.figure() 

676 f_hndl.add_subplot(1, 1, 1) 

677 if (not isinstance(inds, list)) or len(inds) == 1: 

678 if isinstance(inds, list): 

679 ii = inds[0] 

680 else: 

681 ii = inds 

682 x = [p[ii, 0] for p in self._particleDist.particles] 

683 y = [w for p, w in self._particleDist] 

684 f_hndl.axes[0].bar(x, y) 

685 else: 

686 warn("Only 1 element supported for weighted particle distribution") 

687 pltUtil.set_title_label(f_hndl, 0, opts, ttl=title, x_lbl=x_lbl, y_lbl=y_lbl) 

688 if lgnd_loc is not None: 

689 plt.legend(loc=lgnd_loc) 

690 plt.tight_layout() 

691 

692 return f_hndl