From 2c63e82561ae5337b47a244cbf429ca0253734cd Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 1 Sep 2019 11:17:31 -0400 Subject: [PATCH] updated observation and training to handle multi-speed --- torch_training/render_agent_behavior.py | 12 ++++++------ torch_training/training_navigation.py | 19 ++++++++++++------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 8264db6..fc0e067 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -38,7 +38,7 @@ min_dist = 5 observation_builder = TreeObsForRailEnv(max_depth=2) # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents +stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents 'malfunction_rate': 30, # Rate of malfunction occurence 'min_duration': 3, # Minimal duration of malfunction 'max_duration': 20 # Max duration of malfunction @@ -48,10 +48,10 @@ stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents TreeObservation = TreeObsForRailEnv(max_depth=2) # Different agent types (trains) with different speeds. -speed_ration_map = {1.: 0.25, # Fast passenger train - 1. / 2.: 0.25, # Fast freight train - 1. / 3.: 0.25, # Slow commuter train - 1. / 4.: 0.25} # Slow freight train +speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0.0, # Fast freight train + 1. / 3.: 0.0, # Slow commuter train + 1. / 4.: 0.0} # Slow freight train env = RailEnv(width=x_dim, height=y_dim, @@ -103,7 +103,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -with path(torch_training.Nets, "navigator_checkpoint100.pth") as file_in: +with path(torch_training.Nets, "navigator_checkpoint1200.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) record_images = False diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index d324b7c..25b8c14 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -37,7 +37,7 @@ def main(argv): min_dist = 5 # Use a the malfunction generator to break agents from time to time - stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents + stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents 'malfunction_rate': 30, # Rate of malfunction occurence 'min_duration': 3, # Minimal duration of malfunction 'max_duration': 20 # Max duration of malfunction @@ -47,10 +47,10 @@ def main(argv): TreeObservation = TreeObsForRailEnv(max_depth=2) # Different agent types (trains) with different speeds. - speed_ration_map = {1.: 0.25, # Fast passenger train - 1. / 2.: 0.25, # Fast freight train - 1. / 3.: 0.25, # Slow commuter train - 1. / 4.: 0.25} # Slow freight train + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0.0, # Fast freight train + 1. / 3.: 0.0, # Slow commuter train + 1. / 4.: 0.0} # Slow freight train env = RailEnv(width=x_dim, height=y_dim, @@ -120,7 +120,7 @@ def main(argv): # Reset environment obs = env.reset(True, True) - + register_action_state = np.zeros(env.get_num_agents(), dtype=bool) final_obs = agent_obs.copy() final_obs_next = agent_next_obs.copy() @@ -138,6 +138,11 @@ def main(argv): # Action for a in range(env.get_num_agents()): + if env.agents[a].speed_data['position_fraction'] == 0.: + register_action_state[a] = True + else: + register_action_state[a] = False + action = agent.act(agent_obs[a], eps=eps) action_prob[action] += 1 action_dict.update({a: action}) @@ -155,7 +160,7 @@ def main(argv): final_obs[a] = agent_obs[a].copy() final_obs_next[a] = agent_next_obs[a].copy() final_action_dict.update({a: action_dict[a]}) - if not done[a]: + if not done[a] and register_action_state[a]: agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a]) score += all_rewards[a] / env.get_num_agents() -- GitLab