Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pranjal_dhole
Flatland
Commits
4ccccc1e
There was an error fetching the commit references. Please try again later.
Commit
4ccccc1e
authored
5 years ago
by
spmohanty
Browse files
Options
Downloads
Patches
Plain Diff
Addresses #117 - Add ability to pass in custom observation builder
parent
93c99c05
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
flatland/core/env_observation_builder.py
+21
-0
21 additions, 0 deletions
flatland/core/env_observation_builder.py
flatland/evaluators/client.py
+53
-18
53 additions, 18 deletions
flatland/evaluators/client.py
flatland/evaluators/service.py
+2
-12
2 additions, 12 deletions
flatland/evaluators/service.py
with
76 additions
and
30 deletions
flatland/core/env_observation_builder.py
+
21
−
0
View file @
4ccccc1e
...
@@ -73,3 +73,24 @@ class ObservationBuilder:
...
@@ -73,3 +73,24 @@ class ObservationBuilder:
direction
=
np
.
zeros
(
4
)
direction
=
np
.
zeros
(
4
)
direction
[
agent
.
direction
]
=
1
direction
[
agent
.
direction
]
=
1
return
direction
return
direction
class
DummyObservationBuilder
(
ObservationBuilder
):
"""
DummyObservationBuilder class which returns dummy observations
This is used in the evaluation service
"""
def
__init__
(
self
):
self
.
observation_space
=
()
def
_set_env
(
self
,
env
):
self
.
env
=
env
def
reset
(
self
):
pass
def
get_many
(
self
,
handles
=
[]):
return
True
def
get
(
self
,
handle
=
0
):
return
True
This diff is collapsed.
Click to expand it.
flatland/evaluators/client.py
+
53
−
18
View file @
4ccccc1e
...
@@ -18,6 +18,14 @@ logger.setLevel(logging.INFO)
...
@@ -18,6 +18,14 @@ logger.setLevel(logging.INFO)
m
.
patch
()
m
.
patch
()
def
are_dicts_equal
(
d1
,
d2
):
"""
return True if all keys and values are the same
"""
return
all
(
k
in
d2
and
d1
[
k
]
==
d2
[
k
]
for
k
in
d1
)
\
and
all
(
k
in
d1
and
d1
[
k
]
==
d2
[
k
]
for
k
in
d2
)
class
FlatlandRemoteClient
(
object
):
class
FlatlandRemoteClient
(
object
):
"""
"""
Redis client to interface with flatland-rl remote-evaluation-service
Redis client to interface with flatland-rl remote-evaluation-service
...
@@ -133,10 +141,16 @@ class FlatlandRemoteClient(object):
...
@@ -133,10 +141,16 @@ class FlatlandRemoteClient(object):
else
:
else
:
return
True
return
True
def
env_create
(
self
,
params
=
{}):
def
env_create
(
self
,
obs_builder_object
):
"""
Create a local env and remote env on which the
local agent can operate.
The observation builder is only used in the local env
and the remote env uses a DummyObservationBuilder
"""
_request
=
{}
_request
=
{}
_request
[
'
type
'
]
=
messages
.
FLATLAND_RL
.
ENV_CREATE
_request
[
'
type
'
]
=
messages
.
FLATLAND_RL
.
ENV_CREATE
_request
[
'
payload
'
]
=
params
_request
[
'
payload
'
]
=
{}
_response
=
self
.
_blocking_request
(
_request
)
_response
=
self
.
_blocking_request
(
_request
)
observation
=
_response
[
'
payload
'
][
'
observation
'
]
observation
=
_response
[
'
payload
'
][
'
observation
'
]
...
@@ -151,17 +165,15 @@ class FlatlandRemoteClient(object):
...
@@ -151,17 +165,15 @@ class FlatlandRemoteClient(object):
width
=
1
,
width
=
1
,
height
=
1
,
height
=
1
,
rail_generator
=
rail_from_file
(
test_env_file_path
),
rail_generator
=
rail_from_file
(
test_env_file_path
),
obs_builder_object
=
TreeObsForRailEnv
(
obs_builder_object
=
obs_builder_object
max_depth
=
3
,
predictor
=
ShortestPathPredictorForRailEnv
()
)
)
)
self
.
env
.
_max_episode_steps
=
\
self
.
env
.
_max_episode_steps
=
\
int
(
1.5
*
(
self
.
env
.
width
+
self
.
env
.
height
))
int
(
1.5
*
(
self
.
env
.
width
+
self
.
env
.
height
))
self
.
env
.
reset
()
local_observation
=
self
.
env
.
reset
()
# Use the observation from the remote service instead
# Use the local observation
return
observation
# as the remote server uses a dummy observation builder
return
local_observation
def
env_step
(
self
,
action
,
render
=
False
):
def
env_step
(
self
,
action
,
render
=
False
):
"""
"""
...
@@ -173,11 +185,25 @@ class FlatlandRemoteClient(object):
...
@@ -173,11 +185,25 @@ class FlatlandRemoteClient(object):
_request
[
'
payload
'
][
'
action
'
]
=
action
_request
[
'
payload
'
][
'
action
'
]
=
action
_response
=
self
.
_blocking_request
(
_request
)
_response
=
self
.
_blocking_request
(
_request
)
_payload
=
_response
[
'
payload
'
]
_payload
=
_response
[
'
payload
'
]
observation
=
_payload
[
'
observation
'
]
# remote_observation = _payload['observation']
reward
=
_payload
[
'
reward
'
]
reward
=
_payload
[
'
reward
'
]
done
=
_payload
[
'
done
'
]
done
=
_payload
[
'
done
'
]
info
=
_payload
[
'
info
'
]
info
=
_payload
[
'
info
'
]
return
[
observation
,
reward
,
done
,
info
]
# Replicate the action in the local env
local_observation
,
local_rewards
,
local_done
,
local_info
=
\
self
.
env
.
step
(
action
)
assert
are_dicts_equal
(
reward
,
local_rewards
)
assert
are_dicts_equal
(
done
,
local_done
)
# Return local_observation instead of remote_observation
# as the remote_observation is build using a dummy observation
# builder
# We return the remote rewards and done as they are the
# once used by the evaluator
return
[
local_observation
,
reward
,
done
,
info
]
def
submit
(
self
):
def
submit
(
self
):
_request
=
{}
_request
=
{}
...
@@ -196,28 +222,37 @@ class FlatlandRemoteClient(object):
...
@@ -196,28 +222,37 @@ class FlatlandRemoteClient(object):
if
__name__
==
"
__main__
"
:
if
__name__
==
"
__main__
"
:
env
_client
=
FlatlandRemoteClient
()
remote
_client
=
FlatlandRemoteClient
()
def
my_controller
(
obs
,
_env
):
def
my_controller
(
obs
,
_env
):
_action
=
{}
_action
=
{}
for
_idx
,
_
in
enumerate
(
_env
.
agents
):
for
_idx
,
_
in
enumerate
(
_env
.
agents
):
_action
[
_idx
]
=
np
.
random
.
randint
(
0
,
5
)
_action
[
_idx
]
=
np
.
random
.
randint
(
0
,
5
)
return
_action
return
_action
my_observation_builder
=
TreeObsForRailEnv
(
max_depth
=
3
,
predictor
=
ShortestPathPredictorForRailEnv
())
episode
=
0
episode
=
0
obs
=
True
obs
=
True
while
obs
:
while
obs
:
obs
=
env_client
.
env_create
()
obs
=
remote_client
.
env_create
(
obs_builder_object
=
my_observation_builder
)
if
not
obs
:
if
not
obs
:
"""
The remote env returns False as the first obs
when it is done evaluating all the individual episodes
"""
break
break
print
(
"
Episode : {}
"
.
format
(
episode
))
print
(
"
Episode : {}
"
.
format
(
episode
))
episode
+=
1
episode
+=
1
print
(
env
_client
.
env
.
dones
[
'
__all__
'
])
print
(
remote
_client
.
env
.
dones
[
'
__all__
'
])
while
True
:
while
True
:
action
=
my_controller
(
obs
,
env
_client
.
env
)
action
=
my_controller
(
obs
,
remote
_client
.
env
)
observation
,
all_rewards
,
done
,
info
=
env
_client
.
env_step
(
action
)
observation
,
all_rewards
,
done
,
info
=
remote
_client
.
env_step
(
action
)
if
done
[
'
__all__
'
]:
if
done
[
'
__all__
'
]:
print
(
"
Current Episode :
"
,
episode
)
print
(
"
Current Episode :
"
,
episode
)
print
(
"
Episode Done
"
)
print
(
"
Episode Done
"
)
...
@@ -225,6 +260,6 @@ if __name__ == "__main__":
...
@@ -225,6 +260,6 @@ if __name__ == "__main__":
break
break
print
(
"
Evaluation Complete...
"
)
print
(
"
Evaluation Complete...
"
)
print
(
env
_client
.
submit
())
print
(
remote
_client
.
submit
())
This diff is collapsed.
Click to expand it.
flatland/evaluators/service.py
+
2
−
12
View file @
4ccccc1e
...
@@ -3,8 +3,7 @@ from __future__ import print_function
...
@@ -3,8 +3,7 @@ from __future__ import print_function
import
redis
import
redis
from
flatland.envs.generators
import
rail_from_file
from
flatland.envs.generators
import
rail_from_file
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.observations
import
TreeObsForRailEnv
from
flatland.core.env_observation_builder
import
DummyObservationBuilder
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.evaluators
import
messages
from
flatland.evaluators
import
messages
import
numpy
as
np
import
numpy
as
np
import
msgpack
import
msgpack
...
@@ -235,7 +234,6 @@ class FlatlandRemoteEvaluationService:
...
@@ -235,7 +234,6 @@ class FlatlandRemoteEvaluationService:
Add a high level summary of everything thats
Add a high level summary of everything thats
hapenning here.
hapenning here.
"""
"""
env_params
=
command
[
"
payload
"
]
# noqa F841
if
self
.
simulation_count
<
len
(
self
.
env_file_paths
):
if
self
.
simulation_count
<
len
(
self
.
env_file_paths
):
"""
"""
...
@@ -244,19 +242,11 @@ class FlatlandRemoteEvaluationService:
...
@@ -244,19 +242,11 @@ class FlatlandRemoteEvaluationService:
test_env_file_path
=
self
.
env_file_paths
[
self
.
simulation_count
]
test_env_file_path
=
self
.
env_file_paths
[
self
.
simulation_count
]
del
self
.
env
del
self
.
env
# TODO : Use env_params dictionary to instantiate
# the RailEnv
# Maybe use a gin-like interface ?
# Needs discussion with Erik + Giacomo
# -Mohanty
self
.
env
=
RailEnv
(
self
.
env
=
RailEnv
(
width
=
1
,
width
=
1
,
height
=
1
,
height
=
1
,
rail_generator
=
rail_from_file
(
test_env_file_path
),
rail_generator
=
rail_from_file
(
test_env_file_path
),
obs_builder_object
=
TreeObsForRailEnv
(
obs_builder_object
=
DummyObservationBuilder
()
max_depth
=
3
,
predictor
=
ShortestPathPredictorForRailEnv
()
)
)
)
# Set max episode steps allowed
# Set max episode steps allowed
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment