Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
yoogottamk
Flatland
Commits
c07e281f
Commit
c07e281f
authored
5 years ago
by
gmollard
Browse files
Options
Downloads
Plain Diff
basic test work
parents
493b8a20
c794eb4e
No related branches found
Branches containing commit
No related tags found
Tags containing commit
Loading
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
flatland/core/env_observation_builder.py
+347
-6
347 additions, 6 deletions
flatland/core/env_observation_builder.py
tests/test_env_observation_builder.py
+38
-9
38 additions, 9 deletions
tests/test_env_observation_builder.py
with
385 additions
and
15 deletions
flatland/core/env_observation_builder.py
+
347
−
6
View file @
c07e281f
import
numpy
as
np
import
numpy
as
np
from
collections
import
deque
# TODO: add docstrings, pylint, etc...
# TODO: add docstrings, pylint, etc...
...
@@ -15,15 +17,131 @@ class ObservationBuilder:
...
@@ -15,15 +17,131 @@ class ObservationBuilder:
class
TreeObsForRailEnv
(
ObservationBuilder
):
class
TreeObsForRailEnv
(
ObservationBuilder
):
def
__init__
(
self
,
env
):
self
.
env
=
env
def
reset
(
self
):
def
reset
(
self
):
# TODO: precompute distances, etc...
self
.
distance_map
=
np
.
inf
*
np
.
ones
(
shape
=
(
self
.
env
.
number_of_agents
,
# raise NotImplementedError()
self
.
env
.
height
,
pass
self
.
env
.
width
))
self
.
max_dist
=
np
.
zeros
(
self
.
env
.
number_of_agents
)
for
i
in
range
(
self
.
env
.
number_of_agents
):
self
.
max_dist
[
i
]
=
self
.
_distance_map_walker
(
self
.
env
.
agents_target
[
i
],
i
)
def
_distance_map_walker
(
self
,
position
,
target_nr
):
# Returns max distance to target, from the farthest away node, while filling in distance_map
for
ori
in
range
(
4
):
self
.
distance_map
[
target_nr
,
position
[
0
],
position
[
1
]]
=
0
# Fill in the (up to) 4 neighboring nodes
# nodes_queue = [] # list of tuples (row, col, direction, distance);
# direction is the direction of movement, meaning that at least a possible orientation
# of an agent in cell (row,col) allows a movement in direction `direction'
nodes_queue
=
deque
(
self
.
_get_and_update_neighbors
(
position
,
target_nr
,
0
,
enforce_target_direction
=-
1
))
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited
=
set
([(
position
[
0
],
position
[
1
],
0
),
(
position
[
0
],
position
[
1
],
1
),
(
position
[
0
],
position
[
1
],
2
),
(
position
[
0
],
position
[
1
],
3
)])
max_distance
=
0
while
nodes_queue
:
node
=
nodes_queue
.
popleft
()
node_id
=
(
node
[
0
],
node
[
1
],
node
[
2
])
#print(node_id, visited, (node_id in visited))
#print(nodes_queue)
if
node_id
not
in
visited
:
visited
.
add
(
node_id
)
# From the list of possible neighbors that have at least a path to the
# current node, only keep those whose new orientation in the current cell
# would allow a transition to direction node[2]
valid_neighbors
=
self
.
_get_and_update_neighbors
(
(
node
[
0
],
node
[
1
]),
target_nr
,
node
[
3
],
node
[
2
])
for
n
in
valid_neighbors
:
nodes_queue
.
append
(
n
)
if
len
(
valid_neighbors
)
>
0
:
max_distance
=
max
(
max_distance
,
node
[
3
]
+
1
)
return
max_distance
def
_get_and_update_neighbors
(
self
,
position
,
target_nr
,
current_distance
,
enforce_target_direction
=-
1
):
neighbors
=
[]
for
direction
in
range
(
4
):
new_cell
=
self
.
_new_position
(
position
,
(
direction
+
2
)
%
4
)
if
new_cell
[
0
]
>=
0
and
new_cell
[
0
]
<
self
.
env
.
height
and
\
new_cell
[
1
]
>=
0
and
new_cell
[
1
]
<
self
.
env
.
width
:
# Check if the two cells are connected by a valid transition
transitionValid
=
False
for
orientation
in
range
(
4
):
moves
=
self
.
env
.
rail
.
get_transitions
((
new_cell
[
0
],
new_cell
[
1
],
orientation
))
if
moves
[
direction
]:
transitionValid
=
True
break
if
not
transitionValid
:
continue
# Check if a transition in direction node[2] is possible if an agent
# lands in the current cell with orientation `direction'; this only
# applies to cells that are not dead-ends!
directionMatch
=
True
if
enforce_target_direction
>=
0
:
directionMatch
=
self
.
env
.
rail
.
get_transition
(
(
new_cell
[
0
],
new_cell
[
1
],
direction
),
enforce_target_direction
)
# If transition is found to invalid, check if perhaps it
# is a dead-end, in which case the direction of movement is rotated
# 180 degrees (moving forward turns the agents and makes it step in the previous cell)
if
not
directionMatch
:
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits
=
0
tmp
=
self
.
env
.
rail
.
get_transitions
((
new_cell
[
0
],
new_cell
[
1
]))
while
tmp
>
0
:
nbits
+=
(
tmp
&
1
)
tmp
=
tmp
>>
1
if
nbits
==
1
:
# Dead-end!
# Check if transition is possible in new_cell
# with orientation (direction+2)%4 in direction `direction'
directionMatch
=
directionMatch
or
self
.
env
.
rail
.
get_transition
(
(
new_cell
[
0
],
new_cell
[
1
],
(
direction
+
2
)
%
4
),
direction
)
if
transitionValid
and
directionMatch
:
new_distance
=
min
(
self
.
distance_map
[
target_nr
,
new_cell
[
0
],
new_cell
[
1
]],
current_distance
+
1
)
neighbors
.
append
((
new_cell
[
0
],
new_cell
[
1
],
direction
,
new_distance
))
self
.
distance_map
[
target_nr
,
new_cell
[
0
],
new_cell
[
1
]]
=
new_distance
return
neighbors
def
_new_position
(
self
,
position
,
movement
):
if
movement
==
0
:
# NORTH
return
(
position
[
0
]
-
1
,
position
[
1
])
elif
movement
==
1
:
# EAST
return
(
position
[
0
],
position
[
1
]
+
1
)
elif
movement
==
2
:
# SOUTH
return
(
position
[
0
]
+
1
,
position
[
1
])
elif
movement
==
3
:
# WEST
return
(
position
[
0
],
position
[
1
]
-
1
)
def
get
(
self
,
handle
):
def
get
(
self
,
handle
):
# TODO: compute the observation for agent `handle'
# TODO: compute the observation for agent `handle'
# raise NotImplementedError()
return
[]
return
[]
...
@@ -38,12 +156,235 @@ class GlobalObsForRailEnv(ObservationBuilder):
...
@@ -38,12 +156,235 @@ class GlobalObsForRailEnv(ObservationBuilder):
- Four 2D arrays containing respectively the position of the given agent,
- Four 2D arrays containing respectively the position of the given agent,
the position of its target, the positions of the other agents and of
the position of its target, the positions of the other agents and of
their target.
their target.
- A 4 elements array with one of encoding of the direction.
"""
"""
def
__init__
(
self
,
env
):
def
__init__
(
self
,
env
):
super
(
GlobalObsForRailEnv
,
self
).
__init__
(
env
)
super
(
GlobalObsForRailEnv
,
self
).
__init__
(
env
)
def
reset
(
self
):
self
.
rail_obs
=
np
.
zeros
((
self
.
env
.
height
,
self
.
env
.
width
,
16
))
self
.
rail_obs
=
np
.
zeros
((
self
.
env
.
height
,
self
.
env
.
width
,
16
))
for
i
in
range
(
self
.
rail_obs
.
shape
[
0
]):
for
i
in
range
(
self
.
rail_obs
.
shape
[
0
]):
for
j
in
range
(
self
.
rail_obs
.
shape
[
1
]):
for
j
in
range
(
self
.
rail_obs
.
shape
[
1
]):
self
.
rail_obs
[
i
,
j
]
=
self
.
env
.
rail
.
get_transitions
((
i
,
j
))
self
.
rail_obs
[
i
,
j
]
=
np
.
array
(
list
(
f
'
{
self
.
env
.
rail
.
get_transitions
((
i
,
j
))
:
016
b
}
'
)).
astype
(
int
)
# self.targets = np.zeros(self.env.height, self.env.width)
# for target_pos in self.env.agents_target:
# self.targets[target_pos] += 1
def
get
(
self
,
handle
):
obs_agents_targets_pos
=
np
.
zeros
((
4
,
self
.
env
.
height
,
self
.
env
.
width
))
agent_pos
=
self
.
env
.
agents_position
[
handle
]
obs_agents_targets_pos
[
0
][
agent_pos
]
+=
1
for
i
in
range
(
len
(
self
.
env
.
agents_position
)):
if
i
!=
handle
:
obs_agents_targets_pos
[
3
][
self
.
env
.
agents_position
[
i
]]
+=
1
agent_target_pos
=
self
.
env
.
agents_target
[
handle
]
obs_agents_targets_pos
[
1
][
agent_target_pos
]
+=
1
for
i
in
range
(
len
(
self
.
env
.
agents_target
)):
if
i
!=
handle
:
obs_agents_targets_pos
[
2
][
self
.
env
.
agents_target
[
i
]]
+=
1
direction
=
np
.
zeros
(
4
)
direction
[
self
.
env
.
agents_direction
[
handle
]]
=
1
return
self
.
rail_obs
,
obs_agents_targets_pos
,
direction
"""
def get_observation(self, agent):
# Get the current observation for an agent
current_position = self.internal_position[agent]
#target_heading = self._compass(agent).tolist()
coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
agent_distance = self.distance_map[agent][coordinate][0]
# Start tree search
if current_position == self.target[agent]:
agent_tree = Node(current_position, [-np.inf, -np.inf, -np.inf, -np.inf, -1])
else:
agent_tree = Node(current_position, [0, 0, 0, 0, agent_distance])
initial_tree_state = Tree_State(agent, current_position, -1, 0, 0)
self._tree_search(initial_tree_state, agent_tree, agent)
observation = []
distance_data = []
self._flatten_tree(agent_tree, observation, distance_data, self.max_depth+1)
# This is probably very slow!!!!
#max_obs = np.max([i for i in observation if i < np.inf])
#if max_obs != 0:
# observation = np.array(observation)/ max_obs
#print([i for i in distance_data if i >= 0])
observation = np.concatenate((observation, distance_data))
#observation = np.concatenate((observation, np.identity(5)[int(self.last_action[agent])]))
#return np.clip(observation / self.max_dist[agent], -1, 1)
return np.clip(observation / 15., -1, 1)
def _tree_search(self, in_tree_state, parent_node, agent):
if in_tree_state.depth >= self.max_depth:
return
target_distance = np.inf
other_target = np.inf
other_agent = np.inf
coordinate = tuple(np.transpose(self._position_to_coordinate([in_tree_state.position])))
curr_target_dist = self.distance_map[agent][coordinate][0]
forbidden_action = (in_tree_state.direction + 2) % 4
# Update the position
failed_move = 0
leaf_distance = in_tree_state.distance
for child_idx in range(4):
if child_idx != forbidden_action or in_tree_state.direction == -1:
tree_state = copy.deepcopy(in_tree_state)
tree_state.direction = child_idx
current_position, invalid_move = self._detect_path(
tree_state.position, tree_state.direction)
if tree_state.initial_direction == None:
tree_state.initial_direction = child_idx
if not invalid_move:
coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
curr_target_dist = self.distance_map[agent][coordinate][0]
#if tree_state.initial_direction == None:
# tree_state.initial_direction = child_idx
tree_state.position = current_position
tree_state.distance += 1
# Collect information at the current position
detection_distance = tree_state.distance
if current_position == self.target[tree_state.agent]:
target_distance = detection_distance
elif current_position in self.target:
other_target = detection_distance
if current_position in self.internal_position:
other_agent = detection_distance
tree_state.data[0] = self._min_greater_zero(target_distance, tree_state.data[0])
tree_state.data[1] = self._min_greater_zero(other_target, tree_state.data[1])
tree_state.data[2] = self._min_greater_zero(other_agent, tree_state.data[2])
tree_state.data[3] = tree_state.distance
tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])
if self._switch_detection(tree_state.position):
tree_state.depth += 1
new_tree_state = copy.deepcopy(tree_state)
new_node = parent_node.insert(tree_state.position,
tree_state.data, tree_state.initial_direction)
new_tree_state.initial_direction = None
new_tree_state.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
self._tree_search(new_tree_state, new_node, agent)
else:
self._tree_search(tree_state, parent_node, agent)
else:
failed_move += 1
if failed_move == 3 and in_tree_state.direction != -1:
tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])
parent_node.insert(tree_state.position, tree_state.data, tree_state.initial_direction)
return
return
def _flatten_tree(self, node, observation_vector, distance_sensor, depth):
if depth <= 0:
return
if node != None:
observation_vector.extend(node.data[:-1])
distance_sensor.extend([node.data[-1]])
else:
observation_vector.extend([-np.inf, -np.inf, -np.inf, -np.inf])
distance_sensor.extend([-np.inf])
for child_idx in range(4):
if node != None:
child = node.children[child_idx]
else:
child = None
self._flatten_tree(child, observation_vector, distance_sensor, depth -1)
def _switch_detection(self, position):
# Hack to detect switches
# This can later directly be derived from the transition matrix
paths = 0
for i in range(4):
_, invalid_move = self._detect_path(position, i)
if not invalid_move:
paths +=1
if paths >= 3:
return True
return False
def _min_greater_zero(self, x, y):
if x <= 0 and y <= 0:
return 0
if x < 0:
return y
if y < 0:
return x
return min(x, y)
"""
class
Tree_State
:
"""
Keep track of the current state while building the tree
"""
def
__init__
(
self
,
agent
,
position
,
direction
,
depth
,
distance
):
self
.
agent
=
agent
self
.
position
=
position
self
.
direction
=
direction
self
.
depth
=
depth
self
.
initial_direction
=
None
self
.
distance
=
distance
self
.
data
=
[
np
.
inf
,
np
.
inf
,
np
.
inf
,
np
.
inf
,
np
.
inf
]
class
Node
():
"""
Define a tree node to get populated during search
"""
def
__init__
(
self
,
position
,
data
):
self
.
n_children
=
4
self
.
children
=
[
None
]
*
self
.
n_children
self
.
data
=
data
self
.
position
=
position
def
insert
(
self
,
position
,
data
,
child_idx
):
"""
Insert new node with data
@param data node data object to insert
"""
new_node
=
Node
(
position
,
data
)
self
.
children
[
child_idx
]
=
new_node
return
new_node
def
print_tree
(
self
,
i
=
0
,
depth
=
0
):
"""
Print tree content inorder
"""
current_i
=
i
curr_depth
=
depth
+
1
if
i
<
self
.
n_children
:
if
self
.
children
[
i
]
!=
None
:
self
.
children
[
i
].
print_tree
(
depth
=
curr_depth
)
current_i
+=
1
if
self
.
children
[
i
]
!=
None
:
self
.
children
[
i
].
print_tree
(
i
,
depth
=
curr_depth
)
This diff is collapsed.
Click to expand it.
tests/test_env_observation_builder.py
+
38
−
9
View file @
c07e281f
...
@@ -2,8 +2,11 @@
...
@@ -2,8 +2,11 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from
flatland.core.env_observation_builder
import
GlobalObsForRailEnv
from
flatland.core.env_observation_builder
import
GlobalObsForRailEnv
from
flatland.core.transitions
import
Grid4Transitions
# from flatland.core.transitions import Grid4Transitions
from
flatland.core.transition_map
import
GridTransitionMap
,
Grid4Transitions
from
flatland.core.env
import
RailEnv
import
numpy
as
np
import
numpy
as
np
from
flatland.utils.rendertools
import
*
"""
Tests for `flatland` package.
"""
"""
Tests for `flatland` package.
"""
...
@@ -43,18 +46,44 @@ def test_global_obs():
...
@@ -43,18 +46,44 @@ def test_global_obs():
double_switch_north_horizontal_straight
=
transitions
.
rotate_transition
(
double_switch_north_horizontal_straight
=
transitions
.
rotate_transition
(
double_switch_south_horizontal_straight
,
180
)
double_switch_south_horizontal_straight
,
180
)
rail_map
=
np
.
array
(
rail_map
=
np
.
array
(
[[
empty
]
*
3
+
[
dead_end_from_south
]
+
[
empty
]
*
6
]
+
[[
empty
]
*
3
+
[
dead_end_from_south
]
+
[
empty
]
*
6
]
+
[[
empty
]
*
3
+
[
vertical_straight
]
+
[
empty
]
*
6
]
*
2
+
[[
horizontal_straight
]
*
3
+
[
double_switch_north_horizontal_straight
]
+
[
horizontal_straight
]
*
2
+
[
double_switch_south_horizontal_straight
]
+
[
horizontal_straight
]
*
3
]
+
[[
empty
]
*
3
+
[
vertical_straight
]
+
[
empty
]
*
6
]
*
2
+
[[
empty
]
*
3
+
[
vertical_straight
]
+
[
empty
]
*
6
]
*
2
+
[[
empty
]
*
3
+
[
dead_end_from_south
]
+
[
empty
]
*
6
],
dtype
=
np
.
uint16
)
[[
dead_end_from_east
]
+
[
horizontal_straight
]
*
2
+
[
double_switch_north_horizontal_straight
]
+
[
horizontal_straight
]
*
2
+
[
double_switch_south_horizontal_straight
]
+
[
horizontal_straight
]
*
2
+
[
dead_end_from_west
]]
+
[[
empty
]
*
6
+
[
vertical_straight
]
+
[
empty
]
*
3
]
*
2
+
[[
empty
]
*
6
+
[
dead_end_from_north
]
+
[
empty
]
*
3
],
dtype
=
np
.
uint16
)
rail
=
GridTransitionMap
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
transitions
=
transitions
)
rail
.
grid
=
rail_map
env
=
RailEnv
(
rail
,
number_of_agents
=
1
)
env
.
reset
()
# env_renderer = RenderTool(env)
# env_renderer.renderEnv(show=True)
global_obs
=
GlobalObsForRailEnv
(
env
)
global_obs
.
reset
()
assert
(
global_obs
.
rail_obs
.
shape
==
rail_map
.
shape
+
(
16
,))
rail_map_recons
=
np
.
zeros_like
(
rail_map
)
for
i
in
range
(
global_obs
.
rail_obs
.
shape
[
0
]):
for
j
in
range
(
global_obs
.
rail_obs
.
shape
[
1
]):
rail_map_recons
[
i
,
j
]
=
int
(
''
.
join
(
global_obs
.
rail_obs
[
i
,
j
].
astype
(
int
).
astype
(
str
)),
2
)
assert
(
rail_map_recons
.
all
()
==
rail_map
.
all
())
obs
=
global_obs
.
get
(
0
)
# If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell
assert
(
np
.
sum
(
rail_map
*
obs
[
1
][
0
])
>
0
)
print
(
rail_map
.
shape
)
test_global_obs
()
test_global_obs
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment