Skip to content
Snippets Groups Projects
Commit aafa76af authored by u229589's avatar u229589
Browse files

clean up Agent and dueling_double_dqn

parent 43686bfe
No related branches found
No related tags found
1 merge request!7remove unused observation and prediction files
...@@ -22,7 +22,7 @@ test_times = [] ...@@ -22,7 +22,7 @@ test_times = []
test_dones = [] test_dones = []
# Load agent # Load agent
# agent = Agent(state_size, action_size, "FC", 0) # agent = Agent(state_size, action_size, "FC", 0)
# agent.qnetwork_local.load_state_dict(torch.load('./torch_training/Nets/avoid_checkpoint1700.pth')) # agent.qnetwork_local.load_state_dict(torch.load('./torch_training/Nets/avoid_checkpoint500.pth'))
agent = RandomAgent(state_size, action_size) agent = RandomAgent(state_size, action_size)
start_time_scoring = time.time() start_time_scoring = time.time()
test_idx = 0 test_idx = 0
......
...@@ -28,8 +28,8 @@ test_dones = [] ...@@ -28,8 +28,8 @@ test_dones = []
sequential_agent_test = False sequential_agent_test = False
# Load your agent # Load your agent
agent = Agent(state_size, action_size, 0) agent = Agent(state_size, action_size)
agent.qnetwork_local.load_state_dict(torch.load('../torch_training/Nets/avoid_checkpoint60000.pth')) agent.qnetwork_local.load_state_dict(torch.load('../torch_training/Nets/avoid_checkpoint500.pth'))
# Load the necessary Observation Builder and Predictor # Load the necessary Observation Builder and Predictor
predictor = ShortestPathPredictorForRailEnv() predictor = ShortestPathPredictorForRailEnv()
......
File deleted
File deleted
File deleted
...@@ -16,38 +16,33 @@ GAMMA = 0.99 # discount factor 0.99 ...@@ -16,38 +16,33 @@ GAMMA = 0.99 # discount factor 0.99
TAU = 1e-3 # for soft update of target parameters TAU = 1e-3 # for soft update of target parameters
LR = 0.5e-4 # learning rate 0.5e-4 works LR = 0.5e-4 # learning rate 0.5e-4 works
UPDATE_EVERY = 10 # how often to update the network UPDATE_EVERY = 10 # how often to update the network
double_dqn = True # If using double dqn algorithm
input_channels = 5 # Number of Input channels
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print(device) print(device)
class Agent: class Agent:
"""Interacts with and learns from the environment.""" """Interacts with and learns from the environment."""
def __init__(self, state_size, action_size, seed, double_dqn=True, input_channels=5): def __init__(self, state_size, action_size, double_dqn=True):
"""Initialize an Agent object. """Initialize an Agent object.
Params Params
====== ======
state_size (int): dimension of each state state_size (int): dimension of each state
action_size (int): dimension of each action action_size (int): dimension of each action
seed (int): random seed
""" """
self.state_size = state_size self.state_size = state_size
self.action_size = action_size self.action_size = action_size
self.seed = random.seed(seed)
self.double_dqn = double_dqn self.double_dqn = double_dqn
# Q-Network # Q-Network
self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) self.qnetwork_local = QNetwork(state_size, action_size).to(device)
self.qnetwork_target = copy.deepcopy(self.qnetwork_local) self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR) self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
# Replay memory # Replay memory
self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE)
# Initialize time step (for updating every UPDATE_EVERY steps) # Initialize time step (for updating every UPDATE_EVERY steps)
self.t_step = 0 self.t_step = 0
...@@ -147,7 +142,7 @@ class Agent: ...@@ -147,7 +142,7 @@ class Agent:
class ReplayBuffer: class ReplayBuffer:
"""Fixed-size buffer to store experience tuples.""" """Fixed-size buffer to store experience tuples."""
def __init__(self, action_size, buffer_size, batch_size, seed): def __init__(self, action_size, buffer_size, batch_size):
"""Initialize a ReplayBuffer object. """Initialize a ReplayBuffer object.
Params Params
...@@ -155,13 +150,11 @@ class ReplayBuffer: ...@@ -155,13 +150,11 @@ class ReplayBuffer:
action_size (int): dimension of each action action_size (int): dimension of each action
buffer_size (int): maximum size of buffer buffer_size (int): maximum size of buffer
batch_size (int): size of each training batch batch_size (int): size of each training batch
seed (int): random seed
""" """
self.action_size = action_size self.action_size = action_size
self.memory = deque(maxlen=buffer_size) self.memory = deque(maxlen=buffer_size)
self.batch_size = batch_size self.batch_size = batch_size
self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
self.seed = random.seed(seed)
def add(self, state, action, reward, next_state, done): def add(self, state, action, reward, next_state, done):
"""Add a new experience to memory.""" """Add a new experience to memory."""
...@@ -183,7 +176,7 @@ class ReplayBuffer: ...@@ -183,7 +176,7 @@ class ReplayBuffer:
dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \ dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
.float().to(device) .float().to(device)
return (states, actions, rewards, next_states, dones) return states, actions, rewards, next_states, dones
def __len__(self): def __len__(self):
"""Return the current size of internal memory.""" """Return the current size of internal memory."""
......
...@@ -3,7 +3,7 @@ import torch.nn.functional as F ...@@ -3,7 +3,7 @@ import torch.nn.functional as F
class QNetwork(nn.Module): class QNetwork(nn.Module):
def __init__(self, state_size, action_size, seed, hidsize1=128, hidsize2=128): def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128):
super(QNetwork, self).__init__() super(QNetwork, self).__init__()
self.fc1_val = nn.Linear(state_size, hidsize1) self.fc1_val = nn.Linear(state_size, hidsize1)
......
...@@ -87,8 +87,8 @@ dones_list = [] ...@@ -87,8 +87,8 @@ dones_list = []
action_prob = [0] * action_size action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, 0) agent = Agent(state_size, action_size)
with path(torch_training.Nets, "avoid_checkpoint100.pth") as file_in: with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False record_images = False
......
...@@ -121,11 +121,11 @@ def main(argv): ...@@ -121,11 +121,11 @@ def main(argv):
observation_radius = 10 observation_radius = 10
# Initialize the agent # Initialize the agent
agent = Agent(state_size, action_size, 0) agent = Agent(state_size, action_size)
# Here you can pre-load an agent # Here you can pre-load an agent
if False: if False:
with path(torch_training.Nets, "avoid_checkpoint2400.pth") as file_in: with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
# Do training over n_episodes # Do training over n_episodes
......
...@@ -7,16 +7,16 @@ from collections import deque ...@@ -7,16 +7,16 @@ from collections import deque
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
from importlib_resources import path
# Import Torch and utility functions to normalize observation
import torch_training.Nets
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.rail_generators import complex_rail_generator
# Import Flatland/ Observations and Predictors # Import Flatland/ Observations and Predictors
from flatland.envs.schedule_generators import complex_schedule_generator from flatland.envs.schedule_generators import complex_schedule_generator
from importlib_resources import path
# Import Torch and utility functions to normalize observation
import torch_training.Nets
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import norm_obs_clip, split_tree from utils.observation_utils import norm_obs_clip, split_tree
...@@ -41,7 +41,7 @@ def main(argv): ...@@ -41,7 +41,7 @@ def main(argv):
n_agents = np.random.randint(3, 8) n_agents = np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3) n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim)) min_dist = int(0.75 * min(x_dim, y_dim))
tree_depth = 3 tree_depth = 2
print("main2") print("main2")
demo = False demo = False
...@@ -60,7 +60,6 @@ def main(argv): ...@@ -60,7 +60,6 @@ def main(argv):
handle = env.get_agent_handles() handle = env.get_agent_handles()
features_per_node = env.obs_builder.observation_dim features_per_node = env.obs_builder.observation_dim
tree_depth = 2
nr_nodes = 0 nr_nodes = 0
for i in range(tree_depth + 1): for i in range(tree_depth + 1):
nr_nodes += np.power(4, i) nr_nodes += np.power(4, i)
...@@ -87,11 +86,11 @@ def main(argv): ...@@ -87,11 +86,11 @@ def main(argv):
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
# Initialize the agent # Initialize the agent
agent = Agent(state_size, action_size, 0) agent = Agent(state_size, action_size)
# Here you can pre-load an agent # Here you can pre-load an agent
if False: if False:
with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in: with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
# Do training over n_episodes # Do training over n_episodes
...@@ -132,7 +131,7 @@ def main(argv): ...@@ -132,7 +131,7 @@ def main(argv):
# Build agent specific observations # Build agent specific observations
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(obs[a]), data, distance, agent_data = split_tree(tree=np.array(obs[a]),
num_features_per_node=11, num_features_per_node=features_per_node,
current_depth=0) current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
...@@ -165,6 +164,7 @@ def main(argv): ...@@ -165,6 +164,7 @@ def main(argv):
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
num_features_per_node=features_per_node,
current_depth=0) current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
......
...@@ -101,8 +101,8 @@ dones_list = [] ...@@ -101,8 +101,8 @@ dones_list = []
action_prob = [0] * action_size action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, 0) agent = Agent(state_size, action_size)
with path(torch_training.Nets, "navigator_checkpoint10700.pth") as file_in: with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False record_images = False
......
...@@ -117,7 +117,7 @@ def main(argv): ...@@ -117,7 +117,7 @@ def main(argv):
cummulated_reward = np.zeros(env.get_num_agents()) cummulated_reward = np.zeros(env.get_num_agents())
# Now we load a Double dueling DQN agent # Now we load a Double dueling DQN agent
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size)
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment