Coverage for src/gncpy/filters/max_corr_ent_upf.py: 96%

26 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-19 05:48 +0000

1import numpy as np 

2 

3from gncpy.filters.unscented_particle_filter import UnscentedParticleFilter 

4from gncpy.filters.max_corr_ent_ukf import MaxCorrEntUKF 

5 

6 

7class MaxCorrEntUPF(UnscentedParticleFilter): 

8 """Implements a Maximum Correntropy Unscented Particle Filter. 

9 

10 Notes 

11 ----- 

12 This is based on 

13 :cite:`Fan2018_MaximumCorrentropyBasedUnscentedParticleFilterforCooperativeNavigationwithHeavyTailedMeasurementNoises` 

14 

15 """ 

16 

17 def __init__(self, **kwargs): 

18 self._past_state = np.array([[]]) 

19 

20 super().__init__(**kwargs) 

21 self._filt = MaxCorrEntUKF() 

22 

23 def save_filter_state(self): 

24 """Saves filter variables so they can be restored later.""" 

25 filt_state = super().save_filter_state() 

26 

27 filt_state["_past_state"] = self._past_state 

28 

29 return filt_state 

30 

31 def load_filter_state(self, filt_state): 

32 """Initializes filter using saved filter state. 

33 

34 Attributes 

35 ---------- 

36 filt_state : dict 

37 Dictionary generated by :meth:`save_filter_state`. 

38 """ 

39 super().load_filter_state(filt_state) 

40 

41 self._past_state = filt_state["_past_state"] 

42 

43 def _inner_correct(self, timestep, meas, state, filt_kwargs): 

44 """Wrapper so child class can override.""" 

45 return self._filt.correct( 

46 timestep, meas, state, self._past_state, **filt_kwargs 

47 ) 

48 

49 def correct(self, timestep, meas, past_state, **kwargs): 

50 """Correction step of the MCUPF. 

51 

52 This is a wrapper for the parent method to allow for an additional 

53 parameter. 

54 

55 Parameters 

56 ---------- 

57 timestep : float 

58 Current timestep. 

59 meas : Nm x 1 numpy array 

60 Current measurement. 

61 past_state : N x 1 numpy array 

62 State from before the prediction step. 

63 **kwargs : dict 

64 See the parent method. 

65 

66 Returns 

67 ------- 

68 tuple 

69 See the parent method. 

70 """ 

71 self._past_state = past_state 

72 return super().correct(timestep, meas, **kwargs) 

73 

74 @property 

75 def kernel_bandwidth(self): 

76 """Bandwidth for the Gaussian Kernel in the MCUKF. 

77 

78 Returns 

79 ------- 

80 float 

81 bandwidth 

82 """ 

83 return self._filt.kernel_bandwidth 

84 

85 @kernel_bandwidth.setter 

86 def kernel_bandwidth(self, kernel_bandwidth): 

87 self._filt.kernel_bandwidth = kernel_bandwidth