diff --git a/flatland/action_plan/action_plan.py b/flatland/action_plan/action_plan.py index 249c4c0ee12bb8a79c06842a59108bd4f3ce6c5c..b1a56d81839bff62f13a27753a935a19a8d05fe9 100644 --- a/flatland/action_plan/action_plan.py +++ b/flatland/action_plan/action_plan.py @@ -150,7 +150,7 @@ class ControllerFromTrainruns(): def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan: action_plan = [] agent = self.env.agents[agent_id] - minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed'])) + minimum_cell_time = agent.speed_counter.max_count for path_loop, trainrun_waypoint in enumerate(trainrun): trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint diff --git a/flatland/action_plan/action_plan_player.py b/flatland/action_plan/action_plan_player.py index f3deee133d8c99ffc5993005f1500e227be87b7e..074e5590185ff601f9c038e9df4c23fd2f84c455 100644 --- a/flatland/action_plan/action_plan_player.py +++ b/flatland/action_plan/action_plan_player.py @@ -30,6 +30,8 @@ class ControllerFromTrainrunsReplayer(): assert agent.position == waypoint.position, \ "before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position, waypoint.position) + if agent_id == 1: + print(env._elapsed_steps, agent.position, agent.state, agent.speed_counter) actions = ctl.act(i) print("actions for {}: {}".format(i, actions)) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 4dee6dde0f5a938d81e5cd970332223a9f6b841b..6dff63e18e505d6ff7cb8280b53f63178c3f1921 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -1,5 +1,6 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint import numpy as np +import warnings from typing import Tuple, Optional, NamedTuple, List @@ -21,7 +22,6 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]), ('moving', bool), ('earliest_departure', int), ('latest_arrival', int), - ('speed_data', dict), ('malfunction_data', dict), ('handle', int), ('position', Tuple[int, int]), @@ -49,13 +49,6 @@ class EnvAgent: earliest_departure = attrib(default=None, type=int) # default None during _from_line() latest_arrival = attrib(default=None, type=int) # default None during _from_line() - # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0, - # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous - # cell if speed=1, as default) - # N.B. we need to use factory since default arguments are not recreated on each call! - speed_data = attrib( - default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}))) - # if broken>0, the agent's actions are ignored for 'broken' steps # number of time the agent had to stop, since the last time it broke down malfunction_data = attrib( @@ -67,7 +60,7 @@ class EnvAgent: # INIT TILL HERE IN _from_line() # Env step facelift - speed_counter = attrib(default = None, type=SpeedCounter) + speed_counter = attrib(default = Factory(lambda: SpeedCounter(1.0)), type=SpeedCounter) action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver) state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) , type=TrainStateMachine) @@ -94,10 +87,6 @@ class EnvAgent: self.old_direction = None self.moving = False - # Reset agent values for speed - self.speed_data['position_fraction'] = 0. - self.speed_data['transition_action_on_cellexit'] = 0. - # Reset agent malfunction values self.malfunction_data['malfunction'] = 0 self.malfunction_data['nr_malfunctions'] = 0 @@ -115,7 +104,6 @@ class EnvAgent: moving=self.moving, earliest_departure=self.earliest_departure, latest_arrival=self.latest_arrival, - speed_data=self.speed_data, malfunction_data=self.malfunction_data, handle=self.handle, state=self.state, @@ -137,7 +125,7 @@ class EnvAgent: distance = len(shortest_path) else: distance = 0 - speed = self.speed_data['speed'] + speed = self.speed_counter.speed return int(np.ceil(distance / speed)) def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int: @@ -161,11 +149,6 @@ class EnvAgent: agent_list = [] for i_agent in range(num_agents): speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0 - - speed_data = {'position_fraction': 0.0, - 'speed': speed, - 'transition_action_on_cellexit': 0 - } if line.agent_malfunction_rates is not None: malfunction_rate = line.agent_malfunction_rates[i_agent] @@ -177,7 +160,6 @@ class EnvAgent: 'next_malfunction': 0, 'nr_malfunctions': 0 } - agent = EnvAgent(initial_position = line.agent_positions[i_agent], initial_direction = line.agent_directions[i_agent], direction = line.agent_directions[i_agent], @@ -185,7 +167,6 @@ class EnvAgent: moving = False, earliest_departure = None, latest_arrival = None, - speed_data = speed_data, malfunction_data = malfunction_data, handle = i_agent, speed_counter = SpeedCounter(speed=speed)) @@ -195,6 +176,7 @@ class EnvAgent: @classmethod def load_legacy_static_agent(cls, static_agents_data: Tuple): + raise NotImplementedError("Not implemented for Flatland 3") agents = [] for i, static_agent in enumerate(static_agents_data): if len(static_agent) >= 6: @@ -205,16 +187,35 @@ class EnvAgent: agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1], direction=static_agent[1], target=static_agent[2], moving=False, - speed_data={"speed":1., "position_fraction":0., "transition_action_on_cell_exit":0.}, malfunction_data={ 'malfunction': 0, 'nr_malfunctions': 0, 'moving_before_malfunction': False }, + speed_counter=SpeedCounter(1.0), handle=i) agents.append(agent) return agents + def _set_state(self, state): + warnings.warn("Not recommended to set the state with this function unless completely required") + self.state_machine.set_state(state) + + def __str__(self): + return f"\n \ + handle(agent index): {self.handle} \n \ + initial_position: {self.initial_position} initial_direction: {self.initial_direction} \n \ + position: {self.position} direction: {self.position} target: {self.target} \n \ + earliest_departure: {self.earliest_departure} latest_arrival: {self.latest_arrival} \n \ + state: {str(self.state)} \n \ + malfunction_data: {self.malfunction_data} \n \ + action_saver: {self.action_saver} \n \ + speed_counter: {self.speed_counter}" + @property def state(self): return self.state_machine.state + + + + diff --git a/flatland/envs/line_generators.py b/flatland/envs/line_generators.py index 74d01e6f23856e9f14d2fbe70eb2bdbfb85175be..8b412783ca999c5383e102804928888d43aee32a 100644 --- a/flatland/envs/line_generators.py +++ b/flatland/envs/line_generators.py @@ -189,7 +189,7 @@ def line_from_file(filename, load_from_package=None) -> LineGenerator: #agents_direction = [a.direction for a in agents] agents_direction = [a.initial_direction for a in agents] agents_target = [a.target for a in agents] - agents_speed = [a.speed_data['speed'] for a in agents] + agents_speed = [a.speed_counter.speed for a in agents] # Malfunctions from here are not used. They have their own generator. #agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents] diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 1fc0a2e52faf3b228b46b2fd896852ba4c411f26..456d56a0c58fdbaefa5a2ff4c5e938b74618e1c1 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -98,7 +98,7 @@ class TreeObsForRailEnv(ObservationBuilder): _agent.position: self.location_has_agent[tuple(_agent.position)] = 1 self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction - self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] + self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_counter.speed self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[ 'malfunction'] @@ -221,7 +221,7 @@ class TreeObsForRailEnv(ObservationBuilder): agent.direction)], num_agents_same_direction=0, num_agents_opposite_direction=0, num_agents_malfunctioning=agent.malfunction_data['malfunction'], - speed_min_fractional=agent.speed_data['speed'], + speed_min_fractional=agent.speed_counter.speed num_agents_ready_to_depart=0, childs={}) #print("root node type:", type(root_node_observation)) @@ -275,7 +275,7 @@ class TreeObsForRailEnv(ObservationBuilder): visited = OrderedSet() agent = self.env.agents[handle] - time_per_cell = np.reciprocal(agent.speed_data["speed"]) + time_per_cell = np.reciprocal(agent.speed_counter.speed) own_target_encountered = np.inf other_agent_encountered = np.inf other_target_encountered = np.inf @@ -604,7 +604,7 @@ class GlobalObsForRailEnv(ObservationBuilder): if i != handle: obs_agents_state[other_agent.position][1] = other_agent.direction obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] - obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] + obs_agents_state[other_agent.position][3] = other_agent.speed_counter.speed # fifth channel: all ready to depart on this position if other_agent.state.is_off_map_state(): obs_agents_state[other_agent.initial_position][4] += 1 diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 4b097083fa98c5121df644b8f9d34b27fdc34a4b..8f6a191a7eec5ba0dfb44b1f8671f9841b01ff5b 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -141,7 +141,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): continue agent_virtual_direction = agent.direction - agent_speed = agent.speed_data["speed"] + agent_speed = agent.speed_counter.speed times_per_cell = int(np.reciprocal(agent_speed)) prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0] diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 46876ac953535f4c49b57036045b405c6b986cc3..2915e9be2b1d9537631f0639a6f20a9f05955d17 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -261,8 +261,7 @@ class RailEnv(Environment): False: Agent cannot provide an action """ return agent.state == TrainState.READY_TO_DEPART or \ - (agent.state.is_on_map_state() and \ - fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03) ) + (agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry ) def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *, random_seed: bool = None) -> Tuple[Dict, Dict]: @@ -344,19 +343,6 @@ class RailEnv(Environment): # Reset agents to initial states self.reset_agents() - # for agent in self.agents: - # # Induce malfunctions - # if activate_agents: - # self.set_agent_active(agent) - - # self._break_agent(agent) - - # if agent.malfunction_data["malfunction"] > 0: - # agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING - - # # Fix agents that finished their malfunction - # self._fix_agent_after_malfunction(agent) - self.num_resets += 1 self._elapsed_steps = 0 @@ -369,14 +355,7 @@ class RailEnv(Environment): # Empty the episode store of agent positions self.cur_episode = [] - info_dict: Dict = { - 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, - 'malfunction': { - i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents) - }, - 'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)}, - 'state': {i: agent.state for i, agent in enumerate(self.agents)} - } + info_dict = self.get_info_dict() # Return the new observation vectors for each agent observation_dict: Dict = self._get_observations() return observation_dict, info_dict @@ -469,10 +448,12 @@ class RailEnv(Environment): def get_info_dict(self): # TODO Important : Update this info_dict = { - "action_required": {}, - "malfunction": {}, - "speed": {}, - "status": {}, + 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, + 'malfunction': { + i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents) + }, + 'speed': {i: agent.speed_counter.speed for i, agent in enumerate(self.agents)}, + 'state': {i: agent.state for i, agent in enumerate(self.agents)} } return info_dict diff --git a/flatland/envs/timetable_generators.py b/flatland/envs/timetable_generators.py index b7876d742f61db830883f828faaf99a39a48bc65..d93c09199b315c488177febe4d1aa423b7a87894 100644 --- a/flatland/envs/timetable_generators.py +++ b/flatland/envs/timetable_generators.py @@ -57,7 +57,7 @@ def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap, shortest_paths_lengths = [len_handle_none(v) for k,v in shortest_paths.items()] # Find mean_shortest_path_time - agent_speeds = [agent.speed_data['speed'] for agent in agents] + agent_speeds = [agent.speed_counter.speed for agent in agents] agent_shortest_path_times = np.array(shortest_paths_lengths)/ np.array(agent_speeds) mean_shortest_path_time = np.mean(agent_shortest_path_times)