Coverage for src/gncpy/filters/kalman_filter.py: 84%
222 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
1import numpy as np
2import numpy.linalg as la
3import scipy.linalg as sla
4import scipy.stats as stats
5from copy import deepcopy
7import gncpy.errors as gerr
8import gncpy.filters._filters as cpp_bindings
9from gncpy.filters.bayes_filter import BayesFilter
12class KalmanFilter(BayesFilter):
13 """Implementation of a discrete time Kalman Filter.
15 Notes
16 -----
17 This is loosely based on :cite:`Crassidis2011_OptimalEstimationofDynamicSystems`
19 Attributes
20 ----------
21 cov : N x N numpy array
22 Covariance matrix
23 meas_noise : Nm x Nm numpy array
24 Measurement noise matrix
25 proc_noise : N x N numpy array
26 Process noise matrix
27 dt : float
28 Time difference between simulation steps. Required if not using a
29 dynamic object for the state model.
31 """
33 def __init__(
34 self, cov=np.array([[]]), meas_noise=np.array([[]]), dt=None, **kwargs
35 ):
36 self._cov = cov
37 self.meas_noise = meas_noise
38 self.proc_noise = np.array([[]])
39 self.dt = dt
41 self._dyn_obj = None
42 self._state_mat = np.array([[]])
43 self._input_mat = np.array([[]])
44 self._get_state_mat = None
45 self._get_input_mat = None
46 self._meas_mat = np.array([[]])
47 self._meas_fnc = None
48 self._measObj = None
50 self._est_meas_noise_fnc = None
52 self.__model = None
53 self.__predParams = None
54 self.__corrParams = None
56 super().__init__(**kwargs)
58 def __repr__(self) -> str:
59 if self.__model is not None:
60 return self.__model.__repr__()
61 else:
62 return super().__repr__()
64 def __str__(self) -> str:
65 if self.__model is not None:
66 return self.__model.__str__()
67 else:
68 return super().__str__()
70 @property
71 def cov(self):
72 if self.__model is not None:
73 return self.__model.cov
74 else:
75 return self._cov
77 @cov.setter
78 def cov(self, val):
79 if self.__model is not None:
80 self.__model.cov = val
81 else:
82 self._cov = val
84 def save_filter_state(self):
85 """Saves filter variables so they can be restored later."""
86 filt_state = super().save_filter_state()
88 filt_state["cov"] = self.cov.copy()
89 if self.meas_noise is not None:
90 filt_state["meas_noise"] = self.meas_noise.copy()
91 else:
92 filt_state["meas_noise"] = self.meas_noise
93 if self.proc_noise is not None:
94 filt_state["proc_noise"] = self.proc_noise.copy()
95 else:
96 filt_state["proc_noise"] = self.proc_noise
97 filt_state["dt"] = self.dt
98 filt_state["_dyn_obj"] = deepcopy(self._dyn_obj)
100 if self._state_mat is not None:
101 filt_state["_state_mat"] = self._state_mat.copy()
102 else:
103 filt_state["_state_mat"] = self._state_mat
104 if self._input_mat is not None:
105 filt_state["_input_mat"] = self._input_mat.copy()
106 else:
107 filt_state["_input_mat"] = self._input_mat
108 filt_state["_get_state_mat"] = self._get_state_mat
109 filt_state["_get_input_mat"] = self._get_input_mat
111 if self._meas_mat is not None:
112 filt_state["_meas_mat"] = self._meas_mat.copy()
113 else:
114 filt_state["_meas_mat"] = self._meas_mat
115 filt_state["_meas_fnc"] = self._meas_fnc
116 filt_state["_est_meas_noise_fnc"] = self._est_meas_noise_fnc
118 filt_state["_measObj"] = self._measObj
119 filt_state["__model"] = self.__model
120 filt_state["__predParams"] = self.__predParams
121 filt_state["__corrParams"] = self.__corrParams
123 return filt_state
125 def load_filter_state(self, filt_state):
126 """Initializes filter using saved filter state.
128 Attributes
129 ----------
130 filt_state : dict
131 Dictionary generated by :meth:`save_filter_state`.
132 """
133 super().load_filter_state(filt_state)
135 self.cov = filt_state["cov"]
136 self.meas_noise = filt_state["meas_noise"]
137 self.proc_noise = filt_state["proc_noise"]
138 self.dt = filt_state["dt"]
140 self._dyn_obj = filt_state["_dyn_obj"]
141 self._state_mat = filt_state["_state_mat"]
142 self._input_mat = filt_state["_input_mat"]
143 self._get_state_mat = filt_state["_get_state_mat"]
144 self._get_input_mat = filt_state["_get_input_mat"]
145 self._meas_mat = filt_state["_meas_mat"]
146 self._meas_fnc = filt_state["_meas_fnc"]
147 self._est_meas_noise_fnc = filt_state["_est_meas_noise_fnc"]
149 self._measObj = filt_state["_measObj"]
150 self.__model = filt_state["__model"]
151 self.__predParams = filt_state["__predParams"]
152 self.__corrParams = filt_state["__corrParams"]
154 def set_state_model(
155 self,
156 state_mat=None,
157 input_mat=None,
158 cont_time=False,
159 state_mat_fun=None,
160 input_mat_fun=None,
161 dyn_obj=None,
162 ):
163 r"""Sets the state model equation for the filter.
165 If the continuous time model is used then a `dt` must be provided, see
166 the note for algorithm details. Alternatively, if the system is time
167 varying then functions can be specified to return the matrices at each
168 time step.
170 Note
171 -----
172 This can use a continuous or discrete model. The continuous model will
173 be automatically discretized so standard matrix equations can be used.
175 If the discrete model is used it is assumed to have the form
177 .. math::
178 x_{k+1} = F x_k + G u_k
180 If the continuous model is used it is assumed to have the form
182 .. math::
183 \dot{x} = A x + B u
185 and is discretized according to
187 .. math::
188 expm\left[\begin{bmatrix}
189 A & B\\
190 0 & 0
191 \end{bmatrix}dt\right]=\begin{bmatrix}
192 F & G\\
193 0 & I
194 \end{bmatrix}
196 Parameters
197 ----------
198 state_mat : N x N numpy array, optional
199 State matrix, continuous or discrete case. The default is None.
200 input_mat : N x Nu numpy array, optional
201 Input matrixx, continuous or discrete case. The default is None.
202 cont_time : bool, optional
203 Flag inidicating if the continuous model is provided. The default
204 is False.
205 state_mat_fun : callable, optional
206 Function that returns the `state_mat`, must take timestep and
207 `*args`. The default is None.
208 input_mat_fun : callable, optional
209 Function that returns the `input_mat`, must take timestep, and
210 `*args`. The default is None.
211 dyn_obj : :class:`gncpy.dynamics.LinearDynamicsBase`, optional
212 Sets the dynamics according to the class. The default is None.
214 Raises
215 ------
216 RuntimeError
217 If the improper combination of input arguments are specified.
219 Returns
220 -------
221 None.
223 """
224 have_obj = dyn_obj is not None
225 have_mats = state_mat is not None
226 have_funs = state_mat_fun is not None
227 if have_obj:
228 self._dyn_obj = dyn_obj
229 elif have_mats and not cont_time:
230 self.__model = None
231 self._state_mat = state_mat
232 self._input_mat = input_mat
233 elif have_mats:
234 self.__model = None
235 if self.dt is None:
236 msg = "dt must be specified when using continuous time model"
237 raise RuntimeError(msg)
238 n_cols = state_mat.shape[1] + input_mat.shape[1]
239 big_mat = np.vstack(
240 (
241 np.hstack((state_mat, input_mat)),
242 np.zeros((input_mat.shape[1], n_cols)),
243 )
244 )
245 res = sla.expm(big_mat * self.dt)
246 r_s = 0
247 r_e = state_mat.shape[0]
248 c_s = 0
249 c_e = state_mat.shape[1]
250 self._state_mat = res[r_s:r_e, c_s:c_e]
251 c_s = c_e
252 c_e = res.shape[1]
253 self._input_mat = res[r_s:r_e, c_s:c_e]
254 elif have_funs:
255 self.__model = None
256 self._get_state_mat = state_mat_fun
257 self._get_input_mat = input_mat_fun
258 else:
259 raise RuntimeError("Invalid combination of inputs")
261 def set_measurement_model(self, meas_mat=None, meas_fun=None, measObj=None):
262 r"""Sets the measurement model for the filter.
264 This can either set the constant measurement matrix, or the matrix can
265 be time varying.
267 Notes
268 -----
269 This assumes a measurement model of the form
271 .. math::
272 \tilde{y}_{k+1} = H_{k+1} x_{k+1}^-
274 where :math:`H_{k+1}` can be constant over time.
276 Parameters
277 ----------
278 meas_mat : Nm x N numpy array, optional
279 Measurement matrix that transforms the state to estimated
280 measurements. The default is None.
281 meas_fun : callable, optional
282 Function that returns the matrix for transforming the state to
283 estimated measurements. Must take timestep, and `*args` as
284 arguments. The default is None.
285 measObj : class instance
286 Measurement class instance
288 Raises
289 ------
290 RuntimeError
291 Rasied if no arguments are specified.
293 Returns
294 -------
295 None.
297 """
298 if measObj is not None:
299 self._measObj = measObj
300 self._meas_mat = None
301 self._meas_fnc = None
302 elif meas_mat is not None:
303 self.__model = None
304 self._measObj = None
305 self._meas_mat = meas_mat
306 self._meas_fnc = None
307 elif meas_fun is not None:
308 self.__model = None
309 self._measObj = None
310 self._meas_mat = None
311 self._meas_fnc = meas_fun
312 else:
313 raise RuntimeError("Invalid combination of inputs")
315 def set_measurement_noise_estimator(self, function):
316 """Sets the model used for estimating the measurement noise parameters.
318 This is an optional step and the filter will work properly if this is
319 not called. If it is called, the measurement noise will be estimated
320 during the filter's correction step and the measurement noise attribute
321 will not be used.
323 Parameters
324 ----------
325 function : callable
326 A function that implements the prediction and correction steps for
327 an appropriate filter to estimate the measurement noise covariance
328 matrix. It must have the signature `f(est_meas)` where `est_meas`
329 is an Nm x 1 numpy array and it must return an Nm x Nm numpy array
330 representing the measurement noise covariance matrix.
332 Returns
333 -------
334 None.
335 """
336 self._est_meas_noise_fnc = function
338 def _predict_next_state(
339 self, timestep, cur_state, cur_input, state_mat_args, input_mat_args
340 ):
341 if self._dyn_obj is not None:
342 next_state = self._dyn_obj.propagate_state(
343 timestep,
344 cur_state,
345 u=cur_input,
346 state_args=state_mat_args,
347 ctrl_args=input_mat_args,
348 )
349 state_mat = self._dyn_obj.get_state_mat(timestep, *state_mat_args)
350 else:
351 if self._get_state_mat is not None:
352 state_mat = self._get_state_mat(timestep, *state_mat_args)
353 elif self._state_mat is not None:
354 state_mat = self._state_mat
355 else:
356 raise RuntimeError("State model not set")
357 if self._get_input_mat is not None:
358 input_mat = self._get_input_mat(timestep, *input_mat_args)
359 elif self._input_mat is not None:
360 input_mat = self._input_mat
361 else:
362 input_mat = None
363 next_state = state_mat @ cur_state
365 if input_mat is not None and cur_input is not None:
366 next_state += input_mat @ cur_input
367 return next_state, state_mat
369 def _init_model(self):
370 self._cpp_needs_init = (
371 self.__model is None
372 and (self._dyn_obj is not None and self._dyn_obj.allow_cpp)
373 and self._measObj is not None
374 )
375 if self._cpp_needs_init:
376 self.__model = cpp_bindings.Kalman()
377 self.__predParams = cpp_bindings.BayesPredictParams()
378 self.__corrParams = cpp_bindings.BayesCorrectParams()
380 # make sure the cpp filter has its values set based on what python user gave (init only)
381 self.__model.cov = self._cov.astype(np.float64)
382 self.__model.set_state_model(self._dyn_obj.model, self.proc_noise)
383 self.__model.set_measurement_model(self._measObj, self.meas_noise)
385 def predict(
386 self, timestep, cur_state, cur_input=None, state_mat_args=(), input_mat_args=()
387 ):
388 """Implements a discrete time prediction step for a Kalman Filter.
390 Parameters
391 ----------
392 timestep : float
393 Current timestep.
394 cur_state : N x 1 numpy array
395 Current state.
396 cur_input : N x Nu numpy array, optional
397 Current input. The default is None.
398 state_mat_args : tuple, optional
399 keyword arguments for the get state matrix function if one has
400 been specified or the propagate state function if using a dynamic
401 object. The default is ().
402 input_mat_args : tuple, optional
403 keyword arguments for the get input matrix function if one has
404 been specified or the propagate state function if using a dynamic
405 object. The default is ().
407 Raises
408 ------
409 RuntimeError
410 If the state model has not been set
412 Returns
413 -------
414 next_state : N x 1 numpy array
415 Next state.
417 """
418 self._init_model()
420 if self.__model is not None:
421 (
422 self.__predParams.stateTransParams,
423 self.__predParams.controlParams,
424 ) = self._dyn_obj.args_to_params(state_mat_args, input_mat_args)[:2]
425 return self.__model.predict(
426 timestep, cur_state, cur_input, self.__predParams
427 ).reshape((-1, 1))
428 else:
429 next_state, state_mat = self._predict_next_state(
430 timestep, cur_state, cur_input, state_mat_args, input_mat_args
431 )
433 self.cov = state_mat @ self.cov @ state_mat.T + self.proc_noise
434 self.cov = (self.cov + self.cov.T) * 0.5
435 return next_state
437 def _get_meas_mat(self, t, state, n_meas, meas_fun_args):
438 # time varying matrix
439 if self._meas_fnc is not None:
440 meas_mat = self._meas_fnc(t, *meas_fun_args)
441 else:
442 # constant matrix
443 meas_mat = self._meas_mat
444 return meas_mat
446 def _est_meas(self, timestep, cur_state, n_meas, meas_fun_args):
447 meas_mat = self._get_meas_mat(timestep, cur_state, n_meas, meas_fun_args)
449 est_meas = meas_mat @ cur_state
451 return est_meas, meas_mat
453 def _meas_fit_pdf(self, meas, est_meas, meas_cov):
454 extra_args = {}
455 if np.abs(np.linalg.det(meas_cov)) > np.finfo(float).eps:
456 extra_args["allow_singular"] = True
457 return stats.multivariate_normal.pdf(
458 meas.ravel(), mean=est_meas.ravel(), cov=meas_cov, **extra_args
459 )
461 def _calc_meas_fit(self, meas, est_meas, meas_cov):
462 try:
463 meas_fit_prob = self._meas_fit_pdf(meas, est_meas, meas_cov)
464 except la.LinAlgError:
465 # if self._est_meas_noise_fnc is None:
466 # raise
468 msg = (
469 "Inovation matrix is singular, likely from bad "
470 + "measurement-state pairing for measurement noise estimation."
471 )
472 raise gerr.ExtremeMeasurementNoiseError(msg) from None
473 return meas_fit_prob
475 def correct(self, timestep, meas, cur_state, meas_fun_args=()):
476 """Implements a discrete time correction step for a Kalman Filter.
478 Parameters
479 ----------
480 timestep : float
481 Current timestep.
482 meas : Nm x 1 numpy array
483 Current measurement.
484 cur_state : N x 1 numpy array
485 Current state.
486 meas_fun_args : tuple, optional
487 Arguments for the measurement matrix function if one has
488 been specified. The default is ().
490 Raises
491 ------
492 gncpy.errors.ExtremeMeasurementNoiseError
493 If the measurement fit probability calculation fails.
495 Returns
496 -------
497 next_state : N x 1 numpy array
498 The corrected state.
499 meas_fit_prob : float
500 Goodness of fit of the measurement based on the state and
501 covariance assuming Gaussian noise.
503 """
504 self._init_model()
506 if self.__model is not None:
507 self.__corrParams.measParams = self._measObj.args_to_params(meas_fun_args)
508 out = self.__model.correct(timestep, meas, cur_state, self.__corrParams)
509 return out[0].reshape((-1, 1)), out[1]
511 else:
512 est_meas, meas_mat = self._est_meas(
513 timestep, cur_state, meas.size, meas_fun_args
514 )
516 # get the Kalman gain
517 cov_meas_T = self.cov @ meas_mat.T
518 inov_cov = meas_mat @ cov_meas_T
520 # estimate the measurement noise online if applicable
521 if self._est_meas_noise_fnc is not None:
522 self.meas_noise = self._est_meas_noise_fnc(est_meas, inov_cov)
523 inov_cov += self.meas_noise
524 inov_cov = (inov_cov + inov_cov.T) * 0.5
525 if self.use_cholesky_inverse:
526 sqrt_inv_inov_cov = la.inv(la.cholesky(inov_cov))
527 inv_inov_cov = sqrt_inv_inov_cov.T @ sqrt_inv_inov_cov
528 else:
529 inv_inov_cov = la.inv(inov_cov)
530 kalman_gain = cov_meas_T @ inv_inov_cov
532 # update the state with measurement
533 inov = meas - est_meas
534 next_state = cur_state + kalman_gain @ inov
536 # update the covariance
537 n_states = cur_state.shape[0]
538 self.cov = (np.eye(n_states) - kalman_gain @ meas_mat) @ self.cov
540 # calculate the measuremnt fit probability assuming Gaussian
541 meas_fit_prob = self._calc_meas_fit(meas, est_meas, inov_cov)
543 return (next_state, meas_fit_prob)