Commit e4399082 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

Change speed data to speed counter

parent 8a3a043c
Pipeline #8453 failed with stages
in 5 minutes and 31 seconds
......@@ -150,7 +150,7 @@ class ControllerFromTrainruns():
def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan:
action_plan = []
agent = self.env.agents[agent_id]
minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed']))
minimum_cell_time = agent.speed_counter.max_count
for path_loop, trainrun_waypoint in enumerate(trainrun):
trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint
......
......@@ -30,6 +30,8 @@ class ControllerFromTrainrunsReplayer():
assert agent.position == waypoint.position, \
"before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
waypoint.position)
if agent_id == 1:
print(env._elapsed_steps, agent.position, agent.state, agent.speed_counter)
actions = ctl.act(i)
print("actions for {}: {}".format(i, actions))
......
from flatland.envs.rail_trainrun_data_structures import Waypoint
import numpy as np
import warnings
from typing import Tuple, Optional, NamedTuple, List
......@@ -21,7 +22,6 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('moving', bool),
('earliest_departure', int),
('latest_arrival', int),
('speed_data', dict),
('malfunction_data', dict),
('handle', int),
('position', Tuple[int, int]),
......@@ -49,13 +49,6 @@ class EnvAgent:
earliest_departure = attrib(default=None, type=int) # default None during _from_line()
latest_arrival = attrib(default=None, type=int) # default None during _from_line()
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default)
# N.B. we need to use factory since default arguments are not recreated on each call!
speed_data = attrib(
default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
# if broken>0, the agent's actions are ignored for 'broken' steps
# number of time the agent had to stop, since the last time it broke down
malfunction_data = attrib(
......@@ -67,7 +60,7 @@ class EnvAgent:
# INIT TILL HERE IN _from_line()
# Env step facelift
speed_counter = attrib(default = None, type=SpeedCounter)
speed_counter = attrib(default = Factory(lambda: SpeedCounter(1.0)), type=SpeedCounter)
action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver)
state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) ,
type=TrainStateMachine)
......@@ -94,10 +87,6 @@ class EnvAgent:
self.old_direction = None
self.moving = False
# Reset agent values for speed
self.speed_data['position_fraction'] = 0.
self.speed_data['transition_action_on_cellexit'] = 0.
# Reset agent malfunction values
self.malfunction_data['malfunction'] = 0
self.malfunction_data['nr_malfunctions'] = 0
......@@ -115,7 +104,6 @@ class EnvAgent:
moving=self.moving,
earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival,
speed_data=self.speed_data,
malfunction_data=self.malfunction_data,
handle=self.handle,
state=self.state,
......@@ -137,7 +125,7 @@ class EnvAgent:
distance = len(shortest_path)
else:
distance = 0
speed = self.speed_data['speed']
speed = self.speed_counter.speed
return int(np.ceil(distance / speed))
def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int:
......@@ -161,11 +149,6 @@ class EnvAgent:
agent_list = []
for i_agent in range(num_agents):
speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
speed_data = {'position_fraction': 0.0,
'speed': speed,
'transition_action_on_cellexit': 0
}
if line.agent_malfunction_rates is not None:
malfunction_rate = line.agent_malfunction_rates[i_agent]
......@@ -177,7 +160,6 @@ class EnvAgent:
'next_malfunction': 0,
'nr_malfunctions': 0
}
agent = EnvAgent(initial_position = line.agent_positions[i_agent],
initial_direction = line.agent_directions[i_agent],
direction = line.agent_directions[i_agent],
......@@ -185,7 +167,6 @@ class EnvAgent:
moving = False,
earliest_departure = None,
latest_arrival = None,
speed_data = speed_data,
malfunction_data = malfunction_data,
handle = i_agent,
speed_counter = SpeedCounter(speed=speed))
......@@ -195,6 +176,7 @@ class EnvAgent:
@classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple):
raise NotImplementedError("Not implemented for Flatland 3")
agents = []
for i, static_agent in enumerate(static_agents_data):
if len(static_agent) >= 6:
......@@ -205,16 +187,35 @@ class EnvAgent:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2],
moving=False,
speed_data={"speed":1., "position_fraction":0., "transition_action_on_cell_exit":0.},
malfunction_data={
'malfunction': 0,
'nr_malfunctions': 0,
'moving_before_malfunction': False
},
speed_counter=SpeedCounter(1.0),
handle=i)
agents.append(agent)
return agents
def _set_state(self, state):
warnings.warn("Not recommended to set the state with this function unless completely required")
self.state_machine.set_state(state)
def __str__(self):
return f"\n \
handle(agent index): {self.handle} \n \
initial_position: {self.initial_position} initial_direction: {self.initial_direction} \n \
position: {self.position} direction: {self.position} target: {self.target} \n \
earliest_departure: {self.earliest_departure} latest_arrival: {self.latest_arrival} \n \
state: {str(self.state)} \n \
malfunction_data: {self.malfunction_data} \n \
action_saver: {self.action_saver} \n \
speed_counter: {self.speed_counter}"
@property
def state(self):
return self.state_machine.state
......@@ -189,7 +189,7 @@ def line_from_file(filename, load_from_package=None) -> LineGenerator:
#agents_direction = [a.direction for a in agents]
agents_direction = [a.initial_direction for a in agents]
agents_target = [a.target for a in agents]
agents_speed = [a.speed_data['speed'] for a in agents]
agents_speed = [a.speed_counter.speed for a in agents]
# Malfunctions from here are not used. They have their own generator.
#agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents]
......
......@@ -98,7 +98,7 @@ class TreeObsForRailEnv(ObservationBuilder):
_agent.position:
self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_counter.speed
self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
'malfunction']
......@@ -221,7 +221,7 @@ class TreeObsForRailEnv(ObservationBuilder):
agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_data['malfunction'],
speed_min_fractional=agent.speed_data['speed'],
speed_min_fractional=agent.speed_counter.speed
num_agents_ready_to_depart=0,
childs={})
#print("root node type:", type(root_node_observation))
......@@ -275,7 +275,7 @@ class TreeObsForRailEnv(ObservationBuilder):
visited = OrderedSet()
agent = self.env.agents[handle]
time_per_cell = np.reciprocal(agent.speed_data["speed"])
time_per_cell = np.reciprocal(agent.speed_counter.speed)
own_target_encountered = np.inf
other_agent_encountered = np.inf
other_target_encountered = np.inf
......@@ -604,7 +604,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
if i != handle:
obs_agents_state[other_agent.position][1] = other_agent.direction
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
obs_agents_state[other_agent.position][3] = other_agent.speed_counter.speed
# fifth channel: all ready to depart on this position
if other_agent.state.is_off_map_state():
obs_agents_state[other_agent.initial_position][4] += 1
......
......@@ -141,7 +141,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
continue
agent_virtual_direction = agent.direction
agent_speed = agent.speed_data["speed"]
agent_speed = agent.speed_counter.speed
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
......
......@@ -261,8 +261,7 @@ class RailEnv(Environment):
False: Agent cannot provide an action
"""
return agent.state == TrainState.READY_TO_DEPART or \
(agent.state.is_on_map_state() and \
fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03) )
(agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
random_seed: bool = None) -> Tuple[Dict, Dict]:
......@@ -344,19 +343,6 @@ class RailEnv(Environment):
# Reset agents to initial states
self.reset_agents()
# for agent in self.agents:
# # Induce malfunctions
# if activate_agents:
# self.set_agent_active(agent)
# self._break_agent(agent)
# if agent.malfunction_data["malfunction"] > 0:
# agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
# # Fix agents that finished their malfunction
# self._fix_agent_after_malfunction(agent)
self.num_resets += 1
self._elapsed_steps = 0
......@@ -369,14 +355,7 @@ class RailEnv(Environment):
# Empty the episode store of agent positions
self.cur_episode = []
info_dict: Dict = {
'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
'malfunction': {
i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
},
'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)},
'state': {i: agent.state for i, agent in enumerate(self.agents)}
}
info_dict = self.get_info_dict()
# Return the new observation vectors for each agent
observation_dict: Dict = self._get_observations()
return observation_dict, info_dict
......@@ -469,10 +448,12 @@ class RailEnv(Environment):
def get_info_dict(self): # TODO Important : Update this
info_dict = {
"action_required": {},
"malfunction": {},
"speed": {},
"status": {},
'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
'malfunction': {
i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
},
'speed': {i: agent.speed_counter.speed for i, agent in enumerate(self.agents)},
'state': {i: agent.state for i, agent in enumerate(self.agents)}
}
return info_dict
......
......@@ -57,7 +57,7 @@ def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap,
shortest_paths_lengths = [len_handle_none(v) for k,v in shortest_paths.items()]
# Find mean_shortest_path_time
agent_speeds = [agent.speed_data['speed'] for agent in agents]
agent_speeds = [agent.speed_counter.speed for agent in agents]
agent_shortest_path_times = np.array(shortest_paths_lengths)/ np.array(agent_speeds)
mean_shortest_path_time = np.mean(agent_shortest_path_times)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment