Skip to content
Snippets Groups Projects
Commit 3cc3a0c6 authored by hagrid67's avatar hagrid67
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland

parents 40b72b67 06c520f0
No related branches found
No related tags found
No related merge requests found
......@@ -12,8 +12,6 @@ import numpy as np
from collections import deque
# TODO: add docstrings, pylint, etc...
class ObservationBuilder:
"""
......@@ -127,53 +125,6 @@ class TreeObsForRailEnv(ObservationBuilder):
"""
neighbors = []
for direction in range(4):
new_cell = self._new_position(position, (direction+2) % 4)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
# Check if the two cells are connected by a valid transition
transitionValid = False
for orientation in range(4):
moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
if moves[direction]:
transitionValid = True
break
if not transitionValid:
continue
# Check if a transition in direction node[2] is possible if an agent
# lands in the current cell with orientation `direction'; this only
# applies to cells that are not dead-ends!
directionMatch = True
if enforce_target_direction >= 0:
directionMatch = self.env.rail.get_transition(
(new_cell[0], new_cell[1], direction), enforce_target_direction)
# If transition is found to invalid, check if perhaps it
# is a dead-end, in which case the direction of movement is rotated
# 180 degrees (moving forward turns the agents and makes it step in the previous cell)
if not directionMatch:
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# Dead-end!
# Check if transition is possible in new_cell
# with orientation (direction+2)%4 in direction `direction'
directionMatch = directionMatch or self.env.rail.get_transition(
(new_cell[0], new_cell[1], (direction+2) % 4), direction)
if transitionValid and directionMatch:
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], direction],
current_distance+1)
neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], direction] = new_distance
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
......@@ -263,7 +214,7 @@ class TreeObsForRailEnv(ObservationBuilder):
#3: 1 if another agent is detected between the previous node and the current one.
#4:
#4: distance of agent to the current branch node
#5: minimum distance from node to the agent's target (when landing to the node following the corresponding
branch.
......@@ -286,6 +237,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Root node - current position
observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
root_observation = observation[:]
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
......@@ -293,7 +245,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, 1)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
observation = observation + branch_observation
else:
num_cells_to_fill_in = 0
......@@ -305,7 +257,7 @@ class TreeObsForRailEnv(ObservationBuilder):
return observation
def _explore_branch(self, handle, position, direction, depth):
def _explore_branch(self, handle, position, direction, root_observation, depth):
"""
Utility function to compute tree-based observations.
"""
......@@ -319,14 +271,14 @@ class TreeObsForRailEnv(ObservationBuilder):
exploring = True
last_isSwitch = False
last_isDeadEnd = False
last_isTerminal = False # wrong cell encountered OR cycle encountered; either way, we don't want the agent
# to land here
last_isTerminal = False # wrong cell OR cycle; either way, we don't want the agent to land here
last_isTarget = False
visited = set([position[0], position[1], direction])
visited = set()
other_agent_encountered = False
other_target_encountered = False
num_steps = 1
while exploring:
# #############################
# #############################
......@@ -345,6 +297,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if (position[0], position[1], direction) in visited:
last_isTerminal = True
break
visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible.
if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
......@@ -377,6 +330,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if cell_transitions[i]:
position = self._new_position(position, i)
direction = i
num_steps += 1
break
elif num_transitions > 0:
......@@ -386,11 +340,10 @@ class TreeObsForRailEnv(ObservationBuilder):
elif num_transitions == 0:
# Wrong cell type, but let's cover it and treat it as a dead-end, just in case
print("WRONG CELL TYPE detected in tree-search (0 transitions possible)")
last_isTerminal = True
break
visited.add((position[0], position[1], direction))
# `position' is either a terminal node or a switch
observation = []
......@@ -403,25 +356,27 @@ class TreeObsForRailEnv(ObservationBuilder):
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
0,
root_observation[3]+num_steps,
0]
elif last_isTerminal:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
0,
np.inf,
np.inf]
else:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
0,
root_observation[3]+num_steps,
self.distance_map[handle, position[0], position[1], direction]]
# #############################
# #############################
new_root_observation = observation[:]
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]:
......@@ -431,14 +386,22 @@ class TreeObsForRailEnv(ObservationBuilder):
# it back
new_cell = self._new_position(position, (branch_direction+2) % 4)
branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, depth+1)
branch_observation = self._explore_branch(handle,
new_cell,
(branch_direction+2) % 4,
new_root_observation,
depth+1)
observation = observation + branch_observation
elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
branch_direction):
new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1)
branch_observation = self._explore_branch(handle,
new_cell,
branch_direction,
new_root_observation,
depth+1)
observation = observation + branch_observation
else:
......
......@@ -486,6 +486,8 @@ class RailEnv(Environment):
for handle in self.agents_handles:
self.dones[handle] = False
# Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial
# agent's orientations that allow a valid solution.
re_generate = True
while re_generate:
valid_positions = []
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from flatland.core.env_observation_builder import GlobalObsForRailEnv
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
from flatland.utils.rendertools import *
"""Tests for `flatland` package."""
......@@ -45,14 +46,14 @@ def test_global_obs():
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
......@@ -81,17 +82,3 @@ def test_global_obs():
# If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell
assert(np.sum(rail_map * global_obs[0][1][0]) > 0)
test_global_obs()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment