Coverage for src/gncpy/filters/interacting_multiple_model.py: 93%

130 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-13 06:15 +0000

1import numpy as np 

2from warnings import warn 

3from copy import deepcopy 

4from gncpy.filters.particle_filter import ParticleFilter 

5 

6 

7class InteractingMultipleModel: 

8 """Implementation of an InteractingMultipleModel (IMM) filter. 

9 

10 This filter combines several inner filters with different dynamic models 

11 and selects the best estimate based on prior predictions and measurement 

12 updates. 

13 """ 

14 

15 def __init__(self): 

16 self.in_filt_list = [] 

17 self.cur_filt_ind = 0 

18 self.model_trans_mat = np.array([[]]) 

19 self.filt_weights = np.array([]) 

20 self.cur_out_state = np.array([]) 

21 self.filt_weight_history = np.array([[]]) 

22 self.mean_list = [] 

23 self.cov_list = [] 

24 

25 @property 

26 def cov(self): 

27 """Covariance for the IMM Filter.""" 

28 cur_cov = np.zeros((np.shape(self.cov_list[0]))) 

29 for ii in range(0, np.shape(self.cov_list)[0]): 

30 cur_cov = cur_cov + self.filt_weights[ii] * ( 

31 self.cov_list[ii] 

32 + ( 

33 self.mean_list[ii].reshape(np.shape(self.mean_list[0])) 

34 - self.cur_out_state.reshape(np.shape(self.mean_list[0])) 

35 ) 

36 @ ( 

37 self.mean_list[ii].reshape(np.shape(self.mean_list[0])) 

38 - self.cur_out_state.reshape(np.shape(self.mean_list[0])) 

39 ).T 

40 ) 

41 return cur_cov 

42 

43 @cov.setter 

44 def cov(self, val): 

45 warn("Covariance is read only. SKIPPING") 

46 

47 def save_filter_state(self): 

48 """Saves filter variables so they can be restored later. 

49 

50 Note that to pickle the resulting dictionary the :code:`dill` package 

51 may need to be used due to potential pickling of functions. 

52 """ 

53 filt_state = {} 

54 filt_tup_list = [] 

55 for filt in self.in_filt_list: 

56 filt_dict = filt.save_filter_state() 

57 filt_tup_list.append((type(filt), filt_dict)) 

58 

59 filt_state["in_filt_list"] = filt_tup_list.copy() 

60 filt_state["model_trans_mat"] = self.model_trans_mat.copy() 

61 filt_state["filt_weights"] = self.filt_weights.copy() 

62 filt_state["cur_out_state"] = self.cur_out_state.copy() 

63 filt_state["filt_weight_history"] = self.filt_weight_history.copy() 

64 filt_state["mean_list"] = deepcopy(self.mean_list) 

65 filt_state["cov_list"] = deepcopy(self.cov_list) 

66 

67 return filt_state 

68 

69 def load_filter_state(self, filt_state): 

70 """Initializes filter using saved filter state. 

71 

72 Attributes 

73 ---------- 

74 filt_state : dict 

75 Dictionary generated by :meth:`save_filter_state`. 

76 """ 

77 self.in_filt_list = [] 

78 filt_tup_list = filt_state["in_filt_list"] 

79 for tup in filt_tup_list: 

80 cls_type = tup[0] 

81 if cls_type is not None: 

82 filt = cls_type() 

83 filt.load_filter_state(tup[1]) 

84 else: 

85 filt = None 

86 self.in_filt_list.append(filt) 

87 self.model_trans_mat = filt_state["model_trans_mat"] 

88 self.filt_weights = filt_state["filt_weights"] 

89 self.cur_out_state = filt_state["cur_out_state"] 

90 self.filt_weight_history = filt_state["filt_weight_history"] 

91 self.mean_list = filt_state["mean_list"] 

92 self.cov_list = filt_state["cov_list"] 

93 

94 def initialize_states(self, init_means, init_covs, init_weights=None): 

95 self.mean_list = init_means 

96 self.cov_list = init_covs 

97 if len(init_means) != len(self.in_filt_list) or len(init_covs) != len( 

98 self.in_filt_list 

99 ): 

100 raise ValueError( 

101 "Number of means or covariances does not match number of inner filters" 

102 ) 

103 for ii in range(0, len(self.in_filt_list)): 

104 self.in_filt_list[ii].mean = self.mean_list[ii].reshape((-1, 1)) 

105 self.in_filt_list[ii].cov = self.cov_list[ii] 

106 

107 if init_weights is not None: 

108 self.filt_weights = init_weights 

109 else: 

110 self.filt_weights = np.ones(len(self.in_filt_list)) / len(self.in_filt_list) 

111 self.filt_weight_history = np.array(self.filt_weights) 

112 out_state = np.zeros((self.mean_list[0].size, 1)) 

113 for ii in range(len(self.in_filt_list)): 

114 out_state += self.filt_weights[ii] * self.mean_list[ii] 

115 self.cur_out_state = out_state 

116 

117 def initialize_filters(self, filter_lst, model_trans): 

118 if ( 

119 len(filter_lst) != np.shape(model_trans)[0] 

120 or len(filter_lst) != np.shape(model_trans)[1] 

121 or np.shape(model_trans)[0] != np.shape(model_trans)[1] 

122 ): 

123 raise ValueError( 

124 "filter list must be same size as square matrix model_trans" 

125 ) 

126 self.model_trans_mat = model_trans 

127 self.in_filt_list = filter_lst 

128 

129 def set_models( 

130 self, filter_lst, model_trans, init_means, init_covs, init_weights=None 

131 ): 

132 """Set different filters and dynamics models for an IMM filter.""" 

133 

134 self.initialize_filters(filter_lst, model_trans) 

135 self.initialize_states(init_means, init_covs, init_weights=init_weights) 

136 

137 def set_measurement_model(self, **kwargs): 

138 for filt in self.in_filt_list: 

139 filt.set_measurement_model(**kwargs) 

140 

141 def predict(self, timestep, *args, **kwargs): 

142 """Prediction step for the IMM filter.""" 

143 new_weight_list = [] 

144 new_mean_list = [] 

145 new_cov_list = [] 

146 

147 # Perform inner filter predictions 

148 for ii, filt in enumerate(self.in_filt_list): 

149 new_weight = 0 

150 weighted_state = np.zeros(self.mean_list[0].shape) 

151 weighted_cov = np.zeros(np.shape(self.cov_list[0])) 

152 # Calculate weighted input states and new weights 

153 for jj in range(0, len(self.in_filt_list)): 

154 new_weight += self.model_trans_mat[ii][jj] * self.filt_weights[jj] 

155 weighted_state += ( 

156 self.model_trans_mat[ii][jj] 

157 * self.filt_weights[jj] 

158 * self.mean_list[jj] 

159 ) 

160 # Normalize weighted state 

161 weighted_state = weighted_state / new_weight 

162 

163 # Iterate through all means/weights to compile weighted covariance 

164 for jj in range(0, len(self.in_filt_list)): 

165 weighted_cov += ( 

166 self.model_trans_mat[ii][jj] 

167 * self.filt_weights[jj] 

168 * ( 

169 self.cov_list[jj] 

170 + (self.mean_list[jj] - weighted_state) 

171 @ (self.mean_list[jj] - weighted_state).T 

172 ) 

173 ) 

174 weighted_cov = weighted_cov / new_weight 

175 

176 # Perform inner filter prediction 

177 if not isinstance(filt, ParticleFilter): 

178 filt.cov = weighted_cov.copy() 

179 new_mean_list.append( 

180 filt.predict(timestep, weighted_state, *args, **kwargs) 

181 ) 

182 new_cov_list.append(filt.cov.copy()) 

183 else: 

184 raise ValueError("Particle Filters not enabled with IMM") 

185 new_weight_list.append(new_weight) 

186 self.mean_list = new_mean_list 

187 self.cov_list = new_cov_list 

188 self.filt_weights = np.array(new_weight_list) 

189 self.filt_weight_history = np.vstack( 

190 (self.filt_weight_history, new_weight_list) 

191 ) 

192 # Normalize weights 

193 if np.sum(new_weight_list) != 1: 

194 new_weight_list = new_weight_list / np.sum(new_weight_list) 

195 # Output predicted state 

196 out_state = np.zeros(self.mean_list[0].shape) 

197 for ii in range(0, len(self.in_filt_list)): 

198 out_state += new_weight_list[ii] * self.mean_list[ii] 

199 self.cur_out_state = out_state 

200 return out_state 

201 

202 def correct(self, timestep, meas, *args, **kwargs): 

203 """Measurement correction step for the IMM filter.""" 

204 new_weight_list = np.zeros(np.shape(self.filt_weights)) 

205 meas_fit_prob_list = np.zeros(len(self.in_filt_list)) 

206 for ii, filt in enumerate(self.in_filt_list): 

207 if not isinstance(filt, ParticleFilter): 

208 (self.mean_list[ii], meas_fit_prob_list[ii]) = filt.correct( 

209 timestep, meas, self.mean_list[ii].reshape((-1, 1)) 

210 ) 

211 self.cov_list[ii] = filt.cov.copy() 

212 new_weight_list[ii] = meas_fit_prob_list[ii] * self.filt_weights[ii] 

213 out_meas_fit_prob = np.sum(new_weight_list) 

214 if np.sum(new_weight_list) == 0: 

215 new_weight_list = new_weight_list * 0 

216 else: 

217 new_weight_list = new_weight_list / np.sum(new_weight_list) 

218 self.filt_weights = new_weight_list 

219 

220 out_state = np.zeros(self.mean_list[0].shape) 

221 for ii in range(len(self.in_filt_list)): 

222 out_state += new_weight_list[ii] * self.mean_list[ii] 

223 

224 self.cur_out_state = out_state 

225 return (out_state, out_meas_fit_prob)