Coverage for src/gncpy/planning/reinforcement_learning/wrappers.py: 0%

150 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-13 06:15 +0000

1"""Implements some useful wrappers for gym environments.""" 

2import gym 

3import numpy as np 

4import cv2 

5 

6from collections import deque 

7 

8 

9class ResizeImage(gym.ObservationWrapper): 

10 """Wrapper for resizing an image. 

11 

12 This can also work with resizing an image in a dict space. Assumes the image 

13 is H x C x 3. 

14 """ 

15 

16 def __init__(self, env=None, n_rows=84, n_cols=84, key='img'): 

17 """Initialize an object. 

18 

19 Parameters 

20 ---------- 

21 env : gym environment, optional 

22 Environment to wrap. The default is None. 

23 n_rows : int, optional 

24 final number of rows. The default is 84. 

25 n_cols : int, optional 

26 final number of columns. The default is 84. 

27 key : string, optional 

28 key in the dictionary space corresponding to the image. The default is 'img'. 

29 """ 

30 super().__init__(env) 

31 print('Wrapping the env in a', repr(type(self).__name__), 'wrapper.') 

32 

33 self._nrows = n_rows 

34 self._ncols = n_cols 

35 self._key = key 

36 if isinstance(self.observation_space, gym.spaces.Dict): 

37 spaces = {} 

38 for k, v in self.observation_space.spaces.items(): 

39 if k == key: 

40 new_space = gym.spaces.Box(low=0., high=255., 

41 shape=(n_rows, n_cols, 3), 

42 dtype=np.uint8) 

43 else: 

44 new_space = v 

45 

46 spaces[k] = new_space 

47 self.observation_space = gym.spaces.Dict(spaces) 

48 

49 else: 

50 self.observation_space = gym.spaces.Box(low=0., high=255., 

51 shape=(n_rows, n_cols, 3), 

52 dtype=np.uint8) 

53 

54 def observation(self, obs): 

55 """Generates the proper observation.""" 

56 if isinstance(obs, dict) or isinstance(obs, gym.spaces.Dict): 

57 resized_screen = cv2.resize(obs[self._key], (self._ncols, self._nrows), 

58 interpolation=cv2.INTER_AREA) 

59 obs[self._key] = resized_screen.astype(np.uint8) 

60 return obs 

61 else: 

62 resized_screen = cv2.resize(obs, (self._ncols, self._nrows), 

63 interpolation=cv2.INTER_AREA) 

64 return resized_screen.astype(np.uint8) 

65 

66 

67class GrayScaleObservation(gym.ObservationWrapper): 

68 r"""Convert the image observation from RGB to gray scale. 

69 

70 Mostly the same as the open ai gym implementation except this allows 

71 for the wrapper to be applied to a single element inside a Dict observation. 

72 """ 

73 

74 def __init__(self, env, keep_dim=True, key='img'): 

75 """Initialze an object. 

76 

77 Parameters 

78 ---------- 

79 env : gym environment 

80 Environment to wrap. 

81 keep_dim : bool, optional 

82 Flag indicating if the channel dimension should be kept. The 

83 default is True. 

84 key : string, optional 

85 Key in teh dictionary corresponding to the image. The default is 'img'. 

86 """ 

87 super().__init__(env) 

88 print('Wrapping the env in a', repr(type(self).__name__), 'wrapper.') 

89 self.keep_dim = keep_dim 

90 self._key = key 

91 

92 if isinstance(self.observation_space, gym.spaces.Dict): 

93 spaces = {} 

94 for k, v in self.observation_space.spaces.items(): 

95 if k == key: 

96 assert (len(v.shape) == 3 and v.shape[-1] == 3) 

97 

98 obs_shape = v.shape[:2] 

99 if self.keep_dim: 

100 new_space = gym.spaces.Box(low=0, high=255, 

101 shape=(obs_shape[0], 

102 obs_shape[1], 1), 

103 dtype=np.uint8) 

104 else: 

105 new_space = gym.spaces.Box(low=0, high=255, 

106 shape=obs_shape, dtype=np.uint8) 

107 else: 

108 new_space = v 

109 

110 spaces[k] = new_space 

111 self.observation_space = gym.spaces.Dict(spaces) 

112 

113 else: 

114 assert ( 

115 len(env.observation_space.shape) == 3 

116 and env.observation_space.shape[-1] == 3 

117 ) 

118 

119 obs_shape = self.observation_space.shape[:2] 

120 if self.keep_dim: 

121 self.observation_space = gym.spaces.Box(low=0, high=255, 

122 shape=(obs_shape[0], 

123 obs_shape[1], 1), 

124 dtype=np.uint8) 

125 else: 

126 self.observation_space = gym.spaces.Box(low=0, high=255, 

127 shape=obs_shape, 

128 dtype=np.uint8) 

129 

130 def observation(self, obs): 

131 """Generates the proper observation.""" 

132 if isinstance(obs, dict) or isinstance(obs, gym.spaces.Dict): 

133 obs[self._key] = cv2.cvtColor(obs[self._key], cv2.COLOR_RGB2GRAY) 

134 if self.keep_dim: 

135 obs[self._key] = np.expand_dims(obs[self._key], -1) 

136 return obs 

137 else: 

138 obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY) 

139 if self.keep_dim: 

140 obs = np.expand_dims(obs, -1) 

141 return obs 

142 

143 

144class BufferFames(gym.ObservationWrapper): 

145 """Buffer frames. 

146 

147 Similar to the open ai gym implementation except this allows 

148 for the wrapper to be applied to a single element inside a Dict observation, 

149 and does not use a LazyFrame wrapper. 

150 """ 

151 

152 def __init__(self, env, num_stack, key='img'): 

153 super().__init__(env) 

154 print('Wrapping the env in a', repr(type(self).__name__), 'wrapper.') 

155 

156 self._key = key 

157 self._num_stack = num_stack 

158 

159 if isinstance(self.observation_space, gym.spaces.Dict): 

160 spaces = {} 

161 for k, v in self.observation_space.spaces.items(): 

162 if k == key: 

163 low = np.repeat(v.low[np.newaxis, ...], num_stack, axis=0) 

164 high = np.repeat(v.high[np.newaxis, ...], 

165 num_stack, axis=0) 

166 new_space = gym.spaces.Box(low=low, high=high, 

167 dtype=v.dtype) 

168 self.buffer = 255 * np.ones_like(new_space.low) 

169 

170 self.buffer = np.repeat(255 * np.ones_like(new_space.low[np.newaxis, ...]), 

171 num_stack, axis=0).astype(np.uint8) 

172 

173 else: 

174 new_space = v 

175 

176 spaces[k] = new_space 

177 self.observation_space = gym.spaces.Dict(spaces) 

178 else: 

179 

180 low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0) 

181 high = np.repeat(self.observation_space.high[np.newaxis, ...], 

182 num_stack, axis=0) 

183 self.observation_space = gym.spaces.Box(low=low, high=high, 

184 dtype=self.observation_space.dtype) 

185 self.buffer = 255 * np.ones_like(self.observation_space.low) 

186 

187 def reset(self): 

188 """Wraps the reset function.""" 

189 if isinstance(self.observation_space, gym.spaces.Dict): 

190 self.buffer = 255 * np.ones_like(self.observation_space[self._key].low) 

191 else: 

192 self.buffer = np.zeros_like(self.observation_space.low) 

193 return self.observation(self.env.reset()) 

194 

195 def observation(self, obs): 

196 """Generates the proper observation.""" 

197 if isinstance(obs, dict) or isinstance(obs, gym.spaces.Dict): 

198 self.buffer[:-1] = self.buffer[1:] 

199 self.buffer[-1] = obs[self._key] 

200 obs[self._key] = self.buffer 

201 return obs 

202 else: 

203 self.buffer[:-1] = self.buffer[1:] 

204 self.buffer[-1] = obs 

205 return self.buffer 

206 

207 

208class StackFrames(gym.ObservationWrapper): 

209 """Stack frames by taking the average of the buffer. 

210 

211 Similar to the open ai gym implementation except this allows 

212 for the wrapper to be applied to a single element inside a Dict observation, 

213 and does not use a LazyFrame wrapper. 

214 """ 

215 

216 def __init__(self, env, num_stack, key='img'): 

217 super().__init__(env) 

218 print('Wrapping the env in a', repr(type(self).__name__), 'wrapper.') 

219 

220 self._key = key 

221 if isinstance(self.observation_space, gym.spaces.Dict): 

222 assert self._key in self.observation_space.spaces.keys(), f"Key {self._key} not in observation space" 

223 

224 self._num_stack = num_stack 

225 self.buffer = deque([], maxlen=num_stack) 

226 

227 def reset(self): 

228 """Wraps the reset function.""" 

229 self.buffer.clear() 

230 return self.observation(self.env.reset()) 

231 

232 def observation(self, obs): 

233 """Generates the proper observation.""" 

234 if isinstance(obs, dict) or isinstance(obs, gym.spaces.Dict): 

235 self.buffer.appendleft(obs[self._key]) 

236 obs[self._key] = (np.sum(self.buffer, axis=0) / len(self.buffer)).astype(np.uint8) 

237 return obs 

238 else: 

239 self.buffer.appendleft(obs) 

240 obs = (np.sum(self.buffer, axis=0) / len(self.buffer)).astype(np.uint8) 

241 return obs 

242 

243 

244class MaxFrames(gym.ObservationWrapper): 

245 """Buffer frames by taking the max value of each pixel in the buffer. 

246 

247 Similar to the open ai gym implementation except this allows 

248 for the wrapper to be applied to a single element inside a Dict observation, 

249 and does not use a LazyFrame wrapper. 

250 """ 

251 

252 def __init__(self, env, num_stack, key='img'): 

253 super().__init__(env) 

254 print('Wrapping the env in a', repr(type(self).__name__), 'wrapper.') 

255 

256 self._key = key 

257 if isinstance(self.observation_space, gym.spaces.Dict): 

258 assert self._key in self.observation_space.spaces.keys(), f"Key {self._key} not in observation space" 

259 

260 self._num_stack = num_stack 

261 self.buffer = deque([], maxlen=num_stack) 

262 

263 def reset(self): 

264 """Wraps the reset function.""" 

265 self.buffer.clear() 

266 return self.observation(self.env.reset()) 

267 

268 def observation(self, obs): 

269 """Generates the proper observation.""" 

270 if isinstance(obs, dict) or isinstance(obs, gym.spaces.Dict): 

271 self.buffer.appendleft(obs[self._key]) 

272 obs[self._key] = np.max(self.buffer, axis=0).astype(np.uint8) 

273 return obs 

274 else: 

275 self.buffer.appendleft(obs) 

276 obs = np.max(self.buffer, axis=0).astype(np.uint8) 

277 return obs 

278 

279 

280class SkipFrames(gym.Wrapper): 

281 """Wrapper for skipping frames.""" 

282 

283 def __init__(self, env, frame_skip): 

284 """Initialize an object. 

285 

286 Parameters 

287 ---------- 

288 env : gym environment 

289 Environment to wrap. 

290 frame_skip : int 

291 Number of frames to skip. 

292 """ 

293 super().__init__(env) 

294 print('Wrapping the env in a', repr(type(self).__name__), 'wrapper.') 

295 assert frame_skip > 0, "frame_skip must be > 0" 

296 

297 self.frame_skip = frame_skip 

298 

299 def step(self, action): 

300 """Wraps the step function.""" 

301 R = 0.0 

302 

303 for _ in range(self.frame_skip): 

304 obs, reward, done, info = self.env.step(action) 

305 R += reward 

306 

307 if done: 

308 break 

309 

310 return obs, R, done, info