Skip to content
Snippets Groups Projects
Commit ce7297b0 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

merged stochastic breaking branch for testing.

parent abc1ac6e
No related branches found
No related tags found
No related merge requests found
...@@ -29,7 +29,7 @@ env = RailEnv(width=50, ...@@ -29,7 +29,7 @@ env = RailEnv(width=50,
num_neighb=4, # Number of connections to other cities num_neighb=4, # Number of connections to other cities
seed=15, # Random seed seed=15, # Random seed
), ),
number_of_agents=35, number_of_agents=10,
stochastic_data=stochastic_data, # Malfunction generator data stochastic_data=stochastic_data, # Malfunction generator data
obs_builder_object=TreeObservation) obs_builder_object=TreeObservation)
...@@ -74,7 +74,7 @@ class RandomAgent: ...@@ -74,7 +74,7 @@ class RandomAgent:
# Initialize the agent with the parameters corresponding to the environment and observation_builder # Initialize the agent with the parameters corresponding to the environment and observation_builder
# Set action space to 4 to remove stop action # Set action space to 4 to remove stop action
agent = RandomAgent(218, 4) agent = RandomAgent(218, 4)
n_trials = 5 n_trials = 1
# Empty dictionary for all agent action # Empty dictionary for all agent action
action_dict = dict() action_dict = dict()
...@@ -94,6 +94,7 @@ for trials in range(1, n_trials + 1): ...@@ -94,6 +94,7 @@ for trials in range(1, n_trials + 1):
score = 0 score = 0
# Run episode # Run episode
frame_step = 0
for step in range(500): for step in range(500):
# Chose an action for each agent in the environment # Chose an action for each agent in the environment
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
...@@ -104,7 +105,8 @@ for trials in range(1, n_trials + 1): ...@@ -104,7 +105,8 @@ for trials in range(1, n_trials + 1):
# reward and whether their are done # reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=False, show_predictions=False) env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
env_renderer.gl.save_image("./Images/flatland_2_0_frame_{:04d}.bmp".format(frame_step))
frame_step += 1
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment