Commit 0a7ebfd0 authored by u214892's avatar u214892
Browse files

#178 bugfix step function intial malfunction

parent acb3717c
Pipeline #2060 passed with stages
in 33 minutes and 23 seconds
from itertools import starmap
from typing import Tuple
import numpy as np
from attr import attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum
@attrs
class EnvAgentStatic(object):
......@@ -11,10 +14,10 @@ class EnvAgentStatic(object):
rather than where it is at the moment.
The target should also be stored here.
"""
position = attrib()
direction = attrib()
target = attrib()
moving = attrib(default=False)
position = attrib(type=Tuple[int, int])
direction = attrib(type=Grid4TransitionsEnum)
target = attrib(type=Tuple[int, int])
moving = attrib(default=False, type=bool)
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
......@@ -27,7 +30,8 @@ class EnvAgentStatic(object):
# number of time the agent had to stop, since the last time it broke down
malfunction_data = attrib(
default=Factory(
lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0})))
lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
'moving_before_malfunction': False})))
@classmethod
def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None):
......
This diff is collapsed.
......@@ -3,6 +3,7 @@ import random
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
......@@ -46,7 +47,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = self._new_position(agent.position, direction)
new_position = get_new_position(agent.position, direction)
min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
......@@ -150,8 +151,7 @@ def test_malfunction_process_statistically():
env.step(action_dict)
# check that generation of malfunctions works as expected
# results are different in py36 and py37, therefore no exact test on nb_malfunction
assert nb_malfunction == 149, "nb_malfunction={}".format(nb_malfunction)
assert nb_malfunction == 156, "nb_malfunction={}".format(nb_malfunction)
def test_initial_malfunction(rendering=True):
......@@ -207,6 +207,8 @@ def test_initial_malfunction(rendering=True):
action=RailEnvActions.MOVE_FORWARD,
malfunction=2
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
......@@ -252,3 +254,225 @@ def test_initial_malfunction(rendering=True):
if rendering:
renderer.render_env(show=True, show_observations=True)
def test_initial_malfunction_stop_moving(rendering=True):
random.seed(0)
np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=25, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=215545, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
)
if rendering:
renderer = RenderTool(env)
renderer.render_env(show=True, frames=False, show_observations=False)
_action = dict()
replay_steps = [
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=3
),
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action DO_NOTHING, agent should restart without moving
#
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=1
),
# we have stopped and do nothing --> should stand still
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=0
),
# we start to move forward --> should go to next cell now
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0
),
Replay(
position=(28, 4),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0
)
]
info_dict = {
'action_required': [True]
}
for i, replay in enumerate(replay_steps):
def _assert(actual, expected, msg):
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected)
agent: EnvAgent = env.agents[0]
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
_assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
if replay.action is not None:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
if rendering:
renderer.render_env(show=True, show_observations=True)
def test_initial_malfunction_do_nothing(rendering=True):
random.seed(0)
np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=25, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=215545, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
)
if rendering:
renderer = RenderTool(env)
renderer.render_env(show=True, frames=False, show_observations=False)
_action = dict()
replay_steps = [
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=3
),
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action DO_NOTHING, agent should restart without moving
#
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=1
),
# we haven't started moving yet --> stay here
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=0
),
# we start to move forward --> should go to next cell now
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0
),
Replay(
position=(28, 4),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0
)
]
info_dict = {
'action_required': [True]
}
for i, replay in enumerate(replay_steps):
def _assert(actual, expected, msg):
assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected)
agent: EnvAgent = env.agents[0]
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
_assert(agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
if replay.action is not None:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
_, _, _, info_dict = env.step({0: replay.action})
else:
assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False)
_, _, _, info_dict = env.step({})
if rendering:
renderer.render_env(show=True, show_observations=True)
......@@ -580,8 +580,9 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
_assert(agent.position, replay.position, 'position')
_assert(agent.direction, replay.direction, 'direction')
if replay.malfunction:
agent.malfunction_data['malfunction'] = 2
if replay.malfunction > 0:
agent.malfunction_data['malfunction'] = replay.malfunction
agent.malfunction_data['moving_before_malfunction'] = agent.moving
if replay.action is not None:
assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment