Coverage for src/gncpy/filters/quadrature_kalman_filter.py: 91%

107 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-19 05:48 +0000

1import numpy as np 

2import numpy.linalg as la 

3 

4import gncpy.distributions as gdistrib 

5import gncpy.dynamics.basic as gdyn 

6import gncpy.math as gmath 

7from gncpy.filters.kalman_filter import KalmanFilter 

8from gncpy.filters.extended_kalman_filter import ExtendedKalmanFilter 

9 

10 

11class QuadratureKalmanFilter(ExtendedKalmanFilter): 

12 """Implementation of a Quadrature Kalman Filter. 

13 

14 Notes 

15 ----- 

16 This implementation is based on 

17 :cite:`Arasaratnam2007_DiscreteTimeNonlinearFilteringAlgorithmsUsingGaussHermiteQuadrature` 

18 and uses Gauss-Hermite quadrature points. It inherits from EKF to allow 

19 for easier implementation of non-linear dynamics. 

20 

21 Attributes 

22 ---------- 

23 quadPoints : :class:`gncpy.distributions.QuadraturePoints` 

24 Quadrature points used by the filter. 

25 """ 

26 

27 def __init__(self, points_per_axis=None, **kwargs): 

28 super().__init__(**kwargs) 

29 

30 self.quadPoints = gdistrib.QuadraturePoints(points_per_axis=points_per_axis) 

31 self._sqrt_cov = np.array([[]]) 

32 self._use_lin_dyn = False 

33 self._use_non_lin_dyn = False 

34 

35 def save_filter_state(self): 

36 """Saves filter variables so they can be restored later.""" 

37 filt_state = super().save_filter_state() 

38 

39 filt_state["quadPoints"] = self.quadPoints 

40 

41 filt_state["_sqrt_cov"] = self._sqrt_cov 

42 filt_state["_use_lin_dyn"] = self._use_lin_dyn 

43 filt_state["_use_non_lin_dyn"] = self._use_non_lin_dyn 

44 

45 return filt_state 

46 

47 def load_filter_state(self, filt_state): 

48 """Initializes filter using saved filter state. 

49 

50 Attributes 

51 ---------- 

52 filt_state : dict 

53 Dictionary generated by :meth:`save_filter_state`. 

54 """ 

55 super().load_filter_state(filt_state) 

56 

57 self.quadPoints = filt_state["quadPoints"] 

58 self._sqrt_cov = filt_state["_sqrt_cov"] 

59 self._use_lin_dyn = filt_state["_use_lin_dyn"] 

60 self._use_non_lin_dyn = filt_state["_use_non_lin_dyn"] 

61 

62 @property 

63 def points_per_axis(self): 

64 """Wrapper for the number of quadrature points per axis.""" 

65 return self.quadPoints.points_per_axis 

66 

67 @points_per_axis.setter 

68 def points_per_axis(self, val): 

69 self.quadPoints.points_per_axis = val 

70 

71 def set_state_model( 

72 self, 

73 state_mat=None, 

74 input_mat=None, 

75 cont_time=False, 

76 state_mat_fun=None, 

77 input_mat_fun=None, 

78 dyn_obj=None, 

79 ode_lst=None, 

80 ): 

81 """Sets the state model for the filter. 

82 

83 This can use either linear dynamics (by calling the kalman filters 

84 :meth:`gncpy.filters.KalmanFilter.set_state_model`) or non-linear dynamics 

85 (by calling :meth:`gncpy.filters.ExtendedKalmanFilter.set_state_model`). 

86 The linearness is automatically determined by the input arguments specified. 

87 

88 Parameters 

89 ---------- 

90 state_mat : N x N numpy array, optional 

91 State matrix, continuous or discrete case. The default is None. 

92 input_mat : N x Nu numpy array, optional 

93 Input matrixx, continuous or discrete case. The default is None. 

94 cont_time : bool, optional 

95 Flag inidicating if the continuous model is provided. The default 

96 is False. 

97 state_mat_fun : callable, optional 

98 Function that returns the `state_mat`, must take timestep and 

99 `*args`. The default is None. 

100 input_mat_fun : callable, optional 

101 Function that returns the `input_mat`, must take timestep, and 

102 `*args`. The default is None. 

103 dyn_obj : :class:`gncpy.dynamics.LinearDynamicsBase` or :class:`gncpy.dynamics.NonlinearDynamicsBase`, optional 

104 Sets the dynamics according to the class. The default is None. 

105 ode_lst : list, optional 

106 callable functions, 1 per ode/state. The callabale must have the 

107 signature `f(t, x, *f_args)` just like scipy.integrate's ode 

108 function. The default is None. 

109 

110 Raises 

111 ------ 

112 RuntimeError 

113 If an invalid state model or combination of inputs is specified. 

114 

115 Returns 

116 ------- 

117 None. 

118 """ 

119 self._use_lin_dyn = ( 

120 state_mat is not None 

121 or state_mat_fun is not None 

122 or isinstance(dyn_obj, gdyn.LinearDynamicsBase) 

123 ) 

124 self._use_non_lin_dyn = ( 

125 isinstance(dyn_obj, gdyn.NonlinearDynamicsBase) or ode_lst is not None 

126 ) and not self._use_lin_dyn 

127 

128 # allow for linear or non linear dynamics by calling the appropriate parent 

129 if self._use_lin_dyn: 

130 KalmanFilter.set_state_model( 

131 self, 

132 state_mat=state_mat, 

133 input_mat=input_mat, 

134 cont_time=cont_time, 

135 state_mat_fun=state_mat_fun, 

136 input_mat_fun=input_mat_fun, 

137 dyn_obj=dyn_obj, 

138 ) 

139 elif self._use_non_lin_dyn: 

140 ExtendedKalmanFilter.set_state_model(self, dyn_obj=dyn_obj, ode_lst=ode_lst) 

141 else: 

142 raise RuntimeError("Invalid state model.") 

143 

144 def set_measurement_model(self, meas_mat=None, meas_fun=None): 

145 r"""Sets the measurement model for the filter. 

146 

147 This can either set the constant measurement matrix, or a potentially 

148 non-linear function. 

149 

150 Notes 

151 ----- 

152 This assumes a measurement model of the form 

153 

154 .. math:: 

155 \tilde{y}_{k+1} = H x_{k+1}^- 

156 

157 for the measurement matrix case. Or of the form 

158 

159 .. math:: 

160 \tilde{y}_{k+1} = h(t, x_{k+1}^-) 

161 

162 for the potentially non-linear case. 

163 

164 Parameters 

165 ---------- 

166 meas_mat : Nm x N numpy array, optional 

167 Measurement matrix that transforms the state to estimated 

168 measurements. The default is None. 

169 meas_fun : callable, optional 

170 Function that transforms the state to estimated measurements. Must 

171 have the signature :code:`h(t, x, *args)` where `t` is the timestep, 

172 `x` is an N x 1 numpy array of the current state, and return an 

173 Nm x 1 numpy array of the estimated measurement. The default is None. 

174 

175 Raises 

176 ------ 

177 RuntimeError 

178 Rasied if no arguments are specified. 

179 """ 

180 super().set_measurement_model(meas_mat=meas_mat, meas_fun_lst=meas_fun) 

181 

182 def _factorize_cov(self, val=None): 

183 if val is None: 

184 val = self.cov 

185 # numpy linalg is lower triangular 

186 self._sqrt_cov = la.cholesky(val) 

187 

188 def _pred_update_cov(self): 

189 self.cov = self.proc_noise + self.quadPoints.cov 

190 

191 def _predict_next_state( 

192 self, timestep, cur_state, cur_input, state_mat_args, input_mat_args, dyn_fun_params 

193 ): 

194 if self._use_lin_dyn: 

195 return KalmanFilter._predict_next_state( 

196 self, 

197 timestep, 

198 cur_state.reshape((-1, 1)), 

199 cur_input, 

200 state_mat_args, 

201 input_mat_args, 

202 )[0] 

203 elif self._use_non_lin_dyn: 

204 return ExtendedKalmanFilter._predict_next_state( 

205 self, timestep, cur_state.reshape((-1, 1)), dyn_fun_params 

206 )[0] 

207 else: 

208 raise RuntimeError("State model not specified") 

209 

210 def predict( 

211 self, timestep, cur_state, cur_input=None, state_mat_args=None, input_mat_args=None, dyn_fun_params=None, 

212 ): 

213 """Prediction step of the filter. 

214 

215 Parameters 

216 ---------- 

217 timestep : float 

218 Current timestep. 

219 cur_state : N x 1 numpy array 

220 Current state. 

221 cur_input : N x Nu numpy array, optional 

222 Current input. The default is None. 

223 state_mat_args : tuple, optional 

224 Additional arguments for the get state matrix function if one has 

225 been specified, the propagate state function if using a dynamic 

226 object, or the dynamic function is a non-linear model is used. 

227 The default is (). 

228 input_mat_args : tuple, optional 

229 Additional arguments for the get input matrix function if one has 

230 been specified or the propagate state function if using a dynamic 

231 object. The default is (). 

232 dyn_fun_params : tuple, optional 

233 Additional arguments to pass to the dynamics function if using non-linear 

234 dynamics. 

235 

236 Raises 

237 ------ 

238 RuntimeError 

239 If the state model has not been set 

240 

241 Returns 

242 ------- 

243 N x 1 numpy array 

244 The predicted state. 

245 """ 

246 if state_mat_args is None: 

247 state_mat_args = () 

248 if input_mat_args is None: 

249 input_mat_args = () 

250 if dyn_fun_params is None: 

251 dyn_fun_params = () 

252 

253 # factorize covariance as P = sqrt(P) * sqrt(P)^T 

254 self._factorize_cov() 

255 

256 # generate quadrature points as X_i = sqrt(P) * xi_i + x_hat for m points 

257 self.quadPoints.update_points(cur_state, self._sqrt_cov, have_sqrt=True) 

258 

259 # predict each point using the dynamics 

260 for ii, (point, _) in enumerate(self.quadPoints): 

261 pred_point = self._predict_next_state( 

262 timestep, point, cur_input, state_mat_args, input_mat_args, dyn_fun_params 

263 ) 

264 self.quadPoints.points[ii, :] = pred_point.ravel() 

265 # update covariance as Q - m * x * x^T + sum(w_i * X_i * X_i^T) 

266 self._pred_update_cov() 

267 

268 return self.quadPoints.mean 

269 

270 def _corr_update_cov(self, gain, inov_cov): 

271 self.cov = self.cov - gain @ inov_cov @ gain.T 

272 self.cov = 0.5 * (self.cov + self.cov.T) 

273 

274 def _est_meas(self, timestep, cur_state, n_meas, meas_fun_args): 

275 if self._meas_fnc is not None: 

276 return self._meas_fnc(timestep, cur_state, *meas_fun_args).ravel() 

277 else: 

278 return ( 

279 super()._est_meas(timestep, cur_state, n_meas, meas_fun_args)[0].ravel() 

280 ) 

281 

282 def _corr_core(self, timestep, cur_state, meas, meas_fun_args): 

283 # factorize covariance as P = sqrt(P) * sqrt(P)^T 

284 self._factorize_cov() 

285 

286 # generate quadrature points as X_i = sqrt(P) * xi_i + x_hat for m points 

287 self.quadPoints.update_points(cur_state, self._sqrt_cov, have_sqrt=True) 

288 

289 # Estimate a measurement for each quad point, Z_i 

290 measQuads = gdistrib.QuadraturePoints(num_axes=meas.size) 

291 measQuads.points = np.nan * np.ones( 

292 (self.quadPoints.points.shape[0], meas.size) 

293 ) 

294 measQuads.weights = self.quadPoints.weights 

295 for ii, (point, _) in enumerate(self.quadPoints): 

296 measQuads.points[ii, :] = self._est_meas( 

297 timestep, point, meas.size, meas_fun_args 

298 ) 

299 # estimate predicted measurement as sum of est measurement quad points 

300 est_meas = measQuads.mean 

301 

302 return measQuads, est_meas 

303 

304 def correct(self, timestep, meas, cur_state, meas_fun_args=()): 

305 """Implements the correction step of the filter. 

306 

307 Parameters 

308 ---------- 

309 timestep : float 

310 Current timestep. 

311 meas : Nm x 1 numpy array 

312 Current measurement. 

313 cur_state : N x 1 numpy array 

314 Current state. 

315 meas_fun_args : tuple, optional 

316 Arguments for the measurement matrix function if one has 

317 been specified. The default is (). 

318 

319 Raises 

320 ------ 

321 :class:`.errors.ExtremeMeasurementNoiseError` 

322 If estimating the measurement noise and the measurement fit calculation fails. 

323 LinAlgError 

324 Numpy exception raised if not estimating noise and measurement fit fails. 

325 

326 Returns 

327 ------- 

328 next_state : N x 1 numpy array 

329 The corrected state. 

330 meas_fit_prob : float 

331 Goodness of fit of the measurement based on the state and 

332 covariance assuming Gaussian noise. 

333 

334 """ 

335 measQuads, est_meas = self._corr_core(timestep, cur_state, meas, meas_fun_args) 

336 

337 # estimate the measurement noise online if applicable 

338 if self._est_meas_noise_fnc is not None: 

339 self.meas_noise = self._est_meas_noise_fnc(est_meas, measQuads.cov) 

340 # estimate innovation cov as P_zz = R - m * z_hat * z_hat^T + sum(w_i * Z_i * Z_i^T) 

341 inov_cov = self.meas_noise + measQuads.cov 

342 

343 if self.use_cholesky_inverse: 

344 sqrt_inv_inov_cov = la.inv(la.cholesky(inov_cov)) 

345 inv_inov_cov = sqrt_inv_inov_cov.T @ sqrt_inv_inov_cov 

346 else: 

347 inv_inov_cov = la.inv(inov_cov) 

348 # estimate cross cov as P_xz = sum(w_i * X_i * Z_i^T) - m * x * z_hat^T 

349 cov_lst = [None] * self.quadPoints.num_points 

350 for ii, (qp, mp) in enumerate(zip(self.quadPoints, measQuads)): 

351 cov_lst[ii] = (qp[0] - cur_state) @ (mp[0] - est_meas).T 

352 cross_cov = gmath.weighted_sum_mat(self.quadPoints.weights, cov_lst) 

353 

354 # calc Kalman gain as K = P_xz * P_zz^-1 

355 gain = cross_cov @ inv_inov_cov 

356 

357 # state is x_hat + K *(z - z_hat) 

358 innov = meas - est_meas 

359 cor_state = cur_state + gain @ innov 

360 

361 # update covariance as P = P_k - K * P_zz * K^T 

362 self._corr_update_cov(gain, inov_cov) 

363 

364 meas_fit_prob = self._calc_meas_fit(meas, est_meas, inov_cov) 

365 

366 return (cor_state, meas_fit_prob) 

367 

368 def plot_quadrature(self, inds, **kwargs): 

369 """Wrapper function for :meth:`gncpy.distributions.QuadraturePoints.plot_points`.""" 

370 return self.quadPoints.plot_points(inds, **kwargs)