Skip to content
Snippets Groups Projects
Commit f80e77a6 authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch '57-access-resources-through-importlib_resources' into 'master'

57 access resources through importlib resources

See merge request !1
parents 3f36d20a 14893138
No related branches found
No related tags found
1 merge request!157 access resources through importlib resources
Showing
with 197 additions and 119 deletions
include AUTHORS.rst
include CONTRIBUTING.rst
include HISTORY.rst
include LICENSE
include README.rst
include requirements_torch_training.txt
include requirements_RLLib_training.txt
recursive-include tests *
recursive-exclude * __pycache__
recursive-exclude * *.py[co]
recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif
from flatland.envs.rail_env import RailEnv import numpy as np
from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv
from flatland.envs.observations import TreeObsForRailEnv
from ray.rllib.utils.seed import seed as set_seed from ray.rllib.utils.seed import seed as set_seed
from flatland.envs.generators import complex_rail_generator, random_rail_generator from flatland.envs.generators import complex_rail_generator, random_rail_generator
import numpy as np from flatland.envs.rail_env import RailEnv
class RailEnvRLLibWrapper(MultiAgentEnv): class RailEnvRLLibWrapper(MultiAgentEnv):
...@@ -20,24 +20,25 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -20,24 +20,25 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
if config['rail_generator'] == "complex_rail_generator": if config['rail_generator'] == "complex_rail_generator":
self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5, self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5,
nr_extra=config['nr_extra'], seed=config['seed'] * (1+vector_index)) nr_extra=config['nr_extra'],
seed=config['seed'] * (1 + vector_index))
elif config['rail_generator'] == "random_rail_generator": elif config['rail_generator'] == "random_rail_generator":
self.rail_generator = random_rail_generator() self.rail_generator = random_rail_generator()
elif config['rail_generator'] == "load_env": elif config['rail_generator'] == "load_env":
self.predefined_env = True self.predefined_env = True
else: else:
raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}') raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}')
set_seed(config['seed'] * (1+vector_index)) set_seed(config['seed'] * (1 + vector_index))
self.env = RailEnv(width=config["width"], height=config["height"], self.env = RailEnv(width=config["width"], height=config["height"],
number_of_agents=config["number_of_agents"], number_of_agents=config["number_of_agents"],
obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator, obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator,
prediction_builder_object=config['predictor']) prediction_builder_object=config['predictor'])
if self.predefined_env: if self.predefined_env:
self.env.load(config['load_env_path']) self.env.load(config['load_env_path'])
# '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl') self.env.load_resource('torch_training.railway', config['load_env_path'])
self.width = self.env.width self.width = self.env.width
self.height = self.env.height self.height = self.env.height
...@@ -60,7 +61,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -60,7 +61,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
o = dict() o = dict()
for i_agent in range(len(self.env.agents)): for i_agent in range(len(self.env.agents)):
if predictions != {}: if predictions != {}:
pred_obs = self.get_prediction_as_observation(pred_pos, pred_dir, i_agent) pred_obs = self.get_prediction_as_observation(pred_pos, pred_dir, i_agent)
...@@ -71,13 +72,13 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -71,13 +72,13 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
o[i_agent] = obs[i_agent] o[i_agent] = obs[i_agent]
# needed for the renderer # needed for the renderer
self.rail = self.env.rail self.rail = self.env.rail
self.agents = self.env.agents self.agents = self.env.agents
self.agents_static = self.env.agents_static self.agents_static = self.env.agents_static
self.dev_obs_dict = self.env.dev_obs_dict self.dev_obs_dict = self.env.dev_obs_dict
if self.step_memory < 2: if self.step_memory < 2:
return o return o
else: else:
self.old_obs = o self.old_obs = o
oo = dict() oo = dict()
...@@ -121,9 +122,9 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -121,9 +122,9 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
for i_agent in range(len(self.env.agents)): for i_agent in range(len(self.env.agents)):
if i_agent not in self.agents_done: if i_agent not in self.agents_done:
oo[i_agent] = [o[i_agent], self.old_obs[i_agent]] oo[i_agent] = [o[i_agent], self.old_obs[i_agent]]
self.old_obs = o self.old_obs = o
for agent, done in dones.items(): for agent, done in dones.items():
if done and agent != '__all__': if done and agent != '__all__':
self.agents_done.append(agent) self.agents_done.append(agent)
...@@ -190,8 +191,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -190,8 +191,8 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
elif collision_info[1] == 0: elif collision_info[1] == 0:
# In this case, the other agent (agent 2) was on the same cell at t-1 # In this case, the other agent (agent 2) was on the same cell at t-1
# There is a collision if agent 2 is at t, on the cell where was agent 1 at t-1 # There is a collision if agent 2 is at t, on the cell where was agent 1 at t-1
coord_agent_1_t_minus_1 = pred_pos[agent_handle, time_offset-1, 0] + \ coord_agent_1_t_minus_1 = pred_pos[agent_handle, time_offset - 1, 0] + \
1000 * pred_pos[agent_handle, time_offset, 1] 1000 * pred_pos[agent_handle, time_offset, 1]
coord_agent_2_t = coord_other_agents[collision_info[0], 1] coord_agent_2_t = coord_other_agents[collision_info[0], 1]
if coord_agent_1_t_minus_1 == coord_agent_2_t: if coord_agent_1_t_minus_1 == coord_agent_2_t:
pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1 pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1
...@@ -200,7 +201,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -200,7 +201,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
# In this case, the other agent (agent 2) will be on the same cell at t+1 # In this case, the other agent (agent 2) will be on the same cell at t+1
# There is a collision if agent 2 is at t, on the cell where will be agent 1 at t+1 # There is a collision if agent 2 is at t, on the cell where will be agent 1 at t+1
coord_agent_1_t_plus_1 = pred_pos[agent_handle, time_offset + 1, 0] + \ coord_agent_1_t_plus_1 = pred_pos[agent_handle, time_offset + 1, 0] + \
1000 * pred_pos[agent_handle, time_offset, 1] 1000 * pred_pos[agent_handle, time_offset, 1]
coord_agent_2_t = coord_other_agents[collision_info[0], 1] coord_agent_2_t = coord_other_agents[collision_info[0], 1]
if coord_agent_1_t_plus_1 == coord_agent_2_t: if coord_agent_1_t_plus_1 == coord_agent_2_t:
pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1 pred_obs[time_offset, collision_info[0] + 1 * (collision_info[0] >= agent_handle)] = 1
......
from flatland.envs import rail_env
from flatland.envs.rail_env import random_rail_generator
from baselines.RailEnvRLLibWrapper import RailEnvRLLibWrapper
from flatland.utils.rendertools import RenderTool
import random import random
import gym
import matplotlib.pyplot as plt
from flatland.envs.generators import complex_rail_generator
import gym
import numpy as np
import ray
import ray.rllib.agents.ppo.ppo as ppo import ray.rllib.agents.ppo.ppo as ppo
import ray.rllib.agents.dqn.dqn as dqn from RailEnvRLLibWrapper import RailEnvRLLibWrapper
from ray.rllib.agents.ppo.ppo import PPOTrainer from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.tune.registry import register_env
from ray.rllib.models import ModelCatalog from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print from ray.tune.logger import pretty_print
from baselines.CustomPreprocessor import CustomPreprocessor
import ray
import numpy as np
from ray.rllib.env.multi_agent_env import MultiAgentEnv
# RailEnv.__bases__ = (RailEnv.__bases__[0], MultiAgentEnv)
from RLLib_training.custom_preprocessors import CustomPreprocessor
from flatland.envs.generators import complex_rail_generator
ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor) ModelCatalog.register_custom_preprocessor("my_prep", CustomPreprocessor)
ray.init() ray.init()
def train(config): def train(config):
print('Init Env') print('Init Env')
random.seed(1) random.seed(1)
...@@ -52,28 +35,10 @@ def train(config): ...@@ -52,28 +35,10 @@ def train(config):
1] # Case 2b (10) - simple switch mirrored 1] # Case 2b (10) - simple switch mirrored
# Example generate a random rail # Example generate a random rail
"""
env = RailEnv(width=10,
height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1)
"""
env_config = {"width": 20, env_config = {"width": 20,
"height":20, "height": 20,
"rail_generator":complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0), "rail_generator": complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0),
"number_of_agents":5} "number_of_agents": 5}
"""
env = RailEnv(width=20,
height=20,
rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
['../notebooks/temp.npy']),
number_of_agents=3)
"""
# if config['render']:
# env_renderer = RenderTool(env, gl="QT")
# plt.figure(figsize=(5,5))
obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,)) obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
act_space = gym.spaces.Discrete(4) act_space = gym.spaces.Discrete(4)
...@@ -94,13 +59,13 @@ def train(config): ...@@ -94,13 +59,13 @@ def train(config):
agent_config["horizon"] = 50 agent_config["horizon"] = 50
agent_config["num_workers"] = 0 agent_config["num_workers"] = 0
# agent_config["sample_batch_size"]: 1000 # agent_config["sample_batch_size"]: 1000
#agent_config["num_cpus_per_worker"] = 40 # agent_config["num_cpus_per_worker"] = 40
#agent_config["num_gpus"] = 2.0 # agent_config["num_gpus"] = 2.0
#agent_config["num_gpus_per_worker"] = 2.0 # agent_config["num_gpus_per_worker"] = 2.0
#agent_config["num_cpus_for_driver"] = 5 # agent_config["num_cpus_for_driver"] = 5
#agent_config["num_envs_per_worker"] = 15 # agent_config["num_envs_per_worker"] = 15
agent_config["env_config"] = env_config agent_config["env_config"] = env_config
#agent_config["batch_mode"] = "complete_episodes" # agent_config["batch_mode"] = "complete_episodes"
ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config) ppo_trainer = PPOTrainer(env=RailEnvRLLibWrapper, config=agent_config)
...@@ -114,10 +79,5 @@ def train(config): ...@@ -114,10 +79,5 @@ def train(config):
# checkpoint = ppo_trainer.save() # checkpoint = ppo_trainer.save()
# print("checkpoint saved at", checkpoint) # print("checkpoint saved at", checkpoint)
train({})
train({})
from baselines.RLLib_training.RailEnvRLLibWrapper import RailEnvRLLibWrapper import os
import gym
import gin import gin
import gym
from flatland.envs.generators import complex_rail_generator from importlib_resources import path
# Import PPO trainer: we can replace these imports by any other trainer from RLLib. # Import PPO trainer: we can replace these imports by any other trainer from RLLib.
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
# from baselines.CustomPPOTrainer import PPOTrainer as Trainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
# from baselines.CustomPPOPolicyGraph import CustomPPOPolicyGraph as PolicyGraph
from ray.rllib.models import ModelCatalog from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print
from baselines.RLLib_training.custom_preprocessors import CustomPreprocessor, ConvModelPreprocessor
from baselines.RLLib_training.custom_models import ConvModelGlobalObs
from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv
gin.external_configurable(DummyPredictorForRailEnv)
gin.external_configurable(DummyPredictorForRailEnv)
import ray import ray
import numpy as np
from ray.tune.logger import UnifiedLogger from ray.tune.logger import UnifiedLogger
from ray.tune.logger import pretty_print
from RailEnvRLLibWrapper import RailEnvRLLibWrapper
from custom_models import ConvModelGlobalObs
from custom_preprocessors import CustomPreprocessor, ConvModelPreprocessor
import tempfile import tempfile
from ray import tune from ray import tune
from ray.rllib.utils.seed import seed as set_seed from ray.rllib.utils.seed import seed as set_seed
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv,\ from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv, \
LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent LocalObsForRailEnv, GlobalObsForRailEnvDirectionDependent
gin.external_configurable(TreeObsForRailEnv) gin.external_configurable(TreeObsForRailEnv)
gin.external_configurable(GlobalObsForRailEnv) gin.external_configurable(GlobalObsForRailEnv)
...@@ -45,7 +40,9 @@ ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor) ...@@ -45,7 +40,9 @@ ModelCatalog.register_custom_preprocessor("tree_obs_prep", CustomPreprocessor)
ModelCatalog.register_custom_preprocessor("global_obs_prep", TupleFlatteningPreprocessor) ModelCatalog.register_custom_preprocessor("global_obs_prep", TupleFlatteningPreprocessor)
ModelCatalog.register_custom_preprocessor("conv_obs_prep", ConvModelPreprocessor) ModelCatalog.register_custom_preprocessor("conv_obs_prep", ConvModelPreprocessor)
ModelCatalog.register_custom_model("conv_model", ConvModelGlobalObs) ModelCatalog.register_custom_model("conv_model", ConvModelGlobalObs)
ray.init()#object_store_memory=150000000000, redis_max_memory=30000000000) ray.init() # object_store_memory=150000000000, redis_max_memory=30000000000)
__file_dirname__ = os.path.dirname(os.path.realpath(__file__))
def train(config, reporter): def train(config, reporter):
...@@ -67,11 +64,13 @@ def train(config, reporter): ...@@ -67,11 +64,13 @@ def train(config, reporter):
# Observation space and action space definitions # Observation space and action space definitions
if isinstance(config["obs_builder"], TreeObsForRailEnv): if isinstance(config["obs_builder"], TreeObsForRailEnv):
if config['predictor'] is None: if config['predictor'] is None:
obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)), ) * config['step_memory']) obs_space = gym.spaces.Tuple(
(gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),) * config['step_memory'])
else: else:
obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)), obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(147,)),
gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)), gym.spaces.Box(low=0, high=1, shape=(config['n_agents'],)),
gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) *config['step_memory']) gym.spaces.Box(low=0, high=1, shape=(20, config['n_agents'])),) * config[
'step_memory'])
preprocessor = "tree_obs_prep" preprocessor = "tree_obs_prep"
elif isinstance(config["obs_builder"], GlobalObsForRailEnv): elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
...@@ -106,7 +105,6 @@ def train(config, reporter): ...@@ -106,7 +105,6 @@ def train(config, reporter):
else: else:
raise ValueError("Undefined observation space") raise ValueError("Undefined observation space")
act_space = gym.spaces.Discrete(5) act_space = gym.spaces.Discrete(5)
# Dict with the different policies to train # Dict with the different policies to train
...@@ -117,7 +115,6 @@ def train(config, reporter): ...@@ -117,7 +115,6 @@ def train(config, reporter):
def policy_mapping_fn(agent_id): def policy_mapping_fn(agent_id):
return config['policy_folder_name'].format(**locals()) return config['policy_folder_name'].format(**locals())
# Trainer configuration # Trainer configuration
trainer_config = DEFAULT_CONFIG.copy() trainer_config = DEFAULT_CONFIG.copy()
if config['conv_model']: if config['conv_model']:
...@@ -126,8 +123,8 @@ def train(config, reporter): ...@@ -126,8 +123,8 @@ def train(config, reporter):
trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": preprocessor} trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": preprocessor}
trainer_config['multiagent'] = {"policy_graphs": policy_graphs, trainer_config['multiagent'] = {"policy_graphs": policy_graphs,
"policy_mapping_fn": policy_mapping_fn, "policy_mapping_fn": policy_mapping_fn,
"policies_to_train": list(policy_graphs.keys())} "policies_to_train": list(policy_graphs.keys())}
trainer_config["horizon"] = config['horizon'] trainer_config["horizon"] = config['horizon']
trainer_config["num_workers"] = 0 trainer_config["num_workers"] = 0
...@@ -177,7 +174,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -177,7 +174,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder, map_width, map_height, horizon, policy_folder_name, local_dir, obs_builder,
entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae, entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae,
predictor, step_memory): predictor, step_memory):
tune.run( tune.run(
train, train,
name=name, name=name,
...@@ -205,12 +201,15 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, ...@@ -205,12 +201,15 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"cpu": 5, "cpu": 5,
"gpu": 0.2 "gpu": 0.2
}, },
verbose=2,
local_dir=local_dir local_dir=local_dir
) )
if __name__ == '__main__': if __name__ == '__main__':
gin.external_configurable(tune.grid_search) gin.external_configurable(tune.grid_search)
dir = '/home/guillaume/flatland/baselines/RLLib_training/experiment_configs/experiment_agent_memory' # To Modify with path('RLLib_training.experiment_configs.experiment_agent_memory', 'config.gin') as f:
gin.parse_config_file(dir + '/config.gin') gin.parse_config_file(f)
dir = os.path.join(__file_dirname__, 'experiment_configs', 'experiment_agent_memory')
run_experiment(local_dir=dir) run_experiment(local_dir=dir)
#ray==0.7.0
gym ==0.12.5
opencv-python==4.1.0.25
#tensorflow==1.13.1
lz4==2.1.10
gin-config==0.1.4
\ No newline at end of file
torch==1.1.0
\ No newline at end of file
setup.py 0 → 100644
import os
from setuptools import setup, find_packages
# TODO: setup does not support installation from url, move to requirements*.txt
# TODO: @master as soon as mr is merged on flatland.
os.system('pip install git+https://gitlab.aicrowd.com/flatland/flatland.git@57-access-resources-through-importlib_resources')
install_reqs = []
# TODO: include requirements_RLLib_training.txt
requirements_paths = ['requirements_torch_training.txt'] #, 'requirements_RLLib_training.txt']
for requirements_path in requirements_paths:
with open(requirements_path, 'r') as f:
install_reqs += [
s for s in [
line.strip(' \n') for line in f
] if not s.startswith('#') and s != ''
]
requirements = install_reqs
setup_requirements = install_reqs
test_requirements = install_reqs
setup(
author="S.P. Mohanty",
author_email='mohanty@aicrowd.com',
classifiers=[
'Development Status :: 2 - Pre-Alpha',
'Intended Audience :: Developers',
'Natural Language :: English',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
],
description="Multi Agent Reinforcement Learning on Trains",
entry_points={
'console_scripts': [
'flatland=flatland.cli:main',
],
},
install_requires=requirements,
long_description='',
include_package_data=True,
keywords='flatland-baselines',
name='flatland-rl-baselines',
packages=find_packages('.'),
data_files=[],
setup_requires=setup_requirements,
test_suite='tests',
tests_require=test_requirements,
url='https://gitlab.aicrowd.com/flatland/baselines',
version='0.1.1',
zip_safe=False,
)
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from model import QNetwork, QNetwork2 from torch_training.model import QNetwork, QNetwork2
BUFFER_SIZE = int(1e5) # replay buffer size BUFFER_SIZE = int(1e5) # replay buffer size
BATCH_SIZE = 512 # minibatch size BATCH_SIZE = 512 # minibatch size
......
import os
import random import random
from collections import deque from collections import deque
import numpy as np import numpy as np
import torch import torch
from dueling_double_dqn import Agent
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from torch_training.dueling_double_dqn import Agent
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
__file_dirname__ = os.path.dirname(os.path.realpath(__file__))
# Example generate a rail given a manual specification, # Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation) # a map of tuples (cell_type, rotation)
transition_probability = [15, # empty cell - Case 0 transition_probability = [15, # empty cell - Case 0
...@@ -43,7 +47,7 @@ env = RailEnv(width=15, ...@@ -43,7 +47,7 @@ env = RailEnv(width=15,
env = RailEnv(width=10, env = RailEnv(width=10,
height=20) height=20)
env.load("./railway/complex_scene.pkl") env.load_resource('torch_training.railway', "complex_scene.pkl")
""" """
env = RailEnv(width=20, env = RailEnv(width=20,
...@@ -73,10 +77,11 @@ action_prob = [0] * action_size ...@@ -73,10 +77,11 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) agent.qnetwork_local.load_state_dict(torch.load(os.path.join(__file_dirname__, 'Nets', 'avoid_checkpoint15000.pth')))
demo = True demo = True
def max_lt(seq, val): def max_lt(seq, val):
""" """
Return greatest item in seq for which item < val applies. Return greatest item in seq for which item < val applies.
...@@ -133,7 +138,8 @@ for trials in range(1, n_trials + 1): ...@@ -133,7 +138,8 @@ for trials in range(1, n_trials + 1):
final_obs_next = obs.copy() final_obs_next = obs.copy()
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=7, current_depth=0) data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=7,
current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1) agent_data = np.clip(agent_data, -1, 1)
...@@ -154,7 +160,7 @@ for trials in range(1, n_trials + 1): ...@@ -154,7 +160,7 @@ for trials in range(1, n_trials + 1):
# Action # Action
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if demo: if demo:
eps = 0 eps = 1
# action = agent.act(np.array(obs[a]), eps=eps) # action = agent.act(np.array(obs[a]), eps=eps)
action = agent.act(agent_obs[a], eps=eps) action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1 action_prob[action] += 1
...@@ -164,7 +170,7 @@ for trials in range(1, n_trials + 1): ...@@ -164,7 +170,7 @@ for trials in range(1, n_trials + 1):
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=7, data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=7,
current_depth=0) current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1) agent_data = np.clip(agent_data, -1, 1)
...@@ -197,12 +203,13 @@ for trials in range(1, n_trials + 1): ...@@ -197,12 +203,13 @@ for trials in range(1, n_trials + 1):
scores.append(np.mean(scores_window)) scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window))) dones_list.append((np.mean(done_window)))
print('\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( print(
env.get_num_agents(), '\rTraining {} Agents.\t Episode {}\t Average Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
trials, env.get_num_agents(),
np.mean(scores_window), trials,
100 * np.mean(done_window), np.mean(scores_window),
eps, action_prob / np.sum(action_prob)), end=" ") 100 * np.mean(done_window),
eps, action_prob / np.sum(action_prob)), end=" ")
if trials % 100 == 0: if trials % 100 == 0:
print( print(
...@@ -214,5 +221,5 @@ for trials in range(1, n_trials + 1): ...@@ -214,5 +221,5 @@ for trials in range(1, n_trials + 1):
eps, eps,
action_prob / np.sum(action_prob))) action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(), torch.save(agent.qnetwork_local.state_dict(),
'./Nets/avoid_checkpoint' + str(trials) + '.pth') os.path.join(__file_dirname__, 'Nets', 'avoid_checkpoint' + str(trials) + '.pth'))
action_prob = [1] * action_size action_prob = [1] * action_size
tox.ini 0 → 100644
[tox]
; TODO py36, flake8
envlist = py37
[travis]
python =
; TODO: py36
3.7: py37
[testenv]
whitelist_externals = sh
pip
python
setenv =
PYTHONPATH = {toxinidir}
passenv =
DISPLAY
; HTTP_PROXY+HTTPS_PROXY required behind corporate proxies
HTTP_PROXY
HTTPS_PROXY
deps =
-r{toxinidir}/requirements_torch_training.txt
commands =
python torch_training/training_navigation.py
[flake8]
max-line-length = 120
ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W505
[testenv:flake8]
basepython = python
passenv = DISPLAY
deps =
-r{toxinidir}/requirements_torch_training.txt
commands =
flake8 torch_training
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment