diff --git a/examples/temporary_example.py b/examples/temporary_example.py index c015f6140617c31fa020bc9a73dcdb3c9c55cc3e..67fa46162aa26b56d8e875c37b16e5c3648ad66b 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -26,7 +26,7 @@ transition_probability = [1.0, # empty cell - Case 0 0.5, # Case 4 - single slip 0.1, # Case 5 - double slip 0.2, # Case 6 - symmetrical - 0.01] # Case 7 - dead end + 1.0] # Case 7 - dead end # Example generate a random rail env = RailEnv(width=20, diff --git a/examples/training_navigation.py b/examples/training_navigation.py index f81a50ddce1837ed69145e2ad3fd33f0c4dc03f0..33fe287d6ef70ebda52c57e8b3a110541d28244e 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -32,9 +32,9 @@ env = RailEnv(width=20, env.reset() env_renderer = RenderTool(env) -env_renderer.renderEnv(show=True) +handle = env.get_agent_handles() -state_size = 5 +state_size = 105 action_size = 4 agent = Agent(state_size, action_size, "FC", 0) @@ -49,7 +49,6 @@ env = RailEnv(width=6, number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2)) -handle = env.get_agent_handles() env.agents_position[0] = [1, 4] env.agents_target[0] = [1, 1] @@ -62,14 +61,15 @@ env.obs_builder.reset() # print(env.obs_builder.distance_map[0, :, :, i]) obs, all_rewards, done, _ = env.step({0:0}) -print(len(obs[0])) -env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) +#env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) env_renderer = RenderTool(env) -env_renderer.renderEnv(show=True) +action_dict = {0: 0} for step in range(100): obs, all_rewards, done, _ = env.step(action_dict) - action_dict = {} + action = agent.act(np.array(obs[0]),eps=1) + + action_dict = {0 :action} print("Rewards: ", all_rewards, " [done=", done, "]") - env_renderer.renderEnv(show=True) + diff --git a/flatland/agents/dueling_double_dqn.py b/flatland/agents/dueling_double_dqn.py index 63a1badb7dfcfca0aae0f5b34b8766418bf2cecb..3eacf4c9a66612c87f64e4ae65b7714313ffcf64 100644 --- a/flatland/agents/dueling_double_dqn.py +++ b/flatland/agents/dueling_double_dqn.py @@ -2,7 +2,7 @@ import numpy as np import random from collections import namedtuple, deque import os -from agent.model import QNetwork, QNetwork2 +from flatland.agents.model import QNetwork, QNetwork2 import torch import torch.nn.functional as F import torch.optim as optim