Coverage for src/gncpy/distributions.py: 81%

252 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-13 06:15 +0000

1"""Standard distributions for use with the package classes.""" 

2import numpy as np 

3import numpy.linalg as la 

4import numpy.polynomial.hermite_e as herm_e 

5from warnings import warn 

6import itertools 

7import matplotlib.pyplot as plt 

8 

9import gncpy.math as gmath 

10import gncpy.plotting as pltUtil 

11 

12 

13class _QuadPointIter: 

14 def __init__(self, quadPoints): 

15 self._quadPoints = quadPoints 

16 self.__index = 0 

17 

18 def __next__(self): 

19 try: 

20 point = self._quadPoints.points[self.__index, :] 

21 result = (point.reshape((self._quadPoints.num_axes, 1)), 

22 self._quadPoints.weights[self.__index]) 

23 except IndexError: 

24 raise StopIteration 

25 self.__index += 1 

26 return result 

27 

28 

29class QuadraturePoints: 

30 r"""Helper class that defines quadrature points. 

31 

32 Notes 

33 ----- 

34 This implements the Probabilist's version of the Gauss-Hermite quadrature 

35 points. This consists of the Hermite polynomial 

36 

37 .. math:: 

38 H_{e_n}(x) = (-1)^n \exp{\frac{x^2}{2}} \frac{\partial^n}{\partial x^n} \exp{-\frac{x^2}{2}} 

39 

40 and its associated weights. For details see 

41 :cite:`Press1992_NumericalRecipesinCtheArtofScientificComputing`, 

42 :cite:`Golub1969_CalculationofGaussQuadratureRules` and for a multi-variate 

43 extension :cite:`Jackel2005_ANoteonMultivariateGaussHermiteQuadrature`. 

44 

45 Attributes 

46 ---------- 

47 points_per_axis : int 

48 Number of points to use per axis 

49 num_axes : int 

50 Number of axis in each point. This can be set manually, but will be updated 

51 when :meth:`.update_points` is called to match the supplied mean. 

52 weights : numpy array 

53 Weight of each quadrature point 

54 points : M x N numpy array 

55 Each row corresponds to a quadrature point and the total number of rows 

56 is the total number of points. 

57 """ 

58 

59 def __init__(self, points_per_axis=None, num_axes=None): 

60 self.points_per_axis = points_per_axis 

61 self.num_axes = num_axes 

62 self.points = np.array([[]]) 

63 self.weights = np.array([]) 

64 

65 @property 

66 def num_points(self): 

67 """Read only expected number of points.""" 

68 return int(self.points_per_axis**self.num_axes) 

69 

70 @property 

71 def mean(self): 

72 """Mean of the points, accounting for the weights.""" 

73 return gmath.weighted_sum_vec(self.weights, 

74 self.points).reshape((self.num_axes, 1)) 

75 

76 @property 

77 def cov(self): 

78 """Covariance of the points, accounting for the weights.""" 

79 x_bar = self.mean 

80 diff = (self.points - x_bar.ravel()).reshape(self.points.shape[0], x_bar.size, 1) 

81 return gmath.weighted_sum_mat(self.weights, 

82 diff @ diff.reshape(self.points.shape[0], 

83 1, self.num_axes)) 

84 

85 def _factor_scale_matrix(self, scale, have_sqrt): 

86 if have_sqrt: 

87 return scale 

88 else: 

89 return la.cholesky(scale) 

90 

91 def update_points(self, mean, scale, have_sqrt=False): 

92 """Updates the quadrature points given some initial point and scale. 

93 

94 Parameters 

95 ---------- 

96 mean : N x 1 numpy array 

97 Point to represent by quadrature points. 

98 scale : N x N numpy array 

99 Covariance or square root of the covariance matrix of the given point. 

100 have_sqrt : bool 

101 Optional flag indicating if the square root of the matrix was 

102 supplied. The default is False. 

103 

104 Returns 

105 ------- 

106 None. 

107 """ 

108 def create_combos(points_per_ax, num_ax, tot_points): 

109 combos = np.meshgrid(*itertools.repeat(range(points_per_ax), num_ax)) 

110 

111 return np.array(combos).reshape((num_ax, tot_points)).T 

112 

113 self.num_axes = mean.size 

114 

115 sqrt_cov = self._factor_scale_matrix(scale, have_sqrt) 

116 

117 self.points = np.nan * np.ones((self.num_points, self.num_axes)) 

118 self.weights = np.nan * np.ones(self.num_points) 

119 

120 # get standard values for 1 axis case, use "Probabilist's" Hermite polynomials 

121 quad_points, weights = herm_e.hermegauss(self.points_per_axis) 

122 

123 ind_combos = create_combos(self.points_per_axis, self.num_axes, 

124 self.num_points) 

125 

126 for ii, inds in enumerate(ind_combos): 

127 point = quad_points[inds].reshape((inds.size, 1)) 

128 self.points[ii, :] = (mean + sqrt_cov @ point).ravel() 

129 self.weights[ii] = np.prod(weights[inds]) 

130 

131 self.weights = self.weights / np.sum(self.weights) 

132 

133 # note: Sum[w * p @ p.T] should equal the identity matrix here 

134 

135 def __iter__(self): 

136 """Custom iterator for looping over the object. 

137 

138 Returns 

139 ------- 

140 N x 1 numpy array 

141 Current point in set. 

142 float 

143 Weight of the current point. 

144 """ 

145 return _QuadPointIter(self) 

146 

147 def plot_points(self, inds, x_lbl='X Position', y_lbl='Y Position', 

148 ttl='Weighted Positions', size_factor=100**2, **kwargs): 

149 """Plots the weighted points. 

150 

151 Keywrod arguments are processed with 

152 :meth:`gncpy.plotting.init_plotting_opts`. This function 

153 implements 

154 

155 - f_hndl 

156 - Any title/axis/text options 

157 

158 Parameters 

159 ---------- 

160 inds : list or int 

161 Indices of the point vector to plot. Can be a list of at most 2 

162 elements. If only 1 is given a bar chart is created. 

163 x_lbl : string, optional 

164 Label for the x-axis. The default is 'X Position'. 

165 y_lbl : string, optional 

166 Label for the y-axis. The default is 'Y Position'. 

167 ttl : string, optional 

168 Title of the plot. The default is 'Weighted positions'. 

169 size_factor : int, optional 

170 Factor to multiply the weight by when determining the marker size. 

171 Only used if plotting 2 indices. The default is 100**2. 

172 **kwargs : dict 

173 Additional standard plotting options. 

174 

175 Returns 

176 ------- 

177 fig : matplotlib figure handle 

178 Handle to the figure used. 

179 """ 

180 opts = pltUtil.init_plotting_opts(**kwargs) 

181 fig = opts['f_hndl'] 

182 

183 if fig is None: 

184 fig = plt.figure() 

185 fig.add_subplot(1, 1, 1) 

186 

187 if isinstance(inds, list): 

188 if len(inds) >= 2: 

189 fig.axes[0].scatter(self.points[:, inds[0]], 

190 self.points[:, inds[1]], 

191 s=size_factor * self.weights, color='k') 

192 fig.axes[0].grid(True) 

193 

194 elif len(inds) == 1: 

195 fig.axes[0].bar(self.points[:, inds[0]], 

196 self.weights) 

197 else: 

198 fig.axes[0].bar(self.points[:, inds], 

199 self.weights) 

200 

201 pltUtil.set_title_label(fig, 0, opts, ttl=ttl, x_lbl=x_lbl, 

202 y_lbl=y_lbl) 

203 

204 fig.tight_layout() 

205 

206 return fig 

207 

208 

209class SigmaPoints(QuadraturePoints): 

210 """Helper class that defines sigma points. 

211 

212 Notes 

213 ----- 

214 This can be interpretted as a speacial case of the Quadrature points. See 

215 :cite:`Sarkka2015_OntheRelationbetweenGaussianProcessQuadraturesandSigmaPointMethods` 

216 for details. 

217 

218 Attributes 

219 ---------- 

220 alpha : float, optional 

221 Tunig parameter, influences the spread of sigma points about the mean. 

222 In range (0, 1]. The default is 1. 

223 kappa : float, optional 

224 Tunig parameter, influences the spread of sigma points about the mean. 

225 In range [0, inf]. The default is 0. 

226 beta : float, optional 

227 Tunig parameter for distribution type. In range [0, Inf]. Defaults 

228 to 2 which is ideal for Gaussians. 

229 """ 

230 

231 def __init__(self, alpha=1, kappa=0, beta=2, **kwargs): 

232 super().__init__(**kwargs) 

233 

234 self.alpha = alpha 

235 self.kappa = kappa 

236 self.beta = beta 

237 

238 @property 

239 def lam(self): 

240 """Read only derived parameter of the sigma points.""" 

241 return self.alpha**2 * (self.num_axes + self.kappa) - self.num_axes 

242 

243 @property 

244 def num_points(self): 

245 """Read only expected number of points.""" 

246 return int(2 * self.num_axes + 1) 

247 

248 @property 

249 def weights_mean(self): 

250 """Wights for calculating the mean.""" 

251 return self.weights[0:self.num_points] 

252 

253 @weights_mean.setter 

254 def weights_mean(self, val): 

255 if self.weights.size != 2 * self.num_points: 

256 self.weights = np.nan * np.ones(2 * self.num_points) 

257 self.weights[0:self.num_points] = val 

258 

259 @property 

260 def weights_cov(self): 

261 """Wights for calculating the covariance.""" 

262 return self.weights[self.num_points:] 

263 

264 @weights_cov.setter 

265 def weights_cov(self, val): 

266 if self.weights.size != 2 * self.num_points: 

267 self.weights = np.nan * np.ones(2 * self.num_points) 

268 self.weights[self.num_points:] = val 

269 

270 @property 

271 def mean(self): 

272 """Mean of the points, accounting for the weights.""" 

273 return gmath.weighted_sum_vec(self.weights_mean, 

274 self.points).reshape((self.points.shape[1], 1)) 

275 

276 @property 

277 def cov(self): 

278 """Covariance of the points, accounting for the weights.""" 

279 x_bar = self.mean 

280 diff = (self.points - x_bar.ravel()).reshape(self.points.shape[0], x_bar.size, 1) 

281 return gmath.weighted_sum_mat(self.weights_cov, 

282 diff @ diff.reshape(self.points.shape[0], 

283 1, x_bar.size)) 

284 

285 def init_weights(self): 

286 """Initializes the weights based on other parameters. 

287 

288 This should be called to setup the weight vectors after setting 

289 `alpha`, `kappa`, `beta`, and `n`. 

290 """ 

291 lam = self.lam 

292 self.weights_mean = np.nan * np.ones(self.num_points) 

293 self.weights_cov = np.nan * np.ones(self.num_points) 

294 self.weights_mean[0] = lam / (self.num_axes + lam) 

295 self.weights_cov[0] = lam / (self.num_axes + lam) + 1 - self.alpha**2 \ 

296 + self.beta 

297 

298 w = 1 / (2 * (self.num_axes + lam)) 

299 self.weights_mean[1:] = w 

300 self.weights_cov[1:] = w 

301 

302 def update_points(self, x, scale, have_sqrt=False): 

303 """Updates the sigma points given some initial point and covariance. 

304 

305 Parameters 

306 ---------- 

307 x : N x 1 numpy array 

308 Point to represent by sigma points. 

309 scale : N x N numpy array 

310 Covariance or square root of the covariance matrix of the given point. 

311 have_sqrt : bool 

312 Optional flag indicating if the square root of the matrix was 

313 supplied. The default is False. 

314 

315 Returns 

316 ------- 

317 None. 

318 """ 

319 self.num_axes = x.size 

320 if have_sqrt: 

321 factor = np.sqrt(self.num_axes + self.lam) 

322 else: 

323 factor = self.num_axes + self.lam 

324 S = self._factor_scale_matrix(factor * scale, have_sqrt) 

325 

326 self.points = np.nan * np.ones((2 * self.num_axes + 1, self.num_axes)) 

327 self.points[0, :] = x.flatten() 

328 self.points[1:self.num_axes + 1, :] = x.ravel() + S.T 

329 self.points[self.num_axes + 1:, :] = x.ravel() - S.T 

330 

331 

332class Particle: 

333 """Helper class for defining single particles in a particle distribution. 

334 

335 Attributes 

336 ---------- 

337 point : N x 1 numpy array 

338 The location of the particle. 

339 uncertainty : N x N numpy array, optional 

340 The uncertainty of the point, this does not always need to be specified. 

341 Check with the filter if this is needed. 

342 sigmaPoints : :class:`.distributions.SigmaPoints`, optional 

343 Sigma points used to represent the particle. This is not always needed, 

344 check with the filter to determine if this is necessary. 

345 """ 

346 

347 def __init__(self, point=np.array([[]]), uncertainty=np.array([[]]), 

348 sigmaPoints=None): 

349 self.point = point 

350 self.uncertainty = uncertainty 

351 

352 self.sigmaPoints = sigmaPoints 

353 

354 @property 

355 def mean(self): 

356 """Read only mean value of the particle. 

357 

358 If no sigma points are used then it is the same as `point`. Otherwise 

359 it is the mean of the sigma points. 

360 

361 Returns 

362 ------- 

363 N x 1 numpy array 

364 The mean value. 

365 """ 

366 if self.sigmaPoints is not None: 

367 # self.sigmaPoints.update_points(self.point, self.uncertainty) 

368 return self.sigmaPoints.mean 

369 else: 

370 return self.point 

371 

372 @mean.setter 

373 def mean(self, x): 

374 warn('Particle mean is read only') 

375 

376 

377class _ParticleDistIter: 

378 def __init__(self, partDist): 

379 self._partDist = partDist 

380 self.__index = 0 

381 

382 def __next__(self): 

383 try: 

384 result = (self._partDist._particles[self.__index], 

385 self._partDist.weights[self.__index]) 

386 except IndexError: 

387 raise StopIteration 

388 self.__index += 1 

389 return result 

390 

391 

392class SimpleParticleDistribution: 

393 def __init__(self, **kwargs): 

394 self.num_parts_per_ind = np.array([]) 

395 self.particles = np.array([[]]) 

396 self.weights = np.array([]) 

397 

398 @property 

399 def num_particles(self): 

400 return int(np.sum(self.num_parts_per_ind)) 

401 

402 def plot_particles(self, ind, title='Approximate PDF from Particle Distribution', 

403 x_lbl='State', y_lbl='Probability', **kwargs): 

404 """Plots the approximate PDF represented by the particle distribution 

405 

406 Parameters 

407 ---------- 

408 ind : int 

409 Index of the particle vector to plot. 

410 title : string, optional 

411 Title of the plot. The default is 'Particle Distribution'. 

412 x_lbl : string, optional 

413 X-axis label. The default is 'State'. 

414 y_lbl : string, optional 

415 Y-axis label. The default is 'Probability'. 

416 **kwargs : dict 

417 Additional plotting options for :meth:`gncpy.plotting.init_plotting_opts` 

418 function. Values implemented here are `f_hndl`, `lgnd_loc`, and 

419 any values relating to title/axis text formatting. 

420 

421 Returns 

422 ------- 

423 fig : matplotlib figure 

424 Figure object the data was plotted on. 

425 """ 

426 opts = pltUtil.init_plotting_opts(**kwargs) 

427 fig = opts['f_hndl'] 

428 lgnd_loc = opts['lgnd_loc'] 

429 

430 if fig is None: 

431 fig = plt.figure() 

432 fig.add_subplot(1, 1, 1) 

433 

434 all_inds = [] 

435 for ii, num_dups in enumerate(self.num_parts_per_ind): 

436 all_inds.extend([ii] * int(num_dups)) 

437 

438 x = self.particles[all_inds, ind] 

439 h_opts = {"histtype": "stepfilled", "bins": 'auto', "density": True} 

440 fig.axes[0].hist(x, **h_opts) 

441 

442 # x = np.sort(self.particles) 

443 # y = self.num_parts_per_ind / self.num_particles 

444 # fig.axes[0].plot(x, y) 

445 

446 pltUtil.set_title_label(fig, 0, opts, ttl=title, 

447 x_lbl=x_lbl, y_lbl=y_lbl) 

448 if lgnd_loc is not None: 

449 fig.legend(loc=lgnd_loc) 

450 fig.tight_layout() 

451 

452 return fig 

453 

454 

455class ParticleDistribution: 

456 """Particle distribution object. 

457 

458 Helper class for managing arbitrary distributions of particles. 

459 """ 

460 

461 def __init__(self, **kwargs): 

462 self._particles = [] 

463 self._weights = [] 

464 

465 self.__need_mean_lst_update = True 

466 self.__need_uncert_lst_update = True 

467 self.__means = [] 

468 self.__uncertianties = [] 

469 

470 self.__index = 0 

471 

472 @property 

473 def particles(self): 

474 """Particles in the distribution. 

475 

476 Must be set by the :meth:`.distributions.ParticleDistribution.add_particle` 

477 method. 

478 

479 Returns 

480 ------- 

481 list 

482 Each element is a :class:`.distributions.Particle` object. 

483 """ 

484 if self.num_particles > 0: 

485 if self.__need_mean_lst_update: 

486 self.__need_mean_lst_update = False 

487 self.__means = [x.mean for x in self._particles if x] 

488 return self.__means 

489 else: 

490 return [] 

491 

492 @property 

493 def weights(self): 

494 """Weights of the partilces. 

495 

496 Must be set by the :meth:`.distributions.ParticleDistribution.add_particle` 

497 or :meth:`.distributions.ParticleDistribution.update_weights` methods. 

498 

499 Returns 

500 ------- 

501 list 

502 Each element is a float representing the weight of the particle. 

503 """ 

504 return self._weights 

505 

506 @weights.setter 

507 def weights(self, lst): 

508 raise RuntimeError('Use function to add weight.') 

509 

510 @particles.setter 

511 def particles(self, lst): 

512 raise RuntimeError('Use function to add particle.') 

513 

514 @property 

515 def uncertainties(self): 

516 """Read only uncertainty of each particle. 

517 

518 Returns 

519 ------- 

520 list 

521 Each element is a N x N numpy array 

522 """ 

523 if self.num_particles > 0: 

524 if self.__need_uncert_lst_update: 

525 self.__need_uncert_lst_update = False 

526 self.__uncertianties = [x.uncertainty for x in self._particles 

527 if x] 

528 return self.__uncertianties 

529 else: 

530 return [] 

531 

532 def add_particle(self, p, w): 

533 """Adds a particle and weight to the distribution. 

534 

535 Parameters 

536 ---------- 

537 p : :class:`.distributions.Particle` or list 

538 Particle to add or list of particles. 

539 w : float or list 

540 Weight of the particle or list of weights. 

541 

542 Returns 

543 ------- 

544 None. 

545 """ 

546 self.__need_mean_lst_update = True 

547 self.__need_uncert_lst_update = True 

548 if isinstance(p, list): 

549 self._particles.extend(p) 

550 else: 

551 self._particles.append(p) 

552 

553 if isinstance(w, list): 

554 self._weights.extend(w) 

555 else: 

556 self._weights.append(w) 

557 

558 def clear_particles(self): 

559 """Clears the particle and weight lists.""" 

560 self.__need_mean_lst_update = True 

561 self.__need_uncert_lst_update = True 

562 self._particles = [] 

563 self._weights = [] 

564 

565 def update_weights(self, w_lst): 

566 """Updates the weights to match the given list. 

567 

568 Checks that the length of the weights matches the number of particles. 

569 """ 

570 self.__need_mean_lst_update = True 

571 self.__need_uncert_lst_update = True 

572 

573 if len(w_lst) != self.num_particles: 

574 warn('Different number of weights than particles') 

575 else: 

576 self._weights = w_lst 

577 

578 @property 

579 def mean(self): 

580 """Mean of the particles.""" 

581 if self.num_particles == 0 or any(np.abs(self.weights) == np.inf): 

582 mean = np.array([[]]) 

583 else: 

584 mean = gmath.weighted_sum_vec(self.weights, self.particles) 

585 return mean 

586 

587 @property 

588 def covariance(self): 

589 """Covariance of the particles.""" 

590 if self.num_particles == 0: 

591 cov = np.array([[]]) 

592 else: 

593 x_dim = self.particles[0].size 

594 cov = np.cov(np.hstack(self.particles)).reshape((x_dim, x_dim)) 

595 cov = (cov + cov.T) * 0.5 

596 return cov 

597 

598 @property 

599 def num_particles(self): 

600 """Number of particles.""" 

601 return len(self._particles) 

602 

603 def __iter__(self): 

604 """Custom iterator for looping over the object. 

605 

606 Returns 

607 ------- 

608 :class:`.distributions.Particle` 

609 Current particle in distribution. 

610 float 

611 Weight of the current particle. 

612 """ 

613 return _ParticleDistIter(self)