Skip to content
Snippets Groups Projects
Commit 1e8ee7fd authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated multi/agent inference

parent 5befd0e4
No related branches found
No related tags found
No related merge requests found
...@@ -15,63 +15,70 @@ import torch_training.Nets ...@@ -15,63 +15,70 @@ import torch_training.Nets
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import normalize_observation from utils.observation_utils import normalize_observation
random.seed(3) random.seed(1)
np.random.seed(2) np.random.seed(1)
"""
file_name = "./railway/complex_scene.pkl"
env = RailEnv(width=10,
height=20,
rail_generator=rail_from_file(file_name),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
x_dim = env.width
y_dim = env.height
"""
# Parameters for the Environment # Parameters for the Environment
x_dim = 20 x_dim = 25
y_dim = 20 y_dim = 25
n_agents = 5 n_agents = 1
tree_depth = 2
# We are training an Agent using the Tree Observation with depth 2
observation_builder = TreeObsForRailEnv(max_depth=2)
# Use a the malfunction generator to break agents from time to time # Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence 'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction 'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction 'max_duration': 20 # Max duration of malfunction
} }
# Custom observation builder # Custom observation builder
predictor = ShortestPathPredictorForRailEnv() TreeObservation = TreeObsForRailEnv(max_depth=2)
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor)
# Different agent types (trains) with different speeds. # Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0.25, # Fast freight train 1. / 2.: 0.0, # Fast freight train
1. / 3.: 0.25, # Slow commuter train 1. / 3.: 0.0, # Slow commuter train
1. / 4.: 0.25} # Slow freight train 1. / 4.: 0.0} # Slow freight train
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
rail_generator=sparse_rail_generator(num_cities=5, rail_generator=sparse_rail_generator(max_num_cities=3,
# Number of cities in map (where train stations are) # Number of cities in map (where train stations are)
num_intersections=4, seed=1, # Random seed
# Number of intersections (no start / target) grid_mode=False,
num_trainstations=10, # Number of possible start/targets on map max_rails_between_cities=2,
min_node_dist=3, # Minimal distance of nodes max_rails_in_city=2),
node_radius=2, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map), schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=n_agents, number_of_agents=n_agents,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=observation_helper) obs_builder_object=TreeObservation)
env.reset(True, True) env.reset(True, True)
observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
handle = env.get_agent_handles()
num_features_per_node = env.obs_builder.observation_dim num_features_per_node = env.obs_builder.observation_dim
tree_depth = 2
nr_nodes = 0 nr_nodes = 0
for i in range(tree_depth + 1): for i in range(tree_depth + 1):
nr_nodes += np.power(4, i) nr_nodes += np.power(4, i)
state_size = num_features_per_node * nr_nodes state_size = num_features_per_node * nr_nodes
action_size = 5 action_size = 5
n_trials = 10 # We set the number of episodes we would like to train on
observation_radius = 10 if 'n_trials' not in locals():
n_trials = 60000
max_steps = int(3 * (env.height + env.width)) max_steps = int(3 * (env.height + env.width))
eps = 1. eps = 1.
eps_end = 0.005 eps_end = 0.005
...@@ -80,14 +87,13 @@ action_dict = dict() ...@@ -80,14 +87,13 @@ action_dict = dict()
final_action_dict = dict() final_action_dict = dict()
scores_window = deque(maxlen=100) scores_window = deque(maxlen=100)
done_window = deque(maxlen=100) done_window = deque(maxlen=100)
time_obs = deque(maxlen=2)
scores = [] scores = []
dones_list = [] dones_list = []
action_prob = [0] * action_size action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size) agent = Agent(state_size, action_size)
with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in: with path(torch_training.Nets, "avoider_checkpoint1000.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False record_images = False
...@@ -97,29 +103,35 @@ for trials in range(1, n_trials + 1): ...@@ -97,29 +103,35 @@ for trials in range(1, n_trials + 1):
# Reset environment # Reset environment
obs, info = env.reset(True, True) obs, info = env.reset(True, True)
env_renderer.reset() env_renderer.reset()
# Build agent specific observations
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent_obs[a] = normalize_observation(obs[a], observation_radius=10) agent_obs[a] = agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
# Reset score and done
score = 0
env_done = 0
# Run episode # Run episode
for step in range(max_steps): for step in range(max_steps):
env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
if record_images:
env_renderer.gl.save_image("./Images/Avoiding/flatland_frame_{:04d}.bmp".format(frame_step))
frame_step += 1
# time.sleep(1.5)
# Action # Action
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
action = agent.act(agent_obs[a], eps=0) if info['action_required'][a]:
action = agent.act(agent_obs[a], eps=0.)
else:
action = 0
action_prob[action] += 1
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step # Environment step
next_obs, all_rewards, done, _ = env.step(action_dict) obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_predictions=True, show_observations=False)
# Build agent specific observations and normalize
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10) agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
if done['__all__']: if done['__all__']:
break break
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment