Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pranjal_dhole
Flatland
Commits
2ff77f0e
Commit
2ff77f0e
authored
3 years ago
by
mmarti
Browse files
Options
Downloads
Patches
Plain Diff
refactored SkipNoChoiceCellsWrapper by removing the Skipper class
parent
bdfd4f09
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
flatland/contrib/wrappers/flatland_wrappers.py
+64
-85
64 additions, 85 deletions
flatland/contrib/wrappers/flatland_wrappers.py
with
64 additions
and
85 deletions
flatland/contrib/wrappers/flatland_wrappers.py
+
64
−
85
View file @
2ff77f0e
...
@@ -181,110 +181,89 @@ def find_all_cells_where_agent_can_choose(env: RailEnv):
...
@@ -181,110 +181,89 @@ def find_all_cells_where_agent_can_choose(env: RailEnv):
decision_cells
=
switches
+
switches_neighbors
decision_cells
=
switches
+
switches_neighbors
return
tuple
(
map
(
set
,
(
switches
,
switches_neighbors
,
decision_cells
)))
return
tuple
(
map
(
set
,
(
switches
,
switches_neighbors
,
decision_cells
)))
class
NoChoiceCellsSkipper
:
class
SkipNoChoiceCellsWrapper
(
RailEnvWrapper
):
# env can be a real RailEnv, or anything that shares the same interface
# e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
def
__init__
(
self
,
env
:
RailEnv
,
accumulate_skipped_rewards
:
bool
,
discounting
:
float
)
->
None
:
def
__init__
(
self
,
env
:
RailEnv
,
accumulate_skipped_rewards
:
bool
,
discounting
:
float
)
->
None
:
self
.
env
=
env
super
().
__init__
(
env
)
# save these so they can be inspected easier.
self
.
accumulate_skipped_rewards
=
accumulate_skipped_rewards
self
.
discounting
=
discounting
self
.
switches
=
None
self
.
switches
=
None
self
.
switches_neighbors
=
None
self
.
switches_neighbors
=
None
self
.
decision_cells
=
None
self
.
decision_cells
=
None
self
.
accumulate_skipped_rewards
=
accumulate_skipped_rewards
self
.
discounting
=
discounting
self
.
skipped_rewards
=
defaultdict
(
list
)
self
.
skipped_rewards
=
defaultdict
(
list
)
# env.reset() can change the rail grid layout, so the switches, etc. will change! --> need to do this in reset() as well.
# sets initial values for switches, decision_cells, etc.
#self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
# compute and initialize value for switches, switches_neighbors, and decision_cells.
self
.
reset_cells
()
self
.
reset_cells
()
def
on_decision_cell
(
self
,
agent
:
EnvAgent
)
->
bool
:
def
on_decision_cell
(
self
,
agent
:
EnvAgent
)
->
bool
:
return
agent
.
position
is
None
or
agent
.
position
==
agent
.
initial_position
or
agent
.
position
in
self
.
decision_cells
return
agent
.
position
is
None
or
agent
.
position
==
agent
.
initial_position
or
agent
.
position
in
self
.
decision_cells
def
on_switch
(
self
,
agent
:
EnvAgent
)
->
bool
:
def
on_switch
(
self
,
agent
:
EnvAgent
)
->
bool
:
return
agent
.
position
in
self
.
switches
return
agent
.
position
in
self
.
switches
def
next_to_switch
(
self
,
agent
:
EnvAgent
)
->
bool
:
def
next_to_switch
(
self
,
agent
:
EnvAgent
)
->
bool
:
return
agent
.
position
in
self
.
switches_neighbors
return
agent
.
position
in
self
.
switches_neighbors
def
no_choice_skip_step
(
self
,
action_dict
:
Dict
[
int
,
RailEnvActions
])
->
Tuple
[
Dict
,
Dict
,
Dict
,
Dict
]:
o
,
r
,
d
,
i
=
{},
{},
{},
{}
# NEED TO INITIALIZE i["..."]
# as we will access i["..."][agent_id]
i
[
"
action_required
"
]
=
dict
()
i
[
"
malfunction
"
]
=
dict
()
i
[
"
speed
"
]
=
dict
()
i
[
"
state
"
]
=
dict
()
while
len
(
o
)
==
0
:
obs
,
reward
,
done
,
info
=
self
.
env
.
step
(
action_dict
)
for
agent_id
,
agent_obs
in
obs
.
items
():
if
done
[
agent_id
]
or
self
.
on_decision_cell
(
self
.
env
.
agents
[
agent_id
]):
o
[
agent_id
]
=
agent_obs
r
[
agent_id
]
=
reward
[
agent_id
]
d
[
agent_id
]
=
done
[
agent_id
]
i
[
"
action_required
"
][
agent_id
]
=
info
[
"
action_required
"
][
agent_id
]
i
[
"
malfunction
"
][
agent_id
]
=
info
[
"
malfunction
"
][
agent_id
]
i
[
"
speed
"
][
agent_id
]
=
info
[
"
speed
"
][
agent_id
]
i
[
"
state
"
][
agent_id
]
=
info
[
"
state
"
][
agent_id
]
if
self
.
accumulate_skipped_rewards
:
discounted_skipped_reward
=
r
[
agent_id
]
for
skipped_reward
in
reversed
(
self
.
skipped_rewards
[
agent_id
]):
discounted_skipped_reward
=
self
.
discounting
*
discounted_skipped_reward
+
skipped_reward
r
[
agent_id
]
=
discounted_skipped_reward
self
.
skipped_rewards
[
agent_id
]
=
[]
elif
self
.
accumulate_skipped_rewards
:
self
.
skipped_rewards
[
agent_id
].
append
(
reward
[
agent_id
])
# end of for-loop
d
[
'
__all__
'
]
=
done
[
'
__all__
'
]
action_dict
=
{}
# end of while-loop
return
o
,
r
,
d
,
i
def
reset_cells
(
self
)
->
None
:
self
.
switches
,
self
.
switches_neighbors
,
self
.
decision_cells
=
find_all_cells_where_agent_can_choose
(
self
.
env
)
# IMPORTANT: rail env should be reset() / initialized before put into this one!
class
SkipNoChoiceCellsWrapper
(
RailEnvWrapper
):
# env can be a real RailEnv, or anything that shares the same interface
# e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
def
__init__
(
self
,
env
:
RailEnv
,
accumulate_skipped_rewards
:
bool
,
discounting
:
float
)
->
None
:
super
().
__init__
(
env
)
# save these so they can be inspected easier.
self
.
accumulate_skipped_rewards
=
accumulate_skipped_rewards
self
.
discounting
=
discounting
self
.
skipper
=
NoChoiceCellsSkipper
(
env
=
self
.
env
,
accumulate_skipped_rewards
=
self
.
accumulate_skipped_rewards
,
discounting
=
self
.
discounting
)
self
.
skipper
.
reset_cells
()
def
reset_cells
(
self
)
->
None
:
self
.
switches
,
self
.
switches_neighbors
,
self
.
decision_cells
=
find_all_cells_where_agent_can_choose
(
self
.
env
)
self
.
switches
=
self
.
skipper
.
switches
self
.
switches_neighbors
=
self
.
skipper
.
switches_neighbors
self
.
decision_cells
=
self
.
skipper
.
decision_cells
self
.
skipped_rewards
=
self
.
skipper
.
skipped_rewards
def
step
(
self
,
action_dict
:
Dict
[
int
,
RailEnvActions
])
->
Tuple
[
Dict
,
Dict
,
Dict
,
Dict
]:
def
step
(
self
,
action_dict
:
Dict
[
int
,
RailEnvActions
])
->
Tuple
[
Dict
,
Dict
,
Dict
,
Dict
]:
obs
,
rewards
,
dones
,
info
=
self
.
skipper
.
no_choice_skip_step
(
action_dict
=
action_dict
)
o
,
r
,
d
,
i
=
{},
{},
{},
{}
return
obs
,
rewards
,
dones
,
info
# NEED TO INITIALIZE i["..."]
# as we will access i["..."][agent_id]
i
[
"
action_required
"
]
=
dict
()
i
[
"
malfunction
"
]
=
dict
()
i
[
"
speed
"
]
=
dict
()
i
[
"
state
"
]
=
dict
()
while
len
(
o
)
==
0
:
obs
,
reward
,
done
,
info
=
self
.
env
.
step
(
action_dict
)
for
agent_id
,
agent_obs
in
obs
.
items
():
if
done
[
agent_id
]
or
self
.
on_decision_cell
(
self
.
env
.
agents
[
agent_id
]):
o
[
agent_id
]
=
agent_obs
r
[
agent_id
]
=
reward
[
agent_id
]
d
[
agent_id
]
=
done
[
agent_id
]
i
[
"
action_required
"
][
agent_id
]
=
info
[
"
action_required
"
][
agent_id
]
i
[
"
malfunction
"
][
agent_id
]
=
info
[
"
malfunction
"
][
agent_id
]
i
[
"
speed
"
][
agent_id
]
=
info
[
"
speed
"
][
agent_id
]
i
[
"
state
"
][
agent_id
]
=
info
[
"
state
"
][
agent_id
]
if
self
.
accumulate_skipped_rewards
:
discounted_skipped_reward
=
r
[
agent_id
]
for
skipped_reward
in
reversed
(
self
.
skipped_rewards
[
agent_id
]):
discounted_skipped_reward
=
self
.
discounting
*
discounted_skipped_reward
+
skipped_reward
r
[
agent_id
]
=
discounted_skipped_reward
self
.
skipped_rewards
[
agent_id
]
=
[]
elif
self
.
accumulate_skipped_rewards
:
self
.
skipped_rewards
[
agent_id
].
append
(
reward
[
agent_id
])
# end of for-loop
d
[
'
__all__
'
]
=
done
[
'
__all__
'
]
action_dict
=
{}
# end of while-loop
return
o
,
r
,
d
,
i
# arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None
def
reset
(
self
,
**
kwargs
)
->
Tuple
[
Dict
,
Dict
]:
def
reset
(
self
,
**
kwargs
)
->
Tuple
[
Dict
,
Dict
]:
obs
,
info
=
self
.
env
.
reset
(
**
kwargs
)
obs
,
info
=
self
.
env
.
reset
(
**
kwargs
)
# resets decision cells, switches, etc. These can change with an env.reset(...)!
# resets decision cells, switches, etc. These can change with an env.reset(...)!
# needs to be done after env.reset().
# needs to be done after env.reset().
self
.
skipper
.
reset_cells
()
self
.
reset_cells
()
return
obs
,
info
return
obs
,
info
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment