Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
e23d8f00
Commit
e23d8f00
authored
Oct 07, 2021
by
nimishsantosh107
Browse files
wrapper fixes - incomplete
parent
6d0f42a7
Pipeline
#8697
canceled with stages
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
flatland/contrib/wrappers/flatland_wrappers.py
View file @
e23d8f00
...
...
@@ -16,7 +16,8 @@ from flatland.core.grid.grid4_utils import get_new_position
# First of all we import the Flatland rail environment
from
flatland.utils.rendertools
import
RenderTool
,
AgentRenderVariant
from
flatland.envs.agent_utils
import
EnvAgent
,
RailAgentStatus
from
flatland.envs.agent_utils
import
EnvAgent
from
flatland.envs.step_utils.states
import
TrainState
from
flatland.envs.rail_env
import
RailEnv
,
RailEnvActions
...
...
@@ -24,20 +25,13 @@ def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
agent
=
env
.
agents
[
handle
]
if
agent
.
stat
us
==
RailAgent
Stat
us
.
READY_TO_DEPART
:
if
agent
.
stat
e
==
Train
Stat
e
.
READY_TO_DEPART
:
agent_virtual_position
=
agent
.
initial_position
elif
agent
.
stat
us
==
RailAgentStatus
.
ACTIVE
:
elif
agent
.
stat
e
.
is_on_map_state
()
:
agent_virtual_position
=
agent
.
position
elif
agent
.
status
==
RailAgentStatus
.
DONE
:
agent_virtual_position
=
agent
.
target
else
:
print
(
"no action possible!"
)
if
agent
.
status
==
RailAgentStatus
.
DONE_REMOVED
:
print
(
f
"agent status: DONE_REMOVED for agent
{
agent
.
handle
}
"
)
print
(
"to solve this problem, do not input actions for removed agents!"
)
return
[(
RailEnvActions
.
DO_NOTHING
,
0
)]
*
2
print
(
"agent status:"
)
print
(
RailAgentStatus
(
agent
.
status
))
print
(
"agent status: "
,
agent
.
state
)
#return None
# NEW: if agent is at target, DO_NOTHING, and distance is zero.
# NEW: (needs to be tested...)
...
...
@@ -58,25 +52,18 @@ def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
elif
movement
==
(
agent
.
direction
-
1
)
%
4
:
action
=
RailEnvActions
.
MOVE_LEFT
else
:
# MICHEL: prints for debugging
print
(
f
"An error occured. movement is:
{
movement
}
, agent direction is:
{
agent
.
direction
}
"
)
if
movement
==
(
agent
.
direction
+
2
)
%
4
or
(
movement
==
agent
.
direction
-
2
)
%
4
:
print
(
"it seems that we are turning by 180 degrees. Turning in a dead end?"
)
# MICHEL: can this happen when we turn 180 degrees in a dead end?
# i.e. can we then have movement == agent.direction + 2 % 4 (resp. ... == - 2 % 4)?
# TRY OUT: ASSIGN MOVE_FORWARD HERE...
action
=
RailEnvActions
.
MOVE_FORWARD
print
(
"Here we would have a ValueError..."
)
#raise ValueError("Wtf, debug this shit.")
distance
=
distance_map
[
get_new_position
(
agent_virtual_position
,
movement
)
+
(
movement
,)]
possible_steps
.
append
((
action
,
distance
))
possible_steps
=
sorted
(
possible_steps
,
key
=
lambda
step
:
step
[
1
])
# MICHEL: what is this doing?
# if there is only one path to target, this is both the shortest one and the second shortest path.
if
len
(
possible_steps
)
==
1
:
return
possible_steps
*
2
...
...
@@ -186,16 +173,9 @@ class ShortestPathActionWrapper(RailEnvWrapper):
super
().
__init__
(
env
)
#self.action_space = gym.spaces.Discrete(n=3) # 0:stop, 1:shortest path, 2:other direction
# MICHEL: we have to make sure that not agents with agent.stat
us
== DONE_REMOVED are in the action dict.
# MICHEL: we have to make sure that not agents with agent.stat
e
== DONE_REMOVED are in the action dict.
# otherwise, possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] will crash.
def
step
(
self
,
action_dict
:
Dict
[
int
,
RailEnvActions
])
->
Tuple
[
Dict
,
Dict
,
Dict
,
Dict
]:
########## MICHEL: NEW (just for debugging) ########
for
agent_id
,
action
in
action_dict
.
items
():
agent
=
self
.
agents
[
agent_id
]
# assert agent.status != RailAgentStatus.DONE_REMOVED # this comes with agent.position == None...
# assert agent.status != RailAgentStatus.DONE # not sure about this one...
print
(
f
"agent:
{
agent
}
with status:
{
agent
.
status
}
"
)
######################################################
# input: action dict with actions in [0, 1, 2].
transformed_action_dict
=
{}
...
...
@@ -207,21 +187,14 @@ class ShortestPathActionWrapper(RailEnvWrapper):
# MICHEL: how exactly do the indices work here?
#transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.rail_env, agent_id)[action - 1][0]
#print(f"possible actions sorted by distance(...) is: {possible_actions_sorted_by_distance(self.env, agent_id)}")
#assert agent.status != RailAgentStatus.DONE_REMOVED
# MICHEL: THIS LINE CRASHES WITH A "NoneType is not subscriptable" error...
assert
possible_actions_sorted_by_distance
(
self
.
env
,
agent_id
)
is
not
None
assert
possible_actions_sorted_by_distance
(
self
.
env
,
agent_id
)[
action
-
1
]
is
not
None
transformed_action_dict
[
agent_id
]
=
possible_actions_sorted_by_distance
(
self
.
env
,
agent_id
)[
action
-
1
][
0
]
obs
,
rewards
,
dones
,
info
=
self
.
env
.
step
(
transformed_action_dict
)
return
obs
,
rewards
,
dones
,
info
#def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]:
#return self.rail_env.reset(random_seed)
# MICHEL: should not be needed, as we inherit that from RailEnvWrapper...
#def reset(self, **kwargs) -> Tuple[Dict, Dict]:
# obs, info = self.env.reset(**kwargs)
# return obs, info
def
find_all_cells_where_agent_can_choose
(
env
:
RailEnv
):
...
...
@@ -236,19 +209,11 @@ def find_all_cells_where_agent_can_choose(env: RailEnv):
for
h
in
range
(
env
.
height
):
for
w
in
range
(
env
.
width
):
# MICHEL: THIS SEEMS TO BE A BUG. WRONG ODER OF COORDINATES.
# will not show up in quadratic environments.
# should be pos = (h, w)
#pos = (w, h)
# MICHEL: changed this
pos
=
(
h
,
w
)
is_switch
=
False
# Check for switch: if there is more than one outgoing transition
for
orientation
in
directions
:
#print(f"env is: {env}")
#print(f"env.rail is: {env.rail}")
possible_transitions
=
env
.
rail
.
get_transitions
(
*
pos
,
orientation
)
num_transitions
=
np
.
count_nonzero
(
possible_transitions
)
if
num_transitions
>
1
:
...
...
@@ -386,15 +351,12 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper):
self
.
skipper
.
reset_cells
()
# TODO: this is clunky..
# for easier access / checking
self
.
switches
=
self
.
skipper
.
switches
self
.
switches_neighbors
=
self
.
skipper
.
switches_neighbors
self
.
decision_cells
=
self
.
skipper
.
decision_cells
self
.
skipped_rewards
=
self
.
skipper
.
skipped_rewards
# MICHEL: trying to isolate the core part and put it into a separate method.
def
step
(
self
,
action_dict
:
Dict
[
int
,
RailEnvActions
])
->
Tuple
[
Dict
,
Dict
,
Dict
,
Dict
]:
obs
,
rewards
,
dones
,
info
=
self
.
skipper
.
no_choice_skip_step
(
action_dict
=
action_dict
)
return
obs
,
rewards
,
dones
,
info
...
...
@@ -409,4 +371,4 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper):
# resets decision cells, switches, etc. These can change with an env.reset(...)!
# needs to be done after env.reset().
self
.
skipper
.
reset_cells
()
return
obs
,
info
\ No newline at end of file
return
obs
,
info
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment