Skip to content
Snippets Groups Projects
Commit d0bb6d8f authored by hagrid67's avatar hagrid67
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland

parents 027b7955 2e7eaa1b
No related branches found
No related tags found
No related merge requests found
...@@ -116,7 +116,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): ...@@ -116,7 +116,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
# Reset environment # Reset environment
obs = env.reset() obs, _ = env.reset()
final_obs = obs.copy() final_obs = obs.copy()
final_obs_next = obs.copy() final_obs_next = obs.copy()
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
...@@ -148,7 +148,7 @@ for trials in range(1, n_trials + 1): ...@@ -148,7 +148,7 @@ for trials in range(1, n_trials + 1):
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step # Environment step
next_obs, all_rewards, done, _ = env.step(action_dict) (next_obs,_), all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5,
......
...@@ -219,7 +219,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -219,7 +219,8 @@ class TreeObsForRailEnv(ObservationBuilder):
if possible_transitions[branch_direction]: if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction) new_cell = self._new_position(agent.position, branch_direction)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) branch_observation, visited = self._explore_branch(handle, new_cell, branch_direction, root_observation,
1)
observation = observation + branch_observation observation = observation + branch_observation
else: else:
num_cells_to_fill_in = 0 num_cells_to_fill_in = 0
...@@ -228,7 +229,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -228,7 +229,7 @@ class TreeObsForRailEnv(ObservationBuilder):
num_cells_to_fill_in += pow4 num_cells_to_fill_in += pow4
pow4 *= 4 pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
return observation return observation, visited
def _explore_branch(self, handle, position, direction, root_observation, depth): def _explore_branch(self, handle, position, direction, root_observation, depth):
""" """
...@@ -236,7 +237,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -236,7 +237,7 @@ class TreeObsForRailEnv(ObservationBuilder):
""" """
# [Recursive branch opened] # [Recursive branch opened]
if depth >= self.max_depth + 1: if depth >= self.max_depth + 1:
return [] return [], []
# Continue along direction until next switch or # Continue along direction until next switch or
# until no transitions are possible along the current direction (i.e., dead-ends) # until no transitions are possible along the current direction (i.e., dead-ends)
...@@ -377,22 +378,24 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -377,22 +378,24 @@ class TreeObsForRailEnv(ObservationBuilder):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back # it back
new_cell = self._new_position(position, (branch_direction + 2) % 4) new_cell = self._new_position(position, (branch_direction + 2) % 4)
branch_observation = self._explore_branch(handle, branch_observation, branch_visited = self._explore_branch(handle,
new_cell, new_cell,
(branch_direction + 2) % 4, (branch_direction + 2) % 4,
new_root_observation, new_root_observation,
depth + 1) depth + 1)
observation = observation + branch_observation observation = observation + branch_observation
if len(branch_visited) != 0:
visited.union(branch_visited)
elif last_isSwitch and possible_transitions[branch_direction]: elif last_isSwitch and possible_transitions[branch_direction]:
new_cell = self._new_position(position, branch_direction) new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle, branch_observation, branch_visited = self._explore_branch(handle,
new_cell, new_cell,
branch_direction, branch_direction,
new_root_observation, new_root_observation,
depth + 1) depth + 1)
observation = observation + branch_observation observation = observation + branch_observation
if len(branch_visited) != 0:
visited.union(branch_visited)
else: else:
num_cells_to_fill_in = 0 num_cells_to_fill_in = 0
pow4 = 1 pow4 = 1
...@@ -401,7 +404,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -401,7 +404,7 @@ class TreeObsForRailEnv(ObservationBuilder):
pow4 *= 4 pow4 *= 4
observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in
return observation return observation, visited
def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0): def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
""" """
......
...@@ -315,10 +315,11 @@ class RailEnv(Environment): ...@@ -315,10 +315,11 @@ class RailEnv(Environment):
def _get_observations(self): def _get_observations(self):
self.obs_dict = {} self.obs_dict = {}
self.debug_obs_dict = {}
# for handle in self.agents_handles: # for handle in self.agents_handles:
for iAgent in range(self.get_num_agents()): for iAgent in range(self.get_num_agents()):
self.obs_dict[iAgent] = self.obs_builder.get(iAgent) self.obs_dict[iAgent], self.debug_obs_dict[iAgent] = self.obs_builder.get(iAgent)
return self.obs_dict return self.obs_dict, self.debug_obs_dict
def render(self): def render(self):
# TODO: # TODO:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment