Commit ca36e40e authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

change all RailEnvStatus to TrainState

parent bbb0eff6
......@@ -2,18 +2,19 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np
from enum import IntEnum
from flatland.envs.step_utils.states import TrainState
from itertools import starmap
from typing import Tuple, Optional, NamedTuple, List
from attr import attr, attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.schedule_utils import Schedule
from flatland.envs.timetable_utils import Line
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
from flatland.envs.step_utils.state_machine import TrainStateMachine
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
......@@ -137,8 +138,8 @@ class EnvAgent:
"""
speed_datas = []
speed_counters = []
for i in range(len(schedule.agent_positions)):
speed = schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0
for i in range(len(line.agent_positions)):
speed = line.agent_speeds[i] if line.agent_speeds is not None else 1.0
speed_datas.append({'position_fraction': 0.0,
'speed': speed,
'transition_action_on_cellexit': 0})
......@@ -152,16 +153,16 @@ class EnvAgent:
'next_malfunction': 0,
'nr_malfunctions': 0})
return list(starmap(EnvAgent, zip(schedule.agent_positions, # TODO : Dipam - Really want to change this way of loading agents
schedule.agent_directions,
schedule.agent_directions,
schedule.agent_targets,
[False] * len(schedule.agent_positions),
[None] * len(schedule.agent_positions), # earliest_departure
[None] * len(schedule.agent_positions), # latest_arrival
return list(starmap(EnvAgent, zip(line.agent_positions, # TODO : Dipam - Really want to change this way of loading agents
line.agent_directions,
line.agent_directions,
line.agent_targets,
[False] * len(line.agent_positions),
[None] * len(line.agent_positions), # earliest_departure
[None] * len(line.agent_positions), # latest_arrival
speed_datas,
malfunction_datas,
range(len(schedule.agent_positions)),
range(len(line.agent_positions)),
speed_counters,
)))
......
......@@ -5,7 +5,8 @@ from typing import Callable, NamedTuple, Optional, Tuple
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
from flatland.envs import persistence
......@@ -155,7 +156,8 @@ def single_malfunction_generator(earlierst_malfunction: int, malfunction_duratio
malfunction_calls[agent.handle] = 1
# Break an agent that is active at the time of the malfunction
if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed?
if (agent.state == TrainState.MOVING or agent.state == TrainState.STOPPED) \
and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed?
global_nr_malfunctions += 1
return Malfunction(malfunction_duration)
else:
......
......@@ -11,7 +11,8 @@ from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
from flatland.utils.ordered_set import OrderedSet
......@@ -93,7 +94,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self.location_has_agent_ready_to_depart = {}
for _agent in self.env.agents:
if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
if not TrainState.off_map_state(_agent.state) and \
_agent.position:
self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
......@@ -102,7 +103,7 @@ class TreeObsForRailEnv(ObservationBuilder):
'malfunction']
# [NIMISH] WHAT IS THIS
if _agent.status in [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.WAITING] and \
if TrainState.off_map_state(_agent.state) and \
_agent.initial_position:
self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0)
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1
......@@ -569,13 +570,11 @@ class GlobalObsForRailEnv(ObservationBuilder):
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
agent = self.env.agents[handle]
if agent.status == RailAgentStatus.WAITING:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.READY_TO_DEPART:
if TrainState.off_map_state(agent.state):
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
elif TrainState.on_map_state(agent.state):
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
......@@ -596,7 +595,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
other_agent: EnvAgent = self.env.agents[i]
# ignore other agents not in the grid any more
if other_agent.status == RailAgentStatus.DONE_REMOVED:
if other_agent.state == TrainState.DONE:
continue
obs_targets[other_agent.target][1] = 1
......@@ -609,7 +608,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
# fifth channel: all ready to depart on this position
if other_agent.status == RailAgentStatus.READY_TO_DEPART or other_agent.status == RailAgentStatus.WAITING:
if TrainState.off_map_state(other_agent.state):
obs_agents_state[other_agent.initial_position][4] += 1
return self.rail_obs, obs_agents_state, obs_targets
......
......@@ -13,7 +13,7 @@ from flatland.core.env_observation_builder import DummyObservationBuilder
#from flatland.core.grid.grid4_utils import get_new_position
#from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus
from flatland.envs.agent_utils import Agent, EnvAgent
from flatland.envs.distance_map import DistanceMap
#from flatland.envs.observations import GlobalObsForRailEnv
......
......@@ -17,7 +17,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import IntVector2D, position_to_coordinate
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus
from flatland.envs.agent_utils import Agent, EnvAgent
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env_action import RailEnvActions
......@@ -39,8 +39,7 @@ from gym.utils import seeding
# from flatland.envs.line_generators import random_line_generator, LineGenerator
# NEW : Imports
from flatland.envs.schedule_time_generators import schedule_time_generator
from flatland.envs.timetable_generators import timetable_generator
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.transition_utils import check_action
......@@ -285,9 +284,9 @@ class RailEnv(Environment):
True: Agent needs to provide an action
False: Agent cannot provide an action
"""
return (agent.status == RailAgentStatus.READY_TO_DEPART or (
agent.status == RailAgentStatus.ACTIVE and fast_isclose(agent.speed_data['position_fraction'], 0.0,
rtol=1e-03)))
return agent.state == TrainState.READY_TO_DEPART or \
(TrainState.on_map_state(agent.state) and \
fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03) )
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
random_seed: bool = None) -> Tuple[Dict, Dict]:
......@@ -400,7 +399,7 @@ class RailEnv(Environment):
i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
},
'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)},
'status': {i: agent.status for i, agent in enumerate(self.agents)}
'state': {i: agent.state for i, agent in enumerate(self.agents)}
}
# Return the new observation vectors for each agent
observation_dict: Dict = self._get_observations()
......@@ -425,6 +424,8 @@ class RailEnv(Environment):
st_signals['target_reached'] = fast_position_equal(agent.position, agent.target)
st_signals['movement_conflict'] = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information
return st_signals
def _handle_end_reward(self, agent: EnvAgent) -> int:
'''
Handles end-of-episode reward for a particular agent.
......@@ -456,8 +457,7 @@ class RailEnv(Environment):
def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
def step(self, action_dict):
"""
self._elapsed_steps += 1
# If we're done, set reward and info_dict and step() is done.
......@@ -497,7 +497,7 @@ class RailEnv(Environment):
agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)
# Get action for the agent
action = action_dict.get(i_agent, RailEnvActions.DO_NOTHING)
action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING)
# TODO: Add the bottom stuff to separate function(s)
# Preprocess action
......@@ -509,7 +509,7 @@ class RailEnv(Environment):
agent_not_on_map = current_position is None
if agent_not_on_map: # Agent not added on map yet
current_position, current_direction = agent.initial_position, agent.initial_direction
action = preprocess_moving_action(action, agent.state, self.rail, current_position, current_direction)
action = preprocess_moving_action(action, self.rail, current_position, current_direction)
# Save moving actions in not already saved
agent.action_saver.save_action_if_allowed(action, agent.state)
......@@ -519,6 +519,7 @@ class RailEnv(Environment):
if agent_not_on_map and agent.action_saver.is_action_saved:
temp_new_position = agent.initial_position
temp_new_direction = agent.initial_direction
preprocessed_action = action
# When cell exit occurs apply saved action independent of other agents
elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved:
......@@ -526,11 +527,13 @@ class RailEnv(Environment):
# Apply action independent of other agents and get temporary new position and direction
temp_pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction)
temp_new_position, temp_new_direction = temp_pd
preprocessed_action = saved_action
else:
temp_new_position, temp_new_direction = agent.position, agent.direction
preprocessed_action = action
# TODO: Saving temporary positon shouldn't be needed if recheck of position is not needed later (see TAG#1)
temp_saved_data[i_agent] = temp_new_position, temp_new_direction, action
temp_saved_data[i_agent] = temp_new_position, temp_new_direction, preprocessed_action
self.motionCheck.addAgent(i_agent, agent.position, temp_new_position)
# Find conflicts
......@@ -554,6 +557,10 @@ class RailEnv(Environment):
else:
final_new_position = agent.position
final_new_direction = agent.direction
# if final_new_position and self.rail.grid[final_new_position] == 0:
# import pdb; pdb.set_trace()
# if final_new_position and not (final_new_position[0] >= 0 and final_new_position[1] >= 0 and final_new_position[0] < self.rail.height and final_new_position[1] < self.rail.width): # TODO: Remove this
# import pdb; pdb.set_trace()
agent.position = final_new_position
agent.direction = final_new_direction
......@@ -581,49 +588,6 @@ class RailEnv(Environment):
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} # TODO : Remove this
return self._get_observations(), self.rewards_dict, self.dones, info_dict # TODO : Will need changes?
def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
"""
Sets the agent to its initial position. Updates the agent object and the position
of the agent inside the global agent_position numpy array
Parameters
-------
agent: EnvAgent object
new_position: IntVector2D
"""
agent.position = new_position
self.agent_positions[agent.position] = agent.handle
def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D):
"""
Move the agent to the a new position. Updates the agent object and the position
of the agent inside the global agent_position numpy array
Parameters
-------
agent: EnvAgent object
new_position: IntVector2D
"""
agent.position = new_position
self.agent_positions[agent.old_position] = -1
self.agent_positions[agent.position] = agent.handle
def _remove_agent_from_scene(self, agent: EnvAgent):
"""
Remove the agent from the scene. Updates the agent object and the position
of the agent inside the global agent_position numpy array
Parameters
-------
agent: EnvAgent object
"""
self.agent_positions[agent.position] = -1
if self.remove_agents_at_target:
agent.position = None
# setting old_position to None here stops the DONE agents from appearing in the rendered image
agent.old_position = None
agent.status = RailAgentStatus.DONE_REMOVED
def record_timestep(self, dActions):
''' Record the positions and orientations of all agents in memory, in the cur_episode
'''
......
......@@ -7,7 +7,7 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.step_utils.states import TrainState
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env_action import RailEnvActions, RailEnvNextAction
from flatland.envs.rail_trainrun_data_structures import Waypoint
......@@ -227,13 +227,11 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
shortest_paths = dict()
def _shortest_path_for_agent(agent):
if agent.status == RailAgentStatus.WAITING:
if TrainState.off_map_state(agent.state):
position = agent.initial_position
elif agent.status == RailAgentStatus.READY_TO_DEPART:
position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
elif TrainState.on_map_state(agent.state):
position = agent.position
elif agent.status == RailAgentStatus.DONE:
elif agent.state == TrainState.DONE:
position = agent.target
else:
shortest_paths[agent.handle] = None
......
......@@ -20,8 +20,8 @@ def process_do_nothing(state: TrainState):
return action
def process_left_right(action, state, rail, position, direction):
if not check_valid_action(action, state, rail, position, direction):
def process_left_right(action, rail, position, direction):
if not check_valid_action(action, rail, position, direction):
action = RailEnvActions.MOVE_FORWARD
return action
......@@ -48,7 +48,7 @@ def preprocess_raw_action(action, state):
return action
def preprocess_moving_action(action, state, rail, position, direction):
def preprocess_moving_action(action, rail, position, direction):
"""
LEFT/RIGHT is converted to FORWARD if left/right is not available and train is moving
FORWARD is converted to STOP_MOVING if leading to dead end?
......
......@@ -16,6 +16,15 @@ class TrainState(IntEnum):
@staticmethod
def is_malfunction_state(state):
return state in [2, 5] # TODO: Can this be done with names instead?
@staticmethod
def off_map_state(state):
return state in [0, 1, 2]
@staticmethod
def on_map_state(state):
return state in [3, 4, 5]
......@@ -66,9 +66,8 @@ def check_action_on_agent(action, rail, position, direction):
new_direction, transition_valid = check_action(action, position, direction, rail)
new_position = get_new_position(position, new_direction)
cell_inside_grid = check_bounds(new_position, rail.height, rail.width)
cell_not_empty = rail.get_full_transitions(*new_position) > 0
new_cell_valid = cell_inside_grid and cell_not_empty
new_cell_valid = check_bounds(new_position, rail.height, rail.width) and \
rail.get_full_transitions(*new_position) > 0
# If transition validity hasn't been checked yet.
if transition_valid is None:
......
......@@ -7,7 +7,7 @@ import numpy as np
from numpy import array
from recordtype import recordtype
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.step_utils.states import TrainState
from flatland.utils.graphics_pil import PILGL, PILSVG
from flatland.utils.graphics_pgl import PGLGL
......@@ -741,9 +741,9 @@ class RenderLocal(RenderBase):
self.gl.set_cell_occupied(agent_idx, *(agent.position))
if show_inactive_agents:
show_this_agent=True
show_this_agent = True
else:
show_this_agent = agent.status == RailAgentStatus.ACTIVE
show_this_agent = TrainState.on_map_state(agent.state)
if show_this_agent:
self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
......
......@@ -10,7 +10,7 @@ from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator
#from flatland.envs.sparse_rail_gen import SparseRailGen
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.envs.line_generators import sparse_line_generator
def get_small_two_agent_env():
......@@ -35,7 +35,7 @@ def get_small_two_agent_env():
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
schedule_generator = sparse_schedule_generator(speed_ration_map)
line_generator = sparse_line_generator(speed_ration_map)
stochastic_data = MalfunctionParameters(malfunction_rate=1/10000, # Rate of malfunction occurence
......@@ -48,7 +48,7 @@ def get_small_two_agent_env():
env = RailEnv(width=width,
height=height,
rail_generator=rail_generator,
schedule_generator=schedule_generator,
line_generator=line_generator,
number_of_agents=nr_trains,
obs_builder_object=observation_builder,
#malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment