Coverage for src/gncpy/filters/gci_filter.py: 90%
79 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
1import numpy as np
2import gncpy.data_fusion as gdf
4from gncpy.filters.bayes_filter import BayesFilter
5from warnings import warn
6from copy import deepcopy
9class GCIFilter(BayesFilter):
10 def __init__(
11 self,
12 base_filter=None,
13 meas_model_list=[],
14 meas_noise_list=[],
15 weight_list=None,
16 optimizer=None,
17 **kwargs
18 ):
19 super().__init__(**kwargs)
20 self.base_filter = base_filter
21 self.meas_model_list = meas_model_list
22 self.meas_noise_list = meas_noise_list
23 self.optimizer = optimizer
24 if weight_list is not None:
25 self.weight_list = weight_list
26 else:
27 self.weight_list = [
28 1 / (len(self.meas_model_list)) for ii in range(0, len(meas_model_list))
29 ]
31 def save_filter_state(self):
32 filt_state = super().save_filter_state()
33 if self.base_filter is not None:
34 filt_state["base_filter"] = (
35 type(self.base_filter),
36 self.base_filter.save_filter_state(),
37 )
38 else:
39 filt_state["base_filter"] = (None, self.base_filter)
40 filt_state["weight_list"] = deepcopy(self.weight_list)
41 filt_state["optimizer"] = self.optimizer
42 filt_state["meas_model_list"] = deepcopy(self.meas_model_list)
43 filt_state["meas_noise_list"] = deepcopy(self.meas_noise_list)
45 return filt_state
47 def load_filter_state(self, filt_state):
48 super().load_filter_state(filt_state)
49 cls_type = filt_state["base_filter"][0]
50 if cls_type is not None:
51 self.base_filter = cls_type()
52 self.base_filter.load_filter_state(filt_state["base_filter"][1])
53 else:
54 self.base_filter = None
55 self.weight_list = filt_state["weight_list"]
56 self.optimizer = filt_state["optimizer"]
57 self.meas_model_list = filt_state["meas_model_list"]
58 self.meas_noise_list = filt_state["meas_noise_list"]
60 def set_state_model(self, dyn_obj=None):
61 self.base_filter.set_state_model(dyn_obj=dyn_obj)
63 def set_measurement_model(self, meas_model_list=None):
64 warn(
65 "Measurement models defined in the constructor,"
66 " this function will overwrite initialized measurement model list."
67 )
68 if meas_model_list is not None:
69 self.meas_model_list = meas_model_list
71 @property
72 def cov(self):
73 return self.base_filter.cov
75 @cov.setter
76 def cov(self, val):
77 self.base_filter.cov = val
79 @property
80 def proc_noise(self):
81 return self.base_filter.proc_noise
83 @proc_noise.setter
84 def proc_noise(self, val):
85 self.base_filter.proc_noise = val
87 def predict(self, timestep, cur_state, **kwargs):
88 return self.base_filter.predict(timestep, cur_state, **kwargs)
90 def correct(self, timestep, meas_list, cur_state, meas_fun_args=()):
91 n_est_list = []
92 n_prob_list = []
93 n_cov_list = []
94 saved_state = self.base_filter.save_filter_state()
95 for ii, meas in enumerate(meas_list):
96 self.base_filter.load_filter_state(saved_state)
97 if isinstance(self.meas_model_list[ii], list):
98 self.base_filter.set_measurement_model(
99 meas_fun_lst=self.meas_model_list[ii]
100 )
101 else:
102 self.base_filter.set_measurement_model(
103 meas_mat=self.meas_model_list[ii]
104 )
105 self.base_filter.meas_noise = self.meas_noise_list[ii]
106 n_est, n_prob = self.base_filter.correct(
107 timestep, meas, cur_state, meas_fun_args
108 )
109 n_est_list.append(n_est)
110 n_cov_list.append(self.base_filter.cov)
111 n_prob_list.append(n_prob)
113 eye_list = [np.eye(len(cur_state)) for ii in range(len(self.meas_model_list))]
114 new_est, new_cov, new_weight_list = gdf.GeneralizedCovarianceIntersection(
115 n_est_list, n_cov_list, self.weight_list, eye_list, optimizer=self.optimizer
116 )
117 meas_fit_prob = 0
118 for ii, prob in enumerate(n_prob_list):
119 meas_fit_prob += new_weight_list[ii] * prob
121 self.base_filter.cov = new_cov
122 self.weight_list = new_weight_list
124 return new_est, meas_fit_prob