Coverage for src/gncpy/filters/gci_filter.py: 90%

79 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-19 05:48 +0000

1import numpy as np 

2import gncpy.data_fusion as gdf 

3 

4from gncpy.filters.bayes_filter import BayesFilter 

5from warnings import warn 

6from copy import deepcopy 

7 

8 

9class GCIFilter(BayesFilter): 

10 def __init__( 

11 self, 

12 base_filter=None, 

13 meas_model_list=[], 

14 meas_noise_list=[], 

15 weight_list=None, 

16 optimizer=None, 

17 **kwargs 

18 ): 

19 super().__init__(**kwargs) 

20 self.base_filter = base_filter 

21 self.meas_model_list = meas_model_list 

22 self.meas_noise_list = meas_noise_list 

23 self.optimizer = optimizer 

24 if weight_list is not None: 

25 self.weight_list = weight_list 

26 else: 

27 self.weight_list = [ 

28 1 / (len(self.meas_model_list)) for ii in range(0, len(meas_model_list)) 

29 ] 

30 

31 def save_filter_state(self): 

32 filt_state = super().save_filter_state() 

33 if self.base_filter is not None: 

34 filt_state["base_filter"] = ( 

35 type(self.base_filter), 

36 self.base_filter.save_filter_state(), 

37 ) 

38 else: 

39 filt_state["base_filter"] = (None, self.base_filter) 

40 filt_state["weight_list"] = deepcopy(self.weight_list) 

41 filt_state["optimizer"] = self.optimizer 

42 filt_state["meas_model_list"] = deepcopy(self.meas_model_list) 

43 filt_state["meas_noise_list"] = deepcopy(self.meas_noise_list) 

44 

45 return filt_state 

46 

47 def load_filter_state(self, filt_state): 

48 super().load_filter_state(filt_state) 

49 cls_type = filt_state["base_filter"][0] 

50 if cls_type is not None: 

51 self.base_filter = cls_type() 

52 self.base_filter.load_filter_state(filt_state["base_filter"][1]) 

53 else: 

54 self.base_filter = None 

55 self.weight_list = filt_state["weight_list"] 

56 self.optimizer = filt_state["optimizer"] 

57 self.meas_model_list = filt_state["meas_model_list"] 

58 self.meas_noise_list = filt_state["meas_noise_list"] 

59 

60 def set_state_model(self, dyn_obj=None): 

61 self.base_filter.set_state_model(dyn_obj=dyn_obj) 

62 

63 def set_measurement_model(self, meas_model_list=None): 

64 warn( 

65 "Measurement models defined in the constructor," 

66 " this function will overwrite initialized measurement model list." 

67 ) 

68 if meas_model_list is not None: 

69 self.meas_model_list = meas_model_list 

70 

71 @property 

72 def cov(self): 

73 return self.base_filter.cov 

74 

75 @cov.setter 

76 def cov(self, val): 

77 self.base_filter.cov = val 

78 

79 @property 

80 def proc_noise(self): 

81 return self.base_filter.proc_noise 

82 

83 @proc_noise.setter 

84 def proc_noise(self, val): 

85 self.base_filter.proc_noise = val 

86 

87 def predict(self, timestep, cur_state, **kwargs): 

88 return self.base_filter.predict(timestep, cur_state, **kwargs) 

89 

90 def correct(self, timestep, meas_list, cur_state, meas_fun_args=()): 

91 n_est_list = [] 

92 n_prob_list = [] 

93 n_cov_list = [] 

94 saved_state = self.base_filter.save_filter_state() 

95 for ii, meas in enumerate(meas_list): 

96 self.base_filter.load_filter_state(saved_state) 

97 if isinstance(self.meas_model_list[ii], list): 

98 self.base_filter.set_measurement_model( 

99 meas_fun_lst=self.meas_model_list[ii] 

100 ) 

101 else: 

102 self.base_filter.set_measurement_model( 

103 meas_mat=self.meas_model_list[ii] 

104 ) 

105 self.base_filter.meas_noise = self.meas_noise_list[ii] 

106 n_est, n_prob = self.base_filter.correct( 

107 timestep, meas, cur_state, meas_fun_args 

108 ) 

109 n_est_list.append(n_est) 

110 n_cov_list.append(self.base_filter.cov) 

111 n_prob_list.append(n_prob) 

112 

113 eye_list = [np.eye(len(cur_state)) for ii in range(len(self.meas_model_list))] 

114 new_est, new_cov, new_weight_list = gdf.GeneralizedCovarianceIntersection( 

115 n_est_list, n_cov_list, self.weight_list, eye_list, optimizer=self.optimizer 

116 ) 

117 meas_fit_prob = 0 

118 for ii, prob in enumerate(n_prob_list): 

119 meas_fit_prob += new_weight_list[ii] * prob 

120 

121 self.base_filter.cov = new_cov 

122 self.weight_list = new_weight_list 

123 

124 return new_est, meas_fit_prob