Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • jack_bruck/baselines
  • rivesunder/baselines
  • xzhaoma/baselines
  • giulia_cantini/baselines
  • sfwatergit/baselines
  • jiaodaxiaozi/baselines
  • flatland/baselines
7 results
Show changes
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -3,13 +3,13 @@ import time
from collections import deque
import numpy as np
from line_profiler import LineProfiler
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from utils.observation_utils import norm_obs_clip, split_tree
from line_profiler import LineProfiler
from utils.observation_utils import norm_obs_clip, split_tree_into_feature_groups
def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='*'):
......@@ -102,10 +102,9 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
# Reset the env
lp_reset(True, True)
obs = env.reset(True, True)
obs, info = env.reset(True, True)
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(obs[a]),
current_depth=0)
data, distance, agent_data = split_tree_into_feature_groups(obs[a], tree_depth)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1)
......@@ -129,8 +128,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
next_obs, all_rewards, done, _ = lp_step(action_dict)
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
current_depth=0)
data, distance, agent_data = split_tree_into_feature_groups(next_obs[a], tree_depth)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1)
......
This diff is collapsed.