From 60b730d019313dade48abc04a784bff3fb3e37ef Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 9 Jul 2019 11:50:19 +0200 Subject: [PATCH] #42 run baselines in ci --- requirements_torch_training.txt | 2 ++ torch_training/multi_agent_training.py | 2 +- torch_training/training_navigation.py | 21 ++++++++------------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/requirements_torch_training.txt b/requirements_torch_training.txt index 03f6383..d8c4a46 100644 --- a/requirements_torch_training.txt +++ b/requirements_torch_training.txt @@ -1,2 +1,4 @@ git+https://gitlab.aicrowd.com/flatland/flatland.git@master +importlib-metadata>=0.17 +importlib_resources>=1.0.2 torch>=1.1.0 \ No newline at end of file diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index ba9f144..0037f95 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -1,11 +1,11 @@ from collections import deque -from sys import path import matplotlib.pyplot as plt import numpy as np import random import torch from dueling_double_dqn import Agent +from importlib_resources import path import torch_training.Nets from flatland.envs.generators import complex_rail_generator diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 1857b67..c52a891 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -1,12 +1,11 @@ -from sys import path - -import random from collections import deque import matplotlib.pyplot as plt import numpy as np +import random import torch from dueling_double_dqn import Agent +from importlib_resources import path import torch_training.Nets from flatland.envs.generators import complex_rail_generator @@ -14,7 +13,6 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool - from utils.observation_utils import norm_obs_clip, split_tree random.seed(1) @@ -70,7 +68,7 @@ file_load = False """ observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) -env_renderer = RenderTool(env, gl="PILSVG",) +env_renderer = RenderTool(env, gl="PILSVG", ) handle = env.get_agent_handles() features_per_node = 9 state_size = features_per_node * 85 * 2 @@ -94,11 +92,9 @@ agent = Agent(state_size, action_size, "FC", 0) with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) - demo = True record_images = False - for trials in range(1, n_trials + 1): if trials % 50 == 0 and not demo: @@ -136,7 +132,7 @@ for trials in range(1, n_trials + 1): agent_data = np.clip(agent_data, -1, 1) obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) agent_data = env.agents[a] - speed = 1 #np.random.randint(1,5) + speed = 1 # np.random.randint(1,5) agent_data.speed_data['speed'] = 1. / speed for i in range(2): @@ -145,7 +141,6 @@ for trials in range(1, n_trials + 1): for a in range(env.get_num_agents()): agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) - score = 0 env_done = 0 # Run episode @@ -206,10 +201,10 @@ for trials in range(1, n_trials + 1): print( '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( env.get_num_agents(), x_dim, y_dim, - trials, - np.mean(scores_window), - 100 * np.mean(done_window), - eps, action_prob / np.sum(action_prob)), end=" ") + trials, + np.mean(scores_window), + 100 * np.mean(done_window), + eps, action_prob / np.sum(action_prob)), end=" ") if trials % 100 == 0: print( -- GitLab