Skip to content
Snippets Groups Projects
render_agent_behavior.py 5.25 KiB
Newer Older
import random
from collections import deque

import numpy as np
import torch
u214892's avatar
u214892 committed
from importlib_resources import path

import torch_training.Nets
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import normalize_observation

random.seed(1)
np.random.seed(1)
file_name = "./railway/complex_scene.pkl"
env = RailEnv(width=10,
              height=20,
              rail_generator=rail_from_file(file_name),
              obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
x_dim = env.width
y_dim = env.height
"""

# Parameters for the Environment
x_dim = 20
y_dim = 20
n_agents = 1
n_goals = 5
min_dist = 5

# We are training an Agent using the Tree Observation with depth 2
observation_builder = TreeObsForRailEnv(max_depth=2)

# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.0,  # Percentage of defective agents
                   'malfunction_rate': 30,  # Rate of malfunction occurence
                   'min_duration': 3,  # Minimal duration of malfunction
                   'max_duration': 20  # Max duration of malfunction
                   }

# Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2)

# Different agent types (trains) with different speeds.
Erik Nygren's avatar
Erik Nygren committed
speed_ration_map = {1.: 0.,  # Fast passenger train
                    1. / 2.: 1.0,  # Fast freight train
                    1. / 3.: 0.0,  # Slow commuter train
                    1. / 4.: 0.0}  # Slow freight train

env = RailEnv(width=x_dim,
              height=y_dim,
              rail_generator=sparse_rail_generator(num_cities=5,
                                                   # Number of cities in map (where train stations are)
                                                   num_intersections=4,
                                                   # Number of intersections (no start / target)
                                                   num_trainstations=10,  # Number of possible start/targets on map
                                                   min_node_dist=3,  # Minimal distance of nodes
                                                   node_radius=2,  # Proximity of stations to city center
                                                   num_neighb=3,
                                                   # Number of connections to other cities/intersections
                                                   seed=15,  # Random seed
                                                   grid_mode=True,
                                                   enhance_intersection=False
                                                   ),
              schedule_generator=sparse_schedule_generator(speed_ration_map),
              number_of_agents=n_agents,
              stochastic_data=stochastic_data,  # Malfunction data generator
              obs_builder_object=TreeObservation)
env.reset(True, True)
observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", )
num_features_per_node = env.obs_builder.observation_dim

tree_depth = 2
nr_nodes = 0
for i in range(tree_depth + 1):
    nr_nodes += np.power(4, i)
state_size = num_features_per_node * nr_nodes
action_size = 5

# We set the number of episodes we would like to train on
if 'n_trials' not in locals():
    n_trials = 60000
max_steps = int(3 * (env.height + env.width))
eps = 1.
eps_end = 0.005
eps_decay = 0.9995
action_dict = dict()
final_action_dict = dict()
scores_window = deque(maxlen=100)
done_window = deque(maxlen=100)
scores = []
dones_list = []
action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
Erik Nygren's avatar
Erik Nygren committed
with path(torch_training.Nets, "navigator_checkpoint10700.pth") as file_in:
    agent.qnetwork_local.load_state_dict(torch.load(file_in))

record_images = False
frame_step = 0

for trials in range(1, n_trials + 1):

    # Reset environment
    obs = env.reset(True, True)
    env_renderer.reset()
    # Build agent specific observations
    for a in range(env.get_num_agents()):
        agent_obs[a] = agent_obs[a] = normalize_observation(obs[a], observation_radius=10)
    # Reset score and done
    score = 0
    env_done = 0

    # Run episode
    for step in range(max_steps):

        # Action
        for a in range(env.get_num_agents()):
            action = agent.act(agent_obs[a], eps=0.)
            action_prob[action] += 1
            action_dict.update({a: action})
        # Environment step
        obs, all_rewards, done, _ = env.step(action_dict)
        env_renderer.render_env(show=True, show_predictions=True, show_observations=False)
        # Build agent specific observations and normalize
        for a in range(env.get_num_agents()):
            agent_obs[a] = normalize_observation(obs[a], observation_radius=10)


        if done['__all__']:
            break