Coverage for src/gncpy/filters/extended_kalman_filter.py: 91%

160 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-19 05:48 +0000

1import numpy as np 

2import numpy.linalg as la 

3import scipy.linalg as sla 

4import scipy.integrate as s_integrate 

5from copy import deepcopy 

6 

7import gncpy.dynamics.basic as gdyn 

8import gncpy.math as gmath 

9import gncpy.filters._filters as cpp_bindings 

10from gncpy.filters.kalman_filter import KalmanFilter 

11 

12 

13class ExtendedKalmanFilter(KalmanFilter): 

14 """Implementation of a continuous-discrete time Extended Kalman Filter. 

15 

16 This is loosely based on :cite:`Crassidis2011_OptimalEstimationofDynamicSystems` 

17 

18 Attributes 

19 ---------- 

20 cont_cov : bool, optional 

21 Flag indicating if a continuous model of the covariance matrix should 

22 be used in the filter update step. The default is True. 

23 integrator_type : string, optional 

24 integrator type as defined by scipy's integrate.ode function. The 

25 default is `dopri5`. Only used if a dynamic object is not specified. 

26 integrator_params : dict, optional 

27 additional parameters for the integrator. The default is {}. Only used 

28 if a dynamic object is not specified. 

29 """ 

30 

31 def __init__(self, cont_cov=True, dyn_obj=None, ode_lst=None, **kwargs): 

32 super().__init__(**kwargs) 

33 

34 self.cont_cov = cont_cov 

35 self.integrator_type = "dopri5" 

36 self.integrator_params = {} 

37 

38 self._ode_lst = None 

39 

40 if dyn_obj is not None or ode_lst is not None: 

41 self.set_state_model(dyn_obj=dyn_obj, ode_lst=ode_lst) 

42 

43 self._integrator = None 

44 

45 self.__model = None 

46 self.__predParams = None 

47 self.__corrParams = None 

48 

49 def save_filter_state(self): 

50 """Saves filter variables so they can be restored later.""" 

51 filt_state = super().save_filter_state() 

52 

53 filt_state["cont_cov"] = self.cont_cov 

54 filt_state["integrator_type"] = self.integrator_type 

55 filt_state["integrator_params"] = deepcopy(self.integrator_params) 

56 filt_state["_ode_lst"] = self._ode_lst 

57 filt_state["_integrator"] = self._integrator 

58 

59 filt_state["__model"] = self.__model 

60 filt_state["__predParams"] = self.__predParams 

61 filt_state["__corrParams"] = self.__corrParams 

62 

63 return filt_state 

64 

65 def load_filter_state(self, filt_state): 

66 """Initializes filter using saved filter state. 

67 

68 Attributes 

69 ---------- 

70 filt_state : dict 

71 Dictionary generated by :meth:`save_filter_state`. 

72 """ 

73 super().load_filter_state(filt_state) 

74 

75 self.cont_cov = filt_state["cont_cov"] 

76 self.integrator_type = filt_state["integrator_type"] 

77 self.integrator_params = filt_state["integrator_params"] 

78 self._ode_lst = filt_state["_ode_lst"] 

79 self._integrator = filt_state["_integrator"] 

80 

81 self.__model = filt_state["__model"] 

82 self.__predParams = filt_state["__predParams"] 

83 self.__corrParams = filt_state["__corrParams"] 

84 

85 def set_state_model(self, dyn_obj=None, ode_lst=None): 

86 r"""Sets the state model equations. 

87 

88 This allows for setting the differential equations directly 

89 

90 .. math:: 

91 \dot{x} = f(t, x, u) 

92 

93 or setting a :class:`gncpy.dynamics.NonlinearDynamicsBase` object. If 

94 the object is specified then a local copy is created. A 

95 :class:`gncpy.dynamics.LinearDynamicsBase` can also be used in which 

96 case the dynamics follow the same form as the KF. If a linear dynamics 

97 object is used then it is recommended to set the filters dt manually so 

98 a continuous covariance model can be used in the prediction step. 

99 

100 Parameters 

101 ---------- 

102 dyn_obj : :class:`gncpy.dynamics.NonlinearDynamicsBase` or :class:`gncpy.dynamics.LinearDynamicsBase`, optional 

103 Sets the dynamics according to the class. The default is None. 

104 ode_lst : list, optional 

105 callable functions, 1 per ode/state. The callabale must have the 

106 signature `f(t, x, *f_args)` just like scipy.integrate's ode 

107 function. The default is None. 

108 

109 Raises 

110 ------ 

111 RuntimeError 

112 If neither argument is specified. 

113 

114 Returns 

115 ------- 

116 None. 

117 

118 """ 

119 if dyn_obj is not None and ( 

120 isinstance(dyn_obj, gdyn.NonlinearDynamicsBase) 

121 or isinstance(dyn_obj, gdyn.LinearDynamicsBase) 

122 ): 

123 self._dyn_obj = deepcopy(dyn_obj) 

124 elif ode_lst is not None and len(ode_lst) > 0: 

125 self._ode_lst = ode_lst 

126 else: 

127 msg = "Invalid state model specified. Check arguments" 

128 raise RuntimeError(msg) 

129 

130 def _cont_dyn(self, t, x, *args): 

131 """Used in integrator if an ode list is specified.""" 

132 out = np.zeros(x.shape) 

133 

134 for ii, f in enumerate(self._ode_lst): 

135 out[ii] = f(t, x, *args) 

136 return out 

137 

138 def _predict_next_state(self, timestep, cur_state, dyn_fun_params): 

139 if self._dyn_obj is not None: 

140 next_state = self._dyn_obj.propagate_state( 

141 timestep, cur_state, state_args=dyn_fun_params 

142 ) 

143 if isinstance(self._dyn_obj, gdyn.LinearDynamicsBase): 

144 state_mat = self._dyn_obj.get_state_mat(timestep, *dyn_fun_params) 

145 dt = self.dt 

146 else: 

147 state_mat = self._dyn_obj.get_state_mat( 

148 timestep, cur_state, *dyn_fun_params, use_continuous=self.cont_cov 

149 ) 

150 dt = self._dyn_obj.dt 

151 elif self._ode_lst is not None: 

152 self._integrator = s_integrate.ode(self._cont_dyn) 

153 self._integrator.set_integrator( 

154 self.integrator_type, **self.integrator_params 

155 ) 

156 self._integrator.set_initial_value(cur_state, timestep) 

157 self._integrator.set_f_params(*dyn_fun_params) 

158 

159 if self.dt is None: 

160 raise RuntimeError("dt must be set when using an ODE list") 

161 next_time = timestep + self.dt 

162 next_state = self._integrator.integrate(next_time).reshape(cur_state.shape) 

163 if not self._integrator.successful(): 

164 msg = "Integration failed at time {}".format(timestep) 

165 raise RuntimeError(msg) 

166 if self.cont_cov: 

167 state_mat = gmath.get_state_jacobian( 

168 timestep, cur_state, self._ode_lst, dyn_fun_params 

169 ) 

170 else: 

171 raise NotImplementedError( 

172 "Non-continous covariance is not implemented yet for ode list" 

173 ) 

174 dt = self.dt 

175 else: 

176 raise RuntimeError("State model not set") 

177 return next_state, state_mat, dt 

178 

179 def _init_model(self): 

180 self._cpp_needs_init = ( 

181 self.__model is None 

182 and (self._dyn_obj is not None and self._dyn_obj.allow_cpp) 

183 and self._measObj is not None 

184 ) 

185 if self._cpp_needs_init: 

186 self.__model = cpp_bindings.ExtendedKalman() 

187 self.__predParams = cpp_bindings.BayesPredictParams() 

188 self.__corrParams = cpp_bindings.BayesCorrectParams() 

189 

190 # make sure the cpp filter has its values set based on what python user gave (init only) 

191 self.__model.cov = self._cov.astype(np.float64) 

192 self.__model.set_state_model(self._dyn_obj.model, self.proc_noise) 

193 self.__model.set_measurement_model(self._measObj, self.meas_noise) 

194 

195 def predict( 

196 self, 

197 timestep, 

198 cur_state, 

199 cur_input=None, 

200 dyn_fun_params=None, 

201 control_fun_params=None, 

202 ): 

203 r"""Prediction step of the EKF. 

204 

205 This assumes continuous time dynamics and integrates the ode's to get 

206 the next state. 

207 

208 .. math:: 

209 x_{k+1} = \int_t^{t+dt} f(t, x, \phi) dt 

210 

211 for arbitrary parameters :math:`\phi` 

212 

213 

214 Parameters 

215 ---------- 

216 timestep : float 

217 Current timestep. 

218 cur_state : N x 1 numpy array 

219 Current state. 

220 dyn_fun_params : tuple, optional 

221 Extra arguments to be passed to the dynamics function. The default 

222 is None. 

223 control_fun_params : tuple, optional 

224 Extra arguments to be passed to the control input function. The default is None. 

225 

226 Raises 

227 ------ 

228 RuntimeError 

229 Integration fails, or state model not set. 

230 

231 Returns 

232 ------- 

233 next_state : N x 1 numpy array 

234 The predicted state. 

235 

236 """ 

237 self._init_model() 

238 if self.__model is not None: 

239 if control_fun_params is None: 

240 control_fun_params = () 

241 ( 

242 self.__predParams.stateTransParams, 

243 self.__predParams.controlParams, 

244 ) = self._dyn_obj.args_to_params(dyn_fun_params, control_fun_params)[:2] 

245 return self.__model.predict( 

246 timestep, cur_state, cur_input, self.__predParams 

247 ).reshape((-1, 1)) 

248 

249 else: 

250 if dyn_fun_params is None: 

251 dyn_fun_params = () 

252 next_state, state_mat, dt = self._predict_next_state( 

253 timestep, cur_state, dyn_fun_params 

254 ) 

255 

256 if self.cont_cov: 

257 if dt is None: 

258 raise RuntimeError( 

259 "dt can not be None when using a continuous covariance model" 

260 ) 

261 

262 def ode(t, x, n_states, F, proc_noise): 

263 P = x.reshape((n_states, n_states)) 

264 P_dot = F @ P + P @ F.T + proc_noise 

265 return P_dot.ravel() 

266 

267 integrator = s_integrate.ode(ode) 

268 integrator.set_integrator( 

269 self.integrator_type, **self.integrator_params 

270 ) 

271 integrator.set_initial_value(self.cov.flatten(), timestep) 

272 integrator.set_f_params(cur_state.size, state_mat, self.proc_noise) 

273 tmp = integrator.integrate(timestep + dt) 

274 if not integrator.successful(): 

275 msg = "Failed to integrate covariance at {}".format(timestep) 

276 raise RuntimeError(msg) 

277 self.cov = tmp.reshape(self.cov.shape) 

278 else: 

279 self.cov = state_mat @ self.cov @ state_mat.T + self.proc_noise 

280 return next_state 

281 

282 def _get_meas_mat(self, t, state, n_meas, meas_fun_args): 

283 # non-linear mapping, potentially time varying 

284 if self._meas_fnc is not None: 

285 # calculate partial derivatives 

286 meas_mat = np.zeros((n_meas, state.size)) 

287 for ii, h in enumerate(self._meas_fnc): 

288 res = gmath.get_jacobian( 

289 state.copy(), 

290 lambda _x, *_f_args: h(t, _x, *_f_args), 

291 f_args=meas_fun_args, 

292 ) 

293 meas_mat[[ii], :] = res.T 

294 else: 

295 # constant matrix 

296 meas_mat = self._meas_mat 

297 return meas_mat 

298 

299 def _est_meas(self, timestep, cur_state, n_meas, meas_fun_args): 

300 meas_mat = self._get_meas_mat(timestep, cur_state, n_meas, meas_fun_args) 

301 

302 if self._meas_fnc is not None: 

303 est_meas = np.nan * np.ones((n_meas, 1)) 

304 for ii, h in enumerate(self._meas_fnc): 

305 est_meas[ii] = h(timestep, cur_state, *meas_fun_args) 

306 else: 

307 est_meas = meas_mat @ cur_state 

308 return est_meas, meas_mat 

309 

310 def set_measurement_model(self, meas_mat=None, meas_fun_lst=None, measObj=None): 

311 r"""Sets the measurement model for the filter. 

312 

313 This can either set the constant measurement matrix, or a set of 

314 non-linear functions (potentially time varying) to map states to 

315 measurements. 

316 

317 Notes 

318 ----- 

319 The constant matrix assumes a measurement model of the form 

320 

321 .. math:: 

322 \tilde{y}_{k+1} = H x_{k+1}^- 

323 

324 and the non-linear case assumes 

325 

326 .. math:: 

327 \tilde{y}_{k+1} = h(t, x_{k+1}^-) 

328 

329 Parameters 

330 ---------- 

331 meas_mat : Nm x N numpy array, optional 

332 Measurement matrix that transforms the state to estimated 

333 measurements. The default is None. 

334 meas_fun_lst : list, optional 

335 Non-linear functions that return the expected measurement for the 

336 given state. Each function must have the signature `h(t, x, *args)`. 

337 The default is None. 

338 

339 Raises 

340 ------ 

341 RuntimeError 

342 Rasied if no arguments are specified. 

343 

344 Returns 

345 ------- 

346 None. 

347 """ 

348 super().set_measurement_model( 

349 meas_mat=meas_mat, meas_fun=meas_fun_lst, measObj=measObj 

350 ) 

351 

352 def correct(self, timestep, meas, cur_state, meas_fun_args=()): 

353 """Implements a discrete time correction step for a Kalman Filter. 

354 

355 Parameters 

356 ---------- 

357 timestep : float 

358 Current timestep. 

359 meas : Nm x 1 numpy array 

360 Current measurement. 

361 cur_state : N x 1 numpy array 

362 Current state. 

363 meas_fun_args : tuple, optional 

364 Arguments for the measurement matrix function if one has 

365 been specified. The default is (). 

366 

367 Raises 

368 ------ 

369 gncpy.errors.ExtremeMeasurementNoiseError 

370 If the measurement fit probability calculation fails. 

371 

372 Returns 

373 ------- 

374 next_state : N x 1 numpy array 

375 The corrected state. 

376 meas_fit_prob : float 

377 Goodness of fit of the measurement based on the state and 

378 covariance assuming Gaussian noise. 

379 

380 """ 

381 self._init_model() 

382 

383 if self.__model is not None: 

384 self.__corrParams.measParams = self._measObj.args_to_params(meas_fun_args) 

385 out = self.__model.correct(timestep, meas, cur_state, self.__corrParams) 

386 return out[0].reshape((-1, 1)), out[1] 

387 

388 else: 

389 est_meas, meas_mat = self._est_meas( 

390 timestep, cur_state, meas.size, meas_fun_args 

391 ) 

392 

393 # get the Kalman gain 

394 cov_meas_T = self.cov @ meas_mat.T 

395 inov_cov = meas_mat @ cov_meas_T 

396 

397 # estimate the measurement noise online if applicable 

398 if self._est_meas_noise_fnc is not None: 

399 self.meas_noise = self._est_meas_noise_fnc(est_meas, inov_cov) 

400 inov_cov += self.meas_noise 

401 inov_cov = (inov_cov + inov_cov.T) * 0.5 

402 if self.use_cholesky_inverse: 

403 sqrt_inv_inov_cov = la.inv(la.cholesky(inov_cov)) 

404 inv_inov_cov = sqrt_inv_inov_cov.T @ sqrt_inv_inov_cov 

405 else: 

406 inv_inov_cov = la.inv(inov_cov) 

407 kalman_gain = cov_meas_T @ inv_inov_cov 

408 

409 # update the state with measurement 

410 inov = meas - est_meas 

411 next_state = cur_state + kalman_gain @ inov 

412 

413 # update the covariance 

414 n_states = cur_state.shape[0] 

415 self.cov = (np.eye(n_states) - kalman_gain @ meas_mat) @ self.cov 

416 

417 # calculate the measuremnt fit probability assuming Gaussian 

418 meas_fit_prob = self._calc_meas_fit(meas, est_meas, inov_cov) 

419 

420 return (next_state, meas_fit_prob)