Commit 64ca6836 authored by u214892's avatar u214892
Browse files

bugfix speeddate shared over all instances

parent d2c0ce1c
Pipeline #1204 passed with stage
in 7 minutes and 50 seconds
from itertools import starmap
import numpy as np
from attr import attrs, attrib
from attr import attrs, attrib, Factory
@attrs
......@@ -18,7 +18,9 @@ class EnvAgentStatic(object):
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default)
speed_data = attrib(default=dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}))
# N.B. we need to use factory since default arguments are not recreated on each call!
speed_data = attrib(
default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
@classmethod
def from_lists(cls, positions, directions, targets, speeds=None):
......
......@@ -12,8 +12,8 @@ import msgpack
import numpy as np
from flatland.core.env import Environment
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from flatland.envs.generators import random_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
......@@ -196,7 +196,7 @@ class RailEnv(Environment):
for iAgent in range(self.get_num_agents()):
agent = self.agents[iAgent]
if iAgent % 2 == 0:
agent.speed_data["speed"] = 1./10.
agent.speed_data["speed"] = 1. / 10.
if self.dones[iAgent]: # this agent has already completed...
continue
......@@ -277,7 +277,6 @@ class RailEnv(Environment):
if agent.speed_data['position_fraction'] >= 1.0:
# Perform stored action to transition to the next cell
# Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering
......@@ -292,8 +291,6 @@ class RailEnv(Environment):
agent.direction = new_direction
agent.speed_data['position_fraction'] = 0.0
if np.equal(agent.position, agent.target).all():
self.dones[iAgent] = True
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment