Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
e1e0947e
Commit
e1e0947e
authored
Jul 05, 2019
by
u214892
Browse files
refactoring transitions_map
parent
3383b56b
Pipeline
#1356
passed with stage
in 6 minutes and 55 seconds
Changes
9
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
flatland/core/grid/grid4.py
View file @
e1e0947e
from
enum
import
IntEnum
from
typing
import
Type
import
numpy
as
np
...
...
@@ -218,7 +219,7 @@ class Grid4Transitions(Transitions):
cell_transition
=
value
return
cell_transition
def
get_direction_enum
(
self
)
->
Int
Enum
:
def
get_direction_enum
(
self
)
->
Type
[
Grid4Transitions
Enum
]
:
return
Grid4TransitionsEnum
def
has_deadend
(
self
,
cell_transition
):
...
...
flatland/core/transition_map.py
View file @
e1e0947e
...
...
@@ -7,6 +7,7 @@ from importlib_resources import path
from
numpy
import
array
from
flatland.core.grid.grid4
import
Grid4Transitions
from
flatland.core.transitions
import
Transitions
class
TransitionMap
:
...
...
@@ -110,7 +111,7 @@ class GridTransitionMap(TransitionMap):
GridTransitionMap implements utility functions.
"""
def
__init__
(
self
,
width
,
height
,
transitions
=
Grid4Transitions
([])):
def
__init__
(
self
,
width
,
height
,
transitions
:
Transitions
=
Grid4Transitions
([])):
"""
Builder for GridTransitionMap object.
...
...
@@ -132,7 +133,25 @@ class GridTransitionMap(TransitionMap):
self
.
grid
=
np
.
zeros
((
height
,
width
),
dtype
=
self
.
transitions
.
get_type
())
def
get_transitions
(
self
,
cell_id
):
def
get_full_transitions
(
self
,
row
,
column
):
"""
Returns the full transitions for the cell at (row, column) in the format transition_map's transitions.
Parameters
----------
row: int
column: int
(row,column) specifies the cell in this transition map.
Returns
-------
self.transitions.get_type()
The cell content int the format of this map's Transitions.
"""
return
self
.
grid
[
row
][
column
]
def
get_transitions
(
self
,
row
,
column
,
orientation
):
"""
Return a tuple of transitions available in a cell specified by
`cell_id' (e.g., a tuple of size of the maximum number of transitions,
...
...
@@ -150,15 +169,10 @@ class GridTransitionMap(TransitionMap):
Returns
-------
tuple
List of the validity of transitions in the cell.
List of the validity of transitions in the cell
as given by the maps transitions
.
"""
assert
len
(
cell_id
)
in
(
2
,
3
),
\
'GridTransitionMap.get_transitions() ERROR: cell_id tuple must have length 2 or 3.'
if
len
(
cell_id
)
==
3
:
return
self
.
transitions
.
get_transitions
(
self
.
grid
[
cell_id
[
0
]][
cell_id
[
1
]],
cell_id
[
2
])
elif
len
(
cell_id
)
==
2
:
return
self
.
grid
[
cell_id
[
0
]][
cell_id
[
1
]]
return
self
.
transitions
.
get_transitions
(
self
.
grid
[
row
][
column
],
orientation
)
def
set_transitions
(
self
,
cell_id
,
new_transitions
):
"""
...
...
@@ -308,7 +322,7 @@ class GridTransitionMap(TransitionMap):
grcPos
=
array
(
rcPos
)
grcMax
=
self
.
grid
.
shape
binTrans
=
self
.
get_transitions
(
rcPos
)
# 16bit integer - all trans in/out
binTrans
=
self
.
get_
full_
transitions
(
*
rcPos
)
# 16bit integer - all trans in/out
lnBinTrans
=
array
([
binTrans
>>
8
,
binTrans
&
0xff
],
dtype
=
np
.
uint8
)
# 2 x uint8
g2binTrans
=
np
.
unpackbits
(
lnBinTrans
).
reshape
(
4
,
4
)
# 4x4 x uint8 binary(0,1)
gDirOut
=
g2binTrans
.
any
(
axis
=
0
)
# outbound directions as boolean array (4)
...
...
@@ -328,7 +342,7 @@ class GridTransitionMap(TransitionMap):
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2
=
self
.
get_transitions
(
(
*
gPos2
,
iDirOut
)
)
t4Trans2
=
self
.
get_transitions
(
*
gPos2
,
iDirOut
)
if
any
(
t4Trans2
):
continue
else
:
...
...
flatland/envs/grid4_generators_utils.py
View file @
e1e0947e
...
...
@@ -75,7 +75,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
return
1
if
node
not
in
visited
:
visited
.
add
(
node
)
moves
=
rail
.
get_transitions
(
(
node
[
0
][
0
],
node
[
0
][
1
],
node
[
1
])
)
moves
=
rail
.
get_transitions
(
node
[
0
][
0
],
node
[
0
][
1
],
node
[
1
])
for
move_index
in
range
(
4
):
if
moves
[
move_index
]:
stack
.
append
((
get_new_position
(
node
[
0
],
move_index
),
...
...
@@ -84,7 +84,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits
=
0
tmp
=
rail
.
get_transitions
(
(
node
[
0
][
0
],
node
[
0
][
1
])
)
tmp
=
rail
.
get_
full_
transitions
(
node
[
0
][
0
],
node
[
0
][
1
])
while
tmp
>
0
:
nbits
+=
(
tmp
&
1
)
tmp
=
tmp
>>
1
...
...
@@ -96,7 +96,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
valid_positions
=
[]
for
r
in
range
(
rail
.
height
):
for
c
in
range
(
rail
.
width
):
if
rail
.
get_transitions
(
(
r
,
c
)
)
>
0
:
if
rail
.
get_
full_
transitions
(
r
,
c
)
>
0
:
valid_positions
.
append
((
r
,
c
))
re_generate
=
True
...
...
@@ -116,7 +116,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
valid_movements
=
[]
for
direction
in
range
(
4
):
position
=
agents_position
[
i
]
moves
=
rail
.
get_transitions
(
(
position
[
0
],
position
[
1
],
direction
)
)
moves
=
rail
.
get_transitions
(
position
[
0
],
position
[
1
],
direction
)
for
move_index
in
range
(
4
):
if
moves
[
move_index
]:
valid_movements
.
append
((
direction
,
move_index
))
...
...
flatland/envs/observations.py
View file @
e1e0947e
...
...
@@ -253,7 +253,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if
handle
>
len
(
self
.
env
.
agents
):
print
(
"ERROR: obs _get - handle "
,
handle
,
" len(agents)"
,
len
(
self
.
env
.
agents
))
agent
=
self
.
env
.
agents
[
handle
]
# TODO: handle being treated as index
possible_transitions
=
self
.
env
.
rail
.
get_transitions
(
(
*
agent
.
position
,
agent
.
direction
)
)
possible_transitions
=
self
.
env
.
rail
.
get_transitions
(
*
agent
.
position
,
agent
.
direction
)
num_transitions
=
np
.
count_nonzero
(
possible_transitions
)
# Root node - current position
...
...
@@ -383,8 +383,8 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_target
=
True
break
cell_transitions
=
self
.
env
.
rail
.
get_transitions
(
(
*
position
,
direction
)
)
total_transitions
=
bin
(
self
.
env
.
rail
.
get_transitions
(
position
)).
count
(
"1"
)
cell_transitions
=
self
.
env
.
rail
.
get_transitions
(
*
position
,
direction
)
total_transitions
=
bin
(
self
.
env
.
rail
.
get_
full_
transitions
(
*
position
)).
count
(
"1"
)
num_transitions
=
np
.
count_nonzero
(
cell_transitions
)
exploring
=
False
# Detect Switches that can only be used by other agents.
...
...
@@ -394,7 +394,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if
num_transitions
==
1
:
# Check if dead-end, or if we can go forward along direction
nbits
=
0
tmp
=
self
.
env
.
rail
.
get_transitions
(
tuple
(
position
)
)
tmp
=
self
.
env
.
rail
.
get_
full_
transitions
(
*
position
)
while
tmp
>
0
:
nbits
+=
(
tmp
&
1
)
tmp
=
tmp
>>
1
...
...
@@ -469,7 +469,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# Get the possible transitions
possible_transitions
=
self
.
env
.
rail
.
get_transitions
(
(
*
position
,
direction
)
)
possible_transitions
=
self
.
env
.
rail
.
get_transitions
(
*
position
,
direction
)
for
branch_direction
in
[(
direction
+
4
+
i
)
%
4
for
i
in
range
(
-
1
,
3
)]:
if
last_is_dead_end
and
self
.
env
.
rail
.
get_transition
((
*
position
,
direction
),
(
branch_direction
+
2
)
%
4
):
...
...
@@ -572,7 +572,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
self
.
rail_obs
=
np
.
zeros
((
self
.
env
.
height
,
self
.
env
.
width
,
16
))
for
i
in
range
(
self
.
rail_obs
.
shape
[
0
]):
for
j
in
range
(
self
.
rail_obs
.
shape
[
1
]):
bitlist
=
[
int
(
digit
)
for
digit
in
bin
(
self
.
env
.
rail
.
get_transitions
(
(
i
,
j
))
)
[
2
:]]
bitlist
=
[
int
(
digit
)
for
digit
in
bin
(
self
.
env
.
rail
.
get_
full_
transitions
(
i
,
j
))[
2
:]]
bitlist
=
[
0
]
*
(
16
-
len
(
bitlist
))
+
bitlist
self
.
rail_obs
[
i
,
j
]
=
np
.
array
(
bitlist
)
...
...
@@ -630,7 +630,7 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
self
.
rail_obs
=
np
.
zeros
((
self
.
env
.
height
,
self
.
env
.
width
,
16
))
for
i
in
range
(
self
.
rail_obs
.
shape
[
0
]):
for
j
in
range
(
self
.
rail_obs
.
shape
[
1
]):
bitlist
=
[
int
(
digit
)
for
digit
in
bin
(
self
.
env
.
rail
.
get_transitions
(
(
i
,
j
))
)
[
2
:]]
bitlist
=
[
int
(
digit
)
for
digit
in
bin
(
self
.
env
.
rail
.
get_
full_
transitions
(
i
,
j
))[
2
:]]
bitlist
=
[
0
]
*
(
16
-
len
(
bitlist
))
+
bitlist
self
.
rail_obs
[
i
,
j
]
=
np
.
array
(
bitlist
)
...
...
@@ -701,7 +701,7 @@ class LocalObsForRailEnv(ObservationBuilder):
self
.
env
.
width
+
2
*
self
.
view_radius
,
16
))
for
i
in
range
(
self
.
env
.
height
):
for
j
in
range
(
self
.
env
.
width
):
bitlist
=
[
int
(
digit
)
for
digit
in
bin
(
self
.
env
.
rail
.
get_transitions
(
(
i
,
j
))
)
[
2
:]]
bitlist
=
[
int
(
digit
)
for
digit
in
bin
(
self
.
env
.
rail
.
get_
full_
transitions
(
i
,
j
))[
2
:]]
bitlist
=
[
0
]
*
(
16
-
len
(
bitlist
))
+
bitlist
self
.
rail_obs
[
i
+
self
.
view_radius
,
j
+
self
.
view_radius
]
=
np
.
array
(
bitlist
)
...
...
flatland/envs/predictions.py
View file @
e1e0947e
...
...
@@ -131,7 +131,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
prediction
[
index
]
=
[
index
,
*
agent
.
position
,
agent
.
direction
,
RailEnvActions
.
STOP_MOVING
]
continue
# Take shortest possible path
cell_transitions
=
self
.
env
.
rail
.
get_transitions
(
(
*
agent
.
position
,
agent
.
direction
)
)
cell_transitions
=
self
.
env
.
rail
.
get_transitions
(
*
agent
.
position
,
agent
.
direction
)
new_position
=
None
new_direction
=
None
...
...
flatland/envs/rail_env.py
View file @
e1e0947e
...
...
@@ -322,7 +322,7 @@ class RailEnv(Environment):
new_position
,
np
.
clip
(
new_position
,
[
0
,
0
],
[
self
.
height
-
1
,
self
.
width
-
1
]))
and
# check the new position has some transitions (ie is not an empty cell)
self
.
rail
.
get_transitions
(
new_position
)
>
0
)
self
.
rail
.
get_
full_
transitions
(
*
new_position
)
>
0
)
# If transition validity hasn't been checked yet.
if
transition_isValid
is
None
:
...
...
@@ -338,7 +338,7 @@ class RailEnv(Environment):
def
check_action
(
self
,
agent
,
action
):
transition_isValid
=
None
possible_transitions
=
self
.
rail
.
get_transitions
(
(
*
agent
.
position
,
agent
.
direction
)
)
possible_transitions
=
self
.
rail
.
get_transitions
(
*
agent
.
position
,
agent
.
direction
)
num_transitions
=
np
.
count_nonzero
(
possible_transitions
)
new_direction
=
agent
.
direction
...
...
flatland/utils/editor.py
View file @
e1e0947e
...
...
@@ -494,7 +494,7 @@ class EditorModel(object):
if
len
(
lrcStroke
)
>=
2
:
# If the first cell in a stroke is empty, add a deadend to cell 0
if
self
.
env
.
rail
.
get_transitions
(
lrcStroke
[
0
])
==
0
:
if
self
.
env
.
rail
.
get_
full_
transitions
(
*
lrcStroke
[
0
])
==
0
:
self
.
mod_rail_2cells
(
lrcStroke
,
bAddRemove
,
iCellToMod
=
0
)
# Add transitions for groups of 3 cells
...
...
@@ -504,7 +504,7 @@ class EditorModel(object):
# If final cell empty, insert deadend:
if
len
(
lrcStroke
)
==
2
:
if
self
.
env
.
rail
.
get_transitions
(
lrcStroke
[
1
])
==
0
:
if
self
.
env
.
rail
.
get_
full_
transitions
(
*
lrcStroke
[
1
])
==
0
:
self
.
mod_rail_2cells
(
lrcStroke
,
bAddRemove
,
iCellToMod
=
1
)
# now empty out the final two cells from the queue
...
...
@@ -752,7 +752,7 @@ class EditorModel(object):
self
.
log
(
*
args
,
**
kwargs
)
def
debug_cell
(
self
,
rcCell
):
binTrans
=
self
.
env
.
rail
.
get_transitions
(
rcCell
)
binTrans
=
self
.
env
.
rail
.
get_
full_
transitions
(
*
rcCell
)
sbinTrans
=
format
(
binTrans
,
"#018b"
)[
2
:]
self
.
debug
(
"cell "
,
rcCell
,
...
...
flatland/utils/rendertools.py
View file @
e1e0947e
...
...
@@ -86,7 +86,7 @@ class RenderTool(object):
for
visit
in
lVisits
:
# transition for next cell
tbTrans
=
self
.
env
.
rail
.
get_transitions
(
(
*
visit
.
rc
,
visit
.
iDir
)
)
tbTrans
=
self
.
env
.
rail
.
get_transitions
(
*
visit
.
rc
,
visit
.
iDir
)
giTrans
=
np
.
where
(
tbTrans
)[
0
]
# RC list of transitions
gTransRCAg
=
rt
.
gTransRC
[
giTrans
]
self
.
plotTrans
(
visit
.
rc
,
gTransRCAg
,
depth
=
str
(
visit
.
iDepth
),
color
=
color
)
...
...
@@ -125,7 +125,7 @@ class RenderTool(object):
)
"""
tbTrans
=
self
.
env
.
rail
.
get_transitions
(
(
*
rcPos
,
iDir
)
)
tbTrans
=
self
.
env
.
rail
.
get_transitions
(
*
rcPos
,
iDir
)
giTrans
=
np
.
where
(
tbTrans
)[
0
]
# RC list of transitions
# HACK: workaround dead-end transitions
...
...
@@ -459,7 +459,7 @@ class RenderTool(object):
xyCentre
=
array
([
x0
,
y1
])
+
cell_size
/
2
# cell transition values
oCell
=
env
.
rail
.
get_transitions
(
(
r
,
c
)
)
oCell
=
env
.
rail
.
get_
full_
transitions
(
r
,
c
)
bCellValid
=
env
.
rail
.
cell_neighbours_valid
((
r
,
c
),
check_this_cell
=
True
)
...
...
@@ -482,7 +482,7 @@ class RenderTool(object):
from_ori
=
(
orientation
+
2
)
%
4
# 0123=NESW -> 2301=SWNE
from_xy
=
coords
[
from_ori
]
tMoves
=
env
.
rail
.
get_transitions
(
(
r
,
c
,
orientation
)
)
tMoves
=
env
.
rail
.
get_transitions
(
r
,
c
,
orientation
)
for
to_ori
in
range
(
4
):
to_xy
=
coords
[
to_ori
]
...
...
tests/test_flatland_core_transition_map.py
View file @
e1e0947e
...
...
@@ -5,19 +5,42 @@ from flatland.core.transition_map import GridTransitionMap
def
test_grid4_get_transitions
():
grid4_map
=
GridTransitionMap
(
2
,
2
,
Grid4Transitions
([]))
assert
grid4_map
.
get_transitions
((
0
,
0
,
Grid4TransitionsEnum
.
NORTH
))
==
(
0
,
0
,
0
,
0
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
NORTH
)
==
(
0
,
0
,
0
,
0
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
EAST
)
==
(
0
,
0
,
0
,
0
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
SOUTH
)
==
(
0
,
0
,
0
,
0
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
WEST
)
==
(
0
,
0
,
0
,
0
)
assert
grid4_map
.
get_full_transitions
(
0
,
0
)
==
0
grid4_map
.
set_transition
((
0
,
0
,
Grid4TransitionsEnum
.
NORTH
),
Grid4TransitionsEnum
.
NORTH
,
1
)
assert
grid4_map
.
get_transitions
((
0
,
0
,
Grid4TransitionsEnum
.
NORTH
))
==
(
1
,
0
,
0
,
0
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
NORTH
)
==
(
1
,
0
,
0
,
0
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
EAST
)
==
(
0
,
0
,
0
,
0
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
SOUTH
)
==
(
0
,
0
,
0
,
0
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
WEST
)
==
(
0
,
0
,
0
,
0
)
assert
grid4_map
.
get_full_transitions
(
0
,
0
)
==
pow
(
2
,
15
)
# the most significant bit is on
grid4_map
.
set_transition
((
0
,
0
,
Grid4TransitionsEnum
.
NORTH
),
Grid4TransitionsEnum
.
WEST
,
1
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
NORTH
)
==
(
1
,
0
,
0
,
1
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
EAST
)
==
(
0
,
0
,
0
,
0
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
SOUTH
)
==
(
0
,
0
,
0
,
0
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
WEST
)
==
(
0
,
0
,
0
,
0
)
# the most significant and the fourth most significant bits are on
assert
grid4_map
.
get_full_transitions
(
0
,
0
)
==
pow
(
2
,
15
)
+
pow
(
2
,
12
)
grid4_map
.
set_transition
((
0
,
0
,
Grid4TransitionsEnum
.
NORTH
),
Grid4TransitionsEnum
.
NORTH
,
0
)
assert
grid4_map
.
get_transitions
((
0
,
0
,
Grid4TransitionsEnum
.
NORTH
))
==
(
0
,
0
,
0
,
0
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
NORTH
)
==
(
0
,
0
,
0
,
1
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
EAST
)
==
(
0
,
0
,
0
,
0
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
SOUTH
)
==
(
0
,
0
,
0
,
0
)
assert
grid4_map
.
get_transitions
(
0
,
0
,
Grid4TransitionsEnum
.
WEST
)
==
(
0
,
0
,
0
,
0
)
# the fourth most significant bits are on
assert
grid4_map
.
get_full_transitions
(
0
,
0
)
==
pow
(
2
,
12
)
def
test_grid8_set_transitions
():
grid8_map
=
GridTransitionMap
(
2
,
2
,
Grid8Transitions
([]))
assert
grid8_map
.
get_transitions
(
(
0
,
0
,
Grid8TransitionsEnum
.
NORTH
)
)
==
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
)
assert
grid8_map
.
get_transitions
(
0
,
0
,
Grid8TransitionsEnum
.
NORTH
)
==
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
)
grid8_map
.
set_transition
((
0
,
0
,
Grid8TransitionsEnum
.
NORTH
),
Grid8TransitionsEnum
.
NORTH
,
1
)
assert
grid8_map
.
get_transitions
(
(
0
,
0
,
Grid8TransitionsEnum
.
NORTH
)
)
==
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
)
assert
grid8_map
.
get_transitions
(
0
,
0
,
Grid8TransitionsEnum
.
NORTH
)
==
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
)
grid8_map
.
set_transition
((
0
,
0
,
Grid8TransitionsEnum
.
NORTH
),
Grid8TransitionsEnum
.
NORTH
,
0
)
assert
grid8_map
.
get_transitions
(
(
0
,
0
,
Grid8TransitionsEnum
.
NORTH
)
)
==
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
)
assert
grid8_map
.
get_transitions
(
0
,
0
,
Grid8TransitionsEnum
.
NORTH
)
==
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
)
# TODO GridTransitionMap
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment