diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 1f02d518a9714de91d8910b8cb1408f25eb3fe88..18af8a0c88067ee00f9d9b2fb6f8bbbecc486909 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -209,8 +209,7 @@ class TreeObsForRailEnv(ObservationBuilder): #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored. #2: if another agents target is detected the distance in number of cells from the agents current locaiton - is stored - + is stored #3: if another agent is detected the distance in number of cells from current agent position is stored. diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 671b349a4794f565b40e0d085393af6b92a08989..3dbd163ca2d67ffa698a5fe2c931ce98c2dd8bbc 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -140,16 +140,15 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): new_position = get_new_position(agent.position, new_direction) elif np.sum(cell_transitions) > 1: min_dist = np.inf + no_dist_found = True for direction in range(4): if cell_transitions[direction] == 1: neighbour_cell = get_new_position(agent.position, direction) target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction] - if target_dist < min_dist: + if target_dist < min_dist or no_dist_found: min_dist = target_dist new_direction = direction - if new_direction == None: - prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING] - continue + no_dist_found = False new_position = get_new_position(agent.position, new_direction) else: raise Exception("No transition possible {}".format(cell_transitions)) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index f5f46408875e6235493682014d8bc4313ad5ea34..8abfd1b3b295bbe347ff5a25a5cd685314e1621a 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -249,10 +249,10 @@ class RailEnv(Environment): action_selected = False if agent.speed_data['position_fraction'] == 0.: if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: - cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ + cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ self._check_action_on_agent(action, agent) - if all([new_cell_isValid, transition_isValid]): + if all([new_cell_valid, transition_valid]): agent.speed_data['transition_action_on_cellexit'] = action action_selected = True @@ -260,10 +260,10 @@ class RailEnv(Environment): # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, # try to keep moving forward! if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving: - cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ + cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) - if all([new_cell_isValid, transition_isValid]): + if all([new_cell_valid, transition_valid]): agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD action_selected = True @@ -271,17 +271,15 @@ class RailEnv(Environment): # TODO: an invalid action was chosen after entering the cell. The agent cannot move. self.rewards_dict[i_agent] += invalid_action_penalty self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] - agent.moving = False self.rewards_dict[i_agent] += stop_penalty - + agent.moving = False continue else: # TODO: an invalid action was chosen after entering the cell. The agent cannot move. self.rewards_dict[i_agent] += invalid_action_penalty self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] - agent.moving = False self.rewards_dict[i_agent] += stop_penalty - + agent.moving = False continue if agent.moving and (action_selected or agent.speed_data['position_fraction'] > 0.0): @@ -293,10 +291,10 @@ class RailEnv(Environment): # Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering # the cell - cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ + cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent) - if all([new_cell_isValid, transition_isValid, cell_isFree]): + if all([new_cell_valid, transition_valid, cell_free]): agent.position = new_position agent.direction = new_direction agent.speed_data['position_fraction'] = 0.0 @@ -316,14 +314,14 @@ class RailEnv(Environment): def _check_action_on_agent(self, action, agent): # compute number of possible transitions in the current # cell used to check for invalid actions - new_direction, transition_isValid = self.check_action(agent, action) + new_direction, transition_valid = self.check_action(agent, action) new_position = get_new_position(agent.position, new_direction) # Is it a legal move? # 1) transition allows the new_direction in the cell, # 2) the new cell is not empty (case 0), # 3) the cell is free, i.e., no agent is currently in that cell - new_cell_isValid = ( + new_cell_valid = ( np.array_equal( # Check the new position is still in the grid new_position, np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) @@ -331,19 +329,19 @@ class RailEnv(Environment): self.rail.get_transitions(new_position) > 0) # If transition validity hasn't been checked yet. - if transition_isValid is None: - transition_isValid = self.rail.get_transition( + if transition_valid is None: + transition_valid = self.rail.get_transition( (*agent.position, agent.direction), new_direction) # Check the new position is not the same as any of the existing agent positions # (including itself, for simplicity, since it is moving) - cell_isFree = not np.any( + cell_free = not np.any( np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) - return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid + return cell_free, new_cell_valid, new_direction, new_position, transition_valid def check_action(self, agent, action): - transition_isValid = None + transition_valid = None possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) num_transitions = np.count_nonzero(possible_transitions) @@ -351,12 +349,12 @@ class RailEnv(Environment): if action == RailEnvActions.MOVE_LEFT: new_direction = agent.direction - 1 if num_transitions <= 1: - transition_isValid = False + transition_valid = False elif action == RailEnvActions.MOVE_RIGHT: new_direction = agent.direction + 1 if num_transitions <= 1: - transition_isValid = False + transition_valid = False new_direction %= 4 @@ -366,8 +364,8 @@ class RailEnv(Environment): # new_direction will be the only valid transition # - take only available transition new_direction = np.argmax(possible_transitions) - transition_isValid = True - return new_direction, transition_isValid + transition_valid = True + return new_direction, transition_valid def _get_observations(self): self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py new file mode 100644 index 0000000000000000000000000000000000000000..79f4bab164312f757ad584bd3708a7af3fb7a97e --- /dev/null +++ b/tests/test_distance_map.py @@ -0,0 +1,56 @@ +import numpy as np + +from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.generators import rail_from_GridTransitionMap_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv + + +def test_walker(): + # _ _ _ + + cells = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end + transitions = Grid4Transitions([]) + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + + rail_map = np.array( + [[dead_end_from_east] + [horizontal_straight] + [dead_end_from_west]], dtype=np.uint16) + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, + predictor=ShortestPathPredictorForRailEnv(max_depth=10)), + ) + # reset to initialize agents_static + env.reset() + + # set initial position and direction for testing... + env.agents_static[0].position = (0, 1) + env.agents_static[0].direction = 1 + env.agents_static[0].target = (0, 0) + + # reset to set agents from agents_static + env.reset(False, False) + obs_builder: TreeObsForRailEnv = env.obs_builder + + print(obs_builder.distance_map[(0, *[0, 1], 1)]) + assert obs_builder.distance_map[(0, *[0, 1], 1)] == 3 + print(obs_builder.distance_map[(0, *[0, 2], 3)]) + assert obs_builder.distance_map[(0, *[0, 2], 1)] == 2 # does not work yet, Erik's proposal.