diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 5eadb9332a4a084ad030a46dc18bab81f33ac4e0..e353af29ddbee16c208e2059767c18fa7880cb64 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -1,7 +1,7 @@ from itertools import starmap import numpy as np -from attr import attrs, attrib +from attr import attrs, attrib, Factory @attrs @@ -18,7 +18,9 @@ class EnvAgentStatic(object): # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0, # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous # cell if speed=1, as default) - speed_data = attrib(default=dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})) + # N.B. we need to use factory since default arguments are not recreated on each call! + speed_data = attrib( + default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}))) @classmethod def from_lists(cls, positions, directions, targets, speeds=None): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index d6a7cfaca4a034303427688d9ab116a4121e9992..b35865a115af385cb708ac4f06ffd32d555dcbe2 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -12,8 +12,8 @@ import msgpack import numpy as np from flatland.core.env import Environment -from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent from flatland.envs.generators import random_rail_generator from flatland.envs.observations import TreeObsForRailEnv @@ -196,7 +196,7 @@ class RailEnv(Environment): for iAgent in range(self.get_num_agents()): agent = self.agents[iAgent] if iAgent % 2 == 0: - agent.speed_data["speed"] = 1./10. + agent.speed_data["speed"] = 1. / 10. if self.dones[iAgent]: # this agent has already completed... continue @@ -277,7 +277,6 @@ class RailEnv(Environment): if agent.speed_data['position_fraction'] >= 1.0: - # Perform stored action to transition to the next cell # Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering @@ -292,8 +291,6 @@ class RailEnv(Environment): agent.direction = new_direction agent.speed_data['position_fraction'] = 0.0 - - if np.equal(agent.position, agent.target).all(): self.dones[iAgent] = True else: