Skip to content
Snippets Groups Projects
Commit 9bebfe03 authored by u214892's avatar u214892
Browse files

#178 bugfix initial malfunction

parent 096fd933
No related branches found
No related tags found
No related merge requests found
import collections
from flatland.core.grid.grid4_utils import validate_new_transition from flatland.core.grid.grid4_utils import validate_new_transition
...@@ -25,6 +27,54 @@ class AStarNode(): ...@@ -25,6 +27,54 @@ class AStarNode():
self.f = other.f self.f = other.f
# in order for enumeration to be deterministic for testing purposes
# https://stackoverflow.com/questions/1653970/does-python-have-an-ordered-set
class OrderedSet(collections.OrderedDict, collections.MutableSet):
def update(self, *args, **kwargs):
if kwargs:
raise TypeError("update() takes no keyword arguments")
for s in args:
for e in s:
self.add(e)
def add(self, elem):
self[elem] = None
def discard(self, elem):
self.pop(elem, None)
def __le__(self, other):
return all(e in other for e in self)
def __lt__(self, other):
return self <= other and self != other
def __ge__(self, other):
return all(e in self for e in other)
def __gt__(self, other):
return self >= other and self != other
def __repr__(self):
return 'OrderedSet([%s])' % (', '.join(map(repr, self.keys())))
def __str__(self):
return '{%s}' % (', '.join(map(repr, self.keys())))
difference = property(lambda self: self.__sub__)
difference_update = property(lambda self: self.__isub__)
intersection = property(lambda self: self.__and__)
intersection_update = property(lambda self: self.__iand__)
issubset = property(lambda self: self.__le__)
issuperset = property(lambda self: self.__ge__)
symmetric_difference = property(lambda self: self.__xor__)
symmetric_difference_update = property(lambda self: self.__ixor__)
union = property(lambda self: self.__or__)
def a_star(rail_trans, rail_array, start, end): def a_star(rail_trans, rail_array, start, end):
""" """
Returns a list of tuples as a path from the given start to end. Returns a list of tuples as a path from the given start to end.
...@@ -33,12 +83,12 @@ def a_star(rail_trans, rail_array, start, end): ...@@ -33,12 +83,12 @@ def a_star(rail_trans, rail_array, start, end):
rail_shape = rail_array.shape rail_shape = rail_array.shape
start_node = AStarNode(None, start) start_node = AStarNode(None, start)
end_node = AStarNode(None, end) end_node = AStarNode(None, end)
open_nodes = set() open_nodes = OrderedSet()
closed_nodes = set() closed_nodes = OrderedSet()
open_nodes.add(start_node) open_nodes.add(start_node)
while len(open_nodes) > 0: while len(open_nodes) > 0:
# get node with current shortest est. path (lowest f) # get node with current shortest path (lowest f)
current_node = None current_node = None
for item in open_nodes: for item in open_nodes:
if current_node is None: if current_node is None:
......
...@@ -126,6 +126,7 @@ def test_malfunction_process_statistically(): ...@@ -126,6 +126,7 @@ def test_malfunction_process_statistically():
'min_duration': 3, 'min_duration': 3,
'max_duration': 3} 'max_duration': 3}
np.random.seed(5) np.random.seed(5)
random.seed(0)
env = RailEnv(width=20, env = RailEnv(width=20,
height=20, height=20,
...@@ -150,11 +151,13 @@ def test_malfunction_process_statistically(): ...@@ -150,11 +151,13 @@ def test_malfunction_process_statistically():
# check that generation of malfunctions works as expected # check that generation of malfunctions works as expected
# results are different in py36 and py37, therefore no exact test on nb_malfunction # results are different in py36 and py37, therefore no exact test on nb_malfunction
assert nb_malfunction > 150 assert nb_malfunction == 149, "nb_malfunction={}".format(nb_malfunction)
def test_initial_malfunction(rendering=True): def test_initial_malfunction(rendering=True):
random.seed(0) random.seed(0)
np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence 'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction 'min_duration': 2, # Minimal duration of malfunction
...@@ -193,32 +196,32 @@ def test_initial_malfunction(rendering=True): ...@@ -193,32 +196,32 @@ def test_initial_malfunction(rendering=True):
replay_steps = [ replay_steps = [
Replay( Replay(
position=(27, 5), position=(28, 5),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=3 malfunction=3
), ),
Replay( Replay(
position=(27, 5), position=(28, 5),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=2 malfunction=2
), ),
Replay( Replay(
position=(27, 5), position=(28, 5),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=1 malfunction=1
), ),
Replay( Replay(
position=(27, 4), position=(28, 4),
direction=Grid4TransitionsEnum.WEST, direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=0 malfunction=0
), ),
Replay( Replay(
position=(27, 3), position=(27, 4),
direction=Grid4TransitionsEnum.WEST, direction=Grid4TransitionsEnum.NORTH,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=0 malfunction=0
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment