Coverage for src/gncpy/filters/unscented_kalman_filter.py: 88%

105 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-19 05:48 +0000

1import numpy as np 

2import numpy.linalg as la 

3from copy import deepcopy 

4 

5import gncpy.dynamics.basic as gdyn 

6import gncpy.distributions as gdistrib 

7import gncpy.math as gmath 

8from gncpy.filters.kalman_filter import KalmanFilter 

9from gncpy.filters.extended_kalman_filter import ExtendedKalmanFilter 

10 

11 

12class UnscentedKalmanFilter(ExtendedKalmanFilter): 

13 """Implements an unscented kalman filter. 

14 

15 This allows for linear or non-linear dynamics by utilizing the either the 

16 underlying KF or EKF functions where appropriate. It utilizes the same 

17 constraints on the measurement model as the EKF. 

18 

19 Notes 

20 ----- 

21 For details on the filter see 

22 :cite:`Wan2000_TheUnscentedKalmanFilterforNonlinearEstimation`. This 

23 implementation assumes that the noise is purely addative and as such, 

24 appropriate simplifications have been made. These simplifications include 

25 not using sigma points to track the noise, and using the fixed process and 

26 measurement noise covariance matrices in the filter's covariance updates. 

27 

28 Attributes 

29 ---------- 

30 alpha : float 

31 Tunig parameter for sigma points, influences the spread of sigma points about the 

32 mean. In range (0, 1]. If specified then a value does not need to be 

33 given to the :meth:`.init_sigma_points` function. 

34 kappa : float 

35 Tunig parameter for sigma points, influences the spread of sigma points about the 

36 mean. In range [0, inf]. If specified then a value does not need to be 

37 given to the :meth:`.init_sigma_points` function. 

38 beta : float 

39 Tunig parameter for sigma points. In range [0, Inf]. If specified then 

40 a value does not need to be given to the :meth:`.init_sigma_points` function. 

41 Defaults to 2 (ideal for gaussians). 

42 """ 

43 

44 def __init__(self, sigmaPoints=None, **kwargs): 

45 """Initialize an instance. 

46 

47 Parameters 

48 ---------- 

49 sigmaPoints : :class:`.distributions.SigmaPoints`, optional 

50 Set of initialized sigma points to use. The default is None. 

51 **kwargs : dict, optional 

52 Additional arguments for parent constructors. 

53 """ 

54 self.alpha = 1 

55 self.kappa = 0 

56 self.beta = 2 

57 

58 self._stateSigmaPoints = None 

59 if isinstance(sigmaPoints, gdistrib.SigmaPoints): 

60 self._stateSigmaPoints = sigmaPoints 

61 self.alpha = sigmaPoints.alpha 

62 self.beta = sigmaPoints.beta 

63 self.kappa = sigmaPoints.kappa 

64 self._use_lin_dyn = False 

65 self._use_non_lin_dyn = False 

66 self._est_meas_noise_fnc = None 

67 

68 super().__init__(**kwargs) 

69 

70 def save_filter_state(self): 

71 """Saves filter variables so they can be restored later.""" 

72 filt_state = super().save_filter_state() 

73 

74 filt_state["alpha"] = self.alpha 

75 filt_state["kappa"] = self.kappa 

76 filt_state["beta"] = self.beta 

77 

78 filt_state["_stateSigmaPoints"] = deepcopy(self._stateSigmaPoints) 

79 filt_state["_use_lin_dyn"] = self._use_lin_dyn 

80 filt_state["_use_non_lin_dyn"] = self._use_non_lin_dyn 

81 filt_state["_est_meas_noise_fnc"] = self._est_meas_noise_fnc 

82 

83 return filt_state 

84 

85 def load_filter_state(self, filt_state): 

86 """Initializes filter using saved filter state. 

87 

88 Attributes 

89 ---------- 

90 filt_state : dict 

91 Dictionary generated by :meth:`save_filter_state`. 

92 """ 

93 super().load_filter_state(filt_state) 

94 

95 self.alpha = filt_state["alpha"] 

96 self.kappa = filt_state["kappa"] 

97 self.beta = filt_state["beta"] 

98 

99 self._stateSigmaPoints = deepcopy(filt_state["_stateSigmaPoints"]) 

100 self._use_lin_dyn = filt_state["_use_lin_dyn"] 

101 self._use_non_lin_dyn = filt_state["_use_non_lin_dyn"] 

102 self._est_meas_noise_fnc = filt_state["_est_meas_noise_fnc"] 

103 

104 def init_sigma_points(self, state0, alpha=None, kappa=None, beta=None): 

105 """Initializes the sigma points used by the filter. 

106 

107 Parameters 

108 ---------- 

109 state0 : N x 1 numpy array 

110 Initial state. 

111 alpha : float, optional 

112 Tunig parameter, influences the spread of sigma points about the 

113 mean. In range (0, 1]. If not supplied the class value will be used. 

114 If a value is given here then the class value will be updated. 

115 kappa : float, optional 

116 Tunig parameter, influences the spread of sigma points about the 

117 mean. In range [0, inf]. If not supplied the class value will be used. 

118 If a value is given here then the class value will be updated. 

119 beta : float, optional 

120 Tunig parameter for distribution type. In range [0, Inf]. If not 

121 supplied the class value will be used. If a value is given here 

122 then the class value will be updated. 

123 Defaults to 2 for gaussians. 

124 """ 

125 num_axes = state0.size 

126 if alpha is None: 

127 alpha = self.alpha 

128 else: 

129 self.alpha = alpha 

130 if kappa is None: 

131 kappa = self.kappa 

132 else: 

133 self.kappa = kappa 

134 if beta is None: 

135 beta = self.beta 

136 else: 

137 self.beta = beta 

138 self._stateSigmaPoints = gdistrib.SigmaPoints( 

139 alpha=alpha, kappa=kappa, beta=beta, num_axes=num_axes 

140 ) 

141 self._stateSigmaPoints.init_weights() 

142 self._stateSigmaPoints.update_points(state0, self.cov) 

143 

144 def set_state_model( 

145 self, 

146 state_mat=None, 

147 input_mat=None, 

148 cont_time=False, 

149 state_mat_fun=None, 

150 input_mat_fun=None, 

151 dyn_obj=None, 

152 ode_lst=None, 

153 ): 

154 """Sets the state model for the filter. 

155 

156 This can use either linear dynamics (by calling the kalman filters 

157 :meth:`gncpy.filters.KalmanFilter.set_state_model`) or non-linear dynamics 

158 (by calling :meth:`gncpy.filters.ExtendedKalmanFilter.set_state_model`). 

159 The linearness is automatically determined by the input arguments specified. 

160 

161 Parameters 

162 ---------- 

163 state_mat : N x N numpy array, optional 

164 State matrix, continuous or discrete case. The default is None. 

165 input_mat : N x Nu numpy array, optional 

166 Input matrixx, continuous or discrete case. The default is None. 

167 cont_time : bool, optional 

168 Flag inidicating if the continuous model is provided. The default 

169 is False. 

170 state_mat_fun : callable, optional 

171 Function that returns the `state_mat`, must take timestep and 

172 `*args`. The default is None. 

173 input_mat_fun : callable, optional 

174 Function that returns the `input_mat`, must take timestep, and 

175 `*args`. The default is None. 

176 dyn_obj : :class:`gncpy.dynamics.LinearDynamicsBase` or :class:`gncpy.dynamics.NonlinearDynamicsBase`, optional 

177 Sets the dynamics according to the class. The default is None. 

178 ode_lst : list, optional 

179 callable functions, 1 per ode/state. The callabale must have the 

180 signature `f(t, x, *f_args)` just like scipy.integrate's ode 

181 function. The default is None. 

182 

183 Raises 

184 ------ 

185 RuntimeError 

186 If an invalid state model or combination of inputs is specified. 

187 

188 Returns 

189 ------- 

190 None. 

191 """ 

192 self._use_lin_dyn = ( 

193 state_mat is not None 

194 or state_mat_fun is not None 

195 or isinstance(dyn_obj, gdyn.LinearDynamicsBase) 

196 ) 

197 self._use_non_lin_dyn = ( 

198 isinstance(dyn_obj, gdyn.NonlinearDynamicsBase) or ode_lst is not None 

199 ) and not self._use_lin_dyn 

200 

201 # allow for linear or non linear dynamics by calling the appropriate parent 

202 if self._use_lin_dyn: 

203 KalmanFilter.set_state_model( 

204 self, 

205 state_mat=state_mat, 

206 input_mat=input_mat, 

207 cont_time=cont_time, 

208 state_mat_fun=state_mat_fun, 

209 input_mat_fun=input_mat_fun, 

210 dyn_obj=dyn_obj, 

211 ) 

212 elif self._use_non_lin_dyn: 

213 ExtendedKalmanFilter.set_state_model(self, dyn_obj=dyn_obj, ode_lst=ode_lst) 

214 else: 

215 raise RuntimeError("Invalid state model.") 

216 

217 def set_measurement_noise_estimator(self, function): 

218 """Sets the model used for estimating the measurement noise parameters. 

219 

220 This is an optional step and the filter will work properly if this is 

221 not called. If it is called, the measurement noise will be estimated 

222 during the filter's correction step and the measurement noise attribute 

223 will not be used. 

224 

225 Parameters 

226 ---------- 

227 function : callable 

228 A function that implements the prediction and correction steps for 

229 an appropriate filter to estimate the measurement noise covariance 

230 matrix. It must have the signature `f(est_meas)` where `est_meas` 

231 is an Nm x 1 numpy array and it must return an Nm x Nm numpy array 

232 representing the measurement noise covariance matrix. 

233 

234 Returns 

235 ------- 

236 None. 

237 """ 

238 self._est_meas_noise_fnc = function 

239 

240 def predict( 

241 self, 

242 timestep, 

243 cur_state, 

244 cur_input=None, 

245 state_mat_args=(), 

246 input_mat_args=(), 

247 dyn_fun_params=(), 

248 ): 

249 """Prediction step of the UKF. 

250 

251 Automatically calls the state propagation method from either 

252 :meth:`gncpy.KalmanFilter.predict` or :meth:`gncpy.ExtendedKalmanFilter.predict` 

253 depending on if a linear or non-linear state model was specified. 

254 If a linear model is used only the parameters that can be passed to 

255 :meth:`gncpy.KalmanFilter.predict` will be used by this function. 

256 Otherwise the parameters for :meth:`gncpy.ExtendedKalmanFilter.predict` 

257 will be used. 

258 

259 Parameters 

260 ---------- 

261 timestep : float 

262 Current timestep. 

263 cur_state : N x 1 numpy array 

264 Current state. 

265 cur_input : N x Nu numpy array, optional 

266 Current input for linear models. The default is None. 

267 state_mat_args : tuple, optional 

268 keyword arguments for the get state matrix function if one has 

269 been specified or the propagate state function if using a linear 

270 dynamic object. The default is (). 

271 input_mat_args : tuple, optional 

272 keyword arguments for the get input matrix function if one has 

273 been specified or the propagate state function if using a linear 

274 dynamic object. The default is (). 

275 dyn_fun_params : tuple, optional 

276 Extra arguments to be passed to the dynamics function for non-linear 

277 models. The default is (). 

278 

279 Raises 

280 ------ 

281 RuntimeError 

282 If a state model has not been set. 

283 

284 Returns 

285 ------- 

286 next_state : N x 1 numpy array 

287 The next state. 

288 

289 """ 

290 self._stateSigmaPoints.update_points(cur_state, self.cov) 

291 

292 # propagate points 

293 if self._use_lin_dyn: 

294 new_points = np.array( 

295 [ 

296 KalmanFilter._predict_next_state( 

297 self, 

298 timestep, 

299 x.reshape((x.size, 1)), 

300 cur_input, 

301 state_mat_args, 

302 input_mat_args, 

303 )[0].ravel() 

304 for x in self._stateSigmaPoints.points 

305 ] 

306 ) 

307 elif self._use_non_lin_dyn: 

308 new_points = np.array( 

309 [ 

310 ExtendedKalmanFilter._predict_next_state( 

311 self, timestep, x.reshape((x.size, 1)), dyn_fun_params 

312 )[0].ravel() 

313 for x in self._stateSigmaPoints.points 

314 ] 

315 ) 

316 else: 

317 raise RuntimeError("State model not specified") 

318 self._stateSigmaPoints.points = new_points 

319 

320 # update covariance 

321 self.cov = self._stateSigmaPoints.cov + self.proc_noise 

322 self.cov = (self.cov + self.cov.T) * 0.5 

323 

324 # estimate weighted state output 

325 next_state = self._stateSigmaPoints.mean 

326 

327 return next_state 

328 

329 def _calc_meas_cov(self, timestep, n_meas, meas_fun_args): 

330 est_points = np.array( 

331 [ 

332 self._est_meas(timestep, x.reshape((x.size, 1)), n_meas, meas_fun_args)[ 

333 0 

334 ] 

335 for x in self._stateSigmaPoints.points 

336 ] 

337 ) 

338 est_meas = gmath.weighted_sum_vec( 

339 self._stateSigmaPoints.weights_mean, est_points 

340 ) 

341 diff = est_points - est_meas 

342 meas_cov_lst = diff @ diff.reshape((est_points.shape[0], 1, est_meas.size)) 

343 

344 partial_cov = gmath.weighted_sum_mat( 

345 self._stateSigmaPoints.weights_cov, meas_cov_lst 

346 ) 

347 # estimate the measurement noise if applicable 

348 if self._est_meas_noise_fnc is not None: 

349 self.meas_noise = self._est_meas_noise_fnc(est_meas, partial_cov) 

350 meas_cov = self.meas_noise + partial_cov 

351 

352 return meas_cov, est_points, est_meas 

353 

354 def correct(self, timestep, meas, cur_state, meas_fun_args=()): 

355 """Correction step of the UKF. 

356 

357 Parameters 

358 ---------- 

359 timestep : float 

360 Current timestep. 

361 meas : Nm x 1 numpy array 

362 Current measurement. 

363 cur_state : N x 1 numpy array 

364 Current state. 

365 meas_fun_args : tuple, optional 

366 Arguments for the measurement matrix function if one has 

367 been specified. The default is (). 

368 

369 Raises 

370 ------ 

371 :class:`.errors.ExtremeMeasurementNoiseError` 

372 If estimating the measurement noise and the measurement fit calculation fails. 

373 LinAlgError 

374 Numpy exception raised if not estimating noise and measurement fit fails. 

375 

376 Returns 

377 ------- 

378 next_state : N x 1 numpy array 

379 corrected state. 

380 meas_fit_prob : float 

381 measurement fit probability assuming a Gaussian distribution. 

382 

383 """ 

384 meas_cov, est_points, est_meas = self._calc_meas_cov( 

385 timestep, meas.size, meas_fun_args 

386 ) 

387 

388 state_diff = self._stateSigmaPoints.points - cur_state.ravel() 

389 meas_diff = (est_points - est_meas).reshape( 

390 (est_points.shape[0], 1, est_meas.size) 

391 ) 

392 cross_cov_lst = state_diff.reshape((*state_diff.shape, 1)) @ meas_diff 

393 cross_cov = gmath.weighted_sum_mat( 

394 self._stateSigmaPoints.weights_cov, cross_cov_lst 

395 ) 

396 

397 if self.use_cholesky_inverse: 

398 sqrt_inv_meas_cov = la.inv(la.cholesky(meas_cov)) 

399 inv_meas_cov = sqrt_inv_meas_cov.T @ sqrt_inv_meas_cov 

400 else: 

401 inv_meas_cov = la.inv(meas_cov) 

402 gain = cross_cov @ inv_meas_cov 

403 inov = meas - est_meas 

404 

405 self.cov = self.cov - gain @ meas_cov @ gain.T 

406 self.cov = (self.cov + self.cov.T) * 0.5 

407 next_state = cur_state + gain @ inov 

408 

409 meas_fit_prob = self._calc_meas_fit(meas, est_meas, meas_cov) 

410 

411 return next_state, meas_fit_prob