From 1e1a128291e1c60a2363304af0b4929378c0d7aa Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Fri, 3 May 2019 11:32:45 +0200 Subject: [PATCH] fixed tree observation error. testing observation. replaced a few for loops with numpy functions --- examples/training_navigation.py | 5 +++-- flatland/core/env_observation_builder.py | 26 ++++++++---------------- flatland/envs/rail_env.py | 3 +++ 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 9fc83242..1111e0bb 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -1,4 +1,5 @@ from flatland.envs.rail_env import * +from flatland.envs.generators import * from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import * from flatland.baselines.dueling_double_dqn import Agent @@ -54,9 +55,9 @@ scores = [] dones_list = [] action_prob = [0] * 4 agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth')) +#agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth')) -demo = True +demo = False def max_lt(seq, val): diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 6c27afdc..edd9cae7 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -236,6 +236,7 @@ class TreeObsForRailEnv(ObservationBuilder): position = self.env.agents_position[handle] orientation = self.env.agents_direction[handle] + possible_transitions = self.env.rail.get_transitions((position[0], position[1], orientation)) # Root node - current position observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] @@ -245,7 +246,7 @@ class TreeObsForRailEnv(ObservationBuilder): # organize them as [left, forward, right, back], relative to the current orientation # TODO: Adjust this to the novel movement dynamics --> Only Forward present when one transition is possible. for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]: - if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction): + if possible_transitions[branch_direction]: new_cell = self._new_position(position, branch_direction) branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) @@ -308,11 +309,7 @@ class TreeObsForRailEnv(ObservationBuilder): break cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction)) - num_transitions = 0 - for i in range(4): - if cell_transitions[i]: - num_transitions += 1 - + num_transitions = np.count_nonzero(cell_transitions) exploring = False if num_transitions == 1: # Check if dead-end, or if we can go forward along direction @@ -328,13 +325,9 @@ class TreeObsForRailEnv(ObservationBuilder): if not last_isDeadEnd: # Keep walking through the tree along `direction' exploring = True - # TODO: Remove below calculation, this is computed already above and could be reused - for i in range(4): - if cell_transitions[i]: - position = self._new_position(position, i) - direction = i - num_steps += 1 - break + direction = np.argmax(cell_transitions) + position = self._new_position(position, direction) + num_steps += 1 elif num_transitions > 0: # Switch detected @@ -383,13 +376,14 @@ class TreeObsForRailEnv(ObservationBuilder): # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation + # Get the possible transitions + possible_transitions = self.env.rail.get_transitions((position[0], position[1], direction)) for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]: if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction), (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back new_cell = self._new_position(position, (branch_direction + 2) % 4) - branch_observation = self._explore_branch(handle, new_cell, (branch_direction + 2) % 4, @@ -397,10 +391,8 @@ class TreeObsForRailEnv(ObservationBuilder): depth + 1) observation = observation + branch_observation - elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), - (branch_direction + 2) % 4): + elif last_isSwitch and possible_transitions[branch_direction]: new_cell = self._new_position(position, branch_direction) - branch_observation = self._explore_branch(handle, new_cell, branch_direction, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 779ec058..9d67f830 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -323,6 +323,7 @@ class RailEnv(Environment): nbits += (tmp & 1) tmp = tmp >> 1 movement = direction + #print(nbits,np.sum(possible_transitions)) if action == 1: movement = direction - 1 if nbits <= 2 or np.sum(possible_transitions) <= 1: @@ -360,12 +361,14 @@ class RailEnv(Environment): direction = reverse_direction movement = reverse_direction is_deadend = True + if np.sum(possible_transitions) == 1: # Checking for curves curv_dir = np.argmax(possible_transitions) # valid_transition = self.rail.get_transition( # (pos[0], pos[1], direction), # movement) + movement = curv_dir new_position = self._new_position(pos, movement) # Is it a legal move? 1) transition allows the movement in the -- GitLab