Coverage for src/carbs/swarm_estimator/tracker.py: 70%
3695 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-15 06:48 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-15 06:48 +0000
1"""Implements RFS tracking algorithms.
3This module contains the classes and data structures
4for RFS tracking related algorithms.
5"""
7import gncpy.filters
8import numpy as np
9import numpy.linalg as la
10import numpy.random as rnd
11import matplotlib.pyplot as plt
12from typing import Iterable
13from matplotlib.patches import Ellipse
14import matplotlib.animation as animation
15import abc
16from copy import deepcopy
17import warnings
18import itertools
20from carbs.utilities.graphs import (
21 k_shortest,
22 murty_m_best,
23 murty_m_best_all_meas_assigned,
24)
25from carbs.utilities.sampling import gibbs, mm_gibbs
27from gncpy.math import log_sum_exp, get_elem_sym_fnc
28import gncpy.plotting as pltUtil
29import gncpy.filters as gfilts
30import gncpy.errors as gerr
32import serums.models as smodels
33from serums.enums import SingleObjectDistance
34from serums.distances import calculate_ospa, calculate_ospa2, calculate_gospa
37class RandomFiniteSetBase(metaclass=abc.ABCMeta):
38 """Generic base class for RFS based filters.
40 Attributes
41 ----------
42 filter : gncpy.filters.BayesFilter
43 Filter handling dynamics
44 prob_detection : float
45 Modeled probability an object is detected
46 prob_survive : float
47 Modeled probability of object survival
48 birth_terms : list
49 List of terms in the birth model
50 clutter_rate : float
51 Rate of clutter
52 clutter_density : float
53 Density of clutter distribution
54 inv_chi2_gate : float
55 Chi squared threshold for gating the measurements
56 save_covs : bool
57 Save covariance matrix for each state during state extraction
58 debug_plots : bool
59 Saves data needed for extra debugging plots
60 ospa : numpy array
61 Calculated OSPA value for the given truth data. Must be manually updated
62 by a function call.
63 ospa_localization : numpy array
64 Calculated OSPA value for the given truth data. Must be manually updated
65 by a function call.
66 ospa_cardinality : numpy array
67 Calculated OSPA value for the given truth data. Must be manually updated
68 by a function call.
69 enable_spawning : bool
70 Flag for enabling spawning.
71 spawn_cov : N x N numpy array
72 Covariance for spawned targets.
73 spawn_weight : float
74 Weight for spawned targets.
75 """
77 def __init__(
78 self,
79 in_filter: gncpy.filters.BayesFilter = None,
80 prob_detection: float = 1,
81 prob_survive: float = 1,
82 birth_terms: list = None,
83 clutter_rate: float = 0,
84 clutter_den: float = 0,
85 inv_chi2_gate: float = 0,
86 save_covs: bool = False,
87 debug_plots: bool = False,
88 enable_spawning: bool = False,
89 spawn_cov: np.ndarray = None,
90 spawn_weight: float = None,
91 ):
92 """Initialize an object.
94 Parameters
95 ----------
96 in_filter
97 Inner filter object.
98 prob_detection
99 Probability of detection.
100 prob_survive
101 Probability of survival.
102 birth_terms
103 Birth model.
104 clutter_rate
105 Clutter rate per scan.
106 clutter_den
107 Clutter density.
108 inv_chi2_gate
109 Inverse Chi^2 gating threshold.
110 save_covs
111 Flag for saving covariances.
112 debug_plots
113 Flag for enabling debug plots.
114 enable_spawning
115 Flag for enabling spawning.
116 spawn_cov
117 Covariance for spawned targets.
118 spawn_weight
119 Weight for spawned targets.
120 """
121 if birth_terms is None:
122 birth_terms = []
123 self.filter = deepcopy(in_filter)
124 self.prob_detection = prob_detection
125 self.prob_survive = prob_survive
126 self.birth_terms = deepcopy(birth_terms)
127 self.clutter_rate = clutter_rate
128 if isinstance(clutter_den, np.ndarray):
129 clutter_den = clutter_den.item()
130 self.clutter_den = clutter_den
132 self.inv_chi2_gate = inv_chi2_gate
134 self.save_covs = save_covs
135 self.debug_plots = debug_plots
137 self.ospa = None
138 self.ospa_localization = None
139 self.ospa_cardinality = None
140 self._ospa_params = {}
141 self._gospa_params = {}
143 self._states = [] # local copy for internal modification
144 self._meas_tab = (
145 []
146 ) # list of lists, one per timestep, inner is all meas at time
147 self._covs = [] # local copy for internal modification
148 self.enable_spawning = enable_spawning
149 self.spawn_cov = spawn_cov
150 self.spawn_weight = spawn_weight
152 super().__init__()
154 @property
155 def ospa_method(self):
156 """The distance metric used in the OSPA calculation (read only)."""
157 if "core" in self._ospa_params:
158 return self._ospa_params["core"]
159 else:
160 return None
162 @ospa_method.setter
163 def ospa_method(self, val):
164 warnings.warn("OSPA method is read only. SKIPPING")
166 @abc.abstractmethod
167 def save_filter_state(self):
168 """Generic method for saving key filter variables.
170 This must be overridden in the inherited class. It is recommended to keep
171 the signature the same to allow for standardized implemenation of
172 wrapper classes. This should return a single variable that can be passed
173 to the loading function to setup a filter to the same internal state
174 as the current instance when this function was called.
175 """
176 filt_state = {}
177 if self.filter is not None:
178 filt_state["filter"] = (type(self.filter), self.filter.save_filter_state())
179 else:
180 filt_state["filter"] = (None, self.filter)
181 filt_state["prob_detection"] = self.prob_detection
182 filt_state["prob_survive"] = self.prob_survive
183 filt_state["birth_terms"] = self.birth_terms
184 filt_state["clutter_rate"] = self.clutter_rate
185 filt_state["clutter_den"] = self.clutter_den
186 filt_state["inv_chi2_gate"] = self.inv_chi2_gate
187 filt_state["save_covs"] = self.save_covs
188 filt_state["debug_plots"] = self.debug_plots
189 filt_state["ospa"] = self.ospa
190 filt_state["ospa_localization"] = self.ospa_localization
191 filt_state["ospa_cardinality"] = self.ospa_cardinality
193 filt_state["_states"] = self._states
194 filt_state["_meas_tab"] = self._meas_tab
195 filt_state["_covs"] = self._covs
196 filt_state["_ospa_params"] = self._ospa_params
198 return filt_state
200 @abc.abstractmethod
201 def load_filter_state(self, filt_state):
202 """Generic method for saving key filter variables.
204 This must be overridden in the inherited class. It is recommended to keep
205 the signature the same to allow for standardized implemenation of
206 wrapper classes. This initialize all internal variables saved by the
207 filter save function such that a new instance would generate the same
208 output as the original instance that called the save function.
209 """
210 cls_type = filt_state["filter"][0]
211 if cls_type is not None:
212 self.filter = cls_type()
213 self.filter.load_filter_state(filt_state["filter"][1])
214 else:
215 self.filter = filt_state["filter"]
216 self.prob_detection = filt_state["prob_detection"]
217 self.prob_survive = filt_state["prob_survive"]
218 self.birth_terms = filt_state["birth_terms"]
219 self.clutter_rate = filt_state["clutter_rate"]
220 self.clutter_den = filt_state["clutter_den"]
221 self.inv_chi2_gate = filt_state["inv_chi2_gate"]
222 self.save_covs = filt_state["save_covs"]
223 self.debug_plots = filt_state["debug_plots"]
224 self.ospa = filt_state["ospa"]
225 self.ospa_localization = filt_state["ospa_localization"]
226 self.ospa_cardinality = filt_state["ospa_cardinality"]
228 self._states = filt_state["_states"]
229 self._meas_tab = filt_state["_meas_tab"]
230 self._covs = filt_state["_covs"]
231 self._ospa_params = filt_state["_ospa_params"]
233 @property
234 def prob_miss_detection(self):
235 """Compliment of :py:attr:`.swarm_estimator.RandomFiniteSetBase.prob_detection`."""
236 return 1 - self.prob_detection
238 @property
239 def prob_death(self):
240 """Compliment of :attr:`carbs.swarm_estimator.RandomFinitSetBase.prob_survive`."""
241 return 1 - self.prob_survive
243 @property
244 def num_birth_terms(self):
245 """Number of terms in the birth model."""
246 return len(self.birth_terms)
248 @abc.abstractmethod
249 def predict(self, t, **kwargs):
250 """Abstract method for the prediction step.
252 This must be overridden in the inherited class. It is recommended to
253 keep the same structure/order for the arguments for consistency
254 between the inherited classes.
255 """
256 pass
258 @abc.abstractmethod
259 def correct(self, t, m, **kwargs):
260 """Abstract method for the correction step.
262 This must be overridden in the inherited class. It is recommended to
263 keep the same structure/order for the arguments for consistency
264 between the inherited classes.
265 """
266 pass
268 @abc.abstractmethod
269 def extract_states(self, **kwargs):
270 """Abstract method for extracting states."""
271 pass
273 @abc.abstractmethod
274 def cleanup(self, **kwargs):
275 """Abstract method that performs the cleanup step of the filter.
277 This must be overridden in the inherited class. It is recommended to
278 keep the same structure/order for the arguments for consistency
279 between the inherited classes.
280 """
281 pass
283 def _gate_meas(self, meas, means, covs, meas_mat_args={}, est_meas_args={}):
284 """Gates measurements based on current estimates.
286 Notes
287 -----
288 Gating is performed based on a Gaussian noise model.
289 See :cite:`Cox1993_AReviewofStatisticalDataAssociationTechniquesforMotionCorrespondence`
290 for details on the chi squared test used.
292 Parameters
293 ----------
294 meas : list
295 2d numpy arrrays of each measurement.
296 means : list
297 2d numpy arrays of each mean.
298 covs : list
299 2d numpy array of each covariance.
300 meas_mat_args : dict, optional
301 keyword arguments to pass to the inner filters get measurement
302 matrix function. The default is {}.
303 est_meas_args : TYPE, optional
304 keyword arguments to pass to the inner filters get estimate
305 matrix function. The default is {}.
307 Returns
308 -------
309 list
310 2d numpy arrays of valid measurements.
312 """
313 if len(meas) == 0:
314 return []
315 valid = []
316 for m, p in zip(means, covs):
317 meas_mat = self.filter.get_meas_mat(m, **meas_mat_args)
318 est = self.filter.get_est_meas(m, **est_meas_args)
319 meas_pred_cov = meas_mat @ p @ meas_mat.T + self.filter.meas_noise
320 meas_pred_cov = (meas_pred_cov + meas_pred_cov.T) / 2
321 v_s = la.cholesky(meas_pred_cov.T)
322 inv_sqrt_m_cov = la.inv(v_s)
324 for ii, z in enumerate(meas):
325 if ii in valid:
326 continue
327 inov = z - est
328 dist = np.sum((inv_sqrt_m_cov.T @ inov) ** 2)
329 if dist < self.inv_chi2_gate:
330 valid.append(ii)
331 valid.sort()
332 return [meas[ii] for ii in valid]
334 def _ospa_setup_tmat(self, truth, state_dim, true_covs, state_inds):
335 # get sizes
336 num_timesteps = len(truth)
337 num_objs = 0
339 for lst in truth:
340 num_objs = np.max(
341 [
342 num_objs,
343 np.sum([_x is not None and _x.size > 0 for _x in lst]).astype(int),
344 ]
345 )
346 # create matrices
347 true_mat = np.nan * np.ones((state_dim, num_timesteps, num_objs))
348 true_cov_mat = np.nan * np.ones((state_dim, state_dim, num_timesteps, num_objs))
350 for tt, lst in enumerate(truth):
351 obj_num = 0
352 for s in lst:
353 if s is not None and s.size > 0:
354 true_mat[:, tt, obj_num] = s.ravel()[state_inds]
355 obj_num += 1
356 if true_covs is not None:
357 for tt, lst in enumerate(true_covs):
358 obj_num = 0
359 for c in lst:
360 if c is not None and truth[tt][obj_num].size > 0:
361 true_cov_mat[:, :, tt, obj_num] = c[state_inds][:, state_inds]
362 obj_num += 1
363 return true_mat, true_cov_mat
365 def _ospa_setup_emat(self, state_dim, state_inds):
366 # get sizes
367 num_timesteps = len(self._states)
368 num_objs = 0
370 for lst in self._states:
371 num_objs = np.max(
372 [num_objs, np.sum([_x is not None for _x in lst]).astype(int)]
373 )
374 # create matrices
375 est_mat = np.nan * np.ones((state_dim, num_timesteps, num_objs))
376 est_cov_mat = np.nan * np.ones((state_dim, state_dim, num_timesteps, num_objs))
378 for tt, lst in enumerate(self._states):
379 for obj_num, s in enumerate(lst):
380 if s is not None and s.size > 0:
381 est_mat[:, tt, obj_num] = s.ravel()[state_inds]
382 if self.save_covs:
383 for tt, lst in enumerate(self._covs):
384 for obj_num, c in enumerate(lst):
385 if c is not None and self._states[tt][obj_num].size > 0:
386 est_cov_mat[:, :, tt, obj_num] = c[state_inds][:, state_inds]
387 return est_mat, est_cov_mat
389 def _ospa_input_check(self, core_method, truth, true_covs):
390 if core_method is None:
391 core_method = SingleObjectDistance.EUCLIDEAN
392 elif core_method is SingleObjectDistance.MAHALANOBIS and not self.save_covs:
393 msg = "Must save covariances to calculate {:s} OSPA. Using {:s} instead"
394 warnings.warn(msg.format(core_method, SingleObjectDistance.EUCLIDEAN))
395 core_method = SingleObjectDistance.EUCLIDEAN
396 elif core_method is SingleObjectDistance.HELLINGER and true_covs is None:
397 msg = "Must save covariances to calculate {:s} OSPA. Using {:s} instead"
398 warnings.warn(msg.format(core_method, SingleObjectDistance.EUCLIDEAN))
399 core_method = SingleObjectDistance.EUCLIDEAN
400 return core_method
402 def _ospa_find_s_dim(self, truth):
403 state_dim = None
404 for lst in truth:
405 for _x in lst:
406 if _x is not None:
407 state_dim = _x.size
408 break
409 if state_dim is not None:
410 break
411 if state_dim is None:
412 for lst in self._states:
413 for _x in lst:
414 if _x is not None:
415 state_dim = _x.size
416 break
417 if state_dim is not None:
418 break
419 return state_dim
421 def calculate_ospa(
422 self,
423 truth: Iterable[Iterable[np.ndarray]],
424 c: float,
425 p: float,
426 core_method: SingleObjectDistance = None,
427 true_covs: Iterable[Iterable[np.ndarray]] = None,
428 state_inds: Iterable[int] = None,
429 ):
430 """Calculates the OSPA distance between the truth at all timesteps.
432 Wrapper for :func:`serums.distances.calculate_ospa`.
434 Parameters
435 ----------
436 truth : list
437 Each element represents a timestep and is a list of N x 1 numpy array,
438 one per true agent in the swarm.
439 c : float
440 Distance cutoff for considering a point properly assigned. This
441 influences how cardinality errors are penalized. For :math:`p = 1`
442 it is the penalty given false point estimate.
443 p : int
444 The power of the distance term. Higher values penalize outliers
445 more.
446 core_method : :class:`serums.enums.SingleObjectDistance`, Optional
447 The main distance measure to use for the localization component.
448 The default value of None implies :attr:`.SingleObjectDistance.EUCLIDEAN`.
449 true_covs : list, Optional
450 Each element represents a timestep and is a list of N x N numpy arrays
451 corresonponding to the uncertainty about the true states. Note the
452 order must be consistent with the truth data given. This is only
453 needed for core methods :attr:`SingleObjectDistance.HELLINGER`. The defautl
454 value is None.
455 state_inds : list, optional
456 Indices in the state vector to use, will be applied to the truth
457 data as well. The default is None which means the full state is
458 used.
459 """
460 # error checking on optional input arguments
461 core_method = self._ospa_input_check(core_method, truth, true_covs)
463 # setup data structures
464 if state_inds is None:
465 state_dim = self._ospa_find_s_dim(truth)
466 state_inds = range(state_dim)
467 else:
468 state_dim = len(state_inds)
469 if state_dim is None:
470 warnings.warn("Failed to get state dimension. SKIPPING OSPA calculation")
472 nt = len(self._states)
473 self.ospa = np.zeros(nt)
474 self.ospa_localization = np.zeros(nt)
475 self.ospa_cardinality = np.zeros(nt)
476 self._ospa_params["core"] = core_method
477 self._ospa_params["cutoff"] = c
478 self._ospa_params["power"] = p
479 return
480 true_mat, true_cov_mat = self._ospa_setup_tmat(
481 truth, state_dim, true_covs, state_inds
482 )
483 est_mat, est_cov_mat = self._ospa_setup_emat(state_dim, state_inds)
485 # find OSPA
486 (
487 self.ospa,
488 self.ospa_localization,
489 self.ospa_cardinality,
490 self._ospa_params["core"],
491 self._ospa_params["cutoff"],
492 self._ospa_params["power"],
493 ) = calculate_ospa(
494 est_mat,
495 true_mat,
496 c,
497 p,
498 use_empty=True,
499 core_method=core_method,
500 true_cov_mat=true_cov_mat,
501 est_cov_mat=est_cov_mat,
502 )[
503 0:6
504 ]
506 def calculate_gospa(
507 self,
508 truth: Iterable[Iterable[np.ndarray]],
509 c: float,
510 p: float,
511 a: int,
512 core_method: SingleObjectDistance = None,
513 true_covs: Iterable[Iterable[np.ndarray]] = None,
514 state_inds: Iterable[int] = None,
515 ):
516 """Calculates the OSPA distance between the truth at all timesteps.
518 Wrapper for :func:`serums.distances.calculate_ospa`.
520 Parameters
521 ----------
522 truth : list
523 Each element represents a timestep and is a list of N x 1 numpy array,
524 one per true agent in the swarm.
525 c : float
526 Distance cutoff for considering a point properly assigned. This
527 influences how cardinality errors are penalized. For :math:`p = 1`
528 it is the penalty given false point estimate.
529 p : int
530 The power of the distance term. Higher values penalize outliers
531 more.
532 a : int
533 The normalization factor of the distance term. Appropriately penalizes missed
534 or false detection of tracks rather than normalizing by the total maximum
535 cardinality.
536 core_method : :class:`serums.enums.SingleObjectDistance`, Optional
537 The main distance measure to use for the localization component.
538 The default value of None implies :attr:`.SingleObjectDistance.EUCLIDEAN`.
539 true_covs : list, Optional
540 Each element represents a timestep and is a list of N x N numpy arrays
541 corresonponding to the uncertainty about the true states. Note the
542 order must be consistent with the truth data given. This is only
543 needed for core methods :attr:`SingleObjectDistance.HELLINGER`. The defautl
544 value is None.
545 state_inds : list, optional
546 Indices in the state vector to use, will be applied to the truth
547 data as well. The default is None which means the full state is
548 used.
549 """
550 # error checking on optional input arguments
551 core_method = self._ospa_input_check(core_method, truth, true_covs)
553 # setup data structures
554 if state_inds is None:
555 state_dim = self._ospa_find_s_dim(truth)
556 state_inds = range(state_dim)
557 else:
558 state_dim = len(state_inds)
559 if state_dim is None:
560 warnings.warn("Failed to get state dimension. SKIPPING OSPA calculation")
562 nt = len(self._states)
563 self.gospa = np.zeros(nt)
564 self.gospa_localization = np.zeros(nt)
565 self.gospa_cardinality = np.zeros(nt)
566 self._gospa_params["core"] = core_method
567 self._gospa_params["cutoff"] = c
568 self._gospa_params["power"] = p
569 self._gospa_params["normalization"] = a
570 return
571 true_mat, true_cov_mat = self._ospa_setup_tmat(
572 truth, state_dim, true_covs, state_inds
573 )
574 est_mat, est_cov_mat = self._ospa_setup_emat(state_dim, state_inds)
576 # find OSPA
577 (
578 self.gospa,
579 self.gospa_localization,
580 self.gospa_cardinality,
581 self._gospa_params["core"],
582 self._gospa_params["cutoff"],
583 self._gospa_params["power"],
584 self._gospa_params["normalization"],
585 ) = calculate_gospa(
586 est_mat,
587 true_mat,
588 c,
589 p,
590 a,
591 use_empty=True,
592 core_method=core_method,
593 true_cov_mat=true_cov_mat,
594 est_cov_mat=est_cov_mat,
595 )[
596 0:7
597 ]
599 def _plt_ospa_hist(self, y_val, time_units, time, ttl, y_lbl, opts):
600 fig = opts["f_hndl"]
602 if fig is None:
603 fig = plt.figure()
604 fig.add_subplot(1, 1, 1)
605 if time is None:
606 time = np.arange(y_val.size, dtype=int)
607 fig.axes[0].grid(True)
608 fig.axes[0].ticklabel_format(useOffset=False)
609 fig.axes[0].plot(time, y_val)
611 pltUtil.set_title_label(
612 fig, 0, opts, ttl=ttl, x_lbl="Time ({})".format(time_units), y_lbl=y_lbl
613 )
614 fig.tight_layout()
616 return fig
618 def _plt_ospa_hist_subs(self, y_vals, time_units, time, ttl, y_lbls, opts):
619 fig = opts["f_hndl"]
620 new_plot = fig is None
621 num_subs = len(y_vals)
623 if new_plot:
624 fig = plt.figure()
625 pltUtil.set_title_label(fig, 0, opts, ttl=ttl)
626 for ax, (y_val, y_lbl) in enumerate(zip(y_vals, y_lbls)):
627 if new_plot:
628 if ax > 0:
629 kwargs = {"sharex": fig.axes[0]}
630 else:
631 kwargs = {}
632 fig.add_subplot(num_subs, 1, ax + 1, **kwargs)
633 fig.axes[ax].grid(True)
634 fig.axes[ax].ticklabel_format(useOffset=False)
635 kwargs = {"y_lbl": y_lbl}
636 if ax == len(y_vals) - 1:
637 kwargs["x_lbl"] = "Time ({})".format(time_units)
638 pltUtil.set_title_label(fig, ax, opts, **kwargs)
639 if time is None:
640 time = np.arange(y_val.size, dtype=int)
641 fig.axes[ax].plot(time, y_val)
642 if new_plot:
643 fig.tight_layout()
644 return fig
646 def plot_ospa_history(
647 self,
648 time_units="index",
649 time=None,
650 main_opts=None,
651 sub_opts=None,
652 plot_subs=True,
653 ):
654 """Plots the OSPA history.
656 This requires that the OSPA has been calcualted by the approriate
657 function first.
659 Parameters
660 ----------
661 time_units : string, optional
662 Text representing the units of time in the plot. The default is
663 'index'.
664 time : numpy array, optional
665 Vector to use for the x-axis of the plot. If none is given then
666 vector indices are used. The default is None.
667 main_opts : dict, optional
668 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts`
669 function. Values implemented here are `f_hndl`, and any values
670 relating to title/axis text formatting. The default of None implies
671 the default options are used for the main plot.
672 sub_opts : dict, optional
673 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts`
674 function. Values implemented here are `f_hndl`, and any values
675 relating to title/axis text formatting. The default of None implies
676 the default options are used for the sub plot.
677 plot_subs : bool, optional
678 Flag indicating if the component statistics (cardinality and
679 localization) should also be plotted.
681 Returns
682 -------
683 figs : dict
684 Dictionary of matplotlib figure objects the data was plotted on.
685 """
686 if self.ospa is None:
687 warnings.warn("OSPA must be calculated before plotting")
688 return
689 if main_opts is None:
690 main_opts = pltUtil.init_plotting_opts()
691 if sub_opts is None and plot_subs:
692 sub_opts = pltUtil.init_plotting_opts()
693 fmt = "{:s} OSPA (c = {:.1f}, p = {:d})"
694 ttl = fmt.format(
695 self._ospa_params["core"],
696 self._ospa_params["cutoff"],
697 self._ospa_params["power"],
698 )
699 y_lbl = "OSPA"
701 figs = {}
702 figs["OSPA"] = self._plt_ospa_hist(
703 self.ospa, time_units, time, ttl, y_lbl, main_opts
704 )
706 if plot_subs:
707 fmt = "{:s} OSPA Components (c = {:.1f}, p = {:d})"
708 ttl = fmt.format(
709 self._ospa_params["core"],
710 self._ospa_params["cutoff"],
711 self._ospa_params["power"],
712 )
713 y_lbls = ["Localiztion", "Cardinality"]
714 figs["OSPA_subs"] = self._plt_ospa_hist_subs(
715 [self.ospa_localization, self.ospa_cardinality],
716 time_units,
717 time,
718 ttl,
719 y_lbls,
720 main_opts,
721 )
722 return figs
724 def plot_gospa_history(
725 self,
726 time_units="index",
727 time=None,
728 main_opts=None,
729 sub_opts=None,
730 plot_subs=True,
731 ):
732 """Plots the GOSPA history.
734 This requires that the GOSPA has been calcualted by the approriate
735 function first.
737 Parameters
738 ----------
739 time_units : string, optional
740 Text representing the units of time in the plot. The default is
741 'index'.
742 time : numpy array, optional
743 Vector to use for the x-axis of the plot. If none is given then
744 vector indices are used. The default is None.
745 main_opts : dict, optional
746 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts`
747 function. Values implemented here are `f_hndl`, and any values
748 relating to title/axis text formatting. The default of None implies
749 the default options are used for the main plot.
750 sub_opts : dict, optional
751 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts`
752 function. Values implemented here are `f_hndl`, and any values
753 relating to title/axis text formatting. The default of None implies
754 the default options are used for the sub plot.
755 plot_subs : bool, optional
756 Flag indicating if the component statistics (cardinality and
757 localization) should also be plotted.
759 Returns
760 -------
761 figs : dict
762 Dictionary of matplotlib figure objects the data was plotted on.
763 """
764 if self.gospa is None:
765 warnings.warn("GOSPA must be calculated before plotting")
766 return
767 if main_opts is None:
768 main_opts = pltUtil.init_plotting_opts()
769 if sub_opts is None and plot_subs:
770 sub_opts = pltUtil.init_plotting_opts()
771 fmt = "{:s} GOSPA (c = {:.1f}, p = {:d}, a = {:d})"
772 ttl = fmt.format(
773 self._gospa_params["core"],
774 self._gospa_params["cutoff"],
775 self._gospa_params["power"],
776 self._gospa_params["normalization"],
777 )
778 y_lbl = "GOSPA"
780 figs = {}
781 figs["GOSPA"] = self._plt_ospa_hist(
782 self.gospa, time_units, time, ttl, y_lbl, main_opts
783 )
785 if plot_subs:
786 fmt = "{:s} GOSPA Components (c = {:.1f}, p = {:d}, a = {:d})"
787 ttl = fmt.format(
788 self._gospa_params["core"],
789 self._gospa_params["cutoff"],
790 self._gospa_params["power"],
791 self._gospa_params["normalization"],
792 )
793 y_lbls = ["Localiztion", "Cardinality"]
794 figs["GOSPA_subs"] = self._plt_ospa_hist_subs(
795 [self.gospa_localization, self.gospa_cardinality],
796 time_units,
797 time,
798 ttl,
799 y_lbls,
800 main_opts,
801 )
802 return figs
805class ProbabilityHypothesisDensity(RandomFiniteSetBase):
806 """Implements the Probability Hypothesis Density filter.
808 The kwargs in the constructor are passed through to the parent constructor.
810 Notes
811 -----
812 The filter implementation is based on :cite:`Vo2006_TheGaussianMixtureProbabilityHypothesisDensityFilter`
814 Attributes
815 ----------
816 gating_on : bool
817 flag indicating if measurement gating should be performed. The
818 default is False.
819 inv_chi2_gate : float
820 threshold for the chi squared test in the measurement gating. The
821 default is 0.
822 extract_threshold : float
823 threshold for extracting the state. The default is 0.5.
824 prune_threshold : float
825 threshold for removing hypotheses. The default is 10**-5.
826 merge_threshold : float
827 threshold for merging hypotheses. The default is 4.
828 max_gauss : int
829 max number of gaussians to use. The default is 100.
831 """
833 def __init__(
834 self,
835 gating_on=False,
836 inv_chi2_gate=0,
837 extract_threshold=0.5,
838 prune_threshold=1e-5,
839 merge_threshold=4,
840 max_gauss=100,
841 **kwargs,
842 ):
843 self.gating_on = gating_on
844 self.inv_chi2_gate = inv_chi2_gate
845 self.extract_threshold = extract_threshold
846 self.prune_threshold = prune_threshold
847 self.merge_threshold = merge_threshold
848 self.max_gauss = max_gauss
850 self._gaussMix = smodels.GaussianMixture()
852 super().__init__(**kwargs)
854 def save_filter_state(self):
855 """Saves filter variables so they can be restored later."""
856 filt_state = super().save_filter_state()
858 raise RuntimeError("Not implmented yet")
859 return filt_state
861 def load_filter_state(self, filt_state):
862 """Initializes filter using saved filter state.
864 Attributes
865 ----------
866 filt_state : dict
867 Dictionary generated by :meth:`save_filter_state`.
868 """
869 super().load_filter_state(filt_state)
871 raise RuntimeError("Not implmented yet")
873 @property
874 def states(self):
875 """Read only list of extracted states.
877 This is a list with 1 element per timestep, and each element is a list
878 of the best states extracted at that timestep. The order of each
879 element corresponds to the label order.
880 """
881 if len(self._states) > 0:
882 return self._states[-1]
883 else:
884 return []
886 @property
887 def covariances(self):
888 """Read only list of extracted covariances.
890 This is a list with 1 element per timestep, and each element is a list
891 of the best covariances extracted at that timestep. The order of each
892 element corresponds to the state order.
894 Warns
895 -----
896 RuntimeWarning
897 If the class is not saving the covariances, and returns an
898 empty list
899 """
900 if not self.save_covs:
901 warnings.warn("Not saving covariances")
902 return []
903 if len(self._covs) > 0:
904 return self._covs[-1]
905 else:
906 return []
908 @property
909 def cardinality(self):
910 """Read only cardinality of the RFS."""
911 if len(self._states) == 0:
912 return 0
913 else:
914 return len(self._states[-1])
916 def _gen_spawned_targets(self, gaussMix):
917 if self.spawn_cov is not None and self.spawn_weight is not None:
918 gauss_list = [
919 smodels.Gaussian(mean=m.copy(), covariance=self.spawn_cov.copy())
920 for m in gaussMix.means
921 ]
922 return smodels.GaussianMixture(
923 distributions=gauss_list,
924 weights=[self.spawn_weight for ii in range(len(gauss_list))],
925 )
926 else:
927 raise RuntimeError(
928 "self.spawn_cov and self.spawn_weight must be specified."
929 )
931 def predict(self, timestep, filt_args={}):
932 """Prediction step of the PHD filter.
934 This predicts new hypothesis, and propogates them to the next time
935 step. It also updates the cardinality distribution. Because this calls
936 the inner filter's predict function, the keyword arguments must contain
937 any information needed by that function.
940 Parameters
941 ----------
942 timestep: float
943 current timestep
944 filt_args : dict, optional
945 Passed to the inner filter. The default is {}.
947 Returns
948 -------
949 None.
951 """
952 if self.enable_spawning:
953 spawn_mix = self._gen_spawned_targets(self._gaussMix)
955 self._gaussMix = self._predict_prob_density(timestep, self._gaussMix, filt_args)
957 if self.enable_spawning:
958 self._gaussMix.add_components(
959 spawn_mix.means, spawn_mix.covariances, spawn_mix.weights
960 )
962 for gm in self.birth_terms:
963 self._gaussMix.add_components(gm.means, gm.covariances, gm.weights)
965 def _predict_prob_density(self, timestep, probDensity, filt_args):
966 """Predicts the probability density.
968 Loops over all elements in a probability distribution and performs
969 the filter prediction.
971 Parameters
972 ----------
973 timestep: float
974 current timestep
975 probDensity : :class:`serums.models.GaussianMixture`
976 Probability density to perform prediction on.
977 filt_args : dict
978 Passed directly to the inner filter.
980 Returns
981 -------
982 gm : :class:`serums.models.GaussianMixture`
983 predicted Gaussian mixture.
985 """
986 weights = [self.prob_survive * x for x in probDensity.weights.copy()]
987 n_terms = len(probDensity.means)
988 covariances = [None] * n_terms
989 means = [None] * n_terms
990 for ii, (m, P) in enumerate(zip(probDensity.means, probDensity.covariances)):
991 self.filter.cov = P
992 n_mean = self.filter.predict(timestep, m, **filt_args)
993 covariances[ii] = self.filter.cov.copy()
994 means[ii] = n_mean
995 return smodels.GaussianMixture(
996 means=means, covariances=covariances, weights=weights
997 )
999 def correct(
1000 self, timestep, meas_in, meas_mat_args={}, est_meas_args={}, filt_args={}
1001 ):
1002 """Correction step of the PHD filter.
1004 This corrects the hypotheses based on the measurements and gates the
1005 measurements according to the class settings. It also updates the
1006 cardinality distribution.
1009 Parameters
1010 ----------
1011 timestep: float
1012 current timestep
1013 meas_in : list
1014 2d numpy arrays representing a measurement.
1015 meas_mat_args : dict, optional
1016 keyword arguments to pass to the inner filters get measurement
1017 matrix function. Only used if gating is on. The default is {}.
1018 est_meas_args : TYPE, optional
1019 keyword arguments to pass to the inner filters estimate
1020 measurements function. Only used if gating is on. The default is {}.
1021 filt_args : dict, optional
1022 keyword arguments to pass to the inner filters correct function.
1023 The default is {}.
1025 Todo
1026 ----
1027 Fix the measurement gating
1029 Returns
1030 -------
1031 None.
1033 """
1034 meas = deepcopy(meas_in)
1036 if self.gating_on:
1037 meas = self._gate_meas(
1038 meas,
1039 self._gaussMix.means,
1040 self._gaussMix.covariances,
1041 meas_mat_args,
1042 est_meas_args,
1043 )
1044 self._meas_tab.append(meas)
1046 gmix = deepcopy(self._gaussMix)
1047 gmix.weights = [self.prob_miss_detection * x for x in gmix.weights]
1048 gm = self._correct_prob_density(timestep, meas, self._gaussMix, filt_args)
1049 gm.add_components(gmix.means, gmix.covariances, gmix.weights)
1051 self._gaussMix = gm
1053 def _correct_prob_density(self, timestep, meas, probDensity, filt_args):
1054 """Corrects the probability densities.
1056 Loops over all elements in a probability distribution and preforms
1057 the filter correction.
1059 Parameters
1060 ----------
1061 meas : list
1062 2d numpy arrays of each measurement.
1063 probDensity : :py:class:`serums.models.GaussianMixture`
1064 probability density to run correction on.
1065 filt_args : dict
1066 arguements to pass to the inner filter correct function.
1068 Returns
1069 -------
1070 gm : :py:class:`serums.models.GaussianMixture`
1071 corrected probability density.
1073 """
1074 means = []
1075 covariances = []
1076 weights = []
1077 det_weights = [self.prob_detection * x for x in probDensity.weights]
1078 for z in meas:
1079 w_lst = []
1080 for jj in range(0, len(probDensity.means)):
1081 self.filter.cov = probDensity.covariances[jj]
1082 state = probDensity.means[jj]
1083 (mean, qz) = self.filter.correct(timestep, z, state, **filt_args)
1084 cov = self.filter.cov
1085 w = qz * det_weights[jj]
1086 means.append(mean)
1087 covariances.append(cov)
1088 w_lst.append(w)
1089 weights.extend(
1090 [x / (self.clutter_rate * self.clutter_den + sum(w_lst)) for x in w_lst]
1091 )
1092 return smodels.GaussianMixture(
1093 means=means, covariances=covariances, weights=weights
1094 )
1096 def _prune(self):
1097 """Removes hypotheses below a threshold.
1099 This should be called once per time step after the correction and
1100 before the state extraction.
1101 """
1102 inds = np.where(np.asarray(self._gaussMix.weights) < self.prune_threshold)[0]
1103 self._gaussMix.remove_components(inds.flatten().tolist())
1104 return inds
1106 def _merge(self):
1107 """Merges nearby hypotheses."""
1108 loop_inds = set(range(0, len(self._gaussMix.means)))
1110 w_lst = []
1111 m_lst = []
1112 p_lst = []
1113 while len(loop_inds) > 0:
1114 jj = int(np.argmax(self._gaussMix.weights))
1115 comp_inds = []
1116 inv_cov = la.inv(self._gaussMix.covariances[jj])
1117 for ii in loop_inds:
1118 diff = self._gaussMix.means[ii] - self._gaussMix.means[jj]
1119 val = diff.T @ inv_cov @ diff
1120 if val <= self.merge_threshold:
1121 comp_inds.append(ii)
1122 w_new = sum([self._gaussMix.weights[ii] for ii in comp_inds])
1123 m_new = (
1124 sum(
1125 [
1126 self._gaussMix.weights[ii] * self._gaussMix.means[ii]
1127 for ii in comp_inds
1128 ]
1129 )
1130 / w_new
1131 )
1132 p_new = (
1133 sum(
1134 [
1135 self._gaussMix.weights[ii] * self._gaussMix.covariances[ii]
1136 for ii in comp_inds
1137 ]
1138 )
1139 / w_new
1140 )
1142 w_lst.append(w_new)
1143 m_lst.append(m_new)
1144 p_lst.append(p_new)
1146 loop_inds = loop_inds.symmetric_difference(comp_inds)
1147 for ii in comp_inds:
1148 self._gaussMix.weights[ii] = -1
1149 self._gaussMix = smodels.GaussianMixture(
1150 means=m_lst, covariances=p_lst, weights=w_lst
1151 )
1153 def _cap(self):
1154 """Removes least likely hypotheses until a maximum number is reached.
1156 This should be called once per time step after pruning and
1157 before the state extraction.
1158 """
1159 if len(self._gaussMix.weights) > self.max_gauss:
1160 idx = np.argsort(self._gaussMix.weights)
1161 w = sum(self._gaussMix.weights)
1162 self._gaussMix.remove_components(idx[0 : -self.max_gauss])
1163 self._gaussMix.weights = [
1164 x * (w / sum(self._gaussMix.weights)) for x in self._gaussMix.weights
1165 ]
1166 return idx[0 : -self.max_gauss].tolist()
1167 return []
1169 def extract_states(self):
1170 """Extracts the best state estimates.
1172 This extracts the best states from the distribution. It should be
1173 called once per time step after the correction function.
1174 """
1175 inds = np.where(np.asarray(self._gaussMix.weights) >= self.extract_threshold)
1176 inds = np.ndarray.flatten(inds[0])
1177 s_lst = []
1178 c_lst = []
1179 for jj in inds:
1180 jj = int(jj)
1181 num_reps = round(self._gaussMix.weights[jj])
1182 s_lst.extend([self._gaussMix.means[jj]] * num_reps)
1183 if self.save_covs:
1184 c_lst.extend([self._gaussMix.covariances[jj]] * num_reps)
1185 self._states.append(s_lst)
1186 if self.save_covs:
1187 self._covs.append(c_lst)
1189 def cleanup(
1190 self,
1191 enable_prune=True,
1192 enable_cap=True,
1193 enable_merge=True,
1194 enable_extract=True,
1195 extract_kwargs=None,
1196 ):
1197 """Performs the cleanup step of the filter.
1199 This can prune, cap, and extract states. It must be called once per
1200 timestep. If this is called with `enable_extract` set to true then
1201 the extract states method does not need to be called separately. It is
1202 recommended to call this function instead of
1203 :meth:`carbs.swarm_estimator.tracker.GeneralizedLabeledMultiBernoulli.extract_states`
1204 directly.
1206 Parameters
1207 ----------
1208 enable_prune : bool, optional
1209 Flag indicating if prunning should be performed. The default is True.
1210 enable_cap : bool, optional
1211 Flag indicating if capping should be performed. The default is True.
1212 enable_merge : bool, optional
1213 Flag indicating if merging should be performed. The default is True.
1214 enable_extract : bool, optional
1215 Flag indicating if state extraction should be performed. The default is True.
1216 extract_kwargs : dict, optional
1217 Extra arguments to pass to the extract function.
1218 """
1219 if enable_prune:
1220 self._prune()
1221 if enable_merge:
1222 self._merge()
1223 if enable_cap:
1224 self._cap()
1225 if enable_extract:
1226 if extract_kwargs is None:
1227 extract_kwargs = {}
1228 self.extract_states(**extract_kwargs)
1230 def __ani_state_plotting(
1231 self,
1232 f_hndl,
1233 tt,
1234 states,
1235 show_sig,
1236 plt_inds,
1237 sig_bnd,
1238 color,
1239 marker,
1240 state_lbl,
1241 added_sig_lbl,
1242 added_state_lbl,
1243 scat=None,
1244 ):
1245 if scat is None:
1246 if not added_state_lbl:
1247 scat = f_hndl.axes[0].scatter(
1248 [], [], color=color, edgecolors=(0, 0, 0), marker=marker
1249 )
1250 else:
1251 scat = f_hndl.axes[0].scatter(
1252 [],
1253 [],
1254 color=color,
1255 edgecolors=(0, 0, 0),
1256 marker=marker,
1257 label=state_lbl,
1258 )
1259 if len(states) == 0:
1260 return scat
1261 x = np.concatenate(states, axis=1)
1262 if show_sig:
1263 sigs = [None] * len(states)
1264 for ii, cov in enumerate(self._covs[tt]):
1265 sig = np.zeros((2, 2))
1266 sig[0, 0] = cov[plt_inds[0], plt_inds[0]]
1267 sig[0, 1] = cov[plt_inds[0], plt_inds[1]]
1268 sig[1, 0] = cov[plt_inds[1], plt_inds[0]]
1269 sig[1, 1] = cov[plt_inds[1], plt_inds[1]]
1270 sigs[ii] = sig
1271 # plot
1272 for ii, sig in enumerate(sigs):
1273 if sig is None:
1274 continue
1275 w, h, a = pltUtil.calc_error_ellipse(sig, sig_bnd)
1276 if not added_sig_lbl:
1277 s = r"${}\sigma$ Error Ellipses".format(sig_bnd)
1278 e = Ellipse(
1279 xy=x[plt_inds, ii],
1280 width=w,
1281 height=h,
1282 angle=a,
1283 zorder=-10000,
1284 animated=True,
1285 label=s,
1286 )
1287 else:
1288 e = Ellipse(
1289 xy=x[plt_inds, ii],
1290 width=w,
1291 height=h,
1292 angle=a,
1293 zorder=-10000,
1294 animated=True,
1295 )
1296 e.set_clip_box(f_hndl.axes[0].bbox)
1297 e.set_alpha(0.15)
1298 e.set_facecolor(color)
1299 f_hndl.axes[0].add_patch(e)
1300 scat.set_offsets(x[plt_inds[0:2], :].T)
1301 return scat
1303 def plot_states(
1304 self,
1305 plt_inds,
1306 state_lbl="States",
1307 ttl=None,
1308 state_color=None,
1309 x_lbl=None,
1310 y_lbl=None,
1311 **kwargs,
1312 ):
1313 """Plots the best estimate for the states.
1315 This assumes that the states have been extracted. It's designed to plot
1316 two of the state variables (typically x/y position). The error ellipses
1317 are calculated according to :cite:`Hoover1984_AlgorithmsforConfidenceCirclesandEllipses`
1319 Keyword arguments are processed with
1320 :meth:`gncpy.plotting.init_plotting_opts`. This function
1321 implements
1323 - f_hndl
1324 - true_states
1325 - sig_bnd
1326 - rng
1327 - meas_inds
1328 - lgnd_loc
1329 - marker
1331 Parameters
1332 ----------
1333 plt_inds : list
1334 List of indices in the state vector to plot
1335 state_lbl : string
1336 Value to appear in legend for the states. Only appears if the
1337 legend is shown
1338 ttl : string, optional
1339 Title for the plot, if None a default title is generated. The default
1340 is None.
1341 x_lbl : string
1342 Label for the x-axis.
1343 y_lbl : string
1344 Label for the y-axis.
1346 Returns
1347 -------
1348 Matplotlib figure
1349 Instance of the matplotlib figure used
1350 """
1351 opts = pltUtil.init_plotting_opts(**kwargs)
1352 f_hndl = opts["f_hndl"]
1353 true_states = opts["true_states"]
1354 sig_bnd = opts["sig_bnd"]
1355 rng = opts["rng"]
1356 meas_inds = opts["meas_inds"]
1357 lgnd_loc = opts["lgnd_loc"]
1358 marker = opts["marker"]
1359 if ttl is None:
1360 ttl = "State Estimates"
1361 if rng is None:
1362 rng = rnd.default_rng(1)
1363 if x_lbl is None:
1364 x_lbl = "x-position"
1365 if y_lbl is None:
1366 y_lbl = "y-position"
1367 plt_meas = meas_inds is not None
1368 show_sig = sig_bnd is not None and self.save_covs
1370 s_lst = deepcopy(self._states)
1371 x_dim = None
1373 if f_hndl is None:
1374 f_hndl = plt.figure()
1375 f_hndl.add_subplot(1, 1, 1)
1376 # get state dimension
1377 for states in s_lst:
1378 if len(states) > 0:
1379 x_dim = states[0].size
1380 break
1381 # get array of all state values for each label
1382 added_sig_lbl = False
1383 added_true_lbl = False
1384 added_state_lbl = False
1385 added_meas_lbl = False
1386 r = rng.random()
1387 b = rng.random()
1388 g = rng.random()
1389 if state_color is None:
1390 color = (r, g, b)
1391 else:
1392 color = state_color
1393 for tt, states in enumerate(s_lst):
1394 if len(states) == 0:
1395 continue
1396 x = np.concatenate(states, axis=1)
1397 if show_sig:
1398 sigs = [None] * len(states)
1399 for ii, cov in enumerate(self._covs[tt]):
1400 sig = np.zeros((2, 2))
1401 sig[0, 0] = cov[plt_inds[0], plt_inds[0]]
1402 sig[0, 1] = cov[plt_inds[0], plt_inds[1]]
1403 sig[1, 0] = cov[plt_inds[1], plt_inds[0]]
1404 sig[1, 1] = cov[plt_inds[1], plt_inds[1]]
1405 sigs[ii] = sig
1406 # plot
1407 for ii, sig in enumerate(sigs):
1408 if sig is None:
1409 continue
1410 w, h, a = pltUtil.calc_error_ellipse(sig, sig_bnd)
1411 if not added_sig_lbl:
1412 s = r"${}\sigma$ Error Ellipses".format(sig_bnd)
1413 e = Ellipse(
1414 xy=x[plt_inds, ii],
1415 width=w,
1416 height=h,
1417 angle=a,
1418 zorder=-10000,
1419 label=s,
1420 )
1421 added_sig_lbl = True
1422 else:
1423 e = Ellipse(
1424 xy=x[plt_inds, ii],
1425 width=w,
1426 height=h,
1427 angle=a,
1428 zorder=-10000,
1429 )
1430 e.set_clip_box(f_hndl.axes[0].bbox)
1431 e.set_alpha(0.15)
1432 e.set_facecolor(color)
1433 f_hndl.axes[0].add_patch(e)
1434 if not added_state_lbl:
1435 f_hndl.axes[0].scatter(
1436 x[plt_inds[0], :],
1437 x[plt_inds[1], :],
1438 color=color,
1439 edgecolors=(0, 0, 0),
1440 marker=marker,
1441 label=state_lbl,
1442 )
1443 added_state_lbl = True
1444 else:
1445 f_hndl.axes[0].scatter(
1446 x[plt_inds[0], :],
1447 x[plt_inds[1], :],
1448 color=color,
1449 edgecolors=(0, 0, 0),
1450 marker=marker,
1451 )
1452 # if true states are available then plot them
1453 if true_states is not None:
1454 if x_dim is None:
1455 for states in true_states:
1456 if len(states) > 0:
1457 x_dim = states[0].size
1458 break
1459 max_true = max([len(x) for x in true_states])
1460 x = np.nan * np.ones((x_dim, len(true_states), max_true))
1461 for tt, states in enumerate(true_states):
1462 for ii, state in enumerate(states):
1463 x[:, [tt], ii] = state.copy()
1464 for ii in range(0, max_true):
1465 if not added_true_lbl:
1466 f_hndl.axes[0].plot(
1467 x[plt_inds[0], :, ii],
1468 x[plt_inds[1], :, ii],
1469 color="k",
1470 marker=".",
1471 label="True Trajectories",
1472 )
1473 added_true_lbl = True
1474 else:
1475 f_hndl.axes[0].plot(
1476 x[plt_inds[0], :, ii],
1477 x[plt_inds[1], :, ii],
1478 color="k",
1479 marker=".",
1480 )
1481 if plt_meas:
1482 meas_x = []
1483 meas_y = []
1484 for meas_tt in self._meas_tab:
1485 mx_ii = [m[meas_inds[0]].item() for m in meas_tt]
1486 my_ii = [m[meas_inds[1]].item() for m in meas_tt]
1487 meas_x.extend(mx_ii)
1488 meas_y.extend(my_ii)
1489 color = (128 / 255, 128 / 255, 128 / 255)
1490 meas_x = np.asarray(meas_x)
1491 meas_y = np.asarray(meas_y)
1492 if not added_meas_lbl:
1493 f_hndl.axes[0].scatter(
1494 meas_x,
1495 meas_y,
1496 zorder=-1,
1497 alpha=0.35,
1498 color=color,
1499 marker="^",
1500 edgecolors=(0, 0, 0),
1501 label="Measurements",
1502 )
1503 else:
1504 f_hndl.axes[0].scatter(
1505 meas_x,
1506 meas_y,
1507 zorder=-1,
1508 alpha=0.35,
1509 color=color,
1510 marker="^",
1511 edgecolors=(0, 0, 0),
1512 )
1513 f_hndl.axes[0].grid(True)
1514 pltUtil.set_title_label(f_hndl, 0, opts, ttl=ttl, x_lbl=x_lbl, y_lbl=y_lbl)
1516 if lgnd_loc is not None:
1517 plt.legend(loc=lgnd_loc)
1518 plt.tight_layout()
1520 return f_hndl
1522 def animate_state_plot(
1523 self,
1524 plt_inds,
1525 state_lbl="States",
1526 state_color=None,
1527 interval=250,
1528 repeat=True,
1529 repeat_delay=1000,
1530 save_path=None,
1531 **kwargs,
1532 ):
1533 """Creates an animated plot of the states.
1535 Parameters
1536 ----------
1537 plt_inds : list
1538 indices of the state vector to plot.
1539 state_lbl : string, optional
1540 label for the states. The default is 'States'.
1541 state_color : tuple, optional
1542 3-tuple for rgb value. The default is None.
1543 interval : int, optional
1544 interval of the animation in ms. The default is 250.
1545 repeat : bool, optional
1546 flag indicating if the animation loops. The default is True.
1547 repeat_delay : int, optional
1548 delay between loops in ms. The default is 1000.
1549 save_path : string, optional
1550 file path and name to save the gif, does not save if not given.
1551 The default is None.
1552 **kwargs : dict, optional
1553 Standard plotting options for
1554 :meth:`gncpy.plotting.init_plotting_opts`. This function
1555 implements
1557 - f_hndl
1558 - sig_bnd
1559 - rng
1560 - meas_inds
1561 - lgnd_loc
1562 - marker
1564 Returns
1565 -------
1566 anim :
1567 handle to the animation.
1569 """
1570 opts = pltUtil.init_plotting_opts(**kwargs)
1571 f_hndl = opts["f_hndl"]
1572 sig_bnd = opts["sig_bnd"]
1573 rng = opts["rng"]
1574 meas_inds = opts["meas_inds"]
1575 lgnd_loc = opts["lgnd_loc"]
1576 marker = opts["marker"]
1578 plt_meas = meas_inds is not None
1579 show_sig = sig_bnd is not None and self.save_covs
1581 f_hndl.axes[0].grid(True)
1582 pltUtil.set_title_label(
1583 f_hndl,
1584 0,
1585 opts,
1586 ttl="State Estimates",
1587 x_lbl="x-position",
1588 y_lbl="y-position",
1589 )
1591 fr_number = f_hndl.axes[0].annotate(
1592 "0",
1593 (0, 1),
1594 xycoords="axes fraction",
1595 xytext=(10, -10),
1596 textcoords="offset points",
1597 ha="left",
1598 va="top",
1599 animated=False,
1600 )
1602 added_sig_lbl = False
1603 added_state_lbl = False
1604 added_meas_lbl = False
1605 r = rng.random()
1606 b = rng.random()
1607 g = rng.random()
1608 if state_color is None:
1609 s_color = (r, g, b)
1610 else:
1611 s_color = state_color
1612 state_scat = f_hndl.axes[0].scatter(
1613 [], [], color=s_color, edgecolors=(0, 0, 0), marker=marker, label=state_lbl
1614 )
1615 meas_scat = None
1616 if plt_meas:
1617 m_color = (128 / 255, 128 / 255, 128 / 255)
1619 if meas_scat is None:
1620 if not added_meas_lbl:
1621 lbl = "Measurements"
1622 meas_scat = f_hndl.axes[0].scatter(
1623 [],
1624 [],
1625 zorder=-1,
1626 alpha=0.35,
1627 color=m_color,
1628 marker="^",
1629 edgecolors="k",
1630 label=lbl,
1631 )
1632 added_meas_lbl = True
1633 else:
1634 meas_scat = f_hndl.axes[0].scatter(
1635 [],
1636 [],
1637 zorder=-1,
1638 alpha=0.35,
1639 color=m_color,
1640 marker="^",
1641 edgecolors="k",
1642 )
1644 def update(tt, *fargs):
1645 nonlocal added_sig_lbl
1646 nonlocal added_state_lbl
1647 nonlocal added_meas_lbl
1648 nonlocal state_scat
1649 nonlocal meas_scat
1650 nonlocal fr_number
1652 fr_number.set_text("Timestep: {j}".format(j=tt))
1654 states = self._states[tt]
1655 state_scat = self.__ani_state_plotting(
1656 f_hndl,
1657 tt,
1658 states,
1659 show_sig,
1660 plt_inds,
1661 sig_bnd,
1662 s_color,
1663 marker,
1664 state_lbl,
1665 added_sig_lbl,
1666 added_state_lbl,
1667 scat=state_scat,
1668 )
1669 added_sig_lbl = True
1670 added_state_lbl = True
1672 if plt_meas:
1673 meas_tt = self._meas_tab[tt]
1675 meas_x = [m[meas_inds[0]].item() for m in meas_tt]
1676 meas_y = [m[meas_inds[1]].item() for m in meas_tt]
1678 meas_x = np.asarray(meas_x)
1679 meas_y = np.asarray(meas_y)
1680 meas_scat.set_offsets(np.array([meas_x, meas_y]).T)
1682 # plt.figure(f_hndl.number)
1683 anim = animation.FuncAnimation(
1684 f_hndl,
1685 update,
1686 frames=len(self._states),
1687 interval=interval,
1688 repeat_delay=repeat_delay,
1689 repeat=repeat,
1690 )
1692 if lgnd_loc is not None:
1693 plt.legend(loc=lgnd_loc)
1694 if save_path is not None:
1695 writer = animation.PillowWriter(fps=30)
1696 anim.save(save_path, writer=writer)
1697 return anim
1700class CardinalizedPHD(ProbabilityHypothesisDensity):
1701 """Implements the Cardinalized Probability Hypothesis Density filter.
1703 The kwargs in the constructor are passed through to the parent constructor.
1705 Notes
1706 -----
1707 The filter implementation is based on
1708 :cite:`Vo2006_TheCardinalizedProbabilityHypothesisDensityFilterforLinearGaussianMultiTargetModels`
1709 and :cite:`Vo2007_AnalyticImplementationsoftheCardinalizedProbabilityHypothesisDensityFilter`.
1711 Attributes
1712 ----------
1713 agents_per_state : list, optional
1714 number of agents per state. The default is [].
1715 """
1717 def __init__(self, max_expected_card=10, **kwargs):
1718 self.agents_per_state = []
1719 self._max_expected_card = max_expected_card
1721 self._card_dist = np.zeros(
1722 self.max_expected_card + 1
1723 ) # local copy for internal modification
1724 self._card_dist[0] = 1
1725 self._card_time_hist = [] # local copy for internal modification
1726 self._n_states_per_time = []
1728 super().__init__(**kwargs)
1730 @property
1731 def max_expected_card(self):
1732 """Maximum expected cardinality. The default is 10."""
1733 return self._max_expected_card
1735 @max_expected_card.setter
1736 def max_expected_card(self, x):
1737 self._card_dist = np.zeros(x + 1)
1738 self._card_dist[0] = 1
1739 self._max_expected_card = x
1741 @property
1742 def cardinality(self):
1743 """Cardinality of the RFS."""
1744 return np.argmax(self._card_dist)
1746 def predict(self, timestep, **kwargs):
1747 """Prediction step of the CPHD filter.
1749 This predicts new hypothesis, and propogates them to the next time
1750 step. It also updates the cardinality distribution.
1753 Parameters
1754 ----------
1755 timestep: float
1756 current timestep
1757 **kwargs : dict, optional
1758 See :meth:carbs.swarm_estimator.tracker.ProbabilityHypothesisDensity.predict`
1759 for the available arguments.
1761 Returns
1762 -------
1763 None.
1765 """
1766 super().predict(timestep, **kwargs)
1768 survive_cdn_predict = np.zeros(self.max_expected_card + 1)
1769 for j in range(0, self.max_expected_card + 1):
1770 terms = np.zeros(self.max_expected_card + 1)
1771 for i in range(j, self.max_expected_card + 1):
1772 temp = np.array(
1773 [
1774 np.sum(np.log(range(1, i + 1))),
1775 -np.sum(np.log(range(1, j + 1))),
1776 np.sum(np.log(range(1, i - j + 1))),
1777 j * np.log(self.prob_survive),
1778 (i - j) * np.log(self.prob_death),
1779 ]
1780 )
1781 terms[i] = np.exp(np.sum(temp)) * self._card_dist[i]
1782 survive_cdn_predict[j] = np.sum(terms)
1783 cdn_predict = np.zeros(self.max_expected_card + 1)
1784 if len(self.birth_terms) != 1:
1785 warnings.warn("Only using the first birth term in cardinality update")
1786 birth = np.sum(
1787 np.array([w for w in self.birth_terms[0].weights])
1788 ) # NOTE: assumes 1 GM for the birth model
1789 log_birth = np.log(birth)
1790 for n in range(0, self.max_expected_card + 1):
1791 terms = np.zeros(self.max_expected_card + 1)
1792 for j in range(0, n + 1):
1793 temp = np.array(
1794 [birth, (n - j) * log_birth, -np.sum(np.log(range(1, n - j + 1)))]
1795 )
1796 terms[j] = np.exp(np.sum(temp)) * survive_cdn_predict[j]
1797 cdn_predict[n] = np.sum(terms)
1798 self._card_dist = (cdn_predict / np.sum(cdn_predict)).copy()
1800 self._card_time_hist.append(
1801 (np.argmax(self._card_dist).item(), np.std(self._card_dist))
1802 )
1804 def correct(
1805 self, timestep, meas_in, meas_mat_args={}, est_meas_args={}, filt_args={}
1806 ):
1807 """Correction step of the CPHD filter.
1809 This corrects the hypotheses based on the measurements and gates the
1810 measurements according to the class settings. It also updates the
1811 cardinality distribution.
1814 Parameters
1815 ----------
1816 timestep: float
1817 current timestep
1818 meas_in : list
1819 2d numpy arrays representing a measurement.
1820 meas_mat_args : dict, optional
1821 keyword arguments to pass to the inner filters get measurement
1822 matrix function. Only used if gating is on. The default is {}.
1823 est_meas_args : TYPE, optional
1824 keyword arguments to pass to the inner filters estimate
1825 measurements function. Only used if gating is on. The default is {}.
1826 filt_args : TYPE, optional
1827 keyword arguments to pass to the inner filters correct function.
1828 The default is {}.
1830 Returns
1831 -------
1832 None.
1834 """
1835 meas = deepcopy(meas_in)
1837 if self.gating_on:
1838 meas = self._gate_meas(
1839 meas,
1840 self._gaussMix.means,
1841 self._gaussMix.covariances,
1842 meas_mat_args,
1843 est_meas_args,
1844 )
1845 self._meas_tab.append(meas)
1847 gmix = deepcopy(self._gaussMix) # predicted gm
1849 self._gaussMix = self._correct_prob_density(timestep, meas, gmix, filt_args)
1851 def _correct_prob_density(self, timestep, meas, probDensity, filt_args):
1852 """Helper function for correction step.
1854 Loops over all elements in a probability distribution and preforms
1855 the filter correction.
1856 """
1857 w_pred = np.zeros((len(probDensity.weights), 1))
1858 for i in range(0, len(probDensity.weights)):
1859 w_pred[i] = probDensity.weights[i]
1860 xdim = len(probDensity.means[0])
1862 plen = len(probDensity.means)
1863 zlen = len(meas)
1865 qz_temp = np.zeros((plen, zlen))
1866 mean_temp = np.zeros((zlen, xdim, plen))
1867 cov_temp = np.zeros((zlen, plen, xdim, xdim))
1869 for z_ind in range(0, zlen):
1870 for p_ind in range(0, plen):
1871 self.filter.cov = probDensity.covariances[p_ind]
1872 state = probDensity.means[p_ind]
1874 (mean, qz) = self.filter.correct(
1875 timestep, meas[z_ind], state, **filt_args
1876 )
1877 qz_temp[p_ind, z_ind] = qz
1878 mean_temp[z_ind, :, p_ind] = np.ndarray.flatten(mean)
1879 cov_temp[z_ind, p_ind, :, :] = self.filter.cov.copy()
1880 xivals = np.zeros(zlen)
1881 pdc = self.prob_detection / self.clutter_den
1882 for e in range(0, zlen):
1883 xivals[e] = pdc * np.dot(w_pred.T, qz_temp[:, [e]])
1884 esfvals_E = get_elem_sym_fnc(xivals)
1885 esfvals_D = np.zeros((zlen, zlen))
1887 for j in range(0, zlen):
1888 xi_temp = xivals.copy()
1889 xi_temp = np.delete(xi_temp, j)
1890 esfvals_D[:, [j]] = get_elem_sym_fnc(xi_temp)
1891 ups0_E = np.zeros((self.max_expected_card + 1, 1))
1892 ups1_E = np.zeros((self.max_expected_card + 1, 1))
1893 ups1_D = np.zeros((self.max_expected_card + 1, zlen))
1895 tot_w_pred = sum(w_pred)
1896 for nn in range(0, self.max_expected_card + 1):
1897 terms0_E = np.zeros((min(zlen, nn) + 1))
1898 for jj in range(0, min(zlen, nn) + 1):
1899 t1 = -self.clutter_rate + (zlen - jj) * np.log(self.clutter_rate)
1900 t2 = sum([np.log(x) for x in range(1, nn + 1)])
1901 t3 = -1 * sum([np.log(x) for x in range(1, nn - jj + 1)])
1902 t4 = (nn - jj) * np.log(self.prob_death)
1903 t5 = -jj * np.log(tot_w_pred)
1904 terms0_E[jj] = np.exp(t1 + t2 + t3 + t4 + t5) * esfvals_E[jj]
1905 ups0_E[nn] = np.sum(terms0_E)
1907 terms1_E = np.zeros((min(zlen, nn) + 1))
1908 for jj in range(0, min(zlen, nn) + 1):
1909 if nn >= jj + 1:
1910 t1 = -self.clutter_rate + (zlen - jj) * np.log(self.clutter_rate)
1911 t2 = sum([np.log(x) for x in range(1, nn + 1)])
1912 t3 = -1 * sum([np.log(x) for x in range(1, nn - (jj + 1) + 1)])
1913 t4 = (nn - (jj + 1)) * np.log(self.prob_death)
1914 t5 = -(jj + 1) * np.log(tot_w_pred)
1915 terms1_E[jj] = np.exp(t1 + t2 + t3 + t4 + t5) * esfvals_E[jj]
1916 ups1_E[nn] = np.sum(terms1_E)
1918 if zlen != 0:
1919 terms1_D = np.zeros((min(zlen - 1, nn) + 1, zlen))
1920 for ell in range(1, zlen + 1):
1921 for jj in range(0, min((zlen - 1), nn) + 1):
1922 if nn >= jj + 1:
1923 t1 = -self.clutter_rate + ((zlen - 1) - jj) * np.log(
1924 self.clutter_rate
1925 )
1926 t2 = sum([np.log(x) for x in range(1, nn + 1)])
1927 t3 = -1 * sum(
1928 [np.log(x) for x in range(1, nn - (jj + 1) + 1)]
1929 )
1930 t4 = (nn - (jj + 1)) * np.log(self.prob_death)
1931 t5 = -(jj + 1) * np.log(tot_w_pred)
1932 terms1_D[jj, ell - 1] = (
1933 np.exp(t1 + t2 + t3 + t4 + t5) * esfvals_D[jj, ell - 1]
1934 )
1935 ups1_D[nn, :] = np.sum(terms1_D, axis=0)
1936 gmix = deepcopy(probDensity)
1937 w_update = (
1938 ((ups1_E.T @ self._card_dist) / (ups0_E.T @ self._card_dist))
1939 * self.prob_miss_detection
1940 * w_pred
1941 )
1943 gmix.weights = [x.item() for x in w_update]
1945 for ee in range(0, zlen):
1946 wt_1 = (
1947 (ups1_D[:, [ee]].T @ self._card_dist) / (ups0_E.T @ self._card_dist)
1948 ).reshape((1, 1))
1949 wt_2 = self.prob_detection * qz_temp[:, [ee]] / self.clutter_den * w_pred
1950 w_temp = wt_1 * wt_2
1951 for ww in range(0, w_temp.shape[0]):
1952 gmix.add_components(
1953 mean_temp[ee, :, ww].reshape((xdim, 1)),
1954 cov_temp[ee, ww, :, :],
1955 w_temp[ww].item(),
1956 )
1957 cdn_update = self._card_dist.copy()
1958 for ii in range(0, len(cdn_update)):
1959 cdn_update[ii] = ups0_E[ii] * self._card_dist[ii]
1960 self._card_dist = cdn_update / np.sum(cdn_update)
1961 # assumes predict is called before correct
1962 self._card_time_hist[-1] = (
1963 np.argmax(self._card_dist).item(),
1964 np.std(self._card_dist),
1965 )
1967 return gmix
1969 def extract_states(self, allow_multiple=True):
1970 """Extracts the best state estimates.
1972 This extracts the best states from the distribution. It should be
1973 called once per time step after the correction function.
1975 Parameters
1976 ----------
1977 allow_multiple : bool
1978 Flag inicating if extraction is allowed to map a single Gaussian
1979 to multiple states. The default is True.
1980 """
1981 s_weights = np.argsort(self._gaussMix.weights)[::-1]
1982 s_lst = []
1983 c_lst = []
1984 self.agents_per_state = []
1985 ii = 0
1986 tot_agents = 0
1987 while ii < s_weights.size and tot_agents < self.cardinality:
1988 idx = int(s_weights[ii])
1990 if allow_multiple:
1991 n_agents = np.ceil(self._gaussMix.weights[idx])
1992 if n_agents <= 0:
1993 msg = "Gaussian weights are 0 before reaching cardinality"
1994 warnings.warn(msg, RuntimeWarning)
1995 break
1996 if tot_agents + n_agents > self.cardinality:
1997 n_agents = self.cardinality - tot_agents
1998 else:
1999 n_agents = 1
2000 tot_agents += n_agents
2001 self.agents_per_state.append(n_agents)
2003 s_lst.append(self._gaussMix.means[idx])
2004 if self.save_covs:
2005 c_lst.append(self._gaussMix.covariances[idx])
2006 ii += 1
2007 if tot_agents != self.cardinality:
2008 warnings.warn("Failed to meet estimated cardinality when extracting!")
2009 self._states.append(s_lst)
2010 if self.save_covs:
2011 self._covs.append(c_lst)
2012 if self.debug_plots:
2013 self._n_states_per_time.append(ii)
2015 def plot_card_dist(self, **kwargs):
2016 """Plots the current cardinality distribution.
2018 This assumes that the cardinality distribution has been calculated by
2019 the class.
2021 Parameters
2022 ----------
2023 **kwargs : dict, optional
2024 Keyword arguments are processed with
2025 :meth:`gncpy.plotting.init_plotting_opts`. This function
2026 implements
2028 - f_hndl
2030 Returns
2031 -------
2032 Matplotlib figure
2033 Instance of the matplotlib figure used
2035 Raises
2036 ------
2037 RuntimeWarning
2038 If the cardinality distribution is empty.
2039 """
2040 opts = pltUtil.init_plotting_opts(**kwargs)
2041 f_hndl = opts["f_hndl"]
2043 if len(self._card_dist) == 0:
2044 raise RuntimeWarning("Empty Cardinality")
2045 return f_hndl
2046 if f_hndl is None:
2047 f_hndl = plt.figure()
2048 f_hndl.add_subplot(1, 1, 1)
2049 x_vals = np.arange(0, len(self._card_dist))
2050 f_hndl.axes[0].bar(x_vals, self._card_dist)
2052 pltUtil.set_title_label(
2053 f_hndl,
2054 0,
2055 opts,
2056 ttl="Cardinality Distribution",
2057 x_lbl="Cardinality",
2058 y_lbl="Probability",
2059 )
2060 plt.tight_layout()
2062 return f_hndl
2064 def plot_card_history(
2065 self, ttl=None, true_card=None, time_units="index", time=None, **kwargs
2066 ):
2067 """Plots the current cardinality time history.
2069 This assumes that the cardinality distribution has been calculated by
2070 the class.
2072 Parameters
2073 ----------
2074 ttl : string
2075 String for the title, if None a default is created. The default is
2076 None.
2077 true_card : array like
2078 List of the true cardinality at each time
2079 time_units : string, optional
2080 Text representing the units of time in the plot. The default is
2081 'index'.
2082 time : numpy array, optional
2083 Vector to use for the x-axis of the plot. If none is given then
2084 vector indices are used. The default is None.
2085 **kwargs : dict, optional
2086 Keyword arguments are processed with
2087 :meth:`gncpy.plotting.init_plotting_opts`. This function
2088 implements
2090 - f_hndl
2091 - sig_bnd
2092 - time_vec
2093 - lgnd_loc
2095 Returns
2096 -------
2097 Matplotlib figure
2098 Instance of the matplotlib figure used
2099 """
2100 opts = pltUtil.init_plotting_opts(**kwargs)
2101 f_hndl = opts["f_hndl"]
2102 sig_bnd = opts["sig_bnd"]
2103 # time_vec = opts["time_vec"]
2104 lgnd_loc = opts["lgnd_loc"]
2105 if ttl is None:
2106 ttl = "Cardinality History"
2107 if len(self._card_time_hist) == 0:
2108 raise RuntimeWarning("Empty Cardinality")
2109 return f_hndl
2110 if sig_bnd is not None:
2111 stds = [sig_bnd * x[1] for x in self._card_time_hist]
2112 card = [x[0] for x in self._card_time_hist]
2114 if f_hndl is None:
2115 f_hndl = plt.figure()
2116 f_hndl.add_subplot(1, 1, 1)
2117 if time is None:
2118 x_vals = [ii for ii in range(0, len(card))]
2119 else:
2120 x_vals = time
2121 f_hndl.axes[0].step(
2122 x_vals,
2123 card,
2124 label="Cardinality",
2125 color="k",
2126 linestyle="-",
2127 where="post",
2128 )
2130 if true_card is not None:
2131 if len(true_card) != len(x_vals):
2132 c_len = len(true_card)
2133 t_len = len(x_vals)
2134 msg = "True Cardinality vector length ({})".format(
2135 c_len
2136 ) + " does not match time vector length ({})".format(t_len)
2137 warnings.warn(msg)
2138 else:
2139 f_hndl.axes[0].step(
2140 x_vals,
2141 true_card,
2142 color="g",
2143 label="True Cardinality",
2144 linestyle="--",
2145 where="post",
2146 )
2147 if sig_bnd is not None:
2148 lbl = r"${}\sigma$ Bound".format(sig_bnd)
2149 f_hndl.axes[0].plot(
2150 x_vals,
2151 [x + s for (x, s) in zip(card, stds)],
2152 linestyle="--",
2153 color="r",
2154 label=lbl,
2155 )
2156 f_hndl.axes[0].plot(
2157 x_vals, [x - s for (x, s) in zip(card, stds)], linestyle="--", color="r"
2158 )
2159 f_hndl.axes[0].ticklabel_format(useOffset=False)
2161 if lgnd_loc is not None:
2162 plt.legend(loc=lgnd_loc)
2163 plt.grid(True)
2164 pltUtil.set_title_label(
2165 f_hndl,
2166 0,
2167 opts,
2168 ttl=ttl,
2169 x_lbl="Time ({})".format(time_units),
2170 y_lbl="Cardinality",
2171 )
2173 plt.tight_layout()
2175 return f_hndl
2177 def plot_number_states_per_time(self, **kwargs):
2178 """Plots the number of states per timestep.
2180 This is a debug plot for if there are 0 weights in the GM but the
2181 cardinality is not reached. Debug plots must be turned on prior to
2182 running the filter.
2185 Parameters
2186 ----------
2187 **kwargs : dict, optional
2188 Keyword arguments are processed with
2189 :meth:`gncpy.plotting.init_plotting_opts`. This function
2190 implements
2192 - f_hndl
2193 - lgnd_loc
2195 Returns
2196 -------
2197 f_hndl : matplotlib figure
2198 handle to the current figure.
2200 """
2201 opts = pltUtil.init_plotting_opts(**kwargs)
2202 f_hndl = opts["f_hndl"]
2203 lgnd_loc = opts["lgnd_loc"]
2205 if not self.debug_plots:
2206 msg = "Debug plots turned off"
2207 warnings.warn(msg)
2208 return f_hndl
2209 if f_hndl is None:
2210 f_hndl = plt.figure()
2211 f_hndl.add_subplot(1, 1, 1)
2212 if lgnd_loc is not None:
2213 plt.legend(loc=lgnd_loc)
2214 x_vals = [ii for ii in range(0, len(self._n_states_per_time))]
2216 f_hndl.axes[0].plot(x_vals, self._n_states_per_time)
2217 plt.grid(True)
2218 pltUtil.set_title_label(
2219 f_hndl,
2220 0,
2221 opts,
2222 ttl="Gaussians per Timestep",
2223 x_lbl="Time",
2224 y_lbl="Number of Gaussians",
2225 )
2227 return f_hndl
2230class _IMMPHDBase:
2231 def __init__(
2232 self,
2233 filter_lst=None,
2234 model_trans=None,
2235 init_weights=None,
2236 init_means=None,
2237 init_covs=None,
2238 **kwargs,
2239 ):
2240 super().__init__(**kwargs)
2241 if not isinstance(self.filter, gfilts.InteractingMultipleModel):
2242 raise TypeError("Filter must be InteractingMultipleModel")
2243 if filter_lst is not None and model_trans is not None:
2244 self.filter.initialize_filter(filter_lst, model_trans)
2245 if init_means is not None and init_covs is not None:
2246 self.filter.initialize_states(
2247 init_means, init_covs, init_weights=init_weights
2248 )
2249 self._filter_states = []
2251 def _predict_prob_density(self, timestep, probDensity, filt_args):
2252 """Predicts the probability density.
2254 Loops over all elements in a probability distribution and performs
2255 the filter prediction.
2257 Parameters
2258 ----------
2259 timestep: float
2260 current timestep
2261 probDensity : :class:`serums.models.GaussianMixture`
2262 Probability density to perform prediction on.
2263 filt_args : dict
2264 Passed directly to the inner filter.
2266 Returns
2267 -------
2268 gm : :class:`serums.models.GaussianMixture`
2269 predicted Gaussian mixture.
2271 """
2272 weights = [self.prob_survive * x for x in probDensity.weights.copy()]
2273 n_terms = len(probDensity.means)
2274 covariances = [None] * n_terms
2275 means = [None] * n_terms
2276 for ii, (m, P) in enumerate(zip(probDensity.means, probDensity.covariances)):
2277 self.filter.load_filter_state(self._filter_states[ii])
2278 n_mean = self.filter.predict(timestep, **filt_args).reshape((m.shape[0], 1))
2279 covariances[ii] = self.filter.cov.copy()
2280 means[ii] = n_mean
2281 self._filter_states[ii] = self.filter.save_filter_state()
2282 return smodels.GaussianMixture(
2283 means=means, covariances=covariances, weights=weights
2284 )
2286 def predict(self, timestep, filt_args={}):
2287 super().predict(timestep, filt_args=filt_args)
2288 for gm in self.birth_terms:
2289 for m, c in zip(gm.means, gm.covariances):
2290 # if len(m) != 1 or len(c) != 1:
2291 # raise ValueError("only one mean and covariance per filter is supported")
2292 init_means = []
2293 init_covs = []
2294 for ii in range(0, len(self.filter.in_filt_list)):
2295 init_means.append(m)
2296 init_covs.append(c)
2297 self.filter.initialize_states(init_means, init_covs)
2298 self._filter_states.append(self.filter.save_filter_state())
2299 # new imm filter state to represent new means
2301 def _prune(self):
2302 inds = super()._prune()
2303 inds = sorted(inds, reverse=True)
2304 for ind in inds:
2305 if ind < len(self._filter_states):
2306 self._filter_states.pop(ind)
2307 else:
2308 raise RuntimeError("Pruned index is greater than filter state length")
2310 # remove pruned indices from filter state indicies
2312 def _merge(self):
2313 """Merges nearby hypotheses."""
2314 loop_inds = set(range(0, len(self._gaussMix.means)))
2316 w_lst = []
2317 m_lst = []
2318 p_lst = []
2319 fs_lst = []
2320 while len(loop_inds) > 0:
2321 jj = int(np.argmax(self._gaussMix.weights))
2322 comp_inds = []
2323 inv_cov = la.inv(self._gaussMix.covariances[jj])
2324 for ii in loop_inds:
2325 diff = self._gaussMix.means[ii] - self._gaussMix.means[jj]
2326 val = diff.T @ inv_cov @ diff
2327 if val <= self.merge_threshold:
2328 comp_inds.append(ii)
2329 w_new = sum([self._gaussMix.weights[ii] for ii in comp_inds])
2330 m_new = (
2331 sum(
2332 [
2333 self._gaussMix.weights[ii] * self._gaussMix.means[ii]
2334 for ii in comp_inds
2335 ]
2336 )
2337 / w_new
2338 )
2339 p_new = (
2340 sum(
2341 [
2342 self._gaussMix.weights[ii] * self._gaussMix.covariances[ii]
2343 for ii in comp_inds
2344 ]
2345 )
2346 / w_new
2347 )
2349 new_mean_list = []
2350 new_cov_list = []
2351 new_filt_weights = []
2352 for kk in range(0, len(self.filter.in_filt_list)):
2353 # ml_new = ( sum([self._gaussMix.weights[ii] * self._filter_states[ii]["mean_list"][kk]]))
2354 ml_new = 0
2355 cl_new = 0
2356 fw_new = 0
2358 for ii in comp_inds:
2359 ml_new = (
2360 ml_new
2361 + self._gaussMix.weights[ii]
2362 * self._filter_states[ii]["mean_list"][kk]
2363 )
2364 cl_new = (
2365 cl_new
2366 + self._gaussMix.weights[ii]
2367 * self._filter_states[ii]["cov_list"][kk]
2368 )
2369 fw_new = (
2370 fw_new
2371 + self._gaussMix.weights[ii]
2372 * self._filter_states[ii]["filt_weights"][kk]
2373 )
2374 new_mean_list.append(ml_new / w_new)
2375 new_cov_list.append(cl_new / w_new)
2376 new_filt_weights.append(fw_new / w_new)
2377 self.filter.initialize_states(
2378 new_mean_list, new_cov_list, init_weights=new_filt_weights
2379 )
2380 fs_lst.append(self.filter.save_filter_state())
2381 w_lst.append(w_new)
2382 m_lst.append(m_new)
2383 p_lst.append(p_new)
2385 loop_inds = loop_inds.symmetric_difference(comp_inds)
2386 for ii in comp_inds:
2387 self._gaussMix.weights[ii] = -1
2388 self._filter_states = fs_lst
2389 self._gaussMix = smodels.GaussianMixture(
2390 means=m_lst, covariances=p_lst, weights=w_lst
2391 )
2392 # probably need to overwrite, do this later
2394 def _cap(self):
2395 inds = super()._cap()
2396 inds = sorted(inds, reverse=True)
2397 for ind in inds:
2398 if ind < len(self._filter_states):
2399 self._filter_states.pop(ind)
2400 else:
2401 raise RuntimeError("Capped index is greater than filter state length")
2402 # remove capped indices from filter state indicies
2405class IMMProbabilityHypothesisDensity(_IMMPHDBase, ProbabilityHypothesisDensity):
2406 def __init__(self, **kwargs):
2407 super().__init__(**kwargs)
2408 # TODO: init_filter_states_for_imm
2410 def _correct_prob_density(self, timestep, meas, probDensity, filt_args):
2411 """Corrects the probability densities.
2413 Loops over all elements in a probability distribution and preforms
2414 the filter correction.
2416 Parameters
2417 ----------
2418 meas : list
2419 2d numpy arrays of each measurement.
2420 probDensity : :py:class:`serums.models.GaussianMixture`
2421 probability density to run correction on.
2422 filt_args : dict
2423 arguements to pass to the inner filter correct function.
2425 Returns
2426 -------
2427 gm : :py:class:`serums.models.GaussianMixture`
2428 corrected probability density.
2430 """
2431 means = []
2432 covariances = []
2433 weights = []
2434 # corr_filt_weights = np.zeros(np.shape(self.filter.filt_weights))
2435 new_filter_states = []
2436 det_weights = [self.prob_detection * x for x in probDensity.weights]
2437 for z in meas:
2438 w_lst = []
2439 for jj in range(0, len(probDensity.means)):
2440 self.filter.load_filter_state(self._filter_states[jj])
2441 (mean, qz) = self.filter.correct(timestep, z, **filt_args)
2442 cov = self.filter.cov
2443 w = qz * det_weights[jj]
2444 means.append(mean.reshape((5, 1)))
2445 covariances.append(cov)
2446 w_lst.append(w)
2447 new_filter_states.append(self.filter.save_filter_state())
2448 weights.extend(
2449 [x / (self.clutter_rate * self.clutter_den + sum(w_lst)) for x in w_lst]
2450 )
2451 self._filter_states = new_filter_states
2452 return smodels.GaussianMixture(
2453 means=means, covariances=covariances, weights=weights
2454 )
2456 def correct(
2457 self, timestep, meas_in, meas_mat_args={}, est_meas_args={}, filt_args={}
2458 ):
2459 meas = deepcopy(meas_in)
2461 if self.gating_on:
2462 meas = self._gate_meas(
2463 meas,
2464 self._gaussMix.means,
2465 self._gaussMix.covariances,
2466 meas_mat_args,
2467 est_meas_args,
2468 )
2469 self._meas_tab.append(meas)
2471 gmix = deepcopy(self._gaussMix)
2472 gmix.weights = [self.prob_miss_detection * x for x in gmix.weights]
2473 saved_filt_weights = []
2474 for filt_state in self._filter_states:
2475 saved_filt_weights.append(filt_state["filt_weights"].copy())
2476 gm = self._correct_prob_density(timestep, meas, self._gaussMix, filt_args)
2477 gm.add_components(gmix.means, gmix.covariances, gmix.weights)
2479 for jj, (m, c) in enumerate(zip(gmix.means, gmix.covariances)):
2480 # for m, c in zip(m_list, c_list):
2481 m_list = []
2482 c_list = []
2483 for ii in range(0, len(self.filter.in_filt_list)):
2484 m_list.append(m)
2485 c_list.append(c)
2486 self.filter.initialize_states(
2487 m_list, c_list, init_weights=saved_filt_weights[jj]
2488 )
2489 self._filter_states.append(self.filter.save_filter_state())
2491 self._gaussMix = gm
2494class IMMCardinalizedPHD(_IMMPHDBase, CardinalizedPHD):
2495 def __init__(self, **kwargs):
2496 super().__init__(**kwargs)
2498 def _correct_prob_density(self, timestep, meas, probDensity, filt_args):
2499 """Helper function for correction step.
2501 Loops over all elements in a probability distribution and preforms
2502 the filter correction.
2503 """
2504 w_pred = np.zeros((len(probDensity.weights), 1))
2505 for i in range(0, len(probDensity.weights)):
2506 w_pred[i] = probDensity.weights[i]
2507 xdim = len(probDensity.means[0])
2509 plen = len(probDensity.means)
2510 zlen = len(meas)
2512 qz_temp = np.zeros((plen, zlen))
2513 mean_temp = np.zeros((zlen, xdim, plen))
2514 cov_temp = np.zeros((zlen, plen, xdim, xdim))
2515 saved_filt_weights = []
2516 for filt_state in self._filter_states:
2517 saved_filt_weights.append(filt_state["filt_weights"].copy())
2518 new_filter_states = []
2520 for z_ind in range(0, zlen):
2521 for p_ind in range(0, plen):
2522 state = probDensity.means[p_ind]
2523 self.filter.load_filter_state(self._filter_states[p_ind])
2524 # self.filter.initialize_states(probDensity.means[p_ind], probDensity.covariances[p_ind],
2525 # init_weights=self.weight_list[p_ind])
2527 (mean, qz) = self.filter.correct(timestep, meas[z_ind], **filt_args)
2528 qz_temp[p_ind, z_ind] = qz
2529 mean_temp[z_ind, :, p_ind] = np.ndarray.flatten(mean)
2530 cov_temp[z_ind, p_ind, :, :] = self.filter.cov.copy()
2531 new_filter_states.append(self.filter.save_filter_state())
2532 # self._filter_states[p_ind] = self.filter.save_filter_state()
2533 xivals = np.zeros(zlen)
2534 pdc = self.prob_detection / self.clutter_den
2535 for e in range(0, zlen):
2536 xivals[e] = pdc * np.dot(w_pred.T, qz_temp[:, [e]])
2537 esfvals_E = get_elem_sym_fnc(xivals)
2538 esfvals_D = np.zeros((zlen, zlen))
2540 for j in range(0, zlen):
2541 xi_temp = xivals.copy()
2542 xi_temp = np.delete(xi_temp, j)
2543 esfvals_D[:, [j]] = get_elem_sym_fnc(xi_temp)
2544 ups0_E = np.zeros((self.max_expected_card + 1, 1))
2545 ups1_E = np.zeros((self.max_expected_card + 1, 1))
2546 ups1_D = np.zeros((self.max_expected_card + 1, zlen))
2548 tot_w_pred = sum(w_pred)
2549 for nn in range(0, self.max_expected_card + 1):
2550 terms0_E = np.zeros((min(zlen, nn) + 1))
2551 for jj in range(0, min(zlen, nn) + 1):
2552 t1 = -self.clutter_rate + (zlen - jj) * np.log(self.clutter_rate)
2553 t2 = sum([np.log(x) for x in range(1, nn + 1)])
2554 t3 = -1 * sum([np.log(x) for x in range(1, nn - jj + 1)])
2555 t4 = (nn - jj) * np.log(self.prob_death)
2556 t5 = -jj * np.log(tot_w_pred)
2557 terms0_E[jj] = np.exp(t1 + t2 + t3 + t4 + t5) * esfvals_E[jj]
2558 ups0_E[nn] = np.sum(terms0_E)
2560 terms1_E = np.zeros((min(zlen, nn) + 1))
2561 for jj in range(0, min(zlen, nn) + 1):
2562 if nn >= jj + 1:
2563 t1 = -self.clutter_rate + (zlen - jj) * np.log(self.clutter_rate)
2564 t2 = sum([np.log(x) for x in range(1, nn + 1)])
2565 t3 = -1 * sum([np.log(x) for x in range(1, nn - (jj + 1) + 1)])
2566 t4 = (nn - (jj + 1)) * np.log(self.prob_death)
2567 t5 = -(jj + 1) * np.log(tot_w_pred)
2568 terms1_E[jj] = np.exp(t1 + t2 + t3 + t4 + t5) * esfvals_E[jj]
2569 ups1_E[nn] = np.sum(terms1_E)
2571 if zlen != 0:
2572 terms1_D = np.zeros((min(zlen - 1, nn) + 1, zlen))
2573 for ell in range(1, zlen + 1):
2574 for jj in range(0, min((zlen - 1), nn) + 1):
2575 if nn >= jj + 1:
2576 t1 = -self.clutter_rate + ((zlen - 1) - jj) * np.log(
2577 self.clutter_rate
2578 )
2579 t2 = sum([np.log(x) for x in range(1, nn + 1)])
2580 t3 = -1 * sum(
2581 [np.log(x) for x in range(1, nn - (jj + 1) + 1)]
2582 )
2583 t4 = (nn - (jj + 1)) * np.log(self.prob_death)
2584 t5 = -(jj + 1) * np.log(tot_w_pred)
2585 terms1_D[jj, ell - 1] = (
2586 np.exp(t1 + t2 + t3 + t4 + t5) * esfvals_D[jj, ell - 1]
2587 )
2588 ups1_D[nn, :] = np.sum(terms1_D, axis=0)
2589 gmix = deepcopy(probDensity)
2590 w_update = (
2591 ((ups1_E.T @ self._card_dist) / (ups0_E.T @ self._card_dist))
2592 * self.prob_miss_detection
2593 * w_pred
2594 )
2596 old_filt_states = []
2597 gmix.weights = [x.item() for x in w_update]
2598 for jj, (m, c) in enumerate(zip(gmix.means, gmix.covariances)):
2599 m_list = []
2600 c_list = []
2601 for ff in range(0, len(self.filter.in_filt_list)):
2602 m_list.append(m)
2603 c_list.append(c)
2604 self.filter.initialize_states(
2605 m_list, c_list, init_weights=saved_filt_weights[jj]
2606 )
2607 old_filt_states.append(self.filter.save_filter_state())
2609 for ee in range(0, zlen):
2610 wt_1 = (
2611 (ups1_D[:, [ee]].T @ self._card_dist) / (ups0_E.T @ self._card_dist)
2612 ).reshape((1, 1))
2613 wt_2 = self.prob_detection * qz_temp[:, [ee]] / self.clutter_den * w_pred
2614 w_temp = wt_1 * wt_2
2615 for ww in range(0, w_temp.shape[0]):
2616 gmix.add_components(
2617 mean_temp[ee, :, ww].reshape((xdim, 1)),
2618 cov_temp[ee, ww, :, :],
2619 w_temp[ww].item(),
2620 )
2621 cdn_update = self._card_dist.copy()
2622 for ii in range(0, len(cdn_update)):
2623 cdn_update[ii] = ups0_E[ii] * self._card_dist[ii]
2624 self._card_dist = cdn_update / np.sum(cdn_update)
2625 # assumes predict is called before correct
2626 self._card_time_hist[-1] = (
2627 np.argmax(self._card_dist).item(),
2628 np.std(self._card_dist),
2629 )
2630 for filt_state in new_filter_states:
2631 old_filt_states.append(filt_state)
2632 self._filter_states = old_filt_states
2633 return gmix
2636class GeneralizedLabeledMultiBernoulli(RandomFiniteSetBase):
2637 """Delta-Generalized Labeled Multi-Bernoulli filter.
2639 Notes
2640 -----
2641 This is based on :cite:`Vo2013_LabeledRandomFiniteSetsandMultiObjectConjugatePriors`
2642 and :cite:`Vo2014_LabeledRandomFiniteSetsandtheBayesMultiTargetTrackingFilter`
2643 It does not account for agents spawned from existing tracks, only agents
2644 birthed from the given birth model.
2646 Attributes
2647 ----------
2648 req_births : int
2649 Number of requested birth hypotheses
2650 req_surv : int
2651 Number of requested surviving hypotheses
2652 req_upd : int
2653 Number of requested updated hypotheses
2654 gating_on : bool
2655 Determines if measurements are gated
2656 birth_terms :list
2657 List of tuples where the first element is a
2658 :py:class:`gncpy.distributions.GaussianMixture` and
2659 the second is the birth probability for that term
2660 prune_threshold : float
2661 Minimum association probability to keep when pruning
2662 max_hyps : int
2663 Maximum number of hypotheses to keep when capping
2664 decimal_places : int
2665 Number of decimal places to keep in label. The default is 2.
2666 save_measurements : bool
2667 Flag indicating if measurments should be saved. Useful for some extra
2668 plots.
2669 """
2671 class _TabEntry:
2672 def __init__(self):
2673 self.label = () # time step born, index of birth model born from
2674 self.distrib_weights_hist = [] # list of weights of the probDensity
2675 self.filt_states = [] # list of dictionaries from filters save function
2676 self.meas_assoc_hist = (
2677 []
2678 ) # list indices into measurement list per time step
2680 self.state_hist = [] # list of lists of numpy arrays for each timestep
2681 self.cov_hist = (
2682 []
2683 ) # list of lists of numpy arrays for each timestep (or None)
2685 """ linear index corresponding to timestep, manually updated. Used
2686 to index things since timestep in label can have decimals."""
2687 self.time_index = None
2689 def setup(self, tab):
2690 """Use to avoid expensive deepcopy."""
2691 self.label = tab.label
2692 self.distrib_weights_hist = tab.distrib_weights_hist.copy()
2693 self.filt_states = deepcopy(tab.filt_states)
2694 self.meas_assoc_hist = tab.meas_assoc_hist.copy()
2696 self.state_hist = [None] * len(tab.state_hist)
2697 self.state_hist = [s.copy() for s in [s_lst for s_lst in tab.state_hist]]
2698 self.cov_hist = [
2699 c.copy() if c else [] for c in [c_lst for c_lst in tab.cov_hist]
2700 ]
2702 self.time_index = tab.time_index
2704 return self
2706 class _HypothesisHelper:
2707 def __init__(self):
2708 self.assoc_prob = 0
2709 self.track_set = [] # indices in lookup table
2711 @property
2712 def num_tracks(self):
2713 return len(self.track_set)
2715 class _ExtractHistHelper:
2716 def __init__(self):
2717 self.label = ()
2718 self.meas_ind_hist = []
2719 self.b_time_index = None
2720 self.states = []
2721 self.covs = []
2723 def __init__(
2724 self,
2725 req_births=None,
2726 req_surv=None,
2727 req_upd=None,
2728 gating_on=False,
2729 prune_threshold=10**-15,
2730 max_hyps=3000,
2731 decimal_places=2,
2732 save_measurements=False,
2733 **kwargs,
2734 ):
2735 self.req_births = req_births
2736 self.req_surv = req_surv
2737 self.req_upd = req_upd
2738 self.gating_on = gating_on
2739 self.prune_threshold = prune_threshold
2740 self.max_hyps = max_hyps
2741 self.decimal_places = decimal_places
2742 self.save_measurements = save_measurements
2744 self._track_tab = [] # list of all possible tracks
2745 self._labels = [] # local copy for internal modification
2746 self._extractable_hists = []
2748 self._filter = None
2749 self._baseFilter = None
2751 hyp0 = self._HypothesisHelper()
2752 hyp0.assoc_prob = 1
2753 hyp0.track_set = []
2754 self._hypotheses = [hyp0] # list of _HypothesisHelper objects
2756 self._card_dist = [] # probability of having index # as cardinality
2758 """ linear index corresponding to timestep, manually updated. Used
2759 to index things since timestep in label can have decimals. Must
2760 be updated once per time step."""
2761 self._time_index_cntr = 0
2763 self.ospa2 = None
2764 self.ospa2_localization = None
2765 self.ospa2_cardinality = None
2766 self._ospa2_params = {}
2768 super().__init__(**kwargs)
2769 self._states = [[]]
2771 def save_filter_state(self):
2772 """Saves filter variables so they can be restored later.
2774 Note that to pickle the resulting dictionary the :code:`dill` package
2775 may need to be used due to potential pickling of functions.
2776 """
2777 filt_state = super().save_filter_state()
2779 filt_state["req_births"] = self.req_births
2780 filt_state["req_surv"] = self.req_surv
2781 filt_state["req_upd"] = self.req_upd
2782 filt_state["gating_on"] = self.gating_on
2783 filt_state["prune_threshold"] = self.prune_threshold
2784 filt_state["max_hyps"] = self.max_hyps
2785 filt_state["decimal_places"] = self.decimal_places
2786 filt_state["save_measurements"] = self.save_measurements
2788 filt_state["_track_tab"] = self._track_tab
2789 filt_state["_labels"] = self._labels
2790 filt_state["_extractable_hists"] = self._extractable_hists
2792 if self._baseFilter is not None:
2793 filt_state["_baseFilter"] = (
2794 type(self._baseFilter),
2795 self._baseFilter.save_filter_state(),
2796 )
2797 else:
2798 filt_state["_baseFilter"] = (None, self._baseFilter)
2799 filt_state["_hypotheses"] = self._hypotheses
2800 filt_state["_card_dist"] = self._card_dist
2801 filt_state["_time_index_cntr"] = self._time_index_cntr
2803 filt_state["ospa2"] = self.ospa2
2804 filt_state["ospa2_localization"] = self.ospa2_localization
2805 filt_state["ospa2_cardinality"] = self.ospa2_cardinality
2806 filt_state["_ospa2_params"] = self._ospa_params
2808 return filt_state
2810 def load_filter_state(self, filt_state):
2811 """Initializes filter using saved filter state.
2813 Attributes
2814 ----------
2815 filt_state : dict
2816 Dictionary generated by :meth:`save_filter_state`.
2817 """
2818 super().load_filter_state(filt_state)
2820 self.req_births = filt_state["req_births"]
2821 self.req_surv = filt_state["req_surv"]
2822 self.req_upd = filt_state["req_upd"]
2823 self.gating_on = filt_state["gating_on"]
2824 self.prune_threshold = filt_state["prune_threshold"]
2825 self.max_hyps = filt_state["max_hyps"]
2826 self.decimal_places = filt_state["decimal_places"]
2827 self.save_measurements = filt_state["save_measurements"]
2829 self._track_tab = filt_state["_track_tab"]
2830 self._labels = filt_state["_labels"]
2831 self._extractable_hists = filt_state["_extractable_hists"]
2833 cls_type = filt_state["_baseFilter"][0]
2834 if cls_type is not None:
2835 self._baseFilter = cls_type()
2836 self._baseFilter.load_filter_state(filt_state["_baseFilter"][1])
2837 else:
2838 self._baseFilter = None
2839 self._hypotheses = filt_state["_hypotheses"]
2840 self._card_dist = filt_state["_card_dist"]
2841 self._time_index_cntr = filt_state["_time_index_cntr"]
2843 self.ospa2 = filt_state["ospa2"]
2844 self.ospa2_localization = filt_state["ospa2_localization"]
2845 self.ospa2_cardinality = filt_state["ospa2_cardinality"]
2846 self._ospa2_params = filt_state["_ospa2_params"]
2848 @property
2849 def states(self):
2850 """Read only list of extracted states.
2852 This is a list with 1 element per timestep, and each element is a list
2853 of the best states extracted at that timestep. The order of each
2854 element corresponds to the label order.
2855 """
2856 return self._states
2858 @property
2859 def labels(self):
2860 """Read only list of extracted labels.
2862 This is a list with 1 element per timestep, and each element is a list
2863 of the best labels extracted at that timestep. The order of each
2864 element corresponds to the state order.
2865 """
2866 return self._labels
2868 @property
2869 def covariances(self):
2870 """Read only list of extracted covariances.
2872 This is a list with 1 element per timestep, and each element is a list
2873 of the best covariances extracted at that timestep. The order of each
2874 element corresponds to the state order.
2876 Raises
2877 ------
2878 RuntimeWarning
2879 If the class is not saving the covariances, and returns an empty list.
2880 """
2881 if not self.save_covs:
2882 raise RuntimeWarning("Not saving covariances")
2883 return []
2884 return self._covs
2886 @property
2887 def filter(self):
2888 """Inner filter handling dynamics, must be a gncpy.filters.BayesFilter."""
2889 return self._filter
2891 @filter.setter
2892 def filter(self, val):
2893 self._baseFilter = deepcopy(val)
2894 self._filter = val
2896 @property
2897 def cardinality(self):
2898 """Cardinality estimate."""
2899 return np.argmax(self._card_dist)
2901 def _init_filt_states(self, distrib):
2902 filt_states = [None] * len(distrib.means)
2903 states = [m.copy() for m in distrib.means]
2904 if self.save_covs:
2905 covs = [c.copy() for c in distrib.covariances]
2906 else:
2907 covs = []
2908 weights = distrib.weights.copy()
2909 for ii, (m, cov) in enumerate(zip(distrib.means, distrib.covariances)):
2910 self._baseFilter.cov = cov.copy()
2911 if isinstance(self._baseFilter, gfilts.UnscentedKalmanFilter) or isinstance(
2912 self._baseFilter, gfilts.UKFGaussianScaleMixtureFilter
2913 ):
2914 self._baseFilter.init_sigma_points(m)
2915 filt_states[ii] = self._baseFilter.save_filter_state()
2916 return filt_states, weights, states, covs
2918 def _gen_birth_tab(self, timestep):
2919 log_cost = []
2920 birth_tab = []
2921 for ii, (distrib, p) in enumerate(self.birth_terms):
2922 cost = p / (1 - p)
2923 log_cost.append(-np.log(cost))
2924 entry = self._TabEntry()
2925 entry.state_hist = [None]
2926 entry.cov_hist = [None]
2927 entry.distrib_weights_hist = [None]
2928 (
2929 entry.filt_states,
2930 entry.distrib_weights_hist[0],
2931 entry.state_hist[0],
2932 entry.cov_hist[0],
2933 ) = self._init_filt_states(distrib)
2934 entry.label = (round(timestep, self.decimal_places), ii)
2935 entry.time_index = self._time_index_cntr
2936 birth_tab.append(entry)
2937 return birth_tab, log_cost
2939 def _gen_birth_hyps(self, paths, hyp_costs):
2940 birth_hyps = []
2941 tot_b_prob = sum([np.log(1 - x[1]) for x in self.birth_terms])
2942 for p, c in zip(paths, hyp_costs):
2943 hyp = self._HypothesisHelper()
2944 # NOTE: this may suffer from underflow and can be improved
2945 hyp.assoc_prob = tot_b_prob - c.item()
2946 hyp.track_set = p
2947 birth_hyps.append(hyp)
2948 lse = log_sum_exp([x.assoc_prob for x in birth_hyps])
2949 for ii in range(0, len(birth_hyps)):
2950 birth_hyps[ii].assoc_prob = np.exp(birth_hyps[ii].assoc_prob - lse)
2951 return birth_hyps
2953 def _inner_predict(self, timestep, filt_state, state, filt_args):
2954 self.filter.load_filter_state(filt_state)
2955 new_s = self.filter.predict(timestep, state, **filt_args)
2956 new_f_state = self.filter.save_filter_state()
2957 if self.save_covs:
2958 new_cov = self.filter.cov.copy()
2959 else:
2960 new_cov = None
2961 return new_f_state, new_s, new_cov
2963 def _predict_track_tab_entry(self, tab, timestep, filt_args):
2964 """Updates table entries probability density."""
2965 newTab = self._TabEntry().setup(tab)
2966 new_f_states = [None] * len(newTab.filt_states)
2967 new_s_hist = [None] * len(newTab.filt_states)
2968 new_c_hist = [None] * len(newTab.filt_states)
2969 for ii, (f_state, state) in enumerate(
2970 zip(newTab.filt_states, newTab.state_hist[-1])
2971 ):
2972 (new_f_states[ii], new_s_hist[ii], new_c_hist[ii]) = self._inner_predict(
2973 timestep, f_state, state, filt_args
2974 )
2975 newTab.filt_states = new_f_states
2976 newTab.state_hist.append(new_s_hist)
2977 newTab.cov_hist.append(new_c_hist)
2978 newTab.distrib_weights_hist.append(newTab.distrib_weights_hist[-1].copy())
2979 return newTab
2981 def _gen_surv_tab(self, timestep, filt_args):
2982 surv_tab = []
2983 for ii, track in enumerate(self._track_tab):
2984 entry = self._predict_track_tab_entry(track, timestep, filt_args)
2986 surv_tab.append(entry)
2987 return surv_tab
2989 def _gen_surv_hyps(self, avg_prob_survive, avg_prob_death):
2990 surv_hyps = []
2991 sum_sqrt_w = 0
2992 # avg_prob_mm =
2993 for hyp in self._hypotheses:
2994 sum_sqrt_w = sum_sqrt_w + np.sqrt(hyp.assoc_prob)
2995 for hyp in self._hypotheses:
2996 if hyp.num_tracks == 0:
2997 new_hyp = self._HypothesisHelper()
2998 new_hyp.assoc_prob = np.log(hyp.assoc_prob)
2999 new_hyp.track_set = hyp.track_set
3000 surv_hyps.append(new_hyp)
3001 else:
3002 cost = avg_prob_survive[hyp.track_set] / avg_prob_death[hyp.track_set]
3003 log_cost = -np.log(cost) # this is length hyp.num_tracks
3004 k = np.round(self.req_surv * np.sqrt(hyp.assoc_prob) / sum_sqrt_w)
3005 (paths, hyp_cost) = k_shortest(np.array(log_cost), k)
3007 pdeath_log = np.sum(
3008 [np.log(avg_prob_death[ii]) for ii in hyp.track_set]
3009 )
3011 for p, c in zip(paths, hyp_cost):
3012 new_hyp = self._HypothesisHelper()
3013 new_hyp.assoc_prob = pdeath_log + np.log(hyp.assoc_prob) - c.item()
3014 if len(p) > 0:
3015 new_hyp.track_set = [hyp.track_set[ii] for ii in p]
3016 else:
3017 new_hyp.track_set = []
3018 surv_hyps.append(new_hyp)
3019 lse = log_sum_exp([x.assoc_prob for x in surv_hyps])
3020 for ii in range(0, len(surv_hyps)):
3021 surv_hyps[ii].assoc_prob = np.exp(surv_hyps[ii].assoc_prob - lse)
3022 return surv_hyps
3024 def _calc_avg_prob_surv_death(self):
3025 avg_prob_survive = self.prob_survive * np.ones(len(self._track_tab))
3026 avg_prob_death = 1 - avg_prob_survive
3028 return avg_prob_survive, avg_prob_death
3030 def _set_pred_hyps(self, birth_tab, birth_hyps, surv_hyps):
3031 self._hypotheses = []
3032 tot_w = 0
3033 for b_hyp in birth_hyps:
3034 for s_hyp in surv_hyps:
3035 new_hyp = self._HypothesisHelper()
3036 new_hyp.assoc_prob = b_hyp.assoc_prob * s_hyp.assoc_prob
3037 tot_w = tot_w + new_hyp.assoc_prob
3038 surv_lst = []
3039 for x in s_hyp.track_set:
3040 surv_lst.append(x + len(birth_tab))
3041 new_hyp.track_set = b_hyp.track_set + surv_lst
3042 self._hypotheses.append(new_hyp)
3043 for ii in range(0, len(self._hypotheses)):
3044 n_val = self._hypotheses[ii].assoc_prob / tot_w
3045 self._hypotheses[ii].assoc_prob = n_val
3047 def _calc_card_dist(self, hyp_lst):
3048 """Calucaltes the cardinality distribution."""
3049 if len(hyp_lst) == 0:
3050 return [
3051 1,
3052 ]
3053 card_dist = []
3054 for ii in range(0, max(map(lambda x: x.num_tracks, hyp_lst)) + 1):
3055 card = 0
3056 for hyp in hyp_lst:
3057 if hyp.num_tracks == ii:
3058 card = card + hyp.assoc_prob
3059 card_dist.append(card)
3060 return card_dist
3062 def _clean_predictions(self):
3063 hash_lst = []
3064 for hyp in self._hypotheses:
3065 if len(hyp.track_set) == 0:
3066 lst = []
3067 else:
3068 sorted_inds = hyp.track_set.copy()
3069 sorted_inds.sort()
3070 lst = [int(x) for x in sorted_inds]
3071 h = hash("*".join(map(str, lst)))
3072 hash_lst.append(h)
3073 new_hyps = []
3074 used_hash = []
3075 for ii, h in enumerate(hash_lst):
3076 if h not in used_hash:
3077 used_hash.append(h)
3078 new_hyps.append(self._hypotheses[ii])
3079 else:
3080 new_ii = used_hash.index(h)
3081 new_hyps[new_ii].assoc_prob += self._hypotheses[ii].assoc_prob
3082 self._hypotheses = new_hyps
3084 def predict(self, timestep, filt_args={}):
3085 """Prediction step of the GLMB filter.
3087 This predicts new hypothesis, and propogates them to the next time
3088 step. It also updates the cardinality distribution.
3090 Parameters
3091 ----------
3092 timestep: float
3093 Current timestep.
3094 filt_args : dict, optional
3095 Passed to the inner filter. The default is {}.
3097 Returns
3098 -------
3099 None.
3100 """
3101 # Find cost for each birth track, and setup lookup table
3102 birth_tab, log_cost = self._gen_birth_tab(timestep)
3104 # get K best hypothesis, and their index in the lookup table
3105 (paths, hyp_costs) = k_shortest(np.array(log_cost), self.req_births)
3107 # calculate association probabilities for birth hypothesis
3108 birth_hyps = self._gen_birth_hyps(paths, hyp_costs)
3110 # Init and propagate surviving track table
3111 surv_tab = self._gen_surv_tab(timestep, filt_args)
3113 # Calculation for average survival/death probabilities
3114 (avg_prob_survive, avg_prob_death) = self._calc_avg_prob_surv_death()
3116 # loop over postierior components
3117 surv_hyps = self._gen_surv_hyps(avg_prob_survive, avg_prob_death)
3119 self._card_dist = self._calc_card_dist(surv_hyps)
3121 # Get predicted hypothesis by convolution
3122 self._track_tab = birth_tab + surv_tab
3123 self._set_pred_hyps(birth_tab, birth_hyps, surv_hyps)
3125 self._clean_predictions()
3127 def _inner_correct(
3128 self, timestep, meas, filt_state, distrib_weight, state, filt_args
3129 ):
3130 self.filter.load_filter_state(filt_state)
3131 cor_state, likely = self.filter.correct(timestep, meas, state, **filt_args)
3132 new_f_state = self.filter.save_filter_state()
3133 new_s = cor_state
3134 if self.save_covs:
3135 new_c = self.filter.cov.copy()
3136 else:
3137 new_c = None
3138 new_w = distrib_weight * likely
3140 return new_f_state, new_s, new_c, new_w
3142 def _correct_track_tab_entry(self, meas, tab, timestep, filt_args):
3143 newTab = self._TabEntry().setup(tab)
3144 new_f_states = [None] * len(newTab.filt_states)
3145 new_s_hist = [None] * len(newTab.filt_states)
3146 new_c_hist = [None] * len(newTab.filt_states)
3147 new_w = [None] * len(newTab.filt_states)
3148 depleted = False
3149 for ii, (f_state, state, w) in enumerate(
3150 zip(
3151 newTab.filt_states,
3152 newTab.state_hist[-1],
3153 newTab.distrib_weights_hist[-1],
3154 )
3155 ):
3156 try:
3157 (
3158 new_f_states[ii],
3159 new_s_hist[ii],
3160 new_c_hist[ii],
3161 new_w[ii],
3162 ) = self._inner_correct(timestep, meas, f_state, w, state, filt_args)
3163 except (
3164 gerr.ParticleDepletionError,
3165 gerr.ParticleEstimationDomainError,
3166 gerr.ExtremeMeasurementNoiseError,
3167 ):
3168 return None, 0
3169 newTab.filt_states = new_f_states
3170 newTab.state_hist[-1] = new_s_hist
3171 newTab.cov_hist[-1] = new_c_hist
3172 new_w = [w + np.finfo(float).eps for w in new_w]
3173 if not depleted:
3174 cost = np.sum(new_w).item()
3175 newTab.distrib_weights_hist[-1] = [w / cost for w in new_w]
3176 else:
3177 cost = 0
3178 return newTab, cost
3180 def _gen_cor_tab(self, num_meas, meas, timestep, filt_args):
3181 num_pred = len(self._track_tab)
3182 up_tab = [None] * (num_meas + 1) * num_pred
3184 for ii, track in enumerate(self._track_tab):
3185 up_tab[ii] = self._TabEntry().setup(track)
3186 up_tab[ii].meas_assoc_hist.append(None)
3187 # measurement updated tracks
3188 all_cost_m = np.zeros((num_pred, num_meas))
3189 for emm, z in enumerate(meas):
3190 for ii, ent in enumerate(self._track_tab):
3191 s_to_ii = num_pred * emm + ii + num_pred
3192 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry(
3193 z, ent, timestep, filt_args
3194 )
3196 # update association history with current measurement index
3197 if up_tab[s_to_ii] is not None:
3198 up_tab[s_to_ii].meas_assoc_hist.append(emm)
3199 all_cost_m[ii, emm] = cost
3200 return up_tab, all_cost_m
3202 def _gen_cor_hyps(
3203 self, num_meas, avg_prob_detect, avg_prob_miss_detect, all_cost_m
3204 ):
3205 num_pred = len(self._track_tab)
3206 up_hyps = []
3207 if num_meas == 0:
3208 for hyp in self._hypotheses:
3209 pmd_log = np.sum(
3210 [np.log(avg_prob_miss_detect[ii]) for ii in hyp.track_set]
3211 )
3212 hyp.assoc_prob = -self.clutter_rate + pmd_log + np.log(hyp.assoc_prob)
3213 up_hyps.append(hyp)
3214 else:
3215 clutter = self.clutter_rate * self.clutter_den
3216 ss_w = 0
3217 for p_hyp in self._hypotheses:
3218 ss_w += np.sqrt(p_hyp.assoc_prob)
3219 for p_hyp in self._hypotheses:
3220 if p_hyp.num_tracks == 0: # all clutter
3221 new_hyp = self._HypothesisHelper()
3222 new_hyp.assoc_prob = (
3223 -self.clutter_rate
3224 + num_meas * np.log(clutter)
3225 + np.log(p_hyp.assoc_prob)
3226 )
3227 new_hyp.track_set = p_hyp.track_set.copy()
3228 up_hyps.append(new_hyp)
3229 else:
3230 pd = np.array([avg_prob_detect[ii] for ii in p_hyp.track_set])
3231 pmd = np.array([avg_prob_miss_detect[ii] for ii in p_hyp.track_set])
3232 ratio = pd / pmd
3234 ratio = ratio.reshape((ratio.size, 1))
3235 ratio = np.tile(ratio, (1, num_meas))
3237 cost_m = np.zeros(all_cost_m[p_hyp.track_set, :].shape)
3238 for ii, ts in enumerate(p_hyp.track_set):
3239 cost_m[ii, :] = ratio[ii] * all_cost_m[ts, :] / clutter
3240 max_row_inds, max_col_inds = np.where(cost_m >= np.inf)
3241 if max_row_inds.size > 0:
3242 cost_m[max_row_inds, max_col_inds] = np.finfo(float).max
3243 min_row_inds, min_col_inds = np.where(cost_m <= 0.0)
3244 if min_row_inds.size > 0:
3245 cost_m[min_row_inds, min_col_inds] = np.finfo(float).eps # 1
3246 neg_log = -np.log(cost_m)
3247 # if max_row_inds.size > 0:
3248 # neg_log[max_row_inds, max_col_inds] = -np.inf
3249 # if min_row_inds.size > 0:
3250 # neg_log[min_row_inds, min_col_inds] = np.inf
3252 m = np.round(self.req_upd * np.sqrt(p_hyp.assoc_prob) / ss_w)
3253 m = int(m.item())
3254 [assigns, costs] = murty_m_best(neg_log, m)
3256 pmd_log = np.sum(
3257 [np.log(avg_prob_miss_detect[ii]) for ii in p_hyp.track_set]
3258 )
3259 for a, c in zip(assigns, costs):
3260 new_hyp = self._HypothesisHelper()
3261 new_hyp.assoc_prob = (
3262 -self.clutter_rate
3263 + num_meas * np.log(clutter)
3264 + pmd_log
3265 + np.log(p_hyp.assoc_prob)
3266 - c
3267 )
3268 new_hyp.track_set = list(
3269 np.array(p_hyp.track_set) + num_pred * a
3270 )
3271 up_hyps.append(new_hyp)
3272 lse = log_sum_exp([x.assoc_prob for x in up_hyps])
3273 for ii in range(0, len(up_hyps)):
3274 up_hyps[ii].assoc_prob = np.exp(up_hyps[ii].assoc_prob - lse)
3275 return up_hyps
3277 def _calc_avg_prob_det_mdet(self):
3278 avg_prob_detect = self.prob_detection * np.ones(len(self._track_tab))
3279 avg_prob_miss_detect = 1 - avg_prob_detect
3281 return avg_prob_detect, avg_prob_miss_detect
3283 def _clean_updates(self):
3284 used = [0] * len(self._track_tab)
3285 for hyp in self._hypotheses:
3286 for ii in hyp.track_set:
3287 if self._track_tab[ii] is not None:
3288 used[ii] += 1
3289 nnz_inds = [idx for idx, val in enumerate(used) if val != 0]
3290 track_cnt = len(nnz_inds)
3292 new_inds = [None] * len(self._track_tab)
3293 for ii, v in zip(nnz_inds, [ii for ii in range(0, track_cnt)]):
3294 new_inds[ii] = v
3295 # new_tab = [self._TabEntry().setup(self._track_tab[ii]) for ii in nnz_inds]
3296 new_tab = [self._track_tab[ii] for ii in nnz_inds]
3297 new_hyps = []
3298 for ii, hyp in enumerate(self._hypotheses):
3299 if len(hyp.track_set) > 0:
3300 track_set = [new_inds[ii] for ii in hyp.track_set]
3301 if None in track_set:
3302 continue
3303 hyp.track_set = track_set
3304 new_hyps.append(hyp)
3305 self._track_tab = new_tab
3306 self._hypotheses = new_hyps
3308 def correct(self, timestep, meas, filt_args={}):
3309 """Correction step of the GLMB filter.
3311 Notes
3312 -----
3313 This corrects the hypotheses based on the measurements and gates the
3314 measurements according to the class settings. It also updates the
3315 cardinality distribution.
3317 Parameters
3318 ----------
3319 timestep: float
3320 Current timestep.
3321 meas_in : list
3322 List of Nm x 1 numpy arrays each representing a measuremnt.
3323 filt_args : dict, optional
3324 keyword arguments to pass to the inner filters correct function.
3325 The default is {}.
3327 .. todo::
3328 Fix the measurement gating
3330 Returns
3331 -------
3332 None
3333 """
3334 # gate measurements by tracks
3335 if self.gating_on:
3336 warnings.warn("Gating not implemented yet. SKIPPING", RuntimeWarning)
3337 # means = []
3338 # covs = []
3339 # for ent in self._track_tab:
3340 # means.extend(ent.probDensity.means)
3341 # covs.extend(ent.probDensity.covariances)
3342 # meas = self._gate_meas(meas, means, covs)
3343 if self.save_measurements:
3344 self._meas_tab.append(deepcopy(meas))
3345 num_meas = len(meas)
3347 # missed detection tracks
3348 cor_tab, all_cost_m = self._gen_cor_tab(num_meas, meas, timestep, filt_args)
3350 # Calculation for average detection/missed probabilities
3351 avg_prob_det, avg_prob_mdet = self._calc_avg_prob_det_mdet()
3353 # component updates
3354 cor_hyps = self._gen_cor_hyps(num_meas, avg_prob_det, avg_prob_mdet, all_cost_m)
3356 # save values and cleanup
3357 self._track_tab = cor_tab
3358 self._hypotheses = cor_hyps
3359 self._card_dist = self._calc_card_dist(self._hypotheses)
3360 self._clean_updates()
3362 def _extract_helper(self, track):
3363 states = [None] * len(track.state_hist)
3364 covs = [None] * len(track.state_hist)
3365 for ii, (w_lst, s_lst, c_lst) in enumerate(
3366 zip(track.distrib_weights_hist, track.state_hist, track.cov_hist)
3367 ):
3368 idx = np.argmax(w_lst)
3369 states[ii] = s_lst[idx]
3370 if self.save_covs:
3371 covs[ii] = c_lst[idx]
3372 return states, covs
3374 def _update_extract_hist(self, idx_cmp):
3375 used_meas_inds = [[] for ii in range(self._time_index_cntr)]
3376 used_labels = []
3377 new_extract_hists = [None] * len(self._hypotheses[idx_cmp].track_set)
3378 for ii, track in enumerate(
3379 [
3380 self._track_tab[trk_ind]
3381 for trk_ind in self._hypotheses[idx_cmp].track_set
3382 ]
3383 ):
3384 new_extract_hists[ii] = self._ExtractHistHelper()
3385 new_extract_hists[ii].label = track.label
3386 new_extract_hists[ii].meas_ind_hist = track.meas_assoc_hist.copy()
3387 new_extract_hists[ii].b_time_index = track.time_index
3388 (
3389 new_extract_hists[ii].states,
3390 new_extract_hists[ii].covs,
3391 ) = self._extract_helper(track)
3393 used_labels.append(track.label)
3395 for t_inds_after_b, meas_ind in enumerate(
3396 new_extract_hists[ii].meas_ind_hist
3397 ):
3398 tt = new_extract_hists[ii].b_time_index + t_inds_after_b
3399 if meas_ind is not None and meas_ind not in used_meas_inds[tt]:
3400 used_meas_inds[tt].append(meas_ind)
3401 good_inds = []
3402 for ii, existing in enumerate(self._extractable_hists):
3403 used = existing.label in used_labels
3404 if used:
3405 continue
3406 for t_inds_after_b, meas_ind in enumerate(existing.meas_ind_hist):
3407 tt = existing.b_time_index + t_inds_after_b
3408 used = meas_ind is not None and meas_ind in used_meas_inds[tt]
3409 if used:
3410 break
3411 if not used:
3412 good_inds.append(ii)
3413 self._extractable_hists = [self._extractable_hists[ii] for ii in good_inds]
3414 self._extractable_hists.extend(new_extract_hists)
3416 def extract_states(self, update=True, calc_states=True):
3417 """Extracts the best state estimates.
3419 This extracts the best states from the distribution. It should be
3420 called once per time step after the correction function. This calls
3421 both the inner filters predict and correct functions so the keyword
3422 arguments must contain any additional variables needed by those
3423 functions.
3425 Parameters
3426 ----------
3427 update : bool, optional
3428 Flag indicating if the label history should be updated. This should
3429 be done once per timestep and can be disabled if calculating states
3430 after the final timestep. The default is True.
3431 calc_states : bool, optional
3432 Flag indicating if the states should be calculated based on the
3433 label history. This only needs to be done before the states are used.
3434 It can simply be called once after the end of the simulation. The
3435 default is true.
3437 Returns
3438 -------
3439 idx_cmp : int
3440 Index of the hypothesis table used when extracting states.
3441 """
3442 card = np.argmax(self._card_dist)
3443 tracks_per_hyp = np.array([x.num_tracks for x in self._hypotheses])
3444 weight_per_hyp = np.array([x.assoc_prob for x in self._hypotheses])
3446 self._states = [[] for ii in range(self._time_index_cntr)]
3447 self._labels = [[] for ii in range(self._time_index_cntr)]
3448 self._covs = [[] for ii in range(self._time_index_cntr)]
3450 if len(tracks_per_hyp) == 0:
3451 return None
3452 idx_cmp = np.argmax(weight_per_hyp * (tracks_per_hyp == card))
3453 if update:
3454 self._update_extract_hist(idx_cmp)
3455 if calc_states:
3456 for existing in self._extractable_hists:
3457 for t_inds_after_b, (s, c) in enumerate(
3458 zip(existing.states, existing.covs)
3459 ):
3460 tt = existing.b_time_index + t_inds_after_b
3461 # if len(self._labels[tt]) == 0:
3462 # self._states[tt] = [s]
3463 # self._labels[tt] = [existing.label]
3464 # self._covs[tt] = [c]
3465 # else:
3466 self._states[tt].append(s)
3467 self._labels[tt].append(existing.label)
3468 self._covs[tt].append(c)
3469 if not update and not calc_states:
3470 warnings.warn("Extracting states performed no actions")
3471 return idx_cmp
3473 def extract_most_prob_states(self, thresh):
3474 """Extracts the most probable hypotheses up to a threshold.
3476 Parameters
3477 ----------
3478 thresh : float
3479 Minimum association probability to extract.
3481 Returns
3482 -------
3483 state_sets : list
3484 Each element is the state list from the normal
3485 :meth:`carbs.swarm_estimator.tracker.GeneralizedLabeledMultiBernoulli.extract_states`.
3486 label_sets : list
3487 Each element is the label list from the normal
3488 :meth:`carbs.swarm_estimator.tracker.GeneralizedLabeledMultiBernoulli.extract_states`
3489 cov_sets : list
3490 Each element is the covariance list from the normal
3491 :meth:`carbs.swarm_estimator.tracker.GeneralizedLabeledMultiBernoulli.extract_states`
3492 if the covariances are saved.
3493 probs : list
3494 Each element is the association probability for the extracted states.
3495 """
3496 loc_self = deepcopy(self)
3497 state_sets = []
3498 cov_sets = []
3499 label_sets = []
3500 probs = []
3502 idx = loc_self.extract_states()
3503 if idx is None:
3504 return (state_sets, label_sets, cov_sets, probs)
3505 state_sets.append(loc_self.states.copy())
3506 label_sets.append(loc_self.labels.copy())
3507 if loc_self.save_covs:
3508 cov_sets.append(loc_self.covariances.copy())
3509 probs.append(loc_self._hypotheses[idx].assoc_prob)
3510 loc_self._hypotheses[idx].assoc_prob = 0
3511 while True:
3512 idx = loc_self.extract_states()
3513 if idx is None:
3514 break
3515 if loc_self._hypotheses[idx].assoc_prob >= thresh:
3516 state_sets.append(loc_self.states.copy())
3517 label_sets.append(loc_self.labels.copy())
3518 if loc_self.save_covs:
3519 cov_sets.append(loc_self.covariances.copy())
3520 probs.append(loc_self._hypotheses[idx].assoc_prob)
3521 loc_self._hypotheses[idx].assoc_prob = 0
3522 else:
3523 break
3524 return (state_sets, label_sets, cov_sets, probs)
3526 def _prune(self):
3527 """Removes hypotheses below a threshold.
3529 This should be called once per time step after the correction and
3530 before the state extraction.
3531 """
3532 # Find hypotheses with low association probabilities
3533 temp_assoc_probs = np.array([])
3534 for ii in range(0, len(self._hypotheses)):
3535 temp_assoc_probs = np.append(
3536 temp_assoc_probs, self._hypotheses[ii].assoc_prob
3537 )
3538 keep_indices = np.argwhere(temp_assoc_probs > self.prune_threshold).T
3539 keep_indices = keep_indices.flatten()
3541 # For re-weighing association probabilities
3542 new_sum = np.sum(temp_assoc_probs[keep_indices])
3543 self._hypotheses = [self._hypotheses[ii] for ii in keep_indices]
3544 for ii in range(0, len(keep_indices)):
3545 self._hypotheses[ii].assoc_prob = self._hypotheses[ii].assoc_prob / new_sum
3546 # Re-calculate cardinality
3547 self._card_dist = self._calc_card_dist(self._hypotheses)
3549 def _cap(self):
3550 """Removes least likely hypotheses until a maximum number is reached.
3552 This should be called once per time step after pruning and
3553 before the state extraction.
3554 """
3555 # Determine if there are too many hypotheses
3556 if len(self._hypotheses) > self.max_hyps:
3557 temp_assoc_probs = np.array([])
3558 for ii in range(0, len(self._hypotheses)):
3559 temp_assoc_probs = np.append(
3560 temp_assoc_probs, self._hypotheses[ii].assoc_prob
3561 )
3562 sorted_indices = np.argsort(temp_assoc_probs)
3564 # Reverse order to get descending array
3565 sorted_indices = sorted_indices[::-1]
3567 # Take the top n assoc_probs, where n = max_hyps
3568 keep_indices = np.array([], dtype=np.int64)
3569 for ii in range(0, self.max_hyps):
3570 keep_indices = np.append(keep_indices, int(sorted_indices[ii]))
3571 # Assign to class
3572 self._hypotheses = [self._hypotheses[ii] for ii in keep_indices]
3574 # Normalize association probabilities
3575 new_sum = 0
3576 for ii in range(0, len(self._hypotheses)):
3577 new_sum = new_sum + self._hypotheses[ii].assoc_prob
3578 for ii in range(0, len(self._hypotheses)):
3579 self._hypotheses[ii].assoc_prob = (
3580 self._hypotheses[ii].assoc_prob / new_sum
3581 )
3582 # Re-calculate cardinality
3583 self._card_dist = self._calc_card_dist(self._hypotheses)
3585 def cleanup(
3586 self,
3587 enable_prune=True,
3588 enable_cap=True,
3589 enable_extract=True,
3590 extract_kwargs=None,
3591 ):
3592 """Performs the cleanup step of the filter.
3594 This can prune, cap, and extract states. It must be called once per
3595 timestep, even if all three functions are disabled. This is to ensure
3596 that internal counters for tracking linear timestep indices are properly
3597 incremented. If this is called with `enable_extract` set to true then
3598 the extract states method does not need to be called separately. It is
3599 recommended to call this function instead of
3600 :meth:`carbs.swarm_estimator.tracker.GeneralizedLabeledMultiBernoulli.extract_states`
3601 directly.
3603 Parameters
3604 ----------
3605 enable_prune : bool, optional
3606 Flag indicating if prunning should be performed. The default is True.
3607 enable_cap : bool, optional
3608 Flag indicating if capping should be performed. The default is True.
3609 enable_extract : bool, optional
3610 Flag indicating if state extraction should be performed. The default is True.
3611 extract_kwargs : dict, optional
3612 Additional arguments to pass to :meth:`.extract_states`. The
3613 default is None. Only used if extracting states.
3615 Returns
3616 -------
3617 None.
3619 """
3620 self._time_index_cntr += 1
3622 if enable_prune:
3623 self._prune()
3624 if enable_cap:
3625 self._cap()
3626 if enable_extract:
3627 if extract_kwargs is None:
3628 extract_kwargs = {}
3629 self.extract_states(**extract_kwargs)
3631 def _ospa_setup_emat(self, state_dim, state_inds):
3632 # get sizes
3633 num_timesteps = len(self.states)
3634 num_objs = 0
3635 lbl_to_ind = {}
3637 for lst in self.labels:
3638 for lbl in lst:
3639 if lbl is None:
3640 continue
3641 key = str(lbl)
3642 if key not in lbl_to_ind:
3643 lbl_to_ind[key] = num_objs
3644 num_objs += 1
3645 # create matrices
3646 est_mat = np.nan * np.ones((state_dim, num_timesteps, num_objs))
3647 est_cov_mat = np.nan * np.ones((state_dim, state_dim, num_timesteps, num_objs))
3649 for tt, (lbl_lst, s_lst) in enumerate(zip(self.labels, self.states)):
3650 for lbl, s in zip(lbl_lst, s_lst):
3651 if lbl is None:
3652 continue
3653 obj_num = lbl_to_ind[str(lbl)]
3654 est_mat[:, tt, obj_num] = s.ravel()[state_inds]
3655 if self.save_covs:
3656 for tt, (lbl_lst, c_lst) in enumerate(zip(self.labels, self.covariances)):
3657 for lbl, c in zip(lbl_lst, c_lst):
3658 if lbl is None:
3659 continue
3660 est_cov_mat[:, :, tt, lbl_to_ind[str(lbl)]] = c[state_inds][
3661 :, state_inds
3662 ]
3663 return est_mat, est_cov_mat
3665 def calculate_ospa2(
3666 self,
3667 truth,
3668 c,
3669 p,
3670 win_len,
3671 true_covs=None,
3672 core_method=SingleObjectDistance.MANHATTAN,
3673 state_inds=None,
3674 ):
3675 """Calculates the OSPA(2) distance between the truth at all timesteps.
3677 Wrapper for :func:`serums.distances.calculate_ospa2`.
3679 Parameters
3680 ----------
3681 truth : list
3682 Each element represents a timestep and is a list of N x 1 numpy array,
3683 one per true agent in the swarm.
3684 c : float
3685 Distance cutoff for considering a point properly assigned. This
3686 influences how cardinality errors are penalized. For :math:`p = 1`
3687 it is the penalty given false point estimate.
3688 p : int
3689 The power of the distance term. Higher values penalize outliers
3690 more.
3691 win_len : int
3692 Number of samples to include in window.
3693 core_method : :class:`serums.enums.SingleObjectDistance`, Optional
3694 The main distance measure to use for the localization component.
3695 The default value is :attr:`.SingleObjectDistance.MANHATTAN`.
3696 true_covs : list, Optional
3697 Each element represents a timestep and is a list of N x N numpy arrays
3698 corresonponding to the uncertainty about the true states. Note the
3699 order must be consistent with the truth data given. This is only
3700 needed for core methods :attr:`SingleObjectDistance.HELLINGER`. The defautl
3701 value is None.
3702 state_inds : list, optional
3703 Indices in the state vector to use, will be applied to the truth
3704 data as well. The default is None which means the full state is
3705 used.
3706 """
3707 # error checking on optional input arguments
3708 core_method = self._ospa_input_check(core_method, truth, true_covs)
3710 # setup data structures
3711 if state_inds is None:
3712 state_dim = self._ospa_find_s_dim(truth)
3713 state_inds = range(state_dim)
3714 else:
3715 state_dim = len(state_inds)
3716 if state_dim is None:
3717 warnings.warn("Failed to get state dimension. SKIPPING OSPA(2) calculation")
3719 nt = len(self._states)
3720 self.ospa2 = np.zeros(nt)
3721 self.ospa2_localization = np.zeros(nt)
3722 self.ospa2_cardinality = np.zeros(nt)
3723 self._ospa2_params["core"] = core_method
3724 self._ospa2_params["cutoff"] = c
3725 self._ospa2_params["power"] = p
3726 self._ospa2_params["win_len"] = win_len
3727 return
3728 true_mat, true_cov_mat = self._ospa_setup_tmat(
3729 truth, state_dim, true_covs, state_inds
3730 )
3731 est_mat, est_cov_mat = self._ospa_setup_emat(state_dim, state_inds)
3733 # find OSPA
3734 (
3735 self.ospa2,
3736 self.ospa2_localization,
3737 self.ospa2_cardinality,
3738 self._ospa2_params["core"],
3739 self._ospa2_params["cutoff"],
3740 self._ospa2_params["power"],
3741 self._ospa2_params["win_len"],
3742 ) = calculate_ospa2(
3743 est_mat,
3744 true_mat,
3745 c,
3746 p,
3747 win_len,
3748 core_method=core_method,
3749 true_cov_mat=true_cov_mat,
3750 est_cov_mat=est_cov_mat,
3751 )
3753 def plot_states_labels(
3754 self,
3755 plt_inds,
3756 ttl="Labeled State Trajectories",
3757 x_lbl=None,
3758 y_lbl=None,
3759 meas_tx_fnc=None,
3760 **kwargs,
3761 ):
3762 """Plots the best estimate for the states and labels.
3764 This assumes that the states have been extracted. It's designed to plot
3765 two of the state variables (typically x/y position). The error ellipses
3766 are calculated according to :cite:`Hoover1984_AlgorithmsforConfidenceCirclesandEllipses`
3768 Keywrod arguments are processed with
3769 :meth:`gncpy.plotting.init_plotting_opts`. This function
3770 implements
3772 - f_hndl
3773 - true_states
3774 - sig_bnd
3775 - rng
3776 - meas_inds
3777 - lgnd_loc
3779 Parameters
3780 ----------
3781 plt_inds : list
3782 List of indices in the state vector to plot
3783 ttl : string, optional
3784 Title of the plot.
3785 x_lbl : string, optional
3786 X-axis label for the plot.
3787 y_lbl : string, optional
3788 Y-axis label for the plot.
3789 meas_tx_fnc : callable, optional
3790 Takes in the measurement vector as an Nm x 1 numpy array and
3791 returns a numpy array representing the states to plot (size 2). The
3792 default is None.
3794 Returns
3795 -------
3796 Matplotlib figure
3797 Instance of the matplotlib figure used
3798 """
3799 opts = pltUtil.init_plotting_opts(**kwargs)
3800 f_hndl = opts["f_hndl"]
3801 true_states = opts["true_states"]
3802 sig_bnd = opts["sig_bnd"]
3803 rng = opts["rng"]
3804 meas_inds = opts["meas_inds"]
3805 lgnd_loc = opts["lgnd_loc"]
3806 mrkr = opts["marker"]
3808 if rng is None:
3809 rng = rnd.default_rng(1)
3810 if x_lbl is None:
3811 x_lbl = "x-position"
3812 if y_lbl is None:
3813 y_lbl = "y-position"
3814 meas_specs_given = (
3815 meas_inds is not None and len(meas_inds) == 2
3816 ) or meas_tx_fnc is not None
3817 plt_meas = meas_specs_given and self.save_measurements
3818 show_sig = sig_bnd is not None and self.save_covs
3820 s_lst = deepcopy(self.states)
3821 l_lst = deepcopy(self.labels)
3822 x_dim = None
3824 if f_hndl is None:
3825 f_hndl = plt.figure()
3826 f_hndl.add_subplot(1, 1, 1)
3827 # get state dimension
3828 for states in s_lst:
3829 if states is not None and len(states) > 0:
3830 x_dim = states[0].size
3831 break
3832 # get unique labels
3833 u_lbls = []
3834 for lbls in l_lst:
3835 if lbls is None:
3836 continue
3837 for lbl in lbls:
3838 if lbl not in u_lbls:
3839 u_lbls.append(lbl)
3840 cmap = pltUtil.get_cmap(len(u_lbls))
3842 # get array of all state values for each label
3843 added_sig_lbl = False
3844 added_true_lbl = False
3845 added_state_lbl = False
3846 added_meas_lbl = False
3847 for c_idx, lbl in enumerate(u_lbls):
3848 x = np.nan * np.ones((x_dim, len(s_lst)))
3849 if show_sig:
3850 sigs = [None] * len(s_lst)
3851 for tt, lbls in enumerate(l_lst):
3852 if lbls is None:
3853 continue
3854 if lbl in lbls:
3855 ii = lbls.index(lbl)
3856 if s_lst[tt][ii] is not None:
3857 x[:, [tt]] = s_lst[tt][ii].copy()
3858 if show_sig:
3859 sig = np.zeros((2, 2))
3860 if self._covs[tt][ii] is not None:
3861 sig[0, 0] = self._covs[tt][ii][plt_inds[0], plt_inds[0]]
3862 sig[0, 1] = self._covs[tt][ii][plt_inds[0], plt_inds[1]]
3863 sig[1, 0] = self._covs[tt][ii][plt_inds[1], plt_inds[0]]
3864 sig[1, 1] = self._covs[tt][ii][plt_inds[1], plt_inds[1]]
3865 else:
3866 sig = None
3867 sigs[tt] = sig
3868 # plot
3869 color = cmap(c_idx)
3871 if show_sig:
3872 for tt, sig in enumerate(sigs):
3873 if sig is None:
3874 continue
3875 w, h, a = pltUtil.calc_error_ellipse(sig, sig_bnd)
3876 if not added_sig_lbl:
3877 s = r"${}\sigma$ Error Ellipses".format(sig_bnd)
3878 e = Ellipse(
3879 xy=x[plt_inds, tt],
3880 width=w,
3881 height=h,
3882 angle=a,
3883 zorder=-10000,
3884 label=s,
3885 )
3886 added_sig_lbl = True
3887 else:
3888 e = Ellipse(
3889 xy=x[plt_inds, tt],
3890 width=w,
3891 height=h,
3892 angle=a,
3893 zorder=-10000,
3894 )
3895 e.set_clip_box(f_hndl.axes[0].bbox)
3896 e.set_alpha(0.2)
3897 e.set_facecolor(color)
3898 f_hndl.axes[0].add_patch(e)
3899 settings = {
3900 "color": color,
3901 "markeredgecolor": "k",
3902 "marker": mrkr,
3903 "ls": "--",
3904 }
3905 if not added_state_lbl:
3906 settings["label"] = "States"
3907 # f_hndl.axes[0].scatter(x[plt_inds[0], :], x[plt_inds[1], :],
3908 # color=color, edgecolors='k',
3909 # label='States')
3910 added_state_lbl = True
3911 # else:
3912 f_hndl.axes[0].plot(x[plt_inds[0], :], x[plt_inds[1], :], **settings)
3914 s = "({}, {})".format(lbl[0], lbl[1])
3915 tmp = x.copy()
3916 tmp = tmp[:, ~np.any(np.isnan(tmp), axis=0)]
3917 f_hndl.axes[0].text(
3918 tmp[plt_inds[0], 0], tmp[plt_inds[1], 0], s, color=color
3919 )
3920 # if true states are available then plot them
3921 if true_states is not None and any([len(x) > 0 for x in true_states]):
3922 if x_dim is None:
3923 for states in true_states:
3924 if len(states) > 0:
3925 x_dim = states[0].size
3926 break
3927 max_true = max([len(x) for x in true_states])
3928 x = np.nan * np.ones((x_dim, len(true_states), max_true))
3929 for tt, states in enumerate(true_states):
3930 for ii, state in enumerate(states):
3931 if state is not None and state.size > 0:
3932 x[:, [tt], ii] = state.copy()
3933 for ii in range(0, max_true):
3934 if not added_true_lbl:
3935 f_hndl.axes[0].plot(
3936 x[plt_inds[0], :, ii],
3937 x[plt_inds[1], :, ii],
3938 color="k",
3939 marker=".",
3940 label="True Trajectories",
3941 )
3942 added_true_lbl = True
3943 else:
3944 f_hndl.axes[0].plot(
3945 x[plt_inds[0], :, ii],
3946 x[plt_inds[1], :, ii],
3947 color="k",
3948 marker=".",
3949 )
3950 if plt_meas:
3951 meas_x = []
3952 meas_y = []
3953 for meas_tt in self._meas_tab:
3954 if meas_tx_fnc is not None:
3955 tx_meas = [meas_tx_fnc(m) for m in meas_tt]
3956 mx_ii = [tm[0].item() for tm in tx_meas]
3957 my_ii = [tm[1].item() for tm in tx_meas]
3958 else:
3959 mx_ii = [m[meas_inds[0]].item() for m in meas_tt]
3960 my_ii = [m[meas_inds[1]].item() for m in meas_tt]
3961 meas_x.extend(mx_ii)
3962 meas_y.extend(my_ii)
3963 color = (128 / 255, 128 / 255, 128 / 255)
3964 meas_x = np.asarray(meas_x)
3965 meas_y = np.asarray(meas_y)
3966 if meas_x.size > 0:
3967 if not added_meas_lbl:
3968 f_hndl.axes[0].scatter(
3969 meas_x,
3970 meas_y,
3971 zorder=-1,
3972 alpha=0.35,
3973 color=color,
3974 marker="^",
3975 label="Measurements",
3976 )
3977 else:
3978 f_hndl.axes[0].scatter(
3979 meas_x, meas_y, zorder=-1, alpha=0.35, color=color, marker="^"
3980 )
3981 f_hndl.axes[0].grid(True)
3982 pltUtil.set_title_label(
3983 f_hndl, 0, opts, ttl=ttl, x_lbl="x-position", y_lbl="y-position"
3984 )
3985 if lgnd_loc is not None:
3986 plt.legend(loc=lgnd_loc)
3987 plt.tight_layout()
3989 return f_hndl
3991 def plot_card_dist(self, ttl=None, **kwargs):
3992 """Plots the current cardinality distribution.
3994 This assumes that the cardinality distribution has been calculated by
3995 the class.
3997 Keywrod arguments are processed with
3998 :meth:`gncpy.plotting.init_plotting_opts`. This function
3999 implements
4001 - f_hndl
4003 Parameters
4004 ----------
4005 ttl : string
4006 Title of the plot, if None a default title is generated. The default
4007 is None.
4009 Returns
4010 -------
4011 Matplotlib figure
4012 Instance of the matplotlib figure used
4013 """
4014 opts = pltUtil.init_plotting_opts(**kwargs)
4015 f_hndl = opts["f_hndl"]
4016 if ttl is None:
4017 ttl = "Cardinality Distribution"
4018 if len(self._card_dist) == 0:
4019 raise RuntimeWarning("Empty Cardinality")
4020 return f_hndl
4021 if f_hndl is None:
4022 f_hndl = plt.figure()
4023 f_hndl.add_subplot(1, 1, 1)
4024 x_vals = np.arange(0, len(self._card_dist))
4025 f_hndl.axes[0].bar(x_vals, self._card_dist)
4027 pltUtil.set_title_label(
4028 f_hndl, 0, opts, ttl=ttl, x_lbl="Cardinality", y_lbl="Probability"
4029 )
4030 plt.tight_layout()
4032 return f_hndl
4034 def plot_card_history(
4035 self, time_units="index", time=None, ttl="Cardinality History", **kwargs
4036 ):
4037 """Plots the cardinality history.
4039 Parameters
4040 ----------
4041 time_units : string, optional
4042 Text representing the units of time in the plot. The default is
4043 'index'.
4044 time : numpy array, optional
4045 Vector to use for the x-axis of the plot. If none is given then
4046 vector indices are used. The default is None.
4047 ttl : string, optional
4048 Title of the plot.
4049 **kwargs : dict
4050 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts`
4051 function. Values implemented here are `f_hndl`, and any values
4052 relating to title/axis text formatting.
4054 Returns
4055 -------
4056 fig : matplotlib figure
4057 Figure object the data was plotted on.
4058 """
4059 card_history = np.array([len(state_set) for state_set in self.states])
4061 opts = pltUtil.init_plotting_opts(**kwargs)
4062 fig = opts["f_hndl"]
4064 if fig is None:
4065 fig = plt.figure()
4066 fig.add_subplot(1, 1, 1)
4067 if time is None:
4068 time = np.arange(card_history.size, dtype=int)
4069 fig.axes[0].grid(True)
4070 fig.axes[0].step(time, card_history, where="post", label="estimated", color="k")
4071 fig.axes[0].ticklabel_format(useOffset=False)
4073 pltUtil.set_title_label(
4074 fig,
4075 0,
4076 opts,
4077 ttl=ttl,
4078 x_lbl="Time ({})".format(time_units),
4079 y_lbl="Cardinality",
4080 )
4081 fig.tight_layout()
4083 return fig
4085 def plot_ospa2_history(
4086 self,
4087 time_units="index",
4088 time=None,
4089 main_opts=None,
4090 sub_opts=None,
4091 plot_subs=True,
4092 ):
4093 """Plots the OSPA2 history.
4095 This requires that the OSPA2 has been calcualted by the approriate
4096 function first.
4098 Parameters
4099 ----------
4100 time_units : string, optional
4101 Text representing the units of time in the plot. The default is
4102 'index'.
4103 time : numpy array, optional
4104 Vector to use for the x-axis of the plot. If none is given then
4105 vector indices are used. The default is None.
4106 main_opts : dict, optional
4107 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts`
4108 function. Values implemented here are `f_hndl`, and any values
4109 relating to title/axis text formatting. The default of None implies
4110 the default options are used for the main plot.
4111 sub_opts : dict, optional
4112 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts`
4113 function. Values implemented here are `f_hndl`, and any values
4114 relating to title/axis text formatting. The default of None implies
4115 the default options are used for the sub plot.
4116 plot_subs : bool, optional
4117 Flag indicating if the component statistics (cardinality and
4118 localization) should also be plotted.
4120 Returns
4121 -------
4122 figs : dict
4123 Dictionary of matplotlib figure objects the data was plotted on.
4124 """
4125 if self.ospa2 is None:
4126 warnings.warn("OSPA must be calculated before plotting")
4127 return
4128 if main_opts is None:
4129 main_opts = pltUtil.init_plotting_opts()
4130 if sub_opts is None and plot_subs:
4131 sub_opts = pltUtil.init_plotting_opts()
4132 fmt = "{:s} OSPA2 (c = {:.1f}, p = {:d}, w={:d})"
4133 ttl = fmt.format(
4134 self._ospa2_params["core"],
4135 self._ospa2_params["cutoff"],
4136 self._ospa2_params["power"],
4137 self._ospa2_params["win_len"],
4138 )
4139 y_lbl = "OSPA2"
4141 figs = {}
4142 figs["OSPA2"] = self._plt_ospa_hist(
4143 self.ospa2, time_units, time, ttl, y_lbl, main_opts
4144 )
4146 if plot_subs:
4147 fmt = "{:s} OSPA2 Components (c = {:.1f}, p = {:d}, w={:d})"
4148 ttl = fmt.format(
4149 self._ospa2_params["core"],
4150 self._ospa2_params["cutoff"],
4151 self._ospa2_params["power"],
4152 self._ospa2_params["win_len"],
4153 )
4154 y_lbls = ["Localiztion", "Cardinality"]
4155 figs["OSPA2_subs"] = self._plt_ospa_hist_subs(
4156 [self.ospa2_localization, self.ospa2_cardinality],
4157 time_units,
4158 time,
4159 ttl,
4160 y_lbls,
4161 main_opts,
4162 )
4163 return figs
4166class _STMGLMBBase:
4167 def __init__(self, **kwargs):
4168 super().__init__(**kwargs)
4170 def _init_filt_states(self, distrib):
4171 filt_states = [None] * len(distrib.means)
4172 states = [m.copy() for m in distrib.means]
4173 covs = [None] * len(distrib.means)
4175 weights = distrib.weights.copy()
4176 self._baseFilter.dof = distrib.dof
4177 for ii, scale in enumerate(distrib.scalings):
4178 self._baseFilter.scale = scale.copy()
4179 filt_states[ii] = self._baseFilter.save_filter_state()
4180 if self.save_covs:
4181 # no need to copy because cov is already a new object for the student's t-fitler
4182 covs[ii] = self.filter.cov
4183 return filt_states, weights, states, covs
4185 def _gate_meas(self, meas, means, covs, **kwargs):
4186 # TODO: check this implementation
4187 if len(meas) == 0:
4188 return []
4189 scalings = []
4190 for ent in self._track_tab:
4191 scalings.extend(ent.probDensity.scalings)
4192 valid = []
4193 for m, p in zip(means, scalings):
4194 meas_mat = self.filter.get_meas_mat(m, **kwargs)
4195 est = self.filter.get_est_meas(m, **kwargs)
4196 factor = (
4197 self.filter.meas_noise_dof
4198 * (self.filter.dof - 2)
4199 / (self.filter.dof * (self.filter.meas_noise_dof - 2))
4200 )
4201 P_zz = meas_mat @ p @ meas_mat.T + factor * self.filter.meas_noise
4202 inv_P = la.inv(P_zz)
4204 for ii, z in enumerate(meas):
4205 if ii in valid:
4206 continue
4207 innov = z - est
4208 dist = innov.T @ inv_P @ innov
4209 if dist < self.inv_chi2_gate:
4210 valid.append(ii)
4211 valid.sort()
4212 return [meas[ii] for ii in valid]
4215# Note: need inherited classes in this order for proper MRO
4216class STMGeneralizedLabeledMultiBernoulli(
4217 _STMGLMBBase, GeneralizedLabeledMultiBernoulli
4218):
4219 """Implementation of a STM-GLMB filter."""
4221 def __init__(self, **kwargs):
4222 super().__init__(**kwargs)
4225class _SMCGLMBBase:
4226 def __init__(
4227 self, compute_prob_detection=None, compute_prob_survive=None, **kwargs
4228 ):
4229 self.compute_prob_detection = compute_prob_detection
4230 self.compute_prob_survive = compute_prob_survive
4232 # for wrappers for predict/correct function to handle extra args for private functions
4233 self._prob_surv_args = ()
4234 self._prob_det_args = ()
4236 super().__init__(**kwargs)
4238 def _init_filt_states(self, distrib):
4239 self._baseFilter.init_from_dist(distrib, make_copy=True)
4240 filt_states = [
4241 self._baseFilter.save_filter_state(),
4242 ]
4243 states = [distrib.mean]
4244 if self.save_covs:
4245 covs = [
4246 distrib.covariance,
4247 ]
4248 else:
4249 covs = []
4250 weights = [
4251 1,
4252 ] # not needed so set to 1
4254 return filt_states, weights, states, covs
4256 def _calc_avg_prob_surv_death(self):
4257 avg_prob_survive = np.zeros(len(self._track_tab))
4258 for tabidx, ent in enumerate(self._track_tab):
4259 # TODO: fix hack so not using "private" variable outside class
4260 p_surv = self.compute_prob_survive(
4261 ent.filt_states[0]["_particleDist"].particles, *self._prob_surv_args
4262 )
4263 avg_prob_survive[tabidx] = np.sum(
4264 np.array(ent.filt_states[0]["_particleDist"].weights) * p_surv
4265 )
4266 avg_prob_death = 1 - avg_prob_survive
4268 return avg_prob_survive, avg_prob_death
4270 def _inner_predict(self, timestep, filt_state, state, filt_args):
4271 self.filter.load_filter_state(filt_state)
4272 if self.filter._particleDist.num_particles > 0:
4273 new_s = self.filter.predict(timestep, **filt_args)
4275 # manually update weights to account for prob survive
4276 # TODO: fix hack so not using "private" variable outside class
4277 ps = self.compute_prob_survive(
4278 self.filter._particleDist.particles, *self._prob_surv_args
4279 )
4280 new_weights = [
4281 w * ps[ii] for ii, (p, w) in enumerate(self.filter._particleDist)
4282 ]
4283 tot = sum(new_weights)
4284 if np.abs(tot) == np.inf:
4285 w_lst = [np.inf] * len(new_weights)
4286 else:
4287 w_lst = [w / tot for w in new_weights]
4288 self.filter._particleDist.update_weights(w_lst)
4290 new_f_state = self.filter.save_filter_state()
4291 if self.save_covs:
4292 new_cov = self.filter.cov.copy()
4293 else:
4294 new_cov = None
4295 else:
4296 new_f_state = self.filter.save_filter_state()
4297 new_s = state
4298 new_cov = self.filter.cov
4299 return new_f_state, new_s, new_cov
4301 def predict(self, timestep, prob_surv_args=(), **kwargs):
4302 """Prediction step of the SMC-GLMB filter.
4304 This is a wrapper for the parent class to allow for extra parameters.
4305 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.predict` for
4306 additional details.
4308 Parameters
4309 ----------
4310 timestep : float
4311 Current timestep.
4312 prob_surv_args : tuple, optional
4313 Additional arguments for the `compute_prob_survive` function.
4314 The default is ().
4315 **kwargs : dict, optional
4316 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.predict`
4317 """
4318 self._prob_surv_args = prob_surv_args
4319 return super().predict(timestep, **kwargs)
4321 def _calc_avg_prob_det_mdet(self):
4322 avg_prob_detect = np.zeros(len(self._track_tab))
4323 for tabidx, ent in enumerate(self._track_tab):
4324 # TODO: fix hack so not using "private" variable outside class
4325 p_detect = self.compute_prob_detection(
4326 ent.filt_states[0]["_particleDist"].particles, *self._prob_det_args
4327 )
4328 avg_prob_detect[tabidx] = np.sum(
4329 np.array(ent.filt_states[0]["_particleDist"].weights) * p_detect
4330 )
4331 avg_prob_miss_detect = 1 - avg_prob_detect
4333 return avg_prob_detect, avg_prob_miss_detect
4335 def _inner_correct(
4336 self, timestep, meas, filt_state, distrib_weight, state, filt_args
4337 ):
4338 self.filter.load_filter_state(filt_state)
4339 if self.filter._particleDist.num_particles > 0:
4340 cor_state, likely = self.filter.correct(timestep, meas, **filt_args)[0:2]
4342 # manually update the particle weights to account for probability of detection
4343 # TODO: fix hack so not using "private" variable outside class
4344 pd = self.compute_prob_detection(
4345 self.filter._particleDist.particles, *self._prob_det_args
4346 )
4347 pd_weight = (
4348 pd * np.array(self.filter._particleDist.weights) + np.finfo(float).eps
4349 )
4350 self.filter._particleDist.update_weights(
4351 (pd_weight / np.sum(pd_weight)).tolist()
4352 )
4354 # determine the partial cost, the remainder is calculated later from
4355 # the hypothesis
4356 new_w = np.sum(likely * pd_weight) # same as cost in this case
4358 new_f_state = self.filter.save_filter_state()
4359 new_s = cor_state
4360 if self.save_covs:
4361 new_c = self.filter.cov
4362 else:
4363 new_c = None
4364 else:
4365 new_f_state = self.filter.save_filter_state()
4366 new_s = state
4367 new_c = self.filter.cov
4368 new_w = 0
4369 return new_f_state, new_s, new_c, new_w
4371 def correct(self, timestep, meas, prob_det_args=(), **kwargs):
4372 """Correction step of the SMC-GLMB filter.
4374 This is a wrapper for the parent class to allow for extra parameters.
4375 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.correct` for
4376 additional details.
4378 Parameters
4379 ----------
4380 timestep : float
4381 Current timestep.
4382 prob_det_args : tuple, optional
4383 Additional arguments for the `compute_prob_detection` function.
4384 The default is ().
4385 **kwargs : dict, optional
4386 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.correct`
4387 """
4388 self._prob_det_args = prob_det_args
4389 return super().correct(timestep, meas, **kwargs)
4391 def extract_most_prob_states(self, thresh, **kwargs):
4392 """Extracts themost probable states.
4394 .. todo::
4395 Implement this function for the SMC-GLMB filter
4397 Raises
4398 ------
4399 RuntimeWarning
4400 Function must be implemented.
4401 """
4402 warnings.warn("Not implemented for this class")
4405# Note: need inherited classes in this order for proper MRO
4406class SMCGeneralizedLabeledMultiBernoulli(
4407 _SMCGLMBBase, GeneralizedLabeledMultiBernoulli
4408):
4409 """Implementation of a Sequential Monte Carlo GLMB filter.
4411 This is based on :cite:`Vo2014_LabeledRandomFiniteSetsandtheBayesMultiTargetTrackingFilter`
4412 It does not account for agents spawned from existing tracks, only agents
4413 birthed from the given birth model.
4415 Attributes
4416 ----------
4417 compute_prob_detection : callable
4418 Function that takes a list of particles as the first argument and `*args`
4419 as the next. Returns the probability of detection for each particle as a list.
4420 compute_prob_survive : callable
4421 Function that takes a list of particles as the first argument and `*args` as
4422 the next. Returns the average probability of survival for each particle as a list.
4423 """
4425 def __init__(self, **kwargs):
4426 super().__init__(**kwargs)
4429class GSMGeneralizedLabeledMultiBernoulli(GeneralizedLabeledMultiBernoulli):
4430 """Implementation of a GSM-GLMB filter.
4432 The implementation of the GSM-GLMB fitler does not change for different core
4433 filters (i.e. QKF GSM, SQKF GSM, UKF GSM, etc.) so this class can use any
4434 of the GSM inner filters from gncpy.filters
4435 """
4437 def __init__(self, **kwargs):
4438 super().__init__(**kwargs)
4441class _IMMGLMBBase:
4442 def __init__(self, **kwargs):
4443 super().__init__(**kwargs)
4445 def _init_filt_states(self, distrib):
4446 filt_states = [None] * len(distrib.means)
4447 states = [m.copy() for m in distrib.means]
4448 if self.save_covs:
4449 covs = [c.copy() for c in distrib.covariances]
4450 else:
4451 covs = []
4452 weights = distrib.weights.copy()
4453 for ii, (m, cov) in enumerate(zip(distrib.means, distrib.covariances)):
4454 # if len(m) != 1 or len(cov) != 1:
4455 # raise ValueError("Only one mean can be passed to IMM filters for initialization")
4456 m_list = []
4457 c_list = []
4458 for jj in range(0, len(self._baseFilter.in_filt_list)):
4459 m_list.append(m)
4460 c_list.append(cov)
4461 self._baseFilter.initialize_states(m_list, c_list)
4462 filt_states[ii] = self._baseFilter.save_filter_state()
4463 return filt_states, weights, states, covs
4465 def _inner_predict(self, timestep, filt_state, state, filt_args):
4466 self.filter.load_filter_state(filt_state)
4467 new_s = self.filter.predict(timestep, **filt_args)
4468 new_f_state = self.filter.save_filter_state()
4469 if self.save_covs:
4470 new_cov = self.filter.cov.copy()
4471 else:
4472 new_cov = None
4473 return new_f_state, new_s, new_cov
4475 def _inner_correct(
4476 self, timestep, meas, filt_state, distrib_weight, state, filt_args
4477 ):
4478 self.filter.load_filter_state(filt_state)
4479 cor_state, likely = self.filter.correct(timestep, meas, **filt_args)
4480 new_f_state = self.filter.save_filter_state()
4481 new_s = cor_state
4482 if self.save_covs:
4483 new_c = self.filter.cov.copy()
4484 else:
4485 new_c = None
4486 new_w = distrib_weight * likely
4488 return new_f_state, new_s, new_c, new_w
4491class IMMGeneralizedLabeledMultiBernoulli(
4492 _IMMGLMBBase, GeneralizedLabeledMultiBernoulli
4493):
4494 """An implementation of the IMM-GLMB algorithm."""
4496 def __init__(self, **kwargs):
4497 super().__init__(**kwargs)
4500class JointGeneralizedLabeledMultiBernoulli(GeneralizedLabeledMultiBernoulli):
4501 """Implements a Joint Generalized Labeled Multi-Bernoulli Filter.
4503 The Joint GLMB is designed to call predict and correct simultaneously,
4504 as a single joint prediction-correction step.
4505 Calling them asynchronously may cause poor performance.
4507 Notes
4508 -----
4509 This is based on :cite:`Vo2017_AnEfficientImplementationoftheGeneralizedLabeledMultiBernoulliFilter`.
4510 It does not account for agents spawned from existing tracks, only agents
4511 birthed from the given birth model.
4512 """
4514 def __init__(self, rng=None, **kwargs):
4515 super().__init__(**kwargs)
4516 self._old_track_tab_len = len(self._track_tab)
4517 self._update_has_been_called = (
4518 True # used to denote if the update function should be called or not.
4519 )
4520 if rng is None:
4521 self._rng = np.random.default_rng()
4522 else:
4523 self._rng = rng
4525 def save_filter_state(self):
4526 """Saves filter variables so they can be restored later.
4528 Note that to pickle the resulting dictionary the :code:`dill` package
4529 may need to be used due to potential pickling of functions.
4530 """
4531 filt_state = super().save_filter_state()
4533 filt_state["_old_track_tab_len"] = self._old_track_tab_len
4535 return filt_state
4537 def load_filter_state(self, filt_state):
4538 """Initializes filter using saved filter state.
4540 Attributes
4541 ----------
4542 filt_state : dict
4543 Dictionary generated by :meth:`save_filter_state`.
4544 """
4545 super().load_filter_state(filt_state)
4547 self._old_track_tab_len = filt_state["_old_track_tab_len"]
4549 def predict(self, timestep, filt_args={}):
4550 """Prediction step of the JGLMB filter.
4552 This predicts new hypothesis, and propogates them to the next time
4553 step. Because this calls
4554 the inner filter's predict function, the keyword arguments must contain
4555 any information needed by that function.
4557 Parameters
4558 ----------
4559 timestep: float
4560 Current timestep.
4561 filt_args : dict, optional
4562 Passed to the inner filter. The default is {}.
4564 Returns
4565 -------
4566 None.
4567 """
4568 if self._update_has_been_called:
4569 # Birth Track Table
4570 birth_tab = self._gen_birth_tab(timestep)[0]
4571 else:
4572 birth_tab = []
4573 warnings.warn("Joint GLMB should call predict and correct simultaneously")
4574 self._update_has_been_called = False
4576 # Survival Track Table
4577 surv_tab = self._gen_surv_tab(timestep, filt_args)
4579 # Prediction Track Table
4581 self._track_tab = birth_tab + surv_tab
4583 def _unique_faster(self, keys):
4584 difference = np.diff(np.append(keys, np.nan), n=1, axis=0)
4585 keyind = np.not_equal(difference, 0)
4586 mindices = (keys[0][np.where(keyind)]).astype(int)
4587 return mindices
4589 def _calc_avg_prob_surv_death(self):
4590 avg_surv = np.zeros(len(self.birth_terms) + self._old_track_tab_len)
4591 for ii in range(0, avg_surv.shape[0]):
4592 if ii <= len(self.birth_terms) - 1:
4593 avg_surv[ii] = self.birth_terms[ii][1]
4594 else:
4595 avg_surv[ii] = self.prob_survive
4596 # avg_surv = np.array([avg_surv]).T
4597 avg_death = 1 - avg_surv
4598 return avg_surv, avg_death
4600 def _calc_avg_prob_det_mdet(self):
4601 avg_detect = self.prob_detection * np.ones(len(self._track_tab))
4602 # avg_detect = np.array([avg_detect]).T
4603 avg_miss = 1 - avg_detect
4604 return avg_detect, avg_miss
4606 def _gen_cor_tab(self, num_meas, meas, timestep, filt_args):
4607 num_pred = len(self._track_tab)
4608 up_tab = [None] * (num_meas + 1) * num_pred
4610 for ii, track in enumerate(self._track_tab):
4611 up_tab[ii] = self._TabEntry().setup(track)
4612 up_tab[ii].meas_assoc_hist.append(None)
4613 # measurement updated tracks
4614 all_cost_m = np.zeros((num_pred, num_meas))
4615 # for emm, z in enumerate(meas):
4616 for ii, ent in enumerate(self._track_tab):
4617 for emm, z in enumerate(meas):
4618 s_to_ii = num_pred * emm + ii + num_pred
4619 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry(
4620 z, ent, timestep, filt_args
4621 )
4623 # update association history with current measurement index
4624 if up_tab[s_to_ii] is not None:
4625 up_tab[s_to_ii].meas_assoc_hist.append(emm)
4626 all_cost_m[ii, emm] = cost
4627 return up_tab, all_cost_m
4629 def _gen_cor_hyps(
4630 self,
4631 num_meas,
4632 avg_prob_detect,
4633 avg_prob_miss_detect,
4634 avg_prob_surv,
4635 avg_prob_death,
4636 all_cost_m,
4637 ):
4638 # Define clutter
4639 clutter = self.clutter_rate * self.clutter_den
4640 # clutter = self.clutter_den
4642 # Joint Cost Matrix
4643 joint_cost = np.concatenate(
4644 [
4645 np.diag(avg_prob_death.ravel()),
4646 np.diag(avg_prob_surv.ravel() * avg_prob_miss_detect.ravel()),
4647 ],
4648 axis=1,
4649 )
4651 other_jc_terms = (
4652 np.tile((avg_prob_surv * avg_prob_detect).reshape((-1, 1)), (1, num_meas))
4653 * all_cost_m
4654 / (clutter)
4655 )
4657 # Full joint cost matrix
4658 joint_cost = np.append(joint_cost, other_jc_terms, axis=1)
4660 # Gated Measurement index matrix
4661 gate_meas_indices = np.zeros((len(self._track_tab), num_meas))
4662 for ii in range(0, len(self._track_tab)):
4663 for jj in range(0, len(self._track_tab[ii].gatemeas)):
4664 gate_meas_indices[ii][jj] = self._track_tab[ii].gatemeas[jj]
4665 gate_meas_indc = gate_meas_indices >= 0
4667 # Component updates
4668 ss_w = 0
4669 up_hyp = []
4670 for p_hyp in self._hypotheses:
4671 ss_w += np.sqrt(p_hyp.assoc_prob)
4672 for p_hyp in self._hypotheses:
4673 cpreds = len(self._track_tab)
4674 num_births = len(self.birth_terms)
4675 num_exists = len(p_hyp.track_set)
4676 num_tracks = num_births + num_exists
4678 # Hypothesis index masking
4679 tindices = np.concatenate(
4680 (np.arange(0, num_births), num_births + np.array(p_hyp.track_set))
4681 ).astype(int)
4682 lselmask = np.zeros((len(self._track_tab), num_meas), dtype="bool")
4683 lselmask[tindices,] = gate_meas_indc[tindices,]
4685 keys = np.array([np.sort(gate_meas_indices[lselmask])])
4686 mindices = self._unique_faster(keys)
4688 comb_tind_cpred = np.append(
4689 np.append(tindices, cpreds + tindices), [2 * cpreds + mindices]
4690 )
4691 # print(joint_cost.shape)
4692 # print(tindices)
4693 # print(comb_tind_cpred)
4694 cost_m = joint_cost[tindices][:, comb_tind_cpred]
4695 # print(cost_m.shape)
4696 # cost_m = np.zeros((len(tindices), len(comb_tind_cpred)))
4697 # cmi = 0
4698 # for ind in tindices:
4699 # cost_m[cmi, :] = joint_cost[ind, comb_tind_cpred]
4700 # cmi = cmi + 1
4701 with warnings.catch_warnings():
4702 warnings.simplefilter("ignore", RuntimeWarning)
4703 neg_log = -np.log(cost_m)
4705 m = np.round(self.req_upd * np.sqrt(p_hyp.assoc_prob) / ss_w)
4706 m = int(m.item()) + 1
4708 # Gibbs Sampler
4709 [assigns, costs] = gibbs(neg_log, m, rng=self._rng)
4711 # Process unique assighnments from gibbs sampler
4712 assigns[assigns < num_tracks] = -np.inf
4713 for ii in range(np.shape(assigns)[0]):
4714 if len(np.shape(assigns)) < 2:
4715 if assigns[ii] >= num_tracks and assigns[ii] < 2 * num_tracks:
4716 assigns[ii] = -1
4717 else:
4718 for jj in range(np.shape(assigns)[1]):
4719 if (
4720 assigns[ii][jj] >= num_tracks
4721 and assigns[ii][jj] < 2 * num_tracks
4722 ):
4723 assigns[ii][jj] = -1
4724 assigns[assigns >= 2 * num_tracks] -= 2 * num_tracks
4725 if assigns[assigns >= 0].size != 0:
4726 assigns[assigns >= 0] = mindices[
4727 assigns[assigns >= 0].astype(int)[
4728 assigns[assigns >= 0].astype(int) >= 0
4729 ]
4730 ]
4731 # Assign updated hypotheses from gibbs sampler
4732 for c, cst in enumerate(costs.flatten()):
4733 update_hyp_cmp_temp = assigns[c,]
4734 update_hyp_cmp_idx = cpreds * (update_hyp_cmp_temp + 1) + np.append(
4735 np.array([np.arange(0, num_births)]),
4736 num_births + np.array([p_hyp.track_set]),
4737 )
4738 new_hyp = self._HypothesisHelper()
4739 new_hyp.assoc_prob = (
4740 -self.clutter_rate
4741 + num_meas * np.log(clutter)
4742 + np.log(p_hyp.assoc_prob)
4743 - cst
4744 )
4745 new_hyp.track_set = update_hyp_cmp_idx[update_hyp_cmp_idx >= 0].astype(
4746 int
4747 )
4748 up_hyp.append(new_hyp)
4749 lse = log_sum_exp([x.assoc_prob for x in up_hyp])
4751 for ii in range(0, len(up_hyp)):
4752 up_hyp[ii].assoc_prob = np.exp(up_hyp[ii].assoc_prob - lse)
4753 return up_hyp
4755 def correct(self, timestep, meas, filt_args={}):
4756 """Correction step of the JGLMB filter.
4758 This corrects the hypotheses based on the measurements and gates the
4759 measurements according to the class settings. It also updates the
4760 cardinality distribution. Because this calls the inner filter's correct
4761 function, the keyword arguments must contain any information needed by
4762 that function.
4764 Parameters
4765 ----------
4766 timestep: float
4767 Current timestep.
4768 meas_in : list
4769 List of Nm x 1 numpy arrays each representing a measuremnt.
4770 filt_args : dict, optional
4771 keyword arguments to pass to the inner filters correct function.
4772 The default is {}.
4774 Todo
4775 ----
4776 Fix the measurement gating
4778 Returns
4779 -------
4780 None
4781 """
4782 # gating by tracks
4783 if self.gating_on:
4784 RuntimeError("Gating not implemented yet. PLEASE TURN OFF GATING")
4785 # for ent in self._track_tab:
4786 # ent.gatemeas = self._gate_meas(meas, ent.probDensity.means,
4787 # ent.probDensity.covariances)
4788 else:
4789 for ent in self._track_tab:
4790 ent.gatemeas = np.arange(0, len(meas))
4791 # Pre-calculation of average survival/death probabilities
4792 avg_prob_surv, avg_prob_death = self._calc_avg_prob_surv_death()
4794 # Pre-calculation of average detection/missed probabilities
4795 avg_prob_detect, avg_prob_miss_detect = self._calc_avg_prob_det_mdet()
4797 if self.save_measurements:
4798 self._meas_tab.append(deepcopy(meas))
4799 num_meas = len(meas)
4801 # missed detection tracks
4802 [up_tab, all_cost_m] = self._gen_cor_tab(num_meas, meas, timestep, filt_args)
4804 up_hyp = self._gen_cor_hyps(
4805 num_meas,
4806 avg_prob_detect,
4807 avg_prob_miss_detect,
4808 avg_prob_surv,
4809 avg_prob_death,
4810 all_cost_m,
4811 )
4813 self._track_tab = up_tab
4814 self._hypotheses = up_hyp
4815 self._card_dist = self._calc_card_dist(self._hypotheses)
4816 self._clean_predictions()
4817 self._clean_updates()
4818 self._update_has_been_called = True
4819 self._old_track_tab_len = len(self._track_tab)
4822class STMJointGeneralizedLabeledMultiBernoulli(
4823 _STMGLMBBase, JointGeneralizedLabeledMultiBernoulli
4824):
4825 """Implementation of a STM-JGLMB class."""
4827 def __init__(self, **kwargs):
4828 super().__init__(**kwargs)
4831class SMCJointGeneralizedLabeledMultiBernoulli(
4832 _SMCGLMBBase, JointGeneralizedLabeledMultiBernoulli
4833):
4834 """Implementation of a SMC-JGLMB filter."""
4836 def __init__(self, **kwargs):
4837 super().__init__(**kwargs)
4840class GSMJointGeneralizedLabeledMultiBernoulli(JointGeneralizedLabeledMultiBernoulli):
4841 """Implementation of a GSM-JGLMB filter.
4843 The implementation of the GSM-JGLMB fitler does not change for different
4844 core filters (i.e. QKF GSM, SQKF GSM, UKF GSM, etc.) so this class can use
4845 any of the GSM inner filters from gncpy.filters
4846 """
4848 def __init__(self, **kwargs):
4849 super().__init__(**kwargs)
4852class IMMJointGeneralizedLabeledMultiBernoulli(
4853 _IMMGLMBBase, JointGeneralizedLabeledMultiBernoulli
4854):
4855 """Implementation of an IMM-JGLMB filter."""
4857 def __init__(self, **kwargs):
4858 super().__init__(**kwargs)
4861class MSJointGeneralizedLabeledMultiBernoulli(JointGeneralizedLabeledMultiBernoulli):
4862 """Implementation of the Multiple Sensor JGLMB Filter"""
4864 def __init__(self, **kwargs):
4865 super().__init__(**kwargs)
4867 def _gen_cor_tab(self, num_meas, meas, timestep, comb_inds, filt_args):
4868 num_pred = len(self._track_tab)
4869 num_sens = len(meas)
4870 # if len(meas) != len(self.filter.meas_model_list):
4871 # raise ValueError("measurement lists must match number of measurement models")
4872 up_tab = [None] * (num_meas + 1) * num_pred
4874 for ii, track in enumerate(self._track_tab):
4875 up_tab[ii] = self._TabEntry().setup(track)
4876 up_tab[ii].meas_assoc_hist.append(None)
4877 # measurement updated tracks
4878 all_cost_m = np.zeros((num_pred, num_meas))
4879 for ii, ent in enumerate(self._track_tab):
4880 for emm, z in enumerate(meas):
4881 s_to_ii = num_pred * emm + ii + num_pred
4882 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry(
4883 z, ent, timestep, filt_args
4884 )
4885 if up_tab[s_to_ii] is not None:
4886 up_tab[s_to_ii].meas_assoc_hist.append(comb_inds[emm])
4887 all_cost_m[ii, emm] = cost
4889 return up_tab, all_cost_m
4891 def _gen_cor_hyps(
4892 self,
4893 num_meas,
4894 avg_prob_detect,
4895 avg_prob_miss_detect,
4896 avg_prob_surv,
4897 avg_prob_death,
4898 all_cost_m,
4899 meas_combs,
4900 cor_tab=None,
4901 ):
4902 # Define clutter
4903 clutter = self.clutter_rate * self.clutter_den
4905 # Joint Cost Matrix
4906 joint_cost = np.concatenate(
4907 [
4908 np.diag(avg_prob_death.ravel()),
4909 np.diag(avg_prob_surv.ravel() * avg_prob_miss_detect.ravel()),
4910 ],
4911 axis=1,
4912 )
4914 other_jc_terms = (
4915 np.tile((avg_prob_surv * avg_prob_detect).reshape((-1, 1)), (1, num_meas))
4916 * all_cost_m
4917 / (clutter)
4918 )
4920 # Full joint cost matrix for sensor s
4921 joint_cost = np.append(joint_cost, other_jc_terms, axis=1)
4923 gate_meas_indices = np.zeros((len(self._track_tab), num_meas))
4924 for ii in range(0, len(self._track_tab)):
4925 for jj in range(0, len(self._track_tab[ii].gatemeas)):
4926 gate_meas_indices[ii][jj] = self._track_tab[ii].gatemeas[jj]
4927 gate_meas_indc = gate_meas_indices >= 0
4929 # Component updates
4930 ss_w = 0
4931 up_hyp = []
4932 for p_hyp in self._hypotheses:
4933 ss_w += np.sqrt(p_hyp.assoc_prob)
4934 for p_hyp in self._hypotheses:
4935 for ind_lst in meas_combs:
4936 cpreds = len(self._track_tab) # num_pred
4937 num_births = len(self.birth_terms) # num_birth_terms
4938 num_exists = len(p_hyp.track_set) # num_existing_tracks
4939 num_tracks = num_births + num_exists # num_possible_tracks
4941 # Hypothesis index masking
4942 # all birth terms and tracks included in p_hyp.track_set
4943 tindices = np.concatenate(
4944 (np.arange(0, num_births), num_births + np.array(p_hyp.track_set))
4945 ).astype(int)
4946 # lselmask = np.zeros((len(self._track_tab), len(ind_lst)), dtype="bool")
4947 lselmask = np.zeros((len(self._track_tab), num_meas), dtype="bool")
4948 # lselmask = np.
4949 for ii, index in enumerate(ind_lst):
4950 lselmask[tindices, index] = gate_meas_indc[tindices, index]
4952 # verify sort works for 3d arrays similar to 2d arrays, may have to do this list-wise
4953 keys = np.array([np.sort(gate_meas_indices[lselmask])])
4954 # keys = np.array([np.sort(gate_meas_indices[:, ind_lst][lselmask])])
4955 # meas_indices
4956 mindices = self._unique_faster(keys)
4958 comb_tind_cpred = np.append(
4959 np.append(tindices, cpreds + tindices), [2 * cpreds + mindices]
4960 )
4961 # comb_tind_cpred = np.append(
4962 # np.append(tindices, cpreds + tindices), relevant_meas_inds
4963 # ).astype(int)
4965 cost_m = joint_cost[tindices][:, comb_tind_cpred]
4967 with warnings.catch_warnings():
4968 warnings.simplefilter("ignore", RuntimeWarning)
4969 neg_log = -np.log(cost_m)
4971 m = np.round(self.req_upd * np.sqrt(p_hyp.assoc_prob) / ss_w)
4972 m = int(m.item()) + 1
4974 # Gibbs Sampler
4975 [assigns, costs] = gibbs(neg_log, m, rng=self._rng)
4977 # Process unique assignments from gibbs sampler
4978 assigns[assigns < num_tracks] = -np.inf
4979 for ii in range(np.shape(assigns)[0]):
4980 if len(np.shape(assigns)) < 2:
4981 if assigns[ii] >= num_tracks and assigns[ii] < 2 * num_tracks:
4982 assigns[ii] = -1
4983 else:
4984 for jj in range(np.shape(assigns)[1]):
4985 if (
4986 assigns[ii][jj] >= num_tracks
4987 and assigns[ii][jj] < 2 * num_tracks
4988 ):
4989 assigns[ii][jj] = -1
4990 assigns[assigns >= 2 * num_tracks] -= 2 * num_tracks
4991 if assigns[assigns >= 0].size != 0:
4992 assigns[assigns >= 0] = mindices[
4993 assigns[assigns >= 0].astype(int)[
4994 assigns[assigns >= 0].astype(int) >= 0
4995 ]
4996 ]
4997 # Assign updated hypotheses from gibbs sampler
4998 for c, cst in enumerate(costs.flatten()):
4999 update_hyp_cmp_temp = assigns[c,]
5000 update_hyp_cmp_idx = cpreds * (update_hyp_cmp_temp + 1) + np.append(
5001 np.array([np.arange(0, num_births)]),
5002 num_births + np.array([p_hyp.track_set]),
5003 )
5004 new_hyp = self._HypothesisHelper()
5005 new_hyp.assoc_prob = (
5006 -self.clutter_rate
5007 + num_meas * np.log(clutter)
5008 + np.log(p_hyp.assoc_prob)
5009 - cst
5010 )
5011 new_hyp.track_set = update_hyp_cmp_idx[
5012 update_hyp_cmp_idx >= 0
5013 ].astype(int)
5014 up_hyp.append(new_hyp)
5015 lse = log_sum_exp([x.assoc_prob for x in up_hyp])
5017 for ii in range(0, len(up_hyp)):
5018 up_hyp[ii].assoc_prob = np.exp(up_hyp[ii].assoc_prob - lse)
5019 return up_hyp
5021 def correct(self, timestep, meas, filt_args={}):
5022 """Correction step of the MS-JGLMB filter.
5024 This corrects the hypotheses based on the measurements and gates the
5025 measurements according to the class settings. It also updates the
5026 cardinality distribution. Because this calls the inner filter's correct
5027 function, the keyword arguments must contain any information needed by
5028 that function.
5030 Parameters
5031 ----------
5032 timestep: float
5033 Current timestep.
5034 meas : list
5035 List of lists representing sensor measurements containing Nm x 1 numpy arrays each representing a single measurement.
5036 filt_args : dict, optional
5037 keyword arguments to pass to the inner filters correct function.
5038 The default is {}.
5040 Todo
5041 ----
5042 Fix the measurement gating
5044 Returns
5045 -------
5046 None
5047 """
5048 all_combs = list(itertools.product(*meas))
5049 num_meas_per_sens = [len(x) for x in meas]
5050 num_meas = len(all_combs)
5051 num_sens = len(meas)
5052 mnmps = min(num_meas_per_sens)
5054 comb_inds = list(itertools.product(*list(np.arange(0, len(x)) for x in meas)))
5055 comb_inds = [list(ele) for ele in comb_inds]
5057 all_meas_combs = list(itertools.combinations(comb_inds, mnmps))
5058 all_meas_combs = [list(ele) for ele in all_meas_combs]
5060 poss_meas_combs = []
5062 for ii in range(0, len(all_meas_combs)):
5063 break_flag = False
5064 cur_comb = []
5065 for jj, lst1 in enumerate(all_meas_combs[ii]):
5066 for kk, lst2 in enumerate(all_meas_combs[ii]):
5067 if jj == kk:
5068 continue
5069 else:
5070 out = (np.array(lst1) == np.array(lst2)).tolist()
5071 if any(out):
5072 break_flag = True
5073 break
5074 if break_flag:
5075 break
5076 if break_flag:
5077 pass
5078 else:
5079 for lst1 in all_meas_combs[ii]:
5080 for ii, lst2 in enumerate(comb_inds):
5081 if lst1 == lst2:
5082 cur_comb.append(ii)
5083 poss_meas_combs.append(cur_comb)
5085 # gating by tracks
5086 if self.gating_on:
5087 RuntimeError("Gating not implemented yet. PLEASE TURN OFF GATING")
5088 # for ent in self._track_tab:
5089 # ent.gatemeas = self._gate_meas(meas, ent.probDensity.means,
5090 # ent.probDensity.covariances)
5091 else:
5092 for ent in self._track_tab:
5093 ent.gatemeas = np.arange(0, len(all_combs))
5094 # ent.gatemeas = np.arange(0, len(poss_meas_combs))
5095 # Pre-calculation of average survival/death probabilities
5096 avg_prob_surv, avg_prob_death = self._calc_avg_prob_surv_death()
5098 # Pre-calculation of average detection/missed probabilities
5099 avg_prob_detect, avg_prob_miss_detect = self._calc_avg_prob_det_mdet()
5101 if self.save_measurements:
5102 self._meas_tab.append(deepcopy(meas))
5103 # all_combs = list(itertools.product(*meas))
5105 # missed detection tracks
5106 [up_tab, all_cost_m] = self._gen_cor_tab(
5107 num_meas, all_combs, timestep, comb_inds, filt_args
5108 )
5110 up_hyp = self._gen_cor_hyps(
5111 num_meas,
5112 avg_prob_detect,
5113 avg_prob_miss_detect,
5114 avg_prob_surv,
5115 avg_prob_death,
5116 all_cost_m,
5117 poss_meas_combs,
5118 cor_tab=up_tab,
5119 )
5121 self._track_tab = up_tab
5122 self._hypotheses = up_hyp
5123 self._card_dist = self._calc_card_dist(self._hypotheses)
5124 self._clean_predictions()
5125 self._clean_updates()
5126 self._update_has_been_called = True
5127 self._old_track_tab_len = len(self._track_tab)
5130class MSIMMJointGeneralizedLabeledMultiBernoulli(
5131 _IMMGLMBBase, MSJointGeneralizedLabeledMultiBernoulli
5132):
5133 """An implementation of the Multi-Sensor IMM-JGLMB algorithm."""
5135 def __init__(self, **kwargs):
5136 super().__init__(**kwargs)
5138 def _correct_track_tab_entry(self, meas, tab, timestep, filt_args):
5139 newTab = self._TabEntry().setup(tab)
5140 new_f_states = [None] * len(newTab.filt_states)
5141 new_s_hist = [None] * len(newTab.filt_states)
5142 new_c_hist = [None] * len(newTab.filt_states)
5143 new_w = [None] * len(newTab.filt_states)
5144 depleted = False
5145 for ii, (f_state, state, w) in enumerate(
5146 zip(
5147 newTab.filt_states,
5148 newTab.state_hist[-1],
5149 newTab.distrib_weights_hist[-1],
5150 )
5151 ):
5152 try:
5153 (
5154 new_f_states[ii],
5155 new_s_hist[ii],
5156 new_c_hist[ii],
5157 new_w[ii],
5158 ) = self._inner_correct(timestep, meas, f_state, w, state, filt_args)
5159 except (
5160 gerr.ParticleDepletionError,
5161 gerr.ParticleEstimationDomainError,
5162 gerr.ExtremeMeasurementNoiseError,
5163 ):
5164 return None, 0
5165 newTab.filt_states = new_f_states
5166 newTab.state_hist[-1] = new_s_hist
5167 newTab.cov_hist[-1] = new_c_hist
5168 new_w = [w + np.finfo(float).eps for w in new_w]
5169 if not depleted:
5170 cost = np.sum(new_w).item()
5171 newTab.distrib_weights_hist[-1] = [w / cost for w in new_w]
5172 else:
5173 cost = 0
5174 return newTab, cost
5177class PoissonMultiBernoulliMixture(RandomFiniteSetBase):
5178 class _TabEntry:
5179 def __init__(self):
5180 self.label = () # time step born, index of birth model born from
5181 self.distrib_weights_hist = [] # list of weights of the probDensity
5182 self.exist_prob = None # existence probability of the probDensity
5183 self.filt_states = [] # list of dictionaries from filters save function
5184 self.meas_assoc_hist = (
5185 []
5186 ) # list indices into measurement list per time step
5188 self.state_hist = [] # list of lists of numpy arrays for each timestep
5189 self.cov_hist = (
5190 []
5191 ) # list of lists of numpy arrays for each timestep (or None)
5193 """ linear index corresponding to timestep, manually updated. Used
5194 to index things since timestep in label can have decimals."""
5195 self.time_index = None
5197 def setup(self, tab):
5198 """Use to avoid expensive deepcopy."""
5199 self.label = tab.label
5200 self.distrib_weights_hist = tab.distrib_weights_hist.copy()
5201 self.exist_prob = tab.exist_prob
5202 self.filt_states = deepcopy(tab.filt_states)
5203 self.meas_assoc_hist = tab.meas_assoc_hist.copy()
5205 self.state_hist = [None] * len(tab.state_hist)
5206 self.state_hist = [s.copy() for s in [s_lst for s_lst in tab.state_hist]]
5207 self.cov_hist = [
5208 c.copy() if c else [] for c in [c_lst for c_lst in tab.cov_hist]
5209 ]
5211 self.time_index = tab.time_index
5213 return self
5215 class _HypothesisHelper:
5216 def __init__(self):
5217 self.assoc_prob = 0
5218 self.track_set = [] # indices in lookup table
5220 @property
5221 def num_tracks(self):
5222 return len(self.track_set)
5224 class _ExtractHistHelper:
5225 def __init__(self):
5226 self.label = ()
5227 self.meas_ind_hist = []
5228 self.b_time_index = None
5229 self.states = []
5230 self.covs = []
5232 def __init__(
5233 self,
5234 req_upd=None,
5235 gating_on=False,
5236 prune_threshold=10**-15,
5237 exist_threshold=10**-15,
5238 max_hyps=3000,
5239 decimal_places=2,
5240 save_measurements=False,
5241 **kwargs,
5242 ):
5243 self.req_upd = req_upd
5244 self.gating_on = gating_on
5245 self.prune_threshold = prune_threshold
5246 self.exist_threshold = exist_threshold
5247 self.max_hyps = max_hyps
5248 self.decimal_places = decimal_places
5249 self.save_measurements = save_measurements
5251 self._track_tab = [] # list of all possible tracks
5252 self._extractable_hists = []
5254 self._filter = None
5255 self._baseFilter = None
5257 hyp0 = self._HypothesisHelper()
5258 hyp0.assoc_prob = 1
5259 hyp0.track_set = []
5260 self._hypotheses = [hyp0] # list of _HypothesisHelper objects
5262 self._card_dist = [] # probability of having index # as cardinality
5264 """ linear index corresponding to timestep, manually updated. Used
5265 to index things since timestep in label can have decimals. Must
5266 be updated once per time step."""
5267 self._time_index_cntr = 0
5269 self.ospa2 = None
5270 self.ospa2_localization = None
5271 self.ospa2_cardinality = None
5272 self._ospa2_params = {}
5274 super().__init__(**kwargs)
5275 self._states = [[]]
5277 def save_filter_state(self):
5278 """Saves filter variables so they can be restored later.
5280 Note that to pickle the resulting dictionary the :code:`dill` package
5281 may need to be used due to potential pickling of functions.
5282 """
5283 filt_state = super().save_filter_state()
5285 filt_state["req_upd"] = self.req_upd
5286 filt_state["gating_on"] = self.gating_on
5287 filt_state["prune_threshold"] = self.prune_threshold
5288 filt_state["exist_threshold"] = self.exist_threshold
5289 filt_state["max_hyps"] = self.max_hyps
5290 filt_state["decimal_places"] = self.decimal_places
5291 filt_state["save_measurements"] = self.save_measurements
5293 filt_state["_track_tab"] = self._track_tab
5294 filt_state["_extractable_hists"] = self._extractable_hists
5296 if self._baseFilter is not None:
5297 filt_state["_baseFilter"] = (
5298 type(self._baseFilter),
5299 self._baseFilter.save_filter_state(),
5300 )
5301 else:
5302 filt_state["_baseFilter"] = (None, self._baseFilter)
5303 filt_state["_hypotheses"] = self._hypotheses
5304 filt_state["_card_dist"] = self._card_dist
5305 filt_state["_time_index_cntr"] = self._time_index_cntr
5307 filt_state["ospa2"] = self.ospa2
5308 filt_state["ospa2_localization"] = self.ospa2_localization
5309 filt_state["ospa2_cardinality"] = self.ospa2_cardinality
5310 filt_state["_ospa2_params"] = self._ospa_params
5312 return filt_state
5314 def load_filter_state(self, filt_state):
5315 """Initializes filter using saved filter state.
5317 Attributes
5318 ----------
5319 filt_state : dict
5320 Dictionary generated by :meth:`save_filter_state`.
5321 """
5322 super().load_filter_state(filt_state)
5324 self.req_upd = filt_state["req_upd"]
5325 self.gating_on = filt_state["gating_on"]
5326 self.prune_threshold = filt_state["prune_threshold"]
5327 self.exist_threshold = filt_state["exist_threshold"]
5328 self.max_hyps = filt_state["max_hyps"]
5329 self.decimal_places = filt_state["decimal_places"]
5330 self.save_measurements = filt_state["save_measurements"]
5332 self._track_tab = filt_state["_track_tab"]
5333 self._extractable_hists = filt_state["_extractable_hists"]
5335 cls_type = filt_state["_baseFilter"][0]
5336 if cls_type is not None:
5337 self._baseFilter = cls_type()
5338 self._baseFilter.load_filter_state(filt_state["_baseFilter"][1])
5339 else:
5340 self._baseFilter = None
5341 self._hypotheses = filt_state["_hypotheses"]
5342 self._card_dist = filt_state["_card_dist"]
5343 self._time_index_cntr = filt_state["_time_index_cntr"]
5345 self.ospa2 = filt_state["ospa2"]
5346 self.ospa2_localization = filt_state["ospa2_localization"]
5347 self.ospa2_cardinality = filt_state["ospa2_cardinality"]
5348 self._ospa2_params = filt_state["_ospa2_params"]
5350 @property
5351 def states(self):
5352 """Read only list of extracted states.
5354 This is a list with 1 element per timestep, and each element is a list
5355 of the best states extracted at that timestep. The order of each
5356 element corresponds to the label order.
5357 """
5358 return self._states
5360 @property
5361 def covariances(self):
5362 """Read only list of extracted covariances.
5364 This is a list with 1 element per timestep, and each element is a list
5365 of the best covariances extracted at that timestep. The order of each
5366 element corresponds to the state order.
5368 Raises
5369 ------
5370 RuntimeWarning
5371 If the class is not saving the covariances, and returns an empty list.
5372 """
5373 if not self.save_covs:
5374 raise RuntimeWarning("Not saving covariances")
5375 return []
5376 return self._covs
5378 @property
5379 def filter(self):
5380 """Inner filter handling dynamics, must be a gncpy.filters.BayesFilter."""
5381 return self._filter
5383 @filter.setter
5384 def filter(self, val):
5385 self._baseFilter = deepcopy(val)
5386 self._filter = val
5388 @property
5389 def cardinality(self):
5390 """Cardinality estimate."""
5391 return np.argmax(self._card_dist)
5393 def _init_filt_states(self, distrib):
5394 filt_states = [None] * len(distrib.means)
5395 states = [m.copy() for m in distrib.means]
5396 if self.save_covs:
5397 covs = [c.copy() for c in distrib.covariances]
5398 else:
5399 covs = []
5400 weights = distrib.weights.copy()
5401 for ii, (m, cov) in enumerate(zip(distrib.means, distrib.covariances)):
5402 self._baseFilter.cov = cov.copy()
5403 if isinstance(self._baseFilter, gfilts.UnscentedKalmanFilter) or isinstance(
5404 self._baseFilter, gfilts.UKFGaussianScaleMixtureFilter
5405 ):
5406 self._baseFilter.init_sigma_points(m)
5407 filt_states[ii] = self._baseFilter.save_filter_state()
5408 return filt_states, weights, states, covs
5410 def _inner_predict(self, timestep, filt_state, state, filt_args):
5411 self.filter.load_filter_state(filt_state)
5412 new_s = self.filter.predict(timestep, state, **filt_args)
5413 new_f_state = self.filter.save_filter_state()
5414 if self.save_covs:
5415 new_cov = self.filter.cov.copy()
5416 else:
5417 new_cov = None
5418 return new_f_state, new_s, new_cov
5420 def _predict_det_tab_entry(self, tab, timestep, filt_args):
5421 new_tab = self._TabEntry().setup(tab)
5422 new_f_states = [None] * len(new_tab.filt_states)
5423 new_s_hist = [None] * len(new_tab.filt_states)
5424 new_c_hist = [None] * len(new_tab.filt_states)
5425 for ii, (f_state, state) in enumerate(
5426 zip(new_tab.filt_states, new_tab.state_hist[-1])
5427 ):
5428 (new_f_states[ii], new_s_hist[ii], new_c_hist[ii]) = self._inner_predict(
5429 timestep, f_state, state, filt_args
5430 )
5431 new_tab.filt_states = new_f_states
5432 new_tab.state_hist.append(new_s_hist)
5433 new_tab.cov_hist.append(new_c_hist)
5434 new_tab.distrib_weights_hist.append(new_tab.distrib_weights_hist[-1].copy())
5435 new_tab.exist_prob = new_tab.exist_prob * self.prob_survive
5436 return new_tab
5438 def _gen_pred_tab(self, timestep, filt_args):
5439 pred_tab = []
5441 for ii, ent in enumerate(self._track_tab):
5442 entry = self._predict_det_tab_entry(ent, timestep, filt_args)
5443 pred_tab.append(entry)
5445 return pred_tab
5447 def predict(self, timestep, filt_args={}):
5448 # all objects are propagated forward regardless of previous associations.
5449 self._track_tab = self._gen_pred_tab(timestep, filt_args)
5451 def _calc_avg_prob_det_mdet(self, cor_tab):
5452 avg_prob_detect = self.prob_detection * np.ones(len(cor_tab))
5453 avg_prob_miss_detect = 1 - avg_prob_detect
5455 return avg_prob_detect, avg_prob_miss_detect
5457 def _inner_correct(
5458 self, timestep, meas, filt_state, distrib_weight, state, filt_args
5459 ):
5460 self.filter.load_filter_state(filt_state)
5461 cor_state, likely = self.filter.correct(timestep, meas, state, **filt_args)
5462 new_f_state = self.filter.save_filter_state()
5463 new_s = cor_state
5464 if self.save_covs:
5465 new_c = self.filter.cov.copy()
5466 else:
5467 new_c = None
5468 new_w = distrib_weight * likely
5470 return new_f_state, new_s, new_c, new_w
5472 def _correct_track_tab_entry(self, meas, tab, timestep, filt_args):
5473 new_tab = self._TabEntry().setup(tab)
5474 new_f_states = [None] * len(new_tab.filt_states)
5475 new_s_hist = [None] * len(new_tab.filt_states)
5476 new_c_hist = [None] * len(new_tab.filt_states)
5477 new_w = [None] * len(new_tab.filt_states)
5478 depleted = False
5479 for ii, (f_state, state, w) in enumerate(
5480 zip(
5481 new_tab.filt_states,
5482 new_tab.state_hist[-1],
5483 new_tab.distrib_weights_hist[-1],
5484 )
5485 ):
5486 try:
5487 (
5488 new_f_states[ii],
5489 new_s_hist[ii],
5490 new_c_hist[ii],
5491 new_w[ii],
5492 ) = self._inner_correct(timestep, meas, f_state, w, state, filt_args)
5493 except (
5494 gerr.ParticleDepletionError,
5495 gerr.ParticleEstimationDomainError,
5496 gerr.ExtremeMeasurementNoiseError,
5497 ):
5498 return None, 0
5499 new_tab.filt_states = new_f_states
5500 new_tab.state_hist[-1] = new_s_hist
5501 new_tab.cov_hist[-1] = new_c_hist
5502 new_w = [w + np.finfo(float).eps for w in new_w]
5503 if not depleted:
5504 cost = (new_tab.exist_prob * self.prob_detection * np.sum(new_w).item()) / (
5505 (1 - new_tab.exist_prob + new_tab.exist_prob * self.prob_miss_detection)
5506 * np.sum(tab.distrib_weights_hist[-1]).item()
5507 )
5508 # new_tab.distrib_weights_hist[-1] = [w / cost for w in new_w]
5509 nw_list = [w * new_tab.exist_prob * self.prob_detection for w in new_w]
5510 new_tab.distrib_weights_hist[-1] = [
5511 w / np.sum(nw_list).item() for w in nw_list
5512 ]
5513 new_tab.exist_prob = 1
5514 else:
5515 cost = 0
5516 return new_tab, cost
5518 def _correct_birth_tab_entry(self, meas, distrib, timestep, filt_args):
5519 new_tab = self._TabEntry()
5520 (filt_states, weights, states, covs) = self._init_filt_states(distrib)
5522 new_f_states = [None] * len(filt_states)
5523 new_s_hist = [None] * len(filt_states)
5524 new_c_hist = [None] * len(filt_states)
5525 new_w = [None] * len(filt_states)
5526 depleted = False
5527 for ii, (f_state, state, w) in enumerate(zip(filt_states, states, weights)):
5528 try:
5529 (
5530 new_f_states[ii],
5531 new_s_hist[ii],
5532 new_c_hist[ii],
5533 new_w[ii],
5534 ) = self._inner_correct(timestep, meas, f_state, w, state, filt_args)
5535 except (
5536 gerr.ParticleDepletionError,
5537 gerr.ParticleEstimationDomainError,
5538 gerr.ExtremeMeasurementNoiseError,
5539 ):
5540 return None, 0
5541 new_tab.filt_states = new_f_states
5542 new_tab.state_hist = [new_s_hist]
5543 new_tab.cov_hist = [new_c_hist]
5544 new_tab.distrib_weights_hist = []
5545 new_w = [w + np.finfo(float).eps for w in new_w]
5546 if not depleted:
5547 cost = (
5548 np.sum(new_w).item() * self.prob_detection
5549 + self.clutter_rate * self.clutter_den
5550 )
5551 new_tab.distrib_weights_hist.append(
5552 [w / np.sum(new_w).item() for w in new_w]
5553 )
5554 new_tab.exist_prob = (
5555 self.prob_detection
5556 * cost
5557 / (self.clutter_rate * self.clutter_den + self.prob_detection * cost)
5558 )
5559 else:
5560 cost = 0
5561 new_tab.time_index = self._time_index_cntr
5562 return new_tab, cost
5564 def _gen_cor_tab(self, num_meas, meas, timestep, filt_args):
5565 num_pred = len(self._track_tab)
5566 num_birth = len(self.birth_terms)
5567 up_tab = [None] * ((num_meas + 1) * num_pred + num_meas * num_birth)
5569 # Missed Detection Updates
5570 for ii, track in enumerate(self._track_tab):
5571 up_tab[ii] = self._TabEntry().setup(track)
5572 sum_non_exist_prob = (
5573 1
5574 - up_tab[ii].exist_prob
5575 + up_tab[ii].exist_prob * self.prob_miss_detection
5576 )
5577 up_tab[ii].distrib_weights_hist.append(
5578 [w * sum_non_exist_prob for w in up_tab[ii].distrib_weights_hist[-1]]
5579 )
5580 up_tab[ii].exist_prob = (
5581 up_tab[ii].exist_prob * self.prob_miss_detection
5582 ) / (sum_non_exist_prob)
5583 up_tab[ii].meas_assoc_hist.append(None)
5584 # left_cost_m = np.zeros()
5585 # all_cost_m = np.zeros((num_pred + num_birth * num_meas, num_meas))
5586 all_cost_m = np.zeros((num_meas, num_pred + num_birth * num_meas))
5588 # Update for all existing tracks
5589 for emm, z in enumerate(meas):
5590 for ii, ent in enumerate(self._track_tab):
5591 s_to_ii = num_pred * emm + ii + num_pred
5592 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry(
5593 z, ent, timestep, filt_args
5594 )
5595 if up_tab[s_to_ii] is not None:
5596 up_tab[s_to_ii].meas_assoc_hist.append(emm)
5597 all_cost_m[emm, ii] = cost
5599 # Update for all potential new births
5600 for emm, z in enumerate(meas):
5601 for ii, b_model in enumerate(self.birth_terms):
5602 s_to_ii = ((num_meas + 1) * num_pred) + emm * num_birth + ii
5603 (up_tab[s_to_ii], cost) = self._correct_birth_tab_entry(
5604 z, b_model, timestep, filt_args
5605 )
5606 if up_tab[s_to_ii] is not None:
5607 up_tab[s_to_ii].meas_assoc_hist.append(emm)
5608 all_cost_m[emm, emm + num_pred] = cost
5609 return up_tab, all_cost_m
5611 # TODO: find some way to cherry pick the appropriate all_combs to ensure that
5612 # measurements are not duplicated before being passed to assignment
5613 def _gen_cor_hyps(
5614 self, num_meas, avg_prob_detect, avg_prob_miss_detect, all_cost_m, cor_tab
5615 ):
5616 num_pred = len(self._track_tab)
5617 up_hyps = []
5618 if num_meas == 0:
5619 for hyp in self._hypotheses:
5620 pmd_log = np.sum(
5621 [np.log(avg_prob_miss_detect[ii]) for ii in hyp.track_set]
5622 )
5623 hyp.assoc_prob = -self.clutter_rate + pmd_log + np.log(hyp.assoc_prob)
5624 up_hyps.append(hyp)
5625 else:
5626 clutter = self.clutter_rate * self.clutter_den
5627 ss_w = 0
5628 for p_hyp in self._hypotheses:
5629 ss_w += np.sqrt(p_hyp.assoc_prob)
5630 for p_hyp in self._hypotheses:
5631 if p_hyp.num_tracks == 0: # all clutter
5632 inds = np.arange(num_pred, num_pred + num_meas).tolist()
5633 else:
5634 inds = (
5635 p_hyp.track_set
5636 + np.arange(num_pred, num_pred + num_meas).tolist()
5637 )
5639 cost_m = all_cost_m[:, inds]
5640 max_row_inds, max_col_inds = np.where(cost_m >= np.inf)
5641 if max_row_inds.size > 0:
5642 cost_m[max_row_inds, max_col_inds] = np.finfo(float).max
5643 min_row_inds, min_col_inds = np.where(cost_m <= 0.0)
5644 if min_row_inds.size > 0:
5645 cost_m[min_row_inds, min_col_inds] = np.finfo(float).eps # 1
5646 neg_log = -np.log(cost_m)
5647 # if max_row_inds.size > 0:
5648 # neg_log[max_row_inds, max_col_inds] = -np.inf
5649 # if min_row_inds.size > 0:
5650 # neg_log[min_row_inds, min_col_inds] = np.inf
5652 m = np.round(self.req_upd * np.sqrt(p_hyp.assoc_prob) / ss_w)
5653 m = int(m.item())
5654 # if m <1:
5655 # m=1
5656 [assigns, costs] = murty_m_best_all_meas_assigned(neg_log, m)
5657 """assignment matrix consisting of 0 or 1 entries such that each column sums
5658 to one and each row sums to zero or one""" # (transposed from the paper)
5659 # assigns = assigns.T
5660 # assigns = np.delete(assigns, 1, axis=0)
5661 # costs = np.delete(costs, 1, axis=0)
5663 pmd_log = np.sum(
5664 [np.log(avg_prob_miss_detect[ii]) for ii in p_hyp.track_set]
5665 )
5666 for a, c in zip(assigns, costs):
5667 new_hyp = self._HypothesisHelper()
5668 new_hyp.assoc_prob = (
5669 -self.clutter_rate
5670 + num_meas * np.log(clutter)
5671 + pmd_log
5672 + np.log(p_hyp.assoc_prob)
5673 - c
5674 )
5675 if p_hyp.num_tracks == 0:
5676 new_track_list = list(num_pred * a + num_pred * num_meas)
5677 else:
5678 # track_inds = np.argwhere(a==1)
5679 new_track_list = []
5681 for ii, ms in enumerate(a):
5682 if len(p_hyp.track_set) >= ms:
5683 new_track_list.append(
5684 (ii + 1) * num_pred + p_hyp.track_set[(ms - 1)]
5685 )
5686 elif len(p_hyp.track_set) == len(a):
5687 new_track_list.append(
5688 num_pred * ms - ii * (num_pred - 1)
5689 )
5690 elif len(p_hyp.track_set) < len(a):
5691 new_track_list.append(
5692 num_pred * (num_meas + 1) + (ms - num_meas)
5693 )
5694 else:
5695 new_track_list.append(
5696 (num_meas + 1) * num_pred
5697 - (ms - len(p_hyp.track_set) - 1)
5698 )
5699 # if len(p_hyp.track_set) >= ms:
5700 # new_track_list.append(
5701 # (ii + 1) * num_pred + p_hyp.track_set[(ms - 1)]
5702 # )
5703 # else:
5704 # new_track_list.append(
5705 # (num_meas + 1) * num_pred + ms - num_meas
5706 # )
5707 # if len(a) == len(p_hyp.track_set):
5708 # for ii, (ms, t) in enumerate(zip(a, p_hyp.track_set)):
5709 # if len(p_hyp.track_set) >= ms:
5710 # # new_track_list.append(((np.array(t)) * ms + num_pred))
5711 # # new_track_list.append((num_pred * ms + np.array(t)))
5712 # new_track_list.append(
5713 # (ii + 1) * num_pred + p_hyp.track_set[(ms - 1)]
5714 # )
5715 # else:
5716 # new_track_list.append(
5717 # num_pred * ms - ii * (num_pred - 1)
5718 # )
5719 # elif len(p_hyp.track_set) < len(a):
5720 # for ii, ms in enumerate(a):
5721 # if len(p_hyp.track_set) >= ms:
5722 # # coiuld be this one, trying -1 first
5723 # # new_track_list.append(((np.array(p_hyp.track_set[(ms-ii)]) + num_pred) * ms))
5724 # # new_track_list.append(((np.array(p_hyp.track_set[(ms-1)]) + num_pred) * ms + num_meas * ii))
5725 # # new_track_list.append(
5726 # # (ms + ii) * num_pred + p_hyp.track_set[(ms - 1)]
5727 # # )
5728 # new_track_list.append(
5729 # (ii + 1) * num_pred + p_hyp.track_set[(ms - 1)]
5730 # )
5731 # elif len(p_hyp.track_set) < ms:
5732 # new_track_list.append(
5733 # num_pred * (num_meas + 1) + (ms - num_meas)
5734 # )
5735 # # new_track_list.append(num_meas * num_pred + ms)
5736 # elif len(p_hyp.track_set) > len(a):
5737 # # May need to modify this
5738 # for ii, ms in enumerate(a):
5739 # if len(p_hyp.track_set) >= ms:
5740 # new_track_list.append(
5741 # (ii + 1) * num_pred + p_hyp.track_set[(ms - 1)]
5742 # )
5743 # # new_track_list.append(
5744 # # (ms - 1) * num_pred + p_hyp.track_set[(ms - 1)]
5745 # # )
5746 # # new_track_list.append(
5747 # # ms * num_pred + p_hyp.track_set[(ms - 1)]
5748 # # )
5749 # elif len(p_hyp.track_set) < ms:
5750 # new_track_list.append(
5751 # num_pred * (num_meas + 1)
5752 # + (ms - 1 - len(p_hyp.track_set))
5753 # )
5755 # new_track_list = list(np.array(p_hyp.track_set) + num_pred + num_pred * a)# new_track_list = list(num_pred * a + np.array(p_hyp.track_set))
5757 new_hyp.track_set = new_track_list
5758 up_hyps.append(new_hyp)
5760 lse = log_sum_exp([x.assoc_prob for x in up_hyps])
5761 for ii in range(0, len(up_hyps)):
5762 up_hyps[ii].assoc_prob = np.exp(up_hyps[ii].assoc_prob - lse)
5763 return up_hyps
5765 def _clean_updates(self):
5766 used = [0] * len(self._track_tab)
5767 for hyp in self._hypotheses:
5768 for ii in hyp.track_set:
5769 if self._track_tab[ii] is not None:
5770 used[ii] += 1
5771 nnz_inds = [idx for idx, val in enumerate(used) if val != 0]
5772 track_cnt = len(nnz_inds)
5774 new_inds = [None] * len(self._track_tab)
5775 for ii, v in zip(nnz_inds, [ii for ii in range(0, track_cnt)]):
5776 new_inds[ii] = v
5777 # new_tab = [self._TabEntry().setup(self._track_tab[ii]) for ii in nnz_inds]
5778 new_tab = [self._track_tab[ii] for ii in nnz_inds]
5779 new_hyps = []
5780 for ii, hyp in enumerate(self._hypotheses):
5781 if len(hyp.track_set) > 0:
5782 track_set = [new_inds[ii] for ii in hyp.track_set]
5783 if None in track_set:
5784 continue
5785 hyp.track_set = track_set
5786 new_hyps.append(hyp)
5787 self._track_tab = new_tab
5788 self._hypotheses = new_hyps
5790 def _calc_card_dist(self, hyp_lst):
5791 """Calucaltes the cardinality distribution."""
5792 if len(hyp_lst) == 0:
5793 return [
5794 1,
5795 ]
5796 card_dist = []
5797 for ii in range(0, max(map(lambda x: x.num_tracks, hyp_lst)) + 1):
5798 card = 0
5799 for hyp in hyp_lst:
5800 if hyp.num_tracks == ii:
5801 card = card + hyp.assoc_prob
5802 card_dist.append(card)
5803 return card_dist
5805 def correct(self, timestep, meas, filt_args={}):
5806 """Correction step of the PMBM filter.
5808 Notes
5809 -----
5810 This corrects the hypotheses based on the measurements and gates the
5811 measurements according to the class settings. It also updates the
5812 cardinality distribution.
5814 Parameters
5815 ----------
5816 timestep: float
5817 Current timestep.
5818 meas : list
5819 List of Nm x 1 numpy arrays each representing a measuremnt.
5820 filt_args : dict, optional
5821 keyword arguments to pass to the inner filters correct function.
5822 The default is {}.
5824 Returns
5825 -------
5826 None
5827 """
5828 if self.gating_on:
5829 warnings.warn("Gating not implemented yet. SKIPPING", RuntimeWarning)
5830 # means = []
5831 # covs = []
5832 # for ent in self._track_tab:
5833 # means.extend(ent.probDensity.means)
5834 # covs.extend(ent.probDensity.covariances)
5835 # meas = self._gate_meas(meas, means, covs)
5836 if self.save_measurements:
5837 self._meas_tab.append(deepcopy(meas))
5838 num_meas = len(meas)
5839 cor_tab, all_cost_m = self._gen_cor_tab(num_meas, meas, timestep, filt_args)
5841 # self._add_birth_hyps(num_meas)
5843 avg_prob_det, avg_prob_mdet = self._calc_avg_prob_det_mdet(cor_tab)
5845 cor_hyps = self._gen_cor_hyps(
5846 num_meas, avg_prob_det, avg_prob_mdet, all_cost_m, cor_tab
5847 )
5849 self._track_tab = cor_tab
5850 self._hypotheses = cor_hyps
5851 self._card_dist = self._calc_card_dist(self._hypotheses)
5852 self._clean_updates()
5854 def _extract_helper(self, track):
5855 states = [None] * len(track.state_hist)
5856 covs = [None] * len(track.state_hist)
5857 for ii, (w_lst, s_lst, c_lst) in enumerate(
5858 zip(track.distrib_weights_hist, track.state_hist, track.cov_hist)
5859 ):
5860 idx = np.argmax(w_lst)
5861 states[ii] = s_lst[idx]
5862 if self.save_covs:
5863 covs[ii] = c_lst[idx]
5864 return states, covs
5866 def _update_extract_hist(self, idx_cmp):
5867 used_meas_inds = [[] for ii in range(self._time_index_cntr)]
5868 new_extract_hists = [None] * len(self._hypotheses[idx_cmp].track_set)
5869 for ii, track in enumerate(
5870 [
5871 self._track_tab[trk_ind]
5872 for trk_ind in self._hypotheses[idx_cmp].track_set
5873 ]
5874 ):
5875 new_extract_hists[ii] = self._ExtractHistHelper()
5876 new_extract_hists[ii].meas_ind_hist = track.meas_assoc_hist.copy()
5877 new_extract_hists[ii].b_time_index = track.time_index
5878 (
5879 new_extract_hists[ii].states,
5880 new_extract_hists[ii].covs,
5881 ) = self._extract_helper(track)
5883 for t_inds_after_b, meas_ind in enumerate(
5884 new_extract_hists[ii].meas_ind_hist
5885 ):
5886 tt = new_extract_hists[ii].b_time_index + t_inds_after_b
5887 if meas_ind is not None and meas_ind not in used_meas_inds[tt]:
5888 used_meas_inds[tt].append(meas_ind)
5889 good_inds = []
5890 for ii, existing in enumerate(self._extractable_hists):
5891 for t_inds_after_b, meas_ind in enumerate(existing.meas_ind_hist):
5892 tt = existing.b_time_index + t_inds_after_b
5893 used = meas_ind is not None and meas_ind in used_meas_inds[tt]
5894 if used:
5895 break
5896 if not used:
5897 good_inds.append(ii)
5898 self._extractable_hists = [self._extractable_hists[ii] for ii in good_inds]
5899 self._extractable_hists.extend(new_extract_hists)
5901 def extract_states(self, update=True, calc_states=True):
5902 """Extracts the best state estimates.
5904 This extracts the best states from the distribution. It should be
5905 called once per time step after the correction function. This calls
5906 both the inner filters predict and correct functions so the keyword
5907 arguments must contain any additional variables needed by those
5908 functions.
5910 Parameters
5911 ----------
5912 update : bool, optional
5913 Flag indicating if the label history should be updated. This should
5914 be done once per timestep and can be disabled if calculating states
5915 after the final timestep. The default is True.
5916 calc_states : bool, optional
5917 Flag indicating if the states should be calculated based on the
5918 label history. This only needs to be done before the states are used.
5919 It can simply be called once after the end of the simulation. The
5920 default is true.
5922 Returns
5923 -------
5924 idx_cmp : int
5925 Index of the hypothesis table used when extracting states.
5926 """
5927 card = np.argmax(self._card_dist)
5928 tracks_per_hyp = np.array([x.num_tracks for x in self._hypotheses])
5929 weight_per_hyp = np.array([x.assoc_prob for x in self._hypotheses])
5931 self._states = [[] for ii in range(self._time_index_cntr)]
5932 self._covs = [[] for ii in range(self._time_index_cntr)]
5934 if len(tracks_per_hyp) == 0:
5935 return None
5936 idx_cmp = np.argmax(weight_per_hyp * (tracks_per_hyp == card))
5937 if update:
5938 self._update_extract_hist(idx_cmp)
5939 if calc_states:
5940 for existing in self._extractable_hists:
5941 for t_inds_after_b, (s, c) in enumerate(
5942 zip(existing.states, existing.covs)
5943 ):
5944 tt = existing.b_time_index + t_inds_after_b
5945 self._states[tt].append(s)
5946 self._covs[tt].append(c)
5947 if not update and not calc_states:
5948 warnings.warn("Extracting states performed no actions")
5949 return idx_cmp
5951 def _prune(self):
5952 """Removes hypotheses below a threshold.
5954 This should be called once per time step after the correction and
5955 before the state extraction.
5956 """
5957 # Find hypotheses with low association probabilities
5958 temp_assoc_probs = np.array([])
5959 for ii in range(0, len(self._hypotheses)):
5960 temp_assoc_probs = np.append(
5961 temp_assoc_probs, self._hypotheses[ii].assoc_prob
5962 )
5963 keep_indices = np.argwhere(temp_assoc_probs > self.prune_threshold).T
5964 keep_indices = keep_indices.flatten()
5966 # For re-weighing association probabilities
5967 new_sum = np.sum(temp_assoc_probs[keep_indices])
5968 self._hypotheses = [self._hypotheses[ii] for ii in keep_indices]
5969 for ii in range(0, len(keep_indices)):
5970 self._hypotheses[ii].assoc_prob = self._hypotheses[ii].assoc_prob / new_sum
5971 # Re-calculate cardinality
5972 self._card_dist = self._calc_card_dist(self._hypotheses)
5974 def _cap(self):
5975 """Removes least likely hypotheses until a maximum number is reached.
5977 This should be called once per time step after pruning and
5978 before the state extraction.
5979 """
5980 # Determine if there are too many hypotheses
5981 if len(self._hypotheses) > self.max_hyps:
5982 temp_assoc_probs = np.array([])
5983 for ii in range(0, len(self._hypotheses)):
5984 temp_assoc_probs = np.append(
5985 temp_assoc_probs, self._hypotheses[ii].assoc_prob
5986 )
5987 sorted_indices = np.argsort(temp_assoc_probs)
5989 # Reverse order to get descending array
5990 sorted_indices = sorted_indices[::-1]
5992 # Take the top n assoc_probs, where n = max_hyps
5993 keep_indices = np.array([], dtype=np.int64)
5994 for ii in range(0, self.max_hyps):
5995 keep_indices = np.append(keep_indices, int(sorted_indices[ii]))
5996 # Assign to class
5997 self._hypotheses = [self._hypotheses[ii] for ii in keep_indices]
5999 # Normalize association probabilities
6000 new_sum = 0
6001 for ii in range(0, len(self._hypotheses)):
6002 new_sum = new_sum + self._hypotheses[ii].assoc_prob
6003 for ii in range(0, len(self._hypotheses)):
6004 self._hypotheses[ii].assoc_prob = (
6005 self._hypotheses[ii].assoc_prob / new_sum
6006 )
6007 # Re-calculate cardinality
6008 self._card_dist = self._calc_card_dist(self._hypotheses)
6010 def _bern_prune(self):
6011 """Removes track table entries below a threshold.
6013 This should be called once per time step after the correction and
6014 before the state extraction.
6015 """
6016 used = [0] * len(self._track_tab)
6017 for ii in range(0, len(self._track_tab)):
6018 if self._track_tab[ii].exist_prob > self.exist_threshold:
6019 used[ii] += 1
6021 keep_inds = [idx for idx, val in enumerate(used) if val != 0]
6022 track_cnt = len(keep_inds)
6024 new_inds = [None] * len(self._track_tab)
6025 for ii, v in zip(keep_inds, [ii for ii in range(0, track_cnt)]):
6026 new_inds[ii] = v
6028 # loop over track table and remove pruned entries
6029 new_tab = [self._track_tab[ii] for ii in keep_inds]
6030 new_hyps = []
6031 for ii, hyp in enumerate(self._hypotheses):
6032 if len(hyp.track_set) > 0:
6033 track_set = [new_inds[track_ind] for track_ind in hyp.track_set]
6034 if None in track_set:
6035 track_set = [item for item in track_set if item != None]
6036 hyp.track_set = track_set
6037 new_hyps.append(hyp)
6039 del_inds = []
6040 # TODO: ADD CASE FOR NO MORE TRACKS SO THAT WE DON'T REMOVE ALL HYPOTHESES
6041 # AND OR THE HYPOTHESES HAVE AN EMPTY TRACK SET RATHER THAN A TRACK SET OF NONES
6043 for ii in range(0, len(new_hyps)):
6044 same_inds = []
6045 for jj in range(ii, len(new_hyps)):
6046 if ii == jj or any(jj == x for x in del_inds):
6047 continue
6048 if new_hyps[ii].track_set == new_hyps[jj].track_set:
6049 same_inds.append(jj)
6050 for jj in same_inds:
6051 new_hyps[ii].assoc_prob += new_hyps[jj].assoc_prob
6052 del_inds.append(jj)
6053 del_inds.sort(reverse=True)
6054 for ind in del_inds:
6055 new_hyps.pop(ind)
6057 self._track_tab = new_tab
6058 self._hypotheses = new_hyps
6060 def cleanup(
6061 self,
6062 enable_prune=True,
6063 enable_cap=True,
6064 enable_bern_prune=True,
6065 enable_extract=True,
6066 extract_kwargs=None,
6067 ):
6068 """Performs the cleanup step of the filter.
6070 This can prune, cap, and extract states. It must be called once per
6071 timestep, even if all three functions are disabled. This is to ensure
6072 that internal counters for tracking linear timestep indices are properly
6073 incremented. If this is called with `enable_extract` set to true then
6074 the extract states method does not need to be called separately. It is
6075 recommended to call this function instead of
6076 :meth:`carbs.swarm_estimator.tracker.PoissonMultiBernoulliMixture.extract_states`
6077 directly.
6079 Parameters
6080 ----------
6081 enable_prune : bool, optional
6082 Flag indicating if prunning should be performed. The default is True.
6083 enable_cap : bool, optional
6084 Flag indicating if capping should be performed. The default is True.
6085 enable_bern_prune: bool, optional
6086 Flag indicating if bernoulli pruning should be performed. The default is True.
6087 enable_extract : bool, optional
6088 Flag indicating if state extraction should be performed. The default is True.
6089 extract_kwargs : dict, optional
6090 Additional arguments to pass to :meth:`.extract_states`. The
6091 default is None. Only used if extracting states.
6093 Returns
6094 -------
6095 None.
6097 """
6098 self._time_index_cntr += 1
6100 if enable_prune:
6101 self._prune()
6102 if enable_cap:
6103 self._cap()
6104 if enable_bern_prune:
6105 self._bern_prune()
6106 if enable_extract:
6107 if extract_kwargs is None:
6108 extract_kwargs = {}
6109 self.extract_states(**extract_kwargs)
6111 def calculate_ospa2(
6112 self,
6113 truth,
6114 c,
6115 p,
6116 win_len,
6117 true_covs=None,
6118 core_method=SingleObjectDistance.MANHATTAN,
6119 state_inds=None,
6120 ):
6121 """Calculates the OSPA(2) distance between the truth at all timesteps.
6123 Wrapper for :func:`serums.distances.calculate_ospa2`.
6125 Parameters
6126 ----------
6127 truth : list
6128 Each element represents a timestep and is a list of N x 1 numpy array,
6129 one per true agent in the swarm.
6130 c : float
6131 Distance cutoff for considering a point properly assigned. This
6132 influences how cardinality errors are penalized. For :math:`p = 1`
6133 it is the penalty given false point estimate.
6134 p : int
6135 The power of the distance term. Higher values penalize outliers
6136 more.
6137 win_len : int
6138 Number of samples to include in window.
6139 core_method : :class:`serums.enums.SingleObjectDistance`, Optional
6140 The main distance measure to use for the localization component.
6141 The default value is :attr:`.SingleObjectDistance.MANHATTAN`.
6142 true_covs : list, Optional
6143 Each element represents a timestep and is a list of N x N numpy arrays
6144 corresonponding to the uncertainty about the true states. Note the
6145 order must be consistent with the truth data given. This is only
6146 needed for core methods :attr:`SingleObjectDistance.HELLINGER`. The defautl
6147 value is None.
6148 state_inds : list, optional
6149 Indices in the state vector to use, will be applied to the truth
6150 data as well. The default is None which means the full state is
6151 used.
6152 """
6153 # error checking on optional input arguments
6154 core_method = self._ospa_input_check(core_method, truth, true_covs)
6156 # setup data structures
6157 if state_inds is None:
6158 state_dim = self._ospa_find_s_dim(truth)
6159 state_inds = range(state_dim)
6160 else:
6161 state_dim = len(state_inds)
6162 if state_dim is None:
6163 warnings.warn("Failed to get state dimension. SKIPPING OSPA(2) calculation")
6165 nt = len(self._states)
6166 self.ospa2 = np.zeros(nt)
6167 self.ospa2_localization = np.zeros(nt)
6168 self.ospa2_cardinality = np.zeros(nt)
6169 self._ospa2_params["core"] = core_method
6170 self._ospa2_params["cutoff"] = c
6171 self._ospa2_params["power"] = p
6172 self._ospa2_params["win_len"] = win_len
6173 return
6174 true_mat, true_cov_mat = self._ospa_setup_tmat(
6175 truth, state_dim, true_covs, state_inds
6176 )
6177 est_mat, est_cov_mat = self._ospa_setup_emat(state_dim, state_inds)
6179 # find OSPA
6180 (
6181 self.ospa2,
6182 self.ospa2_localization,
6183 self.ospa2_cardinality,
6184 self._ospa2_params["core"],
6185 self._ospa2_params["cutoff"],
6186 self._ospa2_params["power"],
6187 self._ospa2_params["win_len"],
6188 ) = calculate_ospa2(
6189 est_mat,
6190 true_mat,
6191 c,
6192 p,
6193 win_len,
6194 core_method=core_method,
6195 true_cov_mat=true_cov_mat,
6196 est_cov_mat=est_cov_mat,
6197 )
6199 def plot_states(
6200 self,
6201 plt_inds,
6202 state_lbl="States",
6203 ttl=None,
6204 state_color=None,
6205 x_lbl=None,
6206 y_lbl=None,
6207 **kwargs,
6208 ):
6209 """Plots the best estimate for the states.
6211 This assumes that the states have been extracted. It's designed to plot
6212 two of the state variables (typically x/y position). The error ellipses
6213 are calculated according to :cite:`Hoover1984_AlgorithmsforConfidenceCirclesandEllipses`
6215 Keyword arguments are processed with
6216 :meth:`gncpy.plotting.init_plotting_opts`. This function
6217 implements
6219 - f_hndl
6220 - true_states
6221 - sig_bnd
6222 - rng
6223 - meas_inds
6224 - lgnd_loc
6225 - marker
6227 Parameters
6228 ----------
6229 plt_inds : list
6230 List of indices in the state vector to plot
6231 state_lbl : string
6232 Value to appear in legend for the states. Only appears if the
6233 legend is shown
6234 ttl : string, optional
6235 Title for the plot, if None a default title is generated. The default
6236 is None.
6237 x_lbl : string
6238 Label for the x-axis.
6239 y_lbl : string
6240 Label for the y-axis.
6242 Returns
6243 -------
6244 Matplotlib figure
6245 Instance of the matplotlib figure used
6246 """
6247 opts = pltUtil.init_plotting_opts(**kwargs)
6248 f_hndl = opts["f_hndl"]
6249 true_states = opts["true_states"]
6250 sig_bnd = opts["sig_bnd"]
6251 rng = opts["rng"]
6252 meas_inds = opts["meas_inds"]
6253 lgnd_loc = opts["lgnd_loc"]
6254 marker = opts["marker"]
6255 if ttl is None:
6256 ttl = "State Estimates"
6257 if rng is None:
6258 rng = rnd.default_rng(1)
6259 if x_lbl is None:
6260 x_lbl = "x-position"
6261 if y_lbl is None:
6262 y_lbl = "y-position"
6263 plt_meas = meas_inds is not None
6264 show_sig = sig_bnd is not None and self.save_covs
6266 s_lst = deepcopy(self._states)
6267 x_dim = None
6269 if f_hndl is None:
6270 f_hndl = plt.figure()
6271 f_hndl.add_subplot(1, 1, 1)
6272 # get state dimension
6273 for states in s_lst:
6274 if len(states) > 0:
6275 x_dim = states[0].size
6276 break
6277 # get array of all state values for each label
6278 added_sig_lbl = False
6279 added_true_lbl = False
6280 added_state_lbl = False
6281 added_meas_lbl = False
6282 r = rng.random()
6283 b = rng.random()
6284 g = rng.random()
6285 if state_color is None:
6286 color = (r, g, b)
6287 else:
6288 color = state_color
6289 for tt, states in enumerate(s_lst):
6290 if len(states) == 0:
6291 continue
6292 x = np.concatenate(states, axis=1)
6293 if show_sig:
6294 sigs = [None] * len(states)
6295 for ii, cov in enumerate(self._covs[tt]):
6296 sig = np.zeros((2, 2))
6297 sig[0, 0] = cov[plt_inds[0], plt_inds[0]]
6298 sig[0, 1] = cov[plt_inds[0], plt_inds[1]]
6299 sig[1, 0] = cov[plt_inds[1], plt_inds[0]]
6300 sig[1, 1] = cov[plt_inds[1], plt_inds[1]]
6301 sigs[ii] = sig
6302 # plot
6303 for ii, sig in enumerate(sigs):
6304 if sig is None:
6305 continue
6306 w, h, a = pltUtil.calc_error_ellipse(sig, sig_bnd)
6307 if not added_sig_lbl:
6308 s = r"${}\sigma$ Error Ellipses".format(sig_bnd)
6309 e = Ellipse(
6310 xy=x[plt_inds, ii],
6311 width=w,
6312 height=h,
6313 angle=a,
6314 zorder=-10000,
6315 label=s,
6316 )
6317 added_sig_lbl = True
6318 else:
6319 e = Ellipse(
6320 xy=x[plt_inds, ii],
6321 width=w,
6322 height=h,
6323 angle=a,
6324 zorder=-10000,
6325 )
6326 e.set_clip_box(f_hndl.axes[0].bbox)
6327 e.set_alpha(0.15)
6328 e.set_facecolor(color)
6329 f_hndl.axes[0].add_patch(e)
6330 if not added_state_lbl:
6331 f_hndl.axes[0].scatter(
6332 x[plt_inds[0], :],
6333 x[plt_inds[1], :],
6334 color=color,
6335 edgecolors=(0, 0, 0),
6336 marker=marker,
6337 label=state_lbl,
6338 )
6339 added_state_lbl = True
6340 else:
6341 f_hndl.axes[0].scatter(
6342 x[plt_inds[0], :],
6343 x[plt_inds[1], :],
6344 color=color,
6345 edgecolors=(0, 0, 0),
6346 marker=marker,
6347 )
6348 # if true states are available then plot them
6349 if true_states is not None:
6350 if x_dim is None:
6351 for states in true_states:
6352 if len(states) > 0:
6353 x_dim = states[0].size
6354 break
6355 max_true = max([len(x) for x in true_states])
6356 x = np.nan * np.ones((x_dim, len(true_states), max_true))
6357 for tt, states in enumerate(true_states):
6358 for ii, state in enumerate(states):
6359 x[:, [tt], ii] = state.copy()
6360 for ii in range(0, max_true):
6361 if not added_true_lbl:
6362 f_hndl.axes[0].plot(
6363 x[plt_inds[0], :, ii],
6364 x[plt_inds[1], :, ii],
6365 color="k",
6366 marker=".",
6367 label="True Trajectories",
6368 )
6369 added_true_lbl = True
6370 else:
6371 f_hndl.axes[0].plot(
6372 x[plt_inds[0], :, ii],
6373 x[plt_inds[1], :, ii],
6374 color="k",
6375 marker=".",
6376 )
6377 if plt_meas:
6378 meas_x = []
6379 meas_y = []
6380 for meas_tt in self._meas_tab:
6381 mx_ii = [m[meas_inds[0]].item() for m in meas_tt]
6382 my_ii = [m[meas_inds[1]].item() for m in meas_tt]
6383 meas_x.extend(mx_ii)
6384 meas_y.extend(my_ii)
6385 color = (128 / 255, 128 / 255, 128 / 255)
6386 meas_x = np.asarray(meas_x)
6387 meas_y = np.asarray(meas_y)
6388 if not added_meas_lbl:
6389 f_hndl.axes[0].scatter(
6390 meas_x,
6391 meas_y,
6392 zorder=-1,
6393 alpha=0.35,
6394 color=color,
6395 marker="^",
6396 edgecolors=(0, 0, 0),
6397 label="Measurements",
6398 )
6399 else:
6400 f_hndl.axes[0].scatter(
6401 meas_x,
6402 meas_y,
6403 zorder=-1,
6404 alpha=0.35,
6405 color=color,
6406 marker="^",
6407 edgecolors=(0, 0, 0),
6408 )
6409 f_hndl.axes[0].grid(True)
6410 pltUtil.set_title_label(f_hndl, 0, opts, ttl=ttl, x_lbl=x_lbl, y_lbl=y_lbl)
6412 if lgnd_loc is not None:
6413 plt.legend(loc=lgnd_loc)
6414 plt.tight_layout()
6416 return f_hndl
6418 def plot_card_dist(self, ttl=None, **kwargs):
6419 """Plots the current cardinality distribution.
6421 This assumes that the cardinality distribution has been calculated by
6422 the class.
6424 Keywrod arguments are processed with
6425 :meth:`gncpy.plotting.init_plotting_opts`. This function
6426 implements
6428 - f_hndl
6430 Parameters
6431 ----------
6432 ttl : string
6433 Title of the plot, if None a default title is generated. The default
6434 is None.
6436 Returns
6437 -------
6438 Matplotlib figure
6439 Instance of the matplotlib figure used
6440 """
6441 opts = pltUtil.init_plotting_opts(**kwargs)
6442 f_hndl = opts["f_hndl"]
6443 if ttl is None:
6444 ttl = "Cardinality Distribution"
6445 if len(self._card_dist) == 0:
6446 raise RuntimeWarning("Empty Cardinality")
6447 return f_hndl
6448 if f_hndl is None:
6449 f_hndl = plt.figure()
6450 f_hndl.add_subplot(1, 1, 1)
6451 x_vals = np.arange(0, len(self._card_dist))
6452 f_hndl.axes[0].bar(x_vals, self._card_dist)
6454 pltUtil.set_title_label(
6455 f_hndl, 0, opts, ttl=ttl, x_lbl="Cardinality", y_lbl="Probability"
6456 )
6457 plt.tight_layout()
6459 return f_hndl
6461 def plot_card_history(
6462 self, time_units="index", time=None, ttl="Cardinality History", **kwargs
6463 ):
6464 """Plots the cardinality history.
6466 Parameters
6467 ----------
6468 time_units : string, optional
6469 Text representing the units of time in the plot. The default is
6470 'index'.
6471 time : numpy array, optional
6472 Vector to use for the x-axis of the plot. If none is given then
6473 vector indices are used. The default is None.
6474 ttl : string, optional
6475 Title of the plot.
6476 **kwargs : dict
6477 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts`
6478 function. Values implemented here are `f_hndl`, and any values
6479 relating to title/axis text formatting.
6481 Returns
6482 -------
6483 fig : matplotlib figure
6484 Figure object the data was plotted on.
6485 """
6486 card_history = np.array([len(state_set) for state_set in self.states])
6488 opts = pltUtil.init_plotting_opts(**kwargs)
6489 fig = opts["f_hndl"]
6491 if fig is None:
6492 fig = plt.figure()
6493 fig.add_subplot(1, 1, 1)
6494 if time is None:
6495 time = np.arange(card_history.size, dtype=int)
6496 fig.axes[0].grid(True)
6497 fig.axes[0].step(time, card_history, where="post", label="estimated", color="k")
6498 fig.axes[0].ticklabel_format(useOffset=False)
6500 pltUtil.set_title_label(
6501 fig,
6502 0,
6503 opts,
6504 ttl=ttl,
6505 x_lbl="Time ({})".format(time_units),
6506 y_lbl="Cardinality",
6507 )
6508 fig.tight_layout()
6510 return fig
6512 def plot_ospa2_history(
6513 self,
6514 time_units="index",
6515 time=None,
6516 main_opts=None,
6517 sub_opts=None,
6518 plot_subs=True,
6519 ):
6520 """Plots the OSPA2 history.
6522 This requires that the OSPA2 has been calcualted by the approriate
6523 function first.
6525 Parameters
6526 ----------
6527 time_units : string, optional
6528 Text representing the units of time in the plot. The default is
6529 'index'.
6530 time : numpy array, optional
6531 Vector to use for the x-axis of the plot. If none is given then
6532 vector indices are used. The default is None.
6533 main_opts : dict, optional
6534 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts`
6535 function. Values implemented here are `f_hndl`, and any values
6536 relating to title/axis text formatting. The default of None implies
6537 the default options are used for the main plot.
6538 sub_opts : dict, optional
6539 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts`
6540 function. Values implemented here are `f_hndl`, and any values
6541 relating to title/axis text formatting. The default of None implies
6542 the default options are used for the sub plot.
6543 plot_subs : bool, optional
6544 Flag indicating if the component statistics (cardinality and
6545 localization) should also be plotted.
6547 Returns
6548 -------
6549 figs : dict
6550 Dictionary of matplotlib figure objects the data was plotted on.
6551 """
6552 if self.ospa2 is None:
6553 warnings.warn("OSPA must be calculated before plotting")
6554 return
6555 if main_opts is None:
6556 main_opts = pltUtil.init_plotting_opts()
6557 if sub_opts is None and plot_subs:
6558 sub_opts = pltUtil.init_plotting_opts()
6559 fmt = "{:s} OSPA2 (c = {:.1f}, p = {:d}, w={:d})"
6560 ttl = fmt.format(
6561 self._ospa2_params["core"],
6562 self._ospa2_params["cutoff"],
6563 self._ospa2_params["power"],
6564 self._ospa2_params["win_len"],
6565 )
6566 y_lbl = "OSPA2"
6568 figs = {}
6569 figs["OSPA2"] = self._plt_ospa_hist(
6570 self.ospa2, time_units, time, ttl, y_lbl, main_opts
6571 )
6573 if plot_subs:
6574 fmt = "{:s} OSPA2 Components (c = {:.1f}, p = {:d}, w={:d})"
6575 ttl = fmt.format(
6576 self._ospa2_params["core"],
6577 self._ospa2_params["cutoff"],
6578 self._ospa2_params["power"],
6579 self._ospa2_params["win_len"],
6580 )
6581 y_lbls = ["Localiztion", "Cardinality"]
6582 figs["OSPA2_subs"] = self._plt_ospa_hist_subs(
6583 [self.ospa2_localization, self.ospa2_cardinality],
6584 time_units,
6585 time,
6586 ttl,
6587 y_lbls,
6588 main_opts,
6589 )
6590 return figs
6593class LabeledPoissonMultiBernoulliMixture(PoissonMultiBernoulliMixture):
6594 def __init__(self, **kwargs):
6595 super().__init__(**kwargs)
6597 @property
6598 def labels(self):
6599 """Read only list of extracted labels.
6601 This is a list with 1 element per timestep, and each element is a list
6602 of the best labels extracted at that timestep. The order of each
6603 element corresponds to the state order.
6604 """
6605 return self._labels
6607 def _correct_birth_tab_entry(self, meas, distrib, timestep, filt_args):
6608 new_tab = self._TabEntry()
6609 filt_states, weights, states, covs = self._init_filt_states(distrib)
6611 new_f_states = [None] * len(filt_states)
6612 new_s_hist = [None] * len(filt_states)
6613 new_c_hist = [None] * len(filt_states)
6614 new_w = [None] * len(filt_states)
6615 depleted = False
6616 for ii, (f_state, state, w) in enumerate(zip(filt_states, states, weights)):
6617 try:
6618 (
6619 new_f_states[ii],
6620 new_s_hist[ii],
6621 new_c_hist[ii],
6622 new_w[ii],
6623 ) = self._inner_correct(timestep, meas, f_state, w, state, filt_args)
6624 except (
6625 gerr.ParticleDepletionError,
6626 gerr.ParticleEstimationDomainError,
6627 gerr.ExtremeMeasurementNoiseError,
6628 ):
6629 return None, 0
6630 new_tab.filt_states = new_f_states
6631 new_tab.state_hist = [new_s_hist]
6632 new_tab.cov_hist = [new_c_hist]
6633 new_tab.distrib_weights_hist = []
6634 new_w = [w + np.finfo(float).eps for w in new_w]
6635 if not depleted:
6636 cost = (
6637 np.sum(new_w).item() * self.prob_detection
6638 + self.clutter_rate * self.clutter_den
6639 )
6640 new_tab.distrib_weights_hist.append(
6641 [w / np.sum(new_w).item() for w in new_w]
6642 )
6643 new_tab.exist_prob = (
6644 self.prob_detection
6645 * cost
6646 / (self.clutter_rate * self.clutter_den + self.prob_detection * cost)
6647 )
6648 else:
6649 cost = 0
6650 new_tab.time_index = self._time_index_cntr
6651 return new_tab, cost
6653 def _gen_cor_tab(self, num_meas, meas, timestep, filt_args):
6654 num_pred = len(self._track_tab)
6655 num_birth = len(self.birth_terms)
6656 up_tab = [None] * ((num_meas + 1) * num_pred + num_meas * num_birth)
6658 # Missed Detection Updates
6659 for ii, track in enumerate(self._track_tab):
6660 up_tab[ii] = self._TabEntry().setup(track)
6661 sum_non_exist_prob = (
6662 1
6663 - up_tab[ii].exist_prob
6664 + up_tab[ii].exist_prob * self.prob_miss_detection
6665 )
6666 up_tab[ii].distrib_weights_hist.append(
6667 [w * sum_non_exist_prob for w in up_tab[ii].distrib_weights_hist[-1]]
6668 )
6669 up_tab[ii].exist_prob = (
6670 up_tab[ii].exist_prob * self.prob_miss_detection
6671 ) / (sum_non_exist_prob)
6672 up_tab[ii].meas_assoc_hist.append(None)
6673 # left_cost_m = np.zeros()
6674 # all_cost_m = np.zeros((num_pred + num_birth * num_meas, num_meas))
6675 all_cost_m = np.zeros((num_meas, num_pred + num_birth * num_meas))
6677 # Update for all existing tracks
6678 for emm, z in enumerate(meas):
6679 for ii, ent in enumerate(self._track_tab):
6680 s_to_ii = num_pred * emm + ii + num_pred
6681 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry(
6682 z, ent, timestep, filt_args
6683 )
6684 if up_tab[s_to_ii] is not None:
6685 up_tab[s_to_ii].meas_assoc_hist.append(emm)
6686 all_cost_m[emm, ii] = cost
6688 # Update for all potential new births
6689 for emm, z in enumerate(meas):
6690 for ii, b_model in enumerate(self.birth_terms):
6691 s_to_ii = ((num_meas + 1) * num_pred) + emm * num_birth + ii
6692 (up_tab[s_to_ii], cost) = self._correct_birth_tab_entry(
6693 z, b_model, timestep, filt_args
6694 )
6695 if up_tab[s_to_ii] is not None:
6696 up_tab[s_to_ii].meas_assoc_hist.append(emm)
6697 up_tab[s_to_ii].label = (round(timestep, self.decimal_places), ii)
6698 all_cost_m[emm, emm + num_pred] = cost
6699 return up_tab, all_cost_m
6701 def _update_extract_hist(self, idx_cmp):
6702 used_meas_inds = [[] for ii in range(self._time_index_cntr)]
6703 used_labels = []
6704 new_extract_hists = [None] * len(self._hypotheses[idx_cmp].track_set)
6705 for ii, track in enumerate(
6706 [
6707 self._track_tab[trk_ind]
6708 for trk_ind in self._hypotheses[idx_cmp].track_set
6709 ]
6710 ):
6711 new_extract_hists[ii] = self._ExtractHistHelper()
6712 new_extract_hists[ii].label = track.label
6713 new_extract_hists[ii].meas_ind_hist = track.meas_assoc_hist.copy()
6714 new_extract_hists[ii].b_time_index = track.time_index
6715 (
6716 new_extract_hists[ii].states,
6717 new_extract_hists[ii].covs,
6718 ) = self._extract_helper(track)
6720 used_labels.append(track.label)
6722 for t_inds_after_b, meas_ind in enumerate(
6723 new_extract_hists[ii].meas_ind_hist
6724 ):
6725 tt = new_extract_hists[ii].b_time_index + t_inds_after_b
6726 if meas_ind is not None and meas_ind not in used_meas_inds[tt]:
6727 used_meas_inds[tt].append(meas_ind)
6728 good_inds = []
6729 for ii, existing in enumerate(self._extractable_hists):
6730 used = existing.label in used_labels
6731 if used:
6732 continue
6733 for t_inds_after_b, meas_ind in enumerate(existing.meas_ind_hist):
6734 tt = existing.b_time_index + t_inds_after_b
6735 used = meas_ind is not None and meas_ind in used_meas_inds[tt]
6736 if used:
6737 break
6738 if not used:
6739 good_inds.append(ii)
6740 self._extractable_hists = [self._extractable_hists[ii] for ii in good_inds]
6741 self._extractable_hists.extend(new_extract_hists)
6743 def extract_states(self, update=True, calc_states=True):
6744 """Extracts the best state estimates.
6746 This extracts the best states from the distribution. It should be
6747 called once per time step after the correction function. This calls
6748 both the inner filters predict and correct functions so the keyword
6749 arguments must contain any additional variables needed by those
6750 functions.
6752 Parameters
6753 ----------
6754 update : bool, optional
6755 Flag indicating if the label history should be updated. This should
6756 be done once per timestep and can be disabled if calculating states
6757 after the final timestep. The default is True.
6758 calc_states : bool, optional
6759 Flag indicating if the states should be calculated based on the
6760 label history. This only needs to be done before the states are used.
6761 It can simply be called once after the end of the simulation. The
6762 default is true.
6764 Returns
6765 -------
6766 idx_cmp : int
6767 Index of the hypothesis table used when extracting states.
6768 """
6769 card = np.argmax(self._card_dist)
6770 tracks_per_hyp = np.array([x.num_tracks for x in self._hypotheses])
6771 weight_per_hyp = np.array([x.assoc_prob for x in self._hypotheses])
6773 self._states = [[] for ii in range(self._time_index_cntr)]
6774 self._labels = [[] for ii in range(self._time_index_cntr)]
6775 self._covs = [[] for ii in range(self._time_index_cntr)]
6777 if len(tracks_per_hyp) == 0:
6778 return None
6779 idx_cmp = np.argmax(weight_per_hyp * (tracks_per_hyp == card))
6780 if update:
6781 self._update_extract_hist(idx_cmp)
6782 if calc_states:
6783 for existing in self._extractable_hists:
6784 for t_inds_after_b, (s, c) in enumerate(
6785 zip(existing.states, existing.covs)
6786 ):
6787 tt = existing.b_time_index + t_inds_after_b
6788 # if len(self._labels[tt]) == 0:
6789 # self._states[tt] = [s]
6790 # self._labels[tt] = [existing.label]
6791 # self._covs[tt] = [c]
6792 # else:
6793 self._states[tt].append(s)
6794 self._labels[tt].append(existing.label)
6795 self._covs[tt].append(c)
6796 if not update and not calc_states:
6797 warnings.warn("Extracting states performed no actions")
6798 return idx_cmp
6800 def _ospa_setup_emat(self, state_dim, state_inds):
6801 # get sizes
6802 num_timesteps = len(self.states)
6803 num_objs = 0
6804 lbl_to_ind = {}
6806 for lst in self.labels:
6807 for lbl in lst:
6808 if lbl is None:
6809 continue
6810 key = str(lbl)
6811 if key not in lbl_to_ind:
6812 lbl_to_ind[key] = num_objs
6813 num_objs += 1
6814 # create matrices
6815 est_mat = np.nan * np.ones((state_dim, num_timesteps, num_objs))
6816 est_cov_mat = np.nan * np.ones((state_dim, state_dim, num_timesteps, num_objs))
6818 for tt, (lbl_lst, s_lst) in enumerate(zip(self.labels, self.states)):
6819 for lbl, s in zip(lbl_lst, s_lst):
6820 if lbl is None:
6821 continue
6822 obj_num = lbl_to_ind[str(lbl)]
6823 est_mat[:, tt, obj_num] = s.ravel()[state_inds]
6824 if self.save_covs:
6825 for tt, (lbl_lst, c_lst) in enumerate(zip(self.labels, self.covariances)):
6826 for lbl, c in zip(lbl_lst, c_lst):
6827 if lbl is None:
6828 continue
6829 est_cov_mat[:, :, tt, lbl_to_ind[str(lbl)]] = c[state_inds][
6830 :, state_inds
6831 ]
6832 return est_mat, est_cov_mat
6834 def calculate_ospa2(
6835 self,
6836 truth,
6837 c,
6838 p,
6839 win_len,
6840 true_covs=None,
6841 core_method=SingleObjectDistance.MANHATTAN,
6842 state_inds=None,
6843 ):
6844 """Calculates the OSPA(2) distance between the truth at all timesteps.
6846 Wrapper for :func:`serums.distances.calculate_ospa2`.
6848 Parameters
6849 ----------
6850 truth : list
6851 Each element represents a timestep and is a list of N x 1 numpy array,
6852 one per true agent in the swarm.
6853 c : float
6854 Distance cutoff for considering a point properly assigned. This
6855 influences how cardinality errors are penalized. For :math:`p = 1`
6856 it is the penalty given false point estimate.
6857 p : int
6858 The power of the distance term. Higher values penalize outliers
6859 more.
6860 win_len : int
6861 Number of samples to include in window.
6862 core_method : :class:`serums.enums.SingleObjectDistance`, Optional
6863 The main distance measure to use for the localization component.
6864 The default value is :attr:`.SingleObjectDistance.MANHATTAN`.
6865 true_covs : list, Optional
6866 Each element represents a timestep and is a list of N x N numpy arrays
6867 corresonponding to the uncertainty about the true states. Note the
6868 order must be consistent with the truth data given. This is only
6869 needed for core methods :attr:`SingleObjectDistance.HELLINGER`. The defautl
6870 value is None.
6871 state_inds : list, optional
6872 Indices in the state vector to use, will be applied to the truth
6873 data as well. The default is None which means the full state is
6874 used.
6875 """
6876 # error checking on optional input arguments
6877 core_method = self._ospa_input_check(core_method, truth, true_covs)
6879 # setup data structures
6880 if state_inds is None:
6881 state_dim = self._ospa_find_s_dim(truth)
6882 state_inds = range(state_dim)
6883 else:
6884 state_dim = len(state_inds)
6885 if state_dim is None:
6886 warnings.warn("Failed to get state dimension. SKIPPING OSPA(2) calculation")
6888 nt = len(self._states)
6889 self.ospa2 = np.zeros(nt)
6890 self.ospa2_localization = np.zeros(nt)
6891 self.ospa2_cardinality = np.zeros(nt)
6892 self._ospa2_params["core"] = core_method
6893 self._ospa2_params["cutoff"] = c
6894 self._ospa2_params["power"] = p
6895 self._ospa2_params["win_len"] = win_len
6896 return
6897 true_mat, true_cov_mat = self._ospa_setup_tmat(
6898 truth, state_dim, true_covs, state_inds
6899 )
6900 est_mat, est_cov_mat = self._ospa_setup_emat(state_dim, state_inds)
6902 # find OSPA
6903 (
6904 self.ospa2,
6905 self.ospa2_localization,
6906 self.ospa2_cardinality,
6907 self._ospa2_params["core"],
6908 self._ospa2_params["cutoff"],
6909 self._ospa2_params["power"],
6910 self._ospa2_params["win_len"],
6911 ) = calculate_ospa2(
6912 est_mat,
6913 true_mat,
6914 c,
6915 p,
6916 win_len,
6917 core_method=core_method,
6918 true_cov_mat=true_cov_mat,
6919 est_cov_mat=est_cov_mat,
6920 )
6922 def plot_states_labels(
6923 self,
6924 plt_inds,
6925 ttl="Labeled State Trajectories",
6926 x_lbl=None,
6927 y_lbl=None,
6928 meas_tx_fnc=None,
6929 **kwargs,
6930 ):
6931 """Plots the best estimate for the states and labels.
6933 This assumes that the states have been extracted. It's designed to plot
6934 two of the state variables (typically x/y position). The error ellipses
6935 are calculated according to :cite:`Hoover1984_AlgorithmsforConfidenceCirclesandEllipses`
6937 Keywrod arguments are processed with
6938 :meth:`gncpy.plotting.init_plotting_opts`. This function
6939 implements
6941 - f_hndl
6942 - true_states
6943 - sig_bnd
6944 - rng
6945 - meas_inds
6946 - lgnd_loc
6948 Parameters
6949 ----------
6950 plt_inds : list
6951 List of indices in the state vector to plot
6952 ttl : string, optional
6953 Title of the plot.
6954 x_lbl : string, optional
6955 X-axis label for the plot.
6956 y_lbl : string, optional
6957 Y-axis label for the plot.
6958 meas_tx_fnc : callable, optional
6959 Takes in the measurement vector as an Nm x 1 numpy array and
6960 returns a numpy array representing the states to plot (size 2). The
6961 default is None.
6963 Returns
6964 -------
6965 Matplotlib figure
6966 Instance of the matplotlib figure used
6967 """
6968 opts = pltUtil.init_plotting_opts(**kwargs)
6969 f_hndl = opts["f_hndl"]
6970 true_states = opts["true_states"]
6971 sig_bnd = opts["sig_bnd"]
6972 rng = opts["rng"]
6973 meas_inds = opts["meas_inds"]
6974 lgnd_loc = opts["lgnd_loc"]
6975 mrkr = opts["marker"]
6977 if rng is None:
6978 rng = rnd.default_rng(1)
6979 if x_lbl is None:
6980 x_lbl = "x-position"
6981 if y_lbl is None:
6982 y_lbl = "y-position"
6983 meas_specs_given = (
6984 meas_inds is not None and len(meas_inds) == 2
6985 ) or meas_tx_fnc is not None
6986 plt_meas = meas_specs_given and self.save_measurements
6987 show_sig = sig_bnd is not None and self.save_covs
6989 s_lst = deepcopy(self.states)
6990 l_lst = deepcopy(self.labels)
6991 x_dim = None
6993 if f_hndl is None:
6994 f_hndl = plt.figure()
6995 f_hndl.add_subplot(1, 1, 1)
6996 # get state dimension
6997 for states in s_lst:
6998 if states is not None and len(states) > 0:
6999 x_dim = states[0].size
7000 break
7001 # get unique labels
7002 u_lbls = []
7003 for lbls in l_lst:
7004 if lbls is None:
7005 continue
7006 for lbl in lbls:
7007 if lbl not in u_lbls:
7008 u_lbls.append(lbl)
7009 cmap = pltUtil.get_cmap(len(u_lbls))
7011 # get array of all state values for each label
7012 added_sig_lbl = False
7013 added_true_lbl = False
7014 added_state_lbl = False
7015 added_meas_lbl = False
7016 for c_idx, lbl in enumerate(u_lbls):
7017 x = np.nan * np.ones((x_dim, len(s_lst)))
7018 if show_sig:
7019 sigs = [None] * len(s_lst)
7020 for tt, lbls in enumerate(l_lst):
7021 if lbls is None:
7022 continue
7023 if lbl in lbls:
7024 ii = lbls.index(lbl)
7025 if s_lst[tt][ii] is not None:
7026 x[:, [tt]] = s_lst[tt][ii].copy()
7027 if show_sig:
7028 sig = np.zeros((2, 2))
7029 if self._covs[tt][ii] is not None:
7030 sig[0, 0] = self._covs[tt][ii][plt_inds[0], plt_inds[0]]
7031 sig[0, 1] = self._covs[tt][ii][plt_inds[0], plt_inds[1]]
7032 sig[1, 0] = self._covs[tt][ii][plt_inds[1], plt_inds[0]]
7033 sig[1, 1] = self._covs[tt][ii][plt_inds[1], plt_inds[1]]
7034 else:
7035 sig = None
7036 sigs[tt] = sig
7037 # plot
7038 color = cmap(c_idx)
7040 if show_sig:
7041 for tt, sig in enumerate(sigs):
7042 if sig is None:
7043 continue
7044 w, h, a = pltUtil.calc_error_ellipse(sig, sig_bnd)
7045 if not added_sig_lbl:
7046 s = r"${}\sigma$ Error Ellipses".format(sig_bnd)
7047 e = Ellipse(
7048 xy=x[plt_inds, tt],
7049 width=w,
7050 height=h,
7051 angle=a,
7052 zorder=-10000,
7053 label=s,
7054 )
7055 added_sig_lbl = True
7056 else:
7057 e = Ellipse(
7058 xy=x[plt_inds, tt],
7059 width=w,
7060 height=h,
7061 angle=a,
7062 zorder=-10000,
7063 )
7064 e.set_clip_box(f_hndl.axes[0].bbox)
7065 e.set_alpha(0.2)
7066 e.set_facecolor(color)
7067 f_hndl.axes[0].add_patch(e)
7068 settings = {
7069 "color": color,
7070 "markeredgecolor": "k",
7071 "marker": mrkr,
7072 "ls": "--",
7073 }
7074 if not added_state_lbl:
7075 settings["label"] = "States"
7076 # f_hndl.axes[0].scatter(x[plt_inds[0], :], x[plt_inds[1], :],
7077 # color=color, edgecolors='k',
7078 # label='States')
7079 added_state_lbl = True
7080 # else:
7081 f_hndl.axes[0].plot(x[plt_inds[0], :], x[plt_inds[1], :], **settings)
7083 s = "({}, {})".format(lbl[0], lbl[1])
7084 tmp = x.copy()
7085 tmp = tmp[:, ~np.any(np.isnan(tmp), axis=0)]
7086 f_hndl.axes[0].text(
7087 tmp[plt_inds[0], 0], tmp[plt_inds[1], 0], s, color=color
7088 )
7089 # if true states are available then plot them
7090 if true_states is not None and any([len(x) > 0 for x in true_states]):
7091 if x_dim is None:
7092 for states in true_states:
7093 if len(states) > 0:
7094 x_dim = states[0].size
7095 break
7096 max_true = max([len(x) for x in true_states])
7097 x = np.nan * np.ones((x_dim, len(true_states), max_true))
7098 for tt, states in enumerate(true_states):
7099 for ii, state in enumerate(states):
7100 if state is not None and state.size > 0:
7101 x[:, [tt], ii] = state.copy()
7102 for ii in range(0, max_true):
7103 if not added_true_lbl:
7104 f_hndl.axes[0].plot(
7105 x[plt_inds[0], :, ii],
7106 x[plt_inds[1], :, ii],
7107 color="k",
7108 marker=".",
7109 label="True Trajectories",
7110 )
7111 added_true_lbl = True
7112 else:
7113 f_hndl.axes[0].plot(
7114 x[plt_inds[0], :, ii],
7115 x[plt_inds[1], :, ii],
7116 color="k",
7117 marker=".",
7118 )
7119 if plt_meas:
7120 meas_x = []
7121 meas_y = []
7122 for meas_tt in self._meas_tab:
7123 if meas_tx_fnc is not None:
7124 tx_meas = [meas_tx_fnc(m) for m in meas_tt]
7125 mx_ii = [tm[0].item() for tm in tx_meas]
7126 my_ii = [tm[1].item() for tm in tx_meas]
7127 else:
7128 mx_ii = [m[meas_inds[0]].item() for m in meas_tt]
7129 my_ii = [m[meas_inds[1]].item() for m in meas_tt]
7130 meas_x.extend(mx_ii)
7131 meas_y.extend(my_ii)
7132 color = (128 / 255, 128 / 255, 128 / 255)
7133 meas_x = np.asarray(meas_x)
7134 meas_y = np.asarray(meas_y)
7135 if meas_x.size > 0:
7136 if not added_meas_lbl:
7137 f_hndl.axes[0].scatter(
7138 meas_x,
7139 meas_y,
7140 zorder=-1,
7141 alpha=0.35,
7142 color=color,
7143 marker="^",
7144 label="Measurements",
7145 )
7146 else:
7147 f_hndl.axes[0].scatter(
7148 meas_x, meas_y, zorder=-1, alpha=0.35, color=color, marker="^"
7149 )
7150 f_hndl.axes[0].grid(True)
7151 pltUtil.set_title_label(
7152 f_hndl, 0, opts, ttl=ttl, x_lbl="x-position", y_lbl="y-position"
7153 )
7154 if lgnd_loc is not None:
7155 plt.legend(loc=lgnd_loc)
7156 plt.tight_layout()
7158 return f_hndl
7161class _STMPMBMBase:
7162 def __init__(self, **kwargs):
7163 super().__init__(**kwargs)
7165 def _init_filt_states(self, distrib):
7166 filt_states = [None] * len(distrib.means)
7167 states = [m.copy() for m in distrib.means]
7168 covs = [None] * len(distrib.means)
7170 weights = distrib.weights.copy()
7171 self._baseFilter.dof = distrib.dof
7172 for ii, scale in enumerate(distrib.scalings):
7173 self._baseFilter.scale = scale.copy()
7174 filt_states[ii] = self._baseFilter.save_filter_state()
7175 if self.save_covs:
7176 # no need to copy because cov is already a new object for the student's t-fitler
7177 covs[ii] = self.filter.cov
7178 return filt_states, weights, states, covs
7180 def _gate_meas(self, meas, means, covs, **kwargs):
7181 # TODO: check this implementation
7182 if len(meas) == 0:
7183 return []
7184 scalings = []
7185 for ent in self._track_tab:
7186 scalings.extend(ent.probDensity.scalings)
7187 valid = []
7188 for m, p in zip(means, scalings):
7189 meas_mat = self.filter.get_meas_mat(m, **kwargs)
7190 est = self.filter.get_est_meas(m, **kwargs)
7191 factor = (
7192 self.filter.meas_noise_dof
7193 * (self.filter.dof - 2)
7194 / (self.filter.dof * (self.filter.meas_noise_dof - 2))
7195 )
7196 P_zz = meas_mat @ p @ meas_mat.T + factor * self.filter.meas_noise
7197 inv_P = la.inv(P_zz)
7199 for ii, z in enumerate(meas):
7200 if ii in valid:
7201 continue
7202 innov = z - est
7203 dist = innov.T @ inv_P @ innov
7204 if dist < self.inv_chi2_gate:
7205 valid.append(ii)
7206 valid.sort()
7207 return [meas[ii] for ii in valid]
7210class STMPoissonMultiBernoulliMixture(_STMPMBMBase, PoissonMultiBernoulliMixture):
7211 """Implementation of a STM-PMBM filter."""
7213 def __init__(self, **kwargs):
7214 super().__init__(**kwargs)
7217class STMLabeledPoissonMultiBernoulliMixture(
7218 _STMPMBMBase, LabeledPoissonMultiBernoulliMixture
7219):
7220 """Implementation of a STM-LPMBM filter."""
7222 def __init__(self, **kwargs):
7223 super().__init__(**kwargs)
7226class _SMCPMBMBase:
7227 def __init__(
7228 self, compute_prob_detection=None, compute_prob_survive=None, **kwargs
7229 ):
7230 self.compute_prob_detection = compute_prob_detection
7231 self.compute_prob_survive = compute_prob_survive
7233 # for wrappers for predict/correct function to handle extra args for private functions
7234 self._prob_surv_args = ()
7235 self._prob_det_args = ()
7237 super().__init__(**kwargs)
7239 def _init_filt_states(self, distrib):
7240 self._baseFilter.init_from_dist(distrib, make_copy=True)
7241 filt_states = [
7242 self._baseFilter.save_filter_state(),
7243 ]
7244 states = [distrib.mean]
7245 if self.save_covs:
7246 covs = [
7247 distrib.covariance,
7248 ]
7249 else:
7250 covs = []
7251 weights = [
7252 1,
7253 ] # not needed so set to 1
7255 return filt_states, weights, states, covs
7257 def _calc_avg_prob_surv_death(self):
7258 avg_prob_survive = np.zeros(len(self._track_tab))
7259 for tabidx, ent in enumerate(self._track_tab):
7260 # TODO: fix hack so not using "private" variable outside class
7261 p_surv = self.compute_prob_survive(
7262 ent.filt_states[0]["_particleDist"].particles, *self._prob_surv_args
7263 )
7264 avg_prob_survive[tabidx] = np.sum(
7265 np.array(ent.filt_states[0]["_particleDist"].weights) * p_surv
7266 )
7267 avg_prob_death = 1 - avg_prob_survive
7269 return avg_prob_survive, avg_prob_death
7271 def _inner_predict(self, timestep, filt_state, state, filt_args):
7272 self.filter.load_filter_state(filt_state)
7273 if self.filter._particleDist.num_particles > 0:
7274 new_s = self.filter.predict(timestep, **filt_args)
7276 # manually update weights to account for prob survive
7277 # TODO: fix hack so not using "private" variable outside class
7278 ps = self.compute_prob_survive(
7279 self.filter._particleDist.particles, *self._prob_surv_args
7280 )
7281 new_weights = [
7282 w * ps[ii] for ii, (p, w) in enumerate(self.filter._particleDist)
7283 ]
7284 tot = sum(new_weights)
7285 if np.abs(tot) == np.inf:
7286 w_lst = [np.inf] * len(new_weights)
7287 else:
7288 w_lst = [w / tot for w in new_weights]
7289 self.filter._particleDist.update_weights(w_lst)
7291 new_f_state = self.filter.save_filter_state()
7292 if self.save_covs:
7293 new_cov = self.filter.cov.copy()
7294 else:
7295 new_cov = None
7296 else:
7297 new_f_state = self.filter.save_filter_state()
7298 new_s = state
7299 new_cov = self.filter.cov
7300 return new_f_state, new_s, new_cov
7302 def predict(self, timestep, prob_surv_args=(), **kwargs):
7303 """Prediction step of the SMC-GLMB filter.
7305 This is a wrapper for the parent class to allow for extra parameters.
7306 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.predict` for
7307 additional details.
7309 Parameters
7310 ----------
7311 timestep : float
7312 Current timestep.
7313 prob_surv_args : tuple, optional
7314 Additional arguments for the `compute_prob_survive` function.
7315 The default is ().
7316 **kwargs : dict, optional
7317 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.predict`
7318 """
7319 self._prob_surv_args = prob_surv_args
7320 return super().predict(timestep, **kwargs)
7322 def _calc_avg_prob_det_mdet(self, cor_tab):
7323 avg_prob_detect = np.zeros(len(cor_tab))
7324 for tabidx, ent in enumerate(cor_tab):
7325 # TODO: fix hack so not using "private" variable outside class
7326 p_detect = self.compute_prob_detection(
7327 ent.filt_states[0]["_particleDist"].particles, *self._prob_det_args
7328 )
7329 avg_prob_detect[tabidx] = np.sum(
7330 np.array(ent.filt_states[0]["_particleDist"].weights) * p_detect
7331 )
7332 avg_prob_miss_detect = 1 - avg_prob_detect
7334 return avg_prob_detect, avg_prob_miss_detect
7336 def _inner_correct(
7337 self, timestep, meas, filt_state, distrib_weight, state, filt_args
7338 ):
7339 self.filter.load_filter_state(filt_state)
7340 if self.filter._particleDist.num_particles > 0:
7341 cor_state, likely = self.filter.correct(timestep, meas, **filt_args)[0:2]
7343 # manually update the particle weights to account for probability of detection
7344 # TODO: fix hack so not using "private" variable outside class
7345 pd = self.compute_prob_detection(
7346 self.filter._particleDist.particles, *self._prob_det_args
7347 )
7348 pd_weight = (
7349 pd * np.array(self.filter._particleDist.weights) + np.finfo(float).eps
7350 )
7351 self.filter._particleDist.update_weights(
7352 (pd_weight / np.sum(pd_weight)).tolist()
7353 )
7355 # determine the partial cost, the remainder is calculated later from
7356 # the hypothesis
7357 new_w = np.sum(likely * pd_weight) # same as cost in this case
7359 new_f_state = self.filter.save_filter_state()
7360 new_s = cor_state
7361 if self.save_covs:
7362 new_c = self.filter.cov
7363 else:
7364 new_c = None
7365 else:
7366 new_f_state = self.filter.save_filter_state()
7367 new_s = state
7368 new_c = self.filter.cov
7369 new_w = 0
7370 return new_f_state, new_s, new_c, new_w
7372 def correct(self, timestep, meas, prob_det_args=(), **kwargs):
7373 """Correction step of the SMC-GLMB filter.
7375 This is a wrapper for the parent class to allow for extra parameters.
7376 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.correct` for
7377 additional details.
7379 Parameters
7380 ----------
7381 timestep : float
7382 Current timestep.
7383 prob_det_args : tuple, optional
7384 Additional arguments for the `compute_prob_detection` function.
7385 The default is ().
7386 **kwargs : dict, optional
7387 See :meth:`.tracker.GeneralizedLabeledMultiBernoulli.correct`
7388 """
7389 self._prob_det_args = prob_det_args
7390 return super().correct(timestep, meas, **kwargs)
7392 def extract_most_prob_states(self, thresh, **kwargs):
7393 """Extracts themost probable states.
7395 .. todo::
7396 Implement this function for the SMC-GLMB filter
7398 Raises
7399 ------
7400 RuntimeWarning
7401 Function must be implemented.
7402 """
7403 warnings.warn("Not implemented for this class")
7406class SMCPoissonMultiBernoulliMixture(_SMCPMBMBase, PoissonMultiBernoulliMixture):
7407 """Implementation of a Sequential Monte Carlo PMBM filter.
7409 This filter does not account for agents spawned from existing tracks, only agents
7410 birthed from the given birth model.
7412 Attributes
7413 ----------
7414 compute_prob_detection : callable
7415 Function that takes a list of particles as the first argument and `*args`
7416 as the next. Returns the probability of detection for each particle as a list.
7417 compute_prob_survive : callable
7418 Function that takes a list of particles as the first argument and `*args` as
7419 the next. Returns the average probability of survival for each particle as a list.
7420 """
7422 def __init__(self, **kwargs):
7423 super().__init__(**kwargs)
7426class SMCLabeledPoissonMultiBernoulliMixture(
7427 _SMCPMBMBase, LabeledPoissonMultiBernoulliMixture
7428):
7429 """Implementation of a Sequential Monte Carlo LPMBM filter.
7431 This filter does not account for agents spawned from existing tracks, only agents
7432 birthed from the given birth model.
7434 Attributes
7435 ----------
7436 compute_prob_detection : callable
7437 Function that takes a list of particles as the first argument and `*args`
7438 as the next. Returns the probability of detection for each particle as a list.
7439 compute_prob_survive : callable
7440 Function that takes a list of particles as the first argument and `*args` as
7441 the next. Returns the average probability of survival for each particle as a list.
7442 """
7444 def __init__(self, **kwargs):
7445 super().__init__(**kwargs)
7448class _IMMPMBMBase:
7449 def _init_filt_states(self, distrib):
7450 filt_states = [None] * len(distrib.means)
7451 states = [m.copy() for m in distrib.means]
7452 if self.save_covs:
7453 covs = [c.copy() for c in distrib.covariances]
7454 else:
7455 covs = []
7456 weights = distrib.weights.copy()
7457 for ii, (m, cov) in enumerate(zip(distrib.means, distrib.covariances)):
7458 # if len(m) != 1 or len(cov) != 1:
7459 # raise ValueError("Only one mean can be passed to IMM filters for initialization")
7460 m_list = []
7461 c_list = []
7462 for jj in range(0, len(self._baseFilter.in_filt_list)):
7463 m_list.append(m)
7464 c_list.append(cov)
7465 self._baseFilter.initialize_states(m_list, c_list)
7466 filt_states[ii] = self._baseFilter.save_filter_state()
7467 return filt_states, weights, states, covs
7469 def _inner_predict(self, timestep, filt_state, state, filt_args):
7470 self.filter.load_filter_state(filt_state)
7471 new_s = self.filter.predict(timestep, **filt_args)
7472 new_f_state = self.filter.save_filter_state()
7473 if self.save_covs:
7474 new_cov = self.filter.cov.copy()
7475 else:
7476 new_cov = None
7477 return new_f_state, new_s, new_cov
7479 def _inner_correct(
7480 self, timestep, meas, filt_state, distrib_weight, state, filt_args
7481 ):
7482 self.filter.load_filter_state(filt_state)
7483 cor_state, likely = self.filter.correct(timestep, meas, **filt_args)
7484 new_f_state = self.filter.save_filter_state()
7485 new_s = cor_state
7486 if self.save_covs:
7487 new_c = self.filter.cov.copy()
7488 else:
7489 new_c = None
7490 new_w = distrib_weight * likely
7492 return new_f_state, new_s, new_c, new_w
7495class IMMPoissonMultiBernoulliMixture(_IMMPMBMBase, PoissonMultiBernoulliMixture):
7496 """An implementation of the IMM-PMBM algorithm."""
7498 def __init__(self, **kwargs):
7499 super().__init__(**kwargs)
7502class IMMLabeledPoissonMultiBernoulliMixture(
7503 _IMMPMBMBase, LabeledPoissonMultiBernoulliMixture
7504):
7505 """An implementation of the IMM-LPMBM algorithm."""
7507 def __init__(self, **kwargs):
7508 super().__init__(**kwargs)
7511class MSPoissonMultiBernoulliMixture(PoissonMultiBernoulliMixture):
7512 """An Implementation of the Multiple Sensor PMBM Filter."""
7514 # Need measurement association history to incorporate meas inds from each sensor
7515 def __init__(selfself, **kwargs):
7516 super().__init__(**kwargs)
7518 def _gen_cor_tab(self, num_meas, meas, timestep, comb_inds, filt_args):
7519 num_pred = len(self._track_tab)
7520 num_birth = len(self.birth_terms)
7521 up_tab = [None] * ((num_meas + 1) * num_pred + num_meas * num_birth)
7523 # Missed Detection Updates
7524 for ii, track in enumerate(self._track_tab):
7525 up_tab[ii] = self._TabEntry().setup(track)
7526 sum_non_exist_prob = (
7527 1
7528 - up_tab[ii].exist_prob
7529 + up_tab[ii].exist_prob * self.prob_miss_detection
7530 )
7531 up_tab[ii].distrib_weights_hist.append(
7532 [w * sum_non_exist_prob for w in up_tab[ii].distrib_weights_hist[-1]]
7533 )
7534 up_tab[ii].exist_prob = (
7535 up_tab[ii].exist_prob * self.prob_miss_detection
7536 ) / (sum_non_exist_prob)
7537 up_tab[ii].meas_assoc_hist.append(None)
7538 # left_cost_m = np.zeros()
7539 # all_cost_m = np.zeros((num_pred + num_birth * num_meas, num_meas))
7540 all_cost_m = np.zeros((num_meas, num_pred + num_birth * num_meas))
7542 # Update for all existing tracks
7543 for emm, z in enumerate(meas):
7544 for ii, ent in enumerate(self._track_tab):
7545 s_to_ii = num_pred * emm + ii + num_pred
7546 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry(
7547 z, ent, timestep, filt_args
7548 )
7549 if up_tab[s_to_ii] is not None:
7550 up_tab[s_to_ii].meas_assoc_hist.append(comb_inds[emm])
7551 all_cost_m[emm, ii] = cost
7553 # Update for all potential new births
7554 for emm, z in enumerate(meas):
7555 for ii, b_model in enumerate(self.birth_terms):
7556 s_to_ii = ((num_meas + 1) * num_pred) + emm * num_birth + ii
7557 (up_tab[s_to_ii], cost) = self._correct_birth_tab_entry(
7558 z, b_model, timestep, filt_args
7559 )
7560 if up_tab[s_to_ii] is not None:
7561 up_tab[s_to_ii].meas_assoc_hist.append(comb_inds[emm])
7562 all_cost_m[emm, emm + num_pred] = cost
7563 return up_tab, all_cost_m
7565 def _gen_cor_hyps(
7566 self,
7567 num_meas,
7568 avg_prob_detect,
7569 avg_prob_miss_detect,
7570 all_cost_m,
7571 meas_combs,
7572 cor_tab,
7573 ):
7574 num_pred = len(self._track_tab)
7575 up_hyps = []
7576 if not meas_combs:
7577 meas_combs = np.arange(0, np.shape(all_cost_m)[0]).tolist()
7578 # n_obj_for_tracks =
7579 if num_meas == 0:
7580 for hyp in self._hypotheses:
7581 pmd_log = np.sum(
7582 [np.log(avg_prob_miss_detect[ii]) for ii in hyp.track_set]
7583 )
7584 hyp.assoc_prob = -self.clutter_rate + pmd_log + np.log(hyp.assoc_prob)
7585 up_hyps.append(hyp)
7586 else:
7587 clutter = self.clutter_rate * self.clutter_den
7588 ss_w = 0
7589 for p_hyp in self._hypotheses:
7590 ss_w += np.sqrt(p_hyp.assoc_prob)
7591 for p_hyp in self._hypotheses:
7592 for ind_lst in meas_combs:
7593 if len(meas_combs) == 1:
7594 if p_hyp.num_tracks == 0: # all clutter
7595 inds = np.arange(num_pred, num_pred + num_meas).tolist()
7596 else:
7597 inds = (
7598 p_hyp.track_set
7599 + np.arange(num_pred, num_pred + num_meas).tolist()
7600 )
7601 cost_m = all_cost_m[:, inds]
7602 else:
7603 if p_hyp.num_tracks == 0: # all clutter
7604 inds = np.arange(num_pred, num_pred + len(ind_lst)).tolist()
7605 else:
7606 inds = p_hyp.track_set + [x + num_pred for x in ind_lst]
7607 tcm = all_cost_m[
7608 :, inds
7609 ] # error is certainly caused here. I'm going to bed now because it's past 11.
7610 cost_m = tcm[ind_lst, :]
7611 max_row_inds, max_col_inds = np.where(cost_m >= np.inf)
7612 if max_row_inds.size > 0:
7613 cost_m[max_row_inds, max_col_inds] = np.finfo(float).max
7614 min_row_inds, min_col_inds = np.where(cost_m <= 0.0)
7615 if min_row_inds.size > 0:
7616 cost_m[min_row_inds, min_col_inds] = np.finfo(float).eps # 1
7617 neg_log = -np.log(cost_m)
7619 m = np.round(self.req_upd * np.sqrt(p_hyp.assoc_prob) / ss_w)
7620 m = int(m.item())
7622 [assigns, costs] = murty_m_best_all_meas_assigned(neg_log, m)
7624 pmd_log = np.sum(
7625 [np.log(avg_prob_miss_detect[ii]) for ii in p_hyp.track_set]
7626 )
7627 for a, c in zip(assigns, costs):
7628 new_hyp = self._HypothesisHelper()
7629 new_hyp.assoc_prob = (
7630 -self.clutter_rate
7631 + num_meas * np.log(clutter)
7632 + pmd_log
7633 + np.log(p_hyp.assoc_prob)
7634 - c
7635 )
7636 if p_hyp.num_tracks == 0:
7637 new_track_list = list(num_pred * a + num_pred * num_meas)
7638 else:
7639 # track_inds = np.argwhere(a==1)
7640 new_track_list = []
7642 for ii, ms in enumerate(a):
7643 if len(p_hyp.track_set) >= ms:
7644 new_track_list.append(
7645 (ii + 1) * num_pred + p_hyp.track_set[(ms - 1)]
7646 )
7647 elif len(p_hyp.track_set) == len(a):
7648 new_track_list.append(
7649 num_pred * ms - ii * (num_pred - 1)
7650 )
7651 elif len(p_hyp.track_set) < len(a):
7652 new_track_list.append(
7653 num_pred * (num_meas + 1) + (ms - num_meas)
7654 )
7655 else:
7656 new_track_list.append(
7657 (num_meas + 1) * num_pred
7658 - (ms - len(p_hyp.track_set) - 1)
7659 )
7660 # if len(a) == len(p_hyp.track_set):
7661 # for ii, (ms, t) in enumerate(zip(a, p_hyp.track_set)):
7662 # if len(p_hyp.track_set) >= ms:
7663 # # new_track_list.append(((np.array(t)) * ms + num_pred))
7664 # new_track_list.append(
7665 # (num_pred * (ind_lst[ii] + 1) + np.array(t))
7666 # )
7667 # # new_track_list.append((num_pred * ms + np.array(t)))
7668 # else:
7669 # # new_track_list.append(num_pred * ms - ind_lst[ii] * (num_pred - 1))
7670 # new_track_list.append(
7671 # num_pred * (ind_lst[ii] + 1)
7672 # - ii * (num_pred - 1)
7673 # )
7674 # elif len(p_hyp.track_set) < len(a):
7675 # for ii, ms in enumerate(a):
7676 # if len(p_hyp.track_set) >= ms:
7677 # new_track_list.append(
7678 # (1 + ind_lst[ii]) * num_pred
7679 # + p_hyp.track_set[(ms - 1)]
7680 # )
7681 # elif len(p_hyp.track_set) < ms:
7682 # new_track_list.append(
7683 # num_pred * (num_meas + 1) + (ind_lst[ii])
7684 # )
7685 # elif len(p_hyp.track_set) > len(a):
7686 # # May need to modify this
7687 # for ii, ms in enumerate(a):
7688 # if len(p_hyp.track_set) >= ms:
7689 # new_track_list.append(
7690 # ms * num_pred + p_hyp.track_set[(ms - 1)]
7691 # )
7692 # elif len(p_hyp.track_set) < ms:
7693 # new_track_list.append(
7694 # num_pred * (num_meas + 1) + (ms - num_meas)
7695 # )
7697 # new_track_list = list(np.array(p_hyp.track_set) + num_pred + num_pred * a)# new_track_list = list(num_pred * a + np.array(p_hyp.track_set))
7699 new_hyp.track_set = new_track_list
7700 up_hyps.append(new_hyp)
7702 lse = log_sum_exp([x.assoc_prob for x in up_hyps])
7703 for ii in range(0, len(up_hyps)):
7704 up_hyps[ii].assoc_prob = np.exp(up_hyps[ii].assoc_prob - lse)
7705 return up_hyps
7707 def correct(self, timestep, meas, filt_args={}):
7708 """Correction step of the MS-PMBM filter.
7710 Notes
7711 -----
7712 This corrects the hypotheses based on the measurements and gates the
7713 measurements according to the class settings. It also updates the
7714 cardinality distribution.
7716 Parameters
7717 ----------
7718 timestep: float
7719 Current timestep.
7720 meas : list
7721 List of Nm x 1 numpy arrays each representing a measuremnt.
7722 filt_args : dict, optional
7723 keyword arguments to pass to the inner filters correct function.
7724 The default is {}.
7726 Returns
7727 -------
7728 None
7729 """
7730 all_combs = list(itertools.product(*meas))
7731 # TODO: Add method for only single measurements to be assoc'd i.e. all_combs needs to include single measurement options
7732 if self.gating_on:
7733 warnings.warn("Gating not implemented yet. SKIPPING", RuntimeWarning)
7734 # means = []
7735 # covs = []
7736 # for ent in self._track_tab:
7737 # means.extend(ent.probDensity.means)
7738 # covs.extend(ent.probDensity.covariances)
7739 # meas = self._gate_meas(meas, means, covs)
7740 if self.save_measurements:
7741 self._meas_tab.append(deepcopy(meas))
7743 # get matrix of indices in all_combs
7744 num_meas_per_sens = [len(x) for x in meas]
7745 num_meas = len(all_combs)
7746 num_sens = len(meas)
7747 mnmps = min(num_meas_per_sens)
7748 comb_inds = list(itertools.product(*list(np.arange(0, len(x)) for x in meas)))
7749 comb_inds = [list(ele) for ele in comb_inds]
7750 min_meas_in_sens = np.min([len(x) for x in meas])
7752 all_meas_combs = list(itertools.combinations(comb_inds, mnmps))
7753 all_meas_combs = [list(ele) for ele in all_meas_combs]
7755 poss_meas_combs = []
7757 for ii in range(0, len(all_meas_combs)):
7758 break_flag = False
7759 cur_comb = []
7760 for jj, lst1 in enumerate(all_meas_combs[ii]):
7761 for kk, lst2 in enumerate(all_meas_combs[ii]):
7762 if jj == kk:
7763 continue
7764 else:
7765 out = (np.array(lst1) == np.array(lst2)).tolist()
7766 if any(out):
7767 break_flag = True
7768 break
7769 if break_flag:
7770 break
7771 if break_flag:
7772 pass
7773 else:
7774 for lst1 in all_meas_combs[ii]:
7775 for ii, lst2 in enumerate(comb_inds):
7776 if lst1 == lst2:
7777 cur_comb.append(ii)
7778 poss_meas_combs.append(cur_comb)
7780 cor_tab, all_cost_m = self._gen_cor_tab(
7781 num_meas, all_combs, timestep, comb_inds, filt_args
7782 )
7784 # self._add_birth_hyps(num_meas)
7786 avg_prob_det, avg_prob_mdet = self._calc_avg_prob_det_mdet(cor_tab)
7788 cor_hyps = self._gen_cor_hyps(
7789 num_meas, avg_prob_det, avg_prob_mdet, all_cost_m, poss_meas_combs, cor_tab
7790 )
7792 self._track_tab = cor_tab
7793 self._hypotheses = cor_hyps
7794 self._card_dist = self._calc_card_dist(self._hypotheses)
7795 self._clean_updates()
7798class MSLabeledPoissonMultiBernoulliMixture(LabeledPoissonMultiBernoulliMixture):
7799 def __init__(self, **kwargs):
7800 super().__init__(**kwargs)
7802 def _gen_cor_tab(self, num_meas, meas, timestep, comb_inds, filt_args):
7803 num_pred = len(self._track_tab)
7804 num_birth = len(self.birth_terms)
7805 up_tab = [None] * ((num_meas + 1) * num_pred + num_meas * num_birth)
7807 # Missed Detection Updates
7808 for ii, track in enumerate(self._track_tab):
7809 up_tab[ii] = self._TabEntry().setup(track)
7810 sum_non_exist_prob = (
7811 1
7812 - up_tab[ii].exist_prob
7813 + up_tab[ii].exist_prob * self.prob_miss_detection
7814 )
7815 up_tab[ii].distrib_weights_hist.append(
7816 [w * sum_non_exist_prob for w in up_tab[ii].distrib_weights_hist[-1]]
7817 )
7818 up_tab[ii].exist_prob = (
7819 up_tab[ii].exist_prob * self.prob_miss_detection
7820 ) / (sum_non_exist_prob)
7821 up_tab[ii].meas_assoc_hist.append(None)
7822 all_cost_m = np.zeros((num_meas, num_pred + num_birth * num_meas))
7824 # Update for all existing tracks
7825 for emm, z in enumerate(meas):
7826 for ii, ent in enumerate(self._track_tab):
7827 s_to_ii = num_pred * emm + ii + num_pred
7828 (up_tab[s_to_ii], cost) = self._correct_track_tab_entry(
7829 z, ent, timestep, filt_args
7830 )
7831 if up_tab[s_to_ii] is not None:
7832 up_tab[s_to_ii].meas_assoc_hist.append(comb_inds[emm])
7833 all_cost_m[emm, ii] = cost
7835 # Update for all potential new births
7836 for emm, z in enumerate(meas):
7837 for ii, b_model in enumerate(self.birth_terms):
7838 s_to_ii = ((num_meas + 1) * num_pred) + emm * num_birth + ii
7839 (up_tab[s_to_ii], cost) = self._correct_birth_tab_entry(
7840 z, b_model, timestep, filt_args
7841 )
7842 if up_tab[s_to_ii] is not None:
7843 up_tab[s_to_ii].meas_assoc_hist.append(comb_inds[emm])
7844 up_tab[s_to_ii].label = (round(timestep, self.decimal_places), ii)
7845 all_cost_m[emm, emm + num_pred] = cost
7846 return up_tab, all_cost_m
7848 def _gen_cor_hyps(
7849 self,
7850 num_meas,
7851 avg_prob_detect,
7852 avg_prob_miss_detect,
7853 all_cost_m,
7854 meas_combs,
7855 cor_tab,
7856 ):
7857 num_pred = len(self._track_tab)
7858 up_hyps = []
7859 if not meas_combs:
7860 meas_combs = np.arange(0, np.shape(all_cost_m)[0]).tolist()
7861 # n_obj_for_tracks =
7862 if num_meas == 0:
7863 for hyp in self._hypotheses:
7864 pmd_log = np.sum(
7865 [np.log(avg_prob_miss_detect[ii]) for ii in hyp.track_set]
7866 )
7867 hyp.assoc_prob = -self.clutter_rate + pmd_log + np.log(hyp.assoc_prob)
7868 up_hyps.append(hyp)
7869 else:
7870 clutter = self.clutter_rate * self.clutter_den
7871 ss_w = 0
7872 for p_hyp in self._hypotheses:
7873 ss_w += np.sqrt(p_hyp.assoc_prob)
7874 for p_hyp in self._hypotheses:
7875 for ind_lst in meas_combs:
7876 if len(meas_combs) == 1:
7877 if p_hyp.num_tracks == 0: # all clutter
7878 inds = np.arange(num_pred, num_pred + num_meas).tolist()
7879 else:
7880 inds = (
7881 p_hyp.track_set
7882 + np.arange(num_pred, num_pred + num_meas).tolist()
7883 )
7884 cost_m = all_cost_m[:, inds]
7885 else:
7886 # TODO Change this process so it works when we can't just arange through to the end. I think we add num_pred to the indices in ind_lst
7887 if p_hyp.num_tracks == 0: # all clutter
7888 # inds = np.arange(num_pred, num_pred + len(ind_lst)).tolist()
7889 inds = list(num_pred + np.array(ind_lst))
7890 else:
7891 inds = p_hyp.track_set + [x + num_pred for x in ind_lst]
7892 tcm = all_cost_m[
7893 :, inds
7894 ] # error is certainly caused here. I'm going to bed now because it's past 11.
7895 cost_m = tcm[ind_lst, :]
7896 max_row_inds, max_col_inds = np.where(cost_m >= np.inf)
7897 if max_row_inds.size > 0:
7898 cost_m[max_row_inds, max_col_inds] = np.finfo(float).max
7899 min_row_inds, min_col_inds = np.where(cost_m <= 0.0)
7900 if min_row_inds.size > 0:
7901 cost_m[min_row_inds, min_col_inds] = np.finfo(float).eps # 1
7902 neg_log = -np.log(cost_m)
7904 m = np.round(self.req_upd * np.sqrt(p_hyp.assoc_prob) / ss_w)
7905 m = int(m.item())
7907 [assigns, costs] = murty_m_best_all_meas_assigned(neg_log, m)
7909 pmd_log = np.sum(
7910 [np.log(avg_prob_miss_detect[ii]) for ii in p_hyp.track_set]
7911 )
7912 for a, c in zip(assigns, costs):
7913 new_hyp = self._HypothesisHelper()
7914 new_hyp.assoc_prob = (
7915 -self.clutter_rate
7916 + num_meas * np.log(clutter)
7917 + pmd_log
7918 + np.log(p_hyp.assoc_prob)
7919 - c
7920 )
7921 if p_hyp.num_tracks == 0:
7922 # new_track_list = list(num_pred * a + num_pred * num_meas)
7923 # new_track_list = list(num_pred * a + num_pred * num_meas)
7924 new_track_list = []
7925 for ii, ms in enumerate(a):
7926 new_track_list.append(
7927 (ind_lst[ii] * (num_pred + 1) + num_pred * num_meas)
7928 )
7929 else:
7930 # track_inds = np.argwhere(a==1)
7931 new_track_list = []
7932 for ii, ms in enumerate(a):
7933 if len(p_hyp.track_set) >= ms:
7934 new_track_list.append(
7935 (ind_lst[ii] + 1) * num_pred
7936 + p_hyp.track_set[(ms - 1)]
7937 )
7938 elif len(p_hyp.track_set) == len(a):
7939 new_track_list.append(
7940 num_pred * ms - ii * (num_pred - 1)
7941 )
7942 elif len(p_hyp.track_set) < len(a):
7943 new_track_list.append(
7944 num_pred * (num_meas + 1) + (ind_lst[ii])
7945 )
7946 else:
7947 new_track_list.append(
7948 num_pred * (num_meas + 1)
7949 - (ms - len(p_hyp.track_set) - 1)
7950 )
7952 # if len(a) == len(p_hyp.track_set):
7953 # for ii, (ms, t) in enumerate(zip(a, p_hyp.track_set)):
7954 # if len(p_hyp.track_set) >= ms:
7955 # # new_track_list.append(((np.array(t)) * ms + num_pred))
7956 # new_track_list.append(
7957 # (num_pred * (ind_lst[ii] + 1) + np.array(t))
7958 # )
7959 # # new_track_list.append((num_pred * ms + np.array(t)))
7960 # else:
7961 # # new_track_list.append(num_pred * ms - ind_lst[ii] * (num_pred - 1))
7962 # new_track_list.append(
7963 # num_pred * (ind_lst[ii] + 1) + t + ms
7964 # # num_pred * (ms + 1) + ind_lst[ii]
7965 # # num_pred * (ind_lst[ii] + 1)
7966 # # - ii * (num_pred - 1)
7967 # )
7968 # elif len(p_hyp.track_set) < len(a):
7969 # for ii, ms in enumerate(a):
7970 # if len(p_hyp.track_set) >= ms:
7971 # new_track_list.append(
7972 # (1 + ind_lst[ii]) * num_pred
7973 # + p_hyp.track_set[(ms - 1)]
7974 # )
7975 # elif len(p_hyp.track_set) < ms:
7976 # new_track_list.append(
7977 # num_pred * (num_meas + 1) + (ind_lst[ii])
7978 # )
7979 # elif len(p_hyp.track_set) > len(a):
7980 # # May need to modify this
7981 # for ii, ms in enumerate(a):
7982 # if len(p_hyp.track_set) >= ms:
7983 # # new_track_list.append(
7984 # # ms * num_pred + p_hyp.track_set[(ms - 1)]
7985 # # )
7986 # new_track_list.append(
7987 # (1 + ind_lst[ii]) * num_pred
7988 # + p_hyp.track_set[(ms - 1)]
7989 # )
7990 # elif len(p_hyp.track_set) < ms:
7991 # new_track_list.append(
7992 # num_pred * (num_meas + 1) - (ms - len(p_hyp.track_set) - 1)
7993 # )
7995 # new_track_list = list(np.array(p_hyp.track_set) + num_pred + num_pred * a)# new_track_list = list(num_pred * a + np.array(p_hyp.track_set))
7997 new_hyp.track_set = new_track_list
7998 up_hyps.append(new_hyp)
8000 lse = log_sum_exp([x.assoc_prob for x in up_hyps])
8001 for ii in range(0, len(up_hyps)):
8002 up_hyps[ii].assoc_prob = np.exp(up_hyps[ii].assoc_prob - lse)
8003 return up_hyps
8005 def correct(self, timestep, meas, filt_args={}):
8006 """Correction step of the MS-PMBM filter.
8008 Notes
8009 -----
8010 This corrects the hypotheses based on the measurements and gates the
8011 measurements according to the class settings. It also updates the
8012 cardinality distribution.
8014 Parameters
8015 ----------
8016 timestep: float
8017 Current timestep.
8018 meas : list
8019 List of Nm x 1 numpy arrays each representing a measuremnt.
8020 filt_args : dict, optional
8021 keyword arguments to pass to the inner filters correct function.
8022 The default is {}.
8024 Returns
8025 -------
8026 None
8027 """
8028 # sens_len_lst = []
8029 # for sens in meas:
8030 # temp_lst = np.array([len(x) for x in sens])
8031 # if np.all(temp_lst==0):
8032 # sens_len_lst.append(0)
8033 # else:
8034 # sens_len_lst.append(len(sens))
8035 # if len(sens[-1]) != 0:
8036 # sens.append(np.array([]))
8037 all_combs = list(itertools.product(*meas))
8038 # all_combs.pop(-1)
8040 # for ii, c in enumerate(all_combs):
8041 # if np.all([len(tmplst) == 0 for tmplst in c]):
8042 # all_combs.pop(ii)
8043 # else:
8044 # continue
8045 if self.gating_on:
8046 warnings.warn("Gating not implemented yet. SKIPPING", RuntimeWarning)
8047 # means = []
8048 # covs = []
8049 # for ent in self._track_tab:
8050 # means.extend(ent.probDensity.means)
8051 # covs.extend(ent.probDensity.covariances)
8052 # meas = self._gate_meas(meas, means, covs)
8053 if self.save_measurements:
8054 self._meas_tab.append(deepcopy(meas))
8056 # get matrix of indices in all_combs somehow
8057 num_meas_per_sens = [len(x) for x in meas]
8058 num_meas = len(all_combs)
8059 num_sens = len(meas)
8060 # NEED TO ITERATE THROUGH THIS
8061 # SHOULD GO FROM CURR MNMPS AND STOP WHEN IT HITS 0.
8062 # THEN WE'LL GET COMBINATIONS OF ALL POSSIBLE MEASUREMENTS FOR ALL POSSIBLE PERMUTATIONS
8063 mnmps = min(num_meas_per_sens)
8065 # find a way to make this list of lists not list of tuples
8066 comb_inds = list(itertools.product(*list(np.arange(0, len(x)) for x in meas)))
8067 comb_inds = [list(ele) for ele in comb_inds]
8068 # pop_lst = []
8069 # for ii in range(len(comb_inds)):
8070 # for jj in range(len(comb_inds[ii])):
8071 # if comb_inds[ii][jj] >= sens_len_lst[jj]:
8072 # comb_inds[ii][jj] = np.nan
8073 # if np.all(np.isnan(comb_inds[ii])):
8074 # pop_lst.append(ii)
8075 # for ind in pop_lst:
8076 # comb_inds.pop(ind)
8077 min_meas_in_sens = np.min([len(x) for x in meas])
8079 all_meas_combs = list(itertools.combinations(comb_inds, mnmps))
8080 all_meas_combs = [list(ele) for ele in all_meas_combs]
8082 poss_meas_combs = []
8084 # for ii in range(len(all_combs)):
8085 # poss_meas_combs.append([ii])
8087 for ii in range(0, len(all_meas_combs)):
8088 break_flag = False
8089 cur_comb = []
8090 for jj, lst1 in enumerate(all_meas_combs[ii]):
8091 for kk, lst2 in enumerate(all_meas_combs[ii]):
8092 if jj == kk:
8093 continue
8094 else:
8095 out = (np.array(lst1) == np.array(lst2)).tolist()
8096 if any(out):
8097 break_flag = True
8098 break
8099 if break_flag:
8100 break
8101 if break_flag:
8102 pass
8103 else:
8104 for lst1 in all_meas_combs[ii]:
8105 for ii, lst2 in enumerate(comb_inds):
8106 if lst1 == lst2:
8107 cur_comb.append(ii)
8108 poss_meas_combs.append(cur_comb)
8110 cor_tab, all_cost_m = self._gen_cor_tab(
8111 num_meas, all_combs, timestep, comb_inds, filt_args
8112 )
8114 # self._add_birth_hyps(num_meas)
8116 avg_prob_det, avg_prob_mdet = self._calc_avg_prob_det_mdet(cor_tab)
8118 cor_hyps = self._gen_cor_hyps(
8119 num_meas, avg_prob_det, avg_prob_mdet, all_cost_m, poss_meas_combs, cor_tab
8120 )
8122 self._track_tab = cor_tab
8123 self._hypotheses = cor_hyps
8124 self._card_dist = self._calc_card_dist(self._hypotheses)
8125 self._clean_updates()
8128class MSIMMPoissonMultiBernoulliMixture(_IMMPMBMBase, MSPoissonMultiBernoulliMixture):
8129 def __init__(self, **kwargs):
8130 super().__init__(**kwargs)
8133class MSIMMLabeledPoissonMultiBernoulliMixture(
8134 _IMMPMBMBase, MSLabeledPoissonMultiBernoulliMixture
8135):
8136 def __init__(self, **kwargs):
8137 super().__init__(**kwargs)