Skip to content
Snippets Groups Projects
Commit 77e447fe authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch 'bugfix-speeddata-shared' into 'master'

bugfix speeddate shared over all instances

See merge request flatland/flatland!79
parents d2c0ce1c 64ca6836
No related branches found
No related tags found
No related merge requests found
from itertools import starmap from itertools import starmap
import numpy as np import numpy as np
from attr import attrs, attrib from attr import attrs, attrib, Factory
@attrs @attrs
...@@ -18,7 +18,9 @@ class EnvAgentStatic(object): ...@@ -18,7 +18,9 @@ class EnvAgentStatic(object):
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0, # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default) # cell if speed=1, as default)
speed_data = attrib(default=dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})) # N.B. we need to use factory since default arguments are not recreated on each call!
speed_data = attrib(
default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
@classmethod @classmethod
def from_lists(cls, positions, directions, targets, speeds=None): def from_lists(cls, positions, directions, targets, speeds=None):
......
...@@ -12,8 +12,8 @@ import msgpack ...@@ -12,8 +12,8 @@ import msgpack
import numpy as np import numpy as np
from flatland.core.env import Environment from flatland.core.env import Environment
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from flatland.envs.generators import random_rail_generator from flatland.envs.generators import random_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
...@@ -196,7 +196,7 @@ class RailEnv(Environment): ...@@ -196,7 +196,7 @@ class RailEnv(Environment):
for iAgent in range(self.get_num_agents()): for iAgent in range(self.get_num_agents()):
agent = self.agents[iAgent] agent = self.agents[iAgent]
if iAgent % 2 == 0: if iAgent % 2 == 0:
agent.speed_data["speed"] = 1./10. agent.speed_data["speed"] = 1. / 10.
if self.dones[iAgent]: # this agent has already completed... if self.dones[iAgent]: # this agent has already completed...
continue continue
...@@ -277,7 +277,6 @@ class RailEnv(Environment): ...@@ -277,7 +277,6 @@ class RailEnv(Environment):
if agent.speed_data['position_fraction'] >= 1.0: if agent.speed_data['position_fraction'] >= 1.0:
# Perform stored action to transition to the next cell # Perform stored action to transition to the next cell
# Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering # Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering
...@@ -292,8 +291,6 @@ class RailEnv(Environment): ...@@ -292,8 +291,6 @@ class RailEnv(Environment):
agent.direction = new_direction agent.direction = new_direction
agent.speed_data['position_fraction'] = 0.0 agent.speed_data['position_fraction'] = 0.0
if np.equal(agent.position, agent.target).all(): if np.equal(agent.position, agent.target).all():
self.dones[iAgent] = True self.dones[iAgent] = True
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment