Skip to content
Snippets Groups Projects
Commit b8e41dc1 authored by Erik Nygren's avatar Erik Nygren
Browse files

using new level generator for training and inference

parent d5ea50cd
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@ double_dqn = True # If using double dqn algorithm
input_channels = 5 # Number of Input channels
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
print(device)
......
......@@ -9,8 +9,8 @@ import torch_training.Nets
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import norm_obs_clip, split_tree
......@@ -27,20 +27,51 @@ x_dim = env.width
y_dim = env.height
"""
x_dim = np.random.randint(8, 20)
y_dim = np.random.randint(8, 20)
n_agents = np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim))
# Parameters for the Environment
x_dim = 20
y_dim = 20
n_agents = 1
n_goals = 5
min_dist = 5
# We are training an Agent using the Tree Observation with depth 2
observation_builder = TreeObsForRailEnv(max_depth=2)
# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
}
# Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2)
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=x_dim,
height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
max_dist=99999,
seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=n_agents)
rail_generator=sparse_rail_generator(num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=10, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=n_agents,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=TreeObservation)
env.reset(True, True)
observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
......
......@@ -10,8 +10,8 @@ from dueling_double_dqn import Agent
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
from utils.observation_utils import norm_obs_clip, split_tree
......@@ -30,8 +30,8 @@ def main(argv):
np.random.seed(1)
# Parameters for the Environment
x_dim = 10
y_dim = 10
x_dim = 20
y_dim = 20
n_agents = 1
n_goals = 5
min_dist = 5
......@@ -39,15 +39,41 @@ def main(argv):
# We are training an Agent using the Tree Observation with depth 2
observation_builder = TreeObsForRailEnv(max_depth=2)
# Load the Environment
# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
}
# Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2)
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=x_dim,
height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
max_dist=99999,
seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=observation_builder,
number_of_agents=n_agents)
rail_generator=sparse_rail_generator(num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=10, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=n_agents,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=TreeObservation)
env.reset(True, True)
# After training we want to render the results so we also load a renderer
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment