Skip to content
Snippets Groups Projects
Commit 07285d5f authored by u214892's avatar u214892
Browse files

#42 run baselines in continuous integration

parent c167f87e
No related branches found
No related tags found
1 merge request!242 run baselines in ci
Pipeline #1376 canceled
image: themattrix/tox
##########################################
##########################################
## We have to set the following env vars
## in the admin interface :
## - AWS_DEFAULT_REGION
## - BUCKET_NAME
## - AWS_ACCESS_KEY_ID
## - AWS_SECRET_ACCESS_KEY
stages:
- tests
- benchmarks_and_profiling
- deploy_docs
cache:
paths:
- .tox
before_script:
- echo "Setting Up...."
tests:
stage: tests
script:
- apt update
- apt install -y libgl1-mesa-glx xvfb graphviz xdg-utils libcairo2-dev libjpeg-dev libgif-dev
- pip install tox
- xvfb-run tox -v --recreate
torch==1.1.0 git+http://gitlab.aicrowd.com/flatland/flatland.git@master
\ No newline at end of file torch>=1.1.0
\ No newline at end of file
...@@ -21,7 +21,7 @@ passenv = ...@@ -21,7 +21,7 @@ passenv =
deps = deps =
-r{toxinidir}/requirements_torch_training.txt -r{toxinidir}/requirements_torch_training.txt
commands = commands =
python torch_training/training_navigation.py python torch_training/multi_agent_training.py
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
......
import numpy as np import numpy as np
from flatland.envs.observations import TreeObsForRailEnv
def max_lt(seq, val): def max_lt(seq, val):
""" """
...@@ -48,7 +50,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): ...@@ -48,7 +50,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
def split_tree(tree, num_features_per_node=9, current_depth=0): def split_tree(tree, current_depth=0):
""" """
Splits the tree observation into different sub groups that need the same normalization. Splits the tree observation into different sub groups that need the same normalization.
This is necessary because the tree observation includes two different distance: This is necessary because the tree observation includes two different distance:
...@@ -64,6 +66,7 @@ def split_tree(tree, num_features_per_node=9, current_depth=0): ...@@ -64,6 +66,7 @@ def split_tree(tree, num_features_per_node=9, current_depth=0):
:param current_depth: Keeping track of the current depth in the tree :param current_depth: Keeping track of the current depth in the tree
:return: Returns the three different groups of distance and binary values. :return: Returns the three different groups of distance and binary values.
""" """
num_features_per_node = TreeObsForRailEnv.observation_dim
if len(tree) < num_features_per_node: if len(tree) < num_features_per_node:
return [], [], [] return [], [], []
...@@ -88,7 +91,6 @@ def split_tree(tree, num_features_per_node=9, current_depth=0): ...@@ -88,7 +91,6 @@ def split_tree(tree, num_features_per_node=9, current_depth=0):
child_tree = tree[(num_features_per_node + children * child_size): child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)] (num_features_per_node + (children + 1) * child_size)]
tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree, tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree,
num_features_per_node,
current_depth=current_depth + 1) current_depth=current_depth + 1)
if len(tmp_tree_data) > 0: if len(tmp_tree_data) > 0:
tree_data.extend(tmp_tree_data) tree_data.extend(tmp_tree_data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment