Commit 96fad330 authored by manueth's avatar manueth

combining observations example with tree and local conflict obs

parent 79edf57b
import gym
from flatland.core.env_observation_builder import ObservationBuilder
from typing import Optional, List
from envs.flatland.observations import Observation, register_obs, make_obs
@register_obs("combined")
class CombinedObservation(Observation):
def __init__(self, config) -> None:
super().__init__(config)
self._observations = [
make_obs(obs_name, config[obs_name]) for obs_name in config.keys()
]
self._builder = CombinedObsForRailEnv([
o._builder for o in self._observations
])
def builder(self) -> ObservationBuilder:
return self._builder
def observation_space(self) -> gym.Space:
space = []
for o in self._observations:
space.append(o.observation_space())
return gym.spaces.Tuple(space)
class CombinedObsForRailEnv(ObservationBuilder):
def __init__(self, builders: [ObservationBuilder]):
super().__init__()
self._builders = builders
def reset(self):
for b in self._builders:
b.reset()
def get(self, handle: int = 0):
return None
def get_many(self, handles: Optional[List[int]] = None):
obs = {h: [] for h in handles}
for b in self._builders:
sub_obs = b.get_many(handles)
for h in handles:
obs[h].append(sub_obs[h])
return obs
def set_env(self, env):
for b in self._builders:
b.set_env(env)
flatland-sparse-small-tree-and-conflict-fc-apex:
run: APEX
env: flatland_sparse
stop:
timesteps_total: 15000000 # 1e8
checkpoint_freq: 10
checkpoint_at_end: True
keep_checkpoints_num: 5
checkpoint_score_attr: episode_reward_mean
config:
num_workers: 15
num_envs_per_worker: 5
num_gpus: 0
env_config:
observation: combined
observation_config:
tree:
max_depth: 2
shortest_path_max_depth: 30
localConflict:
max_depth: 2
shortest_path_max_depth: 30
n_local: 5
generator: sparse_rail_generator
generator_config: small_v0
resolve_deadlocks: False
deadlock_reward: 0
density_reward_factor: 0
wandb:
project: flatland
entity: masterscrat
tags: ["small_v0", "tree_and_local_conflict", "apex"] # TODO should be set programmatically
model:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
vf_share_layers: True # False
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment