diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 66a37ad290ade376f682fcb2f40f1a02533537dc..3fc6468d7cd2d448b12de3faa1b21b9de0081ff9 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -10,6 +10,7 @@ from predictors.predictions import ShortestPathPredictorForRailEnv import torch_training.Nets from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_file +from flatland.envs.schedule_generators import schedule_from_file from flatland.utils.rendertools import RenderTool from torch_training.dueling_double_dqn import Agent from utils.observation_utils import normalize_observation @@ -24,6 +25,7 @@ file_name = "./railway/simple_avoid.pkl" env = RailEnv(width=10, height=20, rail_generator=rail_from_file(file_name), + schedule_generator=schedule_from_file(file_name), obs_builder_object=observation_helper) x_dim = env.width y_dim = env.height diff --git a/torch_training/observation_builders/__init__.py b/torch_training/observation_builders/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/torch_training/predictors/__init__.py b/torch_training/predictors/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000