Commit 359633c4 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

fix enum handling

parent ca36e40e
from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np
from enum import IntEnum
from itertools import starmap
from typing import Tuple, Optional, NamedTuple, List
......
......@@ -94,7 +94,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self.location_has_agent_ready_to_depart = {}
for _agent in self.env.agents:
if not TrainState.off_map_state(_agent.state) and \
if not _agent.state.is_off_map_state() and \
_agent.position:
self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
......@@ -103,7 +103,7 @@ class TreeObsForRailEnv(ObservationBuilder):
'malfunction']
# [NIMISH] WHAT IS THIS
if TrainState.off_map_state(_agent.state) and \
if _agent.state.is_off_map_state() and \
_agent.initial_position:
self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0)
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1
......@@ -570,9 +570,9 @@ class GlobalObsForRailEnv(ObservationBuilder):
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
agent = self.env.agents[handle]
if TrainState.off_map_state(agent.state):
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif TrainState.on_map_state(agent.state):
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
......@@ -608,7 +608,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
# fifth channel: all ready to depart on this position
if TrainState.off_map_state(other_agent.state):
if other_agent.state.is_off_map_state():
obs_agents_state[other_agent.initial_position][4] += 1
return self.rail_obs, obs_agents_state, obs_targets
......
......@@ -3,7 +3,6 @@ Definition of the RailEnv environment.
"""
import random
# TODO: _ this is a global method --> utils or remove later
from enum import IntEnum
from typing import List, NamedTuple, Optional, Dict, Tuple
import numpy as np
......@@ -285,7 +284,7 @@ class RailEnv(Environment):
False: Agent cannot provide an action
"""
return agent.state == TrainState.READY_TO_DEPART or \
(TrainState.on_map_state(agent.state) and \
(agent.state.is_on_map_state() and \
fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03) )
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
......@@ -406,7 +405,7 @@ class RailEnv(Environment):
return observation_dict, info_dict
def apply_action_independent(self, action, rail, position, direction):
if RailEnvActions.is_moving_action(action):
if action.is_moving_action():
new_direction, _ = check_action(action, position, direction, rail)
new_position = get_new_position(position, new_direction)
else:
......@@ -420,7 +419,7 @@ class RailEnv(Environment):
st_signals['malfunction_counter_complete'] = agent.malfunction_handler.malfunction_counter_complete
st_signals['earliest_departure_reached'] = self._elapsed_steps >= agent.earliest_departure
st_signals['stop_action_given'] = (preprocessed_action == RailEnvActions.STOP_MOVING)
st_signals['valid_movement_action_given'] = RailEnvActions.is_moving_action(preprocessed_action)
st_signals['valid_movement_action_given'] = preprocessed_action.is_moving_action()
st_signals['target_reached'] = fast_position_equal(agent.position, agent.target)
st_signals['movement_conflict'] = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information
......@@ -557,10 +556,6 @@ class RailEnv(Environment):
else:
final_new_position = agent.position
final_new_direction = agent.direction
# if final_new_position and self.rail.grid[final_new_position] == 0:
# import pdb; pdb.set_trace()
# if final_new_position and not (final_new_position[0] >= 0 and final_new_position[1] >= 0 and final_new_position[0] < self.rail.height and final_new_position[1] < self.rail.width): # TODO: Remove this
# import pdb; pdb.set_trace()
agent.position = final_new_position
agent.direction = final_new_direction
......
......@@ -19,9 +19,12 @@ class RailEnvActions(IntEnum):
4: 'S',
}[a]
@staticmethod
def is_moving_action(action):
return action in [1,2,3]
@classmethod
def check_valid_action(cls, action):
return action in cls._value2member_map_
def is_moving_action(self):
return self.value in [self.MOVE_RIGHT, self.MOVE_LEFT, self.MOVE_FORWARD]
RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
......
......@@ -227,9 +227,9 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
shortest_paths = dict()
def _shortest_path_for_agent(agent):
if TrainState.off_map_state(agent.state):
if agent.state.is_off_map_state():
position = agent.initial_position
elif TrainState.on_map_state(agent.state):
elif agent.state.is_on_map_state():
position = agent.position
elif agent.state == TrainState.DONE:
position = agent.target
......
......@@ -5,11 +5,10 @@ from flatland.envs.step_utils.transition_utils import check_valid_action
def process_illegal_action(action: RailEnvActions):
# TODO - Dipam : This check is kind of weird, change this
if action is None or action not in RailEnvActions._value2member_map_:
if not RailEnvActions.check_valid_action(action):
return RailEnvActions.DO_NOTHING
else:
return action
return RailEnvActions(action)
def process_do_nothing(state: TrainState):
......
......@@ -15,8 +15,8 @@ class ActionSaver:
def save_action_if_allowed(self, action, state):
if not self.is_action_saved and \
RailEnvActions.is_moving_action(action) and \
not TrainState.is_malfunction_state(state):
action.is_moving_action() and \
not state.is_malfunction_state():
self.saved_action = action
def clear_saved_action(self):
......
......@@ -13,17 +13,14 @@ class TrainState(IntEnum):
def check_valid_state(cls, state):
return state in cls._value2member_map_
@staticmethod
def is_malfunction_state(state):
return state in [2, 5] # TODO: Can this be done with names instead?
def is_malfunction_state(self):
return self.value in [self.MALFUNCTION, self.MALFUNCTION_OFF_MAP]
@staticmethod
def off_map_state(state):
return state in [0, 1, 2]
def is_off_map_state(self):
return self.value in [self.WAITING, self.READY_TO_DEPART, self.MALFUNCTION_OFF_MAP]
@staticmethod
def on_map_state(state):
return state in [3, 4, 5]
def is_on_map_state(self):
return self.value in [self.MOVING, self.STOPPED, self.MALFUNCTION]
......
......@@ -743,7 +743,7 @@ class RenderLocal(RenderBase):
if show_inactive_agents:
show_this_agent = True
else:
show_this_agent = TrainState.on_map_state(agent.state)
show_this_agent = agent.state.is_on_map_state()
if show_this_agent:
self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment