Skip to content
Snippets Groups Projects
Commit 49627b9e authored by hagrid67's avatar hagrid67
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland

parents fcc163ab dcc279a2
No related branches found
No related tags found
No related merge requests found
...@@ -30,13 +30,13 @@ env = RailEnv(width=20, ...@@ -30,13 +30,13 @@ env = RailEnv(width=20,
rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=10, max_dist=99999, seed=0), rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=10, max_dist=99999, seed=0),
number_of_agents=5) number_of_agents=5)
"""
env = RailEnv(width=20, env = RailEnv(width=20,
height=20, height=20,
rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
['../env-data/tests/circle.npy']), ['../env-data/tests/circle.npy']),
number_of_agents=1) number_of_agents=1)
"""
env_renderer = RenderTool(env, gl="QT") env_renderer = RenderTool(env, gl="QT")
handle = env.get_agent_handles() handle = env.get_agent_handles()
...@@ -109,7 +109,7 @@ for trials in range(1, n_trials + 1): ...@@ -109,7 +109,7 @@ for trials in range(1, n_trials + 1):
for a in range(env.number_of_agents): for a in range(env.number_of_agents):
if demo: if demo:
eps = 0 eps = 0
action = 2 #agent.act(np.array(obs[a]), eps=eps) action = agent.act(np.array(obs[a]), eps=eps)
action_prob[action] += 1 action_prob[action] += 1
action_dict.update({a: action}) action_dict.update({a: action})
#env.obs_builder.util_print_obs_subtree(tree=obs[a], num_features_per_node=5) #env.obs_builder.util_print_obs_subtree(tree=obs[a], num_features_per_node=5)
......
...@@ -243,6 +243,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -243,6 +243,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Start from the current orientation, and see which transitions are available; # Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation # organize them as [left, forward, right, back], relative to the current orientation
# TODO: Adjust this to the novel movement dynamics --> Only Forward present when one transition is possible.
for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]: for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction): if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
new_cell = self._new_position(position, branch_direction) new_cell = self._new_position(position, branch_direction)
......
...@@ -1040,18 +1040,21 @@ class RailEnv(Environment): ...@@ -1040,18 +1040,21 @@ class RailEnv(Environment):
nbits = 0 nbits = 0
tmp = self.rail.get_transitions((pos[0], pos[1])) tmp = self.rail.get_transitions((pos[0], pos[1]))
possible_transitions = self.rail.get_transitions((pos[0], pos[1], direction))
# print(np.sum(self.rail.get_transitions((pos[0], pos[1],direction))),self.rail.get_transitions((pos[0], pos[1],direction)),self.rail.get_transitions((pos[0], pos[1])),(pos[0], pos[1],direction))
while tmp > 0: while tmp > 0:
nbits += (tmp & 1) nbits += (tmp & 1)
tmp = tmp >> 1 tmp = tmp >> 1
movement = direction movement = direction
if action == 1: if action == 1:
movement = direction - 1 movement = direction - 1
if nbits <= 2: if nbits <= 2 or np.sum(possible_transitions) <= 1:
transition_isValid = False transition_isValid = False
elif action == 3: elif action == 3:
movement = direction + 1 movement = direction + 1
if nbits <= 2: if nbits <= 2 or np.sum(possible_transitions) <= 1:
transition_isValid = False transition_isValid = False
if movement < 0: if movement < 0:
movement += 4 movement += 4
...@@ -1081,12 +1084,14 @@ class RailEnv(Environment): ...@@ -1081,12 +1084,14 @@ class RailEnv(Environment):
direction = reverse_direction direction = reverse_direction
movement = reverse_direction movement = reverse_direction
is_deadend = True is_deadend = True
if nbits == 2: if np.sum(possible_transitions) == 1:
# Checking for curves # Checking for curves
curv_dir = np.argmax(possible_transitions)
valid_transition = self.rail.get_transition( #valid_transition = self.rail.get_transition(
(pos[0], pos[1], direction), # (pos[0], pos[1], direction),
movement) # movement)
movement = curv_dir
"""
reverse_direction = (direction + 2) % 4 reverse_direction = (direction + 2) % 4
curv_dir = (movement + 1) % 4 curv_dir = (movement + 1) % 4
while not valid_transition: while not valid_transition:
...@@ -1097,7 +1102,7 @@ class RailEnv(Environment): ...@@ -1097,7 +1102,7 @@ class RailEnv(Environment):
if valid_transition: if valid_transition:
movement = curv_dir movement = curv_dir
curv_dir = (curv_dir + 1) % 4 curv_dir = (curv_dir + 1) % 4
"""
new_position = self._new_position(pos, movement) new_position = self._new_position(pos, movement)
# Is it a legal move? 1) transition allows the movement in the # Is it a legal move? 1) transition allows the movement in the
# cell, 2) the new cell is not empty (case 0), 3) the cell is # cell, 2) the new cell is not empty (case 0), 3) the cell is
......
...@@ -30,8 +30,8 @@ def checkFrozenImage(sFileImage): ...@@ -30,8 +30,8 @@ def checkFrozenImage(sFileImage):
if bytesFrozenImage is None: if bytesFrozenImage is None:
bytesFrozenImage = bytesImage bytesFrozenImage = bytesImage
else: else:
assert(bytesFrozenImage.shape == bytesImage.shape) assert (bytesFrozenImage.shape == bytesImage.shape)
assert((np.sum(np.square(bytesFrozenImage - bytesImage)) / bytesFrozenImage.size) < 1e-3) assert ((np.sum(np.square(bytesFrozenImage - bytesImage)) / bytesFrozenImage.size) < 1e-3)
def test_render_env(): def test_render_env():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment