Coverage for src/gncpy/filters/imm_gci_filter.py: 95%

112 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-19 05:48 +0000

1from copy import deepcopy 

2from gncpy.filters.interacting_multiple_model import InteractingMultipleModel 

3from gncpy.filters.gci_filter import GCIFilter 

4from warnings import warn 

5import numpy as np 

6import gncpy.data_fusion as gdf 

7 

8 

9class IMMGCIFilter(InteractingMultipleModel, GCIFilter): 

10 def __init__( 

11 self, 

12 meas_model_list=[], 

13 meas_noise_list=[], 

14 weight_list=None, 

15 optimizer=None, 

16 **kwargs 

17 ): 

18 super().__init__(**kwargs) 

19 self.meas_model_list = meas_model_list 

20 self.meas_noise_list = meas_noise_list 

21 self.optimizer = optimizer 

22 if weight_list is not None: 

23 self.weight_list = weight_list 

24 else: 

25 self.weight_list = [ 

26 1 / (len(self.meas_model_list)) for ii in range(0, len(meas_model_list)) 

27 ] 

28 

29 def save_filter_state(self): 

30 filt_state = {} 

31 filt_tup_list = [] 

32 for filt in self.in_filt_list: 

33 filt_dict = filt.save_filter_state() 

34 filt_tup_list.append((type(filt), filt_dict)) 

35 

36 filt_state["in_filt_list"] = filt_tup_list.copy() 

37 filt_state["model_trans_mat"] = self.model_trans_mat.copy() 

38 filt_state["filt_weights"] = self.filt_weights.copy() 

39 filt_state["cur_out_state"] = self.cur_out_state.copy() 

40 filt_state["filt_weight_history"] = self.filt_weight_history.copy() 

41 filt_state["mean_list"] = deepcopy(self.mean_list) 

42 filt_state["cov_list"] = deepcopy(self.cov_list) 

43 filt_state["weight_list"] = self.weight_list 

44 filt_state["optimizer"] = self.optimizer 

45 filt_state["meas_model_list"] = self.meas_model_list 

46 filt_state["meas_noise_list"] = self.meas_noise_list 

47 

48 return filt_state 

49 

50 def load_filter_state(self, filt_state): 

51 self.in_filt_list = [] 

52 filt_tup_list = filt_state["in_filt_list"] 

53 for tup in filt_tup_list: 

54 cls_type = tup[0] 

55 if cls_type is not None: 

56 filt = cls_type() 

57 filt.load_filter_state(tup[1]) 

58 else: 

59 filt = None 

60 self.in_filt_list.append(filt) 

61 self.model_trans_mat = filt_state["model_trans_mat"] 

62 self.filt_weights = filt_state["filt_weights"] 

63 self.cur_out_state = filt_state["cur_out_state"] 

64 self.filt_weight_history = filt_state["filt_weight_history"] 

65 self.mean_list = filt_state["mean_list"] 

66 self.cov_list = filt_state["cov_list"] 

67 self.weight_list = filt_state["weight_list"] 

68 self.optimizer = filt_state["optimizer"] 

69 self.meas_model_list = filt_state["meas_model_list"] 

70 self.meas_noise_list = filt_state["meas_noise_list"] 

71 

72 def set_measurement_model(self, meas_model_list=None): 

73 warn( 

74 "Measurement models defined in the constructor," 

75 " this function will overwrite initialized measurement model list." 

76 ) 

77 if meas_model_list is not None: 

78 self.meas_model_list = meas_model_list 

79 

80 def predict(self, timestep, **kwargs): 

81 return super().predict(timestep, **kwargs) 

82 

83 def correct(self, timestep, meas_list, meas_fun_args=(), **kwargs): 

84 # initialize lists 

85 est_list = [] 

86 cov_list = [] 

87 meas_fit_prob_list = [] 

88 new_weight_list = [] 

89 new_gci_weights = [] 

90 all_weights = [] 

91 eye_list = [ 

92 np.eye(self.cov_list[0].shape[0]) for ii in range(len(self.meas_model_list)) 

93 ] 

94 # loop over motion models 

95 for ii, filt in enumerate(self.in_filt_list): 

96 cur_est_list = [] 

97 cur_cov_list = [] 

98 cur_weight_list = [] 

99 cur_gci_weights = [] 

100 saved_state = filt.save_filter_state() 

101 # loop over measurements 

102 for jj, meas in enumerate(meas_list): 

103 filt.load_filter_state(saved_state) 

104 if isinstance(self.meas_model_list[jj], list): 

105 filt.set_measurement_model(meas_fun_lst=self.meas_model_list[jj]) 

106 else: 

107 filt.set_measurement_model(meas_mat=self.meas_model_list[jj]) 

108 filt.meas_noise = self.meas_noise_list[ii] 

109 new_state, new_prob = filt.correct( 

110 timestep, meas, self.mean_list[ii].reshape((-1, 1)) 

111 ) 

112 est_list.append(new_state) 

113 cov_list.append(filt.cov) 

114 cur_est_list.append(new_state) 

115 cur_cov_list.append(filt.cov) 

116 meas_fit_prob_list.append(new_prob) 

117 cur_weight_list.append(new_prob * self.filt_weights[ii]) 

118 # fuse measurements for each motion model 

119 model_est, model_cov, model_weights = gdf.GeneralizedCovarianceIntersection( 

120 cur_est_list, 

121 cur_cov_list, 

122 self.weight_list, 

123 eye_list, 

124 optimizer=self.optimizer, 

125 ) 

126 cur_gci_weights.extend([mw * self.filt_weights[ii] for mw in model_weights]) 

127 filt.cov = model_cov.copy() 

128 self.mean_list[ii] = model_est.copy() 

129 self.cov_list[ii] = filt.cov.copy() 

130 all_weights.extend(list(np.array(model_weights) * self.filt_weights[ii])) 

131 new_weight_list.append(np.sum(cur_weight_list)) 

132 new_gci_weights.append(cur_gci_weights) 

133 if np.sum(new_weight_list) == 0: 

134 new_weight_list = new_weight_list * 0 

135 else: 

136 new_weight_list = new_weight_list / np.sum(new_weight_list) 

137 tngw = np.array(new_gci_weights) 

138 out_gci_weights = [] 

139 for col in range(tngw.shape[1]): 

140 # if np.sum(tngw[;,col]) == 0: 

141 # out_gci_weights.append(np.sum(tngw[:, col]) * 0) 

142 # else: 

143 out_gci_weights.append(np.sum(tngw[:, col])) 

144 # new_gci_weights = new_gci_weights / np.sum(new_gci_weights) 

145 

146 self.filt_weights = new_weight_list 

147 # self.weight_list = new_gci_weights 

148 self.weight_list = list(np.array(out_gci_weights)/np.sum(out_gci_weights)) 

149 

150 out_meas_fit_prob = 0 

151 for ii, prob in enumerate(meas_fit_prob_list): 

152 out_meas_fit_prob += all_weights[ii] * prob 

153 

154 # compute output state from combined motion models 

155 out_state = np.zeros(np.shape(self.mean_list[0])) 

156 for ii in range(0, len(self.in_filt_list)): 

157 out_state = out_state + new_weight_list[ii] * self.mean_list[ii] 

158 out_state = out_state.reshape((np.shape(out_state)[0], 1)) 

159 self.cur_out_state = out_state 

160 

161 return (out_state, out_meas_fit_prob)