Coverage for src/gncpy/filters/max_corr_ent_upf.py: 96%
26 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
1import numpy as np
3from gncpy.filters.unscented_particle_filter import UnscentedParticleFilter
4from gncpy.filters.max_corr_ent_ukf import MaxCorrEntUKF
7class MaxCorrEntUPF(UnscentedParticleFilter):
8 """Implements a Maximum Correntropy Unscented Particle Filter.
10 Notes
11 -----
12 This is based on
13 :cite:`Fan2018_MaximumCorrentropyBasedUnscentedParticleFilterforCooperativeNavigationwithHeavyTailedMeasurementNoises`
15 """
17 def __init__(self, **kwargs):
18 self._past_state = np.array([[]])
20 super().__init__(**kwargs)
21 self._filt = MaxCorrEntUKF()
23 def save_filter_state(self):
24 """Saves filter variables so they can be restored later."""
25 filt_state = super().save_filter_state()
27 filt_state["_past_state"] = self._past_state
29 return filt_state
31 def load_filter_state(self, filt_state):
32 """Initializes filter using saved filter state.
34 Attributes
35 ----------
36 filt_state : dict
37 Dictionary generated by :meth:`save_filter_state`.
38 """
39 super().load_filter_state(filt_state)
41 self._past_state = filt_state["_past_state"]
43 def _inner_correct(self, timestep, meas, state, filt_kwargs):
44 """Wrapper so child class can override."""
45 return self._filt.correct(
46 timestep, meas, state, self._past_state, **filt_kwargs
47 )
49 def correct(self, timestep, meas, past_state, **kwargs):
50 """Correction step of the MCUPF.
52 This is a wrapper for the parent method to allow for an additional
53 parameter.
55 Parameters
56 ----------
57 timestep : float
58 Current timestep.
59 meas : Nm x 1 numpy array
60 Current measurement.
61 past_state : N x 1 numpy array
62 State from before the prediction step.
63 **kwargs : dict
64 See the parent method.
66 Returns
67 -------
68 tuple
69 See the parent method.
70 """
71 self._past_state = past_state
72 return super().correct(timestep, meas, **kwargs)
74 @property
75 def kernel_bandwidth(self):
76 """Bandwidth for the Gaussian Kernel in the MCUKF.
78 Returns
79 -------
80 float
81 bandwidth
82 """
83 return self._filt.kernel_bandwidth
85 @kernel_bandwidth.setter
86 def kernel_bandwidth(self, kernel_bandwidth):
87 self._filt.kernel_bandwidth = kernel_bandwidth