Skip to content
Snippets Groups Projects
Commit 82c81129 authored by u229589's avatar u229589
Browse files

directly inherit from ObservationBuilder and not from TreeObsForRailEnv when it is not necessary

parent f54d6864
No related branches found
No related tags found
No related merge requests found
......@@ -6,8 +6,8 @@ from typing import List
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
......@@ -18,24 +18,20 @@ random.seed(100)
np.random.seed(100)
class SingleAgentNavigationObs(TreeObsForRailEnv):
class SingleAgentNavigationObs(ObservationBuilder):
"""
We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
the minimum distances from each grid node to each agent's target.
We then build a representation vector with 3 binary components, indicating which of the 3 available directions
We build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__(max_depth=0)
super().__init__()
self.observation_space = [3]
def reset(self):
# Recompute the distance map, if the environment has changed.
super().reset()
pass
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
......
......@@ -6,8 +6,9 @@ from typing import Optional, List, Dict
import numpy as np
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
......@@ -20,26 +21,18 @@ random.seed(100)
np.random.seed(100)
class ObservePredictions(TreeObsForRailEnv):
class ObservePredictions(ObservationBuilder):
"""
We use the provided ShortestPathPredictor to illustrate the usage of predictors in your custom observation.
We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
the minimum distances from each grid node to each agent's target.
This is necessary so that we can pass the distance map to the ShortestPathPredictor
Here we also want to highlight how you can visualize your observation
"""
def __init__(self, predictor):
super().__init__(max_depth=0)
super().__init__()
self.observation_space = [10]
self.predictor = predictor
def reset(self):
# Recompute the distance map, if the environment has changed.
super().reset()
pass
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]:
'''
......@@ -106,6 +99,11 @@ class ObservePredictions(TreeObsForRailEnv):
return observation
def _set_env(self, env: Environment):
self.env = env
if self.predictor:
self.predictor._set_env(self.env)
def main(args):
try:
......
......@@ -4,8 +4,8 @@ from typing import List
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
......@@ -15,24 +15,20 @@ random.seed(1)
np.random.seed(1)
class SingleAgentNavigationObs(TreeObsForRailEnv):
class SingleAgentNavigationObs(ObservationBuilder):
"""
We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
the minimum distances from each grid node to each agent's target.
We then build a representation vector with 3 binary components, indicating which of the 3 available directions
We build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__(max_depth=0)
super().__init__()
self.observation_space = [3]
def reset(self):
# Recompute the distance map, if the environment has changed.
super().reset()
pass
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
......
......@@ -2,34 +2,30 @@ import random
from typing import Dict, List
import numpy as np
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
class SingleAgentNavigationObs(TreeObsForRailEnv):
class SingleAgentNavigationObs(ObservationBuilder):
"""
We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
the minimum distances from each grid node to each agent's target.
We then build a representation vector with 3 binary components, indicating which of the 3 available directions
We build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__(max_depth=0)
super().__init__()
self.observation_space = [3]
def reset(self):
# Recompute the distance map, if the environment has changed.
super().reset()
pass
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment