Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pranjal_dhole
Flatland
Commits
e3c821e5
Commit
e3c821e5
authored
3 years ago
by
Dipam Chakraborty
Browse files
Options
Downloads
Patches
Plain Diff
agent_positions and docstrings and other cleanups
parent
08a30dbb
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
flatland/envs/rail_env.py
+48
-80
48 additions, 80 deletions
flatland/envs/rail_env.py
flatland/envs/step_utils/env_utils.py
+52
-0
52 additions, 0 deletions
flatland/envs/step_utils/env_utils.py
with
100 additions
and
80 deletions
flatland/envs/rail_env.py
+
48
−
80
View file @
e3c821e5
...
...
@@ -7,13 +7,11 @@ from typing import List, Optional, Dict, Tuple
import
numpy
as
np
from
gym.utils
import
seeding
from
dataclasses
import
dataclass
from
flatland.utils.rendertools
import
RenderTool
,
AgentRenderVariant
from
flatland.core.env
import
Environment
from
flatland.core.env_observation_builder
import
ObservationBuilder
from
flatland.core.grid.grid4
import
Grid4Transitions
from
flatland.core.grid.grid4_utils
import
get_new_position
from
flatland.core.transition_map
import
GridTransitionMap
from
flatland.envs.agent_utils
import
EnvAgent
from
flatland.envs.distance_map
import
DistanceMap
...
...
@@ -30,8 +28,8 @@ from flatland.envs.observations import GlobalObsForRailEnv
from
flatland.envs.timetable_generators
import
timetable_generator
from
flatland.envs.step_utils.states
import
TrainState
,
StateTransitionSignals
from
flatland.envs.step_utils
import
transition_utils
from
flatland.envs.step_utils
import
action_preprocessing
from
flatland.envs.step_utils
import
env_utils
class
RailEnv
(
Environment
):
"""
...
...
@@ -110,7 +108,6 @@ class RailEnv(Environment):
remove_agents_at_target
=
True
,
random_seed
=
1
,
record_steps
=
False
,
close_following
=
True
):
"""
Environment init.
...
...
@@ -178,16 +175,12 @@ class RailEnv(Environment):
self
.
remove_agents_at_target
=
remove_agents_at_target
self
.
rewards
=
[
0
]
*
number_of_agents
self
.
done
=
False
self
.
obs_builder
=
obs_builder_object
self
.
obs_builder
.
set_env
(
self
)
self
.
_max_episode_steps
:
Optional
[
int
]
=
None
self
.
_elapsed_steps
=
0
self
.
dones
=
dict
.
fromkeys
(
list
(
range
(
number_of_agents
))
+
[
"
__all__
"
],
False
)
self
.
obs_dict
=
{}
self
.
rewards_dict
=
{}
self
.
dev_obs_dict
=
{}
...
...
@@ -205,10 +198,7 @@ class RailEnv(Environment):
if
self
.
random_seed
:
self
.
_seed
(
seed
=
random_seed
)
self
.
valid_positions
=
None
# global numpy array of agents position, True means that there is an agent at that cell
self
.
agent_positions
:
np
.
ndarray
=
np
.
full
((
height
,
width
),
False
)
self
.
agent_positions
=
None
# save episode timesteps ie agent positions, orientations. (not yet actions / observations)
self
.
record_steps
=
record_steps
# whether to save timesteps
...
...
@@ -216,11 +206,8 @@ class RailEnv(Environment):
self
.
cur_episode
=
[]
self
.
list_actions
=
[]
# save actions in here
self
.
close_following
=
close_following
# use close following logic
self
.
motionCheck
=
ac
.
MotionCheck
()
self
.
agent_helpers
=
{}
def
_seed
(
self
,
seed
=
None
):
self
.
np_random
,
seed
=
seeding
.
np_random
(
seed
)
random
.
seed
(
seed
)
...
...
@@ -229,7 +216,7 @@ class RailEnv(Environment):
# no more agent_handles
def
get_agent_handles
(
self
):
return
range
(
self
.
get_num_agents
())
def
get_num_agents
(
self
)
->
int
:
return
len
(
self
.
agents
)
...
...
@@ -337,9 +324,6 @@ class RailEnv(Environment):
agent
.
latest_arrival
=
timetable
.
latest_arrivals
[
agent_i
]
else
:
self
.
distance_map
.
reset
(
self
.
agents
,
self
.
rail
)
# Agent Positions Map
self
.
agent_positions
=
np
.
zeros
((
self
.
height
,
self
.
width
),
dtype
=
int
)
-
1
# Reset agents to initial states
self
.
reset_agents
()
...
...
@@ -347,7 +331,10 @@ class RailEnv(Environment):
self
.
num_resets
+=
1
self
.
_elapsed_steps
=
0
# TODO perhaps dones should be part of each agent.
# Agent positions map
self
.
agent_positions
=
np
.
zeros
((
self
.
height
,
self
.
width
),
dtype
=
int
)
-
1
self
.
_update_agent_positions_map
(
ignore_old_positions
=
False
)
self
.
dones
=
dict
.
fromkeys
(
list
(
range
(
self
.
get_num_agents
()))
+
[
"
__all__
"
],
False
)
# Reset the state of the observation builder with the new environment
...
...
@@ -362,14 +349,16 @@ class RailEnv(Environment):
if
hasattr
(
self
,
"
renderer
"
)
and
self
.
renderer
is
not
None
:
self
.
renderer
=
None
return
observation_dict
,
info_dict
def
apply_action_independent
(
self
,
action
,
rail
,
position
,
direction
):
if
action
.
is_moving_action
():
new_direction
,
_
=
transition_utils
.
check_action
(
action
,
position
,
direction
,
rail
)
new_position
=
get_new_position
(
position
,
new_direction
)
else
:
new_position
,
new_direction
=
position
,
direction
return
new_position
,
new_direction
def
_update_agent_positions_map
(
self
,
ignore_old_positions
=
True
):
"""
Update the agent_positions array for agents that changed positions
"""
for
agent
in
self
.
agents
:
if
not
ignore_old_positions
or
agent
.
old_position
!=
agent
.
position
:
self
.
agent_positions
[
agent
.
position
]
=
agent
.
handle
if
agent
.
old_position
is
not
None
:
self
.
agent_positions
[
agent
.
old_position
]
=
-
1
def
generate_state_transition_signals
(
self
,
agent
,
preprocessed_action
,
movement_allowed
):
"""
Generate State Transitions Signals used in the state machine
"""
...
...
@@ -391,7 +380,7 @@ class RailEnv(Environment):
st_signals
.
valid_movement_action_given
=
preprocessed_action
.
is_moving_action
()
and
movement_allowed
# Target Reached
st_signals
.
target_reached
=
fast_position_equal
(
agent
.
position
,
agent
.
target
)
st_signals
.
target_reached
=
env_utils
.
fast_position_equal
(
agent
.
position
,
agent
.
target
)
# Movement conflict - Multiple trains trying to move into same cell
# If speed counter is not in cell exit, the train can enter the cell
...
...
@@ -449,11 +438,18 @@ class RailEnv(Environment):
"""
Reset the rewards dictionary
"""
self
.
rewards_dict
=
{
i_agent
:
0
for
i_agent
in
range
(
len
(
self
.
agents
))}
def
get_info_dict
(
self
):
# TODO Important : Update this
def
get_info_dict
(
self
):
"""
Returns dictionary of infos for all agents
dict_keys : action_required -
malfunction - Counter value for malfunction > 0 means train is in malfunction
speed - Speed of the train
state - State from the trains
'
s state machine
"""
info_dict
=
{
'
action_required
'
:
{
i
:
self
.
action_required
(
agent
)
for
i
,
agent
in
enumerate
(
self
.
agents
)},
'
malfunction
'
:
{
i
:
agent
.
malfunction_
data
[
'
malfunction
'
]
for
i
,
agent
in
enumerate
(
self
.
agents
)
i
:
agent
.
malfunction_
handler
.
malfunction
_down_counter
for
i
,
agent
in
enumerate
(
self
.
agents
)
},
'
speed
'
:
{
i
:
agent
.
speed_counter
.
speed
for
i
,
agent
in
enumerate
(
self
.
agents
)},
'
state
'
:
{
i
:
agent
.
state
for
i
,
agent
in
enumerate
(
self
.
agents
)}
...
...
@@ -461,9 +457,16 @@ class RailEnv(Environment):
return
info_dict
def
update_step_rewards
(
self
,
i_agent
):
"""
Update the rewards dict for agent id i_agent for every timestep
"""
pass
def
end_of_episode_update
(
self
,
have_all_agents_ended
):
"""
Updates made when episode ends
Parameters: have_all_agents_ended - Indicates if all agents have reached done state
"""
if
have_all_agents_ended
or
\
(
(
self
.
_max_episode_steps
is
not
None
)
and
(
self
.
_elapsed_steps
>=
self
.
_max_episode_steps
)):
...
...
@@ -477,6 +480,7 @@ class RailEnv(Environment):
self
.
dones
[
"
__all__
"
]
=
True
def
handle_done_state
(
self
,
agent
):
"""
Any updates to agent to be made in Done state
"""
if
agent
.
state
==
TrainState
.
DONE
:
agent
.
arrival_time
=
self
.
_elapsed_steps
if
self
.
remove_agents_at_target
:
...
...
@@ -528,7 +532,7 @@ class RailEnv(Environment):
elif
agent
.
action_saver
.
is_action_saved
and
position_update_allowed
:
saved_action
=
agent
.
action_saver
.
saved_action
# Apply action independent of other agents and get temporary new position and direction
new_position
,
new_direction
=
self
.
apply_action_independent
(
saved_action
,
new_position
,
new_direction
=
env_utils
.
apply_action_independent
(
saved_action
,
self
.
rail
,
agent
.
position
,
agent
.
direction
)
...
...
@@ -536,7 +540,7 @@ class RailEnv(Environment):
else
:
new_position
,
new_direction
=
agent
.
position
,
agent
.
direction
temp_transition_data
[
i_agent
]
=
AgentTransitionData
(
position
=
new_position
,
temp_transition_data
[
i_agent
]
=
env_utils
.
AgentTransitionData
(
position
=
new_position
,
direction
=
new_direction
,
preprocessed_action
=
preprocessed_action
)
...
...
@@ -571,7 +575,7 @@ class RailEnv(Environment):
agent
.
state_machine
.
step
()
# Off map or on map state and position should match
state_position_sync_check
(
agent
.
state
,
agent
.
position
,
agent
.
handle
)
env_utils
.
state_position_sync_check
(
agent
.
state
,
agent
.
position
,
agent
.
handle
)
# Handle done state actions, optionally remove agents
self
.
handle_done_state
(
agent
)
...
...
@@ -593,11 +597,14 @@ class RailEnv(Environment):
# Check if episode has ended and update rewards and dones
self
.
end_of_episode_update
(
have_all_agents_ended
)
self
.
_update_agent_positions_map
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
self
.
get_info_dict
()
def
record_timestep
(
self
,
dActions
):
'''
Record the positions and orientations of all agents in memory, in the cur_episode
'''
"""
Record the positions and orientations of all agents in memory, in the cur_episode
"""
list_agents_state
=
[]
for
i_agent
in
range
(
self
.
get_num_agents
()):
agent
=
self
.
agents
[
i_agent
]
...
...
@@ -610,7 +617,7 @@ class RailEnv(Environment):
# print("pos:", pos, type(pos[0]))
list_agents_state
.
append
([
*
pos
,
int
(
agent
.
direction
),
agent
.
malfunction_
data
[
"
malfunction
"
]
,
agent
.
malfunction_
handler
.
malfunction
_down_counter
,
int
(
agent
.
status
),
int
(
agent
.
position
in
self
.
motionCheck
.
svDeadlocked
)
])
...
...
@@ -620,11 +627,7 @@ class RailEnv(Environment):
def
_get_observations
(
self
):
"""
Utility which returns the observations for an agent with respect to environment
Returns
------
Dict object
Utility which returns the dictionary of observations for an agent with respect to environment
"""
# print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}")
self
.
obs_dict
=
self
.
obs_builder
.
get_many
(
list
(
range
(
self
.
get_num_agents
())))
...
...
@@ -633,15 +636,6 @@ class RailEnv(Environment):
def
get_valid_directions_on_grid
(
self
,
row
:
int
,
col
:
int
)
->
List
[
int
]:
"""
Returns directions in which the agent can move
Parameters:
---------
row : int
col : int
Returns:
-------
List[int]
"""
return
Grid4Transitions
.
get_entry_directions
(
self
.
rail
.
get_full_transitions
(
row
,
col
))
...
...
@@ -669,9 +663,10 @@ class RailEnv(Environment):
"""
return
agent
.
malfunction_handler
.
in_malfunction
def
save
(
self
,
filename
):
print
(
"
deprecated
call to env.save() - pls call RailEnvPersister.save()
"
)
print
(
"
DEPRECATED
call to env.save() - pls call RailEnvPersister.save()
"
)
persistence
.
RailEnvPersister
.
save
(
self
,
filename
)
def
render
(
self
,
mode
=
"
rgb_array
"
,
gl
=
"
PGL
"
,
agent_render_variant
=
AgentRenderVariant
.
ONE_STEP_BEHIND
,
...
...
@@ -747,31 +742,4 @@ class RailEnv(Environment):
self
.
renderer
.
close_window
()
except
Exception
as
e
:
print
(
"
Could Not close window due to:
"
,
e
)
self
.
renderer
=
None
@dataclass
(
repr
=
True
)
class
AgentTransitionData
:
"""
Class for keeping track of temporary agent data for position update
"""
position
:
Tuple
[
int
,
int
]
direction
:
Grid4Transitions
preprocessed_action
:
RailEnvActions
# Adrian Egli performance fix (the fast methods brings more than 50%)
def
fast_isclose
(
a
,
b
,
rtol
):
return
(
a
<
(
b
+
rtol
))
or
(
a
<
(
b
-
rtol
))
def
fast_position_equal
(
pos_1
:
(
int
,
int
),
pos_2
:
(
int
,
int
))
->
bool
:
if
pos_1
is
None
:
# TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None
return
False
else
:
return
pos_1
[
0
]
==
pos_2
[
0
]
and
pos_1
[
1
]
==
pos_2
[
1
]
def
state_position_sync_check
(
state
,
position
,
i_agent
):
if
state
.
is_on_map_state
()
and
position
is
None
:
raise
ValueError
(
"
Agent ID {} Agent State {} is on map Agent Position {} if off map
"
.
format
(
i_agent
,
str
(
state
),
str
(
position
)
))
elif
state
.
is_off_map_state
()
and
position
is
not
None
:
raise
ValueError
(
"
Agent ID {} Agent State {} is off map Agent Position {} if on map
"
.
format
(
i_agent
,
str
(
state
),
str
(
position
)
))
self
.
renderer
=
None
\ No newline at end of file
This diff is collapsed.
Click to expand it.
flatland/envs/step_utils/env_utils.py
0 → 100644
+
52
−
0
View file @
e3c821e5
from
dataclasses
import
dataclass
from
typing
import
Tuple
from
flatland.core.grid.grid4_utils
import
get_new_position
from
flatland.envs.step_utils
import
transition_utils
from
flatland.envs.rail_env_action
import
RailEnvActions
from
flatland.core.grid.grid4
import
Grid4Transitions
@dataclass
(
repr
=
True
)
class
AgentTransitionData
:
"""
Class for keeping track of temporary agent data for position update
"""
position
:
Tuple
[
int
,
int
]
direction
:
Grid4Transitions
preprocessed_action
:
RailEnvActions
# Adrian Egli performance fix (the fast methods brings more than 50%)
def
fast_isclose
(
a
,
b
,
rtol
):
return
(
a
<
(
b
+
rtol
))
or
(
a
<
(
b
-
rtol
))
def
fast_position_equal
(
pos_1
:
(
int
,
int
),
pos_2
:
(
int
,
int
))
->
bool
:
if
pos_1
is
None
:
return
False
else
:
return
pos_1
[
0
]
==
pos_2
[
0
]
and
pos_1
[
1
]
==
pos_2
[
1
]
def
apply_action_independent
(
action
,
rail
,
position
,
direction
):
"""
Apply the action on the train regardless of locations of other trains
Checks for valid cells to move and valid rail transitions
---------------------------------------------------------------------
Parameters: action - Action to execute
rail - Flatland env.rail object
position - current position of the train
direction - current direction of the train
---------------------------------------------------------------------
Returns: new_position - New position after applying the action
new_direction - New direction after applying the action
"""
if
action
.
is_moving_action
():
new_direction
,
_
=
transition_utils
.
check_action
(
action
,
position
,
direction
,
rail
)
new_position
=
get_new_position
(
position
,
new_direction
)
else
:
new_position
,
new_direction
=
position
,
direction
return
new_position
,
new_direction
def
state_position_sync_check
(
state
,
position
,
i_agent
):
"""
Check for whether on map and off map states are matching with position
"""
if
state
.
is_on_map_state
()
and
position
is
None
:
raise
ValueError
(
"
Agent ID {} Agent State {} is on map Agent Position {} if off map
"
.
format
(
i_agent
,
str
(
state
),
str
(
position
)
))
elif
state
.
is_off_map_state
()
and
position
is
not
None
:
raise
ValueError
(
"
Agent ID {} Agent State {} is off map Agent Position {} if on map
"
.
format
(
i_agent
,
str
(
state
),
str
(
position
)
))
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment