Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
2be1a6c4
Commit
2be1a6c4
authored
Jun 13, 2019
by
u214892
Browse files
fix master
parent
d77ecfd1
Changes
6
Hide whitespace changes
Inline
Side-by-side
flatland/core/env.py
View file @
2be1a6c4
...
...
@@ -84,21 +84,6 @@ class Environment:
"""
raise
NotImplementedError
()
def
predict
(
self
):
"""
Predictions step.
Returns predictions for the agents.
The returns are dicts mapping from agent_id strings to values.
Returns
-------
predictions : dict
New predictions for each ready agent.
"""
raise
NotImplementedError
()
def
get_agent_handles
(
self
):
"""
Returns a list of agents' handles to be used as keys in the step()
...
...
flatland/envs/observations.py
View file @
2be1a6c4
...
...
@@ -173,10 +173,6 @@ class TreeObsForRailEnv(ObservationBuilder):
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
"""
# TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object.
if
self
.
predictor
:
print
(
self
.
predictor
.
get
(
0
))
observations
=
{}
for
h
in
handles
:
observations
[
h
]
=
self
.
get
(
h
)
...
...
flatland/envs/rail_env.py
View file @
2be1a6c4
...
...
@@ -292,7 +292,6 @@ class RailEnv(Environment):
np
.
equal
(
new_position
,
[
agent2
.
position
for
agent2
in
self
.
agents
]).
all
(
1
))
return
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
def
check_action
(
self
,
agent
,
action
):
transition_isValid
=
None
possible_transitions
=
self
.
rail
.
get_transitions
((
*
agent
.
position
,
agent
.
direction
))
...
...
@@ -324,7 +323,6 @@ class RailEnv(Environment):
self
.
obs_dict
=
self
.
obs_builder
.
get_many
(
list
(
range
(
self
.
get_num_agents
())))
return
self
.
obs_dict
def
get_full_state_msg
(
self
):
grid_data
=
self
.
rail
.
grid
.
tolist
()
agent_static_data
=
[
agent
.
to_list
()
for
agent
in
self
.
agents_static
]
...
...
flatland/utils/editor.py
View file @
2be1a6c4
...
...
@@ -323,7 +323,8 @@ class Controller(object):
def
restartAgents
(
self
,
event
):
self
.
log
(
"Restart Agents - nAgents:"
,
self
.
view
.
wRegenNAgents
.
value
)
if
self
.
model
.
init_agents_static
is
not
None
:
self
.
model
.
env
.
agents_static
=
[
EnvAgentStatic
(
d
[
0
],
d
[
1
],
d
[
2
],
moving
=
False
)
for
d
in
self
.
model
.
init_agents_static
]
self
.
model
.
env
.
agents_static
=
[
EnvAgentStatic
(
d
[
0
],
d
[
1
],
d
[
2
],
moving
=
False
)
for
d
in
self
.
model
.
init_agents_static
]
self
.
model
.
env
.
agents
=
None
self
.
model
.
init_agents_static
=
None
self
.
player
=
None
...
...
flatland/utils/graphics_pil.py
View file @
2be1a6c4
...
...
@@ -396,7 +396,7 @@ class PILSVG(PILGL):
}
# "paint" color of the train images we load - this is the color we will change.
# a3BaseColor = self.rgb_s2i("0091ea")
# a3BaseColor = self.rgb_s2i("0091ea")
\# noqa: E800
# temporary workaround for trains / agents renamed with different colour:
a3BaseColor
=
self
.
rgb_s2i
(
"d50000"
)
...
...
tests/test_env_prediction_builder.py
View file @
2be1a6c4
...
...
@@ -5,7 +5,7 @@ import numpy as np
from
flatland.core.transition_map
import
GridTransitionMap
,
Grid4Transitions
from
flatland.envs.generators
import
rail_from_GridTransitionMap_generator
from
flatland.envs.observations
import
Global
ObsForRailEnv
from
flatland.envs.observations
import
Tree
ObsForRailEnv
from
flatland.envs.predictions
import
DummyPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
...
...
@@ -64,8 +64,7 @@ def test_predictions():
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_GridTransitionMap_generator
(
rail
),
number_of_agents
=
1
,
obs_builder_object
=
GlobalObsForRailEnv
(),
prediction_builder_object
=
DummyPredictorForRailEnv
(
max_depth
=
20
)
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
20
,
predictor
=
DummyPredictorForRailEnv
(
max_depth
=
20
)),
)
env
.
reset
()
...
...
@@ -74,7 +73,7 @@ def test_predictions():
env
.
agents
[
0
].
position
=
(
5
,
6
)
env
.
agents
[
0
].
direction
=
0
predictions
=
env
.
predic
t
()
predictions
=
env
.
obs_builder
.
predictor
.
ge
t
()
positions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
1
],
prediction
[
2
]],
predictions
[
0
])))
directions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
3
]],
predictions
[
0
])))
time_offsets
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
0
]],
predictions
[
0
])))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment