From 9ce6b22101a3c903fcc20dcb06b5343bbaf5986a Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 1 Sep 2019 09:24:44 -0400 Subject: [PATCH] using new level generator for training and inference --- torch_training/multi_agent_inference.py | 72 +++++++++++--------- torch_training/multi_agent_training.py | 84 ++++++++++++++++------- torch_training/predictors/predictions.py | 86 ++++++++++++++++++++++-- torch_training/training_navigation.py | 3 - 4 files changed, 179 insertions(+), 66 deletions(-) diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 3fc6468..8c1cbd0 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -9,47 +9,59 @@ from predictors.predictions import ShortestPathPredictorForRailEnv import torch_training.Nets from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import rail_from_file -from flatland.envs.schedule_generators import schedule_from_file +from flatland.envs.rail_generators import rail_from_file, sparse_rail_generator +from flatland.envs.schedule_generators import schedule_from_file, sparse_schedule_generator + from flatland.utils.rendertools import RenderTool from torch_training.dueling_double_dqn import Agent from utils.observation_utils import normalize_observation random.seed(3) np.random.seed(2) - -tree_depth = 3 -observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv(10)) - -file_name = "./railway/simple_avoid.pkl" -env = RailEnv(width=10, - height=20, - rail_generator=rail_from_file(file_name), - schedule_generator=schedule_from_file(file_name), - obs_builder_object=observation_helper) -x_dim = env.width -y_dim = env.height - -""" - -x_dim = 10 # np.random.randint(8, 20) -y_dim = 10 # np.random.randint(8, 20) -n_agents = 5 # np.random.randint(3, 8) -n_goals = n_agents + np.random.randint(0, 3) -min_dist = int(0.75 * min(x_dim, y_dim)) +# Parameters for the Environment +x_dim = 20 +y_dim = 20 +n_agents = 5 +tree_depth = 2 + +# Use a the malfunction generator to break agents from time to time +stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents + 'malfunction_rate': 30, # Rate of malfunction occurence + 'min_duration': 3, # Minimal duration of malfunction + 'max_duration': 20 # Max duration of malfunction + } + +# Custom observation builder +predictor = ShortestPathPredictorForRailEnv() +observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor) + +# Different agent types (trains) with different speeds. +speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train env = RailEnv(width=x_dim, height=y_dim, - rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, - max_dist=99999, - seed=0), - schedule_generator=complex_schedule_generator(), - obs_builder_object=observation_helper, - number_of_agents=n_agents) + rail_generator=sparse_rail_generator(num_cities=5, + # Number of cities in map (where train stations are) + num_intersections=4, + # Number of intersections (no start / target) + num_trainstations=10, # Number of possible start/targets on map + min_node_dist=3, # Minimal distance of nodes + node_radius=2, # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities/intersections + seed=15, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=n_agents, + stochastic_data=stochastic_data, # Malfunction data generator + obs_builder_object=observation_helper) env.reset(True, True) -""" - env_renderer = RenderTool(env, gl="PILSVG", ) handle = env.get_agent_handles() num_features_per_node = env.obs_builder.observation_dim diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 7da12e5..4822704 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -14,9 +14,9 @@ import torch_training.Nets from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.rail_generators import sparse_rail_generator # Import Flatland/ Observations and Predictors -from flatland.envs.schedule_generators import complex_schedule_generator +from flatland.envs.schedule_generators import sparse_schedule_generator from torch_training.dueling_double_dqn import Agent from utils.observation_utils import normalize_observation @@ -36,30 +36,55 @@ def main(argv): np.random.seed(1) # Initialize a random map with a random number of agents - x_dim = np.random.randint(8, 15) - y_dim = np.random.randint(8, 15) - n_agents = np.random.randint(3, 8) - n_goals = n_agents + np.random.randint(0, 3) - min_dist = int(0.75 * min(x_dim, y_dim)) - tree_depth = 3 - print("main2") """ Get an observation builder and predictor: The predictor will always predict the shortest path from the current location of the agent. This is used to warn for potential conflicts --> Should be enhanced to get better performance! """ + + # Parameters for the Environment + x_dim = 20 + y_dim = 20 + n_agents = 5 + tree_depth = 2 + + # Use a the malfunction generator to break agents from time to time + stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents + 'malfunction_rate': 30, # Rate of malfunction occurence + 'min_duration': 3, # Minimal duration of malfunction + 'max_duration': 20 # Max duration of malfunction + } + + # Custom observation builder predictor = ShortestPathPredictorForRailEnv() observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor) + # Different agent types (trains) with different speeds. + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + env = RailEnv(width=x_dim, height=y_dim, - rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, - max_dist=99999, - seed=0), - schedule_generator=complex_schedule_generator(), - obs_builder_object=observation_helper, - number_of_agents=n_agents) + rail_generator=sparse_rail_generator(num_cities=5, + # Number of cities in map (where train stations are) + num_intersections=4, + # Number of intersections (no start / target) + num_trainstations=10, # Number of possible start/targets on map + min_node_dist=3, # Minimal distance of nodes + node_radius=2, # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities/intersections + seed=15, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=n_agents, + stochastic_data=stochastic_data, # Malfunction data generator + obs_builder_object=observation_helper) env.reset(True, True) handle = env.get_agent_handles() @@ -105,19 +130,26 @@ def main(argv): and the size of the levels every 50 episodes. """ if episodes % 50 == 1: - x_dim = np.random.randint(8, 15) - y_dim = np.random.randint(8, 15) - n_agents = np.random.randint(3, 8) - n_goals = n_agents + np.random.randint(0, 3) - min_dist = int(0.75 * min(x_dim, y_dim)) env = RailEnv(width=x_dim, height=y_dim, - rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, - max_dist=99999, - seed=0), - schedule_generator=complex_schedule_generator(), - obs_builder_object=observation_helper, - number_of_agents=n_agents) + rail_generator=sparse_rail_generator(num_cities=5, + # Number of cities in map (where train stations are) + num_intersections=4, + # Number of intersections (no start / target) + num_trainstations=10, + # Number of possible start/targets on map + min_node_dist=3, # Minimal distance of nodes + node_radius=2, # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities/intersections + seed=15, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=n_agents, + stochastic_data=stochastic_data, # Malfunction data generator + obs_builder_object=observation_helper) # Adjust the parameters according to the new env. max_steps = int((env.height + env.width)) diff --git a/torch_training/predictors/predictions.py b/torch_training/predictors/predictions.py index 10abcef..4718ad9 100644 --- a/torch_training/predictors/predictions.py +++ b/torch_training/predictors/predictions.py @@ -8,6 +8,76 @@ from flatland.core.env_prediction_builder import PredictionBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.rail_env import RailEnvActions + +class DummyPredictorForRailEnv(PredictionBuilder): + """ + DummyPredictorForRailEnv object. + + This object returns predictions for agents in the RailEnv environment. + The prediction acts as if no other agent is in the environment and always takes the forward action. + """ + + def get(self, custom_args=None, handle=None): + """ + Called whenever get_many in the observation build is called. + + Parameters + ------- + custom_args: dict + Not used in this dummy implementation. + handle : int (optional) + Handle of the agent for which to compute the observation vector. + + Returns + ------- + np.array + Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements: + - time_offset + - position axis 0 + - position axis 1 + - direction + - action taken to come here + The prediction at 0 is the current position, direction etc. + + """ + agents = self.env.agents + if handle: + agents = [self.env.agents[handle]] + + prediction_dict = {} + + for agent in agents: + action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT] + _agent_initial_position = agent.position + _agent_initial_direction = agent.direction + prediction = np.zeros(shape=(self.max_depth + 1, 5)) + prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] + for index in range(1, self.max_depth + 1): + action_done = False + # if we're at the target, stop moving... + if agent.position == agent.target: + prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING] + + continue + for action in action_priorities: + cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ + self.env._check_action_on_agent(action, agent) + if all([new_cell_isValid, transition_isValid]): + # move and change direction to face the new_direction that was + # performed + agent.position = new_position + agent.direction = new_direction + prediction[index] = [index, *new_position, new_direction, action] + action_done = True + break + if not action_done: + raise Exception("Cannot move further. Something is wrong") + prediction_dict[agent.handle] = prediction + agent.position = _agent_initial_position + agent.direction = _agent_initial_direction + return prediction_dict + + class ShortestPathPredictorForRailEnv(PredictionBuilder): """ ShortestPathPredictorForRailEnv object. @@ -16,7 +86,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): The prediction acts as if no other agent is in the environment and always takes the forward action. """ - def __init__(self, max_depth): + def __init__(self, max_depth=20): + # Initialize with depth 20 self.max_depth = max_depth def get(self, custom_args=None, handle=None): @@ -53,10 +124,13 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): for agent in agents: _agent_initial_position = agent.position _agent_initial_direction = agent.direction + agent_speed = agent.speed_data["speed"] + times_per_cell = int(np.reciprocal(agent_speed)) prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] + new_direction = _agent_initial_direction + new_position = _agent_initial_position visited = set() - for index in range(1, self.max_depth + 1): # if we're at the target, stop moving... if agent.position == agent.target: @@ -70,12 +144,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): # Take shortest possible path cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) - new_position = None - new_direction = None - if np.sum(cell_transitions) == 1: + if np.sum(cell_transitions) == 1 and index % times_per_cell == 0: new_direction = np.argmax(cell_transitions) new_position = get_new_position(agent.position, new_direction) - elif np.sum(cell_transitions) > 1: + elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0: min_dist = np.inf no_dist_found = True for direction in range(4): @@ -87,7 +159,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): new_direction = direction no_dist_found = False new_position = get_new_position(agent.position, new_direction) - else: + elif index % times_per_cell == 0: raise Exception("No transition possible {}".format(cell_transitions)) # update the agent's position and direction diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 3110590..87ec597 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -36,9 +36,6 @@ def main(argv): n_goals = 5 min_dist = 5 - # We are training an Agent using the Tree Observation with depth 2 - observation_builder = TreeObsForRailEnv(max_depth=2) - # Use a the malfunction generator to break agents from time to time stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents 'malfunction_rate': 30, # Rate of malfunction occurence -- GitLab