Coverage for src/gncpy/planning/reinforcement_learning/envs/simple2d/simpleUAV2d.py: 0%
123 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
1"""Implements RL environments for the SimpleUAV2d game.
3This follows the new format for the OpenAI Gym environment API. They provide
4default wrappers for backwards compatibility with some learning libraries.
5"""
6import gym
7import numpy as np
8import matplotlib.pyplot as plt
9from gym import spaces
10from warnings import warn
12from gncpy.games.SimpleUAV2d import SimpleUAV2d as UAVGame
15class SimpleUAV2d(gym.Env):
16 """RL environment for the :class:`gncpy.games.SimpleUAV2d.SimpleUAV2d` game.
18 Attributes
19 ----------
20 render_mode : string
21 Mode to render in. See :attr:`.metadata` for available modes.
22 game : :class:`gncpy.games.SimpleUAV2d.SimpleUAV2d`
23 Main game to play.
24 fig : matplotlib figure
25 For legacy support of rendering function
26 obs_type : string
27 Observation type to use. Options are :code:`'image'` or :code:`'player_state'`.
28 aux_use_n_targets : bool
29 Flag indicating if auxilary state uses the number of targets.
30 aux_use_time : bool
31 Flag indicating if auxilary state uses the current time.
32 max_time : float
33 Maximum time in real units for the environment. It is recommended to
34 set the game to have unlimited time and use this instead as this allows
35 the RL algorithms more visibility. Once this time is surpassed the
36 episode is truncated and appropriate flags are set.
37 observation_space : :class:`gym.spaces.Box` or :class:`gym.spaces.Dict`
38 Observation space. This depends on the observation type and auxilary
39 flags.
40 """
42 metadata = {"render_modes": ["human", "single_rgb_array"], "render_fps": 60}
43 """Additional metadata for the class."""
45 action_space = spaces.Box(
46 low=-np.ones(2), high=np.ones(2), dtype=np.float32
47 )
48 """Space for available actions."""
50 def __init__(
51 self,
52 config_file="SimpleUAV2d.yaml",
53 render_mode="single_rgb_array",
54 obs_type="player_state",
55 max_time=10,
56 aux_use_n_targets=False,
57 aux_use_time=False,
58 ):
59 """Initialize an object.
61 Parameters
62 ----------
63 config_file : string, optional
64 Full path of the configuration file. The default is "SimpleUAV2d.yaml".
65 render_mode : string, optional
66 Render mode to use. Must be specified at initialization time, then
67 the render function does not need to be called. The default is
68 "single_rgb_array".
69 obs_type : string, optional
70 Observation type to use. The default is "player_state".
71 max_time : float, optional
72 Maximum time for an episode in game's real units. The default is 10.
73 aux_use_n_targets : bool, optional
74 Flag indicating if auxilary state uses the number of targets. The
75 default is False.
76 aux_use_time : bool, optional
77 Flag indicating if auxilary state uses the current time. The default
78 is False.
79 """
80 super().__init__()
82 if render_mode in self.metadata["render_modes"]:
83 self.render_mode = render_mode
84 else:
85 self.render_mode = self.metadata["render_modes"][0]
86 warn(
87 "Invalid render mode ({}) defaulting to {}".format(
88 render_mode, self.render_mode
89 )
90 )
92 # self.action_space = spaces.Box(
93 # low=-np.ones(2), high=np.ones(2), dtype=np.float32
94 # )
96 self.game = UAVGame(config_file, self.render_mode, rng=self.np_random)
97 self.game.setup()
98 self.game.step(self.gen_act_map(np.zeros_like(self.action_space.low)))
100 self.fig = None # for legacy support of render function
101 self.obs_type = obs_type
102 self.aux_use_n_targets = aux_use_n_targets
103 self.aux_use_time = aux_use_time
104 self.max_time = max_time
106 self.observation_space = self.calc_obs_space()
108 self.metadata["render_fps"] = self.game.render_fps
110 def step(self, action):
111 """Perform one iteration of the game loop.
113 Parameters
114 ----------
115 action : numpy array
116 Action to take in the game.
118 Returns
119 -------
120 observation : :class:`gym.spaces.Box` or :class:`gym.spaces.Dict`
121 Current observation of the game.
122 reward : float
123 Reward from current step.
124 done : bool
125 Flag indicating if the episode met the end conditions.
126 truncated : bool
127 Flag indicating if the episode has ended due to time constraints.
128 info : dict
129 Extra debugging info.
130 """
131 truncated = False
132 info = self.game.step(self.gen_act_map(action))
134 if self.max_time is None:
135 truncated = False
136 else:
137 truncated = self.game.elapsed_time > self.max_time
139 return (self._get_obs(), self.game.score, self.game.game_over, truncated, info)
141 def render(self, mode=None):
142 """Deprecated. Handles rendering a frame of the environment.
144 This is deprecated and the render mode should instead be set at
145 initialization.
147 Parameters
148 ----------
149 mode : string, optional
150 The rendering mode to use. The default is None, which does nothing.
151 """
152 if mode is None:
153 return
154 elif self.render_mode is not None:
155 DeprecationWarning(
156 "Calling render directly is deprecated, "
157 + "specify the render mode during initialization instead."
158 )
160 if mode == "single_rgb_array":
161 return self.game.img.copy()
162 elif self.render_mode == "human":
163 return
164 elif mode == "human":
165 if self.fig is None:
166 px2in = 1 / plt.rcParams["figure.dpi"] # pixel in inches
167 orig = plt.rcParams["toolbar"]
168 plt.rcParams["toolbar"] = "None"
169 self.fig = plt.figure(
170 figsize=(
171 px2in * self.game.img.shape[0],
172 px2in * self.game.img.shape[1],
173 )
174 )
175 self.fig.add_axes([0, 0, 1, 1], frame_on=False, rasterized=True)
176 plt.rcParams["toolbar"] = orig
178 self.fig.clear()
179 self.fig.imshow(self.game.img)
180 plt.pause(1 / self.game.render_fps)
182 else:
183 warn("Invalid render mode: {}".format(mode))
185 def reset(self, seed=None, return_info=False, options=None):
186 """Resets the environment to an initial state.
188 This method can reset the environment’s random number generator(s)
189 if seed is an integer or if the environment has not yet initialized a
190 random number generator. If the environment already has a random number
191 generator and reset() is called with seed=None, the RNG should not be
192 reset. Moreover, reset() should (in the typical use case) be called
193 with an integer seed right after initialization and then never again.
195 Parameters
196 ----------
197 seed : int, optional
198 The seed that is used to initialize the environment’s PRNG. If the
199 environment does not already have a PRNG and :code:`seed=None`
200 (the default option) is passed, a seed will be chosen from some
201 source of entropy (e.g. timestamp or /dev/urandom). However, if the
202 environment already has a PRNG and :code:`seed=None` is passed, the
203 PRNG will not be reset. If you pass an integer, the PRNG will be
204 reset even if it already exists. Usually, you want to pass an
205 integer right after the environment has been initialized and then
206 never again. The default is None.
207 return_info : bool, optional
208 If true, return additional information along with initial
209 observation. This info should be analogous to the info returned in
210 :meth:`.step`. The default is False.
211 options : dict, optional
212 Not used by this environment. The default is None.
214 Returns
215 -------
216 observation : :class:`gym.spaces.Box` or :class:`gym.spaces.Dict`
217 Initial observation of the environment.
218 info : dict, optional
219 Additonal debugging info, only returned if :code:`return_info=True`.
220 """
221 seed = super().reset(seed=seed)
223 self.game.reset(rng=self.np_random)
224 info = self.game.step(self.gen_act_map(np.zeros_like(self.action_space.low)))
225 observation = self._get_obs()
227 return (observation, info) if return_info else observation
229 def close(self):
230 """Nicely shuts down the environment."""
231 self.game.close()
232 super().close()
234 def calc_obs_space(self):
235 """Determines the observation space based on specified options.
237 If a dictionary space is used, the keys are :code:`'img'` for the image
238 of the game screen and :code:`'aux'` for the auxilary state vector. Both
239 values are boxes.
241 Raises
242 ------
243 RuntimeError
244 Invalid observation type specified.
246 Returns
247 -------
248 out : :class:`gym.spaces.Box` or :class:`gym.spaces.Dict`
249 Observation space.
250 """
251 # determine main state
252 main_state = None
253 if self.obs_type == "image":
254 shape = (*self.game.get_image_size(), 3)
255 main_state = spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
257 elif self.obs_type == "player_state":
258 state_bnds = self.game.get_player_state_bounds()
259 main_state = spaces.Box(
260 low=state_bnds[0], high=state_bnds[1], dtype=np.float32
261 )
263 else:
264 raise RuntimeError("Invalid observation type ({})".format(self.obs_type))
266 # create aux state if needed
267 aux_state_low = np.array([])
268 aux_state_high = np.array([])
270 if self.aux_use_n_targets:
271 aux_state_low = np.append(aux_state_low, 0)
272 aux_state_high = np.append(aux_state_high, np.inf)
274 if self.aux_use_time:
275 aux_state_low = np.append(aux_state_low, 0)
276 aux_state_high = np.append(aux_state_high, np.inf)
278 # combine into final space
279 if self.obs_type == "image":
280 if aux_state_low.size > 0:
281 aux_state = spaces.Box(aux_state_low, aux_state_high, dtype=np.float32)
282 out = spaces.Dict({"img": main_state, "aux": aux_state})
284 else:
285 out = main_state
287 else:
288 if aux_state_low.size > 0:
289 low = np.concatenate((main_state.low, aux_state_low))
290 high = np.concatenate((main_state.high, aux_state_high))
291 out = spaces.Box(low, high, dtype=np.float32)
293 else:
294 out = main_state
296 return out
298 def gen_act_map(self, action):
299 """Maps actions to entity ids for the game.
301 This assumes there is only 1 player and if there are more then all
302 players get the same action.
304 Parameters
305 ----------
306 action : numpy array
307 Action to take in the game.
309 Returns
310 -------
311 act_map : dict
312 Each key is an entity id and each value is a numpy array.
313 """
314 # Note: should only have 1 player
315 ids = self.game.get_player_ids()
316 if len(ids) > 1:
317 warn(
318 "Multi-player environment not supported, "
319 + "all players using same action."
320 )
322 act_map = {}
323 for _id in ids:
324 act_map[_id] = action
325 return act_map
327 def _get_obs(self):
328 """Generates an observation."""
329 # get main state
330 if self.obs_type == "image":
331 main_state = self.game.img.copy()
332 elif self.obs_type == "player_state":
333 p_states = self.game.get_players_state()
334 if len(p_states) == 0:
335 raise RuntimeError("No players alive")
336 main_state = p_states[list(p_states.keys())[0]]
337 else: # catch all in case a new case is forgotten
338 msg = "Failed to generate observation for type {}".format(self.obs_type)
339 raise NotImplementedError(msg)
341 # get aux state, if any
342 aux_state = np.array([])
343 if self.aux_use_n_targets:
344 aux_state = np.append(aux_state, self.game.get_num_targets())
346 if self.aux_use_time:
347 aux_state = np.append(aux_state, self.game.elapsed_time)
349 # combine into output
350 if self.obs_type == "image":
351 if aux_state.size > 0:
352 return dict(img=main_state, aux=aux_state)
353 else:
354 return main_state
355 else:
356 if aux_state.size > 0:
357 return np.concatenate((main_state, aux_state), dtype=np.float32)
358 else:
359 return main_state.astype(np.float32)
362class SimpleUAVHazards2d(SimpleUAV2d):
363 """Simple 2d UAV environment with hazards.
365 This follows the same underlying game logic as the :class:`.SimpleUAV2d`
366 environment but has some hazards added to its default configuration.
367 """
369 def __init__(self, config_file="SimpleUAVHazards2d.yaml", **kwargs):
370 """Initialize an object."""
371 super().__init__(config_file=config_file, **kwargs)