Coverage for src/gncpy/sampling.py: 0%

22 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-19 05:48 +0000

1"""Implements standard sampling algorithms in a standardized fashion. 

2 

3.. todo:: 

4 Determine a standardized implementation for algorithm interface 

5""" 

6import numpy as np 

7import numpy.random as rnd 

8 

9 

10# TODO: confirm/fix this implementation 

11class MetropolisHastings: 

12 """Implements a Metropolis Hasting algorithm. 

13 

14 .. todo:: 

15 confirm/fix the implementation 

16 """ 

17 

18 def __init__(self, **kwargs): 

19 self.proposal_sampling_fnc = kwargs.get('proposal_sampling_fnc', None) 

20 self.proposal_fnc = kwargs.get('proposal_fnc', None) 

21 self.joint_density_fnc = kwargs.get('joint_density_fnc', None) 

22 self.max_iters = kwargs.get('max_iters', 1) 

23 

24 def sample(self, x, **kwargs): 

25 rng = kwargs.get('rng', rnd.default_rng()) 

26 

27 accepted = False 

28 out = x.copy() 

29 for ii in range(0, self.max_iters): 

30 # draw candidate sample 

31 cand = self.proposal_sampling_fnc(**kwargs) 

32 

33 # determine accpetance probability 

34 prob_last = self.proposal_fnc(out, cand, **kwargs) \ 

35 * self.joint_density_fnc(cand, **kwargs) 

36 

37 prob_cand = self.proposal_fnc(cand, out, **kwargs) \ 

38 * self.joint_density_fnc(out, **kwargs) 

39 

40 accept_prob = np.min((1, prob_last / prob_cand)) 

41 

42 # check fit 

43 u = rng.random() 

44 if u < accept_prob: 

45 out = cand 

46 accepted = True 

47 

48 return out, accepted