diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 216ddc66c439027fdf4a9603f5dc7a4091c246bf..ba9f144ac2edbebdd24706129bb339f1614b3c67 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -7,7 +7,7 @@ import random import torch from dueling_double_dqn import Agent -import torch_training +import torch_training.Nets from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 095f20b4a91c0cab1e97fdec55cf419d3cd77be7..1857b676f0fd4fd65f90ed6e37930988519c6cda 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -8,7 +8,7 @@ import numpy as np import torch from dueling_double_dqn import Agent -import torch_training +import torch_training.Nets from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv