From 4d5624ef47b4377d686b63a2fa8b6802f0b9e00e Mon Sep 17 00:00:00 2001
From: Giacomo Spigler <spiglerg@gmail.com>
Date: Wed, 19 Jun 2019 18:50:47 +0200
Subject: [PATCH] added speeds in generators

---
 flatland/envs/agent_utils.py |  6 ++++--
 flatland/envs/generators.py  | 10 +++++-----
 flatland/envs/rail_env.py    |  4 ++--
 3 files changed, 11 insertions(+), 9 deletions(-)

diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 2d07eee..aa46aec 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -47,12 +47,14 @@ class EnvAgentStatic(object):
         self.speed_data = speed_data
 
     @classmethod
-    def from_lists(cls, positions, directions, targets):
+    def from_lists(cls, positions, directions, targets, speeds=None):
         """ Create a list of EnvAgentStatics from lists of positions, directions and targets
         """
         speed_datas = []
         for i in range(len(positions)):
-            speed_datas.append({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})
+            speed_datas.append({'position_fraction': 0.0,
+                                'speed': speeds[i] if speeds is not None else 1.0,
+                                'transition_action_on_cellexit': 0})
         return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas)))
 
     def to_list(self):
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index f644bc1..085d6fd 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -18,7 +18,7 @@ def empty_rail_generator():
         rail_array = grid_map.grid
         rail_array.fill(0)
 
-        return grid_map, [], [], []
+        return grid_map, [], [], [], []
 
     return generator
 
@@ -139,7 +139,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
         agents_target = [sg[1] for sg in start_goal[:num_agents]]
         agents_direction = start_dir[:num_agents]
 
-        return grid_map, agents_position, agents_direction, agents_target
+        return grid_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
 
     return generator
 
@@ -183,7 +183,7 @@ def rail_from_manual_specifications_generator(rail_spec):
             rail,
             num_agents)
 
-        return rail, agents_position, agents_direction, agents_target
+        return rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
 
     return generator
 
@@ -209,7 +209,7 @@ def rail_from_GridTransitionMap_generator(rail_map):
             rail_map,
             num_agents)
 
-        return rail_map, agents_position, agents_direction, agents_target
+        return rail_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
 
     return generator
 
@@ -482,6 +482,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
             return_rail,
             num_agents)
 
-        return return_rail, agents_position, agents_direction, agents_target
+        return return_rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position)
 
     return generator
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 2621308..58df3a1 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -151,7 +151,7 @@ class RailEnv(Environment):
             self.rail = tRailAgents[0]
 
         if replace_agents:
-            self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:4])
+            self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
 
         self.restart_agents()
 
@@ -191,7 +191,7 @@ class RailEnv(Environment):
         # for i in range(len(self.agents_handles)):
         for iAgent in range(self.get_num_agents()):
             agent = self.agents[iAgent]
-            agent.speed_data['speed']=0.5
+            print(agent.speed_data['speed'])
 
             if self.dones[iAgent]:  # this agent has already completed...
                 continue
-- 
GitLab