diff --git a/flatland/core/env.py b/flatland/core/env.py index 3618d965a39b5a71fd1cf24aa81f2f876d5c6365..1bc5b6f3eba4ee4713bd3c8d6b88440006c215a5 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -84,21 +84,6 @@ class Environment: """ raise NotImplementedError() - def predict(self): - """ - Predictions step. - - Returns predictions for the agents. - The returns are dicts mapping from agent_id strings to values. - - Returns - ------- - predictions : dict - New predictions for each ready agent. - - """ - raise NotImplementedError() - def get_agent_handles(self): """ Returns a list of agents' handles to be used as keys in the step() diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index a3d88d773db9edaa0777e2aee94593a0392a956c..541f8ad592d1481afb8eb6da2eb7b887aacae419 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -31,7 +31,6 @@ class TreeObsForRailEnv(ObservationBuilder): self.location_has_agent = {} self.location_has_agent_direction = {} self.predictor = predictor - self.agents_previous_reset = None def reset(self): @@ -174,9 +173,10 @@ class TreeObsForRailEnv(ObservationBuilder): in the `handles' list. """ - # TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object. + self.predictions = [] if self.predictor: - print(self.predictor.get(0)) + for a in range(len(handles)): + self.predictions.append(self.predictor.get(a)) observations = {} for h in handles: observations[h] = self.get(h) @@ -222,6 +222,8 @@ class TreeObsForRailEnv(ObservationBuilder): (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) 0 = no agent present other direction than myself + #8: possible conflict detected + Missing/padding nodes are filled in with -inf (truncated). Missing values in present node are filled in with +inf (truncated). @@ -256,7 +258,6 @@ class TreeObsForRailEnv(ObservationBuilder): for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: if possible_transitions[branch_direction]: new_cell = self._new_position(agent.position, branch_direction) - branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) observation = observation + branch_observation diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 5d20a5d9f38230f353b0a9616c49ede333206c49..7773f86c1407153f649c972398c9a58067a38947 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -292,7 +292,6 @@ class RailEnv(Environment): np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid - def check_action(self, agent, action): transition_isValid = None possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) @@ -324,7 +323,6 @@ class RailEnv(Environment): self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict - def get_full_state_msg(self): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py index 5b0e830bbe14b081c53ce8a31ef2f9db270b62e3..ae910a369da849a86d39721d83b78ba777086d00 100644 --- a/tests/test_env_prediction_builder.py +++ b/tests/test_env_prediction_builder.py @@ -5,7 +5,7 @@ import numpy as np from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.envs.generators import rail_from_GridTransitionMap_generator -from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -64,8 +64,7 @@ def test_predictions(): height=rail_map.shape[0], rail_generator=rail_from_GridTransitionMap_generator(rail), number_of_agents=1, - obs_builder_object=GlobalObsForRailEnv(), - prediction_builder_object=DummyPredictorForRailEnv(max_depth=10) + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=20)), ) env.reset() @@ -75,7 +74,7 @@ def test_predictions(): env.agents[0].direction = 0 env.agents[0].target = (3., 0.) - predictions = env.predict() + predictions = env.obs_builder.predictor.get() positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))