Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
3a17887c
Commit
3a17887c
authored
Sep 05, 2021
by
nilabha
Browse files
Flatland3 pettingzoo
parent
a6923d7e
Changes
7
Hide whitespace changes
Inline
Side-by-side
flatland/contrib/interface/flatland_env.py
0 → 100644
View file @
3a17887c
import
os
import
math
import
numpy
as
np
import
gym
from
gym.utils
import
seeding
from
pettingzoo
import
AECEnv
from
pettingzoo.utils
import
agent_selector
from
pettingzoo.utils
import
wrappers
from
gym.utils
import
EzPickle
from
pettingzoo.utils.conversions
import
to_parallel_wrapper
from
flatland.envs.rail_env
import
RailEnv
from
mava.wrappers.flatland
import
infer_observation_space
,
normalize_observation
from
functools
import
partial
from
flatland.envs.observations
import
GlobalObsForRailEnv
,
TreeObsForRailEnv
"""Adapted from
- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
- https://github.com/instadeepai/Mava/blob/HEAD/mava/wrappers/flatland.py
"""
def
parallel_wrapper_fn
(
env_fn
):
def
par_fn
(
**
kwargs
):
env
=
env_fn
(
**
kwargs
)
env
=
custom_parallel_wrapper
(
env
)
return
env
return
par_fn
def
env
(
**
kwargs
):
env
=
raw_env
(
**
kwargs
)
# env = wrappers.AssertOutOfBoundsWrapper(env)
# env = wrappers.OrderEnforcingWrapper(env)
return
env
parallel_env
=
parallel_wrapper_fn
(
env
)
class
custom_parallel_wrapper
(
to_parallel_wrapper
):
def
step
(
self
,
actions
):
rewards
=
{
a
:
0
for
a
in
self
.
aec_env
.
agents
}
dones
=
{}
infos
=
{}
observations
=
{}
for
agent
in
self
.
aec_env
.
agents
:
try
:
assert
agent
==
self
.
aec_env
.
agent_selection
,
f
"expected agent
{
agent
}
got agent
{
self
.
aec_env
.
agent_selection
}
, agent order is nontrivial"
except
Exception
as
e
:
# print(e)
print
(
self
.
aec_env
.
dones
.
values
())
raise
e
obs
,
rew
,
done
,
info
=
self
.
aec_env
.
last
()
self
.
aec_env
.
step
(
actions
.
get
(
agent
,
0
))
for
agent
in
self
.
aec_env
.
agents
:
rewards
[
agent
]
+=
self
.
aec_env
.
rewards
[
agent
]
dones
=
dict
(
**
self
.
aec_env
.
dones
)
infos
=
dict
(
**
self
.
aec_env
.
infos
)
self
.
agents
=
self
.
aec_env
.
agents
observations
=
{
agent
:
self
.
aec_env
.
observe
(
agent
)
for
agent
in
self
.
aec_env
.
agents
}
return
observations
,
rewards
,
dones
,
infos
class
raw_env
(
AECEnv
,
gym
.
Env
):
metadata
=
{
'render.modes'
:
[
'human'
,
"rgb_array"
],
'name'
:
"flatland_pettingzoo"
,
'video.frames_per_second'
:
10
,
'semantics.autoreset'
:
False
}
def
__init__
(
self
,
environment
=
False
,
preprocessor
=
False
,
agent_info
=
False
,
use_renderer
=
False
,
*
args
,
**
kwargs
):
# EzPickle.__init__(self, *args, **kwargs)
self
.
_environment
=
environment
self
.
use_renderer
=
use_renderer
self
.
renderer
=
None
if
self
.
use_renderer
:
self
.
initialize_renderer
()
n_agents
=
self
.
num_agents
self
.
_agents
=
[
get_agent_keys
(
i
)
for
i
in
range
(
n_agents
)]
self
.
_possible_agents
=
self
.
agents
[:]
self
.
_reset_next_step
=
True
self
.
_agent_selector
=
agent_selector
(
self
.
agents
)
self
.
num_actions
=
5
self
.
action_spaces
=
{
agent
:
gym
.
spaces
.
Discrete
(
self
.
num_actions
)
for
agent
in
self
.
possible_agents
}
self
.
seed
()
# preprocessor must be for observation builders other than global obs
# treeobs builders would use the default preprocessor if none is
# supplied
self
.
preprocessor
=
self
.
_obtain_preprocessor
(
preprocessor
)
self
.
_include_agent_info
=
agent_info
# observation space:
# flatland defines no observation space for an agent. Here we try
# to define the observation space. All agents are identical and would
# have the same observation space.
# Infer observation space based on returned observation
obs
,
_
=
self
.
_environment
.
reset
(
regenerate_rail
=
False
,
regenerate_schedule
=
False
)
obs
=
self
.
preprocessor
(
obs
)
self
.
observation_spaces
=
{
i
:
infer_observation_space
(
ob
)
for
i
,
ob
in
obs
.
items
()
}
@
property
def
environment
(
self
)
->
RailEnv
:
"""Returns the wrapped environment."""
return
self
.
_environment
@
property
def
dones
(
self
):
dones
=
self
.
_environment
.
dones
# remove_all = dones.pop("__all__", None)
return
{
get_agent_keys
(
key
):
value
for
key
,
value
in
dones
.
items
()}
@
property
def
obs_builder
(
self
):
return
self
.
_environment
.
obs_builder
@
property
def
width
(
self
):
return
self
.
_environment
.
width
@
property
def
height
(
self
):
return
self
.
_environment
.
height
@
property
def
agents_data
(
self
):
"""Rail Env Agents data."""
return
self
.
_environment
.
agents
@
property
def
num_agents
(
self
)
->
int
:
"""Returns the number of trains/agents in the flatland environment"""
return
int
(
self
.
_environment
.
number_of_agents
)
# def __getattr__(self, name):
# """Expose any other attributes of the underlying environment."""
# return getattr(self._environment, name)
@
property
def
agents
(
self
):
return
self
.
_agents
@
property
def
possible_agents
(
self
):
return
self
.
_possible_agents
def
env_done
(
self
):
return
self
.
_environment
.
dones
[
"__all__"
]
or
not
self
.
agents
def
observe
(
self
,
agent
):
return
self
.
obs
.
get
(
agent
)
def
last
(
self
,
observe
=
True
):
'''
returns observation, reward, done, info for the current agent (specified by self.agent_selection)
'''
agent
=
self
.
agent_selection
observation
=
self
.
observe
(
agent
)
if
observe
else
None
return
observation
,
self
.
rewards
.
get
(
agent
),
self
.
dones
.
get
(
agent
),
self
.
infos
.
get
(
agent
)
def
seed
(
self
,
seed
:
int
=
None
)
->
None
:
self
.
_environment
.
_seed
(
seed
)
def
state
(
self
):
'''
Returns an observation of the global environment
'''
return
None
def
_clear_rewards
(
self
):
'''
clears all items in .rewards
'''
# pass
for
agent
in
self
.
rewards
:
self
.
rewards
[
agent
]
=
0
def
reset
(
self
,
*
args
,
**
kwargs
):
self
.
_reset_next_step
=
False
self
.
_agents
=
self
.
possible_agents
[:]
if
self
.
use_renderer
:
if
self
.
renderer
:
#TODO: Errors with RLLib with renderer as None.
self
.
renderer
.
reset
()
obs
,
info
=
self
.
_environment
.
reset
(
*
args
,
**
kwargs
)
observations
=
self
.
_collate_obs_and_info
(
obs
,
info
)
self
.
_agent_selector
.
reinit
(
self
.
agents
)
self
.
agent_selection
=
self
.
_agent_selector
.
next
()
self
.
rewards
=
dict
(
zip
(
self
.
agents
,
[
0
for
_
in
self
.
agents
]))
self
.
_cumulative_rewards
=
dict
(
zip
(
self
.
agents
,
[
0
for
_
in
self
.
agents
]))
self
.
action_dict
=
{
get_agent_handle
(
i
):
0
for
i
in
self
.
possible_agents
}
return
observations
def
step
(
self
,
action
):
if
self
.
env_done
():
self
.
_agents
=
[]
self
.
_reset_next_step
=
True
return
self
.
last
()
agent
=
self
.
agent_selection
self
.
action_dict
[
get_agent_handle
(
agent
)]
=
action
if
self
.
dones
[
agent
]:
# Disabled.. In case we want to remove agents once done
# if self.remove_agents:
# self.agents.remove(agent)
if
self
.
_agent_selector
.
is_last
():
observations
,
rewards
,
dones
,
infos
=
self
.
_environment
.
step
(
self
.
action_dict
)
self
.
rewards
=
{
get_agent_keys
(
key
):
value
for
key
,
value
in
rewards
.
items
()}
if
observations
:
observations
=
self
.
_collate_obs_and_info
(
observations
,
infos
)
self
.
_accumulate_rewards
()
obs
,
cumulative_reward
,
done
,
info
=
self
.
last
()
self
.
agent_selection
=
self
.
_agent_selector
.
next
()
else
:
self
.
_clear_rewards
()
obs
,
cumulative_reward
,
done
,
info
=
self
.
last
()
self
.
agent_selection
=
self
.
_agent_selector
.
next
()
return
obs
,
cumulative_reward
,
done
,
info
if
self
.
_agent_selector
.
is_last
():
observations
,
rewards
,
dones
,
infos
=
self
.
_environment
.
step
(
self
.
action_dict
)
self
.
rewards
=
{
get_agent_keys
(
key
):
value
for
key
,
value
in
rewards
.
items
()}
if
observations
:
observations
=
self
.
_collate_obs_and_info
(
observations
,
infos
)
else
:
self
.
_clear_rewards
()
# self._cumulative_rewards[agent] = 0
self
.
_accumulate_rewards
()
obs
,
cumulative_reward
,
done
,
info
=
self
.
last
()
self
.
agent_selection
=
self
.
_agent_selector
.
next
()
return
obs
,
cumulative_reward
,
done
,
info
# collate agent info and observation into a tuple, making the agents obervation to
# be a tuple of the observation from the env and the agent info
def
_collate_obs_and_info
(
self
,
observes
,
info
):
observations
=
{}
infos
=
{}
observes
=
self
.
preprocessor
(
observes
)
for
agent
,
obs
in
observes
.
items
():
all_infos
=
{
k
:
info
[
k
][
get_agent_handle
(
agent
)]
for
k
in
info
.
keys
()}
agent_info
=
np
.
array
(
list
(
all_infos
.
values
()),
dtype
=
np
.
float32
)
infos
[
agent
]
=
all_infos
obs
=
(
obs
,
agent_info
)
if
self
.
_include_agent_info
else
obs
observations
[
agent
]
=
obs
self
.
infos
=
infos
self
.
obs
=
observations
return
observations
def
render
(
self
,
mode
=
'human'
):
"""
This methods provides the option to render the
environment's behavior to a window which should be
readable to the human eye if mode is set to 'human'.
"""
if
not
self
.
use_renderer
:
return
if
not
self
.
renderer
:
self
.
initialize_renderer
(
mode
=
mode
)
return
self
.
update_renderer
(
mode
=
mode
)
def
initialize_renderer
(
self
,
mode
=
"human"
):
# Initiate the renderer
from
flatland.utils.rendertools
import
RenderTool
,
AgentRenderVariant
self
.
renderer
=
RenderTool
(
self
.
environment
,
gl
=
"PGL"
,
# gl="TKPILSVG",
agent_render_variant
=
AgentRenderVariant
.
ONE_STEP_BEHIND
,
show_debug
=
False
,
screen_height
=
600
,
# Adjust these parameters to fit your resolution
screen_width
=
800
)
# Adjust these parameters to fit your resolution
self
.
renderer
.
show
=
False
def
update_renderer
(
self
,
mode
=
'human'
):
image
=
self
.
renderer
.
render_env
(
show
=
False
,
show_observations
=
False
,
show_predictions
=
False
,
return_image
=
True
)
return
image
[:,:,:
3
]
def
set_renderer
(
self
,
renderer
):
self
.
use_renderer
=
renderer
if
self
.
use_renderer
:
self
.
initialize_renderer
(
mode
=
self
.
use_renderer
)
def
close
(
self
):
# self._environment.close()
if
self
.
renderer
:
try
:
if
self
.
renderer
.
show
:
self
.
renderer
.
close_window
()
except
Exception
as
e
:
print
(
"Could Not close window due to:"
,
e
)
self
.
renderer
=
None
def
_obtain_preprocessor
(
self
,
preprocessor
):
"""Obtains the actual preprocessor to be used based on the supplied
preprocessor and the env's obs_builder object"""
if
not
isinstance
(
self
.
obs_builder
,
GlobalObsForRailEnv
):
_preprocessor
=
preprocessor
if
preprocessor
else
lambda
x
:
x
if
isinstance
(
self
.
obs_builder
,
TreeObsForRailEnv
):
_preprocessor
=
(
partial
(
normalize_observation
,
tree_depth
=
self
.
obs_builder
.
max_depth
)
if
not
preprocessor
else
preprocessor
)
assert
_preprocessor
is
not
None
else
:
def
_preprocessor
(
x
):
return
x
def
returned_preprocessor
(
obs
):
temp_obs
=
{}
for
agent_id
,
ob
in
obs
.
items
():
temp_obs
[
get_agent_keys
(
agent_id
)]
=
_preprocessor
(
ob
)
return
temp_obs
return
returned_preprocessor
# Utility functions
def
convert_np_type
(
dtype
,
value
):
return
np
.
dtype
(
dtype
).
type
(
value
)
def
get_agent_handle
(
id
):
"""Obtain an agents handle given its id"""
return
int
(
id
)
def
get_agent_keys
(
id
):
"""Obtain an agents handle given its id"""
return
str
(
id
)
\ No newline at end of file
flatland/contrib/requirements_training.txt
0 → 100644
View file @
3a17887c
id-mava[flatland]
id-mava
id-mava[tf]
supersuit
stable-baselines3
ray==1.5.2
\ No newline at end of file
flatland/contrib/training/flatland_pettingzoo_rllib.py
0 → 100644
View file @
3a17887c
from
ray
import
tune
from
ray.tune.registry
import
register_env
# from ray.rllib.utils import try_import_tf
from
ray.rllib.env.wrappers.pettingzoo_env
import
ParallelPettingZooEnv
import
numpy
as
np
from
flatland.contrib.interface
import
flatland_env
from
flatland.contrib.utils
import
env_generators
from
flatland.envs.observations
import
TreeObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
# Custom observation builder with predictor, uncomment line below if you want to try this one
observation_builder
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
(
30
))
seed
=
10
np
.
random
.
seed
(
seed
)
wandb_log
=
False
experiment_name
=
"flatland_pettingzoo"
rail_env
=
env_generators
.
small_v0
(
seed
,
observation_builder
)
# __sphinx_doc_begin__
def
env_creator
(
args
):
env
=
flatland_env
.
parallel_env
(
environment
=
rail_env
,
use_renderer
=
False
)
return
env
if
__name__
==
"__main__"
:
env_name
=
"flatland_pettyzoo"
register_env
(
env_name
,
lambda
config
:
ParallelPettingZooEnv
(
env_creator
(
config
)))
test_env
=
ParallelPettingZooEnv
(
env_creator
({}))
obs_space
=
test_env
.
observation_space
act_space
=
test_env
.
action_space
def
gen_policy
(
i
):
config
=
{
"gamma"
:
0.99
,
}
return
(
None
,
obs_space
,
act_space
,
config
)
policies
=
{
"policy_0"
:
gen_policy
(
0
)}
policy_ids
=
list
(
policies
.
keys
())
tune
.
run
(
"PPO"
,
name
=
"PPO"
,
stop
=
{
"timesteps_total"
:
5000000
},
checkpoint_freq
=
10
,
local_dir
=
"~/ray_results/"
+
env_name
,
config
=
{
# Environment specific
"env"
:
env_name
,
# https://github.com/ray-project/ray/issues/10761
"no_done_at_end"
:
True
,
# "soft_horizon" : True,
"num_gpus"
:
0
,
"num_workers"
:
2
,
"num_envs_per_worker"
:
1
,
"compress_observations"
:
False
,
"batch_mode"
:
'truncate_episodes'
,
"clip_rewards"
:
False
,
"vf_clip_param"
:
500.0
,
"entropy_coeff"
:
0.01
,
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
"train_batch_size"
:
1000
,
# 5000
"rollout_fragment_length"
:
50
,
# 100
"sgd_minibatch_size"
:
100
,
# 500
"vf_share_layers"
:
False
},
)
# __sphinx_doc_end__
flatland/contrib/training/flatland_pettingzoo_stable_baselines.py
0 → 100644
View file @
3a17887c
import
numpy
as
np
import
os
import
PIL
import
shutil
from
stable_baselines3.ppo
import
MlpPolicy
from
stable_baselines3
import
PPO
import
supersuit
as
ss
from
flatland.contrib.interface
import
flatland_env
from
flatland.contrib.utils
import
env_generators
from
flatland.envs.observations
import
TreeObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
import
fnmatch
import
wandb
"""
https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/tutorials/13_lines.py
"""
# Custom observation builder without predictor
# observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor
observation_builder
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
(
30
))
seed
=
10
np
.
random
.
seed
(
seed
)
wandb_log
=
False
experiment_name
=
"flatland_pettingzoo"
try
:
if
os
.
path
.
isdir
(
experiment_name
):
shutil
.
rmtree
(
experiment_name
)
os
.
mkdir
(
experiment_name
)
except
OSError
as
e
:
print
(
"Error: %s - %s."
%
(
e
.
filename
,
e
.
strerror
))
# rail_env = env_generators.sparse_env_small(seed, observation_builder)
rail_env
=
env_generators
.
small_v0
(
seed
,
observation_builder
)
# __sphinx_doc_begin__
env
=
flatland_env
.
parallel_env
(
environment
=
rail_env
,
use_renderer
=
False
)
# env = flatland_env.env(environment = rail_env, use_renderer = False)
if
wandb_log
:
run
=
wandb
.
init
(
project
=
"flatland2021"
,
entity
=
"nilabha2007"
,
sync_tensorboard
=
True
,
config
=
{},
name
=
experiment_name
,
save_code
=
True
)
env_steps
=
1000
# 2 * env.width * env.height # Code uses 1.5 to calculate max_steps
rollout_fragment_length
=
50
env
=
ss
.
pettingzoo_env_to_vec_env_v0
(
env
)
# env.black_death = True
env
=
ss
.
concat_vec_envs_v0
(
env
,
1
,
num_cpus
=
1
,
base_class
=
'stable_baselines3'
)
model
=
PPO
(
MlpPolicy
,
env
,
tensorboard_log
=
f
"/tmp/
{
experiment_name
}
"
,
verbose
=
3
,
gamma
=
0.95
,
n_steps
=
rollout_fragment_length
,
ent_coef
=
0.01
,
learning_rate
=
5e-5
,
vf_coef
=
1
,
max_grad_norm
=
0.9
,
gae_lambda
=
1.0
,
n_epochs
=
30
,
clip_range
=
0.3
,
batch_size
=
150
,
seed
=
seed
)
# wandb.watch(model.policy.action_net,log='all', log_freq = 1)
# wandb.watch(model.policy.value_net, log='all', log_freq = 1)
train_timesteps
=
100000
model
.
learn
(
total_timesteps
=
train_timesteps
)
model
.
save
(
f
"policy_flatland_
{
train_timesteps
}
"
)
# __sphinx_doc_end__
model
=
PPO
.
load
(
f
"policy_flatland_
{
train_timesteps
}
"
)
env
=
flatland_env
.
env
(
environment
=
rail_env
,
use_renderer
=
True
)
if
wandb_log
:
artifact
=
wandb
.
Artifact
(
'model'
,
type
=
'model'
)
artifact
.
add_file
(
f
'policy_flatland_
{
train_timesteps
}
.zip'
)
run
.
log_artifact
(
artifact
)
# Model Interference
seed
=
100
env
.
reset
(
random_seed
=
seed
)
step
=
0
ep_no
=
0
frame_list
=
[]
while
ep_no
<
1
:
for
agent
in
env
.
agent_iter
():
obs
,
reward
,
done
,
info
=
env
.
last
()
act
=
model
.
predict
(
obs
,
deterministic
=
True
)[
0
]
if
not
done
else
None
env
.
step
(
act
)
frame_list
.
append
(
PIL
.
Image
.
fromarray
(
env
.
render
(
mode
=
'rgb_array'
)))
step
+=
1
if
step
%
100
==
0
:
print
(
f
"env step:
{
step
}
and action taken:
{
act
}
"
)
completion
=
env_generators
.
perc_completion
(
env
)
print
(
"Agents Completed:"
,
completion
)
completion
=
env_generators
.
perc_completion
(
env
)
print
(
"Final Agents Completed:"
,
completion
)
ep_no
+=
1
frame_list
[
0
].
save
(
f
"
{
experiment_name
}{
os
.
sep
}
pettyzoo_out_
{
ep_no
}
.gif"
,
save_all
=
True
,
append_images
=
frame_list
[
1
:],
duration
=
3
,
loop
=
0
)
frame_list
=
[]
env
.
close
()
env
.
reset
(
random_seed
=
seed
+
ep_no
)
def
find
(
pattern
,
path
):
result
=
[]
for
root
,
dirs
,
files
in
os
.
walk
(
path
):
for
name
in
files
:
if
fnmatch
.
fnmatch
(
name
,
pattern
):
result
.
append
(
os
.
path
.
join
(
root
,
name
))
return
result
if
wandb_log
:
extn
=
"gif"
_video_file
=
f
'*.
{
extn
}
'
_found_videos
=
find
(
_video_file
,
experiment_name
)
print
(
_found_videos
)
for
_found_video
in
_found_videos
:
wandb
.
log
({
_found_video
:
wandb
.
Video
(
_found_video
,
format
=
extn
)})
run
.
join
()
flatland/contrib/utils/env_generators.py
0 → 100644
View file @
3a17887c
import
logging
import
random
import
numpy
as
np
from
typing
import
NamedTuple
from
flatland.envs.malfunction_generators
import
malfunction_from_params
,
MalfunctionParameters
,
ParamMalfunctionGen
,
no_malfunction_generator
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_generators
import
sparse_rail_generator
from
flatland.envs.line_generators
import
sparse_line_generator
from
flatland.envs.agent_utils
import
RailAgentStatus
from
flatland.core.grid.grid4_utils
import
get_new_position
MalfunctionParameters
=
NamedTuple
(
'MalfunctionParameters'
,
[(
'malfunction_rate'
,
float
),
(
'min_duration'
,
int
),
(
'max_duration'
,
int
)])
def
get_shortest_path_action
(
env
,
handle
):
distance_map
=
env
.
distance_map
.
get
()
agent
=
env
.
agents
[
handle
]
if
agent
.
status
==
RailAgentStatus
.
READY_TO_DEPART
:
agent_virtual_position
=
agent
.
initial_position