Coverage for src/gncpy/filters/unscented_kalman_filter.py: 88%
105 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
1import numpy as np
2import numpy.linalg as la
3from copy import deepcopy
5import gncpy.dynamics.basic as gdyn
6import gncpy.distributions as gdistrib
7import gncpy.math as gmath
8from gncpy.filters.kalman_filter import KalmanFilter
9from gncpy.filters.extended_kalman_filter import ExtendedKalmanFilter
12class UnscentedKalmanFilter(ExtendedKalmanFilter):
13 """Implements an unscented kalman filter.
15 This allows for linear or non-linear dynamics by utilizing the either the
16 underlying KF or EKF functions where appropriate. It utilizes the same
17 constraints on the measurement model as the EKF.
19 Notes
20 -----
21 For details on the filter see
22 :cite:`Wan2000_TheUnscentedKalmanFilterforNonlinearEstimation`. This
23 implementation assumes that the noise is purely addative and as such,
24 appropriate simplifications have been made. These simplifications include
25 not using sigma points to track the noise, and using the fixed process and
26 measurement noise covariance matrices in the filter's covariance updates.
28 Attributes
29 ----------
30 alpha : float
31 Tunig parameter for sigma points, influences the spread of sigma points about the
32 mean. In range (0, 1]. If specified then a value does not need to be
33 given to the :meth:`.init_sigma_points` function.
34 kappa : float
35 Tunig parameter for sigma points, influences the spread of sigma points about the
36 mean. In range [0, inf]. If specified then a value does not need to be
37 given to the :meth:`.init_sigma_points` function.
38 beta : float
39 Tunig parameter for sigma points. In range [0, Inf]. If specified then
40 a value does not need to be given to the :meth:`.init_sigma_points` function.
41 Defaults to 2 (ideal for gaussians).
42 """
44 def __init__(self, sigmaPoints=None, **kwargs):
45 """Initialize an instance.
47 Parameters
48 ----------
49 sigmaPoints : :class:`.distributions.SigmaPoints`, optional
50 Set of initialized sigma points to use. The default is None.
51 **kwargs : dict, optional
52 Additional arguments for parent constructors.
53 """
54 self.alpha = 1
55 self.kappa = 0
56 self.beta = 2
58 self._stateSigmaPoints = None
59 if isinstance(sigmaPoints, gdistrib.SigmaPoints):
60 self._stateSigmaPoints = sigmaPoints
61 self.alpha = sigmaPoints.alpha
62 self.beta = sigmaPoints.beta
63 self.kappa = sigmaPoints.kappa
64 self._use_lin_dyn = False
65 self._use_non_lin_dyn = False
66 self._est_meas_noise_fnc = None
68 super().__init__(**kwargs)
70 def save_filter_state(self):
71 """Saves filter variables so they can be restored later."""
72 filt_state = super().save_filter_state()
74 filt_state["alpha"] = self.alpha
75 filt_state["kappa"] = self.kappa
76 filt_state["beta"] = self.beta
78 filt_state["_stateSigmaPoints"] = deepcopy(self._stateSigmaPoints)
79 filt_state["_use_lin_dyn"] = self._use_lin_dyn
80 filt_state["_use_non_lin_dyn"] = self._use_non_lin_dyn
81 filt_state["_est_meas_noise_fnc"] = self._est_meas_noise_fnc
83 return filt_state
85 def load_filter_state(self, filt_state):
86 """Initializes filter using saved filter state.
88 Attributes
89 ----------
90 filt_state : dict
91 Dictionary generated by :meth:`save_filter_state`.
92 """
93 super().load_filter_state(filt_state)
95 self.alpha = filt_state["alpha"]
96 self.kappa = filt_state["kappa"]
97 self.beta = filt_state["beta"]
99 self._stateSigmaPoints = deepcopy(filt_state["_stateSigmaPoints"])
100 self._use_lin_dyn = filt_state["_use_lin_dyn"]
101 self._use_non_lin_dyn = filt_state["_use_non_lin_dyn"]
102 self._est_meas_noise_fnc = filt_state["_est_meas_noise_fnc"]
104 def init_sigma_points(self, state0, alpha=None, kappa=None, beta=None):
105 """Initializes the sigma points used by the filter.
107 Parameters
108 ----------
109 state0 : N x 1 numpy array
110 Initial state.
111 alpha : float, optional
112 Tunig parameter, influences the spread of sigma points about the
113 mean. In range (0, 1]. If not supplied the class value will be used.
114 If a value is given here then the class value will be updated.
115 kappa : float, optional
116 Tunig parameter, influences the spread of sigma points about the
117 mean. In range [0, inf]. If not supplied the class value will be used.
118 If a value is given here then the class value will be updated.
119 beta : float, optional
120 Tunig parameter for distribution type. In range [0, Inf]. If not
121 supplied the class value will be used. If a value is given here
122 then the class value will be updated.
123 Defaults to 2 for gaussians.
124 """
125 num_axes = state0.size
126 if alpha is None:
127 alpha = self.alpha
128 else:
129 self.alpha = alpha
130 if kappa is None:
131 kappa = self.kappa
132 else:
133 self.kappa = kappa
134 if beta is None:
135 beta = self.beta
136 else:
137 self.beta = beta
138 self._stateSigmaPoints = gdistrib.SigmaPoints(
139 alpha=alpha, kappa=kappa, beta=beta, num_axes=num_axes
140 )
141 self._stateSigmaPoints.init_weights()
142 self._stateSigmaPoints.update_points(state0, self.cov)
144 def set_state_model(
145 self,
146 state_mat=None,
147 input_mat=None,
148 cont_time=False,
149 state_mat_fun=None,
150 input_mat_fun=None,
151 dyn_obj=None,
152 ode_lst=None,
153 ):
154 """Sets the state model for the filter.
156 This can use either linear dynamics (by calling the kalman filters
157 :meth:`gncpy.filters.KalmanFilter.set_state_model`) or non-linear dynamics
158 (by calling :meth:`gncpy.filters.ExtendedKalmanFilter.set_state_model`).
159 The linearness is automatically determined by the input arguments specified.
161 Parameters
162 ----------
163 state_mat : N x N numpy array, optional
164 State matrix, continuous or discrete case. The default is None.
165 input_mat : N x Nu numpy array, optional
166 Input matrixx, continuous or discrete case. The default is None.
167 cont_time : bool, optional
168 Flag inidicating if the continuous model is provided. The default
169 is False.
170 state_mat_fun : callable, optional
171 Function that returns the `state_mat`, must take timestep and
172 `*args`. The default is None.
173 input_mat_fun : callable, optional
174 Function that returns the `input_mat`, must take timestep, and
175 `*args`. The default is None.
176 dyn_obj : :class:`gncpy.dynamics.LinearDynamicsBase` or :class:`gncpy.dynamics.NonlinearDynamicsBase`, optional
177 Sets the dynamics according to the class. The default is None.
178 ode_lst : list, optional
179 callable functions, 1 per ode/state. The callabale must have the
180 signature `f(t, x, *f_args)` just like scipy.integrate's ode
181 function. The default is None.
183 Raises
184 ------
185 RuntimeError
186 If an invalid state model or combination of inputs is specified.
188 Returns
189 -------
190 None.
191 """
192 self._use_lin_dyn = (
193 state_mat is not None
194 or state_mat_fun is not None
195 or isinstance(dyn_obj, gdyn.LinearDynamicsBase)
196 )
197 self._use_non_lin_dyn = (
198 isinstance(dyn_obj, gdyn.NonlinearDynamicsBase) or ode_lst is not None
199 ) and not self._use_lin_dyn
201 # allow for linear or non linear dynamics by calling the appropriate parent
202 if self._use_lin_dyn:
203 KalmanFilter.set_state_model(
204 self,
205 state_mat=state_mat,
206 input_mat=input_mat,
207 cont_time=cont_time,
208 state_mat_fun=state_mat_fun,
209 input_mat_fun=input_mat_fun,
210 dyn_obj=dyn_obj,
211 )
212 elif self._use_non_lin_dyn:
213 ExtendedKalmanFilter.set_state_model(self, dyn_obj=dyn_obj, ode_lst=ode_lst)
214 else:
215 raise RuntimeError("Invalid state model.")
217 def set_measurement_noise_estimator(self, function):
218 """Sets the model used for estimating the measurement noise parameters.
220 This is an optional step and the filter will work properly if this is
221 not called. If it is called, the measurement noise will be estimated
222 during the filter's correction step and the measurement noise attribute
223 will not be used.
225 Parameters
226 ----------
227 function : callable
228 A function that implements the prediction and correction steps for
229 an appropriate filter to estimate the measurement noise covariance
230 matrix. It must have the signature `f(est_meas)` where `est_meas`
231 is an Nm x 1 numpy array and it must return an Nm x Nm numpy array
232 representing the measurement noise covariance matrix.
234 Returns
235 -------
236 None.
237 """
238 self._est_meas_noise_fnc = function
240 def predict(
241 self,
242 timestep,
243 cur_state,
244 cur_input=None,
245 state_mat_args=(),
246 input_mat_args=(),
247 dyn_fun_params=(),
248 ):
249 """Prediction step of the UKF.
251 Automatically calls the state propagation method from either
252 :meth:`gncpy.KalmanFilter.predict` or :meth:`gncpy.ExtendedKalmanFilter.predict`
253 depending on if a linear or non-linear state model was specified.
254 If a linear model is used only the parameters that can be passed to
255 :meth:`gncpy.KalmanFilter.predict` will be used by this function.
256 Otherwise the parameters for :meth:`gncpy.ExtendedKalmanFilter.predict`
257 will be used.
259 Parameters
260 ----------
261 timestep : float
262 Current timestep.
263 cur_state : N x 1 numpy array
264 Current state.
265 cur_input : N x Nu numpy array, optional
266 Current input for linear models. The default is None.
267 state_mat_args : tuple, optional
268 keyword arguments for the get state matrix function if one has
269 been specified or the propagate state function if using a linear
270 dynamic object. The default is ().
271 input_mat_args : tuple, optional
272 keyword arguments for the get input matrix function if one has
273 been specified or the propagate state function if using a linear
274 dynamic object. The default is ().
275 dyn_fun_params : tuple, optional
276 Extra arguments to be passed to the dynamics function for non-linear
277 models. The default is ().
279 Raises
280 ------
281 RuntimeError
282 If a state model has not been set.
284 Returns
285 -------
286 next_state : N x 1 numpy array
287 The next state.
289 """
290 self._stateSigmaPoints.update_points(cur_state, self.cov)
292 # propagate points
293 if self._use_lin_dyn:
294 new_points = np.array(
295 [
296 KalmanFilter._predict_next_state(
297 self,
298 timestep,
299 x.reshape((x.size, 1)),
300 cur_input,
301 state_mat_args,
302 input_mat_args,
303 )[0].ravel()
304 for x in self._stateSigmaPoints.points
305 ]
306 )
307 elif self._use_non_lin_dyn:
308 new_points = np.array(
309 [
310 ExtendedKalmanFilter._predict_next_state(
311 self, timestep, x.reshape((x.size, 1)), dyn_fun_params
312 )[0].ravel()
313 for x in self._stateSigmaPoints.points
314 ]
315 )
316 else:
317 raise RuntimeError("State model not specified")
318 self._stateSigmaPoints.points = new_points
320 # update covariance
321 self.cov = self._stateSigmaPoints.cov + self.proc_noise
322 self.cov = (self.cov + self.cov.T) * 0.5
324 # estimate weighted state output
325 next_state = self._stateSigmaPoints.mean
327 return next_state
329 def _calc_meas_cov(self, timestep, n_meas, meas_fun_args):
330 est_points = np.array(
331 [
332 self._est_meas(timestep, x.reshape((x.size, 1)), n_meas, meas_fun_args)[
333 0
334 ]
335 for x in self._stateSigmaPoints.points
336 ]
337 )
338 est_meas = gmath.weighted_sum_vec(
339 self._stateSigmaPoints.weights_mean, est_points
340 )
341 diff = est_points - est_meas
342 meas_cov_lst = diff @ diff.reshape((est_points.shape[0], 1, est_meas.size))
344 partial_cov = gmath.weighted_sum_mat(
345 self._stateSigmaPoints.weights_cov, meas_cov_lst
346 )
347 # estimate the measurement noise if applicable
348 if self._est_meas_noise_fnc is not None:
349 self.meas_noise = self._est_meas_noise_fnc(est_meas, partial_cov)
350 meas_cov = self.meas_noise + partial_cov
352 return meas_cov, est_points, est_meas
354 def correct(self, timestep, meas, cur_state, meas_fun_args=()):
355 """Correction step of the UKF.
357 Parameters
358 ----------
359 timestep : float
360 Current timestep.
361 meas : Nm x 1 numpy array
362 Current measurement.
363 cur_state : N x 1 numpy array
364 Current state.
365 meas_fun_args : tuple, optional
366 Arguments for the measurement matrix function if one has
367 been specified. The default is ().
369 Raises
370 ------
371 :class:`.errors.ExtremeMeasurementNoiseError`
372 If estimating the measurement noise and the measurement fit calculation fails.
373 LinAlgError
374 Numpy exception raised if not estimating noise and measurement fit fails.
376 Returns
377 -------
378 next_state : N x 1 numpy array
379 corrected state.
380 meas_fit_prob : float
381 measurement fit probability assuming a Gaussian distribution.
383 """
384 meas_cov, est_points, est_meas = self._calc_meas_cov(
385 timestep, meas.size, meas_fun_args
386 )
388 state_diff = self._stateSigmaPoints.points - cur_state.ravel()
389 meas_diff = (est_points - est_meas).reshape(
390 (est_points.shape[0], 1, est_meas.size)
391 )
392 cross_cov_lst = state_diff.reshape((*state_diff.shape, 1)) @ meas_diff
393 cross_cov = gmath.weighted_sum_mat(
394 self._stateSigmaPoints.weights_cov, cross_cov_lst
395 )
397 if self.use_cholesky_inverse:
398 sqrt_inv_meas_cov = la.inv(la.cholesky(meas_cov))
399 inv_meas_cov = sqrt_inv_meas_cov.T @ sqrt_inv_meas_cov
400 else:
401 inv_meas_cov = la.inv(meas_cov)
402 gain = cross_cov @ inv_meas_cov
403 inov = meas - est_meas
405 self.cov = self.cov - gain @ meas_cov @ gain.T
406 self.cov = (self.cov + self.cov.T) * 0.5
407 next_state = cur_state + gain @ inov
409 meas_fit_prob = self._calc_meas_fit(meas, est_meas, meas_cov)
411 return next_state, meas_fit_prob