Coverage for src/gncpy/filters/gsm_filter_base.py: 71%

273 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-19 05:48 +0000

1import numpy as np 

2import numpy.random as rnd 

3import scipy.stats as stats 

4from queue import deque 

5from warnings import warn 

6 

7import gncpy.distributions as gdistrib 

8import gncpy.errors as gerr 

9from serums.enums import GSMTypes 

10from gncpy.filters.bayes_filter import BayesFilter 

11from gncpy.filters.bootstrap_filter import BootstrapFilter 

12 

13 

14class _GSMProcNoiseEstimator: 

15 """Helper class for estimating proc noise in the GSM filters.""" 

16 

17 def __init__(self): 

18 self.q_fifo = deque([], None) 

19 self.b_fifo = deque([], None) 

20 

21 self.startup_delay = 1 

22 

23 self._last_q_hat = None 

24 self._call_count = 0 

25 

26 @property 

27 def maxlen(self): 

28 return self.q_fifo.maxlen 

29 

30 @maxlen.setter 

31 def maxlen(self, val): 

32 self.q_fifo = deque([], val) 

33 self.b_fifo = deque([], val) 

34 

35 @property 

36 def win_len(self): 

37 return len(self.q_fifo) 

38 

39 def estimate_next(self, cur_est, pred_state, pred_cov, cor_state, cor_cov): 

40 self._call_count += 1 

41 

42 # update bk 

43 bk = pred_cov - cur_est - cor_cov 

44 if self.b_fifo.maxlen is None or len(self.b_fifo) < self.b_fifo.maxlen: 

45 bk_last = np.zeros(bk.shape) 

46 else: 

47 bk_last = self.b_fifo.pop() 

48 self.b_fifo.appendleft(bk) 

49 

50 # update qk and qk_hat 

51 qk = cor_state - pred_state 

52 if self._last_q_hat is None: 

53 self._last_q_hat = np.zeros(qk.shape) 

54 if self.q_fifo.maxlen is None or len(self.q_fifo) < self.q_fifo.maxlen: 

55 qk_last = np.zeros(qk.shape) 

56 else: 

57 qk_last = self.q_fifo.pop() 

58 self.q_fifo.appendleft(qk) 

59 

60 qk_klast_diff = qk - qk_last 

61 win_len = self.win_len 

62 inv_win_len = 1 / win_len 

63 win_len_m1 = win_len - 1 

64 if self._call_count <= np.max([1, self.startup_delay]): 

65 return cur_est 

66 self._last_q_hat += inv_win_len * qk_klast_diff 

67 

68 # estimate cov 

69 qk_khat_diff = qk - self._last_q_hat 

70 qklast_khat_diff = qk_last - self._last_q_hat 

71 bk_diff = bk_last - bk 

72 next_est = cur_est + 1 / win_len_m1 * ( 

73 qk_khat_diff @ qk_khat_diff.T 

74 - qklast_khat_diff @ qklast_khat_diff.T 

75 + inv_win_len * qk_klast_diff @ qk_klast_diff.T 

76 + win_len_m1 * inv_win_len * bk_diff 

77 ) 

78 

79 # for numerical reasons 

80 for ii in range(next_est.shape[0]): 

81 next_est[ii, ii] = np.abs(next_est[ii, ii]) 

82 return next_est 

83 

84 

85class GSMFilterBase(BayesFilter): 

86 """Base implementation of a Gaussian Scale Mixture (GSM) filter. 

87 

88 This should be inherited from to include specific implementations of the 

89 core filter by exposing the necessary core filter attributes. 

90 

91 Notes 

92 ----- 

93 This is based on a generic version of 

94 :cite:`VilaValls2012_NonlinearBayesianFilteringintheGaussianScaleMixtureContext` 

95 which extends :cite:`VilaValls2011_BayesianFilteringforNonlinearStateSpaceModelsinSymmetricStableMeasurementNoise`. 

96 This class does not implement a specific core filter, that is up to the 

97 child classes. 

98 

99 Attributes 

100 ---------- 

101 enable_proc_noise_estimation : bool 

102 Flag indicating if the process noise should be estimated. 

103 """ 

104 

105 def __init__(self, enable_proc_noise_estimation=False, **kwargs): 

106 """Initializes the class. 

107 

108 Parameters 

109 ---------- 

110 enable_proc_noise_estimation : bool, optional 

111 Flag indicating if the process noise should be estimated. The 

112 default is False. 

113 **kwargs : dict 

114 Additional keyword arguements for the parent class. 

115 """ 

116 super().__init__(**kwargs) 

117 

118 self.enable_proc_noise_estimation = enable_proc_noise_estimation 

119 

120 self._coreFilter = None 

121 self._meas_noise_filters = [] 

122 self._import_w_factory_lst = None 

123 self._procNoiseEstimator = _GSMProcNoiseEstimator() 

124 

125 def save_filter_state(self): 

126 """Saves filter variables so they can be restored later. 

127 

128 Note that to pickle the resulting dictionary the :code:`dill` package 

129 may need to be used due to potential pickling of functions. 

130 """ 

131 filt_state = super().save_filter_state() 

132 

133 filt_state["enable_proc_noise_estimation"] = self.enable_proc_noise_estimation 

134 

135 if self._coreFilter is not None: 

136 filt_state["_coreFilter"] = ( 

137 type(self._coreFilter), 

138 self._coreFilter.save_filter_state(), 

139 ) 

140 else: 

141 filt_state["_coreFilter"] = (None, self._coreFilter) 

142 filt_state["_meas_noise_filters"] = [ 

143 (type(f), f.save_filter_state()) for f in self._meas_noise_filters 

144 ] 

145 

146 filt_state["_import_w_factory_lst"] = self._import_w_factory_lst 

147 filt_state["_procNoiseEstimator"] = self._procNoiseEstimator 

148 

149 return filt_state 

150 

151 def load_filter_state(self, filt_state): 

152 """Initializes filter using saved filter state. 

153 

154 Attributes 

155 ---------- 

156 filt_state : dict 

157 Dictionary generated by :meth:`save_filter_state`. 

158 """ 

159 super().load_filter_state(filt_state) 

160 

161 self.enable_proc_noise_estimation = filt_state["enable_proc_noise_estimation"] 

162 

163 cls_type = filt_state["_coreFilter"][0] 

164 if cls_type is not None: 

165 self._coreFilter = cls_type() 

166 self._coreFilter.load_filter_state(filt_state["_coreFilter"][1]) 

167 else: 

168 self._coreFilter = None 

169 num_m_filts = len(filt_state["_meas_noise_filters"]) 

170 self._meas_noise_filters = [None] * num_m_filts 

171 for ii, (cls_type, vals) in enumerate(filt_state["_meas_noise_filters"]): 

172 if cls_type is not None: 

173 self._meas_noise_filters[ii] = cls_type() 

174 self._meas_noise_filters[ii].load_filter_state(vals) 

175 self._import_w_factory_lst = filt_state["_import_w_factory_lst"] 

176 self._procNoiseEstimator = filt_state["_procNoiseEstimator"] 

177 

178 def set_state_model(self, **kwargs): 

179 """Wrapper for the core filters set state model function.""" 

180 if self._coreFilter is not None: 

181 self._coreFilter.set_state_model(**kwargs) 

182 else: 

183 warn("Core filter is not set, use an inherited class.", RuntimeWarning) 

184 

185 def set_measurement_model(self, **kwargs): 

186 """Wrapper for the core filters set measurement model function.""" 

187 if self._coreFilter is not None: 

188 self._coreFilter.set_measurement_model(**kwargs) 

189 else: 

190 warn("Core filter is not set, use an inherited class.", RuntimeWarning) 

191 

192 def _define_student_t_pf(self, gsm, rng, num_parts): 

193 def import_w_factory(inov_cov): 

194 def import_w_fnc(meas, parts): 

195 stds = np.sqrt(parts[:, 2] * parts[:, 1] ** 2 + inov_cov) 

196 return np.array( 

197 [stats.norm.pdf(meas.item(), scale=scale) for scale in stds] 

198 ) 

199 

200 return import_w_fnc 

201 

202 def gsm_import_dist_factory(): 

203 def import_dist_fnc(parts, _rng): 

204 new_parts = np.nan * np.ones(parts.particles.shape) 

205 

206 disc = 0.99 

207 a = (3 * disc - 1) / (2 * disc) 

208 h = np.sqrt(1 - a ** 2) 

209 last_means = np.mean(parts.particles, axis=0) 

210 means = a * parts.particles[:, 0:2] + (1 - a) * last_means[0:2] 

211 

212 # df, sig 

213 for ind in range(means.shape[1]): 

214 std = np.sqrt(h ** 2 * np.cov(parts.particles[:, ind])) 

215 

216 for ii, m in enumerate(means): 

217 samp = stats.norm.rvs(loc=m[ind], scale=std, random_state=_rng) 

218 new_parts[ii, ind] = samp 

219 df = np.mean(new_parts[:, 0]) 

220 if df < 0: 

221 msg = "Degree of freedom must be > 0 {:.4f}".format(df) 

222 raise gerr.ParticleEstimationDomainError(msg) 

223 new_parts[:, 2] = stats.invgamma.rvs( 

224 df / 2, 

225 scale=1 / (2 / df), 

226 random_state=_rng, 

227 size=new_parts.shape[0], 

228 ) 

229 return new_parts 

230 

231 return import_dist_fnc 

232 

233 pf = BootstrapFilter() 

234 pf.importance_dist_fnc = gsm_import_dist_factory() 

235 pf.particleDistribution = gdistrib.SimpleParticleDistribution() 

236 

237 df_scale = gsm.df_range[1] - gsm.df_range[0] 

238 df_loc = gsm.df_range[0] 

239 df_particles = stats.uniform.rvs( 

240 loc=df_loc, scale=df_scale, size=num_parts, random_state=rng 

241 ) 

242 

243 sig_scale = gsm.scale_range[1] - gsm.scale_range[0] 

244 sig_loc = gsm.scale_range[0] 

245 sig_particles = stats.uniform.rvs( 

246 loc=sig_loc, scale=sig_scale, size=num_parts, random_state=rng 

247 ) 

248 

249 z_particles = np.nan * np.ones(num_parts) 

250 for ii, v in enumerate(df_particles): 

251 z_particles[ii] = stats.invgamma.rvs( 

252 v / 2, scale=1 / (2 / v), random_state=rng 

253 ) 

254 pf.particleDistribution.particles = np.stack( 

255 (df_particles, sig_particles, z_particles), axis=1 

256 ) 

257 pf.particleDistribution.num_parts_per_ind = np.ones(num_parts) 

258 pf.particleDistribution.weights = 1 / num_parts * np.ones(num_parts) 

259 pf.rng = rng 

260 

261 return pf, import_w_factory 

262 

263 def _define_cauchy_pf(self, gsm, rng, num_parts): 

264 def import_w_factory(inov_cov): 

265 def import_w_fnc(meas, parts): 

266 stds = np.sqrt(parts[:, 1] * parts[:, 0] ** 2 + inov_cov) 

267 return np.array( 

268 [stats.norm.pdf(meas.item(), scale=scale) for scale in stds] 

269 ) 

270 

271 return import_w_fnc 

272 

273 def gsm_import_dist_factory(): 

274 def import_dist_fnc(parts, _rng): 

275 new_parts = np.nan * np.ones(parts.particles.shape) 

276 

277 disc = 0.99 

278 a = (3 * disc - 1) / (2 * disc) 

279 h = np.sqrt(1 - a ** 2) 

280 last_means = np.mean(parts.particles, axis=0) 

281 means = a * parts.particles[:, [0]] + (1 - a) * last_means[0] 

282 

283 # df, sig 

284 for ind in range(means.shape[1]): 

285 std = np.sqrt(h ** 2 * np.cov(parts.particles[:, ind])) 

286 

287 for ii, m in enumerate(means): 

288 samp = stats.norm.rvs(loc=m[ind], scale=std, random_state=_rng) 

289 new_parts[ii, ind] = samp 

290 new_parts[:, 1] = stats.invgamma.rvs( 

291 1 / 2, scale=1 / 2, random_state=_rng, size=new_parts.shape[0] 

292 ) 

293 return new_parts 

294 

295 return import_dist_fnc 

296 

297 pf = BootstrapFilter() 

298 pf.importance_dist_fnc = gsm_import_dist_factory() 

299 pf.particleDistribution = gdistrib.SimpleParticleDistribution() 

300 

301 sig_scale = gsm.scale_range[1] - gsm.scale_range[0] 

302 sig_loc = gsm.scale_range[0] 

303 sig_particles = stats.uniform.rvs( 

304 loc=sig_loc, scale=sig_scale, size=num_parts, random_state=rng 

305 ) 

306 

307 z_particles = stats.invgamma.rvs( 

308 1 / 2, scale=1 / 2, random_state=rng, size=num_parts 

309 ) 

310 

311 pf.particleDistribution.particles = np.stack( 

312 (sig_particles, z_particles), axis=1 

313 ) 

314 pf.particleDistribution.num_parts_per_ind = np.ones(num_parts) 

315 pf.particleDistribution.weights = 1 / num_parts * np.ones(num_parts) 

316 pf.rng = rng 

317 

318 return pf, import_w_factory 

319 

320 def set_meas_noise_model( 

321 self, 

322 bootstrap_lst=None, 

323 importance_weight_factory_lst=None, 

324 gsm_lst=None, 

325 num_parts=None, 

326 rng=None, 

327 ): 

328 """Initializes the measurement noise estimators. 

329 

330 The filters and importance 

331 weight factories can be provided or a list of 

332 :class:`serums.models.GaussianScaleMixture` objecst, list of particles, 

333 and a random number generator. If the latter set is given then 

334 bootstrap filters are constructed automatically. The recommended way 

335 of specifying the filters is to provide GSM objects. 

336 

337 Notes 

338 ----- 

339 This uses independent bootstrap particle filters for each measurement 

340 based on the provided model information. 

341 

342 Parameters 

343 ---------- 

344 bootstrap_lst : list, optional 

345 List of :class:`.BootstrapFilter` objects that have already been 

346 initialized. If given then importance weight factory list must also 

347 be given and of the same length. The default is None. 

348 importance_weight_factory_lst : list, optional 

349 List of callables, each takes a 1 x 1 numpy array or float as input 

350 and returns a callable with the signature `f(z, parts)` where 

351 `z` is a float that is the difference between the estimated measurement 

352 and actual measurement (so the distribution is 0 mean) and `parts` is 

353 a numpy array of all the particles from the bootrap filter `f` must 

354 return a numpy array of weights for each particle. See the 

355 :attr:`.BootstrapFilter.importance_weight_fnc` for more details. The 

356 default is None. 

357 gsm_lst : list, optional 

358 List of :class:`serums.models.GaussianScaleMixture` objects, one 

359 per measurement. Requires `num_parts` also be specified and optionally 

360 `rng`. The default is None. 

361 num_parts : list or int, optional 

362 Number of particles to use in each automatically constructed filter. 

363 If only one number is supplied all filters will use the same number 

364 of particles. The default is None. 

365 rng : numpy random generator, optional 

366 Random number generator to use in constructed filters. If supplied, 

367 each filter uses the same instance, otherwise a new generator is 

368 created for each filter using numpy's default initialization routine 

369 with no supplied seed. Only used if a `gsm_lst` is supplied. 

370 

371 Raises 

372 ------ 

373 RuntimeError 

374 If an incorrect combination of input arguments is provided. 

375 

376 Todo 

377 ---- 

378 Allow for GSM Object to use location parameter when specifying noise models 

379 """ 

380 if bootstrap_lst is not None: 

381 self._meas_noise_filters = bootstrap_lst 

382 if importance_weight_factory_lst is not None: 

383 if len(importance_weight_factory_lst) != len(bootstrap_lst): 

384 msg = ( 

385 "Importance weight factory list " 

386 + "length ({:d}) ".format(len(importance_weight_factory_lst)) 

387 + "does not match the number of bootstrap filters ({:d})".format( 

388 len(bootstrap_lst) 

389 ) 

390 ) 

