Coverage for src/gncpy/distributions.py: 81%
252 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
1"""Standard distributions for use with the package classes."""
2import numpy as np
3import numpy.linalg as la
4import numpy.polynomial.hermite_e as herm_e
5from warnings import warn
6import itertools
7import matplotlib.pyplot as plt
9import gncpy.math as gmath
10import gncpy.plotting as pltUtil
13class _QuadPointIter:
14 def __init__(self, quadPoints):
15 self._quadPoints = quadPoints
16 self.__index = 0
18 def __next__(self):
19 try:
20 point = self._quadPoints.points[self.__index, :]
21 result = (point.reshape((self._quadPoints.num_axes, 1)),
22 self._quadPoints.weights[self.__index])
23 except IndexError:
24 raise StopIteration
25 self.__index += 1
26 return result
29class QuadraturePoints:
30 r"""Helper class that defines quadrature points.
32 Notes
33 -----
34 This implements the Probabilist's version of the Gauss-Hermite quadrature
35 points. This consists of the Hermite polynomial
37 .. math::
38 H_{e_n}(x) = (-1)^n \exp{\frac{x^2}{2}} \frac{\partial^n}{\partial x^n} \exp{-\frac{x^2}{2}}
40 and its associated weights. For details see
41 :cite:`Press1992_NumericalRecipesinCtheArtofScientificComputing`,
42 :cite:`Golub1969_CalculationofGaussQuadratureRules` and for a multi-variate
43 extension :cite:`Jackel2005_ANoteonMultivariateGaussHermiteQuadrature`.
45 Attributes
46 ----------
47 points_per_axis : int
48 Number of points to use per axis
49 num_axes : int
50 Number of axis in each point. This can be set manually, but will be updated
51 when :meth:`.update_points` is called to match the supplied mean.
52 weights : numpy array
53 Weight of each quadrature point
54 points : M x N numpy array
55 Each row corresponds to a quadrature point and the total number of rows
56 is the total number of points.
57 """
59 def __init__(self, points_per_axis=None, num_axes=None):
60 self.points_per_axis = points_per_axis
61 self.num_axes = num_axes
62 self.points = np.array([[]])
63 self.weights = np.array([])
65 @property
66 def num_points(self):
67 """Read only expected number of points."""
68 return int(self.points_per_axis**self.num_axes)
70 @property
71 def mean(self):
72 """Mean of the points, accounting for the weights."""
73 return gmath.weighted_sum_vec(self.weights,
74 self.points).reshape((self.num_axes, 1))
76 @property
77 def cov(self):
78 """Covariance of the points, accounting for the weights."""
79 x_bar = self.mean
80 diff = (self.points - x_bar.ravel()).reshape(self.points.shape[0], x_bar.size, 1)
81 return gmath.weighted_sum_mat(self.weights,
82 diff @ diff.reshape(self.points.shape[0],
83 1, self.num_axes))
85 def _factor_scale_matrix(self, scale, have_sqrt):
86 if have_sqrt:
87 return scale
88 else:
89 return la.cholesky(scale)
91 def update_points(self, mean, scale, have_sqrt=False):
92 """Updates the quadrature points given some initial point and scale.
94 Parameters
95 ----------
96 mean : N x 1 numpy array
97 Point to represent by quadrature points.
98 scale : N x N numpy array
99 Covariance or square root of the covariance matrix of the given point.
100 have_sqrt : bool
101 Optional flag indicating if the square root of the matrix was
102 supplied. The default is False.
104 Returns
105 -------
106 None.
107 """
108 def create_combos(points_per_ax, num_ax, tot_points):
109 combos = np.meshgrid(*itertools.repeat(range(points_per_ax), num_ax))
111 return np.array(combos).reshape((num_ax, tot_points)).T
113 self.num_axes = mean.size
115 sqrt_cov = self._factor_scale_matrix(scale, have_sqrt)
117 self.points = np.nan * np.ones((self.num_points, self.num_axes))
118 self.weights = np.nan * np.ones(self.num_points)
120 # get standard values for 1 axis case, use "Probabilist's" Hermite polynomials
121 quad_points, weights = herm_e.hermegauss(self.points_per_axis)
123 ind_combos = create_combos(self.points_per_axis, self.num_axes,
124 self.num_points)
126 for ii, inds in enumerate(ind_combos):
127 point = quad_points[inds].reshape((inds.size, 1))
128 self.points[ii, :] = (mean + sqrt_cov @ point).ravel()
129 self.weights[ii] = np.prod(weights[inds])
131 self.weights = self.weights / np.sum(self.weights)
133 # note: Sum[w * p @ p.T] should equal the identity matrix here
135 def __iter__(self):
136 """Custom iterator for looping over the object.
138 Returns
139 -------
140 N x 1 numpy array
141 Current point in set.
142 float
143 Weight of the current point.
144 """
145 return _QuadPointIter(self)
147 def plot_points(self, inds, x_lbl='X Position', y_lbl='Y Position',
148 ttl='Weighted Positions', size_factor=100**2, **kwargs):
149 """Plots the weighted points.
151 Keywrod arguments are processed with
152 :meth:`gncpy.plotting.init_plotting_opts`. This function
153 implements
155 - f_hndl
156 - Any title/axis/text options
158 Parameters
159 ----------
160 inds : list or int
161 Indices of the point vector to plot. Can be a list of at most 2
162 elements. If only 1 is given a bar chart is created.
163 x_lbl : string, optional
164 Label for the x-axis. The default is 'X Position'.
165 y_lbl : string, optional
166 Label for the y-axis. The default is 'Y Position'.
167 ttl : string, optional
168 Title of the plot. The default is 'Weighted positions'.
169 size_factor : int, optional
170 Factor to multiply the weight by when determining the marker size.
171 Only used if plotting 2 indices. The default is 100**2.
172 **kwargs : dict
173 Additional standard plotting options.
175 Returns
176 -------
177 fig : matplotlib figure handle
178 Handle to the figure used.
179 """
180 opts = pltUtil.init_plotting_opts(**kwargs)
181 fig = opts['f_hndl']
183 if fig is None:
184 fig = plt.figure()
185 fig.add_subplot(1, 1, 1)
187 if isinstance(inds, list):
188 if len(inds) >= 2:
189 fig.axes[0].scatter(self.points[:, inds[0]],
190 self.points[:, inds[1]],
191 s=size_factor * self.weights, color='k')
192 fig.axes[0].grid(True)
194 elif len(inds) == 1:
195 fig.axes[0].bar(self.points[:, inds[0]],
196 self.weights)
197 else:
198 fig.axes[0].bar(self.points[:, inds],
199 self.weights)
201 pltUtil.set_title_label(fig, 0, opts, ttl=ttl, x_lbl=x_lbl,
202 y_lbl=y_lbl)
204 fig.tight_layout()
206 return fig
209class SigmaPoints(QuadraturePoints):
210 """Helper class that defines sigma points.
212 Notes
213 -----
214 This can be interpretted as a speacial case of the Quadrature points. See
215 :cite:`Sarkka2015_OntheRelationbetweenGaussianProcessQuadraturesandSigmaPointMethods`
216 for details.
218 Attributes
219 ----------
220 alpha : float, optional
221 Tunig parameter, influences the spread of sigma points about the mean.
222 In range (0, 1]. The default is 1.
223 kappa : float, optional
224 Tunig parameter, influences the spread of sigma points about the mean.
225 In range [0, inf]. The default is 0.
226 beta : float, optional
227 Tunig parameter for distribution type. In range [0, Inf]. Defaults
228 to 2 which is ideal for Gaussians.
229 """
231 def __init__(self, alpha=1, kappa=0, beta=2, **kwargs):
232 super().__init__(**kwargs)
234 self.alpha = alpha
235 self.kappa = kappa
236 self.beta = beta
238 @property
239 def lam(self):
240 """Read only derived parameter of the sigma points."""
241 return self.alpha**2 * (self.num_axes + self.kappa) - self.num_axes
243 @property
244 def num_points(self):
245 """Read only expected number of points."""
246 return int(2 * self.num_axes + 1)
248 @property
249 def weights_mean(self):
250 """Wights for calculating the mean."""
251 return self.weights[0:self.num_points]
253 @weights_mean.setter
254 def weights_mean(self, val):
255 if self.weights.size != 2 * self.num_points:
256 self.weights = np.nan * np.ones(2 * self.num_points)
257 self.weights[0:self.num_points] = val
259 @property
260 def weights_cov(self):
261 """Wights for calculating the covariance."""
262 return self.weights[self.num_points:]
264 @weights_cov.setter
265 def weights_cov(self, val):
266 if self.weights.size != 2 * self.num_points:
267 self.weights = np.nan * np.ones(2 * self.num_points)
268 self.weights[self.num_points:] = val
270 @property
271 def mean(self):
272 """Mean of the points, accounting for the weights."""
273 return gmath.weighted_sum_vec(self.weights_mean,
274 self.points).reshape((self.points.shape[1], 1))
276 @property
277 def cov(self):
278 """Covariance of the points, accounting for the weights."""
279 x_bar = self.mean
280 diff = (self.points - x_bar.ravel()).reshape(self.points.shape[0], x_bar.size, 1)
281 return gmath.weighted_sum_mat(self.weights_cov,
282 diff @ diff.reshape(self.points.shape[0],
283 1, x_bar.size))
285 def init_weights(self):
286 """Initializes the weights based on other parameters.
288 This should be called to setup the weight vectors after setting
289 `alpha`, `kappa`, `beta`, and `n`.
290 """
291 lam = self.lam
292 self.weights_mean = np.nan * np.ones(self.num_points)
293 self.weights_cov = np.nan * np.ones(self.num_points)
294 self.weights_mean[0] = lam / (self.num_axes + lam)
295 self.weights_cov[0] = lam / (self.num_axes + lam) + 1 - self.alpha**2 \
296 + self.beta
298 w = 1 / (2 * (self.num_axes + lam))
299 self.weights_mean[1:] = w
300 self.weights_cov[1:] = w
302 def update_points(self, x, scale, have_sqrt=False):
303 """Updates the sigma points given some initial point and covariance.
305 Parameters
306 ----------
307 x : N x 1 numpy array
308 Point to represent by sigma points.
309 scale : N x N numpy array
310 Covariance or square root of the covariance matrix of the given point.
311 have_sqrt : bool
312 Optional flag indicating if the square root of the matrix was
313 supplied. The default is False.
315 Returns
316 -------
317 None.
318 """
319 self.num_axes = x.size
320 if have_sqrt:
321 factor = np.sqrt(self.num_axes + self.lam)
322 else:
323 factor = self.num_axes + self.lam
324 S = self._factor_scale_matrix(factor * scale, have_sqrt)
326 self.points = np.nan * np.ones((2 * self.num_axes + 1, self.num_axes))
327 self.points[0, :] = x.flatten()
328 self.points[1:self.num_axes + 1, :] = x.ravel() + S.T
329 self.points[self.num_axes + 1:, :] = x.ravel() - S.T
332class Particle:
333 """Helper class for defining single particles in a particle distribution.
335 Attributes
336 ----------
337 point : N x 1 numpy array
338 The location of the particle.
339 uncertainty : N x N numpy array, optional
340 The uncertainty of the point, this does not always need to be specified.
341 Check with the filter if this is needed.
342 sigmaPoints : :class:`.distributions.SigmaPoints`, optional
343 Sigma points used to represent the particle. This is not always needed,
344 check with the filter to determine if this is necessary.
345 """
347 def __init__(self, point=np.array([[]]), uncertainty=np.array([[]]),
348 sigmaPoints=None):
349 self.point = point
350 self.uncertainty = uncertainty
352 self.sigmaPoints = sigmaPoints
354 @property
355 def mean(self):
356 """Read only mean value of the particle.
358 If no sigma points are used then it is the same as `point`. Otherwise
359 it is the mean of the sigma points.
361 Returns
362 -------
363 N x 1 numpy array
364 The mean value.
365 """
366 if self.sigmaPoints is not None:
367 # self.sigmaPoints.update_points(self.point, self.uncertainty)
368 return self.sigmaPoints.mean
369 else:
370 return self.point
372 @mean.setter
373 def mean(self, x):
374 warn('Particle mean is read only')
377class _ParticleDistIter:
378 def __init__(self, partDist):
379 self._partDist = partDist
380 self.__index = 0
382 def __next__(self):
383 try:
384 result = (self._partDist._particles[self.__index],
385 self._partDist.weights[self.__index])
386 except IndexError:
387 raise StopIteration
388 self.__index += 1
389 return result
392class SimpleParticleDistribution:
393 def __init__(self, **kwargs):
394 self.num_parts_per_ind = np.array([])
395 self.particles = np.array([[]])
396 self.weights = np.array([])
398 @property
399 def num_particles(self):
400 return int(np.sum(self.num_parts_per_ind))
402 def plot_particles(self, ind, title='Approximate PDF from Particle Distribution',
403 x_lbl='State', y_lbl='Probability', **kwargs):
404 """Plots the approximate PDF represented by the particle distribution
406 Parameters
407 ----------
408 ind : int
409 Index of the particle vector to plot.
410 title : string, optional
411 Title of the plot. The default is 'Particle Distribution'.
412 x_lbl : string, optional
413 X-axis label. The default is 'State'.
414 y_lbl : string, optional
415 Y-axis label. The default is 'Probability'.
416 **kwargs : dict
417 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts`
418 function. Values implemented here are `f_hndl`, `lgnd_loc`, and
419 any values relating to title/axis text formatting.
421 Returns
422 -------
423 fig : matplotlib figure
424 Figure object the data was plotted on.
425 """
426 opts = pltUtil.init_plotting_opts(**kwargs)
427 fig = opts['f_hndl']
428 lgnd_loc = opts['lgnd_loc']
430 if fig is None:
431 fig = plt.figure()
432 fig.add_subplot(1, 1, 1)
434 all_inds = []
435 for ii, num_dups in enumerate(self.num_parts_per_ind):
436 all_inds.extend([ii] * int(num_dups))
438 x = self.particles[all_inds, ind]
439 h_opts = {"histtype": "stepfilled", "bins": 'auto', "density": True}
440 fig.axes[0].hist(x, **h_opts)
442 # x = np.sort(self.particles)
443 # y = self.num_parts_per_ind / self.num_particles
444 # fig.axes[0].plot(x, y)
446 pltUtil.set_title_label(fig, 0, opts, ttl=title,
447 x_lbl=x_lbl, y_lbl=y_lbl)
448 if lgnd_loc is not None:
449 fig.legend(loc=lgnd_loc)
450 fig.tight_layout()
452 return fig
455class ParticleDistribution:
456 """Particle distribution object.
458 Helper class for managing arbitrary distributions of particles.
459 """
461 def __init__(self, **kwargs):
462 self._particles = []
463 self._weights = []
465 self.__need_mean_lst_update = True
466 self.__need_uncert_lst_update = True
467 self.__means = []
468 self.__uncertianties = []
470 self.__index = 0
472 @property
473 def particles(self):
474 """Particles in the distribution.
476 Must be set by the :meth:`.distributions.ParticleDistribution.add_particle`
477 method.
479 Returns
480 -------
481 list
482 Each element is a :class:`.distributions.Particle` object.
483 """
484 if self.num_particles > 0:
485 if self.__need_mean_lst_update:
486 self.__need_mean_lst_update = False
487 self.__means = [x.mean for x in self._particles if x]
488 return self.__means
489 else:
490 return []
492 @property
493 def weights(self):
494 """Weights of the partilces.
496 Must be set by the :meth:`.distributions.ParticleDistribution.add_particle`
497 or :meth:`.distributions.ParticleDistribution.update_weights` methods.
499 Returns
500 -------
501 list
502 Each element is a float representing the weight of the particle.
503 """
504 return self._weights
506 @weights.setter
507 def weights(self, lst):
508 raise RuntimeError('Use function to add weight.')
510 @particles.setter
511 def particles(self, lst):
512 raise RuntimeError('Use function to add particle.')
514 @property
515 def uncertainties(self):
516 """Read only uncertainty of each particle.
518 Returns
519 -------
520 list
521 Each element is a N x N numpy array
522 """
523 if self.num_particles > 0:
524 if self.__need_uncert_lst_update:
525 self.__need_uncert_lst_update = False
526 self.__uncertianties = [x.uncertainty for x in self._particles
527 if x]
528 return self.__uncertianties
529 else:
530 return []
532 def add_particle(self, p, w):
533 """Adds a particle and weight to the distribution.
535 Parameters
536 ----------
537 p : :class:`.distributions.Particle` or list
538 Particle to add or list of particles.
539 w : float or list
540 Weight of the particle or list of weights.
542 Returns
543 -------
544 None.
545 """
546 self.__need_mean_lst_update = True
547 self.__need_uncert_lst_update = True
548 if isinstance(p, list):
549 self._particles.extend(p)
550 else:
551 self._particles.append(p)
553 if isinstance(w, list):
554 self._weights.extend(w)
555 else:
556 self._weights.append(w)
558 def clear_particles(self):
559 """Clears the particle and weight lists."""
560 self.__need_mean_lst_update = True
561 self.__need_uncert_lst_update = True
562 self._particles = []
563 self._weights = []
565 def update_weights(self, w_lst):
566 """Updates the weights to match the given list.
568 Checks that the length of the weights matches the number of particles.
569 """
570 self.__need_mean_lst_update = True
571 self.__need_uncert_lst_update = True
573 if len(w_lst) != self.num_particles:
574 warn('Different number of weights than particles')
575 else:
576 self._weights = w_lst
578 @property
579 def mean(self):
580 """Mean of the particles."""
581 if self.num_particles == 0 or any(np.abs(self.weights) == np.inf):
582 mean = np.array([[]])
583 else:
584 mean = gmath.weighted_sum_vec(self.weights, self.particles)
585 return mean
587 @property
588 def covariance(self):
589 """Covariance of the particles."""
590 if self.num_particles == 0:
591 cov = np.array([[]])
592 else:
593 x_dim = self.particles[0].size
594 cov = np.cov(np.hstack(self.particles)).reshape((x_dim, x_dim))
595 cov = (cov + cov.T) * 0.5
596 return cov
598 @property
599 def num_particles(self):
600 """Number of particles."""
601 return len(self._particles)
603 def __iter__(self):
604 """Custom iterator for looping over the object.
606 Returns
607 -------
608 :class:`.distributions.Particle`
609 Current particle in distribution.
610 float
611 Weight of the current particle.
612 """
613 return _ParticleDistIter(self)