Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
cc04d63f
Commit
cc04d63f
authored
Sep 14, 2021
by
Dipam Chakraborty
Browse files
malfunction fix: can save actions during malfunction
parent
d4667187
Pipeline
#8481
failed with stages
in 4 minutes and 39 seconds
Changes
5
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
flatland/envs/malfunction_generators.py
View file @
cc04d63f
...
...
@@ -253,7 +253,7 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct
min_number_of_steps_broken
=
parameters
.
min_duration
max_number_of_steps_broken
=
parameters
.
max_duration
def
generator
(
agent
:
EnvAgent
=
None
,
np_random
:
RandomState
=
None
,
reset
=
False
)
->
Optional
[
Malfunction
]:
def
generator
(
np_random
:
RandomState
=
None
,
reset
=
False
)
->
Optional
[
Malfunction
]:
"""
Generate malfunctions for agents
Parameters
...
...
@@ -270,11 +270,10 @@ def malfunction_from_params(parameters: MalfunctionParameters) -> Tuple[Malfunct
if
reset
:
return
Malfunction
(
0
)
if
agent
.
malfunction_data
[
'malfunction'
]
<
1
:
if
np_random
.
rand
()
<
_malfunction_prob
(
mean_malfunction_rate
):
num_broken_steps
=
np_random
.
randint
(
min_number_of_steps_broken
,
max_number_of_steps_broken
+
1
)
+
1
return
Malfunction
(
num_broken_steps
)
if
np_random
.
rand
()
<
_malfunction_prob
(
mean_malfunction_rate
):
num_broken_steps
=
np_random
.
randint
(
min_number_of_steps_broken
,
max_number_of_steps_broken
+
1
)
return
Malfunction
(
num_broken_steps
)
return
Malfunction
(
0
)
return
generator
,
MalfunctionProcessData
(
mean_malfunction_rate
,
min_number_of_steps_broken
,
...
...
flatland/envs/rail_env.py
View file @
cc04d63f
...
...
@@ -549,7 +549,7 @@ class RailEnv(Environment):
if
agent
.
malfunction_handler
.
in_malfunction
:
movement_allowed
=
False
else
:
movement_allowed
=
self
.
motionCheck
.
check_motion
(
i_agent
,
agent
.
position
)
# TODO: Remove final_new_postion from motioncheck
movement_allowed
=
self
.
motionCheck
.
check_motion
(
i_agent
,
agent
.
position
)
# Position can be changed only if other cell is empty
# And either the speed counter completes or agent is being added to map
...
...
flatland/envs/step_utils/action_saver.py
View file @
cc04d63f
...
...
@@ -17,14 +17,10 @@ class ActionSaver:
"""
Save the action if all conditions are met
1. It is a movement based action -> Forward, Left, Right
2. Action is not already saved
3. Not in a malfunction state
4. Agent is not already done
2. Action is not already saved
3. Agent is not already done
"""
if
action
.
is_moving_action
()
and
\
not
self
.
is_action_saved
and
\
not
state
.
is_malfunction_state
()
and
\
not
state
==
TrainState
.
DONE
:
if
action
.
is_moving_action
()
and
not
self
.
is_action_saved
and
not
state
==
TrainState
.
DONE
:
self
.
saved_action
=
action
def
clear_saved_action
(
self
):
...
...
tests/test_flatland_malfunction.py
View file @
cc04d63f
...
...
@@ -190,8 +190,9 @@ def test_malfunction_before_entry():
# Test initial malfunction values for all agents
# we want some agents to be malfuncitoning already and some to be working
# we want different next_malfunction values for the agents
assert
env
.
agents
[
0
].
malfunction_data
[
'malfunction'
]
==
0
assert
env
.
agents
[
1
].
malfunction_data
[
'malfunction'
]
==
10
malfunction_values
=
[
env
.
malfunction_generator
(
env
.
np_random
).
num_broken_steps
for
_
in
range
(
1000
)]
expected_value
=
(
1
-
np
.
exp
(
-
0.5
))
*
10
assert
np
.
allclose
(
np
.
mean
(
malfunction_values
),
expected_value
,
rtol
=
0.1
),
"Mean values of malfunction don't match rate"
def
test_malfunction_values_and_behavior
():
...
...
@@ -257,7 +258,7 @@ def test_initial_malfunction():
set_penalties_for_replay
(
env
)
replay_config
=
ReplayConfig
(
replay
=
[
Replay
(
Replay
(
# 0
position
=
(
3
,
2
),
direction
=
Grid4TransitionsEnum
.
EAST
,
action
=
RailEnvActions
.
MOVE_FORWARD
,
...
...
@@ -265,7 +266,7 @@ def test_initial_malfunction():
malfunction
=
3
,
reward
=
env
.
step_penalty
# full step penalty when malfunctioning
),
Replay
(
Replay
(
# 1
position
=
(
3
,
2
),
direction
=
Grid4TransitionsEnum
.
EAST
,
action
=
RailEnvActions
.
MOVE_FORWARD
,
...
...
@@ -274,7 +275,7 @@ def test_initial_malfunction():
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
Replay
(
Replay
(
# 2
position
=
(
3
,
2
),
direction
=
Grid4TransitionsEnum
.
EAST
,
action
=
RailEnvActions
.
MOVE_FORWARD
,
...
...
@@ -282,14 +283,14 @@ def test_initial_malfunction():
reward
=
env
.
step_penalty
),
# malfunctioning ends: starting and running at speed 1.0
Replay
(
Replay
(
# 3
position
=
(
3
,
2
),
direction
=
Grid4TransitionsEnum
.
EAST
,
action
=
RailEnvActions
.
MOVE_FORWARD
,
malfunction
=
0
,
reward
=
env
.
start_penalty
+
env
.
step_penalty
*
1.0
# running at speed 1.0
),
Replay
(
Replay
(
# 4
position
=
(
3
,
3
),
direction
=
Grid4TransitionsEnum
.
EAST
,
action
=
RailEnvActions
.
MOVE_FORWARD
,
...
...
@@ -420,7 +421,7 @@ def test_initial_malfunction_do_nothing():
action
=
RailEnvActions
.
DO_NOTHING
,
malfunction
=
2
,
reward
=
env
.
step_penalty
,
# full step penalty while malfunctioning
state
=
TrainState
.
ACTIVE
state
=
TrainState
.
MOVING
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action DO_NOTHING, agent should restart without moving
...
...
@@ -431,7 +432,7 @@ def test_initial_malfunction_do_nothing():
action
=
RailEnvActions
.
DO_NOTHING
,
malfunction
=
1
,
reward
=
env
.
step_penalty
,
# full step penalty while stopped
state
=
TrainState
.
ACTIVE
state
=
TrainState
.
MOVING
),
# we haven't started moving yet --> stay here
Replay
(
...
...
@@ -440,7 +441,7 @@ def test_initial_malfunction_do_nothing():
action
=
RailEnvActions
.
DO_NOTHING
,
malfunction
=
0
,
reward
=
env
.
step_penalty
,
# full step penalty while stopped
state
=
TrainState
.
ACTIVE
state
=
TrainState
.
MOVING
),
Replay
(
...
...
@@ -449,7 +450,7 @@ def test_initial_malfunction_do_nothing():
action
=
RailEnvActions
.
MOVE_FORWARD
,
malfunction
=
0
,
reward
=
env
.
start_penalty
+
env
.
step_penalty
*
1.0
,
# start penalty + step penalty for speed 1.0
state
=
TrainState
.
ACTIVE
state
=
TrainState
.
MOVING
),
# we start to move forward --> should go to next cell now
Replay
(
position
=
(
3
,
3
),
...
...
@@ -457,7 +458,7 @@ def test_initial_malfunction_do_nothing():
action
=
RailEnvActions
.
MOVE_FORWARD
,
malfunction
=
0
,
reward
=
env
.
step_penalty
*
1.0
,
# step penalty for speed 1.0
state
=
TrainState
.
ACTIVE
state
=
TrainState
.
MOVING
)
],
speed
=
env
.
agents
[
0
].
speed_counter
.
speed
,
...
...
@@ -546,7 +547,7 @@ def test_last_malfunction_step():
env
.
reset
(
False
,
False
)
for
a_idx
in
range
(
len
(
env
.
agents
)):
env
.
agents
[
a_idx
].
position
=
env
.
agents
[
a_idx
].
initial_position
env
.
agents
[
a_idx
].
state
=
TrainState
.
ACTIVE
env
.
agents
[
a_idx
].
state
=
TrainState
.
MOVING
# Force malfunction to be off at beginning and next malfunction to happen in 2 steps
env
.
agents
[
0
].
malfunction_data
[
'next_malfunction'
]
=
2
env
.
agents
[
0
].
malfunction_data
[
'malfunction'
]
=
0
...
...
tests/test_utils.py
View file @
cc04d63f
...
...
@@ -107,7 +107,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
for
a
,
test_config
in
enumerate
(
test_configs
):
agent
:
EnvAgent
=
env
.
agents
[
a
]
replay
=
test_config
.
replay
[
step
]
print
(
agent
.
position
,
replay
.
position
,
agent
.
state
,
agent
.
speed_counter
)
_assert
(
a
,
agent
.
position
,
replay
.
position
,
'position'
)
_assert
(
a
,
agent
.
direction
,
replay
.
direction
,
'direction'
)
if
replay
.
state
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment