Coverage for src/gncpy/planning/reinforcement_learning/wrappers.py: 0%
150 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
1"""Implements some useful wrappers for gym environments."""
2import gym
3import numpy as np
4import cv2
6from collections import deque
9class ResizeImage(gym.ObservationWrapper):
10 """Wrapper for resizing an image.
12 This can also work with resizing an image in a dict space. Assumes the image
13 is H x C x 3.
14 """
16 def __init__(self, env=None, n_rows=84, n_cols=84, key='img'):
17 """Initialize an object.
19 Parameters
20 ----------
21 env : gym environment, optional
22 Environment to wrap. The default is None.
23 n_rows : int, optional
24 final number of rows. The default is 84.
25 n_cols : int, optional
26 final number of columns. The default is 84.
27 key : string, optional
28 key in the dictionary space corresponding to the image. The default is 'img'.
29 """
30 super().__init__(env)
31 print('Wrapping the env in a', repr(type(self).__name__), 'wrapper.')
33 self._nrows = n_rows
34 self._ncols = n_cols
35 self._key = key
36 if isinstance(self.observation_space, gym.spaces.Dict):
37 spaces = {}
38 for k, v in self.observation_space.spaces.items():
39 if k == key:
40 new_space = gym.spaces.Box(low=0., high=255.,
41 shape=(n_rows, n_cols, 3),
42 dtype=np.uint8)
43 else:
44 new_space = v
46 spaces[k] = new_space
47 self.observation_space = gym.spaces.Dict(spaces)
49 else:
50 self.observation_space = gym.spaces.Box(low=0., high=255.,
51 shape=(n_rows, n_cols, 3),
52 dtype=np.uint8)
54 def observation(self, obs):
55 """Generates the proper observation."""
56 if isinstance(obs, dict) or isinstance(obs, gym.spaces.Dict):
57 resized_screen = cv2.resize(obs[self._key], (self._ncols, self._nrows),
58 interpolation=cv2.INTER_AREA)
59 obs[self._key] = resized_screen.astype(np.uint8)
60 return obs
61 else:
62 resized_screen = cv2.resize(obs, (self._ncols, self._nrows),
63 interpolation=cv2.INTER_AREA)
64 return resized_screen.astype(np.uint8)
67class GrayScaleObservation(gym.ObservationWrapper):
68 r"""Convert the image observation from RGB to gray scale.
70 Mostly the same as the open ai gym implementation except this allows
71 for the wrapper to be applied to a single element inside a Dict observation.
72 """
74 def __init__(self, env, keep_dim=True, key='img'):
75 """Initialze an object.
77 Parameters
78 ----------
79 env : gym environment
80 Environment to wrap.
81 keep_dim : bool, optional
82 Flag indicating if the channel dimension should be kept. The
83 default is True.
84 key : string, optional
85 Key in teh dictionary corresponding to the image. The default is 'img'.
86 """
87 super().__init__(env)
88 print('Wrapping the env in a', repr(type(self).__name__), 'wrapper.')
89 self.keep_dim = keep_dim
90 self._key = key
92 if isinstance(self.observation_space, gym.spaces.Dict):
93 spaces = {}
94 for k, v in self.observation_space.spaces.items():
95 if k == key:
96 assert (len(v.shape) == 3 and v.shape[-1] == 3)
98 obs_shape = v.shape[:2]
99 if self.keep_dim:
100 new_space = gym.spaces.Box(low=0, high=255,
101 shape=(obs_shape[0],
102 obs_shape[1], 1),
103 dtype=np.uint8)
104 else:
105 new_space = gym.spaces.Box(low=0, high=255,
106 shape=obs_shape, dtype=np.uint8)
107 else:
108 new_space = v
110 spaces[k] = new_space
111 self.observation_space = gym.spaces.Dict(spaces)
113 else:
114 assert (
115 len(env.observation_space.shape) == 3
116 and env.observation_space.shape[-1] == 3
117 )
119 obs_shape = self.observation_space.shape[:2]
120 if self.keep_dim:
121 self.observation_space = gym.spaces.Box(low=0, high=255,
122 shape=(obs_shape[0],
123 obs_shape[1], 1),
124 dtype=np.uint8)
125 else:
126 self.observation_space = gym.spaces.Box(low=0, high=255,
127 shape=obs_shape,
128 dtype=np.uint8)
130 def observation(self, obs):
131 """Generates the proper observation."""
132 if isinstance(obs, dict) or isinstance(obs, gym.spaces.Dict):
133 obs[self._key] = cv2.cvtColor(obs[self._key], cv2.COLOR_RGB2GRAY)
134 if self.keep_dim:
135 obs[self._key] = np.expand_dims(obs[self._key], -1)
136 return obs
137 else:
138 obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
139 if self.keep_dim:
140 obs = np.expand_dims(obs, -1)
141 return obs
144class BufferFames(gym.ObservationWrapper):
145 """Buffer frames.
147 Similar to the open ai gym implementation except this allows
148 for the wrapper to be applied to a single element inside a Dict observation,
149 and does not use a LazyFrame wrapper.
150 """
152 def __init__(self, env, num_stack, key='img'):
153 super().__init__(env)
154 print('Wrapping the env in a', repr(type(self).__name__), 'wrapper.')
156 self._key = key
157 self._num_stack = num_stack
159 if isinstance(self.observation_space, gym.spaces.Dict):
160 spaces = {}
161 for k, v in self.observation_space.spaces.items():
162 if k == key:
163 low = np.repeat(v.low[np.newaxis, ...], num_stack, axis=0)
164 high = np.repeat(v.high[np.newaxis, ...],
165 num_stack, axis=0)
166 new_space = gym.spaces.Box(low=low, high=high,
167 dtype=v.dtype)
168 self.buffer = 255 * np.ones_like(new_space.low)
170 self.buffer = np.repeat(255 * np.ones_like(new_space.low[np.newaxis, ...]),
171 num_stack, axis=0).astype(np.uint8)
173 else:
174 new_space = v
176 spaces[k] = new_space
177 self.observation_space = gym.spaces.Dict(spaces)
178 else:
180 low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
181 high = np.repeat(self.observation_space.high[np.newaxis, ...],
182 num_stack, axis=0)
183 self.observation_space = gym.spaces.Box(low=low, high=high,
184 dtype=self.observation_space.dtype)
185 self.buffer = 255 * np.ones_like(self.observation_space.low)
187 def reset(self):
188 """Wraps the reset function."""
189 if isinstance(self.observation_space, gym.spaces.Dict):
190 self.buffer = 255 * np.ones_like(self.observation_space[self._key].low)
191 else:
192 self.buffer = np.zeros_like(self.observation_space.low)
193 return self.observation(self.env.reset())
195 def observation(self, obs):
196 """Generates the proper observation."""
197 if isinstance(obs, dict) or isinstance(obs, gym.spaces.Dict):
198 self.buffer[:-1] = self.buffer[1:]
199 self.buffer[-1] = obs[self._key]
200 obs[self._key] = self.buffer
201 return obs
202 else:
203 self.buffer[:-1] = self.buffer[1:]
204 self.buffer[-1] = obs
205 return self.buffer
208class StackFrames(gym.ObservationWrapper):
209 """Stack frames by taking the average of the buffer.
211 Similar to the open ai gym implementation except this allows
212 for the wrapper to be applied to a single element inside a Dict observation,
213 and does not use a LazyFrame wrapper.
214 """
216 def __init__(self, env, num_stack, key='img'):
217 super().__init__(env)
218 print('Wrapping the env in a', repr(type(self).__name__), 'wrapper.')
220 self._key = key
221 if isinstance(self.observation_space, gym.spaces.Dict):
222 assert self._key in self.observation_space.spaces.keys(), f"Key {self._key} not in observation space"
224 self._num_stack = num_stack
225 self.buffer = deque([], maxlen=num_stack)
227 def reset(self):
228 """Wraps the reset function."""
229 self.buffer.clear()
230 return self.observation(self.env.reset())
232 def observation(self, obs):
233 """Generates the proper observation."""
234 if isinstance(obs, dict) or isinstance(obs, gym.spaces.Dict):
235 self.buffer.appendleft(obs[self._key])
236 obs[self._key] = (np.sum(self.buffer, axis=0) / len(self.buffer)).astype(np.uint8)
237 return obs
238 else:
239 self.buffer.appendleft(obs)
240 obs = (np.sum(self.buffer, axis=0) / len(self.buffer)).astype(np.uint8)
241 return obs
244class MaxFrames(gym.ObservationWrapper):
245 """Buffer frames by taking the max value of each pixel in the buffer.
247 Similar to the open ai gym implementation except this allows
248 for the wrapper to be applied to a single element inside a Dict observation,
249 and does not use a LazyFrame wrapper.
250 """
252 def __init__(self, env, num_stack, key='img'):
253 super().__init__(env)
254 print('Wrapping the env in a', repr(type(self).__name__), 'wrapper.')
256 self._key = key
257 if isinstance(self.observation_space, gym.spaces.Dict):
258 assert self._key in self.observation_space.spaces.keys(), f"Key {self._key} not in observation space"
260 self._num_stack = num_stack
261 self.buffer = deque([], maxlen=num_stack)
263 def reset(self):
264 """Wraps the reset function."""
265 self.buffer.clear()
266 return self.observation(self.env.reset())
268 def observation(self, obs):
269 """Generates the proper observation."""
270 if isinstance(obs, dict) or isinstance(obs, gym.spaces.Dict):
271 self.buffer.appendleft(obs[self._key])
272 obs[self._key] = np.max(self.buffer, axis=0).astype(np.uint8)
273 return obs
274 else:
275 self.buffer.appendleft(obs)
276 obs = np.max(self.buffer, axis=0).astype(np.uint8)
277 return obs
280class SkipFrames(gym.Wrapper):
281 """Wrapper for skipping frames."""
283 def __init__(self, env, frame_skip):
284 """Initialize an object.
286 Parameters
287 ----------
288 env : gym environment
289 Environment to wrap.
290 frame_skip : int
291 Number of frames to skip.
292 """
293 super().__init__(env)
294 print('Wrapping the env in a', repr(type(self).__name__), 'wrapper.')
295 assert frame_skip > 0, "frame_skip must be > 0"
297 self.frame_skip = frame_skip
299 def step(self, action):
300 """Wraps the step function."""
301 R = 0.0
303 for _ in range(self.frame_skip):
304 obs, reward, done, info = self.env.step(action)
305 R += reward
307 if done:
308 break
310 return obs, R, done, info