Skip to content
Snippets Groups Projects
Commit 2c63e825 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated observation and training to handle multi-speed

parent cebde3d8
No related branches found
No related tags found
No related merge requests found
......@@ -38,7 +38,7 @@ min_dist = 5
observation_builder = TreeObsForRailEnv(max_depth=2)
# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents
stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
......@@ -48,10 +48,10 @@ stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents
TreeObservation = TreeObsForRailEnv(max_depth=2)
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0.0, # Fast freight train
1. / 3.: 0.0, # Slow commuter train
1. / 4.: 0.0} # Slow freight train
env = RailEnv(width=x_dim,
height=y_dim,
......@@ -103,7 +103,7 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
with path(torch_training.Nets, "navigator_checkpoint100.pth") as file_in:
with path(torch_training.Nets, "navigator_checkpoint1200.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False
......
......@@ -37,7 +37,7 @@ def main(argv):
min_dist = 5
# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents
stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
......@@ -47,10 +47,10 @@ def main(argv):
TreeObservation = TreeObsForRailEnv(max_depth=2)
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0.0, # Fast freight train
1. / 3.: 0.0, # Slow commuter train
1. / 4.: 0.0} # Slow freight train
env = RailEnv(width=x_dim,
height=y_dim,
......@@ -120,7 +120,7 @@ def main(argv):
# Reset environment
obs = env.reset(True, True)
register_action_state = np.zeros(env.get_num_agents(), dtype=bool)
final_obs = agent_obs.copy()
final_obs_next = agent_next_obs.copy()
......@@ -138,6 +138,11 @@ def main(argv):
# Action
for a in range(env.get_num_agents()):
if env.agents[a].speed_data['position_fraction'] == 0.:
register_action_state[a] = True
else:
register_action_state[a] = False
action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1
action_dict.update({a: action})
......@@ -155,7 +160,7 @@ def main(argv):
final_obs[a] = agent_obs[a].copy()
final_obs_next[a] = agent_next_obs[a].copy()
final_action_dict.update({a: action_dict[a]})
if not done[a]:
if not done[a] and register_action_state[a]:
agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
score += all_rewards[a] / env.get_num_agents()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment