diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5c98b428844d9f7d529e2b6fb918d15bf072f3df --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,2 @@ +# Default ignored files +/workspace.xml \ No newline at end of file diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 718ce3a63aec6d13d7b2d48cd222d09b4a3ff604..b310e95f19e6284cfd6fed6f195d9807551875b8 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -3,13 +3,13 @@ from collections import deque import numpy as np import torch -from flatland.envs.generators import rail_from_file, complex_rail_generator -from observation_builders.observations import TreeObsForRailEnv -from predictors.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.generators import rail_from_file from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool from importlib_resources import path -import time +from observation_builders.observations import TreeObsForRailEnv +from predictors.predictions import ShortestPathPredictorForRailEnv + import torch_training.Nets from torch_training.dueling_double_dqn import Agent from utils.observation_utils import normalize_observation @@ -73,7 +73,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -with path(torch_training.Nets, "avoid_checkpoint59900.pth") as file_in: +with path(torch_training.Nets, "avoid_checkpoint60000.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) record_images = False diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 2b836087611f5a5a0b01e47d6d91b78c16da9d42..5cc7305563076e54e25a3fb2d890e276bcbf9a38 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -118,7 +118,7 @@ def main(argv): # Only render when not triaing if not Training: - env_renderer.renderEnv(show=True, show_observations=True) + env_renderer.render_env(show=True, show_observations=True) # Chose the actions for a in range(env.get_num_agents()): @@ -210,7 +210,7 @@ def main(argv): # Run episode for step in range(max_steps): - env_renderer.renderEnv(show=True, show_observations=False) + env_renderer.render_env(show=True, show_observations=False) # Chose the actions for a in range(env.get_num_agents()):