From 0a7ebfd07b92710c355f105fffea4d9da12108cb Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 18 Sep 2019 17:51:33 +0200 Subject: [PATCH] #178 bugfix step function intial malfunction --- flatland/envs/agent_utils.py | 14 +- flatland/envs/rail_env.py | 361 ++++++++++++++++------------- tests/test_flatland_malfunction.py | 230 +++++++++++++++++- tests/test_multi_speed.py | 5 +- 4 files changed, 440 insertions(+), 170 deletions(-) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index b228e10b..f659ec84 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -1,8 +1,11 @@ from itertools import starmap +from typing import Tuple import numpy as np from attr import attrs, attrib, Factory +from flatland.core.grid.grid4 import Grid4TransitionsEnum + @attrs class EnvAgentStatic(object): @@ -11,10 +14,10 @@ class EnvAgentStatic(object): rather than where it is at the moment. The target should also be stored here. """ - position = attrib() - direction = attrib() - target = attrib() - moving = attrib(default=False) + position = attrib(type=Tuple[int, int]) + direction = attrib(type=Grid4TransitionsEnum) + target = attrib(type=Tuple[int, int]) + moving = attrib(default=False, type=bool) # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0, # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous @@ -27,7 +30,8 @@ class EnvAgentStatic(object): # number of time the agent had to stop, since the last time it broke down malfunction_data = attrib( default=Factory( - lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0}))) + lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0, + 'moving_before_malfunction': False}))) @classmethod def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index df50b813..11df8b20 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -4,7 +4,7 @@ Definition of the RailEnv environment. # TODO: _ this is a global method --> utils or remove later import warnings from enum import IntEnum -from typing import List, Set, NamedTuple, Optional +from typing import List, Set, NamedTuple, Optional, Tuple, Dict import msgpack import msgpack_numpy as m @@ -122,7 +122,7 @@ class RailEnv(Environment): Environment init. Parameters - ------- + ---------- rail_generator : function The rail_generator function is a function that takes the width, height and agents handles of a rail environment, along with the number of times @@ -147,9 +147,6 @@ class RailEnv(Environment): ObservationBuilder-derived object that takes builds observation vectors for each agent. max_episode_steps : int or None - - file_name: you can load a pickle file. from previously saved *.pkl file - """ super().__init__() @@ -272,18 +269,10 @@ class RailEnv(Environment): agent.malfunction_data['malfunction'] = 0 - initial_malfunction = self._agent_new_malfunction(i_agent) + initial_malfunction = self._agent_malfunction(i_agent) + if initial_malfunction: - valid_actions = set(map(lambda x: x.action, self.get_valid_move_actions(agent))) - if RailEnvActions.MOVE_FORWARD in valid_actions: - agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD - elif RailEnvActions.MOVE_LEFT in valid_actions: - agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_LEFT - elif RailEnvActions.MOVE_RIGHT in valid_actions: - agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_RIGHT - else: - raise Exception( - "Agent {} cannot move forward/left/right from initial position".format(agent.handle)) + agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING self.num_resets += 1 self._elapsed_steps = 0 @@ -299,7 +288,7 @@ class RailEnv(Environment): # Return the new observation vectors for each agent return self._get_observations() - def _agent_new_malfunction(self, i_agent) -> bool: + def _agent_malfunction(self, i_agent) -> bool: """ Returns true if the agent enters into malfunction. (False, if not broken down or already broken down before). """ @@ -326,12 +315,28 @@ class RailEnv(Environment): num_broken_steps = np.random.randint(self.min_number_of_steps_broken, self.max_number_of_steps_broken + 1) + 1 agent.malfunction_data['malfunction'] = num_broken_steps + agent.malfunction_data['moving_before_malfunction'] = agent.moving return True + else: + # The train was broken before... + if agent.malfunction_data['malfunction'] > 0: + + # Last step of malfunction --> Agent starts moving again after getting fixed + if agent.malfunction_data['malfunction'] < 2: + agent.malfunction_data['malfunction'] -= 1 + + # restore moving state before malfunction without further penalty + self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction'] + + else: + agent.malfunction_data['malfunction'] -= 1 + + # Nothing left to do with broken agent + return True return False - # TODO refactor to decrease length of this method! - def step(self, action_dict_): + def step(self, action_dict_: Dict[int, RailEnvActions]): self._elapsed_steps += 1 # Reset the step rewards @@ -349,126 +354,7 @@ class RailEnv(Environment): return self._get_observations(), self.rewards_dict, self.dones, info_dict for i_agent in range(self.get_num_agents()): - - if self.dones[i_agent]: # this agent has already completed... - continue - - agent = self.agents[i_agent] - agent.old_direction = agent.direction - agent.old_position = agent.position - - # Check if agent breaks at this step - new_malfunction = self._agent_new_malfunction(i_agent) - - # Is the agent at the beginning of the cell? Then, it can take an action - # Design choice (Erik+Christian): - # as long as we're broken down at the beginning of the cell, we can choose other actions! - if agent.speed_data['position_fraction'] == 0.0: - # No action has been supplied for this agent -> set DO_NOTHING as default - if i_agent not in action_dict_: - action = RailEnvActions.DO_NOTHING - else: - action = action_dict_[i_agent] - - if action < 0 or action > len(RailEnvActions): - print('ERROR: illegal action=', action, - 'for agent with index=', i_agent, - '"DO NOTHING" will be executed instead') - action = RailEnvActions.DO_NOTHING - - if action == RailEnvActions.DO_NOTHING and agent.moving: - # Keep moving - action = RailEnvActions.MOVE_FORWARD - - if action == RailEnvActions.STOP_MOVING and agent.moving: - # Only allow halting an agent on entering new cells. - agent.moving = False - self.rewards_dict[i_agent] += self.stop_penalty - - if not agent.moving and not ( - action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): - # Allow agent to start with any forward or direction action - agent.moving = True - self.rewards_dict[i_agent] += self.start_penalty - - # Store the action - if agent.moving: - _action_stored = False - _, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(action, agent) - - if all([new_cell_valid, transition_valid]): - agent.speed_data['transition_action_on_cellexit'] = action - _action_stored = True - else: - # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, - # try to keep moving forward! - if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT): - _, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) - - if all([new_cell_valid, transition_valid]): - agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD - _action_stored = True - - if not _action_stored: - # If the agent cannot move due to an invalid transition, we set its state to not moving - self.rewards_dict[i_agent] += self.invalid_action_penalty - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - self.rewards_dict[i_agent] += self.stop_penalty - agent.moving = False - - # if we've just broken in this step, nothing else to do - if new_malfunction: - continue - - # The train was broken before... - if agent.malfunction_data['malfunction'] > 0: - - # Last step of malfunction --> Agent starts moving again after getting fixed - if agent.malfunction_data['malfunction'] < 2: - agent.malfunction_data['malfunction'] -= 1 - self.agents[i_agent].moving = True - - else: - agent.malfunction_data['malfunction'] -= 1 - - # Broken agents are stopped - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - self.agents[i_agent].moving = False - - # Nothing left to do with broken agent - continue - - # Now perform a movement. - # If agent.moving, increment the position_fraction by the speed of the agent - # If the new position fraction is >= 1, reset to 0, and perform the stored - # transition_action_on_cellexit if the cell is free. - if agent.moving: - - agent.speed_data['position_fraction'] += agent.speed_data['speed'] - if agent.speed_data['position_fraction'] >= 1.0: - # Perform stored action to transition to the next cell as soon as cell is free - # Notice that we've already check new_cell_valid and transition valid when we stored the action, - # so we only have to check cell_free now! - - # cell and transition validity was checked when we stored transition_action_on_cellexit! - cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( - agent.speed_data['transition_action_on_cellexit'], agent) - - # N.B. validity of new_cell and transition should have been verified before the action was stored! - assert new_cell_valid - assert transition_valid - if cell_free: - agent.position = new_position - agent.direction = new_direction - agent.speed_data['position_fraction'] = 0.0 - - if np.equal(agent.position, agent.target).all(): - self.dones[i_agent] = True - agent.moving = False - else: - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + self._step_agent(i_agent, action_dict_) # Check for end of episode + add global reward to all rewards! if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): @@ -496,17 +382,144 @@ class RailEnv(Environment): return self._get_observations(), self.rewards_dict, self.dones, info_dict - def _check_action_on_agent(self, action, agent): + def _step_agent(self, i_agent, action_dict_: Dict[int, RailEnvActions]): + """ + Performs a step and step, start and stop penalty on a single agent in the following sub steps: + - malfunction + - action handling if at the beginning of cell + - movement + Parameters + ---------- + i_agent : int + action_dict_ : Dict[int,RailEnvActions] + + """ + if self.dones[i_agent]: # this agent has already completed... + return + + agent = self.agents[i_agent] + agent.old_direction = agent.direction + agent.old_position = agent.position + + # is the agent malfunctioning? + malfunction = self._agent_malfunction(i_agent) + + # if agent is broken, actions are ignored and agent does not move, + # the agent is not penalized in this step! + if malfunction: + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + return + + # Is the agent at the beginning of the cell? Then, it can take an action. + if agent.speed_data['position_fraction'] == 0.0: + # No action has been supplied for this agent -> set DO_NOTHING as default + if i_agent not in action_dict_: + action = RailEnvActions.DO_NOTHING + else: + action = action_dict_[i_agent] + + if action < 0 or action > len(RailEnvActions): + print('ERROR: illegal action=', action, + 'for agent with index=', i_agent, + '"DO NOTHING" will be executed instead') + action = RailEnvActions.DO_NOTHING + + if action == RailEnvActions.DO_NOTHING and agent.moving: + # Keep moving + action = RailEnvActions.MOVE_FORWARD + + if action == RailEnvActions.STOP_MOVING and agent.moving: + # Only allow halting an agent on entering new cells. + agent.moving = False + self.rewards_dict[i_agent] += self.stop_penalty + + if not agent.moving and not ( + action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): + # Allow agent to start with any forward or direction action + agent.moving = True + self.rewards_dict[i_agent] += self.start_penalty + + # Store the action if action is moving + # If not moving, the action will be stored when the agent starts moving again. + if agent.moving: + _action_stored = False + _, new_cell_valid, new_direction, new_position, transition_valid = \ + self._check_action_on_agent(action, agent) + + if all([new_cell_valid, transition_valid]): + agent.speed_data['transition_action_on_cellexit'] = action + _action_stored = True + else: + # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, + # try to keep moving forward! + if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT): + _, new_cell_valid, new_direction, new_position, transition_valid = \ + self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) + + if all([new_cell_valid, transition_valid]): + agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD + _action_stored = True + + if not _action_stored: + # If the agent cannot move due to an invalid transition, we set its state to not moving + self.rewards_dict[i_agent] += self.invalid_action_penalty + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agent] += self.stop_penalty + agent.moving = False + + # Now perform a movement. + # If agent.moving, increment the position_fraction by the speed of the agent + # If the new position fraction is >= 1, reset to 0, and perform the stored + # transition_action_on_cellexit if the cell is free. + if agent.moving: + agent.speed_data['position_fraction'] += agent.speed_data['speed'] + if agent.speed_data['position_fraction'] >= 1.0: + # Perform stored action to transition to the next cell as soon as cell is free + # Notice that we've already checked new_cell_valid and transition valid when we stored the action, + # so we only have to check cell_free now! + + # cell and transition validity was checked when we stored transition_action_on_cellexit! + cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( + agent.speed_data['transition_action_on_cellexit'], agent) + + # N.B. validity of new_cell and transition should have been verified before the action was stored! + assert new_cell_valid + assert transition_valid + if cell_free: + agent.position = new_position + agent.direction = new_direction + agent.speed_data['position_fraction'] = 0.0 + + # has the agent reached its target? + if np.equal(agent.position, agent.target).all(): + self.dones[i_agent] = True + agent.moving = False + else: + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + + def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): + """ + + Parameters + ---------- + action : RailEnvActions + agent : EnvAgent + + Returns + ------- + bool + Is it a legal move? + 1) transition allows the new_direction in the cell, + 2) the new cell is not empty (case 0), + 3) the cell is free, i.e., no agent is currently in that cell + + """ # compute number of possible transitions in the current # cell used to check for invalid actions new_direction, transition_valid = self.check_action(agent, action) new_position = get_new_position(agent.position, new_direction) - # Is it a legal move? - # 1) transition allows the new_direction in the cell, - # 2) the new cell is not empty (case 0), - # 3) the cell is free, i.e., no agent is currently in that cell new_cell_valid = ( np.array_equal( # Check the new position is still in the grid new_position, @@ -522,11 +535,24 @@ class RailEnv(Environment): # Check the new position is not the same as any of the existing agent positions # (including itself, for simplicity, since it is moving) - cell_free = not np.any( - np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + cell_free = not np.any(np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) return cell_free, new_cell_valid, new_direction, new_position, transition_valid - def check_action(self, agent, action): + def check_action(self, agent: EnvAgent, action: RailEnvActions): + """ + + Parameters + ---------- + agent : EnvAgent + action : RailEnvActions + + Returns + ------- + Tuple[Grid4TransitionsEnum,Tuple[int,int]] + + + + """ transition_valid = None possible_transitions = self.rail.get_transitions(*agent.position, agent.direction) num_transitions = np.count_nonzero(possible_transitions) @@ -544,26 +570,41 @@ class RailEnv(Environment): new_direction %= 4 - if action == RailEnvActions.MOVE_FORWARD: - if num_transitions == 1: - # - dead-end, straight line or curved line; - # new_direction will be the only valid transition - # - take only available transition - new_direction = np.argmax(possible_transitions) - transition_valid = True + if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1: + # - dead-end, straight line or curved line; + # new_direction will be the only valid transition + # - take only available transition + new_direction = np.argmax(possible_transitions) + transition_valid = True return new_direction, transition_valid - def get_valid_move_actions(self, agent: EnvAgent) -> Set[RailEnvNextAction]: + @staticmethod + def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, + agent_position: Tuple[int, int], + rail: GridTransitionMap) -> Set[RailEnvNextAction]: + """ + Get the valid move actions (forward, left, right) for an agent. + + Parameters + ---------- + agent_direction : Grid4TransitionsEnum + agent_position: Tuple[int,int] + rail : GridTransitionMap + + + Returns + ------- + Set of `RailEnvNextAction` (tuples of (action,position,direction)) + Possible move actions (forward,left,right) and the next position/direction they lead to. + It is not checked that the next cell is free. + """ valid_actions: Set[RailEnvNextAction] = OrderedSet() - agent_position = agent.position - agent_direction = agent.direction - possible_transitions = self.rail.get_transitions(*agent_position, agent_direction) + possible_transitions = rail.get_transitions(*agent_position, agent_direction) num_transitions = np.count_nonzero(possible_transitions) - # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right], relative to the current orientation # If only one transition is possible, the forward branch is aligned with it. - if self.rail.is_dead_end(agent_position): + if rail.is_dead_end(agent_position): action = RailEnvActions.MOVE_FORWARD exit_direction = (agent_direction + 2) % 4 if possible_transitions[exit_direction]: diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 70be7a0e..fde9df58 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -3,6 +3,7 @@ import random import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import EnvAgent from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions @@ -46,7 +47,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): min_distances = [] for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: if possible_transitions[direction]: - new_position = self._new_position(agent.position, direction) + new_position = get_new_position(agent.position, direction) min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction]) else: min_distances.append(np.inf) @@ -150,8 +151,7 @@ def test_malfunction_process_statistically(): env.step(action_dict) # check that generation of malfunctions works as expected - # results are different in py36 and py37, therefore no exact test on nb_malfunction - assert nb_malfunction == 149, "nb_malfunction={}".format(nb_malfunction) + assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction) def test_initial_malfunction(rendering=True): @@ -207,6 +207,8 @@ def test_initial_malfunction(rendering=True): action=RailEnvActions.MOVE_FORWARD, malfunction=2 ), + # malfunction stops in the next step and we're still at the beginning of the cell + # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell Replay( position=(28, 5), direction=Grid4TransitionsEnum.EAST, @@ -252,3 +254,225 @@ def test_initial_malfunction(rendering=True): if rendering: renderer.render_env(show=True, show_observations=True) + + +def test_initial_malfunction_stop_moving(rendering=True): + random.seed(0) + np.random.seed(0) + + stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents + 'malfunction_rate': 70, # Rate of malfunction occurence + 'min_duration': 2, # Minimal duration of malfunction + 'max_duration': 5 # Max duration of malfunction + } + + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0., # Fast freight train + 1. / 3.: 0., # Slow commuter train + 1. / 4.: 0.} # Slow freight train + + env = RailEnv(width=25, + height=30, + rail_generator=sparse_rail_generator(num_cities=5, + # Number of cities in map (where train stations are) + num_intersections=4, + # Number of intersections (no start / target) + num_trainstations=25, # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities/intersections + seed=215545, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=1, + stochastic_data=stochastic_data, # Malfunction data generator + ) + + if rendering: + renderer = RenderTool(env) + renderer.render_env(show=True, frames=False, show_observations=False) + _action = dict() + + replay_steps = [ + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=3 + ), + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=2 + ), + # malfunction stops in the next step and we're still at the beginning of the cell + # --> if we take action DO_NOTHING, agent should restart without moving + # + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.STOP_MOVING, + malfunction=1 + ), + # we have stopped and do nothing --> should stand still + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=0 + ), + # we start to move forward --> should go to next cell now + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0 + ), + Replay( + position=(28, 4), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0 + ) + ] + + info_dict = { + 'action_required': [True] + } + + for i, replay in enumerate(replay_steps): + + def _assert(actual, expected, msg): + assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) + + agent: EnvAgent = env.agents[0] + + _assert(agent.position, replay.position, 'position') + _assert(agent.direction, replay.direction, 'direction') + _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') + + if replay.action is not None: + assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) + _, _, _, info_dict = env.step({0: replay.action}) + + else: + assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) + _, _, _, info_dict = env.step({}) + + if rendering: + renderer.render_env(show=True, show_observations=True) + + +def test_initial_malfunction_do_nothing(rendering=True): + random.seed(0) + np.random.seed(0) + + stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents + 'malfunction_rate': 70, # Rate of malfunction occurence + 'min_duration': 2, # Minimal duration of malfunction + 'max_duration': 5 # Max duration of malfunction + } + + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0., # Fast freight train + 1. / 3.: 0., # Slow commuter train + 1. / 4.: 0.} # Slow freight train + + env = RailEnv(width=25, + height=30, + rail_generator=sparse_rail_generator(num_cities=5, + # Number of cities in map (where train stations are) + num_intersections=4, + # Number of intersections (no start / target) + num_trainstations=25, # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities/intersections + seed=215545, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=1, + stochastic_data=stochastic_data, # Malfunction data generator + ) + + if rendering: + renderer = RenderTool(env) + renderer.render_env(show=True, frames=False, show_observations=False) + _action = dict() + + replay_steps = [ + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=3 + ), + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=2 + ), + # malfunction stops in the next step and we're still at the beginning of the cell + # --> if we take action DO_NOTHING, agent should restart without moving + # + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=1 + ), + # we haven't started moving yet --> stay here + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.DO_NOTHING, + malfunction=0 + ), + # we start to move forward --> should go to next cell now + Replay( + position=(28, 5), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0 + ), + Replay( + position=(28, 4), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0 + ) + ] + + info_dict = { + 'action_required': [True] + } + + for i, replay in enumerate(replay_steps): + + def _assert(actual, expected, msg): + assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) + + agent: EnvAgent = env.agents[0] + + _assert(agent.position, replay.position, 'position') + _assert(agent.direction, replay.direction, 'direction') + _assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') + + if replay.action is not None: + assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) + _, _, _, info_dict = env.step({0: replay.action}) + + else: + assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) + _, _, _, info_dict = env.step({}) + + if rendering: + renderer.render_env(show=True, show_observations=True) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 4b14a5cd..1cf0c325 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -580,8 +580,9 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True): _assert(agent.position, replay.position, 'position') _assert(agent.direction, replay.direction, 'direction') - if replay.malfunction: - agent.malfunction_data['malfunction'] = 2 + if replay.malfunction > 0: + agent.malfunction_data['malfunction'] = replay.malfunction + agent.malfunction_data['moving_before_malfunction'] = agent.moving if replay.action is not None: assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) -- GitLab