From e6223ce12d0cc4c7d30dddfeeba946414e7d5d0b Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipam@aicrowd.com> Date: Thu, 9 Sep 2021 20:30:56 +0530 Subject: [PATCH] remove list starmap init for agents --- flatland/envs/agent_utils.py | 69 ++++++++++++++++++++---------------- flatland/envs/rail_env.py | 1 - 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 91c6d72f..bf926371 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -73,8 +73,6 @@ class EnvAgent: state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) , type=TrainStateMachine) malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler) - - state = attrib(default=TrainState.WAITING, type=TrainState) position = attrib(default=None, type=Optional[Tuple[int, int]]) @@ -134,35 +132,42 @@ class EnvAgent: def from_line(cls, line: Line): """ Create a list of EnvAgent from lists of positions, directions and targets """ - speed_datas = [] - speed_counters = [] - for i in range(len(line.agent_positions)): - speed = line.agent_speeds[i] if line.agent_speeds is not None else 1.0 - speed_datas.append({'position_fraction': 0.0, - 'speed': speed, - 'transition_action_on_cellexit': 0}) - speed_counters.append( SpeedCounter(speed=speed) ) - - malfunction_datas = [] - for i in range(len(line.agent_positions)): - malfunction_datas.append({'malfunction': 0, - 'malfunction_rate': line.agent_malfunction_rates[ - i] if line.agent_malfunction_rates is not None else 0., - 'next_malfunction': 0, - 'nr_malfunctions': 0}) + num_agents = len(line.agent_positions) - return list(starmap(EnvAgent, zip(line.agent_positions, # TODO : Dipam - Really want to change this way of loading agents - line.agent_directions, - line.agent_directions, - line.agent_targets, - [False] * len(line.agent_positions), - [None] * len(line.agent_positions), # earliest_departure - [None] * len(line.agent_positions), # latest_arrival - speed_datas, - malfunction_datas, - range(len(line.agent_positions)), - speed_counters, - ))) + agent_list = [] + for i_agent in range(num_agents): + speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0 + + speed_data = {'position_fraction': 0.0, + 'speed': speed, + 'transition_action_on_cellexit': 0 + } + + if line.agent_malfunction_rates is not None: + malfunction_rate = line.agent_malfunction_rates[i_agent] + else: + malfunction_rate = 0. + + malfunction_data = {'malfunction': 0, + 'malfunction_rate': malfunction_rate, + 'next_malfunction': 0, + 'nr_malfunctions': 0 + } + + agent = EnvAgent(initial_position = line.agent_positions[i_agent], + initial_direction = line.agent_directions[i_agent], + direction = line.agent_directions[i_agent], + target = line.agent_targets[i_agent], + moving = False, + earliest_departure = None, + latest_arrival = None, + speed_data = speed_data, + malfunction_data = malfunction_data, + handle = i_agent, + speed_counter = SpeedCounter(speed=speed)) + agent_list.append(agent) + + return agent_list @classmethod def load_legacy_static_agent(cls, static_agents_data: Tuple): @@ -185,3 +190,7 @@ class EnvAgent: handle=i) agents.append(agent) return agents + + @property + def state(self): + return self.state_machine.state diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 4181482b..c8f75908 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -561,7 +561,6 @@ class RailEnv(Environment): state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed) agent.state_machine.set_transition_signals(state_transition_signals) agent.state_machine.step() - agent.state = agent.state_machine.state # TODO : Make this a property instead? # Remove agent is required if self.remove_agents_at_target and agent.state == TrainState.DONE: -- GitLab