Skip to content
Snippets Groups Projects
Commit f436c714 authored by u214892's avatar u214892
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland into...

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland into 51-run-examples-in-ci-including-notebooks
parents 6e78cdc7 2de92fce
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,10 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth + 1):
size += pow4
pow4 *= 4
self.observation_space = [size * 6]
self.observation_dim = 7
self.observation_space = [size * self.observation_dim]
self.location_has_agent = {}
self.location_has_agent_direction = {}
def reset(self):
agents = self.env.agents
......@@ -181,8 +184,15 @@ class TreeObsForRailEnv(ObservationBuilder):
#5: minimum distance from node to the agent's target (when landing to the node following the corresponding
branch.
#6: agent direction
#6: agent in the same direction
1 = agent present same direction
(possible future use: number of other agents in the same direction in this branch)
0 = no agent present same direction
#7: agent in the opposite drection
1 = agent present other direction than myself (so conflict)
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself
Missing/padding nodes are filled in with -inf (truncated).
......@@ -195,13 +205,15 @@ class TreeObsForRailEnv(ObservationBuilder):
# Update local lookup table for all agents' positions
self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], agent.direction]
observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
root_observation = observation[:]
visited = set()
# Start from the current orientation, and see which transitions are available;
......@@ -212,7 +224,6 @@ class TreeObsForRailEnv(ObservationBuilder):
if num_transitions == 1:
orientation == np.argmax(possible_transitions)
# for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
......@@ -227,7 +238,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth):
num_cells_to_fill_in += pow4
pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
self.env.dev_obs_dict[handle] = visited
return observation
......@@ -275,7 +286,6 @@ class TreeObsForRailEnv(ObservationBuilder):
visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible.
# if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
if np.array_equal(position, self.env.agents[handle].target):
last_isTarget = True
break
......@@ -297,6 +307,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if not last_isDeadEnd:
# Keep walking through the tree along `direction'
exploring = True
# convert one-hot encoding to 0,1,2,3
direction = np.argmax(cell_transitions)
position = self._new_position(position, direction)
num_steps += 1
......@@ -321,36 +332,53 @@ class TreeObsForRailEnv(ObservationBuilder):
# #############################
# Modify here to append new / different features for each visited cell!
"""
other_agent_same_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
other_agent_opposite_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
if last_isTarget:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
0,
direction]
other_agent_same_direction,
other_agent_opposite_direction
]
elif last_isTerminal:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
other_target_encountered,
other_agent_encountered,
np.inf,
np.inf,
direction]
other_agent_same_direction,
other_agent_opposite_direction
]
else:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction],
direction]
other_agent_same_direction,
other_agent_opposite_direction
]
"""
other_agent_same_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
other_agent_opposite_direction = \
1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
if last_isTarget:
observation = [0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
0,
direction
other_agent_same_direction,
other_agent_opposite_direction
]
elif last_isTerminal:
......@@ -359,7 +387,8 @@ class TreeObsForRailEnv(ObservationBuilder):
other_agent_encountered,
np.inf,
np.inf,
direction
other_agent_same_direction,
other_agent_opposite_direction
]
else:
observation = [0,
......@@ -367,7 +396,8 @@ class TreeObsForRailEnv(ObservationBuilder):
other_agent_encountered,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction],
direction
other_agent_same_direction,
other_agent_opposite_direction
]
# #############################
# #############################
......@@ -407,8 +437,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.max_depth - depth):
num_cells_to_fill_in += pow4
pow4 *= 4
observation = \
observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
return observation, visited
......
......@@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder.
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.rail_env import RailEnvActions
class DummyPredictorForRailEnv(PredictionBuilder):
......@@ -41,12 +42,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
prediction_dict = {}
for agent in agents:
# 0: do nothing
# 1: turn left and move to the next cell
# 2: move to the next cell in front of the agent
# 3: turn right and move to the next cell
action_priorities = [2, 1, 3]
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth, 5))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment