Skip to content
Snippets Groups Projects
Unverified Commit 90096804 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Add slurm support for distributed training (#508)

* implement _init_dist_slurm()

* add slurm train/test scripts

* fix linting error

* minor fix
parent 53c647ea
No related branches found
No related tags found
No related merge requests found
import logging
import os
import random
import subprocess
import numpy as np
import torch
......@@ -34,8 +35,19 @@ def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError
def _init_dist_slurm(backend, **kwargs):
raise NotImplementedError
def _init_dist_slurm(backend, port=29500, **kwargs):
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(
'scontrol show hostname {} | head -n1'.format(node_list))
os.environ['MASTER_PORT'] = str(port)
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=backend)
def set_random_seed(seed):
......
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
CHECKPOINT=$4
GPUS=${GPUS:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-32}
PY_ARGS=${@:5}
SRUN_ARGS=${SRUN_ARGS:-""}
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS} \
--ntasks=1 \
--ntasks-per-node=1 \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python tools/test.py ${CONFIG} ${CHECKPOINT} --gpus ${GPUS} ${PY_ARGS}
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
WORK_DIR=$4
GPUS=${5:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""}
PY_ARGS=${PY_ARGS:-"--validate"}
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/train.py ${CONFIG} --work_dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment