Coverage for src/gncpy/filters/square_root_qkf.py: 99%
74 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
1import numpy as np
2import numpy.linalg as la
3import scipy.linalg as sla
6from gncpy.filters.quadrature_kalman_filter import QuadratureKalmanFilter
9class SquareRootQKF(QuadratureKalmanFilter):
10 """Implementation of a Square root Quadrature Kalman Filter (SQKF).
12 Notes
13 -----
14 This is based on :cite:`Arasaratnam2008_SquareRootQuadratureKalmanFiltering`.
15 """
17 def __init__(self, **kwargs):
18 super().__init__(**kwargs)
20 self._meas_noise = np.array([[]])
21 self._sqrt_p_noise = np.array([[]])
22 self._sqrt_m_noise = np.array([[]])
24 def save_filter_state(self):
25 """Saves filter variables so they can be restored later."""
26 filt_state = super().save_filter_state()
28 filt_state["_meas_noise"] = self._meas_noise
29 filt_state["_sqrt_p_noise"] = self._sqrt_p_noise
30 filt_state["_sqrt_m_noise"] = self._sqrt_m_noise
32 return filt_state
34 def load_filter_state(self, filt_state):
35 """Initializes filter using saved filter state.
37 Attributes
38 ----------
39 filt_state : dict
40 Dictionary generated by :meth:`save_filter_state`.
41 """
42 super().load_filter_state(filt_state)
44 self._meas_noise = filt_state["_meas_noise"]
45 self._sqrt_p_noise = filt_state["_sqrt_p_noise"]
46 self._sqrt_m_noise = filt_state["_sqrt_m_noise"]
48 def set_measurement_noise_estimator(self, function):
49 """Sets the model used for estimating the measurement noise parameters.
51 This is an optional step and the filter will work properly if this is
52 not called. If it is called, the measurement noise will be estimated
53 during the filter's correction step and the measurement noise attribute
54 will not be used.
56 Parameters
57 ----------
58 function : callable
59 A function that implements the prediction and correction steps for
60 an appropriate filter to estimate the measurement noise covariance
61 matrix. It must have the signature `f(est_meas)` where `est_meas`
62 is an Nm x 1 numpy array and it must return an Nm x Nm numpy array
63 representing the measurement noise covariance matrix.
65 Returns
66 -------
67 None.
68 """
69 self._est_meas_noise_fnc = function
71 @property
72 def cov(self):
73 """Covariance of the filter."""
74 # sqrt cov is lower triangular
75 return self._sqrt_cov @ self._sqrt_cov.T
77 @cov.setter
78 def cov(self, val):
79 if val.size == 0:
80 self._sqrt_cov = val
81 else:
82 super()._factorize_cov(val=val)
84 @property
85 def proc_noise(self):
86 """Process noise of the filter."""
87 return self._sqrt_p_noise @ self._sqrt_p_noise.T
89 @proc_noise.setter
90 def proc_noise(self, val):
91 if val.size == 0 or np.all(val == 0):
92 self._sqrt_p_noise = val
93 else:
94 self._sqrt_p_noise = la.cholesky(val)
96 @property
97 def meas_noise(self):
98 """Measurement noise of the filter."""
99 return self._sqrt_m_noise @ self._sqrt_m_noise.T
101 @meas_noise.setter
102 def meas_noise(self, val):
103 if val.size == 0 or np.all(val == 0):
104 self._sqrt_m_noise = val
105 else:
106 self._sqrt_m_noise = la.cholesky(val)
108 def _factorize_cov(self):
109 pass
111 def _pred_update_cov(self):
112 weight_mat = np.diag(np.sqrt(self.quadPoints.weights))
113 x_hat = self.quadPoints.mean
114 state_mat = np.concatenate(
115 [x.reshape((x.size, 1)) - x_hat for x in self.quadPoints.points], axis=1
116 )
118 self._sqrt_cov = la.qr(
119 np.concatenate((state_mat @ weight_mat, self._sqrt_p_noise.T), axis=1).T,
120 mode="r",
121 ).T
123 def _corr_update_cov(self, gain, state_mat, meas_mat):
124 self._sqrt_cov = la.qr(
125 np.concatenate(
126 (state_mat - gain @ meas_mat, gain @ self._sqrt_m_noise), axis=1
127 ).T,
128 mode="r",
129 ).T
131 def correct(self, timestep, meas, cur_state, meas_fun_args=()):
132 """Implements the correction step of the filter.
134 Parameters
135 ----------
136 timestep : float
137 Current timestep.
138 meas : Nm x 1 numpy array
139 Current measurement.
140 cur_state : N x 1 numpy array
141 Current state.
142 meas_fun_args : tuple, optional
143 Arguments for the measurement matrix function if one has
144 been specified. The default is ().
146 Raises
147 ------
148 :class:`.errors.ExtremeMeasurementNoiseError`
149 If estimating the measurement noise and the measurement fit calculation fails.
150 LinAlgError
151 Numpy exception raised if not estimating noise and measurement fit fails.
153 Returns
154 -------
155 next_state : N x 1 numpy array
156 The corrected state.
157 meas_fit_prob : float
158 Goodness of fit of the measurement based on the state and
159 covariance assuming Gaussian noise.
161 """
162 measQuads, est_meas = self._corr_core(timestep, cur_state, meas, meas_fun_args)
164 weight_mat = np.diag(np.sqrt(self.quadPoints.weights))
166 # calculate sqrt of the measurement covariance
167 meas_mat = (
168 np.concatenate(
169 [z.reshape((z.size, 1)) - est_meas for z in measQuads.points], axis=1
170 )
171 @ weight_mat
172 )
173 if self._est_meas_noise_fnc is not None:
174 self.meas_noise = self._est_meas_noise_fnc(est_meas, meas_mat @ meas_mat.T)
175 sqrt_inov_cov = la.qr(
176 np.concatenate((meas_mat, self._sqrt_m_noise), axis=1).T, mode="r"
177 ).T
179 # calculate cross covariance
180 x_hat = self.quadPoints.mean
181 state_mat = (
182 np.concatenate(
183 [x.reshape((x.size, 1)) - x_hat for x in self.quadPoints.points], axis=1
184 )
185 @ weight_mat
186 )
187 cross_cov = state_mat @ meas_mat.T
189 # calculate gain
190 inter = sla.solve_triangular(sqrt_inov_cov.T, cross_cov.T)
191 gain = sla.solve_triangular(sqrt_inov_cov, inter, lower=True).T
193 # the above gain is equavalent to
194 # inv_sqrt_inov_cov = la.inv(sqrt_inov_cov)
195 # gain = cross_cov @ (inv_sqrt_inov_cov.T @ inv_sqrt_inov_cov)
197 # state is x_hat + K *(z - z_hat)
198 innov = meas - est_meas
199 cor_state = cur_state + gain @ innov
201 # update covariance
202 inov_cov = sqrt_inov_cov @ sqrt_inov_cov.T
204 self._corr_update_cov(gain, state_mat, meas_mat)
206 meas_fit_prob = self._calc_meas_fit(meas, est_meas, inov_cov)
208 return (cor_state, meas_fit_prob)