From e6223ce12d0cc4c7d30dddfeeba946414e7d5d0b Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipam@aicrowd.com>
Date: Thu, 9 Sep 2021 20:30:56 +0530
Subject: [PATCH] remove list starmap init for agents

---
 flatland/envs/agent_utils.py | 69 ++++++++++++++++++++----------------
 flatland/envs/rail_env.py    |  1 -
 2 files changed, 39 insertions(+), 31 deletions(-)

diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 91c6d72f..bf926371 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -73,8 +73,6 @@ class EnvAgent:
     state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) , 
                            type=TrainStateMachine)
     malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
-    
-    state = attrib(default=TrainState.WAITING, type=TrainState)
 
     position = attrib(default=None, type=Optional[Tuple[int, int]])
 
@@ -134,35 +132,42 @@ class EnvAgent:
     def from_line(cls, line: Line):
         """ Create a list of EnvAgent from lists of positions, directions and targets
         """
-        speed_datas = []
-        speed_counters = []
-        for i in range(len(line.agent_positions)):
-            speed = line.agent_speeds[i] if line.agent_speeds is not None else 1.0
-            speed_datas.append({'position_fraction': 0.0,
-                                'speed': speed,
-                                'transition_action_on_cellexit': 0})
-            speed_counters.append( SpeedCounter(speed=speed) )
-
-        malfunction_datas = []
-        for i in range(len(line.agent_positions)):
-            malfunction_datas.append({'malfunction': 0,
-                                      'malfunction_rate': line.agent_malfunction_rates[
-                                          i] if line.agent_malfunction_rates is not None else 0.,
-                                      'next_malfunction': 0,
-                                      'nr_malfunctions': 0})
+        num_agents = len(line.agent_positions)
         
-        return list(starmap(EnvAgent, zip(line.agent_positions,  # TODO : Dipam - Really want to change this way of loading agents
-                                          line.agent_directions,
-                                          line.agent_directions,
-                                          line.agent_targets, 
-                                          [False] * len(line.agent_positions), 
-                                          [None] * len(line.agent_positions), # earliest_departure
-                                          [None] * len(line.agent_positions), # latest_arrival
-                                          speed_datas,
-                                          malfunction_datas,
-                                          range(len(line.agent_positions)),
-                                          speed_counters,
-                                          )))
+        agent_list = []
+        for i_agent in range(num_agents):
+            speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
+
+            speed_data = {'position_fraction': 0.0,
+                           'speed': speed,
+                           'transition_action_on_cellexit': 0
+                          }
+            
+            if line.agent_malfunction_rates is not None:
+                malfunction_rate = line.agent_malfunction_rates[i_agent]
+            else:
+                malfunction_rate = 0.
+            
+            malfunction_data = {'malfunction': 0,
+                                'malfunction_rate': malfunction_rate,
+                                'next_malfunction': 0,
+                                'nr_malfunctions': 0
+                               }
+            
+            agent = EnvAgent(initial_position = line.agent_positions[i_agent],
+                            initial_direction = line.agent_directions[i_agent],
+                            direction = line.agent_directions[i_agent],
+                            target = line.agent_targets[i_agent], 
+                            moving = False, 
+                            earliest_departure = None,
+                            latest_arrival = None,
+                            speed_data = speed_data,
+                            malfunction_data = malfunction_data,
+                            handle = i_agent,
+                            speed_counter = SpeedCounter(speed=speed))
+            agent_list.append(agent)
+
+        return agent_list
 
     @classmethod
     def load_legacy_static_agent(cls, static_agents_data: Tuple):
@@ -185,3 +190,7 @@ class EnvAgent:
                                 handle=i)
             agents.append(agent)
         return agents
+    
+    @property
+    def state(self):
+        return self.state_machine.state
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 4181482b..c8f75908 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -561,7 +561,6 @@ class RailEnv(Environment):
             state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed)
             agent.state_machine.set_transition_signals(state_transition_signals)
             agent.state_machine.step()
-            agent.state = agent.state_machine.state # TODO : Make this a property instead?
 
             # Remove agent is required
             if self.remove_agents_at_target and agent.state == TrainState.DONE:
-- 
GitLab