Coverage for src/gncpy/game_engine/base_game.py: 0%

209 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-13 06:15 +0000

1"""Defines base game engine classes. 

2 

3These define common functions, properties, and interfaces for all games. 

4""" 

5from abc import ABC, abstractmethod 

6import os 

7import pathlib 

8import numpy as np 

9from ruamel.yaml import YAML 

10 

11import gncpy.game_engine.rendering2d as grender2d 

12import gncpy.game_engine.components as gcomp 

13from gncpy.game_engine.rendering2d import Shape2dParams 

14from gncpy.game_engine.physics2d import Physics2dParams, Collision2dParams 

15from gncpy.game_engine.entities import EntityManager 

16 

17 

18yaml = YAML() 

19"""Global yaml interpretter, should be used when parsing any came configs.""" 

20 

21 

22def ndarray_representer(dumper, data): 

23 return dumper.represent_list(data.tolist()) 

24 

25 

26class WindowParams: 

27 """Parameters of the game window to be parsed by the yaml parser. 

28 

29 The types defined in this class determine what type the parser uses. 

30 

31 Attributes 

32 ---------- 

33 width : int 

34 Width of the window in pixels. 

35 height : int 

36 Height of the window in pixels. 

37 """ 

38 

39 def __init__(self): 

40 super().__init__() 

41 self.width = 0 

42 self.height = 0 

43 

44 

45class BaseParams: 

46 """Main parameter class, may be inherited from to add custom attributes. 

47 

48 Attributes 

49 ---------- 

50 window : :class:`.WindowParams` 

51 Parameters of the window. 

52 """ 

53 

54 def __init__(self): 

55 super().__init__() 

56 self.window = WindowParams() 

57 

58 

59class Base2dParams(BaseParams): 

60 """Base parameters for 2d games, can be inherited from. 

61 

62 Attributes 

63 ---------- 

64 physics : :class:`.physics2d.Physics2dParams` 

65 Physics system parameters. 

66 start_time : float 

67 Starting time in game. 

68 max_time : float 

69 Maximum time for the game to run. The default is infinity if not set 

70 in the config file. To manually set unlimited time then supply a negative 

71 value in the config file. 

72 """ 

73 

74 def __init__(self): 

75 super().__init__() 

76 self.physics = Physics2dParams() 

77 self.start_time = 0.0 

78 self.max_time = float("inf") 

79 

80 

81class BaseGame(ABC): 

82 """Base class for defining games. 

83 

84 This should implement all the necessary systems (i.e. game logic) by operating 

85 on entities. It must be subclassed and defines the expected interface of games. 

86 

87 Attributes 

88 ---------- 

89 config_file : string 

90 Full file path of the configuration file. 

91 entityManager : :class:`.entities.EntityManager` 

92 Factory class for making and managing game entities. 

93 render_mode : string 

94 Method of rendering the game. 

95 render_fps : int 

96 Frame rate to render the game at. 

97 current_frame : int 

98 Number of the current frame. 

99 score : float 

100 Score accumulated in the game. 

101 game_over : bool 

102 Flag indicating if the game has ended. 

103 params : :class:`BaseParams` 

104 Parameters for the game, read from yaml file. 

105 rng : numpy random generator 

106 Random number generator. 

107 seed_val : int 

108 Optional seed used in the random generator. 

109 library_dir : string 

110 Default directory to look for config files defined by the package. 

111 This is useful when extending this class outside of the gncpy 

112 package to provide a new default search location. The default of 

113 None is good for most other cases. 

114 """ 

115 

116 def __init__( 

117 self, 

118 config_file, 

119 render_mode, 

120 render_fps=None, 

121 use_library_config=False, 

122 seed=None, 

123 rng=None, 

124 library_dir=None, 

125 ): 

126 """Initialize the object. 

127 

128 Parameters 

129 ---------- 

130 config_file : string 

131 Full path of the configuration file. This can be the name of a one 

132 of the libraries default files but the use_library_config flag 

133 should also be set. 

134 render_mode : string 

135 Mode to render the game. 

136 render_fps : int, optional 

137 FPS to render the game at. The default is None. 

138 use_library_config : bool, optional 

139 Flag indicating if the config file is in the default library location. 

140 The default is False. 

141 seed : int, optional 

142 Seed for the random number generator. The default is 0. 

143 rng : numpy random generator, optional 

144 Instance of the random generator to use. The default is None and will 

145 cause one to be created. 

146 library_dir : string, optional 

147 Default directory to look for config files defined by the package. 

148 This is useful when extending this class outside of the gncpy 

149 package to provide a new default search location. The default of 

150 None is good for most other cases. 

151 """ 

152 super().__init__() 

153 if library_dir is None: 

154 self.library_config_dir = os.path.join( 

155 pathlib.Path(__file__).parent.parent.resolve(), "games", "configs" 

156 ) 

157 else: 

158 self.library_config_dir = library_dir 

159 

160 self.config_file = self.validate_config_file( 

161 config_file, use_library=use_library_config 

162 ) 

163 self.entityManager = EntityManager() 

164 self.render_mode = render_mode 

165 self.render_fps = render_fps 

166 self.current_frame = -1 

167 self.score = 0 

168 self.game_over = False 

169 self.params = None 

170 if rng is not None: 

171 self.rng = rng 

172 elif seed is not None: 

173 self.rng = np.random.default_rng(seed) 

174 else: 

175 self.rng = np.random.default_rng() 

176 

177 self.seed_val = seed 

178 

179 def setup(self): 

180 """Sets up the game by parsing the config file and checking max time. 

181 

182 This should be called before any other game functions. It should be 

183 extended by child classes. If overridden then the developer is responsible 

184 for registering the parameters for the yaml parsing and parsing the config 

185 file. 

186 """ 

187 global yaml 

188 self.register_params(yaml) 

189 

190 self.parse_config_file() 

191 if self.params.max_time < 0: 

192 self.params.max_time = float("inf") 

193 

194 def register_params(self, yaml): 

195 """Register classes with the yaml parser. 

196 

197 This should be extended by inherited classes. 

198 

199 Parameters 

200 ---------- 

201 yaml : ruamel.yaml YAML object 

202 yaml parser to use, should be the global parser. 

203 """ 

204 yaml.representer.add_representer(np.ndarray, ndarray_representer) 

205 yaml.register_class(WindowParams) 

206 yaml.register_class(BaseParams) 

207 

208 def validate_config_file(self, config_file, use_library): 

209 """Validate that the config file exists. 

210 

211 First checks if the file exists as provided, then checks the library 

212 directory if the use_library flag is true or it failed to find the file 

213 as provided. 

214 

215 Parameters 

216 ---------- 

217 config_file : string 

218 Full path to the config file. 

219 use_library : bool 

220 Flag indicating if the library directory will be checked. 

221 

222 Raises 

223 ------ 

224 FileNotFoundError 

225 If the file cannot be found. 

226 

227 Returns 

228 ------- 

229 cf : string 

230 full path to the config file that was found. 

231 """ 

232 succ = os.path.isfile(config_file) 

233 if use_library or not succ: 

234 cf = os.path.join(self.library_config_dir, config_file) 

235 succ = os.path.isfile(cf) 

236 else: 

237 cf = config_file 

238 

239 if not succ: 

240 raise FileNotFoundError("Failed to find config file {}".format(config_file)) 

241 

242 return cf 

243 

244 def parse_config_file(self): 

245 """Parses the config file and saves the parameters.""" 

246 builtins = (int, float, bool, str, dict) 

247 

248 def helper(item): 

249 """Set default class values before being overriding from file.""" 

250 item_cls = type(item) 

251 if item_cls in builtins or item_cls == tuple: 

252 return item 

253 true_item = item_cls() 

254 for field in dir(item): 

255 if field[0:2] == "__" or field[0] == "_": 

256 continue 

257 # skip properties, assume they are read only 

258 try: 

259 if isinstance(getattr(item_cls, field), property): 

260 continue 

261 except AttributeError: 

262 pass 

263 val = getattr(item, field) 

264 if callable(val): 

265 continue 

266 val_type = type(getattr(true_item, field)) 

267 if val_type in builtins: 

268 setattr(true_item, field, val) 

269 elif val_type == list: 

270 if not isinstance(val, list): 

271 val = list([val,]) # noqa 

272 lst = [] 

273 for lst_item in val: 

274 lst.append(helper(lst_item)) 

275 setattr(true_item, field, lst) 

276 elif val_type == tuple: 

277 if not isinstance(val, tuple): 

278 val = tuple(val) 

279 setattr(true_item, field, val) 

280 elif val_type == np.ndarray: 

281 try: 

282 if not isinstance(val, list): 

283 val = list([val,]) # noqa 

284 arr_val = np.array(val, dtype=float) 

285 except ValueError: 

286 raise RuntimeError( 

287 "Failed to convert {:s} to numpy array ({}).".format( 

288 field, val 

289 ) 

290 ) 

291 setattr(true_item, field, arr_val) 

292 else: 

293 setattr(true_item, field, helper(val)) 

294 

295 return true_item 

296 

297 with open(self.config_file, "r") as fin: 

298 self.params = helper(yaml.load(fin)) 

299 

300 def reset(self, seed=None, rng=None): 

301 """Resets to the base state. 

302 

303 If a random generator is provided then that is used. Otherwise the seed 

304 value is used to create a new generator. If neither is provided, but a 

305 seed had previously been provided then the old seed is used to recreate 

306 the generator. If all else fails, then a new default generator is initialized. 

307 

308 Parameters 

309 ---------- 

310 seed : int, optional 

311 seed for the random number generator. The default is None. 

312 rng : numpy random generator, optional 

313 Instance of the random number generator to use. The default is None. 

314 """ 

315 self.entityManager = EntityManager() 

316 self.current_frame = 0 

317 self.game_over = False 

318 self.score = 0 

319 

320 if rng is not None: 

321 self.rng = rng 

322 elif seed is None: 

323 if self.seed_val is None: 

324 self.rng = np.random.default_rng() 

325 else: 

326 self.rng = np.random.default_rng(self.seed_val) 

327 else: 

328 self.rng = np.random.default_rng(seed) 

329 self.seed_val = seed 

330 

331 @abstractmethod 

332 def s_movement(self, action): 

333 """Abstract method for moving entities according to their dynamics. 

334 

335 Parameters 

336 ---------- 

337 action : numpy array, int, bool, dict, etc. 

338 action to take in the game. 

339 """ 

340 raise NotImplementedError() 

341 

342 @abstractmethod 

343 def s_collision(self): 

344 """Abstract method to check for collisions between entities. 

345 

346 May return extra info useful by the step function. 

347 """ 

348 raise NotImplementedError() 

349 

350 @abstractmethod 

351 def s_game_over(self): 

352 """Abstract method to check for game over conditions.""" 

353 raise NotImplementedError() 

354 

355 @abstractmethod 

356 def s_score(self): 

357 """Abstact method to calculate the score. 

358 

359 Returns 

360 ------- 

361 info : dict 

362 Extra info for debugging. 

363 """ 

364 raise NotImplementedError() 

365 

366 @abstractmethod 

367 def s_input(self, *args): 

368 """Abstract method to turn user inputs into game actions.""" 

369 raise NotImplementedError() 

370 

371 @abstractmethod 

372 def step(self, *args): 

373 """Abstract method defining what to do each frame. 

374 

375 Returns 

376 ------- 

377 info : dict 

378 Extra infomation for debugging. 

379 """ 

380 raise NotImplementedError 

381 

382 

383class BaseGame2d(BaseGame): 

384 """Base class for defining 2d games. 

385 

386 This should implement all the necessary systems (i.e. game logic) by operating 

387 on entities. It assumes the rendering will be done by pygame. 

388 

389 Attributes 

390 ---------- 

391 clock : pygame clock 

392 Clock for the rendering system. 

393 window : pygame window 

394 Window for drawing to. 

395 img : H x W x 3 numpy array 

396 Pixel values of the screen image. 

397 current_update_count : int 

398 Number of update steps taken, this may be more than the frame count as 

399 multiple physics updates may be made per frame. 

400 dist_per_pix : numpy array 

401 Real distance per pixel in x (width) and y (height) order. 

402 """ 

403 

404 def __init__(self, config_file, render_mode, **kwargs): 

405 super().__init__(config_file, render_mode, **kwargs) 

406 self.clock = None 

407 self.window = None 

408 self.img = np.array([]) 

409 self.current_update_count = -1 

410 self.dist_per_pix = np.array([]) # width, height 

411 

412 @property 

413 def current_time(self): 

414 """Current time in real units.""" 

415 return self.elapsed_time + self.params.start_time 

416 

417 @property 

418 def elapsed_time(self): 

419 """Amount of time elapsed in game in real units.""" 

420 return self.params.physics.update_dt * self.current_update_count 

421 

422 def setup(self): 

423 """Sets up the game and should be called before any game functions. 

424 

425 Sets up the physics and rendering system and resets to the base state. 

426 """ 

427 super().setup() 

428 

429 # update render fps if not set 

430 if self.render_fps is None: 

431 self.render_fps = 1 / self.params.physics.dt 

432 

433 self.clock = grender2d.init_rendering_system() 

434 self.window = grender2d.init_window( 

435 self.render_mode, self.params.window.width, self.params.window.height 

436 ) 

437 

438 self.dist_per_pix = np.array( 

439 [ 

440 self.params.physics.dist_width / self.window.get_width(), 

441 self.params.physics.dist_height / self.window.get_height(), 

442 ] 

443 ) 

444 

445 self.reset() 

446 

447 def register_params(self, yaml): 

448 """Register custom classes for this game with the yaml parser. 

449 

450 Parameters 

451 ---------- 

452 yaml : ruamel.yaml YAML object 

453 yaml parser to use, should be the global parser. 

454 """ 

455 super().register_params(yaml) 

456 yaml.register_class(Shape2dParams) 

457 yaml.register_class(Collision2dParams) 

458 yaml.register_class(Physics2dParams) 

459 yaml.register_class(Base2dParams) 

460 

461 def get_image_size(self): 

462 """Gets the size of the window. 

463 

464 Returns 

465 ------- 

466 tuple 

467 first is the height next is the width, in pixels. 

468 """ 

469 sz = self.window.get_size() 

470 return sz[1], sz[0] 

471 

472 def append_name_to_keys(self, in_dict, prefix): 

473 """Append a prefix to every key in a dictionary. 

474 

475 A dot is placed between the prefix and the original key. 

476 

477 Parameters 

478 ---------- 

479 in_dict : dict 

480 Original dictionary. 

481 prefix : string 

482 string to prepend. 

483 

484 Returns 

485 ------- 

486 out : dict 

487 updated dictionary. 

488 """ 

489 out = {} 

490 for key, val in in_dict.items(): 

491 n_key = "{:s}.{:s}".format(prefix, key) 

492 out[n_key] = val 

493 return out 

494 

495 def reset(self, **kwargs): 

496 """Resets to the base state.""" 

497 super().reset(**kwargs) 

498 self.img = 255 * np.ones((*self.get_image_size(), 3), dtype=np.uint8) 

499 self.current_update_count = 0 

500 

501 def step(self, user_input): 

502 """Perform one iteration of the game loop. 

503 

504 Multiple physics updates can be made between rendering calls. 

505 

506 Parameters 

507 ---------- 

508 user_input : dict 

509 Each key is an integer representing the entity id that the action 

510 applies to. Each value is an action to take. 

511 

512 Returns 

513 ------- 

514 info : dict 

515 Extra infomation for debugging. 

516 """ 

517 info = {} 

518 self.current_frame += 1 

519 

520 self.score = 0 

521 reached_target = False 

522 for ii in range(self.params.physics.step_factor): 

523 self.current_update_count += 1 

524 self.entityManager.update() 

525 

526 # clear events for entities 

527 for e in self.entityManager.get_entities(): 

528 if e.has_component(gcomp.CEvents): 

529 e.get_component(gcomp.CEvents).events = [] 

530 

531 action = self.s_input(user_input) 

532 self.s_movement(action) 

533 hit_tar = self.s_collision() 

534 reached_target = reached_target or hit_tar 

535 self.s_game_over() 

536 score, s_info = self.s_score() 

537 self.score += score 

538 self.score /= self.params.physics.step_factor 

539 

540 info["reached_target"] = reached_target 

541 

542 self.s_render() 

543 

544 info.update(self.append_name_to_keys(s_info, "Reward")) 

545 

546 return info 

547 

548 def s_render(self): 

549 """Render a frame of the game.""" 

550 self.img = grender2d.render( 

551 grender2d.get_drawable_entities(self.entityManager.get_entities()), 

552 self.window, 

553 self.clock, 

554 self.render_mode, 

555 self.render_fps, 

556 ) 

557 

558 def close(self): 

559 """Shutdown the rendering system.""" 

560 grender2d.shutdown(self.window)