Skip to content
Snippets Groups Projects
Commit 359633c4 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

fix enum handling

parent ca36e40e
No related branches found
No related tags found
No related merge requests found
from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np
from enum import IntEnum
from itertools import starmap
from typing import Tuple, Optional, NamedTuple, List
......
......@@ -94,7 +94,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self.location_has_agent_ready_to_depart = {}
for _agent in self.env.agents:
if not TrainState.off_map_state(_agent.state) and \
if not _agent.state.is_off_map_state() and \
_agent.position:
self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
......@@ -103,7 +103,7 @@ class TreeObsForRailEnv(ObservationBuilder):
'malfunction']
# [NIMISH] WHAT IS THIS
if TrainState.off_map_state(_agent.state) and \
if _agent.state.is_off_map_state() and \
_agent.initial_position:
self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0)
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1
......@@ -570,9 +570,9 @@ class GlobalObsForRailEnv(ObservationBuilder):
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
agent = self.env.agents[handle]
if TrainState.off_map_state(agent.state):
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif TrainState.on_map_state(agent.state):
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
......@@ -608,7 +608,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
# fifth channel: all ready to depart on this position
if TrainState.off_map_state(other_agent.state):
if other_agent.state.is_off_map_state():
obs_agents_state[other_agent.initial_position][4] += 1
return self.rail_obs, obs_agents_state, obs_targets
......
......@@ -3,7 +3,6 @@ Definition of the RailEnv environment.
"""
import random
# TODO: _ this is a global method --> utils or remove later
from enum import IntEnum
from typing import List, NamedTuple, Optional, Dict, Tuple
import numpy as np
......@@ -285,7 +284,7 @@ class RailEnv(Environment):
False: Agent cannot provide an action
"""
return agent.state == TrainState.READY_TO_DEPART or \
(TrainState.on_map_state(agent.state) and \
(agent.state.is_on_map_state() and \
fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03) )
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
......@@ -406,7 +405,7 @@ class RailEnv(Environment):
return observation_dict, info_dict
def apply_action_independent(self, action, rail, position, direction):
if RailEnvActions.is_moving_action(action):
if action.is_moving_action():
new_direction, _ = check_action(action, position, direction, rail)
new_position = get_new_position(position, new_direction)
else:
......@@ -420,7 +419,7 @@ class RailEnv(Environment):
st_signals['malfunction_counter_complete'] = agent.malfunction_handler.malfunction_counter_complete
st_signals['earliest_departure_reached'] = self._elapsed_steps >= agent.earliest_departure
st_signals['stop_action_given'] = (preprocessed_action == RailEnvActions.STOP_MOVING)
st_signals['valid_movement_action_given'] = RailEnvActions.is_moving_action(preprocessed_action)
st_signals['valid_movement_action_given'] = preprocessed_action.is_moving_action()
st_signals['target_reached'] = fast_position_equal(agent.position, agent.target)
st_signals['movement_conflict'] = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information
......@@ -557,10 +556,6 @@ class RailEnv(Environment):
else:
final_new_position = agent.position
final_new_direction = agent.direction
# if final_new_position and self.rail.grid[final_new_position] == 0:
# import pdb; pdb.set_trace()
# if final_new_position and not (final_new_position[0] >= 0 and final_new_position[1] >= 0 and final_new_position[0] < self.rail.height and final_new_position[1] < self.rail.width): # TODO: Remove this
# import pdb; pdb.set_trace()
agent.position = final_new_position
agent.direction = final_new_direction
......
......@@ -19,9 +19,12 @@ class RailEnvActions(IntEnum):
4: 'S',
}[a]
@staticmethod
def is_moving_action(action):
return action in [1,2,3]
@classmethod
def check_valid_action(cls, action):
return action in cls._value2member_map_
def is_moving_action(self):
return self.value in [self.MOVE_RIGHT, self.MOVE_LEFT, self.MOVE_FORWARD]
RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
......
......@@ -227,9 +227,9 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
shortest_paths = dict()
def _shortest_path_for_agent(agent):
if TrainState.off_map_state(agent.state):
if agent.state.is_off_map_state():
position = agent.initial_position
elif TrainState.on_map_state(agent.state):
elif agent.state.is_on_map_state():
position = agent.position
elif agent.state == TrainState.DONE:
position = agent.target
......
......@@ -5,11 +5,10 @@ from flatland.envs.step_utils.transition_utils import check_valid_action
def process_illegal_action(action: RailEnvActions):
# TODO - Dipam : This check is kind of weird, change this
if action is None or action not in RailEnvActions._value2member_map_:
if not RailEnvActions.check_valid_action(action):
return RailEnvActions.DO_NOTHING
else:
return action
return RailEnvActions(action)
def process_do_nothing(state: TrainState):
......
......@@ -15,8 +15,8 @@ class ActionSaver:
def save_action_if_allowed(self, action, state):
if not self.is_action_saved and \
RailEnvActions.is_moving_action(action) and \
not TrainState.is_malfunction_state(state):
action.is_moving_action() and \
not state.is_malfunction_state():
self.saved_action = action
def clear_saved_action(self):
......
......@@ -13,17 +13,14 @@ class TrainState(IntEnum):
def check_valid_state(cls, state):
return state in cls._value2member_map_
@staticmethod
def is_malfunction_state(state):
return state in [2, 5] # TODO: Can this be done with names instead?
def is_malfunction_state(self):
return self.value in [self.MALFUNCTION, self.MALFUNCTION_OFF_MAP]
@staticmethod
def off_map_state(state):
return state in [0, 1, 2]
def is_off_map_state(self):
return self.value in [self.WAITING, self.READY_TO_DEPART, self.MALFUNCTION_OFF_MAP]
@staticmethod
def on_map_state(state):
return state in [3, 4, 5]
def is_on_map_state(self):
return self.value in [self.MOVING, self.STOPPED, self.MALFUNCTION]
......
......@@ -743,7 +743,7 @@ class RenderLocal(RenderBase):
if show_inactive_agents:
show_this_agent = True
else:
show_this_agent = TrainState.on_map_state(agent.state)
show_this_agent = agent.state.is_on_map_state()
if show_this_agent:
self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment