diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 9c78a72b6dd3515de9f34ab2c24f5ac76cb7a34b..fd85fdc8e77530b0223b4f71a952887b47eb762f 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -1,23 +1,22 @@ -from datetime import datetime import os import random import sys from argparse import ArgumentParser, Namespace +from collections import deque +from datetime import datetime from pathlib import Path from pprint import pprint -import psutil -from flatland.utils.rendertools import RenderTool -from torch.utils.tensorboard import SummaryWriter -import numpy as np -import torch -from collections import deque +import numpy as np +import psutil +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters -from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.utils.rendertools import RenderTool +from torch.utils.tensorboard import SummaryWriter base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -250,6 +249,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Run episode for step in range(max_steps - 1): inference_timer.start() + policy.start_step() for agent in train_env.get_agent_handles(): if info['action_required'][agent]: update_values[agent] = True @@ -264,6 +264,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): update_values[agent] = False action = 0 action_dict.update({agent: action}) + policy.end_step() inference_timer.end() # Environment step @@ -285,7 +286,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): if update_values[agent] or done['__all__']: # Only learn from timesteps where somethings happened learn_timer.start() - policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent], done[agent]) + policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent], + done[agent]) learn_timer.end() agent_prev_obs[agent] = agent_obs[agent].copy() @@ -434,6 +436,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): final_step = 0 for step in range(max_steps - 1): + policy.start_step() for agent in env.get_agent_handles(): if tree_observation.check_is_observation_valid(agent_obs[agent]): agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth, @@ -444,7 +447,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): if tree_observation.check_is_observation_valid(agent_obs[agent]): action = policy.act(agent_obs[agent], eps=0.0) action_dict.update({agent: action}) - + policy.end_step() obs, all_rewards, done, info = env.step(action_dict) for agent in env.get_agent_handles(): @@ -495,7 +498,8 @@ if __name__ == "__main__": parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int) parser.add_argument("--render", help="render 1 episode in 100", action='store_true') parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) - parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", action='store_true') + parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", + action='store_true') parser.add_argument("--max_depth", help="max depth", default=2, type=int) training_params = parser.parse_args() diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py index d5c6a3c23dacbb182dceda52617a0be12d1acf7b..b8714d1a2fd8e085d4e9e00c48c7362846e8ed87 100644 --- a/reinforcement_learning/policy.py +++ b/reinforcement_learning/policy.py @@ -10,3 +10,9 @@ class Policy: def load(self, filename): raise NotImplementedError + + def start_step(self): + pass + + def end_step(self): + pass diff --git a/run.py b/run.py index f11c94168e9ba1856665273e716361d1a2077b50..09dc04dae944bc23aaeeeb70e0789a605eb31a4a 100644 --- a/run.py +++ b/run.py @@ -114,6 +114,7 @@ while True: if not check_if_all_blocked(env=local_env): time_start = time.time() action_dict = {} + policy.start_step() for agent in range(nb_agents): if info['action_required'][agent]: if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]): @@ -128,6 +129,7 @@ while True: if USE_ACTION_CACHE: agent_last_obs[agent] = observation[agent] agent_last_action[agent] = action + policy.end_step() agent_time = time.time() - time_start time_taken_by_controller.append(agent_time) diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index 7e4c934212449d4e5712b01456f2bb8a7d6a1a8a..8fe6caf2535f448de561ae270ef6022d1795f12a 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -21,7 +21,7 @@ class FastTreeObs(ObservationBuilder): def __init__(self, max_depth): self.max_depth = max_depth - self.observation_dim = 26 + self.observation_dim = 27 def build_data(self): if self.env is not None: