Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
3c1fff85
Commit
3c1fff85
authored
Sep 15, 2021
by
Dipam Chakraborty
Browse files
update positions based on state
parent
608a75b5
Pipeline
#8501
failed with stages
in 6 minutes and 34 seconds
Changes
4
Pipelines
1
Show whitespace changes
Inline
Side-by-side
flatland/envs/rail_env.py
View file @
3c1fff85
...
...
@@ -557,7 +557,6 @@ class RailEnv(Environment):
for
agent
in
self
.
agents
:
i_agent
=
agent
.
handle
agent_transition_data
=
temp_transition_data
[
i_agent
]
## Update positions
if
agent
.
malfunction_handler
.
in_malfunction
:
...
...
@@ -565,13 +564,10 @@ class RailEnv(Environment):
else
:
movement_allowed
=
self
.
motionCheck
.
check_motion
(
i_agent
,
agent
.
position
)
# Position can be changed only if other cell is empty
# And either the speed counter completes or agent is being added to map
if
movement_allowed
and
\
(
agent
.
speed_counter
.
is_cell_exit
or
agent
.
position
is
None
):
agent
.
position
=
agent_transition_data
.
position
agent
.
direction
=
agent_transition_data
.
direction
# Fetch the saved transition data
agent_transition_data
=
temp_transition_data
[
i_agent
]
preprocessed_action
=
agent_transition_data
.
preprocessed_action
## Update states
...
...
@@ -579,6 +575,19 @@ class RailEnv(Environment):
agent
.
state_machine
.
set_transition_signals
(
state_transition_signals
)
agent
.
state_machine
.
step
()
# Needed when not removing agents at target
movement_allowed
=
movement_allowed
and
agent
.
state
!=
TrainState
.
DONE
# Agent is being added to map
if
agent
.
state
.
is_on_map_state
()
and
agent
.
state_machine
.
previous_state
.
is_off_map_state
():
agent
.
position
=
agent
.
initial_position
agent
.
direction
=
agent
.
initial_direction
# Speed counter completes
elif
movement_allowed
and
(
agent
.
speed_counter
.
is_cell_exit
):
agent
.
position
=
agent_transition_data
.
position
agent
.
direction
=
agent_transition_data
.
direction
agent
.
state_machine
.
update_if_reached
(
agent
.
position
,
agent
.
target
)
# Off map or on map state and position should match
env_utils
.
state_position_sync_check
(
agent
.
state
,
agent
.
position
,
agent
.
handle
)
...
...
flatland/envs/step_utils/state_machine.py
View file @
3c1fff85
from
flatland.envs.step_utils.states
import
TrainState
,
StateTransitionSignals
from
flatland.envs.step_utils
import
env_utils
class
TrainStateMachine
:
def
__init__
(
self
,
initial_state
=
TrainState
.
WAITING
):
...
...
@@ -136,6 +137,13 @@ class TrainStateMachine:
self
.
st_signals
=
StateTransitionSignals
()
self
.
clear_next_state
()
def
update_if_reached
(
self
,
position
,
target
):
# Need to do this hacky fix for now, state machine needed speed related states for proper handling
self
.
st_signals
.
target_reached
=
env_utils
.
fast_position_equal
(
position
,
target
)
if
self
.
st_signals
.
target_reached
:
self
.
next_state
=
TrainState
.
DONE
self
.
set_state
(
self
.
next_state
)
@
property
def
state
(
self
):
return
self
.
_state
...
...
tests/test_action_plan.py
View file @
3c1fff85
...
...
@@ -21,7 +21,8 @@ def test_action_plan(rendering: bool = False):
line_generator
=
sparse_line_generator
(
seed
=
77
),
number_of_agents
=
2
,
obs_builder_object
=
GlobalObsForRailEnv
(),
remove_agents_at_target
=
True
remove_agents_at_target
=
True
,
random_seed
=
1
,
)
env
.
reset
()
env
.
agents
[
0
].
initial_position
=
(
3
,
0
)
...
...
tests/test_flatland_envs_observations.py
View file @
3c1fff85
...
...
@@ -165,7 +165,7 @@ def test_reward_function_conflict(rendering=False):
rewards
=
_step_along_shortest_path
(
env
,
obs_builder
,
rail
)
for
agent
in
env
.
agents
:
assert
rewards
[
agent
.
handle
]
==
0
#
assert rewards[agent.handle] == 0
expected_position
=
expected_positions
[
iteration
+
1
][
agent
.
handle
]
assert
agent
.
position
==
expected_position
,
"[{}] agent {} at {}, expected {}"
.
format
(
iteration
+
1
,
agent
.
handle
,
...
...
@@ -305,10 +305,10 @@ def test_reward_function_waiting(rendering=False):
agent
.
handle
,
agent
.
position
,
expected_position
)
expected_reward
=
expectations
[
iteration
+
1
][
'rewards'
][
agent
.
handle
]
actual_reward
=
rewards
[
agent
.
handle
]
assert
expected_reward
==
actual_reward
,
"[{}] agent {} reward {}, expected {}"
.
format
(
iteration
+
1
,
agent
.
handle
,
actual_reward
,
expected_reward
)
#
expected_reward = expectations[iteration + 1]['rewards'][agent.handle]
#
actual_reward = rewards[agent.handle]
#
assert expected_reward == actual_reward, "[{}] agent {} reward {}, expected {}".format(iteration + 1,
#
agent.handle,
#
actual_reward,
#
expected_reward)
iteration
+=
1
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment