Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
960361f1
Commit
960361f1
authored
Jun 04, 2019
by
u214892
Browse files
25 predictor draft
parent
311b9814
Pipeline
#849
failed with stage
in 4 minutes and 35 seconds
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
flatland/core/env_prediction_builder.py
View file @
960361f1
...
...
@@ -15,8 +15,8 @@ class PredictionBuilder:
PredictionBuilder base class.
"""
def
__init__
(
self
):
pass
def
__init__
(
self
,
max_depth
:
int
=
20
):
self
.
max_depth
=
max_depth
def
_set_env
(
self
,
env
):
self
.
env
=
env
...
...
@@ -25,12 +25,11 @@ class PredictionBuilder:
"""
Called after each environment reset.
"""
raise
NotImplementedError
()
pass
def
get
(
self
,
handle
=
0
):
"""
Called whenever an observation has to be computed for the `env' environment, possibly
for each agent independently (agent id `handle').
Called whenever step_prediction is called on the environment.
Parameters
-------
...
...
@@ -40,6 +39,6 @@ class PredictionBuilder:
Returns
-------
function
A
n
prediction structure, specific to the corresponding environment.
A prediction structure, specific to the corresponding environment.
"""
raise
NotImplementedError
()
flatland/envs/predictions.py
View file @
960361f1
...
...
@@ -2,6 +2,8 @@
Collection of environment-specific PredictionBuilder.
"""
import
numpy
as
np
from
flatland.core.env_prediction_builder
import
PredictionBuilder
...
...
@@ -13,11 +15,58 @@ class DummyPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def
__init__
(
self
):
pass
def
get
(
self
,
handle
=
None
):
"""
Called whenever step_prediction is called on the environment.
Parameters
-------
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
"""
agents
=
self
.
env
.
agents
if
handle
:
agents
=
[
self
.
env
.
agents
[
handle
]]
prediction_dict
=
{}
def
reset
(
self
):
pass
for
agent
in
agents
:
def
get
(
self
,
handle
=
0
):
return
{}
# 0: do nothing
# 1: turn left and move to the next cell
# 2: move to the next cell in front of the agent
# 3: turn right and move to the next cell
action_priorities
=
[
2
,
1
,
3
]
_agent_initial_position
=
agent
.
position
_agent_initial_direction
=
agent
.
direction
prediction
=
np
.
zeros
(
shape
=
(
self
.
max_depth
,
5
))
prediction
[
0
]
=
[
0
,
_agent_initial_position
[
0
],
_agent_initial_position
[
1
],
_agent_initial_direction
,
0
]
for
index
in
range
(
1
,
self
.
max_depth
):
action_done
=
False
for
action
in
action_priorities
:
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
self
.
env
.
_check_action_on_agent
(
action
,
agent
)
if
all
([
new_cell_isValid
,
transition_isValid
]):
# move and change direction to face the new_direction that was
# performed
agent
.
position
=
new_position
agent
.
direction
=
new_direction
prediction
[
index
]
=
[
index
,
new_position
[
0
],
new_position
[
1
],
new_direction
,
action
]
action_done
=
True
break
if
not
action_done
:
print
(
"Cannot move further."
)
prediction_dict
[
agent
.
handle
]
=
prediction
agent
.
position
=
_agent_initial_position
agent
.
direction
=
_agent_initial_direction
return
prediction_dict
flatland/envs/rail_env.py
View file @
960361f1
...
...
@@ -219,51 +219,9 @@ class RailEnv(Environment):
return
if
action
>
0
:
# pos = agent.position # self.agents_position[i]
# direction = agent.direction # self.agents_direction[i]
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction
,
transition_isValid
=
self
.
check_action
(
agent
,
action
)
new_position
=
get_new_position
(
agent
.
position
,
new_direction
)
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell
# if (
# new_position[1] >= self.width or
# new_position[0] >= self.height or
# new_position[0] < 0 or new_position[1] < 0):
# new_cell_isValid = False
# if self.rail.get_transitions(new_position) == 0:
# new_cell_isValid = False
new_cell_isValid
=
(
np
.
array_equal
(
# Check the new position is still in the grid
new_position
,
np
.
clip
(
new_position
,
[
0
,
0
],
[
self
.
height
-
1
,
self
.
width
-
1
]))
and
# check the new position has some transitions (ie is not an empty cell)
self
.
rail
.
get_transitions
(
new_position
)
>
0
)
# If transition validity hasn't been checked yet.
if
transition_isValid
is
None
:
transition_isValid
=
self
.
rail
.
get_transition
(
(
*
agent
.
position
,
agent
.
direction
),
new_direction
)
# cell_isFree = True
# for j in range(self.number_of_agents):
# if self.agents_position[j] == new_position:
# cell_isFree = False
# break
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_isFree
=
not
np
.
any
(
np
.
equal
(
new_position
,
[
agent2
.
position
for
agent2
in
self
.
agents
]).
all
(
1
))
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
self
.
_check_action_on_agent
(
action
,
agent
,
transition_isValid
)
if
all
([
new_cell_isValid
,
transition_isValid
,
cell_isFree
]):
# move and change direction to face the new_direction that was
...
...
@@ -303,6 +261,46 @@ class RailEnv(Environment):
self
.
actions
=
[
0
]
*
self
.
get_num_agents
()
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
{}
def
_check_action_on_agent
(
self
,
action
,
agent
):
# pos = agent.position # self.agents_position[i]
# direction = agent.direction # self.agents_direction[i]
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction
,
transition_isValid
=
self
.
check_action
(
agent
,
action
)
new_position
=
get_new_position
(
agent
.
position
,
new_direction
)
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell
# if (
# new_position[1] >= self.width or
# new_position[0] >= self.height or
# new_position[0] < 0 or new_position[1] < 0):
# new_cell_isValid = False
# if self.rail.get_transitions(new_position) == 0:
# new_cell_isValid = False
new_cell_isValid
=
(
np
.
array_equal
(
# Check the new position is still in the grid
new_position
,
np
.
clip
(
new_position
,
[
0
,
0
],
[
self
.
height
-
1
,
self
.
width
-
1
]))
and
# check the new position has some transitions (ie is not an empty cell)
self
.
rail
.
get_transitions
(
new_position
)
>
0
)
# If transition validity hasn't been checked yet.
if
transition_isValid
is
None
:
transition_isValid
=
self
.
rail
.
get_transition
(
(
*
agent
.
position
,
agent
.
direction
),
new_direction
)
# cell_isFree = True
# for j in range(self.number_of_agents):
# if self.agents_position[j] == new_position:
# cell_isFree = False
# break
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_isFree
=
not
np
.
any
(
np
.
equal
(
new_position
,
[
agent2
.
position
for
agent2
in
self
.
agents
]).
all
(
1
))
return
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
def
predict
(
self
):
if
not
self
.
prediction_builder
:
return
{}
...
...
tests/test_env_prediction_builder.py
View file @
960361f1
...
...
@@ -3,13 +3,13 @@
import
numpy
as
np
from
flatland.envs.observations
import
GlobalObsForRailEnv
from
flatland.core.transition_map
import
GridTransitionMap
,
Grid4Transitions
from
flatland.envs.generators
import
rail_from_GridTransitionMap_generator
from
flatland.envs.observations
import
GlobalObsForRailEnv
from
flatland.envs.predictions
import
DummyPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.generators
import
rail_from_GridTransitionMap_generator
"""Tests for `flatland` package."""
"""Test
prediction
s for `flatland` package."""
def
test_predictions
():
...
...
@@ -65,12 +65,106 @@ def test_predictions():
rail_generator
=
rail_from_GridTransitionMap_generator
(
rail
),
number_of_agents
=
1
,
obs_builder_object
=
GlobalObsForRailEnv
(),
prediction_builder_object
=
DummyPredictorForRailEnv
()
prediction_builder_object
=
DummyPredictorForRailEnv
(
max_depth
=
20
)
)
env
.
reset
()
# set initial position and direction for testing...
env
.
agents
[
0
].
position
=
(
5
,
6
)
env
.
agents
[
0
].
direction
=
0
predictions
=
env
.
predict
()
positions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
1
],
prediction
[
2
]],
predictions
[
0
])))
directions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
3
]],
predictions
[
0
])))
time_offsets
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
0
]],
predictions
[
0
])))
actions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
4
]],
predictions
[
0
])))
# compare against expected values
expected_positions
=
np
.
array
([[
5.
,
6.
],
[
4.
,
6.
],
[
3.
,
6.
],
[
3.
,
5.
],
[
3.
,
4.
],
[
3.
,
3.
],
[
3.
,
2.
],
[
3.
,
1.
],
[
3.
,
0.
],
[
3.
,
1.
],
[
3.
,
2.
],
[
3.
,
3.
],
[
3.
,
4.
],
[
3.
,
5.
],
[
3.
,
6.
],
[
3.
,
7.
],
[
3.
,
8.
],
[
3.
,
9.
],
[
3.
,
8.
],
[
3.
,
7.
]])
expected_directions
=
np
.
array
([[
0.
],
[
0.
],
[
0.
],
[
3.
],
[
3.
],
[
3.
],
[
3.
],
[
3.
],
[
3.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
1.
],
[
3.
],
[
3.
]])
expected_time_offsets
=
np
.
array
([[
0.
],
[
1.
],
[
2.
],
[
3.
],
[
4.
],
[
5.
],
[
6.
],
[
7.
],
[
8.
],
[
9.
],
[
10.
],
[
11.
],
[
12.
],
[
13.
],
[
14.
],
[
15.
],
[
16.
],
[
17.
],
[
18.
],
[
19.
]])
expected_actions
=
np
.
array
([[
0.
],
[
2.
],
[
2.
],
[
1.
],
[
2.
],
[
2.
],
[
2.
],
[
2.
],
[
2.
],
[
2.
],
[
2.
],
[
2.
],
[
2.
],
[
2.
],
[
2.
],
[
2.
],
[
2.
],
[
2.
],
[
2.
],
[
2.
]])
assert
np
.
array_equal
(
positions
,
expected_positions
)
assert
np
.
array_equal
(
directions
,
expected_directions
)
assert
np
.
array_equal
(
time_offsets
,
expected_time_offsets
)
assert
np
.
array_equal
(
actions
,
expected_actions
)
def
main
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment