From 64ca683633e258cc1158b2bb78c7913532ec2beb Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Sat, 22 Jun 2019 19:48:14 +0200
Subject: [PATCH] bugfix speeddate shared over all instances

---
 flatland/envs/agent_utils.py | 6 ++++--
 flatland/envs/rail_env.py    | 7 ++-----
 2 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 5eadb933..e353af29 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -1,7 +1,7 @@
 from itertools import starmap
 
 import numpy as np
-from attr import attrs, attrib
+from attr import attrs, attrib, Factory
 
 
 @attrs
@@ -18,7 +18,9 @@ class EnvAgentStatic(object):
     # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
     # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
     # cell if speed=1, as default)
-    speed_data = attrib(default=dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}))
+    # N.B. we need to use factory since default arguments are not recreated on each call!
+    speed_data = attrib(
+        default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
 
     @classmethod
     def from_lists(cls, positions, directions, targets, speeds=None):
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index d6a7cfac..b35865a1 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -12,8 +12,8 @@ import msgpack
 import numpy as np
 
 from flatland.core.env import Environment
-from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
 from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
 from flatland.envs.generators import random_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 
@@ -196,7 +196,7 @@ class RailEnv(Environment):
         for iAgent in range(self.get_num_agents()):
             agent = self.agents[iAgent]
             if iAgent % 2 == 0:
-                agent.speed_data["speed"] = 1./10.
+                agent.speed_data["speed"] = 1. / 10.
             if self.dones[iAgent]:  # this agent has already completed...
                 continue
 
@@ -277,7 +277,6 @@ class RailEnv(Environment):
 
             if agent.speed_data['position_fraction'] >= 1.0:
 
-
                 # Perform stored action to transition to the next cell
 
                 # Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering
@@ -292,8 +291,6 @@ class RailEnv(Environment):
                     agent.direction = new_direction
                     agent.speed_data['position_fraction'] = 0.0
 
-
-
             if np.equal(agent.position, agent.target).all():
                 self.dones[iAgent] = True
             else:
-- 
GitLab