Skip to content
Snippets Groups Projects
Commit e4399082 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

Change speed data to speed counter

parent 8a3a043c
No related branches found
No related tags found
No related merge requests found
...@@ -150,7 +150,7 @@ class ControllerFromTrainruns(): ...@@ -150,7 +150,7 @@ class ControllerFromTrainruns():
def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan: def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan:
action_plan = [] action_plan = []
agent = self.env.agents[agent_id] agent = self.env.agents[agent_id]
minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed'])) minimum_cell_time = agent.speed_counter.max_count
for path_loop, trainrun_waypoint in enumerate(trainrun): for path_loop, trainrun_waypoint in enumerate(trainrun):
trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint
......
...@@ -30,6 +30,8 @@ class ControllerFromTrainrunsReplayer(): ...@@ -30,6 +30,8 @@ class ControllerFromTrainrunsReplayer():
assert agent.position == waypoint.position, \ assert agent.position == waypoint.position, \
"before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position, "before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
waypoint.position) waypoint.position)
if agent_id == 1:
print(env._elapsed_steps, agent.position, agent.state, agent.speed_counter)
actions = ctl.act(i) actions = ctl.act(i)
print("actions for {}: {}".format(i, actions)) print("actions for {}: {}".format(i, actions))
......
from flatland.envs.rail_trainrun_data_structures import Waypoint from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np import numpy as np
import warnings
from typing import Tuple, Optional, NamedTuple, List from typing import Tuple, Optional, NamedTuple, List
...@@ -21,7 +22,6 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]), ...@@ -21,7 +22,6 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('moving', bool), ('moving', bool),
('earliest_departure', int), ('earliest_departure', int),
('latest_arrival', int), ('latest_arrival', int),
('speed_data', dict),
('malfunction_data', dict), ('malfunction_data', dict),
('handle', int), ('handle', int),
('position', Tuple[int, int]), ('position', Tuple[int, int]),
...@@ -49,13 +49,6 @@ class EnvAgent: ...@@ -49,13 +49,6 @@ class EnvAgent:
earliest_departure = attrib(default=None, type=int) # default None during _from_line() earliest_departure = attrib(default=None, type=int) # default None during _from_line()
latest_arrival = attrib(default=None, type=int) # default None during _from_line() latest_arrival = attrib(default=None, type=int) # default None during _from_line()
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default)
# N.B. we need to use factory since default arguments are not recreated on each call!
speed_data = attrib(
default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
# if broken>0, the agent's actions are ignored for 'broken' steps # if broken>0, the agent's actions are ignored for 'broken' steps
# number of time the agent had to stop, since the last time it broke down # number of time the agent had to stop, since the last time it broke down
malfunction_data = attrib( malfunction_data = attrib(
...@@ -67,7 +60,7 @@ class EnvAgent: ...@@ -67,7 +60,7 @@ class EnvAgent:
# INIT TILL HERE IN _from_line() # INIT TILL HERE IN _from_line()
# Env step facelift # Env step facelift
speed_counter = attrib(default = None, type=SpeedCounter) speed_counter = attrib(default = Factory(lambda: SpeedCounter(1.0)), type=SpeedCounter)
action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver) action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) , state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) ,
type=TrainStateMachine) type=TrainStateMachine)
...@@ -94,10 +87,6 @@ class EnvAgent: ...@@ -94,10 +87,6 @@ class EnvAgent:
self.old_direction = None self.old_direction = None
self.moving = False self.moving = False
# Reset agent values for speed
self.speed_data['position_fraction'] = 0.
self.speed_data['transition_action_on_cellexit'] = 0.
# Reset agent malfunction values # Reset agent malfunction values
self.malfunction_data['malfunction'] = 0 self.malfunction_data['malfunction'] = 0
self.malfunction_data['nr_malfunctions'] = 0 self.malfunction_data['nr_malfunctions'] = 0
...@@ -115,7 +104,6 @@ class EnvAgent: ...@@ -115,7 +104,6 @@ class EnvAgent:
moving=self.moving, moving=self.moving,
earliest_departure=self.earliest_departure, earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival, latest_arrival=self.latest_arrival,
speed_data=self.speed_data,
malfunction_data=self.malfunction_data, malfunction_data=self.malfunction_data,
handle=self.handle, handle=self.handle,
state=self.state, state=self.state,
...@@ -137,7 +125,7 @@ class EnvAgent: ...@@ -137,7 +125,7 @@ class EnvAgent:
distance = len(shortest_path) distance = len(shortest_path)
else: else:
distance = 0 distance = 0
speed = self.speed_data['speed'] speed = self.speed_counter.speed
return int(np.ceil(distance / speed)) return int(np.ceil(distance / speed))
def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int: def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
...@@ -161,11 +149,6 @@ class EnvAgent: ...@@ -161,11 +149,6 @@ class EnvAgent:
agent_list = [] agent_list = []
for i_agent in range(num_agents): for i_agent in range(num_agents):
speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0 speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
speed_data = {'position_fraction': 0.0,
'speed': speed,
'transition_action_on_cellexit': 0
}
if line.agent_malfunction_rates is not None: if line.agent_malfunction_rates is not None:
malfunction_rate = line.agent_malfunction_rates[i_agent] malfunction_rate = line.agent_malfunction_rates[i_agent]
...@@ -177,7 +160,6 @@ class EnvAgent: ...@@ -177,7 +160,6 @@ class EnvAgent:
'next_malfunction': 0, 'next_malfunction': 0,
'nr_malfunctions': 0 'nr_malfunctions': 0
} }
agent = EnvAgent(initial_position = line.agent_positions[i_agent], agent = EnvAgent(initial_position = line.agent_positions[i_agent],
initial_direction = line.agent_directions[i_agent], initial_direction = line.agent_directions[i_agent],
direction = line.agent_directions[i_agent], direction = line.agent_directions[i_agent],
...@@ -185,7 +167,6 @@ class EnvAgent: ...@@ -185,7 +167,6 @@ class EnvAgent:
moving = False, moving = False,
earliest_departure = None, earliest_departure = None,
latest_arrival = None, latest_arrival = None,
speed_data = speed_data,
malfunction_data = malfunction_data, malfunction_data = malfunction_data,
handle = i_agent, handle = i_agent,
speed_counter = SpeedCounter(speed=speed)) speed_counter = SpeedCounter(speed=speed))
...@@ -195,6 +176,7 @@ class EnvAgent: ...@@ -195,6 +176,7 @@ class EnvAgent:
@classmethod @classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple): def load_legacy_static_agent(cls, static_agents_data: Tuple):
raise NotImplementedError("Not implemented for Flatland 3")
agents = [] agents = []
for i, static_agent in enumerate(static_agents_data): for i, static_agent in enumerate(static_agents_data):
if len(static_agent) >= 6: if len(static_agent) >= 6:
...@@ -205,16 +187,35 @@ class EnvAgent: ...@@ -205,16 +187,35 @@ class EnvAgent:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1], agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2], direction=static_agent[1], target=static_agent[2],
moving=False, moving=False,
speed_data={"speed":1., "position_fraction":0., "transition_action_on_cell_exit":0.},
malfunction_data={ malfunction_data={
'malfunction': 0, 'malfunction': 0,
'nr_malfunctions': 0, 'nr_malfunctions': 0,
'moving_before_malfunction': False 'moving_before_malfunction': False
}, },
speed_counter=SpeedCounter(1.0),
handle=i) handle=i)
agents.append(agent) agents.append(agent)
return agents return agents
def _set_state(self, state):
warnings.warn("Not recommended to set the state with this function unless completely required")
self.state_machine.set_state(state)
def __str__(self):
return f"\n \
handle(agent index): {self.handle} \n \
initial_position: {self.initial_position} initial_direction: {self.initial_direction} \n \
position: {self.position} direction: {self.position} target: {self.target} \n \
earliest_departure: {self.earliest_departure} latest_arrival: {self.latest_arrival} \n \
state: {str(self.state)} \n \
malfunction_data: {self.malfunction_data} \n \
action_saver: {self.action_saver} \n \
speed_counter: {self.speed_counter}"
@property @property
def state(self): def state(self):
return self.state_machine.state return self.state_machine.state
...@@ -189,7 +189,7 @@ def line_from_file(filename, load_from_package=None) -> LineGenerator: ...@@ -189,7 +189,7 @@ def line_from_file(filename, load_from_package=None) -> LineGenerator:
#agents_direction = [a.direction for a in agents] #agents_direction = [a.direction for a in agents]
agents_direction = [a.initial_direction for a in agents] agents_direction = [a.initial_direction for a in agents]
agents_target = [a.target for a in agents] agents_target = [a.target for a in agents]
agents_speed = [a.speed_data['speed'] for a in agents] agents_speed = [a.speed_counter.speed for a in agents]
# Malfunctions from here are not used. They have their own generator. # Malfunctions from here are not used. They have their own generator.
#agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents] #agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents]
......
...@@ -98,7 +98,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -98,7 +98,7 @@ class TreeObsForRailEnv(ObservationBuilder):
_agent.position: _agent.position:
self.location_has_agent[tuple(_agent.position)] = 1 self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_counter.speed
self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[ self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
'malfunction'] 'malfunction']
...@@ -221,7 +221,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -221,7 +221,7 @@ class TreeObsForRailEnv(ObservationBuilder):
agent.direction)], agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0, num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_data['malfunction'], num_agents_malfunctioning=agent.malfunction_data['malfunction'],
speed_min_fractional=agent.speed_data['speed'], speed_min_fractional=agent.speed_counter.speed
num_agents_ready_to_depart=0, num_agents_ready_to_depart=0,
childs={}) childs={})
#print("root node type:", type(root_node_observation)) #print("root node type:", type(root_node_observation))
...@@ -275,7 +275,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -275,7 +275,7 @@ class TreeObsForRailEnv(ObservationBuilder):
visited = OrderedSet() visited = OrderedSet()
agent = self.env.agents[handle] agent = self.env.agents[handle]
time_per_cell = np.reciprocal(agent.speed_data["speed"]) time_per_cell = np.reciprocal(agent.speed_counter.speed)
own_target_encountered = np.inf own_target_encountered = np.inf
other_agent_encountered = np.inf other_agent_encountered = np.inf
other_target_encountered = np.inf other_target_encountered = np.inf
...@@ -604,7 +604,7 @@ class GlobalObsForRailEnv(ObservationBuilder): ...@@ -604,7 +604,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
if i != handle: if i != handle:
obs_agents_state[other_agent.position][1] = other_agent.direction obs_agents_state[other_agent.position][1] = other_agent.direction
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] obs_agents_state[other_agent.position][3] = other_agent.speed_counter.speed
# fifth channel: all ready to depart on this position # fifth channel: all ready to depart on this position
if other_agent.state.is_off_map_state(): if other_agent.state.is_off_map_state():
obs_agents_state[other_agent.initial_position][4] += 1 obs_agents_state[other_agent.initial_position][4] += 1
......
...@@ -141,7 +141,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -141,7 +141,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
continue continue
agent_virtual_direction = agent.direction agent_virtual_direction = agent.direction
agent_speed = agent.speed_data["speed"] agent_speed = agent.speed_counter.speed
times_per_cell = int(np.reciprocal(agent_speed)) times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0] prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
......
...@@ -261,8 +261,7 @@ class RailEnv(Environment): ...@@ -261,8 +261,7 @@ class RailEnv(Environment):
False: Agent cannot provide an action False: Agent cannot provide an action
""" """
return agent.state == TrainState.READY_TO_DEPART or \ return agent.state == TrainState.READY_TO_DEPART or \
(agent.state.is_on_map_state() and \ (agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03) )
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *, def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
random_seed: bool = None) -> Tuple[Dict, Dict]: random_seed: bool = None) -> Tuple[Dict, Dict]:
...@@ -344,19 +343,6 @@ class RailEnv(Environment): ...@@ -344,19 +343,6 @@ class RailEnv(Environment):
# Reset agents to initial states # Reset agents to initial states
self.reset_agents() self.reset_agents()
# for agent in self.agents:
# # Induce malfunctions
# if activate_agents:
# self.set_agent_active(agent)
# self._break_agent(agent)
# if agent.malfunction_data["malfunction"] > 0:
# agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
# # Fix agents that finished their malfunction
# self._fix_agent_after_malfunction(agent)
self.num_resets += 1 self.num_resets += 1
self._elapsed_steps = 0 self._elapsed_steps = 0
...@@ -369,14 +355,7 @@ class RailEnv(Environment): ...@@ -369,14 +355,7 @@ class RailEnv(Environment):
# Empty the episode store of agent positions # Empty the episode store of agent positions
self.cur_episode = [] self.cur_episode = []
info_dict: Dict = { info_dict = self.get_info_dict()
'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
'malfunction': {
i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
},
'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)},
'state': {i: agent.state for i, agent in enumerate(self.agents)}
}
# Return the new observation vectors for each agent # Return the new observation vectors for each agent
observation_dict: Dict = self._get_observations() observation_dict: Dict = self._get_observations()
return observation_dict, info_dict return observation_dict, info_dict
...@@ -469,10 +448,12 @@ class RailEnv(Environment): ...@@ -469,10 +448,12 @@ class RailEnv(Environment):
def get_info_dict(self): # TODO Important : Update this def get_info_dict(self): # TODO Important : Update this
info_dict = { info_dict = {
"action_required": {}, 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
"malfunction": {}, 'malfunction': {
"speed": {}, i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
"status": {}, },
'speed': {i: agent.speed_counter.speed for i, agent in enumerate(self.agents)},
'state': {i: agent.state for i, agent in enumerate(self.agents)}
} }
return info_dict return info_dict
......
...@@ -57,7 +57,7 @@ def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap, ...@@ -57,7 +57,7 @@ def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap,
shortest_paths_lengths = [len_handle_none(v) for k,v in shortest_paths.items()] shortest_paths_lengths = [len_handle_none(v) for k,v in shortest_paths.items()]
# Find mean_shortest_path_time # Find mean_shortest_path_time
agent_speeds = [agent.speed_data['speed'] for agent in agents] agent_speeds = [agent.speed_counter.speed for agent in agents]
agent_shortest_path_times = np.array(shortest_paths_lengths)/ np.array(agent_speeds) agent_shortest_path_times = np.array(shortest_paths_lengths)/ np.array(agent_speeds)
mean_shortest_path_time = np.mean(agent_shortest_path_times) mean_shortest_path_time = np.mean(agent_shortest_path_times)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment