Skip to content
Snippets Groups Projects
Commit 4c8f4d40 authored by u214892's avatar u214892
Browse files

update baselines to master of flatland

parent 9a7e2fc1
No related branches found
No related tags found
1 merge request!6Update baselines
......@@ -12,4 +12,4 @@ recursive-include tests *
recursive-exclude * __pycache__
recursive-exclude * *.py[co]
recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif
recursive-include docs *.rst *.md conf.py Makefile make.bat *.jpg *.png *.gif
......@@ -7,7 +7,7 @@ from collections import deque
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import coordinate_to_position
......@@ -86,10 +86,10 @@ class TreeObsForRailEnv(ObservationBuilder):
# Fill in the (up to) 4 neighboring nodes
# direction is the direction of movement, meaning that at least a possible orientation of an agent
# in cell (row,col) allows a movement in direction `direction'
# in cell (row,col) allows a movement in direction `direction`
nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid
# BFS from target `position` to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
(position[0], position[1], 3)}
......@@ -125,12 +125,12 @@ class TreeObsForRailEnv(ObservationBuilder):
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# The agent must land into the current cell with orientation `enforce_target_direction`.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction + 2) % 4]
for neigh_direction in possible_directions:
new_cell = self._new_position(position, neigh_direction)
new_cell = get_new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
......@@ -138,7 +138,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
# Is a transition along movement `desired_movement_from_new_cell` to the current cell possible?
is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
......@@ -156,23 +156,10 @@ class TreeObsForRailEnv(ObservationBuilder):
return neighbors
def _new_position(self, position, movement):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
def get_many(self, handles=None):
"""
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
"""
if handles is None:
......@@ -200,7 +187,7 @@ class TreeObsForRailEnv(ObservationBuilder):
def get(self, handle):
"""
Computes the current observation for agent `handle' in env
Computes the current observation for agent `handle` in env
The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
......@@ -280,7 +267,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
new_cell = get_new_position(agent.position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, 1, 1)
observation = observation + branch_observation
......@@ -428,11 +415,11 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_dead_end = True
if not last_is_dead_end:
# Keep walking through the tree along `direction'
# Keep walking through the tree along `direction`
exploring = True
# convert one-hot encoding to 0,1,2,3
direction = np.argmax(cell_transitions)
position = self._new_position(position, direction)
position = get_new_position(position, direction)
num_steps += 1
tot_dist += 1
elif num_transitions > 0:
......@@ -447,7 +434,7 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_terminal = True
break
# `position' is either a terminal node or a switch
# `position` is either a terminal node or a switch
# #############################
# #############################
......@@ -499,7 +486,7 @@ class TreeObsForRailEnv(ObservationBuilder):
(branch_direction + 2) % 4):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back
new_cell = self._new_position(position, (branch_direction + 2) % 4)
new_cell = get_new_position(position, (branch_direction + 2) % 4)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
(branch_direction + 2) % 4,
......@@ -509,7 +496,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if len(branch_visited) != 0:
visited = visited.union(branch_visited)
elif last_is_switch and possible_transitions[branch_direction]:
new_cell = self._new_position(position, branch_direction)
new_cell = get_new_position(position, branch_direction)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
branch_direction,
......
......@@ -8,6 +8,7 @@ from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions
class ShortestPathPredictorForRailEnv(PredictionBuilder):
"""
ShortestPathPredictorForRailEnv object.
......@@ -25,10 +26,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
Requires distance_map to extract the shortest path.
Parameters
-------
----------
custom_args: dict
- distance_map : dict
handle : int (optional)
handle : int, optional
Handle of the agent for which to compute the observation vector.
Returns
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment