From 29cc5cda3c7773112b46ab16f0a3e8fd0d6a17d6 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Mon, 16 Sep 2019 19:22:04 +0200 Subject: [PATCH] #178 bugfix initial malfunction --- flatland/envs/rail_env.py | 116 +++++++++++++++++++++-------- tests/test_flatland_malfunction.py | 111 ++++++++++++++++++++++++++- tests/test_multi_speed.py | 32 +++----- tests/test_utils.py | 21 ++++++ 4 files changed, 223 insertions(+), 57 deletions(-) create mode 100644 tests/test_utils.py diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index e4d69306..b7483b02 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -4,13 +4,14 @@ Definition of the RailEnv environment. # TODO: _ this is a global method --> utils or remove later import warnings from enum import IntEnum -from typing import List +from typing import List, Set, NamedTuple import msgpack import msgpack_numpy as m import numpy as np from flatland.core.env import Environment +from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent @@ -39,6 +40,11 @@ class RailEnvActions(IntEnum): }[a] +RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)]) +RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos), + ('next_direction', Grid4TransitionsEnum)]) + + class RailEnv(Environment): """ RailEnv environment class. @@ -262,7 +268,18 @@ class RailEnv(Environment): agent.malfunction_data['malfunction'] = 0 - self._agent_new_malfunction(i_agent, RailEnvActions.DO_NOTHING) + initial_malfunction = self._agent_new_malfunction(i_agent) + if initial_malfunction: + valid_actions = set(map(lambda x: x.action, self.get_valid_move_actions(agent))) + if RailEnvActions.MOVE_FORWARD in valid_actions: + agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD + elif RailEnvActions.MOVE_LEFT in valid_actions: + agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_LEFT + elif RailEnvActions.MOVE_RIGHT in valid_actions: + agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_RIGHT + else: + raise Exception( + "Agent {} cannot move forward/left/right from initial position".format(agent.handle)) self.num_resets += 1 self._elapsed_steps = 0 @@ -277,7 +294,7 @@ class RailEnv(Environment): # Return the new observation vectors for each agent return self._get_observations() - def _agent_new_malfunction(self, i_agent, action) -> bool: + def _agent_new_malfunction(self, i_agent) -> bool: """ Returns true if the agent enters into malfunction. (False, if not broken down or already broken down before). """ @@ -335,25 +352,25 @@ class RailEnv(Environment): agent.old_direction = agent.direction agent.old_position = agent.position - # No action has been supplied for this agent -> set DO_NOTHING as default - if i_agent not in action_dict_: - action = RailEnvActions.DO_NOTHING - else: - action = action_dict_[i_agent] - - if action < 0 or action > len(RailEnvActions): - print('ERROR: illegal action=', action, - 'for agent with index=', i_agent, - '"DO NOTHING" will be executed instead') - action = RailEnvActions.DO_NOTHING - # Check if agent breaks at this step - new_malfunction = self._agent_new_malfunction(i_agent, action) + new_malfunction = self._agent_new_malfunction(i_agent) # Is the agent at the beginning of the cell? Then, it can take an action # Design choice (Erik+Christian): # as long as we're broken down at the beginning of the cell, we can choose other actions! if agent.speed_data['position_fraction'] == 0.0: + # No action has been supplied for this agent -> set DO_NOTHING as default + if i_agent not in action_dict_: + action = RailEnvActions.DO_NOTHING + else: + action = action_dict_[i_agent] + + if action < 0 or action > len(RailEnvActions): + print('ERROR: illegal action=', action, + 'for agent with index=', i_agent, + '"DO NOTHING" will be executed instead') + action = RailEnvActions.DO_NOTHING + if action == RailEnvActions.DO_NOTHING and agent.moving: # Keep moving action = RailEnvActions.MOVE_FORWARD @@ -370,12 +387,14 @@ class RailEnv(Environment): self.rewards_dict[i_agent] += self.start_penalty # Store the action - if agent.moving and action not in [RailEnvActions.DO_NOTHING, RailEnvActions.STOP_MOVING]: + if agent.moving: + _action_stored = False _, new_cell_valid, new_direction, new_position, transition_valid = \ self._check_action_on_agent(action, agent) if all([new_cell_valid, transition_valid]): agent.speed_data['transition_action_on_cellexit'] = action + _action_stored = True else: # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, # try to keep moving forward! @@ -385,19 +404,14 @@ class RailEnv(Environment): if all([new_cell_valid, transition_valid]): agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD - else: - # If the agent cannot move due to an invalid transition, we set its state to not moving - self.rewards_dict[i_agent] += self.invalid_action_penalty - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - self.rewards_dict[i_agent] += self.stop_penalty - agent.moving = False - - else: - # If the agent cannot move due to an invalid transition, we set its state to not moving - self.rewards_dict[i_agent] += self.invalid_action_penalty - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - self.rewards_dict[i_agent] += self.stop_penalty - agent.moving = False + _action_stored = True + + if not _action_stored: + # If the agent cannot move due to an invalid transition, we set its state to not moving + self.rewards_dict[i_agent] += self.invalid_action_penalty + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agent] += self.stop_penalty + agent.moving = False # if we've just broken in this step, nothing else to do if new_malfunction: @@ -410,7 +424,6 @@ class RailEnv(Environment): if agent.malfunction_data['malfunction'] < 2: agent.malfunction_data['malfunction'] -= 1 self.agents[i_agent].moving = True - action = RailEnvActions.DO_NOTHING else: agent.malfunction_data['malfunction'] -= 1 @@ -438,6 +451,9 @@ class RailEnv(Environment): cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( agent.speed_data['transition_action_on_cellexit'], agent) + # N.B. validity of new_cell and transition should have been verified before the action was stored! + assert new_cell_valid + assert transition_valid if cell_free: agent.position = new_position agent.direction = new_direction @@ -532,6 +548,44 @@ class RailEnv(Environment): transition_valid = True return new_direction, transition_valid + def get_valid_move_actions(self, agent: EnvAgent) -> Set[RailEnvNextAction]: + valid_actions: Set[RailEnvNextAction] = set() + agent_position = agent.position + agent_direction = agent.direction + possible_transitions = self.rail.get_transitions(*agent_position, agent_direction) + num_transitions = np.count_nonzero(possible_transitions) + + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right], relative to the current orientation + # If only one transition is possible, the forward branch is aligned with it. + if self.rail.is_dead_end(agent_position): + action = RailEnvActions.MOVE_FORWARD + exit_direction = (agent_direction + 2) % 4 + if possible_transitions[exit_direction]: + new_position = get_new_position(agent_position, exit_direction) + valid_actions.add(RailEnvNextAction(action, new_position, exit_direction)) + elif num_transitions == 1: + action = RailEnvActions.MOVE_FORWARD + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + new_position = get_new_position(agent_position, new_direction) + valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) + else: + for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + if new_direction == agent_direction: + action = RailEnvActions.MOVE_FORWARD + elif new_direction == (agent_direction + 1) % 4: + action = RailEnvActions.MOVE_RIGHT + elif new_direction == (agent_direction - 1) % 4: + action = RailEnvActions.MOVE_LEFT + else: + raise Exception("Illegal state") + + new_position = get_new_position(agent_position, new_direction) + valid_actions.add(RailEnvNextAction(action, new_position, new_direction)) + return valid_actions + def _get_observations(self): self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index a63e9722..e74666e1 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -1,9 +1,15 @@ +import random + import numpy as np +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.agent_utils import EnvAgent from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_schedule_generator +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator +from flatland.utils.rendertools import RenderTool +from test_utils import Replay class SingleAgentNavigationObs(TreeObsForRailEnv): @@ -145,3 +151,102 @@ def test_malfunction_process_statistically(): # check that generation of malfunctions works as expected # results are different in py36 and py37, therefore no exact test on nb_malfunction assert nb_malfunction > 150 + + +# TODO test DO_NOTHING! +def test_initial_malfunction(rendering=True): + random.seed(0) + stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents + 'malfunction_rate': 70, # Rate of malfunction occurence + 'min_duration': 2, # Minimal duration of malfunction + 'max_duration': 5 # Max duration of malfunction + } + + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0., # Fast freight train + 1. / 3.: 0., # Slow commuter train + 1. / 4.: 0.} # Slow freight train + + env = RailEnv(width=25, + height=30, + rail_generator=sparse_rail_generator(num_cities=5, + # Number of cities in map (where train stations are) + num_intersections=4, + # Number of intersections (no start / target) + num_trainstations=25, # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities/intersections + seed=215545, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=1, + stochastic_data=stochastic_data, # Malfunction data generator + ) + + if rendering: + renderer = RenderTool(env) + renderer.render_env(show=True, frames=False, show_observations=False) + _action = dict() + + replay_steps = [ + Replay( + position=(27, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=3 + ), + Replay( + position=(27, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=2 + ), + Replay( + position=(27, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=1 + ), + Replay( + position=(27, 4), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0 + ), + Replay( + position=(27, 3), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0 + ) + ] + + info_dict = { + 'action_required': [True] + } + + for i, replay in enumerate(replay_steps): + + def _assert(actual, expected, msg): + assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) + + agent: EnvAgent = env.agents[0] + + _assert(agent.position, replay.position, 'position') + _assert(agent.direction, replay.direction, 'direction') + _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') + + if replay.action: + assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) + _, _, _, info_dict = env.step({0: replay.action}) + + else: + assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) + _, _, _, info_dict = env.step({}) + + if rendering: + renderer.render_env(show=True, show_observations=True) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 86edc08c..529e9412 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -1,7 +1,6 @@ -from typing import List +import time import numpy as np -from attr import attrib, attrs from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic @@ -12,6 +11,7 @@ from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail +from test_utils import TestConfig, Replay np.random.seed(1) @@ -97,21 +97,6 @@ def test_multi_speed_init(): old_pos[i_agent] = env.agents[i_agent].position -@attrs -class Replay(object): - position = attrib() - direction = attrib() - action = attrib(type=RailEnvActions) - malfunction = attrib(default=0, type=int) - - -@attrs -class TestConfig(object): - replay = attrib(type=List[Replay]) - target = attrib() - speed = attrib(type=float) - - def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): """Test that actions are correctly performed on cell exit for a single agent.""" rail, rail_map = make_simple_rail() @@ -179,6 +164,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): direction=Grid4TransitionsEnum.SOUTH, action=RailEnvActions.STOP_MOVING ), + # Replay( position=(4, 6), direction=Grid4TransitionsEnum.SOUTH, @@ -438,13 +424,13 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True): _assert(a, agent.position, replay.position, 'position') _assert(a, agent.direction, replay.direction, 'direction') - - if replay.action: - assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(step, a, True) + assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format( + step, a, True) action_dict[a] = replay.action else: - assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(step, a, False) + assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format( + step, a, False) _, _, _, info_dict = env.step(action_dict) if rendering: @@ -493,7 +479,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, - malfunction=2 # recovers in two steps from now! + malfunction=2 # recovers in two steps from now! ), # agent recovers in this step Replay( @@ -515,7 +501,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, - malfunction=2 # recovers in two steps from now! + malfunction=2 # recovers in two steps from now! ), # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken! Replay( diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..4bd84e76 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,21 @@ +"""Test Utils.""" +from typing import List + +from attr import attrs, attrib + +from flatland.envs.rail_env import RailEnvActions + + +@attrs +class Replay(object): + position = attrib() + direction = attrib() + action = attrib(type=RailEnvActions) + malfunction = attrib(default=0, type=int) + + +@attrs +class TestConfig(object): + replay = attrib(type=List[Replay]) + target = attrib() + speed = attrib(type=float) -- GitLab