Commit 6d9793bb authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

Merge env-step-facelift to flatland-3-updates

parents 636faf72 57b15b9f
Pipeline #8496 failed with stages
in 7 minutes and 53 seconds
......@@ -120,4 +120,6 @@ test_save.dat
playground/
**/tmp
**/TEMP
\ No newline at end of file
**/TEMP
*.pkl
......@@ -150,7 +150,7 @@ class ControllerFromTrainruns():
def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan:
action_plan = []
agent = self.env.agents[agent_id]
minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed']))
minimum_cell_time = agent.speed_counter.max_count + 1
for path_loop, trainrun_waypoint in enumerate(trainrun):
trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint
......
......@@ -31,7 +31,6 @@ class ControllerFromTrainrunsReplayer():
"before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
waypoint.position)
actions = ctl.act(i)
print("actions for {}: {}".format(i, actions))
obs, all_rewards, done, _ = env.step(actions)
......
......@@ -218,21 +218,21 @@ class MotionCheck(object):
if "color" in dAttr:
sColor = dAttr["color"]
if sColor in [ "red", "purple" ]:
return (False, rcPos)
return False
dSucc = self.G.succ[rcPos]
# This should never happen - only the next cell of an agent has no successor
if len(dSucc)==0:
print(f"error condition - agent {iAgent} node {rcPos} has no successor")
return (False, rcPos)
return False
# This agent has a successor
rcNext = self.G.successors(rcPos).__next__()
if rcNext == rcPos: # the agent didn't want to move
return (False, rcNext)
return False
# The agent wanted to move, and it can
return (True, rcNext)
return True
......
from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np
import warnings
from enum import IntEnum
from itertools import starmap
from typing import Tuple, Optional, NamedTuple, List
from attr import attr, attrs, attrib, Factory
......@@ -10,13 +9,11 @@ from attr import attr, attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.timetable_utils import Line
class RailAgentStatus(IntEnum):
WAITING = 0
READY_TO_DEPART = 1 # not in grid yet (position is None) -> prediction as if it were at initial position
ACTIVE = 2 # in grid (position is not None), not done -> prediction is remaining path
DONE = 3 # in grid (position is not None), but done -> prediction is stay at target forever
DONE_REMOVED = 4 # removed from grid (position is None) -> prediction is None
from flatland.envs.step_utils.action_saver import ActionSaver
from flatland.envs.step_utils.speed_counter import SpeedCounter
from flatland.envs.step_utils.state_machine import TrainStateMachine
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('initial_direction', Grid4TransitionsEnum),
......@@ -25,15 +22,38 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('moving', bool),
('earliest_departure', int),
('latest_arrival', int),
('speed_data', dict),
('malfunction_data', dict),
('handle', int),
('status', RailAgentStatus),
('position', Tuple[int, int]),
('arrival_time', int),
('old_direction', Grid4TransitionsEnum),
('old_position', Tuple[int, int])])
('old_position', Tuple[int, int]),
('speed_counter', SpeedCounter),
('action_saver', ActionSaver),
('state_machine', TrainStateMachine),
('malfunction_handler', MalfunctionHandler),
])
def load_env_agent(agent_tuple: Agent):
return EnvAgent(
initial_position = agent_tuple.initial_position,
initial_direction = agent_tuple.initial_direction,
direction = agent_tuple.direction,
target = agent_tuple.target,
moving = agent_tuple.moving,
earliest_departure = agent_tuple.earliest_departure,
latest_arrival = agent_tuple.latest_arrival,
handle = agent_tuple.handle,
position = agent_tuple.position,
arrival_time = agent_tuple.arrival_time,
old_direction = agent_tuple.old_direction,
old_position = agent_tuple.old_position,
speed_counter = agent_tuple.speed_counter,
action_saver = agent_tuple.action_saver,
state_machine = agent_tuple.state_machine,
malfunction_handler = agent_tuple.malfunction_handler,
)
@attrs
class EnvAgent:
......@@ -48,13 +68,6 @@ class EnvAgent:
earliest_departure = attrib(default=None, type=int) # default None during _from_line()
latest_arrival = attrib(default=None, type=int) # default None during _from_line()
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default)
# N.B. we need to use factory since default arguments are not recreated on each call!
speed_data = attrib(
default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
# if broken>0, the agent's actions are ignored for 'broken' steps
# number of time the agent had to stop, since the last time it broke down
malfunction_data = attrib(
......@@ -65,7 +78,13 @@ class EnvAgent:
handle = attrib(default=None)
# INIT TILL HERE IN _from_line()
status = attrib(default=RailAgentStatus.WAITING, type=RailAgentStatus)
# Env step facelift
speed_counter = attrib(default = Factory(lambda: SpeedCounter(1.0)), type=SpeedCounter)
action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) ,
type=TrainStateMachine)
malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
position = attrib(default=None, type=Optional[Tuple[int, int]])
# NEW : EnvAgent Reward Handling
......@@ -75,6 +94,7 @@ class EnvAgent:
old_direction = attrib(default=None)
old_position = attrib(default=None)
def reset(self):
"""
Resets the agents to their initial values of the episode. Called after ScheduleTime generation.
......@@ -82,28 +102,38 @@ class EnvAgent:
self.position = None
# TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280
self.direction = self.initial_direction
if (self.earliest_departure == 0):
self.status = RailAgentStatus.READY_TO_DEPART
else:
self.status = RailAgentStatus.WAITING
self.arrival_time = None
self.old_position = None
self.old_direction = None
self.moving = False
# Reset agent values for speed
self.speed_data['position_fraction'] = 0.
self.speed_data['transition_action_on_cellexit'] = 0.
# Reset agent malfunction values
self.malfunction_data['malfunction'] = 0
self.malfunction_data['nr_malfunctions'] = 0
self.malfunction_data['moving_before_malfunction'] = False
# NEW : Callables
self.action_saver.clear_saved_action()
self.speed_counter.reset_counter()
self.state_machine.reset()
def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position,
initial_direction=self.initial_direction,
direction=self.direction,
target=self.target,
moving=self.moving,
earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival,
malfunction_data=self.malfunction_data,
handle=self.handle,
position=self.position,
old_direction=self.old_direction,
old_position=self.old_position,
speed_counter=self.speed_counter,
action_saver=self.action_saver,
arrival_time=self.arrival_time,
state_machine=self.state_machine,
malfunction_handler=self.malfunction_handler)
def get_shortest_path(self, distance_map) -> List[Waypoint]:
from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix
return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
......@@ -114,7 +144,7 @@ class EnvAgent:
distance = len(shortest_path)
else:
distance = 0
speed = self.speed_data['speed']
speed = self.speed_counter.speed
return int(np.ceil(distance / speed))
def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
......@@ -128,42 +158,40 @@ class EnvAgent:
return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
self.get_travel_time_on_shortest_path(distance_map)
def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction,
direction=self.direction, target=self.target, moving=self.moving, earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival, speed_data=self.speed_data, malfunction_data=self.malfunction_data,
handle=self.handle, status=self.status, position=self.position, arrival_time=self.arrival_time,
old_direction=self.old_direction, old_position=self.old_position)
@classmethod
def from_line(cls, line: Line):
""" Create a list of EnvAgent from lists of positions, directions and targets
"""
speed_datas = []
for i in range(len(line.agent_positions)):
speed_datas.append({'position_fraction': 0.0,
'speed': line.agent_speeds[i] if line.agent_speeds is not None else 1.0,
'transition_action_on_cellexit': 0})
malfunction_datas = []
for i in range(len(line.agent_positions)):
malfunction_datas.append({'malfunction': 0,
'malfunction_rate': line.agent_malfunction_rates[
i] if line.agent_malfunction_rates is not None else 0.,
'next_malfunction': 0,
'nr_malfunctions': 0})
return list(starmap(EnvAgent, zip(line.agent_positions,
line.agent_directions,
line.agent_directions,
line.agent_targets,
[False] * len(line.agent_positions),
[None] * len(line.agent_positions), # earliest_departure
[None] * len(line.agent_positions), # latest_arrival
speed_datas,
malfunction_datas,
range(len(line.agent_positions)))))
num_agents = len(line.agent_positions)
agent_list = []
for i_agent in range(num_agents):
speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
if line.agent_malfunction_rates is not None:
malfunction_rate = line.agent_malfunction_rates[i_agent]
else:
malfunction_rate = 0.
malfunction_data = {'malfunction': 0,
'malfunction_rate': malfunction_rate,
'next_malfunction': 0,
'nr_malfunctions': 0
}
agent = EnvAgent(initial_position = line.agent_positions[i_agent],
initial_direction = line.agent_directions[i_agent],
direction = line.agent_directions[i_agent],
target = line.agent_targets[i_agent],
moving = False,
earliest_departure = None,
latest_arrival = None,
malfunction_data = malfunction_data,
handle = i_agent,
speed_counter = SpeedCounter(speed=speed))
agent_list.append(agent)
return agent_list
@classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple):
......@@ -172,17 +200,46 @@ class EnvAgent:
if len(static_agent) >= 6:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i)
speed_counter=SpeedCounter(static_agent[4]['speed']), malfunction_data=static_agent[5],
handle=i)
else:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2],
moving=False,
speed_data={"speed":1., "position_fraction":0., "transition_action_on_cell_exit":0.},
malfunction_data={
'malfunction': 0,
'nr_malfunctions': 0,
'moving_before_malfunction': False
},
speed_counter=SpeedCounter(1.0),
handle=i)
agents.append(agent)
return agents
def __str__(self):
return f"\n \
handle(agent index): {self.handle} \n \
initial_position: {self.initial_position} initial_direction: {self.initial_direction} \n \
position: {self.position} direction: {self.direction} target: {self.target} \n \
old_position: {self.old_position} old_direction {self.old_direction} \n \
earliest_departure: {self.earliest_departure} latest_arrival: {self.latest_arrival} \n \
state: {str(self.state)} \n \
malfunction_handler: {self.malfunction_handler} \n \
action_saver: {self.action_saver} \n \
speed_counter: {self.speed_counter}"
@property
def state(self):
return self.state_machine.state
@state.setter
def state(self, state):
self._set_state(state)
def _set_state(self, state):
warnings.warn("Not recommended to set the state with this function unless completely required")
self.state_machine.set_state(state)
......@@ -84,11 +84,6 @@ class SparseLineGen(BaseLineGen):
train_stations = hints['train_stations']
city_positions = hints['city_positions']
city_orientation = hints['city_orientations']
max_num_agents = hints['num_agents']
city_orientations = hints['city_orientations']
if num_agents > max_num_agents:
num_agents = max_num_agents
warnings.warn("Too many agents! Changes number of agents.")
# Place agents and targets within available train stations
agents_position = []
agents_target = []
......@@ -189,7 +184,7 @@ def line_from_file(filename, load_from_package=None) -> LineGenerator:
#agents_direction = [a.direction for a in agents]
agents_direction = [a.initial_direction for a in agents]
agents_target = [a.target for a in agents]
agents_speed = [a.speed_data['speed'] for a in agents]
agents_speed = [a.speed_counter.speed for a in agents]
# Malfunctions from here are not used. They have their own generator.
#agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents]
......
......@@ -5,7 +5,8 @@ from typing import Callable, NamedTuple, Optional, Tuple
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
from flatland.envs import persistence
......@@ -18,7 +19,7 @@ MalfunctionProcessData = NamedTuple('MalfunctionProcessData',
Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
# Why is the return value Optional? We always return a Malfunction.
MalfunctionGenerator = Callable[[EnvAgent, RandomState, bool], Optional[Malfunction]]
MalfunctionGenerator = Callable[[RandomState, bool], Malfunction]
def _malfunction_prob(rate: float) -> float:
"""
......@@ -42,21 +43,14 @@ class ParamMalfunctionGen(object):
#self.max_number_of_steps_broken = parameters.max_duration
self.MFP = parameters
def generate(self,
agent: EnvAgent = None,
np_random: RandomState = None,
reset=False) -> Optional[Malfunction]:
# Dummy reset function as we don't implement specific seeding here
if reset:
return Malfunction(0)
def generate(self, np_random: RandomState) -> Malfunction:
if agent.malfunction_data['malfunction'] < 1:
if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate):
num_broken_steps = np_random.randint(self.MFP.min_duration,
self.MFP.max_duration + 1) + 1
return Malfunction(num_broken_steps)
return Malfunction(0)
if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate):
num_broken_steps = np_random.randint(self.MFP.min_duration,
self.MFP.max_duration + 1) + 1
else:
num_broken_steps = 0
return Malfunction(num_broken_steps)
def get_process_data(self):
return MalfunctionProcessData(*self.MFP)
......@@ -103,7 +97,7 @@ def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcess
min_number_of_steps_broken = 0
max_number_of_steps_broken = 0
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
def generator(np_random: RandomState = None) -> Malfunction:
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
......@@ -162,7 +156,8 @@ def single_malfunction_generator(earlierst_malfunction: int, malfunction_duratio
malfunction_calls[agent.handle] = 1
# Break an agent that is active at the time of the malfunction
if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction:
if (agent.state == TrainState.MOVING or agent.state == TrainState.STOPPED) \
and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed?
global_nr_malfunctions += 1
return Malfunction(malfunction_duration)
else:
......@@ -258,7 +253,7 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct
min_number_of_steps_broken = parameters.min_duration
max_number_of_steps_broken = parameters.max_duration
def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
def generator(np_random: RandomState = None, reset=False) -> Optional[Malfunction]:
"""
Generate malfunctions for agents
Parameters
......@@ -275,11 +270,10 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct
if reset:
return Malfunction(0)
if agent.malfunction_data['malfunction'] < 1:
if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
num_broken_steps = np_random.randint(min_number_of_steps_broken,
max_number_of_steps_broken + 1) + 1
return Malfunction(num_broken_steps)
if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
num_broken_steps = np_random.randint(min_number_of_steps_broken,
max_number_of_steps_broken + 1)
return Malfunction(num_broken_steps)
return Malfunction(0)
return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken,
......
......@@ -11,7 +11,8 @@ from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
from flatland.utils.ordered_set import OrderedSet
......@@ -93,16 +94,16 @@ class TreeObsForRailEnv(ObservationBuilder):
self.location_has_agent_ready_to_depart = {}
for _agent in self.env.agents:
if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
if not _agent.state.is_off_map_state() and \
_agent.position:
self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_counter.speed
self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
'malfunction']
# [NIMISH] WHAT IS THIS
if _agent.status in [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.WAITING] and \
if _agent.state.is_off_map_state() and \
_agent.initial_position:
self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0)
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1
......@@ -195,14 +196,12 @@ class TreeObsForRailEnv(ObservationBuilder):
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
if agent.status == RailAgentStatus.WAITING:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.READY_TO_DEPART:
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
......@@ -222,7 +221,7 @@ class TreeObsForRailEnv(ObservationBuilder):
agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_data['malfunction'],
speed_min_fractional=agent.speed_data['speed'],
speed_min_fractional=agent.speed_counter.speed,
num_agents_ready_to_depart=0,
childs={})
#print("root node type:", type(root_node_observation))
......@@ -276,7 +275,7 @@ class TreeObsForRailEnv(ObservationBuilder):
visited = OrderedSet()
agent = self.env.agents[handle]
time_per_cell = np.reciprocal(agent.speed_data["speed"])
time_per_cell = np.reciprocal(agent.speed_counter.speed)
own_target_encountered = np.inf
other_agent_encountered = np.inf
other_target_encountered = np.inf
......@@ -342,7 +341,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self._reverse_dir(
self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for conflicting paths at distance num_step-1
......@@ -353,7 +352,7 @@ class TreeObsForRailEnv(ObservationBuilder):
and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for conflicting paths at distance num_step+1
......@@ -364,7 +363,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self.predicted_dir[post_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
if self.env.agents[ca].state == TrainState.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
if position in self.location_has_target and position != agent.target:
......@@ -569,13 +568,11 @@ class GlobalObsForRailEnv(ObservationBuilder):
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
agent = self.env.agents[handle]
if agent.status == RailAgentStatus.WAITING:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.READY_TO_DEPART:
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
......@@ -596,7 +593,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
other_agent: EnvAgent = self.env.agents[i]
# ignore other agents not in the grid any more
if other_agent.status == RailAgentStatus.DONE_REMOVED:
if other_agent.state == TrainState.DONE:
continue
obs_targets[other_agent.target][1] = 1
......@@ -607,9 +604,9 @@ class GlobalObsForRailEnv(ObservationBuilder):
if i != handle:
obs_agents_state[other_agent.position][1] = other_agent.direction
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
obs_agents_state[other_agent.position][3] = other_agent.speed_counter.speed
# fifth channel: all ready to depart on this position
if other_agent.status == RailAgentStatus.READY_TO_DEPART or other_agent.status == RailAgentStatus.WAITING:
if other_agent.state.is_off_map_state():
obs_agents_state[other_agent.initial_position][4] += 1
return self.rail_obs, obs_agents_state, obs_targets
......
......@@ -2,28 +2,21 @@
import pickle
import msgpack
import msgpack_numpy
import numpy as np