Skip to content
Snippets Groups Projects
Commit 44844ffa authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

added shortest path predictor to multi agent training and inference

parent 1e8ee7fd
No related branches found
No related tags found
No related merge requests found
...@@ -30,7 +30,7 @@ y_dim = env.height ...@@ -30,7 +30,7 @@ y_dim = env.height
# Parameters for the Environment # Parameters for the Environment
x_dim = 25 x_dim = 25
y_dim = 25 y_dim = 25
n_agents = 1 n_agents = 10
# We are training an Agent using the Tree Observation with depth 2 # We are training an Agent using the Tree Observation with depth 2
observation_builder = TreeObsForRailEnv(max_depth=2) observation_builder = TreeObsForRailEnv(max_depth=2)
...@@ -43,13 +43,13 @@ stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents ...@@ -43,13 +43,13 @@ stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents
} }
# Custom observation builder # Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2) TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
# Different agent types (trains) with different speeds. # Different agent types (trains) with different speeds.
speed_ration_map = {1.: 1., # Fast passenger train speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.0, # Fast freight train 1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.0, # Slow commuter train 1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.0} # Slow freight train 1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
...@@ -93,7 +93,7 @@ action_prob = [0] * action_size ...@@ -93,7 +93,7 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size) agent = Agent(state_size, action_size)
with path(torch_training.Nets, "avoider_checkpoint1000.pth") as file_in: with path(torch_training.Nets, "avoider_checkpoint100.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False record_images = False
......
...@@ -14,6 +14,8 @@ import torch ...@@ -14,6 +14,8 @@ import torch
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.envs.schedule_generators import sparse_schedule_generator
...@@ -48,7 +50,7 @@ def main(argv): ...@@ -48,7 +50,7 @@ def main(argv):
} }
# Custom observation builder # Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2) TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
# Different agent types (trains) with different speeds. # Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train speed_ration_map = {1.: 0.25, # Fast passenger train
......
...@@ -4,7 +4,6 @@ from collections import deque ...@@ -4,7 +4,6 @@ from collections import deque
import numpy as np import numpy as np
import torch import torch
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.envs.schedule_generators import sparse_schedule_generator
...@@ -67,7 +66,6 @@ env = RailEnv(width=x_dim, ...@@ -67,7 +66,6 @@ env = RailEnv(width=x_dim,
obs_builder_object=TreeObservation) obs_builder_object=TreeObservation)
env.reset(True, True) env.reset(True, True)
observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
num_features_per_node = env.obs_builder.observation_dim num_features_per_node = env.obs_builder.observation_dim
......
...@@ -13,8 +13,8 @@ import numpy as np ...@@ -13,8 +13,8 @@ import numpy as np
import torch import torch
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from flatland.envs.observations import TreeObsForRailEnv flatland.envs.rail_env
from flatland.envs.rail_env import RailEnv import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment