diff --git a/torch_training/dueling_double_dqn.py b/torch_training/dueling_double_dqn.py index dd67b4f0d73ffe1b3f4ad3e947debf18508e78b0..cf2f7d512b99aafb9fe0477bf048441efa0bff9e 100644 --- a/torch_training/dueling_double_dqn.py +++ b/torch_training/dueling_double_dqn.py @@ -20,7 +20,7 @@ double_dqn = True # If using double dqn algorithm input_channels = 5 # Number of Input channels device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -device = torch.device("cpu") +#device = torch.device("cpu") print(device) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index ad42e0a0ec60ba56d270cfea950da32b982527d1..222430dd0af0f239ddd99d127b21349a13a2e892 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -4,6 +4,11 @@ import random import sys from collections import deque +# make sure the root path is in system path +from pathlib import Path +base_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(base_dir)) + import matplotlib.pyplot as plt import numpy as np import torch diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index c97f1f5df2e171410f05f482a594d2b840c42dbc..f69929f65accc53101ba28d8904cdf76b7e1cfca 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -3,6 +3,11 @@ import random import sys from collections import deque +# make sure the root path is in system path +from pathlib import Path +base_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(base_dir)) + import matplotlib.pyplot as plt import numpy as np import torch