Coverage for src/gncpy/filters/particle_filter.py: 68%
234 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
1import numpy as np
2import numpy.random as rnd
3import matplotlib.pyplot as plt
4from copy import deepcopy
5from warnings import warn
7import gncpy.distributions as gdistrib
8import gncpy.errors as gerr
9import gncpy.plotting as pltUtil
10from gncpy.filters.bayes_filter import BayesFilter
13class ParticleFilter(BayesFilter):
14 """Implements a basic Particle Filter.
16 Notes
17 -----
18 The implementation is based on
19 :cite:`Simon2006_OptimalStateEstimationKalmanHInfinityandNonlinearApproaches`
20 and uses Sampling-Importance Resampling (SIR) sampling. Other resampling
21 methods can be added in derived classes.
23 Attributes
24 ----------
25 require_copy_prop_parts : bool
26 Flag indicating if the propagated particles need to be copied if this
27 filter is being manipulated externally. This is a constant value that
28 should not be modified outside of the class, but can be overridden by
29 inherited classes.
30 require_copy_can_dist : bool
31 Flag indicating if a candidate distribution needs to be copied if this
32 filter is being manipulated externally. This is a constant value that
33 should not be modified outside of the class, but can be overridden by
34 inherited classes.
35 """
37 require_copy_prop_parts = True
38 require_copy_can_dist = False
40 def __init__(
41 self,
42 dyn_obj=None,
43 dyn_fun=None,
44 part_dist=None,
45 transition_prob_fnc=None,
46 rng=None,
47 **kwargs
48 ):
50 self.__meas_likelihood_fnc = None
51 self.__proposal_sampling_fnc = None
52 self.__proposal_fnc = None
53 self.__transition_prob_fnc = None
55 if rng is None:
56 rng = rnd.default_rng(1)
57 self.rng = rng
59 self._dyn_fnc = None
60 self._dyn_obj = None
62 self._meas_mat = None
63 self._meas_fnc = None
65 if dyn_obj is not None or dyn_fun is not None:
66 self.set_state_model(dyn_obj=dyn_obj, dyn_fun=dyn_fun)
67 self._particleDist = gdistrib.ParticleDistribution()
68 if part_dist is not None:
69 self.init_from_dist(part_dist)
70 self.prop_parts = []
72 super().__init__(**kwargs)
74 def save_filter_state(self):
75 """Saves filter variables so they can be restored later."""
76 filt_state = super().save_filter_state()
78 filt_state["__meas_likelihood_fnc"] = self.__meas_likelihood_fnc
79 filt_state["__proposal_sampling_fnc"] = self.__proposal_sampling_fnc
80 filt_state["__proposal_fnc"] = self.__proposal_fnc
81 filt_state["__transition_prob_fnc"] = self.__transition_prob_fnc
83 filt_state["rng"] = self.rng
85 filt_state["_dyn_fnc"] = self._dyn_fnc
86 filt_state["_dyn_obj"] = self._dyn_obj
88 filt_state["_meas_mat"] = self._meas_mat
89 filt_state["_meas_fnc"] = self._meas_fnc
91 filt_state["_particleDist"] = deepcopy(self._particleDist)
92 filt_state["prop_parts"] = deepcopy(self.prop_parts)
94 return filt_state
96 def load_filter_state(self, filt_state):
97 """Initializes filter using saved filter state.
99 Attributes
100 ----------
101 filt_state : dict
102 Dictionary generated by :meth:`save_filter_state`.
103 """
104 super().load_filter_state(filt_state)
106 self.__meas_likelihood_fnc = filt_state["__meas_likelihood_fnc"]
107 self.__proposal_sampling_fnc = filt_state["__proposal_sampling_fnc"]
108 self.__proposal_fnc = filt_state["__proposal_fnc"]
109 self.__transition_prob_fnc = filt_state["__transition_prob_fnc"]
111 self.rng = filt_state["rng"]
113 self._dyn_fnc = filt_state["_dyn_fnc"]
114 self._dyn_obj = filt_state["_dyn_obj"]
116 self._meas_mat = filt_state["_meas_mat"]
117 self._meas_fnc = filt_state["_meas_fnc"]
119 self._particleDist = filt_state["_particleDist"]
120 self.prop_parts = filt_state["prop_parts"]
122 @property
123 def meas_likelihood_fnc(self):
124 r"""A function that returns the likelihood of the measurement.
126 This must have the signature :code:`f(y, y_hat, *args)` where `y` is
127 the measurement as an Nm x 1 numpy array, and `y_hat` is the estimated
128 measurement.
130 Notes
131 -----
132 This represents :math:`p(y_t \vert x_t)` in the importance
133 weight
135 .. math::
137 w_t = w_{t-1} \frac{p(y_t \vert x_t) p(x_t \vert x_{t-1})}{q(x_t \vert x_{t-1}, y_t)}
139 Returns
140 -------
141 callable
142 function to return the measurement likelihood.
143 """
144 return self.__meas_likelihood_fnc
146 @meas_likelihood_fnc.setter
147 def meas_likelihood_fnc(self, val):
148 self.__meas_likelihood_fnc = val
150 @property
151 def proposal_fnc(self):
152 r"""A function that returns the probability for the proposal distribution.
154 This must have the signature :code:`f(x_hat, x, y, *args)` where
155 `x_hat` is a :class:`gncpy.distributions.Particle` of the estimated
156 state, `x` is the particle it is conditioned on, and `y` is the
157 measurement.
159 Notes
160 -----
161 This represents :math:`q(x_t \vert x_{t-1}, y_t)` in the importance
162 weight
164 .. math::
166 w_t = w_{t-1} \frac{p(y_t \vert x_t) p(x_t \vert x_{t-1})}{q(x_t \vert x_{t-1}, y_t)}
168 Returns
169 -------
170 callable
171 function to return the proposal probability.
172 """
173 return self.__proposal_fnc
175 @proposal_fnc.setter
176 def proposal_fnc(self, val):
177 self.__proposal_fnc = val
179 @property
180 def proposal_sampling_fnc(self):
181 """A function that returns a random sample from the proposal distribtion.
183 This should be consistent with the PDF specified in the
184 :meth:`gncpy.filters.ParticleFilter.proposal_fnc`.
186 Returns
187 -------
188 callable
189 function to return a random sample.
190 """
191 return self.__proposal_sampling_fnc
193 @proposal_sampling_fnc.setter
194 def proposal_sampling_fnc(self, val):
195 self.__proposal_sampling_fnc = val
197 @property
198 def transition_prob_fnc(self):
199 r"""A function that returns the transition probability for the state.
201 This must have the signature :code:`f(x_hat, x, *args)` where
202 `x_hat` is an N x 1 numpy array representing the propagated state, and
203 `x` is the state it is conditioned on.
205 Notes
206 -----
207 This represents :math:`p(x_t \vert x_{t-1})` in the importance
208 weight
210 .. math::
212 w_t = w_{t-1} \frac{p(y_t \vert x_t) p(x_t \vert x_{t-1})}{q(x_t \vert x_{t-1}, y_t)}
214 Returns
215 -------
216 callable
217 function to return the transition probability.
218 """
219 return self.__transition_prob_fnc
221 @transition_prob_fnc.setter
222 def transition_prob_fnc(self, val):
223 self.__transition_prob_fnc = val
225 def set_state_model(self, dyn_obj=None, dyn_fun=None):
226 """Sets the state model.
228 Parameters
229 ----------
230 dyn_obj : :class:gncpy.dynamics.DynamicsBase`, optional
231 Dynamic object to use. The default is None.
232 dyn_fun : callable, optional
233 function that returns the next state. It must have the signature
234 `f(t, x, *args)` and return a N x 1 numpy array. The default is None.
236 Raises
237 ------
238 RuntimeError
239 If no model is specified.
241 Returns
242 -------
243 None.
244 """
245 if dyn_obj is not None:
246 self._dyn_obj = deepcopy(dyn_obj)
247 elif dyn_fun is not None:
248 self._dyn_fnc = dyn_fun
249 else:
250 msg = "Invalid state model specified. Check arguments"
251 raise RuntimeError(msg)
253 def set_measurement_model(self, meas_mat=None, meas_fun=None):
254 r"""Sets the measurement model for the filter.
256 This can either set the constant measurement matrix, or a set of
257 non-linear functions (potentially time varying) to map states to
258 measurements.
260 Notes
261 -----
262 The constant matrix assumes a measurement model of the form
264 .. math::
265 \tilde{y}_{k+1} = H x_{k+1}^-
267 and the non-linear case assumes
269 .. math::
270 \tilde{y}_{k+1} = h(t, x_{k+1}^-)
272 Parameters
273 ----------
274 meas_mat : Nm x N numpy array, optional
275 Measurement matrix that transforms the state to estimated
276 measurements. The default is None.
277 meas_fun_lst : list, optional
278 Non-linear functions that return the expected measurement for the
279 given state. Each function must have the signature `h(t, x, *args)`.
280 The default is None.
282 Raises
283 ------
284 RuntimeError
285 Rasied if no arguments are specified.
287 Returns
288 -------
289 None.
290 """
291 if meas_mat is not None:
292 self._meas_mat = meas_mat
293 elif meas_fun is not None:
294 self._meas_fnc = meas_fun
295 else:
296 raise RuntimeError("Invalid combination of inputs")
298 @property
299 def cov(self):
300 """Read only covariance of the particles.
302 Returns
303 -------
304 N x N numpy array
305 covariance matrix.
307 """
308 return self._particleDist.covariance
310 @cov.setter
311 def cov(self, x):
312 raise RuntimeError("Covariance is read only")
314 @property
315 def num_particles(self):
316 """Read only number of particles used by the filter.
318 Returns
319 -------
320 int
321 Number of particles.
322 """
323 return self._particleDist.num_particles
325 def init_from_dist(self, dist, make_copy=True):
326 """Initialize the distribution from a distribution object.
328 Parameters
329 ----------
330 dist : :class:`gncpy.distributions.ParticleDistribution`
331 Distribution object to use.
332 make_copy : bool, optional
333 Flag indicating if a deepcopy of the input distribution should be
334 performed. The default is True.
336 Returns
337 -------
338 None.
339 """
340 if make_copy:
341 self._particleDist = deepcopy(dist)
342 else:
343 self._particleDist = dist
345 def extract_dist(self, make_copy=True):
346 """Extracts the particle distribution used by the filter.
348 Parameters
349 ----------
350 make_copy : bool, optional
351 Flag indicating if a deepcopy of the distribution should be
352 performed. The default is True.
354 Returns
355 -------
356 :class:`gncpy.distributions.ParticleDistribution`
357 Particle distribution object used by the filter
358 """
359 if make_copy:
360 return deepcopy(self._particleDist)
361 else:
362 return self._particleDist
364 def init_particles(self, particle_lst):
365 """Initializes the particle distribution with the given list of points.
367 Parameters
368 ----------
369 particle_lst : list
370 List of numpy arrays, one for each particle.
371 """
372 num_parts = len(particle_lst)
373 if num_parts <= 0:
374 warn("No particles to initialize. SKIPPING")
375 return
376 self._particleDist.clear_particles()
377 self._particleDist.add_particle(particle_lst, [1.0 / num_parts] * num_parts)
379 def _calc_state(self):
380 return self._particleDist.mean
382 def predict(
383 self, timestep, dyn_fun_params=(), sampling_args=(), transition_args=()
384 ):
385 """Predicts the next state.
387 Parameters
388 ----------
389 timestep : float
390 Current timestep.
391 dyn_fun_params : tuple, optional
392 Extra arguments to be passed to the dynamics function. The default
393 is ().
394 sampling_args : tuple, optional
395 Extra arguments to be passed to the proposal sampling function.
396 The default is ().
398 Raises
399 ------
400 RuntimeError
401 If no state model is set.
403 Returns
404 -------
405 N x 1 numpy array
406 predicted state.
408 """
409 if self._dyn_obj is not None:
410 self.prop_parts = [
411 self._dyn_obj.propagate_state(timestep, x, state_args=dyn_fun_params)
412 for x in self._particleDist.particles
413 ]
414 mean = self._dyn_obj.propagate_state(
415 timestep, self._particleDist.mean, state_args=dyn_fun_params
416 )
417 elif self._dyn_fnc is not None:
418 self.prop_parts = [
419 self._dyn_fnc(timestep, x, *dyn_fun_params)
420 for x in self._particleDist.particles
421 ]
422 mean = self._dyn_fnc(timestep, self._particleDist.mean, *dyn_fun_params)
423 else:
424 raise RuntimeError("No state model set")
425 new_weights = [
426 w * self.transition_prob_fnc(x, mean, *transition_args)
427 if self.transition_prob_fnc is not None
428 else w
429 for x, w in zip(self.prop_parts, self._particleDist.weights)
430 ]
432 new_parts = [
433 self.proposal_sampling_fnc(p, self.rng, *sampling_args)
434 for p in self.prop_parts
435 ]
437 self._particleDist.clear_particles()
438 for p, w in zip(new_parts, new_weights):
439 part = gdistrib.Particle(point=p)
440 self._particleDist.add_particle(part, w)
441 return self._calc_state()
443 def _est_meas(self, timestep, cur_state, n_meas, meas_fun_args):
444 if self._meas_fnc is not None:
445 est_meas = self._meas_fnc(timestep, cur_state, *meas_fun_args)
446 elif self._meas_mat is not None:
447 est_meas = self._meas_mat @ cur_state
448 else:
449 raise RuntimeError("No measurement model set")
450 return est_meas
452 def _selection(self, unnorm_weights, rel_likeli_in=None):
453 new_parts = [None] * self.num_particles
454 old_weights = [None] * self.num_particles
455 rel_likeli_out = [None] * self.num_particles
456 inds_kept = []
457 probs = self.rng.random(self.num_particles)
458 cumulative_weight = np.cumsum(self._particleDist.weights)
459 failed = False
460 for ii, r in enumerate(probs):
461 inds = np.where(cumulative_weight >= r)[0]
462 if inds.size > 0:
463 new_parts[ii] = deepcopy(self._particleDist._particles[inds[0]])
464 old_weights[ii] = unnorm_weights[inds[0]]
465 if rel_likeli_in is not None:
466 rel_likeli_out[ii] = rel_likeli_in[inds[0]]
467 if inds[0] not in inds_kept:
468 inds_kept.append(inds[0])
469 else:
470 failed = True
471 if failed:
472 tot = np.sum(self._particleDist.weights)
473 self._particleDist.clear_particles()
474 msg = (
475 "Failed to select enough particles, "
476 + "check weights (sum = {})".format(tot)
477 )
478 raise gerr.ParticleDepletionError(msg)
479 inds_removed = [
480 ii for ii in range(0, self.num_particles) if ii not in inds_kept
481 ]
483 self._particleDist.clear_particles()
484 w = 1 / len(new_parts)
485 self._particleDist.add_particle(new_parts, [w] * len(new_parts))
487 return inds_removed, old_weights, rel_likeli_out
489 def correct(
490 self,
491 timestep,
492 meas,
493 meas_fun_args=(),
494 meas_likely_args=(),
495 proposal_args=(),
496 selection=True,
497 ):
498 """Corrects the state estimate.
500 Parameters
501 ----------
502 timestep : float
503 Current timestep.
504 meas : Nm x 1 numpy array
505 Current measurement.
506 meas_fun_args : tuple, optional
507 Arguments for the measurement matrix function if one has
508 been specified. The default is ().
509 meas_likely_args : tuple, optional
510 additional agruments for the measurement likelihood function.
511 The default is ().
512 proposal_args : tuple, optional
513 Additional arguments for the proposal distribution function. The
514 default is ().
515 selection : bool, optional
516 Flag indicating if the selection step should be performed. The
517 default is True.
519 Raises
520 ------
521 RuntimeError
522 If no measurement model is set
524 Returns
525 -------
526 state : N x 1 numpy array
527 corrected state.
528 rel_likeli : numpy array
529 The unnormalized measurement likelihood of each particle.
530 inds_removed : list
531 each element is an int representing the index of any particles
532 that were removed during the selection process.
534 """
535 # calculate weights
536 est_meas = [
537 self._est_meas(timestep, p, meas.size, meas_fun_args)
538 for p in self._particleDist.particles
539 ]
540 if self.meas_likelihood_fnc is None:
541 rel_likeli = np.ones(len(est_meas))
542 else:
543 rel_likeli = np.array(
544 [self.meas_likelihood_fnc(meas, y, *meas_likely_args) for y in est_meas]
545 ).ravel()
546 if self.proposal_fnc is None or len(self.prop_parts) == 0:
547 prop_fit = np.ones(len(self._particleDist.particles))
548 else:
549 prop_fit = np.array(
550 [
551 self.proposal_fnc(x_hat, cond, meas, *proposal_args)
552 for x_hat, cond in zip(
553 self._particleDist.particles, self.prop_parts
554 )
555 ]
556 ).ravel()
557 inds = np.where(prop_fit < np.finfo(float).eps)[0]
558 if inds.size > 0:
559 prop_fit[inds] = np.finfo(float).eps
560 unnorm_weights = rel_likeli / prop_fit * np.array(self._particleDist.weights)
562 tot = np.sum(unnorm_weights)
563 if tot > 0 and tot != np.inf:
564 weights = unnorm_weights / tot
565 else:
566 weights = np.inf * np.ones(unnorm_weights.size)
567 self._particleDist.update_weights(weights)
569 # resample
570 if selection:
571 inds_removed, rel_likeli = self._selection(
572 unnorm_weights, rel_likeli_in=rel_likeli.tolist()
573 )[0:3:2]
574 else:
575 inds_removed = []
576 return (self._calc_state(), rel_likeli, inds_removed)
578 def plot_particles(
579 self,
580 inds,
581 title="Particle Distribution",
582 x_lbl="State",
583 y_lbl="Probability",
584 **kwargs
585 ):
586 """Plots the particle distribution.
588 This will either plot a histogram for a single index, or plot a 2-d
589 heatmap/histogram if a list of 2 indices are given. The 1-d case will
590 have the counts normalized to represent the probability.
592 Parameters
593 ----------
594 inds : int or list
595 Index of the particle vector to plot.
596 title : string, optional
597 Title of the plot. The default is 'Particle Distribution'.
598 x_lbl : string, optional
599 X-axis label. The default is 'State'.
600 y_lbl : string, optional
601 Y-axis label. The default is 'Probability'.
602 **kwargs : dict
603 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts`
604 function. Values implemented here are `f_hndl`, `lgnd_loc`, and
605 any values relating to title/axis text formatting.
607 Returns
608 -------
609 f_hndl : matplotlib figure
610 Figure object the data was plotted on.
611 """
612 opts = pltUtil.init_plotting_opts(**kwargs)
613 f_hndl = opts["f_hndl"]
614 lgnd_loc = opts["lgnd_loc"]
616 if f_hndl is None:
617 f_hndl = plt.figure()
618 f_hndl.add_subplot(1, 1, 1)
619 h_opts = {"histtype": "stepfilled", "bins": "auto", "density": True}
620 if (not isinstance(inds, list)) or len(inds) == 1:
621 if isinstance(inds, list):
622 ii = inds[0]
623 else:
624 ii = inds
625 x = [p[ii, 0] for p in self._particleDist.particles]
626 f_hndl.axes[0].hist(x, **h_opts)
627 else:
628 x = [p[inds[0], 0] for p in self._particleDist.particles]
629 y = [p[inds[1], 0] for p in self._particleDist.particles]
630 f_hndl.axes[0].hist2d(x, y)
631 pltUtil.set_title_label(f_hndl, 0, opts, ttl=title, x_lbl=x_lbl, y_lbl=y_lbl)
632 if lgnd_loc is not None:
633 plt.legend(loc=lgnd_loc)
634 plt.tight_layout()
636 return f_hndl
638 def plot_weighted_particles(
639 self,
640 inds,
641 x_lbl="State",
642 y_lbl="Weight",
643 title="Weighted Particle Distribution",
644 **kwargs
645 ):
646 """Plots the weight vs state distribution of the particles.
648 This generates a bar chart and only works for single indices.
650 Parameters
651 ----------
652 inds : int
653 Index of the particle vector to plot.
654 x_lbl : string, optional
655 X-axis label. The default is 'State'.
656 y_lbl : string, optional
657 Y-axis label. The default is 'Weight'.
658 title : string, optional
659 Title of the plot. The default is 'Weighted Particle Distribution'.
660 **kwargs : dict
661 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts`
662 function. Values implemented here are `f_hndl`, `lgnd_loc`, and
663 any values relating to title/axis text formatting.
665 Returns
666 -------
667 f_hndl : matplotlib figure
668 Figure object the data was plotted on.
669 """
670 opts = pltUtil.init_plotting_opts(**kwargs)
671 f_hndl = opts["f_hndl"]
672 lgnd_loc = opts["lgnd_loc"]
674 if f_hndl is None:
675 f_hndl = plt.figure()
676 f_hndl.add_subplot(1, 1, 1)
677 if (not isinstance(inds, list)) or len(inds) == 1:
678 if isinstance(inds, list):
679 ii = inds[0]
680 else:
681 ii = inds
682 x = [p[ii, 0] for p in self._particleDist.particles]
683 y = [w for p, w in self._particleDist]
684 f_hndl.axes[0].bar(x, y)
685 else:
686 warn("Only 1 element supported for weighted particle distribution")
687 pltUtil.set_title_label(f_hndl, 0, opts, ttl=title, x_lbl=x_lbl, y_lbl=y_lbl)
688 if lgnd_loc is not None:
689 plt.legend(loc=lgnd_loc)
690 plt.tight_layout()
692 return f_hndl