Skip to content
Snippets Groups Projects
Commit e4c43e71 authored by Erik Nygren's avatar Erik Nygren
Browse files

Added new observation and prediction builders to deviate from standard implementation in flatland

parent d83a358c
No related branches found
No related tags found
No related merge requests found
......@@ -4,8 +4,8 @@ from collections import deque
import numpy as np
import torch
from flatland.envs.generators import rail_from_file, complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from observation_builders.observations import TreeObsForRailEnv
from predictors.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from importlib_resources import path
......@@ -17,7 +17,7 @@ from utils.observation_utils import normalize_observation
random.seed(3)
np.random.seed(2)
file_name = "./railway/testing_stuff.pkl"
file_name = "./railway/simple_avoid.pkl"
env = RailEnv(width=10,
height=20,
rail_generator=rail_from_file(file_name),
......@@ -94,7 +94,7 @@ for trials in range(1, n_trials + 1):
if record_images:
env_renderer.gl.save_image("./Images/Avoiding/flatland_frame_{:04d}.bmp".format(frame_step))
frame_step += 1
time.sleep(1.5)
# time.sleep(1.5)
# Action
for a in range(env.get_num_agents()):
action = agent.act(agent_obs[a], eps=0)
......
This diff is collapsed.
"""
Collection of environment-specific PredictionBuilder.
"""
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions
class ShortestPathPredictorForRailEnv(PredictionBuilder):
"""
ShortestPathPredictorForRailEnv object.
This object returns shortest-path predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, custom_args=None, handle=None):
"""
Called whenever get_many in the observation build is called.
Requires distance_map to extract the shortest path.
Parameters
-------
custom_args: dict
- distance_map : dict
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
assert custom_args is not None
distance_map = custom_args.get('distance_map')
assert distance_map is not None
prediction_dict = {}
for agent in agents:
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
visited = set()
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
visited.add((agent.position[0], agent.position[1], agent.direction))
continue
if not agent.moving:
prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
visited.add((agent.position[0], agent.position[1], agent.direction))
continue
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
new_position = None
new_direction = None
if np.sum(cell_transitions) == 1:
new_direction = np.argmax(cell_transitions)
new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1:
min_dist = np.inf
no_dist_found = True
for direction in range(4):
if cell_transitions[direction] == 1:
neighbour_cell = get_new_position(agent.position, direction)
target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
if target_dist < min_dist or no_dist_found:
min_dist = target_dist
new_direction = direction
no_dist_found = False
new_position = get_new_position(agent.position, new_direction)
else:
raise Exception("No transition possible {}".format(cell_transitions))
# update the agent's position and direction
agent.position = new_position
agent.direction = new_direction
# prediction is ready
prediction[index] = [index, *new_position, new_direction, 0]
visited.add((new_position[0], new_position[1], new_direction))
self.env.dev_pred_dict[agent.handle] = visited
prediction_dict[agent.handle] = prediction
# cleanup: reset initial position
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict
......@@ -97,7 +97,6 @@ def split_tree(tree, num_features_per_node, current_depth=0):
agent_data.extend(tmp_agent_data)
return tree_data, distance_data, agent_data
def normalize_observation(observation, num_features_per_node=9, observation_radius=0):
data, distance, agent_data = split_tree(tree=np.array(observation), num_features_per_node=num_features_per_node,
current_depth=0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment