Skip to content
Snippets Groups Projects
Commit f5022411 authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch '25-predictor-first-steps' into 'master'

Resolve "Predictor/planner API"

See merge request flatland/flatland!39
parents dfb25836 b95eb5f9
No related branches found
No related tags found
No related merge requests found
...@@ -84,6 +84,21 @@ class Environment: ...@@ -84,6 +84,21 @@ class Environment:
""" """
raise NotImplementedError() raise NotImplementedError()
def predict(self):
"""
Predictions step.
Returns predictions for the agents.
The returns are dicts mapping from agent_id strings to values.
Returns
-------
predictions : dict
New predictions for each ready agent.
"""
raise NotImplementedError()
def render(self): def render(self):
""" """
Perform rendering of the environment. Perform rendering of the environment.
......
...@@ -19,7 +19,6 @@ class ObservationBuilder: ...@@ -19,7 +19,6 @@ class ObservationBuilder:
def __init__(self): def __init__(self):
self.observation_space = () self.observation_space = ()
pass
def _set_env(self, env): def _set_env(self, env):
self.env = env self.env = env
......
"""
PredictionBuilder objects are objects that can be passed to environments designed for customizability.
The PredictionBuilder-derived custom classes implement 2 functions, reset() and get([handle]).
If predictions are not required in every step or not for all agents, then
+ Reset() is called after each environment reset, to allow for pre-computing relevant data.
+ Get() is called whenever an step has to be computed, potentially for each agent independently in
case of multi-agent environments.
"""
class PredictionBuilder:
"""
PredictionBuilder base class.
"""
def __init__(self, max_depth: int = 20):
self.max_depth = max_depth
def _set_env(self, env):
self.env = env
def reset(self):
"""
Called after each environment reset.
"""
pass
def get(self, handle=0):
"""
Called whenever step_prediction is called on the environment.
Parameters
-------
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
A prediction structure, specific to the corresponding environment.
"""
raise NotImplementedError()
"""
Collection of environment-specific PredictionBuilder.
"""
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
class DummyPredictorForRailEnv(PredictionBuilder):
"""
DummyPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, handle=None):
"""
Called whenever step_prediction is called on the environment.
Parameters
-------
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
prediction_dict = {}
for agent in agents:
# 0: do nothing
# 1: turn left and move to the next cell
# 2: move to the next cell in front of the agent
# 3: turn right and move to the next cell
action_priorities = [2, 1, 3]
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth, 5))
prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
for index in range(1, self.max_depth):
action_done = False
for action in action_priorities:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self.env._check_action_on_agent(action,
agent)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
# performed
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, new_position[0], new_position[1], new_direction, action]
action_done = True
break
if not action_done:
print("Cannot move further.")
prediction_dict[agent.handle] = prediction
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict
...@@ -51,7 +51,9 @@ class RailEnv(Environment): ...@@ -51,7 +51,9 @@ class RailEnv(Environment):
height, height,
rail_generator=random_rail_generator(), rail_generator=random_rail_generator(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2)): obs_builder_object=TreeObsForRailEnv(max_depth=2),
prediction_builder_object=None
):
""" """
Environment init. Environment init.
...@@ -94,6 +96,11 @@ class RailEnv(Environment): ...@@ -94,6 +96,11 @@ class RailEnv(Environment):
self.obs_builder = obs_builder_object self.obs_builder = obs_builder_object
self.obs_builder._set_env(self) self.obs_builder._set_env(self)
self.prediction_builder = prediction_builder_object
if self.prediction_builder:
self.prediction_builder._set_env(self)
self.action_space = [1] self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets? self.observation_space = self.obs_builder.observation_space # updated on resets?
...@@ -212,51 +219,9 @@ class RailEnv(Environment): ...@@ -212,51 +219,9 @@ class RailEnv(Environment):
return return
if action > 0: if action > 0:
# pos = agent.position # self.agents_position[i] cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self._check_action_on_agent(action,
# direction = agent.direction # self.agents_direction[i] agent,
transition_isValid)
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_isValid = self.check_action(agent, action)
new_position = get_new_position(agent.position, new_direction)
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell
# if (
# new_position[1] >= self.width or
# new_position[0] >= self.height or
# new_position[0] < 0 or new_position[1] < 0):
# new_cell_isValid = False
# if self.rail.get_transitions(new_position) == 0:
# new_cell_isValid = False
new_cell_isValid = (
np.array_equal( # Check the new position is still in the grid
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_transitions(new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_isValid is None:
transition_isValid = self.rail.get_transition(
(*agent.position, agent.direction),
new_direction)
# cell_isFree = True
# for j in range(self.number_of_agents):
# if self.agents_position[j] == new_position:
# cell_isFree = False
# break
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_isFree = not np.any(
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
if all([new_cell_isValid, transition_isValid, cell_isFree]): if all([new_cell_isValid, transition_isValid, cell_isFree]):
# move and change direction to face the new_direction that was # move and change direction to face the new_direction that was
...@@ -296,6 +261,52 @@ class RailEnv(Environment): ...@@ -296,6 +261,52 @@ class RailEnv(Environment):
self.actions = [0] * self.get_num_agents() self.actions = [0] * self.get_num_agents()
return self._get_observations(), self.rewards_dict, self.dones, {} return self._get_observations(), self.rewards_dict, self.dones, {}
def _check_action_on_agent(self, action, agent):
# pos = agent.position # self.agents_position[i]
# direction = agent.direction # self.agents_direction[i]
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_isValid = self.check_action(agent, action)
new_position = get_new_position(agent.position, new_direction)
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell
# if (
# new_position[1] >= self.width or
# new_position[0] >= self.height or
# new_position[0] < 0 or new_position[1] < 0):
# new_cell_isValid = False
# if self.rail.get_transitions(new_position) == 0:
# new_cell_isValid = False
new_cell_isValid = (
np.array_equal( # Check the new position is still in the grid
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_transitions(new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_isValid is None:
transition_isValid = self.rail.get_transition(
(*agent.position, agent.direction),
new_direction)
# cell_isFree = True
# for j in range(self.number_of_agents):
# if self.agents_position[j] == new_position:
# cell_isFree = False
# break
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_isFree = not np.any(
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
def predict(self):
if not self.prediction_builder:
return {}
return self.prediction_builder.get()
def check_action(self, agent, action): def check_action(self, agent, action):
transition_isValid = None transition_isValid = None
possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
...@@ -332,6 +343,11 @@ class RailEnv(Environment): ...@@ -332,6 +343,11 @@ class RailEnv(Environment):
self.obs_dict[iAgent] = self.obs_builder.get(iAgent) self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
return self.obs_dict return self.obs_dict
def _get_predictions(self):
if not self.prediction_builder:
return {}
return {}
def render(self): def render(self):
# TODO: # TODO:
pass pass
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
"""Test predictions for `flatland` package."""
def test_predictions():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ /_\ _ _ _ _ _ _
# \ /
# |
# |
# |
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
prediction_builder_object=DummyPredictorForRailEnv(max_depth=20)
)
env.reset()
# set initial position and direction for testing...
env.agents[0].position = (5, 6)
env.agents[0].direction = 0
predictions = env.predict()
positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
# compare against expected values
expected_positions = np.array([[5., 6.],
[4., 6.],
[3., 6.],
[3., 5.],
[3., 4.],
[3., 3.],
[3., 2.],
[3., 1.],
[3., 0.],
[3., 1.],
[3., 2.],
[3., 3.],
[3., 4.],
[3., 5.],
[3., 6.],
[3., 7.],
[3., 8.],
[3., 9.],
[3., 8.],
[3., 7.]])
expected_directions = np.array([[0.],
[0.],
[0.],
[3.],
[3.],
[3.],
[3.],
[3.],
[3.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[3.],
[3.]])
expected_time_offsets = np.array([[0.],
[1.],
[2.],
[3.],
[4.],
[5.],
[6.],
[7.],
[8.],
[9.],
[10.],
[11.],
[12.],
[13.],
[14.],
[15.],
[16.],
[17.],
[18.],
[19.]])
expected_actions = np.array([[0.],
[2.],
[2.],
[1.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.]])
assert np.array_equal(positions, expected_positions)
assert np.array_equal(directions, expected_directions)
assert np.array_equal(time_offsets, expected_time_offsets)
assert np.array_equal(actions, expected_actions)
def main():
test_predictions()
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment