From df0b0ef116980c363e4f442eed7c376efbd8af96 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 6 Oct 2019 11:20:57 -0400 Subject: [PATCH] updated training file --- torch_training/render_agent_behavior.py | 8 ++++---- torch_training/training_navigation.py | 25 +++++++++++-------------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 6802a22..93f9f12 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -28,8 +28,8 @@ y_dim = env.height """ # Parameters for the Environment -x_dim = 20 -y_dim = 20 +x_dim = 25 +y_dim = 25 n_agents = 1 n_goals = 5 min_dist = 5 @@ -48,9 +48,9 @@ stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents TreeObservation = TreeObsForRailEnv(max_depth=2) # Different agent types (trains) with different speeds. -speed_ration_map = {1.: 1., # Fast passenger train +speed_ration_map = {1.: 0., # Fast passenger train 1. / 2.: 0.0, # Fast freight train - 1. / 3.: 0.0, # Slow commuter train + 1. / 3.: 1.0, # Slow commuter train 1. / 4.: 0.0} # Slow freight train env = RailEnv(width=x_dim, diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 3220b94..229c804 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -35,8 +35,8 @@ def main(argv): np.random.seed(1) # Parameters for the Environment - x_dim = 20 - y_dim = 20 + x_dim = 30 + y_dim = 30 n_agents = 1 @@ -63,7 +63,7 @@ def main(argv): seed=1, # Random seed grid_mode=False, max_rails_between_cities=2, - max_rails_in_city=2), + max_rails_in_city=3), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=n_agents, stochastic_data=stochastic_data, # Malfunction data generator @@ -105,7 +105,8 @@ def main(argv): agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent_obs_buffer = [None] * env.get_num_agents() - agent_action_buffer = [None] * env.get_num_agents() + agent_action_buffer = [2] * env.get_num_agents() + agent_done_buffer = [False] * env.get_num_agents() cummulated_reward = np.zeros(env.get_num_agents()) # Now we load a Double dueling DQN agent @@ -115,7 +116,7 @@ def main(argv): # Reset environment obs, info = env.reset(True, True) - + env_renderer.reset() # Build agent specific observations for a in range(env.get_num_agents()): agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) @@ -132,36 +133,32 @@ def main(argv): if info['action_required'][a]: action = agent.act(agent_obs[a], eps=eps) action_prob[action] += 1 - if step == 0: - agent_action_buffer[a] = action else: action = 0 action_dict.update({a: action}) # Environment step next_obs, all_rewards, done, info = env.step(action_dict) - # Build agent specific observations and normalize for a in range(env.get_num_agents()): # Penalize waiting in order to get agent to move if env.agents[a].status == 0: all_rewards[a] -= 1 - - agent_next_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10) + if info['action_required'][a]: + agent_next_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10) cummulated_reward[a] += all_rewards[a] # Update replay buffer and train agent for a in range(env.get_num_agents()): - if (agent_obs_buffer[a] is not None and info['action_required'][a] and env.agents[a].status != 3) or \ - env.agents[a].status == 2: + if (info['action_required'][a] and env.agents[a].status != 3) or env.agents[a].status == 2: - agent_delayed_next = agent_obs[a].copy() agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a], - agent_delayed_next, done[a]) + agent_obs[a], agent_done_buffer[a]) cummulated_reward[a] = 0. if info['action_required'][a]: agent_obs_buffer[a] = agent_obs[a].copy() agent_action_buffer[a] = action_dict[a] + agent_done_buffer[a] = done[a] score += all_rewards[a] / env.get_num_agents() -- GitLab