Coverage for src/gncpy/filters/imm_gci_filter.py: 95%
112 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-19 05:48 +0000
1from copy import deepcopy
2from gncpy.filters.interacting_multiple_model import InteractingMultipleModel
3from gncpy.filters.gci_filter import GCIFilter
4from warnings import warn
5import numpy as np
6import gncpy.data_fusion as gdf
9class IMMGCIFilter(InteractingMultipleModel, GCIFilter):
10 def __init__(
11 self,
12 meas_model_list=[],
13 meas_noise_list=[],
14 weight_list=None,
15 optimizer=None,
16 **kwargs
17 ):
18 super().__init__(**kwargs)
19 self.meas_model_list = meas_model_list
20 self.meas_noise_list = meas_noise_list
21 self.optimizer = optimizer
22 if weight_list is not None:
23 self.weight_list = weight_list
24 else:
25 self.weight_list = [
26 1 / (len(self.meas_model_list)) for ii in range(0, len(meas_model_list))
27 ]
29 def save_filter_state(self):
30 filt_state = {}
31 filt_tup_list = []
32 for filt in self.in_filt_list:
33 filt_dict = filt.save_filter_state()
34 filt_tup_list.append((type(filt), filt_dict))
36 filt_state["in_filt_list"] = filt_tup_list.copy()
37 filt_state["model_trans_mat"] = self.model_trans_mat.copy()
38 filt_state["filt_weights"] = self.filt_weights.copy()
39 filt_state["cur_out_state"] = self.cur_out_state.copy()
40 filt_state["filt_weight_history"] = self.filt_weight_history.copy()
41 filt_state["mean_list"] = deepcopy(self.mean_list)
42 filt_state["cov_list"] = deepcopy(self.cov_list)
43 filt_state["weight_list"] = self.weight_list
44 filt_state["optimizer"] = self.optimizer
45 filt_state["meas_model_list"] = self.meas_model_list
46 filt_state["meas_noise_list"] = self.meas_noise_list
48 return filt_state
50 def load_filter_state(self, filt_state):
51 self.in_filt_list = []
52 filt_tup_list = filt_state["in_filt_list"]
53 for tup in filt_tup_list:
54 cls_type = tup[0]
55 if cls_type is not None:
56 filt = cls_type()
57 filt.load_filter_state(tup[1])
58 else:
59 filt = None
60 self.in_filt_list.append(filt)
61 self.model_trans_mat = filt_state["model_trans_mat"]
62 self.filt_weights = filt_state["filt_weights"]
63 self.cur_out_state = filt_state["cur_out_state"]
64 self.filt_weight_history = filt_state["filt_weight_history"]
65 self.mean_list = filt_state["mean_list"]
66 self.cov_list = filt_state["cov_list"]
67 self.weight_list = filt_state["weight_list"]
68 self.optimizer = filt_state["optimizer"]
69 self.meas_model_list = filt_state["meas_model_list"]
70 self.meas_noise_list = filt_state["meas_noise_list"]
72 def set_measurement_model(self, meas_model_list=None):
73 warn(
74 "Measurement models defined in the constructor,"
75 " this function will overwrite initialized measurement model list."
76 )
77 if meas_model_list is not None:
78 self.meas_model_list = meas_model_list
80 def predict(self, timestep, **kwargs):
81 return super().predict(timestep, **kwargs)
83 def correct(self, timestep, meas_list, meas_fun_args=(), **kwargs):
84 # initialize lists
85 est_list = []
86 cov_list = []
87 meas_fit_prob_list = []
88 new_weight_list = []
89 new_gci_weights = []
90 all_weights = []
91 eye_list = [
92 np.eye(self.cov_list[0].shape[0]) for ii in range(len(self.meas_model_list))
93 ]
94 # loop over motion models
95 for ii, filt in enumerate(self.in_filt_list):
96 cur_est_list = []
97 cur_cov_list = []
98 cur_weight_list = []
99 cur_gci_weights = []
100 saved_state = filt.save_filter_state()
101 # loop over measurements
102 for jj, meas in enumerate(meas_list):
103 filt.load_filter_state(saved_state)
104 if isinstance(self.meas_model_list[jj], list):
105 filt.set_measurement_model(meas_fun_lst=self.meas_model_list[jj])
106 else:
107 filt.set_measurement_model(meas_mat=self.meas_model_list[jj])
108 filt.meas_noise = self.meas_noise_list[ii]
109 new_state, new_prob = filt.correct(
110 timestep, meas, self.mean_list[ii].reshape((-1, 1))
111 )
112 est_list.append(new_state)
113 cov_list.append(filt.cov)
114 cur_est_list.append(new_state)
115 cur_cov_list.append(filt.cov)
116 meas_fit_prob_list.append(new_prob)
117 cur_weight_list.append(new_prob * self.filt_weights[ii])
118 # fuse measurements for each motion model
119 model_est, model_cov, model_weights = gdf.GeneralizedCovarianceIntersection(
120 cur_est_list,
121 cur_cov_list,
122 self.weight_list,
123 eye_list,
124 optimizer=self.optimizer,
125 )
126 cur_gci_weights.extend([mw * self.filt_weights[ii] for mw in model_weights])
127 filt.cov = model_cov.copy()
128 self.mean_list[ii] = model_est.copy()
129 self.cov_list[ii] = filt.cov.copy()
130 all_weights.extend(list(np.array(model_weights) * self.filt_weights[ii]))
131 new_weight_list.append(np.sum(cur_weight_list))
132 new_gci_weights.append(cur_gci_weights)
133 if np.sum(new_weight_list) == 0:
134 new_weight_list = new_weight_list * 0
135 else:
136 new_weight_list = new_weight_list / np.sum(new_weight_list)
137 tngw = np.array(new_gci_weights)
138 out_gci_weights = []
139 for col in range(tngw.shape[1]):
140 # if np.sum(tngw[;,col]) == 0:
141 # out_gci_weights.append(np.sum(tngw[:, col]) * 0)
142 # else:
143 out_gci_weights.append(np.sum(tngw[:, col]))
144 # new_gci_weights = new_gci_weights / np.sum(new_gci_weights)
146 self.filt_weights = new_weight_list
147 # self.weight_list = new_gci_weights
148 self.weight_list = list(np.array(out_gci_weights)/np.sum(out_gci_weights))
150 out_meas_fit_prob = 0
151 for ii, prob in enumerate(meas_fit_prob_list):
152 out_meas_fit_prob += all_weights[ii] * prob
154 # compute output state from combined motion models
155 out_state = np.zeros(np.shape(self.mean_list[0]))
156 for ii in range(0, len(self.in_filt_list)):
157 out_state = out_state + new_weight_list[ii] * self.mean_list[ii]
158 out_state = out_state.reshape((np.shape(out_state)[0], 1))
159 self.cur_out_state = out_state
161 return (out_state, out_meas_fit_prob)