Skip to content
Snippets Groups Projects
Commit b7f90b08 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

resolved issues

parent 2503a421
No related branches found
No related tags found
No related merge requests found
...@@ -8,7 +8,7 @@ from flatland.utils.rendertools import RenderTool ...@@ -8,7 +8,7 @@ from flatland.utils.rendertools import RenderTool
np.random.seed(1) np.random.seed(1)
# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks # Use the new sparse_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment # Training on simple small tasks is the best way to get familiar with the environment
# Use a the malfunction generator to break agents from time to time # Use a the malfunction generator to break agents from time to time
...@@ -22,7 +22,7 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor ...@@ -22,7 +22,7 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor
env = RailEnv(width=20, env = RailEnv(width=20,
height=20, height=20,
rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map (where train stations are) rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map (where train stations are)
num_intersections=1, # Number of interesections (no start / target) num_intersections=1, # Number of intersections (no start / target)
num_trainstations=15, # Number of possible start/targets on map num_trainstations=15, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes min_node_dist=3, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center node_radius=3, # Proximity of stations to city center
...@@ -32,16 +32,14 @@ env = RailEnv(width=20, ...@@ -32,16 +32,14 @@ env = RailEnv(width=20,
enhance_intersection=True enhance_intersection=True
), ),
number_of_agents=5, number_of_agents=5,
stochastic_data=stochastic_data, # Malfunction generator data stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=TreeObservation) obs_builder_object=TreeObservation)
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
# Import your own Agent or use RLlib to train agents on Flatland # Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent here # As an example we use a random agent instead
class RandomAgent: class RandomAgent:
def __init__(self, state_size, action_size): def __init__(self, state_size, action_size):
...@@ -76,48 +74,46 @@ class RandomAgent: ...@@ -76,48 +74,46 @@ class RandomAgent:
# Initialize the agent with the parameters corresponding to the environment and observation_builder # Initialize the agent with the parameters corresponding to the environment and observation_builder
# Set action space to 4 to remove stop action # Set action space to 4 to remove stop action
agent = RandomAgent(218, 4) agent = RandomAgent(218, 4)
n_trials = 1
# Empty dictionary for all agent action # Empty dictionary for all agent action
action_dict = dict() action_dict = dict()
print("Starting Training...")
print("Start episode...")
for trials in range(1, n_trials + 1): # Reset environment and get initial observations for all agents
obs = env.reset()
# Reset environment and get initial observations for all agents # Update/Set agent's speed
obs = env.reset() for idx in range(env.get_num_agents()):
for idx in range(env.get_num_agents()): speed = 1.0 / ((idx % 5) + 1.0)
tmp_agent = env.agents[idx] env.agents[idx].speed_data["speed"] = speed
speed = (idx % 5) + 1
tmp_agent.speed_data["speed"] = 1 / speed # Reset the rendering sytem
env_renderer.reset() env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository # Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
score = 0
# Run episode score = 0
frame_step = 0 # Run episode
for step in range(500): frame_step = 0
# Chose an action for each agent in the environment for step in range(500):
for a in range(env.get_num_agents()): # Chose an action for each agent in the environment
action = agent.act(obs[a]) for a in range(env.get_num_agents()):
action_dict.update({a: action}) action = agent.act(obs[a])
action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done # Environment step which returns the observations for all agents, their corresponding
next_obs, all_rewards, done, _ = env.step(action_dict) # reward and whether their are done
env_renderer.render_env(show=True, show_observations=False, show_predictions=False) next_obs, all_rewards, done, _ = env.step(action_dict)
try: env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
env_renderer.gl.save_image("./../rendering/flatland_2_0_frame_{:04d}.bmp".format(frame_step)) frame_step += 1
except: # Update replay buffer and train agent
print("Path not found: ./../rendering/") for a in range(env.get_num_agents()):
frame_step += 1 agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
# Update replay buffer and train agent score += all_rewards[a]
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) obs = next_obs.copy()
score += all_rewards[a] if done['__all__']:
break
obs = next_obs.copy()
if done['__all__']: print('Episode: Steps {}\t Score = {}'.format(step, score))
break
print('Episode Nr. {}\t Score = {}'.format(trials, score))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment