Coverage for src/gncpy/filters/max_corr_ent_ukf.py: 100%
50 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
1import numpy as np
2import numpy.linalg as la
4import gncpy.math as gmath
5from gncpy.filters.unscented_kalman_filter import UnscentedKalmanFilter
8class MaxCorrEntUKF(UnscentedKalmanFilter):
9 """Implements a Maximum Correntropy Unscented Kalman filter.
11 Notes
12 -----
13 This is based on
14 :cite:`Hou2018_MaximumCorrentropyUnscentedKalmanFilterforBallisticMissileNavigationSystemBasedonSINSCNSDeeplyIntegratedMode`
16 Attributes
17 ----------
18 kernel_bandwidth : float, optional
19 Bandwidth of the Gaussian Kernel. The default is 1.
20 """
22 def __init__(self, kernel_bandwidth=1, **kwargs):
23 self.kernel_bandwidth = kernel_bandwidth
25 # for correction/calc_meas_cov wrapper function
26 self._past_state = np.array([[]])
27 self._cur_state = np.array([[]])
28 self._meas = np.array([[]])
30 super().__init__(**kwargs)
32 def save_filter_state(self):
33 """Saves filter variables so they can be restored later."""
34 filt_state = super().save_filter_state()
36 filt_state["kernel_bandwidth"] = self.kernel_bandwidth
37 filt_state["_past_state"] = self._past_state
38 filt_state["_cur_state"] = self._cur_state
39 filt_state["_meas"] = self._meas
41 return filt_state
43 def load_filter_state(self, filt_state):
44 """Initializes filter using saved filter state.
46 Attributes
47 ----------
48 filt_state : dict
49 Dictionary generated by :meth:`save_filter_state`.
50 """
51 super().load_filter_state(filt_state)
53 self.kernel_bandwidth = filt_state["kernel_bandwidth"]
54 self._past_state = filt_state["_past_state"]
55 self._cur_state = filt_state["_cur_state"]
56 self._meas = filt_state["_meas"]
58 def _calc_meas_cov(self, timestep, n_meas, meas_fun_args):
59 meas_cov, est_points, est_meas = super()._calc_meas_cov(
60 timestep, n_meas, meas_fun_args
61 )
63 # find square root of combined covariance matrix
64 n_state = self.cov.shape[0]
65 n_meas = est_meas.shape[0]
66 z_12 = np.zeros((n_state, n_meas))
67 z_21 = np.zeros((n_meas, n_state))
68 comb_cov = np.vstack(
69 (np.hstack((self.cov, z_12)), np.hstack((z_21, self.meas_noise)))
70 )
71 comb_cov = (comb_cov + comb_cov.T) * 0.5
72 sqrt_comb = la.cholesky(comb_cov)
73 inv_sqrt_comb = la.inv(sqrt_comb)
75 # find error vector
76 pred_meas = self._est_meas(timestep, self._past_state, n_meas, meas_fun_args)[0]
77 g = inv_sqrt_comb @ np.vstack((self._past_state, pred_meas))
78 d = inv_sqrt_comb @ np.vstack((self._cur_state, self._meas))
79 e = (d - g).ravel()
81 # kernel function on error
82 kern_lst = [gmath.gaussian_kernel(e_ii, self.kernel_bandwidth) for e_ii in e]
83 c = np.diag(kern_lst)
84 c_inv = la.inv(c)
86 # calculate the measurement covariance
87 scaled_mat = sqrt_comb @ c_inv @ sqrt_comb.T
88 scaled_meas_noise = scaled_mat[n_state:, n_state:]
89 meas_cov = meas_cov + scaled_meas_noise
91 return meas_cov, est_points, est_meas
93 def correct(self, timestep, meas, cur_state, past_state, **kwargs):
94 """Correction function for the Max Correntropy UKF.
96 This is a wrapper for the parent method to allow for additional
97 parameters.
99 Parameters
100 ----------
101 timestep : float
102 Current timestep.
103 meas : Nm x 1 numpy array
104 Current measurement.
105 cur_state : N x 1 numpy array
106 Current state.
107 past_state : N x 1 numpy array
108 State from before the prediction step.
109 **kwargs : dict, optional
110 See the parent function for additional parameters.
112 Returns
113 -------
114 tuple
115 See the parent method.
116 """
117 self._past_state = past_state.copy()
118 self._cur_state = cur_state.copy()
119 self._meas = meas.copy()
121 return super().correct(timestep, meas, cur_state, **kwargs)