Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
94feb039
Commit
94feb039
authored
Sep 14, 2021
by
Dipam Chakraborty
Browse files
position update allowed on cell exit and stopped state
parent
1d63feb8
Changes
4
Hide whitespace changes
Inline
Side-by-side
flatland/envs/rail_env.py
View file @
94feb039
...
...
@@ -431,7 +431,7 @@ class RailEnv(Environment):
* Block all actions when in waiting state
* Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD
"""
action
=
action_preprocessing
.
preprocess_raw_action
(
action
,
agent
.
state
)
action
=
action_preprocessing
.
preprocess_raw_action
(
action
,
agent
.
state
,
agent
.
action_saver
.
saved_action
)
action
=
action_preprocessing
.
preprocess_action_when_waiting
(
action
,
agent
.
state
)
# Try moving actions on current position
...
...
@@ -440,7 +440,6 @@ class RailEnv(Environment):
current_position
,
current_direction
=
agent
.
initial_position
,
agent
.
initial_direction
action
=
action_preprocessing
.
preprocess_moving_action
(
action
,
self
.
rail
,
current_position
,
current_direction
)
return
action
def
clear_rewards_dict
(
self
):
...
...
@@ -513,6 +512,9 @@ class RailEnv(Environment):
# Save moving actions in not already saved
agent
.
action_saver
.
save_action_if_allowed
(
preprocessed_action
,
agent
.
state
)
# Train's next position can change if current stopped in a fractional speed or train is at cell's exit
position_update_allowed
=
(
agent
.
speed_counter
.
is_cell_exit
or
agent
.
state
==
TrainState
.
STOPPED
)
# Calculate new position
# Add agent to the map if not on it yet
if
agent
.
position
is
None
and
agent
.
action_saver
.
is_action_saved
:
...
...
@@ -520,7 +522,7 @@ class RailEnv(Environment):
new_direction
=
agent
.
initial_direction
# If movement is allowed apply saved action independent of other agents
elif
agent
.
action_saver
.
is_action_saved
:
elif
agent
.
action_saver
.
is_action_saved
and
position_update_allowed
:
saved_action
=
agent
.
action_saver
.
saved_action
# Apply action independent of other agents and get temporary new position and direction
new_position
,
new_direction
=
self
.
apply_action_independent
(
saved_action
,
...
...
@@ -557,7 +559,7 @@ class RailEnv(Environment):
(
agent
.
speed_counter
.
is_cell_exit
or
agent
.
position
is
None
):
agent
.
position
=
agent_transition_data
.
position
agent
.
direction
=
agent_transition_data
.
direction
preprocessed_action
=
agent_transition_data
.
preprocessed_action
## Update states
...
...
@@ -565,9 +567,8 @@ class RailEnv(Environment):
agent
.
state_machine
.
set_transition_signals
(
state_transition_signals
)
agent
.
state_machine
.
step
()
if
agent
.
state
.
is_on_map_state
()
and
agent
.
position
is
None
:
raise
ValueError
(
"Agent ID {} Agent State {} not matching with Agent Position {} "
.
format
(
agent
.
handle
,
str
(
agent
.
state
),
str
(
agent
.
position
)
))
# Off map or on map state and position should match
state_position_sync_check
(
agent
.
state
,
agent
.
position
,
agent
.
handle
)
# Handle done state actions, optionally remove agents
self
.
handle_done_state
(
agent
)
...
...
@@ -583,7 +584,7 @@ class RailEnv(Environment):
agent
.
malfunction_handler
.
update_counter
()
# Clear old action when starting in new cell
if
agent
.
speed_counter
.
is_cell_entry
:
if
agent
.
speed_counter
.
is_cell_entry
and
agent
.
position
is
not
None
:
agent
.
action_saver
.
clear_saved_action
()
# Check if episode has ended and update rewards and dones
...
...
@@ -687,3 +688,11 @@ def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
return
False
else
:
return
pos_1
[
0
]
==
pos_2
[
0
]
and
pos_1
[
1
]
==
pos_2
[
1
]
def
state_position_sync_check
(
state
,
position
,
i_agent
):
if
state
.
is_on_map_state
()
and
position
is
None
:
raise
ValueError
(
"Agent ID {} Agent State {} is on map Agent Position {} if off map "
.
format
(
i_agent
,
str
(
state
),
str
(
position
)
))
elif
state
.
is_off_map_state
()
and
position
is
not
None
:
raise
ValueError
(
"Agent ID {} Agent State {} is off map Agent Position {} if on map "
.
format
(
i_agent
,
str
(
state
),
str
(
position
)
))
flatland/envs/step_utils/action_preprocessing.py
View file @
94feb039
...
...
@@ -11,9 +11,11 @@ def process_illegal_action(action: RailEnvActions):
return
RailEnvActions
(
action
)
def
process_do_nothing
(
state
:
TrainState
):
def
process_do_nothing
(
state
:
TrainState
,
saved_action
:
RailEnvActions
):
if
state
==
TrainState
.
MOVING
:
action
=
RailEnvActions
.
MOVE_FORWARD
elif
saved_action
:
action
=
saved_action
else
:
action
=
RailEnvActions
.
STOP_MOVING
return
action
...
...
@@ -34,7 +36,7 @@ def preprocess_action_when_waiting(action, state):
return
action
def
preprocess_raw_action
(
action
,
state
):
def
preprocess_raw_action
(
action
,
state
,
saved_action
):
"""
Preprocesses actions to handle different situations of usage of action based on context
- DO_NOTHING is converted to FORWARD if train is moving
...
...
@@ -43,7 +45,7 @@ def preprocess_raw_action(action, state):
action
=
process_illegal_action
(
action
)
if
action
==
RailEnvActions
.
DO_NOTHING
:
action
=
process_do_nothing
(
state
)
action
=
process_do_nothing
(
state
,
saved_action
)
return
action
...
...
@@ -55,6 +57,4 @@ def preprocess_moving_action(action, rail, position, direction):
if
action
in
[
RailEnvActions
.
MOVE_LEFT
,
RailEnvActions
.
MOVE_RIGHT
]:
action
=
process_left_right
(
action
,
rail
,
position
,
direction
)
if
not
check_valid_action
(
action
,
rail
,
position
,
direction
):
action
=
RailEnvActions
.
STOP_MOVING
return
action
\ No newline at end of file
flatland/envs/step_utils/malfunction_handler.py
View file @
94feb039
...
...
@@ -10,6 +10,7 @@ def get_number_of_steps_to_break(malfunction_generator, np_random):
class
MalfunctionHandler
:
def
__init__
(
self
):
self
.
_malfunction_down_counter
=
0
self
.
num_malfunctions
=
0
@
property
def
in_malfunction
(
self
):
...
...
@@ -33,6 +34,7 @@ class MalfunctionHandler:
# Only set new malfunction value if old malfunction is completed
if
self
.
_malfunction_down_counter
==
0
:
self
.
_malfunction_down_counter
=
val
self
.
num_malfunctions
+=
1
def
generate_malfunction
(
self
,
malfunction_generator
,
np_random
):
num_broken_steps
=
get_number_of_steps_to_break
(
malfunction_generator
,
np_random
)
...
...
@@ -44,16 +46,20 @@ class MalfunctionHandler:
def
__repr__
(
self
):
return
f
"malfunction_down_counter:
{
self
.
_malfunction_down_counter
}
\
in_malfunction:
{
self
.
in_malfunction
}
"
in_malfunction:
{
self
.
in_malfunction
}
\
num_malfunctions:
{
self
.
num_malfunctions
}
"
def
to_dict
(
self
):
return
{
"malfunction_down_counter"
:
self
.
_malfunction_down_counter
}
return
{
"malfunction_down_counter"
:
self
.
_malfunction_down_counter
,
"num_malfunctions"
:
self
.
num_malfunctions
}
def
from_dict
(
self
,
load_dict
):
self
.
_malfunction_down_counter
=
load_dict
[
'malfunction_down_counter'
]
self
.
num_malfunctions
=
load_dict
[
'num_malfunctions'
]
def
__eq__
(
self
,
other
):
return
self
.
_malfunction_down_counter
==
other
.
_malfunction_down_counter
return
self
.
_malfunction_down_counter
==
other
.
_malfunction_down_counter
and
\
self
.
num_malfunctions
==
other
.
num_malfunctions
...
...
flatland/envs/step_utils/state_machine.py
View file @
94feb039
...
...
@@ -31,12 +31,21 @@ class TrainStateMachine:
def
_handle_malfunction_off_map
(
self
):
if
self
.
st_signals
.
malfunction_counter_complete
:
if
self
.
st_signals
.
earliest_departure_reached
:
self
.
next_state
=
TrainState
.
READY_TO_DEPART
if
self
.
st_signals
.
valid_movement_action_given
:
self
.
next_state
=
TrainState
.
MOVING
elif
self
.
st_signals
.
stop_action_given
:
self
.
next_state
=
TrainState
.
STOPPED
else
:
self
.
next_state
=
TrainState
.
READY_TO_DEPART
else
:
self
.
next_state
=
TrainState
.
STOPPED
self
.
next_state
=
TrainState
.
WAITING
else
:
self
.
next_state
=
TrainState
.
WAITING
self
.
next_state
=
TrainState
.
MALFUNCTION_OFF_MAP
def
_handle_moving
(
self
):
if
self
.
st_signals
.
in_malfunction
:
...
...
@@ -61,7 +70,7 @@ class TrainStateMachine:
self
.
st_signals
.
valid_movement_action_given
:
self
.
next_state
=
TrainState
.
MOVING
elif
self
.
st_signals
.
malfunction_counter_complete
and
\
(
self
.
st_signals
.
stop_action_given
or
self
.
st_signals
.
movement_conflict
):
(
self
.
st_signals
.
stop_action_given
or
self
.
st_signals
.
movement_conflict
):
self
.
next_state
=
TrainState
.
STOPPED
else
:
self
.
next_state
=
TrainState
.
MALFUNCTION
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment