Coverage for src/gncpy/filters/max_corr_ent_ukf.py: 100%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-13 06:15 +0000

1import numpy as np 

2import numpy.linalg as la 

3 

4import gncpy.math as gmath 

5from gncpy.filters.unscented_kalman_filter import UnscentedKalmanFilter 

6 

7 

8class MaxCorrEntUKF(UnscentedKalmanFilter): 

9 """Implements a Maximum Correntropy Unscented Kalman filter. 

10 

11 Notes 

12 ----- 

13 This is based on 

14 :cite:`Hou2018_MaximumCorrentropyUnscentedKalmanFilterforBallisticMissileNavigationSystemBasedonSINSCNSDeeplyIntegratedMode` 

15 

16 Attributes 

17 ---------- 

18 kernel_bandwidth : float, optional 

19 Bandwidth of the Gaussian Kernel. The default is 1. 

20 """ 

21 

22 def __init__(self, kernel_bandwidth=1, **kwargs): 

23 self.kernel_bandwidth = kernel_bandwidth 

24 

25 # for correction/calc_meas_cov wrapper function 

26 self._past_state = np.array([[]]) 

27 self._cur_state = np.array([[]]) 

28 self._meas = np.array([[]]) 

29 

30 super().__init__(**kwargs) 

31 

32 def save_filter_state(self): 

33 """Saves filter variables so they can be restored later.""" 

34 filt_state = super().save_filter_state() 

35 

36 filt_state["kernel_bandwidth"] = self.kernel_bandwidth 

37 filt_state["_past_state"] = self._past_state 

38 filt_state["_cur_state"] = self._cur_state 

39 filt_state["_meas"] = self._meas 

40 

41 return filt_state 

42 

43 def load_filter_state(self, filt_state): 

44 """Initializes filter using saved filter state. 

45 

46 Attributes 

47 ---------- 

48 filt_state : dict 

49 Dictionary generated by :meth:`save_filter_state`. 

50 """ 

51 super().load_filter_state(filt_state) 

52 

53 self.kernel_bandwidth = filt_state["kernel_bandwidth"] 

54 self._past_state = filt_state["_past_state"] 

55 self._cur_state = filt_state["_cur_state"] 

56 self._meas = filt_state["_meas"] 

57 

58 def _calc_meas_cov(self, timestep, n_meas, meas_fun_args): 

59 meas_cov, est_points, est_meas = super()._calc_meas_cov( 

60 timestep, n_meas, meas_fun_args 

61 ) 

62 

63 # find square root of combined covariance matrix 

64 n_state = self.cov.shape[0] 

65 n_meas = est_meas.shape[0] 

66 z_12 = np.zeros((n_state, n_meas)) 

67 z_21 = np.zeros((n_meas, n_state)) 

68 comb_cov = np.vstack( 

69 (np.hstack((self.cov, z_12)), np.hstack((z_21, self.meas_noise))) 

70 ) 

71 comb_cov = (comb_cov + comb_cov.T) * 0.5 

72 sqrt_comb = la.cholesky(comb_cov) 

73 inv_sqrt_comb = la.inv(sqrt_comb) 

74 

75 # find error vector 

76 pred_meas = self._est_meas(timestep, self._past_state, n_meas, meas_fun_args)[0] 

77 g = inv_sqrt_comb @ np.vstack((self._past_state, pred_meas)) 

78 d = inv_sqrt_comb @ np.vstack((self._cur_state, self._meas)) 

79 e = (d - g).ravel() 

80 

81 # kernel function on error 

82 kern_lst = [gmath.gaussian_kernel(e_ii, self.kernel_bandwidth) for e_ii in e] 

83 c = np.diag(kern_lst) 

84 c_inv = la.inv(c) 

85 

86 # calculate the measurement covariance 

87 scaled_mat = sqrt_comb @ c_inv @ sqrt_comb.T 

88 scaled_meas_noise = scaled_mat[n_state:, n_state:] 

89 meas_cov = meas_cov + scaled_meas_noise 

90 

91 return meas_cov, est_points, est_meas 

92 

93 def correct(self, timestep, meas, cur_state, past_state, **kwargs): 

94 """Correction function for the Max Correntropy UKF. 

95 

96 This is a wrapper for the parent method to allow for additional 

97 parameters. 

98 

99 Parameters 

100 ---------- 

101 timestep : float 

102 Current timestep. 

103 meas : Nm x 1 numpy array 

104 Current measurement. 

105 cur_state : N x 1 numpy array 

106 Current state. 

107 past_state : N x 1 numpy array 

108 State from before the prediction step. 

109 **kwargs : dict, optional 

110 See the parent function for additional parameters. 

111 

112 Returns 

113 ------- 

114 tuple 

115 See the parent method. 

116 """ 

117 self._past_state = past_state.copy() 

118 self._cur_state = cur_state.copy() 

119 self._meas = meas.copy() 

120 

121 return super().correct(timestep, meas, cur_state, **kwargs)