Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
e3c821e5
Commit
e3c821e5
authored
Sep 15, 2021
by
Dipam Chakraborty
Browse files
agent_positions and docstrings and other cleanups
parent
08a30dbb
Changes
2
Hide whitespace changes
Inline
Side-by-side
flatland/envs/rail_env.py
View file @
e3c821e5
...
...
@@ -7,13 +7,11 @@ from typing import List, Optional, Dict, Tuple
import
numpy
as
np
from
gym.utils
import
seeding
from
dataclasses
import
dataclass
from
flatland.utils.rendertools
import
RenderTool
,
AgentRenderVariant
from
flatland.core.env
import
Environment
from
flatland.core.env_observation_builder
import
ObservationBuilder
from
flatland.core.grid.grid4
import
Grid4Transitions
from
flatland.core.grid.grid4_utils
import
get_new_position
from
flatland.core.transition_map
import
GridTransitionMap
from
flatland.envs.agent_utils
import
EnvAgent
from
flatland.envs.distance_map
import
DistanceMap
...
...
@@ -30,8 +28,8 @@ from flatland.envs.observations import GlobalObsForRailEnv
from
flatland.envs.timetable_generators
import
timetable_generator
from
flatland.envs.step_utils.states
import
TrainState
,
StateTransitionSignals
from
flatland.envs.step_utils
import
transition_utils
from
flatland.envs.step_utils
import
action_preprocessing
from
flatland.envs.step_utils
import
env_utils
class
RailEnv
(
Environment
):
"""
...
...
@@ -110,7 +108,6 @@ class RailEnv(Environment):
remove_agents_at_target
=
True
,
random_seed
=
1
,
record_steps
=
False
,
close_following
=
True
):
"""
Environment init.
...
...
@@ -178,16 +175,12 @@ class RailEnv(Environment):
self
.
remove_agents_at_target
=
remove_agents_at_target
self
.
rewards
=
[
0
]
*
number_of_agents
self
.
done
=
False
self
.
obs_builder
=
obs_builder_object
self
.
obs_builder
.
set_env
(
self
)
self
.
_max_episode_steps
:
Optional
[
int
]
=
None
self
.
_elapsed_steps
=
0
self
.
dones
=
dict
.
fromkeys
(
list
(
range
(
number_of_agents
))
+
[
"__all__"
],
False
)
self
.
obs_dict
=
{}
self
.
rewards_dict
=
{}
self
.
dev_obs_dict
=
{}
...
...
@@ -205,10 +198,7 @@ class RailEnv(Environment):
if
self
.
random_seed
:
self
.
_seed
(
seed
=
random_seed
)
self
.
valid_positions
=
None
# global numpy array of agents position, True means that there is an agent at that cell
self
.
agent_positions
:
np
.
ndarray
=
np
.
full
((
height
,
width
),
False
)
self
.
agent_positions
=
None
# save episode timesteps ie agent positions, orientations. (not yet actions / observations)
self
.
record_steps
=
record_steps
# whether to save timesteps
...
...
@@ -216,11 +206,8 @@ class RailEnv(Environment):
self
.
cur_episode
=
[]
self
.
list_actions
=
[]
# save actions in here
self
.
close_following
=
close_following
# use close following logic
self
.
motionCheck
=
ac
.
MotionCheck
()
self
.
agent_helpers
=
{}
def
_seed
(
self
,
seed
=
None
):
self
.
np_random
,
seed
=
seeding
.
np_random
(
seed
)
random
.
seed
(
seed
)
...
...
@@ -229,7 +216,7 @@ class RailEnv(Environment):
# no more agent_handles
def
get_agent_handles
(
self
):
return
range
(
self
.
get_num_agents
())
def
get_num_agents
(
self
)
->
int
:
return
len
(
self
.
agents
)
...
...
@@ -337,9 +324,6 @@ class RailEnv(Environment):
agent
.
latest_arrival
=
timetable
.
latest_arrivals
[
agent_i
]
else
:
self
.
distance_map
.
reset
(
self
.
agents
,
self
.
rail
)
# Agent Positions Map
self
.
agent_positions
=
np
.
zeros
((
self
.
height
,
self
.
width
),
dtype
=
int
)
-
1
# Reset agents to initial states
self
.
reset_agents
()
...
...
@@ -347,7 +331,10 @@ class RailEnv(Environment):
self
.
num_resets
+=
1
self
.
_elapsed_steps
=
0
# TODO perhaps dones should be part of each agent.
# Agent positions map
self
.
agent_positions
=
np
.
zeros
((
self
.
height
,
self
.
width
),
dtype
=
int
)
-
1
self
.
_update_agent_positions_map
(
ignore_old_positions
=
False
)
self
.
dones
=
dict
.
fromkeys
(
list
(
range
(
self
.
get_num_agents
()))
+
[
"__all__"
],
False
)
# Reset the state of the observation builder with the new environment
...
...
@@ -362,14 +349,16 @@ class RailEnv(Environment):
if
hasattr
(
self
,
"renderer"
)
and
self
.
renderer
is
not
None
:
self
.
renderer
=
None
return
observation_dict
,
info_dict
def
apply_action_independent
(
self
,
action
,
rail
,
position
,
direction
):
if
action
.
is_moving_action
():
new_direction
,
_
=
transition_utils
.
check_action
(
action
,
position
,
direction
,
rail
)
new_position
=
get_new_position
(
position
,
new_direction
)
else
:
new_position
,
new_direction
=
position
,
direction
return
new_position
,
new_direction
def
_update_agent_positions_map
(
self
,
ignore_old_positions
=
True
):
""" Update the agent_positions array for agents that changed positions """
for
agent
in
self
.
agents
:
if
not
ignore_old_positions
or
agent
.
old_position
!=
agent
.
position
:
self
.
agent_positions
[
agent
.
position
]
=
agent
.
handle
if
agent
.
old_position
is
not
None
:
self
.
agent_positions
[
agent
.
old_position
]
=
-
1
def
generate_state_transition_signals
(
self
,
agent
,
preprocessed_action
,
movement_allowed
):
""" Generate State Transitions Signals used in the state machine """
...
...
@@ -391,7 +380,7 @@ class RailEnv(Environment):
st_signals
.
valid_movement_action_given
=
preprocessed_action
.
is_moving_action
()
and
movement_allowed
# Target Reached
st_signals
.
target_reached
=
fast_position_equal
(
agent
.
position
,
agent
.
target
)
st_signals
.
target_reached
=
env_utils
.
fast_position_equal
(
agent
.
position
,
agent
.
target
)
# Movement conflict - Multiple trains trying to move into same cell
# If speed counter is not in cell exit, the train can enter the cell
...
...
@@ -449,11 +438,18 @@ class RailEnv(Environment):
""" Reset the rewards dictionary """
self
.
rewards_dict
=
{
i_agent
:
0
for
i_agent
in
range
(
len
(
self
.
agents
))}
def
get_info_dict
(
self
):
# TODO Important : Update this
def
get_info_dict
(
self
):
"""
Returns dictionary of infos for all agents
dict_keys : action_required -
malfunction - Counter value for malfunction > 0 means train is in malfunction
speed - Speed of the train
state - State from the trains's state machine
"""
info_dict
=
{
'action_required'
:
{
i
:
self
.
action_required
(
agent
)
for
i
,
agent
in
enumerate
(
self
.
agents
)},
'malfunction'
:
{
i
:
agent
.
malfunction_
data
[
'
malfunction
'
]
for
i
,
agent
in
enumerate
(
self
.
agents
)
i
:
agent
.
malfunction_
handler
.
malfunction
_down_counter
for
i
,
agent
in
enumerate
(
self
.
agents
)
},
'speed'
:
{
i
:
agent
.
speed_counter
.
speed
for
i
,
agent
in
enumerate
(
self
.
agents
)},
'state'
:
{
i
:
agent
.
state
for
i
,
agent
in
enumerate
(
self
.
agents
)}
...
...
@@ -461,9 +457,16 @@ class RailEnv(Environment):
return
info_dict
def
update_step_rewards
(
self
,
i_agent
):
"""
Update the rewards dict for agent id i_agent for every timestep
"""
pass
def
end_of_episode_update
(
self
,
have_all_agents_ended
):
"""
Updates made when episode ends
Parameters: have_all_agents_ended - Indicates if all agents have reached done state
"""
if
have_all_agents_ended
or
\
(
(
self
.
_max_episode_steps
is
not
None
)
and
(
self
.
_elapsed_steps
>=
self
.
_max_episode_steps
)):
...
...
@@ -477,6 +480,7 @@ class RailEnv(Environment):
self
.
dones
[
"__all__"
]
=
True
def
handle_done_state
(
self
,
agent
):
""" Any updates to agent to be made in Done state """
if
agent
.
state
==
TrainState
.
DONE
:
agent
.
arrival_time
=
self
.
_elapsed_steps
if
self
.
remove_agents_at_target
:
...
...
@@ -528,7 +532,7 @@ class RailEnv(Environment):
elif
agent
.
action_saver
.
is_action_saved
and
position_update_allowed
:
saved_action
=
agent
.
action_saver
.
saved_action
# Apply action independent of other agents and get temporary new position and direction
new_position
,
new_direction
=
self
.
apply_action_independent
(
saved_action
,
new_position
,
new_direction
=
env_utils
.
apply_action_independent
(
saved_action
,
self
.
rail
,
agent
.
position
,
agent
.
direction
)
...
...
@@ -536,7 +540,7 @@ class RailEnv(Environment):
else
:
new_position
,
new_direction
=
agent
.
position
,
agent
.
direction
temp_transition_data
[
i_agent
]
=
AgentTransitionData
(
position
=
new_position
,
temp_transition_data
[
i_agent
]
=
env_utils
.
AgentTransitionData
(
position
=
new_position
,
direction
=
new_direction
,
preprocessed_action
=
preprocessed_action
)
...
...
@@ -571,7 +575,7 @@ class RailEnv(Environment):
agent
.
state_machine
.
step
()
# Off map or on map state and position should match
state_position_sync_check
(
agent
.
state
,
agent
.
position
,
agent
.
handle
)
env_utils
.
state_position_sync_check
(
agent
.
state
,
agent
.
position
,
agent
.
handle
)
# Handle done state actions, optionally remove agents
self
.
handle_done_state
(
agent
)
...
...
@@ -593,11 +597,14 @@ class RailEnv(Environment):
# Check if episode has ended and update rewards and dones
self
.
end_of_episode_update
(
have_all_agents_ended
)
self
.
_update_agent_positions_map
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
self
.
get_info_dict
()
def
record_timestep
(
self
,
dActions
):
''' Record the positions and orientations of all agents in memory, in the cur_episode
'''
"""
Record the positions and orientations of all agents in memory, in the cur_episode
"""
list_agents_state
=
[]
for
i_agent
in
range
(
self
.
get_num_agents
()):
agent
=
self
.
agents
[
i_agent
]
...
...
@@ -610,7 +617,7 @@ class RailEnv(Environment):
# print("pos:", pos, type(pos[0]))
list_agents_state
.
append
([
*
pos
,
int
(
agent
.
direction
),
agent
.
malfunction_
data
[
"
malfunction
"
]
,
agent
.
malfunction_
handler
.
malfunction
_down_counter
,
int
(
agent
.
status
),
int
(
agent
.
position
in
self
.
motionCheck
.
svDeadlocked
)
])
...
...
@@ -620,11 +627,7 @@ class RailEnv(Environment):
def
_get_observations
(
self
):
"""
Utility which returns the observations for an agent with respect to environment
Returns
------
Dict object
Utility which returns the dictionary of observations for an agent with respect to environment
"""
# print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}")
self
.
obs_dict
=
self
.
obs_builder
.
get_many
(
list
(
range
(
self
.
get_num_agents
())))
...
...
@@ -633,15 +636,6 @@ class RailEnv(Environment):
def
get_valid_directions_on_grid
(
self
,
row
:
int
,
col
:
int
)
->
List
[
int
]:
"""
Returns directions in which the agent can move
Parameters:
---------
row : int
col : int
Returns:
-------
List[int]
"""
return
Grid4Transitions
.
get_entry_directions
(
self
.
rail
.
get_full_transitions
(
row
,
col
))
...
...
@@ -669,9 +663,10 @@ class RailEnv(Environment):
"""
return
agent
.
malfunction_handler
.
in_malfunction
def
save
(
self
,
filename
):
print
(
"
deprecated
call to env.save() - pls call RailEnvPersister.save()"
)
print
(
"
DEPRECATED
call to env.save() - pls call RailEnvPersister.save()"
)
persistence
.
RailEnvPersister
.
save
(
self
,
filename
)
def
render
(
self
,
mode
=
"rgb_array"
,
gl
=
"PGL"
,
agent_render_variant
=
AgentRenderVariant
.
ONE_STEP_BEHIND
,
...
...
@@ -747,31 +742,4 @@ class RailEnv(Environment):
self
.
renderer
.
close_window
()
except
Exception
as
e
:
print
(
"Could Not close window due to:"
,
e
)
self
.
renderer
=
None
@
dataclass
(
repr
=
True
)
class
AgentTransitionData
:
""" Class for keeping track of temporary agent data for position update """
position
:
Tuple
[
int
,
int
]
direction
:
Grid4Transitions
preprocessed_action
:
RailEnvActions
# Adrian Egli performance fix (the fast methods brings more than 50%)
def
fast_isclose
(
a
,
b
,
rtol
):
return
(
a
<
(
b
+
rtol
))
or
(
a
<
(
b
-
rtol
))
def
fast_position_equal
(
pos_1
:
(
int
,
int
),
pos_2
:
(
int
,
int
))
->
bool
:
if
pos_1
is
None
:
# TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None
return
False
else
:
return
pos_1
[
0
]
==
pos_2
[
0
]
and
pos_1
[
1
]
==
pos_2
[
1
]
def
state_position_sync_check
(
state
,
position
,
i_agent
):
if
state
.
is_on_map_state
()
and
position
is
None
:
raise
ValueError
(
"Agent ID {} Agent State {} is on map Agent Position {} if off map "
.
format
(
i_agent
,
str
(
state
),
str
(
position
)
))
elif
state
.
is_off_map_state
()
and
position
is
not
None
:
raise
ValueError
(
"Agent ID {} Agent State {} is off map Agent Position {} if on map "
.
format
(
i_agent
,
str
(
state
),
str
(
position
)
))
self
.
renderer
=
None
\ No newline at end of file
flatland/envs/step_utils/env_utils.py
0 → 100644
View file @
e3c821e5
from
dataclasses
import
dataclass
from
typing
import
Tuple
from
flatland.core.grid.grid4_utils
import
get_new_position
from
flatland.envs.step_utils
import
transition_utils
from
flatland.envs.rail_env_action
import
RailEnvActions
from
flatland.core.grid.grid4
import
Grid4Transitions
@
dataclass
(
repr
=
True
)
class
AgentTransitionData
:
""" Class for keeping track of temporary agent data for position update """
position
:
Tuple
[
int
,
int
]
direction
:
Grid4Transitions
preprocessed_action
:
RailEnvActions
# Adrian Egli performance fix (the fast methods brings more than 50%)
def
fast_isclose
(
a
,
b
,
rtol
):
return
(
a
<
(
b
+
rtol
))
or
(
a
<
(
b
-
rtol
))
def
fast_position_equal
(
pos_1
:
(
int
,
int
),
pos_2
:
(
int
,
int
))
->
bool
:
if
pos_1
is
None
:
return
False
else
:
return
pos_1
[
0
]
==
pos_2
[
0
]
and
pos_1
[
1
]
==
pos_2
[
1
]
def
apply_action_independent
(
action
,
rail
,
position
,
direction
):
""" Apply the action on the train regardless of locations of other trains
Checks for valid cells to move and valid rail transitions
---------------------------------------------------------------------
Parameters: action - Action to execute
rail - Flatland env.rail object
position - current position of the train
direction - current direction of the train
---------------------------------------------------------------------
Returns: new_position - New position after applying the action
new_direction - New direction after applying the action
"""
if
action
.
is_moving_action
():
new_direction
,
_
=
transition_utils
.
check_action
(
action
,
position
,
direction
,
rail
)
new_position
=
get_new_position
(
position
,
new_direction
)
else
:
new_position
,
new_direction
=
position
,
direction
return
new_position
,
new_direction
def
state_position_sync_check
(
state
,
position
,
i_agent
):
""" Check for whether on map and off map states are matching with position """
if
state
.
is_on_map_state
()
and
position
is
None
:
raise
ValueError
(
"Agent ID {} Agent State {} is on map Agent Position {} if off map "
.
format
(
i_agent
,
str
(
state
),
str
(
position
)
))
elif
state
.
is_off_map_state
()
and
position
is
not
None
:
raise
ValueError
(
"Agent ID {} Agent State {} is off map Agent Position {} if on map "
.
format
(
i_agent
,
str
(
state
),
str
(
position
)
))
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment