diff --git a/run.py b/run.py index 55e31ea1f40ce3f45b54d919ac7bfc9218310bf8..33bc56ce761a43d13b6bf167efe576b260be0a30 100644 --- a/run.py +++ b/run.py @@ -45,7 +45,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy # Print per-step logs VERBOSE = True -USE_FAST_TREE_OBS = True +USE_FAST_TREEOBS = True # Checkpoint to use (remember to push it!) checkpoint = "./checkpoints/201124171810-7800.pth" # 18.249244799876152 DEPTH=2 AGENTS=10 @@ -74,9 +74,11 @@ if USE_FAST_TREEOBS: def check_is_observation_valid(observation): return True + def get_normalized_observation(observation, tree_depth: int, observation_radius=0): return observation + tree_observation = FastTreeObs(max_depth=observation_tree_depth) state_size = tree_observation.observation_dim else: