Skip to content
Snippets Groups Projects
Commit eb71fa51 authored by u214892's avatar u214892
Browse files

#42 run baselines in ci

parent 054bc93a
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ import random
import torch
from dueling_double_dqn import Agent
import torch_training
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
......@@ -63,7 +64,7 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in:
with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in))
demo = True
......
......@@ -7,6 +7,8 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
from dueling_double_dqn import Agent
import torch_training
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
......@@ -89,7 +91,7 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in:
with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment