diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 3a61d1f8fb07783e0cc5bb792c0cac5887f200a8..607206ec8f423c421d2e57397cce3a2a1900679e 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -51,8 +51,8 @@ def main(argv): TreeObservation = TreeObsForRailEnv(max_depth=2) # Different agent types (trains) with different speeds. - speed_ration_map = {1.: 0., # Fast passenger train - 1. / 2.: 1.0, # Fast freight train + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0.0, # Fast freight train 1. / 3.: 0.0, # Slow commuter train 1. / 4.: 0.0} # Slow freight train @@ -114,7 +114,7 @@ def main(argv): for trials in range(1, n_trials + 1): # Reset environment - obs = env.reset(True, True) + obs, info = env.reset(True, True) register_action_state = np.zeros(env.get_num_agents(), dtype=bool) final_obs = agent_obs.copy() final_obs_next = agent_next_obs.copy() @@ -132,7 +132,7 @@ def main(argv): for step in range(max_steps): # Action for a in range(env.get_num_agents()): - if env.agents[a].speed_data['position_fraction'] < 0.001: + if info['action_required'][a]: register_action_state[a] = True action = agent.act(agent_obs[a], eps=eps) action_prob[action] += 1 @@ -144,10 +144,13 @@ def main(argv): action_dict.update({a: action}) # Environment step - next_obs, all_rewards, done, _ = env.step(action_dict) + next_obs, all_rewards, done, info = env.step(action_dict) # Build agent specific observations and normalize for a in range(env.get_num_agents()): + # Penalize waiting in order to get agent to move + if env.agents[a].status == 0: + all_rewards[a] -= 1 agent_next_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10) cummulated_reward[a] += all_rewards[a]