Coverage for src/gncpy/filters/extended_kalman_filter.py: 91%
160 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
1import numpy as np
2import numpy.linalg as la
3import scipy.linalg as sla
4import scipy.integrate as s_integrate
5from copy import deepcopy
7import gncpy.dynamics.basic as gdyn
8import gncpy.math as gmath
9import gncpy.filters._filters as cpp_bindings
10from gncpy.filters.kalman_filter import KalmanFilter
13class ExtendedKalmanFilter(KalmanFilter):
14 """Implementation of a continuous-discrete time Extended Kalman Filter.
16 This is loosely based on :cite:`Crassidis2011_OptimalEstimationofDynamicSystems`
18 Attributes
19 ----------
20 cont_cov : bool, optional
21 Flag indicating if a continuous model of the covariance matrix should
22 be used in the filter update step. The default is True.
23 integrator_type : string, optional
24 integrator type as defined by scipy's integrate.ode function. The
25 default is `dopri5`. Only used if a dynamic object is not specified.
26 integrator_params : dict, optional
27 additional parameters for the integrator. The default is {}. Only used
28 if a dynamic object is not specified.
29 """
31 def __init__(self, cont_cov=True, dyn_obj=None, ode_lst=None, **kwargs):
32 super().__init__(**kwargs)
34 self.cont_cov = cont_cov
35 self.integrator_type = "dopri5"
36 self.integrator_params = {}
38 self._ode_lst = None
40 if dyn_obj is not None or ode_lst is not None:
41 self.set_state_model(dyn_obj=dyn_obj, ode_lst=ode_lst)
43 self._integrator = None
45 self.__model = None
46 self.__predParams = None
47 self.__corrParams = None
49 def save_filter_state(self):
50 """Saves filter variables so they can be restored later."""
51 filt_state = super().save_filter_state()
53 filt_state["cont_cov"] = self.cont_cov
54 filt_state["integrator_type"] = self.integrator_type
55 filt_state["integrator_params"] = deepcopy(self.integrator_params)
56 filt_state["_ode_lst"] = self._ode_lst
57 filt_state["_integrator"] = self._integrator
59 filt_state["__model"] = self.__model
60 filt_state["__predParams"] = self.__predParams
61 filt_state["__corrParams"] = self.__corrParams
63 return filt_state
65 def load_filter_state(self, filt_state):
66 """Initializes filter using saved filter state.
68 Attributes
69 ----------
70 filt_state : dict
71 Dictionary generated by :meth:`save_filter_state`.
72 """
73 super().load_filter_state(filt_state)
75 self.cont_cov = filt_state["cont_cov"]
76 self.integrator_type = filt_state["integrator_type"]
77 self.integrator_params = filt_state["integrator_params"]
78 self._ode_lst = filt_state["_ode_lst"]
79 self._integrator = filt_state["_integrator"]
81 self.__model = filt_state["__model"]
82 self.__predParams = filt_state["__predParams"]
83 self.__corrParams = filt_state["__corrParams"]
85 def set_state_model(self, dyn_obj=None, ode_lst=None):
86 r"""Sets the state model equations.
88 This allows for setting the differential equations directly
90 .. math::
91 \dot{x} = f(t, x, u)
93 or setting a :class:`gncpy.dynamics.NonlinearDynamicsBase` object. If
94 the object is specified then a local copy is created. A
95 :class:`gncpy.dynamics.LinearDynamicsBase` can also be used in which
96 case the dynamics follow the same form as the KF. If a linear dynamics
97 object is used then it is recommended to set the filters dt manually so
98 a continuous covariance model can be used in the prediction step.
100 Parameters
101 ----------
102 dyn_obj : :class:`gncpy.dynamics.NonlinearDynamicsBase` or :class:`gncpy.dynamics.LinearDynamicsBase`, optional
103 Sets the dynamics according to the class. The default is None.
104 ode_lst : list, optional
105 callable functions, 1 per ode/state. The callabale must have the
106 signature `f(t, x, *f_args)` just like scipy.integrate's ode
107 function. The default is None.
109 Raises
110 ------
111 RuntimeError
112 If neither argument is specified.
114 Returns
115 -------
116 None.
118 """
119 if dyn_obj is not None and (
120 isinstance(dyn_obj, gdyn.NonlinearDynamicsBase)
121 or isinstance(dyn_obj, gdyn.LinearDynamicsBase)
122 ):
123 self._dyn_obj = deepcopy(dyn_obj)
124 elif ode_lst is not None and len(ode_lst) > 0:
125 self._ode_lst = ode_lst
126 else:
127 msg = "Invalid state model specified. Check arguments"
128 raise RuntimeError(msg)
130 def _cont_dyn(self, t, x, *args):
131 """Used in integrator if an ode list is specified."""
132 out = np.zeros(x.shape)
134 for ii, f in enumerate(self._ode_lst):
135 out[ii] = f(t, x, *args)
136 return out
138 def _predict_next_state(self, timestep, cur_state, dyn_fun_params):
139 if self._dyn_obj is not None:
140 next_state = self._dyn_obj.propagate_state(
141 timestep, cur_state, state_args=dyn_fun_params
142 )
143 if isinstance(self._dyn_obj, gdyn.LinearDynamicsBase):
144 state_mat = self._dyn_obj.get_state_mat(timestep, *dyn_fun_params)
145 dt = self.dt
146 else:
147 state_mat = self._dyn_obj.get_state_mat(
148 timestep, cur_state, *dyn_fun_params, use_continuous=self.cont_cov
149 )
150 dt = self._dyn_obj.dt
151 elif self._ode_lst is not None:
152 self._integrator = s_integrate.ode(self._cont_dyn)
153 self._integrator.set_integrator(
154 self.integrator_type, **self.integrator_params
155 )
156 self._integrator.set_initial_value(cur_state, timestep)
157 self._integrator.set_f_params(*dyn_fun_params)
159 if self.dt is None:
160 raise RuntimeError("dt must be set when using an ODE list")
161 next_time = timestep + self.dt
162 next_state = self._integrator.integrate(next_time).reshape(cur_state.shape)
163 if not self._integrator.successful():
164 msg = "Integration failed at time {}".format(timestep)
165 raise RuntimeError(msg)
166 if self.cont_cov:
167 state_mat = gmath.get_state_jacobian(
168 timestep, cur_state, self._ode_lst, dyn_fun_params
169 )
170 else:
171 raise NotImplementedError(
172 "Non-continous covariance is not implemented yet for ode list"
173 )
174 dt = self.dt
175 else:
176 raise RuntimeError("State model not set")
177 return next_state, state_mat, dt
179 def _init_model(self):
180 self._cpp_needs_init = (
181 self.__model is None
182 and (self._dyn_obj is not None and self._dyn_obj.allow_cpp)
183 and self._measObj is not None
184 )
185 if self._cpp_needs_init:
186 self.__model = cpp_bindings.ExtendedKalman()
187 self.__predParams = cpp_bindings.BayesPredictParams()
188 self.__corrParams = cpp_bindings.BayesCorrectParams()
190 # make sure the cpp filter has its values set based on what python user gave (init only)
191 self.__model.cov = self._cov.astype(np.float64)
192 self.__model.set_state_model(self._dyn_obj.model, self.proc_noise)
193 self.__model.set_measurement_model(self._measObj, self.meas_noise)
195 def predict(
196 self,
197 timestep,
198 cur_state,
199 cur_input=None,
200 dyn_fun_params=None,
201 control_fun_params=None,
202 ):
203 r"""Prediction step of the EKF.
205 This assumes continuous time dynamics and integrates the ode's to get
206 the next state.
208 .. math::
209 x_{k+1} = \int_t^{t+dt} f(t, x, \phi) dt
211 for arbitrary parameters :math:`\phi`
214 Parameters
215 ----------
216 timestep : float
217 Current timestep.
218 cur_state : N x 1 numpy array
219 Current state.
220 dyn_fun_params : tuple, optional
221 Extra arguments to be passed to the dynamics function. The default
222 is None.
223 control_fun_params : tuple, optional
224 Extra arguments to be passed to the control input function. The default is None.
226 Raises
227 ------
228 RuntimeError
229 Integration fails, or state model not set.
231 Returns
232 -------
233 next_state : N x 1 numpy array
234 The predicted state.
236 """
237 self._init_model()
238 if self.__model is not None:
239 if control_fun_params is None:
240 control_fun_params = ()
241 (
242 self.__predParams.stateTransParams,
243 self.__predParams.controlParams,
244 ) = self._dyn_obj.args_to_params(dyn_fun_params, control_fun_params)[:2]
245 return self.__model.predict(
246 timestep, cur_state, cur_input, self.__predParams
247 ).reshape((-1, 1))
249 else:
250 if dyn_fun_params is None:
251 dyn_fun_params = ()
252 next_state, state_mat, dt = self._predict_next_state(
253 timestep, cur_state, dyn_fun_params
254 )
256 if self.cont_cov:
257 if dt is None:
258 raise RuntimeError(
259 "dt can not be None when using a continuous covariance model"
260 )
262 def ode(t, x, n_states, F, proc_noise):
263 P = x.reshape((n_states, n_states))
264 P_dot = F @ P + P @ F.T + proc_noise
265 return P_dot.ravel()
267 integrator = s_integrate.ode(ode)
268 integrator.set_integrator(
269 self.integrator_type, **self.integrator_params
270 )
271 integrator.set_initial_value(self.cov.flatten(), timestep)
272 integrator.set_f_params(cur_state.size, state_mat, self.proc_noise)
273 tmp = integrator.integrate(timestep + dt)
274 if not integrator.successful():
275 msg = "Failed to integrate covariance at {}".format(timestep)
276 raise RuntimeError(msg)
277 self.cov = tmp.reshape(self.cov.shape)
278 else:
279 self.cov = state_mat @ self.cov @ state_mat.T + self.proc_noise
280 return next_state
282 def _get_meas_mat(self, t, state, n_meas, meas_fun_args):
283 # non-linear mapping, potentially time varying
284 if self._meas_fnc is not None:
285 # calculate partial derivatives
286 meas_mat = np.zeros((n_meas, state.size))
287 for ii, h in enumerate(self._meas_fnc):
288 res = gmath.get_jacobian(
289 state.copy(),
290 lambda _x, *_f_args: h(t, _x, *_f_args),
291 f_args=meas_fun_args,
292 )
293 meas_mat[[ii], :] = res.T
294 else:
295 # constant matrix
296 meas_mat = self._meas_mat
297 return meas_mat
299 def _est_meas(self, timestep, cur_state, n_meas, meas_fun_args):
300 meas_mat = self._get_meas_mat(timestep, cur_state, n_meas, meas_fun_args)
302 if self._meas_fnc is not None:
303 est_meas = np.nan * np.ones((n_meas, 1))
304 for ii, h in enumerate(self._meas_fnc):
305 est_meas[ii] = h(timestep, cur_state, *meas_fun_args)
306 else:
307 est_meas = meas_mat @ cur_state
308 return est_meas, meas_mat
310 def set_measurement_model(self, meas_mat=None, meas_fun_lst=None, measObj=None):
311 r"""Sets the measurement model for the filter.
313 This can either set the constant measurement matrix, or a set of
314 non-linear functions (potentially time varying) to map states to
315 measurements.
317 Notes
318 -----
319 The constant matrix assumes a measurement model of the form
321 .. math::
322 \tilde{y}_{k+1} = H x_{k+1}^-
324 and the non-linear case assumes
326 .. math::
327 \tilde{y}_{k+1} = h(t, x_{k+1}^-)
329 Parameters
330 ----------
331 meas_mat : Nm x N numpy array, optional
332 Measurement matrix that transforms the state to estimated
333 measurements. The default is None.
334 meas_fun_lst : list, optional
335 Non-linear functions that return the expected measurement for the
336 given state. Each function must have the signature `h(t, x, *args)`.
337 The default is None.
339 Raises
340 ------
341 RuntimeError
342 Rasied if no arguments are specified.
344 Returns
345 -------
346 None.
347 """
348 super().set_measurement_model(
349 meas_mat=meas_mat, meas_fun=meas_fun_lst, measObj=measObj
350 )
352 def correct(self, timestep, meas, cur_state, meas_fun_args=()):
353 """Implements a discrete time correction step for a Kalman Filter.
355 Parameters
356 ----------
357 timestep : float
358 Current timestep.
359 meas : Nm x 1 numpy array
360 Current measurement.
361 cur_state : N x 1 numpy array
362 Current state.
363 meas_fun_args : tuple, optional
364 Arguments for the measurement matrix function if one has
365 been specified. The default is ().
367 Raises
368 ------
369 gncpy.errors.ExtremeMeasurementNoiseError
370 If the measurement fit probability calculation fails.
372 Returns
373 -------
374 next_state : N x 1 numpy array
375 The corrected state.
376 meas_fit_prob : float
377 Goodness of fit of the measurement based on the state and
378 covariance assuming Gaussian noise.
380 """
381 self._init_model()
383 if self.__model is not None:
384 self.__corrParams.measParams = self._measObj.args_to_params(meas_fun_args)
385 out = self.__model.correct(timestep, meas, cur_state, self.__corrParams)
386 return out[0].reshape((-1, 1)), out[1]
388 else:
389 est_meas, meas_mat = self._est_meas(
390 timestep, cur_state, meas.size, meas_fun_args
391 )
393 # get the Kalman gain
394 cov_meas_T = self.cov @ meas_mat.T
395 inov_cov = meas_mat @ cov_meas_T
397 # estimate the measurement noise online if applicable
398 if self._est_meas_noise_fnc is not None:
399 self.meas_noise = self._est_meas_noise_fnc(est_meas, inov_cov)
400 inov_cov += self.meas_noise
401 inov_cov = (inov_cov + inov_cov.T) * 0.5
402 if self.use_cholesky_inverse:
403 sqrt_inv_inov_cov = la.inv(la.cholesky(inov_cov))
404 inv_inov_cov = sqrt_inv_inov_cov.T @ sqrt_inv_inov_cov
405 else:
406 inv_inov_cov = la.inv(inov_cov)
407 kalman_gain = cov_meas_T @ inv_inov_cov
409 # update the state with measurement
410 inov = meas - est_meas
411 next_state = cur_state + kalman_gain @ inov
413 # update the covariance
414 n_states = cur_state.shape[0]
415 self.cov = (np.eye(n_states) - kalman_gain @ meas_mat) @ self.cov
417 # calculate the measuremnt fit probability assuming Gaussian
418 meas_fit_prob = self._calc_meas_fit(meas, est_meas, inov_cov)
420 return (next_state, meas_fit_prob)