diff --git a/benchmarks/run_all_examples.py b/benchmarks/run_all_examples.py index 509e232416db525c28ce3d69ac84ad1657c533eb..1b3e3be066989e18af3f36e1dd73ded37c0bc6cf 100644 --- a/benchmarks/run_all_examples.py +++ b/benchmarks/run_all_examples.py @@ -18,6 +18,7 @@ for entry in [entry for entry in importlib_resources.contents('examples') if with path('examples', entry) as file_in: print("") print("") + print("") print("*****************************************************************") print("Running {}".format(entry)) diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index b6027184c631c79c15f967e8867b8b5f3c8ba0f6..9238a2af4137e37e9d79bc3c1aaade2bb987403e 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -11,6 +11,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator +from flatland.utils.ordered_set import OrderedSet from flatland.utils.rendertools import RenderTool random.seed(100) @@ -82,7 +83,7 @@ class ObservePredictions(TreeObsForRailEnv): # We are going to track what cells where considered while building the obervation and make them accesible # For rendering - visited = set() + visited = OrderedSet() for _idx in range(10): # Check if any of the other prediction overlap with agents own predictions x_coord = self.predictions[handle][_idx][1] diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index feb72313f21b9ecc989688d63ba02ccf3a458107..d1652a38c7ecd45cbdd28522b6aeeb28683c4736 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -1,4 +1,5 @@ from flatland.core.grid.grid4_utils import validate_new_transition +from flatland.utils.ordered_set import OrderedSet class AStarNode(): @@ -33,12 +34,12 @@ def a_star(rail_trans, rail_array, start, end): rail_shape = rail_array.shape start_node = AStarNode(None, start) end_node = AStarNode(None, end) - open_nodes = set() - closed_nodes = set() + open_nodes = OrderedSet() + closed_nodes = OrderedSet() open_nodes.add(start_node) while len(open_nodes) > 0: - # get node with current shortest est. path (lowest f) + # get node with current shortest path (lowest f) current_node = None for item in open_nodes: if current_node is None: diff --git a/flatland/core/grid/rail_env_grid.py b/flatland/core/grid/rail_env_grid.py index 680e945316ab3a4876bd36fa8e6b001ea346cd26..db09fbd57b18d203c956742d4711973c986ca452 100644 --- a/flatland/core/grid/rail_env_grid.py +++ b/flatland/core/grid/rail_env_grid.py @@ -1,4 +1,5 @@ from flatland.core.grid.grid4 import Grid4Transitions +from flatland.utils.ordered_set import OrderedSet class RailEnvTransitions(Grid4Transitions): @@ -44,7 +45,7 @@ class RailEnvTransitions(Grid4Transitions): ) # create this to make validation faster - self.transitions_all = set() + self.transitions_all = OrderedSet() for index, trans in enumerate(self.transitions): self.transitions_all.add(trans) if index in (2, 4, 6, 7, 8, 9, 10): diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 232d6fdab02c57da95bf04c631e4905986c71327..105f1c90bc7201b4d8c9b17e184fc6d56ffb02b2 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -10,6 +10,7 @@ from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transitions import Transitions +from flatland.utils.ordered_set import OrderedSet class TransitionMap: @@ -336,7 +337,7 @@ class GridTransitionMap(TransitionMap): tmp = self.get_full_transitions(rcPos[0], rcPos[1]) def is_simple_turn(trans): - all_simple_turns = set() + all_simple_turns = OrderedSet() for trans in [int('0100000000000010', 2), # Case 1b (8) - simple turn right int('0001001000000000', 2) # Case 1c (9) - simple turn left]: ]: @@ -351,7 +352,7 @@ class GridTransitionMap(TransitionMap): # print("_path_exists({},{},{}".format(start, direction, end)) # BFS - Check if a path exists between the 2 nodes - visited = set() + visited = OrderedSet() stack = [(start, direction)] while stack: node = stack.pop() diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index b228e10b6c146f5692166e179bb9f574a68c9134..f659ec8436a941606b6d649e24d2481e5be9b66d 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -1,8 +1,11 @@ from itertools import starmap +from typing import Tuple import numpy as np from attr import attrs, attrib, Factory +from flatland.core.grid.grid4 import Grid4TransitionsEnum + @attrs class EnvAgentStatic(object): @@ -11,10 +14,10 @@ class EnvAgentStatic(object): rather than where it is at the moment. The target should also be stored here. """ - position = attrib() - direction = attrib() - target = attrib() - moving = attrib(default=False) + position = attrib(type=Tuple[int, int]) + direction = attrib(type=Grid4TransitionsEnum) + target = attrib(type=Tuple[int, int]) + moving = attrib(default=False, type=bool) # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0, # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous @@ -27,7 +30,8 @@ class EnvAgentStatic(object): # number of time the agent had to stop, since the last time it broke down malfunction_data = attrib( default=Factory( - lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0}))) + lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0, + 'moving_before_malfunction': False}))) @classmethod def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index a833fc01949d4184d5ca2442c6bb429d697318f3..85cd5fdc798fde4f1a5f2b75edeefce2dfa2104a 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -9,6 +9,7 @@ from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import coordinate_to_position +from flatland.utils.ordered_set import OrderedSet class TreeObsForRailEnv(ObservationBuilder): @@ -154,7 +155,7 @@ class TreeObsForRailEnv(ObservationBuilder): observation = [0, 0, 0, 0, 0, 0, self.env.distance_map.get()[(handle, *agent.position, agent.direction)], 0, 0, agent.malfunction_data['malfunction'], agent.speed_data['speed']] - visited = set() + visited = OrderedSet() # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation @@ -170,7 +171,7 @@ class TreeObsForRailEnv(ObservationBuilder): branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, 1, 1) observation = observation + branch_observation - visited = visited.union(branch_visited) + visited |= branch_visited else: # add cells filled with infinity if no transition is possible observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth) @@ -207,7 +208,7 @@ class TreeObsForRailEnv(ObservationBuilder): last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here last_is_target = False - visited = set() + visited = OrderedSet() agent = self.env.agents[handle] time_per_cell = np.reciprocal(agent.speed_data["speed"]) own_target_encountered = np.inf @@ -420,7 +421,7 @@ class TreeObsForRailEnv(ObservationBuilder): depth + 1) observation = observation + branch_observation if len(branch_visited) != 0: - visited = visited.union(branch_visited) + visited |= branch_visited elif last_is_switch and possible_transitions[branch_direction]: new_cell = get_new_position(position, branch_direction) branch_observation, branch_visited = self._explore_branch(handle, @@ -430,7 +431,7 @@ class TreeObsForRailEnv(ObservationBuilder): depth + 1) observation = observation + branch_observation if len(branch_visited) != 0: - visited = visited.union(branch_visited) + visited |= branch_visited else: # no exploring possible, add just cells with infinity observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth) diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index c77b57871fcf30b231132c364c115a38d6de3889..7f03b5be43d98ba3b4d87d933c267457fd133ddd 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -7,6 +7,7 @@ import numpy as np from flatland.core.env_prediction_builder import PredictionBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.rail_env import RailEnvActions +from flatland.utils.ordered_set import OrderedSet class DummyPredictorForRailEnv(PredictionBuilder): @@ -125,7 +126,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] new_direction = _agent_initial_direction new_position = _agent_initial_position - visited = set() + visited = OrderedSet() for index in range(1, self.max_depth + 1): # if we're at the target, stop moving... if agent.position == agent.target: diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index f5881836072ec09a2fe4ee9a70482566867362c4..294ffab233458f1f3b98c18be50743ba65bd2d73 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -4,13 +4,14 @@ Definition of the RailEnv environment. # TODO: _ this is a global method --> utils or remove later import warnings from enum import IntEnum -from typing import List +from typing import List, Set, NamedTuple, Optional, Tuple, Dict import msgpack import msgpack_numpy as m import numpy as np from flatland.core.env import Environment +from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent @@ -18,6 +19,7 @@ from flatland.envs.distance_map import DistanceMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator +from flatland.utils.ordered_set import OrderedSet m.patch() @@ -40,6 +42,11 @@ class RailEnvActions(IntEnum): }[a] +RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)]) +RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos), + ('next_direction', Grid4TransitionsEnum)]) + + class RailEnv(Environment): """ RailEnv environment class. @@ -115,7 +122,7 @@ class RailEnv(Environment): Environment init. Parameters - ------- + ---------- rail_generator : function The rail_generator function is a function that takes the width, height and agents handles of a rail environment, along with the number of times @@ -140,16 +147,13 @@ class RailEnv(Environment): ObservationBuilder-derived object that takes builds observation vectors for each agent. max_episode_steps : int or None - - file_name: you can load a pickle file. from previously saved *.pkl file - """ super().__init__() self.rail_generator: RailGenerator = rail_generator self.schedule_generator: ScheduleGenerator = schedule_generator self.rail_generator = rail_generator - self.rail: GridTransitionMap = None + self.rail: Optional[GridTransitionMap] = None self.width = width self.height = height @@ -265,7 +269,10 @@ class RailEnv(Environment): agent.malfunction_data['malfunction'] = 0 - self._agent_new_malfunction(i_agent, RailEnvActions.DO_NOTHING) + initial_malfunction = self._agent_malfunction(i_agent) + + if initial_malfunction: + agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING self.num_resets += 1 self._elapsed_steps = 0 @@ -281,7 +288,7 @@ class RailEnv(Environment): # Return the new observation vectors for each agent return self._get_observations() - def _agent_new_malfunction(self, i_agent, action) -> bool: + def _agent_malfunction(self, i_agent) -> bool: """ Returns true if the agent enters into malfunction. (False, if not broken down or already broken down before). """ @@ -308,12 +315,28 @@ class RailEnv(Environment): num_broken_steps = np.random.randint(self.min_number_of_steps_broken, self.max_number_of_steps_broken + 1) + 1 agent.malfunction_data['malfunction'] = num_broken_steps + agent.malfunction_data['moving_before_malfunction'] = agent.moving return True + else: + # The train was broken before... + if agent.malfunction_data['malfunction'] > 0: + + # Last step of malfunction --> Agent starts moving again after getting fixed + if agent.malfunction_data['malfunction'] < 2: + agent.malfunction_data['malfunction'] -= 1 + + # restore moving state before malfunction without further penalty + self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction'] + + else: + agent.malfunction_data['malfunction'] -= 1 + + # Nothing left to do with broken agent + return True return False - # TODO refactor to decrease length of this method! - def step(self, action_dict_): + def step(self, action_dict_: Dict[int, RailEnvActions]): self._elapsed_steps += 1 # Reset the step rewards @@ -321,8 +344,9 @@ class RailEnv(Environment): for i_agent in range(self.get_num_agents()): self.rewards_dict[i_agent] = 0 + # If we're done, set reward and info_dict and step() is done. if self.dones["__all__"]: - self.rewards_dict = {i: r + self.global_reward for i, r in self.rewards_dict.items()} + self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())} info_dict = { 'action_required': {i: False for i in range(self.get_num_agents())}, 'malfunction': {i: 0 for i in range(self.get_num_agents())}, @@ -330,166 +354,173 @@ class RailEnv(Environment): } return self._get_observations(), self.rewards_dict, self.dones, info_dict + # Perform step on all agents for i_agent in range(self.get_num_agents()): + self._step_agent(i_agent, action_dict_.get(i_agent)) - if self.dones[i_agent]: # this agent has already completed... - continue - - agent = self.agents[i_agent] - agent.old_direction = agent.direction - agent.old_position = agent.position + # Check for end of episode + set global reward to all rewards! + if np.all([np.array_equal(agent.position, agent.target) for agent in self.agents]): + self.dones["__all__"] = True + self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())} - # No action has been supplied for this agent -> set DO_NOTHING as default - if i_agent not in action_dict_: - action = RailEnvActions.DO_NOTHING - else: - action = action_dict_[i_agent] + if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): + self.dones["__all__"] = True + for k in self.dones.keys(): + self.dones[k] = True - if action < 0 or action > len(RailEnvActions): - print('ERROR: illegal action=', action, - 'for agent with index=', i_agent, - '"DO NOTHING" will be executed instead') - action = RailEnvActions.DO_NOTHING + action_required_agents = { + i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in range(self.get_num_agents()) + } + malfunction_agents = { + i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) + } + speed_agents = {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())} - # Check if agent breaks at this step - new_malfunction = self._agent_new_malfunction(i_agent, action) + info_dict = { + 'action_required': action_required_agents, + 'malfunction': malfunction_agents, + 'speed': speed_agents + } - # Is the agent at the beginning of the cell? Then, it can take an action - # Design choice (Erik+Christian): - # as long as we're broken down at the beginning of the cell, we can choose other actions! - if agent.speed_data['position_fraction'] == 0.0: - if action == RailEnvActions.DO_NOTHING and agent.moving: - # Keep moving - action = RailEnvActions.MOVE_FORWARD + return self._get_observations(), self.rewards_dict, self.dones, info_dict - if action == RailEnvActions.STOP_MOVING and agent.moving: - # Only allow halting an agent on entering new cells. - agent.moving = False - self.rewards_dict[i_agent] += self.stop_penalty + def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None): + """ + Performs a step and step, start and stop penalty on a single agent in the following sub steps: + - malfunction + - action handling if at the beginning of cell + - movement - if not agent.moving and not ( - action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): - # Allow agent to start with any forward or direction action - agent.moving = True - self.rewards_dict[i_agent] += self.start_penalty + Parameters + ---------- + i_agent : int + action_dict_ : Dict[int,RailEnvActions] - # Store the action - if agent.moving and action not in [RailEnvActions.DO_NOTHING, RailEnvActions.STOP_MOVING]: - _, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(action, agent) + """ + if self.dones[i_agent]: # this agent has already completed... + return - if all([new_cell_valid, transition_valid]): - agent.speed_data['transition_action_on_cellexit'] = action - else: - # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, - # try to keep moving forward! - if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT): - _, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) - - if all([new_cell_valid, transition_valid]): - agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD - else: - # If the agent cannot move due to an invalid transition, we set its state to not moving - self.rewards_dict[i_agent] += self.invalid_action_penalty - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - self.rewards_dict[i_agent] += self.stop_penalty - agent.moving = False - - else: - # If the agent cannot move due to an invalid transition, we set its state to not moving - self.rewards_dict[i_agent] += self.invalid_action_penalty - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - self.rewards_dict[i_agent] += self.stop_penalty - agent.moving = False - - # if we've just broken in this step, nothing else to do - if new_malfunction: - continue + agent = self.agents[i_agent] + agent.old_direction = agent.direction + agent.old_position = agent.position - # The train was broken before... - if agent.malfunction_data['malfunction'] > 0: + # is the agent malfunctioning? + malfunction = self._agent_malfunction(i_agent) - # Last step of malfunction --> Agent starts moving again after getting fixed - if agent.malfunction_data['malfunction'] < 2: - agent.malfunction_data['malfunction'] -= 1 - self.agents[i_agent].moving = True - action = RailEnvActions.DO_NOTHING + # if agent is broken, actions are ignored and agent does not move, + # the agent is not penalized in this step! + if malfunction: + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + return - else: - agent.malfunction_data['malfunction'] -= 1 + # Is the agent at the beginning of the cell? Then, it can take an action. + if agent.speed_data['position_fraction'] == 0.0: + # No action has been supplied for this agent -> set DO_NOTHING as default + if action is None: + action = RailEnvActions.DO_NOTHING - # Broken agents are stopped - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - self.agents[i_agent].moving = False + if action < 0 or action > len(RailEnvActions): + print('ERROR: illegal action=', action, + 'for agent with index=', i_agent, + '"DO NOTHING" will be executed instead') + action = RailEnvActions.DO_NOTHING - # Nothing left to do with broken agent - continue + if action == RailEnvActions.DO_NOTHING and agent.moving: + # Keep moving + action = RailEnvActions.MOVE_FORWARD - # Now perform a movement. - # If agent.moving, increment the position_fraction by the speed of the agent - # If the new position fraction is >= 1, reset to 0, and perform the stored - # transition_action_on_cellexit if the cell is free. - if agent.moving: + if action == RailEnvActions.STOP_MOVING and agent.moving: + # Only allow halting an agent on entering new cells. + agent.moving = False + self.rewards_dict[i_agent] += self.stop_penalty - agent.speed_data['position_fraction'] += agent.speed_data['speed'] - if agent.speed_data['position_fraction'] >= 1.0: - # Perform stored action to transition to the next cell as soon as cell is free - # Notice that we've already check new_cell_valid and transition valid when we stored the action, - # so we only have to check cell_free now! + if not agent.moving and not ( + action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): + # Allow agent to start with any forward or direction action + agent.moving = True + self.rewards_dict[i_agent] += self.start_penalty - # cell and transition validity was checked when we stored transition_action_on_cellexit! - cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( - agent.speed_data['transition_action_on_cellexit'], agent) + # Store the action if action is moving + # If not moving, the action will be stored when the agent starts moving again. + if agent.moving: + _action_stored = False + _, new_cell_valid, new_direction, new_position, transition_valid = \ + self._check_action_on_agent(action, agent) - if cell_free: - agent.position = new_position - agent.direction = new_direction - agent.speed_data['position_fraction'] = 0.0 + if all([new_cell_valid, transition_valid]): + agent.speed_data['transition_action_on_cellexit'] = action + _action_stored = True + else: + # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, + # try to keep moving forward! + if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT): + _, new_cell_valid, new_direction, new_position, transition_valid = \ + self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) + + if all([new_cell_valid, transition_valid]): + agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD + _action_stored = True + + if not _action_stored: + # If the agent cannot move due to an invalid transition, we set its state to not moving + self.rewards_dict[i_agent] += self.invalid_action_penalty + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agent] += self.stop_penalty + agent.moving = False + # Now perform a movement. + # If agent.moving, increment the position_fraction by the speed of the agent + # If the new position fraction is >= 1, reset to 0, and perform the stored + # transition_action_on_cellexit if the cell is free. + if agent.moving: + agent.speed_data['position_fraction'] += agent.speed_data['speed'] + if agent.speed_data['position_fraction'] >= 1.0: + # Perform stored action to transition to the next cell as soon as cell is free + # Notice that we've already checked new_cell_valid and transition valid when we stored the action, + # so we only have to check cell_free now! + + # cell and transition validity was checked when we stored transition_action_on_cellexit! + cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( + agent.speed_data['transition_action_on_cellexit'], agent) + + # N.B. validity of new_cell and transition should have been verified before the action was stored! + assert new_cell_valid + assert transition_valid + if cell_free: + agent.position = new_position + agent.direction = new_direction + agent.speed_data['position_fraction'] = 0.0 + + # has the agent reached its target? if np.equal(agent.position, agent.target).all(): self.dones[i_agent] = True agent.moving = False else: self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - # Check for end of episode + add global reward to all rewards! - if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): - self.dones["__all__"] = True - self.rewards_dict = {i: 0 * r + self.global_reward for i, r in self.rewards_dict.items()} - - if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): - self.dones["__all__"] = True - for k in self.dones.keys(): - self.dones[k] = True - - action_required_agents = { - i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in range(self.get_num_agents()) - } - malfunction_agents = { - i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) - } - speed_agents = {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())} + def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): + """ - info_dict = { - 'action_required': action_required_agents, - 'malfunction': malfunction_agents, - 'speed': speed_agents - } + Parameters + ---------- + action : RailEnvActions + agent : EnvAgent - return self._get_observations(), self.rewards_dict, self.dones, info_dict + Returns + ------- + bool + Is it a legal move? + 1) transition allows the new_direction in the cell, + 2) the new cell is not empty (case 0), + 3) the cell is free, i.e., no agent is currently in that cell - def _check_action_on_agent(self, action, agent): + """ # compute number of possible transitions in the current # cell used to check for invalid actions new_direction, transition_valid = self.check_action(agent, action) new_position = get_new_position(agent.position, new_direction) - # Is it a legal move? - # 1) transition allows the new_direction in the cell, - # 2) the new cell is not empty (case 0), - # 3) the cell is free, i.e., no agent is currently in that cell new_cell_valid = ( np.array_equal( # Check the new position is still in the grid new_position, @@ -505,11 +536,24 @@ class RailEnv(Environment): # Check the new position is not the same as any of the existing agent positions # (including itself, for simplicity, since it is moving) - cell_free = not np.any( - np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + cell_free = not np.any(np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) return cell_free, new_cell_valid, new_direction, new_position, transition_valid - def check_action(self, agent, action): + def check_action(self, agent: EnvAgent, action: RailEnvActions): + """ + + Parameters + ---------- + agent : EnvAgent + action : RailEnvActions + + Returns + ------- + Tuple[Grid4TransitionsEnum,Tuple[int,int]] + + + + """ transition_valid = None possible_transitions = self.rail.get_transitions(*agent.position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) @@ -527,15 +571,68 @@ class RailEnv(Environment): new_direction %= 4 - if action == RailEnvActions.MOVE_FORWARD: - if num_transitions == 1: - # - dead-end, straight line or curved line; - # new_direction will be the only valid transition - # - take only available transition - new_direction = np.argmax(possible_transitions) - transition_valid = True + if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1: + # - dead-end, straight line or curved line; + # new_direction will be the only valid transition + # - take only available transition + new_direction = np.argmax(possible_transitions) + transition_valid = True return new_direction, transition_valid + @staticmethod + def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, + agent_position: Tuple[int, int], + rail: GridTransitionMap) -> Set[RailEnvNextAction]: + """ + Get the valid move actions (forward, left, right) for an agent. + + Parameters + ---------- + agent_direction : Grid4TransitionsEnum + agent_position: Tuple[int,int] + rail : GridTransitionMap + + + Returns + ------- + Set of `RailEnvNextAction` (tuples of (action,position,direction)) + Possible move actions (forward,left,right) and the next position/direction they lead to. + It is not checked that the next cell is free. + """ + valid_actions: Set[RailEnvNextAction] = OrderedSet() + possible_transitions = rail.get_transitions(*agent_position, agent_direction) + num_transitions = np.count_nonzero(possible_transitions) + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right], relative to the current orientation + # If only one transition is possible, the forward branch is aligned with it. + if rail.is_dead_end(agent_position): + action = RailEnvActions.MOVE_FORWARD + exit_direction = (agent_direction + 2) % 4 + if possible_transitions[exit_direction]: + new_position = get_new_position(agent_position, exit_direction) + valid_actions.add(RailEnvNextAction(action, new_position, exit_direction)) + elif num_transitions == 1: + action = RailEnvActions.MOVE_FORWARD + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + new_position = get_new_position(agent_position, new_direction) + valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) + else: + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + if new_direction == agent_direction: + action = RailEnvActions.MOVE_FORWARD + elif new_direction == (agent_direction + 1) % 4: + action = RailEnvActions.MOVE_RIGHT + elif new_direction == (agent_direction - 1) % 4: + action = RailEnvActions.MOVE_LEFT + else: + raise Exception("Illegal state") + + new_position = get_new_position(agent_position, new_direction) + valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) + return valid_actions + def _get_observations(self): self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict diff --git a/flatland/utils/ordered_set.py b/flatland/utils/ordered_set.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd1689488f566872445334ac6e3bb8362daa347 --- /dev/null +++ b/flatland/utils/ordered_set.py @@ -0,0 +1,49 @@ +# in order for enumeration to be deterministic for testing purposes +# https://stackoverflow.com/questions/1653970/does-python-have-an-ordered-set +from collections import OrderedDict +from collections.abc import MutableSet + + +class OrderedSet(OrderedDict, MutableSet): + + def update(self, *args, **kwargs): + if kwargs: + raise TypeError("update() takes no keyword arguments") + + for s in args: + for e in s: + self.add(e) + + def add(self, elem): + self[elem] = None + + def discard(self, elem): + self.pop(elem, None) + + def __le__(self, other): + return all(e in other for e in self) + + def __lt__(self, other): + return self <= other and self != other + + def __ge__(self, other): + return all(e in self for e in other) + + def __gt__(self, other): + return self >= other and self != other + + def __repr__(self): + return 'OrderedSet([%s])' % (', '.join(map(repr, self.keys()))) + + def __str__(self): + return '{%s}' % (', '.join(map(repr, self.keys()))) + + difference = property(lambda self: self.__sub__) + difference_update = property(lambda self: self.__isub__) + intersection = property(lambda self: self.__and__) + intersection_update = property(lambda self: self.__iand__) + issubset = property(lambda self: self.__le__) + issuperset = property(lambda self: self.__ge__) + symmetric_difference = property(lambda self: self.__xor__) + symmetric_difference_update = property(lambda self: self.__ixor__) + union = property(lambda self: self.__or__) diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index a0e2b995b35bc2d6984bf6274170a28c18cada70..d363597e107a63cdfc0c8f6f429e0f023b0b7c38 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -1,3 +1,5 @@ +import random + import numpy as np from flatland.envs.observations import GlobalObsForRailEnv @@ -29,6 +31,793 @@ def test_sparse_rail_generator(): # TODO test assertions! +def test_sparse_rail_generator_deterministic(): + """Check that sparse_rail_generator runs deterministic over different python versions!""" + random.seed(0) + np.random.seed(0) + + stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents + 'malfunction_rate': 70, # Rate of malfunction occurence + 'min_duration': 2, # Minimal duration of malfunction + 'max_duration': 5 # Max duration of malfunction + } + + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0., # Fast freight train + 1. / 3.: 0., # Slow commuter train + 1. / 4.: 0.} # Slow freight train + + env = RailEnv(width=25, + height=30, + rail_generator=sparse_rail_generator(num_cities=5, + # Number of cities in map (where train stations are) + num_intersections=4, + # Number of intersections (no start / target) + num_trainstations=25, # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities/intersections + seed=215545, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=1, + stochastic_data=stochastic_data, # Malfunction data generator + ) + assert env.rail.get_full_transitions(0, 0) == 0, "[0][0]" + assert env.rail.get_full_transitions(0, 1) == 0, "[0][1]" + assert env.rail.get_full_transitions(0, 2) == 0, "[0][2]" + assert env.rail.get_full_transitions(0, 3) == 0, "[0][3]" + assert env.rail.get_full_transitions(0, 4) == 0, "[0][4]" + assert env.rail.get_full_transitions(0, 5) == 0, "[0][5]" + assert env.rail.get_full_transitions(0, 6) == 0, "[0][6]" + assert env.rail.get_full_transitions(0, 7) == 0, "[0][7]" + assert env.rail.get_full_transitions(0, 8) == 0, "[0][8]" + assert env.rail.get_full_transitions(0, 9) == 0, "[0][9]" + assert env.rail.get_full_transitions(0, 10) == 0, "[0][10]" + assert env.rail.get_full_transitions(0, 11) == 0, "[0][11]" + assert env.rail.get_full_transitions(0, 12) == 0, "[0][12]" + assert env.rail.get_full_transitions(0, 13) == 0, "[0][13]" + assert env.rail.get_full_transitions(0, 14) == 0, "[0][14]" + assert env.rail.get_full_transitions(0, 15) == 0, "[0][15]" + assert env.rail.get_full_transitions(0, 16) == 0, "[0][16]" + assert env.rail.get_full_transitions(0, 17) == 0, "[0][17]" + assert env.rail.get_full_transitions(0, 18) == 0, "[0][18]" + assert env.rail.get_full_transitions(0, 19) == 0, "[0][19]" + assert env.rail.get_full_transitions(0, 20) == 0, "[0][20]" + assert env.rail.get_full_transitions(0, 21) == 0, "[0][21]" + assert env.rail.get_full_transitions(0, 22) == 0, "[0][22]" + assert env.rail.get_full_transitions(0, 23) == 0, "[0][23]" + assert env.rail.get_full_transitions(0, 24) == 0, "[0][24]" + assert env.rail.get_full_transitions(1, 0) == 0, "[1][0]" + assert env.rail.get_full_transitions(1, 1) == 0, "[1][1]" + assert env.rail.get_full_transitions(1, 2) == 0, "[1][2]" + assert env.rail.get_full_transitions(1, 3) == 0, "[1][3]" + assert env.rail.get_full_transitions(1, 4) == 0, "[1][4]" + assert env.rail.get_full_transitions(1, 5) == 0, "[1][5]" + assert env.rail.get_full_transitions(1, 6) == 0, "[1][6]" + assert env.rail.get_full_transitions(1, 7) == 0, "[1][7]" + assert env.rail.get_full_transitions(1, 8) == 0, "[1][8]" + assert env.rail.get_full_transitions(1, 9) == 0, "[1][9]" + assert env.rail.get_full_transitions(1, 10) == 0, "[1][10]" + assert env.rail.get_full_transitions(1, 11) == 0, "[1][11]" + assert env.rail.get_full_transitions(1, 12) == 0, "[1][12]" + assert env.rail.get_full_transitions(1, 13) == 0, "[1][13]" + assert env.rail.get_full_transitions(1, 14) == 0, "[1][14]" + assert env.rail.get_full_transitions(1, 15) == 0, "[1][15]" + assert env.rail.get_full_transitions(1, 16) == 0, "[1][16]" + assert env.rail.get_full_transitions(1, 17) == 0, "[1][17]" + assert env.rail.get_full_transitions(1, 18) == 0, "[1][18]" + assert env.rail.get_full_transitions(1, 19) == 0, "[1][19]" + assert env.rail.get_full_transitions(1, 20) == 0, "[1][20]" + assert env.rail.get_full_transitions(1, 21) == 0, "[1][21]" + assert env.rail.get_full_transitions(1, 22) == 0, "[1][22]" + assert env.rail.get_full_transitions(1, 23) == 0, "[1][23]" + assert env.rail.get_full_transitions(1, 24) == 0, "[1][24]" + assert env.rail.get_full_transitions(2, 0) == 0, "[2][0]" + assert env.rail.get_full_transitions(2, 1) == 0, "[2][1]" + assert env.rail.get_full_transitions(2, 2) == 0, "[2][2]" + assert env.rail.get_full_transitions(2, 3) == 0, "[2][3]" + assert env.rail.get_full_transitions(2, 4) == 0, "[2][4]" + assert env.rail.get_full_transitions(2, 5) == 0, "[2][5]" + assert env.rail.get_full_transitions(2, 6) == 0, "[2][6]" + assert env.rail.get_full_transitions(2, 7) == 0, "[2][7]" + assert env.rail.get_full_transitions(2, 8) == 0, "[2][8]" + assert env.rail.get_full_transitions(2, 9) == 0, "[2][9]" + assert env.rail.get_full_transitions(2, 10) == 0, "[2][10]" + assert env.rail.get_full_transitions(2, 11) == 0, "[2][11]" + assert env.rail.get_full_transitions(2, 12) == 0, "[2][12]" + assert env.rail.get_full_transitions(2, 13) == 0, "[2][13]" + assert env.rail.get_full_transitions(2, 14) == 0, "[2][14]" + assert env.rail.get_full_transitions(2, 15) == 0, "[2][15]" + assert env.rail.get_full_transitions(2, 16) == 0, "[2][16]" + assert env.rail.get_full_transitions(2, 17) == 0, "[2][17]" + assert env.rail.get_full_transitions(2, 18) == 0, "[2][18]" + assert env.rail.get_full_transitions(2, 19) == 0, "[2][19]" + assert env.rail.get_full_transitions(2, 20) == 0, "[2][20]" + assert env.rail.get_full_transitions(2, 21) == 0, "[2][21]" + assert env.rail.get_full_transitions(2, 22) == 0, "[2][22]" + assert env.rail.get_full_transitions(2, 23) == 0, "[2][23]" + assert env.rail.get_full_transitions(2, 24) == 0, "[2][24]" + assert env.rail.get_full_transitions(3, 0) == 0, "[3][0]" + assert env.rail.get_full_transitions(3, 1) == 0, "[3][1]" + assert env.rail.get_full_transitions(3, 2) == 0, "[3][2]" + assert env.rail.get_full_transitions(3, 3) == 16386, "[3][3]" + assert env.rail.get_full_transitions(3, 4) == 1025, "[3][4]" + assert env.rail.get_full_transitions(3, 5) == 1025, "[3][5]" + assert env.rail.get_full_transitions(3, 6) == 1025, "[3][6]" + assert env.rail.get_full_transitions(3, 7) == 1025, "[3][7]" + assert env.rail.get_full_transitions(3, 8) == 1025, "[3][8]" + assert env.rail.get_full_transitions(3, 9) == 1025, "[3][9]" + assert env.rail.get_full_transitions(3, 10) == 1025, "[3][10]" + assert env.rail.get_full_transitions(3, 11) == 1025, "[3][11]" + assert env.rail.get_full_transitions(3, 12) == 4608, "[3][12]" + assert env.rail.get_full_transitions(3, 13) == 0, "[3][13]" + assert env.rail.get_full_transitions(3, 14) == 0, "[3][14]" + assert env.rail.get_full_transitions(3, 15) == 0, "[3][15]" + assert env.rail.get_full_transitions(3, 16) == 0, "[3][16]" + assert env.rail.get_full_transitions(3, 17) == 0, "[3][17]" + assert env.rail.get_full_transitions(3, 18) == 0, "[3][18]" + assert env.rail.get_full_transitions(3, 19) == 0, "[3][19]" + assert env.rail.get_full_transitions(3, 20) == 0, "[3][20]" + assert env.rail.get_full_transitions(3, 21) == 0, "[3][21]" + assert env.rail.get_full_transitions(3, 22) == 8192, "[3][22]" + assert env.rail.get_full_transitions(3, 23) == 0, "[3][23]" + assert env.rail.get_full_transitions(3, 24) == 0, "[3][24]" + assert env.rail.get_full_transitions(4, 0) == 0, "[4][0]" + assert env.rail.get_full_transitions(4, 1) == 0, "[4][1]" + assert env.rail.get_full_transitions(4, 2) == 0, "[4][2]" + assert env.rail.get_full_transitions(4, 3) == 32800, "[4][3]" + assert env.rail.get_full_transitions(4, 4) == 0, "[4][4]" + assert env.rail.get_full_transitions(4, 5) == 0, "[4][5]" + assert env.rail.get_full_transitions(4, 6) == 0, "[4][6]" + assert env.rail.get_full_transitions(4, 7) == 0, "[4][7]" + assert env.rail.get_full_transitions(4, 8) == 0, "[4][8]" + assert env.rail.get_full_transitions(4, 9) == 0, "[4][9]" + assert env.rail.get_full_transitions(4, 10) == 0, "[4][10]" + assert env.rail.get_full_transitions(4, 11) == 0, "[4][11]" + assert env.rail.get_full_transitions(4, 12) == 32800, "[4][12]" + assert env.rail.get_full_transitions(4, 13) == 0, "[4][13]" + assert env.rail.get_full_transitions(4, 14) == 0, "[4][14]" + assert env.rail.get_full_transitions(4, 15) == 0, "[4][15]" + assert env.rail.get_full_transitions(4, 16) == 0, "[4][16]" + assert env.rail.get_full_transitions(4, 17) == 0, "[4][17]" + assert env.rail.get_full_transitions(4, 18) == 0, "[4][18]" + assert env.rail.get_full_transitions(4, 19) == 0, "[4][19]" + assert env.rail.get_full_transitions(4, 20) == 0, "[4][20]" + assert env.rail.get_full_transitions(4, 21) == 0, "[4][21]" + assert env.rail.get_full_transitions(4, 22) == 32800, "[4][22]" + assert env.rail.get_full_transitions(4, 23) == 0, "[4][23]" + assert env.rail.get_full_transitions(4, 24) == 0, "[4][24]" + assert env.rail.get_full_transitions(5, 0) == 0, "[5][0]" + assert env.rail.get_full_transitions(5, 1) == 0, "[5][1]" + assert env.rail.get_full_transitions(5, 2) == 0, "[5][2]" + assert env.rail.get_full_transitions(5, 3) == 32800, "[5][3]" + assert env.rail.get_full_transitions(5, 4) == 0, "[5][4]" + assert env.rail.get_full_transitions(5, 5) == 0, "[5][5]" + assert env.rail.get_full_transitions(5, 6) == 0, "[5][6]" + assert env.rail.get_full_transitions(5, 7) == 0, "[5][7]" + assert env.rail.get_full_transitions(5, 8) == 0, "[5][8]" + assert env.rail.get_full_transitions(5, 9) == 0, "[5][9]" + assert env.rail.get_full_transitions(5, 10) == 0, "[5][10]" + assert env.rail.get_full_transitions(5, 11) == 0, "[5][11]" + assert env.rail.get_full_transitions(5, 12) == 32800, "[5][12]" + assert env.rail.get_full_transitions(5, 13) == 0, "[5][13]" + assert env.rail.get_full_transitions(5, 14) == 0, "[5][14]" + assert env.rail.get_full_transitions(5, 15) == 0, "[5][15]" + assert env.rail.get_full_transitions(5, 16) == 0, "[5][16]" + assert env.rail.get_full_transitions(5, 17) == 0, "[5][17]" + assert env.rail.get_full_transitions(5, 18) == 0, "[5][18]" + assert env.rail.get_full_transitions(5, 19) == 0, "[5][19]" + assert env.rail.get_full_transitions(5, 20) == 0, "[5][20]" + assert env.rail.get_full_transitions(5, 21) == 0, "[5][21]" + assert env.rail.get_full_transitions(5, 22) == 32800, "[5][22]" + assert env.rail.get_full_transitions(5, 23) == 0, "[5][23]" + assert env.rail.get_full_transitions(5, 24) == 0, "[5][24]" + assert env.rail.get_full_transitions(6, 0) == 0, "[6][0]" + assert env.rail.get_full_transitions(6, 1) == 0, "[6][1]" + assert env.rail.get_full_transitions(6, 2) == 0, "[6][2]" + assert env.rail.get_full_transitions(6, 3) == 32800, "[6][3]" + assert env.rail.get_full_transitions(6, 4) == 0, "[6][4]" + assert env.rail.get_full_transitions(6, 5) == 0, "[6][5]" + assert env.rail.get_full_transitions(6, 6) == 0, "[6][6]" + assert env.rail.get_full_transitions(6, 7) == 0, "[6][7]" + assert env.rail.get_full_transitions(6, 8) == 0, "[6][8]" + assert env.rail.get_full_transitions(6, 9) == 0, "[6][9]" + assert env.rail.get_full_transitions(6, 10) == 0, "[6][10]" + assert env.rail.get_full_transitions(6, 11) == 0, "[6][11]" + assert env.rail.get_full_transitions(6, 12) == 32800, "[6][12]" + assert env.rail.get_full_transitions(6, 13) == 0, "[6][13]" + assert env.rail.get_full_transitions(6, 14) == 0, "[6][14]" + assert env.rail.get_full_transitions(6, 15) == 0, "[6][15]" + assert env.rail.get_full_transitions(6, 16) == 0, "[6][16]" + assert env.rail.get_full_transitions(6, 17) == 0, "[6][17]" + assert env.rail.get_full_transitions(6, 18) == 0, "[6][18]" + assert env.rail.get_full_transitions(6, 19) == 0, "[6][19]" + assert env.rail.get_full_transitions(6, 20) == 0, "[6][20]" + assert env.rail.get_full_transitions(6, 21) == 0, "[6][21]" + assert env.rail.get_full_transitions(6, 22) == 32800, "[6][22]" + assert env.rail.get_full_transitions(6, 23) == 0, "[6][23]" + assert env.rail.get_full_transitions(6, 24) == 0, "[6][24]" + assert env.rail.get_full_transitions(7, 0) == 0, "[7][0]" + assert env.rail.get_full_transitions(7, 1) == 0, "[7][1]" + assert env.rail.get_full_transitions(7, 2) == 0, "[7][2]" + assert env.rail.get_full_transitions(7, 3) == 32800, "[7][3]" + assert env.rail.get_full_transitions(7, 4) == 0, "[7][4]" + assert env.rail.get_full_transitions(7, 5) == 0, "[7][5]" + assert env.rail.get_full_transitions(7, 6) == 0, "[7][6]" + assert env.rail.get_full_transitions(7, 7) == 0, "[7][7]" + assert env.rail.get_full_transitions(7, 8) == 0, "[7][8]" + assert env.rail.get_full_transitions(7, 9) == 0, "[7][9]" + assert env.rail.get_full_transitions(7, 10) == 0, "[7][10]" + assert env.rail.get_full_transitions(7, 11) == 0, "[7][11]" + assert env.rail.get_full_transitions(7, 12) == 32800, "[7][12]" + assert env.rail.get_full_transitions(7, 13) == 0, "[7][13]" + assert env.rail.get_full_transitions(7, 14) == 0, "[7][14]" + assert env.rail.get_full_transitions(7, 15) == 0, "[7][15]" + assert env.rail.get_full_transitions(7, 16) == 0, "[7][16]" + assert env.rail.get_full_transitions(7, 17) == 0, "[7][17]" + assert env.rail.get_full_transitions(7, 18) == 0, "[7][18]" + assert env.rail.get_full_transitions(7, 19) == 0, "[7][19]" + assert env.rail.get_full_transitions(7, 20) == 0, "[7][20]" + assert env.rail.get_full_transitions(7, 21) == 0, "[7][21]" + assert env.rail.get_full_transitions(7, 22) == 32800, "[7][22]" + assert env.rail.get_full_transitions(7, 23) == 0, "[7][23]" + assert env.rail.get_full_transitions(7, 24) == 0, "[7][24]" + assert env.rail.get_full_transitions(8, 0) == 0, "[8][0]" + assert env.rail.get_full_transitions(8, 1) == 0, "[8][1]" + assert env.rail.get_full_transitions(8, 2) == 0, "[8][2]" + assert env.rail.get_full_transitions(8, 3) == 32800, "[8][3]" + assert env.rail.get_full_transitions(8, 4) == 0, "[8][4]" + assert env.rail.get_full_transitions(8, 5) == 8192, "[8][5]" + assert env.rail.get_full_transitions(8, 6) == 0, "[8][6]" + assert env.rail.get_full_transitions(8, 7) == 0, "[8][7]" + assert env.rail.get_full_transitions(8, 8) == 0, "[8][8]" + assert env.rail.get_full_transitions(8, 9) == 8192, "[8][9]" + assert env.rail.get_full_transitions(8, 10) == 8192, "[8][10]" + assert env.rail.get_full_transitions(8, 11) == 0, "[8][11]" + assert env.rail.get_full_transitions(8, 12) == 32800, "[8][12]" + assert env.rail.get_full_transitions(8, 13) == 8192, "[8][13]" + assert env.rail.get_full_transitions(8, 14) == 0, "[8][14]" + assert env.rail.get_full_transitions(8, 15) == 0, "[8][15]" + assert env.rail.get_full_transitions(8, 16) == 0, "[8][16]" + assert env.rail.get_full_transitions(8, 17) == 0, "[8][17]" + assert env.rail.get_full_transitions(8, 18) == 0, "[8][18]" + assert env.rail.get_full_transitions(8, 19) == 0, "[8][19]" + assert env.rail.get_full_transitions(8, 20) == 0, "[8][20]" + assert env.rail.get_full_transitions(8, 21) == 0, "[8][21]" + assert env.rail.get_full_transitions(8, 22) == 32800, "[8][22]" + assert env.rail.get_full_transitions(8, 23) == 0, "[8][23]" + assert env.rail.get_full_transitions(8, 24) == 0, "[8][24]" + assert env.rail.get_full_transitions(9, 0) == 8192, "[9][0]" + assert env.rail.get_full_transitions(9, 1) == 0, "[9][1]" + assert env.rail.get_full_transitions(9, 2) == 0, "[9][2]" + assert env.rail.get_full_transitions(9, 3) == 32800, "[9][3]" + assert env.rail.get_full_transitions(9, 4) == 8192, "[9][4]" + assert env.rail.get_full_transitions(9, 5) == 32800, "[9][5]" + assert env.rail.get_full_transitions(9, 6) == 0, "[9][6]" + assert env.rail.get_full_transitions(9, 7) == 0, "[9][7]" + assert env.rail.get_full_transitions(9, 8) == 0, "[9][8]" + assert env.rail.get_full_transitions(9, 9) == 72, "[9][9]" + assert env.rail.get_full_transitions(9, 10) == 37408, "[9][10]" + assert env.rail.get_full_transitions(9, 11) == 0, "[9][11]" + assert env.rail.get_full_transitions(9, 12) == 49186, "[9][12]" + assert env.rail.get_full_transitions(9, 13) == 3089, "[9][13]" + assert env.rail.get_full_transitions(9, 14) == 4608, "[9][14]" + assert env.rail.get_full_transitions(9, 15) == 0, "[9][15]" + assert env.rail.get_full_transitions(9, 16) == 0, "[9][16]" + assert env.rail.get_full_transitions(9, 17) == 0, "[9][17]" + assert env.rail.get_full_transitions(9, 18) == 0, "[9][18]" + assert env.rail.get_full_transitions(9, 19) == 0, "[9][19]" + assert env.rail.get_full_transitions(9, 20) == 0, "[9][20]" + assert env.rail.get_full_transitions(9, 21) == 0, "[9][21]" + assert env.rail.get_full_transitions(9, 22) == 32800, "[9][22]" + assert env.rail.get_full_transitions(9, 23) == 0, "[9][23]" + assert env.rail.get_full_transitions(9, 24) == 0, "[9][24]" + assert env.rail.get_full_transitions(10, 0) == 32800, "[10][0]" + assert env.rail.get_full_transitions(10, 1) == 0, "[10][1]" + assert env.rail.get_full_transitions(10, 2) == 0, "[10][2]" + assert env.rail.get_full_transitions(10, 3) == 32800, "[10][3]" + assert env.rail.get_full_transitions(10, 4) == 32800, "[10][4]" + assert env.rail.get_full_transitions(10, 5) == 32800, "[10][5]" + assert env.rail.get_full_transitions(10, 6) == 0, "[10][6]" + assert env.rail.get_full_transitions(10, 7) == 0, "[10][7]" + assert env.rail.get_full_transitions(10, 8) == 0, "[10][8]" + assert env.rail.get_full_transitions(10, 9) == 4, "[10][9]" + assert env.rail.get_full_transitions(10, 10) == 1097, "[10][10]" + assert env.rail.get_full_transitions(10, 11) == 1025, "[10][11]" + assert env.rail.get_full_transitions(10, 12) == 37408, "[10][12]" + assert env.rail.get_full_transitions(10, 13) == 0, "[10][13]" + assert env.rail.get_full_transitions(10, 14) == 128, "[10][14]" + assert env.rail.get_full_transitions(10, 15) == 0, "[10][15]" + assert env.rail.get_full_transitions(10, 16) == 0, "[10][16]" + assert env.rail.get_full_transitions(10, 17) == 0, "[10][17]" + assert env.rail.get_full_transitions(10, 18) == 0, "[10][18]" + assert env.rail.get_full_transitions(10, 19) == 0, "[10][19]" + assert env.rail.get_full_transitions(10, 20) == 0, "[10][20]" + assert env.rail.get_full_transitions(10, 21) == 0, "[10][21]" + assert env.rail.get_full_transitions(10, 22) == 32800, "[10][22]" + assert env.rail.get_full_transitions(10, 23) == 0, "[10][23]" + assert env.rail.get_full_transitions(10, 24) == 0, "[10][24]" + assert env.rail.get_full_transitions(11, 0) == 16458, "[11][0]" + assert env.rail.get_full_transitions(11, 1) == 17411, "[11][1]" + assert env.rail.get_full_transitions(11, 2) == 1025, "[11][2]" + assert env.rail.get_full_transitions(11, 3) == 52275, "[11][3]" + assert env.rail.get_full_transitions(11, 4) == 3089, "[11][4]" + assert env.rail.get_full_transitions(11, 5) == 2064, "[11][5]" + assert env.rail.get_full_transitions(11, 6) == 0, "[11][6]" + assert env.rail.get_full_transitions(11, 7) == 0, "[11][7]" + assert env.rail.get_full_transitions(11, 8) == 0, "[11][8]" + assert env.rail.get_full_transitions(11, 9) == 0, "[11][9]" + assert env.rail.get_full_transitions(11, 10) == 0, "[11][10]" + assert env.rail.get_full_transitions(11, 11) == 0, "[11][11]" + assert env.rail.get_full_transitions(11, 12) == 32800, "[11][12]" + assert env.rail.get_full_transitions(11, 13) == 0, "[11][13]" + assert env.rail.get_full_transitions(11, 14) == 0, "[11][14]" + assert env.rail.get_full_transitions(11, 15) == 0, "[11][15]" + assert env.rail.get_full_transitions(11, 16) == 0, "[11][16]" + assert env.rail.get_full_transitions(11, 17) == 0, "[11][17]" + assert env.rail.get_full_transitions(11, 18) == 0, "[11][18]" + assert env.rail.get_full_transitions(11, 19) == 0, "[11][19]" + assert env.rail.get_full_transitions(11, 20) == 0, "[11][20]" + assert env.rail.get_full_transitions(11, 21) == 0, "[11][21]" + assert env.rail.get_full_transitions(11, 22) == 32800, "[11][22]" + assert env.rail.get_full_transitions(11, 23) == 0, "[11][23]" + assert env.rail.get_full_transitions(11, 24) == 0, "[11][24]" + assert env.rail.get_full_transitions(12, 0) == 128, "[12][0]" + assert env.rail.get_full_transitions(12, 1) == 128, "[12][1]" + assert env.rail.get_full_transitions(12, 2) == 0, "[12][2]" + assert env.rail.get_full_transitions(12, 3) == 49186, "[12][3]" + assert env.rail.get_full_transitions(12, 4) == 1025, "[12][4]" + assert env.rail.get_full_transitions(12, 5) == 1025, "[12][5]" + assert env.rail.get_full_transitions(12, 6) == 1025, "[12][6]" + assert env.rail.get_full_transitions(12, 7) == 1025, "[12][7]" + assert env.rail.get_full_transitions(12, 8) == 1025, "[12][8]" + assert env.rail.get_full_transitions(12, 9) == 1025, "[12][9]" + assert env.rail.get_full_transitions(12, 10) == 1025, "[12][10]" + assert env.rail.get_full_transitions(12, 11) == 1025, "[12][11]" + assert env.rail.get_full_transitions(12, 12) == 34864, "[12][12]" + assert env.rail.get_full_transitions(12, 13) == 0, "[12][13]" + assert env.rail.get_full_transitions(12, 14) == 0, "[12][14]" + assert env.rail.get_full_transitions(12, 15) == 0, "[12][15]" + assert env.rail.get_full_transitions(12, 16) == 0, "[12][16]" + assert env.rail.get_full_transitions(12, 17) == 0, "[12][17]" + assert env.rail.get_full_transitions(12, 18) == 0, "[12][18]" + assert env.rail.get_full_transitions(12, 19) == 0, "[12][19]" + assert env.rail.get_full_transitions(12, 20) == 0, "[12][20]" + assert env.rail.get_full_transitions(12, 21) == 0, "[12][21]" + assert env.rail.get_full_transitions(12, 22) == 32800, "[12][22]" + assert env.rail.get_full_transitions(12, 23) == 0, "[12][23]" + assert env.rail.get_full_transitions(12, 24) == 0, "[12][24]" + assert env.rail.get_full_transitions(13, 0) == 0, "[13][0]" + assert env.rail.get_full_transitions(13, 1) == 0, "[13][1]" + assert env.rail.get_full_transitions(13, 2) == 0, "[13][2]" + assert env.rail.get_full_transitions(13, 3) == 32800, "[13][3]" + assert env.rail.get_full_transitions(13, 4) == 0, "[13][4]" + assert env.rail.get_full_transitions(13, 5) == 0, "[13][5]" + assert env.rail.get_full_transitions(13, 6) == 0, "[13][6]" + assert env.rail.get_full_transitions(13, 7) == 0, "[13][7]" + assert env.rail.get_full_transitions(13, 8) == 0, "[13][8]" + assert env.rail.get_full_transitions(13, 9) == 0, "[13][9]" + assert env.rail.get_full_transitions(13, 10) == 0, "[13][10]" + assert env.rail.get_full_transitions(13, 11) == 0, "[13][11]" + assert env.rail.get_full_transitions(13, 12) == 32800, "[13][12]" + assert env.rail.get_full_transitions(13, 13) == 0, "[13][13]" + assert env.rail.get_full_transitions(13, 14) == 0, "[13][14]" + assert env.rail.get_full_transitions(13, 15) == 0, "[13][15]" + assert env.rail.get_full_transitions(13, 16) == 0, "[13][16]" + assert env.rail.get_full_transitions(13, 17) == 0, "[13][17]" + assert env.rail.get_full_transitions(13, 18) == 0, "[13][18]" + assert env.rail.get_full_transitions(13, 19) == 0, "[13][19]" + assert env.rail.get_full_transitions(13, 20) == 0, "[13][20]" + assert env.rail.get_full_transitions(13, 21) == 0, "[13][21]" + assert env.rail.get_full_transitions(13, 22) == 32800, "[13][22]" + assert env.rail.get_full_transitions(13, 23) == 0, "[13][23]" + assert env.rail.get_full_transitions(13, 24) == 0, "[13][24]" + assert env.rail.get_full_transitions(14, 0) == 0, "[14][0]" + assert env.rail.get_full_transitions(14, 1) == 0, "[14][1]" + assert env.rail.get_full_transitions(14, 2) == 0, "[14][2]" + assert env.rail.get_full_transitions(14, 3) == 32800, "[14][3]" + assert env.rail.get_full_transitions(14, 4) == 0, "[14][4]" + assert env.rail.get_full_transitions(14, 5) == 0, "[14][5]" + assert env.rail.get_full_transitions(14, 6) == 0, "[14][6]" + assert env.rail.get_full_transitions(14, 7) == 0, "[14][7]" + assert env.rail.get_full_transitions(14, 8) == 0, "[14][8]" + assert env.rail.get_full_transitions(14, 9) == 0, "[14][9]" + assert env.rail.get_full_transitions(14, 10) == 0, "[14][10]" + assert env.rail.get_full_transitions(14, 11) == 0, "[14][11]" + assert env.rail.get_full_transitions(14, 12) == 32800, "[14][12]" + assert env.rail.get_full_transitions(14, 13) == 0, "[14][13]" + assert env.rail.get_full_transitions(14, 14) == 0, "[14][14]" + assert env.rail.get_full_transitions(14, 15) == 0, "[14][15]" + assert env.rail.get_full_transitions(14, 16) == 0, "[14][16]" + assert env.rail.get_full_transitions(14, 17) == 0, "[14][17]" + assert env.rail.get_full_transitions(14, 18) == 0, "[14][18]" + assert env.rail.get_full_transitions(14, 19) == 0, "[14][19]" + assert env.rail.get_full_transitions(14, 20) == 0, "[14][20]" + assert env.rail.get_full_transitions(14, 21) == 0, "[14][21]" + assert env.rail.get_full_transitions(14, 22) == 32800, "[14][22]" + assert env.rail.get_full_transitions(14, 23) == 0, "[14][23]" + assert env.rail.get_full_transitions(14, 24) == 0, "[14][24]" + assert env.rail.get_full_transitions(15, 0) == 0, "[15][0]" + assert env.rail.get_full_transitions(15, 1) == 0, "[15][1]" + assert env.rail.get_full_transitions(15, 2) == 0, "[15][2]" + assert env.rail.get_full_transitions(15, 3) == 32800, "[15][3]" + assert env.rail.get_full_transitions(15, 4) == 0, "[15][4]" + assert env.rail.get_full_transitions(15, 5) == 0, "[15][5]" + assert env.rail.get_full_transitions(15, 6) == 0, "[15][6]" + assert env.rail.get_full_transitions(15, 7) == 0, "[15][7]" + assert env.rail.get_full_transitions(15, 8) == 0, "[15][8]" + assert env.rail.get_full_transitions(15, 9) == 0, "[15][9]" + assert env.rail.get_full_transitions(15, 10) == 0, "[15][10]" + assert env.rail.get_full_transitions(15, 11) == 0, "[15][11]" + assert env.rail.get_full_transitions(15, 12) == 32800, "[15][12]" + assert env.rail.get_full_transitions(15, 13) == 0, "[15][13]" + assert env.rail.get_full_transitions(15, 14) == 0, "[15][14]" + assert env.rail.get_full_transitions(15, 15) == 0, "[15][15]" + assert env.rail.get_full_transitions(15, 16) == 0, "[15][16]" + assert env.rail.get_full_transitions(15, 17) == 0, "[15][17]" + assert env.rail.get_full_transitions(15, 18) == 0, "[15][18]" + assert env.rail.get_full_transitions(15, 19) == 0, "[15][19]" + assert env.rail.get_full_transitions(15, 20) == 0, "[15][20]" + assert env.rail.get_full_transitions(15, 21) == 0, "[15][21]" + assert env.rail.get_full_transitions(15, 22) == 32800, "[15][22]" + assert env.rail.get_full_transitions(15, 23) == 0, "[15][23]" + assert env.rail.get_full_transitions(15, 24) == 0, "[15][24]" + assert env.rail.get_full_transitions(16, 0) == 0, "[16][0]" + assert env.rail.get_full_transitions(16, 1) == 0, "[16][1]" + assert env.rail.get_full_transitions(16, 2) == 0, "[16][2]" + assert env.rail.get_full_transitions(16, 3) == 32800, "[16][3]" + assert env.rail.get_full_transitions(16, 4) == 0, "[16][4]" + assert env.rail.get_full_transitions(16, 5) == 0, "[16][5]" + assert env.rail.get_full_transitions(16, 6) == 0, "[16][6]" + assert env.rail.get_full_transitions(16, 7) == 0, "[16][7]" + assert env.rail.get_full_transitions(16, 8) == 0, "[16][8]" + assert env.rail.get_full_transitions(16, 9) == 0, "[16][9]" + assert env.rail.get_full_transitions(16, 10) == 0, "[16][10]" + assert env.rail.get_full_transitions(16, 11) == 0, "[16][11]" + assert env.rail.get_full_transitions(16, 12) == 32800, "[16][12]" + assert env.rail.get_full_transitions(16, 13) == 0, "[16][13]" + assert env.rail.get_full_transitions(16, 14) == 0, "[16][14]" + assert env.rail.get_full_transitions(16, 15) == 0, "[16][15]" + assert env.rail.get_full_transitions(16, 16) == 0, "[16][16]" + assert env.rail.get_full_transitions(16, 17) == 0, "[16][17]" + assert env.rail.get_full_transitions(16, 18) == 0, "[16][18]" + assert env.rail.get_full_transitions(16, 19) == 0, "[16][19]" + assert env.rail.get_full_transitions(16, 20) == 0, "[16][20]" + assert env.rail.get_full_transitions(16, 21) == 0, "[16][21]" + assert env.rail.get_full_transitions(16, 22) == 32800, "[16][22]" + assert env.rail.get_full_transitions(16, 23) == 0, "[16][23]" + assert env.rail.get_full_transitions(16, 24) == 0, "[16][24]" + assert env.rail.get_full_transitions(17, 0) == 0, "[17][0]" + assert env.rail.get_full_transitions(17, 1) == 0, "[17][1]" + assert env.rail.get_full_transitions(17, 2) == 0, "[17][2]" + assert env.rail.get_full_transitions(17, 3) == 32800, "[17][3]" + assert env.rail.get_full_transitions(17, 4) == 0, "[17][4]" + assert env.rail.get_full_transitions(17, 5) == 0, "[17][5]" + assert env.rail.get_full_transitions(17, 6) == 0, "[17][6]" + assert env.rail.get_full_transitions(17, 7) == 0, "[17][7]" + assert env.rail.get_full_transitions(17, 8) == 0, "[17][8]" + assert env.rail.get_full_transitions(17, 9) == 0, "[17][9]" + assert env.rail.get_full_transitions(17, 10) == 0, "[17][10]" + assert env.rail.get_full_transitions(17, 11) == 0, "[17][11]" + assert env.rail.get_full_transitions(17, 12) == 32800, "[17][12]" + assert env.rail.get_full_transitions(17, 13) == 0, "[17][13]" + assert env.rail.get_full_transitions(17, 14) == 0, "[17][14]" + assert env.rail.get_full_transitions(17, 15) == 0, "[17][15]" + assert env.rail.get_full_transitions(17, 16) == 0, "[17][16]" + assert env.rail.get_full_transitions(17, 17) == 0, "[17][17]" + assert env.rail.get_full_transitions(17, 18) == 0, "[17][18]" + assert env.rail.get_full_transitions(17, 19) == 0, "[17][19]" + assert env.rail.get_full_transitions(17, 20) == 0, "[17][20]" + assert env.rail.get_full_transitions(17, 21) == 0, "[17][21]" + assert env.rail.get_full_transitions(17, 22) == 32800, "[17][22]" + assert env.rail.get_full_transitions(17, 23) == 0, "[17][23]" + assert env.rail.get_full_transitions(17, 24) == 0, "[17][24]" + assert env.rail.get_full_transitions(18, 0) == 0, "[18][0]" + assert env.rail.get_full_transitions(18, 1) == 0, "[18][1]" + assert env.rail.get_full_transitions(18, 2) == 0, "[18][2]" + assert env.rail.get_full_transitions(18, 3) == 32800, "[18][3]" + assert env.rail.get_full_transitions(18, 4) == 0, "[18][4]" + assert env.rail.get_full_transitions(18, 5) == 0, "[18][5]" + assert env.rail.get_full_transitions(18, 6) == 0, "[18][6]" + assert env.rail.get_full_transitions(18, 7) == 0, "[18][7]" + assert env.rail.get_full_transitions(18, 8) == 0, "[18][8]" + assert env.rail.get_full_transitions(18, 9) == 0, "[18][9]" + assert env.rail.get_full_transitions(18, 10) == 0, "[18][10]" + assert env.rail.get_full_transitions(18, 11) == 0, "[18][11]" + assert env.rail.get_full_transitions(18, 12) == 32800, "[18][12]" + assert env.rail.get_full_transitions(18, 13) == 0, "[18][13]" + assert env.rail.get_full_transitions(18, 14) == 0, "[18][14]" + assert env.rail.get_full_transitions(18, 15) == 0, "[18][15]" + assert env.rail.get_full_transitions(18, 16) == 0, "[18][16]" + assert env.rail.get_full_transitions(18, 17) == 0, "[18][17]" + assert env.rail.get_full_transitions(18, 18) == 0, "[18][18]" + assert env.rail.get_full_transitions(18, 19) == 0, "[18][19]" + assert env.rail.get_full_transitions(18, 20) == 0, "[18][20]" + assert env.rail.get_full_transitions(18, 21) == 0, "[18][21]" + assert env.rail.get_full_transitions(18, 22) == 32800, "[18][22]" + assert env.rail.get_full_transitions(18, 23) == 0, "[18][23]" + assert env.rail.get_full_transitions(18, 24) == 0, "[18][24]" + assert env.rail.get_full_transitions(19, 0) == 0, "[19][0]" + assert env.rail.get_full_transitions(19, 1) == 0, "[19][1]" + assert env.rail.get_full_transitions(19, 2) == 0, "[19][2]" + assert env.rail.get_full_transitions(19, 3) == 32872, "[19][3]" + assert env.rail.get_full_transitions(19, 4) == 1025, "[19][4]" + assert env.rail.get_full_transitions(19, 5) == 1025, "[19][5]" + assert env.rail.get_full_transitions(19, 6) == 1025, "[19][6]" + assert env.rail.get_full_transitions(19, 7) == 1025, "[19][7]" + assert env.rail.get_full_transitions(19, 8) == 1025, "[19][8]" + assert env.rail.get_full_transitions(19, 9) == 1025, "[19][9]" + assert env.rail.get_full_transitions(19, 10) == 1025, "[19][10]" + assert env.rail.get_full_transitions(19, 11) == 1025, "[19][11]" + assert env.rail.get_full_transitions(19, 12) == 6672, "[19][12]" + assert env.rail.get_full_transitions(19, 13) == 0, "[19][13]" + assert env.rail.get_full_transitions(19, 14) == 0, "[19][14]" + assert env.rail.get_full_transitions(19, 15) == 0, "[19][15]" + assert env.rail.get_full_transitions(19, 16) == 0, "[19][16]" + assert env.rail.get_full_transitions(19, 17) == 0, "[19][17]" + assert env.rail.get_full_transitions(19, 18) == 0, "[19][18]" + assert env.rail.get_full_transitions(19, 19) == 0, "[19][19]" + assert env.rail.get_full_transitions(19, 20) == 0, "[19][20]" + assert env.rail.get_full_transitions(19, 21) == 0, "[19][21]" + assert env.rail.get_full_transitions(19, 22) == 32800, "[19][22]" + assert env.rail.get_full_transitions(19, 23) == 0, "[19][23]" + assert env.rail.get_full_transitions(19, 24) == 0, "[19][24]" + assert env.rail.get_full_transitions(20, 0) == 0, "[20][0]" + assert env.rail.get_full_transitions(20, 1) == 0, "[20][1]" + assert env.rail.get_full_transitions(20, 2) == 0, "[20][2]" + assert env.rail.get_full_transitions(20, 3) == 32800, "[20][3]" + assert env.rail.get_full_transitions(20, 4) == 0, "[20][4]" + assert env.rail.get_full_transitions(20, 5) == 0, "[20][5]" + assert env.rail.get_full_transitions(20, 6) == 0, "[20][6]" + assert env.rail.get_full_transitions(20, 7) == 0, "[20][7]" + assert env.rail.get_full_transitions(20, 8) == 0, "[20][8]" + assert env.rail.get_full_transitions(20, 9) == 0, "[20][9]" + assert env.rail.get_full_transitions(20, 10) == 0, "[20][10]" + assert env.rail.get_full_transitions(20, 11) == 0, "[20][11]" + assert env.rail.get_full_transitions(20, 12) == 32800, "[20][12]" + assert env.rail.get_full_transitions(20, 13) == 0, "[20][13]" + assert env.rail.get_full_transitions(20, 14) == 0, "[20][14]" + assert env.rail.get_full_transitions(20, 15) == 0, "[20][15]" + assert env.rail.get_full_transitions(20, 16) == 0, "[20][16]" + assert env.rail.get_full_transitions(20, 17) == 0, "[20][17]" + assert env.rail.get_full_transitions(20, 18) == 0, "[20][18]" + assert env.rail.get_full_transitions(20, 19) == 0, "[20][19]" + assert env.rail.get_full_transitions(20, 20) == 0, "[20][20]" + assert env.rail.get_full_transitions(20, 21) == 0, "[20][21]" + assert env.rail.get_full_transitions(20, 22) == 32800, "[20][22]" + assert env.rail.get_full_transitions(20, 23) == 0, "[20][23]" + assert env.rail.get_full_transitions(20, 24) == 0, "[20][24]" + assert env.rail.get_full_transitions(21, 0) == 0, "[21][0]" + assert env.rail.get_full_transitions(21, 1) == 0, "[21][1]" + assert env.rail.get_full_transitions(21, 2) == 0, "[21][2]" + assert env.rail.get_full_transitions(21, 3) == 32800, "[21][3]" + assert env.rail.get_full_transitions(21, 4) == 0, "[21][4]" + assert env.rail.get_full_transitions(21, 5) == 0, "[21][5]" + assert env.rail.get_full_transitions(21, 6) == 0, "[21][6]" + assert env.rail.get_full_transitions(21, 7) == 0, "[21][7]" + assert env.rail.get_full_transitions(21, 8) == 0, "[21][8]" + assert env.rail.get_full_transitions(21, 9) == 0, "[21][9]" + assert env.rail.get_full_transitions(21, 10) == 0, "[21][10]" + assert env.rail.get_full_transitions(21, 11) == 0, "[21][11]" + assert env.rail.get_full_transitions(21, 12) == 32800, "[21][12]" + assert env.rail.get_full_transitions(21, 13) == 0, "[21][13]" + assert env.rail.get_full_transitions(21, 14) == 0, "[21][14]" + assert env.rail.get_full_transitions(21, 15) == 0, "[21][15]" + assert env.rail.get_full_transitions(21, 16) == 0, "[21][16]" + assert env.rail.get_full_transitions(21, 17) == 0, "[21][17]" + assert env.rail.get_full_transitions(21, 18) == 0, "[21][18]" + assert env.rail.get_full_transitions(21, 19) == 0, "[21][19]" + assert env.rail.get_full_transitions(21, 20) == 0, "[21][20]" + assert env.rail.get_full_transitions(21, 21) == 0, "[21][21]" + assert env.rail.get_full_transitions(21, 22) == 32800, "[21][22]" + assert env.rail.get_full_transitions(21, 23) == 0, "[21][23]" + assert env.rail.get_full_transitions(21, 24) == 0, "[21][24]" + assert env.rail.get_full_transitions(22, 0) == 0, "[22][0]" + assert env.rail.get_full_transitions(22, 1) == 0, "[22][1]" + assert env.rail.get_full_transitions(22, 2) == 0, "[22][2]" + assert env.rail.get_full_transitions(22, 3) == 32800, "[22][3]" + assert env.rail.get_full_transitions(22, 4) == 0, "[22][4]" + assert env.rail.get_full_transitions(22, 5) == 0, "[22][5]" + assert env.rail.get_full_transitions(22, 6) == 0, "[22][6]" + assert env.rail.get_full_transitions(22, 7) == 0, "[22][7]" + assert env.rail.get_full_transitions(22, 8) == 0, "[22][8]" + assert env.rail.get_full_transitions(22, 9) == 0, "[22][9]" + assert env.rail.get_full_transitions(22, 10) == 0, "[22][10]" + assert env.rail.get_full_transitions(22, 11) == 0, "[22][11]" + assert env.rail.get_full_transitions(22, 12) == 32800, "[22][12]" + assert env.rail.get_full_transitions(22, 13) == 0, "[22][13]" + assert env.rail.get_full_transitions(22, 14) == 0, "[22][14]" + assert env.rail.get_full_transitions(22, 15) == 0, "[22][15]" + assert env.rail.get_full_transitions(22, 16) == 0, "[22][16]" + assert env.rail.get_full_transitions(22, 17) == 0, "[22][17]" + assert env.rail.get_full_transitions(22, 18) == 0, "[22][18]" + assert env.rail.get_full_transitions(22, 19) == 0, "[22][19]" + assert env.rail.get_full_transitions(22, 20) == 0, "[22][20]" + assert env.rail.get_full_transitions(22, 21) == 0, "[22][21]" + assert env.rail.get_full_transitions(22, 22) == 32800, "[22][22]" + assert env.rail.get_full_transitions(22, 23) == 0, "[22][23]" + assert env.rail.get_full_transitions(22, 24) == 0, "[22][24]" + assert env.rail.get_full_transitions(23, 0) == 0, "[23][0]" + assert env.rail.get_full_transitions(23, 1) == 0, "[23][1]" + assert env.rail.get_full_transitions(23, 2) == 0, "[23][2]" + assert env.rail.get_full_transitions(23, 3) == 32800, "[23][3]" + assert env.rail.get_full_transitions(23, 4) == 0, "[23][4]" + assert env.rail.get_full_transitions(23, 5) == 0, "[23][5]" + assert env.rail.get_full_transitions(23, 6) == 0, "[23][6]" + assert env.rail.get_full_transitions(23, 7) == 0, "[23][7]" + assert env.rail.get_full_transitions(23, 8) == 0, "[23][8]" + assert env.rail.get_full_transitions(23, 9) == 0, "[23][9]" + assert env.rail.get_full_transitions(23, 10) == 0, "[23][10]" + assert env.rail.get_full_transitions(23, 11) == 0, "[23][11]" + assert env.rail.get_full_transitions(23, 12) == 32800, "[23][12]" + assert env.rail.get_full_transitions(23, 13) == 0, "[23][13]" + assert env.rail.get_full_transitions(23, 14) == 0, "[23][14]" + assert env.rail.get_full_transitions(23, 15) == 0, "[23][15]" + assert env.rail.get_full_transitions(23, 16) == 0, "[23][16]" + assert env.rail.get_full_transitions(23, 17) == 0, "[23][17]" + assert env.rail.get_full_transitions(23, 18) == 0, "[23][18]" + assert env.rail.get_full_transitions(23, 19) == 0, "[23][19]" + assert env.rail.get_full_transitions(23, 20) == 0, "[23][20]" + assert env.rail.get_full_transitions(23, 21) == 0, "[23][21]" + assert env.rail.get_full_transitions(23, 22) == 32800, "[23][22]" + assert env.rail.get_full_transitions(23, 23) == 0, "[23][23]" + assert env.rail.get_full_transitions(23, 24) == 0, "[23][24]" + assert env.rail.get_full_transitions(24, 0) == 0, "[24][0]" + assert env.rail.get_full_transitions(24, 1) == 0, "[24][1]" + assert env.rail.get_full_transitions(24, 2) == 0, "[24][2]" + assert env.rail.get_full_transitions(24, 3) == 32800, "[24][3]" + assert env.rail.get_full_transitions(24, 4) == 0, "[24][4]" + assert env.rail.get_full_transitions(24, 5) == 0, "[24][5]" + assert env.rail.get_full_transitions(24, 6) == 0, "[24][6]" + assert env.rail.get_full_transitions(24, 7) == 0, "[24][7]" + assert env.rail.get_full_transitions(24, 8) == 0, "[24][8]" + assert env.rail.get_full_transitions(24, 9) == 8192, "[24][9]" + assert env.rail.get_full_transitions(24, 10) == 0, "[24][10]" + assert env.rail.get_full_transitions(24, 11) == 0, "[24][11]" + assert env.rail.get_full_transitions(24, 12) == 32800, "[24][12]" + assert env.rail.get_full_transitions(24, 13) == 0, "[24][13]" + assert env.rail.get_full_transitions(24, 14) == 0, "[24][14]" + assert env.rail.get_full_transitions(24, 15) == 0, "[24][15]" + assert env.rail.get_full_transitions(24, 16) == 0, "[24][16]" + assert env.rail.get_full_transitions(24, 17) == 0, "[24][17]" + assert env.rail.get_full_transitions(24, 18) == 0, "[24][18]" + assert env.rail.get_full_transitions(24, 19) == 0, "[24][19]" + assert env.rail.get_full_transitions(24, 20) == 0, "[24][20]" + assert env.rail.get_full_transitions(24, 21) == 0, "[24][21]" + assert env.rail.get_full_transitions(24, 22) == 32800, "[24][22]" + assert env.rail.get_full_transitions(24, 23) == 0, "[24][23]" + assert env.rail.get_full_transitions(24, 24) == 0, "[24][24]" + assert env.rail.get_full_transitions(25, 0) == 0, "[25][0]" + assert env.rail.get_full_transitions(25, 1) == 0, "[25][1]" + assert env.rail.get_full_transitions(25, 2) == 0, "[25][2]" + assert env.rail.get_full_transitions(25, 3) == 32800, "[25][3]" + assert env.rail.get_full_transitions(25, 4) == 0, "[25][4]" + assert env.rail.get_full_transitions(25, 5) == 8192, "[25][5]" + assert env.rail.get_full_transitions(25, 6) == 0, "[25][6]" + assert env.rail.get_full_transitions(25, 7) == 0, "[25][7]" + assert env.rail.get_full_transitions(25, 8) == 0, "[25][8]" + assert env.rail.get_full_transitions(25, 9) == 32800, "[25][9]" + assert env.rail.get_full_transitions(25, 10) == 0, "[25][10]" + assert env.rail.get_full_transitions(25, 11) == 8192, "[25][11]" + assert env.rail.get_full_transitions(25, 12) == 32800, "[25][12]" + assert env.rail.get_full_transitions(25, 13) == 0, "[25][13]" + assert env.rail.get_full_transitions(25, 14) == 0, "[25][14]" + assert env.rail.get_full_transitions(25, 15) == 0, "[25][15]" + assert env.rail.get_full_transitions(25, 16) == 0, "[25][16]" + assert env.rail.get_full_transitions(25, 17) == 0, "[25][17]" + assert env.rail.get_full_transitions(25, 18) == 0, "[25][18]" + assert env.rail.get_full_transitions(25, 19) == 0, "[25][19]" + assert env.rail.get_full_transitions(25, 20) == 0, "[25][20]" + assert env.rail.get_full_transitions(25, 21) == 0, "[25][21]" + assert env.rail.get_full_transitions(25, 22) == 32800, "[25][22]" + assert env.rail.get_full_transitions(25, 23) == 0, "[25][23]" + assert env.rail.get_full_transitions(25, 24) == 0, "[25][24]" + assert env.rail.get_full_transitions(26, 0) == 8192, "[26][0]" + assert env.rail.get_full_transitions(26, 1) == 4, "[26][1]" + assert env.rail.get_full_transitions(26, 2) == 4608, "[26][2]" + assert env.rail.get_full_transitions(26, 3) == 32800, "[26][3]" + assert env.rail.get_full_transitions(26, 4) == 0, "[26][4]" + assert env.rail.get_full_transitions(26, 5) == 32800, "[26][5]" + assert env.rail.get_full_transitions(26, 6) == 0, "[26][6]" + assert env.rail.get_full_transitions(26, 7) == 0, "[26][7]" + assert env.rail.get_full_transitions(26, 8) == 0, "[26][8]" + assert env.rail.get_full_transitions(26, 9) == 32800, "[26][9]" + assert env.rail.get_full_transitions(26, 10) == 0, "[26][10]" + assert env.rail.get_full_transitions(26, 11) == 32800, "[26][11]" + assert env.rail.get_full_transitions(26, 12) == 32800, "[26][12]" + assert env.rail.get_full_transitions(26, 13) == 0, "[26][13]" + assert env.rail.get_full_transitions(26, 14) == 0, "[26][14]" + assert env.rail.get_full_transitions(26, 15) == 0, "[26][15]" + assert env.rail.get_full_transitions(26, 16) == 0, "[26][16]" + assert env.rail.get_full_transitions(26, 17) == 0, "[26][17]" + assert env.rail.get_full_transitions(26, 18) == 0, "[26][18]" + assert env.rail.get_full_transitions(26, 19) == 0, "[26][19]" + assert env.rail.get_full_transitions(26, 20) == 0, "[26][20]" + assert env.rail.get_full_transitions(26, 21) == 0, "[26][21]" + assert env.rail.get_full_transitions(26, 22) == 32800, "[26][22]" + assert env.rail.get_full_transitions(26, 23) == 0, "[26][23]" + assert env.rail.get_full_transitions(26, 24) == 0, "[26][24]" + assert env.rail.get_full_transitions(27, 0) == 72, "[27][0]" + assert env.rail.get_full_transitions(27, 1) == 17411, "[27][1]" + assert env.rail.get_full_transitions(27, 2) == 1097, "[27][2]" + assert env.rail.get_full_transitions(27, 3) == 1097, "[27][3]" + assert env.rail.get_full_transitions(27, 4) == 5633, "[27][4]" + assert env.rail.get_full_transitions(27, 5) == 3089, "[27][5]" + assert env.rail.get_full_transitions(27, 6) == 1025, "[27][6]" + assert env.rail.get_full_transitions(27, 7) == 1025, "[27][7]" + assert env.rail.get_full_transitions(27, 8) == 1025, "[27][8]" + assert env.rail.get_full_transitions(27, 9) == 1097, "[27][9]" + assert env.rail.get_full_transitions(27, 10) == 17411, "[27][10]" + assert env.rail.get_full_transitions(27, 11) == 1097, "[27][11]" + assert env.rail.get_full_transitions(27, 12) == 1097, "[27][12]" + assert env.rail.get_full_transitions(27, 13) == 5633, "[27][13]" + assert env.rail.get_full_transitions(27, 14) == 1025, "[27][14]" + assert env.rail.get_full_transitions(27, 15) == 1025, "[27][15]" + assert env.rail.get_full_transitions(27, 16) == 1025, "[27][16]" + assert env.rail.get_full_transitions(27, 17) == 1025, "[27][17]" + assert env.rail.get_full_transitions(27, 18) == 1025, "[27][18]" + assert env.rail.get_full_transitions(27, 19) == 1025, "[27][19]" + assert env.rail.get_full_transitions(27, 20) == 1025, "[27][20]" + assert env.rail.get_full_transitions(27, 21) == 1025, "[27][21]" + assert env.rail.get_full_transitions(27, 22) == 2064, "[27][22]" + assert env.rail.get_full_transitions(27, 23) == 0, "[27][23]" + assert env.rail.get_full_transitions(27, 24) == 0, "[27][24]" + assert env.rail.get_full_transitions(28, 0) == 0, "[28][0]" + assert env.rail.get_full_transitions(28, 1) == 32800, "[28][1]" + assert env.rail.get_full_transitions(28, 2) == 0, "[28][2]" + assert env.rail.get_full_transitions(28, 3) == 0, "[28][3]" + assert env.rail.get_full_transitions(28, 4) == 72, "[28][4]" + assert env.rail.get_full_transitions(28, 5) == 256, "[28][5]" + assert env.rail.get_full_transitions(28, 6) == 0, "[28][6]" + assert env.rail.get_full_transitions(28, 7) == 0, "[28][7]" + assert env.rail.get_full_transitions(28, 8) == 0, "[28][8]" + assert env.rail.get_full_transitions(28, 9) == 0, "[28][9]" + assert env.rail.get_full_transitions(28, 10) == 32800, "[28][10]" + assert env.rail.get_full_transitions(28, 11) == 0, "[28][11]" + assert env.rail.get_full_transitions(28, 12) == 16386, "[28][12]" + assert env.rail.get_full_transitions(28, 13) == 34864, "[28][13]" + assert env.rail.get_full_transitions(28, 14) == 0, "[28][14]" + assert env.rail.get_full_transitions(28, 15) == 0, "[28][15]" + assert env.rail.get_full_transitions(28, 16) == 0, "[28][16]" + assert env.rail.get_full_transitions(28, 17) == 0, "[28][17]" + assert env.rail.get_full_transitions(28, 18) == 0, "[28][18]" + assert env.rail.get_full_transitions(28, 19) == 0, "[28][19]" + assert env.rail.get_full_transitions(28, 20) == 0, "[28][20]" + assert env.rail.get_full_transitions(28, 21) == 0, "[28][21]" + assert env.rail.get_full_transitions(28, 22) == 0, "[28][22]" + assert env.rail.get_full_transitions(28, 23) == 0, "[28][23]" + assert env.rail.get_full_transitions(28, 24) == 0, "[28][24]" + assert env.rail.get_full_transitions(29, 0) == 0, "[29][0]" + assert env.rail.get_full_transitions(29, 1) == 128, "[29][1]" + assert env.rail.get_full_transitions(29, 2) == 0, "[29][2]" + assert env.rail.get_full_transitions(29, 3) == 0, "[29][3]" + assert env.rail.get_full_transitions(29, 4) == 0, "[29][4]" + assert env.rail.get_full_transitions(29, 5) == 0, "[29][5]" + assert env.rail.get_full_transitions(29, 6) == 0, "[29][6]" + assert env.rail.get_full_transitions(29, 7) == 0, "[29][7]" + assert env.rail.get_full_transitions(29, 8) == 0, "[29][8]" + assert env.rail.get_full_transitions(29, 9) == 0, "[29][9]" + assert env.rail.get_full_transitions(29, 10) == 128, "[29][10]" + assert env.rail.get_full_transitions(29, 11) == 0, "[29][11]" + assert env.rail.get_full_transitions(29, 12) == 128, "[29][12]" + assert env.rail.get_full_transitions(29, 13) == 128, "[29][13]" + assert env.rail.get_full_transitions(29, 14) == 0, "[29][14]" + assert env.rail.get_full_transitions(29, 15) == 0, "[29][15]" + assert env.rail.get_full_transitions(29, 16) == 0, "[29][16]" + assert env.rail.get_full_transitions(29, 17) == 0, "[29][17]" + assert env.rail.get_full_transitions(29, 18) == 0, "[29][18]" + assert env.rail.get_full_transitions(29, 19) == 0, "[29][19]" + assert env.rail.get_full_transitions(29, 20) == 0, "[29][20]" + assert env.rail.get_full_transitions(29, 21) == 0, "[29][21]" + assert env.rail.get_full_transitions(29, 22) == 0, "[29][22]" + assert env.rail.get_full_transitions(29, 23) == 0, "[29][23]" + assert env.rail.get_full_transitions(29, 24) == 0, "[29][24]" + + def test_rail_env_action_required_info(): np.random.seed(0) speed_ration_map = {1.: 0.25, # Fast passenger train diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 81b61381ed67d927cac44f4c9733d8a040903ef5..fde9df58663993ae170c4c1e3fea55637feb4282 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -1,10 +1,16 @@ +import random + import numpy as np +from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import EnvAgent from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_schedule_generator +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator +from flatland.utils.rendertools import RenderTool +from test_utils import Replay class SingleAgentNavigationObs(TreeObsForRailEnv): @@ -42,7 +48,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: new_position = get_new_position(agent.position, direction) - min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction]) + min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) @@ -121,6 +127,7 @@ def test_malfunction_process_statistically(): 'min_duration': 3, 'max_duration': 3} np.random.seed(5) + random.seed(0) env = RailEnv(width=20, height=20, @@ -144,5 +151,328 @@ def test_malfunction_process_statistically(): env.step(action_dict) # check that generation of malfunctions works as expected - # results are different in py36 and py37, therefore no exact test on nb_malfunction - assert nb_malfunction > 150 + assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction) + + +def test_initial_malfunction(rendering=True): + random.seed(0) + np.random.seed(0) + + stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents + 'malfunction_rate': 70, # Rate of malfunction occurence + 'min_duration': 2, # Minimal duration of malfunction + 'max_duration': 5 # Max duration of malfunction + } + + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0., # Fast freight train + 1. / 3.: 0., # Slow commuter train + 1. / 4.: 0.} # Slow freight train + + env = RailEnv(width=25, + height=30, + rail_generator=sparse_rail_generator(num_cities=5, + # Number of cities in map (where train stations are) + num_intersections=4, + # Number of intersections (no start / target) + num_trainstations=25, # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities/intersections + seed=215545, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=1, + stochastic_data=stochastic_data, # Malfunction data generator + ) + + if rendering: + renderer = RenderTool(env) + renderer.render_env(show=True, frames=False, show_observations=False) + _action = dict() + + replay_steps = [ + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=3 + ), + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=2 + ), + # malfunction stops in the next step and we're still at the beginning of the cell + # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=1 + ), + Replay( + position=(28, 4), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0 + ), + Replay( + position=(27, 4), + direction=Grid4TransitionsEnum.NORTH, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0 + ) + ] + + info_dict = { + 'action_required': [True] + } + + for i, replay in enumerate(replay_steps): + + def _assert(actual, expected, msg): + assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) + + agent: EnvAgent = env.agents[0] + + _assert(agent.position, replay.position, 'position') + _assert(agent.direction, replay.direction, 'direction') + _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') + + if replay.action is not None: + assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) + _, _, _, info_dict = env.step({0: replay.action}) + + else: + assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) + _, _, _, info_dict = env.step({}) + + if rendering: + renderer.render_env(show=True, show_observations=True) + + +def test_initial_malfunction_stop_moving(rendering=True): + random.seed(0) + np.random.seed(0) + + stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents + 'malfunction_rate': 70, # Rate of malfunction occurence + 'min_duration': 2, # Minimal duration of malfunction + 'max_duration': 5 # Max duration of malfunction + } + + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0., # Fast freight train + 1. / 3.: 0., # Slow commuter train + 1. / 4.: 0.} # Slow freight train + + env = RailEnv(width=25, + height=30, + rail_generator=sparse_rail_generator(num_cities=5, + # Number of cities in map (where train stations are) + num_intersections=4, + # Number of intersections (no start / target) + num_trainstations=25, # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities/intersections + seed=215545, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=1, + stochastic_data=stochastic_data, # Malfunction data generator + ) + + if rendering: + renderer = RenderTool(env) + renderer.render_env(show=True, frames=False, show_observations=False) + _action = dict() + + replay_steps = [ + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=3 + ), + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=2 + ), + # malfunction stops in the next step and we're still at the beginning of the cell + # --> if we take action DO_NOTHING, agent should restart without moving + # + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.STOP_MOVING, + malfunction=1 + ), + # we have stopped and do nothing --> should stand still + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=0 + ), + # we start to move forward --> should go to next cell now + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0 + ), + Replay( + position=(28, 4), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0 + ) + ] + + info_dict = { + 'action_required': [True] + } + + for i, replay in enumerate(replay_steps): + + def _assert(actual, expected, msg): + assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) + + agent: EnvAgent = env.agents[0] + + _assert(agent.position, replay.position, 'position') + _assert(agent.direction, replay.direction, 'direction') + _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') + + if replay.action is not None: + assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) + _, _, _, info_dict = env.step({0: replay.action}) + + else: + assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) + _, _, _, info_dict = env.step({}) + + if rendering: + renderer.render_env(show=True, show_observations=True) + + +def test_initial_malfunction_do_nothing(rendering=True): + random.seed(0) + np.random.seed(0) + + stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents + 'malfunction_rate': 70, # Rate of malfunction occurence + 'min_duration': 2, # Minimal duration of malfunction + 'max_duration': 5 # Max duration of malfunction + } + + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0., # Fast freight train + 1. / 3.: 0., # Slow commuter train + 1. / 4.: 0.} # Slow freight train + + env = RailEnv(width=25, + height=30, + rail_generator=sparse_rail_generator(num_cities=5, + # Number of cities in map (where train stations are) + num_intersections=4, + # Number of intersections (no start / target) + num_trainstations=25, # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities/intersections + seed=215545, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=1, + stochastic_data=stochastic_data, # Malfunction data generator + ) + + if rendering: + renderer = RenderTool(env) + renderer.render_env(show=True, frames=False, show_observations=False) + _action = dict() + + replay_steps = [ + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=3 + ), + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=2 + ), + # malfunction stops in the next step and we're still at the beginning of the cell + # --> if we take action DO_NOTHING, agent should restart without moving + # + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=1 + ), + # we haven't started moving yet --> stay here + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=0 + ), + # we start to move forward --> should go to next cell now + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0 + ), + Replay( + position=(28, 4), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0 + ) + ] + + info_dict = { + 'action_required': [True] + } + + for i, replay in enumerate(replay_steps): + + def _assert(actual, expected, msg): + assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) + + agent: EnvAgent = env.agents[0] + + _assert(agent.position, replay.position, 'position') + _assert(agent.direction, replay.direction, 'direction') + _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') + + if replay.action is not None: + assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) + _, _, _, info_dict = env.step({0: replay.action}) + + else: + assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) + _, _, _, info_dict = env.step({}) + + if rendering: + renderer.render_env(show=True, show_observations=True) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 86edc08c07552488e72537ec9b1f3b0b7625efed..1cf0c325ac48e9e3d5ac04fb51b5f8462c867726 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -1,7 +1,4 @@ -from typing import List - import numpy as np -from attr import attrib, attrs from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic @@ -12,6 +9,7 @@ from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail +from test_utils import ReplayConfig, Replay np.random.seed(1) @@ -97,21 +95,8 @@ def test_multi_speed_init(): old_pos[i_agent] = env.agents[i_agent].position -@attrs -class Replay(object): - position = attrib() - direction = attrib() - action = attrib(type=RailEnvActions) - malfunction = attrib(default=0, type=int) - - -@attrs -class TestConfig(object): - replay = attrib(type=List[Replay]) - target = attrib() - speed = attrib(type=float) - - +# TODO test penalties! +# TODO test invalid actions! def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): """Test that actions are correctly performed on cell exit for a single agent.""" rail, rail_map = make_simple_rail() @@ -132,7 +117,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): if rendering: renderer = RenderTool(env, gl="PILSVG") - test_config = TestConfig( + test_config = ReplayConfig( replay=[ Replay( position=(3, 9), # east dead-end @@ -179,6 +164,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): direction=Grid4TransitionsEnum.SOUTH, action=RailEnvActions.STOP_MOVING ), + # Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, @@ -205,7 +191,6 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): speed=0.5 ) - # TODO test penalties! agentStatic: EnvAgentStatic = env.agents_static[0] info_dict = { 'action_required': [True] @@ -230,7 +215,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): _assert(agent.position, replay.position, 'position') _assert(agent.direction, replay.direction, 'direction') - if replay.action: + if replay.action is not None: assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) _, _, _, info_dict = env.step({0: replay.action}) @@ -263,7 +248,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): renderer = RenderTool(env, gl="PILSVG") test_configs = [ - TestConfig( + ReplayConfig( replay=[ Replay( position=(3, 8), @@ -331,7 +316,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): ], target=(3, 0), # west dead-end speed=1 / 3), - TestConfig( + ReplayConfig( replay=[ Replay( position=(3, 9), # east dead-end @@ -438,13 +423,13 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): _assert(a, agent.position, replay.position, 'position') _assert(a, agent.direction, replay.direction, 'direction') - - - if replay.action: - assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(step, a, True) + if replay.action is not None: + assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format( + step, a, True) action_dict[a] = replay.action else: - assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(step, a, False) + assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format( + step, a, False) _, _, _, info_dict = env.step(action_dict) if rendering: @@ -471,7 +456,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): if rendering: renderer = RenderTool(env, gl="PILSVG") - test_config = TestConfig( + test_config = ReplayConfig( replay=[ Replay( position=(3, 9), # east dead-end @@ -493,7 +478,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, - malfunction=2 # recovers in two steps from now! + malfunction=2 # recovers in two steps from now! ), # agent recovers in this step Replay( @@ -515,7 +500,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, - malfunction=2 # recovers in two steps from now! + malfunction=2 # recovers in two steps from now! ), # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken! Replay( @@ -548,9 +533,20 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): direction=Grid4TransitionsEnum.SOUTH, action=None ), + # DO_NOTHING keeps moving! Replay( position=(5, 6), direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.DO_NOTHING + ), + Replay( + position=(5, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=None + ), + Replay( + position=(6, 6), + direction=Grid4TransitionsEnum.SOUTH, action=RailEnvActions.MOVE_FORWARD ), @@ -584,10 +580,11 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): _assert(agent.position, replay.position, 'position') _assert(agent.direction, replay.direction, 'direction') - if replay.malfunction: - agent.malfunction_data['malfunction'] = 2 + if replay.malfunction > 0: + agent.malfunction_data['malfunction'] = replay.malfunction + agent.malfunction_data['moving_before_malfunction'] = agent.moving - if replay.action: + if replay.action is not None: assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) _, _, _, info_dict = env.step({0: replay.action}) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6347bd0f5048350c099ba2568dac7caba74baf2d --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,22 @@ +"""Test Utils.""" +from typing import List, Tuple + +from attr import attrs, attrib + +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.rail_env import RailEnvActions + + +@attrs +class Replay(object): + position = attrib(type=Tuple[int, int]) + direction = attrib(type=Grid4TransitionsEnum) + action = attrib(type=RailEnvActions) + malfunction = attrib(default=0, type=int) + + +@attrs +class ReplayConfig(object): + replay = attrib(type=List[Replay]) + target = attrib(type=Tuple[int, int]) + speed = attrib(type=float)