Skip to content
Snippets Groups Projects
Commit cc6b05f0 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch 'master' into 'stochasticbreaking'

# Conflicts:
#   examples/training_example.py
parents 48eb6c32 fc34b470
No related branches found
No related tags found
No related merge requests found
......@@ -10,16 +10,14 @@ np.random.seed(1)
# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
#
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
env = RailEnv(width=50,
height=50,
rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
obs_builder_object=TreeObservation,
number_of_agents=10)
number_of_agents=3)
env_renderer = RenderTool(env, gl="PILSVG", )
......@@ -70,6 +68,9 @@ for trials in range(1, n_trials + 1):
# Reset environment and get initial observations for all agents
obs = env.reset()
for idx in range(env.get_num_agents()):
tmp_agent = env.agents[idx]
tmp_agent.speed_data["speed"] = 1 / (idx + 1)
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
......@@ -84,7 +85,7 @@ for trials in range(1, n_trials + 1):
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
......
......@@ -324,6 +324,7 @@ class TreeObsForRailEnv(ObservationBuilder):
visited = set()
agent = self.env.agents[handle]
time_per_cell = np.reciprocal(agent.speed_data["speed"])
own_target_encountered = np.inf
other_agent_encountered = np.inf
other_target_encountered = np.inf
......@@ -359,18 +360,21 @@ class TreeObsForRailEnv(ObservationBuilder):
crossing_found = True
# Register possible future conflict
if self.predictor and num_steps < self.max_prediction_depth:
predicted_time = int(tot_dist * time_per_cell)
if self.predictor and predicted_time < self.max_prediction_depth:
int_position = coordinate_to_position(self.env.width, [position])
if tot_dist < self.max_prediction_depth:
pre_step = max(0, tot_dist - 1)
post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
pre_step = max(0, predicted_time - 1)
post_step = min(self.max_prediction_depth - 1, predicted_time + 1)
# Look for conflicting paths at distance tot_dist
if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0):
conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[tot_dist][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[tot_dist][ca])] == 1 and tot_dist < potential_conflict:
if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
self._reverse_dir(
self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
......
......@@ -124,8 +124,12 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
for agent in agents:
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
agent_speed = agent.speed_data["speed"]
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
new_direction = _agent_initial_direction
new_position = _agent_initial_position
visited = set()
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving...
......@@ -140,12 +144,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
new_position = None
new_direction = None
if np.sum(cell_transitions) == 1:
if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
new_direction = np.argmax(cell_transitions)
new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1:
elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
min_dist = np.inf
no_dist_found = True
for direction in range(4):
......@@ -157,7 +159,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
new_direction = direction
no_dist_found = False
new_position = get_new_position(agent.position, new_direction)
else:
elif index % times_per_cell == 0:
raise Exception("No transition possible {}".format(cell_transitions))
# update the agent's position and direction
......
import numpy as np
from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv
np.random.seed(1)
# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
#
class RandomAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
def act(self, state):
"""
:param state: input is the observation of the agent
:return: returns an action
"""
return np.random.choice([1, 2, 3])
def step(self, memories):
"""
Step function to improve agent by adjusting policy given the observations
:param memories: SARS Tuple to be
:return:
"""
return
def save(self, filename):
# Store the current policy
return
def load(self, filename):
# Load a policy
return
def test_multi_speed_init():
env = RailEnv(width=50,
height=50,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
seed=0),
number_of_agents=5)
# Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 4)
# Empty dictionary for all agent action
action_dict = dict()
# Set all the different speeds
# Reset environment and get initial observations for all agents
env.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
old_pos = []
for i_agent in range(env.get_num_agents()):
env.agents[i_agent].speed_data['speed'] = 1. / (i_agent + 1)
old_pos.append(env.agents[i_agent].position)
# Run episode
for step in range(100):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = agent.act(0)
action_dict.update({a: action})
# Check that agent did not move in between its speed updates
assert old_pos[a] == env.agents[a].position
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
_, _, _, _ = env.step(action_dict)
# Update old position whenever an agent was allowed to move
for i_agent in range(env.get_num_agents()):
if (step + 1) % (i_agent + 1) == 0:
print(step, i_agent, env.agents[a].position)
old_pos[i_agent] = env.agents[i_agent].position
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment