Skip to content
Snippets Groups Projects
Commit b935770a authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch 'master' into '140_predictor_multi_speed'

# Conflicts:
#   examples/training_example.py
parents 30b4ca61 8775f09c
No related branches found
No related tags found
No related merge requests found
......@@ -13,11 +13,11 @@ np.random.seed(1)
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
env = RailEnv(width=50,
height=50,
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
obs_builder_object=TreeObservation,
number_of_agents=5)
number_of_agents=3)
env_renderer = RenderTool(env, gl="PILSVG", )
......@@ -68,6 +68,9 @@ for trials in range(1, n_trials + 1):
# Reset environment and get initial observations for all agents
obs = env.reset()
for idx in range(env.get_num_agents()):
tmp_agent = env.agents[idx]
tmp_agent.speed_data["speed"] = 1 / (idx + 1)
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
......@@ -83,7 +86,7 @@ for trials in range(1, n_trials + 1):
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
......
......@@ -324,6 +324,7 @@ class TreeObsForRailEnv(ObservationBuilder):
visited = set()
agent = self.env.agents[handle]
time_per_cell = np.reciprocal(agent.speed_data["speed"])
own_target_encountered = np.inf
other_agent_encountered = np.inf
other_target_encountered = np.inf
......@@ -359,18 +360,21 @@ class TreeObsForRailEnv(ObservationBuilder):
crossing_found = True
# Register possible future conflict
if self.predictor and num_steps < self.max_prediction_depth:
predicted_time = int(tot_dist * time_per_cell)
if self.predictor and predicted_time < self.max_prediction_depth:
int_position = coordinate_to_position(self.env.width, [position])
if tot_dist < self.max_prediction_depth:
pre_step = max(0, tot_dist - 1)
post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
pre_step = max(0, predicted_time - 1)
post_step = min(self.max_prediction_depth - 1, predicted_time + 1)
# Look for conflicting paths at distance tot_dist
if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0):
conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[tot_dist][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[tot_dist][ca])] == 1 and tot_dist < potential_conflict:
if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
self._reverse_dir(
self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment