Coverage for src/gncpy/filters/students_t_filter.py: 94%

68 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-19 05:48 +0000

1import numpy as np 

2import numpy.linalg as la 

3import scipy.stats as stats 

4 

5from gncpy.filters.kalman_filter import KalmanFilter 

6 

7 

8class StudentsTFilter(KalmanFilter): 

9 r"""Implementation of a Students T filter. 

10 

11 This is based on :cite:`Liu2018_AStudentsTMixtureProbabilityHypothesisDensityFilterforMultiTargetTrackingwithOutliers` 

12 and :cite:`Roth2013_AStudentsTFilterforHeavyTailedProcessandMeasurementNoise` 

13 and uses moment matching to limit the degree of freedom growth. 

14 

15 Notes 

16 ----- 

17 This models the multi-variate Student's t-distribution as 

18 

19 .. math:: 

20 \begin{align} 

21 p(x) &= \frac{\Gamma(\frac{\nu + 2}{2})}{\Gamma(\frac{\nu}{2})} 

22 \frac{1}{(\nu \pi)^{d/2}} 

23 \frac{1}{\sqrt{\vert \Sigma \vert}}\left( 1 + 

24 \frac{\Delta^2}{\nu}\right)^{-\frac{\nu + 2}{\nu}} \\ 

25 \Delta^2 &= (x - m)^T \Sigma^{-1} (x - m) 

26 \end{align} 

27 

28 or compactly as :math:`St(x; m,\Sigma, \nu) = p(x)` for scale matrix 

29 :math:`\Sigma` and degree of freedom :math:`\nu` 

30 

31 Attributes 

32 ---------- 

33 scale : N x N numpy array, optional 

34 Scaling matrix of the Students T distribution. The default is np.array([[]]). 

35 dof : int , optional 

36 Degree of freedom for the state distribution. The default is 3. 

37 proc_noise_dof : int, optional 

38 Degree of freedom for the process noise model. The default is 3. 

39 meas_noise_dof : int, optional 

40 Degree of freedom for the measurement noise model. The default is 3. 

41 use_moment_matching : bool, optional 

42 Flag indicating if moment matching is used to maintain the heavy tail 

43 property as the filter propagates over time. The default is True. 

44 """ 

45 

46 def __init__( 

47 self, 

48 scale=np.array([[]]), 

49 dof=3, 

50 proc_noise_dof=3, 

51 meas_noise_dof=3, 

52 use_moment_matching=True, 

53 **kwargs 

54 ): 

55 self.scale = scale 

56 self.dof = dof 

57 self.proc_noise_dof = proc_noise_dof 

58 self.meas_noise_dof = meas_noise_dof 

59 self.use_moment_matching = use_moment_matching 

60 

61 super().__init__(**kwargs) 

62 

63 def save_filter_state(self): 

64 """Saves filter variables so they can be restored later.""" 

65 filt_state = super().save_filter_state() 

66 

67 filt_state["scale"] = self.scale 

68 filt_state["dof"] = self.dof 

69 filt_state["proc_noise_dof"] = self.proc_noise_dof 

70 filt_state["meas_noise_dof"] = self.meas_noise_dof 

71 filt_state["use_moment_matching"] = self.use_moment_matching 

72 

73 return filt_state 

74 

75 def load_filter_state(self, filt_state): 

76 """Initializes filter using saved filter state. 

77 

78 Attributes 

79 ---------- 

80 filt_state : dict 

81 Dictionary generated by :meth:`save_filter_state`. 

82 """ 

83 super().load_filter_state(filt_state) 

84 

85 self.scale = filt_state["scale"] 

86 self.dof = filt_state["dof"] 

87 self.proc_noise_dof = filt_state["proc_noise_dof"] 

88 self.meas_noise_dof = filt_state["meas_noise_dof"] 

89 self.use_moment_matching = filt_state["use_moment_matching"] 

90 

91 @property 

92 def cov(self): 

93 """Read only covariance matrix. 

94 

95 This is calculated from the scale matrix and degree of freedom. 

96 

97 Raises 

98 ------ 

99 RuntimeError 

100 If the degree of freedom is less than or equal to 2 

101 

102 Returns 

103 ------- 

104 N x 1 numpy array 

105 Calcualted covariance matrix 

106 """ 

107 if self.dof <= 2: 

108 msg = "Degrees of freedom ({}) must be > 2" 

109 raise RuntimeError(msg.format(self.dof)) 

110 return self.dof / (self.dof - 2) * self.scale 

111 

112 @cov.setter 

113 def cov(self, cov): 

114 pass 

115 

116 def predict( 

117 self, timestep, cur_state, cur_input=None, state_mat_args=(), input_mat_args=() 

118 ): 

119 """Implements the prediction step of the Students T filter. 

120 

121 Parameters 

122 ---------- 

123 timestep : float 

124 Current timestep. 

125 cur_state : N x 1 numpy array 

126 Current state. 

127 cur_input : N x Nu numpy array, optional 

128 Current input. The default is None. 

129 state_mat_args : tuple, optional 

130 keyword arguments for the get state matrix function if one has 

131 been specified or the propagate state function if using a dynamic 

132 object. The default is (). 

133 input_mat_args : tuple, optional 

134 keyword arguments for the get input matrix function if one has 

135 been specified or the propagate state function if using a dynamic 

136 object. The default is (). 

137 

138 Returns 

139 ------- 

140 next_state : N x 1 numpy array 

141 Next state. 

142 """ 

143 next_state, state_mat = self._predict_next_state( 

144 timestep, cur_state, cur_input, state_mat_args, input_mat_args 

145 ) 

146 

147 factor = ( 

148 self.proc_noise_dof 

149 * (self.dof - 2) 

150 / (self.dof * (self.proc_noise_dof - 2)) 

151 ) 

152 self.scale = state_mat @ self.scale @ state_mat.T + factor * self.proc_noise 

153 self.scale = (self.scale + self.scale.T) * 0.5 

154 

155 return next_state 

156 

157 def _meas_fit_pdf(self, meas, est_meas, meas_cov): 

158 return stats.multivariate_t.pdf( 

159 meas.ravel(), loc=est_meas.ravel(), shape=meas_cov, df=self.meas_noise_dof 

160 ) 

161 

162 def correct(self, timestep, meas, cur_state, meas_fun_args=()): 

163 """Implements the correction step of the students T filter. 

164 

165 This also performs the moment matching. 

166 

167 Parameters 

168 ---------- 

169 timestep : float 

170 Current timestep. 

171 meas : Nm x 1 numpy array 

172 Current measurement. 

173 cur_state : N x 1 numpy array 

174 Current state. 

175 meas_fun_args : tuple, optional 

176 Arguments for the measurement matrix function if one has 

177 been specified. The default is (). 

178 

179 Returns 

180 ------- 

181 next_state : N x 1 numpy array 

182 The corrected state. 

183 meas_fit_prob : float 

184 Goodness of fit of the measurement based on the state and 

185 scale assuming Student's t noise. 

186 """ 

187 est_meas, meas_mat = self._est_meas( 

188 timestep, cur_state, meas.size, meas_fun_args 

189 ) 

190 

191 # get gain 

192 scale_meas_T = self.scale @ meas_mat.T 

193 factor = ( 

194 self.meas_noise_dof 

195 * (self.dof - 2) 

196 / (self.dof * (self.meas_noise_dof - 2)) 

197 ) 

198 inov_cov = meas_mat @ scale_meas_T + factor * self.meas_noise 

199 inov_cov = (inov_cov + inov_cov.T) * 0.5 

200 if self.use_cholesky_inverse: 

201 sqrt_inv_inov_cov = la.inv(la.cholesky(inov_cov)) 

202 inv_inov_cov = sqrt_inv_inov_cov.T @ sqrt_inv_inov_cov 

203 else: 

204 inv_inov_cov = la.inv(inov_cov) 

205 gain = scale_meas_T @ inv_inov_cov 

206 P_kk = (np.eye(cur_state.shape[0]) - gain @ meas_mat) @ self.scale 

207 

208 # update state 

209 innov = meas - est_meas 

210 delta_2 = innov.T @ inv_inov_cov @ innov 

211 next_state = cur_state + gain @ innov 

212 

213 # moment matching 

214 if self.use_moment_matching: 

215 dof_p = self.dof + meas.size 

216 factor = (self.dof + delta_2) / dof_p 

217 P_k = factor * P_kk 

218 

219 factor = dof_p * (self.dof - 2) / (self.dof * (dof_p - 2)) 

220 self.scale = factor * P_k 

221 else: 

222 self.scale = P_kk 

223 # get measurement fit 

224 meas_fit_prob = self._calc_meas_fit(meas, est_meas, inov_cov) 

225 

226 return next_state, meas_fit_prob