Skip to content
Snippets Groups Projects
Commit 9139e2de authored by u214892's avatar u214892
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/baselines into...

Merge branch 'master' of gitlab.aicrowd.com:flatland/baselines into 57-access-resources-through-importlib_resources
parents 1f8dfa71 3f36d20a
No related branches found
No related tags found
No related merge requests found
No preview for this file type
......@@ -6,6 +6,8 @@ import numpy as np
import torch
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from torch_training.dueling_double_dqn import Agent
......@@ -47,10 +49,12 @@ env = RailEnv(width=10,
height=20)
env.load_resource('torch_training.railway', "complex_scene.pkl")
"""
env = RailEnv(width=8,
height=8,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=5, max_dist=99999, seed=0),
number_of_agents=1)
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv()),
number_of_agents=10)
env.reset(True, True)
env_renderer = RenderTool(env, gl="PILSVG")
......@@ -73,9 +77,9 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
# agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth'))
agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
demo = False
demo = True
def max_lt(seq, val):
......@@ -149,7 +153,7 @@ for trials in range(1, n_trials + 1):
score = 0
env_done = 0
# Run episode
for step in range(100):
for step in range(env.height * env.width):
if demo:
env_renderer.renderEnv(show=True, show_observations=False)
# print(step)
......@@ -157,6 +161,7 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()):
if demo:
eps = 1
# action = agent.act(np.array(obs[a]), eps=eps)
action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1
action_dict.update({a: action})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment