Coverage for src/gncpy/filters/quadrature_kalman_filter.py: 91%
107 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
1import numpy as np
2import numpy.linalg as la
4import gncpy.distributions as gdistrib
5import gncpy.dynamics.basic as gdyn
6import gncpy.math as gmath
7from gncpy.filters.kalman_filter import KalmanFilter
8from gncpy.filters.extended_kalman_filter import ExtendedKalmanFilter
11class QuadratureKalmanFilter(ExtendedKalmanFilter):
12 """Implementation of a Quadrature Kalman Filter.
14 Notes
15 -----
16 This implementation is based on
17 :cite:`Arasaratnam2007_DiscreteTimeNonlinearFilteringAlgorithmsUsingGaussHermiteQuadrature`
18 and uses Gauss-Hermite quadrature points. It inherits from EKF to allow
19 for easier implementation of non-linear dynamics.
21 Attributes
22 ----------
23 quadPoints : :class:`gncpy.distributions.QuadraturePoints`
24 Quadrature points used by the filter.
25 """
27 def __init__(self, points_per_axis=None, **kwargs):
28 super().__init__(**kwargs)
30 self.quadPoints = gdistrib.QuadraturePoints(points_per_axis=points_per_axis)
31 self._sqrt_cov = np.array([[]])
32 self._use_lin_dyn = False
33 self._use_non_lin_dyn = False
35 def save_filter_state(self):
36 """Saves filter variables so they can be restored later."""
37 filt_state = super().save_filter_state()
39 filt_state["quadPoints"] = self.quadPoints
41 filt_state["_sqrt_cov"] = self._sqrt_cov
42 filt_state["_use_lin_dyn"] = self._use_lin_dyn
43 filt_state["_use_non_lin_dyn"] = self._use_non_lin_dyn
45 return filt_state
47 def load_filter_state(self, filt_state):
48 """Initializes filter using saved filter state.
50 Attributes
51 ----------
52 filt_state : dict
53 Dictionary generated by :meth:`save_filter_state`.
54 """
55 super().load_filter_state(filt_state)
57 self.quadPoints = filt_state["quadPoints"]
58 self._sqrt_cov = filt_state["_sqrt_cov"]
59 self._use_lin_dyn = filt_state["_use_lin_dyn"]
60 self._use_non_lin_dyn = filt_state["_use_non_lin_dyn"]
62 @property
63 def points_per_axis(self):
64 """Wrapper for the number of quadrature points per axis."""
65 return self.quadPoints.points_per_axis
67 @points_per_axis.setter
68 def points_per_axis(self, val):
69 self.quadPoints.points_per_axis = val
71 def set_state_model(
72 self,
73 state_mat=None,
74 input_mat=None,
75 cont_time=False,
76 state_mat_fun=None,
77 input_mat_fun=None,
78 dyn_obj=None,
79 ode_lst=None,
80 ):
81 """Sets the state model for the filter.
83 This can use either linear dynamics (by calling the kalman filters
84 :meth:`gncpy.filters.KalmanFilter.set_state_model`) or non-linear dynamics
85 (by calling :meth:`gncpy.filters.ExtendedKalmanFilter.set_state_model`).
86 The linearness is automatically determined by the input arguments specified.
88 Parameters
89 ----------
90 state_mat : N x N numpy array, optional
91 State matrix, continuous or discrete case. The default is None.
92 input_mat : N x Nu numpy array, optional
93 Input matrixx, continuous or discrete case. The default is None.
94 cont_time : bool, optional
95 Flag inidicating if the continuous model is provided. The default
96 is False.
97 state_mat_fun : callable, optional
98 Function that returns the `state_mat`, must take timestep and
99 `*args`. The default is None.
100 input_mat_fun : callable, optional
101 Function that returns the `input_mat`, must take timestep, and
102 `*args`. The default is None.
103 dyn_obj : :class:`gncpy.dynamics.LinearDynamicsBase` or :class:`gncpy.dynamics.NonlinearDynamicsBase`, optional
104 Sets the dynamics according to the class. The default is None.
105 ode_lst : list, optional
106 callable functions, 1 per ode/state. The callabale must have the
107 signature `f(t, x, *f_args)` just like scipy.integrate's ode
108 function. The default is None.
110 Raises
111 ------
112 RuntimeError
113 If an invalid state model or combination of inputs is specified.
115 Returns
116 -------
117 None.
118 """
119 self._use_lin_dyn = (
120 state_mat is not None
121 or state_mat_fun is not None
122 or isinstance(dyn_obj, gdyn.LinearDynamicsBase)
123 )
124 self._use_non_lin_dyn = (
125 isinstance(dyn_obj, gdyn.NonlinearDynamicsBase) or ode_lst is not None
126 ) and not self._use_lin_dyn
128 # allow for linear or non linear dynamics by calling the appropriate parent
129 if self._use_lin_dyn:
130 KalmanFilter.set_state_model(
131 self,
132 state_mat=state_mat,
133 input_mat=input_mat,
134 cont_time=cont_time,
135 state_mat_fun=state_mat_fun,
136 input_mat_fun=input_mat_fun,
137 dyn_obj=dyn_obj,
138 )
139 elif self._use_non_lin_dyn:
140 ExtendedKalmanFilter.set_state_model(self, dyn_obj=dyn_obj, ode_lst=ode_lst)
141 else:
142 raise RuntimeError("Invalid state model.")
144 def set_measurement_model(self, meas_mat=None, meas_fun=None):
145 r"""Sets the measurement model for the filter.
147 This can either set the constant measurement matrix, or a potentially
148 non-linear function.
150 Notes
151 -----
152 This assumes a measurement model of the form
154 .. math::
155 \tilde{y}_{k+1} = H x_{k+1}^-
157 for the measurement matrix case. Or of the form
159 .. math::
160 \tilde{y}_{k+1} = h(t, x_{k+1}^-)
162 for the potentially non-linear case.
164 Parameters
165 ----------
166 meas_mat : Nm x N numpy array, optional
167 Measurement matrix that transforms the state to estimated
168 measurements. The default is None.
169 meas_fun : callable, optional
170 Function that transforms the state to estimated measurements. Must
171 have the signature :code:`h(t, x, *args)` where `t` is the timestep,
172 `x` is an N x 1 numpy array of the current state, and return an
173 Nm x 1 numpy array of the estimated measurement. The default is None.
175 Raises
176 ------
177 RuntimeError
178 Rasied if no arguments are specified.
179 """
180 super().set_measurement_model(meas_mat=meas_mat, meas_fun_lst=meas_fun)
182 def _factorize_cov(self, val=None):
183 if val is None:
184 val = self.cov
185 # numpy linalg is lower triangular
186 self._sqrt_cov = la.cholesky(val)
188 def _pred_update_cov(self):
189 self.cov = self.proc_noise + self.quadPoints.cov
191 def _predict_next_state(
192 self, timestep, cur_state, cur_input, state_mat_args, input_mat_args, dyn_fun_params
193 ):
194 if self._use_lin_dyn:
195 return KalmanFilter._predict_next_state(
196 self,
197 timestep,
198 cur_state.reshape((-1, 1)),
199 cur_input,
200 state_mat_args,
201 input_mat_args,
202 )[0]
203 elif self._use_non_lin_dyn:
204 return ExtendedKalmanFilter._predict_next_state(
205 self, timestep, cur_state.reshape((-1, 1)), dyn_fun_params
206 )[0]
207 else:
208 raise RuntimeError("State model not specified")
210 def predict(
211 self, timestep, cur_state, cur_input=None, state_mat_args=None, input_mat_args=None, dyn_fun_params=None,
212 ):
213 """Prediction step of the filter.
215 Parameters
216 ----------
217 timestep : float
218 Current timestep.
219 cur_state : N x 1 numpy array
220 Current state.
221 cur_input : N x Nu numpy array, optional
222 Current input. The default is None.
223 state_mat_args : tuple, optional
224 Additional arguments for the get state matrix function if one has
225 been specified, the propagate state function if using a dynamic
226 object, or the dynamic function is a non-linear model is used.
227 The default is ().
228 input_mat_args : tuple, optional
229 Additional arguments for the get input matrix function if one has
230 been specified or the propagate state function if using a dynamic
231 object. The default is ().
232 dyn_fun_params : tuple, optional
233 Additional arguments to pass to the dynamics function if using non-linear
234 dynamics.
236 Raises
237 ------
238 RuntimeError
239 If the state model has not been set
241 Returns
242 -------
243 N x 1 numpy array
244 The predicted state.
245 """
246 if state_mat_args is None:
247 state_mat_args = ()
248 if input_mat_args is None:
249 input_mat_args = ()
250 if dyn_fun_params is None:
251 dyn_fun_params = ()
253 # factorize covariance as P = sqrt(P) * sqrt(P)^T
254 self._factorize_cov()
256 # generate quadrature points as X_i = sqrt(P) * xi_i + x_hat for m points
257 self.quadPoints.update_points(cur_state, self._sqrt_cov, have_sqrt=True)
259 # predict each point using the dynamics
260 for ii, (point, _) in enumerate(self.quadPoints):
261 pred_point = self._predict_next_state(
262 timestep, point, cur_input, state_mat_args, input_mat_args, dyn_fun_params
263 )
264 self.quadPoints.points[ii, :] = pred_point.ravel()
265 # update covariance as Q - m * x * x^T + sum(w_i * X_i * X_i^T)
266 self._pred_update_cov()
268 return self.quadPoints.mean
270 def _corr_update_cov(self, gain, inov_cov):
271 self.cov = self.cov - gain @ inov_cov @ gain.T
272 self.cov = 0.5 * (self.cov + self.cov.T)
274 def _est_meas(self, timestep, cur_state, n_meas, meas_fun_args):
275 if self._meas_fnc is not None:
276 return self._meas_fnc(timestep, cur_state, *meas_fun_args).ravel()
277 else:
278 return (
279 super()._est_meas(timestep, cur_state, n_meas, meas_fun_args)[0].ravel()
280 )
282 def _corr_core(self, timestep, cur_state, meas, meas_fun_args):
283 # factorize covariance as P = sqrt(P) * sqrt(P)^T
284 self._factorize_cov()
286 # generate quadrature points as X_i = sqrt(P) * xi_i + x_hat for m points
287 self.quadPoints.update_points(cur_state, self._sqrt_cov, have_sqrt=True)
289 # Estimate a measurement for each quad point, Z_i
290 measQuads = gdistrib.QuadraturePoints(num_axes=meas.size)
291 measQuads.points = np.nan * np.ones(
292 (self.quadPoints.points.shape[0], meas.size)
293 )
294 measQuads.weights = self.quadPoints.weights
295 for ii, (point, _) in enumerate(self.quadPoints):
296 measQuads.points[ii, :] = self._est_meas(
297 timestep, point, meas.size, meas_fun_args
298 )
299 # estimate predicted measurement as sum of est measurement quad points
300 est_meas = measQuads.mean
302 return measQuads, est_meas
304 def correct(self, timestep, meas, cur_state, meas_fun_args=()):
305 """Implements the correction step of the filter.
307 Parameters
308 ----------
309 timestep : float
310 Current timestep.
311 meas : Nm x 1 numpy array
312 Current measurement.
313 cur_state : N x 1 numpy array
314 Current state.
315 meas_fun_args : tuple, optional
316 Arguments for the measurement matrix function if one has
317 been specified. The default is ().
319 Raises
320 ------
321 :class:`.errors.ExtremeMeasurementNoiseError`
322 If estimating the measurement noise and the measurement fit calculation fails.
323 LinAlgError
324 Numpy exception raised if not estimating noise and measurement fit fails.
326 Returns
327 -------
328 next_state : N x 1 numpy array
329 The corrected state.
330 meas_fit_prob : float
331 Goodness of fit of the measurement based on the state and
332 covariance assuming Gaussian noise.
334 """
335 measQuads, est_meas = self._corr_core(timestep, cur_state, meas, meas_fun_args)
337 # estimate the measurement noise online if applicable
338 if self._est_meas_noise_fnc is not None:
339 self.meas_noise = self._est_meas_noise_fnc(est_meas, measQuads.cov)
340 # estimate innovation cov as P_zz = R - m * z_hat * z_hat^T + sum(w_i * Z_i * Z_i^T)
341 inov_cov = self.meas_noise + measQuads.cov
343 if self.use_cholesky_inverse:
344 sqrt_inv_inov_cov = la.inv(la.cholesky(inov_cov))
345 inv_inov_cov = sqrt_inv_inov_cov.T @ sqrt_inv_inov_cov
346 else:
347 inv_inov_cov = la.inv(inov_cov)
348 # estimate cross cov as P_xz = sum(w_i * X_i * Z_i^T) - m * x * z_hat^T
349 cov_lst = [None] * self.quadPoints.num_points
350 for ii, (qp, mp) in enumerate(zip(self.quadPoints, measQuads)):
351 cov_lst[ii] = (qp[0] - cur_state) @ (mp[0] - est_meas).T
352 cross_cov = gmath.weighted_sum_mat(self.quadPoints.weights, cov_lst)
354 # calc Kalman gain as K = P_xz * P_zz^-1
355 gain = cross_cov @ inv_inov_cov
357 # state is x_hat + K *(z - z_hat)
358 innov = meas - est_meas
359 cor_state = cur_state + gain @ innov
361 # update covariance as P = P_k - K * P_zz * K^T
362 self._corr_update_cov(gain, inov_cov)
364 meas_fit_prob = self._calc_meas_fit(meas, est_meas, inov_cov)
366 return (cor_state, meas_fit_prob)
368 def plot_quadrature(self, inds, **kwargs):
369 """Wrapper function for :meth:`gncpy.distributions.QuadraturePoints.plot_points`."""
370 return self.quadPoints.plot_points(inds, **kwargs)