Skip to content
Snippets Groups Projects
Commit 6bc8fcd3 authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch '141-different-agent-classes' into 'master'

#141 different agent classes

See merge request !4
parents 9a7c3401 9a7e2fc1
No related branches found
No related tags found
1 merge request!4#141 different agent classes
...@@ -2,8 +2,9 @@ import numpy as np ...@@ -2,8 +2,9 @@ import numpy as np
from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.seed import seed as set_seed from ray.rllib.utils.seed import seed as set_seed
from flatland.envs.generators import complex_rail_generator, random_rail_generator
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator, random_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator
class RailEnvRLLibWrapper(MultiAgentEnv): class RailEnvRLLibWrapper(MultiAgentEnv):
...@@ -25,19 +26,25 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -25,19 +26,25 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
min_dist=config['min_dist'], min_dist=config['min_dist'],
nr_extra=config['nr_extra'], nr_extra=config['nr_extra'],
seed=config['seed'] * (1 + vector_index)) seed=config['seed'] * (1 + vector_index))
self.schedule_generator = complex_schedule_generator()
elif config['rail_generator'] == "random_rail_generator": elif config['rail_generator'] == "random_rail_generator":
self.rail_generator = random_rail_generator() self.rail_generator = random_rail_generator()
self.schedule_generator = random_schedule_generator()
elif config['rail_generator'] == "load_env": elif config['rail_generator'] == "load_env":
self.predefined_env = True self.predefined_env = True
self.rail_generator = random_rail_generator() self.rail_generator = random_rail_generator()
self.schedule_generator = random_schedule_generator()
else: else:
raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}') raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}')
set_seed(config['seed'] * (1 + vector_index)) set_seed(config['seed'] * (1 + vector_index))
self.env = RailEnv(width=config["width"], height=config["height"], self.env = RailEnv(width=config["width"], height=config["height"],
number_of_agents=config["number_of_agents"], number_of_agents=config["number_of_agents"],
obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator) obs_builder_object=config['obs_builder'],
rail_generator=self.rail_generator,
schedule_generator=self.schedule_generator
)
if self.predefined_env: if self.predefined_env:
self.env.load_resource('torch_training.railway', 'complex_scene.pkl') self.env.load_resource('torch_training.railway', 'complex_scene.pkl')
......
import random import random
import time import time
from collections import deque
import numpy as np import numpy as np
from flatland.envs.generators import complex_rail_generator, rail_from_file
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.utils.rendertools import RenderTool
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator, rail_from_file
from utils.observation_utils import norm_obs_clip, split_tree, max_lt from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.rendertools import RenderTool
# Time factor to test the max time allowed for an env. # Time factor to test the max time allowed for an env.
max_time_factor = 1 max_time_factor = 1
def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='*'): def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='*'):
""" """
Call in a loop to create terminal progress bar Call in a loop to create terminal progress bar
...@@ -54,7 +54,6 @@ def run_test(parameters, agent, observation_builder=None, observation_wrapper=No ...@@ -54,7 +54,6 @@ def run_test(parameters, agent, observation_builder=None, observation_wrapper=No
random.seed(parameters[3]) random.seed(parameters[3])
np.random.seed(parameters[3]) np.random.seed(parameters[3])
printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20)
for trial in range(nr_trials_per_test): for trial in range(nr_trials_per_test):
# Reset the env # Reset the env
...@@ -73,7 +72,6 @@ def run_test(parameters, agent, observation_builder=None, observation_wrapper=No ...@@ -73,7 +72,6 @@ def run_test(parameters, agent, observation_builder=None, observation_wrapper=No
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
obs[a] = observation_wrapper(obs[a]) obs[a] = observation_wrapper(obs[a])
# Run episode # Run episode
trial_score = 0 trial_score = 0
max_steps = int(max_time_factor * (env.height + env.width)) max_steps = int(max_time_factor * (env.height + env.width))
...@@ -115,6 +113,7 @@ def create_testfiles(parameters, test_nr=0, nr_trials_per_test=100): ...@@ -115,6 +113,7 @@ def create_testfiles(parameters, test_nr=0, nr_trials_per_test=100):
rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=parameters[3]), seed=parameters[3]),
schedule_generator=complex_schedule_generator(),
obs_builder_object=TreeObsForRailEnv(max_depth=2), obs_builder_object=TreeObsForRailEnv(max_depth=2),
number_of_agents=parameters[2]) number_of_agents=parameters[2])
printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20)
...@@ -151,6 +150,7 @@ def render_test(parameters, test_nr=0, nr_examples=5): ...@@ -151,6 +150,7 @@ def render_test(parameters, test_nr=0, nr_examples=5):
env_renderer.close_window() env_renderer.close_window()
return return
def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3): def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3):
# Parameter initialization # Parameter initialization
features_per_node = 9 features_per_node = 9
...@@ -168,7 +168,6 @@ def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3): ...@@ -168,7 +168,6 @@ def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3):
random.seed(parameters[3]) random.seed(parameters[3])
np.random.seed(parameters[3]) np.random.seed(parameters[3])
printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20)
for trial in range(nr_trials_per_test): for trial in range(nr_trials_per_test):
# Reset the env # Reset the env
...@@ -177,7 +176,8 @@ def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3): ...@@ -177,7 +176,8 @@ def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3):
env = RailEnv(width=3, env = RailEnv(width=3,
height=3, height=3,
rail_generator=rail_from_file(file_name), rail_generator=rail_from_file(file_name),
obs_builder_object=TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=tree_depth,
predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=1, number_of_agents=1,
) )
...@@ -185,7 +185,7 @@ def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3): ...@@ -185,7 +185,7 @@ def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3):
done = env.dones done = env.dones
# Run episode # Run episode
trial_score = 0 trial_score = 0
max_steps = int(max_time_factor* (env.height + env.width)) max_steps = int(max_time_factor * (env.height + env.width))
for step in range(max_steps): for step in range(max_steps):
# Action # Action
......
from sequential_agent.simple_order_agent import OrderedAgent import numpy as np
from flatland.envs.generators import rail_from_file, complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
import numpy as np from sequential_agent.simple_order_agent import OrderedAgent
np.random.seed(2) np.random.seed(2)
""" """
...@@ -29,6 +31,7 @@ env = RailEnv(width=x_dim, ...@@ -29,6 +31,7 @@ env = RailEnv(width=x_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=0), seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=n_agents) number_of_agents=n_agents)
env.reset(True, True) env.reset(True, True)
......
...@@ -3,14 +3,14 @@ from collections import deque ...@@ -3,14 +3,14 @@ from collections import deque
import numpy as np import numpy as np
import torch import torch
from flatland.envs.generators import rail_from_file
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from importlib_resources import path from importlib_resources import path
from observation_builders.observations import TreeObsForRailEnv from observation_builders.observations import TreeObsForRailEnv
from predictors.predictions import ShortestPathPredictorForRailEnv from predictors.predictions import ShortestPathPredictorForRailEnv
import torch_training.Nets import torch_training.Nets
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.utils.rendertools import RenderTool
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import normalize_observation from utils.observation_utils import normalize_observation
...@@ -41,6 +41,7 @@ env = RailEnv(width=x_dim, ...@@ -41,6 +41,7 @@ env = RailEnv(width=x_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=0), seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=observation_helper, obs_builder_object=observation_helper,
number_of_agents=n_agents) number_of_agents=n_agents)
env.reset(True, True) env.reset(True, True)
......
...@@ -7,15 +7,16 @@ from collections import deque ...@@ -7,15 +7,16 @@ from collections import deque
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
# Import Flatland/ Observations and Predictors
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from importlib_resources import path from importlib_resources import path
# Import Torch and utility functions to normalize observation # Import Torch and utility functions to normalize observation
import torch_training.Nets import torch_training.Nets
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
# Import Flatland/ Observations and Predictors
from flatland.envs.schedule_generators import complex_schedule_generator
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import normalize_observation from utils.observation_utils import normalize_observation
...@@ -56,6 +57,7 @@ def main(argv): ...@@ -56,6 +57,7 @@ def main(argv):
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=0), seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=observation_helper, obs_builder_object=observation_helper,
number_of_agents=n_agents) number_of_agents=n_agents)
env.reset(True, True) env.reset(True, True)
...@@ -113,6 +115,7 @@ def main(argv): ...@@ -113,6 +115,7 @@ def main(argv):
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=0), seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=observation_helper, obs_builder_object=observation_helper,
number_of_agents=n_agents) number_of_agents=n_agents)
......
...@@ -7,15 +7,16 @@ from collections import deque ...@@ -7,15 +7,16 @@ from collections import deque
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
# Import Flatland/ Observations and Predictors
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from importlib_resources import path from importlib_resources import path
# Import Torch and utility functions to normalize observation # Import Torch and utility functions to normalize observation
import torch_training.Nets import torch_training.Nets
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
# Import Flatland/ Observations and Predictors
from flatland.envs.schedule_generators import complex_schedule_generator
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import norm_obs_clip, split_tree from utils.observation_utils import norm_obs_clip, split_tree
...@@ -52,6 +53,7 @@ def main(argv): ...@@ -52,6 +53,7 @@ def main(argv):
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=0), seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=observation_helper, obs_builder_object=observation_helper,
number_of_agents=n_agents) number_of_agents=n_agents)
env.reset(True, True) env.reset(True, True)
...@@ -109,6 +111,7 @@ def main(argv): ...@@ -109,6 +111,7 @@ def main(argv):
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=0), seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=TreeObsForRailEnv(max_depth=3, obs_builder_object=TreeObsForRailEnv(max_depth=3,
predictor=ShortestPathPredictorForRailEnv()), predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=n_agents) number_of_agents=n_agents)
......
...@@ -3,14 +3,15 @@ from collections import deque ...@@ -3,14 +3,15 @@ from collections import deque
import numpy as np import numpy as np
import torch import torch
from flatland.envs.generators import complex_rail_generator from importlib_resources import path
import torch_training.Nets
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from importlib_resources import path
import torch_training.Nets
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import norm_obs_clip, split_tree from utils.observation_utils import norm_obs_clip, split_tree
...@@ -37,6 +38,7 @@ env = RailEnv(width=x_dim, ...@@ -37,6 +38,7 @@ env = RailEnv(width=x_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=0), seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=n_agents) number_of_agents=n_agents)
env.reset(True, True) env.reset(True, True)
......
...@@ -7,11 +7,12 @@ import matplotlib.pyplot as plt ...@@ -7,11 +7,12 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
from dueling_double_dqn import Agent from dueling_double_dqn import Agent
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from utils.observation_utils import norm_obs_clip, split_tree from utils.observation_utils import norm_obs_clip, split_tree
...@@ -44,6 +45,7 @@ def main(argv): ...@@ -44,6 +45,7 @@ def main(argv):
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=0), seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=observation_builder, obs_builder_object=observation_builder,
number_of_agents=n_agents) number_of_agents=n_agents)
env.reset(True, True) env.reset(True, True)
......
...@@ -3,15 +3,16 @@ import time ...@@ -3,15 +3,16 @@ import time
from collections import deque from collections import deque
import numpy as np import numpy as np
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from line_profiler import LineProfiler from line_profiler import LineProfiler
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from utils.observation_utils import norm_obs_clip, split_tree from utils.observation_utils import norm_obs_clip, split_tree
def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '*'): def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='*'):
""" """
Call in a loop to create terminal progress bar Call in a loop to create terminal progress bar
@params: @params:
...@@ -31,13 +32,14 @@ def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, ...@@ -31,13 +32,14 @@ def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1,
if iteration == total: if iteration == total:
print('') print('')
class RandomAgent: class RandomAgent:
def __init__(self, state_size, action_size): def __init__(self, state_size, action_size):
self.state_size = state_size self.state_size = state_size
self.action_size = action_size self.action_size = action_size
def act(self, state, eps = 0): def act(self, state, eps=0):
""" """
:param state: input is the observation of the agent :param state: input is the observation of the agent
:return: returns an action :return: returns an action
...@@ -87,6 +89,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): ...@@ -87,6 +89,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=parameters[3]), seed=parameters[3]),
schedule_generator=complex_schedule_generator(),
obs_builder_object=GlobalObsForRailEnv(), obs_builder_object=GlobalObsForRailEnv(),
number_of_agents=parameters[2]) number_of_agents=parameters[2])
max_steps = int(3 * (env.height + env.width)) max_steps = int(3 * (env.height + env.width))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment