Skip to content
Snippets Groups Projects
Commit 054bc93a authored by u214892's avatar u214892
Browse files

#42 run baselines in ci

parent 31b23f2c
No related branches found
No related tags found
1 merge request!242 run baselines in ci
Pipeline #1405 canceled
import random
from collections import deque
from sys import path
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
from dueling_double_dqn import Agent
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from utils.observation_utils import norm_obs_clip, split_tree
random.seed(1)
......@@ -62,7 +63,8 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth'))
with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in))
demo = True
record_images = False
......
from sys import path
import random
from collections import deque
......@@ -87,7 +89,9 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth'))
with path("torch_training/Nets", "avoid_checkpoint30000.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in))
demo = True
record_images = False
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment