Coverage for src/gncpy/filters/bootstrap_filter.py: 86%

58 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-19 05:48 +0000

1import numpy as np 

2import numpy.random as rnd 

3from copy import deepcopy 

4from warnings import warn 

5 

6import gncpy.distributions as gdistrib 

7import gncpy.errors as gerr 

8from gncpy.filters.bayes_filter import BayesFilter 

9 

10 

11class BootstrapFilter(BayesFilter): 

12 """Stripped down version of the :class:`.ParticleFilter`. 

13 

14 This is an alternative implementation of a basic Particle filter. This 

15 removes some of the quality of life features of the :class:`.ParticleFilter` 

16 class and can be more complicated to setup. But it may provide runtime improvements 

17 for simple cases. Most times it is advised to use the :class:`.ParticleFilter` 

18 instead of this class. Most other derived classes use the :class:`.ParticleFilter` 

19 class as a base. 

20 """ 

21 

22 def __init__( 

23 self, 

24 importance_dist_fnc=None, 

25 importance_weight_fnc=None, 

26 particleDistribution=None, 

27 rng=None, 

28 **kwargs 

29 ): 

30 """Initializes the object. 

31 

32 Parameters 

33 ---------- 

34 importance_dist_fnc : callable, optional 

35 Must have the signature `f(parts, rng)` where `parts` is an 

36 instance of :class:`.distributions.SimpleParticleDistribution` 

37 and `rng` is a numpy random generator. It must return a numpy array 

38 of new particles for a :class:`.distributions.SimpleParticleDistribution`. 

39 Any state transitions to a new timestep must happen within this 

40 function. The default is None. 

41 importance_weight_fnc : callable, optional 

42 Must have the signature `f(meas, parts)` where `meas` is an Nm x 1 

43 numpy array representing the measurement and `parts` is the 

44 numpy array of particles from a :class:`.distributions.SimpleParticleDistribution` 

45 object. It must return a numpy array of weights, one for each 

46 particle. The default is None. 

47 particleDistribution : :class:`.distributions.SimpleParticleDistribution`, optional 

48 Initial particle distribution to use. The default is None. 

49 rng : numpy random generator, optional 

50 Random number generator to use. If none supplied then the numpy default 

51 is used. The default is None. 

52 **kwargs : dict 

53 Additional arguments for the parent constructor. 

54 """ 

55 super().__init__(**kwargs) 

56 self.importance_dist_fnc = importance_dist_fnc 

57 self.importance_weight_fnc = importance_weight_fnc 

58 if particleDistribution is None: 

59 self.particleDistribution = gdistrib.SimpleParticleDistribution() 

60 else: 

61 self.particleDistribution = particleDistribution 

62 if rng is None: 

63 self.rng = rnd.default_rng() 

64 else: 

65 self.rng = rng 

66 

67 def save_filter_state(self): 

68 """Saves filter variables so they can be restored later.""" 

69 filt_state = super().save_filter_state() 

70 filt_state["importance_dist_fnc"] = self.importance_dist_fnc 

71 filt_state["importance_weight_fnc"] = self.importance_weight_fnc 

72 filt_state["particleDistribution"] = deepcopy(self.particleDistribution) 

73 filt_state["rng"] = self.rng 

74 

75 return filt_state 

76 

77 def load_filter_state(self, filt_state): 

78 """Initializes filter using saved filter state. 

79 

80 Attributes 

81 ---------- 

82 filt_state : dict 

83 Dictionary generated by :meth:`save_filter_state`. 

84 """ 

85 super().load_filter_state(filt_state) 

86 

87 self.importance_dist_fnc = filt_state["importance_dist_fnc"] 

88 self.importance_weight_fnc = filt_state["importance_weight_fnc"] 

89 self.particleDistribution = filt_state["particleDistribution"] 

90 self.rng = filt_state["rng"] 

91 

92 def predict(self, timestep): 

93 """Prediction step of the filter. 

94 

95 Calls the importance distribution function to generate new samples of 

96 particles. 

97 

98 Parameters 

99 ---------- 

100 timestep : float 

101 Current timestep. 

102 

103 Returns 

104 ------- 

105 N x 1 

106 mean estimate of the particles. 

107 """ 

108 self.particleDistribution.particles = self.importance_dist_fnc( 

109 self.particleDistribution, self.rng 

110 ) 

111 

112 shape = (self.particleDistribution.particles.shape[1], 1) 

113 return np.mean(self.particleDistribution.particles, axis=0).reshape(shape) 

114 

115 def correct(self, timestep, meas): 

116 """Correction step of the filter. 

117 

118 Parameters 

119 ---------- 

120 timestep : float 

121 Current timestep. 

122 meas : Nm x 1 numpy array 

123 Current measurement. 

124 

125 Raises 

126 ------ 

127 :class:`gerr.ParticleDepletionError` 

128 If all particles weights sum to zero (all particles will be removed). 

129 

130 Returns 

131 ------- 

132 N x 1 numpy array 

133 mean estimate of the particles. 

134 """ 

135 self.particleDistribution.weights *= self.importance_weight_fnc( 

136 meas, self.particleDistribution.particles 

137 ) 

138 tot = np.sum(self.particleDistribution.weights) 

139 if tot <= 0: 

140 raise gerr.ParticleDepletionError("Importance weights sum to 0.") 

141 self.particleDistribution.weights /= tot 

142 

143 # selection 

144 num_parts = self.particleDistribution.num_particles 

145 keep_inds = self.rng.choice( 

146 np.array(range(self.particleDistribution.weights.size)), 

147 p=self.particleDistribution.weights, 

148 size=num_parts, 

149 ) 

150 unique_inds, counts = np.unique(keep_inds, return_counts=True) 

151 self.particleDistribution.num_parts_per_ind = counts 

152 self.particleDistribution.particles = self.particleDistribution.particles[ 

153 unique_inds, : 

154 ] 

155 self.particleDistribution.weights = ( 

156 1 / num_parts * np.ones(self.particleDistribution.particles.shape[0]) 

157 ) 

158 

159 if unique_inds.size <= 1: 

160 msg = "Only {:d} particles selected".format(unique_inds.size) 

161 raise gerr.ParticleDepletionError(msg) 

162 # weights are all equal here so don't need weighted sum 

163 shape = (self.particleDistribution.particles.shape[1], 1) 

164 return np.mean(self.particleDistribution.particles, axis=0).reshape(shape) 

165 

166 def set_state_model(self, **kwargs): 

167 """Not used by the Bootstrap filter.""" 

168 warn( 

169 "Not used by BootstrapFilter, directly handled by importance_dist_fnc.", 

170 RuntimeWarning, 

171 ) 

172 

173 def set_measurement_model(self, **kwargs): 

174 """Not used by the Bootstrap filter.""" 

175 warn( 

176 "Not used by BootstrapFilter, directly handled by importance_weight_fnc.", 

177 RuntimeWarning, 

178 ) 

179 

180 def plot_particles(self, inds, **kwargs): 

181 """Wrapper for :class:`.distributions.SimpleParticleDistribution.plot_particles`.""" 

182 return self.particleDistribution.plot_particles(inds, **kwargs)