Commit 29cc5cda authored by u214892's avatar u214892
Browse files

#178 bugfix initial malfunction

parent 58d478b2
......@@ -4,13 +4,14 @@ Definition of the RailEnv environment.
# TODO: _ this is a global method --> utils or remove later
import warnings
from enum import IntEnum
from typing import List
from typing import List, Set, NamedTuple
import msgpack
import msgpack_numpy as m
import numpy as np
from flatland.core.env import Environment
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
......@@ -39,6 +40,11 @@ class RailEnvActions(IntEnum):
}[a]
RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)])
RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos),
('next_direction', Grid4TransitionsEnum)])
class RailEnv(Environment):
"""
RailEnv environment class.
......@@ -262,7 +268,18 @@ class RailEnv(Environment):
agent.malfunction_data['malfunction'] = 0
self._agent_new_malfunction(i_agent, RailEnvActions.DO_NOTHING)
initial_malfunction = self._agent_new_malfunction(i_agent)
if initial_malfunction:
valid_actions = set(map(lambda x: x.action, self.get_valid_move_actions(agent)))
if RailEnvActions.MOVE_FORWARD in valid_actions:
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
elif RailEnvActions.MOVE_LEFT in valid_actions:
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_LEFT
elif RailEnvActions.MOVE_RIGHT in valid_actions:
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_RIGHT
else:
raise Exception(
"Agent {} cannot move forward/left/right from initial position".format(agent.handle))
self.num_resets += 1
self._elapsed_steps = 0
......@@ -277,7 +294,7 @@ class RailEnv(Environment):
# Return the new observation vectors for each agent
return self._get_observations()
def _agent_new_malfunction(self, i_agent, action) -> bool:
def _agent_new_malfunction(self, i_agent) -> bool:
"""
Returns true if the agent enters into malfunction. (False, if not broken down or already broken down before).
"""
......@@ -335,25 +352,25 @@ class RailEnv(Environment):
agent.old_direction = agent.direction
agent.old_position = agent.position
# No action has been supplied for this agent -> set DO_NOTHING as default
if i_agent not in action_dict_:
action = RailEnvActions.DO_NOTHING
else:
action = action_dict_[i_agent]
if action < 0 or action > len(RailEnvActions):
print('ERROR: illegal action=', action,
'for agent with index=', i_agent,
'"DO NOTHING" will be executed instead')
action = RailEnvActions.DO_NOTHING
# Check if agent breaks at this step
new_malfunction = self._agent_new_malfunction(i_agent, action)
new_malfunction = self._agent_new_malfunction(i_agent)
# Is the agent at the beginning of the cell? Then, it can take an action
# Design choice (Erik+Christian):
# as long as we're broken down at the beginning of the cell, we can choose other actions!
if agent.speed_data['position_fraction'] == 0.0:
# No action has been supplied for this agent -> set DO_NOTHING as default
if i_agent not in action_dict_:
action = RailEnvActions.DO_NOTHING
else:
action = action_dict_[i_agent]
if action < 0 or action > len(RailEnvActions):
print('ERROR: illegal action=', action,
'for agent with index=', i_agent,
'"DO NOTHING" will be executed instead')
action = RailEnvActions.DO_NOTHING
if action == RailEnvActions.DO_NOTHING and agent.moving:
# Keep moving
action = RailEnvActions.MOVE_FORWARD
......@@ -370,12 +387,14 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += self.start_penalty
# Store the action
if agent.moving and action not in [RailEnvActions.DO_NOTHING, RailEnvActions.STOP_MOVING]:
if agent.moving:
_action_stored = False
_, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(action, agent)
if all([new_cell_valid, transition_valid]):
agent.speed_data['transition_action_on_cellexit'] = action
_action_stored = True
else:
# But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
# try to keep moving forward!
......@@ -385,19 +404,14 @@ class RailEnv(Environment):
if all([new_cell_valid, transition_valid]):
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
else:
# If the agent cannot move due to an invalid transition, we set its state to not moving
self.rewards_dict[i_agent] += self.invalid_action_penalty
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += self.stop_penalty
agent.moving = False
else:
# If the agent cannot move due to an invalid transition, we set its state to not moving
self.rewards_dict[i_agent] += self.invalid_action_penalty
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += self.stop_penalty
agent.moving = False
_action_stored = True
if not _action_stored:
# If the agent cannot move due to an invalid transition, we set its state to not moving
self.rewards_dict[i_agent] += self.invalid_action_penalty
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += self.stop_penalty
agent.moving = False
# if we've just broken in this step, nothing else to do
if new_malfunction:
......@@ -410,7 +424,6 @@ class RailEnv(Environment):
if agent.malfunction_data['malfunction'] < 2:
agent.malfunction_data['malfunction'] -= 1
self.agents[i_agent].moving = True
action = RailEnvActions.DO_NOTHING
else:
agent.malfunction_data['malfunction'] -= 1
......@@ -438,6 +451,9 @@ class RailEnv(Environment):
cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
agent.speed_data['transition_action_on_cellexit'], agent)
# N.B. validity of new_cell and transition should have been verified before the action was stored!
assert new_cell_valid
assert transition_valid
if cell_free:
agent.position = new_position
agent.direction = new_direction
......@@ -532,6 +548,44 @@ class RailEnv(Environment):
transition_valid = True
return new_direction, transition_valid
def get_valid_move_actions(self, agent: EnvAgent) -> Set[RailEnvNextAction]:
valid_actions: Set[RailEnvNextAction] = set()
agent_position = agent.position
agent_direction = agent.direction
possible_transitions = self.rail.get_transitions(*agent_position, agent_direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if self.rail.is_dead_end(agent_position):
action = RailEnvActions.MOVE_FORWARD
exit_direction = (agent_direction + 2) % 4
if possible_transitions[exit_direction]:
new_position = get_new_position(agent_position, exit_direction)
valid_actions.add(RailEnvNextAction(action, new_position, exit_direction))
elif num_transitions == 1:
action = RailEnvActions.MOVE_FORWARD
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
new_position = get_new_position(agent_position, new_direction)
valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
else:
for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[new_direction]:
if new_direction == agent_direction:
action = RailEnvActions.MOVE_FORWARD
elif new_direction == (agent_direction + 1) % 4:
action = RailEnvActions.MOVE_RIGHT
elif new_direction == (agent_direction - 1) % 4:
action = RailEnvActions.MOVE_LEFT
else:
raise Exception("Illegal state")
new_position = get_new_position(agent_position, new_direction)
valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
return valid_actions
def _get_observations(self):
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
......
import random
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
from test_utils import Replay
class SingleAgentNavigationObs(TreeObsForRailEnv):
......@@ -145,3 +151,102 @@ def test_malfunction_process_statistically():
# check that generation of malfunctions works as expected
# results are different in py36 and py37, therefore no exact test on nb_malfunction
assert nb_malfunction > 150
# TODO test DO_NOTHING!
def test_initial_malfunction(rendering=True):
random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=25, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=215545, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
)
if rendering:
renderer = RenderTool(env)
renderer.render_env(show=True, frames=False, show_observations=False)
_action = dict()
replay_steps = [
Replay(
position=(27, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=3
),
Replay(
position=(27, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=2
),
Replay(
position=(27, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=1
),
Replay(
position=(27, 4),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0
),
Replay(
position=(27, 3),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0
)
]
info_dict = {
'action_required': [True]
}
for i, replay in enumerate(replay_steps):
def _assert(actual, expected, msg):
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected)
agent: EnvAgent = env.agents[0]
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
_assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
if replay.action:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
if rendering:
renderer.render_env(show=True, show_observations=True)
from typing import List
import time
import numpy as np
from attr import attrib, attrs
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
......@@ -12,6 +11,7 @@ from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid
from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
from test_utils import TestConfig, Replay
np.random.seed(1)
......@@ -97,21 +97,6 @@ def test_multi_speed_init():
old_pos[i_agent] = env.agents[i_agent].position
@attrs
class Replay(object):
position = attrib()
direction = attrib()
action = attrib(type=RailEnvActions)
malfunction = attrib(default=0, type=int)
@attrs
class TestConfig(object):
replay = attrib(type=List[Replay])
target = attrib()
speed = attrib(type=float)
def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
"""Test that actions are correctly performed on cell exit for a single agent."""
rail, rail_map = make_simple_rail()
......@@ -179,6 +164,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
direction=Grid4TransitionsEnum.SOUTH,
action=RailEnvActions.STOP_MOVING
),
#
Replay(
position=(4, 6),
direction=Grid4TransitionsEnum.SOUTH,
......@@ -438,13 +424,13 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
_assert(a, agent.position, replay.position, 'position')
_assert(a, agent.direction, replay.direction, 'direction')
if replay.action:
assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(step, a, True)
assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
step, a, True)
action_dict[a] = replay.action
else:
assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(step, a, False)
assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(
step, a, False)
_, _, _, info_dict = env.step(action_dict)
if rendering:
......@@ -493,7 +479,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
malfunction=2 # recovers in two steps from now!
malfunction=2 # recovers in two steps from now!
),
# agent recovers in this step
Replay(
......@@ -515,7 +501,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=2 # recovers in two steps from now!
malfunction=2 # recovers in two steps from now!
),
# agent recovers in this step; since we're at the beginning, we provide a different action although we're broken!
Replay(
......
"""Test Utils."""
from typing import List
from attr import attrs, attrib
from flatland.envs.rail_env import RailEnvActions
@attrs
class Replay(object):
position = attrib()
direction = attrib()
action = attrib(type=RailEnvActions)
malfunction = attrib(default=0, type=int)
@attrs
class TestConfig(object):
replay = attrib(type=List[Replay])
target = attrib()
speed = attrib(type=float)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment