From b8c300f929f02ba869ba2b255913a8eab0a34e91 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Fri, 19 Apr 2019 16:07:10 +0200 Subject: [PATCH] setting up training environment --- examples/temporary_example.py | 2 +- examples/training_navigation.py | 16 ++++++++-------- flatland/agents/dueling_double_dqn.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/temporary_example.py b/examples/temporary_example.py index c015f614..67fa4616 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -26,7 +26,7 @@ transition_probability = [1.0, # empty cell - Case 0 0.5, # Case 4 - single slip 0.1, # Case 5 - double slip 0.2, # Case 6 - symmetrical - 0.01] # Case 7 - dead end + 1.0] # Case 7 - dead end # Example generate a random rail env = RailEnv(width=20, diff --git a/examples/training_navigation.py b/examples/training_navigation.py index f81a50dd..33fe287d 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -32,9 +32,9 @@ env = RailEnv(width=20, env.reset() env_renderer = RenderTool(env) -env_renderer.renderEnv(show=True) +handle = env.get_agent_handles() -state_size = 5 +state_size = 105 action_size = 4 agent = Agent(state_size, action_size, "FC", 0) @@ -49,7 +49,6 @@ env = RailEnv(width=6, number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2)) -handle = env.get_agent_handles() env.agents_position[0] = [1, 4] env.agents_target[0] = [1, 1] @@ -62,14 +61,15 @@ env.obs_builder.reset() # print(env.obs_builder.distance_map[0, :, :, i]) obs, all_rewards, done, _ = env.step({0:0}) -print(len(obs[0])) -env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) +#env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) env_renderer = RenderTool(env) -env_renderer.renderEnv(show=True) +action_dict = {0: 0} for step in range(100): obs, all_rewards, done, _ = env.step(action_dict) - action_dict = {} + action = agent.act(np.array(obs[0]),eps=1) + + action_dict = {0 :action} print("Rewards: ", all_rewards, " [done=", done, "]") - env_renderer.renderEnv(show=True) + diff --git a/flatland/agents/dueling_double_dqn.py b/flatland/agents/dueling_double_dqn.py index 63a1badb..3eacf4c9 100644 --- a/flatland/agents/dueling_double_dqn.py +++ b/flatland/agents/dueling_double_dqn.py @@ -2,7 +2,7 @@ import numpy as np import random from collections import namedtuple, deque import os -from agent.model import QNetwork, QNetwork2 +from flatland.agents.model import QNetwork, QNetwork2 import torch import torch.nn.functional as F import torch.optim as optim -- GitLab