Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
sfwatergit
Flatland
Commits
482a73e5
Commit
482a73e5
authored
6 years ago
by
Erik Nygren
Browse files
Options
Downloads
Patches
Plain Diff
updated curve calculation
parent
2f7ee5c1
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
flatland/envs/rail_env.py
+39
-34
39 additions, 34 deletions
flatland/envs/rail_env.py
with
39 additions
and
34 deletions
flatland/envs/rail_env.py
+
39
−
34
View file @
482a73e5
...
@@ -163,9 +163,9 @@ def a_star(rail_trans, rail_array, start, end):
...
@@ -163,9 +163,9 @@ def a_star(rail_trans, rail_array, start, end):
for
new_pos
in
[(
0
,
-
1
),
(
0
,
1
),
(
-
1
,
0
),
(
1
,
0
)]:
for
new_pos
in
[(
0
,
-
1
),
(
0
,
1
),
(
-
1
,
0
),
(
1
,
0
)]:
node_pos
=
(
current_node
.
pos
[
0
]
+
new_pos
[
0
],
current_node
.
pos
[
1
]
+
new_pos
[
1
])
node_pos
=
(
current_node
.
pos
[
0
]
+
new_pos
[
0
],
current_node
.
pos
[
1
]
+
new_pos
[
1
])
if
node_pos
[
0
]
>=
rail_shape
[
0
]
or
\
if
node_pos
[
0
]
>=
rail_shape
[
0
]
or
\
node_pos
[
0
]
<
0
or
\
node_pos
[
0
]
<
0
or
\
node_pos
[
1
]
>=
rail_shape
[
1
]
or
\
node_pos
[
1
]
>=
rail_shape
[
1
]
or
\
node_pos
[
1
]
<
0
:
node_pos
[
1
]
<
0
:
continue
continue
# validate positions
# validate positions
...
@@ -232,7 +232,7 @@ def connect_rail(rail_trans, rail_array, start, end):
...
@@ -232,7 +232,7 @@ def connect_rail(rail_trans, rail_array, start, end):
end_pos
=
path
[
-
1
]
end_pos
=
path
[
-
1
]
for
index
in
range
(
len
(
path
)
-
1
):
for
index
in
range
(
len
(
path
)
-
1
):
current_pos
=
path
[
index
]
current_pos
=
path
[
index
]
new_pos
=
path
[
index
+
1
]
new_pos
=
path
[
index
+
1
]
new_dir
=
get_direction
(
current_pos
,
new_pos
)
new_dir
=
get_direction
(
current_pos
,
new_pos
)
new_trans
=
rail_array
[
current_pos
]
new_trans
=
rail_array
[
current_pos
]
...
@@ -359,6 +359,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
...
@@ -359,6 +359,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
# print("too close:", dist, sg_new[i], sg[j])
# print("too close:", dist, sg_new[i], sg[j])
return
False
return
False
return
True
return
True
if
check_all_dist
(
sg_new
):
if
check_all_dist
(
sg_new
):
break
break
start_goal
.
append
([
start
,
goal
])
start_goal
.
append
([
start
,
goal
])
...
@@ -394,6 +395,7 @@ def rail_from_manual_specifications_generator(rail_spec):
...
@@ -394,6 +395,7 @@ def rail_from_manual_specifications_generator(rail_spec):
Generator function that always returns a GridTransitionMap object with
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each cell.
the matrix of correct 16-bit bitmaps for each cell.
"""
"""
def
generator
(
width
,
height
,
num_resets
=
0
):
def
generator
(
width
,
height
,
num_resets
=
0
):
t_utils
=
RailEnvTransitions
()
t_utils
=
RailEnvTransitions
()
...
@@ -429,6 +431,7 @@ def rail_from_GridTransitionMap_generator(rail_map):
...
@@ -429,6 +431,7 @@ def rail_from_GridTransitionMap_generator(rail_map):
function
function
Generator function that always returns the given `rail_map
'
object.
Generator function that always returns the given `rail_map
'
object.
"""
"""
def
generator
(
width
,
height
,
num_resets
=
0
):
def
generator
(
width
,
height
,
num_resets
=
0
):
return
rail_map
return
rail_map
...
@@ -449,6 +452,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
...
@@ -449,6 +452,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
function
function
Generator function that always returns the given `rail_map
'
object.
Generator function that always returns the given `rail_map
'
object.
"""
"""
def
generator
(
width
,
height
,
num_resets
=
0
):
def
generator
(
width
,
height
,
num_resets
=
0
):
t_utils
=
RailEnvTransitions
()
t_utils
=
RailEnvTransitions
()
rail_map
=
GridTransitionMap
(
width
=
width
,
height
=
height
,
transitions
=
t_utils
)
rail_map
=
GridTransitionMap
(
width
=
width
,
height
=
height
,
transitions
=
t_utils
)
...
@@ -525,9 +529,9 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
...
@@ -525,9 +529,9 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
# add all rotations
# add all rotations
for
rot
in
[
0
,
90
,
180
,
270
]:
for
rot
in
[
0
,
90
,
180
,
270
]:
transitions_templates_
.
append
((
template
,
transitions_templates_
.
append
((
template
,
t_utils
.
rotate_transition
(
t_utils
.
rotate_transition
(
t_utils
.
transitions
[
i
],
t_utils
.
transitions
[
i
],
rot
)))
rot
)))
transition_probabilities
.
append
(
transition_probability
[
i
])
transition_probabilities
.
append
(
transition_probability
[
i
])
template
=
[
template
[
-
1
]]
+
template
[:
-
1
]
template
=
[
template
[
-
1
]]
+
template
[:
-
1
]
...
@@ -537,7 +541,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
...
@@ -537,7 +541,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
is_match
=
True
is_match
=
True
for
j
in
range
(
4
):
for
j
in
range
(
4
):
if
template
[
j
]
>=
0
and
\
if
template
[
j
]
>=
0
and
\
template
[
j
]
!=
transitions_templates_
[
i
][
0
][
j
]:
template
[
j
]
!=
transitions_templates_
[
i
][
0
][
j
]:
is_match
=
False
is_match
=
False
break
break
if
is_match
:
if
is_match
:
...
@@ -678,7 +682,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
...
@@ -678,7 +682,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans
=
rail
[
r
][
1
]
neigh_trans
=
rail
[
r
][
1
]
if
neigh_trans
is
not
None
:
if
neigh_trans
is
not
None
:
for
k
in
range
(
4
):
for
k
in
range
(
4
):
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
1
)
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
1
)
if
max_bit
:
if
max_bit
:
rail
[
r
][
0
]
=
t_utils
.
rotate_transition
(
int
(
'
0010000000000000
'
,
2
),
270
)
rail
[
r
][
0
]
=
t_utils
.
rotate_transition
(
int
(
'
0010000000000000
'
,
2
),
270
)
...
@@ -690,7 +694,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
...
@@ -690,7 +694,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans
=
rail
[
r
][
-
2
]
neigh_trans
=
rail
[
r
][
-
2
]
if
neigh_trans
is
not
None
:
if
neigh_trans
is
not
None
:
for
k
in
range
(
4
):
for
k
in
range
(
4
):
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
(
1
<<
2
))
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
(
1
<<
2
))
if
max_bit
:
if
max_bit
:
rail
[
r
][
-
1
]
=
t_utils
.
rotate_transition
(
int
(
'
0010000000000000
'
,
2
),
rail
[
r
][
-
1
]
=
t_utils
.
rotate_transition
(
int
(
'
0010000000000000
'
,
2
),
...
@@ -704,7 +708,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
...
@@ -704,7 +708,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans
=
rail
[
1
][
c
]
neigh_trans
=
rail
[
1
][
c
]
if
neigh_trans
is
not
None
:
if
neigh_trans
is
not
None
:
for
k
in
range
(
4
):
for
k
in
range
(
4
):
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
(
1
<<
3
))
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
(
1
<<
3
))
if
max_bit
:
if
max_bit
:
rail
[
0
][
c
]
=
int
(
'
0010000000000000
'
,
2
)
rail
[
0
][
c
]
=
int
(
'
0010000000000000
'
,
2
)
...
@@ -716,7 +720,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
...
@@ -716,7 +720,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans
=
rail
[
-
2
][
c
]
neigh_trans
=
rail
[
-
2
][
c
]
if
neigh_trans
is
not
None
:
if
neigh_trans
is
not
None
:
for
k
in
range
(
4
):
for
k
in
range
(
4
):
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
(
1
<<
1
))
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
(
1
<<
1
))
if
max_bit
:
if
max_bit
:
rail
[
-
1
][
c
]
=
t_utils
.
rotate_transition
(
int
(
'
0010000000000000
'
,
2
),
180
)
rail
[
-
1
][
c
]
=
t_utils
.
rotate_transition
(
int
(
'
0010000000000000
'
,
2
),
180
)
...
@@ -840,9 +844,9 @@ class RailEnv(Environment):
...
@@ -840,9 +844,9 @@ class RailEnv(Environment):
def
check_agent_lists
(
self
):
def
check_agent_lists
(
self
):
for
lAgents
,
name
in
zip
(
for
lAgents
,
name
in
zip
(
[
self
.
agents_handles
,
self
.
agents_position
,
self
.
agents_direction
],
[
self
.
agents_handles
,
self
.
agents_position
,
self
.
agents_direction
],
[
"
handles
"
,
"
positions
"
,
"
directions
"
]):
[
"
handles
"
,
"
positions
"
,
"
directions
"
]):
assert
self
.
number_of_agents
==
len
(
lAgents
),
"
Inconsistent agent list:
"
+
name
assert
self
.
number_of_agents
==
len
(
lAgents
),
"
Inconsistent agent list:
"
+
name
def
check_agent_locdirpath
(
self
,
iAgent
):
def
check_agent_locdirpath
(
self
,
iAgent
):
valid_movements
=
[]
valid_movements
=
[]
...
@@ -857,7 +861,7 @@ class RailEnv(Environment):
...
@@ -857,7 +861,7 @@ class RailEnv(Environment):
for
m
in
valid_movements
:
for
m
in
valid_movements
:
new_position
=
self
.
_new_position
(
self
.
agents_position
[
iAgent
],
m
[
1
])
new_position
=
self
.
_new_position
(
self
.
agents_position
[
iAgent
],
m
[
1
])
if
m
[
0
]
not
in
valid_starting_directions
and
\
if
m
[
0
]
not
in
valid_starting_directions
and
\
self
.
_path_exists
(
new_position
,
m
[
0
],
self
.
agents_target
[
iAgent
]):
self
.
_path_exists
(
new_position
,
m
[
0
],
self
.
agents_target
[
iAgent
]):
valid_starting_directions
.
append
(
m
[
0
])
valid_starting_directions
.
append
(
m
[
0
])
if
len
(
valid_starting_directions
)
==
0
:
if
len
(
valid_starting_directions
)
==
0
:
...
@@ -876,7 +880,7 @@ class RailEnv(Environment):
...
@@ -876,7 +880,7 @@ class RailEnv(Environment):
for
m
in
valid_movements
:
for
m
in
valid_movements
:
new_position
=
self
.
_new_position
(
rcPos
,
m
[
1
])
new_position
=
self
.
_new_position
(
rcPos
,
m
[
1
])
if
m
[
0
]
not
in
valid_starting_directions
and
\
if
m
[
0
]
not
in
valid_starting_directions
and
\
self
.
_path_exists
(
new_position
,
m
[
0
],
rcTarget
):
self
.
_path_exists
(
new_position
,
m
[
0
],
rcTarget
):
valid_starting_directions
.
append
(
m
[
0
])
valid_starting_directions
.
append
(
m
[
0
])
if
len
(
valid_starting_directions
)
==
0
:
if
len
(
valid_starting_directions
)
==
0
:
...
@@ -891,7 +895,7 @@ class RailEnv(Environment):
...
@@ -891,7 +895,7 @@ class RailEnv(Environment):
rcPos
=
np
.
random
.
choice
(
len
(
self
.
valid_positions
))
rcPos
=
np
.
random
.
choice
(
len
(
self
.
valid_positions
))
iAgent
=
self
.
number_of_agents
iAgent
=
self
.
number_of_agents
self
.
agents_position
.
append
(
tuple
(
rcPos
))
# ensure it's a tuple not a list
self
.
agents_position
.
append
(
tuple
(
rcPos
))
# ensure it's a tuple not a list
self
.
agents_handles
.
append
(
max
(
self
.
agents_handles
+
[
-
1
])
+
1
)
# max(handles) + 1, starting at 0
self
.
agents_handles
.
append
(
max
(
self
.
agents_handles
+
[
-
1
])
+
1
)
# max(handles) + 1, starting at 0
...
@@ -902,7 +906,7 @@ class RailEnv(Environment):
...
@@ -902,7 +906,7 @@ class RailEnv(Environment):
self
.
number_of_agents
+=
1
self
.
number_of_agents
+=
1
self
.
check_agent_lists
()
self
.
check_agent_lists
()
return
iAgent
return
iAgent
def
reset
(
self
,
regen_rail
=
True
,
replace_agents
=
True
):
def
reset
(
self
,
regen_rail
=
True
,
replace_agents
=
True
):
if
regen_rail
or
self
.
rail
is
None
:
if
regen_rail
or
self
.
rail
is
None
:
# TODO: Import not only rail information but also start and goal positions
# TODO: Import not only rail information but also start and goal positions
...
@@ -961,7 +965,7 @@ class RailEnv(Environment):
...
@@ -961,7 +965,7 @@ class RailEnv(Environment):
for
m
in
valid_movements
:
for
m
in
valid_movements
:
new_position
=
self
.
_new_position
(
self
.
agents_position
[
i
],
m
[
1
])
new_position
=
self
.
_new_position
(
self
.
agents_position
[
i
],
m
[
1
])
if
m
[
0
]
not
in
valid_starting_directions
and
\
if
m
[
0
]
not
in
valid_starting_directions
and
\
self
.
_path_exists
(
new_position
,
m
[
0
],
self
.
agents_target
[
i
]):
self
.
_path_exists
(
new_position
,
m
[
0
],
self
.
agents_target
[
i
]):
valid_starting_directions
.
append
(
m
[
0
])
valid_starting_directions
.
append
(
m
[
0
])
if
len
(
valid_starting_directions
)
==
0
:
if
len
(
valid_starting_directions
)
==
0
:
...
@@ -1011,6 +1015,15 @@ class RailEnv(Environment):
...
@@ -1011,6 +1015,15 @@ class RailEnv(Environment):
pos
=
self
.
agents_position
[
i
]
pos
=
self
.
agents_position
[
i
]
direction
=
self
.
agents_direction
[
i
]
direction
=
self
.
agents_direction
[
i
]
# compute number of possible transitions in the current
# cell used to check for invalid actions
nbits
=
0
tmp
=
self
.
rail
.
get_transitions
((
pos
[
0
],
pos
[
1
]))
while
tmp
>
0
:
nbits
+=
(
tmp
&
1
)
tmp
=
tmp
>>
1
movement
=
direction
movement
=
direction
if
action
==
1
:
if
action
==
1
:
movement
=
direction
-
1
movement
=
direction
-
1
...
@@ -1024,14 +1037,6 @@ class RailEnv(Environment):
...
@@ -1024,14 +1037,6 @@ class RailEnv(Environment):
is_deadend
=
False
is_deadend
=
False
if
action
==
2
:
if
action
==
2
:
# compute number of possible transitions in the current
# cell
nbits
=
0
tmp
=
self
.
rail
.
get_transitions
((
pos
[
0
],
pos
[
1
]))
while
tmp
>
0
:
nbits
+=
(
tmp
&
1
)
tmp
=
tmp
>>
1
if
nbits
==
1
:
if
nbits
==
1
:
# dead-end; assuming the rail network is consistent,
# dead-end; assuming the rail network is consistent,
# this should match the direction the agent has come
# this should match the direction the agent has come
...
@@ -1074,9 +1079,9 @@ class RailEnv(Environment):
...
@@ -1074,9 +1079,9 @@ class RailEnv(Environment):
# Is it a legal move? 1) transition allows the movement in the
# Is it a legal move? 1) transition allows the movement in the
# cell, 2) the new cell is not empty (case 0), 3) the cell is
# cell, 2) the new cell is not empty (case 0), 3) the cell is
# free, i.e., no agent is currently in that cell
# free, i.e., no agent is currently in that cell
if
new_position
[
1
]
>=
self
.
width
or
\
if
new_position
[
1
]
>=
self
.
width
or
\
new_position
[
0
]
>=
self
.
height
or
\
new_position
[
0
]
>=
self
.
height
or
\
new_position
[
0
]
<
0
or
new_position
[
1
]
<
0
:
new_position
[
0
]
<
0
or
new_position
[
1
]
<
0
:
new_cell_isValid
=
False
new_cell_isValid
=
False
elif
self
.
rail
.
get_transitions
((
new_position
[
0
],
new_position
[
1
]))
>
0
:
elif
self
.
rail
.
get_transitions
((
new_position
[
0
],
new_position
[
1
]))
>
0
:
...
@@ -1105,7 +1110,7 @@ class RailEnv(Environment):
...
@@ -1105,7 +1110,7 @@ class RailEnv(Environment):
# if agent is not in target position, add step penalty
# if agent is not in target position, add step penalty
if
self
.
agents_position
[
i
][
0
]
==
self
.
agents_target
[
i
][
0
]
and
\
if
self
.
agents_position
[
i
][
0
]
==
self
.
agents_target
[
i
][
0
]
and
\
self
.
agents_position
[
i
][
1
]
==
self
.
agents_target
[
i
][
1
]:
self
.
agents_position
[
i
][
1
]
==
self
.
agents_target
[
i
][
1
]:
self
.
dones
[
handle
]
=
True
self
.
dones
[
handle
]
=
True
else
:
else
:
self
.
rewards_dict
[
handle
]
+=
step_penalty
self
.
rewards_dict
[
handle
]
+=
step_penalty
...
@@ -1114,7 +1119,7 @@ class RailEnv(Environment):
...
@@ -1114,7 +1119,7 @@ class RailEnv(Environment):
num_agents_in_target_position
=
0
num_agents_in_target_position
=
0
for
i
in
range
(
self
.
number_of_agents
):
for
i
in
range
(
self
.
number_of_agents
):
if
self
.
agents_position
[
i
][
0
]
==
self
.
agents_target
[
i
][
0
]
and
\
if
self
.
agents_position
[
i
][
0
]
==
self
.
agents_target
[
i
][
0
]
and
\
self
.
agents_position
[
i
][
1
]
==
self
.
agents_target
[
i
][
1
]:
self
.
agents_position
[
i
][
1
]
==
self
.
agents_target
[
i
][
1
]:
num_agents_in_target_position
+=
1
num_agents_in_target_position
+=
1
if
num_agents_in_target_position
==
self
.
number_of_agents
:
if
num_agents_in_target_position
==
self
.
number_of_agents
:
...
@@ -1127,7 +1132,7 @@ class RailEnv(Environment):
...
@@ -1127,7 +1132,7 @@ class RailEnv(Environment):
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
{}
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
{}
def
_new_position
(
self
,
position
,
movement
):
def
_new_position
(
self
,
position
,
movement
):
if
movement
==
0
:
# NORTH
if
movement
==
0
:
# NORTH
return
(
position
[
0
]
-
1
,
position
[
1
])
return
(
position
[
0
]
-
1
,
position
[
1
])
elif
movement
==
1
:
# EAST
elif
movement
==
1
:
# EAST
return
(
position
[
0
],
position
[
1
]
+
1
)
return
(
position
[
0
],
position
[
1
]
+
1
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment