Skip to content
Snippets Groups Projects
Commit 3729c9d2 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch '92-reward-function-test' into 'master'

Resolve "Reward function test"

Closes #92

See merge request flatland/flatland!99
parents 5721f868 2ed3eb6a
No related branches found
No related tags found
No related merge requests found
from typing import Tuple
import numpy as np import numpy as np
from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
def make_simple_rail(): def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid: # We instantiate a very simple rail network on a 7x10 grid:
# | # |
# | # |
......
...@@ -53,4 +53,4 @@ def test_walker(): ...@@ -53,4 +53,4 @@ def test_walker():
print(obs_builder.distance_map[(0, *[0, 1], 1)]) print(obs_builder.distance_map[(0, *[0, 1], 1)])
assert obs_builder.distance_map[(0, *[0, 1], 1)] == 3 assert obs_builder.distance_map[(0, *[0, 1], 1)] == 3
print(obs_builder.distance_map[(0, *[0, 2], 3)]) print(obs_builder.distance_map[(0, *[0, 2], 3)])
assert obs_builder.distance_map[(0, *[0, 2], 1)] == 2 # does not work yet, Erik's proposal. assert obs_builder.distance_map[(0, *[0, 2], 1)] == 2
...@@ -3,9 +3,13 @@ ...@@ -3,9 +3,13 @@
import numpy as np import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.utils.rendertools import RenderTool
from tests.simple_rail import make_simple_rail from tests.simple_rail import make_simple_rail
"""Tests for `flatland` package.""" """Tests for `flatland` package."""
...@@ -35,3 +39,255 @@ def test_global_obs(): ...@@ -35,3 +39,255 @@ def test_global_obs():
# If this assertion is wrong, it means that the observation returned # If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell # places the agent on an empty cell
assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0) assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0)
def _step_along_shortest_path(env, obs_builder, rail):
actions = {}
expected_next_position = {}
for agent in env.agents:
agent: EnvAgent
shortest_distance = np.inf
for exit_direction in range(4):
neighbour = obs_builder._new_position(agent.position, exit_direction)
if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width:
desired_movement_from_new_cell = (exit_direction + 2) % 4
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `entry_direction' to the neighbour possible?
is_valid = obs_builder.env.rail.get_transition((neighbour[0], neighbour[1], agent_orientation),
desired_movement_from_new_cell)
if is_valid:
distance_to_target = obs_builder.distance_map[
(agent.handle, *agent.position, exit_direction)]
print("agent {} at {} facing {} taking {} distance {}".format(agent.handle, agent.position,
agent.direction,
exit_direction,
distance_to_target))
if distance_to_target < shortest_distance:
shortest_distance = distance_to_target
actions_to_be_taken_when_facing_north = {
Grid4TransitionsEnum.NORTH: RailEnvActions.MOVE_FORWARD,
Grid4TransitionsEnum.EAST: RailEnvActions.MOVE_RIGHT,
Grid4TransitionsEnum.WEST: RailEnvActions.MOVE_LEFT,
Grid4TransitionsEnum.SOUTH: RailEnvActions.DO_NOTHING,
}
print(" improved (direction) -> {}".format(exit_direction))
actions[agent.handle] = actions_to_be_taken_when_facing_north[
(exit_direction - agent.direction) % len(rail.transitions.get_direction_enum())]
expected_next_position[agent.handle] = neighbour
print(" improved (action) -> {}".format(actions[agent.handle]))
_, rewards, dones, _ = env.step(actions)
return rewards
def test_reward_function_conflict(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
obs_builder: TreeObsForRailEnv = env.obs_builder
# initialize agents_static
env.reset()
# set the initial position
agent = env.agents_static[0]
agent.position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
agent = env.agents_static[1]
agent.position = (3, 8) # east dead-end
agent.direction = 3 # west
agent.target = (6, 6) # south dead-end
agent.moving = True
# reset to set agents from agents_static
env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=True)
iteration = 0
expected_positions = {
0: {
0: (5, 6),
1: (3, 8)
},
# both can move
1: {
0: (4, 6),
1: (3, 7)
},
# first can move, second stuck
2: {
0: (3, 6),
1: (3, 7)
},
# both stuck from now on
3: {
0: (3, 6),
1: (3, 7)
},
4: {
0: (3, 6),
1: (3, 7)
},
5: {
0: (3, 6),
1: (3, 7)
},
}
while iteration < 5:
rewards = _step_along_shortest_path(env, obs_builder, rail)
for agent in env.agents:
assert rewards[agent.handle] == -1
expected_position = expected_positions[iteration + 1][agent.handle]
assert agent.position == expected_position, "[{}] agent {} at {}, expected {}".format(iteration + 1,
agent.handle,
agent.position,
expected_position)
if rendering:
renderer.renderEnv(show=True, show_observations=True)
iteration += 1
def test_reward_function_waiting(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
obs_builder: TreeObsForRailEnv = env.obs_builder
# initialize agents_static
env.reset()
# set the initial position
agent = env.agents_static[0]
agent.position = (3, 8) # east dead-end
agent.direction = 3 # west
agent.target = (3, 1) # west dead-end
agent.moving = True
agent = env.agents_static[1]
agent.position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.target = (3, 8) # east dead-end
agent.moving = True
# reset to set agents from agents_static
env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.renderEnv(show=True, show_observations=True)
iteration = 0
expectations = {
0: {
'positions': {
0: (3, 8),
1: (5, 6),
},
'rewards': [-1, -1],
},
1: {
'positions': {
0: (3, 7),
1: (4, 6),
},
'rewards': [-1, -1],
},
# second agent has to wait for first, first can continue
2: {
'positions': {
0: (3, 6),
1: (4, 6),
},
'rewards': [-1, -1],
},
# both can move again
3: {
'positions': {
0: (3, 5),
1: (3, 6),
},
'rewards': [-1, -1],
},
4: {
'positions': {
0: (3, 4),
1: (3, 7),
},
'rewards': [-1, -1],
},
# second reached target
5: {
'positions': {
0: (3, 3),
1: (3, 8),
},
'rewards': [-1, 0],
},
6: {
'positions': {
0: (3, 2),
1: (3, 8),
},
'rewards': [-1, 0],
},
# first reaches, target too
7: {
'positions': {
0: (3, 1),
1: (3, 8),
},
'rewards': [1, 1],
},
8: {
'positions': {
0: (3, 1),
1: (3, 8),
},
'rewards': [1, 1],
},
}
while iteration < 7:
rewards = _step_along_shortest_path(env, obs_builder, rail)
if rendering:
renderer.renderEnv(show=True, show_observations=True)
print(env.dones["__all__"])
for agent in env.agents:
agent: EnvAgent
print("[{}] agent {} at {}, target {} ".format(iteration + 1, agent.handle, agent.position, agent.target))
print(np.all([np.array_equal(agent2.position, agent2.target) for agent2 in env.agents]))
for agent in env.agents:
expected_position = expectations[iteration + 1]['positions'][agent.handle]
assert agent.position == expected_position, \
"[{}] agent {} at {}, expected {}".format(iteration + 1,
agent.handle,
agent.position,
expected_position)
expected_reward = expectations[iteration + 1]['rewards'][agent.handle]
actual_reward = rewards[agent.handle]
assert expected_reward == actual_reward, "[{}] agent {} reward {}, expected {}".format(iteration + 1,
agent.handle,
actual_reward,
expected_reward)
iteration += 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment