diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 848ac15aab5f9b0c933751da8f278633bb1077b8..e4d693065aec0307bee8eaf3acd1c9e9e9df0e93 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -4,6 +4,7 @@ Definition of the RailEnv environment. # TODO: _ this is a global method --> utils or remove later import warnings from enum import IntEnum +from typing import List import msgpack import msgpack_numpy as m @@ -89,6 +90,15 @@ class RailEnv(Environment): For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility. """ + alpha = 1.0 + beta = 1.0 + # Epsilon to avoid rounding errors + epsilon = 0.01 + invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty + step_penalty = -1 * alpha + global_reward = 1 * beta + stop_penalty = 0 # penalty for stopping a moving agent + start_penalty = 0 # penalty for starting a stopped agent def __init__(self, width, @@ -156,8 +166,8 @@ class RailEnv(Environment): self.dev_obs_dict = {} self.dev_pred_dict = {} - self.agents = [None] * number_of_agents # live agents - self.agents_static = [None] * number_of_agents # static agent information + self.agents: List[EnvAgent] = [None] * number_of_agents # live agents + self.agents_static: List[EnvAgentStatic] = [None] * number_of_agents # static agent information self.num_resets = 0 self.action_space = [1] @@ -230,17 +240,17 @@ class RailEnv(Environment): self.height, self.width = self.rail.grid.shape for r in range(self.height): for c in range(self.width): - rcPos = (r, c) - check = self.rail.cell_neighbours_valid(rcPos, True) + rc_pos = (r, c) + check = self.rail.cell_neighbours_valid(rc_pos, True) if not check: - warnings.warn("Invalid grid at {} -> {}".format(rcPos, check)) + warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check)) if replace_agents: agents_hints = None if optionals and 'agents_hints' in optionals: agents_hints = optionals['agents_hints'] self.agents_static = EnvAgentStatic.from_lists( - *self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints)) + *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints)) self.restart_agents() for i_agent in range(self.get_num_agents()): @@ -252,7 +262,7 @@ class RailEnv(Environment): agent.malfunction_data['malfunction'] = 0 - self._agent_malfunction(agent) + self._agent_new_malfunction(i_agent, RailEnvActions.DO_NOTHING) self.num_resets += 1 self._elapsed_steps = 0 @@ -267,52 +277,48 @@ class RailEnv(Environment): # Return the new observation vectors for each agent return self._get_observations() - def _agent_malfunction(self, agent): + def _agent_new_malfunction(self, i_agent, action) -> bool: + """ + Returns true if the agent enters into malfunction. (False, if not broken down or already broken down before). + """ + agent = self.agents[i_agent] + # Decrease counter for next event if agent.malfunction_data['malfunction_rate'] > 0: agent.malfunction_data['next_malfunction'] -= 1 # Only agents that have a positive rate for malfunctions and are not currently broken are considered - if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']: - - # If counter has come to zero --> Agent has malfunction - # set next malfunction time and duration of current malfunction - if agent.malfunction_data['next_malfunction'] <= 0: - # Increase number of malfunctions - agent.malfunction_data['nr_malfunctions'] += 1 - - # Next malfunction in number of stops - next_breakdown = int( - np.random.exponential(scale=agent.malfunction_data['malfunction_rate'])) - agent.malfunction_data['next_malfunction'] = next_breakdown - - # Duration of current malfunction - num_broken_steps = np.random.randint(self.min_number_of_steps_broken, - self.max_number_of_steps_broken + 1) + 1 - agent.malfunction_data['malfunction'] = num_broken_steps - + # If counter has come to zero --> Agent has malfunction + # set next malfunction time and duration of current malfunction + if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction'] and \ + agent.malfunction_data['next_malfunction'] <= 0: + # Increase number of malfunctions + agent.malfunction_data['nr_malfunctions'] += 1 + + # Next malfunction in number of stops + next_breakdown = int( + np.random.exponential(scale=agent.malfunction_data['malfunction_rate'])) + agent.malfunction_data['next_malfunction'] = next_breakdown + + # Duration of current malfunction + num_broken_steps = np.random.randint(self.min_number_of_steps_broken, + self.max_number_of_steps_broken + 1) + 1 + agent.malfunction_data['malfunction'] = num_broken_steps + + return True + return False + + # TODO refactor to decrease length of this method! def step(self, action_dict_): self._elapsed_steps += 1 - action_dict = action_dict_.copy() - - alpha = 1.0 - beta = 1.0 - # Epsilon to avoid rounding errors - epsilon = 0.01 - invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty - step_penalty = -1 * alpha - global_reward = 1 * beta - stop_penalty = 0 # penalty for stopping a moving agent - start_penalty = 0 # penalty for starting a stopped agent - # Reset the step rewards self.rewards_dict = dict() for i_agent in range(self.get_num_agents()): self.rewards_dict[i_agent] = 0 if self.dones["__all__"]: - self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()} + self.rewards_dict = {i: r + self.global_reward for i, r in self.rewards_dict.items()} info_dict = { 'action_required': {i: False for i in range(self.get_num_agents())}, 'malfunction': {i: 0 for i in range(self.get_num_agents())}, @@ -321,136 +327,132 @@ class RailEnv(Environment): return self._get_observations(), self.rewards_dict, self.dones, info_dict for i_agent in range(self.get_num_agents()): - agent = self.agents[i_agent] - agent.old_direction = agent.direction - agent.old_position = agent.position - - # Check if agent breaks at this step - self._agent_malfunction(agent) if self.dones[i_agent]: # this agent has already completed... continue - # No action has been supplied for this agent - if i_agent not in action_dict: - action_dict[i_agent] = RailEnvActions.DO_NOTHING - - # The train is broken - if agent.malfunction_data['malfunction'] > 0: - # Last step of malfunction --> Agent starts moving again after getting fixed - if agent.malfunction_data['malfunction'] < 2: - agent.malfunction_data['malfunction'] -= 1 - self.agents[i_agent].moving = True - action_dict[i_agent] = RailEnvActions.DO_NOTHING - - else: - agent.malfunction_data['malfunction'] -= 1 - - # Broken agents are stopped - self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] - self.agents[i_agent].moving = False - action_dict[i_agent] = RailEnvActions.DO_NOTHING + agent = self.agents[i_agent] + agent.old_direction = agent.direction + agent.old_position = agent.position - # Nothing left to do with broken agent - continue + # No action has been supplied for this agent -> set DO_NOTHING as default + if i_agent not in action_dict_: + action = RailEnvActions.DO_NOTHING + else: + action = action_dict_[i_agent] - if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): - print('ERROR: illegal action=', action_dict[i_agent], + if action < 0 or action > len(RailEnvActions): + print('ERROR: illegal action=', action, 'for agent with index=', i_agent, '"DO NOTHING" will be executed instead') - action_dict[i_agent] = RailEnvActions.DO_NOTHING - - action = action_dict[i_agent] - - if action == RailEnvActions.DO_NOTHING and agent.moving: - # Keep moving - action = RailEnvActions.MOVE_FORWARD + action = RailEnvActions.DO_NOTHING - if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data[ - 'position_fraction'] <= epsilon: - # Only allow halting an agent on entering new cells. - agent.moving = False - self.rewards_dict[i_agent] += stop_penalty + # Check if agent breaks at this step + new_malfunction = self._agent_new_malfunction(i_agent, action) + + # Is the agent at the beginning of the cell? Then, it can take an action + # Design choice (Erik+Christian): + # as long as we're broken down at the beginning of the cell, we can choose other actions! + if agent.speed_data['position_fraction'] == 0.0: + if action == RailEnvActions.DO_NOTHING and agent.moving: + # Keep moving + action = RailEnvActions.MOVE_FORWARD + + if action == RailEnvActions.STOP_MOVING and agent.moving: + # Only allow halting an agent on entering new cells. + agent.moving = False + self.rewards_dict[i_agent] += self.stop_penalty - if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): - # Allow agent to start with any forward or direction action - agent.moving = True - self.rewards_dict[i_agent] += start_penalty + if not agent.moving and not ( + action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): + # Allow agent to start with any forward or direction action + agent.moving = True + self.rewards_dict[i_agent] += self.start_penalty - # Now perform a movement. - # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps) - # store the desired action in `transition_action_on_cellexit' (only if the desired transition is - # allowed! otherwise DO_NOTHING!) - # Then in any case (if agent.moving) and the `transition_action_on_cellexit' is valid, increment the - # position_fraction by the speed of the agent (regardless of action taken, as long as no - # STOP_MOVING, but that makes agent.moving=False) - # If the new position fraction is >= 1, reset to 0, and perform the stored - # transition_action_on_cellexit - - # If the agent can make an action - action_selected = False - if agent.speed_data['position_fraction'] <= epsilon: - if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: - cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ + # Store the action + if agent.moving and action not in [RailEnvActions.DO_NOTHING, RailEnvActions.STOP_MOVING]: + _, new_cell_valid, new_direction, new_position, transition_valid = \ self._check_action_on_agent(action, agent) if all([new_cell_valid, transition_valid]): agent.speed_data['transition_action_on_cellexit'] = action - action_selected = True - else: # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, # try to keep moving forward! - if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving: - cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ + if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT): + _, new_cell_valid, new_direction, new_position, transition_valid = \ self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) if all([new_cell_valid, transition_valid]): agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD - action_selected = True - else: - # TODO: an invalid action was chosen after entering the cell. The agent cannot move. - self.rewards_dict[i_agent] += invalid_action_penalty - self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] - self.rewards_dict[i_agent] += stop_penalty + # If the agent cannot move due to an invalid transition, we set its state to not moving + self.rewards_dict[i_agent] += self.invalid_action_penalty + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agent] += self.stop_penalty agent.moving = False - continue + else: - # TODO: an invalid action was chosen after entering the cell. The agent cannot move. - self.rewards_dict[i_agent] += invalid_action_penalty - self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] - self.rewards_dict[i_agent] += stop_penalty + # If the agent cannot move due to an invalid transition, we set its state to not moving + self.rewards_dict[i_agent] += self.invalid_action_penalty + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agent] += self.stop_penalty agent.moving = False - continue - if agent.moving and (action_selected or agent.speed_data['position_fraction'] > 0.0): - agent.speed_data['position_fraction'] += agent.speed_data['speed'] + # if we've just broken in this step, nothing else to do + if new_malfunction: + continue - if agent.speed_data['position_fraction'] >= 1.0: + # The train was broken before... + if agent.malfunction_data['malfunction'] > 0: - # Perform stored action to transition to the next cell as soon as cell is free - cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent) + # Last step of malfunction --> Agent starts moving again after getting fixed + if agent.malfunction_data['malfunction'] < 2: + agent.malfunction_data['malfunction'] -= 1 + self.agents[i_agent].moving = True + action = RailEnvActions.DO_NOTHING - if all([new_cell_valid, transition_valid, cell_free]) and agent.malfunction_data['malfunction'] == 0: - agent.position = new_position - agent.direction = new_direction - agent.speed_data['position_fraction'] = 0.0 - elif not transition_valid or not new_cell_valid: - # If the agent cannot move due to an invalid transition, we set its state to not moving - agent.moving = False + else: + agent.malfunction_data['malfunction'] -= 1 + + # Broken agents are stopped + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + self.agents[i_agent].moving = False + + # Nothing left to do with broken agent + continue + + # Now perform a movement. + # If agent.moving, increment the position_fraction by the speed of the agent + # If the new position fraction is >= 1, reset to 0, and perform the stored + # transition_action_on_cellexit if the cell is free. + if agent.moving: + + agent.speed_data['position_fraction'] += agent.speed_data['speed'] + if agent.speed_data['position_fraction'] >= 1.0: + # Perform stored action to transition to the next cell as soon as cell is free + # Notice that we've already check new_cell_valid and transition valid when we stored the action, + # so we only have to check cell_free now! + + # cell and transition validity was checked when we stored transition_action_on_cellexit! + cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( + agent.speed_data['transition_action_on_cellexit'], agent) + + if cell_free: + agent.position = new_position + agent.direction = new_direction + agent.speed_data['position_fraction'] = 0.0 if np.equal(agent.position, agent.target).all(): self.dones[i_agent] = True agent.moving = False else: - self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] # Check for end of episode + add global reward to all rewards! if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): self.dones["__all__"] = True - self.rewards_dict = {i: 0 * r + global_reward for i, r in self.rewards_dict.items()} + self.rewards_dict = {i: 0 * r + self.global_reward for i, r in self.rewards_dict.items()} if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): self.dones["__all__"] = True @@ -458,7 +460,7 @@ class RailEnv(Environment): self.dones[k] = True action_required_agents = { - i: self.agents[i].speed_data['position_fraction'] <= epsilon for i in range(self.get_num_agents()) + i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in range(self.get_num_agents()) } malfunction_agents = { i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) @@ -474,6 +476,7 @@ class RailEnv(Environment): return self._get_observations(), self.rewards_dict, self.dones, info_dict def _check_action_on_agent(self, action, agent): + # compute number of possible transitions in the current # cell used to check for invalid actions new_direction, transition_valid = self.check_action(agent, action) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 8573c25c31f59d9ddb3d9341d891aa1eae231b8e..5af4a079b1b6d210a395896df5d577c1c7a16267 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -1,6 +1,6 @@ """Rail generators (infrastructure manager, "Infrastrukturbetreiber").""" import warnings -from typing import Callable, Tuple, Any, Optional +from typing import Callable, Tuple, Optional, Dict, List, Any import msgpack import numpy as np @@ -11,7 +11,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes -RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]] +RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] @@ -560,63 +560,43 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Generate a set of nodes for the sparse network # Try to connect cities to nodes first - node_positions = [] city_positions = [] intersection_positions = [] + # Evenly distribute cities and intersections + node_positions: List[Any] = None + nb_nodes = num_cities + num_intersections if grid_mode: - tot_num_node = num_intersections + num_cities nodes_ratio = height / width - nodes_per_row = int(np.ceil(np.sqrt(tot_num_node * nodes_ratio))) - nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row)) + nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio))) + nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row)) x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) - city_idx = np.random.choice(np.arange(tot_num_node), num_cities) + city_idx = np.random.choice(np.arange(nb_nodes), num_cities) - for node_idx in range(num_cities + num_intersections): - to_close = True - tries = 0 + node_positions = _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, + nb_nodes, + nodes_per_row, x_positions, + y_positions) - if not grid_mode: - while to_close: - x_tmp = node_radius + np.random.randint(height - node_radius) - y_tmp = node_radius + np.random.randint(width - node_radius) - to_close = False - - # Check distance to cities - for node_pos in city_positions: - if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: - to_close = True - - # Check distance to intersections - for node_pos in intersection_positions: - if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: - to_close = True - - if not to_close: - node_positions.append((x_tmp, y_tmp)) - if node_idx < num_cities: - city_positions.append((x_tmp, y_tmp)) - else: - intersection_positions.append((x_tmp, y_tmp)) - tries += 1 - if tries > 100: - warnings.warn("Could not set nodes, please change initial parameters!!!!") - break - else: - x_tmp = x_positions[node_idx % nodes_per_row] - y_tmp = y_positions[node_idx // nodes_per_row] - if node_idx in city_idx: - city_positions.append((x_tmp, y_tmp)) - else: - intersection_positions.append((x_tmp, y_tmp)) - node_positions = city_positions + intersection_positions + + + else: + + node_positions = _generate_node_positions_not_grid_mode(city_positions, height, + intersection_positions, + nb_nodes, width) + + # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode + nb_nodes = len(node_positions) + _num_cities = len(city_positions) + _num_intersections = len(intersection_positions) # Chose node connection # Set up list of available nodes to connect to - available_nodes_full = np.arange(num_cities + num_intersections) - available_cities = np.arange(num_cities) - available_intersections = np.arange(num_cities, num_cities + num_intersections) + available_nodes_full = np.arange(nb_nodes) + available_cities = np.arange(_num_cities) + available_intersections = np.arange(_num_cities, nb_nodes) # Start at some node current_node = np.random.randint(len(available_nodes_full)) @@ -629,13 +609,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) # Priority city to intersection connections - if current_node < num_cities and len(available_intersections) > 0: + if current_node < _num_cities and len(available_intersections) > 0: available_nodes = available_intersections delete_idx = np.where(available_cities == current_node) available_cities = np.delete(available_cities, delete_idx, 0) # Priority intersection to city connections - elif current_node >= num_cities and len(available_cities) > 0: + elif current_node >= _num_cities and len(available_cities) > 0: available_nodes = available_cities delete_idx = np.where(available_intersections == current_node) available_intersections = np.delete(available_intersections, delete_idx, 0) @@ -669,15 +649,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 node_stack.pop(0) # Place train stations close to the node - # We currently place them uniformly distirbuted among all cities + # We currently place them uniformly distributed among all cities built_num_trainstation = 0 - train_stations = [[] for i in range(num_cities)] + train_stations = [[] for i in range(_num_cities)] - if num_cities > 1: + if _num_cities > 1: for station in range(num_trainstations): spot_found = True - trainstation_node = int(station / num_trainstations * num_cities) + trainstation_node = int(station / num_trainstations * _num_cities) station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), 0, @@ -702,6 +682,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if tries > 100: warnings.warn("Could not set trainstations, please change initial parameters!!!!") spot_found = False + break if spot_found: train_stations[trainstation_node].append((station_x, station_y)) @@ -725,7 +706,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # We currently place them uniformly distirbuted among all cities if enhance_intersection: - for intersection in range(num_intersections): + for intersection in range(_num_intersections): intersect_x_1 = np.clip(intersection_positions[intersection][0] + np.random.randint(1, 3), 1, height - 2) @@ -762,7 +743,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Slot availability in node node_available_start = [] node_available_target = [] - for node_idx in range(num_cities): + for node_idx in range(_num_cities): node_available_start.append(len(train_stations[node_idx])) node_available_target.append(len(train_stations[node_idx])) @@ -797,4 +778,57 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 'train_stations': train_stations }} + def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes, + width): + + node_positions = [] + for node_idx in range(nb_nodes): + to_close = True + tries = 0 + + while to_close: + x_tmp = node_radius + np.random.randint(height - node_radius) + y_tmp = node_radius + np.random.randint(width - node_radius) + to_close = False + + # Check distance to cities + for node_pos in city_positions: + if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: + to_close = True + + # Check distance to intersections + for node_pos in intersection_positions: + if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: + to_close = True + + if not to_close: + node_positions.append((x_tmp, y_tmp)) + if node_idx < num_cities: + city_positions.append((x_tmp, y_tmp)) + else: + intersection_positions.append((x_tmp, y_tmp)) + tries += 1 + if tries > 100: + warnings.warn( + "Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format( + len(node_positions), + tries, nb_nodes)) + break + + node_positions = city_positions + intersection_positions + return node_positions + + def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes, + nodes_per_row, x_positions, y_positions): + for node_idx in range(nb_nodes): + + x_tmp = x_positions[node_idx % nodes_per_row] + y_tmp = y_positions[node_idx // nodes_per_row] + if node_idx in city_idx: + city_positions.append((x_tmp, y_tmp)) + else: + intersection_positions.append((x_tmp, y_tmp)) + node_positions = city_positions + intersection_positions + return node_positions + return generator diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 4645d80aadc3eb247d8b60b1c8456fc250d7feb8..a0e2b995b35bc2d6984bf6274170a28c18cada70 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -55,24 +55,25 @@ def test_rail_env_action_required_info(): obs_builder_object=GlobalObsForRailEnv()) np.random.seed(0) env_only_if_action_required = RailEnv(width=50, - height=50, - rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map - num_intersections=10, - # Number of interesections in map - num_trainstations=50, - # Number of possible start/targets on map - min_node_dist=6, # Minimal distance of nodes - node_radius=3, - # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities - seed=5, # Random seed - grid_mode=False - # Ordered distribution of nodes - ), - schedule_generator=sparse_schedule_generator(speed_ration_map), - number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv()) + height=50, + rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map + num_intersections=10, + # Number of interesections in map + num_trainstations=50, + # Number of possible start/targets on map + min_node_dist=6, + # Minimal distance of nodes + node_radius=3, + # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities + seed=5, # Random seed + grid_mode=False + # Ordered distribution of nodes + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv()) env_renderer = RenderTool(env_always_action, gl="PILSVG", ) for step in range(100): @@ -87,7 +88,8 @@ def test_rail_env_action_required_info(): if step == 0 or info_only_if_action_required['action_required'][a]: action_dict_only_if_action_required.update({a: action}) else: - print("[{}] not action_required {}, speed_data={}".format(step, a, env_always_action.agents[a].speed_data)) + print("[{}] not action_required {}, speed_data={}".format(step, a, + env_always_action.agents[a].speed_data)) obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step( action_dict_always_action) @@ -156,3 +158,23 @@ def test_rail_env_malfunction_speed_info(): if done['__all__']: break + + +def test_sparse_generator_with_too_man_cities_does_not_break_down(): + np.random.seed(0) + + RailEnv(width=50, + height=50, + rail_generator=sparse_rail_generator( + num_cities=100, # Number of cities in map + num_intersections=10, # Number of interesections in map + num_trainstations=50, # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=3, # Number of connections to other cities + seed=5, # Random seed + grid_mode=False # Ordered distribution of nodes + ), + schedule_generator=sparse_schedule_generator(), + number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv()) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index eaf782df3255ecfc6ebaa7078935f485497ed359..a63e97229457d606027c88e5189d0cb680f25c9b 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -110,3 +110,38 @@ def test_malfunction_process(): # Check that malfunctioning data was standing around assert total_down_time > 0 + + +def test_malfunction_process_statistically(): + """Tests hat malfunctions are produced by stochastic_data!""" + # Set fixed malfunction duration for this test + stochastic_data = {'prop_malfunction': 1., + 'malfunction_rate': 2, + 'min_duration': 3, + 'max_duration': 3} + np.random.seed(5) + + env = RailEnv(width=20, + height=20, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, + seed=0), + schedule_generator=complex_schedule_generator(), + number_of_agents=2, + obs_builder_object=SingleAgentNavigationObs(), + stochastic_data=stochastic_data) + + env.reset() + nb_malfunction = 0 + for step in range(100): + action_dict = {} + for agent in env.agents: + if agent.malfunction_data['malfunction'] > 0: + nb_malfunction += 1 + # We randomly select an action + action_dict[agent.handle] = np.random.randint(4) + + env.step(action_dict) + + # check that generation of malfunctions works as expected + # results are different in py36 and py37, therefore no exact test on nb_malfunction + assert nb_malfunction > 150 diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 8de36c81e4a13c0b7e7e5e556ad79234503ad31a..86edc08c07552488e72537ec9b1f3b0b7625efed 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -1,8 +1,17 @@ +from typing import List + import numpy as np +from attr import attrib, attrs -from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import complex_rail_generator -from flatland.envs.schedule_generators import complex_schedule_generator +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map +from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator +from flatland.utils.rendertools import RenderTool +from flatland.utils.simple_rail import make_simple_rail np.random.seed(1) @@ -86,3 +95,505 @@ def test_multi_speed_init(): if (step + 1) % (i_agent + 1) == 0: print(step, i_agent, env.agents[i_agent].position) old_pos[i_agent] = env.agents[i_agent].position + + +@attrs +class Replay(object): + position = attrib() + direction = attrib() + action = attrib(type=RailEnvActions) + malfunction = attrib(default=0, type=int) + + +@attrs +class TestConfig(object): + replay = attrib(type=List[Replay]) + target = attrib() + speed = attrib(type=float) + + +def test_multispeed_actions_no_malfunction_no_blocking(rendering=True): + """Test that actions are correctly performed on cell exit for a single agent.""" + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + # initialize agents_static + env.reset() + + # reset to set agents from agents_static + env.reset(False, False) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + + test_config = TestConfig( + replay=[ + Replay( + position=(3, 9), # east dead-end + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + action=None + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_LEFT + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.STOP_MOVING + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.STOP_MOVING + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=None + ), + Replay( + position=(5, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.MOVE_FORWARD + ), + + ], + target=(3, 0), # west dead-end + speed=0.5 + ) + + # TODO test penalties! + agentStatic: EnvAgentStatic = env.agents_static[0] + info_dict = { + 'action_required': [True] + } + for i, replay in enumerate(test_config.replay): + if i == 0: + # set the initial position + agentStatic.position = replay.position + agentStatic.direction = replay.direction + agentStatic.target = test_config.target + agentStatic.moving = True + agentStatic.speed_data['speed'] = test_config.speed + + # reset to set agents from agents_static + env.reset(False, False) + + def _assert(actual, expected, msg): + assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) + + agent: EnvAgent = env.agents[0] + + _assert(agent.position, replay.position, 'position') + _assert(agent.direction, replay.direction, 'direction') + + if replay.action: + assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) + _, _, _, info_dict = env.step({0: replay.action}) + + else: + assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) + _, _, _, info_dict = env.step({}) + + if rendering: + renderer.render_env(show=True, show_observations=True) + + +def test_multispeed_actions_no_malfunction_blocking(rendering=True): + """The second agent blocks the first because it is slower.""" + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + # initialize agents_static + env.reset() + + # reset to set agents from agents_static + env.reset(False, False) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + + test_configs = [ + TestConfig( + replay=[ + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 5), + direction=Grid4TransitionsEnum.WEST, + action=None + ) + ], + target=(3, 0), # west dead-end + speed=1 / 3), + TestConfig( + replay=[ + Replay( + position=(3, 9), # east dead-end + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + action=None + ), + # blocked although fraction >= 1.0 + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + action=None + ), + + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + # blocked although fraction >= 1.0 + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + # blocked although fraction >= 1.0 + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_LEFT + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + # not blocked, action required! + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.MOVE_FORWARD + ), + ], + target=(3, 0), # west dead-end + speed=0.5 + ) + + ] + + # TODO test penalties! + info_dict = { + 'action_required': [True for _ in test_configs] + } + for step in range(len(test_configs[0].replay)): + if step == 0: + for a, test_config in enumerate(test_configs): + agentStatic: EnvAgentStatic = env.agents_static[a] + replay = test_config.replay[0] + # set the initial position + agentStatic.position = replay.position + agentStatic.direction = replay.direction + agentStatic.target = test_config.target + agentStatic.moving = True + agentStatic.speed_data['speed'] = test_config.speed + + # reset to set agents from agents_static + env.reset(False, False) + + def _assert(a, actual, expected, msg): + assert actual == expected, "[{}] {} {}: actual={}, expected={}".format(step, a, msg, actual, expected) + + action_dict = {} + + for a, test_config in enumerate(test_configs): + agent: EnvAgent = env.agents[a] + replay = test_config.replay[step] + + _assert(a, agent.position, replay.position, 'position') + _assert(a, agent.direction, replay.direction, 'direction') + + + + if replay.action: + assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(step, a, True) + action_dict[a] = replay.action + else: + assert info_dict['action_required'][a] == False, "[{}] agent {} expecting action_required={}".format(step, a, False) + _, _, _, info_dict = env.step(action_dict) + + if rendering: + renderer.render_env(show=True, show_observations=True) + + +def test_multispeed_actions_malfunction_no_blocking(rendering=True): + """Test on a single agent whether action on cell exit work correctly despite malfunction.""" + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + + # initialize agents_static + env.reset() + + # reset to set agents from agents_static + env.reset(False, False) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + + test_config = TestConfig( + replay=[ + Replay( + position=(3, 9), # east dead-end + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + action=None + ), + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + # add additional step in the cell + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None, + malfunction=2 # recovers in two steps from now! + ), + # agent recovers in this step + Replay( + position=(3, 8), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(3, 7), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=2 # recovers in two steps from now! + ), + # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken! + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=RailEnvActions.MOVE_LEFT, + ), + Replay( + position=(3, 6), + direction=Grid4TransitionsEnum.WEST, + action=None + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.STOP_MOVING + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.STOP_MOVING + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.MOVE_FORWARD + ), + Replay( + position=(4, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=None + ), + Replay( + position=(5, 6), + direction=Grid4TransitionsEnum.SOUTH, + action=RailEnvActions.MOVE_FORWARD + ), + + ], + target=(3, 0), # west dead-end + speed=0.5 + ) + + # TODO test penalties! + agentStatic: EnvAgentStatic = env.agents_static[0] + info_dict = { + 'action_required': [True] + } + for i, replay in enumerate(test_config.replay): + if i == 0: + # set the initial position + agentStatic.position = replay.position + agentStatic.direction = replay.direction + agentStatic.target = test_config.target + agentStatic.moving = True + agentStatic.speed_data['speed'] = test_config.speed + + # reset to set agents from agents_static + env.reset(False, False) + + def _assert(actual, expected, msg): + assert actual == expected, "[{}] {}: actual={}, expected={}".format(i, msg, actual, expected) + + agent: EnvAgent = env.agents[0] + + _assert(agent.position, replay.position, 'position') + _assert(agent.direction, replay.direction, 'direction') + + if replay.malfunction: + agent.malfunction_data['malfunction'] = 2 + + if replay.action: + assert info_dict['action_required'][0] == True, "[{}] expecting action_required={}".format(i, True) + _, _, _, info_dict = env.step({0: replay.action}) + + else: + assert info_dict['action_required'][0] == False, "[{}] expecting action_required={}".format(i, False) + _, _, _, info_dict = env.step({}) + + if rendering: + renderer.render_env(show=True, show_observations=True)