Commit eabdd117 authored by Erik Nygren's avatar Erik Nygren
Browse files

fixed prediction bug when agent is off grid

parent c1edfb03
Pipeline #2353 failed with stages
in 11 minutes and 6 seconds
......@@ -30,14 +30,14 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(max_num_cities=4,
env = RailEnv(width=20,
height=20,
rail_generator=sparse_rail_generator(max_num_cities=3,
# Number of cities in map (where train stations are)
seed=1, # Random seed
grid_mode=False,
max_rails_between_cities=2,
max_rails_in_city=4,
max_rails_in_city=2,
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=10,
......
......@@ -2,8 +2,6 @@ from typing import Tuple, Callable, List, Type
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
Vector2D: Type = Tuple[float, float]
IntVector2D: Type = Tuple[int, int]
......@@ -286,7 +284,11 @@ def coordinate_to_position(depth, coords):
position = np.empty(len(coords), dtype=int)
idx = 0
for t in coords:
position[idx] = int(t[1] * depth + t[0])
# Set None type coordinates off the grid
if np.isnan(t[0]):
position[idx] = -1
else:
position[idx] = int(t[1] * depth + t[0])
idx += 1
return position
......
......@@ -162,6 +162,7 @@ class TreeObsForRailEnv(ObservationBuilder):
In case the target node is reached, the values are [0, 0, 0, 0, 0].
"""
# Update local lookup table for all agents' positions
# ignore other agents not in the grid (only status active and done)
self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if
......
......@@ -132,9 +132,13 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
elif agent.status == RailAgentStatus.ACTIVE:
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
agent_virtual_position = agent.target
agent_virtual_position = agent.target
else:
prediction_dict[agent.handle] = None
prediction = np.zeros(shape=(self.max_depth + 1, 5))
for i in range(self.max_depth):
prediction[i] = [i, None, None, None, None]
prediction_dict[agent.handle] = prediction
continue
agent_virtual_direction = agent.direction
......
......@@ -557,7 +557,6 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
cell_vector_field = np.zeros(shape=(height, width), dtype=int) - 1
city_radius = int(np.ceil((max_rails_in_city + 2) / 2.0)) + 1
min_nr_rails_in_city = 3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment