Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
yoogottamk
Flatland
Commits
9eec6c35
Commit
9eec6c35
authored
5 years ago
by
Erik Nygren
Browse files
Options
Downloads
Plain Diff
Merge branch 'redis-opts' into 'master'
Redis opts See merge request
flatland/flatland!239
parents
a796ca0d
d22472fb
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
flatland/envs/rail_env.py
+42
-30
42 additions, 30 deletions
flatland/envs/rail_env.py
flatland/evaluators/client.py
+76
-55
76 additions, 55 deletions
flatland/evaluators/client.py
flatland/evaluators/service.py
+53
-33
53 additions, 33 deletions
flatland/evaluators/service.py
with
171 additions
and
118 deletions
flatland/envs/rail_env.py
+
42
−
30
View file @
9eec6c35
...
@@ -321,7 +321,7 @@ class RailEnv(Environment):
...
@@ -321,7 +321,7 @@ class RailEnv(Environment):
# todo change self.agents_static[0] with the refactoring for agents_static -> issue nr. 185
# todo change self.agents_static[0] with the refactoring for agents_static -> issue nr. 185
# https://gitlab.aicrowd.com/flatland/flatland/issues/185
# https://gitlab.aicrowd.com/flatland/flatland/issues/185
if
regenerate_schedule
or
self
.
agents_static
[
0
]
is
None
:
if
regenerate_schedule
or
regenerate_rail
or
self
.
agents_static
[
0
]
is
None
:
agents_hints
=
None
agents_hints
=
None
if
optionals
and
'
agents_hints
'
in
optionals
:
if
optionals
and
'
agents_hints
'
in
optionals
:
agents_hints
=
optionals
[
'
agents_hints
'
]
agents_hints
=
optionals
[
'
agents_hints
'
]
...
@@ -436,48 +436,60 @@ class RailEnv(Environment):
...
@@ -436,48 +436,60 @@ class RailEnv(Environment):
self
.
_elapsed_steps
+=
1
self
.
_elapsed_steps
+=
1
# Reset the step rewards
self
.
rewards_dict
=
dict
()
for
i_agent
in
range
(
self
.
get_num_agents
()):
self
.
rewards_dict
[
i_agent
]
=
0
# If we're done, set reward and info_dict and step() is done.
# If we're done, set reward and info_dict and step() is done.
if
self
.
dones
[
"
__all__
"
]:
if
self
.
dones
[
"
__all__
"
]:
self
.
rewards_dict
=
{
i
:
self
.
global_reward
for
i
in
range
(
self
.
get_num_agents
())
}
self
.
rewards_dict
=
{}
info_dict
=
{
info_dict
=
{
'
action_required
'
:
{
i
:
False
for
i
in
range
(
self
.
get_num_agents
())
},
"
action_required
"
:
{
},
'
malfunction
'
:
{
i
:
0
for
i
in
range
(
self
.
get_num_agents
())
},
"
malfunction
"
:
{
},
'
speed
'
:
{
i
:
0
for
i
in
range
(
self
.
get_num_agents
())
},
"
speed
"
:
{
},
'
status
'
:
{
i
:
agent
.
status
for
i
,
agent
in
enumerate
(
self
.
agents
)}
"
status
"
:
{},
}
}
for
i_agent
,
agent
in
enumerate
(
self
.
agents
):
self
.
rewards_dict
[
i_agent
]
=
self
.
global_reward
info_dict
[
"
action_required
"
][
i_agent
]
=
False
info_dict
[
"
malfunction
"
][
i_agent
]
=
0
info_dict
[
"
speed
"
][
i_agent
]
=
0
info_dict
[
"
status
"
][
i_agent
]
=
agent
.
status
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
info_dict
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
info_dict
# Perform step on all agents
# Reset the step rewards
for
i_agent
in
range
(
self
.
get_num_agents
()):
self
.
rewards_dict
=
dict
()
info_dict
=
{
"
action_required
"
:
{},
"
malfunction
"
:
{},
"
speed
"
:
{},
"
status
"
:
{},
}
have_all_agents_ended
=
True
# boolean flag to check if all agents are done
for
i_agent
,
agent
in
enumerate
(
self
.
agents
):
# Reset the step rewards
self
.
rewards_dict
[
i_agent
]
=
0
# Perform step on the agent
self
.
_step_agent
(
i_agent
,
action_dict_
.
get
(
i_agent
))
self
.
_step_agent
(
i_agent
,
action_dict_
.
get
(
i_agent
))
# manage the boolean flag to check if all agents are indeed done (or done_removed)
have_all_agents_ended
&=
(
agent
.
status
in
[
RailAgentStatus
.
DONE
,
RailAgentStatus
.
DONE_REMOVED
])
# Build info dict
info_dict
[
"
action_required
"
][
i_agent
]
=
\
(
agent
.
status
==
RailAgentStatus
.
READY_TO_DEPART
or
(
agent
.
status
==
RailAgentStatus
.
ACTIVE
and
np
.
isclose
(
agent
.
speed_data
[
'
position_fraction
'
],
0.0
,
rtol
=
1e-03
)))
info_dict
[
"
malfunction
"
][
i_agent
]
=
agent
.
malfunction_data
[
'
malfunction
'
]
info_dict
[
"
speed
"
][
i_agent
]
=
agent
.
speed_data
[
'
speed
'
]
info_dict
[
"
status
"
][
i_agent
]
=
agent
.
status
# Check for end of episode + set global reward to all rewards!
# Check for end of episode + set global reward to all rewards!
if
np
.
all
([
agent
.
status
in
[
RailAgentStatus
.
DONE
,
RailAgentStatus
.
DONE_REMOVED
]
for
agent
in
self
.
agents
])
:
if
have_all_agents_ended
:
self
.
dones
[
"
__all__
"
]
=
True
self
.
dones
[
"
__all__
"
]
=
True
self
.
rewards_dict
=
{
i
:
self
.
global_reward
for
i
in
range
(
self
.
get_num_agents
())}
self
.
rewards_dict
=
{
i
:
self
.
global_reward
for
i
in
range
(
self
.
get_num_agents
())}
if
(
self
.
_max_episode_steps
is
not
None
)
and
(
self
.
_elapsed_steps
>=
self
.
_max_episode_steps
):
if
(
self
.
_max_episode_steps
is
not
None
)
and
(
self
.
_elapsed_steps
>=
self
.
_max_episode_steps
):
self
.
dones
[
"
__all__
"
]
=
True
self
.
dones
[
"
__all__
"
]
=
True
for
i
in
range
(
self
.
get_num_agents
()):
for
i_agent
in
range
(
self
.
get_num_agents
()):
self
.
agents
[
i
].
status
=
RailAgentStatus
.
DONE
self
.
dones
[
i_agent
]
=
True
self
.
dones
[
i
]
=
True
info_dict
=
{
'
action_required
'
:
{
i
:
(
agent
.
status
==
RailAgentStatus
.
READY_TO_DEPART
or
(
agent
.
status
==
RailAgentStatus
.
ACTIVE
and
np
.
isclose
(
agent
.
speed_data
[
'
position_fraction
'
],
0.0
,
rtol
=
1e-03
)))
for
i
,
agent
in
enumerate
(
self
.
agents
)},
'
malfunction
'
:
{
i
:
self
.
agents
[
i
].
malfunction_data
[
'
malfunction
'
]
for
i
in
range
(
self
.
get_num_agents
())
},
'
speed
'
:
{
i
:
self
.
agents
[
i
].
speed_data
[
'
speed
'
]
for
i
in
range
(
self
.
get_num_agents
())},
'
status
'
:
{
i
:
agent
.
status
for
i
,
agent
in
enumerate
(
self
.
agents
)}
}
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
info_dict
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
info_dict
...
...
This diff is collapsed.
Click to expand it.
flatland/evaluators/client.py
+
76
−
55
View file @
9eec6c35
...
@@ -23,14 +23,6 @@ logger.setLevel(logging.INFO)
...
@@ -23,14 +23,6 @@ logger.setLevel(logging.INFO)
m
.
patch
()
m
.
patch
()
def
are_dicts_equal
(
d1
,
d2
):
"""
return True if all keys and values are the same
"""
return
all
(
k
in
d2
and
np
.
isclose
(
d1
[
k
],
d2
[
k
])
for
k
in
d1
)
\
and
all
(
k
in
d1
and
np
.
isclose
(
d1
[
k
],
d2
[
k
])
for
k
in
d2
)
class
FlatlandRemoteClient
(
object
):
class
FlatlandRemoteClient
(
object
):
"""
"""
Redis client to interface with flatland-rl remote-evaluation-service
Redis client to interface with flatland-rl remote-evaluation-service
...
@@ -64,6 +56,8 @@ class FlatlandRemoteClient(object):
...
@@ -64,6 +56,8 @@ class FlatlandRemoteClient(object):
port
=
remote_port
,
port
=
remote_port
,
db
=
remote_db
,
db
=
remote_db
,
password
=
remote_password
)
password
=
remote_password
)
self
.
redis_conn
=
redis
.
Redis
(
connection_pool
=
self
.
redis_pool
)
self
.
namespace
=
"
flatland-rl
"
self
.
namespace
=
"
flatland-rl
"
self
.
service_id
=
os
.
getenv
(
self
.
service_id
=
os
.
getenv
(
'
FLATLAND_RL_SERVICE_ID
'
,
'
FLATLAND_RL_SERVICE_ID
'
,
...
@@ -87,8 +81,26 @@ class FlatlandRemoteClient(object):
...
@@ -87,8 +81,26 @@ class FlatlandRemoteClient(object):
self
.
env
=
None
self
.
env
=
None
self
.
ping_pong
()
self
.
ping_pong
()
self
.
env_step_times
=
[]
self
.
stats
=
{}
def
update_running_mean_stats
(
self
,
key
,
scalar
):
"""
Computes the running mean for certain params
"""
mean_key
=
"
{}_mean
"
.
format
(
key
)
counter_key
=
"
{}_counter
"
.
format
(
key
)
try
:
self
.
stats
[
mean_key
]
=
\
((
self
.
stats
[
mean_key
]
*
self
.
stats
[
counter_key
])
+
scalar
)
/
(
self
.
stats
[
counter_key
]
+
1
)
self
.
stats
[
counter_key
]
+=
1
except
KeyError
:
self
.
stats
[
mean_key
]
=
0
self
.
stats
[
counter_key
]
=
0
def
get_redis_connection
(
self
):
def
get_redis_connection
(
self
):
return
redis
.
Redis
(
connection_pool
=
self
.
redis_
pool
)
return
self
.
redis_
conn
def
_generate_response_channel
(
self
):
def
_generate_response_channel
(
self
):
random_hash
=
hashlib
.
md5
(
random_hash
=
hashlib
.
md5
(
...
@@ -100,7 +112,7 @@ class FlatlandRemoteClient(object):
...
@@ -100,7 +112,7 @@ class FlatlandRemoteClient(object):
random_hash
)
random_hash
)
return
response_channel
return
response_channel
def
_
blocking
_request
(
self
,
_request
):
def
_
remote
_request
(
self
,
_request
,
blocking
=
True
):
"""
"""
request:
request:
-command_type
-command_type
...
@@ -114,6 +126,7 @@ class FlatlandRemoteClient(object):
...
@@ -114,6 +126,7 @@ class FlatlandRemoteClient(object):
"""
"""
assert
isinstance
(
_request
,
dict
)
assert
isinstance
(
_request
,
dict
)
_request
[
'
response_channel
'
]
=
self
.
_generate_response_channel
()
_request
[
'
response_channel
'
]
=
self
.
_generate_response_channel
()
_request
[
'
timestamp
'
]
=
time
.
time
()
_redis
=
self
.
get_redis_connection
()
_redis
=
self
.
get_redis_connection
()
"""
"""
...
@@ -126,18 +139,20 @@ class FlatlandRemoteClient(object):
...
@@ -126,18 +139,20 @@ class FlatlandRemoteClient(object):
# Note: The patched msgpack supports numpy arrays
# Note: The patched msgpack supports numpy arrays
payload
=
msgpack
.
packb
(
_request
,
default
=
m
.
encode
,
use_bin_type
=
True
)
payload
=
msgpack
.
packb
(
_request
,
default
=
m
.
encode
,
use_bin_type
=
True
)
_redis
.
lpush
(
self
.
command_channel
,
payload
)
_redis
.
lpush
(
self
.
command_channel
,
payload
)
# Wait with a blocking pop for the response
_response
=
_redis
.
blpop
(
_request
[
'
response_channel
'
])[
1
]
if
blocking
:
if
self
.
verbose
:
# Wait with a blocking pop for the response
print
(
"
Response :
"
,
_response
)
_response
=
_redis
.
blpop
(
_request
[
'
response_channel
'
])[
1
]
_response
=
msgpack
.
unpackb
(
if
self
.
verbose
:
_response
,
print
(
"
Response :
"
,
_response
)
object_hook
=
m
.
decode
,
_response
=
msgpack
.
unpackb
(
encoding
=
"
utf8
"
)
_response
,
if
_response
[
'
type
'
]
==
messages
.
FLATLAND_RL
.
ERROR
:
object_hook
=
m
.
decode
,
raise
Exception
(
str
(
_response
[
"
payload
"
]))
encoding
=
"
utf8
"
)
else
:
if
_response
[
'
type
'
]
==
messages
.
FLATLAND_RL
.
ERROR
:
return
_response
raise
Exception
(
str
(
_response
[
"
payload
"
]))
else
:
return
_response
def
ping_pong
(
self
):
def
ping_pong
(
self
):
"""
"""
...
@@ -151,7 +166,7 @@ class FlatlandRemoteClient(object):
...
@@ -151,7 +166,7 @@ class FlatlandRemoteClient(object):
_request
[
'
payload
'
]
=
{
_request
[
'
payload
'
]
=
{
"
version
"
:
flatland
.
__version__
"
version
"
:
flatland
.
__version__
}
}
_response
=
self
.
_
blocking
_request
(
_request
)
_response
=
self
.
_
remote
_request
(
_request
)
if
_response
[
'
type
'
]
!=
messages
.
FLATLAND_RL
.
PONG
:
if
_response
[
'
type
'
]
!=
messages
.
FLATLAND_RL
.
PONG
:
raise
Exception
(
raise
Exception
(
"
Unable to perform handshake with the evaluation service.
\
"
Unable to perform handshake with the evaluation service.
\
...
@@ -166,13 +181,17 @@ class FlatlandRemoteClient(object):
...
@@ -166,13 +181,17 @@ class FlatlandRemoteClient(object):
The observation builder is only used in the local env
The observation builder is only used in the local env
and the remote env uses a DummyObservationBuilder
and the remote env uses a DummyObservationBuilder
"""
"""
time_start
=
time
.
time
()
_request
=
{}
_request
=
{}
_request
[
'
type
'
]
=
messages
.
FLATLAND_RL
.
ENV_CREATE
_request
[
'
type
'
]
=
messages
.
FLATLAND_RL
.
ENV_CREATE
_request
[
'
payload
'
]
=
{}
_request
[
'
payload
'
]
=
{}
_response
=
self
.
_
blocking
_request
(
_request
)
_response
=
self
.
_
remote
_request
(
_request
)
observation
=
_response
[
'
payload
'
][
'
observation
'
]
observation
=
_response
[
'
payload
'
][
'
observation
'
]
info
=
_response
[
'
payload
'
][
'
info
'
]
info
=
_response
[
'
payload
'
][
'
info
'
]
random_seed
=
_response
[
'
payload
'
][
'
random_seed
'
]
random_seed
=
_response
[
'
payload
'
][
'
random_seed
'
]
test_env_file_path
=
_response
[
'
payload
'
][
'
env_file_path
'
]
time_diff
=
time
.
time
()
-
time_start
self
.
update_running_mean_stats
(
"
env_creation_wait_time
"
,
time_diff
)
if
not
observation
:
if
not
observation
:
# If the observation is False,
# If the observation is False,
...
@@ -180,7 +199,6 @@ class FlatlandRemoteClient(object):
...
@@ -180,7 +199,6 @@ class FlatlandRemoteClient(object):
# hence return false
# hence return false
return
observation
,
info
return
observation
,
info
test_env_file_path
=
_response
[
'
payload
'
][
'
env_file_path
'
]
if
self
.
verbose
:
if
self
.
verbose
:
print
(
"
Received Env :
"
,
test_env_file_path
)
print
(
"
Received Env :
"
,
test_env_file_path
)
...
@@ -207,13 +225,15 @@ class FlatlandRemoteClient(object):
...
@@ -207,13 +225,15 @@ class FlatlandRemoteClient(object):
obs_builder_object
=
obs_builder_object
obs_builder_object
=
obs_builder_object
)
)
time_start
=
time
.
time
()
local_observation
,
info
=
self
.
env
.
reset
(
local_observation
,
info
=
self
.
env
.
reset
(
regenerate_rail
=
Fals
e
,
regenerate_rail
=
Tru
e
,
regenerate_schedule
=
Fals
e
,
regenerate_schedule
=
Tru
e
,
activate_agents
=
False
,
activate_agents
=
False
,
random_seed
=
random_seed
random_seed
=
random_seed
)
)
time_diff
=
time
.
time
()
-
time_start
self
.
update_running_mean_stats
(
"
internal_env_reset_time
"
,
time_diff
)
# Use the local observation
# Use the local observation
# as the remote server uses a dummy observation builder
# as the remote server uses a dummy observation builder
return
local_observation
,
info
return
local_observation
,
info
...
@@ -226,39 +246,38 @@ class FlatlandRemoteClient(object):
...
@@ -226,39 +246,38 @@ class FlatlandRemoteClient(object):
_request
[
'
type
'
]
=
messages
.
FLATLAND_RL
.
ENV_STEP
_request
[
'
type
'
]
=
messages
.
FLATLAND_RL
.
ENV_STEP
_request
[
'
payload
'
]
=
{}
_request
[
'
payload
'
]
=
{}
_request
[
'
payload
'
][
'
action
'
]
=
action
_request
[
'
payload
'
][
'
action
'
]
=
action
_response
=
self
.
_blocking_request
(
_request
)
_payload
=
_response
[
'
payload
'
]
# Relay the action in a non-blocking way to the server
# so that it can start doing an env.step on it in ~ parallel
self
.
_remote_request
(
_request
,
blocking
=
False
)
# remote_observation = _payload['observation'] # noqa
# Apply the action in the local env
remote_reward
=
_payload
[
'
reward
'
]
time_start
=
time
.
time
()
remote_done
=
_payload
[
'
done
'
]
remote_info
=
_payload
[
'
info
'
]
# Replicate the action in the local env
local_observation
,
local_reward
,
local_done
,
local_info
=
\
local_observation
,
local_reward
,
local_done
,
local_info
=
\
self
.
env
.
step
(
action
)
self
.
env
.
step
(
action
)
time_diff
=
time
.
time
()
-
time_start
# Compute a running mean of env step times
self
.
update_running_mean_stats
(
"
internal_env_step_time
"
,
time_diff
)
if
self
.
verbose
:
return
[
local_observation
,
local_reward
,
local_done
,
local_info
]
print
(
local_reward
)
if
not
are_dicts_equal
(
remote_reward
,
local_reward
):
print
(
"
Remote Reward :
"
,
remote_reward
,
"
Local Reward :
"
,
local_reward
)
raise
Exception
(
"
local and remote `reward` are diverging
"
)
if
not
are_dicts_equal
(
remote_done
,
local_done
):
print
(
"
Remote Done :
"
,
remote_done
,
"
Local Done :
"
,
local_done
)
raise
Exception
(
"
local and remote `done` are diverging
"
)
# Return local_observation instead of remote_observation
# as the remote_observation is build using a dummy observation
# builder
# We return the remote rewards and done as they are the
# once used by the evaluator
return
[
local_observation
,
remote_reward
,
remote_done
,
remote_info
]
def
submit
(
self
):
def
submit
(
self
):
_request
=
{}
_request
=
{}
_request
[
'
type
'
]
=
messages
.
FLATLAND_RL
.
ENV_SUBMIT
_request
[
'
type
'
]
=
messages
.
FLATLAND_RL
.
ENV_SUBMIT
_request
[
'
payload
'
]
=
{}
_request
[
'
payload
'
]
=
{}
_response
=
self
.
_blocking_request
(
_request
)
_response
=
self
.
_remote_request
(
_request
)
######################################################################
# Print Local Stats
######################################################################
print
(
"
=
"
*
100
)
print
(
"
=
"
*
100
)
print
(
"
## Client Performance Stats
"
)
print
(
"
=
"
*
100
)
for
_key
in
self
.
stats
:
if
_key
.
endswith
(
"
_mean
"
):
print
(
"
\t
- {}
\t
:{}
"
.
format
(
_key
,
self
.
stats
[
_key
]))
print
(
"
=
"
*
100
)
if
os
.
getenv
(
"
AICROWD_BLOCKING_SUBMIT
"
):
if
os
.
getenv
(
"
AICROWD_BLOCKING_SUBMIT
"
):
"""
"""
If the submission is supposed to happen as a blocking submit,
If the submission is supposed to happen as a blocking submit,
...
@@ -279,13 +298,12 @@ if __name__ == "__main__":
...
@@ -279,13 +298,12 @@ if __name__ == "__main__":
_action
[
_idx
]
=
np
.
random
.
randint
(
0
,
5
)
_action
[
_idx
]
=
np
.
random
.
randint
(
0
,
5
)
return
_action
return
_action
my_observation_builder
=
TreeObsForRailEnv
(
max_depth
=
3
,
my_observation_builder
=
DummyObservationBuilder
()
predictor
=
ShortestPathPredictorForRailEnv
())
episode
=
0
episode
=
0
obs
=
True
obs
=
True
while
obs
:
while
obs
:
obs
=
remote_client
.
env_create
(
obs
,
info
=
remote_client
.
env_create
(
obs_builder_object
=
my_observation_builder
obs_builder_object
=
my_observation_builder
)
)
if
not
obs
:
if
not
obs
:
...
@@ -301,7 +319,10 @@ if __name__ == "__main__":
...
@@ -301,7 +319,10 @@ if __name__ == "__main__":
while
True
:
while
True
:
action
=
my_controller
(
obs
,
remote_client
.
env
)
action
=
my_controller
(
obs
,
remote_client
.
env
)
time_start
=
time
.
time
()
observation
,
all_rewards
,
done
,
info
=
remote_client
.
env_step
(
action
)
observation
,
all_rewards
,
done
,
info
=
remote_client
.
env_step
(
action
)
time_diff
=
time
.
time
()
-
time_start
print
(
"
Step Time :
"
,
time_diff
)
if
done
[
'
__all__
'
]:
if
done
[
'
__all__
'
]:
print
(
"
Current Episode :
"
,
episode
)
print
(
"
Current Episode :
"
,
episode
)
print
(
"
Episode Done
"
)
print
(
"
Episode Done
"
)
...
...
This diff is collapsed.
Click to expand it.
flatland/evaluators/service.py
+
53
−
33
View file @
9eec6c35
...
@@ -18,6 +18,7 @@ import timeout_decorator
...
@@ -18,6 +18,7 @@ import timeout_decorator
from
flatland.core.env_observation_builder
import
DummyObservationBuilder
from
flatland.core.env_observation_builder
import
DummyObservationBuilder
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.agent_utils
import
RailAgentStatus
from
flatland.envs.rail_generators
import
rail_from_file
from
flatland.envs.rail_generators
import
rail_from_file
from
flatland.envs.schedule_generators
import
schedule_from_file
from
flatland.envs.schedule_generators
import
schedule_from_file
from
flatland.evaluators
import
aicrowd_helpers
from
flatland.evaluators
import
aicrowd_helpers
...
@@ -123,6 +124,7 @@ class FlatlandRemoteEvaluationService:
...
@@ -123,6 +124,7 @@ class FlatlandRemoteEvaluationService:
"
normalized_reward
"
:
0.0
"
normalized_reward
"
:
0.0
}
}
}
}
self
.
stats
=
{}
# RailEnv specific variables
# RailEnv specific variables
self
.
env
=
False
self
.
env
=
False
...
@@ -134,6 +136,7 @@ class FlatlandRemoteEvaluationService:
...
@@ -134,6 +136,7 @@ class FlatlandRemoteEvaluationService:
self
.
simulation_percentage_complete
=
[]
self
.
simulation_percentage_complete
=
[]
self
.
simulation_steps
=
[]
self
.
simulation_steps
=
[]
self
.
simulation_times
=
[]
self
.
simulation_times
=
[]
self
.
env_step_times
=
[]
self
.
begin_simulation
=
False
self
.
begin_simulation
=
False
self
.
current_step
=
0
self
.
current_step
=
0
self
.
visualize
=
visualize
self
.
visualize
=
visualize
...
@@ -148,6 +151,21 @@ class FlatlandRemoteEvaluationService:
...
@@ -148,6 +151,21 @@ class FlatlandRemoteEvaluationService:
shutil
.
rmtree
(
self
.
vizualization_folder_name
)
shutil
.
rmtree
(
self
.
vizualization_folder_name
)
os
.
mkdir
(
self
.
vizualization_folder_name
)
os
.
mkdir
(
self
.
vizualization_folder_name
)
def
update_running_mean_stats
(
self
,
key
,
scalar
):
"""
Computes the running mean for certain params
"""
mean_key
=
"
{}_mean
"
.
format
(
key
)
counter_key
=
"
{}_counter
"
.
format
(
key
)
try
:
self
.
stats
[
mean_key
]
=
\
((
self
.
stats
[
mean_key
]
*
self
.
stats
[
counter_key
])
+
scalar
)
/
(
self
.
stats
[
counter_key
]
+
1
)
self
.
stats
[
counter_key
]
+=
1
except
KeyError
:
self
.
stats
[
mean_key
]
=
0
self
.
stats
[
counter_key
]
=
0
def
get_env_filepaths
(
self
):
def
get_env_filepaths
(
self
):
"""
"""
Gathers a list of all available rail env files to be used
Gathers a list of all available rail env files to be used
...
@@ -198,25 +216,14 @@ class FlatlandRemoteEvaluationService:
...
@@ -198,25 +216,14 @@ class FlatlandRemoteEvaluationService:
db
=
self
.
remote_db
,
db
=
self
.
remote_db
,
password
=
self
.
remote_password
password
=
self
.
remote_password
)
)
self
.
redis_conn
=
redis
.
Redis
(
connection_pool
=
self
.
redis_pool
)
def
get_redis_connection
(
self
):
def
get_redis_connection
(
self
):
"""
"""
Obtains a new redis connection from a previously instantiated
Obtains a new redis connection from a previously instantiated
redis connection pool
redis connection pool
"""
"""
redis_conn
=
redis
.
Redis
(
connection_pool
=
self
.
redis_pool
)
return
self
.
redis_conn
try
:
redis_conn
.
ping
()
except
Exception
:
raise
Exception
(
"
Unable to connect to redis server at {}:{} .
"
"
Are you sure there is a redis-server running at the
"
"
specified location ?
"
.
format
(
self
.
remote_host
,
self
.
remote_port
)
)
return
redis_conn
def
_error_template
(
self
,
payload
):
def
_error_template
(
self
,
payload
):
"""
"""
...
@@ -266,7 +273,9 @@ class FlatlandRemoteEvaluationService:
...
@@ -266,7 +273,9 @@ class FlatlandRemoteEvaluationService:
)
)
if
self
.
verbose
:
if
self
.
verbose
:
print
(
"
Received Request :
"
,
command
)
print
(
"
Received Request :
"
,
command
)
message_queue_latency
=
time
.
time
()
-
command
[
"
timestamp
"
]
self
.
update_running_mean_stats
(
"
message_queue_latency
"
,
message_queue_latency
)
return
command
return
command
def
send_response
(
self
,
_command_response
,
command
,
suppress_logs
=
False
):
def
send_response
(
self
,
_command_response
,
command
,
suppress_logs
=
False
):
...
@@ -319,7 +328,6 @@ class FlatlandRemoteEvaluationService:
...
@@ -319,7 +328,6 @@ class FlatlandRemoteEvaluationService:
"""
"""
There are still test envs left that are yet to be evaluated
There are still test envs left that are yet to be evaluated
"""
"""
test_env_file_path
=
self
.
env_file_paths
[
self
.
simulation_count
]
test_env_file_path
=
self
.
env_file_paths
[
self
.
simulation_count
]
print
(
"
Evaluating : {}
"
.
format
(
test_env_file_path
))
print
(
"
Evaluating : {}
"
.
format
(
test_env_file_path
))
test_env_file_path
=
os
.
path
.
join
(
test_env_file_path
=
os
.
path
.
join
(
...
@@ -334,10 +342,6 @@ class FlatlandRemoteEvaluationService:
...
@@ -334,10 +342,6 @@ class FlatlandRemoteEvaluationService:
schedule_generator
=
schedule_from_file
(
test_env_file_path
),
schedule_generator
=
schedule_from_file
(
test_env_file_path
),
obs_builder_object
=
DummyObservationBuilder
()
obs_builder_object
=
DummyObservationBuilder
()
)
)
if
self
.
visualize
:
if
self
.
env_renderer
:
del
self
.
env_renderer
self
.
env_renderer
=
RenderTool
(
self
.
env
,
gl
=
"
PILSVG
"
,
)
if
self
.
begin_simulation
:
if
self
.
begin_simulation
:
# If begin simulation has already been initialized
# If begin simulation has already been initialized
...
@@ -353,12 +357,17 @@ class FlatlandRemoteEvaluationService:
...
@@ -353,12 +357,17 @@ class FlatlandRemoteEvaluationService:
self
.
current_step
=
0
self
.
current_step
=
0
_observation
,
_info
=
self
.
env
.
reset
(
_observation
,
_info
=
self
.
env
.
reset
(
regenerate_rail
=
Fals
e
,
regenerate_rail
=
Tru
e
,
regenerate_schedule
=
Fals
e
,
regenerate_schedule
=
Tru
e
,
activate_agents
=
False
,
activate_agents
=
False
,
random_seed
=
RANDOM_SEED
random_seed
=
RANDOM_SEED
)
)
if
self
.
visualize
:
if
self
.
env_renderer
:
del
self
.
env_renderer
self
.
env_renderer
=
RenderTool
(
self
.
env
,
gl
=
"
PILSVG
"
,
)
_command_response
=
{}
_command_response
=
{}
_command_response
[
'
type
'
]
=
messages
.
FLATLAND_RL
.
ENV_CREATE_RESPONSE
_command_response
[
'
type
'
]
=
messages
.
FLATLAND_RL
.
ENV_CREATE_RESPONSE
_command_response
[
'
payload
'
]
=
{}
_command_response
[
'
payload
'
]
=
{}
...
@@ -412,9 +421,12 @@ class FlatlandRemoteEvaluationService:
...
@@ -412,9 +421,12 @@ class FlatlandRemoteEvaluationService:
has done[
'
__all__
'
]==True
"
)
has done[
'
__all__
'
]==True
"
)
action
=
_payload
[
'
action
'
]
action
=
_payload
[
'
action
'
]
time_start
=
time
.
time
()
_observation
,
all_rewards
,
done
,
info
=
self
.
env
.
step
(
action
)
_observation
,
all_rewards
,
done
,
info
=
self
.
env
.
step
(
action
)
time_diff
=
time
.
time
()
-
time_start
self
.
update_running_mean_stats
(
"
internal_env_step_time
"
,
time_diff
)
cumulative_reward
=
np
.
sum
(
list
(
all_rewards
.
values
())
)
cumulative_reward
=
sum
(
all_rewards
.
values
())
self
.
simulation_rewards
[
-
1
]
+=
cumulative_reward
self
.
simulation_rewards
[
-
1
]
+=
cumulative_reward
self
.
simulation_steps
[
-
1
]
+=
1
self
.
simulation_steps
[
-
1
]
+=
1
"""
"""
...
@@ -434,7 +446,7 @@ class FlatlandRemoteEvaluationService:
...
@@ -434,7 +446,7 @@ class FlatlandRemoteEvaluationService:
complete
=
0
complete
=
0
for
i_agent
in
range
(
self
.
env
.
get_num_agents
()):
for
i_agent
in
range
(
self
.
env
.
get_num_agents
()):
agent
=
self
.
env
.
agents
[
i_agent
]
agent
=
self
.
env
.
agents
[
i_agent
]
if
agent
.
position
==
a
gent
.
ta
rget
:
if
agent
.
status
in
[
RailA
gent
S
ta
tus
.
DONE_REMOVED
]
:
complete
+=
1
complete
+=
1
percentage_complete
=
complete
*
1.0
/
self
.
env
.
get_num_agents
()
percentage_complete
=
complete
*
1.0
/
self
.
env
.
get_num_agents
()
self
.
simulation_percentage_complete
[
-
1
]
=
percentage_complete
self
.
simulation_percentage_complete
[
-
1
]
=
percentage_complete
...
@@ -459,16 +471,6 @@ class FlatlandRemoteEvaluationService:
...
@@ -459,16 +471,6 @@ class FlatlandRemoteEvaluationService:
))
))
self
.
record_frame_step
+=
1
self
.
record_frame_step
+=
1
# Build and send response
_command_response
=
{}
_command_response
[
'
type
'
]
=
messages
.
FLATLAND_RL
.
ENV_STEP_RESPONSE
_command_response
[
'
payload
'
]
=
{}
_command_response
[
'
payload
'
][
'
observation
'
]
=
_observation
_command_response
[
'
payload
'
][
'
reward
'
]
=
all_rewards
_command_response
[
'
payload
'
][
'
done
'
]
=
done
_command_response
[
'
payload
'
][
'
info
'
]
=
info
self
.
send_response
(
_command_response
,
command
)
def
handle_env_submit
(
self
,
command
):
def
handle_env_submit
(
self
,
command
):
"""
"""
Handles a ENV_SUBMIT command from the client
Handles a ENV_SUBMIT command from the client
...
@@ -476,6 +478,18 @@ class FlatlandRemoteEvaluationService:
...
@@ -476,6 +478,18 @@ class FlatlandRemoteEvaluationService:
"""
"""
_payload
=
command
[
'
payload
'
]
_payload
=
command
[
'
payload
'
]
######################################################################
# Print Local Stats
######################################################################
print
(
"
=
"
*
100
)
print
(
"
=
"
*
100
)
print
(
"
## Server Performance Stats
"
)
print
(
"
=
"
*
100
)
for
_key
in
self
.
stats
:
if
_key
.
endswith
(
"
_mean
"
):
print
(
"
\t
- {}
\t
:{}
"
.
format
(
_key
,
self
.
stats
[
_key
]))
print
(
"
=
"
*
100
)
# Register simulation time of the last episode
# Register simulation time of the last episode
self
.
simulation_times
.
append
(
time
.
time
()
-
self
.
begin_simulation
)
self
.
simulation_times
.
append
(
time
.
time
()
-
self
.
begin_simulation
)
...
@@ -594,8 +608,12 @@ class FlatlandRemoteEvaluationService:
...
@@ -594,8 +608,12 @@ class FlatlandRemoteEvaluationService:
and acts accordingly.
and acts accordingly.
"""
"""
print
(
"
Listening at :
"
,
self
.
command_channel
)
print
(
"
Listening at :
"
,
self
.
command_channel
)
MESSAGE_QUEUE_LATENCY
=
[]
while
True
:
while
True
:
command
=
self
.
get_next_command
()
command
=
self
.
get_next_command
()
if
"
timestamp
"
in
command
.
keys
():
latency
=
time
.
time
()
-
command
[
"
timestamp
"
]
MESSAGE_QUEUE_LATENCY
.
append
(
latency
)
if
self
.
verbose
:
if
self
.
verbose
:
print
(
"
Self.Reward :
"
,
self
.
reward
)
print
(
"
Self.Reward :
"
,
self
.
reward
)
...
@@ -633,6 +651,8 @@ class FlatlandRemoteEvaluationService:
...
@@ -633,6 +651,8 @@ class FlatlandRemoteEvaluationService:
Submit the final cumulative reward
Submit the final cumulative reward
"""
"""
print
(
"
Overall Message Queue Latency :
"
,
np
.
array
(
MESSAGE_QUEUE_LATENCY
).
mean
())
self
.
handle_env_submit
(
command
)
self
.
handle_env_submit
(
command
)
else
:
else
:
_error
=
self
.
_error_template
(
_error
=
self
.
_error_template
(
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment