Coverage for src/gncpy/filters/bootstrap_filter.py: 86%
58 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
1import numpy as np
2import numpy.random as rnd
3from copy import deepcopy
4from warnings import warn
6import gncpy.distributions as gdistrib
7import gncpy.errors as gerr
8from gncpy.filters.bayes_filter import BayesFilter
11class BootstrapFilter(BayesFilter):
12 """Stripped down version of the :class:`.ParticleFilter`.
14 This is an alternative implementation of a basic Particle filter. This
15 removes some of the quality of life features of the :class:`.ParticleFilter`
16 class and can be more complicated to setup. But it may provide runtime improvements
17 for simple cases. Most times it is advised to use the :class:`.ParticleFilter`
18 instead of this class. Most other derived classes use the :class:`.ParticleFilter`
19 class as a base.
20 """
22 def __init__(
23 self,
24 importance_dist_fnc=None,
25 importance_weight_fnc=None,
26 particleDistribution=None,
27 rng=None,
28 **kwargs
29 ):
30 """Initializes the object.
32 Parameters
33 ----------
34 importance_dist_fnc : callable, optional
35 Must have the signature `f(parts, rng)` where `parts` is an
36 instance of :class:`.distributions.SimpleParticleDistribution`
37 and `rng` is a numpy random generator. It must return a numpy array
38 of new particles for a :class:`.distributions.SimpleParticleDistribution`.
39 Any state transitions to a new timestep must happen within this
40 function. The default is None.
41 importance_weight_fnc : callable, optional
42 Must have the signature `f(meas, parts)` where `meas` is an Nm x 1
43 numpy array representing the measurement and `parts` is the
44 numpy array of particles from a :class:`.distributions.SimpleParticleDistribution`
45 object. It must return a numpy array of weights, one for each
46 particle. The default is None.
47 particleDistribution : :class:`.distributions.SimpleParticleDistribution`, optional
48 Initial particle distribution to use. The default is None.
49 rng : numpy random generator, optional
50 Random number generator to use. If none supplied then the numpy default
51 is used. The default is None.
52 **kwargs : dict
53 Additional arguments for the parent constructor.
54 """
55 super().__init__(**kwargs)
56 self.importance_dist_fnc = importance_dist_fnc
57 self.importance_weight_fnc = importance_weight_fnc
58 if particleDistribution is None:
59 self.particleDistribution = gdistrib.SimpleParticleDistribution()
60 else:
61 self.particleDistribution = particleDistribution
62 if rng is None:
63 self.rng = rnd.default_rng()
64 else:
65 self.rng = rng
67 def save_filter_state(self):
68 """Saves filter variables so they can be restored later."""
69 filt_state = super().save_filter_state()
70 filt_state["importance_dist_fnc"] = self.importance_dist_fnc
71 filt_state["importance_weight_fnc"] = self.importance_weight_fnc
72 filt_state["particleDistribution"] = deepcopy(self.particleDistribution)
73 filt_state["rng"] = self.rng
75 return filt_state
77 def load_filter_state(self, filt_state):
78 """Initializes filter using saved filter state.
80 Attributes
81 ----------
82 filt_state : dict
83 Dictionary generated by :meth:`save_filter_state`.
84 """
85 super().load_filter_state(filt_state)
87 self.importance_dist_fnc = filt_state["importance_dist_fnc"]
88 self.importance_weight_fnc = filt_state["importance_weight_fnc"]
89 self.particleDistribution = filt_state["particleDistribution"]
90 self.rng = filt_state["rng"]
92 def predict(self, timestep):
93 """Prediction step of the filter.
95 Calls the importance distribution function to generate new samples of
96 particles.
98 Parameters
99 ----------
100 timestep : float
101 Current timestep.
103 Returns
104 -------
105 N x 1
106 mean estimate of the particles.
107 """
108 self.particleDistribution.particles = self.importance_dist_fnc(
109 self.particleDistribution, self.rng
110 )
112 shape = (self.particleDistribution.particles.shape[1], 1)
113 return np.mean(self.particleDistribution.particles, axis=0).reshape(shape)
115 def correct(self, timestep, meas):
116 """Correction step of the filter.
118 Parameters
119 ----------
120 timestep : float
121 Current timestep.
122 meas : Nm x 1 numpy array
123 Current measurement.
125 Raises
126 ------
127 :class:`gerr.ParticleDepletionError`
128 If all particles weights sum to zero (all particles will be removed).
130 Returns
131 -------
132 N x 1 numpy array
133 mean estimate of the particles.
134 """
135 self.particleDistribution.weights *= self.importance_weight_fnc(
136 meas, self.particleDistribution.particles
137 )
138 tot = np.sum(self.particleDistribution.weights)
139 if tot <= 0:
140 raise gerr.ParticleDepletionError("Importance weights sum to 0.")
141 self.particleDistribution.weights /= tot
143 # selection
144 num_parts = self.particleDistribution.num_particles
145 keep_inds = self.rng.choice(
146 np.array(range(self.particleDistribution.weights.size)),
147 p=self.particleDistribution.weights,
148 size=num_parts,
149 )
150 unique_inds, counts = np.unique(keep_inds, return_counts=True)
151 self.particleDistribution.num_parts_per_ind = counts
152 self.particleDistribution.particles = self.particleDistribution.particles[
153 unique_inds, :
154 ]
155 self.particleDistribution.weights = (
156 1 / num_parts * np.ones(self.particleDistribution.particles.shape[0])
157 )
159 if unique_inds.size <= 1:
160 msg = "Only {:d} particles selected".format(unique_inds.size)
161 raise gerr.ParticleDepletionError(msg)
162 # weights are all equal here so don't need weighted sum
163 shape = (self.particleDistribution.particles.shape[1], 1)
164 return np.mean(self.particleDistribution.particles, axis=0).reshape(shape)
166 def set_state_model(self, **kwargs):
167 """Not used by the Bootstrap filter."""
168 warn(
169 "Not used by BootstrapFilter, directly handled by importance_dist_fnc.",
170 RuntimeWarning,
171 )
173 def set_measurement_model(self, **kwargs):
174 """Not used by the Bootstrap filter."""
175 warn(
176 "Not used by BootstrapFilter, directly handled by importance_weight_fnc.",
177 RuntimeWarning,
178 )
180 def plot_particles(self, inds, **kwargs):
181 """Wrapper for :class:`.distributions.SimpleParticleDistribution.plot_particles`."""
182 return self.particleDistribution.plot_particles(inds, **kwargs)