Commit 2be1a6c4 authored by u214892's avatar u214892
Browse files

fix master

parent d77ecfd1
......@@ -84,21 +84,6 @@ class Environment:
"""
raise NotImplementedError()
def predict(self):
"""
Predictions step.
Returns predictions for the agents.
The returns are dicts mapping from agent_id strings to values.
Returns
-------
predictions : dict
New predictions for each ready agent.
"""
raise NotImplementedError()
def get_agent_handles(self):
"""
Returns a list of agents' handles to be used as keys in the step()
......
......@@ -173,10 +173,6 @@ class TreeObsForRailEnv(ObservationBuilder):
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
"""
# TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object.
if self.predictor:
print(self.predictor.get(0))
observations = {}
for h in handles:
observations[h] = self.get(h)
......
......@@ -292,7 +292,6 @@ class RailEnv(Environment):
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
def check_action(self, agent, action):
transition_isValid = None
possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
......@@ -324,7 +323,6 @@ class RailEnv(Environment):
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
return self.obs_dict
def get_full_state_msg(self):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
......
......@@ -323,7 +323,8 @@ class Controller(object):
def restartAgents(self, event):
self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value)
if self.model.init_agents_static is not None:
self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in self.model.init_agents_static]
self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in
self.model.init_agents_static]
self.model.env.agents = None
self.model.init_agents_static = None
self.player = None
......
......@@ -396,7 +396,7 @@ class PILSVG(PILGL):
}
# "paint" color of the train images we load - this is the color we will change.
# a3BaseColor = self.rgb_s2i("0091ea")
# a3BaseColor = self.rgb_s2i("0091ea") \# noqa: E800
# temporary workaround for trains / agents renamed with different colour:
a3BaseColor = self.rgb_s2i("d50000")
......
......@@ -5,7 +5,7 @@ import numpy as np
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
......@@ -64,8 +64,7 @@ def test_predictions():
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
prediction_builder_object=DummyPredictorForRailEnv(max_depth=20)
obs_builder_object=TreeObsForRailEnv(max_depth=20, predictor=DummyPredictorForRailEnv(max_depth=20)),
)
env.reset()
......@@ -74,7 +73,7 @@ def test_predictions():
env.agents[0].position = (5, 6)
env.agents[0].direction = 0
predictions = env.predict()
predictions = env.obs_builder.predictor.get()
positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment