Skip to content
Snippets Groups Projects
Commit 960361f1 authored by u214892's avatar u214892
Browse files

25 predictor draft

parent 311b9814
No related branches found
No related tags found
No related merge requests found
......@@ -15,8 +15,8 @@ class PredictionBuilder:
PredictionBuilder base class.
"""
def __init__(self):
pass
def __init__(self, max_depth: int = 20):
self.max_depth = max_depth
def _set_env(self, env):
self.env = env
......@@ -25,12 +25,11 @@ class PredictionBuilder:
"""
Called after each environment reset.
"""
raise NotImplementedError()
pass
def get(self, handle=0):
"""
Called whenever an observation has to be computed for the `env' environment, possibly
for each agent independently (agent id `handle').
Called whenever step_prediction is called on the environment.
Parameters
-------
......@@ -40,6 +39,6 @@ class PredictionBuilder:
Returns
-------
function
An prediction structure, specific to the corresponding environment.
A prediction structure, specific to the corresponding environment.
"""
raise NotImplementedError()
......@@ -2,6 +2,8 @@
Collection of environment-specific PredictionBuilder.
"""
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
......@@ -13,11 +15,58 @@ class DummyPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def __init__(self):
pass
def get(self, handle=None):
"""
Called whenever step_prediction is called on the environment.
Parameters
-------
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
prediction_dict = {}
def reset(self):
pass
for agent in agents:
def get(self, handle=0):
return {}
# 0: do nothing
# 1: turn left and move to the next cell
# 2: move to the next cell in front of the agent
# 3: turn right and move to the next cell
action_priorities = [2, 1, 3]
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth, 5))
prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
for index in range(1, self.max_depth):
action_done = False
for action in action_priorities:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self.env._check_action_on_agent(action,
agent)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
# performed
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, new_position[0], new_position[1], new_direction, action]
action_done = True
break
if not action_done:
print("Cannot move further.")
prediction_dict[agent.handle] = prediction
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict
......@@ -219,51 +219,9 @@ class RailEnv(Environment):
return
if action > 0:
# pos = agent.position # self.agents_position[i]
# direction = agent.direction # self.agents_direction[i]
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_isValid = self.check_action(agent, action)
new_position = get_new_position(agent.position, new_direction)
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell
# if (
# new_position[1] >= self.width or
# new_position[0] >= self.height or
# new_position[0] < 0 or new_position[1] < 0):
# new_cell_isValid = False
# if self.rail.get_transitions(new_position) == 0:
# new_cell_isValid = False
new_cell_isValid = (
np.array_equal( # Check the new position is still in the grid
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_transitions(new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_isValid is None:
transition_isValid = self.rail.get_transition(
(*agent.position, agent.direction),
new_direction)
# cell_isFree = True
# for j in range(self.number_of_agents):
# if self.agents_position[j] == new_position:
# cell_isFree = False
# break
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_isFree = not np.any(
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self._check_action_on_agent(action,
agent,
transition_isValid)
if all([new_cell_isValid, transition_isValid, cell_isFree]):
# move and change direction to face the new_direction that was
......@@ -303,6 +261,46 @@ class RailEnv(Environment):
self.actions = [0] * self.get_num_agents()
return self._get_observations(), self.rewards_dict, self.dones, {}
def _check_action_on_agent(self, action, agent):
# pos = agent.position # self.agents_position[i]
# direction = agent.direction # self.agents_direction[i]
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_isValid = self.check_action(agent, action)
new_position = get_new_position(agent.position, new_direction)
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell
# if (
# new_position[1] >= self.width or
# new_position[0] >= self.height or
# new_position[0] < 0 or new_position[1] < 0):
# new_cell_isValid = False
# if self.rail.get_transitions(new_position) == 0:
# new_cell_isValid = False
new_cell_isValid = (
np.array_equal( # Check the new position is still in the grid
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_transitions(new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_isValid is None:
transition_isValid = self.rail.get_transition(
(*agent.position, agent.direction),
new_direction)
# cell_isFree = True
# for j in range(self.number_of_agents):
# if self.agents_position[j] == new_position:
# cell_isFree = False
# break
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_isFree = not np.any(
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
def predict(self):
if not self.prediction_builder:
return {}
......
......@@ -3,13 +3,13 @@
import numpy as np
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import rail_from_GridTransitionMap_generator
"""Tests for `flatland` package."""
"""Test predictions for `flatland` package."""
def test_predictions():
......@@ -65,12 +65,106 @@ def test_predictions():
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
prediction_builder_object=DummyPredictorForRailEnv()
prediction_builder_object=DummyPredictorForRailEnv(max_depth=20)
)
env.reset()
# set initial position and direction for testing...
env.agents[0].position = (5, 6)
env.agents[0].direction = 0
predictions = env.predict()
positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
# compare against expected values
expected_positions = np.array([[5., 6.],
[4., 6.],
[3., 6.],
[3., 5.],
[3., 4.],
[3., 3.],
[3., 2.],
[3., 1.],
[3., 0.],
[3., 1.],
[3., 2.],
[3., 3.],
[3., 4.],
[3., 5.],
[3., 6.],
[3., 7.],
[3., 8.],
[3., 9.],
[3., 8.],
[3., 7.]])
expected_directions = np.array([[0.],
[0.],
[0.],
[3.],
[3.],
[3.],
[3.],
[3.],
[3.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[3.],
[3.]])
expected_time_offsets = np.array([[0.],
[1.],
[2.],
[3.],
[4.],
[5.],
[6.],
[7.],
[8.],
[9.],
[10.],
[11.],
[12.],
[13.],
[14.],
[15.],
[16.],
[17.],
[18.],
[19.]])
expected_actions = np.array([[0.],
[2.],
[2.],
[1.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.]])
assert np.array_equal(positions, expected_positions)
assert np.array_equal(directions, expected_directions)
assert np.array_equal(time_offsets, expected_time_offsets)
assert np.array_equal(actions, expected_actions)
def main():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment