Commit 8e8f91a5 authored by u214892's avatar u214892
Browse files

#178 bugfix initial malfunction

parent 78b1f9ee
......@@ -18,6 +18,7 @@ for entry in [entry for entry in importlib_resources.contents('examples') if
with path('examples', entry) as file_in:
print("")
print("")
print("")
print("*****************************************************************")
print("Running {}".format(entry))
......
import collections
from flatland.core.grid.grid4_utils import validate_new_transition
from flatland.utils.ordered_set import OrderedSet
class AStarNode():
......@@ -27,54 +26,6 @@ class AStarNode():
self.f = other.f
# in order for enumeration to be deterministic for testing purposes
# https://stackoverflow.com/questions/1653970/does-python-have-an-ordered-set
class OrderedSet(collections.OrderedDict, collections.MutableSet):
def update(self, *args, **kwargs):
if kwargs:
raise TypeError("update() takes no keyword arguments")
for s in args:
for e in s:
self.add(e)
def add(self, elem):
self[elem] = None
def discard(self, elem):
self.pop(elem, None)
def __le__(self, other):
return all(e in other for e in self)
def __lt__(self, other):
return self <= other and self != other
def __ge__(self, other):
return all(e in self for e in other)
def __gt__(self, other):
return self >= other and self != other
def __repr__(self):
return 'OrderedSet([%s])' % (', '.join(map(repr, self.keys())))
def __str__(self):
return '{%s}' % (', '.join(map(repr, self.keys())))
difference = property(lambda self: self.__sub__)
difference_update = property(lambda self: self.__isub__)
intersection = property(lambda self: self.__and__)
intersection_update = property(lambda self: self.__iand__)
issubset = property(lambda self: self.__le__)
issuperset = property(lambda self: self.__ge__)
symmetric_difference = property(lambda self: self.__xor__)
symmetric_difference_update = property(lambda self: self.__ixor__)
union = property(lambda self: self.__or__)
def a_star(rail_trans, rail_array, start, end):
"""
Returns a list of tuples as a path from the given start to end.
......
......@@ -10,6 +10,7 @@ from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transitions import Transitions
from flatland.utils.ordered_set import OrderedSet
class TransitionMap:
......@@ -336,7 +337,7 @@ class GridTransitionMap(TransitionMap):
tmp = self.get_full_transitions(rcPos[0], rcPos[1])
def is_simple_turn(trans):
all_simple_turns = set()
all_simple_turns = OrderedSet()
for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right
int('0001001000000000', 2) # Case 1c (9) - simple turn left]:
]:
......@@ -351,7 +352,7 @@ class GridTransitionMap(TransitionMap):
# print("_path_exists({},{},{}".format(start, direction, end))
# BFS - Check if a path exists between the 2 nodes
visited = set()
visited = OrderedSet()
stack = [(start, direction)]
while stack:
node = stack.pop()
......
......@@ -9,6 +9,7 @@ import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.utils.ordered_set import OrderedSet
class TreeObsForRailEnv(ObservationBuilder):
......@@ -279,7 +280,7 @@ class TreeObsForRailEnv(ObservationBuilder):
observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0,
agent.malfunction_data['malfunction'], agent.speed_data['speed']]
visited = set()
visited = OrderedSet()
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
......@@ -295,7 +296,7 @@ class TreeObsForRailEnv(ObservationBuilder):
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, 1, 1)
observation = observation + branch_observation
visited = visited.union(branch_visited)
visited |= branch_visited
else:
# add cells filled with infinity if no transition is possible
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
......@@ -332,7 +333,7 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here
last_is_target = False
visited = set()
visited = OrderedSet()
agent = self.env.agents[handle]
time_per_cell = np.reciprocal(agent.speed_data["speed"])
own_target_encountered = np.inf
......@@ -545,7 +546,7 @@ class TreeObsForRailEnv(ObservationBuilder):
depth + 1)
observation = observation + branch_observation
if len(branch_visited) != 0:
visited = visited.union(branch_visited)
visited |= branch_visited
elif last_is_switch and possible_transitions[branch_direction]:
new_cell = self._new_position(position, branch_direction)
branch_observation, branch_visited = self._explore_branch(handle,
......@@ -555,7 +556,7 @@ class TreeObsForRailEnv(ObservationBuilder):
depth + 1)
observation = observation + branch_observation
if len(branch_visited) != 0:
visited = visited.union(branch_visited)
visited |= branch_visited
else:
# no exploring possible, add just cells with infinity
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
......
......@@ -7,6 +7,7 @@ import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions
from flatland.utils.ordered_set import OrderedSet
class DummyPredictorForRailEnv(PredictionBuilder):
......@@ -130,7 +131,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
new_direction = _agent_initial_direction
new_position = _agent_initial_position
visited = set()
visited = OrderedSet()
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving...
if agent.position == agent.target:
......
......@@ -4,7 +4,7 @@ Definition of the RailEnv environment.
# TODO: _ this is a global method --> utils or remove later
import warnings
from enum import IntEnum
from typing import List, Set, NamedTuple
from typing import List, Set, NamedTuple, Optional
import msgpack
import msgpack_numpy as m
......@@ -18,6 +18,7 @@ from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
from flatland.utils.ordered_set import OrderedSet
m.patch()
......@@ -153,7 +154,7 @@ class RailEnv(Environment):
self.rail_generator: RailGenerator = rail_generator
self.schedule_generator: ScheduleGenerator = schedule_generator
self.rail_generator = rail_generator
self.rail: GridTransitionMap = None
self.rail: Optional[GridTransitionMap] = None
self.width = width
self.height = height
......@@ -549,7 +550,7 @@ class RailEnv(Environment):
return new_direction, transition_valid
def get_valid_move_actions(self, agent: EnvAgent) -> Set[RailEnvNextAction]:
valid_actions: Set[RailEnvNextAction] = set()
valid_actions: Set[RailEnvNextAction] = OrderedSet()
agent_position = agent.position
agent_direction = agent.direction
possible_transitions = self.rail.get_transitions(*agent_position, agent_direction)
......
# in order for enumeration to be deterministic for testing purposes
# https://stackoverflow.com/questions/1653970/does-python-have-an-ordered-set
from collections import OrderedDict
from collections.abc import MutableSet
class OrderedSet(OrderedDict, MutableSet):
def update(self, *args, **kwargs):
if kwargs:
raise TypeError("update() takes no keyword arguments")
for s in args:
for e in s:
self.add(e)
def add(self, elem):
self[elem] = None
def discard(self, elem):
self.pop(elem, None)
def __le__(self, other):
return all(e in other for e in self)
def __lt__(self, other):
return self <= other and self != other
def __ge__(self, other):
return all(e in self for e in other)
def __gt__(self, other):
return self >= other and self != other
def __repr__(self):
return 'OrderedSet([%s])' % (', '.join(map(repr, self.keys())))
def __str__(self):
return '{%s}' % (', '.join(map(repr, self.keys())))
difference = property(lambda self: self.__sub__)
difference_update = property(lambda self: self.__isub__)
intersection = property(lambda self: self.__and__)
intersection_update = property(lambda self: self.__iand__)
issubset = property(lambda self: self.__le__)
issuperset = property(lambda self: self.__ge__)
symmetric_difference = property(lambda self: self.__xor__)
symmetric_difference_update = property(lambda self: self.__ixor__)
union = property(lambda self: self.__or__)
......@@ -9,7 +9,7 @@ from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid
from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
from test_utils import TestConfig, Replay
from test_utils import ReplayConfig, Replay
np.random.seed(1)
......@@ -117,7 +117,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
if rendering:
renderer = RenderTool(env, gl="PILSVG")
test_config = TestConfig(
test_config = ReplayConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
......@@ -248,7 +248,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
renderer = RenderTool(env, gl="PILSVG")
test_configs = [
TestConfig(
ReplayConfig(
replay=[
Replay(
position=(3, 8),
......@@ -316,7 +316,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
],
target=(3, 0), # west dead-end
speed=1 / 3),
TestConfig(
ReplayConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
......@@ -456,7 +456,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
if rendering:
renderer = RenderTool(env, gl="PILSVG")
test_config = TestConfig(
test_config = ReplayConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
......
......@@ -15,7 +15,7 @@ class Replay(object):
@attrs
class TestConfig(object):
class ReplayConfig(object):
replay = attrib(type=List[Replay])
target = attrib()
speed = attrib(type=float)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment