Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 1926 additions and 669 deletions
import getopt
import sys
import time
import numpy as np
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.malfunction_generators import MalfunctionParameters
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.utils.misc import str2bool
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
# Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent instead
class RandomAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
def act(self, state):
"""
:param state: input is the observation of the agent
:return: returns an action
"""
return 2 # np.random.choice(np.arange(self.action_size))
def step(self, memories):
"""
Step function to improve agent by adjusting policy given the observations
:param memories: SARS Tuple to be
:return:
"""
return
def save(self, filename):
# Store the current policy
return
def load(self, filename):
# Load a policy
return
def create_env():
# Use the new sparse_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
# Use a the malfunction generator to break agents from time to time
stochastic_data = MalfunctionParameters(malfunction_rate=30, # Rate of malfunction occurence
min_duration=3, # Minimal duration of malfunction
max_duration=20 # Max duration of malfunction
)
# Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
nAgents = 3
n_cities = 2
max_rails_between_cities = 2
max_rails_in_city = 4
seed = 0
env = RailEnv(
width=20,
height=30,
rail_generator=sparse_rail_generator(
max_num_cities=n_cities,
seed=seed,
grid_mode=True,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city
),
line_generator=sparse_line_generator(),
number_of_agents=nAgents,
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
)
return env
def flatland_3_0_example(sleep_for_animation, do_rendering):
np.random.seed(1)
env = create_env()
env.reset()
env_renderer = None
if do_rendering:
env_renderer = RenderTool(env, gl="PILSVG",
agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
show_debug=True,
screen_height=1000,
screen_width=1000)
# Initialize the agent with the parameters corresponding to the environment and observation_builder
# Set action space to 4 to remove stop action
agent = RandomAgent(218, 4)
# Empty dictionary for all agent action
action_dict = dict()
print("Start episode...")
# Reset environment and get initial observations for all agents
start_reset = time.time()
obs, info = env.reset()
end_reset = time.time()
print(end_reset - start_reset)
print(env.get_num_agents(), )
# Reset the rendering sytem
if env_renderer is not None:
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
score = 0
# Run episode
frame_step = 0
for step in range(500):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = agent.act(obs[a])
action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
if env_renderer is not None:
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
frame_step += 1
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
obs = next_obs.copy()
if done['__all__']:
break
if env_renderer is not None:
env_renderer.close_window()
print('Episode: Steps {}\t Score = {}'.format(step, score))
RailEnvPersister.save(env, "saved_episode_2.pkl")
def main(args):
try:
opts, args = getopt.getopt(args, "", ["sleep-for-animation=", "do_rendering=", ""])
except getopt.GetoptError as err:
print(str(err)) # will print something like "option -a not recognized"
sys.exit(2)
sleep_for_animation = True
do_rendering = True
for o, a in opts:
if o in ("--sleep-for-animation"):
sleep_for_animation = str2bool(a)
elif o in ("--do_rendering"):
do_rendering = str2bool(a)
else:
assert False, "unhandled option"
# execute example
flatland_3_0_example(sleep_for_animation, do_rendering)
if __name__ == '__main__':
if 'argv' in globals():
main(argv)
else:
main(sys.argv[1:])
import cProfile
import pstats
import numpy as np
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.malfunction_generators import MalfunctionParameters, ParamMalfunctionGen
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
class RandomAgent:
def __init__(self, action_size):
self.action_size = action_size
def act(self, state):
"""
:param state: input is the observation of the agent
:return: returns an action
"""
return np.random.choice(np.arange(self.action_size))
def get_rail_env(nAgents=70, use_dummy_obs=False, width=300, height=300):
# Rail Generator:
num_cities = 5 # Number of cities to place on the map
seed = 1 # Random seed
max_rails_between_cities = 2 # Maximum number of rails connecting 2 cities
max_rail_pairs_in_cities = 2 # Maximum number of pairs of tracks within a city
# Even tracks are used as start points, odd tracks are used as endpoints)
rail_generator = sparse_rail_generator(
max_num_cities=num_cities,
seed=seed,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rail_pairs_in_cities,
)
# Line Generator
# sparse_line_generator accepts a dictionary which maps speeds to probabilities.
# Different agent types (trains) with different speeds.
speed_probability_map = {
1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25 # Slow freight train
}
line_generator = sparse_line_generator(speed_probability_map)
# Malfunction Generator:
stochastic_data = MalfunctionParameters(
malfunction_rate=1 / 10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
malfunction_generator = ParamMalfunctionGen(stochastic_data)
# Observation Builder
# tree observation returns a tree of possible paths from the current position.
max_depth = 3 # Max depth of the tree
predictor = ShortestPathPredictorForRailEnv(
max_depth=50) # (Specific to Tree Observation - read code)
observation_builder = TreeObsForRailEnv(
max_depth=max_depth,
predictor=predictor
)
if use_dummy_obs:
observation_builder = DummyObservationBuilder()
number_of_agents = nAgents # Number of trains to create
seed = 1 # Random seed
env = RailEnv(
width=width,
height=height,
rail_generator=rail_generator,
line_generator=line_generator,
number_of_agents=number_of_agents,
random_seed=seed,
obs_builder_object=observation_builder,
malfunction_generator=malfunction_generator
)
return env
def run_simulation(env_fast: RailEnv, do_rendering):
agent = RandomAgent(action_size=5)
max_steps = 200
env_renderer = None
if do_rendering:
env_renderer = RenderTool(env_fast,
gl="PGL",
show_debug=True,
agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS)
env_renderer.set_new_rail()
env_renderer.reset()
for step in range(max_steps):
# Chose an action for each agent in the environment
for handle in range(env_fast.get_num_agents()):
action = agent.act(handle)
action_dict.update({handle: action})
next_obs, all_rewards, done, _ = env_fast.step(action_dict)
if env_renderer is not None:
env_renderer.render_env(
show=True,
frames=False,
show_observations=True,
show_predictions=False
)
if env_renderer is not None:
env_renderer.close_window()
USE_PROFILER = True
PROFILE_CREATE = False
PROFILE_RESET = False
PROFILE_STEP = True
PROFILE_OBSERVATION = False
RUN_SIMULATION = False
DO_RENDERING = False
if __name__ == "__main__":
print("Start ...")
if USE_PROFILER:
profiler = cProfile.Profile()
print("Create env ... ")
if PROFILE_CREATE:
profiler.enable()
env_fast = get_rail_env(nAgents=200, use_dummy_obs=False, width=100, height=100)
if PROFILE_CREATE:
profiler.disable()
print("Reset env ... ")
if PROFILE_RESET:
profiler.enable()
env_fast.reset(random_seed=1)
if PROFILE_RESET:
profiler.disable()
print("Make actions ... ")
action_dict = {agent.handle: 0 for agent in env_fast.agents}
print("Step env ... ")
if PROFILE_STEP:
profiler.enable()
for i in range(1):
env_fast.step(action_dict)
if PROFILE_STEP:
profiler.disable()
if PROFILE_OBSERVATION:
profiler.enable()
print("get observation ... ")
obs = env_fast._get_observations()
if PROFILE_OBSERVATION:
profiler.disable()
if USE_PROFILER:
if False:
print("---- tottime")
stats = pstats.Stats(profiler).sort_stats('tottime') # ncalls, 'cumtime'...
stats.print_stats(20)
if True:
print("---- cumtime")
stats = pstats.Stats(profiler).sort_stats('cumtime') # ncalls, 'cumtime'...
stats.print_stats(200)
if False:
print("---- ncalls")
stats = pstats.Stats(profiler).sort_stats('ncalls') # ncalls, 'cumtime'...
stats.print_stats(200)
print("... end ")
if RUN_SIMULATION:
run_simulation(env_fast, DO_RENDERING)
import os
import numpy as np
from flatland.envs.line_generators import sparse_line_generator
# In Flatland you can use custom observation builders and predicitors
# Observation builders generate the observation needed by the controller
# Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network
from flatland.envs.malfunction_generators import MalfunctionParameters, ParamMalfunctionGen
from flatland.envs.observations import GlobalObsForRailEnv
# First of all we import the Flatland rail environment
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator
# We also include a renderer because we want to visualize what is going on in the environment
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
# This is an introduction example for the Flatland 2.1.* version.
# Changes and highlights of this version include
# - Stochastic events (malfunctions)
# - Different travel speeds for differet agents
# - Levels are generated using a novel generator to reflect more realistic railway networks
# - Agents start outside of the environment and enter at their own time
# - Agents leave the environment after they have reached their goal
# Use the new sparse_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
# We start by importing the necessary rail and schedule generators
# The rail generator will generate the railway infrastructure
# The schedule generator will assign tasks to all the agent within the railway network
# The railway infrastructure can be build using any of the provided generators in env/rail_generators.py
# Here we use the sparse_rail_generator with the following parameters
DO_RENDERING = False
width = 16 * 7 # With of map
height = 9 * 7 # Height of map
nr_trains = 50 # Number of trains that have an assigned task in the env
cities_in_map = 20 # Number of cities where agents can start or end
seed = 14 # Random seed
grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed
max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city
max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation
rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
seed=seed,
grid_mode=grid_distribution_of_cities,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rail_in_cities,
)
# rail_generator = SparseRailGen(max_num_cities=cities_in_map,
# seed=seed,
# grid_mode=grid_distribution_of_cities,
# max_rails_between_cities=max_rails_between_cities,
# max_rails_in_city=max_rail_in_cities,
# )
# The schedule generator can make very basic schedules with a start point, end point and a speed profile for each agent.
# The speed profiles can be adjusted directly as well as shown later on. We start by introducing a statistical
# distribution of speed profiles
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
# We can now initiate the schedule generator with the given speed profiles
line_generator = sparse_line_generator(speed_ration_map)
# We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
# during an episode.
stochastic_data = MalfunctionParameters(malfunction_rate=1 / 10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
# Custom observation builder without predictor
observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor, uncomment line below if you want to try this one
# observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
# Construct the enviornment with the given observation, generataors, predictors, and stochastic data
env = RailEnv(width=width,
height=height,
rail_generator=rail_generator,
line_generator=line_generator,
number_of_agents=nr_trains,
obs_builder_object=observation_builder,
malfunction_generator=ParamMalfunctionGen(stochastic_data),
remove_agents_at_target=True)
env.reset()
# Initiate the renderer
env_renderer = None
if DO_RENDERING:
env_renderer = RenderTool(env,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
# The first thing we notice is that some agents don't have feasible paths to their target.
# We first look at the map we have created
# nv_renderer.render_env(show=True)
# time.sleep(2)
# Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent instead
class RandomAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
def act(self, state):
"""
:param state: input is the observation of the agent
:return: returns an action
"""
return np.random.choice([RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT,
RailEnvActions.STOP_MOVING])
def step(self, memories):
"""
Step function to improve agent by adjusting policy given the observations
:param memories: SARS Tuple to be
:return:
"""
return
def save(self, filename):
# Store the current policy
return
def load(self, filename):
# Load a policy
return
# Initialize the agent with the parameters corresponding to the environment and observation_builder
controller = RandomAgent(218, env.action_space[0])
# We start by looking at the information of each agent
# We can see the task assigned to the agent by looking at
print("\n Agents in the environment have to solve the following tasks: \n")
for agent_idx, agent in enumerate(env.agents):
print(
"The agent with index {} has the task to go from its initial position {}, facing in the direction {} to its target at {}.".format(
agent_idx, agent.initial_position, agent.direction, agent.target))
# The agent will always have a status indicating if it is currently present in the environment or done or active
# For example we see that agent with index 0 is currently not active
print("\n Their current statuses are:")
print("============================")
for agent_idx, agent in enumerate(env.agents):
print("Agent {} status is: {} with its current position being {}".format(agent_idx, str(agent.state),
str(agent.position)))
# The agent needs to take any action [1,2,3] except do_nothing or stop to enter the level
# If the starting cell is free they will enter the level
# If multiple agents want to enter the same cell at the same time the lower index agent will enter first.
# Let's check if there are any agents with the same start location
agents_with_same_start = set()
print("\n The following agents have the same initial position:")
print("=====================================================")
for agent_idx, agent in enumerate(env.agents):
for agent_2_idx, agent2 in enumerate(env.agents):
if agent_idx != agent_2_idx and agent.initial_position == agent2.initial_position:
print("Agent {} as the same initial position as agent {}".format(agent_idx, agent_2_idx))
agents_with_same_start.add(agent_idx)
# Lets try to enter with all of these agents at the same time
action_dict = dict()
for agent_id in agents_with_same_start:
action_dict[agent_id] = 1 # Try to move with the agents
# Do a step in the environment to see what agents entered:
env.step(action_dict)
# Current state and position of the agents after all agents with same start position tried to move
print("\n This happened when all tried to enter at the same time:")
print("========================================================")
for agent_id in agents_with_same_start:
print(
"Agent {} status is: {} with the current position being {}.".format(
agent_id, str(env.agents[agent_id].state),
str(env.agents[agent_id].position)))
# As you see only the agents with lower indexes moved. As soon as the cell is free again the agents can attempt
# to start again.
# You will also notice, that the agents move at different speeds once they are on the rail.
# The agents will always move at full speed when moving, never a speed inbetween.
# The fastest an agent can go is 1, meaning that it moves to the next cell at every time step
# All slower speeds indicate the fraction of a cell that is moved at each time step
# Lets look at the current speed data of the agents:
print("\n The speed information of the agents are:")
print("=========================================")
for agent_idx, agent in enumerate(env.agents):
print(
"Agent {} speed is: {:.2f} with the current fractional position being {}/{}".format(
agent_idx, agent.speed_counter.speed, agent.speed_counter.counter, agent.speed_counter.max_count))
# New the agents can also have stochastic malfunctions happening which will lead to them being unable to move
# for a certain amount of time steps. The malfunction data of the agents can easily be accessed as follows
print("\n The malfunction data of the agents are:")
print("========================================")
for agent_idx, agent in enumerate(env.agents):
print(
"Agent {} is OK = {}".format(
agent_idx, agent.malfunction_handler.in_malfunction))
# Now that you have seen these novel concepts that were introduced you will realize that agents don't need to take
# an action at every time step as it will only change the outcome when actions are chosen at cell entry.
# Therefore the environment provides information about what agents need to provide an action in the next step.
# You can access this in the following way.
# Chose an action for each agent
for a in range(env.get_num_agents()):
action = controller.act(0)
action_dict.update({a: action})
# Do the environment step
observations, rewards, dones, information = env.step(action_dict)
print("\n The following agents can register an action:")
print("========================================")
for info in information['action_required']:
print("Agent {} needs to submit an action.".format(info))
# We recommend that you monitor the malfunction data and the action required in order to optimize your training
# and controlling code.
# Let us now look at an episode playing out with random actions performed
print("\nStart episode...")
# Reset the rendering system
if env_renderer is not None:
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
score = 0
# Run episode
frame_step = 0
os.makedirs("tmp/frames", exist_ok=True)
for step in range(200):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = controller.act(observations[a])
action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
if env_renderer is not None:
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
env_renderer.gl.save_image('tmp/frames/flatland_frame_{:04d}.png'.format(step))
frame_step += 1
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
controller.step((observations[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
observations = next_obs.copy()
if done['__all__']:
break
print('Episode: Steps {}\t Score = {}'.format(step, score))
# close the renderer / rendering window
if env_renderer is not None:
env_renderer.close_window()
# Making Videos from Env
In order to generate Videos or gifs, it is easiest to generate image files and then run ffmpeg to generate a video.
## 1. Generating Images from Env
Start by importing the render and instantiating it
```
from flatland.utils.rendertools import RenderTool
env_renderer = RenderTool(env, gl="PILSVG", )
```
If the environment changes don't forget to reset the renderer
```
env_renderer.reset()
```
You can now record an image after every step. It is best to use a format similar to the one below, where `frame_step` is counting the number of steps.
```
env_renderer.gl.save_image("./Images/Avoiding/flatland_frame_{:04d}.bmp".format(frame_step))
```
Once the images have been saved to the folder you can run a shell from that folder and run the following commands.
Generate a mp4 out of the images:
```
ffmpeg -y -framerate 12 -i flatland_frame_%04d.bmp -hide_banner -c:v libx264 -pix_fmt yuv420p test.mp4
```
Generate a palette out of the video necessary to generate beautiful gifs:
```
ffmpeg -i test.mp4 -filter_complex "[0:v] palettegen" palette.png
```
and finaly generate the gif
```
ffmpeg -i test.mp4 -i palette.png -filter_complex "[0:v][1:v] paletteuse" single_agent_navigation.gif
```
from flatland.envs.rail_env import RailEnv, random_rail_generator
# from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import RenderTool
from flatland.baselines.dueling_double_dqn import Agent
from collections import deque
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
import time
def main(render=True, delay=0.0):
random.seed(1)
np.random.seed(1)
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
transition_probability = [0.5, # empty cell - Case 0
1.0, # Case 1 - straight
1.0, # Case 2 - simple switch
0.3, # Case 3 - diamond drossing
0.5, # Case 4 - single slip
0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical
0.0] # Case 7 - dead end
# Example generate a random rail
env = RailEnv(width=15,
height=15,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=5)
if render:
env_renderer = RenderTool(env, gl="QT")
plt.figure(figsize=(5,5))
# fRedis = redis.Redis()
handle = env.get_agent_handles()
state_size = 105
action_size = 4
n_trials = 9999
eps = 1.
eps_end = 0.005
eps_decay = 0.998
action_dict = dict()
scores_window = deque(maxlen=100)
done_window = deque(maxlen=100)
scores = []
dones_list = []
action_prob = [0]*4
agent = Agent(state_size, action_size, "FC", 0)
# agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
def max_lt(seq, val):
"""
Return greatest item in seq for which item < val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
idx = len(seq)-1
while idx >= 0:
if seq[idx] < val and seq[idx] >= 0:
return seq[idx]
idx -= 1
return None
iFrame = 0
tStart = time.time()
for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset()
for a in range(env.number_of_agents):
norm = max(1, max_lt(obs[a],np.inf))
obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1)
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
score = 0
env_done = 0
# Run episode
for step in range(50):
#if trials > 114:
#env_renderer.renderEnv(show=True)
#print(step)
# Action
for a in range(env.number_of_agents):
action = agent.act(np.array(obs[a]), eps=eps)
action_prob[action] += 1
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.number_of_agents):
norm = max(1, max_lt(next_obs[a], np.inf))
next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
# Update replay buffer and train agent
for a in range(env.number_of_agents):
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
score += all_rewards[a]
if render:
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step)
if delay > 0:
time.sleep(delay)
iFrame += 1
obs = next_obs.copy()
if done['__all__']:
env_done = 1
break
# Epsilon decay
eps = max(eps_end, eps_decay * eps) # decrease epsilon
done_window.append(env_done)
scores_window.append(score) # save most recent score
scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window)))
print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' +
'\tEpsilon: {:.2f} \t Action Probabilities: \t {}').format(
env.number_of_agents,
trials,
np.mean(scores_window),
100 * np.mean(done_window),
eps, action_prob/np.sum(action_prob)),
end=" ")
if trials % 100 == 0:
tNow = time.time()
rFps = iFrame / (tNow - tStart)
print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' +
'\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format(
env.number_of_agents,
trials,
np.mean(scores_window),
100 * np.mean(done_window),
eps, rFps, action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(),
'../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
action_prob = [1]*4
if __name__ == "__main__":
main()
\ No newline at end of file
import random
import numpy as np
import matplotlib.pyplot as plt
from flatland.envs.rail_env import *
from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import *
random.seed(0)
np.random.seed(0)
transition_probability = [1.0, # empty cell - Case 0
1.0, # Case 1 - straight
1.0, # Case 2 - simple switch
0.3, # Case 3 - diamond drossing
0.5, # Case 4 - single slip
0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical
0.0] # Case 7 - dead end
"""
# Example generate a random rail
env = RailEnv(width=20,
height=20,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=10)
# env = RailEnv(width=20,
# height=20,
# rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(['examples/sample_10_10_rail.npy']),
# number_of_agents=10)
env.reset()
env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True)
"""
"""
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
[(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]]
env = RailEnv(width=6,
height=2,
rail_generator=rail_from_manual_specifications_generator(specs),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
handle = env.get_agent_handles()
env.agents_position[0] = [1, 4]
env.agents_target[0] = [1, 1]
env.agents_direction[0] = 1
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset()
"""
"""
# INFINITE-LOOP TEST
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
[(7, 270), (1, 90), (1, 90), (2, 270), (2, 0), (0, 0)],
[(0, 0), (0, 0), (0, 0), (2, 180), (2, 90), (7, 90)],
[(0, 0), (0, 0), (0, 0), (7, 180), (0, 0), (0, 0)]]
env = RailEnv(width=6,
height=4,
rail_generator=rail_from_manual_specifications_generator(specs),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
handle = env.get_agent_handles()
env.agents_position[0] = [1, 3]
env.agents_target[0] = [1, 1]
env.agents_direction[0] = 1
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset()
"""
env = RailEnv(width=7,
height=7,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=2)
# Print the distance map of each cell to the target of the first agent
# for i in range(4):
# print(env.obs_builder.distance_map[0, :, :, i])
# Print the observation vector for agent 0
obs, all_rewards, done, _ = env.step({0:0})
for i in range(env.number_of_agents):
env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5)
env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True)
print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \
(turnleft+move, move to front, turnright+move)")
for step in range(100):
cmd = input(">> ")
cmds = cmd.split(" ")
action_dict = {}
i = 0
while i < len(cmds):
if cmds[i] == 'q':
import sys
sys.exit()
elif cmds[i] == 's':
obs, all_rewards, done, _ = env.step(action_dict)
action_dict = {}
print("Rewards: ", all_rewards, " [done=", done, "]")
else:
agent_id = int(cmds[i])
action = int(cmds[i+1])
action_dict[agent_id] = action
i = i+1
i += 1
env_renderer.renderEnv(show=True)
import getopt
import sys
import numpy as np
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.utils.misc import str2bool
from flatland.utils.rendertools import RenderTool
def create_env():
nAgents = 1
n_cities = 2
max_rails_between_cities = 2
max_rails_in_city = 4
seed = 0
env = RailEnv(
width=30,
height=40,
rail_generator=sparse_rail_generator(
max_num_cities=n_cities,
seed=seed,
grid_mode=True,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city
),
line_generator=sparse_line_generator(),
number_of_agents=nAgents,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
)
return env
# Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent here
class RandomAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
def act(self, state):
"""
:param state: input is the observation of the agent
:return: returns an action
"""
return np.random.choice(np.arange(self.action_size))
def step(self, memories):
"""
Step function to improve agent by adjusting policy given the observations
:param memories: SARS Tuple to be
:return:
"""
return
def save(self, filename):
# Store the current policy
return
def load(self, filename):
# Load a policy
return
def training_example(sleep_for_animation, do_rendering):
np.random.seed(1)
# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
env = create_env()
env.reset()
env_renderer = None
if do_rendering:
env_renderer = RenderTool(env)
# Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 5)
n_trials = 5
# Empty dictionary for all agent action
action_dict = dict()
print("Starting Training...")
for trials in range(1, n_trials + 1):
# Reset environment and get initial observations for all agents
obs, info = env.reset()
if env_renderer is not None:
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
score = 0
# Run episode
for step in range(500):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = agent.act(obs[a])
action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
if env_renderer is not None:
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
obs = next_obs.copy()
if done['__all__']:
break
print('Episode Nr. {}\t Score = {}'.format(trials, score))
if env_renderer is not None:
env_renderer.close_window()
def main(args):
try:
opts, args = getopt.getopt(args, "", ["sleep-for-animation=", "do_rendering=", ""])
except getopt.GetoptError as err:
print(str(err)) # will print something like "option -a not recognized"
sys.exit(2)
sleep_for_animation = True
do_rendering = True
for o, a in opts:
if o in ("--sleep-for-animation"):
sleep_for_animation = str2bool(a)
elif o in ("--do_rendering"):
do_rendering = str2bool(a)
else:
assert False, "unhandled option"
# execute example
training_example(sleep_for_animation, do_rendering)
if __name__ == '__main__':
if 'argv' in globals():
main(argv)
else:
main(sys.argv[1:])
from flatland.envs.rail_env import *
from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import *
from flatland.baselines.dueling_double_dqn import Agent
from collections import deque
import torch, random
random.seed(1)
np.random.seed(1)
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
transition_probability = [5, # empty cell - Case 0
15, # Case 1 - straight
5, # Case 2 - simple switch
1, # Case 3 - diamond crossing
1, # Case 4 - single slip
1, # Case 5 - double slip
1, # Case 6 - symmetrical
0] # Case 7 - dead end
# Example generate a random rail
env = RailEnv(width=10,
height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=3)
env_renderer = RenderTool(env, gl="QT")
handle = env.get_agent_handles()
state_size = 105
action_size = 4
n_trials = 15000
eps = 1.
eps_end = 0.005
eps_decay = 0.998
action_dict = dict()
scores_window = deque(maxlen=100)
done_window = deque(maxlen=100)
scores = []
dones_list = []
action_prob = [0]*4
agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
demo = False
def max_lt(seq, val):
"""
Return greatest item in seq for which item < val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
max = 0
idx = len(seq)-1
while idx >= 0:
if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
max = seq[idx]
idx -= 1
return max
def min_lt(seq, val):
"""
Return smallest item in seq for which item > val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
min = np.inf
idx = len(seq)-1
while idx >= 0:
if seq[idx] > val and seq[idx] < min:
min = seq[idx]
idx -= 1
return min
for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset()
for a in range(env.number_of_agents):
norm = max(1, max_lt(obs[a], np.inf))
obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1)
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
score = 0
env_done = 0
# Run episode
for step in range(100):
if demo:
env_renderer.renderEnv(show=True)
#print(step)
# Action
for a in range(env.number_of_agents):
if demo:
eps = 0
action = agent.act(np.array(obs[a]), eps=eps)
action_prob[action] += 1
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.number_of_agents):
norm = max(1, max_lt(next_obs[a], np.inf))
next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
# Update replay buffer and train agent
for a in range(env.number_of_agents):
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
score += all_rewards[a]
obs = next_obs.copy()
if done['__all__']:
env_done = 1
break
# Epsioln decay
eps = max(eps_end, eps_decay * eps) # decrease epsilon
done_window.append(env_done)
scores_window.append(score) # save most recent score
scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window)))
print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.number_of_agents,
trials,
np.mean(
scores_window),
100 * np.mean(
done_window),
eps, action_prob/np.sum(action_prob)),
end=" ")
if trials % 100 == 0:
print(
'\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.number_of_agents,
trials,
np.mean(
scores_window),
100 * np.mean(
done_window),
eps, action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(),
'../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
action_prob = [1]*4
......@@ -4,4 +4,4 @@
__author__ = """S.P. Mohanty"""
__email__ = 'mohanty@aicrowd.com'
__version__ = '0.1.1'
__version__ = '3.0.15'
import pprint
from typing import Dict, List, Optional, NamedTuple
import numpy as np
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_action_for_move
from flatland.envs.rail_trainrun_data_structures import Waypoint, Trainrun, TrainrunWaypoint
# ---- ActionPlan ---------------
# an action plan element represents the actions to be taken by an agent at the given time step
ActionPlanElement = NamedTuple('ActionPlanElement', [
('scheduled_at', int),
('action', RailEnvActions)
])
# an action plan gathers all the the actions to be taken by a single agent at the corresponding time steps
ActionPlan = List[ActionPlanElement]
# An action plan dict gathers all the actions for every agent identified by the dictionary key = agent_handle
ActionPlanDict = Dict[int, ActionPlan]
class ControllerFromTrainruns():
"""Takes train runs, derives the actions from it and re-acts them."""
pp = pprint.PrettyPrinter(indent=4)
def __init__(self,
env: RailEnv,
trainrun_dict: Dict[int, Trainrun]):
self.env: RailEnv = env
self.trainrun_dict: Dict[int, Trainrun] = trainrun_dict
self.action_plan: ActionPlanDict = [self._create_action_plan_for_agent(agent_id, chosen_path)
for agent_id, chosen_path in trainrun_dict.items()]
def get_waypoint_before_or_at_step(self, agent_id: int, step: int) -> Waypoint:
"""
Get the way point point from which the current position can be extracted.
Parameters
----------
agent_id
step
Returns
-------
WalkingElement
"""
trainrun = self.trainrun_dict[agent_id]
entry_time_step = trainrun[0].scheduled_at
# the agent has no position before and at choosing to enter the grid (one tick elapses before the agent enters the grid)
if step <= entry_time_step:
return Waypoint(position=None, direction=self.env.agents[agent_id].initial_direction)
# the agent has no position as soon as the target is reached
exit_time_step = trainrun[-1].scheduled_at
if step >= exit_time_step:
# agent loses position as soon as target cell is reached
return Waypoint(position=None, direction=trainrun[-1].waypoint.direction)
waypoint = None
for trainrun_waypoint in trainrun:
if step < trainrun_waypoint.scheduled_at:
return waypoint
if step >= trainrun_waypoint.scheduled_at:
waypoint = trainrun_waypoint.waypoint
assert waypoint is not None
return waypoint
def get_action_at_step(self, agent_id: int, current_step: int) -> Optional[RailEnvActions]:
"""
Get the current action if any is defined in the `ActionPlan`.
ASSUMPTION we assume the env has `remove_agents_at_target=True` and `activate_agents=False`!!
Parameters
----------
agent_id
current_step
Returns
-------
WalkingElement, optional
"""
for action_plan_element in self.action_plan[agent_id]:
scheduled_at = action_plan_element.scheduled_at
if scheduled_at > current_step:
return None
elif current_step == scheduled_at:
return action_plan_element.action
return None
def act(self, current_step: int) -> Dict[int, RailEnvActions]:
"""
Get the action dictionary to be replayed at the current step.
Returns only action where required (no action for done agents or those not at the beginning of the cell).
ASSUMPTION we assume the env has `remove_agents_at_target=True` and `activate_agents=False`!!
Parameters
----------
current_step: int
Returns
-------
Dict[int, RailEnvActions]
"""
action_dict = {}
for agent_id in range(len(self.env.agents)):
action: Optional[RailEnvActions] = self.get_action_at_step(agent_id, current_step)
if action is not None:
action_dict[agent_id] = action
return action_dict
def print_action_plan(self):
"""Pretty-prints `ActionPlanDict` of this `ControllerFromTrainruns` to stdout."""
self.__class__.print_action_plan_dict(self.action_plan)
@staticmethod
def print_action_plan_dict(action_plan: ActionPlanDict):
"""Pretty-prints `ActionPlanDict` to stdout."""
for agent_id, plan in enumerate(action_plan):
print("{}: ".format(agent_id))
for step in plan:
print(" {}".format(step))
@staticmethod
def assert_actions_plans_equal(expected_action_plan: ActionPlanDict, actual_action_plan: ActionPlanDict):
assert len(expected_action_plan) == len(actual_action_plan)
for k in range(len(expected_action_plan)):
assert len(expected_action_plan[k]) == len(actual_action_plan[k]), \
"len for agent {} should be the same.\n\n expected ({}) = {}\n\n actual ({}) = {}".format(
k,
len(expected_action_plan[k]),
ControllerFromTrainruns.pp.pformat(expected_action_plan[k]),
len(actual_action_plan[k]),
ControllerFromTrainruns.pp.pformat(actual_action_plan[k]))
for i in range(len(expected_action_plan[k])):
assert expected_action_plan[k][i] == actual_action_plan[k][i], \
"not the same at agent {} at step {}\n\n expected = {}\n\n actual = {}".format(
k, i,
ControllerFromTrainruns.pp.pformat(expected_action_plan[k][i]),
ControllerFromTrainruns.pp.pformat(actual_action_plan[k][i]))
assert expected_action_plan == actual_action_plan, \
"expected {}, found {}".format(expected_action_plan, actual_action_plan)
def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan:
action_plan = []
agent = self.env.agents[agent_id]
minimum_cell_time = agent.speed_counter.max_count + 1
for path_loop, trainrun_waypoint in enumerate(trainrun):
trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint
position = trainrun_waypoint.waypoint.position
if Vec2d.is_equal(agent.target, position):
break
next_trainrun_waypoint: TrainrunWaypoint = trainrun[path_loop + 1]
next_position = next_trainrun_waypoint.waypoint.position
if path_loop == 0:
self._add_action_plan_elements_for_first_path_element_of_agent(
action_plan,
trainrun_waypoint,
next_trainrun_waypoint,
minimum_cell_time
)
continue
just_before_target = Vec2d.is_equal(agent.target, next_position)
self._add_action_plan_elements_for_current_path_element(
action_plan,
minimum_cell_time,
trainrun_waypoint,
next_trainrun_waypoint)
# add a final element
if just_before_target:
self._add_action_plan_elements_for_target_at_path_element_just_before_target(
action_plan,
minimum_cell_time,
trainrun_waypoint,
next_trainrun_waypoint)
return action_plan
def _add_action_plan_elements_for_current_path_element(self,
action_plan: ActionPlan,
minimum_cell_time: int,
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint):
scheduled_at = trainrun_waypoint.scheduled_at
next_entry_value = next_trainrun_waypoint.scheduled_at
position = trainrun_waypoint.waypoint.position
direction = trainrun_waypoint.waypoint.direction
next_position = next_trainrun_waypoint.waypoint.position
next_direction = next_trainrun_waypoint.waypoint.direction
next_action = get_action_for_move(position,
direction,
next_position,
next_direction,
self.env.rail)
# if the next entry is later than minimum_cell_time, then stop here and
# move minimum_cell_time before the exit
# we have to do this since agents in the RailEnv are processed in the step() in the order of their handle
if next_entry_value > scheduled_at + minimum_cell_time:
action = ActionPlanElement(scheduled_at, RailEnvActions.STOP_MOVING)
action_plan.append(action)
action = ActionPlanElement(next_entry_value - minimum_cell_time, next_action)
action_plan.append(action)
else:
action = ActionPlanElement(scheduled_at, next_action)
action_plan.append(action)
def _add_action_plan_elements_for_target_at_path_element_just_before_target(self,
action_plan: ActionPlan,
minimum_cell_time: int,
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint):
scheduled_at = trainrun_waypoint.scheduled_at
action = ActionPlanElement(scheduled_at + minimum_cell_time, RailEnvActions.STOP_MOVING)
action_plan.append(action)
def _add_action_plan_elements_for_first_path_element_of_agent(self,
action_plan: ActionPlan,
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint,
minimum_cell_time: int):
scheduled_at = trainrun_waypoint.scheduled_at
position = trainrun_waypoint.waypoint.position
direction = trainrun_waypoint.waypoint.direction
next_position = next_trainrun_waypoint.waypoint.position
next_direction = next_trainrun_waypoint.waypoint.direction
# add intial do nothing if we do not enter immediately, actually not necessary
if scheduled_at > 0:
action = ActionPlanElement(0, RailEnvActions.DO_NOTHING)
action_plan.append(action)
# add action to enter the grid
action = ActionPlanElement(scheduled_at, RailEnvActions.MOVE_FORWARD)
action_plan.append(action)
next_action = get_action_for_move(position,
direction,
next_position,
next_direction,
self.env.rail)
# if the agent is blocked in the cell, we have to call stop upon entering!
if next_trainrun_waypoint.scheduled_at > scheduled_at + 1 + minimum_cell_time:
action = ActionPlanElement(scheduled_at + 1, RailEnvActions.STOP_MOVING)
action_plan.append(action)
# execute the action exactly minimum_cell_time before the entry into the next cell
action = ActionPlanElement(next_trainrun_waypoint.scheduled_at - minimum_cell_time, next_action)
action_plan.append(action)
from typing import Callable
from flatland.action_plan.action_plan import ControllerFromTrainruns
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_trainrun_data_structures import Waypoint
ControllerFromTrainrunsReplayerRenderCallback = Callable[[RailEnv], None]
class ControllerFromTrainrunsReplayer():
"""Allows to verify a `DeterministicController` by replaying it against a FLATland env without malfunction."""
@staticmethod
def replay_verify(ctl: ControllerFromTrainruns, env: RailEnv,
call_back: ControllerFromTrainrunsReplayerRenderCallback = lambda *a, **k: None):
"""Replays this deterministic `ActionPlan` and verifies whether it is feasible.
Parameters
----------
ctl
env
call_back
Called before/after each step() call. The env is passed to it.
"""
call_back(env)
i = 0
while not env.dones['__all__'] and i <= env._max_episode_steps:
for agent_id, agent in enumerate(env.agents):
waypoint: Waypoint = ctl.get_waypoint_before_or_at_step(agent_id, i)
assert agent.position == waypoint.position, \
"before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
waypoint.position)
actions = ctl.act(i)
obs, all_rewards, done, _ = env.step(actions)
call_back(env)
i += 1
import numpy as np
import random
from collections import namedtuple, deque
import os
from flatland.baselines.model import QNetwork, QNetwork2
import torch
import torch.nn.functional as F
import torch.optim as optim
import copy
BUFFER_SIZE = int(1e5) # replay buffer size
BATCH_SIZE = 512 # minibatch size
GAMMA = 0.99 # discount factor 0.99
TAU = 1e-3 # for soft update of target parameters
LR = 0.5e-4 # learning rate 5
UPDATE_EVERY = 10 # how often to update the network
double_dqn = True # If using double dqn algorithm
input_channels = 5 # Number of Input channels
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
print(device)
class Agent:
"""Interacts with and learns from the environment."""
def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
"""Initialize an Agent object.
Params
======
state_size (int): dimension of each state
action_size (int): dimension of each action
seed (int): random seed
"""
self.state_size = state_size
self.action_size = action_size
self.seed = random.seed(seed)
self.version = net_type
self.double_dqn = double_dqn
# Q-Network
if self.version == "Conv":
self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device)
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
else:
self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
# Replay memory
self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
# Initialize time step (for updating every UPDATE_EVERY steps)
self.t_step = 0
def save(self, filename):
torch.save(self.qnetwork_local.state_dict(), filename + ".local")
torch.save(self.qnetwork_target.state_dict(), filename + ".target")
def load(self, filename):
if os.path.exists(filename + ".local"):
self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
if os.path.exists(filename + ".target"):
self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
def step(self, state, action, reward, next_state, done):
# Save experience in replay memory
self.memory.add(state, action, reward, next_state, done)
# Learn every UPDATE_EVERY time steps.
self.t_step = (self.t_step + 1) % UPDATE_EVERY
if self.t_step == 0:
# If enough samples are available in memory, get random subset and learn
if len(self.memory) > BATCH_SIZE:
experiences = self.memory.sample()
self.learn(experiences, GAMMA)
def act(self, state, eps=0.):
"""Returns actions for given state as per current policy.
Params
======
state (array_like): current state
eps (float): epsilon, for epsilon-greedy action selection
"""
state = torch.from_numpy(state).float().unsqueeze(0).to(device)
self.qnetwork_local.eval()
with torch.no_grad():
action_values = self.qnetwork_local(state)
self.qnetwork_local.train()
# Epsilon-greedy action selection
if random.random() > eps:
return np.argmax(action_values.cpu().data.numpy())
else:
return random.choice(np.arange(self.action_size))
def learn(self, experiences, gamma):
"""Update value parameters using given batch of experience tuples.
Params
======
experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
gamma (float): discount factor
"""
states, actions, rewards, next_states, dones = experiences
# Get expected Q values from local model
Q_expected = self.qnetwork_local(states).gather(1, actions)
if self.double_dqn:
# Double DQN
q_best_action = self.qnetwork_local(next_states).max(1)[1]
Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1))
else:
# DQN
Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1)
# Compute Q targets for current states
Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
# Compute loss
loss = F.mse_loss(Q_expected, Q_targets)
# Minimize the loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# ------------------- update target network ------------------- #
self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
def soft_update(self, local_model, target_model, tau):
"""Soft update model parameters.
θ_target = τ*θ_local + (1 - τ)*θ_target
Params
======
local_model (PyTorch model): weights will be copied from
target_model (PyTorch model): weights will be copied to
tau (float): interpolation parameter
"""
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
class ReplayBuffer:
"""Fixed-size buffer to store experience tuples."""
def __init__(self, action_size, buffer_size, batch_size, seed):
"""Initialize a ReplayBuffer object.
Params
======
action_size (int): dimension of each action
buffer_size (int): maximum size of buffer
batch_size (int): size of each training batch
seed (int): random seed
"""
self.action_size = action_size
self.memory = deque(maxlen=buffer_size)
self.batch_size = batch_size
self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
self.seed = random.seed(seed)
def add(self, state, action, reward, next_state, done):
"""Add a new experience to memory."""
e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)
self.memory.append(e)
def sample(self):
"""Randomly sample a batch of experiences from memory."""
experiences = random.sample(self.memory, k=self.batch_size)
states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(
device)
dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
device)
return (states, actions, rewards, next_states, dones)
def __len__(self):
"""Return the current size of internal memory."""
return len(self.memory)
import torch.nn as nn
import torch.nn.functional as F
class QNetwork(nn.Module):
def __init__(self, state_size, action_size, seed, hidsize1=128, hidsize2=128):
super(QNetwork, self).__init__()
self.fc1_val = nn.Linear(state_size, hidsize1)
self.fc2_val = nn.Linear(hidsize1, hidsize2)
self.fc3_val = nn.Linear(hidsize2, 1)
self.fc1_adv = nn.Linear(state_size, hidsize1)
self.fc2_adv = nn.Linear(hidsize1, hidsize2)
self.fc3_adv = nn.Linear(hidsize2, action_size)
def forward(self, x):
val = F.relu(self.fc1_val(x))
val = F.relu(self.fc2_val(val))
val = self.fc3_val(val)
# advantage calculation
adv = F.relu(self.fc1_adv(x))
adv = F.relu(self.fc2_adv(adv))
adv = self.fc3_adv(adv)
return val + adv - adv.mean()
class QNetwork2(nn.Module):
def __init__(self, state_size, action_size, seed, input_channels, hidsize1=128, hidsize2=64):
super(QNetwork2, self).__init__()
self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 64, kernel_size=5, stride=3)
self.bn3 = nn.BatchNorm2d(64)
self.fc1_val = nn.Linear(6400, hidsize1)
self.fc2_val = nn.Linear(hidsize1, hidsize2)
self.fc3_val = nn.Linear(hidsize2, 1)
self.fc1_adv = nn.Linear(6400, hidsize1)
self.fc2_adv = nn.Linear(hidsize1, hidsize2)
self.fc3_adv = nn.Linear(hidsize2, action_size)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
# value function approximation
val = F.relu(self.fc1_val(x.view(x.size(0), -1)))
val = F.relu(self.fc2_val(val))
val = self.fc3_val(val)
# advantage calculation
adv = F.relu(self.fc1_adv(x.view(x.size(0), -1)))
adv = F.relu(self.fc2_adv(adv))
adv = self.fc3_adv(adv)
return val + adv - adv.mean()
......@@ -2,17 +2,110 @@
"""Console script for flatland."""
import sys
import time
import click
import numpy as np
import redis
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.evaluators.service import FlatlandRemoteEvaluationService, FLATLAND_RL_SERVICE_ID
from flatland.utils.rendertools import RenderTool
@click.command()
def main(args=None):
"""Console script for flatland."""
click.echo("Replace this message by putting your code into "
"flatland.cli.main")
click.echo("See click documentation at http://click.pocoo.org/")
def demo(args=None):
"""Demo script to check installation"""
env = RailEnv(
width=30,
height=30,
rail_generator=sparse_rail_generator(
max_num_cities=3,
grid_mode=False,
max_rails_between_cities=4,
max_rail_pairs_in_city=2,
seed=0
),
line_generator=sparse_line_generator(),
number_of_agents=5)
env._max_episode_steps = int(15 * (env.width + env.height))
env_renderer = RenderTool(env)
obs, info = env.reset()
_done = False
# Run a single episode here
step = 0
while not _done:
# Compute Action
_action = {}
for _idx, _ in enumerate(env.agents):
_action[_idx] = np.random.randint(0, 5)
obs, all_rewards, done, _ = env.step(_action)
_done = done['__all__']
step += 1
env_renderer.render_env(
show=True,
frames=False,
show_observations=False,
show_predictions=False
)
time.sleep(0.1)
return 0
@click.command()
@click.option('--tests',
type=click.Path(exists=True),
help="Path to folder containing Flatland tests",
required=True
)
@click.option('--service_id',
default=FLATLAND_RL_SERVICE_ID,
help="Evaluation Service ID. This has to match the service id on the client.",
required=False
)
@click.option('--shuffle',
type=bool,
default=False,
help="Shuffle the environments before starting evaluation.",
required=False
)
@click.option('--disable_timeouts',
default=False,
help="Disable all evaluation timeouts.",
required=False
)
@click.option('--results_path',
type=click.Path(exists=False),
default=None,
help="Path where the evaluator should write the results metadata.",
required=False
)
def evaluator(tests, service_id, shuffle, disable_timeouts, results_path):
try:
redis_connection = redis.Redis()
redis_connection.ping()
except redis.exceptions.ConnectionError as e:
raise Exception(
"\nRedis server does not seem to be running on your localhost.\n"
"Please ensure that you have a redis server running on your localhost"
)
grader = FlatlandRemoteEvaluationService(
test_env_folder=tests,
flatland_rl_service_id=service_id,
visualize=False,
result_output_path=results_path,
verbose=False,
shuffle=shuffle,
disable_timeouts=disable_timeouts
)
grader.run()
if __name__ == "__main__":
sys.exit(main()) # pragma: no cover
sys.exit(demo()) # pragma: no cover
import os
import math
import numpy as np
import gym
from gym.utils import seeding
from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector
from pettingzoo.utils import wrappers
from gym.utils import EzPickle
from pettingzoo.utils.conversions import to_parallel_wrapper
from flatland.envs.rail_env import RailEnv
from mava.wrappers.flatland import infer_observation_space, normalize_observation
from functools import partial
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
"""Adapted from
- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
- https://github.com/instadeepai/Mava/blob/HEAD/mava/wrappers/flatland.py
"""
def parallel_wrapper_fn(env_fn):
def par_fn(**kwargs):
env = env_fn(**kwargs)
env = custom_parallel_wrapper(env)
return env
return par_fn
def env(**kwargs):
env = raw_env(**kwargs)
# env = wrappers.AssertOutOfBoundsWrapper(env)
# env = wrappers.OrderEnforcingWrapper(env)
return env
parallel_env = parallel_wrapper_fn(env)
class custom_parallel_wrapper(to_parallel_wrapper):
def step(self, actions):
rewards = {a: 0 for a in self.aec_env.agents}
dones = {}
infos = {}
observations = {}
for agent in self.aec_env.agents:
try:
assert agent == self.aec_env.agent_selection, f"expected agent {agent} got agent {self.aec_env.agent_selection}, agent order is nontrivial"
except Exception as e:
# print(e)
print(self.aec_env.dones.values())
raise e
obs, rew, done, info = self.aec_env.last()
self.aec_env.step(actions.get(agent,0))
for agent in self.aec_env.agents:
rewards[agent] += self.aec_env.rewards[agent]
dones = dict(**self.aec_env.dones)
infos = dict(**self.aec_env.infos)
self.agents = self.aec_env.agents
observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents}
return observations, rewards, dones, infos
class raw_env(AECEnv, gym.Env):
metadata = {'render.modes': ['human', "rgb_array"], 'name': "flatland_pettingzoo",
'video.frames_per_second': 10,
'semantics.autoreset': False }
def __init__(self, environment = False, preprocessor = False, agent_info = False, *args, **kwargs):
# EzPickle.__init__(self, *args, **kwargs)
self._environment = environment
n_agents = self.num_agents
self._agents = [get_agent_keys(i) for i in range(n_agents)]
self._possible_agents = self.agents[:]
self._reset_next_step = True
self._agent_selector = agent_selector(self.agents)
self.num_actions = 5
self.action_spaces = {
agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents
}
self.seed()
# preprocessor must be for observation builders other than global obs
# treeobs builders would use the default preprocessor if none is
# supplied
self.preprocessor = self._obtain_preprocessor(preprocessor)
self._include_agent_info = agent_info
# observation space:
# flatland defines no observation space for an agent. Here we try
# to define the observation space. All agents are identical and would
# have the same observation space.
# Infer observation space based on returned observation
obs, _ = self._environment.reset(regenerate_rail = False, regenerate_schedule = False)
obs = self.preprocessor(obs)
self.observation_spaces = {
i: infer_observation_space(ob) for i, ob in obs.items()
}
@property
def environment(self) -> RailEnv:
"""Returns the wrapped environment."""
return self._environment
@property
def dones(self):
dones = self._environment.dones
# remove_all = dones.pop("__all__", None)
return {get_agent_keys(key): value for key, value in dones.items()}
@property
def obs_builder(self):
return self._environment.obs_builder
@property
def width(self):
return self._environment.width
@property
def height(self):
return self._environment.height
@property
def agents_data(self):
"""Rail Env Agents data."""
return self._environment.agents
@property
def num_agents(self) -> int:
"""Returns the number of trains/agents in the flatland environment"""
return int(self._environment.number_of_agents)
# def __getattr__(self, name):
# """Expose any other attributes of the underlying environment."""
# return getattr(self._environment, name)
@property
def agents(self):
return self._agents
@property
def possible_agents(self):
return self._possible_agents
def env_done(self):
return self._environment.dones["__all__"] or not self.agents
def observe(self,agent):
return self.obs.get(agent)
def last(self, observe=True):
'''
returns observation, reward, done, info for the current agent (specified by self.agent_selection)
'''
agent = self.agent_selection
observation = self.observe(agent) if observe else None
return observation, self.rewards.get(agent), self.dones.get(agent), self.infos.get(agent)
def seed(self, seed: int = None) -> None:
self._environment._seed(seed)
def state(self):
'''
Returns an observation of the global environment
'''
return None
def _clear_rewards(self):
'''
clears all items in .rewards
'''
# pass
for agent in self.rewards:
self.rewards[agent] = 0
def reset(self, *args, **kwargs):
self._reset_next_step = False
self._agents = self.possible_agents[:]
obs, info = self._environment.reset(*args, **kwargs)
observations = self._collate_obs_and_info(obs, info)
self._agent_selector.reinit(self.agents)
self.agent_selection = self._agent_selector.next()
self.rewards = dict(zip(self.agents, [0 for _ in self.agents]))
self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents]))
self.action_dict = {get_agent_handle(i):0 for i in self.possible_agents}
return observations
def step(self, action):
if self.env_done():
self._agents = []
self._reset_next_step = True
return self.last()
agent = self.agent_selection
self.action_dict[get_agent_handle(agent)] = action
if self.dones[agent]:
# Disabled.. In case we want to remove agents once done
# if self.remove_agents:
# self.agents.remove(agent)
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
self._accumulate_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
else:
self._clear_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
return obs, cumulative_reward, done, info
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
else:
self._clear_rewards()
# self._cumulative_rewards[agent] = 0
self._accumulate_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
return obs, cumulative_reward, done, info
# collate agent info and observation into a tuple, making the agents obervation to
# be a tuple of the observation from the env and the agent info
def _collate_obs_and_info(self, observes, info):
observations = {}
infos = {}
observes = self.preprocessor(observes)
for agent, obs in observes.items():
all_infos = {k: info[k][get_agent_handle(agent)] for k in info.keys()}
agent_info = np.array(
list(all_infos.values()), dtype=np.float32
)
infos[agent] = all_infos
obs = (obs, agent_info) if self._include_agent_info else obs
observations[agent] = obs
self.infos = infos
self.obs = observations
return observations
def set_probs(self, probs):
self.probs = probs
def render(self, mode='rgb_array'):
"""
This methods provides the option to render the
environment's behavior as an image or to a window.
"""
if mode == "rgb_array":
env_rgb_array = self._environment.render(mode)
if not hasattr(self, "image_shape "):
self.image_shape = env_rgb_array.shape
if not hasattr(self, "probs "):
self.probs = [[0., 0., 0., 0.]]
fig, ax = plt.subplots(figsize=(self.image_shape[1]/100, self.image_shape[0]/100),
constrained_layout=True, dpi=100)
df = pd.DataFrame(np.array(self.probs).T)
sns.barplot(x=df.index, y=0, data=df, ax=ax)
ax.set(xlabel='actions', ylabel='probs')
fig.canvas.draw()
X = np.array(fig.canvas.renderer.buffer_rgba())
Image.fromarray(X)
# Image.fromarray(X)
rgb_image = np.array(Image.fromarray(X).convert('RGB'))
plt.close(fig)
q_value_rgb_array = rgb_image
return np.append(env_rgb_array, q_value_rgb_array, axis=1)
else:
return self._environment.render(mode)
def close(self):
self._environment.close()
def _obtain_preprocessor(self, preprocessor):
"""Obtains the actual preprocessor to be used based on the supplied
preprocessor and the env's obs_builder object"""
if not isinstance(self.obs_builder, GlobalObsForRailEnv):
_preprocessor = preprocessor if preprocessor else lambda x: x
if isinstance(self.obs_builder, TreeObsForRailEnv):
_preprocessor = (
partial(
normalize_observation, tree_depth=self.obs_builder.max_depth
)
if not preprocessor
else preprocessor
)
assert _preprocessor is not None
else:
def _preprocessor(x):
return x
def returned_preprocessor(obs):
temp_obs = {}
for agent_id, ob in obs.items():
temp_obs[get_agent_keys(agent_id)] = _preprocessor(ob)
return temp_obs
return returned_preprocessor
# Utility functions
def convert_np_type(dtype, value):
return np.dtype(dtype).type(value)
def get_agent_handle(id):
"""Obtain an agents handle given its id"""
return int(id)
def get_agent_keys(id):
"""Obtain an agents handle given its id"""
return str(id)
\ No newline at end of file
id-mava[flatland]
id-mava
id-mava[tf]
supersuit
stable-baselines3
ray==1.5.2
seaborn
matplotlib
pandas
\ No newline at end of file
from ray import tune
from ray.tune.registry import register_env
# from ray.rllib.utils import try_import_tf
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
import numpy as np
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
# Custom observation builder with predictor, uncomment line below if you want to try this one
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name = "flatland_pettingzoo"
rail_env = env_generators.small_v0(seed, observation_builder)
# __sphinx_doc_begin__
def env_creator(args):
env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
return env
if __name__ == "__main__":
env_name = "flatland_pettyzoo"
register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
test_env = ParallelPettingZooEnv(env_creator({}))
obs_space = test_env.observation_space
act_space = test_env.action_space
def gen_policy(i):
config = {
"gamma": 0.99,
}
return (None, obs_space, act_space, config)
policies = {"policy_0": gen_policy(0)}
policy_ids = list(policies.keys())
tune.run(
"PPO",
name="PPO",
stop={"timesteps_total": 5000000},
checkpoint_freq=10,
local_dir="~/ray_results/"+env_name,
config={
# Environment specific
"env": env_name,
# https://github.com/ray-project/ray/issues/10761
"no_done_at_end": True,
# "soft_horizon" : True,
"num_gpus": 0,
"num_workers": 2,
"num_envs_per_worker": 1,
"compress_observations": False,
"batch_mode": 'truncate_episodes',
"clip_rewards": False,
"vf_clip_param": 500.0,
"entropy_coeff": 0.01,
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
"train_batch_size": 1000, # 5000
"rollout_fragment_length": 50, # 100
"sgd_minibatch_size": 100, # 500
"vf_share_layers": False
},
)
# __sphinx_doc_end__
import numpy as np
import os
import PIL
import shutil
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3 import PPO
import supersuit as ss
from flatland.contrib.interface import flatland_env
from flatland.contrib.utils import env_generators
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
import fnmatch
import wandb
"""
https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/tutorials/13_lines.py
"""
# Custom observation builder without predictor
# observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor
observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
seed = 10
np.random.seed(seed)
wandb_log = False
experiment_name = "flatland_pettingzoo"
try:
if os.path.isdir(experiment_name):
shutil.rmtree(experiment_name)
os.mkdir(experiment_name)
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
# rail_env = env_generators.sparse_env_small(seed, observation_builder)
rail_env = env_generators.small_v0(seed, observation_builder)
# __sphinx_doc_begin__
env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
# env = flatland_env.env(environment = rail_env, use_renderer = False)
if wandb_log:
run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True,
config={}, name=experiment_name, save_code=True)
env_steps = 1000 # 2 * env.width * env.height # Code uses 1.5 to calculate max_steps
rollout_fragment_length = 50
env = ss.pettingzoo_env_to_vec_env_v0(env)
# env.black_death = True
env = ss.concat_vec_envs_v0(env, 1, num_cpus=1, base_class='stable_baselines3')
model = PPO(MlpPolicy, env, tensorboard_log=f"/tmp/{experiment_name}", verbose=3, gamma=0.95,
n_steps=rollout_fragment_length, ent_coef=0.01,
learning_rate=5e-5, vf_coef=1, max_grad_norm=0.9, gae_lambda=1.0, n_epochs=30, clip_range=0.3,
batch_size=150, seed=seed)
# wandb.watch(model.policy.action_net,log='all', log_freq = 1)
# wandb.watch(model.policy.value_net, log='all', log_freq = 1)
train_timesteps = 100000
model.learn(total_timesteps=train_timesteps)
model.save(f"policy_flatland_{train_timesteps}")
# __sphinx_doc_end__
model = PPO.load(f"policy_flatland_{train_timesteps}")
env = flatland_env.env(environment=rail_env, use_renderer=True)
if wandb_log:
artifact = wandb.Artifact('model', type='model')
artifact.add_file(f'policy_flatland_{train_timesteps}.zip')
run.log_artifact(artifact)
# Model Interference
seed = 100
env.reset(random_seed=seed)
step = 0
ep_no = 0
frame_list = []
while ep_no < 1:
for agent in env.agent_iter():
obs, reward, done, info = env.last()
act = model.predict(obs, deterministic=True)[0] if not done else None
env.step(act)
frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
step += 1
if step % 100 == 0:
print(f"env step:{step} and action taken:{act}")
completion = env_generators.perc_completion(env)
print("Agents Completed:", completion)
completion = env_generators.perc_completion(env)
print("Final Agents Completed:", completion)
ep_no += 1
frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env.close()
env.reset(random_seed=seed+ep_no)
def find(pattern, path):
result = []
for root, dirs, files in os.walk(path):
for name in files:
if fnmatch.fnmatch(name, pattern):
result.append(os.path.join(root, name))
return result
if wandb_log:
extn = "gif"
_video_file = f'*.{extn}'
_found_videos = find(_video_file, experiment_name)
print(_found_videos)
for _found_video in _found_videos:
wandb.log({_found_video: wandb.Video(_found_video, format=extn)})
run.join()
from typing import List
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.step_utils import env_utils
class Deadlock_Checker:
def __init__(self, env):
self.env = env
self.deadlocked_agents = []
self.immediate_deadlocked = []
def reset(self) -> None:
self.deadlocked_agents = []
self.immediate_deadlocked = []
# an immediate deadlock consists of two trains "trying to pass through each other".
# An agent may have a free possible transition, but took a bad action and "ran into another train". This is now a deadlock, and the other free
# direction can not be chosen anymore!
def check_immediate_deadlocks(self, action_dict) -> List[EnvAgent]:
"""
output: list of agents who are in immediate deadlocks
"""
env = self.env
newly_deadlocked_agents = []
# TODO: check restrictions to relevant agents (status ACTIVE, etc.)
relevant_agents = [agent for agent in env.agents if agent.state != TrainState.DONE and agent.position is not None]
for agent in relevant_agents:
other_agents = [other_agent for other_agent in env.agents if other_agent != agent] # check if this is a good test for inequality. Maybe use handles...
# get the transitions the agent can take from his current position and orientation
# an indicator array of the form e.g. (0,1,1,0) meaning that he can only go to east and south, not to north and west.
possible_transitions = env.rail.get_transitions(*agent.position, agent.direction)
#print(f"possible transitions: {possible_transitions}")
# the directions are: 0(north), 1(east), 2(south) and 3(west)
#possible_directions = [direction for direction, flag in enumerate(possible_transitions) if flag == 1]
#print(f"possible directions: {possible_directions}")
################### only consider direction for actually chosen action ###############################
new_position, new_direction = env_utils.apply_action_independent(action=action_dict[agent.handle], rail=env.rail, position=agent.position, direction=agent.direction)
#assert new_direction in possible_directions, "Error, action leads to impossible direction"
assert new_position == get_new_position(agent.position, new_direction), "Error, something is wrong with new position"
opposed_agent_id = env.agent_positions[new_position] # TODO: check that agent_positions now works correctly in flatland V3 (i.e. gets correctly updated...)
# agent_positions[cell] is an agent_id if an agent is there, otherwise -1.
if opposed_agent_id != -1:
opposed_agent = env.agents[opposed_agent_id]
# other agent with opposing direction is in the way --> deadlock
# an opposing direction means having a different direction than our agent would have if he moved to the new cell. (180 degrees or 90 degrees to our agent)
if opposed_agent.direction != new_direction:
if agent not in newly_deadlocked_agents: # to avoid duplicates
newly_deadlocked_agents.append(agent)
if opposed_agent not in newly_deadlocked_agents: # to avoid duplicates
newly_deadlocked_agents.append(opposed_agent)
self.immediate_deadlocked = newly_deadlocked_agents
return newly_deadlocked_agents
# main method to check for all deadlocks
def check_deadlocks(self, action_dict) -> List[EnvAgent]:
env = self.env
relevant_agents = [agent for agent in env.agents if agent.state != TrainState.DONE and agent.position is not None]
immediate_deadlocked = self.check_immediate_deadlocks(action_dict)
self.immediate_deadlocked = immediate_deadlocked
deadlocked = immediate_deadlocked[:]
# now we have to "close": each train which is blocked by another deadlocked train becomes deadlocked itself.
still_changing = True
while still_changing:
still_changing = False # will be overwritten below if a change did occur
# check if for any agent, there is a new deadlock found
for agent in relevant_agents:
#possible_transitions = env.rail.get_transitions(*agent.position, agent.direction)
#print(f"possible transitions: {possible_transitions}")
# the directions are: 0 (north), 1(east), 2(south) and 3(west)
#possible_directions = [direction for direction, flag in enumerate(possible_transitions) if flag == 1]
#print(f"possible directions: {possible_directions}")
new_position, new_direction = env_utils.apply_action_independent(action=action_dict[agent.handle], rail=env.rail, position=agent.position, direction=agent.direction)
#assert new_direction in possible_directions, "Error, action leads to impossible direction"
assert new_position == get_new_position(agent.position, new_direction), "Error, something is wrong with new position"
opposed_agent_id = env.agent_positions[new_position]
if opposed_agent_id != -1: # there is an opposed agent there
opposed_agent = env.agents[opposed_agent_id]
if opposed_agent in deadlocked:
if agent not in deadlocked: # to avoid duplicates
deadlocked.append(agent)
still_changing = True
self.deadlocked_agents = deadlocked
return deadlocked
\ No newline at end of file