Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
4169a0f1
Commit
4169a0f1
authored
Sep 10, 2021
by
Dipam Chakraborty
Browse files
fixes to env.step() direction update
parent
e4399082
Pipeline
#8454
failed with stages
in 5 minutes and 5 seconds
Changes
6
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
flatland/action_plan/action_plan.py
View file @
4169a0f1
...
...
@@ -150,7 +150,7 @@ class ControllerFromTrainruns():
def
_create_action_plan_for_agent
(
self
,
agent_id
,
trainrun
)
->
ActionPlan
:
action_plan
=
[]
agent
=
self
.
env
.
agents
[
agent_id
]
minimum_cell_time
=
agent
.
speed_counter
.
max_count
minimum_cell_time
=
agent
.
speed_counter
.
max_count
+
1
for
path_loop
,
trainrun_waypoint
in
enumerate
(
trainrun
):
trainrun_waypoint
:
TrainrunWaypoint
=
trainrun_waypoint
...
...
flatland/action_plan/action_plan_player.py
View file @
4169a0f1
...
...
@@ -30,10 +30,7 @@ class ControllerFromTrainrunsReplayer():
assert
agent
.
position
==
waypoint
.
position
,
\
"before {}, agent {} at {}, expected {}"
.
format
(
i
,
agent_id
,
agent
.
position
,
waypoint
.
position
)
if
agent_id
==
1
:
print
(
env
.
_elapsed_steps
,
agent
.
position
,
agent
.
state
,
agent
.
speed_counter
)
actions
=
ctl
.
act
(
i
)
print
(
"actions for {}: {}"
.
format
(
i
,
actions
))
obs
,
all_rewards
,
done
,
_
=
env
.
step
(
actions
)
...
...
flatland/envs/observations.py
View file @
4169a0f1
...
...
@@ -221,7 +221,7 @@ class TreeObsForRailEnv(ObservationBuilder):
agent
.
direction
)],
num_agents_same_direction
=
0
,
num_agents_opposite_direction
=
0
,
num_agents_malfunctioning
=
agent
.
malfunction_data
[
'malfunction'
],
speed_min_fractional
=
agent
.
speed_counter
.
speed
speed_min_fractional
=
agent
.
speed_counter
.
speed
,
num_agents_ready_to_depart
=
0
,
childs
=
{})
#print("root node type:", type(root_node_observation))
...
...
flatland/envs/rail_env.py
View file @
4169a0f1
...
...
@@ -366,9 +366,10 @@ class RailEnv(Environment):
new_position
=
get_new_position
(
position
,
new_direction
)
else
:
new_position
,
new_direction
=
position
,
direction
return
new_position
,
direction
return
new_position
,
new_
direction
def
generate_state_transition_signals
(
self
,
agent
,
preprocessed_action
,
movement_allowed
):
""" Generate State Transitions Signals used in the state machine """
st_signals
=
StateTransitionSignals
()
# Malfunction onset - Malfunction starts
...
...
@@ -442,9 +443,8 @@ class RailEnv(Environment):
return
action
def
clear_rewards_dict
(
self
):
""" Reset the step rewards """
self
.
rewards_dict
=
dict
()
""" Reset the rewards dictionary """
self
.
rewards_dict
=
{
i_agent
:
0
for
i_agent
in
range
(
len
(
self
.
agents
))}
def
get_info_dict
(
self
):
# TODO Important : Update this
info_dict
=
{
...
...
@@ -456,6 +456,22 @@ class RailEnv(Environment):
'state'
:
{
i
:
agent
.
state
for
i
,
agent
in
enumerate
(
self
.
agents
)}
}
return
info_dict
def
update_step_rewards
(
self
,
i_agent
):
pass
def
end_of_episode_update
(
self
,
have_all_agents_ended
):
if
have_all_agents_ended
or
\
(
(
self
.
_max_episode_steps
is
not
None
)
and
(
self
.
_elapsed_steps
>=
self
.
_max_episode_steps
)):
for
i_agent
,
agent
in
enumerate
(
self
.
agents
):
reward
=
self
.
_handle_end_reward
(
agent
)
self
.
rewards_dict
[
i_agent
]
+=
reward
self
.
dones
[
i_agent
]
=
True
self
.
dones
[
"__all__"
]
=
True
def
step
(
self
,
action_dict_
:
Dict
[
int
,
RailEnvActions
]):
"""
...
...
@@ -520,6 +536,8 @@ class RailEnv(Environment):
i_agent
=
agent
.
handle
agent_transition_data
=
temp_transition_data
[
i_agent
]
old_position
=
agent
.
position
## Update positions
if
agent
.
malfunction_handler
.
in_malfunction
:
movement_allowed
=
False
...
...
@@ -544,30 +562,18 @@ class RailEnv(Environment):
have_all_agents_ended
&=
(
agent
.
state
==
TrainState
.
DONE
)
## Update rewards
#
self.update_rewards(i_agent
, agent, rail) # TODO : Step Rewards
self
.
update_
step_
rewards
(
i_agent
)
## Update counters (malfunction and speed)
agent
.
speed_counter
.
update_counter
(
agent
.
state
)
agent
.
speed_counter
.
update_counter
(
agent
.
state
,
old_position
)
agent
.
malfunction_handler
.
update_counter
()
# Clear old action when starting in new cell
if
agent
.
speed_counter
.
is_cell_entry
:
agent
.
action_saver
.
clear_saved_action
()
self
.
rewards_dict
=
{
i_agent
:
0
for
i_agent
in
range
(
len
(
self
.
agents
))}
if
((
self
.
_max_episode_steps
is
not
None
)
and
(
self
.
_elapsed_steps
>=
self
.
_max_episode_steps
))
\
or
have_all_agents_ended
:
for
i_agent
,
agent
in
enumerate
(
self
.
agents
):
reward
=
self
.
_handle_end_reward
(
agent
)
self
.
rewards_dict
[
i_agent
]
+=
reward
self
.
dones
[
i_agent
]
=
True
self
.
dones
[
"__all__"
]
=
True
# Check if episode has ended and update rewards and dones
self
.
end_of_episode_update
(
have_all_agents_ended
)
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
self
.
get_info_dict
()
...
...
flatland/envs/step_utils/speed_counter.py
View file @
4169a0f1
...
...
@@ -4,12 +4,13 @@ from flatland.envs.step_utils.states import TrainState
class
SpeedCounter
:
def
__init__
(
self
,
speed
):
self
.
speed
=
speed
self
.
max_count
=
int
(
1
/
speed
)
self
.
max_count
=
int
(
1
/
speed
)
-
1
def
update_counter
(
self
,
state
):
if
state
==
TrainState
.
MOVING
:
def
update_counter
(
self
,
state
,
old_position
):
# When coming onto the map, do no update speed counter
if
state
==
TrainState
.
MOVING
and
old_position
is
not
None
:
self
.
counter
+=
1
self
.
counter
=
self
.
counter
%
self
.
max_count
self
.
counter
=
self
.
counter
%
(
self
.
max_count
+
1
)
def
__repr__
(
self
):
return
f
"speed:
{
self
.
speed
}
\
...
...
@@ -27,5 +28,5 @@ class SpeedCounter:
@
property
def
is_cell_exit
(
self
):
return
self
.
counter
==
self
.
max_count
-
1
return
self
.
counter
==
self
.
max_count
tests/test_action_plan.py
View file @
4169a0f1
...
...
@@ -9,6 +9,7 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint
from
flatland.envs.line_generators
import
sparse_line_generator
from
flatland.utils.rendertools
import
RenderTool
,
AgentRenderVariant
from
flatland.utils.simple_rail
import
make_simple_rail
from
flatland.envs.step_utils.speed_counter
import
SpeedCounter
def
test_action_plan
(
rendering
:
bool
=
False
):
...
...
@@ -29,7 +30,7 @@ def test_action_plan(rendering: bool = False):
env
.
agents
[
1
].
initial_position
=
(
3
,
8
)
env
.
agents
[
1
].
initial_direction
=
Grid4TransitionsEnum
.
WEST
env
.
agents
[
1
].
target
=
(
0
,
3
)
env
.
agents
[
1
].
speed_
data
[
'speed'
]
=
0.5
# two
env
.
agents
[
1
].
speed_
counter
=
SpeedCounter
(
speed
=
0.5
)
env
.
reset
(
False
,
False
)
for
handle
,
agent
in
enumerate
(
env
.
agents
):
print
(
"[{}] {} -> {}"
.
format
(
handle
,
agent
.
initial_position
,
agent
.
target
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment