Coverage for src/gncpy/filters/interacting_multiple_model.py: 93%
130 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
1import numpy as np
2from warnings import warn
3from copy import deepcopy
4from gncpy.filters.particle_filter import ParticleFilter
7class InteractingMultipleModel:
8 """Implementation of an InteractingMultipleModel (IMM) filter.
10 This filter combines several inner filters with different dynamic models
11 and selects the best estimate based on prior predictions and measurement
12 updates.
13 """
15 def __init__(self):
16 self.in_filt_list = []
17 self.cur_filt_ind = 0
18 self.model_trans_mat = np.array([[]])
19 self.filt_weights = np.array([])
20 self.cur_out_state = np.array([])
21 self.filt_weight_history = np.array([[]])
22 self.mean_list = []
23 self.cov_list = []
25 @property
26 def cov(self):
27 """Covariance for the IMM Filter."""
28 cur_cov = np.zeros((np.shape(self.cov_list[0])))
29 for ii in range(0, np.shape(self.cov_list)[0]):
30 cur_cov = cur_cov + self.filt_weights[ii] * (
31 self.cov_list[ii]
32 + (
33 self.mean_list[ii].reshape(np.shape(self.mean_list[0]))
34 - self.cur_out_state.reshape(np.shape(self.mean_list[0]))
35 )
36 @ (
37 self.mean_list[ii].reshape(np.shape(self.mean_list[0]))
38 - self.cur_out_state.reshape(np.shape(self.mean_list[0]))
39 ).T
40 )
41 return cur_cov
43 @cov.setter
44 def cov(self, val):
45 warn("Covariance is read only. SKIPPING")
47 def save_filter_state(self):
48 """Saves filter variables so they can be restored later.
50 Note that to pickle the resulting dictionary the :code:`dill` package
51 may need to be used due to potential pickling of functions.
52 """
53 filt_state = {}
54 filt_tup_list = []
55 for filt in self.in_filt_list:
56 filt_dict = filt.save_filter_state()
57 filt_tup_list.append((type(filt), filt_dict))
59 filt_state["in_filt_list"] = filt_tup_list.copy()
60 filt_state["model_trans_mat"] = self.model_trans_mat.copy()
61 filt_state["filt_weights"] = self.filt_weights.copy()
62 filt_state["cur_out_state"] = self.cur_out_state.copy()
63 filt_state["filt_weight_history"] = self.filt_weight_history.copy()
64 filt_state["mean_list"] = deepcopy(self.mean_list)
65 filt_state["cov_list"] = deepcopy(self.cov_list)
67 return filt_state
69 def load_filter_state(self, filt_state):
70 """Initializes filter using saved filter state.
72 Attributes
73 ----------
74 filt_state : dict
75 Dictionary generated by :meth:`save_filter_state`.
76 """
77 self.in_filt_list = []
78 filt_tup_list = filt_state["in_filt_list"]
79 for tup in filt_tup_list:
80 cls_type = tup[0]
81 if cls_type is not None:
82 filt = cls_type()
83 filt.load_filter_state(tup[1])
84 else:
85 filt = None
86 self.in_filt_list.append(filt)
87 self.model_trans_mat = filt_state["model_trans_mat"]
88 self.filt_weights = filt_state["filt_weights"]
89 self.cur_out_state = filt_state["cur_out_state"]
90 self.filt_weight_history = filt_state["filt_weight_history"]
91 self.mean_list = filt_state["mean_list"]
92 self.cov_list = filt_state["cov_list"]
94 def initialize_states(self, init_means, init_covs, init_weights=None):
95 self.mean_list = init_means
96 self.cov_list = init_covs
97 if len(init_means) != len(self.in_filt_list) or len(init_covs) != len(
98 self.in_filt_list
99 ):
100 raise ValueError(
101 "Number of means or covariances does not match number of inner filters"
102 )
103 for ii in range(0, len(self.in_filt_list)):
104 self.in_filt_list[ii].mean = self.mean_list[ii].reshape((-1, 1))
105 self.in_filt_list[ii].cov = self.cov_list[ii]
107 if init_weights is not None:
108 self.filt_weights = init_weights
109 else:
110 self.filt_weights = np.ones(len(self.in_filt_list)) / len(self.in_filt_list)
111 self.filt_weight_history = np.array(self.filt_weights)
112 out_state = np.zeros((self.mean_list[0].size, 1))
113 for ii in range(len(self.in_filt_list)):
114 out_state += self.filt_weights[ii] * self.mean_list[ii]
115 self.cur_out_state = out_state
117 def initialize_filters(self, filter_lst, model_trans):
118 if (
119 len(filter_lst) != np.shape(model_trans)[0]
120 or len(filter_lst) != np.shape(model_trans)[1]
121 or np.shape(model_trans)[0] != np.shape(model_trans)[1]
122 ):
123 raise ValueError(
124 "filter list must be same size as square matrix model_trans"
125 )
126 self.model_trans_mat = model_trans
127 self.in_filt_list = filter_lst
129 def set_models(
130 self, filter_lst, model_trans, init_means, init_covs, init_weights=None
131 ):
132 """Set different filters and dynamics models for an IMM filter."""
134 self.initialize_filters(filter_lst, model_trans)
135 self.initialize_states(init_means, init_covs, init_weights=init_weights)
137 def set_measurement_model(self, **kwargs):
138 for filt in self.in_filt_list:
139 filt.set_measurement_model(**kwargs)
141 def predict(self, timestep, *args, **kwargs):
142 """Prediction step for the IMM filter."""
143 new_weight_list = []
144 new_mean_list = []
145 new_cov_list = []
147 # Perform inner filter predictions
148 for ii, filt in enumerate(self.in_filt_list):
149 new_weight = 0
150 weighted_state = np.zeros(self.mean_list[0].shape)
151 weighted_cov = np.zeros(np.shape(self.cov_list[0]))
152 # Calculate weighted input states and new weights
153 for jj in range(0, len(self.in_filt_list)):
154 new_weight += self.model_trans_mat[ii][jj] * self.filt_weights[jj]
155 weighted_state += (
156 self.model_trans_mat[ii][jj]
157 * self.filt_weights[jj]
158 * self.mean_list[jj]
159 )
160 # Normalize weighted state
161 weighted_state = weighted_state / new_weight
163 # Iterate through all means/weights to compile weighted covariance
164 for jj in range(0, len(self.in_filt_list)):
165 weighted_cov += (
166 self.model_trans_mat[ii][jj]
167 * self.filt_weights[jj]
168 * (
169 self.cov_list[jj]
170 + (self.mean_list[jj] - weighted_state)
171 @ (self.mean_list[jj] - weighted_state).T
172 )
173 )
174 weighted_cov = weighted_cov / new_weight
176 # Perform inner filter prediction
177 if not isinstance(filt, ParticleFilter):
178 filt.cov = weighted_cov.copy()
179 new_mean_list.append(
180 filt.predict(timestep, weighted_state, *args, **kwargs)
181 )
182 new_cov_list.append(filt.cov.copy())
183 else:
184 raise ValueError("Particle Filters not enabled with IMM")
185 new_weight_list.append(new_weight)
186 self.mean_list = new_mean_list
187 self.cov_list = new_cov_list
188 self.filt_weights = np.array(new_weight_list)
189 self.filt_weight_history = np.vstack(
190 (self.filt_weight_history, new_weight_list)
191 )
192 # Normalize weights
193 if np.sum(new_weight_list) != 1:
194 new_weight_list = new_weight_list / np.sum(new_weight_list)
195 # Output predicted state
196 out_state = np.zeros(self.mean_list[0].shape)
197 for ii in range(0, len(self.in_filt_list)):
198 out_state += new_weight_list[ii] * self.mean_list[ii]
199 self.cur_out_state = out_state
200 return out_state
202 def correct(self, timestep, meas, *args, **kwargs):
203 """Measurement correction step for the IMM filter."""
204 new_weight_list = np.zeros(np.shape(self.filt_weights))
205 meas_fit_prob_list = np.zeros(len(self.in_filt_list))
206 for ii, filt in enumerate(self.in_filt_list):
207 if not isinstance(filt, ParticleFilter):
208 (self.mean_list[ii], meas_fit_prob_list[ii]) = filt.correct(
209 timestep, meas, self.mean_list[ii].reshape((-1, 1))
210 )
211 self.cov_list[ii] = filt.cov.copy()
212 new_weight_list[ii] = meas_fit_prob_list[ii] * self.filt_weights[ii]
213 out_meas_fit_prob = np.sum(new_weight_list)
214 if np.sum(new_weight_list) == 0:
215 new_weight_list = new_weight_list * 0
216 else:
217 new_weight_list = new_weight_list / np.sum(new_weight_list)
218 self.filt_weights = new_weight_list
220 out_state = np.zeros(self.mean_list[0].shape)
221 for ii in range(len(self.in_filt_list)):
222 out_state += new_weight_list[ii] * self.mean_list[ii]
224 self.cur_out_state = out_state
225 return (out_state, out_meas_fit_prob)