Coverage for src/gncpy/games/SimpleUAV2d/__init__.py: 0%
514 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-13 06:15 +0000
1"""Implements the SimpleUAV2d game.
3Included configuration files for this game are:
5 * SimpleUAV2d.yaml
6 * SimpleUAVHazards2d.yaml
8"""
9import numpy as np
10import enum
12import gncpy.dynamics.basic as gdyn
13import gncpy.game_engine.physics2d as gphysics
14import gncpy.game_engine.components as gcomp
15from gncpy.game_engine.base_game import (
16 BaseGame2d,
17 Base2dParams,
18 Shape2dParams,
19 Collision2dParams,
20)
23class BirthModelParams:
24 """Parameters for the birth model to be parsed by the yaml parser.
26 The types defined in this class determine what type the parser uses.
28 Attributes
29 ----------
30 type : string
31 type of score system to use. Options are defined by
32 :class:`gncpy.game_engine.components.CBirth`.
33 extra_params : dict
34 Extra parameters for the birth model. Varies depending on the type.
35 location : numpy array
36 Either just x/y position or full state length. Location parameter for
37 the birth distribution.
38 scale : numpy array
39 Must be of appropriate dimension for the location. Scale parameter for
40 the birth distrbution.
41 randomize : bool
42 Flag indicating if a random sample should be drawn from the birth
43 distribution. This does not affect birth time.
44 times : numpy array
45 Birth times for object, these do not need to align with a game step time.
46 If no value is provided then a probability must be given.
47 prob : float
48 Probability of birth at every time step. Only used if no times are given.
49 Must be in the range (0, 1].
50 """
52 def __init__(self):
53 self.type = ""
54 self.extra_params = {}
56 self.location = np.array([])
57 self.scale = np.array([])
59 self.randomize = False
60 self.times = np.array([])
61 self.prob = 0
64class ControlModelParams:
65 """Parameters for the birth model to be parsed by the yaml parser.
67 The types defined in this class determine what type the parser uses.
69 Attributes
70 ----------
71 type : string
72 type of control model to use. Options are dependent on the dynamics type.
73 For :code:`'DoubleIntegrator'` dynamics, options are :code:`'velocity'`.
74 For :code:`'CoordinatedTurn'` dynamics, options are :code:`'velocity_turn'`.
75 max_vel : float
76 Maximum velocity for x/y components. Must be set for coordinated turn
77 :code:`'velocity_turn'` type. Only used by double integrator
78 :code:`'velocity'` type if max_vel_x and max_vel_y not set.
79 max_vel_x : float
80 Maximum velocity in the x direction. Used by double integrator
81 :code:`'velocity'`.
82 max_vel_y : float
83 Maximum velocity in the y direction. Used by double integrator
84 :code:`'velocity'`.
85 max_turn_rate : float
86 Maximum turn rate. Used by coordinated turn :code:`'velocity_turn'`.
87 """
89 def __init__(self):
90 self.type = ""
91 self.max_vel = None
92 self.max_vel_x = None
93 self.max_vel_y = None
94 self.max_turn_rate = None
97class StateConstraintParams:
98 """Parameters for the state constraints to be parsed by the yaml parser.
100 The types defined in this class determine what type the parser uses.
102 Attributes
103 ----------
104 type : string
105 Type of control model to use. Options are :code:`'none'` or
106 :code:`'velocity'`.
107 min_vels : numpy array
108 Minimum x/y velocity.
109 max_vels : numpy array
110 Maximum x/y velocity.
111 """
113 def __init__(self):
114 self.type = ""
115 self.min_vels = np.array([])
116 self.max_vels = np.array([])
119class DynamicsParams:
120 """Parameters for the state constraints to be parsed by the yaml parser.
122 The types defined in this class determine what type the parser uses.
124 Attributes
125 ----------
126 type : string
127 Type of control model to use. Options are :code:`'DoubleIntegrator'` or
128 :code:`'CoordinatedTurn'`. See :mod:`gncpy.dynamics.basic` for details
129 on the dynamics.
130 extra_params : dict
131 Extra parameters for the dynamics object. Varies depending on the type.
132 controlModel : :class:`.ControlModelParams`
133 Parameters for the control model.
134 stateConstraint : :class:`.StateConstraintParams`
135 Parameters for the state constraints.
136 """
138 def __init__(self):
139 self.type = ""
140 self.extra_params = {}
141 self.controlModel = ControlModelParams()
142 self.stateConstraint = StateConstraintParams()
145class PlayerParams:
146 """Parameters for a player object to be parsed by the yaml parser.
148 The types defined in this class determine what type the parser uses.
150 Attributes
151 ----------
152 birth : :class:`.BirthModelParams`
153 Parametrs for the birth model
154 dynamics : :class:`.DynamicsParams`
155 Parameters for the dynamics model.
156 shape : :class:`gncpy.game_engine.rendering2d.Shape2dParams`
157 Parameters for the shape.
158 collision : :class`gncpy.game_engine.physics2d.Collision2dParams`
159 Parameters for the collision bounding box.
160 capabilities : list
161 Each element is a string representing some capability of the player.
162 Allows for modeling additional hardware some targets may require to
163 get successful "hit".
164 """
166 def __init__(self):
167 self.birth = BirthModelParams()
168 self.dynamics = DynamicsParams()
169 self.shape = Shape2dParams()
170 self.collision = Collision2dParams()
171 self.capabilities = []
174class ObstacleParams:
175 """Parameters for an obstacle object to be parsed by the yaml parser.
177 The types defined in this class determine what type the parser uses.
179 Attributes
180 ----------
181 loc_x : float
182 X location of the center in real units.
183 loc_y : float
184 Y location of the center in real units.
185 shape : :class:`gncpy.game_engine.rendering2d.Shape2dParams`
186 Parameters for the shape.
187 collision : :class`gncpy.game_engine.physics2d.Collision2dParams`
188 Parameters for the collision bounding box.
189 """
191 def __init__(self):
192 self.loc_x = 0
193 self.loc_y = 0
194 self.shape = Shape2dParams()
195 self.collision = Collision2dParams()
198class TargetParams:
199 """Parameters for a target object to be parsed by the yaml parser.
201 The types defined in this class determine what type the parser uses.
203 Attributes
204 ----------
205 loc_x : float
206 X location of the center in real units.
207 loc_y : float
208 Y location of the center in real units.
209 shape : :class:`gncpy.game_engine.rendering2d.Shape2dParams`
210 Parameters for the shape.
211 collision : :class`gncpy.game_engine.physics2d.Collision2dParams`
212 Parameters for the collision bounding box.
213 capabilities : list
214 Each element is a string representing some capability required by the
215 target. When a player collides with a target, the more capabilities that
216 match the better the score.
217 priority : float
218 Relative importance of reaching this target.
219 order : int
220 Order to reach this target reltive to other targets. Must be >= 0. All
221 targets with the same order will be available at the same time.
222 """
224 def __init__(self):
225 self.loc_x = 0
226 self.loc_y = 0
227 self.shape = Shape2dParams()
228 self.collision = Collision2dParams()
229 self.capabilities = []
230 self.priority = 0
231 self.order = 0
234class HazardParams:
235 """Parameters for a hazard object to be parsed by the yaml parser.
237 The types defined in this class determine what type the parser uses.
239 Attributes
240 ----------
241 loc_x : float
242 X location of the center in real units.
243 loc_y : float
244 Y location of the center in real units.
245 shape : :class:`gncpy.game_engine.rendering2d.Shape2dParams`
246 Parameters for the shape.
247 collision : :class`gncpy.game_engine.physics2d.Collision2dParams`
248 Parameters for the collision bounding box.
249 prob_of_death : float
250 Probability of dying for each timestep in the hazard. Must be in the
251 range [0, 1].
252 """
254 def __init__(self):
255 self.loc_x = 0
256 self.loc_y = 0
257 self.shape = Shape2dParams()
258 self.collision = Collision2dParams()
259 self.prob_of_death = 0
262class ScoreParams:
263 """Parameters for the score system to be parsed by the yaml parser.
265 The types defined in this class determine what type the parser uses.
266 See :meth:`.SimpleUAV2d.s_score` for details on what these terms mean for
267 specific score systems.
269 Attributes
270 ----------
271 type : string
272 Type of control model to use. Options are :code:`'basic'`.
273 hazard_multiplier : float
274 Multiplier for the base hazard term.
275 death_scale : float
276 Scaling factor fo the death term.
277 death_decay : float
278 Decay factor for the death term.
279 death_penalty : float
280 Additional penalty for death.
281 time_penalty : float
282 Time penalty.
283 missed_multiplier : float
284 Multiplier for missing targets
285 wall_penalty : float
286 Penalty for hitting obstacles
287 vel_penalty : float
288 Penalty for having extreme velocity
289 min_vel_per : float
290 Minimum velocity as a percentage of the magnitude. Must be in range
291 [0, 1].
292 """
294 def __init__(self):
295 self.type = "basic"
296 self.hazard_multiplier = 2
297 self.death_scale = 0
298 self.death_decay = 0.05
299 self.death_penalty = 100
300 self.time_penalty = 1
301 self.missed_multiplier = 5
302 self.target_multiplier = 50
303 self.wall_penalty = 2
304 self.vel_penalty = 1
305 self.min_vel_per = 0.2
308class Params(Base2dParams):
309 """Main set of parameters to be parsed by the yaml parser.
311 The types defined in this class determine what type the parser uses.
313 Attributes
314 ----------
315 players : list
316 Each element is a :class:`.PlayerParams` object.
317 targets : list
318 Each element is a :class:`.TargetParams` object.
319 obstacles : list
320 Each element is a :class:`.ObstacleParams` object.
321 hazards : list
322 Each element is a :class:`.HazardsParams` object.
323 score : :class:`.ScoreParams`
324 Parameters for the score system.
325 """
327 def __init__(self):
328 super().__init__()
330 self.players = []
331 self.targets = []
332 self.obstacles = []
333 self.hazards = []
334 self.score = ScoreParams()
337@enum.unique
338class EventType(enum.Enum):
339 """Define the different types of events in the game."""
341 HAZARD = enum.auto()
342 DEATH = enum.auto()
343 TARGET = enum.auto()
344 OBSTACLE = enum.auto()
345 COL_PLAYER = enum.auto()
347 def __str__(self):
348 """Return the enum name for strings."""
349 return self.name
352class SimpleUAV2d(BaseGame2d):
353 """Simple 2D UAV game.
355 Assumes obstacles and hazards are static, and all players have the same
356 state and action spaces.
358 Attributes
359 ----------
360 cur_target_seq : int
361 Current index into the target_seq list.
362 target_seq : list
363 Each element is an order value from the target parameters, only has
364 unique values and is sorted in ascending order.
365 all_capabilities : list
366 Each element is a string, holds all (unique) possible values from target
367 and player capabilities.
368 has_random_player_birth_times : bool
369 Flag indicating if any player can have random birth times.
370 max_player_birth_time : float
371 Maximum time a new player can be born.
372 """
374 def __init__(self, config_file, render_mode, **kwargs):
375 """Initialize an object.
377 Parameters
378 ----------
379 config_file : string
380 Full path of the configuratin file.
381 render_mode : string
382 Mode to render the game.
383 **kwargs : dict
384 Additional arguments for the parent classes.
385 """
386 super().__init__(config_file, render_mode, **kwargs)
388 self.cur_target_seq = None
389 self.target_seq = []
390 self.all_capabilities = []
392 self.has_random_player_birth_times = False
393 self.max_player_birth_time = -np.inf
395 self._performed_reset_spawn = False
397 def register_params(self, yaml):
398 """Register custom classes for this game with the yaml parser.
400 Parameters
401 ----------
402 yaml : ruamel.yaml YAML object
403 yaml parser to use.
404 """
405 super().register_params(yaml)
406 yaml.register_class(ScoreParams)
407 yaml.register_class(BirthModelParams)
408 yaml.register_class(ControlModelParams)
409 yaml.register_class(StateConstraintParams)
410 yaml.register_class(DynamicsParams)
411 yaml.register_class(PlayerParams)
412 yaml.register_class(ObstacleParams)
413 yaml.register_class(TargetParams)
414 yaml.register_class(HazardParams)
415 yaml.register_class(ScoreParams)
416 yaml.register_class(Params)
418 def create_obstacles(self):
419 """Creates all obstacles based on config values."""
420 for params in self.params.obstacles:
421 e = self.entityManager.add_entity("obstacle")
423 e.add_component(gcomp.CTransform)
424 c_transform = e.get_component(gcomp.CTransform)
425 c_transform.pos[0] = gphysics.dist_to_pixels(
426 params.loc_x,
427 self.dist_per_pix[0],
428 min_pos=self.params.physics.min_pos[0],
429 )
430 c_transform.pos[1] = gphysics.dist_to_pixels(
431 params.loc_y,
432 self.dist_per_pix[1],
433 min_pos=self.params.physics.min_pos[1],
434 )
435 c_transform.last_pos[0] = c_transform.pos[0]
436 c_transform.last_pos[1] = c_transform.pos[1]
438 e.add_component(
439 gcomp.CShape,
440 s_type=params.shape.type,
441 w=gphysics.dist_to_pixels(params.shape.width, self.dist_per_pix[0]),
442 h=gphysics.dist_to_pixels(params.shape.height, self.dist_per_pix[1]),
443 color=params.shape.color,
444 zorder=1000,
445 fpath=params.shape.file,
446 )
448 if params.collision.height > 0 and params.collision.width > 0:
449 e.add_component(
450 gcomp.CCollision,
451 w=gphysics.dist_to_pixels(
452 params.collision.width, self.dist_per_pix[0]
453 ),
454 h=gphysics.dist_to_pixels(
455 params.collision.height, self.dist_per_pix[1]
456 ),
457 )
459 def create_hazards(self):
460 """Creates all hazards based on config values."""
461 for params in self.params.hazards:
462 e = self.entityManager.add_entity("hazard")
464 e.add_component(gcomp.CTransform)
465 c_transform = e.get_component(gcomp.CTransform)
466 c_transform.pos[0] = gphysics.dist_to_pixels(
467 params.loc_x,
468 self.dist_per_pix[0],
469 min_pos=self.params.physics.min_pos[0],
470 )
471 c_transform.pos[1] = gphysics.dist_to_pixels(
472 params.loc_y,
473 self.dist_per_pix[1],
474 min_pos=self.params.physics.min_pos[1],
475 )
476 c_transform.last_pos[0] = c_transform.pos[0]
477 c_transform.last_pos[1] = c_transform.pos[1]
479 e.add_component(
480 gcomp.CShape,
481 s_type=params.shape.type,
482 w=gphysics.dist_to_pixels(params.shape.width, self.dist_per_pix[0]),
483 h=gphysics.dist_to_pixels(params.shape.height, self.dist_per_pix[1]),
484 color=params.shape.color,
485 zorder=-100,
486 fpath=params.shape.file,
487 )
489 if params.collision.height > 0 and params.collision.width > 0:
490 e.add_component(
491 gcomp.CCollision,
492 w=gphysics.dist_to_pixels(
493 params.collision.width, self.dist_per_pix[0]
494 ),
495 h=gphysics.dist_to_pixels(
496 params.collision.height, self.dist_per_pix[1]
497 ),
498 )
500 pd = float(params.prob_of_death)
501 if pd > 1:
502 pd = pd / 100.0
503 e.add_component(gcomp.CHazard, prob_of_death=pd)
505 def create_targets(self):
506 """Creates current targets based on config values.
508 Returns
509 -------
510 bool
511 Flag indicating if targets were generated
512 """
513 if len(self.entityManager.get_entities("target")) > 0:
514 return False
516 if self.cur_target_seq is None:
517 self.cur_target_seq = 0
518 else:
519 self.cur_target_seq += 1
521 if self.cur_target_seq >= len(self.target_seq):
522 return False
524 order = self.target_seq[self.cur_target_seq]
526 for params in self.params.targets:
527 if params.order != order:
528 continue
530 e = self.entityManager.add_entity("target")
532 e.add_component(gcomp.CTransform)
533 c_transform = e.get_component(gcomp.CTransform)
534 c_transform.pos[0] = gphysics.dist_to_pixels(
535 params.loc_x,
536 self.dist_per_pix[0],
537 min_pos=self.params.physics.min_pos[0],
538 )
539 c_transform.pos[1] = gphysics.dist_to_pixels(
540 params.loc_y,
541 self.dist_per_pix[1],
542 min_pos=self.params.physics.min_pos[1],
543 )
544 c_transform.last_pos[0] = c_transform.pos[0]
545 c_transform.last_pos[1] = c_transform.pos[1]
547 e.add_component(
548 gcomp.CShape,
549 s_type=params.shape.type,
550 w=gphysics.dist_to_pixels(params.shape.width, self.dist_per_pix[0]),
551 h=gphysics.dist_to_pixels(params.shape.height, self.dist_per_pix[1]),
552 color=params.shape.color,
553 zorder=1,
554 fpath=params.shape.file,
555 )
557 if params.collision.height > 0 and params.collision.width > 0:
558 e.add_component(
559 gcomp.CCollision,
560 w=gphysics.dist_to_pixels(
561 params.collision.width, self.dist_per_pix[0]
562 ),
563 h=gphysics.dist_to_pixels(
564 params.collision.height, self.dist_per_pix[1]
565 ),
566 )
568 e.add_component(gcomp.CCapabilities, capabilities=params.capabilities)
570 e.add_component(gcomp.CPriority, priority=params.priority)
572 return True
574 def get_player_pos_vel_inds(self, params=None):
575 """Determines the position and velocity indices in the state vector.
577 Note, this assumes all players have the same state bounds when params is
578 not specified.
580 Parameters
581 ----------
582 params : :class:`.DynamicsParams`, optional
583 Player parameter structure. The default is None which uses the first
584 player in the player's list.
586 Returns
587 -------
588 pos_inds : list
589 indices for the position components (x, y).
590 vel_inds : list
591 indices for the velocity components (x, y).
592 """
593 if params is None:
594 params = self.params.players[0].dynamics
595 if params.type == "DoubleIntegrator":
596 pos_inds = [0, 1]
597 vel_inds = [2, 3]
598 elif params.type == "CoordinatedTurn":
599 pos_inds = [0, 2]
600 vel_inds = [1, 3]
602 return pos_inds, vel_inds
604 def get_player_state_bounds(self, params=None):
605 """Calculate the bounds on the player state.
607 Note, this assumes all players have the same state bounds when params is
608 not specified.
610 Parameters
611 ----------
612 params : :class:`.DynamicsParams`, optional
613 Player parameter structure. The default is None which uses the first
614 player in the player's list.
616 Returns
617 -------
618 2 x N numpy array
619 minimum and maximum values of the player state.
620 """
621 if params is None:
622 params = self.params.players[0].dynamics
624 pos_inds, vel_inds = self.get_player_pos_vel_inds(params=params)
626 if params.type == "DoubleIntegrator":
627 state_low = np.hstack(
628 (self.params.physics.min_pos, np.array([-np.inf, -np.inf]))
629 )
630 state_high = np.hstack(
631 (
632 self.params.physics.min_pos
633 + np.array(
634 [
635 self.params.physics.dist_width,
636 self.params.physics.dist_height,
637 ]
638 ),
639 np.array([np.inf, np.inf]),
640 )
641 )
642 sParams = params.stateConstraint
643 if sParams.type.lower() != "none":
644 if sParams.type.lower() == "velocity":
645 state_low[vel_inds] = sParams.min_vels
646 state_high[vel_inds] = sParams.max_vels
648 elif params.type == "CoordinatedTurn":
649 state_low = np.hstack(
650 (
651 self.params.physics.min_pos[0],
652 np.array([-np.inf]),
653 self.params.physics.min_pos[1],
654 np.array([-np.inf, -2 * np.pi]),
655 )
656 )
657 state_high = np.hstack(
658 (
659 self.params.physics.min_pos[0] + self.params.physics.dist_width,
660 np.array([np.inf]),
661 self.params.physics.min_pos[1] + self.params.physics.dist_height,
662 np.array([np.inf, 2 * np.pi]),
663 )
664 )
665 if sParams.type.lower() != "none":
666 if sParams.type.lower() == "velocity":
667 state_low[vel_inds] = sParams.min_vel * np.ones(len(vel_inds))
668 state_high[vel_inds] = sParams.min_vel * np.ones(len(vel_inds))
670 else:
671 raise RuntimeError("Invalid dynamics type: {}".format(sParams.type))
673 return state_low, state_high
675 def get_players_state(self):
676 """Returns current dynamic state of all players.
678 Returns
679 -------
680 states : dict
681 Each key is the entity id and each value is a numpy array.
682 """
683 states = {}
684 for p in self.entityManager.get_entities("player"):
685 pTrans = p.get_component(gcomp.CDynamics)
686 states[p.id] = pTrans.state.copy().ravel()
688 return states
690 def create_dynamics(self, params, cBirth):
691 """Create the dynamics system for a player.
693 Parameters
694 ----------
695 params : :class:`.DynamicsParams`
696 Parameters to use when making the dynamics.
697 cBirth : :class:`gncpy.game_engine.components.CBirth`
698 Birth component associated with this player.
700 Raises
701 ------
702 RuntimeError
703 Incorrect parameters are set.
704 NotImplementedError
705 Wrong dynamics, control model and/or state constraint combo specified.
707 Returns
708 -------
709 dynObj : :class:`gncpy.dynamics.basic.DynamicsBase`
710 Dynamic object created
711 pos_inds : list
712 indices of the position variables in the state
713 vel_inds : list
714 indices of the velocity variables in the state
715 state_args : tuple
716 additional arguments for the dynObj propagate function
717 ctrl_args
718 additional arguments for the dynObj propagate function
719 state_low : numpy array
720 lower state bounds
721 state_high : numpy array
722 upper state bounds
723 state0 : numpy array
724 initial state of the dynObj
725 """
726 cls_type = getattr(gdyn, params.type)
728 pos_inds, vel_inds = self.get_player_pos_vel_inds(params=params)
729 kwargs = {}
730 if params.type == "DoubleIntegrator":
731 state_args = (self.params.physics.update_dt,)
733 cParams = params.controlModel
734 if cParams.type.lower() == "velocity":
735 ctrl_args = ()
737 def _ctrl_mod(t, x, *args):
738 if cParams.max_vel_x and cParams.max_vel_y:
739 mat = np.diag(
740 (float(cParams.max_vel_x), float(cParams.max_vel_y))
741 )
742 elif cParams.max_vel:
743 mat = cParams.max_vel * np.eye(2)
744 else:
745 raise RuntimeError(
746 "Must set max_vel or max_vel_x and max_vel_y in control model."
747 )
748 return np.vstack((np.zeros((2, 2)), mat))
750 else:
751 msg = "Control model type {} not implemented for dynamics {}".format(
752 cParams.type, params.type
753 )
754 raise NotImplementedError(msg)
755 kwargs["control_model"] = _ctrl_mod
757 sParams = params.stateConstraint
758 if sParams.type.lower() != "none":
759 if sParams.type.lower() == "velocity":
761 def _state_constraint(t, x):
762 x[vel_inds] = np.min(
763 np.vstack((x[vel_inds].ravel(), sParams.max_vels)), axis=0,
764 ).reshape((len(vel_inds), 1))
765 x[vel_inds] = np.max(
766 np.vstack((x[vel_inds].ravel(), sParams.min_vels)), axis=0,
767 ).reshape((len(vel_inds), 1))
768 return x
770 else:
771 msg = "State constraint type {} not implemented for dynamics {}".format(
772 sParams.type, params.type
773 )
774 raise NotImplementedError(msg)
775 kwargs["state_constraint"] = _state_constraint
777 elif params.type == "CoordinatedTurn":
778 state_args = ()
780 cParams = params.controlModel
781 if cParams.type.lower() == "velocity_turn":
782 ctrl_args = ()
784 def _g1(t, x, u, *args):
785 return cParams.max_vel * np.cos(x[4].item()) * u[0].item()
787 def _g0(t, x, u, *args):
788 return 0
790 def _g3(t, x, u, *args):
791 return cParams.max_vel * np.sin(x[4].item()) * u[0].item()
793 def _g2(t, x, u, *args):
794 return 0
796 def _g4(t, x, u, *args):
797 return cParams.max_turn_rate * np.pi / 180 * u[1].item()
799 else:
800 msg = "Control model type {} not implemented for dynamics {}".format(
801 cParams.type, params.type
802 )
803 raise NotImplementedError(msg)
804 kwargs["control_model"] = [_g0, _g1, _g2, _g3, _g4]
806 sParams = params.stateConstraint
807 if sParams.type.lower() != "none":
808 if sParams.type.lower() == "velocity":
810 def _state_constraint(t, x):
811 x[vel_inds] = np.min(
812 np.vstack((x[vel_inds].ravel(), sParams.max_vels)), axis=0,
813 ).reshape((-1, 1))
814 x[vel_inds] = np.max(
815 np.vstack((x[vel_inds].ravel(), sParams.min_vels)), axis=0,
816 ).reshape((-1, 1))
817 if x[4] < 0:
818 x[4] = np.mod(x[4], -2 * np.pi)
819 else:
820 x[4] = np.mod(x[4], 2 * np.pi)
822 return x
824 else:
825 msg = "State constraint type {} not implemented for dynamics {}".format(
826 sParams.type, params.type
827 )
828 raise NotImplementedError(msg)
829 kwargs["state_constraint"] = _state_constraint
831 kwargs.update(params.extra_params)
833 state_low, state_high = self.get_player_state_bounds(params=params)
835 dynObj = cls_type(**kwargs)
836 state0 = np.zeros((state_low.size, 1))
837 val = cBirth.sample()
838 if val.size == len(pos_inds):
839 state0[pos_inds] = val.reshape(state0[pos_inds].shape)
841 if cBirth.randomize and params.type == "CoordinatedTurn":
842 state0[4] = self.rng.random() * 2 * np.pi
844 elif val.size == state0.size:
845 state0 = val.reshape(state0.shape)
846 else:
847 raise RuntimeError("Birth location must match position size or full state.")
849 return (
850 dynObj,
851 pos_inds,
852 vel_inds,
853 state_args,
854 ctrl_args,
855 state_low,
856 state_high,
857 state0,
858 )
860 def create_player(self, params):
861 """Creates a player entity.
863 Parameters
864 ----------
865 params : :class:`gncpy.games.SimpleUAV2d.PlayerParams`
866 Parameters for the player being created.
868 Returns
869 -------
870 p : :class:`gncpy.game_engine.entities.Entity`
871 Reference to the player entity that was created.
872 """
873 # check if using random birth time
874 if params.birth.times.size == 0:
875 req_spawn = self.rng.uniform(0.0, 1.0) < params.birth.prob
876 else:
877 diff = self.current_time - np.sort(params.birth.times)
878 inds = np.where(diff >= -1e-8)[0]
879 if inds.size == 0:
880 return None
881 min_diff = diff[inds[-1]]
882 # birth times don't have to align with updates
883 req_spawn = min_diff < self.params.physics.update_dt - 1e-8
885 if not req_spawn:
886 return None
888 e = self.entityManager.add_entity("player")
890 e.add_component(
891 gcomp.CBirth,
892 b_type=params.birth.type,
893 loc=params.birth.location,
894 scale=params.birth.scale,
895 params=params.birth.extra_params,
896 rng=self.rng,
897 randomize=params.birth.randomize,
898 )
900 e.add_component(gcomp.CDynamics)
901 cDyn = e.get_component(gcomp.CDynamics)
902 (
903 cDyn.dynObj,
904 cDyn.pos_inds,
905 cDyn.vel_inds,
906 cDyn.state_args,
907 cDyn.ctrl_args,
908 cDyn.state_low,
909 cDyn.state_high,
910 cDyn.state,
911 ) = self.create_dynamics(params.dynamics, e.get_component(gcomp.CBirth))
913 e.add_component(gcomp.CTransform)
914 cTrans = e.get_component(gcomp.CTransform)
915 p_ii = cDyn.pos_inds
916 v_ii = cDyn.vel_inds
917 cTrans.pos = gphysics.dist_to_pixels(
918 cDyn.state[p_ii], self.dist_per_pix, min_pos=self.params.physics.min_pos
919 )
920 if v_ii is not None:
921 cTrans.vel = gphysics.dist_to_pixels(cDyn.state[v_ii], self.dist_per_pix)
923 e.add_component(gcomp.CEvents)
925 e.add_component(
926 gcomp.CShape,
927 s_type=params.shape.type,
928 w=gphysics.dist_to_pixels(params.shape.width, self.dist_per_pix[0]),
929 h=gphysics.dist_to_pixels(params.shape.height, self.dist_per_pix[1]),
930 color=tuple(params.shape.color),
931 zorder=100,
932 fpath=params.shape.file,
933 )
935 e.add_component(
936 gcomp.CCollision,
937 w=gphysics.dist_to_pixels(params.collision.width, self.dist_per_pix[0]),
938 h=gphysics.dist_to_pixels(params.collision.height, self.dist_per_pix[1]),
939 )
941 e.add_component(gcomp.CCapabilities, capabilities=params.capabilities)
943 return e
945 def spawn_players(self):
946 """Spawns a new player if needed."""
947 for params in self.params.players:
948 self.create_player(params)
950 def propagate_dynamics(self, eDyn, action):
951 """Propagates the dynamics with the given action.
953 Parameters
954 ----------
955 eDyn : :class:`gncpy.game_engine.components.CDynamics`
956 dynamics component to propagate.
957 action : numpy array
958 Control input for the dynamics object.
959 """
960 eDyn.state = eDyn.dynObj.propagate_state(
961 self.current_time,
962 eDyn.last_state.reshape((-1, 1)),
963 u=action.reshape((-1, 1)),
964 state_args=eDyn.state_args,
965 ctrl_args=eDyn.ctrl_args,
966 ).reshape((-1, 1))
968 def get_player_ids(self):
969 """Get the entity ids of all players.
971 Returns
972 -------
973 list
974 all entity ids of the players.
975 """
976 return self.entityManager.get_entity_ids(tag="player")
978 def reset(self, **kwargs):
979 """Resets the game to the base state.
981 Parameters
982 ----------
983 kwargs : dict
984 Additional arguments for the parent classes.
985 """
986 super().reset(**kwargs)
988 # find list of all possible order values, and all capabilities
989 self.cur_target_seq = None
990 self.target_seq = []
991 self.all_capabilities = []
992 for t in self.params.targets:
993 if t.order not in self.target_seq:
994 self.target_seq.append(t.order)
996 for c in t.capabilities:
997 if c not in self.all_capabilities:
998 self.all_capabilities.append(c)
1000 self.target_seq.sort()
1002 # make sure all players either have birth times or birth probability and update all capabilities
1003 self.has_random_player_birth_times = False
1004 self.max_player_birth_time = -np.inf
1005 for ii, p in enumerate(self.params.players):
1006 if p.birth.times.size == 0 and p.birth.prob <= 0:
1007 raise RuntimeError("Player {} has invalid birth settings.".format(ii))
1008 self.has_random_player_birth_times = (
1009 self.has_random_player_birth_times or p.birth.prob > 0
1010 )
1011 if np.max(p.birth.times) > self.max_player_birth_time:
1012 self.max_player_birth_time = np.max(p.birth.times)
1014 for c in p.capabilities:
1015 if c not in self.all_capabilities:
1016 self.all_capabilities.append(c)
1018 self.create_obstacles()
1019 self.create_hazards()
1020 self.create_targets()
1021 self.spawn_players()
1022 self._performed_reset_spawn = True
1024 self.entityManager.update()
1026 def s_collision(self):
1027 """Check for collisions between entities.
1029 This also handles player death if a hazard destroys a player, and
1030 updates the events.
1032 Returns
1033 -------
1034 bool
1035 Flag indicating if a target was hit.
1036 """
1037 hit_target = False
1039 # update all bounding boxes
1040 for e in self.entityManager.get_entities():
1041 if e.has_component(gcomp.CTransform) and e.has_component(gcomp.CCollision):
1042 c_collision = e.get_component(gcomp.CCollision)
1043 c_transform = e.get_component(gcomp.CTransform)
1044 c_collision.aabb.centerx = c_transform.pos[0].item()
1045 c_collision.aabb.centery = c_transform.pos[1].item()
1047 # check for collision of player
1048 for e in self.entityManager.get_entities("player"):
1049 p_aabb = e.get_component(gcomp.CCollision).aabb
1050 p_trans = e.get_component(gcomp.CTransform)
1051 p_events = e.get_component(gcomp.CEvents)
1053 # check for out of bounds, stop at out of bounds
1054 out_side, out_top = gphysics.clamp_window_bounds2d(
1055 p_aabb, p_trans, self.window.get_width(), self.window.get_height()
1056 )
1057 if out_side:
1058 p_events.events.append((EventType.OBSTACLE, None))
1059 if out_top:
1060 p_events.events.append((EventType.OBSTACLE, None))
1062 # check for collision with obstacle
1063 for w in self.entityManager.get_entities("obstacle"):
1064 if not w.has_component(gcomp.CCollision):
1065 continue
1066 w_aabb = w.get_component(gcomp.CCollision).aabb
1067 if gphysics.check_collision2d(p_aabb, w_aabb):
1068 gphysics.resolve_collision2d(
1069 p_aabb, w_aabb, p_trans, w.get_component(gcomp.CTransform)
1070 )
1071 p_events.events.append((EventType.OBSTACLE, None))
1073 # check for collision with other players
1074 for otherP in self.entityManager.get_entities("player"):
1075 if otherP.id == e.id:
1076 continue
1077 fixedAAABB = otherP.get_component(gcomp.CCollision).aabb
1078 if gphysics.check_collision2d(p_aabb, fixedAAABB):
1079 gphysics.resolve_collision2d(
1080 p_aabb,
1081 fixedAAABB,
1082 p_trans,
1083 otherP.get_component(gcomp.CTransform),
1084 )
1085 if self.params.score.type.lower() == "basic":
1086 p_events.events.append((EventType.OBSTACLE, None))
1087 else:
1088 p_events.events.append((EventType.COL_PLAYER, None))
1090 # check for collision with hazard
1091 for h in self.entityManager.get_entities("hazard"):
1092 if not h.has_component(gcomp.CCollision):
1093 continue
1094 h_aabb = h.get_component(gcomp.CCollision).aabb
1095 c_hazard = h.get_component(gcomp.CHazard)
1096 if gphysics.check_collision2d(p_aabb, h_aabb):
1097 if self.rng.uniform(0.0, 1.0) < c_hazard.prob_of_death:
1098 e.destroy()
1099 p_events.events.append((EventType.DEATH, None))
1100 if e.id in c_hazard.entrance_times:
1101 del c_hazard.entrance_times[e.id]
1103 else:
1104 if e.id not in c_hazard.entrance_times:
1105 c_hazard.entrance_times[e.id] = self.current_time
1106 p_events.events.append(
1107 (
1108 EventType.HAZARD,
1109 {
1110 "prob": c_hazard.prob_of_death,
1111 "t_ent": c_hazard.entrance_times[e.id],
1112 },
1113 )
1114 )
1115 else:
1116 if e.id in c_hazard.entrance_times:
1117 del c_hazard.entrance_times[e.id]
1119 if not e.active:
1120 continue
1122 # check for collision with target
1123 for t in self.entityManager.get_entities("target"):
1124 if not t.active:
1125 continue
1126 if not t.has_component(gcomp.CCollision):
1127 continue
1129 if gphysics.check_collision2d(
1130 p_aabb, t.get_component(gcomp.CCollision).aabb
1131 ):
1132 hit_target = True
1133 p_events.events.append((EventType.TARGET, {"target": t}))
1134 t.destroy()
1135 break
1137 # update state
1138 p_dynamics = e.get_component(gcomp.CDynamics)
1139 p_ii = p_dynamics.pos_inds
1140 v_ii = p_dynamics.vel_inds
1142 p_dynamics.state[p_ii] = gphysics.pixels_to_dist(
1143 p_trans.pos, self.dist_per_pix, min_pos=self.params.physics.min_pos
1144 )
1145 if v_ii is not None:
1146 p_dynamics.state[v_ii] = gphysics.pixels_to_dist(
1147 p_trans.vel, self.dist_per_pix
1148 )
1150 return hit_target
1152 def s_input(self, user_input):
1153 """Validate user input.
1155 Only allows actions that correspond to a current entity.
1157 Parameters
1158 ----------
1159 user_input : dict
1160 Each key is an entity id. Each value is a numpy array representing
1161 the action for that entity.
1163 Returns
1164 -------
1165 out : dict
1166 Each key is an entity id. Each value is a numpy array for the action.
1167 """
1168 ids = self.entityManager.get_entity_ids()
1169 out = {}
1170 for key, val in user_input.items():
1171 if key in ids:
1172 out[key] = val.reshape((-1, 1))
1173 return out
1175 def s_game_over(self):
1176 """Determines if the game has met the end conditions.
1178 End conditions are if all players are dead and no random births, or
1179 all targets have been reached, or it has past the maximum time.
1180 """
1181 n_players = len(self.entityManager.get_entities("player"))
1182 all_players_dead = (
1183 not self.has_random_player_birth_times
1184 and n_players == 0
1185 and self.current_time > self.max_player_birth_time
1186 )
1187 self.game_over = (
1188 self.current_time >= self.params.max_time
1189 or self.cur_target_seq >= len(self.target_seq)
1190 or all_players_dead
1191 )
1193 def s_movement(self, action):
1194 """Move entities according to their dynamics.
1196 Parameters
1197 ----------
1198 action : dict
1199 Each key is an entity id and its value is a 2 x 1 numpy array
1200 corresponding to control inputs for the given dynamics model. This
1201 comes from the input system.
1202 """
1203 for e in self.entityManager.get_entities():
1204 if e.has_component(gcomp.CTransform):
1205 eTrans = e.get_component(gcomp.CTransform)
1206 eTrans.last_pos[0] = eTrans.pos[0]
1207 eTrans.last_pos[1] = eTrans.pos[1]
1209 act_key = e.id
1210 if e.has_component(gcomp.CDynamics) and act_key in action.keys():
1211 eDyn = e.get_component(gcomp.CDynamics)
1212 eDyn.last_state = eDyn.state.copy()
1213 self.propagate_dynamics(eDyn, action[act_key])
1215 p_ii = eDyn.pos_inds
1216 v_ii = eDyn.vel_inds
1217 eTrans.pos = gphysics.dist_to_pixels(
1218 eDyn.state[p_ii],
1219 self.dist_per_pix,
1220 min_pos=self.params.physics.min_pos,
1221 )
1222 if v_ii is not None:
1223 eTrans.vel = gphysics.dist_to_pixels(
1224 eDyn.state[v_ii], self.dist_per_pix
1225 )
1227 def basic_reward(self):
1228 """Calculate the reward for a timestep for the basic reward type.
1230 Returns
1231 -------
1232 reward : float
1233 reward for the timestep.
1234 info : dict
1235 extra info useful for debugging.
1236 """
1238 def _match_function(test_cap, req_cap):
1239 if len(req_cap) > 0:
1240 return sum([1 for c in test_cap if c in req_cap]) / len(req_cap)
1241 else:
1242 return 1
1244 t = self.current_time
1246 reward = 0
1248 # accumulate rewards from all players
1249 r_vel = 0
1250 r_haz_cumul = 0
1251 r_tar_cumul = 0.0
1252 r_death_cumul = 0
1253 r_wall_cumul = 0
1254 r_vel_cumul = 0
1255 for player in self.entityManager.get_entities("player"):
1256 r_hazard = 0
1257 r_target = 0
1258 r_death = 0
1259 r_wall = 0
1261 p_dynamics = player.get_component(gcomp.CDynamics)
1262 p_events = player.get_component(gcomp.CEvents)
1263 p_capabilities = player.get_component(gcomp.CCapabilities)
1265 if p_dynamics.vel_inds is not None and len(p_dynamics.vel_inds) > 0:
1266 max_vel = np.linalg.norm(p_dynamics.state_high[p_dynamics.vel_inds])
1267 min_vel = np.linalg.norm(p_dynamics.state_low[p_dynamics.vel_inds])
1268 vel = np.linalg.norm(p_dynamics.state[p_dynamics.vel_inds])
1270 vel_per = vel / np.max((max_vel, min_vel))
1271 if vel_per < self.params.score.min_vel_per:
1272 r_vel += -self.params.score.vel_penalty
1274 for e_type, info in p_events.events:
1275 if e_type == EventType.HAZARD:
1276 r_hazard += -(
1277 self.params.score.hazard_multiplier
1278 * (info["prob"] * 100)
1279 * (t - info["t_ent"])
1280 )
1282 elif e_type == EventType.DEATH:
1283 time_decay = self.params.score.death_scale * np.exp(
1284 -self.params.score.death_decay * t
1285 )
1286 r_death = -(
1287 time_decay
1288 * _match_function(
1289 p_capabilities.capabilities, self.all_capabilities
1290 )
1291 + self.params.score.death_penalty
1292 )
1293 r_hazard = 0
1294 r_target = 0
1295 r_wall = 0
1296 r_vel = 0
1297 break
1299 elif e_type == EventType.TARGET:
1300 target = info["target"]
1301 t_capabilities = target.get_component(gcomp.CCapabilities)
1302 t_priority = target.get_component(gcomp.CPriority)
1303 match_per = _match_function(
1304 p_capabilities.capabilities, t_capabilities.capabilities
1305 )
1306 r_target = (
1307 self.params.score.target_multiplier
1308 * t_priority.priority
1309 * match_per
1310 )
1312 elif e_type == EventType.OBSTACLE:
1313 r_wall += -self.params.score.wall_penalty
1315 r_haz_cumul += r_hazard
1316 r_tar_cumul += r_target
1317 r_death_cumul += r_death
1318 r_wall_cumul += r_wall
1319 r_vel_cumul += r_vel
1320 reward += r_hazard + r_target + r_death + r_wall + r_vel
1322 # add fixed terms to reward
1323 r_missed = 0
1324 if self.game_over:
1325 # get all targets later in the sequence
1326 if self.cur_target_seq < len(self.target_seq):
1327 for target in self.params.targets:
1328 if target.order <= self.target_seq[self.cur_target_seq]:
1329 continue
1330 r_missed += target.priority
1332 # get all remaining targets at current point in sequence
1333 for target in self.entityManager.get_entities("target"):
1334 if target.active:
1335 r_missed += -target.get_component(gcomp.CPriority).priority
1337 r_missed *= self.params.score.missed_multiplier
1339 reward += -self.params.score.time_penalty + r_missed
1341 info = {
1342 "hazard": r_haz_cumul,
1343 "target": r_tar_cumul,
1344 "death": r_death_cumul,
1345 "wall": r_wall_cumul,
1346 "missed": r_missed,
1347 "velocity": r_vel_cumul,
1348 }
1350 return reward, info
1352 def s_score(self):
1353 """Determines the total score for the timestep.
1355 Raises
1356 ------
1357 NotImplementedError
1358 An incorrect type is specified.
1360 Returns
1361 -------
1362 score : flaot
1363 total score for this timestep.
1364 info : dict
1365 extra info from the score system
1366 """
1367 if self.params.score.type.lower() == "basic":
1368 return self.basic_reward()
1369 else:
1370 msg = "Score system has no implementation for reward type {}".format(
1371 self.params.score.type
1372 )
1373 raise NotImplementedError(msg)
1375 def step(self, user_input):
1376 """Perform one iteration of the game loop.
1378 Multiple physics updates can be made between rendering calls. Also
1379 spwans players if needed and updates the targets based on what has been
1380 reached.
1382 Parameters
1383 ----------
1384 user_input : dict
1385 Each key is an integer representing the entity id that the action
1386 applies to. Each value is a numpy array for the action to take.
1388 Returns
1389 -------
1390 info : dict
1391 Extra infomation for debugging.
1392 """
1393 # reset handles spawn so don't call it on first step after reset
1394 if not self._performed_reset_spawn:
1395 self.spawn_players()
1396 self._performed_reset_spawn = False
1398 self.create_targets()
1399 return super().step(user_input)