Coverage for src/gncpy/filters/gsm_filter_base.py: 71%
273 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
1import numpy as np
2import numpy.random as rnd
3import scipy.stats as stats
4from queue import deque
5from warnings import warn
7import gncpy.distributions as gdistrib
8import gncpy.errors as gerr
9from serums.enums import GSMTypes
10from gncpy.filters.bayes_filter import BayesFilter
11from gncpy.filters.bootstrap_filter import BootstrapFilter
14class _GSMProcNoiseEstimator:
15 """Helper class for estimating proc noise in the GSM filters."""
17 def __init__(self):
18 self.q_fifo = deque([], None)
19 self.b_fifo = deque([], None)
21 self.startup_delay = 1
23 self._last_q_hat = None
24 self._call_count = 0
26 @property
27 def maxlen(self):
28 return self.q_fifo.maxlen
30 @maxlen.setter
31 def maxlen(self, val):
32 self.q_fifo = deque([], val)
33 self.b_fifo = deque([], val)
35 @property
36 def win_len(self):
37 return len(self.q_fifo)
39 def estimate_next(self, cur_est, pred_state, pred_cov, cor_state, cor_cov):
40 self._call_count += 1
42 # update bk
43 bk = pred_cov - cur_est - cor_cov
44 if self.b_fifo.maxlen is None or len(self.b_fifo) < self.b_fifo.maxlen:
45 bk_last = np.zeros(bk.shape)
46 else:
47 bk_last = self.b_fifo.pop()
48 self.b_fifo.appendleft(bk)
50 # update qk and qk_hat
51 qk = cor_state - pred_state
52 if self._last_q_hat is None:
53 self._last_q_hat = np.zeros(qk.shape)
54 if self.q_fifo.maxlen is None or len(self.q_fifo) < self.q_fifo.maxlen:
55 qk_last = np.zeros(qk.shape)
56 else:
57 qk_last = self.q_fifo.pop()
58 self.q_fifo.appendleft(qk)
60 qk_klast_diff = qk - qk_last
61 win_len = self.win_len
62 inv_win_len = 1 / win_len
63 win_len_m1 = win_len - 1
64 if self._call_count <= np.max([1, self.startup_delay]):
65 return cur_est
66 self._last_q_hat += inv_win_len * qk_klast_diff
68 # estimate cov
69 qk_khat_diff = qk - self._last_q_hat
70 qklast_khat_diff = qk_last - self._last_q_hat
71 bk_diff = bk_last - bk
72 next_est = cur_est + 1 / win_len_m1 * (
73 qk_khat_diff @ qk_khat_diff.T
74 - qklast_khat_diff @ qklast_khat_diff.T
75 + inv_win_len * qk_klast_diff @ qk_klast_diff.T
76 + win_len_m1 * inv_win_len * bk_diff
77 )
79 # for numerical reasons
80 for ii in range(next_est.shape[0]):
81 next_est[ii, ii] = np.abs(next_est[ii, ii])
82 return next_est
85class GSMFilterBase(BayesFilter):
86 """Base implementation of a Gaussian Scale Mixture (GSM) filter.
88 This should be inherited from to include specific implementations of the
89 core filter by exposing the necessary core filter attributes.
91 Notes
92 -----
93 This is based on a generic version of
94 :cite:`VilaValls2012_NonlinearBayesianFilteringintheGaussianScaleMixtureContext`
95 which extends :cite:`VilaValls2011_BayesianFilteringforNonlinearStateSpaceModelsinSymmetricStableMeasurementNoise`.
96 This class does not implement a specific core filter, that is up to the
97 child classes.
99 Attributes
100 ----------
101 enable_proc_noise_estimation : bool
102 Flag indicating if the process noise should be estimated.
103 """
105 def __init__(self, enable_proc_noise_estimation=False, **kwargs):
106 """Initializes the class.
108 Parameters
109 ----------
110 enable_proc_noise_estimation : bool, optional
111 Flag indicating if the process noise should be estimated. The
112 default is False.
113 **kwargs : dict
114 Additional keyword arguements for the parent class.
115 """
116 super().__init__(**kwargs)
118 self.enable_proc_noise_estimation = enable_proc_noise_estimation
120 self._coreFilter = None
121 self._meas_noise_filters = []
122 self._import_w_factory_lst = None
123 self._procNoiseEstimator = _GSMProcNoiseEstimator()
125 def save_filter_state(self):
126 """Saves filter variables so they can be restored later.
128 Note that to pickle the resulting dictionary the :code:`dill` package
129 may need to be used due to potential pickling of functions.
130 """
131 filt_state = super().save_filter_state()
133 filt_state["enable_proc_noise_estimation"] = self.enable_proc_noise_estimation
135 if self._coreFilter is not None:
136 filt_state["_coreFilter"] = (
137 type(self._coreFilter),
138 self._coreFilter.save_filter_state(),
139 )
140 else:
141 filt_state["_coreFilter"] = (None, self._coreFilter)
142 filt_state["_meas_noise_filters"] = [
143 (type(f), f.save_filter_state()) for f in self._meas_noise_filters
144 ]
146 filt_state["_import_w_factory_lst"] = self._import_w_factory_lst
147 filt_state["_procNoiseEstimator"] = self._procNoiseEstimator
149 return filt_state
151 def load_filter_state(self, filt_state):
152 """Initializes filter using saved filter state.
154 Attributes
155 ----------
156 filt_state : dict
157 Dictionary generated by :meth:`save_filter_state`.
158 """
159 super().load_filter_state(filt_state)
161 self.enable_proc_noise_estimation = filt_state["enable_proc_noise_estimation"]
163 cls_type = filt_state["_coreFilter"][0]
164 if cls_type is not None:
165 self._coreFilter = cls_type()
166 self._coreFilter.load_filter_state(filt_state["_coreFilter"][1])
167 else:
168 self._coreFilter = None
169 num_m_filts = len(filt_state["_meas_noise_filters"])
170 self._meas_noise_filters = [None] * num_m_filts
171 for ii, (cls_type, vals) in enumerate(filt_state["_meas_noise_filters"]):
172 if cls_type is not None:
173 self._meas_noise_filters[ii] = cls_type()
174 self._meas_noise_filters[ii].load_filter_state(vals)
175 self._import_w_factory_lst = filt_state["_import_w_factory_lst"]
176 self._procNoiseEstimator = filt_state["_procNoiseEstimator"]
178 def set_state_model(self, **kwargs):
179 """Wrapper for the core filters set state model function."""
180 if self._coreFilter is not None:
181 self._coreFilter.set_state_model(**kwargs)
182 else:
183 warn("Core filter is not set, use an inherited class.", RuntimeWarning)
185 def set_measurement_model(self, **kwargs):
186 """Wrapper for the core filters set measurement model function."""
187 if self._coreFilter is not None:
188 self._coreFilter.set_measurement_model(**kwargs)
189 else:
190 warn("Core filter is not set, use an inherited class.", RuntimeWarning)
192 def _define_student_t_pf(self, gsm, rng, num_parts):
193 def import_w_factory(inov_cov):
194 def import_w_fnc(meas, parts):
195 stds = np.sqrt(parts[:, 2] * parts[:, 1] ** 2 + inov_cov)
196 return np.array(
197 [stats.norm.pdf(meas.item(), scale=scale) for scale in stds]
198 )
200 return import_w_fnc
202 def gsm_import_dist_factory():
203 def import_dist_fnc(parts, _rng):
204 new_parts = np.nan * np.ones(parts.particles.shape)
206 disc = 0.99
207 a = (3 * disc - 1) / (2 * disc)
208 h = np.sqrt(1 - a ** 2)
209 last_means = np.mean(parts.particles, axis=0)
210 means = a * parts.particles[:, 0:2] + (1 - a) * last_means[0:2]
212 # df, sig
213 for ind in range(means.shape[1]):
214 std = np.sqrt(h ** 2 * np.cov(parts.particles[:, ind]))
216 for ii, m in enumerate(means):
217 samp = stats.norm.rvs(loc=m[ind], scale=std, random_state=_rng)
218 new_parts[ii, ind] = samp
219 df = np.mean(new_parts[:, 0])
220 if df < 0:
221 msg = "Degree of freedom must be > 0 {:.4f}".format(df)
222 raise gerr.ParticleEstimationDomainError(msg)
223 new_parts[:, 2] = stats.invgamma.rvs(
224 df / 2,
225 scale=1 / (2 / df),
226 random_state=_rng,
227 size=new_parts.shape[0],
228 )
229 return new_parts
231 return import_dist_fnc
233 pf = BootstrapFilter()
234 pf.importance_dist_fnc = gsm_import_dist_factory()
235 pf.particleDistribution = gdistrib.SimpleParticleDistribution()
237 df_scale = gsm.df_range[1] - gsm.df_range[0]
238 df_loc = gsm.df_range[0]
239 df_particles = stats.uniform.rvs(
240 loc=df_loc, scale=df_scale, size=num_parts, random_state=rng
241 )
243 sig_scale = gsm.scale_range[1] - gsm.scale_range[0]
244 sig_loc = gsm.scale_range[0]
245 sig_particles = stats.uniform.rvs(
246 loc=sig_loc, scale=sig_scale, size=num_parts, random_state=rng
247 )
249 z_particles = np.nan * np.ones(num_parts)
250 for ii, v in enumerate(df_particles):
251 z_particles[ii] = stats.invgamma.rvs(
252 v / 2, scale=1 / (2 / v), random_state=rng
253 )
254 pf.particleDistribution.particles = np.stack(
255 (df_particles, sig_particles, z_particles), axis=1
256 )
257 pf.particleDistribution.num_parts_per_ind = np.ones(num_parts)
258 pf.particleDistribution.weights = 1 / num_parts * np.ones(num_parts)
259 pf.rng = rng
261 return pf, import_w_factory
263 def _define_cauchy_pf(self, gsm, rng, num_parts):
264 def import_w_factory(inov_cov):
265 def import_w_fnc(meas, parts):
266 stds = np.sqrt(parts[:, 1] * parts[:, 0] ** 2 + inov_cov)
267 return np.array(
268 [stats.norm.pdf(meas.item(), scale=scale) for scale in stds]
269 )
271 return import_w_fnc
273 def gsm_import_dist_factory():
274 def import_dist_fnc(parts, _rng):
275 new_parts = np.nan * np.ones(parts.particles.shape)
277 disc = 0.99
278 a = (3 * disc - 1) / (2 * disc)
279 h = np.sqrt(1 - a ** 2)
280 last_means = np.mean(parts.particles, axis=0)
281 means = a * parts.particles[:, [0]] + (1 - a) * last_means[0]
283 # df, sig
284 for ind in range(means.shape[1]):
285 std = np.sqrt(h ** 2 * np.cov(parts.particles[:, ind]))
287 for ii, m in enumerate(means):
288 samp = stats.norm.rvs(loc=m[ind], scale=std, random_state=_rng)
289 new_parts[ii, ind] = samp
290 new_parts[:, 1] = stats.invgamma.rvs(
291 1 / 2, scale=1 / 2, random_state=_rng, size=new_parts.shape[0]
292 )
293 return new_parts
295 return import_dist_fnc
297 pf = BootstrapFilter()
298 pf.importance_dist_fnc = gsm_import_dist_factory()
299 pf.particleDistribution = gdistrib.SimpleParticleDistribution()
301 sig_scale = gsm.scale_range[1] - gsm.scale_range[0]
302 sig_loc = gsm.scale_range[0]
303 sig_particles = stats.uniform.rvs(
304 loc=sig_loc, scale=sig_scale, size=num_parts, random_state=rng
305 )
307 z_particles = stats.invgamma.rvs(
308 1 / 2, scale=1 / 2, random_state=rng, size=num_parts
309 )
311 pf.particleDistribution.particles = np.stack(
312 (sig_particles, z_particles), axis=1
313 )
314 pf.particleDistribution.num_parts_per_ind = np.ones(num_parts)
315 pf.particleDistribution.weights = 1 / num_parts * np.ones(num_parts)
316 pf.rng = rng
318 return pf, import_w_factory
320 def set_meas_noise_model(
321 self,
322 bootstrap_lst=None,
323 importance_weight_factory_lst=None,
324 gsm_lst=None,
325 num_parts=None,
326 rng=None,
327 ):
328 """Initializes the measurement noise estimators.
330 The filters and importance
331 weight factories can be provided or a list of
332 :class:`serums.models.GaussianScaleMixture` objecst, list of particles,
333 and a random number generator. If the latter set is given then
334 bootstrap filters are constructed automatically. The recommended way
335 of specifying the filters is to provide GSM objects.
337 Notes
338 -----
339 This uses independent bootstrap particle filters for each measurement
340 based on the provided model information.
342 Parameters
343 ----------
344 bootstrap_lst : list, optional
345 List of :class:`.BootstrapFilter` objects that have already been
346 initialized. If given then importance weight factory list must also
347 be given and of the same length. The default is None.
348 importance_weight_factory_lst : list, optional
349 List of callables, each takes a 1 x 1 numpy array or float as input
350 and returns a callable with the signature `f(z, parts)` where
351 `z` is a float that is the difference between the estimated measurement
352 and actual measurement (so the distribution is 0 mean) and `parts` is
353 a numpy array of all the particles from the bootrap filter `f` must
354 return a numpy array of weights for each particle. See the
355 :attr:`.BootstrapFilter.importance_weight_fnc` for more details. The
356 default is None.
357 gsm_lst : list, optional
358 List of :class:`serums.models.GaussianScaleMixture` objects, one
359 per measurement. Requires `num_parts` also be specified and optionally
360 `rng`. The default is None.
361 num_parts : list or int, optional
362 Number of particles to use in each automatically constructed filter.
363 If only one number is supplied all filters will use the same number
364 of particles. The default is None.
365 rng : numpy random generator, optional
366 Random number generator to use in constructed filters. If supplied,
367 each filter uses the same instance, otherwise a new generator is
368 created for each filter using numpy's default initialization routine
369 with no supplied seed. Only used if a `gsm_lst` is supplied.
371 Raises
372 ------
373 RuntimeError
374 If an incorrect combination of input arguments is provided.
376 Todo
377 ----
378 Allow for GSM Object to use location parameter when specifying noise models
379 """
380 if bootstrap_lst is not None:
381 self._meas_noise_filters = bootstrap_lst
382 if importance_weight_factory_lst is not None:
383 if len(importance_weight_factory_lst) != len(bootstrap_lst):
384 msg = (
385 "Importance weight factory list "
386 + "length ({:d}) ".format(len(importance_weight_factory_lst))
387 + "does not match the number of bootstrap filters ({:d})".format(
388 len(bootstrap_lst)
389 )
390 )
391 raise RuntimeError(msg)
392 self._import_w_factory_lst = importance_weight_factory_lst
393 else:
394 msg = (
395 "Must supply an importance weight factory when "
396 + "specifying the bootstrap filters."
397 )
398 raise RuntimeError(msg)
399 elif gsm_lst is not None:
400 num_filts = len(gsm_lst)
401 if not (isinstance(num_parts, list) or isinstance(num_parts, tuple)):
402 if num_parts is None:
403 msg = "Must specify number of particles when giving a list of GSM objects."
404 raise RuntimeError(msg)
405 else:
406 num_parts = [num_parts] * num_filts
407 self._meas_noise_filters = [None] * num_filts
408 self._import_w_factory_lst = [None] * num_filts
409 for ii, gsm in enumerate(gsm_lst):
410 if rng is None:
411 rng = rnd.default_rng()
412 if gsm.type is GSMTypes.STUDENTS_T:
413 (
414 self._meas_noise_filters[ii],
415 self._import_w_factory_lst[ii],
416 ) = self._define_student_t_pf(gsm, rng, num_parts[ii])
417 elif GSMTypes.CAUCHY:
418 (
419 self._meas_noise_filters[ii],
420 self._import_w_factory_lst[ii],
421 ) = self._define_cauchy_pf(gsm, rng, num_parts[ii])
422 else:
423 msg = (
424 "GSM filter can not automatically setup Bootstrap "
425 + "filter for GSM Type {:s}. ".format(gsm.type)
426 + "Update implementation."
427 )
428 raise RuntimeError(msg)
429 else:
430 msg = (
431 "Incorrect input arguement combination. See documentation for details."
432 )
433 raise RuntimeError(msg)
435 def set_process_noise_model(
436 self, initial_est=None, filter_length=None, startup_delay=None
437 ):
438 """Sets the filter parameters for estimating process noise.
440 This assumes the same process noise model as in
441 :cite:`VilaValls2011_BayesianFilteringforNonlinearStateSpaceModelsinSymmetricStableMeasurementNoise`.
443 Parameters
444 ----------
445 initial_est : N x N numpy array, optional
446 Initial estimate of the covariance, can either be set here or by
447 manually setting the process noise variable. The default is None,
448 this assumes process noise was manually set.
449 filter_length : int, optional
450 Number of past samples to use in the filter, must be > 1. The
451 default is None, this implies all past samples are used or the
452 previously set value is maintained.
453 startup_delay : int, optional
454 Number of samples to delay before runing the filter, used to fill
455 the FIFO buffers. Must be >= 1. The default is None, this means
456 either the previous value is maintained or a value of 1 is used.
458 Returns
459 -------
460 None.
461 """
462 self.enable_proc_noise_estimation = True
464 if filter_length is not None:
465 self._procNoiseEstimator.maxlen = filter_length
466 if initial_est is None and (
467 self.proc_noise is None or self.proc_noise.size <= 0
468 ):
469 msg = (
470 "Please manually set the initial process noise "
471 + "or specify a value here."
472 )
473 warn(msg)
474 elif initial_est is not None:
475 self.proc_noise = initial_est
476 if startup_delay is not None:
477 self._procNoiseEstimator.startup_delay = startup_delay
479 @property
480 def cov(self):
481 """Covariance of the filter."""
482 if self._coreFilter is None:
483 return np.array([[]])
484 else:
485 return self._coreFilter.cov
487 @cov.setter
488 def cov(self, val):
489 self._coreFilter.cov = val
491 @property
492 def proc_noise(self):
493 """Wrapper for the process noise covariance of the core filter."""
494 if self._coreFilter is None:
495 return np.array([[]])
496 else:
497 return self._coreFilter.proc_noise
499 @proc_noise.setter
500 def proc_noise(self, val):
501 self._coreFilter.proc_noise = val
503 @property
504 def meas_noise(self):
505 """Measurement noise of the core filter, estimated online and does not need to be set."""
506 if self._coreFilter is None:
507 return np.array([[]])
508 else:
509 return self._coreFilter.meas_noise
511 @meas_noise.setter
512 def meas_noise(self, val):
513 warn(
514 "Measurement noise is estimated online. NOT SETTING VALUE HERE.",
515 RuntimeWarning,
516 )
518 def predict(self, timestep, cur_state, **kwargs):
519 """Prediction step of the GSM filter.
521 This optionally estimates the process noise then calls the core filters
522 prediction function.
523 """
524 return self._coreFilter.predict(timestep, cur_state, **kwargs)
526 def correct(self, timestep, meas, cur_state, core_filt_kwargs={}):
527 """Correction step of the GSM filter.
529 This optionally estimates the measurement noise then calls the core
530 filters correction function.
531 """
532 # setup core filter for estimating measurement noise during
533 # correction function call
534 def est_meas_noise(est_meas, inov_cov):
535 m_diag = np.nan * np.ones(est_meas.size)
536 f_meas = (meas - est_meas).ravel()
537 inov_cov = np.diag(inov_cov)
538 for ii, filt in enumerate(self._meas_noise_filters):
539 filt.importance_weight_fnc = self._import_w_factory_lst[ii](
540 inov_cov[ii]
541 )
542 filt.predict(timestep)
543 state = filt.correct(timestep, f_meas[ii].reshape((1, 1)))
544 if state.size == 3:
545 m_diag[ii] = state[2] * state[1] ** 2 # z * sig^2
546 else:
547 m_diag[ii] = state[1] * state[0] ** 2 # z * sig^2
548 return np.diag(m_diag)
550 self._coreFilter.set_measurement_noise_estimator(est_meas_noise)
552 pred_cov = self.cov.copy()
553 cor_state, meas_fit_prob = self._coreFilter.correct(
554 timestep, meas, cur_state, **core_filt_kwargs
555 )
557 # update process noise estimate (if applicable)
558 if self.enable_proc_noise_estimation:
559 self.proc_noise = self._procNoiseEstimator.estimate_next(
560 self.proc_noise, cur_state, pred_cov, cor_state, self.cov
561 )
562 return cor_state, meas_fit_prob
564 def plot_particles(self, filt_inds, dist_inds, **kwargs):
565 """Plots the particle distribution for every given measurement index.
567 Parameters
568 ----------
569 filt_inds : int or list
570 Index of the measurement index/indices to plot the particle distribution
571 for.
572 dist_inds : int or list
573 Index of the particle(s) in the given filter(s) to plot. See the
574 :meth:`.BootstrapFilter.plot_particles` for details.
575 **kwargs : dict
576 Additional keyword arguements. See :meth:`.BootstrapFilter.plot_particles`.
578 Returns
579 -------
580 figs : dict
581 Each value in the dictionary is a matplotlib figure handle.
582 keys : list
583 Each value is a string corresponding to a key in the resulting dictionary
584 """
585 figs = {}
586 key_base = "meas_noise_particles_F{:02d}_D{:02d}"
587 keys = []
588 if not isinstance(filt_inds, list):
589 key = key_base.format(filt_inds, dist_inds)
590 figs[key] = self._meas_noise_filters[filt_inds].plot_particles(
591 dist_inds, **kwargs
592 )
593 keys.append(key)
594 else:
595 for ii in filt_inds:
596 key = key_base.format(ii, dist_inds)
597 figs[key] = self._meas_noise_filters[ii].plot_particles(
598 dist_inds, **kwargs
599 )
600 keys.append(key)
601 return figs, keys