Skip to content
Snippets Groups Projects
Commit 85f9e04b authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Normalization of observation updated to stay between -1 and 1

parent 366b4fe5
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,6 @@ from collections import deque ...@@ -3,7 +3,6 @@ from collections import deque
import numpy as np import numpy as np
import torch import torch
from dueling_double_dqn import Agent from dueling_double_dqn import Agent
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
...@@ -38,16 +37,16 @@ env = RailEnv(width=15, ...@@ -38,16 +37,16 @@ env = RailEnv(width=15,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
number_of_agents=1) number_of_agents=1)
"""
env = RailEnv(width=10, env = RailEnv(width=10,
height=20) height=20)
env.load("./railway/complex_scene.pkl") env.load("./railway/complex_scene.pkl")
"""
env = RailEnv(width=8, env = RailEnv(width=8,
height=8, height=8,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=5, max_dist=99999, seed=0), rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=5, max_dist=99999, seed=0),
number_of_agents=2) number_of_agents=1)
env.reset(True, True) env.reset(True, True)
env_renderer = RenderTool(env, gl="PILSVG") env_renderer = RenderTool(env, gl="PILSVG")
...@@ -133,7 +132,7 @@ for trials in range(1, n_trials + 1): ...@@ -133,7 +132,7 @@ for trials in range(1, n_trials + 1):
data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=7, current_depth=0) data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=7, current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1)
obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
for i in range(2): for i in range(2):
time_obs.append(obs) time_obs.append(obs)
...@@ -146,13 +145,12 @@ for trials in range(1, n_trials + 1): ...@@ -146,13 +145,12 @@ for trials in range(1, n_trials + 1):
# Run episode # Run episode
for step in range(100): for step in range(100):
if demo: if demo:
env_renderer.renderEnv(show=True, show_observations=False)
env_renderer.renderEnv(show=True, show_observations=True)
# print(step) # print(step)
# Action # Action
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if demo: if demo:
eps = 0 eps = 1
# action = agent.act(np.array(obs[a]), eps=eps) # action = agent.act(np.array(obs[a]), eps=eps)
action = agent.act(agent_obs[a], eps=eps) action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1 action_prob[action] += 1
...@@ -165,6 +163,7 @@ for trials in range(1, n_trials + 1): ...@@ -165,6 +163,7 @@ for trials in range(1, n_trials + 1):
current_depth=0) current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1)
next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
time_obs.append(next_obs) time_obs.append(next_obs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment