Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
482a73e5
Commit
482a73e5
authored
May 01, 2019
by
Erik Nygren
Browse files
updated curve calculation
parent
2f7ee5c1
Pipeline
#458
failed with stage
in 1 minute and 47 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
flatland/envs/rail_env.py
View file @
482a73e5
...
...
@@ -163,9 +163,9 @@ def a_star(rail_trans, rail_array, start, end):
for
new_pos
in
[(
0
,
-
1
),
(
0
,
1
),
(
-
1
,
0
),
(
1
,
0
)]:
node_pos
=
(
current_node
.
pos
[
0
]
+
new_pos
[
0
],
current_node
.
pos
[
1
]
+
new_pos
[
1
])
if
node_pos
[
0
]
>=
rail_shape
[
0
]
or
\
node_pos
[
0
]
<
0
or
\
node_pos
[
1
]
>=
rail_shape
[
1
]
or
\
node_pos
[
1
]
<
0
:
node_pos
[
0
]
<
0
or
\
node_pos
[
1
]
>=
rail_shape
[
1
]
or
\
node_pos
[
1
]
<
0
:
continue
# validate positions
...
...
@@ -232,7 +232,7 @@ def connect_rail(rail_trans, rail_array, start, end):
end_pos
=
path
[
-
1
]
for
index
in
range
(
len
(
path
)
-
1
):
current_pos
=
path
[
index
]
new_pos
=
path
[
index
+
1
]
new_pos
=
path
[
index
+
1
]
new_dir
=
get_direction
(
current_pos
,
new_pos
)
new_trans
=
rail_array
[
current_pos
]
...
...
@@ -359,6 +359,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
# print("too close:", dist, sg_new[i], sg[j])
return
False
return
True
if
check_all_dist
(
sg_new
):
break
start_goal
.
append
([
start
,
goal
])
...
...
@@ -394,6 +395,7 @@ def rail_from_manual_specifications_generator(rail_spec):
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each cell.
"""
def
generator
(
width
,
height
,
num_resets
=
0
):
t_utils
=
RailEnvTransitions
()
...
...
@@ -429,6 +431,7 @@ def rail_from_GridTransitionMap_generator(rail_map):
function
Generator function that always returns the given `rail_map' object.
"""
def
generator
(
width
,
height
,
num_resets
=
0
):
return
rail_map
...
...
@@ -449,6 +452,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
function
Generator function that always returns the given `rail_map' object.
"""
def
generator
(
width
,
height
,
num_resets
=
0
):
t_utils
=
RailEnvTransitions
()
rail_map
=
GridTransitionMap
(
width
=
width
,
height
=
height
,
transitions
=
t_utils
)
...
...
@@ -525,9 +529,9 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
# add all rotations
for
rot
in
[
0
,
90
,
180
,
270
]:
transitions_templates_
.
append
((
template
,
t_utils
.
rotate_transition
(
t_utils
.
transitions
[
i
],
rot
)))
t_utils
.
rotate_transition
(
t_utils
.
transitions
[
i
],
rot
)))
transition_probabilities
.
append
(
transition_probability
[
i
])
template
=
[
template
[
-
1
]]
+
template
[:
-
1
]
...
...
@@ -537,7 +541,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
is_match
=
True
for
j
in
range
(
4
):
if
template
[
j
]
>=
0
and
\
template
[
j
]
!=
transitions_templates_
[
i
][
0
][
j
]:
template
[
j
]
!=
transitions_templates_
[
i
][
0
][
j
]:
is_match
=
False
break
if
is_match
:
...
...
@@ -678,7 +682,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans
=
rail
[
r
][
1
]
if
neigh_trans
is
not
None
:
for
k
in
range
(
4
):
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
1
)
if
max_bit
:
rail
[
r
][
0
]
=
t_utils
.
rotate_transition
(
int
(
'0010000000000000'
,
2
),
270
)
...
...
@@ -690,7 +694,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans
=
rail
[
r
][
-
2
]
if
neigh_trans
is
not
None
:
for
k
in
range
(
4
):
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
(
1
<<
2
))
if
max_bit
:
rail
[
r
][
-
1
]
=
t_utils
.
rotate_transition
(
int
(
'0010000000000000'
,
2
),
...
...
@@ -704,7 +708,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans
=
rail
[
1
][
c
]
if
neigh_trans
is
not
None
:
for
k
in
range
(
4
):
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
(
1
<<
3
))
if
max_bit
:
rail
[
0
][
c
]
=
int
(
'0010000000000000'
,
2
)
...
...
@@ -716,7 +720,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
neigh_trans
=
rail
[
-
2
][
c
]
if
neigh_trans
is
not
None
:
for
k
in
range
(
4
):
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
(
1
<<
1
))
if
max_bit
:
rail
[
-
1
][
c
]
=
t_utils
.
rotate_transition
(
int
(
'0010000000000000'
,
2
),
180
)
...
...
@@ -840,9 +844,9 @@ class RailEnv(Environment):
def
check_agent_lists
(
self
):
for
lAgents
,
name
in
zip
(
[
self
.
agents_handles
,
self
.
agents_position
,
self
.
agents_direction
],
[
"handles"
,
"positions"
,
"directions"
]):
assert
self
.
number_of_agents
==
len
(
lAgents
),
"Inconsistent agent list:"
+
name
[
self
.
agents_handles
,
self
.
agents_position
,
self
.
agents_direction
],
[
"handles"
,
"positions"
,
"directions"
]):
assert
self
.
number_of_agents
==
len
(
lAgents
),
"Inconsistent agent list:"
+
name
def
check_agent_locdirpath
(
self
,
iAgent
):
valid_movements
=
[]
...
...
@@ -857,7 +861,7 @@ class RailEnv(Environment):
for
m
in
valid_movements
:
new_position
=
self
.
_new_position
(
self
.
agents_position
[
iAgent
],
m
[
1
])
if
m
[
0
]
not
in
valid_starting_directions
and
\
self
.
_path_exists
(
new_position
,
m
[
0
],
self
.
agents_target
[
iAgent
]):
self
.
_path_exists
(
new_position
,
m
[
0
],
self
.
agents_target
[
iAgent
]):
valid_starting_directions
.
append
(
m
[
0
])
if
len
(
valid_starting_directions
)
==
0
:
...
...
@@ -876,7 +880,7 @@ class RailEnv(Environment):
for
m
in
valid_movements
:
new_position
=
self
.
_new_position
(
rcPos
,
m
[
1
])
if
m
[
0
]
not
in
valid_starting_directions
and
\
self
.
_path_exists
(
new_position
,
m
[
0
],
rcTarget
):
self
.
_path_exists
(
new_position
,
m
[
0
],
rcTarget
):
valid_starting_directions
.
append
(
m
[
0
])
if
len
(
valid_starting_directions
)
==
0
:
...
...
@@ -891,7 +895,7 @@ class RailEnv(Environment):
rcPos
=
np
.
random
.
choice
(
len
(
self
.
valid_positions
))
iAgent
=
self
.
number_of_agents
self
.
agents_position
.
append
(
tuple
(
rcPos
))
# ensure it's a tuple not a list
self
.
agents_handles
.
append
(
max
(
self
.
agents_handles
+
[
-
1
])
+
1
)
# max(handles) + 1, starting at 0
...
...
@@ -902,7 +906,7 @@ class RailEnv(Environment):
self
.
number_of_agents
+=
1
self
.
check_agent_lists
()
return
iAgent
def
reset
(
self
,
regen_rail
=
True
,
replace_agents
=
True
):
if
regen_rail
or
self
.
rail
is
None
:
# TODO: Import not only rail information but also start and goal positions
...
...
@@ -961,7 +965,7 @@ class RailEnv(Environment):
for
m
in
valid_movements
:
new_position
=
self
.
_new_position
(
self
.
agents_position
[
i
],
m
[
1
])
if
m
[
0
]
not
in
valid_starting_directions
and
\
self
.
_path_exists
(
new_position
,
m
[
0
],
self
.
agents_target
[
i
]):
self
.
_path_exists
(
new_position
,
m
[
0
],
self
.
agents_target
[
i
]):
valid_starting_directions
.
append
(
m
[
0
])
if
len
(
valid_starting_directions
)
==
0
:
...
...
@@ -1011,6 +1015,15 @@ class RailEnv(Environment):
pos
=
self
.
agents_position
[
i
]
direction
=
self
.
agents_direction
[
i
]
# compute number of possible transitions in the current
# cell used to check for invalid actions
nbits
=
0
tmp
=
self
.
rail
.
get_transitions
((
pos
[
0
],
pos
[
1
]))
while
tmp
>
0
:
nbits
+=
(
tmp
&
1
)
tmp
=
tmp
>>
1
movement
=
direction
if
action
==
1
:
movement
=
direction
-
1
...
...
@@ -1024,14 +1037,6 @@ class RailEnv(Environment):
is_deadend
=
False
if
action
==
2
:
# compute number of possible transitions in the current
# cell
nbits
=
0
tmp
=
self
.
rail
.
get_transitions
((
pos
[
0
],
pos
[
1
]))
while
tmp
>
0
:
nbits
+=
(
tmp
&
1
)
tmp
=
tmp
>>
1
if
nbits
==
1
:
# dead-end; assuming the rail network is consistent,
# this should match the direction the agent has come
...
...
@@ -1074,9 +1079,9 @@ class RailEnv(Environment):
# Is it a legal move? 1) transition allows the movement in the
# cell, 2) the new cell is not empty (case 0), 3) the cell is
# free, i.e., no agent is currently in that cell
if
new_position
[
1
]
>=
self
.
width
or
\
new_position
[
0
]
>=
self
.
height
or
\
new_position
[
0
]
<
0
or
new_position
[
1
]
<
0
:
if
new_position
[
1
]
>=
self
.
width
or
\
new_position
[
0
]
>=
self
.
height
or
\
new_position
[
0
]
<
0
or
new_position
[
1
]
<
0
:
new_cell_isValid
=
False
elif
self
.
rail
.
get_transitions
((
new_position
[
0
],
new_position
[
1
]))
>
0
:
...
...
@@ -1105,7 +1110,7 @@ class RailEnv(Environment):
# if agent is not in target position, add step penalty
if
self
.
agents_position
[
i
][
0
]
==
self
.
agents_target
[
i
][
0
]
and
\
self
.
agents_position
[
i
][
1
]
==
self
.
agents_target
[
i
][
1
]:
self
.
agents_position
[
i
][
1
]
==
self
.
agents_target
[
i
][
1
]:
self
.
dones
[
handle
]
=
True
else
:
self
.
rewards_dict
[
handle
]
+=
step_penalty
...
...
@@ -1114,7 +1119,7 @@ class RailEnv(Environment):
num_agents_in_target_position
=
0
for
i
in
range
(
self
.
number_of_agents
):
if
self
.
agents_position
[
i
][
0
]
==
self
.
agents_target
[
i
][
0
]
and
\
self
.
agents_position
[
i
][
1
]
==
self
.
agents_target
[
i
][
1
]:
self
.
agents_position
[
i
][
1
]
==
self
.
agents_target
[
i
][
1
]:
num_agents_in_target_position
+=
1
if
num_agents_in_target_position
==
self
.
number_of_agents
:
...
...
@@ -1127,7 +1132,7 @@ class RailEnv(Environment):
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
{}
def
_new_position
(
self
,
position
,
movement
):
if
movement
==
0
:
# NORTH
if
movement
==
0
:
# NORTH
return
(
position
[
0
]
-
1
,
position
[
1
])
elif
movement
==
1
:
# EAST
return
(
position
[
0
],
position
[
1
]
+
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment