Coverage for src/gncpy/planning/reinforcement_learning/envs/simple2d/simpleUAV2d.py: 0%

123 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-13 06:15 +0000

1"""Implements RL environments for the SimpleUAV2d game. 

2 

3This follows the new format for the OpenAI Gym environment API. They provide 

4default wrappers for backwards compatibility with some learning libraries. 

5""" 

6import gym 

7import numpy as np 

8import matplotlib.pyplot as plt 

9from gym import spaces 

10from warnings import warn 

11 

12from gncpy.games.SimpleUAV2d import SimpleUAV2d as UAVGame 

13 

14 

15class SimpleUAV2d(gym.Env): 

16 """RL environment for the :class:`gncpy.games.SimpleUAV2d.SimpleUAV2d` game. 

17 

18 Attributes 

19 ---------- 

20 render_mode : string 

21 Mode to render in. See :attr:`.metadata` for available modes. 

22 game : :class:`gncpy.games.SimpleUAV2d.SimpleUAV2d` 

23 Main game to play. 

24 fig : matplotlib figure 

25 For legacy support of rendering function 

26 obs_type : string 

27 Observation type to use. Options are :code:`'image'` or :code:`'player_state'`. 

28 aux_use_n_targets : bool 

29 Flag indicating if auxilary state uses the number of targets. 

30 aux_use_time : bool 

31 Flag indicating if auxilary state uses the current time. 

32 max_time : float 

33 Maximum time in real units for the environment. It is recommended to 

34 set the game to have unlimited time and use this instead as this allows 

35 the RL algorithms more visibility. Once this time is surpassed the 

36 episode is truncated and appropriate flags are set. 

37 observation_space : :class:`gym.spaces.Box` or :class:`gym.spaces.Dict` 

38 Observation space. This depends on the observation type and auxilary 

39 flags. 

40 """ 

41 

42 metadata = {"render_modes": ["human", "single_rgb_array"], "render_fps": 60} 

43 """Additional metadata for the class.""" 

44 

45 action_space = spaces.Box( 

46 low=-np.ones(2), high=np.ones(2), dtype=np.float32 

47 ) 

48 """Space for available actions.""" 

49 

50 def __init__( 

51 self, 

52 config_file="SimpleUAV2d.yaml", 

53 render_mode="single_rgb_array", 

54 obs_type="player_state", 

55 max_time=10, 

56 aux_use_n_targets=False, 

57 aux_use_time=False, 

58 ): 

59 """Initialize an object. 

60 

61 Parameters 

62 ---------- 

63 config_file : string, optional 

64 Full path of the configuration file. The default is "SimpleUAV2d.yaml". 

65 render_mode : string, optional 

66 Render mode to use. Must be specified at initialization time, then 

67 the render function does not need to be called. The default is 

68 "single_rgb_array". 

69 obs_type : string, optional 

70 Observation type to use. The default is "player_state". 

71 max_time : float, optional 

72 Maximum time for an episode in game's real units. The default is 10. 

73 aux_use_n_targets : bool, optional 

74 Flag indicating if auxilary state uses the number of targets. The 

75 default is False. 

76 aux_use_time : bool, optional 

77 Flag indicating if auxilary state uses the current time. The default 

78 is False. 

79 """ 

80 super().__init__() 

81 

82 if render_mode in self.metadata["render_modes"]: 

83 self.render_mode = render_mode 

84 else: 

85 self.render_mode = self.metadata["render_modes"][0] 

86 warn( 

87 "Invalid render mode ({}) defaulting to {}".format( 

88 render_mode, self.render_mode 

89 ) 

90 ) 

91 

92 # self.action_space = spaces.Box( 

93 # low=-np.ones(2), high=np.ones(2), dtype=np.float32 

94 # ) 

95 

96 self.game = UAVGame(config_file, self.render_mode, rng=self.np_random) 

97 self.game.setup() 

98 self.game.step(self.gen_act_map(np.zeros_like(self.action_space.low))) 

99 

100 self.fig = None # for legacy support of render function 

101 self.obs_type = obs_type 

102 self.aux_use_n_targets = aux_use_n_targets 

103 self.aux_use_time = aux_use_time 

104 self.max_time = max_time 

105 

106 self.observation_space = self.calc_obs_space() 

107 

108 self.metadata["render_fps"] = self.game.render_fps 

109 

110 def step(self, action): 

111 """Perform one iteration of the game loop. 

112 

113 Parameters 

114 ---------- 

115 action : numpy array 

116 Action to take in the game. 

117 

118 Returns 

119 ------- 

120 observation : :class:`gym.spaces.Box` or :class:`gym.spaces.Dict` 

121 Current observation of the game. 

122 reward : float 

123 Reward from current step. 

124 done : bool 

125 Flag indicating if the episode met the end conditions. 

126 truncated : bool 

127 Flag indicating if the episode has ended due to time constraints. 

128 info : dict 

129 Extra debugging info. 

130 """ 

131 truncated = False 

132 info = self.game.step(self.gen_act_map(action)) 

133 

134 if self.max_time is None: 

135 truncated = False 

136 else: 

137 truncated = self.game.elapsed_time > self.max_time 

138 

139 return (self._get_obs(), self.game.score, self.game.game_over, truncated, info) 

140 

141 def render(self, mode=None): 

142 """Deprecated. Handles rendering a frame of the environment. 

143 

144 This is deprecated and the render mode should instead be set at 

145 initialization. 

146 

147 Parameters 

148 ---------- 

149 mode : string, optional 

150 The rendering mode to use. The default is None, which does nothing. 

151 """ 

152 if mode is None: 

153 return 

154 elif self.render_mode is not None: 

155 DeprecationWarning( 

156 "Calling render directly is deprecated, " 

157 + "specify the render mode during initialization instead." 

158 ) 

159 

160 if mode == "single_rgb_array": 

161 return self.game.img.copy() 

162 elif self.render_mode == "human": 

163 return 

164 elif mode == "human": 

165 if self.fig is None: 

166 px2in = 1 / plt.rcParams["figure.dpi"] # pixel in inches 

167 orig = plt.rcParams["toolbar"] 

168 plt.rcParams["toolbar"] = "None" 

169 self.fig = plt.figure( 

170 figsize=( 

171 px2in * self.game.img.shape[0], 

172 px2in * self.game.img.shape[1], 

173 ) 

174 ) 

175 self.fig.add_axes([0, 0, 1, 1], frame_on=False, rasterized=True) 

176 plt.rcParams["toolbar"] = orig 

177 

178 self.fig.clear() 

179 self.fig.imshow(self.game.img) 

180 plt.pause(1 / self.game.render_fps) 

181 

182 else: 

183 warn("Invalid render mode: {}".format(mode)) 

184 

185 def reset(self, seed=None, return_info=False, options=None): 

186 """Resets the environment to an initial state. 

187 

188 This method can reset the environment’s random number generator(s) 

189 if seed is an integer or if the environment has not yet initialized a 

190 random number generator. If the environment already has a random number 

191 generator and reset() is called with seed=None, the RNG should not be 

192 reset. Moreover, reset() should (in the typical use case) be called 

193 with an integer seed right after initialization and then never again. 

194 

195 Parameters 

196 ---------- 

197 seed : int, optional 

198 The seed that is used to initialize the environment’s PRNG. If the 

199 environment does not already have a PRNG and :code:`seed=None` 

200 (the default option) is passed, a seed will be chosen from some 

201 source of entropy (e.g. timestamp or /dev/urandom). However, if the 

202 environment already has a PRNG and :code:`seed=None` is passed, the 

203 PRNG will not be reset. If you pass an integer, the PRNG will be 

204 reset even if it already exists. Usually, you want to pass an 

205 integer right after the environment has been initialized and then 

206 never again. The default is None. 

207 return_info : bool, optional 

208 If true, return additional information along with initial 

209 observation. This info should be analogous to the info returned in 

210 :meth:`.step`. The default is False. 

211 options : dict, optional 

212 Not used by this environment. The default is None. 

213 

214 Returns 

215 ------- 

216 observation : :class:`gym.spaces.Box` or :class:`gym.spaces.Dict` 

217 Initial observation of the environment. 

218 info : dict, optional 

219 Additonal debugging info, only returned if :code:`return_info=True`. 

220 """ 

221 seed = super().reset(seed=seed) 

222 

223 self.game.reset(rng=self.np_random) 

224 info = self.game.step(self.gen_act_map(np.zeros_like(self.action_space.low))) 

225 observation = self._get_obs() 

226 

227 return (observation, info) if return_info else observation 

228 

229 def close(self): 

230 """Nicely shuts down the environment.""" 

231 self.game.close() 

232 super().close() 

233 

234 def calc_obs_space(self): 

235 """Determines the observation space based on specified options. 

236 

237 If a dictionary space is used, the keys are :code:`'img'` for the image 

238 of the game screen and :code:`'aux'` for the auxilary state vector. Both 

239 values are boxes. 

240 

241 Raises 

242 ------ 

243 RuntimeError 

244 Invalid observation type specified. 

245 

246 Returns 

247 ------- 

248 out : :class:`gym.spaces.Box` or :class:`gym.spaces.Dict` 

249 Observation space. 

250 """ 

251 # determine main state 

252 main_state = None 

253 if self.obs_type == "image": 

254 shape = (*self.game.get_image_size(), 3) 

255 main_state = spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8) 

256 

257 elif self.obs_type == "player_state": 

258 state_bnds = self.game.get_player_state_bounds() 

259 main_state = spaces.Box( 

260 low=state_bnds[0], high=state_bnds[1], dtype=np.float32 

261 ) 

262 

263 else: 

264 raise RuntimeError("Invalid observation type ({})".format(self.obs_type)) 

265 

266 # create aux state if needed 

267 aux_state_low = np.array([]) 

268 aux_state_high = np.array([]) 

269 

270 if self.aux_use_n_targets: 

271 aux_state_low = np.append(aux_state_low, 0) 

272 aux_state_high = np.append(aux_state_high, np.inf) 

273 

274 if self.aux_use_time: 

275 aux_state_low = np.append(aux_state_low, 0) 

276 aux_state_high = np.append(aux_state_high, np.inf) 

277 

278 # combine into final space 

279 if self.obs_type == "image": 

280 if aux_state_low.size > 0: 

281 aux_state = spaces.Box(aux_state_low, aux_state_high, dtype=np.float32) 

282 out = spaces.Dict({"img": main_state, "aux": aux_state}) 

283 

284 else: 

285 out = main_state 

286 

287 else: 

288 if aux_state_low.size > 0: 

289 low = np.concatenate((main_state.low, aux_state_low)) 

290 high = np.concatenate((main_state.high, aux_state_high)) 

291 out = spaces.Box(low, high, dtype=np.float32) 

292 

293 else: 

294 out = main_state 

295 

296 return out 

297 

298 def gen_act_map(self, action): 

299 """Maps actions to entity ids for the game. 

300 

301 This assumes there is only 1 player and if there are more then all 

302 players get the same action. 

303 

304 Parameters 

305 ---------- 

306 action : numpy array 

307 Action to take in the game. 

308 

309 Returns 

310 ------- 

311 act_map : dict 

312 Each key is an entity id and each value is a numpy array. 

313 """ 

314 # Note: should only have 1 player 

315 ids = self.game.get_player_ids() 

316 if len(ids) > 1: 

317 warn( 

318 "Multi-player environment not supported, " 

319 + "all players using same action." 

320 ) 

321 

322 act_map = {} 

323 for _id in ids: 

324 act_map[_id] = action 

325 return act_map 

326 

327 def _get_obs(self): 

328 """Generates an observation.""" 

329 # get main state 

330 if self.obs_type == "image": 

331 main_state = self.game.img.copy() 

332 elif self.obs_type == "player_state": 

333 p_states = self.game.get_players_state() 

334 if len(p_states) == 0: 

335 raise RuntimeError("No players alive") 

336 main_state = p_states[list(p_states.keys())[0]] 

337 else: # catch all in case a new case is forgotten 

338 msg = "Failed to generate observation for type {}".format(self.obs_type) 

339 raise NotImplementedError(msg) 

340 

341 # get aux state, if any 

342 aux_state = np.array([]) 

343 if self.aux_use_n_targets: 

344 aux_state = np.append(aux_state, self.game.get_num_targets()) 

345 

346 if self.aux_use_time: 

347 aux_state = np.append(aux_state, self.game.elapsed_time) 

348 

349 # combine into output 

350 if self.obs_type == "image": 

351 if aux_state.size > 0: 

352 return dict(img=main_state, aux=aux_state) 

353 else: 

354 return main_state 

355 else: 

356 if aux_state.size > 0: 

357 return np.concatenate((main_state, aux_state), dtype=np.float32) 

358 else: 

359 return main_state.astype(np.float32) 

360 

361 

362class SimpleUAVHazards2d(SimpleUAV2d): 

363 """Simple 2d UAV environment with hazards. 

364 

365 This follows the same underlying game logic as the :class:`.SimpleUAV2d` 

366 environment but has some hazards added to its default configuration. 

367 """ 

368 

369 def __init__(self, config_file="SimpleUAVHazards2d.yaml", **kwargs): 

370 """Initialize an object.""" 

371 super().__init__(config_file=config_file, **kwargs)