Skip to content
Snippets Groups Projects
Commit 550fef31 authored by Erik Nygren's avatar Erik Nygren
Browse files

minor bugfixes. Simple distance map test (thanks Christian). This test will be enhanced soon.

parent b7e98ab0
No related branches found
No related tags found
No related merge requests found
...@@ -209,8 +209,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -209,8 +209,7 @@ class TreeObsForRailEnv(ObservationBuilder):
#1: if own target lies on the explored branch the current distance from the agent in number of cells is stored. #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored.
#2: if another agents target is detected the distance in number of cells from the agents current locaiton #2: if another agents target is detected the distance in number of cells from the agents current locaiton
is stored is stored
#3: if another agent is detected the distance in number of cells from current agent position is stored. #3: if another agent is detected the distance in number of cells from current agent position is stored.
......
...@@ -140,16 +140,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -140,16 +140,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
new_position = get_new_position(agent.position, new_direction) new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1: elif np.sum(cell_transitions) > 1:
min_dist = np.inf min_dist = np.inf
no_dist_found = True
for direction in range(4): for direction in range(4):
if cell_transitions[direction] == 1: if cell_transitions[direction] == 1:
neighbour_cell = get_new_position(agent.position, direction) neighbour_cell = get_new_position(agent.position, direction)
target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction] target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
if target_dist < min_dist: if target_dist < min_dist or no_dist_found:
min_dist = target_dist min_dist = target_dist
new_direction = direction new_direction = direction
if new_direction == None: no_dist_found = False
prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
continue
new_position = get_new_position(agent.position, new_direction) new_position = get_new_position(agent.position, new_direction)
else: else:
raise Exception("No transition possible {}".format(cell_transitions)) raise Exception("No transition possible {}".format(cell_transitions))
......
...@@ -249,10 +249,10 @@ class RailEnv(Environment): ...@@ -249,10 +249,10 @@ class RailEnv(Environment):
action_selected = False action_selected = False
if agent.speed_data['position_fraction'] == 0.: if agent.speed_data['position_fraction'] == 0.:
if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(action, agent) self._check_action_on_agent(action, agent)
if all([new_cell_isValid, transition_isValid]): if all([new_cell_valid, transition_valid]):
agent.speed_data['transition_action_on_cellexit'] = action agent.speed_data['transition_action_on_cellexit'] = action
action_selected = True action_selected = True
...@@ -260,10 +260,10 @@ class RailEnv(Environment): ...@@ -260,10 +260,10 @@ class RailEnv(Environment):
# But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
# try to keep moving forward! # try to keep moving forward!
if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving: if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
if all([new_cell_isValid, transition_isValid]): if all([new_cell_valid, transition_valid]):
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
action_selected = True action_selected = True
...@@ -271,17 +271,15 @@ class RailEnv(Environment): ...@@ -271,17 +271,15 @@ class RailEnv(Environment):
# TODO: an invalid action was chosen after entering the cell. The agent cannot move. # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self.rewards_dict[i_agent] += invalid_action_penalty self.rewards_dict[i_agent] += invalid_action_penalty
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
agent.moving = False
self.rewards_dict[i_agent] += stop_penalty self.rewards_dict[i_agent] += stop_penalty
agent.moving = False
continue continue
else: else:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move. # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self.rewards_dict[i_agent] += invalid_action_penalty self.rewards_dict[i_agent] += invalid_action_penalty
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
agent.moving = False
self.rewards_dict[i_agent] += stop_penalty self.rewards_dict[i_agent] += stop_penalty
agent.moving = False
continue continue
if agent.moving and (action_selected or agent.speed_data['position_fraction'] > 0.0): if agent.moving and (action_selected or agent.speed_data['position_fraction'] > 0.0):
...@@ -293,10 +291,10 @@ class RailEnv(Environment): ...@@ -293,10 +291,10 @@ class RailEnv(Environment):
# Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering # Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering
# the cell # the cell
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent) self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent)
if all([new_cell_isValid, transition_isValid, cell_isFree]): if all([new_cell_valid, transition_valid, cell_free]):
agent.position = new_position agent.position = new_position
agent.direction = new_direction agent.direction = new_direction
agent.speed_data['position_fraction'] = 0.0 agent.speed_data['position_fraction'] = 0.0
...@@ -316,14 +314,14 @@ class RailEnv(Environment): ...@@ -316,14 +314,14 @@ class RailEnv(Environment):
def _check_action_on_agent(self, action, agent): def _check_action_on_agent(self, action, agent):
# compute number of possible transitions in the current # compute number of possible transitions in the current
# cell used to check for invalid actions # cell used to check for invalid actions
new_direction, transition_isValid = self.check_action(agent, action) new_direction, transition_valid = self.check_action(agent, action)
new_position = get_new_position(agent.position, new_direction) new_position = get_new_position(agent.position, new_direction)
# Is it a legal move? # Is it a legal move?
# 1) transition allows the new_direction in the cell, # 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0), # 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell # 3) the cell is free, i.e., no agent is currently in that cell
new_cell_isValid = ( new_cell_valid = (
np.array_equal( # Check the new position is still in the grid np.array_equal( # Check the new position is still in the grid
new_position, new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
...@@ -331,19 +329,19 @@ class RailEnv(Environment): ...@@ -331,19 +329,19 @@ class RailEnv(Environment):
self.rail.get_transitions(new_position) > 0) self.rail.get_transitions(new_position) > 0)
# If transition validity hasn't been checked yet. # If transition validity hasn't been checked yet.
if transition_isValid is None: if transition_valid is None:
transition_isValid = self.rail.get_transition( transition_valid = self.rail.get_transition(
(*agent.position, agent.direction), (*agent.position, agent.direction),
new_direction) new_direction)
# Check the new position is not the same as any of the existing agent positions # Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving) # (including itself, for simplicity, since it is moving)
cell_isFree = not np.any( cell_free = not np.any(
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid return cell_free, new_cell_valid, new_direction, new_position, transition_valid
def check_action(self, agent, action): def check_action(self, agent, action):
transition_isValid = None transition_valid = None
possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
num_transitions = np.count_nonzero(possible_transitions) num_transitions = np.count_nonzero(possible_transitions)
...@@ -351,12 +349,12 @@ class RailEnv(Environment): ...@@ -351,12 +349,12 @@ class RailEnv(Environment):
if action == RailEnvActions.MOVE_LEFT: if action == RailEnvActions.MOVE_LEFT:
new_direction = agent.direction - 1 new_direction = agent.direction - 1
if num_transitions <= 1: if num_transitions <= 1:
transition_isValid = False transition_valid = False
elif action == RailEnvActions.MOVE_RIGHT: elif action == RailEnvActions.MOVE_RIGHT:
new_direction = agent.direction + 1 new_direction = agent.direction + 1
if num_transitions <= 1: if num_transitions <= 1:
transition_isValid = False transition_valid = False
new_direction %= 4 new_direction %= 4
...@@ -366,8 +364,8 @@ class RailEnv(Environment): ...@@ -366,8 +364,8 @@ class RailEnv(Environment):
# new_direction will be the only valid transition # new_direction will be the only valid transition
# - take only available transition # - take only available transition
new_direction = np.argmax(possible_transitions) new_direction = np.argmax(possible_transitions)
transition_isValid = True transition_valid = True
return new_direction, transition_isValid return new_direction, transition_valid
def _get_observations(self): def _get_observations(self):
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
......
import numpy as np
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
def test_walker():
# _ _ _
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
rail_map = np.array(
[[dead_end_from_east] + [horizontal_straight] + [dead_end_from_west]], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
)
# reset to initialize agents_static
env.reset()
# set initial position and direction for testing...
env.agents_static[0].position = (0, 1)
env.agents_static[0].direction = 1
env.agents_static[0].target = (0, 0)
# reset to set agents from agents_static
env.reset(False, False)
obs_builder: TreeObsForRailEnv = env.obs_builder
print(obs_builder.distance_map[(0, *[0, 1], 1)])
assert obs_builder.distance_map[(0, *[0, 1], 1)] == 3
print(obs_builder.distance_map[(0, *[0, 2], 3)])
assert obs_builder.distance_map[(0, *[0, 2], 1)] == 2 # does not work yet, Erik's proposal.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment