Skip to content
Snippets Groups Projects
Commit 9ce6b221 authored by Erik Nygren's avatar Erik Nygren
Browse files

using new level generator for training and inference

parent b8e41dc1
No related branches found
No related tags found
No related merge requests found
...@@ -9,47 +9,59 @@ from predictors.predictions import ShortestPathPredictorForRailEnv ...@@ -9,47 +9,59 @@ from predictors.predictions import ShortestPathPredictorForRailEnv
import torch_training.Nets import torch_training.Nets
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file from flatland.envs.rail_generators import rail_from_file, sparse_rail_generator
from flatland.envs.schedule_generators import schedule_from_file from flatland.envs.schedule_generators import schedule_from_file, sparse_schedule_generator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import normalize_observation from utils.observation_utils import normalize_observation
random.seed(3) random.seed(3)
np.random.seed(2) np.random.seed(2)
# Parameters for the Environment
tree_depth = 3 x_dim = 20
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv(10)) y_dim = 20
n_agents = 5
file_name = "./railway/simple_avoid.pkl" tree_depth = 2
env = RailEnv(width=10,
height=20, # Use a the malfunction generator to break agents from time to time
rail_generator=rail_from_file(file_name), stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents
schedule_generator=schedule_from_file(file_name), 'malfunction_rate': 30, # Rate of malfunction occurence
obs_builder_object=observation_helper) 'min_duration': 3, # Minimal duration of malfunction
x_dim = env.width 'max_duration': 20 # Max duration of malfunction
y_dim = env.height }
""" # Custom observation builder
predictor = ShortestPathPredictorForRailEnv()
x_dim = 10 # np.random.randint(8, 20) observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor)
y_dim = 10 # np.random.randint(8, 20)
n_agents = 5 # np.random.randint(3, 8) # Different agent types (trains) with different speeds.
n_goals = n_agents + np.random.randint(0, 3) speed_ration_map = {1.: 0.25, # Fast passenger train
min_dist = int(0.75 * min(x_dim, y_dim)) 1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, rail_generator=sparse_rail_generator(num_cities=5,
max_dist=99999, # Number of cities in map (where train stations are)
seed=0), num_intersections=4,
schedule_generator=complex_schedule_generator(), # Number of intersections (no start / target)
obs_builder_object=observation_helper, num_trainstations=10, # Number of possible start/targets on map
number_of_agents=n_agents) min_node_dist=3, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=n_agents,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=observation_helper)
env.reset(True, True) env.reset(True, True)
"""
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
handle = env.get_agent_handles() handle = env.get_agent_handles()
num_features_per_node = env.obs_builder.observation_dim num_features_per_node = env.obs_builder.observation_dim
......
...@@ -14,9 +14,9 @@ import torch_training.Nets ...@@ -14,9 +14,9 @@ import torch_training.Nets
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.rail_generators import sparse_rail_generator
# Import Flatland/ Observations and Predictors # Import Flatland/ Observations and Predictors
from flatland.envs.schedule_generators import complex_schedule_generator from flatland.envs.schedule_generators import sparse_schedule_generator
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import normalize_observation from utils.observation_utils import normalize_observation
...@@ -36,30 +36,55 @@ def main(argv): ...@@ -36,30 +36,55 @@ def main(argv):
np.random.seed(1) np.random.seed(1)
# Initialize a random map with a random number of agents # Initialize a random map with a random number of agents
x_dim = np.random.randint(8, 15)
y_dim = np.random.randint(8, 15)
n_agents = np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim))
tree_depth = 3
print("main2")
""" """
Get an observation builder and predictor: Get an observation builder and predictor:
The predictor will always predict the shortest path from the current location of the agent. The predictor will always predict the shortest path from the current location of the agent.
This is used to warn for potential conflicts --> Should be enhanced to get better performance! This is used to warn for potential conflicts --> Should be enhanced to get better performance!
""" """
# Parameters for the Environment
x_dim = 20
y_dim = 20
n_agents = 5
tree_depth = 2
# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
}
# Custom observation builder
predictor = ShortestPathPredictorForRailEnv() predictor = ShortestPathPredictorForRailEnv()
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor) observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor)
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, rail_generator=sparse_rail_generator(num_cities=5,
max_dist=99999, # Number of cities in map (where train stations are)
seed=0), num_intersections=4,
schedule_generator=complex_schedule_generator(), # Number of intersections (no start / target)
obs_builder_object=observation_helper, num_trainstations=10, # Number of possible start/targets on map
number_of_agents=n_agents) min_node_dist=3, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=n_agents,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=observation_helper)
env.reset(True, True) env.reset(True, True)
handle = env.get_agent_handles() handle = env.get_agent_handles()
...@@ -105,19 +130,26 @@ def main(argv): ...@@ -105,19 +130,26 @@ def main(argv):
and the size of the levels every 50 episodes. and the size of the levels every 50 episodes.
""" """
if episodes % 50 == 1: if episodes % 50 == 1:
x_dim = np.random.randint(8, 15)
y_dim = np.random.randint(8, 15)
n_agents = np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim))
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, rail_generator=sparse_rail_generator(num_cities=5,
max_dist=99999, # Number of cities in map (where train stations are)
seed=0), num_intersections=4,
schedule_generator=complex_schedule_generator(), # Number of intersections (no start / target)
obs_builder_object=observation_helper, num_trainstations=10,
number_of_agents=n_agents) # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=n_agents,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=observation_helper)
# Adjust the parameters according to the new env. # Adjust the parameters according to the new env.
max_steps = int((env.height + env.width)) max_steps = int((env.height + env.width))
......
...@@ -8,6 +8,76 @@ from flatland.core.env_prediction_builder import PredictionBuilder ...@@ -8,6 +8,76 @@ from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions from flatland.envs.rail_env import RailEnvActions
class DummyPredictorForRailEnv(PredictionBuilder):
"""
DummyPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, custom_args=None, handle=None):
"""
Called whenever get_many in the observation build is called.
Parameters
-------
custom_args: dict
Not used in this dummy implementation.
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
prediction_dict = {}
for agent in agents:
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
for index in range(1, self.max_depth + 1):
action_done = False
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
continue
for action in action_priorities:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self.env._check_action_on_agent(action, agent)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
# performed
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, *new_position, new_direction, action]
action_done = True
break
if not action_done:
raise Exception("Cannot move further. Something is wrong")
prediction_dict[agent.handle] = prediction
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict
class ShortestPathPredictorForRailEnv(PredictionBuilder): class ShortestPathPredictorForRailEnv(PredictionBuilder):
""" """
ShortestPathPredictorForRailEnv object. ShortestPathPredictorForRailEnv object.
...@@ -16,7 +86,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -16,7 +86,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action. The prediction acts as if no other agent is in the environment and always takes the forward action.
""" """
def __init__(self, max_depth): def __init__(self, max_depth=20):
# Initialize with depth 20
self.max_depth = max_depth self.max_depth = max_depth
def get(self, custom_args=None, handle=None): def get(self, custom_args=None, handle=None):
...@@ -53,10 +124,13 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -53,10 +124,13 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
for agent in agents: for agent in agents:
_agent_initial_position = agent.position _agent_initial_position = agent.position
_agent_initial_direction = agent.direction _agent_initial_direction = agent.direction
agent_speed = agent.speed_data["speed"]
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
new_direction = _agent_initial_direction
new_position = _agent_initial_position
visited = set() visited = set()
for index in range(1, self.max_depth + 1): for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving... # if we're at the target, stop moving...
if agent.position == agent.target: if agent.position == agent.target:
...@@ -70,12 +144,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -70,12 +144,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
# Take shortest possible path # Take shortest possible path
cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
new_position = None if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
new_direction = None
if np.sum(cell_transitions) == 1:
new_direction = np.argmax(cell_transitions) new_direction = np.argmax(cell_transitions)
new_position = get_new_position(agent.position, new_direction) new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1: elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
min_dist = np.inf min_dist = np.inf
no_dist_found = True no_dist_found = True
for direction in range(4): for direction in range(4):
...@@ -87,7 +159,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -87,7 +159,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
new_direction = direction new_direction = direction
no_dist_found = False no_dist_found = False
new_position = get_new_position(agent.position, new_direction) new_position = get_new_position(agent.position, new_direction)
else: elif index % times_per_cell == 0:
raise Exception("No transition possible {}".format(cell_transitions)) raise Exception("No transition possible {}".format(cell_transitions))
# update the agent's position and direction # update the agent's position and direction
......
...@@ -36,9 +36,6 @@ def main(argv): ...@@ -36,9 +36,6 @@ def main(argv):
n_goals = 5 n_goals = 5
min_dist = 5 min_dist = 5
# We are training an Agent using the Tree Observation with depth 2
observation_builder = TreeObsForRailEnv(max_depth=2)
# Use a the malfunction generator to break agents from time to time # Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence 'malfunction_rate': 30, # Rate of malfunction occurence
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment