diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index feb72313f21b9ecc989688d63ba02ccf3a458107..c9c94fb9661b9a46cb8e5eae6268ef4a9a2974db 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -1,3 +1,5 @@ +import collections + from flatland.core.grid.grid4_utils import validate_new_transition @@ -25,6 +27,54 @@ class AStarNode(): self.f = other.f + +# in order for enumeration to be deterministic for testing purposes +# https://stackoverflow.com/questions/1653970/does-python-have-an-ordered-set +class OrderedSet(collections.OrderedDict, collections.MutableSet): + + def update(self, *args, **kwargs): + if kwargs: + raise TypeError("update() takes no keyword arguments") + + for s in args: + for e in s: + self.add(e) + + def add(self, elem): + self[elem] = None + + def discard(self, elem): + self.pop(elem, None) + + def __le__(self, other): + return all(e in other for e in self) + + def __lt__(self, other): + return self <= other and self != other + + def __ge__(self, other): + return all(e in self for e in other) + + def __gt__(self, other): + return self >= other and self != other + + def __repr__(self): + return 'OrderedSet([%s])' % (', '.join(map(repr, self.keys()))) + + def __str__(self): + return '{%s}' % (', '.join(map(repr, self.keys()))) + + difference = property(lambda self: self.__sub__) + difference_update = property(lambda self: self.__isub__) + intersection = property(lambda self: self.__and__) + intersection_update = property(lambda self: self.__iand__) + issubset = property(lambda self: self.__le__) + issuperset = property(lambda self: self.__ge__) + symmetric_difference = property(lambda self: self.__xor__) + symmetric_difference_update = property(lambda self: self.__ixor__) + union = property(lambda self: self.__or__) + + def a_star(rail_trans, rail_array, start, end): """ Returns a list of tuples as a path from the given start to end. @@ -33,12 +83,12 @@ def a_star(rail_trans, rail_array, start, end): rail_shape = rail_array.shape start_node = AStarNode(None, start) end_node = AStarNode(None, end) - open_nodes = set() - closed_nodes = set() + open_nodes = OrderedSet() + closed_nodes = OrderedSet() open_nodes.add(start_node) while len(open_nodes) > 0: - # get node with current shortest est. path (lowest f) + # get node with current shortest path (lowest f) current_node = None for item in open_nodes: if current_node is None: diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index fb191fd9cb10b86ee4b5e66c8c52b123648833cd..70be7a0e86efe2db64025d551a4d91b4dfcc0f13 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -126,6 +126,7 @@ def test_malfunction_process_statistically(): 'min_duration': 3, 'max_duration': 3} np.random.seed(5) + random.seed(0) env = RailEnv(width=20, height=20, @@ -150,11 +151,13 @@ def test_malfunction_process_statistically(): # check that generation of malfunctions works as expected # results are different in py36 and py37, therefore no exact test on nb_malfunction - assert nb_malfunction > 150 + assert nb_malfunction == 149, "nb_malfunction={}".format(nb_malfunction) def test_initial_malfunction(rendering=True): random.seed(0) + np.random.seed(0) + stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 70, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction @@ -193,32 +196,32 @@ def test_initial_malfunction(rendering=True): replay_steps = [ Replay( - position=(27, 5), + position=(28, 5), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=3 ), Replay( - position=(27, 5), + position=(28, 5), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=2 ), Replay( - position=(27, 5), + position=(28, 5), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=1 ), Replay( - position=(27, 4), + position=(28, 4), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, malfunction=0 ), Replay( - position=(27, 3), - direction=Grid4TransitionsEnum.WEST, + position=(27, 4), + direction=Grid4TransitionsEnum.NORTH, action=RailEnvActions.MOVE_FORWARD, malfunction=0 )