Skip to content
Snippets Groups Projects
Commit 07285d5f authored by u214892's avatar u214892
Browse files

#42 run baselines in continuous integration

parent c167f87e
No related branches found
No related tags found
No related merge requests found
image: themattrix/tox
##########################################
##########################################
## We have to set the following env vars
## in the admin interface :
## - AWS_DEFAULT_REGION
## - BUCKET_NAME
## - AWS_ACCESS_KEY_ID
## - AWS_SECRET_ACCESS_KEY
stages:
- tests
- benchmarks_and_profiling
- deploy_docs
cache:
paths:
- .tox
before_script:
- echo "Setting Up...."
tests:
stage: tests
script:
- apt update
- apt install -y libgl1-mesa-glx xvfb graphviz xdg-utils libcairo2-dev libjpeg-dev libgif-dev
- pip install tox
- xvfb-run tox -v --recreate
torch==1.1.0
\ No newline at end of file
git+http://gitlab.aicrowd.com/flatland/flatland.git@master
torch>=1.1.0
\ No newline at end of file
......@@ -21,7 +21,7 @@ passenv =
deps =
-r{toxinidir}/requirements_torch_training.txt
commands =
python torch_training/training_navigation.py
python torch_training/multi_agent_training.py
[flake8]
max-line-length = 120
......
import numpy as np
from flatland.envs.observations import TreeObsForRailEnv
def max_lt(seq, val):
"""
......@@ -48,7 +50,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
def split_tree(tree, num_features_per_node=9, current_depth=0):
def split_tree(tree, current_depth=0):
"""
Splits the tree observation into different sub groups that need the same normalization.
This is necessary because the tree observation includes two different distance:
......@@ -64,6 +66,7 @@ def split_tree(tree, num_features_per_node=9, current_depth=0):
:param current_depth: Keeping track of the current depth in the tree
:return: Returns the three different groups of distance and binary values.
"""
num_features_per_node = TreeObsForRailEnv.observation_dim
if len(tree) < num_features_per_node:
return [], [], []
......@@ -88,7 +91,6 @@ def split_tree(tree, num_features_per_node=9, current_depth=0):
child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)]
tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree,
num_features_per_node,
current_depth=current_depth + 1)
if len(tmp_tree_data) > 0:
tree_data.extend(tmp_tree_data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment