Skip to content
Snippets Groups Projects
Commit 1e1a1282 authored by Erik Nygren's avatar Erik Nygren
Browse files

fixed tree observation error. testing observation. replaced a few for loops with numpy functions

parent 65d7cf9e
No related branches found
No related tags found
No related merge requests found
from flatland.envs.rail_env import *
from flatland.envs.generators import *
from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import *
from flatland.baselines.dueling_double_dqn import Agent
......@@ -54,9 +55,9 @@ scores = []
dones_list = []
action_prob = [0] * 4
agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth'))
#agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth'))
demo = True
demo = False
def max_lt(seq, val):
......
......@@ -236,6 +236,7 @@ class TreeObsForRailEnv(ObservationBuilder):
position = self.env.agents_position[handle]
orientation = self.env.agents_direction[handle]
possible_transitions = self.env.rail.get_transitions((position[0], position[1], orientation))
# Root node - current position
observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
......@@ -245,7 +246,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# organize them as [left, forward, right, back], relative to the current orientation
# TODO: Adjust this to the novel movement dynamics --> Only Forward present when one transition is possible.
for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
if possible_transitions[branch_direction]:
new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
......@@ -308,11 +309,7 @@ class TreeObsForRailEnv(ObservationBuilder):
break
cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
num_transitions = 0
for i in range(4):
if cell_transitions[i]:
num_transitions += 1
num_transitions = np.count_nonzero(cell_transitions)
exploring = False
if num_transitions == 1:
# Check if dead-end, or if we can go forward along direction
......@@ -328,13 +325,9 @@ class TreeObsForRailEnv(ObservationBuilder):
if not last_isDeadEnd:
# Keep walking through the tree along `direction'
exploring = True
# TODO: Remove below calculation, this is computed already above and could be reused
for i in range(4):
if cell_transitions[i]:
position = self._new_position(position, i)
direction = i
num_steps += 1
break
direction = np.argmax(cell_transitions)
position = self._new_position(position, direction)
num_steps += 1
elif num_transitions > 0:
# Switch detected
......@@ -383,13 +376,14 @@ class TreeObsForRailEnv(ObservationBuilder):
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# Get the possible transitions
possible_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction),
(branch_direction + 2) % 4):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back
new_cell = self._new_position(position, (branch_direction + 2) % 4)
branch_observation = self._explore_branch(handle,
new_cell,
(branch_direction + 2) % 4,
......@@ -397,10 +391,8 @@ class TreeObsForRailEnv(ObservationBuilder):
depth + 1)
observation = observation + branch_observation
elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
(branch_direction + 2) % 4):
elif last_isSwitch and possible_transitions[branch_direction]:
new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle,
new_cell,
branch_direction,
......
......@@ -323,6 +323,7 @@ class RailEnv(Environment):
nbits += (tmp & 1)
tmp = tmp >> 1
movement = direction
#print(nbits,np.sum(possible_transitions))
if action == 1:
movement = direction - 1
if nbits <= 2 or np.sum(possible_transitions) <= 1:
......@@ -360,12 +361,14 @@ class RailEnv(Environment):
direction = reverse_direction
movement = reverse_direction
is_deadend = True
if np.sum(possible_transitions) == 1:
# Checking for curves
curv_dir = np.argmax(possible_transitions)
# valid_transition = self.rail.get_transition(
# (pos[0], pos[1], direction),
# movement)
movement = curv_dir
new_position = self._new_position(pos, movement)
# Is it a legal move? 1) transition allows the movement in the
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment