Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
43cf00ff
Commit
43cf00ff
authored
Jun 06, 2019
by
u214892
Browse files
formatted everything with IntelliJ/PyCharm formatter, optimizing imports
parent
f296c0e3
Pipeline
#917
passed with stage
in 8 minutes and 16 seconds
Changes
18
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
docs/conf.py
View file @
43cf00ff
...
...
@@ -13,16 +13,17 @@
# All configuration values have a default; values that are commented out
# serve to show the default.
import
os
import
sys
# If extensions (or modules to document with autodoc) are in another
# directory, add these directories to sys.path here. If the directory is
# relative to the documentation root, use os.path.abspath to make it
# absolute, like shown here.
#
import
flatland
import
os
import
sys
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
'..'
))
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
'..'
))
# -- General configuration ---------------------------------------------
...
...
@@ -78,7 +79,6 @@ pygments_style = 'sphinx'
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos
=
False
# -- Options for HTML output -------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
...
...
@@ -86,7 +86,6 @@ todo_include_todos = False
#
html_theme
=
"sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a
# theme further. For a list of options available for each theme, see the
# documentation.
...
...
@@ -98,13 +97,11 @@ html_theme = "sphinx_rtd_theme"
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path
=
[
'_static'
]
# -- Options for HTMLHelp output ---------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename
=
'flatlanddoc'
# -- Options for LaTeX output ------------------------------------------
latex_elements
=
{
...
...
@@ -134,7 +131,6 @@ latex_documents = [
u
'S.P. Mohanty'
,
'manual'
),
]
# -- Options for manual page output ------------------------------------
# One entry per manual page. List of tuples
...
...
@@ -145,7 +141,6 @@ man_pages = [
[
author
],
1
)
]
# -- Options for Texinfo output ----------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
...
...
examples/custom_observation_example.py
View file @
43cf00ff
import
random
import
numpy
as
np
from
flatland.core.env_observation_builder
import
ObservationBuilder
from
flatland.envs.generators
import
random_rail_generator
from
flatland.envs.rail_env
import
RailEnv
from
flatland.core.env_observation_builder
import
ObservationBuilder
import
numpy
as
np
random
.
seed
(
100
)
np
.
random
.
seed
(
100
)
...
...
@@ -18,7 +18,7 @@ class CustomObs(ObservationBuilder):
return
def
get
(
self
,
handle
):
observation
=
handle
*
np
.
ones
((
5
,))
observation
=
handle
*
np
.
ones
((
5
,))
return
observation
...
...
examples/custom_railmap_example.py
View file @
43cf00ff
...
...
@@ -23,6 +23,7 @@ def custom_rail_generator():
agents_target
=
[]
return
grid_map
,
agents_positions
,
agents_direction
,
agents_target
return
generator
...
...
examples/tkplay.py
View file @
43cf00ff
from
examples.play_model
import
Player
from
flatland.envs.generators
import
complex_rail_generator
from
flatland.envs.rail_env
import
RailEnv
...
...
@@ -26,7 +25,7 @@ def tkmain(n_trials=2, n_steps=50, sGL="PIL"):
env_renderer
.
renderEnv
(
show
=
True
,
frames
=
True
,
iEpisode
=
trials
,
iStep
=
step
,
action_dict
=
oPlayer
.
action_dict
)
env_renderer
.
close_window
()
env_renderer
.
close_window
()
if
__name__
==
"__main__"
:
...
...
flatland/cli.py
View file @
43cf00ff
...
...
@@ -2,6 +2,7 @@
"""Console script for flatland."""
import
sys
import
click
...
...
flatland/core/transition_map.py
View file @
43cf00ff
...
...
@@ -297,8 +297,8 @@ class GridTransitionMap(TransitionMap):
self
.
grid
=
np
.
zeros
((
self
.
height
,
self
.
width
),
dtype
=
np
.
uint64
)
self
.
grid
[
0
:
min
(
self
.
height
,
new_height
),
0
:
min
(
self
.
width
,
new_width
)]
=
new_grid
[
0
:
min
(
self
.
height
,
new_height
),
0
:
min
(
self
.
width
,
new_width
)]
0
:
min
(
self
.
width
,
new_width
)]
=
new_grid
[
0
:
min
(
self
.
height
,
new_height
),
0
:
min
(
self
.
width
,
new_width
)]
def
is_cell_valid
(
self
,
rcPos
):
cell_transition
=
self
.
grid
[
tuple
(
rcPos
)]
...
...
@@ -336,8 +336,8 @@ class GridTransitionMap(TransitionMap):
lnBinTrans
=
array
([
binTrans
>>
8
,
binTrans
&
0xff
],
dtype
=
np
.
uint8
)
# 2 x uint8
g2binTrans
=
np
.
unpackbits
(
lnBinTrans
).
reshape
(
4
,
4
)
# 4x4 x uint8 binary(0,1)
# gDirIn = g2binTrans.any(axis=1) # inbound directions as boolean array (4)
gDirOut
=
g2binTrans
.
any
(
axis
=
0
)
# outbound directions as boolean array (4)
giDirOut
=
np
.
argwhere
(
gDirOut
)[:,
0
]
# valid outbound directions as array of int
gDirOut
=
g2binTrans
.
any
(
axis
=
0
)
# outbound directions as boolean array (4)
giDirOut
=
np
.
argwhere
(
gDirOut
)[:,
0
]
# valid outbound directions as array of int
# loop over available outbound directions (indices) for rcPos
for
iDirOut
in
giDirOut
:
...
...
flatland/core/transitions.py
View file @
43cf00ff
...
...
@@ -319,7 +319,7 @@ class Grid4Transitions(Transitions):
value
=
self
.
set_transitions
(
value
,
i
,
block_tuple
)
# Rotate the 4-bits blocks
value
=
((
value
&
(
2
**
(
rotation
*
4
)
-
1
))
<<
((
4
-
rotation
)
*
4
))
|
(
value
>>
(
rotation
*
4
))
value
=
((
value
&
(
2
**
(
rotation
*
4
)
-
1
))
<<
((
4
-
rotation
)
*
4
))
|
(
value
>>
(
rotation
*
4
))
cell_transition
=
value
return
cell_transition
...
...
@@ -499,7 +499,7 @@ class Grid8Transitions(Transitions):
value
=
self
.
set_transitions
(
value
,
i
,
block_tuple
)
# Rotate the 8bits blocks
value
=
((
value
&
(
2
**
(
rotation
*
8
)
-
1
))
<<
((
8
-
rotation
)
*
8
))
|
(
value
>>
(
rotation
*
8
))
value
=
((
value
&
(
2
**
(
rotation
*
8
)
-
1
))
<<
((
8
-
rotation
)
*
8
))
|
(
value
>>
(
rotation
*
8
))
cell_transition
=
value
...
...
@@ -587,9 +587,9 @@ class RailEnvTransitions(Grid4Transitions):
sRepr
=
" "
.
join
([
"{}:{}"
.
format
(
sDir
,
sbinTrans
[
i
:(
i
+
4
)])
for
i
,
sDir
in
zip
(
range
(
0
,
len
(
sbinTrans
),
4
),
self
.
lsDirs
)])
# NESW
zip
(
range
(
0
,
len
(
sbinTrans
),
4
),
self
.
lsDirs
)])
# NESW
return
sRepr
if
version
==
1
:
...
...
flatland/envs/agent_utils.py
View file @
43cf00ff
from
attr
import
attrs
,
attrib
from
itertools
import
starmap
import
numpy
as
np
from
attr
import
attrs
,
attrib
# from flatland.envs.rail_env import RailEnv
...
...
@@ -16,7 +18,7 @@ class EnvDescription(object):
height
=
attrib
()
width
=
attrib
()
rail_generator
=
attrib
()
obs_builder
=
attrib
()
# not sure if this should closer to the agent than the env
obs_builder
=
attrib
()
# not sure if this should closer to the agent than the env
@
attrs
...
...
@@ -41,7 +43,7 @@ class EnvAgentStatic(object):
def
from_lists
(
cls
,
positions
,
directions
,
targets
):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
return
list
(
starmap
(
EnvAgentStatic
,
zip
(
positions
,
directions
,
targets
,
[
False
]
*
len
(
positions
))))
return
list
(
starmap
(
EnvAgentStatic
,
zip
(
positions
,
directions
,
targets
,
[
False
]
*
len
(
positions
))))
def
to_list
(
self
):
...
...
@@ -78,7 +80,7 @@ class EnvAgent(EnvAgentStatic):
def
to_list
(
self
):
return
[
self
.
position
,
self
.
direction
,
self
.
target
,
self
.
handle
,
self
.
position
,
self
.
direction
,
self
.
target
,
self
.
handle
,
self
.
old_direction
,
self
.
old_position
,
self
.
moving
]
@
classmethod
...
...
flatland/envs/generators.py
View file @
43cf00ff
import
numpy
as
np
# from flatland.core.env import Environment
# from flatland.envs.observations import TreeObsForRailEnv
from
flatland.core.transitions
import
Grid8Transitions
,
RailEnvTransitions
from
flatland.core.transition_map
import
GridTransitionMap
from
flatland.core.transitions
import
Grid8Transitions
,
RailEnvTransitions
from
flatland.envs.env_utils
import
distance_on_rail
,
connect_rail
,
get_direction
,
mirror
from
flatland.envs.env_utils
import
get_rnd_agents_pos_tgt_dir_on_rail
# from flatland.core.env import Environment
# from flatland.envs.observations import TreeObsForRailEnv
def
empty_rail_generator
():
"""
Returns a generator which returns an empty rail mail with no agents.
Primarily used by the editor
"""
def
generator
(
width
,
height
,
num_agents
=
0
,
num_resets
=
0
):
rail_trans
=
RailEnvTransitions
()
grid_map
=
GridTransitionMap
(
width
=
width
,
height
=
height
,
transitions
=
rail_trans
)
...
...
@@ -21,6 +23,7 @@ def empty_rail_generator():
rail_array
.
fill
(
0
)
return
grid_map
,
[],
[],
[]
return
generator
...
...
@@ -41,7 +44,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
def
generator
(
width
,
height
,
num_agents
,
num_resets
=
0
):
if
num_agents
>
nr_start_goal
:
num_agents
=
nr_start_goal
num_agents
=
nr_start_goal
print
(
"complex_rail_generator: num_agents > nr_start_goal, changing num_agents"
)
rail_trans
=
RailEnvTransitions
()
grid_map
=
GridTransitionMap
(
width
=
width
,
height
=
height
,
transitions
=
rail_trans
)
...
...
flatland/envs/rail_env.py
View file @
43cf00ff
...
...
@@ -7,9 +7,10 @@ a GridTransitionMap object.
# TODO: _ this is a global method --> utils or remove later
# from inspect import currentframe
from
enum
import
IntEnum
import
msgpack
import
numpy
as
np
from
enum
import
IntEnum
from
flatland.core.env
import
Environment
from
flatland.envs.agent_utils
import
EnvAgentStatic
,
EnvAgent
...
...
flatland/flatland.py
View file @
43cf00ff
# -*- coding: utf-8 -*-
"""Main module."""
flatland/utils/svg.py
View file @
43cf00ff
import
copy
import
re
import
svgutils
import
re
import
copy
from
flatland.core.transitions
import
RailEnvTransitions
...
...
@@ -60,7 +60,7 @@ class SVG(object):
sNewStyles
=
"
\n
"
for
sKey
,
sValue
in
self
.
dStyles
.
items
():
if
sKey
==
style_name
:
sValue
=
"fill:#"
+
""
.
join
([(
'{:#04x}'
.
format
(
int
(
255.0
*
col
))[
2
:
4
])
for
col
in
color
[
0
:
3
]])
+
";"
sValue
=
"fill:#"
+
""
.
join
([(
'{:#04x}'
.
format
(
int
(
255.0
*
col
))[
2
:
4
])
for
col
in
color
[
0
:
3
]])
+
";"
sNewStyle
=
"
\t
.st"
+
sKey
+
"{"
+
sValue
+
"}
\n
"
sNewStyles
+=
sNewStyle
...
...
@@ -111,6 +111,7 @@ class Track(object):
The directions and images are also rotated by 90, 180 & 270 degrees.
(There is some redundancy in this process, given the images provided)
"""
def
__init__
(
self
):
dFiles
=
{
""
:
"Background_#9CCB89.svg"
,
...
...
make_coverage.py
View file @
43cf00ff
#!/usr/bin/env python
import
os
import
webbrowser
import
subprocess
import
webbrowser
from
urllib.request
import
pathname2url
def
browser
(
pathname
):
webbrowser
.
open
(
"file:"
+
pathname2url
(
os
.
path
.
abspath
(
pathname
)))
subprocess
.
call
([
'coverage'
,
'run'
,
'--source'
,
'flatland'
,
'-m'
,
'pytest'
])
subprocess
.
call
([
'coverage'
,
'report'
,
'-m'
])
subprocess
.
call
([
'coverage'
,
'html'
])
...
...
make_docs.py
View file @
43cf00ff
#!/usr/bin/env python
import
os
import
webbrowser
import
subprocess
import
webbrowser
from
urllib.request
import
pathname2url
def
browser
(
pathname
):
webbrowser
.
open
(
"file:"
+
pathname2url
(
os
.
path
.
abspath
(
pathname
)))
def
remove_exists
(
filename
):
try
:
os
.
remove
(
filename
)
...
...
setup.py
View file @
43cf00ff
...
...
@@ -3,13 +3,10 @@
"""The setup script."""
import
os
from
setuptools
import
setup
,
find_packages
import
sys
import
os
import
platform
import
sys
from
setuptools
import
setup
,
find_packages
with
open
(
'README.rst'
)
as
readme_file
:
readme
=
readme_file
.
read
()
...
...
@@ -17,10 +14,6 @@ with open('README.rst') as readme_file:
with
open
(
'HISTORY.rst'
)
as
history_file
:
history
=
history_file
.
read
()
# install pycairo on Windows
if
os
.
name
==
'nt'
:
p
=
platform
.
architecture
()
...
...
@@ -51,13 +44,14 @@ if os.name == 'nt':
import
site
import
ctypes.util
default_os_path
=
os
.
environ
[
'PATH'
]
os
.
environ
[
'PATH'
]
=
''
for
s
in
site
.
getsitepackages
():
os
.
environ
[
'PATH'
]
=
os
.
environ
[
'PATH'
]
+
';'
+
s
+
'
\\
cairo'
os
.
environ
[
'PATH'
]
=
os
.
environ
[
'PATH'
]
+
';'
+
default_os_path
os
.
environ
[
'PATH'
]
=
os
.
environ
[
'PATH'
]
+
';'
+
s
+
'
\\
cairo'
os
.
environ
[
'PATH'
]
=
os
.
environ
[
'PATH'
]
+
';'
+
default_os_path
print
(
os
.
environ
[
'PATH'
])
if
ctypes
.
util
.
find_library
(
'cairo'
)
is
not
None
:
if
ctypes
.
util
.
find_library
(
'cairo'
)
is
not
None
:
print
(
"cairo installed: OK"
)
else
:
try
:
...
...
@@ -69,7 +63,7 @@ else:
def
get_all_svg_files
(
directory
=
'./svg/'
):
ret
=
[]
for
f
in
os
.
listdir
(
directory
):
ret
.
append
(
directory
+
f
)
ret
.
append
(
directory
+
f
)
return
ret
...
...
tests/test_env_edit.py
View file @
43cf00ff
from
flatland.envs.rail_env
import
RailEnv
# from flatland.envs.agent_utils import EnvAgent
from
flatland.envs.agent_utils
import
EnvAgentStatic
from
flatland.envs.rail_env
import
RailEnv
def
test_load_env
():
...
...
@@ -11,5 +10,3 @@ def test_load_env():
agent_static
=
EnvAgentStatic
((
0
,
0
),
2
,
(
5
,
5
),
False
)
env
.
add_agent_static
(
agent_static
)
assert
env
.
get_num_agents
()
==
1
tests/test_env_observation_builder.py
View file @
43cf00ff
...
...
@@ -3,10 +3,10 @@
import
numpy
as
np
from
flatland.envs.observations
import
GlobalObsForRailEnv
from
flatland.core.transition_map
import
GridTransitionMap
,
Grid4Transitions
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.generators
import
rail_from_GridTransitionMap_generator
from
flatland.envs.observations
import
GlobalObsForRailEnv
from
flatland.envs.rail_env
import
RailEnv
"""Tests for `flatland` package."""
...
...
@@ -70,7 +70,7 @@ def test_global_obs():
# env_renderer.renderEnv(show=True)
# global_obs.reset()
assert
(
global_obs
[
0
][
0
].
shape
==
rail_map
.
shape
+
(
16
,))
assert
(
global_obs
[
0
][
0
].
shape
==
rail_map
.
shape
+
(
16
,))
rail_map_recons
=
np
.
zeros_like
(
rail_map
)
for
i
in
range
(
global_obs
[
0
][
0
].
shape
[
0
]):
...
...
@@ -78,11 +78,11 @@ def test_global_obs():
rail_map_recons
[
i
,
j
]
=
int
(
''
.
join
(
global_obs
[
0
][
0
][
i
,
j
].
astype
(
int
).
astype
(
str
)),
2
)
assert
(
rail_map_recons
.
all
()
==
rail_map
.
all
())
assert
(
rail_map_recons
.
all
()
==
rail_map
.
all
())
# If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell
assert
(
np
.
sum
(
rail_map
*
global_obs
[
0
][
1
][:,
:,
:
4
].
sum
(
2
))
>
0
)
assert
(
np
.
sum
(
rail_map
*
global_obs
[
0
][
1
][:,
:,
:
4
].
sum
(
2
))
>
0
)
def
main
():
...
...
tests/test_transitions.py
View file @
43cf00ff
...
...
@@ -2,10 +2,11 @@
# -*- coding: utf-8 -*-
"""Tests for `flatland` package."""
import
numpy
as
np
from
flatland.core.transitions
import
RailEnvTransitions
,
Grid8Transitions
# from flatland.envs.rail_env import validate_new_transition
from
flatland.envs.env_utils
import
validate_new_transition
import
numpy
as
np
def
test_is_valid_railenv_transitions
():
...
...
@@ -13,14 +14,14 @@ def test_is_valid_railenv_transitions():
transition_list
=
rail_env_trans
.
transitions
for
t
in
transition_list
:
assert
(
rail_env_trans
.
is_valid
(
t
)
is
True
)
assert
(
rail_env_trans
.
is_valid
(
t
)
is
True
)
for
i
in
range
(
3
):
rot_trans
=
rail_env_trans
.
rotate_transition
(
t
,
90
*
i
)
assert
(
rail_env_trans
.
is_valid
(
rot_trans
)
is
True
)
assert
(
rail_env_trans
.
is_valid
(
rot_trans
)
is
True
)
assert
(
rail_env_trans
.
is_valid
(
int
(
'1111111111110010'
,
2
))
is
False
)
assert
(
rail_env_trans
.
is_valid
(
int
(
'1001111111110010'
,
2
))
is
False
)
assert
(
rail_env_trans
.
is_valid
(
int
(
'1001111001110110'
,
2
))
is
False
)
assert
(
rail_env_trans
.
is_valid
(
int
(
'1111111111110010'
,
2
))
is
False
)
assert
(
rail_env_trans
.
is_valid
(
int
(
'1001111111110010'
,
2
))
is
False
)
assert
(
rail_env_trans
.
is_valid
(
int
(
'1001111001110110'
,
2
))
is
False
)
def
test_adding_new_valid_transition
():
...
...
@@ -28,32 +29,32 @@ def test_adding_new_valid_transition():
rail_array
=
np
.
zeros
(
shape
=
(
15
,
15
),
dtype
=
np
.
uint16
)
# adding straight
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
4
,
5
),
(
5
,
5
),
(
6
,
5
),
(
10
,
10
))
is
True
)
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
4
,
5
),
(
5
,
5
),
(
6
,
5
),
(
10
,
10
))
is
True
)
# adding valid right turn
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
5
,
4
),
(
5
,
5
),
(
5
,
6
),
(
10
,
10
))
is
True
)
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
5
,
4
),
(
5
,
5
),
(
5
,
6
),
(
10
,
10
))
is
True
)
# adding valid left turn
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
5
,
6
),
(
5
,
5
),
(
5
,
6
),
(
10
,
10
))
is
True
)
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
5
,
6
),
(
5
,
5
),
(
5
,
6
),
(
10
,
10
))
is
True
)
# adding invalid turn
rail_array
[(
5
,
5
)]
=
rail_trans
.
transitions
[
2
]
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
4
,
5
),
(
5
,
5
),
(
5
,
6
),
(
10
,
10
))
is
False
)
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
4
,
5
),
(
5
,
5
),
(
5
,
6
),
(
10
,
10
))
is
False
)
# should create #4 -> valid
rail_array
[(
5
,
5
)]
=
rail_trans
.
transitions
[
3
]
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
4
,
5
),
(
5
,
5
),
(
5
,
6
),
(
10
,
10
))
is
True
)
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
4
,
5
),
(
5
,
5
),
(
5
,
6
),
(
10
,
10
))
is
True
)
# adding invalid turn
rail_array
[(
5
,
5
)]
=
rail_trans
.
transitions
[
7
]
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
4
,
5
),
(
5
,
5
),
(
5
,
6
),
(
10
,
10
))
is
False
)
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
4
,
5
),
(
5
,
5
),
(
5
,
6
),
(
10
,
10
))
is
False
)
# test path start condition
rail_array
[(
5
,
5
)]
=
rail_trans
.
transitions
[
0
]
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
None
,
(
5
,
5
),
(
5
,
6
),
(
10
,
10
))
is
True
)
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
None
,
(
5
,
5
),
(
5
,
6
),
(
10
,
10
))
is
True
)
# test path end condition
rail_array
[(
5
,
5
)]
=
rail_trans
.
transitions
[
0
]
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
5
,
4
),
(
5
,
5
),
(
6
,
5
),
(
6
,
5
))
is
True
)
assert
(
validate_new_transition
(
rail_trans
,
rail_array
,
(
5
,
4
),
(
5
,
5
),
(
6
,
5
),
(
6
,
5
))
is
True
)
def
test_valid_railenv_transitions
():
...
...
@@ -65,48 +66,48 @@ def test_valid_railenv_transitions():
# 'W': 3}
for
i
in
range
(
2
):
assert
(
rail_env_trans
.
get_transitions
(
int
(
'1100110000110011'
,
2
),
i
)
==
(
1
,
1
,
0
,
0
))
assert
(
rail_env_trans
.
get_transitions
(
int
(
'1100110000110011'
,
2
),
2
+
i
)
==
(
0
,
0
,
1
,
1
))
assert
(
rail_env_trans
.
get_transitions
(
int
(
'1100110000110011'
,
2
),
i
)
==
(
1
,
1
,
0
,
0
))
assert
(
rail_env_trans
.
get_transitions
(
int
(
'1100110000110011'
,
2
),
2
+
i
)
==
(
0
,
0
,
1
,
1
))
no_transition_cell
=
int
(
'0000000000000000'
,
2
)
for
i
in
range
(
4
):
assert
(
rail_env_trans
.
get_transitions
(
no_transition_cell
,
i
)
==
(
0
,
0
,
0
,
0
))
assert
(
rail_env_trans
.
get_transitions
(
no_transition_cell
,
i
)
==
(
0
,
0
,
0
,
0
))
# Facing south, going south
north_south_transition
=
rail_env_trans
.
set_transitions
(
no_transition_cell
,
2
,
(
0
,
0
,
1
,
0
))
assert
(
rail_env_trans
.
set_transition
(
north_south_transition
,
2
,
2
,
0
)
==
no_transition_cell
)
assert
(
rail_env_trans
.
get_transition
(
north_south_transition
,
2
,
2
))
assert
(
rail_env_trans
.
set_transition
(
north_south_transition
,
2
,
2
,
0
)
==
no_transition_cell
)
assert
(
rail_env_trans
.
get_transition
(
north_south_transition
,
2
,
2
))
# Facing north, going east
south_east_transition
=
\
rail_env_trans
.
set_transition
(
no_transition_cell
,
0
,
1
,
1
)
assert
(
rail_env_trans
.
get_transition
(
south_east_transition
,
0
,
1
))
assert
(
rail_env_trans
.
get_transition
(
south_east_transition
,
0
,
1
))
# The opposite transitions are not feasible
assert
(
not
rail_env_trans
.
get_transition
(
north_south_transition
,
2
,
0
))
assert
(
not
rail_env_trans
.
get_transition
(
south_east_transition
,
2
,
1
))
assert
(
not
rail_env_trans
.
get_transition
(
north_south_transition
,
2
,
0
))
assert
(
not
rail_env_trans
.
get_transition
(
south_east_transition
,
2
,
1
))
east_west_transition
=
rail_env_trans
.
rotate_transition
(
north_south_transition
,
90
)
north_west_transition
=
rail_env_trans
.
rotate_transition
(
south_east_transition
,
180
)
# Facing west, going west
assert
(
rail_env_trans
.
get_transition
(
east_west_transition
,
3
,
3
))
assert
(
rail_env_trans
.
get_transition
(
east_west_transition
,
3
,
3
))
# Facing south, going west
assert
(
rail_env_trans
.
get_transition
(
north_west_transition
,
2
,
3
))
assert
(
rail_env_trans
.
get_transition
(
north_west_transition
,
2
,
3
))
assert
(
south_east_transition
==
rail_env_trans
.
rotate_transition
(
south_east_transition
,
360
))
assert
(
south_east_transition
==
rail_env_trans
.
rotate_transition
(
south_east_transition
,
360
))
def
test_diagonal_transitions
():
...
...
@@ -114,12 +115,12 @@ def test_diagonal_transitions():
# Facing north, going north-east
south_northeast_transition
=
int
(
'01000000'
+
'0'
*
8
*
7
,
2
)
assert
(
diagonal_trans_env
.
get_transitions
(
south_northeast_transition
,
0
)
==
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
))
assert
(
diagonal_trans_env
.
get_transitions
(
south_northeast_transition
,
0
)
==
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
))
# Allowing transition from north to southwest: Facing south, going SW
north_southwest_transition
=
\
diagonal_trans_env
.
set_transitions
(
int
(
'0'
*
64
,
2
),
4
,
(
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
))
assert
(
diagonal_trans_env
.
rotate_transition
(
south_northeast_transition
,
180
)
==
north_southwest_transition
)
assert
(
diagonal_trans_env
.
rotate_transition
(
south_northeast_transition
,
180
)
==
north_southwest_transition
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment