Skip to content
Snippets Groups Projects
Commit 6e2c1c9d authored by Erik Nygren's avatar Erik Nygren
Browse files

updating local obs for rail env to be better suited for the task

parent c9c5e411
No related branches found
No related tags found
No related merge requests found
import numpy as np
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
np.random.seed(1)
......@@ -12,12 +13,14 @@ np.random.seed(1)
#
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
obs_builder_object=TreeObservation,
obs_builder_object=LocalGridObs,
number_of_agents=2)
env_renderer = RenderTool(env, gl="PILSVG", )
# Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent here
......@@ -66,6 +69,7 @@ for trials in range(1, n_trials + 1):
# Reset environment and get initial observations for all agents
obs = env.reset()
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
......@@ -80,6 +84,8 @@ for trials in range(1, n_trials + 1):
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
......
......@@ -74,6 +74,7 @@ class ObservationBuilder:
direction[agent.direction] = 1
return direction
class DummyObservationBuilder(ObservationBuilder):
"""
DummyObservationBuilder class which returns dummy observations
......
......@@ -698,71 +698,115 @@ class LocalObsForRailEnv(ObservationBuilder):
The observation is composed of the following elements:
- transition map array of the local environment around the given agent,
with dimensions (2*view_radius + 1, 2*view_radius + 1, 16),
with dimensions (view_height,2*view_width+1, 16),
assuming 16 bits encoding of transitions.
- Two 2D arrays (2*view_radius + 1, 2*view_radius + 1, 2) containing respectively,
- Two 3D arrays (view_height,2*view_width+1, 2) containing respectively,
if they are in the agent's vision range, its target position, the positions of the other targets.
- A 3D array (2*view_radius + 1, 2*view_radius + 1, 4) containing the one hot encoding of directions
- A 3D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions
of the other agents at their position coordinates, if they are in the agent's vision range.
- A 4 elements array with one hot encoding of the direction.
"""
def __init__(self, view_radius):
def __init__(self, view_width, view_height, center):
"""
:param view_radius:
"""
super(LocalObsForRailEnv, self).__init__()
self.view_radius = view_radius
self.view_width = view_width
self.view_height = view_height
self.center = center
self.max_padding = max(self.view_width, self.view_height - self.center)
def reset(self):
# We build the transition map with a view_radius empty cells expansion on each side.
# This helps to collect the local transition map view when the agent is close to a border.
self.rail_obs = np.zeros((self.env.height + 2 * self.view_radius,
self.env.width + 2 * self.view_radius, 16))
self.max_padding = max(self.view_width, self.view_height)
self.rail_obs = np.zeros((self.env.height + 2 * self.max_padding,
self.env.width + 2 * self.max_padding, 16))
for i in range(self.env.height):
for j in range(self.env.width):
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist)
self.rail_obs[i + self.view_height, j + self.view_width] = np.array(bitlist)
def get(self, handle):
agents = self.env.agents
agent = agents[handle]
agent_rel_pos = [0, 0]
local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0] + 2 * self.view_radius + 1,
agent.position[1]:agent.position[1] + 2 * self.view_radius + 1]
obs_map_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 2))
# Correct agents position for padding
agent_rel_pos[0] = agent.position[0] + self.max_padding
agent_rel_pos[1] = agent.position[1] + self.max_padding
obs_other_agents_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 4))
# Collect the rail information in the local field of view
local_rail_obs = self.field_of_view(agent_rel_pos, agent.direction, state=self.rail_obs)
def relative_pos(pos):
return [agent.position[0] - pos[0], agent.position[1] - pos[1]]
# Locate observed agents and their coresponding targets
obs_map_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 2))
obs_other_agents_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 4))
def is_in(rel_pos):
return (abs(rel_pos[0]) <= self.view_radius) and (abs(rel_pos[1]) <= self.view_radius)
# Collect visible cells as set to be plotted
visited = self.field_of_view(agent.position, agent.direction)
target_rel_pos = relative_pos(agent.target)
if is_in(target_rel_pos):
obs_map_state[self.view_radius + np.array(target_rel_pos)][0] += 1
# Add the visible cells to the observed cells
self.env.dev_obs_dict[handle] = visited
for i in range(len(agents)):
if i != handle: # TODO: handle used as index...?
agent2 = agents[i]
direction = self._get_one_hot_for_agent_direction(agent)
agent_2_rel_pos = relative_pos(agent2.position)
if is_in(agent_2_rel_pos):
obs_other_agents_state[self.view_radius + agent_2_rel_pos[0],
self.view_radius + agent_2_rel_pos[1]][agent2.direction] += 1
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
target_rel_pos_2 = relative_pos(agent2.position)
if is_in(target_rel_pos_2):
obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1
def get_many(self, handles=None):
"""
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
"""
direction = self._get_one_hot_for_agent_direction(agent)
observations = {}
for h in handles:
observations[h] = self.get(h)
return observations
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
def field_of_view(self, position, direction, state=None):
# Compute the local field of view for an agent in the environment
data_collection = False
if state is not None:
temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
data_collection = True
if direction == 0:
origin = (position[0] + self.center, position[1] - self.view_width)
elif direction == 1:
origin = (position[0] - self.view_width, position[1] - self.center)
elif direction == 2:
origin = (position[0] - self.center, position[1] + self.view_width)
else:
origin = (position[0] + self.view_width, position[1] + self.center)
visible = set()
for h in range(self.view_height):
for w in range(2 * self.view_width + 1):
if direction == 0:
if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
visible.add((origin[0] - h, origin[1] + w))
if data_collection:
temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
elif direction == 1:
if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
visible.add((origin[0] + w, origin[1] + h))
if data_collection:
temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
elif direction == 2:
if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
visible.add((origin[0] + h, origin[1] - w))
if data_collection:
temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
else:
if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
visible.add((origin[0] - w, origin[1] - h))
if data_collection:
temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
if data_collection:
return temp_visible_data
else:
return visible
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment