diff --git a/requirements_torch_training.txt b/requirements_torch_training.txt index 03f63837b289808620da54561c631649b1d88d55..d8c4a46b356cde45be5563ff77cac2042cfa3ff1 100644 --- a/requirements_torch_training.txt +++ b/requirements_torch_training.txt @@ -1,2 +1,4 @@ git+https://gitlab.aicrowd.com/flatland/flatland.git@master +importlib-metadata>=0.17 +importlib_resources>=1.0.2 torch>=1.1.0 \ No newline at end of file diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index ba9f144ac2edbebdd24706129bb339f1614b3c67..0037f95697ed9ec6099e597211b5aacd31ede567 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -1,11 +1,11 @@ from collections import deque -from sys import path import matplotlib.pyplot as plt import numpy as np import random import torch from dueling_double_dqn import Agent +from importlib_resources import path import torch_training.Nets from flatland.envs.generators import complex_rail_generator diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 1857b676f0fd4fd65f90ed6e37930988519c6cda..c52a891163ab538014efd44c7c3d9b8df5530a4d 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -1,12 +1,11 @@ -from sys import path - -import random from collections import deque import matplotlib.pyplot as plt import numpy as np +import random import torch from dueling_double_dqn import Agent +from importlib_resources import path import torch_training.Nets from flatland.envs.generators import complex_rail_generator @@ -14,7 +13,6 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool - from utils.observation_utils import norm_obs_clip, split_tree random.seed(1) @@ -70,7 +68,7 @@ file_load = False """ observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) -env_renderer = RenderTool(env, gl="PILSVG",) +env_renderer = RenderTool(env, gl="PILSVG", ) handle = env.get_agent_handles() features_per_node = 9 state_size = features_per_node * 85 * 2 @@ -94,11 +92,9 @@ agent = Agent(state_size, action_size, "FC", 0) with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) - demo = True record_images = False - for trials in range(1, n_trials + 1): if trials % 50 == 0 and not demo: @@ -136,7 +132,7 @@ for trials in range(1, n_trials + 1): agent_data = np.clip(agent_data, -1, 1) obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) agent_data = env.agents[a] - speed = 1 #np.random.randint(1,5) + speed = 1 # np.random.randint(1,5) agent_data.speed_data['speed'] = 1. / speed for i in range(2): @@ -145,7 +141,6 @@ for trials in range(1, n_trials + 1): for a in range(env.get_num_agents()): agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) - score = 0 env_done = 0 # Run episode @@ -206,10 +201,10 @@ for trials in range(1, n_trials + 1): print( '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( env.get_num_agents(), x_dim, y_dim, - trials, - np.mean(scores_window), - 100 * np.mean(done_window), - eps, action_prob / np.sum(action_prob)), end=" ") + trials, + np.mean(scores_window), + 100 * np.mean(done_window), + eps, action_prob / np.sum(action_prob)), end=" ") if trials % 100 == 0: print(