From ca36e40eeb553c6ffecdb837da6e6635f191634d Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipam@aicrowd.com> Date: Thu, 9 Sep 2021 18:34:21 +0530 Subject: [PATCH] change all RailEnvStatus to TrainState --- flatland/envs/agent_utils.py | 25 ++++--- flatland/envs/malfunction_generators.py | 6 +- flatland/envs/observations.py | 19 +++-- flatland/envs/persistence.py | 2 +- flatland/envs/rail_env.py | 74 +++++-------------- flatland/envs/rail_env_shortest_paths.py | 10 +-- .../envs/step_utils/action_preprocessing.py | 6 +- flatland/envs/step_utils/states.py | 9 +++ flatland/envs/step_utils/transition_utils.py | 5 +- flatland/utils/rendertools.py | 6 +- tests/test_env_step_utils.py | 6 +- 11 files changed, 70 insertions(+), 98 deletions(-) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 632caeea..9230659e 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -2,18 +2,19 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint import numpy as np from enum import IntEnum -from flatland.envs.step_utils.states import TrainState + from itertools import starmap from typing import Tuple, Optional, NamedTuple, List from attr import attr, attrs, attrib, Factory from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.schedule_utils import Schedule +from flatland.envs.timetable_utils import Line from flatland.envs.step_utils.action_saver import ActionSaver from flatland.envs.step_utils.speed_counter import SpeedCounter from flatland.envs.step_utils.state_machine import TrainStateMachine +from flatland.envs.step_utils.states import TrainState from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]), @@ -137,8 +138,8 @@ class EnvAgent: """ speed_datas = [] speed_counters = [] - for i in range(len(schedule.agent_positions)): - speed = schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0 + for i in range(len(line.agent_positions)): + speed = line.agent_speeds[i] if line.agent_speeds is not None else 1.0 speed_datas.append({'position_fraction': 0.0, 'speed': speed, 'transition_action_on_cellexit': 0}) @@ -152,16 +153,16 @@ class EnvAgent: 'next_malfunction': 0, 'nr_malfunctions': 0}) - return list(starmap(EnvAgent, zip(schedule.agent_positions, # TODO : Dipam - Really want to change this way of loading agents - schedule.agent_directions, - schedule.agent_directions, - schedule.agent_targets, - [False] * len(schedule.agent_positions), - [None] * len(schedule.agent_positions), # earliest_departure - [None] * len(schedule.agent_positions), # latest_arrival + return list(starmap(EnvAgent, zip(line.agent_positions, # TODO : Dipam - Really want to change this way of loading agents + line.agent_directions, + line.agent_directions, + line.agent_targets, + [False] * len(line.agent_positions), + [None] * len(line.agent_positions), # earliest_departure + [None] * len(line.agent_positions), # latest_arrival speed_datas, malfunction_datas, - range(len(schedule.agent_positions)), + range(len(line.agent_positions)), speed_counters, ))) diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index 2fecddf1..0dfafb36 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -5,7 +5,8 @@ from typing import Callable, NamedTuple, Optional, Tuple import numpy as np from numpy.random.mtrand import RandomState -from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs.agent_utils import EnvAgent +from flatland.envs.step_utils.states import TrainState from flatland.envs import persistence @@ -155,7 +156,8 @@ def single_malfunction_generator(earlierst_malfunction: int, malfunction_duratio malfunction_calls[agent.handle] = 1 # Break an agent that is active at the time of the malfunction - if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed? + if (agent.state == TrainState.MOVING or agent.state == TrainState.STOPPED) \ + and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed? global_nr_malfunctions += 1 return Malfunction(malfunction_duration) else: diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 4de36060..a140da06 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -11,7 +11,8 @@ from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_prediction_builder import PredictionBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import coordinate_to_position -from flatland.envs.agent_utils import RailAgentStatus, EnvAgent +from flatland.envs.agent_utils import EnvAgent +from flatland.envs.step_utils.states import TrainState from flatland.utils.ordered_set import OrderedSet @@ -93,7 +94,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.location_has_agent_ready_to_depart = {} for _agent in self.env.agents: - if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ + if not TrainState.off_map_state(_agent.state) and \ _agent.position: self.location_has_agent[tuple(_agent.position)] = 1 self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction @@ -102,7 +103,7 @@ class TreeObsForRailEnv(ObservationBuilder): 'malfunction'] # [NIMISH] WHAT IS THIS - if _agent.status in [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.WAITING] and \ + if TrainState.off_map_state(_agent.state) and \ _agent.initial_position: self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0) self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1 @@ -569,13 +570,11 @@ class GlobalObsForRailEnv(ObservationBuilder): def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray): agent = self.env.agents[handle] - if agent.status == RailAgentStatus.WAITING: - agent_virtual_position = agent.initial_position - elif agent.status == RailAgentStatus.READY_TO_DEPART: + if TrainState.off_map_state(agent.state): agent_virtual_position = agent.initial_position - elif agent.status == RailAgentStatus.ACTIVE: + elif TrainState.on_map_state(agent.state): agent_virtual_position = agent.position - elif agent.status == RailAgentStatus.DONE: + elif agent.state == TrainState.DONE: agent_virtual_position = agent.target else: return None @@ -596,7 +595,7 @@ class GlobalObsForRailEnv(ObservationBuilder): other_agent: EnvAgent = self.env.agents[i] # ignore other agents not in the grid any more - if other_agent.status == RailAgentStatus.DONE_REMOVED: + if other_agent.state == TrainState.DONE: continue obs_targets[other_agent.target][1] = 1 @@ -609,7 +608,7 @@ class GlobalObsForRailEnv(ObservationBuilder): obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] # fifth channel: all ready to depart on this position - if other_agent.status == RailAgentStatus.READY_TO_DEPART or other_agent.status == RailAgentStatus.WAITING: + if TrainState.off_map_state(other_agent.state): obs_agents_state[other_agent.initial_position][4] += 1 return self.rail_obs, obs_agents_state, obs_targets diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py index 41f352e7..c5ec8f33 100644 --- a/flatland/envs/persistence.py +++ b/flatland/envs/persistence.py @@ -13,7 +13,7 @@ from flatland.core.env_observation_builder import DummyObservationBuilder #from flatland.core.grid.grid4_utils import get_new_position #from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus +from flatland.envs.agent_utils import Agent, EnvAgent from flatland.envs.distance_map import DistanceMap #from flatland.envs.observations import GlobalObsForRailEnv diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 1dc332d9..68bb0de2 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -17,7 +17,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import IntVector2D, position_to_coordinate from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus +from flatland.envs.agent_utils import Agent, EnvAgent from flatland.envs.distance_map import DistanceMap from flatland.envs.rail_env_action import RailEnvActions @@ -39,8 +39,7 @@ from gym.utils import seeding # from flatland.envs.line_generators import random_line_generator, LineGenerator -# NEW : Imports -from flatland.envs.schedule_time_generators import schedule_time_generator +from flatland.envs.timetable_generators import timetable_generator from flatland.envs.step_utils.states import TrainState from flatland.envs.step_utils.transition_utils import check_action @@ -285,9 +284,9 @@ class RailEnv(Environment): True: Agent needs to provide an action False: Agent cannot provide an action """ - return (agent.status == RailAgentStatus.READY_TO_DEPART or ( - agent.status == RailAgentStatus.ACTIVE and fast_isclose(agent.speed_data['position_fraction'], 0.0, - rtol=1e-03))) + return agent.state == TrainState.READY_TO_DEPART or \ + (TrainState.on_map_state(agent.state) and \ + fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03) ) def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *, random_seed: bool = None) -> Tuple[Dict, Dict]: @@ -400,7 +399,7 @@ class RailEnv(Environment): i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents) }, 'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)}, - 'status': {i: agent.status for i, agent in enumerate(self.agents)} + 'state': {i: agent.state for i, agent in enumerate(self.agents)} } # Return the new observation vectors for each agent observation_dict: Dict = self._get_observations() @@ -425,6 +424,8 @@ class RailEnv(Environment): st_signals['target_reached'] = fast_position_equal(agent.position, agent.target) st_signals['movement_conflict'] = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information + return st_signals + def _handle_end_reward(self, agent: EnvAgent) -> int: ''' Handles end-of-episode reward for a particular agent. @@ -456,8 +457,7 @@ class RailEnv(Environment): def step(self, action_dict_: Dict[int, RailEnvActions]): """ Updates rewards for the agents at a step. - - def step(self, action_dict): + """ self._elapsed_steps += 1 # If we're done, set reward and info_dict and step() is done. @@ -497,7 +497,7 @@ class RailEnv(Environment): agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random) # Get action for the agent - action = action_dict.get(i_agent, RailEnvActions.DO_NOTHING) + action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING) # TODO: Add the bottom stuff to separate function(s) # Preprocess action @@ -509,7 +509,7 @@ class RailEnv(Environment): agent_not_on_map = current_position is None if agent_not_on_map: # Agent not added on map yet current_position, current_direction = agent.initial_position, agent.initial_direction - action = preprocess_moving_action(action, agent.state, self.rail, current_position, current_direction) + action = preprocess_moving_action(action, self.rail, current_position, current_direction) # Save moving actions in not already saved agent.action_saver.save_action_if_allowed(action, agent.state) @@ -519,6 +519,7 @@ class RailEnv(Environment): if agent_not_on_map and agent.action_saver.is_action_saved: temp_new_position = agent.initial_position temp_new_direction = agent.initial_direction + preprocessed_action = action # When cell exit occurs apply saved action independent of other agents elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved: @@ -526,11 +527,13 @@ class RailEnv(Environment): # Apply action independent of other agents and get temporary new position and direction temp_pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction) temp_new_position, temp_new_direction = temp_pd + preprocessed_action = saved_action else: temp_new_position, temp_new_direction = agent.position, agent.direction + preprocessed_action = action # TODO: Saving temporary positon shouldn't be needed if recheck of position is not needed later (see TAG#1) - temp_saved_data[i_agent] = temp_new_position, temp_new_direction, action + temp_saved_data[i_agent] = temp_new_position, temp_new_direction, preprocessed_action self.motionCheck.addAgent(i_agent, agent.position, temp_new_position) # Find conflicts @@ -554,6 +557,10 @@ class RailEnv(Environment): else: final_new_position = agent.position final_new_direction = agent.direction + # if final_new_position and self.rail.grid[final_new_position] == 0: + # import pdb; pdb.set_trace() + # if final_new_position and not (final_new_position[0] >= 0 and final_new_position[1] >= 0 and final_new_position[0] < self.rail.height and final_new_position[1] < self.rail.width): # TODO: Remove this + # import pdb; pdb.set_trace() agent.position = final_new_position agent.direction = final_new_direction @@ -581,49 +588,6 @@ class RailEnv(Environment): self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} # TODO : Remove this return self._get_observations(), self.rewards_dict, self.dones, info_dict # TODO : Will need changes? - def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D): - """ - Sets the agent to its initial position. Updates the agent object and the position - of the agent inside the global agent_position numpy array - - Parameters - ------- - agent: EnvAgent object - new_position: IntVector2D - """ - agent.position = new_position - self.agent_positions[agent.position] = agent.handle - - def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D): - """ - Move the agent to the a new position. Updates the agent object and the position - of the agent inside the global agent_position numpy array - - Parameters - ------- - agent: EnvAgent object - new_position: IntVector2D - """ - agent.position = new_position - self.agent_positions[agent.old_position] = -1 - self.agent_positions[agent.position] = agent.handle - - def _remove_agent_from_scene(self, agent: EnvAgent): - """ - Remove the agent from the scene. Updates the agent object and the position - of the agent inside the global agent_position numpy array - - Parameters - ------- - agent: EnvAgent object - """ - self.agent_positions[agent.position] = -1 - if self.remove_agents_at_target: - agent.position = None - # setting old_position to None here stops the DONE agents from appearing in the rendered image - agent.old_position = None - agent.status = RailAgentStatus.DONE_REMOVED - def record_timestep(self, dActions): ''' Record the positions and orientations of all agents in memory, in the cur_episode ''' diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py index 8c981778..7d063282 100644 --- a/flatland/envs/rail_env_shortest_paths.py +++ b/flatland/envs/rail_env_shortest_paths.py @@ -7,7 +7,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.step_utils.states import TrainState from flatland.envs.distance_map import DistanceMap from flatland.envs.rail_env_action import RailEnvActions, RailEnvNextAction from flatland.envs.rail_trainrun_data_structures import Waypoint @@ -227,13 +227,11 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non shortest_paths = dict() def _shortest_path_for_agent(agent): - if agent.status == RailAgentStatus.WAITING: + if TrainState.off_map_state(agent.state): position = agent.initial_position - elif agent.status == RailAgentStatus.READY_TO_DEPART: - position = agent.initial_position - elif agent.status == RailAgentStatus.ACTIVE: + elif TrainState.on_map_state(agent.state): position = agent.position - elif agent.status == RailAgentStatus.DONE: + elif agent.state == TrainState.DONE: position = agent.target else: shortest_paths[agent.handle] = None diff --git a/flatland/envs/step_utils/action_preprocessing.py b/flatland/envs/step_utils/action_preprocessing.py index e8ad1d79..c777342d 100644 --- a/flatland/envs/step_utils/action_preprocessing.py +++ b/flatland/envs/step_utils/action_preprocessing.py @@ -20,8 +20,8 @@ def process_do_nothing(state: TrainState): return action -def process_left_right(action, state, rail, position, direction): - if not check_valid_action(action, state, rail, position, direction): +def process_left_right(action, rail, position, direction): + if not check_valid_action(action, rail, position, direction): action = RailEnvActions.MOVE_FORWARD return action @@ -48,7 +48,7 @@ def preprocess_raw_action(action, state): return action -def preprocess_moving_action(action, state, rail, position, direction): +def preprocess_moving_action(action, rail, position, direction): """ LEFT/RIGHT is converted to FORWARD if left/right is not available and train is moving FORWARD is converted to STOP_MOVING if leading to dead end? diff --git a/flatland/envs/step_utils/states.py b/flatland/envs/step_utils/states.py index 4c991b66..b8040939 100644 --- a/flatland/envs/step_utils/states.py +++ b/flatland/envs/step_utils/states.py @@ -16,6 +16,15 @@ class TrainState(IntEnum): @staticmethod def is_malfunction_state(state): return state in [2, 5] # TODO: Can this be done with names instead? + + @staticmethod + def off_map_state(state): + return state in [0, 1, 2] + + @staticmethod + def on_map_state(state): + return state in [3, 4, 5] + diff --git a/flatland/envs/step_utils/transition_utils.py b/flatland/envs/step_utils/transition_utils.py index 2d58d21e..157db5ac 100644 --- a/flatland/envs/step_utils/transition_utils.py +++ b/flatland/envs/step_utils/transition_utils.py @@ -66,9 +66,8 @@ def check_action_on_agent(action, rail, position, direction): new_direction, transition_valid = check_action(action, position, direction, rail) new_position = get_new_position(position, new_direction) - cell_inside_grid = check_bounds(new_position, rail.height, rail.width) - cell_not_empty = rail.get_full_transitions(*new_position) > 0 - new_cell_valid = cell_inside_grid and cell_not_empty + new_cell_valid = check_bounds(new_position, rail.height, rail.width) and \ + rail.get_full_transitions(*new_position) > 0 # If transition validity hasn't been checked yet. if transition_valid is None: diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 910dec32..9499c5c4 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -7,7 +7,7 @@ import numpy as np from numpy import array from recordtype import recordtype -from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.step_utils.states import TrainState from flatland.utils.graphics_pil import PILGL, PILSVG from flatland.utils.graphics_pgl import PGLGL @@ -741,9 +741,9 @@ class RenderLocal(RenderBase): self.gl.set_cell_occupied(agent_idx, *(agent.position)) if show_inactive_agents: - show_this_agent=True + show_this_agent = True else: - show_this_agent = agent.status == RailAgentStatus.ACTIVE + show_this_agent = TrainState.on_map_state(agent.state) if show_this_agent: self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, diff --git a/tests/test_env_step_utils.py b/tests/test_env_step_utils.py index 739d3d06..4c249de3 100644 --- a/tests/test_env_step_utils.py +++ b/tests/test_env_step_utils.py @@ -10,7 +10,7 @@ from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnvActions from flatland.envs.rail_generators import sparse_rail_generator #from flatland.envs.sparse_rail_gen import SparseRailGen -from flatland.envs.schedule_generators import sparse_schedule_generator +from flatland.envs.line_generators import sparse_line_generator def get_small_two_agent_env(): @@ -35,7 +35,7 @@ def get_small_two_agent_env(): 1. / 3.: 0.25, # Slow commuter train 1. / 4.: 0.25} # Slow freight train - schedule_generator = sparse_schedule_generator(speed_ration_map) + line_generator = sparse_line_generator(speed_ration_map) stochastic_data = MalfunctionParameters(malfunction_rate=1/10000, # Rate of malfunction occurence @@ -48,7 +48,7 @@ def get_small_two_agent_env(): env = RailEnv(width=width, height=height, rail_generator=rail_generator, - schedule_generator=schedule_generator, + line_generator=line_generator, number_of_agents=nr_trains, obs_builder_object=observation_builder, #malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), -- GitLab