Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
B
baselines
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Flatland
baselines
Commits
4c8f4d40
Commit
4c8f4d40
authored
5 years ago
by
u214892
Browse files
Options
Downloads
Patches
Plain Diff
update baselines to master of flatland
parent
9a7e2fc1
No related branches found
No related tags found
1 merge request
!6
Update baselines
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
MANIFEST.in
+1
-1
1 addition, 1 deletion
MANIFEST.in
torch_training/observation_builders/observations.py
+15
-28
15 additions, 28 deletions
torch_training/observation_builders/observations.py
torch_training/predictors/predictions.py
+3
-2
3 additions, 2 deletions
torch_training/predictors/predictions.py
with
19 additions
and
31 deletions
MANIFEST.in
+
1
−
1
View file @
4c8f4d40
...
...
@@ -12,4 +12,4 @@ recursive-include tests *
recursive-exclude * __pycache__
recursive-exclude * *.py[co]
recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif
recursive-include docs *.rst
*.md
conf.py Makefile make.bat *.jpg *.png *.gif
This diff is collapsed.
Click to expand it.
torch_training/observation_builders/observations.py
+
15
−
28
View file @
4c8f4d40
...
...
@@ -7,7 +7,7 @@ from collections import deque
import
numpy
as
np
from
flatland.core.env_observation_builder
import
ObservationBuilder
from
flatland.core.grid.grid4
import
Grid4Tran
sition
sEnum
from
flatland.core.grid.grid4
_utils
import
get_new_po
sition
from
flatland.core.grid.grid_utils
import
coordinate_to_position
...
...
@@ -86,10 +86,10 @@ class TreeObsForRailEnv(ObservationBuilder):
# Fill in the (up to) 4 neighboring nodes
# direction is the direction of movement, meaning that at least a possible orientation of an agent
# in cell (row,col) allows a movement in direction `direction
'
# in cell (row,col) allows a movement in direction `direction
`
nodes_queue
=
deque
(
self
.
_get_and_update_neighbors
(
position
,
target_nr
,
0
,
enforce_target_direction
=-
1
))
# BFS from target `position
'
to all the reachable nodes in the grid
# BFS from target `position
`
to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited
=
{(
position
[
0
],
position
[
1
],
0
),
(
position
[
0
],
position
[
1
],
1
),
(
position
[
0
],
position
[
1
],
2
),
(
position
[
0
],
position
[
1
],
3
)}
...
...
@@ -125,12 +125,12 @@ class TreeObsForRailEnv(ObservationBuilder):
possible_directions
=
[
0
,
1
,
2
,
3
]
if
enforce_target_direction
>=
0
:
# The agent must land into the current cell with orientation `enforce_target_direction
'
.
# The agent must land into the current cell with orientation `enforce_target_direction
`
.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions
=
[(
enforce_target_direction
+
2
)
%
4
]
for
neigh_direction
in
possible_directions
:
new_cell
=
self
.
_new_position
(
position
,
neigh_direction
)
new_cell
=
get
_new_position
(
position
,
neigh_direction
)
if
new_cell
[
0
]
>=
0
and
new_cell
[
0
]
<
self
.
env
.
height
and
new_cell
[
1
]
>=
0
and
new_cell
[
1
]
<
self
.
env
.
width
:
...
...
@@ -138,7 +138,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Check all possible transitions in new_cell
for
agent_orientation
in
range
(
4
):
# Is a transition along movement `desired_movement_from_new_cell
'
to the current cell possible?
# Is a transition along movement `desired_movement_from_new_cell
`
to the current cell possible?
is_valid
=
self
.
env
.
rail
.
get_transition
((
new_cell
[
0
],
new_cell
[
1
],
agent_orientation
),
desired_movement_from_new_cell
)
...
...
@@ -156,23 +156,10 @@ class TreeObsForRailEnv(ObservationBuilder):
return
neighbors
def
_new_position
(
self
,
position
,
movement
):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if
movement
==
Grid4TransitionsEnum
.
NORTH
:
return
(
position
[
0
]
-
1
,
position
[
1
])
elif
movement
==
Grid4TransitionsEnum
.
EAST
:
return
(
position
[
0
],
position
[
1
]
+
1
)
elif
movement
==
Grid4TransitionsEnum
.
SOUTH
:
return
(
position
[
0
]
+
1
,
position
[
1
])
elif
movement
==
Grid4TransitionsEnum
.
WEST
:
return
(
position
[
0
],
position
[
1
]
-
1
)
def
get_many
(
self
,
handles
=
None
):
"""
Called whenever an observation has to be computed for the `env
'
environment, for each agent with handle
in the `handles
'
list.
Called whenever an observation has to be computed for the `env
`
environment, for each agent with handle
in the `handles
`
list.
"""
if
handles
is
None
:
...
...
@@ -200,7 +187,7 @@ class TreeObsForRailEnv(ObservationBuilder):
def
get
(
self
,
handle
):
"""
Computes the current observation for agent `handle
'
in env
Computes the current observation for agent `handle
`
in env
The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
...
...
@@ -280,7 +267,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for
branch_direction
in
[(
orientation
+
i
)
%
4
for
i
in
range
(
-
1
,
3
)]:
if
possible_transitions
[
branch_direction
]:
new_cell
=
self
.
_new_position
(
agent
.
position
,
branch_direction
)
new_cell
=
get
_new_position
(
agent
.
position
,
branch_direction
)
branch_observation
,
branch_visited
=
\
self
.
_explore_branch
(
handle
,
new_cell
,
branch_direction
,
1
,
1
)
observation
=
observation
+
branch_observation
...
...
@@ -428,11 +415,11 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_dead_end
=
True
if
not
last_is_dead_end
:
# Keep walking through the tree along `direction
'
# Keep walking through the tree along `direction
`
exploring
=
True
# convert one-hot encoding to 0,1,2,3
direction
=
np
.
argmax
(
cell_transitions
)
position
=
self
.
_new_position
(
position
,
direction
)
position
=
get
_new_position
(
position
,
direction
)
num_steps
+=
1
tot_dist
+=
1
elif
num_transitions
>
0
:
...
...
@@ -447,7 +434,7 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_terminal
=
True
break
# `position
'
is either a terminal node or a switch
# `position
`
is either a terminal node or a switch
# #############################
# #############################
...
...
@@ -499,7 +486,7 @@ class TreeObsForRailEnv(ObservationBuilder):
(
branch_direction
+
2
)
%
4
):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back
new_cell
=
self
.
_new_position
(
position
,
(
branch_direction
+
2
)
%
4
)
new_cell
=
get
_new_position
(
position
,
(
branch_direction
+
2
)
%
4
)
branch_observation
,
branch_visited
=
self
.
_explore_branch
(
handle
,
new_cell
,
(
branch_direction
+
2
)
%
4
,
...
...
@@ -509,7 +496,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if
len
(
branch_visited
)
!=
0
:
visited
=
visited
.
union
(
branch_visited
)
elif
last_is_switch
and
possible_transitions
[
branch_direction
]:
new_cell
=
self
.
_new_position
(
position
,
branch_direction
)
new_cell
=
get
_new_position
(
position
,
branch_direction
)
branch_observation
,
branch_visited
=
self
.
_explore_branch
(
handle
,
new_cell
,
branch_direction
,
...
...
This diff is collapsed.
Click to expand it.
torch_training/predictors/predictions.py
+
3
−
2
View file @
4c8f4d40
...
...
@@ -8,6 +8,7 @@ from flatland.core.env_prediction_builder import PredictionBuilder
from
flatland.core.grid.grid4_utils
import
get_new_position
from
flatland.envs.rail_env
import
RailEnvActions
class
ShortestPathPredictorForRailEnv
(
PredictionBuilder
):
"""
ShortestPathPredictorForRailEnv object.
...
...
@@ -25,10 +26,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
Requires distance_map to extract the shortest path.
Parameters
-------
-------
---
custom_args: dict
- distance_map : dict
handle : int
(
optional
)
handle : int
,
optional
Handle of the agent for which to compute the observation vector.
Returns
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment