diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 9230659e500c0b292e3525eceb9280f77878c1a8..91c6d72faee99df7b2a6242de246f55bd5953f17 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -1,8 +1,6 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint import numpy as np -from enum import IntEnum - from itertools import starmap from typing import Tuple, Optional, NamedTuple, List diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index a140da06c29916d1f05b01cb2df0bb437f27a5d6..e859173469fa60f81e86c0f6fad74d82905e21ad 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -94,7 +94,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.location_has_agent_ready_to_depart = {} for _agent in self.env.agents: - if not TrainState.off_map_state(_agent.state) and \ + if not _agent.state.is_off_map_state() and \ _agent.position: self.location_has_agent[tuple(_agent.position)] = 1 self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction @@ -103,7 +103,7 @@ class TreeObsForRailEnv(ObservationBuilder): 'malfunction'] # [NIMISH] WHAT IS THIS - if TrainState.off_map_state(_agent.state) and \ + if _agent.state.is_off_map_state() and \ _agent.initial_position: self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0) self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1 @@ -570,9 +570,9 @@ class GlobalObsForRailEnv(ObservationBuilder): def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray): agent = self.env.agents[handle] - if TrainState.off_map_state(agent.state): + if agent.state.is_off_map_state(): agent_virtual_position = agent.initial_position - elif TrainState.on_map_state(agent.state): + elif agent.state.is_on_map_state(): agent_virtual_position = agent.position elif agent.state == TrainState.DONE: agent_virtual_position = agent.target @@ -608,7 +608,7 @@ class GlobalObsForRailEnv(ObservationBuilder): obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] # fifth channel: all ready to depart on this position - if TrainState.off_map_state(other_agent.state): + if other_agent.state.is_off_map_state(): obs_agents_state[other_agent.initial_position][4] += 1 return self.rail_obs, obs_agents_state, obs_targets diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 68bb0de2677f64fa8ee2763455c3d5aeef33d203..2f19bddd8aeda0e8fe8bdc886fcf21db646a22c7 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -3,7 +3,6 @@ Definition of the RailEnv environment. """ import random # TODO: _ this is a global method --> utils or remove later -from enum import IntEnum from typing import List, NamedTuple, Optional, Dict, Tuple import numpy as np @@ -285,7 +284,7 @@ class RailEnv(Environment): False: Agent cannot provide an action """ return agent.state == TrainState.READY_TO_DEPART or \ - (TrainState.on_map_state(agent.state) and \ + (agent.state.is_on_map_state() and \ fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03) ) def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *, @@ -406,7 +405,7 @@ class RailEnv(Environment): return observation_dict, info_dict def apply_action_independent(self, action, rail, position, direction): - if RailEnvActions.is_moving_action(action): + if action.is_moving_action(): new_direction, _ = check_action(action, position, direction, rail) new_position = get_new_position(position, new_direction) else: @@ -420,7 +419,7 @@ class RailEnv(Environment): st_signals['malfunction_counter_complete'] = agent.malfunction_handler.malfunction_counter_complete st_signals['earliest_departure_reached'] = self._elapsed_steps >= agent.earliest_departure st_signals['stop_action_given'] = (preprocessed_action == RailEnvActions.STOP_MOVING) - st_signals['valid_movement_action_given'] = RailEnvActions.is_moving_action(preprocessed_action) + st_signals['valid_movement_action_given'] = preprocessed_action.is_moving_action() st_signals['target_reached'] = fast_position_equal(agent.position, agent.target) st_signals['movement_conflict'] = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information @@ -557,10 +556,6 @@ class RailEnv(Environment): else: final_new_position = agent.position final_new_direction = agent.direction - # if final_new_position and self.rail.grid[final_new_position] == 0: - # import pdb; pdb.set_trace() - # if final_new_position and not (final_new_position[0] >= 0 and final_new_position[1] >= 0 and final_new_position[0] < self.rail.height and final_new_position[1] < self.rail.width): # TODO: Remove this - # import pdb; pdb.set_trace() agent.position = final_new_position agent.direction = final_new_direction diff --git a/flatland/envs/rail_env_action.py b/flatland/envs/rail_env_action.py index a25cc8f0f37233f76b921ffc62c83818e8e7bb9b..f583eb72568f375e3a066cbf355771bb9b60e038 100644 --- a/flatland/envs/rail_env_action.py +++ b/flatland/envs/rail_env_action.py @@ -19,9 +19,12 @@ class RailEnvActions(IntEnum): 4: 'S', }[a] - @staticmethod - def is_moving_action(action): - return action in [1,2,3] + @classmethod + def check_valid_action(cls, action): + return action in cls._value2member_map_ + + def is_moving_action(self): + return self.value in [self.MOVE_RIGHT, self.MOVE_LEFT, self.MOVE_FORWARD] RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)]) diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py index 7d0632823e46432fb9d13be4af9af06d3bf6f873..e844390f7d4927476525da45196db28893145f7a 100644 --- a/flatland/envs/rail_env_shortest_paths.py +++ b/flatland/envs/rail_env_shortest_paths.py @@ -227,9 +227,9 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non shortest_paths = dict() def _shortest_path_for_agent(agent): - if TrainState.off_map_state(agent.state): + if agent.state.is_off_map_state(): position = agent.initial_position - elif TrainState.on_map_state(agent.state): + elif agent.state.is_on_map_state(): position = agent.position elif agent.state == TrainState.DONE: position = agent.target diff --git a/flatland/envs/step_utils/action_preprocessing.py b/flatland/envs/step_utils/action_preprocessing.py index c777342d6d9cf20167d7f2a9df10a9a2a3a7922a..98c42b15961b515e8c07eb89b521453a581e82f2 100644 --- a/flatland/envs/step_utils/action_preprocessing.py +++ b/flatland/envs/step_utils/action_preprocessing.py @@ -5,11 +5,10 @@ from flatland.envs.step_utils.transition_utils import check_valid_action def process_illegal_action(action: RailEnvActions): - # TODO - Dipam : This check is kind of weird, change this - if action is None or action not in RailEnvActions._value2member_map_: + if not RailEnvActions.check_valid_action(action): return RailEnvActions.DO_NOTHING else: - return action + return RailEnvActions(action) def process_do_nothing(state: TrainState): diff --git a/flatland/envs/step_utils/action_saver.py b/flatland/envs/step_utils/action_saver.py index 56f7465af77de4a88ce6d010593bca92c8280759..a34778ed48a90a8dede28ba83181877c247deb96 100644 --- a/flatland/envs/step_utils/action_saver.py +++ b/flatland/envs/step_utils/action_saver.py @@ -15,8 +15,8 @@ class ActionSaver: def save_action_if_allowed(self, action, state): if not self.is_action_saved and \ - RailEnvActions.is_moving_action(action) and \ - not TrainState.is_malfunction_state(state): + action.is_moving_action() and \ + not state.is_malfunction_state(): self.saved_action = action def clear_saved_action(self): diff --git a/flatland/envs/step_utils/states.py b/flatland/envs/step_utils/states.py index b8040939419e27987a06636c205e2f2cfef45166..4e612ae842fd9da2e206f010a72bca2e0b9f5608 100644 --- a/flatland/envs/step_utils/states.py +++ b/flatland/envs/step_utils/states.py @@ -13,17 +13,14 @@ class TrainState(IntEnum): def check_valid_state(cls, state): return state in cls._value2member_map_ - @staticmethod - def is_malfunction_state(state): - return state in [2, 5] # TODO: Can this be done with names instead? + def is_malfunction_state(self): + return self.value in [self.MALFUNCTION, self.MALFUNCTION_OFF_MAP] - @staticmethod - def off_map_state(state): - return state in [0, 1, 2] + def is_off_map_state(self): + return self.value in [self.WAITING, self.READY_TO_DEPART, self.MALFUNCTION_OFF_MAP] - @staticmethod - def on_map_state(state): - return state in [3, 4, 5] + def is_on_map_state(self): + return self.value in [self.MOVING, self.STOPPED, self.MALFUNCTION] diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 9499c5c4542bcb7c16231c2fd88a8b445ff3bce0..cd765cd19ba0c9510d301ac77a0782bccd6bd6b4 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -743,7 +743,7 @@ class RenderLocal(RenderBase): if show_inactive_agents: show_this_agent = True else: - show_this_agent = TrainState.on_map_state(agent.state) + show_this_agent = agent.state.is_on_map_state() if show_this_agent: self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,