Coverage for src/gncpy/filters/square_root_qkf.py: 99%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-13 06:15 +0000

1import numpy as np 

2import numpy.linalg as la 

3import scipy.linalg as sla 

4 

5 

6from gncpy.filters.quadrature_kalman_filter import QuadratureKalmanFilter 

7 

8 

9class SquareRootQKF(QuadratureKalmanFilter): 

10 """Implementation of a Square root Quadrature Kalman Filter (SQKF). 

11 

12 Notes 

13 ----- 

14 This is based on :cite:`Arasaratnam2008_SquareRootQuadratureKalmanFiltering`. 

15 """ 

16 

17 def __init__(self, **kwargs): 

18 super().__init__(**kwargs) 

19 

20 self._meas_noise = np.array([[]]) 

21 self._sqrt_p_noise = np.array([[]]) 

22 self._sqrt_m_noise = np.array([[]]) 

23 

24 def save_filter_state(self): 

25 """Saves filter variables so they can be restored later.""" 

26 filt_state = super().save_filter_state() 

27 

28 filt_state["_meas_noise"] = self._meas_noise 

29 filt_state["_sqrt_p_noise"] = self._sqrt_p_noise 

30 filt_state["_sqrt_m_noise"] = self._sqrt_m_noise 

31 

32 return filt_state 

33 

34 def load_filter_state(self, filt_state): 

35 """Initializes filter using saved filter state. 

36 

37 Attributes 

38 ---------- 

39 filt_state : dict 

40 Dictionary generated by :meth:`save_filter_state`. 

41 """ 

42 super().load_filter_state(filt_state) 

43 

44 self._meas_noise = filt_state["_meas_noise"] 

45 self._sqrt_p_noise = filt_state["_sqrt_p_noise"] 

46 self._sqrt_m_noise = filt_state["_sqrt_m_noise"] 

47 

48 def set_measurement_noise_estimator(self, function): 

49 """Sets the model used for estimating the measurement noise parameters. 

50 

51 This is an optional step and the filter will work properly if this is 

52 not called. If it is called, the measurement noise will be estimated 

53 during the filter's correction step and the measurement noise attribute 

54 will not be used. 

55 

56 Parameters 

57 ---------- 

58 function : callable 

59 A function that implements the prediction and correction steps for 

60 an appropriate filter to estimate the measurement noise covariance 

61 matrix. It must have the signature `f(est_meas)` where `est_meas` 

62 is an Nm x 1 numpy array and it must return an Nm x Nm numpy array 

63 representing the measurement noise covariance matrix. 

64 

65 Returns 

66 ------- 

67 None. 

68 """ 

69 self._est_meas_noise_fnc = function 

70 

71 @property 

72 def cov(self): 

73 """Covariance of the filter.""" 

74 # sqrt cov is lower triangular 

75 return self._sqrt_cov @ self._sqrt_cov.T 

76 

77 @cov.setter 

78 def cov(self, val): 

79 if val.size == 0: 

80 self._sqrt_cov = val 

81 else: 

82 super()._factorize_cov(val=val) 

83 

84 @property 

85 def proc_noise(self): 

86 """Process noise of the filter.""" 

87 return self._sqrt_p_noise @ self._sqrt_p_noise.T 

88 

89 @proc_noise.setter 

90 def proc_noise(self, val): 

91 if val.size == 0 or np.all(val == 0): 

92 self._sqrt_p_noise = val 

93 else: 

94 self._sqrt_p_noise = la.cholesky(val) 

95 

96 @property 

97 def meas_noise(self): 

98 """Measurement noise of the filter.""" 

99 return self._sqrt_m_noise @ self._sqrt_m_noise.T 

100 

101 @meas_noise.setter 

102 def meas_noise(self, val): 

103 if val.size == 0 or np.all(val == 0): 

104 self._sqrt_m_noise = val 

105 else: 

106 self._sqrt_m_noise = la.cholesky(val) 

107 

108 def _factorize_cov(self): 

109 pass 

110 

111 def _pred_update_cov(self): 

112 weight_mat = np.diag(np.sqrt(self.quadPoints.weights)) 

113 x_hat = self.quadPoints.mean 

114 state_mat = np.concatenate( 

115 [x.reshape((x.size, 1)) - x_hat for x in self.quadPoints.points], axis=1 

116 ) 

117 

118 self._sqrt_cov = la.qr( 

119 np.concatenate((state_mat @ weight_mat, self._sqrt_p_noise.T), axis=1).T, 

120 mode="r", 

121 ).T 

122 

123 def _corr_update_cov(self, gain, state_mat, meas_mat): 

124 self._sqrt_cov = la.qr( 

125 np.concatenate( 

126 (state_mat - gain @ meas_mat, gain @ self._sqrt_m_noise), axis=1 

127 ).T, 

128 mode="r", 

129 ).T 

130 

131 def correct(self, timestep, meas, cur_state, meas_fun_args=()): 

132 """Implements the correction step of the filter. 

133 

134 Parameters 

135 ---------- 

136 timestep : float 

137 Current timestep. 

138 meas : Nm x 1 numpy array 

139 Current measurement. 

140 cur_state : N x 1 numpy array 

141 Current state. 

142 meas_fun_args : tuple, optional 

143 Arguments for the measurement matrix function if one has 

144 been specified. The default is (). 

145 

146 Raises 

147 ------ 

148 :class:`.errors.ExtremeMeasurementNoiseError` 

149 If estimating the measurement noise and the measurement fit calculation fails. 

150 LinAlgError 

151 Numpy exception raised if not estimating noise and measurement fit fails. 

152 

153 Returns 

154 ------- 

155 next_state : N x 1 numpy array 

156 The corrected state. 

157 meas_fit_prob : float 

158 Goodness of fit of the measurement based on the state and 

159 covariance assuming Gaussian noise. 

160 

161 """ 

162 measQuads, est_meas = self._corr_core(timestep, cur_state, meas, meas_fun_args) 

163 

164 weight_mat = np.diag(np.sqrt(self.quadPoints.weights)) 

165 

166 # calculate sqrt of the measurement covariance 

167 meas_mat = ( 

168 np.concatenate( 

169 [z.reshape((z.size, 1)) - est_meas for z in measQuads.points], axis=1 

170 ) 

171 @ weight_mat 

172 ) 

173 if self._est_meas_noise_fnc is not None: 

174 self.meas_noise = self._est_meas_noise_fnc(est_meas, meas_mat @ meas_mat.T) 

175 sqrt_inov_cov = la.qr( 

176 np.concatenate((meas_mat, self._sqrt_m_noise), axis=1).T, mode="r" 

177 ).T 

178 

179 # calculate cross covariance 

180 x_hat = self.quadPoints.mean 

181 state_mat = ( 

182 np.concatenate( 

183 [x.reshape((x.size, 1)) - x_hat for x in self.quadPoints.points], axis=1 

184 ) 

185 @ weight_mat 

186 ) 

187 cross_cov = state_mat @ meas_mat.T 

188 

189 # calculate gain 

190 inter = sla.solve_triangular(sqrt_inov_cov.T, cross_cov.T) 

191 gain = sla.solve_triangular(sqrt_inov_cov, inter, lower=True).T 

192 

193 # the above gain is equavalent to 

194 # inv_sqrt_inov_cov = la.inv(sqrt_inov_cov) 

195 # gain = cross_cov @ (inv_sqrt_inov_cov.T @ inv_sqrt_inov_cov) 

196 

197 # state is x_hat + K *(z - z_hat) 

198 innov = meas - est_meas 

199 cor_state = cur_state + gain @ innov 

200 

201 # update covariance 

202 inov_cov = sqrt_inov_cov @ sqrt_inov_cov.T 

203 

204 self._corr_update_cov(gain, state_mat, meas_mat) 

205 

206 meas_fit_prob = self._calc_meas_fit(meas, est_meas, inov_cov) 

207 

208 return (cor_state, meas_fit_prob)