391 raise RuntimeError(msg) 

392 self._import_w_factory_lst = importance_weight_factory_lst 

393 else: 

394 msg = ( 

395 "Must supply an importance weight factory when " 

396 + "specifying the bootstrap filters." 

397 ) 

398 raise RuntimeError(msg) 

399 elif gsm_lst is not None: 

400 num_filts = len(gsm_lst) 

401 if not (isinstance(num_parts, list) or isinstance(num_parts, tuple)): 

402 if num_parts is None: 

403 msg = "Must specify number of particles when giving a list of GSM objects." 

404 raise RuntimeError(msg) 

405 else: 

406 num_parts = [num_parts] * num_filts 

407 self._meas_noise_filters = [None] * num_filts 

408 self._import_w_factory_lst = [None] * num_filts 

409 for ii, gsm in enumerate(gsm_lst): 

410 if rng is None: 

411 rng = rnd.default_rng() 

412 if gsm.type is GSMTypes.STUDENTS_T: 

413 ( 

414 self._meas_noise_filters[ii], 

415 self._import_w_factory_lst[ii], 

416 ) = self._define_student_t_pf(gsm, rng, num_parts[ii]) 

417 elif GSMTypes.CAUCHY: 

418 ( 

419 self._meas_noise_filters[ii], 

420 self._import_w_factory_lst[ii], 

421 ) = self._define_cauchy_pf(gsm, rng, num_parts[ii]) 

422 else: 

423 msg = ( 

424 "GSM filter can not automatically setup Bootstrap " 

425 + "filter for GSM Type {:s}. ".format(gsm.type) 

426 + "Update implementation." 

427 ) 

428 raise RuntimeError(msg) 

429 else: 

430 msg = ( 

431 "Incorrect input arguement combination. See documentation for details." 

432 ) 

433 raise RuntimeError(msg) 

434 

435 def set_process_noise_model( 

436 self, initial_est=None, filter_length=None, startup_delay=None 

437 ): 

438 """Sets the filter parameters for estimating process noise. 

439 

440 This assumes the same process noise model as in 

441 :cite:`VilaValls2011_BayesianFilteringforNonlinearStateSpaceModelsinSymmetricStableMeasurementNoise`. 

442 

443 Parameters 

444 ---------- 

445 initial_est : N x N numpy array, optional 

446 Initial estimate of the covariance, can either be set here or by 

447 manually setting the process noise variable. The default is None, 

448 this assumes process noise was manually set. 

449 filter_length : int, optional 

450 Number of past samples to use in the filter, must be > 1. The 

451 default is None, this implies all past samples are used or the 

452 previously set value is maintained. 

453 startup_delay : int, optional 

454 Number of samples to delay before runing the filter, used to fill 

455 the FIFO buffers. Must be >= 1. The default is None, this means 

456 either the previous value is maintained or a value of 1 is used. 

457 

458 Returns 

459 ------- 

460 None. 

461 """ 

462 self.enable_proc_noise_estimation = True 

463 

464 if filter_length is not None: 

465 self._procNoiseEstimator.maxlen = filter_length 

466 if initial_est is None and ( 

467 self.proc_noise is None or self.proc_noise.size <= 0 

468 ): 

469 msg = ( 

470 "Please manually set the initial process noise " 

471 + "or specify a value here." 

472 ) 

473 warn(msg) 

474 elif initial_est is not None: 

475 self.proc_noise = initial_est 

476 if startup_delay is not None: 

477 self._procNoiseEstimator.startup_delay = startup_delay 

478 

479 @property 

480 def cov(self): 

481 """Covariance of the filter.""" 

482 if self._coreFilter is None: 

483 return np.array([[]]) 

484 else: 

485 return self._coreFilter.cov 

486 

487 @cov.setter 

488 def cov(self, val): 

489 self._coreFilter.cov = val 

490 

491 @property 

492 def proc_noise(self): 

493 """Wrapper for the process noise covariance of the core filter.""" 

494 if self._coreFilter is None: 

495 return np.array([[]]) 

496 else: 

497 return self._coreFilter.proc_noise 

498 

499 @proc_noise.setter 

500 def proc_noise(self, val): 

501 self._coreFilter.proc_noise = val 

502 

503 @property 

504 def meas_noise(self): 

505 """Measurement noise of the core filter, estimated online and does not need to be set.""" 

506 if self._coreFilter is None: 

507 return np.array([[]]) 

508 else: 

509 return self._coreFilter.meas_noise 

510 

511 @meas_noise.setter 

512 def meas_noise(self, val): 

513 warn( 

514 "Measurement noise is estimated online. NOT SETTING VALUE HERE.", 

515 RuntimeWarning, 

516 ) 

517 

518 def predict(self, timestep, cur_state, **kwargs): 

519 """Prediction step of the GSM filter. 

520 

521 This optionally estimates the process noise then calls the core filters 

522 prediction function. 

523 """ 

524 return self._coreFilter.predict(timestep, cur_state, **kwargs) 

525 

526 def correct(self, timestep, meas, cur_state, core_filt_kwargs={}): 

527 """Correction step of the GSM filter. 

528 

529 This optionally estimates the measurement noise then calls the core 

530 filters correction function. 

531 """ 

532 # setup core filter for estimating measurement noise during 

533 # correction function call 

534 def est_meas_noise(est_meas, inov_cov): 

535 m_diag = np.nan * np.ones(est_meas.size) 

536 f_meas = (meas - est_meas).ravel() 

537 inov_cov = np.diag(inov_cov) 

538 for ii, filt in enumerate(self._meas_noise_filters): 

539 filt.importance_weight_fnc = self._import_w_factory_lst[ii]( 

540 inov_cov[ii] 

541 ) 

542 filt.predict(timestep) 

543 state = filt.correct(timestep, f_meas[ii].reshape((1, 1))) 

544 if state.size == 3: 

545 m_diag[ii] = state[2] * state[1] ** 2 # z * sig^2 

546 else: 

547 m_diag[ii] = state[1] * state[0] ** 2 # z * sig^2 

548 return np.diag(m_diag) 

549 

550 self._coreFilter.set_measurement_noise_estimator(est_meas_noise) 

551 

552 pred_cov = self.cov.copy() 

553 cor_state, meas_fit_prob = self._coreFilter.correct( 

554 timestep, meas, cur_state, **core_filt_kwargs 

555 ) 

556 

557 # update process noise estimate (if applicable) 

558 if self.enable_proc_noise_estimation: 

559 self.proc_noise = self._procNoiseEstimator.estimate_next( 

560 self.proc_noise, cur_state, pred_cov, cor_state, self.cov 

561 ) 

562 return cor_state, meas_fit_prob 

563 

564 def plot_particles(self, filt_inds, dist_inds, **kwargs): 

565 """Plots the particle distribution for every given measurement index. 

566 

567 Parameters 

568 ---------- 

569 filt_inds : int or list 

570 Index of the measurement index/indices to plot the particle distribution 

571 for. 

572 dist_inds : int or list 

573 Index of the particle(s) in the given filter(s) to plot. See the 

574 :meth:`.BootstrapFilter.plot_particles` for details. 

575 **kwargs : dict 

576 Additional keyword arguements. See :meth:`.BootstrapFilter.plot_particles`. 

577 

578 Returns 

579 ------- 

580 figs : dict 

581 Each value in the dictionary is a matplotlib figure handle. 

582 keys : list 

583 Each value is a string corresponding to a key in the resulting dictionary 

584 """ 

585 figs = {} 

586 key_base = "meas_noise_particles_F{:02d}_D{:02d}" 

587 keys = [] 

588 if not isinstance(filt_inds, list): 

589 key = key_base.format(filt_inds, dist_inds) 

590 figs[key] = self._meas_noise_filters[filt_inds].plot_particles( 

591 dist_inds, **kwargs 

592 ) 

593 keys.append(key) 

594 else: 

595 for ii in filt_inds: 

596 key = key_base.format(ii, dist_inds) 

597 figs[key] = self._meas_noise_filters[ii].plot_particles( 

598 dist_inds, **kwargs 

599 ) 

600 keys.append(key) 

601 return figs, keys