Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
2ff77f0e
Commit
2ff77f0e
authored
Oct 12, 2021
by
mmarti
Browse files
refactored SkipNoChoiceCellsWrapper by removing the Skipper class
parent
bdfd4f09
Pipeline
#8732
failed with stages
in 5 minutes and 42 seconds
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
flatland/contrib/wrappers/flatland_wrappers.py
View file @
2ff77f0e
...
...
@@ -182,22 +182,26 @@ def find_all_cells_where_agent_can_choose(env: RailEnv):
return
tuple
(
map
(
set
,
(
switches
,
switches_neighbors
,
decision_cells
)))
class
NoChoiceCellsSkipper
:
class
SkipNoChoiceCellsWrapper
(
RailEnvWrapper
):
# env can be a real RailEnv, or anything that shares the same interface
# e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
def
__init__
(
self
,
env
:
RailEnv
,
accumulate_skipped_rewards
:
bool
,
discounting
:
float
)
->
None
:
self
.
env
=
env
super
().
__init__
(
env
)
# save these so they can be inspected easier.
self
.
accumulate_skipped_rewards
=
accumulate_skipped_rewards
self
.
discounting
=
discounting
self
.
switches
=
None
self
.
switches_neighbors
=
None
self
.
decision_cells
=
None
self
.
accumulate_skipped_rewards
=
accumulate_skipped_rewards
self
.
discounting
=
discounting
self
.
skipped_rewards
=
defaultdict
(
list
)
# env.reset() can change the rail grid layout, so the switches, etc. will change! --> need to do this in reset() as well.
#self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
# compute and initialize value for switches, switches_neighbors, and decision_cells.
# sets initial values for switches, decision_cells, etc.
self
.
reset_cells
()
def
on_decision_cell
(
self
,
agent
:
EnvAgent
)
->
bool
:
return
agent
.
position
is
None
or
agent
.
position
==
agent
.
initial_position
or
agent
.
position
in
self
.
decision_cells
...
...
@@ -207,7 +211,12 @@ class NoChoiceCellsSkipper:
def
next_to_switch
(
self
,
agent
:
EnvAgent
)
->
bool
:
return
agent
.
position
in
self
.
switches_neighbors
def
no_choice_skip_step
(
self
,
action_dict
:
Dict
[
int
,
RailEnvActions
])
->
Tuple
[
Dict
,
Dict
,
Dict
,
Dict
]:
def
reset_cells
(
self
)
->
None
:
self
.
switches
,
self
.
switches_neighbors
,
self
.
decision_cells
=
find_all_cells_where_agent_can_choose
(
self
.
env
)
def
step
(
self
,
action_dict
:
Dict
[
int
,
RailEnvActions
])
->
Tuple
[
Dict
,
Dict
,
Dict
,
Dict
]:
o
,
r
,
d
,
i
=
{},
{},
{},
{}
# NEED TO INITIALIZE i["..."]
...
...
@@ -222,12 +231,10 @@ class NoChoiceCellsSkipper:
for
agent_id
,
agent_obs
in
obs
.
items
():
if
done
[
agent_id
]
or
self
.
on_decision_cell
(
self
.
env
.
agents
[
agent_id
]):
o
[
agent_id
]
=
agent_obs
r
[
agent_id
]
=
reward
[
agent_id
]
d
[
agent_id
]
=
done
[
agent_id
]
i
[
"action_required"
][
agent_id
]
=
info
[
"action_required"
][
agent_id
]
i
[
"malfunction"
][
agent_id
]
=
info
[
"malfunction"
][
agent_id
]
i
[
"speed"
][
agent_id
]
=
info
[
"speed"
][
agent_id
]
...
...
@@ -235,8 +242,10 @@ class NoChoiceCellsSkipper:
if
self
.
accumulate_skipped_rewards
:
discounted_skipped_reward
=
r
[
agent_id
]
for
skipped_reward
in
reversed
(
self
.
skipped_rewards
[
agent_id
]):
discounted_skipped_reward
=
self
.
discounting
*
discounted_skipped_reward
+
skipped_reward
r
[
agent_id
]
=
discounted_skipped_reward
self
.
skipped_rewards
[
agent_id
]
=
[]
...
...
@@ -251,40 +260,10 @@ class NoChoiceCellsSkipper:
return
o
,
r
,
d
,
i
def
reset_cells
(
self
)
->
None
:
self
.
switches
,
self
.
switches_neighbors
,
self
.
decision_cells
=
find_all_cells_where_agent_can_choose
(
self
.
env
)
# IMPORTANT: rail env should be reset() / initialized before put into this one!
class
SkipNoChoiceCellsWrapper
(
RailEnvWrapper
):
# env can be a real RailEnv, or anything that shares the same interface
# e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
def
__init__
(
self
,
env
:
RailEnv
,
accumulate_skipped_rewards
:
bool
,
discounting
:
float
)
->
None
:
super
().
__init__
(
env
)
# save these so they can be inspected easier.
self
.
accumulate_skipped_rewards
=
accumulate_skipped_rewards
self
.
discounting
=
discounting
self
.
skipper
=
NoChoiceCellsSkipper
(
env
=
self
.
env
,
accumulate_skipped_rewards
=
self
.
accumulate_skipped_rewards
,
discounting
=
self
.
discounting
)
self
.
skipper
.
reset_cells
()
self
.
switches
=
self
.
skipper
.
switches
self
.
switches_neighbors
=
self
.
skipper
.
switches_neighbors
self
.
decision_cells
=
self
.
skipper
.
decision_cells
self
.
skipped_rewards
=
self
.
skipper
.
skipped_rewards
def
step
(
self
,
action_dict
:
Dict
[
int
,
RailEnvActions
])
->
Tuple
[
Dict
,
Dict
,
Dict
,
Dict
]:
obs
,
rewards
,
dones
,
info
=
self
.
skipper
.
no_choice_skip_step
(
action_dict
=
action_dict
)
return
obs
,
rewards
,
dones
,
info
# arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None
def
reset
(
self
,
**
kwargs
)
->
Tuple
[
Dict
,
Dict
]:
obs
,
info
=
self
.
env
.
reset
(
**
kwargs
)
# resets decision cells, switches, etc. These can change with an env.reset(...)!
# needs to be done after env.reset().
self
.
skipper
.
reset_cells
()
self
.
reset_cells
()
return
obs
,
info
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment