Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • jack_bruck/baselines
  • rivesunder/baselines
  • xzhaoma/baselines
  • giulia_cantini/baselines
  • sfwatergit/baselines
  • jiaodaxiaozi/baselines
  • flatland/baselines
7 results
Show changes
Commits on Source (3)
*pycache* *pycache*
*ppo_policy* *ppo_policy*
torch_training/Nets/
...@@ -20,7 +20,7 @@ double_dqn = True # If using double dqn algorithm ...@@ -20,7 +20,7 @@ double_dqn = True # If using double dqn algorithm
input_channels = 5 # Number of Input channels input_channels = 5 # Number of Input channels
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu") #device = torch.device("cpu")
print(device) print(device)
......
...@@ -4,6 +4,11 @@ import random ...@@ -4,6 +4,11 @@ import random
import sys import sys
from collections import deque from collections import deque
# make sure the root path is in system path
from pathlib import Path
base_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(base_dir))
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
......
...@@ -87,7 +87,7 @@ action_size = 5 ...@@ -87,7 +87,7 @@ action_size = 5
# We set the number of episodes we would like to train on # We set the number of episodes we would like to train on
if 'n_trials' not in locals(): if 'n_trials' not in locals():
n_trials = 60000 n_trials = 6000
max_steps = int(3 * (env.height + env.width)) max_steps = int(3 * (env.height + env.width))
eps = 1. eps = 1.
eps_end = 0.005 eps_end = 0.005
......
...@@ -3,6 +3,11 @@ import random ...@@ -3,6 +3,11 @@ import random
import sys import sys
from collections import deque from collections import deque
# make sure the root path is in system path
from pathlib import Path
base_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(base_dir))
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
......