Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
B
baselines
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
xzhaoma
baselines
Commits
507f0e86
Commit
507f0e86
authored
5 years ago
by
gmollard
Browse files
Options
Downloads
Patches
Plain Diff
added simple conflict detection
parent
b9836597
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
RLLib_training/RailEnvRLLibWrapper.py
+102
-32
102 additions, 32 deletions
RLLib_training/RailEnvRLLibWrapper.py
with
102 additions
and
32 deletions
RLLib_training/RailEnvRLLibWrapper.py
+
102
−
32
View file @
507f0e86
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_env
import
RailEnv
from
ray.rllib.env.multi_agent_env
import
MultiAgentEnv
from
ray.rllib.env.multi_agent_env
import
MultiAgentEnv
from
flatland.envs.observations
import
TreeObsForRailEnv
from
flatland.envs.observations
import
TreeObsForRailEnv
from
flatland.envs.generators
import
random_rail_generator
from
ray.rllib.utils.seed
import
seed
as
set_seed
from
ray.rllib.utils.seed
import
seed
as
set_seed
from
flatland.envs.generators
import
complex_rail_generator
,
random_rail_generator
from
flatland.envs.generators
import
complex_rail_generator
,
random_rail_generator
import
numpy
as
np
import
numpy
as
np
from
flatland.envs.predictions
import
DummyPredictorForRailEnv
class
RailEnvRLLibWrapper
(
MultiAgentEnv
):
class
RailEnvRLLibWrapper
(
MultiAgentEnv
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
# width,
# height,
# rail_generator=random_rail_generator(),
# number_of_agents=1,
# obs_builder_object=TreeObsForRailEnv(max_depth=2)):
super
(
MultiAgentEnv
,
self
).
__init__
()
super
(
MultiAgentEnv
,
self
).
__init__
()
if
hasattr
(
config
,
"
vector_index
"
):
if
hasattr
(
config
,
"
vector_index
"
):
vector_index
=
config
.
vector_index
vector_index
=
config
.
vector_index
else
:
else
:
vector_index
=
1
vector_index
=
1
self
.
predefined_env
=
False
if
config
[
'
rail_generator
'
]
==
"
complex_rail_generator
"
:
if
config
[
'
rail_generator
'
]
==
"
complex_rail_generator
"
:
self
.
rail_generator
=
complex_rail_generator
(
nr_start_goal
=
config
[
'
number_of_agents
'
],
min_dist
=
5
,
self
.
rail_generator
=
complex_rail_generator
(
nr_start_goal
=
config
[
'
number_of_agents
'
],
min_dist
=
5
,
nr_extra
=
config
[
'
nr_extra
'
],
seed
=
config
[
'
seed
'
]
*
(
1
+
vector_index
))
nr_extra
=
config
[
'
nr_extra
'
],
seed
=
config
[
'
seed
'
]
*
(
1
+
vector_index
))
else
:
elif
config
[
'
rail_generator
'
]
==
"
random_rail_generator
"
:
raise
(
Error
)
self
.
rail_generator
=
random_rail_generator
()
self
.
rail_generator
=
random_rail_generator
()
elif
config
[
'
rail_generator
'
]
==
"
load_env
"
:
self
.
predefined_env
=
True
else
:
raise
(
ValueError
,
f
'
Unknown rail generator:
{
config
[
"
rail_generator
"
]
}
'
)
set_seed
(
config
[
'
seed
'
]
*
(
1
+
vector_index
))
set_seed
(
config
[
'
seed
'
]
*
(
1
+
vector_index
))
self
.
env
=
RailEnv
(
width
=
config
[
"
width
"
],
height
=
config
[
"
height
"
],
self
.
env
=
RailEnv
(
width
=
config
[
"
width
"
],
height
=
config
[
"
height
"
],
number_of_agents
=
config
[
"
number_of_agents
"
],
number_of_agents
=
config
[
"
number_of_agents
"
],
obs_builder_object
=
config
[
'
obs_builder
'
],
rail_generator
=
self
.
rail_generator
)
obs_builder_object
=
config
[
'
obs_builder
'
],
rail_generator
=
self
.
rail_generator
,
prediction_builder_object
=
DummyPredictorForRailEnv
())
# self.env.load('/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl')
if
self
.
predefined_env
:
self
.
env
.
load
(
config
[
'
load_env_path
'
])
# '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl')
self
.
width
=
self
.
env
.
width
self
.
width
=
self
.
env
.
width
self
.
height
=
self
.
env
.
height
self
.
height
=
self
.
env
.
height
...
@@ -42,19 +47,47 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
...
@@ -42,19 +47,47 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
def
reset
(
self
):
def
reset
(
self
):
self
.
agents_done
=
[]
self
.
agents_done
=
[]
obs
=
self
.
env
.
reset
()
if
self
.
predefined_env
:
obs
=
self
.
env
.
reset
(
False
,
False
)
else
:
obs
=
self
.
env
.
reset
()
predictions
=
self
.
env
.
predict
()
pred_pos
=
np
.
concatenate
([[
x
[:,
1
:
3
]]
for
x
in
list
(
predictions
.
values
())],
axis
=
0
)
o
=
dict
()
o
=
dict
()
for
i_agent
in
range
(
len
(
self
.
env
.
agents
)):
#for agent, _ in obs.items():
#o[agent] = obs[agent]
# prediction of collision that will be added to the observation
# one_hot_agent_encoding = np.zeros(len(self.env.agents))
# Allows to the agent to know which other train is is about to meet (maybe will come
# one_hot_agent_encoding[agent] += 1
# up with a priority order of trains).
# o[agent] = np.append(obs[agent], one_hot_agent_encoding)
pred_obs
=
np
.
zeros
((
len
(
predictions
[
0
]),
len
(
self
.
env
.
agents
)))
# o['agents'] = obs
for
time_offset
in
range
(
len
(
predictions
[
0
])):
# obs[0] = [obs[0], np.ones((17, 17)) * 17]
# obs['global_obs'] = np.ones((17, 17)) * 17
# We consider a time window of t-1; t+1 to find a collision
collision_window
=
list
(
range
(
max
(
time_offset
-
1
,
0
),
min
(
time_offset
+
2
,
len
(
predictions
[
0
]))))
coord_agent
=
pred_pos
[
i_agent
,
time_offset
,
0
]
+
1000
*
pred_pos
[
i_agent
,
time_offset
,
1
]
# x coordinates of all other train in the time window
x_coord_other_agents
=
pred_pos
[
list
(
range
(
i_agent
))
+
list
(
range
(
i_agent
+
1
,
len
(
self
.
env
.
agents
)))][
:,
collision_window
,
0
]
# y coordinates of all other train in the time window
y_coord_other_agents
=
pred_pos
[
list
(
range
(
i_agent
))
+
list
(
range
(
i_agent
+
1
,
len
(
self
.
env
.
agents
)))][
:,
collision_window
,
1
]
coord_other_agents
=
x_coord_other_agents
+
1000
*
y_coord_other_agents
# collision_info here contains the index of the agent colliding with the current agent
for
collision_info
in
np
.
argwhere
(
coord_agent
==
coord_other_agents
)[:,
0
]:
pred_obs
[
time_offset
,
collision_info
+
1
*
(
collision_info
>=
i_agent
)]
=
1
agent_id_one_hot
=
np
.
zeros
(
len
(
self
.
env
.
agents
))
agent_id_one_hot
[
i_agent
]
=
1
o
[
i_agent
]
=
[
obs
[
i_agent
],
agent_id_one_hot
,
pred_obs
]
self
.
rail
=
self
.
env
.
rail
self
.
rail
=
self
.
env
.
rail
...
@@ -72,16 +105,53 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
...
@@ -72,16 +105,53 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
o
=
dict
()
o
=
dict
()
# print(self.agents_done)
# print(self.agents_done)
# print(dones)
# print(dones)
for
agent
,
done
in
dones
.
items
():
if
agent
not
in
self
.
agents_done
:
for
i_agent
in
range
(
len
(
self
.
env
.
agents
)):
if
agent
!=
'
__all__
'
:
if
i_agent
not
in
self
.
agents_done
:
# o[agent] = obs[agent]
# prediction of collision that will be added to the observation
#one_hot_agent_encoding = np.zeros(len(self.env.agents))
# Allows to the agent to know which other train is is about to meet (maybe will come
#one_hot_agent_encoding[agent] += 1
# up with a priority order of trains).
o
[
agent
]
=
obs
[
agent
]
#np.append(obs[agent], one_hot_agent_encoding)
pred_obs
=
np
.
zeros
((
len
(
predictions
[
0
]),
len
(
self
.
env
.
agents
)))
r
[
agent
]
=
rewards
[
agent
]
for
time_offset
in
range
(
len
(
predictions
[
0
])):
d
[
agent
]
=
dones
[
agent
]
# We consider a time window of t-1; t+1 to find a collision
collision_window
=
list
(
range
(
max
(
time_offset
-
1
,
0
),
min
(
time_offset
+
2
,
len
(
predictions
[
0
]))))
coord_agent
=
pred_pos
[
i_agent
,
time_offset
,
0
]
+
1000
*
pred_pos
[
i_agent
,
time_offset
,
1
]
# x coordinates of all other train in the time window
x_coord_other_agents
=
pred_pos
[
list
(
range
(
i_agent
))
+
list
(
range
(
i_agent
+
1
,
len
(
self
.
env
.
agents
)))][
:,
collision_window
,
0
]
# y coordinates of all other train in the time window
y_coord_other_agents
=
pred_pos
[
list
(
range
(
i_agent
))
+
list
(
range
(
i_agent
+
1
,
len
(
self
.
env
.
agents
)))][
:,
collision_window
,
1
]
coord_other_agents
=
x_coord_other_agents
+
1000
*
y_coord_other_agents
# collision_info here contains the index of the agent colliding with the current agent
for
collision_info
in
np
.
argwhere
(
coord_agent
==
coord_other_agents
)[:,
0
]:
pred_obs
[
time_offset
,
collision_info
+
1
*
(
collision_info
>=
i_agent
)]
=
1
agent_id_one_hot
=
np
.
zeros
(
len
(
self
.
env
.
agents
))
agent_id_one_hot
[
i_agent
]
=
1
o
[
i_agent
]
=
[
obs
[
i_agent
],
agent_id_one_hot
,
pred_obs
]
r
[
i_agent
]
=
rewards
[
i_agent
]
d
[
i_agent
]
=
dones
[
i_agent
]
d
[
'
__all__
'
]
=
dones
[
'
__all__
'
]
# for agent, done in dones.items():
# if agent not in self.agents_done:
# if agent != '__all__':
# # o[agent] = obs[agent]
# #one_hot_agent_encoding = np.zeros(len(self.env.agents))
# #one_hot_agent_encoding[agent] += 1
# o[agent] = obs[agent]#np.append(obs[agent], one_hot_agent_encoding)
#
#
# d[agent] = dones[agent]
for
agent
,
done
in
dones
.
items
():
for
agent
,
done
in
dones
.
items
():
if
done
and
agent
!=
'
__all__
'
:
if
done
and
agent
!=
'
__all__
'
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment