Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
9301e30b
Commit
9301e30b
authored
Sep 13, 2021
by
Dipam Chakraborty
Browse files
malfunction fix to let previous malfunction finish
parent
6db05ca9
Changes
6
Hide whitespace changes
Inline
Side-by-side
flatland/envs/rail_env.py
View file @
9301e30b
...
...
@@ -372,7 +372,7 @@ class RailEnv(Environment):
""" Generate State Transitions Signals used in the state machine """
st_signals
=
StateTransitionSignals
()
# Malfunction
onset - Malfunction starts
# Malfunction
starts when in_malfunction is set to true
st_signals
.
malfunction_onset
=
agent
.
malfunction_handler
.
in_malfunction
# Malfunction counter complete - Malfunction ends next timestep
...
...
@@ -563,7 +563,8 @@ class RailEnv(Environment):
agent
.
state_machine
.
step
()
if
agent
.
state
.
is_on_map_state
()
and
agent
.
position
is
None
:
import
pdb
;
pdb
.
set_trace
()
raise
ValueError
(
"Agent ID {} Agent State {} not matching with Agent Position {} "
.
format
(
agent
.
handle
,
str
(
agent
.
state
),
str
(
agent
.
position
)
))
# Handle done state actions, optionally remove agents
self
.
handle_done_state
(
agent
)
...
...
flatland/envs/step_utils/malfunction_handler.py
View file @
9301e30b
...
...
@@ -30,7 +30,9 @@ class MalfunctionHandler:
def
_set_malfunction_down_counter
(
self
,
val
):
if
val
<
0
:
raise
ValueError
(
"Cannot set a negative value to malfunction down counter"
)
self
.
_malfunction_down_counter
=
val
# Only set new malfunction value if old malfunction is completed
if
self
.
_malfunction_down_counter
==
0
:
self
.
_malfunction_down_counter
=
val
def
generate_malfunction
(
self
,
malfunction_generator
,
np_random
):
num_broken_steps
=
get_number_of_steps_to_break
(
malfunction_generator
,
np_random
)
...
...
@@ -40,6 +42,10 @@ class MalfunctionHandler:
if
self
.
_malfunction_down_counter
>
0
:
self
.
_malfunction_down_counter
-=
1
def
__repr__
(
self
):
return
f
"malfunction_down_counter:
{
self
.
_malfunction_down_counter
}
\
in_malfunction:
{
self
.
in_malfunction
}
"
def
to_dict
(
self
):
return
{
"malfunction_down_counter"
:
self
.
_malfunction_down_counter
}
...
...
flatland/envs/step_utils/state_machine.py
View file @
9301e30b
...
...
@@ -13,7 +13,7 @@ class TrainStateMachine:
# TODO: Important - The malfunction handling is not like proper state machine
# Both transition signals can happen at the same time
# Atleast mention it in the diagram
if
self
.
st_signals
.
malfunction
_onset
:
if
self
.
st_signals
.
in_
malfunction
:
self
.
next_state
=
TrainState
.
MALFUNCTION_OFF_MAP
elif
self
.
st_signals
.
earliest_departure_reached
:
self
.
next_state
=
TrainState
.
READY_TO_DEPART
...
...
@@ -22,7 +22,7 @@ class TrainStateMachine:
def
_handle_ready_to_depart
(
self
):
""" Can only go to MOVING if a valid action is provided """
if
self
.
st_signals
.
malfunction
_onset
:
if
self
.
st_signals
.
in_
malfunction
:
self
.
next_state
=
TrainState
.
MALFUNCTION_OFF_MAP
elif
self
.
st_signals
.
valid_movement_action_given
:
self
.
next_state
=
TrainState
.
MOVING
...
...
@@ -39,7 +39,7 @@ class TrainStateMachine:
self
.
next_state
=
TrainState
.
WAITING
def
_handle_moving
(
self
):
if
self
.
st_signals
.
malfunction
_onset
:
if
self
.
st_signals
.
in_
malfunction
:
self
.
next_state
=
TrainState
.
MALFUNCTION
elif
self
.
st_signals
.
target_reached
:
self
.
next_state
=
TrainState
.
DONE
...
...
@@ -49,7 +49,7 @@ class TrainStateMachine:
self
.
next_state
=
TrainState
.
MOVING
def
_handle_stopped
(
self
):
if
self
.
st_signals
.
malfunction
_onset
:
if
self
.
st_signals
.
in_
malfunction
:
self
.
next_state
=
TrainState
.
MALFUNCTION
elif
self
.
st_signals
.
valid_movement_action_given
:
self
.
next_state
=
TrainState
.
MOVING
...
...
flatland/envs/step_utils/states.py
View file @
9301e30b
...
...
@@ -27,7 +27,7 @@ class TrainState(IntEnum):
@
dataclass
(
repr
=
True
)
class
StateTransitionSignals
:
malfunction
_onset
:
bool
=
False
in_
malfunction
:
bool
=
False
malfunction_counter_complete
:
bool
=
False
earliest_departure_reached
:
bool
=
False
stop_action_given
:
bool
=
False
...
...
requirements_dev.txt
View file @
9301e30b
...
...
@@ -23,3 +23,4 @@ networkx
ipycanvas
graphviz
imageio
dataclasses
tests/test_utils.py
View file @
9301e30b
...
...
@@ -107,11 +107,9 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
for
a
,
test_config
in
enumerate
(
test_configs
):
agent
:
EnvAgent
=
env
.
agents
[
a
]
replay
=
test_config
.
replay
[
step
]
print
(
agent
.
position
,
replay
.
position
,
agent
.
state
,
agent
.
speed_counter
)
# import pdb; pdb.set_trace()
# _assert(a, agent.position, replay.position, 'position')
# _assert(a, agent.direction, replay.direction, 'direction')
_assert
(
a
,
agent
.
position
,
replay
.
position
,
'position'
)
_assert
(
a
,
agent
.
direction
,
replay
.
direction
,
'direction'
)
if
replay
.
state
is
not
None
:
_assert
(
a
,
agent
.
state
,
replay
.
state
,
'state'
)
...
...
@@ -129,10 +127,8 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
# As we force malfunctions on the agents we have to set a positive rate that the env
# recognizes the agent as potentially malfuncitoning
# We also set next malfunction to infitiy to avoid interference with our tests
agent
.
malfunction_data
[
'malfunction'
]
=
replay
.
set_malfunction
agent
.
malfunction_data
[
'moving_before_malfunction'
]
=
agent
.
moving
agent
.
malfunction_data
[
'fixed'
]
=
False
# _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
env
.
agents
[
a
].
malfunction_handler
.
_set_malfunction_down_counter
(
replay
.
set_malfunction
)
_assert
(
a
,
agent
.
malfunction_handler
.
malfunction_down_counter
,
replay
.
malfunction
,
'malfunction'
)
print
(
step
)
_
,
rewards_dict
,
_
,
info_dict
=
env
.
step
(
action_dict
)
if
rendering
:
...
...
@@ -143,8 +139,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
if
not
skip_reward_check
:
_assert
(
a
,
rewards_dict
[
a
],
replay
.
reward
,
'reward'
)
assert
False
def
create_and_save_env
(
file_name
:
str
,
line_generator
:
LineGenerator
,
rail_generator
:
RailGenerator
):
stochastic_data
=
MalfunctionParameters
(
malfunction_rate
=
1000
,
# Rate of malfunction occurence
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment