Skip to content
Snippets Groups Projects
Commit 38cf83a1 authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch '62-unit-test-coverage-and-code-cleanup' into 'master'

Resolve "increase unit test coverage (at least 80%)"

See merge request flatland/flatland!95
parents 62dbb589 3383b56b
No related branches found
No related tags found
No related merge requests found
......@@ -336,8 +336,4 @@ class GridTransitionMap(TransitionMap):
return True
# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
# (most general implementation) or to make Grid-class specific methods for
# slicing over the 3 dimensions? I'd say both perhaps.
# TODO: override __getitem__ and __setitem__ (cell contents, not transitions?)
# TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)
......@@ -23,6 +23,7 @@ class TreeObsForRailEnv(ObservationBuilder):
observation_dim = 9
def __init__(self, max_depth, predictor=None):
super().__init__()
self.max_depth = max_depth
# Compute the size of the returned observation vector
......@@ -41,15 +42,14 @@ class TreeObsForRailEnv(ObservationBuilder):
def reset(self):
agents = self.env.agents
nAgents = len(agents)
nb_agents = len(agents)
compute_distance_map = True
if self.agents_previous_reset is not None:
if nAgents == len(self.agents_previous_reset):
compute_distance_map = False
for i in range(nAgents):
if agents[i].target != self.agents_previous_reset[i].target:
compute_distance_map = True
if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
compute_distance_map = False
for i in range(nb_agents):
if agents[i].target != self.agents_previous_reset[i].target:
compute_distance_map = True
self.agents_previous_reset = agents
if compute_distance_map:
......@@ -57,12 +57,12 @@ class TreeObsForRailEnv(ObservationBuilder):
def _compute_distance_map(self):
agents = self.env.agents
nAgents = len(agents)
self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents,
nb_agents = len(agents)
self.distance_map = np.inf * np.ones(shape=(nb_agents,
self.env.height,
self.env.width,
4))
self.max_dist = np.zeros(nAgents)
self.max_dist = np.zeros(nb_agents)
self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
# Update local lookup table for all agents' target locations
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
......@@ -83,10 +83,8 @@ class TreeObsForRailEnv(ObservationBuilder):
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = set([(position[0], position[1], 0),
(position[0], position[1], 1),
(position[0], position[1], 2),
(position[0], position[1], 3)])
visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
(position[0], position[1], 3)}
max_distance = 0
......@@ -133,10 +131,10 @@ class TreeObsForRailEnv(ObservationBuilder):
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
if isValid:
if is_valid:
"""
# TODO: check that it works with deadends! -- still bugged!
movement = desired_movement_from_new_cell
......@@ -163,12 +161,14 @@ class TreeObsForRailEnv(ObservationBuilder):
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
def get_many(self, handles=[]):
def get_many(self, handles=None):
"""
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
"""
if handles is None:
handles = []
if self.predictor:
self.predicted_pos = {}
self.predicted_dir = {}
......@@ -259,7 +259,6 @@ class TreeObsForRailEnv(ObservationBuilder):
# Root node - current position
observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
root_observation = observation[:]
visited = set()
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
......@@ -273,7 +272,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, root_observation, 1, 1)
self._explore_branch(handle, new_cell, branch_direction, 1, 1)
observation = observation + branch_observation
visited = visited.union(branch_visited)
else:
......@@ -291,7 +290,7 @@ class TreeObsForRailEnv(ObservationBuilder):
pow4 *= 4
return num_observations * self.observation_dim
def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth):
def _explore_branch(self, handle, position, direction, tot_dist, depth):
"""
Utility function to compute tree-based observations.
We walk along the branch and collect the information documented in the get() function.
......@@ -305,10 +304,10 @@ class TreeObsForRailEnv(ObservationBuilder):
# until no transitions are possible along the current direction (i.e., dead-ends)
# We treat dead-ends as nodes, instead of going back, to avoid loops
exploring = True
last_isSwitch = False
last_isDeadEnd = False
last_isTerminal = False # wrong cell OR cycle; either way, we don't want the agent to land here
last_isTarget = False
last_is_switch = False
last_is_dead_end = False
last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here
last_is_target = False
visited = set()
agent = self.env.agents[handle]
......@@ -369,21 +368,19 @@ class TreeObsForRailEnv(ObservationBuilder):
if tot_dist < other_target_encountered:
other_target_encountered = tot_dist
if position == agent.target:
if tot_dist < own_target_encountered:
own_target_encountered = tot_dist
if position == agent.target and tot_dist < own_target_encountered:
own_target_encountered = tot_dist
# #############################
# #############################
if (position[0], position[1], direction) in visited:
last_isTerminal = True
last_is_terminal = True
break
visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible.
if np.array_equal(position, self.env.agents[handle].target):
last_isTarget = True
last_is_target = True
break
cell_transitions = self.env.rail.get_transitions((*position, direction))
......@@ -403,9 +400,9 @@ class TreeObsForRailEnv(ObservationBuilder):
tmp = tmp >> 1
if nbits == 1:
# Dead-end!
last_isDeadEnd = True
last_is_dead_end = True
if not last_isDeadEnd:
if not last_is_dead_end:
# Keep walking through the tree along `direction'
exploring = True
# convert one-hot encoding to 0,1,2,3
......@@ -415,14 +412,14 @@ class TreeObsForRailEnv(ObservationBuilder):
tot_dist += 1
elif num_transitions > 0:
# Switch detected
last_isSwitch = True
last_is_switch = True
break
elif num_transitions == 0:
# Wrong cell type, but let's cover it and treat it as a dead-end, just in case
print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
position[1], direction)
last_isTerminal = True
last_is_terminal = True
break
# `position' is either a terminal node or a switch
......@@ -433,7 +430,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# #############################
# Modify here to append new / different features for each visited cell!
if last_isTarget:
if last_is_target:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
......@@ -445,7 +442,7 @@ class TreeObsForRailEnv(ObservationBuilder):
other_agent_opposite_direction
]
elif last_isTerminal:
elif last_is_terminal:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
......@@ -469,32 +466,30 @@ class TreeObsForRailEnv(ObservationBuilder):
]
# #############################
# #############################
new_root_observation = observation[:]
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# Get the possible transitions
possible_transitions = self.env.rail.get_transitions((*position, direction))
for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
if last_isDeadEnd and self.env.rail.get_transition((*position, direction),
(branch_direction + 2) % 4):
if last_is_dead_end and self.env.rail.get_transition((*position, direction),
(branch_direction + 2) % 4):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back
new_cell = self._new_position(position, (branch_direction + 2) % 4)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
(branch_direction + 2) % 4,
new_root_observation, tot_dist + 1,
tot_dist + 1,
depth + 1)
observation = observation + branch_observation
if len(branch_visited) != 0:
visited = visited.union(branch_visited)
elif last_isSwitch and possible_transitions[branch_direction]:
elif last_is_switch and possible_transitions[branch_direction]:
new_cell = self._new_position(position, branch_direction)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
branch_direction,
new_root_observation, tot_dist + 1,
tot_dist + 1,
depth + 1)
observation = observation + branch_observation
if len(branch_visited) != 0:
......
......@@ -4,7 +4,6 @@ Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
# TODO: _ this is a global method --> utils or remove later
from enum import IntEnum
......@@ -85,7 +84,6 @@ class RailEnv(Environment):
a GridTransitionMap object
rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from
a rail specifications array
TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
width : int
The width of the rail map. Potentially in the future,
a range of widths to sample from.
......@@ -109,7 +107,7 @@ class RailEnv(Environment):
self.obs_builder._set_env(self)
self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets?
self.observation_space = self.obs_builder.observation_space
self.rewards = [0] * number_of_agents
self.done = False
......@@ -195,31 +193,29 @@ class RailEnv(Environment):
# Reset the step rewards
self.rewards_dict = dict()
for iAgent in range(self.get_num_agents()):
self.rewards_dict[iAgent] = 0
for i_agent in range(self.get_num_agents()):
self.rewards_dict[i_agent] = 0
if self.dones["__all__"]:
self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()}
return self._get_observations(), self.rewards_dict, self.dones, {}
# for i in range(len(self.agents_handles)):
for iAgent in range(self.get_num_agents()):
agent = self.agents[iAgent]
for i_agent, agent in enumerate(self.agents):
agent.old_direction = agent.direction
agent.old_position = agent.position
if self.dones[iAgent]: # this agent has already completed...
if self.dones[i_agent]: # this agent has already completed...
continue
if iAgent not in action_dict: # no action has been supplied for this agent
action_dict[iAgent] = RailEnvActions.DO_NOTHING
if i_agent not in action_dict: # no action has been supplied for this agent
action_dict[i_agent] = RailEnvActions.DO_NOTHING
if action_dict[iAgent] < 0 or action_dict[iAgent] > len(RailEnvActions):
print('ERROR: illegal action=', action_dict[iAgent],
'for agent with index=', iAgent,
if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
print('ERROR: illegal action=', action_dict[i_agent],
'for agent with index=', i_agent,
'"DO NOTHING" will be executed instead')
action_dict[iAgent] = RailEnvActions.DO_NOTHING
action_dict[i_agent] = RailEnvActions.DO_NOTHING
action = action_dict[iAgent]
action = action_dict[i_agent]
if action == RailEnvActions.DO_NOTHING and agent.moving:
# Keep moving
......@@ -228,12 +224,12 @@ class RailEnv(Environment):
if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.:
# Only allow halting an agent on entering new cells.
agent.moving = False
self.rewards_dict[iAgent] += stop_penalty
self.rewards_dict[i_agent] += stop_penalty
if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
# Only allow agent to start moving by pressing forward.
agent.moving = True
self.rewards_dict[iAgent] += start_penalty
self.rewards_dict[i_agent] += start_penalty
# Now perform a movement.
# If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
......@@ -269,16 +265,16 @@ class RailEnv(Environment):
else:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self.rewards_dict[iAgent] += invalid_action_penalty
self.rewards_dict[i_agent] += invalid_action_penalty
agent.moving = False
self.rewards_dict[iAgent] += stop_penalty
self.rewards_dict[i_agent] += stop_penalty
continue
else:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self.rewards_dict[iAgent] += invalid_action_penalty
self.rewards_dict[i_agent] += invalid_action_penalty
agent.moving = False
self.rewards_dict[iAgent] += stop_penalty
self.rewards_dict[i_agent] += stop_penalty
continue
......@@ -300,9 +296,9 @@ class RailEnv(Environment):
agent.speed_data['position_fraction'] = 0.0
if np.equal(agent.position, agent.target).all():
self.dones[iAgent] = True
self.dones[i_agent] = True
else:
self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
# Check for end of episode + add global reward to all rewards!
if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
......
......@@ -3,7 +3,7 @@ from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum
from flatland.core.transition_map import GridTransitionMap
def test_grid4_set_transitions():
def test_grid4_get_transitions():
grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0)
grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1)
......@@ -19,3 +19,5 @@ def test_grid8_set_transitions():
assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (1, 0, 0, 0, 0, 0, 0, 0)
grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0)
assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0)
# TODO GridTransitionMap
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment