diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..ec4a626fbb7799f6a25b45fb86344b2bf7b37e64
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1 @@
+*.pth filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
index a3e8f58d6bc6e7887e5a924e7e0f33b3f4fc0e11..bedb5871a53a9dfee919ffe3e5556f1ba622fe29 100644
--- a/.gitignore
+++ b/.gitignore
@@ -129,3 +129,5 @@ dmypy.json
 # Pyre type checker
 .pyre/
 
+# Editor
+.vscode/
diff --git a/README.md b/README.md
index cfd9e1eb079965aee61c55a17fc911352842af7d..ab9cefc3de6b59586b3b37c005d5acf16c3513b9 100644
--- a/README.md
+++ b/README.md
@@ -15,7 +15,7 @@ Clone the repository to compete now!
 *  **Documentation** on how to submit your agent to the leaderboard
 *  **The procedure** for best practices and information on how we evaluate your agent, etc.
 *  **Starter code** for you to get started!
-
+*  **SiamMOT**: Siamese Multi-Object Tracking baseline
 
 
 # Table of Contents
@@ -26,8 +26,9 @@ Clone the repository to compete now!
 4. [How do I specify my software runtime / dependencies?](#how-do-i-specify-my-software-runtime-dependencies-)
 5. [What should my code structure be like ?](#what-should-my-code-structure-be-like-)
 6. [How to make submission](#how-to-make-submission)
-7. [Other concepts and FAQs](#other-concepts)
-8. [Important links](#-important-links)
+7. [:star: SiamMOT baseline](#submit-siammot-baseline)
+8. [Other concepts and FAQs](#other-concepts)
+9. [Important links](#-important-links)
 
 
 <p style="text-align:center"><img style="text-align:center" src="https://images.aicrowd.com/dataset_files/challenge_753/493d98aa-b7e5-45f8-aed1-640e4768f647_video.gif"  width="1024"></p>
@@ -127,9 +128,24 @@ Please specify if your code will use a GPU or not for the evaluation of your mod
 
 👉 [SUBMISSION.md](/docs/SUBMISSION.md)
 
-
 **Best of Luck** :tada: :tada:
 
+# SiamMOT baseline
+
+This repository contains [SiamMOT](https://github.com/amazon-research/siam-mot) baseline interface which you can submit and improve upon.
+
+SiamMOT is a region-based Siamese Multi-Object Tracking network that detects and associates object instances simultaneously.
+
+## Additional Steps
+
+1. Change your entrypoint i.e. `run.sh` from `python test.py` to `python siam_mot_test.py`.
+2. Copy the Dockerfile present in `siam-mot/Dockerfile` to repository root.
+3. Follow common steps shared in [SUBMISSION.md](/docs/SUBMISSION.md)
+
+```
+#> cp siam-mot/Dockerfile Dockerfile
+```
+
 # Other Concepts
 
 ## Time constraints
@@ -144,6 +160,25 @@ You can also test end to end evaluation on your own systems. The scripts are ava
 
 We have curated frequently asked questions and common mistakes on Discourse, you can read them here: [FAQ and Common mistakes](https://discourse.aicrowd.com/t/faqs-and-common-mistakes-while-making-a-submission/5781)
 
+
+# SiamMOT baseline
+
+[SiamMOT](https://github.com/amazon-research/siam-mot) is a region-based Siamese Multi-Object Tracking network that detects and associates object instances simultaneously.
+
+This repository contains [SiamMOT](https://github.com/amazon-research/siam-mot) baseline interface which you can submit and improve upon.
+
+SiamMOT is a region-based Siamese Multi-Object Tracking network that detects and associates object instances simultaneously.
+
+## How to submit SiamMOT
+
+1. Change your entrypoint i.e. `run.sh` from `python test.py` to `python siam_mot_test.py`.
+2. Copy the Dockerfile present in `siam-mot/Dockerfile` to repository root.
+3. Follow common steps shared in [SUBMISSION.md](/docs/SUBMISSION.md)
+
+```
+#> cp siam-mot/Dockerfile Dockerfile
+```
+
 # 📎 Important links
 
 
diff --git a/siam-mot/README.md b/siam-mot/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..05353c6588570525c44165b8beb29ba5c80c42e9
--- /dev/null
+++ b/siam-mot/README.md
@@ -0,0 +1,9 @@
+This folder is from SiamMOT: Siamese Multi-Object Tracking:
+https://github.com/amazon-research/siam-mot
+
+Model checkpoint is taken from [model_zoo.md](https://github.com/amazon-research/siam-mot/blob/main/readme/model_zoo.md) and saved in `models/`
+
+Additional files:
+
+- `siam_mot_tracker.py`: Interfacing for Airprime Challenge Submission (Contributed by [Dipam Chakraborty](https://github.com/Dipamc77))
+- `Dockerfile`: Docker file for plug and play use (Contributed by [Yoogottam Khandelwal](https://github.com/YoogottamK))
diff --git a/siam-mot/configs/dla/DLA_34_FPN.yaml b/siam-mot/configs/dla/DLA_34_FPN.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c4a1d409fa6b5647f50ee26c50cb1509f080ceca
--- /dev/null
+++ b/siam-mot/configs/dla/DLA_34_FPN.yaml
@@ -0,0 +1,72 @@
+INPUT:
+  MIN_SIZE_TRAIN: (640, 720, 800, 880, 960)
+  MAX_SIZE_TRAIN: 1500
+  MIN_SIZE_TEST: 800
+  MAX_SIZE_TEST: 1500
+  PIXEL_MEAN: [0.485, 0.456, 0.406]
+  PIXEL_STD: [0.229, 0.224, 0.225]
+  TO_BGR255: False
+  # This augmentation is intended for
+  # dataset with small motion
+  MOTION_LIMIT: 0.05
+  MOTION_BLUR_PROB: 1.0
+  COMPRESSION_LIMIT: 50
+  BRIGHTNESS: 0.1
+  CONTRAST: 0.1
+  SATURATION: 0.1
+  HUE: 0.1
+  AMODAL: True
+VIDEO:
+   TEMPORAL_WINDOW: 1000
+   TEMPORAL_SAMPLING: 250
+   RANDOM_FRAMES_PER_CLIP: 2
+MODEL:
+  META_ARCHITECTURE: "GeneralizedRCNN"
+  BOX_ON: True
+  TRACK_ON: True
+
+  BACKBONE:
+    CONV_BODY: "DLA-34-FPN"
+
+  RPN:
+    USE_FPN: True
+    ANCHOR_STRIDE: (4, 8, 16, 32, 64)
+    ANCHOR_SIZES: (32, 64, 128, 256, 512)
+    PRE_NMS_TOP_N_TRAIN: 2000
+    PRE_NMS_TOP_N_TEST: 1000
+    POST_NMS_TOP_N_TEST: 300
+    FPN_POST_NMS_TOP_N_TEST: 300
+  ROI_HEADS:
+    USE_FPN: True
+    BATCH_SIZE_PER_IMAGE: 256
+  ROI_BOX_HEAD:
+    POOLER_RESOLUTION: 7
+    POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
+    POOLER_SAMPLING_RATIO: 2
+    FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
+    PREDICTOR: "FPNPredictor"
+    NUM_CLASSES: 2
+    MLP_HEAD_DIM: 1024
+
+  TRACK_HEAD:
+    MAX_DORMANT_FRAMES: 1
+
+SOLVER:
+  # Assume 8 GPUs
+  BASE_LR: 0.02
+  WEIGHT_DECAY: 0.0001
+  STEPS: (15000, 20000)
+  MAX_ITER: 25000
+  VIDEO_CLIPS_PER_BATCH: 16
+
+DATASETS:
+  # SET THIS PATH CORRECTLY
+  ROOT_DIR: ""
+  TRAIN: ("crowdhuman_train_fbox", "COCO17_train")
+  TEST: ("MOT17_50_50", )
+
+DATALOADER:
+  SIZE_DIVISIBILITY: 32
+  NUM_WORKERS: 0
+
+DTYPE: "float32"
diff --git a/siam-mot/configs/dla/DLA_34_FPN_AOT.yaml b/siam-mot/configs/dla/DLA_34_FPN_AOT.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0e925be5f0cb584c678d07f8953fb218fb2b203c
--- /dev/null
+++ b/siam-mot/configs/dla/DLA_34_FPN_AOT.yaml
@@ -0,0 +1,83 @@
+INPUT:
+  MIN_SIZE_TRAIN: (2048,)
+  MAX_SIZE_TRAIN: 2480
+  MIN_SIZE_TEST: 2048
+  MAX_SIZE_TEST: 2480
+  PIXEL_MEAN: [0.485, 0.456, 0.406]
+  PIXEL_STD: [0.229, 0.224, 0.225]
+  TO_BGR255: False
+  # This augmentation is intended for
+  # dataset with small motion
+  MOTION_LIMIT: 0.005
+  MOTION_BLUR_PROB: 0.
+  COMPRESSION_LIMIT: 0
+  BRIGHTNESS: 0.
+  CONTRAST: 0.
+  SATURATION: 0.
+  HUE: 0.
+  AMODAL: False
+VIDEO:
+   TEMPORAL_WINDOW: 1000
+   TEMPORAL_SAMPLING: 1000
+   RANDOM_FRAMES_PER_CLIP: 2
+MODEL:
+  META_ARCHITECTURE: "GeneralizedRCNN"
+  BOX_ON: True
+  TRACK_ON: True
+
+  BACKBONE:
+    CONV_BODY: "DLA-34-FPN"
+
+  RPN:
+    USE_FPN: True
+    ANCHOR_STRIDE: (4, 8, 16, 32, 64)
+    ANCHOR_SIZES: (6, 12, 24, 48, 96)
+    PRE_NMS_TOP_N_TRAIN: 2000
+    PRE_NMS_TOP_N_TEST: 1000
+    POST_NMS_TOP_N_TEST: 300
+    FPN_POST_NMS_TOP_N_TEST: 300
+  ROI_HEADS:
+    USE_FPN: True
+    BATCH_SIZE_PER_IMAGE: 256
+  ROI_BOX_HEAD:
+    POOLER_RESOLUTION: 7
+    POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
+    POOLER_SAMPLING_RATIO: 2
+    FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
+    PREDICTOR: "FPNPredictor"
+    NUM_CLASSES: 2
+    MLP_HEAD_DIM: 1024
+
+  # Track branch configuration
+  TRACK_HEAD:
+    MODEL: "EMM"
+    PAD_PIXELS: 256
+    POOLER_RESOLUTION: 7
+    SEARCH_REGION: 5.0
+    # For inference
+    TRACK_THRESH: 0.6
+    START_TRACK_THRESH: 0.95
+    EMM:
+      USE_CENTERNESS: False
+      COSINE_WINDOW_WEIGHT: 0.1
+      TRACK_LOSS_WEIGHT: 0.1
+
+SOLVER:
+  # Assume 8 GPUs
+  BASE_LR: 0.02
+  WEIGHT_DECAY: 0.0001
+  STEPS: (15000, 20000)
+  MAX_ITER: 25000
+  VIDEO_CLIPS_PER_BATCH: 16
+
+DATASETS:
+  # SET THIS PATH CORRECTLY
+  ROOT_DIR: ""
+  TRAIN: ("prime_air", )
+  TEST: ("prime_air", )
+
+DATALOADER:
+  SIZE_DIVISIBILITY: 32
+  NUM_WORKERS: 2
+
+DTYPE: "float32"
diff --git a/siam-mot/models/DLA-34-FPN_box_track_aot_d4.pth b/siam-mot/models/DLA-34-FPN_box_track_aot_d4.pth
new file mode 100644
index 0000000000000000000000000000000000000000..6d9fa86759db9c8756ca196aab630ce4f39519c9
--- /dev/null
+++ b/siam-mot/models/DLA-34-FPN_box_track_aot_d4.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:34afba9240f69a0ae2a68b14fa4775a34d47ef3a3598a1093677f7c9c2a049be
+size 95833603
diff --git a/siam-mot/siam_mot_tracker.py b/siam-mot/siam_mot_tracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..edcb00456433aa9554b03f4323ab5ecf77dde7bb
--- /dev/null
+++ b/siam-mot/siam_mot_tracker.py
@@ -0,0 +1,76 @@
+import os
+import logging
+import torch
+from PIL import Image
+from pathlib import Path
+from tqdm import tqdm
+import urllib
+import zipfile
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
+
+from siammot.configs.defaults import cfg
+from siammot.modelling.rcnn import build_siammot
+from siammot.data.adapters.augmentation.build_augmentation import build_siam_augmentation
+
+import cv2
+
+class SiamMOTTracker:
+    """
+    Implement a wrapper to call tracker
+    """
+
+    def __init__(self,
+                 config_file,
+                 model_path,
+                 gpu_id=0):
+
+        self.device = torch.device("cuda:{}".format(gpu_id))
+
+        cfg.merge_from_file(config_file)
+        self.cfg = cfg
+        self.model_path = model_path
+
+        self.transform = build_siam_augmentation(cfg, is_train=False)
+        self.tracker = self._build_and_load_tracker()
+        self.tracker.eval()
+        self.tracker.reset_siammot_status()
+
+    def _preprocess(self, frame):
+
+        # frame is RGB-Channel
+        frame = Image.fromarray(frame, 'RGB')
+        dummy_bbox = torch.tensor([[0, 0, 1, 1]])
+        dummy_boxlist = BoxList(dummy_bbox, frame.size, mode='xywh')
+        frame, _ = self.transform(frame, dummy_boxlist)
+
+        return frame
+
+    def _build_and_load_tracker(self):
+        tracker = build_siammot(self.cfg)
+        tracker.to(self.device)
+        checkpointer = DetectronCheckpointer(cfg, tracker,
+                                              save_dir=self.model_path)
+        if os.path.isfile(self.model_path):
+            _ = checkpointer.load(self.model_path)
+        elif os.path.isdir(self.model_path):
+            _ = checkpointer.load(use_latest=True)
+        else:
+            raise ValueError("No model parameters are loaded.")
+
+        return tracker
+
+    def process(self, frame):
+        orig_h, orig_w, _ = frame.shape
+        # frame should be RGB image
+        frame = self._preprocess(frame)
+
+        with torch.no_grad():
+            results = self.tracker(frame.to(self.device))
+
+        assert (len(results) == 1)
+        results = results[0].to('cpu')
+        results = results.resize([orig_w, orig_h]).convert('xywh')
+
+        return results
diff --git a/siam-mot/siammot/configs/defaults.py b/siam-mot/siammot/configs/defaults.py
new file mode 100644
index 0000000000000000000000000000000000000000..e89f66f9fc55c79acf61d6c6b0817fbeddb255b2
--- /dev/null
+++ b/siam-mot/siammot/configs/defaults.py
@@ -0,0 +1,87 @@
+from yacs.config import CfgNode as CN
+
+from maskrcnn_benchmark.config import cfg
+
+# Root directory of datasets
+cfg.DATASETS.ROOT_DIR = ''
+
+# all video-related parameters
+cfg.VIDEO = CN()
+# the length of video clip for training/testing
+cfg.VIDEO.TEMPORAL_WINDOW = 8
+# the temporal sampling frequency for training
+cfg.VIDEO.TEMPORAL_SAMPLING = 4
+cfg.VIDEO.RANDOM_FRAMES_PER_CLIP = 2
+
+cfg.MODEL.BOX_ON = True
+
+# DLA
+cfg.MODEL.DLA = CN()
+cfg.MODEL.DLA.DLA_STAGE2_OUT_CHANNELS = 64
+cfg.MODEL.DLA.DLA_STAGE3_OUT_CHANNELS = 128
+cfg.MODEL.DLA.DLA_STAGE4_OUT_CHANNELS = 256
+cfg.MODEL.DLA.DLA_STAGE5_OUT_CHANNELS = 512
+cfg.MODEL.DLA.BACKBONE_OUT_CHANNELS = 128
+cfg.MODEL.DLA.STAGE_WITH_DCN = (False, False, False, False, False, False)
+
+# TRACK branch
+cfg.MODEL.TRACK_ON = False
+cfg.MODEL.TRACK_HEAD = CN()
+cfg.MODEL.TRACK_HEAD.TRACKTOR = False
+cfg.MODEL.TRACK_HEAD.POOLER_SCALES = (0.25, 0.125, 0.0625, 0.03125)
+cfg.MODEL.TRACK_HEAD.POOLER_RESOLUTION = 15
+cfg.MODEL.TRACK_HEAD.POOLER_SAMPLING_RATIO = 2
+
+cfg.MODEL.TRACK_HEAD.PAD_PIXELS = 512
+# the times of width/height of search region comparing to original bounding boxes
+cfg.MODEL.TRACK_HEAD.SEARCH_REGION = 2.0
+# the minimal width / height of the search region
+cfg.MODEL.TRACK_HEAD.MINIMUM_SREACH_REGION = 0
+cfg.MODEL.TRACK_HEAD.MODEL = 'EMM'
+
+# solver params
+cfg.MODEL.TRACK_HEAD.TRACK_THRESH = 0.4
+cfg.MODEL.TRACK_HEAD.START_TRACK_THRESH = 0.6
+cfg.MODEL.TRACK_HEAD.RESUME_TRACK_THRESH = 0.4
+# maximum number of frames that a track can be dormant
+cfg.MODEL.TRACK_HEAD.MAX_DORMANT_FRAMES = 1
+
+# track proposal sampling
+cfg.MODEL.TRACK_HEAD.PROPOSAL_PER_IMAGE = 256
+cfg.MODEL.TRACK_HEAD.FG_IOU_THRESHOLD = 0.65
+cfg.MODEL.TRACK_HEAD.BG_IOU_THRESHOLD = 0.35
+
+cfg.MODEL.TRACK_HEAD.IMM = CN()
+# the feature dimension of search region (after fc layer)
+# in comparison to that of target region (after fc layer)
+cfg.MODEL.TRACK_HEAD.IMM.FC_HEAD_DIM_MULTIPLIER = 2
+cfg.MODEL.TRACK_HEAD.IMM.FC_HEAD_DIM = 256
+
+cfg.MODEL.TRACK_HEAD.EMM = CN()
+# Use_centerness flag only activates during inference
+cfg.MODEL.TRACK_HEAD.EMM.USE_CENTERNESS = True
+cfg.MODEL.TRACK_HEAD.EMM.POS_RATIO = 0.25
+cfg.MODEL.TRACK_HEAD.EMM.HN_RATIO = 0.25
+cfg.MODEL.TRACK_HEAD.EMM.TRACK_LOSS_WEIGHT = 1.
+# The ratio of center region to be positive positions
+cfg.MODEL.TRACK_HEAD.EMM.CLS_POS_REGION = 0.8
+# The lower this weight, it allows large motion offset during inference
+# Setting this param to be small (e.g. 0.1) for datasets that have fast motion,
+# such as caltech roadside pedestrian
+cfg.MODEL.TRACK_HEAD.EMM.COSINE_WINDOW_WEIGHT = 0.4
+
+#Inference
+cfg.INFERENCE = CN()
+cfg.INFERENCE.USE_GIVEN_DETECTIONS = False
+# The length of clip per forward pass
+cfg.INFERENCE.CLIP_LEN = 1
+
+#Solver
+cfg.SOLVER.CHECKPOINT_PERIOD = 5000
+cfg.SOLVER.VIDEO_CLIPS_PER_BATCH = 16
+
+#Input
+cfg.INPUT.MOTION_LIMIT = 0.05
+cfg.INPUT.COMPRESSION_LIMIT = 50
+cfg.INPUT.MOTION_BLUR_PROB = 0.5
+cfg.INPUT.AMODAL = False
diff --git a/siam-mot/siammot/engine/inferencer.py b/siam-mot/siammot/engine/inferencer.py
new file mode 100644
index 0000000000000000000000000000000000000000..60861f976336382e64af84ca55022878a0bd2b32
--- /dev/null
+++ b/siam-mot/siammot/engine/inferencer.py
@@ -0,0 +1,154 @@
+import os
+import logging
+import time
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from gluoncv.torch.data.gluoncv_motion_dataset.dataset import DataSample
+
+from ..data.build_inference_data_loader import build_video_loader
+from ..data.adapters.augmentation.build_augmentation import build_siam_augmentation
+from ..utils.boxlists_to_entities import boxlists_to_entities
+from ..eval.eval_clears_mot import eval_clears_mot
+
+
+def do_inference(cfg, model, sample: DataSample, transforms=None,
+                 given_detection: DataSample = None) -> DataSample:
+    """
+    Do inference on a specific video (sample)
+    :param cfg: configuration file of the model
+    :param model: a pytorch model
+    :param sample: a testing video
+    :param transforms: image-wise transform that prepares
+           video frames for processing
+    :param given_detection: the cached detections from other model,
+           it means that the detection branch is disabled in the
+           model forward pass
+    :return: the detection results in the format of DataSample
+    """
+    logger = logging.getLogger(__name__)
+    model.eval()
+    gpu_device = torch.device('cuda')
+
+    video_loader = build_video_loader(cfg, sample, transforms)
+
+    sample_result = DataSample(sample.id, raw_info=None, metadata=sample.metadata)
+    network_time = 0
+    for (video_clip, frame_id, timestamps) in tqdm(video_loader):
+        frame_id = frame_id.item()
+        timestamps = torch.squeeze(timestamps, dim=0).tolist()
+        video_clip = torch.squeeze(video_clip, dim=0)
+
+        with torch.no_grad():
+            video_clip = video_clip.to(gpu_device)
+            torch.cuda.synchronize()
+            network_start_time = time.time()
+            output_boxlists= model(video_clip)
+            torch.cuda.synchronize()
+            network_time += time.time() - network_start_time
+
+        # Resize to original image size and to xywh mode
+        output_boxlists = [o.resize([sample.width, sample.height]).convert('xywh')
+                           for o in output_boxlists]
+        output_boxlists = [o.to(torch.device("cpu")) for o in output_boxlists]
+        output_entities = boxlists_to_entities(output_boxlists, frame_id, timestamps)
+        for entity in output_entities:
+            sample_result.add_entity(entity)
+
+    logger.info('Sample_id {} / Speed {} fps'.format(sample.id, len(sample) / (network_time)))
+
+    return sample_result
+
+
+class DatasetInference(object):
+    def __init__(self, cfg, model, dataset, output_dir, data_filter_fn=None,
+                 distributed=False):
+
+        self._cfg = cfg
+
+        self._transform = build_siam_augmentation(cfg, is_train=False)
+        self._model = model
+        self._dataset = dataset
+        self._output_dir = output_dir
+        self._distributed = distributed
+        self._data_filter_fn = data_filter_fn
+        self._track_conf = 0.7
+        self._track_len = 5
+        self._logger = logging.getLogger(__name__)
+
+        self.results = dict()
+
+    def _eval_det_ap(self):
+        from ..eval.eval_det_ap import eval_det_ap
+        iou_threshold = np.arange(0.5, 0.95, 0.05).tolist()
+        ap_matrix = eval_det_ap(self._dataset, self.results,
+                                data_filter_fn=self._data_filter_fn,
+                                iou_threshold=iou_threshold)
+        ap = np.mean(ap_matrix, axis=0)
+
+        ap_str_summary = "\n"
+        ap_str_summary += 'Detection AP @[ IoU=0.50:0.95 ] = {:.2f}\n'.format(np.mean(ap) * 100)
+        ap_str_summary += 'Detection AP @[ IoU=0.50 ] = {:.2f}\n'.format(ap[0] * 100)
+        ap_str_summary += 'Detection AP @[ IoU=0.75 ] = {:.2f}\n'.format(ap[5] * 100)
+
+        return ap, ap_str_summary
+
+    def _eval_clear_mot(self):
+
+        motmetric, motstrsummary = eval_clears_mot(self._dataset, self.results,
+                                                   data_filter_fn=self._data_filter_fn)
+        return motmetric, motstrsummary
+
+    def _inference_on_video(self, sample):
+        cache_path = os.path.join(self._output_dir, '{}.json'.format(sample.id))
+        os.makedirs(os.path.dirname(cache_path), exist_ok=True)
+
+        if os.path.exists(cache_path):
+            sample_result = DataSample.load(cache_path)
+        else:
+            sample_result = do_inference(self._cfg, self._model, sample,
+                                         transforms=self._transform,
+                                         )
+            sample_result.dump(cache_path)
+        return sample_result
+
+    def _postprocess_tracks(self, tracks: DataSample):
+        """
+        post_process the tracks to filter out short and non-confident tracks
+        :param tracks: un-filtered tracks
+        :return: filtered tracks that would be used for evaluation
+        """
+        track_ids = set()
+        for _entity in tracks.entities:
+            if _entity.id not in track_ids and _entity.id >= 0:
+                track_ids.add(_entity.id)
+
+        filter_tracks = tracks.get_copy_without_entities()
+        for _id in track_ids:
+            _id_entities = tracks.get_entities_with_id(_id)
+            _track_conf = np.mean([_e.confidence for _e in _id_entities])
+            if len(_id_entities) >= self._track_len \
+                    and _track_conf >= self._track_conf:
+                for _entity in _id_entities:
+                    filter_tracks.add_entity(_entity)
+        return filter_tracks
+
+    def __call__(self):
+        # todo: enable the inference in an efficient distributed framework
+        for (sample_id, sample) in tqdm(self._dataset):
+            # clean up the memory
+            self._model.reset_siammot_status()
+
+            sample_result = self._inference_on_video(sample)
+
+            sample_result = self._postprocess_tracks(sample_result)
+            self.results.update({sample.id: sample_result})
+
+        self._logger.info("\n---------------- Start evaluating ----------------\n")
+        motmetric, motstrsummary = self._eval_clear_mot()
+        self._logger.info(motstrsummary)
+
+        # ap, ap_str_summary = self._eval_det_ap()
+        # self._logger.info(ap_str_summary)
+        self._logger.info("\n---------------- Finish evaluating ----------------\n")
diff --git a/siam-mot/siammot/engine/tensorboard_writer.py b/siam-mot/siammot/engine/tensorboard_writer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc0685c5f96c4195d70d9dd544bb0b74fbba86d9
--- /dev/null
+++ b/siam-mot/siammot/engine/tensorboard_writer.py
@@ -0,0 +1,86 @@
+import torch
+import itertools
+import numpy as np
+import torch.distributed as dist
+from torch.utils.tensorboard import SummaryWriter
+
+from maskrcnn_benchmark.utils.comm import get_world_size
+
+
+class TensorboardWriter(SummaryWriter):
+    def __init__(self, cfg, train_dir):
+        if get_world_size() < 2 or dist.get_rank() == 0:
+            super(TensorboardWriter, self).__init__(log_dir=train_dir)
+
+        device = torch.device(cfg.MODEL.DEVICE)
+        self.model_mean = torch.as_tensor(cfg.INPUT.PIXEL_MEAN, device=device)
+        self.model_std = torch.as_tensor(cfg.INPUT.PIXEL_STD, device=device)
+
+        self.image_to_bgr255 = cfg.INPUT.TO_BGR255
+
+        # number of images per row during visualization
+        self.num_col = cfg.VIDEO.RANDOM_FRAMES_PER_CLIP
+
+    def __call__(self, iter, loss, loss_dict, images, targets):
+        """
+
+        :param iter:
+        :param loss_dict:
+        :param images:
+        :return:
+        """
+        if get_world_size() < 2 or dist.get_rank() == 0:
+            self.add_scalar('loss', loss.detach().cpu().numpy(), iter)
+            for (_loss_key, _val) in loss_dict.items():
+                self.add_scalar(_loss_key, _val.detach().cpu().numpy(), iter)
+
+            # write down images / ground truths every 500 images
+            if iter == 1 or iter % 500 == 0:
+                show_images = images.tensors
+                show_images = show_images.mul_(self.model_std[None, :, None, None]).\
+                    add_(self.model_mean[None, :, None, None])
+
+                # From RGB255 to BGR255
+                if self.image_to_bgr255:
+                    show_images = show_images[:, [2, 1, 0], :, :] / 255.
+
+                # Detection ground truth
+                merged_image, bbox_in_merged_image = self.images_with_boxes(show_images, targets)
+                self.add_image_with_boxes('ground truth', merged_image, bbox_in_merged_image, iter)
+
+    def images_with_boxes(self, images, boxes):
+        """
+        Get images inpainted with bounding boxes
+        :param images: A batch of images are packed in a torch tensor BxCxHxW
+        :param boxes:  A list of bounding boxes for the corresponding images
+        :param ncols:
+        """
+        # To numpy array
+        images = images.detach().cpu().numpy()
+        # new stitched image
+        batch, channels, height, width = images.shape
+        assert batch % self.num_col == 0
+        nrows = batch // self.num_col
+
+        new_height = height * nrows
+        new_width = width * self.num_col
+
+        merged_image = np.zeros([channels, new_height, new_width])
+        bbox_in_merged_image = []
+
+        for img_idx in range(batch):
+            row = img_idx // self.num_col
+            col = img_idx % self.num_col
+            merged_image[:, row * height:(row + 1) * height, col * width:(col + 1) * width] = \
+                images[img_idx, :, :, :]
+            box = boxes[img_idx].bbox.detach().cpu().numpy()
+            if box.size > 0:
+                box[:, 0] += col * width
+                box[:, 1] += row * height
+                box[:, 2] += col * width
+                box[:, 3] += row * height
+                bbox_in_merged_image.append(box)
+
+        bbox_in_merged_image = np.array(list(itertools.chain(*bbox_in_merged_image)))
+
+        return merged_image, bbox_in_merged_image
diff --git a/siam-mot/siammot/engine/trainer.py b/siam-mot/siammot/engine/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..393985a4b02c9c61fc80dc41a5babd01c8554316
--- /dev/null
+++ b/siam-mot/siammot/engine/trainer.py
@@ -0,0 +1,104 @@
+import datetime
+import logging
+import time
+from apex import amp
+import torch.distributed as dist
+
+from maskrcnn_benchmark.utils.metric_logger import MetricLogger
+from maskrcnn_benchmark.engine.trainer import reduce_loss_dict
+from maskrcnn_benchmark.utils.comm import get_world_size
+
+from .tensorboard_writer import TensorboardWriter
+
+
+def do_train(
+        model,
+        data_loader,
+        optimizer,
+        scheduler,
+        checkpointer,
+        device,
+        checkpoint_period,
+        arguments,
+        tensorboard_writer: TensorboardWriter = None
+):
+    logger = logging.getLogger(__name__)
+    logger.info("Start training")
+    meters = MetricLogger(delimiter="  ")
+    max_iter = len(data_loader)
+    start_iter = arguments["iteration"]
+    model.train()
+    start_training_time = time.time()
+    end = time.time()
+
+    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
+
+        if any(len(target) < 1 for target in targets):
+            logger.error(
+                "Iteration={iteration + 1} || Image Ids used for training {_} || "
+                "targets Length={[len(target) for target in targets]}")
+            continue
+
+        data_time = time.time() - end
+        iteration = iteration + 1
+        arguments["iteration"] = iteration
+
+        scheduler.step()
+
+        images = images.to(device)
+        targets = [target.to(device) for target in targets]
+
+        result, loss_dict = model(images, targets)
+
+        losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = reduce_loss_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+        meters.update(loss=losses_reduced, **loss_dict_reduced)
+
+        optimizer.zero_grad()
+        # Note: If mixed precision is not used, this ends up doing nothing
+        # Otherwise apply loss scaling for mixed-precision recipe
+        with amp.scale_loss(losses, optimizer) as scaled_losses:
+            scaled_losses.backward()
+        optimizer.step()
+
+        # write images / ground truth / evaluation metrics to tensorboard
+        tensorboard_writer(iteration, losses_reduced, loss_dict_reduced, images, targets)
+
+        batch_time = time.time() - end
+        end = time.time()
+        meters.update(time=batch_time, data=data_time)
+        eta_seconds = meters.time.global_avg * (max_iter - iteration)
+        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+
+        if get_world_size() < 2 or dist.get_rank() == 0:
+            if iteration % 20 == 0 or iteration == max_iter:
+                logger.info(
+                    meters.delimiter.join(
+                        [
+                            "eta: {eta}",
+                            "iter: {iter}",
+                            "{meters}",
+                            "lr: {lr:.6f}",
+                        ]
+                    ).format(
+                        eta=eta_string,
+                        iter=iteration,
+                        meters=str(meters),
+                        lr=optimizer.param_groups[0]["lr"],
+                    )
+                )
+        if iteration % checkpoint_period == 0:
+            checkpointer.save("model_{:07d}".format(iteration), **arguments)
+        if iteration == max_iter:
+            checkpointer.save("model_final", **arguments)
+
+    total_training_time = time.time() - start_training_time
+    total_time_str = str(datetime.timedelta(seconds=total_training_time))
+    logger.info(
+        "Total training time: {} ({:.4f} s / it)".format(
+            total_time_str, total_training_time / (max_iter)
+        )
+    )
diff --git a/siam-mot/siammot/eval/eval_clears_mot.py b/siam-mot/siammot/eval/eval_clears_mot.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f7cf1ea06a594525576e5fd27ec3ed209b4c150
--- /dev/null
+++ b/siam-mot/siammot/eval/eval_clears_mot.py
@@ -0,0 +1,84 @@
+from tqdm import tqdm
+import motmetrics as mm
+
+
+def eval_clears_mot(samples, predicted_samples, data_filter_fn=None,
+                    iou_thresh=0.5):
+    """
+    :param samples: a list of (sample_id, sample:DataSample)
+    :param predicted_samples: a dict with (sample_id: predicted_tracks:DataSample)
+    :param data_filter_fn: a callable function to filter entities
+    :param iou_thresh: The IOU (between a predicted bounding box and gt ) threshold
+                       that determines a predicted bounding box is a true positive
+    """
+
+    assert 0 < iou_thresh <= 1
+
+    all_accumulators = []
+    sample_ids = []
+
+    metrics_host = mm.metrics.create()
+
+    for (sample_id, sample) in tqdm(samples):
+
+        predicted_tracks = predicted_samples[sample_id]
+        num_frames = len(sample)
+
+        def get_id_and_bbox(entities):
+            ids = [entity.id for entity in entities]
+            bboxes = [entity.bbox for entity in entities]
+            return ids, bboxes
+
+        accumulator = mm.MOTAccumulator(auto_id=True)
+        for i in range(num_frames):
+            valid_gt = sample.get_entities_for_frame_num(i)
+            ignore_gt = []
+
+            # If data filter function is available
+            if data_filter_fn is not None:
+                valid_gt, ignore_gt = data_filter_fn(valid_gt,
+                                                     meta_data=sample.metadata)
+            gt_ids, gt_bboxes = get_id_and_bbox(valid_gt)
+
+            out_ids = []
+            out_bboxes = []
+
+            # if there is no annotation for a particular frame, we don't evaluate on it
+            # this happens for low-fps annotation such as in CRP
+            # if len(gt_bboxes) > 0:
+            predicted_entities = predicted_tracks.get_entities_for_frame_num(i)
+
+            # If data filter function is available
+            if data_filter_fn is not None:
+                valid_pred, ignore_pred = data_filter_fn(predicted_entities, ignore_gt)
+            else:
+                valid_pred = predicted_entities
+
+            out_ids, out_bboxes = get_id_and_bbox(valid_pred)
+
+            bbox_distances = mm.distances.iou_matrix(gt_bboxes, out_bboxes, max_iou=1-iou_thresh)
+            accumulator.update(gt_ids, out_ids, bbox_distances)
+
+        all_accumulators.append(accumulator)
+        sample_ids.append(sample_id)
+
+    # Make sure to update to the latest version of motmetrics via pip or idf1 calculation might be very slow
+    metrics = ['num_frames', 'mostly_tracked', 'partially_tracked', 'mostly_lost', 'num_switches',
+               'num_false_positives', 'num_misses', 'mota', 'motp', 'idf1']
+
+    strsummary = ""
+    if len(all_accumulators):
+        summary = metrics_host.compute_many(
+            all_accumulators,
+            metrics=metrics,
+            names=sample_ids,
+            generate_overall=True
+        )
+
+        strsummary = mm.io.render_summary(
+            summary,
+            formatters=metrics_host.formatters,
+            namemap=mm.io.motchallenge_metric_names
+        )
+
+    return all_accumulators, "\n\n"+strsummary+"\n\n"
diff --git a/siam-mot/siammot/eval/eval_det_ap.py b/siam-mot/siammot/eval/eval_det_ap.py
new file mode 100644
index 0000000000000000000000000000000000000000..aac097a8fbed53d51eae2a13dfb39b6b0e441b3f
--- /dev/null
+++ b/siam-mot/siammot/eval/eval_det_ap.py
@@ -0,0 +1,107 @@
+import numpy as np
+import copy
+from tqdm import tqdm
+
+from gluoncv.torch.data.gluoncv_motion_dataset.dataset import DataSample
+
+from .eval_utils import greedy_matching, compute_AP, bbs_iou
+
+
+def eval_det_ap(gt: list, pred: dict, class_table=None, data_filter_fn=None, iou_threshold=[0.5]):
+    """
+    Evaluate the detection performance (COCO-style ap) on PoseTrack dataset
+    :param gt: ground truth annotations for all videos
+    :type gt: dict(vid_id: DataSample)
+    :param pred: predictions for all videos
+    :type pred: dict(vid_id: DataSample)
+    :param data_filter_fn: a callable function that filters out detections that are not considered during evaluation
+    :param class_table: class table specify the class order
+    :param iou_threshold:
+    :return: Average Precision (AP) over different thresholds
+    """
+    if class_table is None:
+        class_table = ["person"]
+    num_classes = len(class_table)
+
+    all_scores = [[[] for _ in range(len(iou_threshold))] for _ in range(num_classes)]
+    all_pr_ious = [[[] for _ in range(len(iou_threshold))] for _ in range(num_classes)]
+    all_gt_ious = [[[] for _ in range(len(iou_threshold))] for _ in range(num_classes)]
+
+    for (vid_id, vid_gt) in tqdm(gt):
+        vid_pred = pred[vid_id]
+
+        eval_frame_idxs = vid_gt.get_non_empty_frames()
+
+        # Loop over all classes
+        for class_id in range(0, num_classes):
+            gt_class_entities = vid_gt.entities
+            # gt_class_entities = vid_gt.get_entities_with_label(class_table[class_id])
+            pred_class_entities = vid_pred.get_entities_with_label(class_table[class_id])
+
+            # Wrap entities to a DataSample
+            vid_class_gt = DataSample(vid_id, metadata=vid_gt.metadata)
+            vid_class_pred = DataSample(vid_id, metadata=vid_pred.metadata)
+            for _entity in gt_class_entities:
+                vid_class_gt.add_entity(_entity)
+            for _entity in pred_class_entities:
+                vid_class_pred.add_entity(_entity)
+
+            # Get AP for this class and video
+            vid_class_scores, vid_class_pr_ious, vid_class_gt_ious = \
+                get_ap(vid_class_gt, vid_class_pred, data_filter_fn, eval_frame_idxs, iou_threshold)
+
+            for iou_id in range(len(iou_threshold)):
+                all_scores[class_id][iou_id] += vid_class_scores[iou_id]
+                all_pr_ious[class_id][iou_id] += vid_class_pr_ious[iou_id]
+                all_gt_ious[class_id][iou_id] += vid_class_gt_ious[iou_id]
+
+    class_ap_matrix = np.zeros((num_classes, len(iou_threshold)))
+    for class_id in range(num_classes):
+        class_ap_matrix[class_id, :] = compute_AP(all_scores[class_id],
+                                                  all_pr_ious[class_id],
+                                                  all_gt_ious[class_id])
+
+    return class_ap_matrix
+
+
+def get_ap(vid_class_gt: DataSample, vid_class_pred: DataSample, filter_fn, eval_frame_idxs, iou_thresh=[0.5]):
+    """
+    :param vid_class_gt: the ground truths for a specific class, in DataSample format
+    :param vid_class_pred: the predictions for a specific class, in DataSample format
+    :param filter_fn: a callable function to filter out detections
+    :param eval_frame_idxs: the frame indexs where evaluation happens
+    :param iou_thresh: the list of iou threshod that determines whether a detection is TP
+    :returns
+           vid_scores: the confidence for every predicted entity (a Python list)
+           vid_pr_ious: the iou between the predicted entity and its matching gt entity (a Python list)
+           vid_gt_ious: the iou between the gt entity and its matching predicted entity (a Python list)
+    """
+    if not isinstance(iou_thresh, list):
+        iou_thresh = [iou_thresh]
+    vid_scores = [[] for _ in iou_thresh]
+    vid_pr_ious = [[] for _ in iou_thresh]
+    vid_gt_ious = [[] for _ in iou_thresh]
+    for frame_idx in eval_frame_idxs:
+
+        gt_entities = vid_class_gt.get_entities_for_frame_num(frame_idx)
+        pred_entities = vid_class_pred.get_entities_for_frame_num(frame_idx)
+
+        # Remove detections for evaluation that are within ignore regions
+        if filter_fn is not None:
+            # Filter out ignored gt entities
+            gt_entities, ignore_gt_entities = filter_fn(gt_entities, meta_data=vid_class_gt.metadata)
+            # Filter out predicted entities that overlaps with ignored gt entities
+            pred_entities, ignore_pred_entities = filter_fn(pred_entities, ignore_gt_entities)
+
+        # sort the entity based on confidence scores
+        pred_entities = sorted(pred_entities, key=lambda x: x.confidence, reverse=True)
+        iou_matrix = bbs_iou(pred_entities, gt_entities)
+        scores = [entity.confidence for entity in pred_entities]
+        for i, _iou in enumerate(iou_thresh):
+            # pred_ious, gt_ious = target_matching(pred_entities, gt_entities)
+            pred_ious, gt_ious = greedy_matching(copy.deepcopy(iou_matrix), _iou)
+            vid_scores[i] += scores
+            vid_pr_ious[i] += pred_ious
+            vid_gt_ious[i] += gt_ious
+
+    return vid_scores, vid_pr_ious, vid_gt_ious
diff --git a/siam-mot/siammot/eval/eval_utils.py b/siam-mot/siammot/eval/eval_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a1b3c5898ba7d83f1357871107924ef009c9f98
--- /dev/null
+++ b/siam-mot/siammot/eval/eval_utils.py
@@ -0,0 +1,103 @@
+import numpy as np
+from sklearn.metrics import auc
+
+from gluoncv.torch.data.gluoncv_motion_dataset.dataset import AnnoEntity
+from ..utils.entity_utils import bbs_iou
+
+
+def evaluate_recall(gt: [AnnoEntity], pred: [AnnoEntity], iou_thresh=0.5):
+    """
+    :param gt: groundtruth entities for a frame
+    :param pred: prediction entities for a frame
+    :param iou_thresh:
+
+    """
+    iou_matrix = bbs_iou(pred, gt)
+    pred_ious, gt_ious = greedy_matching(iou_matrix, iou_thresh=iou_thresh)
+
+    tp = 0
+    fn = len(gt)
+
+    for pred_iou in pred_ious:
+        if pred_iou == 1:
+            tp += 1
+            fn -= 1
+
+    assert(tp+fn == len(gt))
+
+    return tp, fn
+
+
+def precision_recall_curve(scores, pred_ious, gt_ious, iou_threshold=0.5):
+    """
+    Return a list of precision/recall based on different confidence thresholds
+    """
+    precisions = []
+    recalls = []
+    sorted_ = sorted(zip(scores, pred_ious), reverse=True)
+
+    tp = 0
+    fp = 0
+    fn = len(gt_ious)
+    for (score, pred_iou) in sorted_:
+        if pred_iou >= iou_threshold:
+            tp += 1
+            fn -= 1
+        else:
+            fp += 1
+        precisions.append(float(tp)/float(tp+fp+1e-4))
+        recalls.append(float(tp)/float(tp+fn+1e-4))
+
+    return precisions, recalls
+
+
+def greedy_matching(iou_matrix, iou_thresh=0.5):
+    """
+        Do the greedy matching across predictions and ground truth annotations
+        Returns the matching ious for every predictions and ground truths
+        Every row denotes a ground truth matching with all predictions
+        """
+    (num_pred, num_gt) = iou_matrix.shape
+    gt_ious = np.zeros(num_gt)
+    pred_ious = np.zeros(num_pred)
+    if iou_matrix.size > 0:
+        for i in range(num_pred):
+            max_iou = np.amax(iou_matrix[i, :])
+            if max_iou >= iou_thresh:
+                _id = np.where(iou_matrix[i, :] == max_iou)[0][0]
+                pred_ious[i] = 1
+                gt_ious[_id] = 1
+                iou_matrix[:, _id] = 0
+    return pred_ious.tolist(), gt_ious.tolist()
+
+
+def compute_AP(scores, pred_ious, gt_ious):
+    """
+    Computer  Average Precision (AP) given a list of score
+    :param scores: A list of confidence scores of detections
+    :param pred_ious: A list of iou of detections w.r.t the most matching ground truth bounding boxes
+    :param gt_ious: A list of iou of ground truth bounding boxes w.r.t the most matching detections
+    :return: Average Precision (AP)
+    """
+    if not isinstance(scores[0], list):
+        scores = [scores]
+        pred_ious = [pred_ious]
+        gt_ious = [gt_ious]
+    assert (len(scores) == len(pred_ious))
+    assert (len(scores) == len(gt_ious))
+
+    ap_list = np.zeros((len(scores), ))
+    precisions = []
+    recalls = []
+    for i in range(len(scores)):
+        precision, recall = precision_recall_curve(scores[i],
+                                                     pred_ious[i],
+                                                     gt_ious[i])
+
+        if len(recall) >= 2:
+            ap_list[i] = auc(recall, precision)
+            precisions.append(precision)
+            recalls.append(recall)
+
+    return ap_list
+
diff --git a/siam-mot/siammot/modelling/backbone/backbone_ext.py b/siam-mot/siammot/modelling/backbone/backbone_ext.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bd0e1d2308047f8e110faff5dbb6a109bf2d9c7
--- /dev/null
+++ b/siam-mot/siammot/modelling/backbone/backbone_ext.py
@@ -0,0 +1,48 @@
+from torch import nn
+from collections import OrderedDict
+
+from maskrcnn_benchmark.modeling import registry
+from maskrcnn_benchmark.modeling.make_layers import conv_with_kaiming_uniform
+
+from .dla import dla
+from . import fpn as fpn_module
+
+
+@registry.BACKBONES.register("DLA-34-FPN")
+@registry.BACKBONES.register("DLA-46-C-FPN")
+@registry.BACKBONES.register("DLA-60-FPN")
+@registry.BACKBONES.register("DLA-102-FPN")
+@registry.BACKBONES.register("DLA-169-FPN")
+def build_dla_fpn_backbone(cfg):
+    body = dla(cfg)
+    in_channels_stage2 = cfg.MODEL.DLA.DLA_STAGE2_OUT_CHANNELS
+    in_channels_stage3 = cfg.MODEL.DLA.DLA_STAGE3_OUT_CHANNELS
+    in_channels_stage4 = cfg.MODEL.DLA.DLA_STAGE4_OUT_CHANNELS
+    in_channels_stage5 = cfg.MODEL.DLA.DLA_STAGE5_OUT_CHANNELS
+    out_channels = cfg.MODEL.DLA.BACKBONE_OUT_CHANNELS
+
+    fpn = fpn_module.FPN(
+        in_channels_list=[
+            in_channels_stage2,
+            in_channels_stage3,
+            in_channels_stage4,
+            in_channels_stage5
+        ],
+        out_channels=out_channels,
+        conv_block=conv_with_kaiming_uniform(
+            cfg.MODEL.FPN.USE_GN, cfg.MODEL.FPN.USE_RELU
+        ),
+        top_blocks=fpn_module.LastLevelMaxPool(),
+    )
+    model = nn.Sequential(OrderedDict([("body", body), ("fpn", fpn)]))
+    model.out_channels = out_channels
+    return model
+
+
+def build_backbone(cfg):
+    assert cfg.MODEL.BACKBONE.CONV_BODY in registry.BACKBONES, \
+        "cfg.MODEL.BACKBONE.CONV_BODY: {} are not registered in registry".format(
+            cfg.MODEL.BACKBONE.CONV_BODY
+        )
+    return registry.BACKBONES[cfg.MODEL.BACKBONE.CONV_BODY](cfg)
+
diff --git a/siam-mot/siammot/modelling/backbone/dla.py b/siam-mot/siammot/modelling/backbone/dla.py
new file mode 100644
index 0000000000000000000000000000000000000000..e45031d36c87341b763a3163a870254deef940ac
--- /dev/null
+++ b/siam-mot/siammot/modelling/backbone/dla.py
@@ -0,0 +1,409 @@
+""" Deep Layer Aggregation Backbone
+"""
+import os
+import math
+import torch
+import torch.nn as nn
+from maskrcnn_benchmark.layers import Conv2d
+from maskrcnn_benchmark.layers import DFConv2d
+from maskrcnn_benchmark.layers import FrozenBatchNorm2d
+
+from torchvision.models.utils import load_state_dict_from_url
+
+from timm.models.layers import SelectAdaptivePool2d
+
+from maskrcnn_benchmark.utils.registry import Registry
+from maskrcnn_benchmark.utils.model_serialization import load_state_dict
+
+model_urls = {
+    'dla34': 'http://dl.yf.io/dla/models/imagenet/dla34-ba72cf86.pth',
+    'dla46_c': 'http://dl.yf.io/dla/models/imagenet/dla46_c-2bfd52c3.pth',
+    'dla46x_c': 'http://dl.yf.io/dla/models/imagenet/dla46x_c-d761bae7.pth',
+    'dla60': 'http://dl.yf.io/dla/models/imagenet/dla60-24839fc4.pth',
+    'dla102': 'http://dl.yf.io/dla/models/imagenet/dla102-d94d9790.pth',
+    'dla169': 'http://dl.yf.io/dla/models/imagenet/dla169-0914e092.pth',
+    'dla60_res2net':
+        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net_dla60_4s-d88db7f9.pth',
+}
+
+
+class DlaBasic(nn.Module):
+    """DLA Basic"""
+    def __init__(self, inplanes, planes, stride=1, dilation=1, batch_norm=FrozenBatchNorm2d, **_):
+        super(DlaBasic, self).__init__()
+        self.conv1 = Conv2d(
+            inplanes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation)
+        self.bn1 = batch_norm(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = Conv2d(
+            planes, planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation)
+        self.bn2 = batch_norm(planes)
+        self.stride = stride
+
+    def forward(self, x, residual=None):
+        if residual is None:
+            residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class DlaBottleneck(nn.Module):
+    """DLA/DLA-X Bottleneck"""
+    expansion = 2
+
+    def __init__(self, inplanes, outplanes, stride=1, dilation=1,
+                 cardinality=1, base_width=64, batch_norm=FrozenBatchNorm2d,
+                 with_dcn=False):
+        super(DlaBottleneck, self).__init__()
+        self.stride = stride
+        mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
+        mid_planes = mid_planes // self.expansion
+
+        self.conv1 = Conv2d(inplanes, mid_planes, kernel_size=1, bias=False)
+        self.bn1 = batch_norm(mid_planes)
+        if with_dcn:
+            self.conv2 = DFConv2d(mid_planes, mid_planes, with_modulated_dcn=False,
+                                  kernel_size=3, stride=stride, bias=False,
+                                  dilation=dilation, groups=cardinality)
+        else:
+            self.conv2 = Conv2d(
+                mid_planes, mid_planes, kernel_size=3, stride=stride, padding=dilation,
+                bias=False, dilation=dilation, groups=cardinality)
+        self.bn2 = batch_norm(mid_planes)
+        self.conv3 = Conv2d(mid_planes, outplanes, kernel_size=1, bias=False)
+        self.bn3 = batch_norm(outplanes)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x, residual=None):
+        if residual is None:
+            residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class DlaBottle2neck(nn.Module):
+    """ Res2Net/Res2NeXT DLA Bottleneck
+    Adapted from https://github.com/gasvn/Res2Net/blob/master/dla.py
+    """
+    expansion = 2
+
+    def __init__(self, inplanes, outplanes, stride=1, dilation=1, scale=4,
+                 cardinality=8, base_width=4, batch_norm=FrozenBatchNorm2d):
+        super(DlaBottle2neck, self).__init__()
+        self.is_first = stride > 1
+        self.scale = scale
+        mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
+        mid_planes = mid_planes // self.expansion
+        self.width = mid_planes
+
+        self.conv1 = Conv2d(inplanes, mid_planes * scale, kernel_size=1, bias=False)
+        self.bn1 = batch_norm(mid_planes * scale)
+
+        num_scale_convs = max(1, scale - 1)
+        convs = []
+        bns = []
+        for _ in range(num_scale_convs):
+            convs.append(Conv2d(
+                mid_planes, mid_planes, kernel_size=3, stride=stride,
+                padding=dilation, dilation=dilation, groups=cardinality, bias=False))
+            bns.append(batch_norm(mid_planes))
+        self.convs = nn.ModuleList(convs)
+        self.bns = nn.ModuleList(bns)
+        if self.is_first:
+            self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
+
+        self.conv3 = Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False)
+        self.bn3 = batch_norm(outplanes)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x, residual=None):
+        if residual is None:
+            residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        spx = torch.split(out, self.width, 1)
+        spo = []
+        for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
+            sp = spx[i] if i == 0 or self.is_first else sp + spx[i]
+            sp = conv(sp)
+            sp = bn(sp)
+            sp = self.relu(sp)
+            spo.append(sp)
+        if self.scale > 1 :
+            spo.append(self.pool(spx[-1]) if self.is_first else spx[-1])
+        out = torch.cat(spo, 1)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class DlaRoot(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, residual, batch_norm=FrozenBatchNorm2d):
+        super(DlaRoot, self).__init__()
+        self.conv = Conv2d(
+            in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2)
+        self.bn = batch_norm(out_channels)
+        self.relu = nn.ReLU(inplace=True)
+        self.residual = residual
+
+    def forward(self, *x):
+        children = x
+        x = self.conv(torch.cat(x, 1))
+        x = self.bn(x)
+        if self.residual:
+            x += children[0]
+        x = self.relu(x)
+
+        return x
+
+
+class DlaTree(nn.Module):
+    def __init__(self, levels, block, in_channels, out_channels, stride=1,
+                 dilation=1, cardinality=1, base_width=64,
+                 level_root=False, root_dim=0, root_kernel_size=1, root_residual=False,
+                 batch_norm=FrozenBatchNorm2d, with_dcn=False):
+        super(DlaTree, self).__init__()
+        if root_dim == 0:
+            root_dim = 2 * out_channels
+        if level_root:
+            root_dim += in_channels
+        cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width, batch_norm=batch_norm, with_dcn=with_dcn)
+        if levels == 1:
+            self.tree1 = block(in_channels, out_channels, stride, **cargs)
+            self.tree2 = block(out_channels, out_channels, 1, **cargs)
+        else:
+            cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual))
+            self.tree1 = DlaTree(
+                levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs)
+            self.tree2 = DlaTree(
+                levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs)
+        if levels == 1:
+            self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual, batch_norm=batch_norm)
+        self.level_root = level_root
+        self.root_dim = root_dim
+        self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else None
+        self.project = None
+        if in_channels != out_channels:
+            self.project = nn.Sequential(
+                Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
+                batch_norm(out_channels)
+            )
+        self.levels = levels
+
+    def forward(self, x, residual=None, children=None):
+        children = [] if children is None else children
+        bottom = self.downsample(x) if self.downsample else x
+        residual = self.project(bottom) if self.project else bottom
+        if self.level_root:
+            children.append(bottom)
+        x1 = self.tree1(x, residual)
+        if self.levels == 1:
+            x2 = self.tree2(x1)
+            x = self.root(x2, x1, *children)
+        else:
+            children.append(x1)
+            x = self.tree2(x1, children=children)
+        return x
+
+
+class DLA(nn.Module):
+    def __init__(self, levels, channels, num_classes=1000, in_chans=3, cardinality=1, base_width=64,
+                 block=DlaBottle2neck, residual_root=False, linear_root=False, batch_norm=FrozenBatchNorm2d,
+                 drop_rate=0.0, global_pool='avg', feature_only=True, dcn_config=(False,)):
+        super(DLA, self).__init__()
+        self.channels = channels
+        self.num_classes = num_classes
+        self.cardinality = cardinality
+        self.base_width = base_width
+        self.drop_rate = drop_rate
+
+        # check whether deformable conv config is right
+        if len(dcn_config) != 6:
+            raise ValueError("Deformable configuration is not correct, "
+                             "every level should specifcy a configuration.")
+
+        self.base_layer = nn.Sequential(
+            Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
+            batch_norm(channels[0]),
+            nn.ReLU(inplace=True))
+        self.level0 = self._make_conv_level(channels[0], channels[0], levels[0], batch_norm=batch_norm)
+        self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2, batch_norm=batch_norm)
+        cargs = dict(cardinality=cardinality, base_width=base_width, root_residual=residual_root, batch_norm=batch_norm)
+        self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False,
+                              with_dcn=dcn_config[2], **cargs)
+        self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True,
+                              with_dcn=dcn_config[3], **cargs)
+        self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True,
+                              with_dcn=dcn_config[4], **cargs)
+        self.level5 = DlaTree(levels[5], block, channels[4], channels[5], 2, level_root=True,
+                              with_dcn=dcn_config[5], **cargs)
+
+        if not feature_only:
+            self.num_features = channels[-1]
+            self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
+            self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes, 1, bias=True)
+
+    def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1, batch_norm=FrozenBatchNorm2d):
+        modules = []
+        for i in range(convs):
+            modules.extend([
+                Conv2d(inplanes, planes, kernel_size=3, stride=stride if i == 0 else 1,
+                          padding=dilation, bias=False, dilation=dilation),
+                batch_norm(planes),
+                nn.ReLU(inplace=True)])
+            inplanes = planes
+        return nn.Sequential(*modules)
+
+    def forward(self, x):
+        features = []
+        x = self.base_layer(x)
+        x0 = self.level0(x)
+        x1 = self.level1(x0)
+        x2 = self.level2(x1)
+        x3 = self.level3(x2)
+        x4 = self.level4(x3)
+        x5 = self.level5(x4)
+
+        features.append(x2)
+        features.append(x3)
+        features.append(x4)
+        features.append(x5)
+
+        return features
+
+
+def dla_34(dcn_config, feature_only=True, batch_norm=FrozenBatchNorm2d):
+    model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512],
+                block=DlaBasic,
+                batch_norm=batch_norm,
+                feature_only=feature_only,
+                dcn_config=dcn_config)
+    return model
+
+
+def dla_46_c(dcn_config, feature_only=True, batch_norm=FrozenBatchNorm2d):
+    model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256],
+                block=DlaBottleneck,
+                batch_norm=batch_norm,
+                feature_only=feature_only,
+                dcn_config=dcn_config)
+    return model
+
+
+def dla_46_xc(dcn_config, feature_only=True, batch_norm=FrozenBatchNorm2d):
+    model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 64, 128, 256],
+                block=DlaBottleneck,
+                cardinality=32,
+                base_width=4,
+                batch_norm=batch_norm,
+                feature_only=feature_only,
+                dcn_config=dcn_config)
+    return model
+
+
+def dla_60(dcn_config, feature_only=True, batch_norm=FrozenBatchNorm2d):
+    model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024],
+                block=DlaBottleneck,
+                batch_norm=batch_norm,
+                feature_only=feature_only,
+                dcn_config=dcn_config)
+    return model
+
+
+def dla60_res2net(dcn_config, feature_only=True, batch_norm=FrozenBatchNorm2d):
+    model = DLA(levels=(1, 1, 1, 2, 3, 1),
+                channels=(16, 32, 128, 256, 512, 1024),
+                block=DlaBottle2neck,
+                cardinality=1,
+                base_width=28,
+                batch_norm=batch_norm,
+                feature_only=feature_only,
+                dcn_config=dcn_config)
+    return model
+
+
+def dla_102(dcn_config, feature_only=True, batch_norm=FrozenBatchNorm2d):
+    model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024],
+                block=DlaBottleneck,
+                residual_root=True,
+                batch_norm=batch_norm,
+                feature_only=feature_only,
+                dcn_config=dcn_config)
+    return model
+
+
+def dla_169(dcn_config, feature_only=True, batch_norm=FrozenBatchNorm2d):
+    model = DLA([1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024],
+                block=DlaBottleneck,
+                residual_root=True,
+                batch_norm=batch_norm,
+                feature_only=feature_only,
+                dcn_config=dcn_config)
+    return model
+
+
+BACKBONE = Registry({
+    "DLA-34-FPN": dla_34,
+    "DLA-46-C-FPN": dla_46_c,
+    "DLA-46-XC-FPN": dla_46_xc,
+    "DLA-60-FPN": dla_60,
+    "DLA-60-RES2NET-FPN": dla60_res2net,
+    "DLA-102-FPN": dla_102,
+    "DLA-169-FPN": dla_169
+})
+
+BACKBONE_ARCH = {
+    "DLA-34-FPN": "dla34",
+    "DLA-46-C-FPN": "dla_46_c",
+    "DLA-46-XC-FPN": "dla_46_xc",
+    "DLA-60-FPN": "dla_60",
+    "DLA-60-RES2NET-FPN": "dla60_res2net",
+    "DLA-102-FPN": "dla_102",
+    "DLA-169-FPN": "dla_169"
+}
+
+
+def dla(cfg):
+    model = BACKBONE[cfg.MODEL.BACKBONE.CONV_BODY](cfg.MODEL.DLA.STAGE_WITH_DCN)
+
+    # Load the ImageNet pretrained backbone if no valid pre-trained model weights are given
+    if not os.path.exists(cfg.MODEL.WEIGHT):
+        state_dict = load_state_dict_from_url(model_urls[BACKBONE_ARCH[cfg.MODEL.BACKBONE.CONV_BODY]],
+                                              progress=True)
+        load_state_dict(model, state_dict)
+
+    return model
+
+
diff --git a/siam-mot/siammot/modelling/backbone/fpn.py b/siam-mot/siammot/modelling/backbone/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..1602750098542e323c2d36f03a8af4e99aa2ca85
--- /dev/null
+++ b/siam-mot/siammot/modelling/backbone/fpn.py
@@ -0,0 +1,98 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class FPN(nn.Module):
+    """
+    Module that adds FPN on top of a list of feature maps.
+    The feature maps are currently supposed to be in increasing depth
+    order, and must be consecutive
+    """
+
+    def __init__(
+        self, in_channels_list, out_channels, conv_block, top_blocks=None
+    ):
+        """
+        Arguments:
+            in_channels_list (list[int]): number of channels for each feature map that
+                will be fed
+            out_channels (int): number of channels of the FPN representation
+            top_blocks (nn.Module or None): if provided, an extra operation will
+                be performed on the output of the last (smallest resolution)
+                FPN output, and the result will extend the result list
+        """
+        super(FPN, self).__init__()
+        self.inner_blocks = []
+        self.layer_blocks = []
+        for idx, in_channels in enumerate(in_channels_list, 1):
+            inner_block = "fpn_inner{}".format(idx)
+            layer_block = "fpn_layer{}".format(idx)
+
+            if in_channels == 0:
+                continue
+            inner_block_module = conv_block(in_channels, out_channels, 1)
+            layer_block_module = conv_block(out_channels, out_channels, 3, 1)
+            self.add_module(inner_block, inner_block_module)
+            self.add_module(layer_block, layer_block_module)
+            self.inner_blocks.append(inner_block)
+            self.layer_blocks.append(layer_block)
+        self.top_blocks = top_blocks
+
+    def forward(self, x):
+        """
+        Arguments:
+            x (list[Tensor]): feature maps for each feature level.
+        Returns:
+            results (tuple[Tensor]): feature maps after FPN layers.
+                They are ordered from highest resolution first.
+        """
+        last_inner = getattr(self, self.inner_blocks[-1])(x[-1])
+        results = []
+        results.append(getattr(self, self.layer_blocks[-1])(last_inner))
+        for feature, inner_block, layer_block in zip(
+            x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
+        ):
+            if not inner_block:
+                continue
+            # inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest")
+            inner_lateral = getattr(self, inner_block)(feature)
+            # change it to support the image size that is not divisible by 32
+            inner_top_down = F.interpolate(last_inner, size=inner_lateral.shape[-2:],
+                mode='bilinear', align_corners=False)
+            last_inner = inner_lateral + inner_top_down
+            results.insert(0, getattr(self, layer_block)(last_inner))
+
+        if isinstance(self.top_blocks, LastLevelP6P7):
+            last_results = self.top_blocks(x[-1], results[-1])
+            results.extend(last_results)
+        elif isinstance(self.top_blocks, LastLevelMaxPool):
+            last_results = self.top_blocks(results[-1])
+            results.extend(last_results)
+
+        return tuple(results)
+
+
+class LastLevelMaxPool(nn.Module):
+    def forward(self, x):
+        return [F.max_pool2d(x, 1, 2, 0)]
+
+
+class LastLevelP6P7(nn.Module):
+    """
+    This module is used in RetinaNet to generate extra layers, P6 and P7.
+    """
+    def __init__(self, in_channels, out_channels):
+        super(LastLevelP6P7, self).__init__()
+        self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
+        self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
+        for module in [self.p6, self.p7]:
+            nn.init.kaiming_uniform_(module.weight, a=1)
+            nn.init.constant_(module.bias, 0)
+        self.use_P5 = in_channels == out_channels
+
+    def forward(self, c5, p5):
+        x = p5 if self.use_P5 else c5
+        p6 = self.p6(x)
+        p7 = self.p7(F.relu(p6))
+        return [p6, p7]
diff --git a/siam-mot/siammot/modelling/box_head/box_head.py b/siam-mot/siammot/modelling/box_head/box_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..92a80db8a0ab6aa7bf22aad389802aaec6ad7f92
--- /dev/null
+++ b/siam-mot/siammot/modelling/box_head/box_head.py
@@ -0,0 +1,70 @@
+import torch
+
+from maskrcnn_benchmark.modeling.roi_heads.box_head.roi_box_feature_extractors import make_roi_box_feature_extractor
+from maskrcnn_benchmark.modeling.roi_heads.box_head.roi_box_predictors import make_roi_box_predictor
+from maskrcnn_benchmark.modeling.roi_heads.box_head.loss import make_roi_box_loss_evaluator
+
+from .inference import make_roi_box_post_processor
+
+
+class ROIBoxHead(torch.nn.Module):
+    """
+    Generic Box Head class.
+    """
+
+    def __init__(self, cfg, in_channels):
+        super(ROIBoxHead, self).__init__()
+        self.feature_extractor = make_roi_box_feature_extractor(cfg, in_channels)
+        self.predictor = make_roi_box_predictor(
+            cfg, self.feature_extractor.out_channels)
+        self.post_processor = make_roi_box_post_processor(cfg)
+        self.loss_evaluator = make_roi_box_loss_evaluator(cfg)
+
+    def forward(self, features, proposals, targets=None):
+        """
+        Arguments:
+            features (list[Tensor]): feature-maps from possibly several levels
+            proposals (list[BoxList]): proposal boxes
+            targets (list[BoxList], optional): the ground-truth targets.
+
+        Returns:
+            x (Tensor): the result of the feature extractor
+            proposals (list[BoxList]): during training, the subsampled proposals
+                are returned. During testing, the predicted boxlists are returned
+            losses (dict[Tensor]): During training, returns the losses for the
+                head. During testing, returns an empty dict.
+        """
+
+        if self.training:
+            # Faster R-CNN subsamples during training the proposals with a fixed
+            # positive / negative ratio
+            with torch.no_grad():
+                proposals = self.loss_evaluator.subsample(proposals, targets)
+
+        # extract features that will be fed to the final classifier. The
+        # feature_extractor generally corresponds to the pooler + heads
+        x = self.feature_extractor(features, proposals)
+        # final classifier that converts the features into predictions
+        class_logits, box_regression = self.predictor(x)
+
+        if not self.training:
+            result = self.post_processor((class_logits, box_regression), proposals)
+            return x, result, {}
+
+        loss_classifier, loss_box_reg = self.loss_evaluator(
+            [class_logits], [box_regression]
+        )
+        return (
+            x,
+            proposals,
+            dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg),
+        )
+
+
+def build_roi_box_head(cfg, in_channels):
+    """
+    Constructs a new box head.
+    By default, uses ROIBoxHead, but if it turns out not to be enough, just register a new class
+    and make it a parameter in the config
+    """
+    return ROIBoxHead(cfg, in_channels)
diff --git a/siam-mot/siammot/modelling/box_head/inference.py b/siam-mot/siammot/modelling/box_head/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..5db792ad3a6dc63cdc5b035667eba54790a2a6f7
--- /dev/null
+++ b/siam-mot/siammot/modelling/box_head/inference.py
@@ -0,0 +1,217 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
+from maskrcnn_benchmark.modeling.box_coder import BoxCoder
+
+
+class PostProcessor(nn.Module):
+    """
+    From a set of classification scores, box regression and proposals,
+    computes the post-processed boxes, and applies NMS to obtain the
+    final results
+    """
+
+    def __init__(
+        self,
+        score_thresh=0.05,
+        nms=0.5,
+        detections_per_img=100,
+        box_coder=None,
+        cls_agnostic_bbox_reg=False,
+        bbox_aug_enabled=False,
+        amodal_inference=False
+    ):
+        """
+        Arguments:
+            score_thresh (float)
+            nms (float)
+            detections_per_img (int)
+            box_coder (BoxCoder)
+        """
+        super(PostProcessor, self).__init__()
+        self.score_thresh = score_thresh
+        self.nms = nms
+        self.detections_per_img = detections_per_img
+        if box_coder is None:
+            box_coder = BoxCoder(weights=(10., 10., 5., 5.))
+        self.box_coder = box_coder
+        self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
+        self.bbox_aug_enabled = bbox_aug_enabled
+        self.amodal_inference = amodal_inference
+
+    def forward(self, x, boxes):
+        """
+        Arguments:
+            x (tuple[tensor, tensor]): x contains the class logits
+                and the box_regression from the model.
+            boxes (list[BoxList]): bounding boxes that are used as
+                reference, one for each image
+
+        Returns:
+            results (list[BoxList]): one BoxList for each image, containing
+                the extra fields labels and scores
+        """
+        class_logits, box_regression = x
+        class_prob = F.softmax(class_logits, -1)
+        device = class_logits.device
+
+        # TODO think about a representation of batch of boxes
+        image_shapes = [box.size for box in boxes]
+        boxes_per_image = [len(box) for box in boxes]
+        concat_boxes = torch.cat([a.bbox for a in boxes], dim=0)
+
+        if self.cls_agnostic_bbox_reg:
+            box_regression = box_regression[:, -4:]
+        proposals = self.box_coder.decode(
+            box_regression.view(sum(boxes_per_image), -1), concat_boxes
+        )
+        if self.cls_agnostic_bbox_reg:
+            proposals = proposals.repeat(1, class_prob.shape[1])
+
+        num_classes = class_prob.shape[1]
+
+        proposals = proposals.split(boxes_per_image, dim=0)
+        class_prob = class_prob.split(boxes_per_image, dim=0)
+
+        results = [self.create_empty_boxlist(device) for _ in boxes]
+
+        for i, (prob, boxes_per_img, image_shape, _box) in enumerate(zip(
+                class_prob, proposals, image_shapes, boxes)):
+
+            # get ids for each bbox
+            if _box.has_field('ids'):
+                ids = _box.get_field('ids')
+            else:
+                # deafult id is -1
+                ids = torch.zeros((len(_box),), dtype=torch.int64, device=device) - 1
+
+            # this only happens for tracks
+            if _box.has_field('labels'):
+                labels = _box.get_field('labels')
+
+                # tracks
+                track_inds = torch.squeeze(torch.nonzero(ids >= 0))
+
+                # avoid track bbs be suppressed during nms
+                if track_inds.numel() > 0:
+                    prob_cp = prob.clone()
+                    prob[track_inds, :] = 0.
+                    prob[track_inds, labels] = prob_cp[track_inds, labels] + 1.
+
+                # # avoid track bbs be suppressed during nms
+            # prob[ids >= 0] = prob[ids >= 0] + 1.
+
+            boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape, ids)
+            if not self.amodal_inference:
+                boxlist = boxlist.clip_to_image(remove_empty=False)
+            boxlist = self.filter_results(boxlist, num_classes)
+
+            results[i] = boxlist
+        return results
+
+    @staticmethod
+    def prepare_boxlist(boxes, scores, image_shape, ids):
+        """
+        Returns BoxList from `boxes` and adds probability scores information
+        as an extra field
+        `boxes` has shape (#detections, 4 * #classes), where each row represents
+        a list of predicted bounding boxes for each of the object classes in the
+        dataset (including the background class). The detections in each row
+        originate from the same object proposal.
+        `scores` has shape (#detection, #classes), where each row represents a list
+        of object detection confidence scores for each of the object classes in the
+        dataset (including the background class). `scores[i, j]`` corresponds to the
+        box at `boxes[i, j * 4:(j + 1) * 4]`.
+        """
+        boxes = boxes.reshape(-1, 4)
+        scores = scores.reshape(-1)
+        boxlist = BoxList(boxes, image_shape, mode="xyxy")
+        boxlist.add_field("scores", scores)
+        boxlist.add_field("ids", ids)
+        return boxlist
+
+    def create_empty_boxlist(self, device="cpu"):
+
+        init_bbox = torch.zeros(([0, 4]), dtype=torch.float32, device=device)
+        init_score = torch.zeros([0, ], dtype=torch.float32, device=device)
+        init_ids = torch.zeros(([0, ]), dtype=torch.int64, device=device)
+        empty_boxlist = self.prepare_boxlist(init_bbox, init_score, [0, 0], init_ids)
+        return empty_boxlist
+
+    def filter_results(self, boxlist, num_classes):
+        """Returns bounding-box detection results by thresholding on scores and
+        applying non-maximum suppression (NMS).
+        """
+        # unwrap the boxlist to avoid additional overhead.
+        # if we had multi-class NMS, we could perform this directly on the boxlist
+        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
+        scores = boxlist.get_field("scores").reshape(-1, num_classes)
+        device = scores.device
+
+        assert (boxlist.has_field('ids'))
+        ids = boxlist.get_field('ids')
+
+        result = [self.create_empty_boxlist(device=device)
+                  for _ in range(1, num_classes)]
+
+        # Apply threshold on detection probabilities and apply NMS
+        # Skip j = 0, because it's the background class
+        inds_all = scores > self.score_thresh
+        for j in range(1, num_classes):
+            inds = inds_all[:, j].nonzero().squeeze(1)
+            scores_j = scores[inds, j]
+            boxes_j = boxes[inds, j * 4: (j + 1) * 4]
+            ids_j = ids[inds]
+
+            det_idx = ids_j < 0
+            det_boxlist = BoxList(boxes_j[det_idx, :], boxlist.size, mode="xyxy")
+            det_boxlist.add_field("scores", scores_j[det_idx])
+            det_boxlist.add_field("ids", ids_j[det_idx])
+            det_boxlist = boxlist_nms(det_boxlist, self.nms)
+
+            track_idx = ids_j >= 0
+            # track_box is available
+            if torch.any(track_idx > 0):
+                track_boxlist = BoxList(boxes_j[track_idx, :], boxlist.size, mode="xyxy")
+                track_boxlist.add_field("scores", scores_j[track_idx])
+                track_boxlist.add_field("ids", ids_j[track_idx])
+                det_boxlist = cat_boxlist([det_boxlist, track_boxlist])
+
+            num_labels = len(det_boxlist)
+            det_boxlist.add_field(
+                "labels", torch.full((num_labels,), j, dtype=torch.int64, device=device)
+            )
+            result[j-1] = det_boxlist
+
+        result = cat_boxlist(result)
+        return result
+
+
+def make_roi_box_post_processor(cfg):
+    use_fpn = cfg.MODEL.ROI_HEADS.USE_FPN
+
+    bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS
+    box_coder = BoxCoder(weights=bbox_reg_weights)
+
+    score_thresh = cfg.MODEL.ROI_HEADS.SCORE_THRESH
+    nms_thresh = cfg.MODEL.ROI_HEADS.NMS
+    detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG
+    cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG
+    bbox_aug_enabled = cfg.TEST.BBOX_AUG.ENABLED
+
+    amodal_inference = cfg.INPUT.AMODAL
+
+    postprocessor = PostProcessor(
+        score_thresh,
+        nms_thresh,
+        detections_per_img,
+        box_coder,
+        cls_agnostic_bbox_reg,
+        bbox_aug_enabled,
+        amodal_inference
+    )
+    return postprocessor
diff --git a/siam-mot/siammot/modelling/rcnn.py b/siam-mot/siammot/modelling/rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..8292f1b9c97efc86fc855eac8018c470c9784fe7
--- /dev/null
+++ b/siam-mot/siammot/modelling/rcnn.py
@@ -0,0 +1,71 @@
+"""
+Implements the Generalized R-CNN for SiamMOT
+"""
+from torch import nn
+
+from maskrcnn_benchmark.structures.image_list import to_image_list
+from maskrcnn_benchmark.modeling.rpn.rpn import build_rpn
+
+from .roi_heads import build_roi_heads
+from .backbone.backbone_ext import build_backbone
+
+
+class SiamMOT(nn.Module):
+    """
+    Main class for R-CNN. Currently supports boxes and tracks.
+    It consists of three main parts:
+    - backbone
+    - rpn
+    - heads: takes the features + the proposals from the RPN and
+             computes detections / tracks from it.
+    """
+
+    def __init__(self, cfg):
+        super(SiamMOT, self).__init__()
+
+        self.backbone = build_backbone(cfg)
+        self.rpn = build_rpn(cfg, self.backbone.out_channels)
+        self.roi_heads = build_roi_heads(cfg, self.backbone.out_channels)
+
+        self.track_memory = None
+
+    def flush_memory(self, cache=None):
+        self.track_memory = cache
+
+    def reset_siammot_status(self):
+        self.flush_memory()
+        self.roi_heads.reset_roi_status()
+
+    def forward(self, images, targets=None, given_detection=None):
+
+        if self.training and targets is None:
+            raise ValueError("In training mode, targets should be passed")
+
+        images = to_image_list(images)
+        features = self.backbone(images.tensors)
+        proposals, proposal_losses = self.rpn(images, features, targets)
+
+        if self.roi_heads:
+            x, result, roi_losses = self.roi_heads(features,
+                                                   proposals,
+                                                   targets,
+                                                   self.track_memory,
+                                                   given_detection)
+            if not self.training:
+                self.flush_memory(cache=x)
+
+        else:
+            raise NotImplementedError
+
+        if self.training:
+            losses = {}
+            losses.update(roi_losses)
+            losses.update(proposal_losses)
+            return result, losses
+
+        return result
+
+
+def build_siammot(cfg):
+    siammot = SiamMOT(cfg)
+    return siammot
diff --git a/siam-mot/siammot/modelling/roi_heads.py b/siam-mot/siammot/modelling/roi_heads.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0f1c2d30a79ccef9ce77148d9f6ae06e3fc3ad6
--- /dev/null
+++ b/siam-mot/siammot/modelling/roi_heads.py
@@ -0,0 +1,93 @@
+import torch
+from .box_head.box_head import build_roi_box_head
+from .track_head.track_head import build_track_head
+from .track_head.track_utils import build_track_utils
+from .track_head.track_solver import builder_tracker_solver
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
+
+
+class CombinedROIHeads(torch.nn.ModuleDict):
+    """
+    Combines a set of individual heads (for box prediction or masks) into a single
+    head.
+    """
+
+    def __init__(self, cfg, heads):
+        super(CombinedROIHeads, self).__init__(heads)
+        self.cfg = cfg.clone()
+
+    def forward(self, features, proposals, targets=None, track_memory=None, given_detection=None):
+        losses = {}
+
+        x, detections, loss_box = self.box(features, proposals, targets)
+        losses.update(loss_box)
+
+        if self.cfg.MODEL.TRACK_ON:
+            y, tracks, loss_track = self.track(features, proposals, targets, track_memory)
+            losses.update(loss_track)
+
+            # solver is only needed during inference
+            if not self.training:
+                if tracks is not None:
+                    tracks = self._refine_tracks(features, tracks)
+                    detections = [cat_boxlist(detections + tracks)]
+
+                detections = self.solver(detections)
+
+                # get the current state for tracking
+                x = self.track.get_track_memory(features, detections)
+
+        return x, detections, losses
+
+    def reset_roi_status(self):
+        """
+        Reset the status of ROI Heads
+        """
+        if self.cfg.MODEL.TRACK_ON:
+            self.track.reset_track_pool()
+
+    def _refine_tracks(self, features, tracks):
+        """
+        Use box head to refine the bounding box location
+        The final vis score is an average between appearance and matching score
+        """
+        if len(tracks[0]) == 0:
+            return tracks[0]
+        track_scores = tracks[0].get_field('scores') + 1.
+        # track_boxes = tracks[0].bbox
+        _, tracks, _ = self.box(features, tracks)
+        det_scores = tracks[0].get_field('scores')
+        det_boxes = tracks[0].bbox
+
+        if self.cfg.MODEL.TRACK_HEAD.TRACKTOR:
+            scores = det_scores
+        else:
+            scores = (det_scores + track_scores) / 2.
+        boxes = det_boxes
+
+        r_tracks = BoxList(boxes, image_size=tracks[0].size, mode=tracks[0].mode)
+        r_tracks.add_field('scores', scores)
+        r_tracks.add_field('ids', tracks[0].get_field('ids'))
+        r_tracks.add_field('labels', tracks[0].get_field('labels'))
+
+        return [r_tracks]
+
+
+def build_roi_heads(cfg, in_channels):
+    # individually create the heads, that will be combined together
+    roi_heads = []
+    if not cfg.MODEL.RPN_ONLY:
+        roi_heads.append(("box", build_roi_box_head(cfg, in_channels)))
+    if cfg.MODEL.TRACK_ON:
+        track_utils, track_pool = build_track_utils(cfg)
+        roi_heads.append(("track", build_track_head(cfg, track_utils, track_pool)))
+        # solver is a non-learnable layer that would only be used during inference
+        roi_heads.append(("solver", builder_tracker_solver(cfg, track_pool)))
+
+    # combine individual heads in a single module
+    if roi_heads:
+        roi_heads = CombinedROIHeads(cfg, roi_heads)
+
+    return roi_heads
diff --git a/siam-mot/siammot/modelling/track_head/EMM/feature_extractor.py b/siam-mot/siammot/modelling/track_head/EMM/feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..fee903120d99bfba7d6e4040ea9ef7f5bc18dd06
--- /dev/null
+++ b/siam-mot/siammot/modelling/track_head/EMM/feature_extractor.py
@@ -0,0 +1,69 @@
+from torch import nn
+from torch.nn import functional as F
+
+from maskrcnn_benchmark.modeling.make_layers import make_conv3x3
+
+from .sr_pool import SRPooler
+
+
+class EMMFeatureExtractor(nn.Module):
+    """
+    Feature extraction for template and search region is different in this case
+    """
+
+    def __init__(self, cfg):
+        super(EMMFeatureExtractor, self).__init__()
+
+        resolution = cfg.MODEL.TRACK_HEAD.POOLER_RESOLUTION
+        scales = cfg.MODEL.TRACK_HEAD.POOLER_SCALES
+        sampling_ratio = cfg.MODEL.TRACK_HEAD.POOLER_SAMPLING_RATIO
+        r = cfg.MODEL.TRACK_HEAD.SEARCH_REGION
+
+        pooler_z = SRPooler(
+            output_size=(resolution, resolution),
+            scales=scales,
+            sampling_ratio=sampling_ratio)
+        pooler_x = SRPooler(
+            output_size=(int(resolution*r), int(resolution*r)),
+            scales=scales,
+            sampling_ratio=sampling_ratio)
+
+        self.pooler_x = pooler_x
+        self.pooler_z = pooler_z
+
+    def forward(self, x, proposals, sr=None):
+        if sr is not None:
+            x = self.pooler_x(x, proposals, sr)
+        else:
+            x = self.pooler_z(x, proposals)
+
+        return x
+
+
+class EMMPredictor(nn.Module):
+    def __init__(self, cfg):
+        super(EMMPredictor, self).__init__()
+
+        if cfg.MODEL.BACKBONE.CONV_BODY.startswith("DLA"):
+            in_channels = cfg.MODEL.DLA.BACKBONE_OUT_CHANNELS
+        elif cfg.MODEL.BACKBONE.CONV_BODY.startswith("R-"):
+            in_channels = cfg.MODEL.RESNETS.BACKBONE_OUT_CHANNELS
+        else:
+            in_channels = 128
+
+        self.cls_tower = make_conv3x3(in_channels=in_channels, out_channels=in_channels,
+                                      use_gn=True, use_relu=True, kaiming_init=False)
+        self.reg_tower = make_conv3x3(in_channels=in_channels, out_channels=in_channels,
+                                      use_gn=True, use_relu=True, kaiming_init=False)
+        self.cls = make_conv3x3(in_channels=in_channels, out_channels=2, kaiming_init=False)
+        self.center = make_conv3x3(in_channels=in_channels, out_channels=1, kaiming_init=False)
+        self.reg = make_conv3x3(in_channels=in_channels, out_channels=4, kaiming_init=False)
+
+    def forward(self, x):
+        cls_x = self.cls_tower(x)
+        reg_x = self.reg_tower(x)
+        cls_logits = self.cls(cls_x)
+        center_logits = self.center(cls_x)
+        reg_logits = F.relu(self.reg(reg_x))
+
+        return cls_logits, center_logits, reg_logits
\ No newline at end of file
diff --git a/siam-mot/siammot/modelling/track_head/EMM/sr_pool.py b/siam-mot/siammot/modelling/track_head/EMM/sr_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..c630959458088e86b6267e5e5816640813b0bb3e
--- /dev/null
+++ b/siam-mot/siammot/modelling/track_head/EMM/sr_pool.py
@@ -0,0 +1,91 @@
+import torch
+from torch import nn
+
+from maskrcnn_benchmark.modeling.poolers import LevelMapper
+from maskrcnn_benchmark.modeling.utils import cat
+from maskrcnn_benchmark.layers import ROIAlign
+
+
+class SRPooler(nn.Module):
+    """
+    SRPooler for Detection with or without FPN.
+    Also, the requirement of passing the scales is not strictly necessary, as they
+    can be inferred from the size of the feature map / size of original image,
+    which is available thanks to the BoxList.
+    """
+
+    def __init__(self, output_size, scales, sampling_ratio):
+        """
+        Arguments:
+            output_size (list[tuple[int]] or list[int]): output size for the pooled region
+            scales (list[float]): scales for each Pooler
+            sampling_ratio (int): sampling ratio for ROIAlign
+        """
+        super(SRPooler, self).__init__()
+        poolers = []
+        for scale in scales:
+            poolers.append(
+                ROIAlign(
+                    output_size, spatial_scale=scale, sampling_ratio=sampling_ratio
+                )
+            )
+        self.poolers = nn.ModuleList(poolers)
+        self.output_size = output_size
+        # get the levels in the feature map by leveraging the fact that the network always
+        # downsamples by a factor of 2 at each level.
+        lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
+        lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
+        self.map_levels = LevelMapper(lvl_min, lvl_max)
+
+    def convert_to_roi_format(self, boxes):
+        concat_boxes = cat([b.bbox for b in boxes], dim=0)
+        device, dtype = concat_boxes.device, concat_boxes.dtype
+        ids = cat(
+            [
+                torch.full((len(b), 1), i, dtype=dtype, device=device)
+                for i, b in enumerate(boxes)
+            ],
+            dim=0,
+        )
+        rois = torch.cat([ids, concat_boxes], dim=1)
+        return rois
+
+    def forward(self, x, boxes, sr=None):
+        """
+        Arguments:
+            x (list[Tensor]): feature maps for each level
+            boxes (list[BoxList]): boxes to be used to perform the pooling operation.
+            sr(list([BoxList])): search region boxes.
+        Returns:
+            result (Tensor)
+        """
+        num_levels = len(self.poolers)
+
+        if sr is None:
+            rois = self.convert_to_roi_format(boxes)
+        else:
+            # extract features for SR when it is none
+            rois = self.convert_to_roi_format(sr)
+
+        if num_levels == 1:
+            return self.poolers[0](x[0], rois)
+
+        # Always use the template box to get the feature level
+        levels = self.map_levels(boxes)
+
+        num_rois = len(rois)
+        num_channels = x[0].shape[1]
+        output_size = self.output_size[0]
+
+        dtype, device = x[0].dtype, x[0].device
+        result = torch.zeros(
+            (num_rois, num_channels, output_size, output_size),
+            dtype=dtype,
+            device=device,
+        )
+        for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
+            idx_in_level = torch.nonzero(levels == level).squeeze(1)
+            rois_per_level = rois[idx_in_level]
+            result[idx_in_level] = pooler(per_level_feature, rois_per_level).to(dtype)
+
+        return result
\ No newline at end of file
diff --git a/siam-mot/siammot/modelling/track_head/EMM/target_sampler.py b/siam-mot/siammot/modelling/track_head/EMM/target_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..62aef13dbad7920bc94c66d5d47e21f68117fee2
--- /dev/null
+++ b/siam-mot/siammot/modelling/track_head/EMM/target_sampler.py
@@ -0,0 +1,304 @@
+import torch
+import copy
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist, boxlist_iou
+from maskrcnn_benchmark.modeling.matcher import Matcher
+
+from siammot.utils import registry
+
+
+class EMMTargetSampler(object):
+    """
+    Sample track targets for SiamMOT.
+    It samples from track proposals from RPN
+    """
+    def __init__(self, track_utils, matcher, propsals_per_image=256,
+                 pos_ratio=0.25, hn_ratio=0.25):
+        self.track_utils = track_utils
+        self.proposal_iou_matcher = matcher
+        self.proposals_per_image = propsals_per_image
+        self.hn_ratio = hn_ratio
+        self.pos_ratio = pos_ratio
+
+    def match_targets_with_iou(self, proposal: BoxList, gt: BoxList):
+        match_quality_matrix = boxlist_iou(gt, proposal)
+        matched_idxs = self.proposal_iou_matcher(match_quality_matrix)
+
+        target = gt.copy_with_fields(("ids", "labels"))
+        matched_target = target[torch.clamp_min(matched_idxs, -1)]
+        proposal_ids = matched_target.get_field('ids')
+        proposal_labels = matched_target.get_field('labels')
+
+        # id = -1 for background
+        # id = -2 for ignore proposals
+        proposal_ids[matched_idxs == -1] = -1
+        proposal_ids[matched_idxs == -2] = -2
+        proposal_labels[matched_idxs < 0] = 0
+
+        return proposal_ids.type(torch.float), proposal_labels.type(torch.float)
+
+    def assign_matched_ids_to_proposals(self, proposals: BoxList, gts: BoxList):
+        """
+        Assign for each proposal a matched id, if it is matched to a gt box
+        Otherwise, it is assigned -1
+        """
+        for proposal, gt in zip(proposals, gts):
+            proposal_ids, proposal_labels = self.match_targets_with_iou(proposal, gt)
+            proposal.add_field('ids', proposal_ids)
+            proposal.add_field('labels', proposal_labels)
+
+    def duplicate_boxlist(self, boxlist, num_duplicates):
+        """
+        Duplicate samples in box list by concatenating multiple times.
+        """
+        if num_duplicates == 0:
+            return self.get_dummy_boxlist(boxlist)
+        list_to_join = []
+        for _ in range(num_duplicates):
+            dup = boxlist.copy_with_fields(list(boxlist.extra_fields.keys()))
+            list_to_join.append(dup)
+
+        return cat_boxlist(list_to_join)
+
+    def get_dummy_boxlist(self, boxlist:BoxList, num_boxes=0):
+        """
+        Create dummy boxlist, with bbox [-1, -1, -1, -1],
+        id -1, label -1
+        when num_boxes = 0, it means return an empty BoxList
+        """
+        boxes = torch.zeros((num_boxes, 4)) - 1.
+        dummy_boxlist = self.get_default_boxlist(boxlist, boxes)
+
+        return dummy_boxlist
+
+    @staticmethod
+    def get_default_boxlist(boxlist:BoxList, bboxes, ids=None, labels=None):
+        """
+        Construct a boxlist with bbox as bboxes,
+        all other fields to be default
+        id -1, label -1
+        """
+        device = boxlist.bbox.device
+        num_boxes = bboxes.shape[0]
+        if ids is None:
+            ids = torch.zeros((num_boxes,)) - 1.
+        if labels is None:
+            labels = torch.zeros((num_boxes,)) - 1.
+
+        default_boxlist = BoxList(bboxes, image_size=boxlist.size, mode='xyxy')
+        default_boxlist.add_field('labels', labels)
+        default_boxlist.add_field('ids', ids)
+
+        return default_boxlist.to(device)
+
+    @staticmethod
+    def sample_examples(src_box: [BoxList], pair_box: [BoxList],
+                        tar_box: [BoxList], num_samples):
+        """
+        Sample examples
+        """
+        src_box = cat_boxlist(src_box)
+        pair_box = cat_boxlist(pair_box)
+        tar_box = cat_boxlist(tar_box)
+
+        assert (len(src_box) == len(pair_box) and len(src_box) == len(tar_box))
+
+        if len(src_box) <= num_samples:
+            return [src_box, pair_box, tar_box]
+        else:
+            indices = torch.zeros((len(src_box), ), dtype=torch.bool)
+            permuted_idxs = torch.randperm(len(src_box))
+            sampled_idxs = permuted_idxs[: num_samples]
+            indices[sampled_idxs] = 1
+
+            sampled_src_box = src_box[indices]
+            sampled_pair_box = pair_box[indices]
+            sampled_tar_box = tar_box[indices]
+            return [sampled_src_box, sampled_pair_box, sampled_tar_box]
+
+    def sample_boxlist(self, boxlist: BoxList, indices, num_samples):
+        assert (num_samples <= indices.numel())
+
+        if num_samples == 0:
+            sampled_boxlist = self.get_dummy_boxlist(boxlist, num_boxes=0)
+        else:
+            permuted_idxs = torch.randperm(indices.numel())
+            sampled_idxs = indices[permuted_idxs[: num_samples], 0]
+            sampled_boxes = boxlist.bbox[sampled_idxs, :]
+            sampled_ids = None
+            sampled_labels = None
+            if 'ids' in boxlist.fields():
+                sampled_ids = boxlist.get_field('ids')[sampled_idxs]
+            if 'labels' in boxlist.fields():
+                sampled_labels = boxlist.get_field('labels')[sampled_idxs]
+
+            sampled_boxlist = self.get_default_boxlist(boxlist, sampled_boxes,
+                                                       sampled_ids, sampled_labels)
+        return sampled_boxlist
+
+    def get_target_box(self, target_gt, indices):
+        """
+        Get the corresponding target box given by the 1-off indices
+        if there is no matching target box, it returns a dummy box
+        """
+        tar_box = target_gt[indices]
+        # the assertion makes sure that different boxes have different ids
+        assert (len(tar_box) <= 1)
+        if len(tar_box) == 0:
+            # dummy bounding boxes
+            tar_box = self.get_dummy_boxlist(target_gt, num_boxes=1)
+
+        return tar_box
+
+    def generate_hn_pair(self, src_gt: BoxList, proposal: BoxList,
+                         src_h=None, proposal_h=None):
+        """
+        Generate hard negative pair by sampling non-negative proposals
+        """
+        proposal_ids = proposal.get_field('ids')
+        src_id = src_gt.get_field('ids')
+
+        scales = torch.ones_like(proposal_ids)
+        if (src_h is not None) and (proposal_h is not None):
+            scales = src_h / proposal_h
+
+        # sample proposals with similar scales
+        # and non-negative proposals
+        hard_bb_idxs = ((proposal_ids >= 0) & (proposal_ids != src_id))
+        scale_idxs = (scales >= 0.5) & (scales <= 2)
+        indices = (hard_bb_idxs & scale_idxs)
+        unique_ids = torch.unique(proposal_ids[indices])
+        idxs = indices.nonzero()
+
+        # avoid sampling redundant samples
+        num_hn = min(idxs.numel(), unique_ids.numel())
+        sampled_hn_boxes = self.sample_boxlist(proposal, idxs, num_hn)
+
+        return sampled_hn_boxes
+
+    def generate_pos(self, src_gt: BoxList, proposal: BoxList):
+        assert (src_gt.mode == 'xyxy' and len(src_gt) == 1)
+        proposal_ids = proposal.get_field('ids')
+        src_id = src_gt.get_field('ids')
+
+        pos_indices = (proposal_ids == src_id)
+        pos_boxes = proposal[pos_indices]
+        pos_boxes = pos_boxes.copy_with_fields(('ids', 'labels'))
+
+        return pos_boxes
+
+    def generate_pos_hn_example(self, proposals, gts):
+        """
+        Generate positive and hard negative training examples
+        """
+        src_gts = copy.deepcopy(gts)
+        tar_gts = self.track_utils.swap_pairs(copy.deepcopy(gts))
+
+        track_source = []
+        track_target = []
+        track_pair = []
+        for src_gt, tar_gt, proposal in zip(src_gts, tar_gts, proposals):
+            pos_src_boxlist, pos_pair_boxlist, pos_tar_boxlist = ([] for _ in range(3))
+            hn_src_boxlist, hn_pair_boxlist, hn_tar_boxlist = ([] for _ in range(3))
+
+            proposal_h = proposal.bbox[:, 3] - proposal.bbox[:, 1]
+            src_h = src_gt.bbox[:, 3] - src_gt.bbox[:, 1]
+            src_ids = src_gt.get_field('ids')
+            tar_ids = tar_gt.get_field('ids')
+
+            for i, src_id in enumerate(src_ids):
+                _src_box = src_gt[src_ids == src_id]
+                _tar_box = self.get_target_box(tar_gt, tar_ids == src_id)
+
+                pos_src_boxes = self.generate_pos(_src_box, proposal)
+                pos_pair_boxes = copy.deepcopy(pos_src_boxes)
+                pos_tar_boxes = self.duplicate_boxlist(_tar_box, len(pos_src_boxes))
+
+                hn_pair_boxes = self.generate_hn_pair(_src_box, proposal, src_h[i], proposal_h)
+                hn_src_boxes = self.duplicate_boxlist(_src_box, len(hn_pair_boxes))
+                hn_tar_boxes = self.duplicate_boxlist(_tar_box, len(hn_pair_boxes))
+
+                pos_src_boxlist.append(pos_src_boxes)
+                pos_pair_boxlist.append(pos_pair_boxes)
+                pos_tar_boxlist.append(pos_tar_boxes)
+
+                hn_src_boxlist.append(hn_src_boxes)
+                hn_pair_boxlist.append(hn_pair_boxes)
+                hn_tar_boxlist.append(hn_tar_boxes)
+
+            num_pos = int(self.proposals_per_image * self.pos_ratio)
+            num_hn = int(self.proposals_per_image * self.hn_ratio)
+            sampled_pos = self.sample_examples(pos_src_boxlist, pos_pair_boxlist,
+                                               pos_tar_boxlist, num_pos)
+            sampled_hn = self.sample_examples(hn_src_boxlist, hn_pair_boxlist,
+                                              hn_tar_boxlist, num_hn)
+            track_source.append(cat_boxlist([sampled_pos[0], sampled_hn[0]]))
+            track_pair.append(cat_boxlist([sampled_pos[1], sampled_hn[1]]))
+            track_target.append(cat_boxlist([sampled_pos[2], sampled_hn[2]]))
+
+        return track_source, track_pair, track_target
+
+    def generate_neg_examples(self, proposals: [BoxList], gts: [BoxList], pos_hn_boxes: [BoxList]):
+        """
+        Generate negative training examples
+        """
+        track_source = []
+        track_pair = []
+        track_target = []
+        for proposal, gt, pos_hn_box in zip(proposals, gts, pos_hn_boxes):
+            proposal_ids = proposal.get_field('ids')
+            objectness = proposal.get_field('objectness')
+
+            proposal_h = proposal.bbox[:, 3] - proposal.bbox[:, 1]
+            proposal_w = proposal.bbox[:, 2] - proposal.bbox[:, 0]
+
+            neg_indices = ((proposal_ids == -1) & (objectness >= 0.3) &
+                           (proposal_h >= 5) & (proposal_w >= 5))
+            idxs = neg_indices.nonzero()
+
+            neg_samples = min(idxs.numel(), self.proposals_per_image - len(pos_hn_box))
+            neg_samples = max(0, neg_samples)
+
+            sampled_neg_boxes = self.sample_boxlist(proposal, idxs, neg_samples)
+            # for target box
+            sampled_tar_boxes = self.get_dummy_boxlist(proposal, neg_samples)
+
+            track_source.append(sampled_neg_boxes)
+            track_pair.append(sampled_neg_boxes)
+            track_target.append(sampled_tar_boxes)
+        return track_source, track_pair, track_target
+
+    def __call__(self, proposals: [BoxList], gts: [BoxList]):
+
+        self.assign_matched_ids_to_proposals(proposals, gts)
+
+        pos_hn_src, pos_hn_pair, pos_hn_tar = self.generate_pos_hn_example(proposals, gts)
+        neg_src, neg_pair, neg_tar = self.generate_neg_examples(proposals, gts, pos_hn_src)
+
+        track_source = [cat_boxlist([pos_hn, neg]) for (pos_hn, neg) in zip(pos_hn_src, neg_src)]
+        track_pair = [cat_boxlist([pos_hn, neg]) for (pos_hn, neg) in zip(pos_hn_pair, neg_pair)]
+        track_target = [cat_boxlist([pos_hn, neg]) for (pos_hn, neg) in zip(pos_hn_tar, neg_tar)]
+
+        sr = self.track_utils.update_boxes_in_pad_images(track_pair)
+        sr = self.track_utils.extend_bbox(sr)
+
+        return track_source, sr, track_target
+
+
+@registry.TRACKER_SAMPLER.register("EMM")
+def make_emm_target_sampler(cfg,
+                            track_utils
+                            ):
+    matcher = Matcher(
+        cfg.MODEL.TRACK_HEAD.FG_IOU_THRESHOLD,
+        cfg.MODEL.TRACK_HEAD.BG_IOU_THRESHOLD,
+        allow_low_quality_matches=False,
+    )
+
+    track_sampler = EMMTargetSampler(track_utils, matcher,
+                                     propsals_per_image=cfg.MODEL.TRACK_HEAD.PROPOSAL_PER_IMAGE,
+                                     pos_ratio=cfg.MODEL.TRACK_HEAD.EMM.POS_RATIO,
+                                     hn_ratio=cfg.MODEL.TRACK_HEAD.EMM.HN_RATIO,
+                                     )
+    return track_sampler
diff --git a/siam-mot/siammot/modelling/track_head/EMM/track_core.py b/siam-mot/siammot/modelling/track_head/EMM/track_core.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d7d0a9a77baae08b624aad33261ebb94e476f0b
--- /dev/null
+++ b/siam-mot/siammot/modelling/track_head/EMM/track_core.py
@@ -0,0 +1,225 @@
+import torch
+import numpy as np
+import torch.nn.functional as F
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.modeling.utils import cat
+
+from siammot.utils import registry
+from .xcorr import xcorr_depthwise
+from .feature_extractor import EMMFeatureExtractor, EMMPredictor
+from .track_loss import EMMLossComputation
+
+
+@registry.SIAMESE_TRACKER.register("EMM")
+class EMM(torch.nn.Module):
+    def __init__(self, cfg, track_utils):
+        super(EMM, self).__init__()
+        self.feature_extractor = EMMFeatureExtractor(cfg)
+        self.predictor = EMMPredictor(cfg)
+        self.loss = EMMLossComputation(cfg)
+
+        self.track_utils = track_utils
+        self.amodal = cfg.INPUT.AMODAL
+        self.use_centerness = cfg.MODEL.TRACK_HEAD.EMM.USE_CENTERNESS
+        self.pad_pixels = cfg.MODEL.TRACK_HEAD.PAD_PIXELS
+        self.sigma = cfg.MODEL.TRACK_HEAD.EMM.COSINE_WINDOW_WEIGHT
+
+    def forward(self, features, boxes, sr, targets=None, template_features=None):
+        """
+        forward functions of the tracker
+        :param features: raw FPN feature maps from feature backbone
+        :param boxes: template bounding boxes
+        :param sr: search region bounding boxes
+        :param targets:
+        :param template_features: features of the template bounding boxes
+
+        the number of track boxes should be the same as that of
+        search region and template_features
+        """
+
+        # x, y shifting due to feature padding
+        shift_x = self.pad_pixels
+        shift_y = self.pad_pixels
+
+        if self.training:
+            template_features = self.feature_extractor(features, boxes)
+            features = self.track_utils.shuffle_feature(features)
+
+        features = self.track_utils.pad_feature(features)
+
+        sr_features = self.feature_extractor(features, boxes, sr)
+
+        response_map = xcorr_depthwise(sr_features, template_features)
+        cls_logits, center_logits, reg_logits = self.predictor(response_map)
+
+        if self.training:
+            locations = get_locations(sr_features, template_features, sr, shift_xy=(shift_x, shift_y))
+            src_bboxes = cat([b.bbox for b in boxes], dim=0)
+            gt_bboxes = cat([b.bbox for b in targets], dim=0)
+            cls_loss, reg_loss, centerness_loss = self.loss(
+                locations, cls_logits, reg_logits, center_logits, src_bboxes, gt_bboxes)
+
+            loss = dict(loss_tracker_class=cls_loss,
+                        loss_tracker_motion=reg_loss,
+                        loss_tracker_center=centerness_loss)
+
+            return {}, {}, loss
+        else:
+            cls_logits = F.interpolate(cls_logits, scale_factor=16, mode='bicubic')
+            center_logits = F.interpolate(center_logits, scale_factor=16, mode='bicubic')
+            reg_logits = F.interpolate(reg_logits, scale_factor=16, mode='bicubic')
+
+            locations = get_locations(sr_features, template_features, sr, shift_xy=(shift_x, shift_y), up_scale=16)
+
+            assert len(boxes) == 1
+            bb, bb_conf = decode_response(cls_logits, center_logits, reg_logits, locations, boxes[0],
+                                          use_centerness=self.use_centerness, sigma= self.sigma)
+            track_result = wrap_results_to_boxlist(bb, bb_conf, boxes, amodal=self.amodal)
+            return {}, track_result, {}
+
+    def extract_cache(self, features, detection):
+        """
+        Get the cache (state) that is necessary for tracking
+        output: (features for tracking targets,
+                 search region,
+                 detection bounding boxes)
+        """
+
+        # get cache features for search region
+        # FPN features
+        detection = [detection]
+        x = self.feature_extractor(features, detection)
+
+        sr = self.track_utils.update_boxes_in_pad_images(detection)
+        sr = self.track_utils.extend_bbox(sr)
+
+        cache = (x, sr, detection)
+        return cache
+
+
+def decode_response(cls_logits, center_logits, reg_logits, locations, boxes,
+                    use_centerness=True, sigma=0.4):
+    cls_logits = F.softmax(cls_logits, dim=1)
+
+    cls_logits = cls_logits[:, 1:2, :, :]
+    if use_centerness:
+        centerness = F.sigmoid(center_logits)
+        obj_confidence = cls_logits * centerness
+    else:
+        obj_confidence = cls_logits
+
+    num_track_objects = obj_confidence.shape[0]
+    obj_confidence = obj_confidence.reshape((num_track_objects, -1))
+    tlbr = reg_logits.reshape((num_track_objects, 4, -1))
+
+    scale_penalty = get_scale_penalty(tlbr, boxes)
+    cos_window = get_cosine_window_penalty(tlbr)
+    p_obj_confidence = (obj_confidence * scale_penalty)*(1-sigma) + sigma*cos_window
+
+    idxs = torch.argmax(p_obj_confidence, dim=1)
+
+    target_ids = torch.arange(num_track_objects)
+    bb_c = locations[target_ids, idxs, :]
+    shift_tlbr = tlbr[target_ids, :, idxs]
+
+    bb_tl_x = bb_c[:, 0:1] - shift_tlbr[:, 0:1]
+    bb_tl_y = bb_c[:, 1:2] - shift_tlbr[:, 1:2]
+    bb_br_x = bb_c[:, 0:1] + shift_tlbr[:, 2:3]
+    bb_br_y = bb_c[:, 1:2] + shift_tlbr[:, 3:4]
+    bb = torch.cat((bb_tl_x, bb_tl_y, bb_br_x, bb_br_y), dim=1)
+
+    cls_logits = cls_logits.reshape((num_track_objects, -1))
+    bb_conf = cls_logits[target_ids, idxs]
+
+    return bb, bb_conf
+
+
+def get_scale_penalty(tlbr: torch.Tensor, boxes: BoxList):
+    box_w = boxes.bbox[:, 2] - boxes.bbox[:, 0]
+    box_h = boxes.bbox[:, 3] - boxes.bbox[:, 1]
+
+    r_w = tlbr[:, 2] + tlbr[:, 0]
+    r_h = tlbr[:, 3] + tlbr[:, 1]
+
+    scale_w = r_w / box_w[:, None]
+    scale_h = r_h / box_h[:, None]
+    scale_w = torch.max(scale_w, 1 / scale_w)
+    scale_h = torch.max(scale_h, 1 / scale_h)
+
+    scale_penalty = torch.exp((-scale_w * scale_h + 1) * 0.1)
+
+    return scale_penalty
+
+
+def get_cosine_window_penalty(tlbr: torch.Tensor):
+    num_boxes, _, num_elements = tlbr.shape
+    h_w = int(np.sqrt(num_elements))
+    hanning = torch.hann_window(h_w, dtype=torch.float, device=tlbr.device)
+    window = torch.ger(hanning, hanning)
+    window = window.reshape(-1)
+
+    return window[None, :]
+
+
+def wrap_results_to_boxlist(bb, bb_conf, boxes: [BoxList], amodal=False):
+    num_boxes_per_image = [len(box) for box in boxes]
+    bb = bb.split(num_boxes_per_image, dim=0)
+    bb_conf = bb_conf.split(num_boxes_per_image, dim=0)
+
+    track_boxes = []
+    for _bb, _bb_conf, _boxes in zip(bb, bb_conf, boxes):
+        _bb = _bb.reshape(-1, 4)
+        track_box = BoxList(_bb, _boxes.size, mode="xyxy")
+        track_box.add_field("ids", _boxes.get_field('ids'))
+        track_box.add_field("labels", _boxes.get_field('labels'))
+        track_box.add_field("scores", _bb_conf)
+        if not amodal:
+            track_box.clip_to_image(remove_empty=True)
+        track_boxes.append(track_box)
+
+    return track_boxes
+
+
+def get_locations(fmap: torch.Tensor, template_fmap: torch.Tensor,
+                  sr_boxes: [BoxList], shift_xy, up_scale=1):
+    """
+
+    """
+    h, w = fmap.size()[-2:]
+    h, w = h*up_scale, w*up_scale
+    concat_boxes = cat([b.bbox for b in sr_boxes], dim=0)
+    box_w = concat_boxes[:, 2] - concat_boxes[:, 0]
+    box_h = concat_boxes[:, 3] - concat_boxes[:, 1]
+    stride_h = box_h / (h - 1)
+    stride_w = box_w / (w - 1)
+
+    device = concat_boxes.device
+    delta_x = torch.arange(0, w, dtype=torch.float32, device=device)
+    delta_y = torch.arange(0, h, dtype=torch.float32, device=device)
+
+    delta_x = (concat_boxes[:, 0])[:, None] + delta_x[None, :] * stride_w[:, None]
+    delta_y = (concat_boxes[:, 1])[:, None] + delta_y[None, :] * stride_h[:, None]
+
+    h0, w0 = template_fmap.shape[-2:]
+    assert (h0 == w0)
+    border = np.int(np.floor(h0 / 2))
+    st_end_idx = int(border * up_scale)
+    delta_x = delta_x[:, st_end_idx:-st_end_idx]
+    delta_y = delta_y[:, st_end_idx:-st_end_idx]
+
+    locations = []
+    num_boxes = delta_x.shape[0]
+    for i in range(num_boxes):
+        _y, _x = torch.meshgrid((delta_y[i, :], delta_x[i, :]))
+        _y = _y.reshape(-1)
+        _x = _x.reshape(-1)
+        _xy = torch.stack((_x, _y), dim=1)
+        locations.append(_xy)
+    locations = torch.stack(locations)
+
+    # shift the coordinates w.r.t the original image space (before padding)
+    locations[:, :, 0] -= shift_xy[0]
+    locations[:, :, 1] -= shift_xy[1]
+
+    return locations
\ No newline at end of file
diff --git a/siam-mot/siammot/modelling/track_head/EMM/track_loss.py b/siam-mot/siammot/modelling/track_head/EMM/track_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4529f7dabe9c0406575486841fd912c31e59ee6
--- /dev/null
+++ b/siam-mot/siammot/modelling/track_head/EMM/track_loss.py
@@ -0,0 +1,160 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def get_cls_loss(pred, label, select):
+    if len(select.size()) == 0 or \
+            select.size() == torch.Size([0]):
+        return 0
+    pred = torch.index_select(pred, 0, select)
+    label = torch.index_select(label, 0, select)
+    return F.nll_loss(pred, label)
+
+
+def select_cross_entropy_loss(pred, label):
+    pred = pred.view(-1, 2)
+    label = label.view(-1)
+    pos = label.data.eq(1).nonzero().squeeze().cuda()
+    neg = label.data.eq(0).nonzero().squeeze().cuda()
+    loss_pos = get_cls_loss(pred, label, pos)
+    loss_neg = get_cls_loss(pred, label, neg)
+    return loss_pos * 0.5 + loss_neg * 0.5
+
+
+def log_softmax(cls_logits):
+    b, a2, h, w = cls_logits.size()
+    cls_logits = cls_logits.view(b, 2, a2 // 2, h, w)
+    cls_logits = cls_logits.permute(0, 2, 3, 4, 1).contiguous()
+    cls_logits = F.log_softmax(cls_logits, dim=4)
+    return cls_logits
+
+
+class IOULoss(nn.Module):
+    def forward(self, pred, target, weight=None):
+        pred_l = pred[:, 0]
+        pred_t = pred[:, 1]
+        pred_r = pred[:, 2]
+        pred_b = pred[:, 3]
+
+        target_l = target[:, 0]
+        target_t = target[:, 1]
+        target_r = target[:, 2]
+        target_b = target[:, 3]
+
+        target_area = (target_l + target_r) * (target_t + target_b)
+        pred_area = (pred_l + pred_r) * (pred_t + pred_b)
+
+        w_intersect = torch.min(pred_l, target_l) + torch.min(pred_r, target_r)
+        h_intersect = torch.min(pred_b, target_b) + torch.min(pred_t, target_t)
+
+        area_intersect = w_intersect * h_intersect
+        area_union = target_area + pred_area - area_intersect
+
+        losses = -torch.log((area_intersect + 1.) / (area_union + 1.))
+
+        if weight is not None and weight.sum() > 0:
+            return (losses * weight).sum() / weight.sum()
+        else:
+            return losses.mean()
+
+
+class EMMLossComputation(object):
+    def __init__(self, cfg):
+        self.box_reg_loss_func = IOULoss()
+        self.centerness_loss_func = nn.BCEWithLogitsLoss()
+        self.cfg = cfg
+        self.pos_ratio = cfg.MODEL.TRACK_HEAD.EMM.CLS_POS_REGION
+        self.loss_weight = cfg.MODEL.TRACK_HEAD.EMM.TRACK_LOSS_WEIGHT
+
+    def prepare_targets(self, points, src_bbox, gt_bbox):
+
+        cls_labels, reg_targets = self.compute_targets(points, src_bbox, gt_bbox)
+
+        return cls_labels, reg_targets
+
+    def compute_targets(self, locations, src_bbox, tar_bbox):
+        xs, ys = locations[:, :, 0], locations[:, :, 1]
+
+        num_boxes, num_locations, _ = locations.shape
+        cls_labels = torch.zeros((num_boxes, num_locations),
+                                 dtype=torch.int64, device=locations.device)
+
+        _l = xs - tar_bbox[:, 0:1].float()
+        _t = ys - tar_bbox[:, 1:2].float()
+        _r = tar_bbox[:, 2:3].float() - xs
+        _b = tar_bbox[:, 3:4].float() - ys
+
+        s1 = _l > self.pos_ratio * ((tar_bbox[:, 2:3] - tar_bbox[:, 0:1]) / 2).float()
+        s2 = _r > self.pos_ratio * ((tar_bbox[:, 2:3] - tar_bbox[:, 0:1]) / 2).float()
+        s3 = _t > self.pos_ratio * ((tar_bbox[:, 3:4] - tar_bbox[:, 1:2]) / 2).float()
+        s4 = _b > self.pos_ratio * ((tar_bbox[:, 3:4] - tar_bbox[:, 1:2]) / 2).float()
+
+        is_in_pos_boxes = s1 * s2 * s3 * s4
+        cls_labels[is_in_pos_boxes == 1] = 1
+
+        reg_targets = torch.stack([_l, _t, _r, _b], dim=2)
+
+        return cls_labels.contiguous(), reg_targets.contiguous()
+
+    @staticmethod
+    def compute_centerness_targets(reg_targets):
+        left_right = reg_targets[:, [0, 2]]
+        top_bottom = reg_targets[:, [1, 3]]
+        centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \
+                     (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
+        return torch.sqrt(centerness)
+
+    @staticmethod
+    def normalize_regression_outputs(src_bbox, regression_outputs):
+        # normalize the regression targets
+        half_src_box_w = (src_bbox[:, 2:3] - src_bbox[:, 0:1]) / 2. + 1e-10
+        half_src_box_h = (src_bbox[:, 3:4] - src_bbox[:, 1:2]) / 2. + 1e-10
+        assert (all(half_src_box_w > 0))
+        assert (all(half_src_box_h > 0))
+
+        regression_outputs[:, :, 0] = (regression_outputs[:, :, 0] / half_src_box_w) * 128
+        regression_outputs[:, :, 1] = (regression_outputs[:, :, 1] / half_src_box_h) * 128
+        regression_outputs[:, :, 2] = (regression_outputs[:, :, 2] / half_src_box_w) * 128
+        regression_outputs[:, :, 3] = (regression_outputs[:, :, 3] / half_src_box_h) * 128
+
+        return regression_outputs
+
+    def __call__(self, locations, box_cls, box_regression, centerness, src, targets):
+        """
+        """
+
+        cls_labels, reg_targets = self.prepare_targets(locations, src, targets)
+
+        box_regression = (box_regression.permute(0, 2, 3, 1).contiguous()).view(-1, 4)
+        box_regression_flatten = box_regression.view(-1, 4)
+        reg_targets_flatten = reg_targets.view(-1, 4)
+        cls_labels_flatten = cls_labels.view(-1)
+        centerness_flatten = centerness.view(-1)
+
+        in_box_inds = torch.nonzero(cls_labels_flatten > 0).squeeze(1)
+        box_regression_flatten = box_regression_flatten[in_box_inds]
+        reg_targets_flatten = reg_targets_flatten[in_box_inds]
+        centerness_flatten = centerness_flatten[in_box_inds]
+
+        box_cls = log_softmax(box_cls)
+        cls_loss = select_cross_entropy_loss(box_cls, cls_labels_flatten)
+
+        if in_box_inds.numel() > 0:
+            centerness_targets = self.compute_centerness_targets(reg_targets_flatten)
+            reg_loss = self.box_reg_loss_func(
+                box_regression_flatten,
+                reg_targets_flatten,
+                centerness_targets
+            )
+            centerness_loss = self.centerness_loss_func(
+                centerness_flatten,
+                centerness_targets
+            )
+        else:
+            reg_loss = 0. * box_regression_flatten.sum()
+            centerness_loss = 0. * centerness_flatten.sum()
+
+        return self.loss_weight*cls_loss, self.loss_weight*reg_loss, self.loss_weight*centerness_loss
+
+
diff --git a/siam-mot/siammot/modelling/track_head/EMM/xcorr.py b/siam-mot/siammot/modelling/track_head/EMM/xcorr.py
new file mode 100644
index 0000000000000000000000000000000000000000..088c98732be8a70c09a3611f36a22cfe868300f4
--- /dev/null
+++ b/siam-mot/siammot/modelling/track_head/EMM/xcorr.py
@@ -0,0 +1,46 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import torch
+import torch.nn.functional as F
+
+
+def xcorr_slow(x, kernel):
+    """for loop to calculate cross correlation, slow version
+    """
+    batch = x.size()[0]
+    out = []
+    for i in range(batch):
+        px = x[i]
+        pk = kernel[i]
+        px = px.view(1, -1, px.size()[1], px.size()[2])
+        pk = pk.view(1, -1, pk.size()[1], pk.size()[2])
+        po = F.conv2d(px, pk)
+        out.append(po)
+    out = torch.cat(out, 0)
+    return out
+
+
+def xcorr_fast(x, kernel):
+    """group conv2d to calculate cross correlation, fast version
+    """
+    batch = kernel.size()[0]
+    pk = kernel.view(-1, x.size()[1], kernel.size()[2], kernel.size()[3])
+    px = x.view(1, -1, x.size()[2], x.size()[3])
+    po = F.conv2d(px, pk, groups=batch)
+    po = po.view(batch, -1, po.size()[2], po.size()[3])
+    return po
+
+
+def xcorr_depthwise(x, kernel):
+    """depthwise cross correlation
+    """
+    batch = kernel.size(0)
+    channel = kernel.size(1)
+    x = x.view(1, batch*channel, x.size(2), x.size(3))
+    kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3))
+    out = F.conv2d(x, kernel, groups=batch*channel)
+    out = out.view(batch, channel, out.size(2), out.size(3))
+    return out
\ No newline at end of file
diff --git a/siam-mot/siammot/modelling/track_head/track_head.py b/siam-mot/siammot/modelling/track_head/track_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..6585287913004a1d054bdf487b8b2dcd1f1e63c0
--- /dev/null
+++ b/siam-mot/siammot/modelling/track_head/track_head.py
@@ -0,0 +1,126 @@
+import torch
+
+from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
+
+from siammot.utils import registry
+
+
+class TrackHead(torch.nn.Module):
+    def __init__(self, tracker, tracker_sampler, track_utils, track_pool):
+        super(TrackHead, self).__init__()
+
+        self.tracker = tracker
+        self.sampler = tracker_sampler
+
+        self.track_utils = track_utils
+        self.track_pool = track_pool
+
+    def forward(self, features, proposals=None, targets=None, track_memory=None):
+        if self.training:
+            return self.forward_train(features, proposals, targets)
+        else:
+            return self.forward_inference(features, track_memory)
+
+    def forward_train(self, features, proposals=None, targets=None):
+        """
+        Perform correlation on feature maps and regress the location of the object in other frame
+        :param features: a list of feature maps from different intermediary layers of feature backbone
+        :param proposals:
+        :param targets:
+        """
+
+        with torch.no_grad():
+            track_proposals, sr, track_targets = self.sampler(proposals, targets)
+
+        return self.tracker(features, track_proposals, sr=sr, targets=track_targets)
+
+    def forward_inference(self, features, track_memory=None):
+        track_boxes = None
+        if track_memory is None:
+            self.track_pool.reset()
+        else:
+            (template_features, sr, template_boxes) = track_memory
+            if template_features.numel() > 0:
+                return self.tracker(features, template_boxes, sr=sr,
+                                    template_features=template_features)
+        return {}, track_boxes, {}
+
+    def reset_track_pool(self):
+        """
+        Reset the track pool
+        """
+        self.track_pool.reset()
+
+    def get_track_memory(self, features, tracks):
+        assert (len(tracks) == 1)
+        active_tracks = self._get_track_targets(tracks[0])
+
+        # no need for feature extraction of search region if
+        # the tracker is tracktor, or no trackable instances
+        if len(active_tracks) == 0:
+            import copy
+            template_features = torch.tensor([], device=features[0].device)
+            sr = copy.deepcopy(active_tracks)
+            sr.size = [active_tracks.size[0] + self.track_utils.pad_pixels * 2,
+                       active_tracks.size[1] + self.track_utils.pad_pixels * 2]
+            track_memory = (template_features, [sr], [active_tracks])
+
+        else:
+            track_memory = self.tracker.extract_cache(features, active_tracks)
+
+        track_memory = self._update_memory_with_dormant_track(track_memory)
+
+        self.track_pool.update_cache(track_memory)
+
+        return track_memory
+
+    def _update_memory_with_dormant_track(self, track_memory):
+        cache = self.track_pool.get_cache()
+        if not cache or track_memory is None:
+            return track_memory
+
+        dormant_caches = []
+        for dormant_id in self.track_pool.get_dormant_ids():
+            if dormant_id in cache:
+                dormant_caches.append(cache[dormant_id])
+        cached_features = [x[0][None, ...] for x in dormant_caches]
+        if track_memory[0] is None:
+            if track_memory[1][0] or track_memory[2][0]:
+                raise Exception("Unexpected cache state")
+            track_memory = [[]] * 3
+            buffer_feat = []
+        else:
+            buffer_feat = [track_memory[0]]
+        features = torch.cat(buffer_feat + cached_features)
+        sr = cat_boxlist(track_memory[1] + [x[1] for x in dormant_caches])
+        boxes = cat_boxlist(track_memory[2] + [x[2] for x in dormant_caches])
+        return features, [sr], [boxes]
+
+    def _get_track_targets(self, target):
+        if len(target) == 0:
+            return target
+        active_ids = self.track_pool.get_active_ids()
+
+        ids = target.get_field('ids').tolist()
+        idxs = torch.zeros((len(ids), ), dtype=torch.bool, device=target.bbox.device)
+        for _i, _id in enumerate(ids):
+            if _id in active_ids:
+                idxs[_i] = True
+
+        return target[idxs]
+
+
+def build_track_head(cfg, track_utils, track_pool):
+
+    import siammot.modelling.track_head.EMM.track_core
+    import siammot.modelling.track_head.EMM.target_sampler
+
+    tracker = registry.SIAMESE_TRACKER[
+        cfg.MODEL.TRACK_HEAD.MODEL
+    ](cfg, track_utils)
+
+    tracker_sampler = registry.TRACKER_SAMPLER[
+        cfg.MODEL.TRACK_HEAD.MODEL
+    ](cfg, track_utils)
+
+    return TrackHead(tracker, tracker_sampler, track_utils, track_pool)
diff --git a/siam-mot/siammot/modelling/track_head/track_solver.py b/siam-mot/siammot/modelling/track_head/track_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c640bc560acb1851f76530ebdd5a1f0ce786211
--- /dev/null
+++ b/siam-mot/siammot/modelling/track_head/track_solver.py
@@ -0,0 +1,115 @@
+import torch
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
+
+
+class TrackSolver(torch.nn.Module):
+    def __init__(self,
+                 track_pool,
+                 track_thresh=0.3,
+                 start_track_thresh=0.5,
+                 resume_track_thresh=0.4,
+                 ):
+        super(TrackSolver, self).__init__()
+
+        self.track_pool = track_pool
+        self.track_thresh = track_thresh
+        self.start_thresh = start_track_thresh
+        self.resume_track_thresh = resume_track_thresh
+
+    def get_nms_boxes(self, detection):
+        detection = boxlist_nms(detection, nms_thresh=0.5)
+
+        _ids = detection.get_field('ids')
+        _scores = detection.get_field('scores')
+
+        # adjust the scores to the right range
+        # _scores -= torch.floor(_scores) * (_ids >= 0) * (torch.floor(_scores) != _scores)
+        # _scores[_scores >= 1.] = 1.
+
+        _scores[_scores >= 2.] = _scores[_scores >= 2.] - 2.
+        _scores[_scores >= 1.] = _scores[_scores >= 1.] - 1.
+
+        return detection, _ids, _scores
+
+    def forward(self, detection: [BoxList]):
+        """
+        The solver is to merge predictions from detection branch as well as from track branch.
+        The goal is to assign an unique track id to bounding boxes that are deemed tracked
+        :param detection: it includes three set of distinctive prediction:
+        prediction propagated from active tracks, (2 >= score > 1, id >= 0),
+        prediction propagated from dormant tracks, (2 >= score > 1, id >= 0),
+        prediction from detection (1 > score > 0, id = -1).
+        :return:
+        """
+
+        # only process one frame at a time
+        assert len(detection) == 1
+        detection = detection[0]
+
+        if len(detection) == 0:
+            return [detection]
+
+        track_pool = self.track_pool
+
+        all_ids = detection.get_field('ids')
+        all_scores = detection.get_field('scores')
+        active_ids = track_pool.get_active_ids()
+        dormant_ids = track_pool.get_dormant_ids()
+        device = all_ids.device
+
+        active_mask = torch.tensor([int(x) in active_ids for x in all_ids], device=device)
+
+        # differentiate active tracks from dormant tracks with scores
+        # active tracks, (3 >= score > 2, id >= 0),
+        # dormant tracks, (2 >= score > 1, id >= 0),
+        # By doing this, dormant tracks will be merged to active tracks during nms,
+        # if they highly overlap
+        all_scores[active_mask] += 1.
+
+        nms_detection, nms_ids, nms_scores = self.get_nms_boxes(detection)
+
+        combined_detection = nms_detection
+        _ids = combined_detection.get_field('ids')
+        _scores = combined_detection.get_field('scores')
+
+        # start track ids
+        start_idxs = ((_ids < 0) & (_scores >= self.start_thresh)).nonzero()
+
+        # inactive track ids
+        inactive_idxs = ((_ids >= 0) & (_scores < self.track_thresh))
+        nms_track_ids = set(_ids[_ids >= 0].tolist())
+        all_track_ids = set(all_ids[all_ids >= 0].tolist())
+        # active tracks that are removed by nms
+        nms_removed_ids = all_track_ids - nms_track_ids
+        inactive_ids = set(_ids[inactive_idxs].tolist()) | nms_removed_ids
+
+        # resume dormant mask, if needed
+        dormant_mask = torch.tensor([int(x) in dormant_ids for x in _ids], device=device)
+        resume_ids = _ids[dormant_mask & (_scores >= self.resume_track_thresh)]
+        for _id in resume_ids.tolist():
+            track_pool.resume_track(_id)
+
+        for _idx in start_idxs:
+            _ids[_idx] = track_pool.start_track()
+
+        active_ids = track_pool.get_active_ids()
+        for _id in inactive_ids:
+            if _id in active_ids:
+                track_pool.suspend_track(_id)
+
+        # make sure that the ids for inactive tracks in current frame are meaningless (< 0)
+        _ids[inactive_idxs] = -1
+
+        track_pool.expire_tracks()
+        track_pool.increment_frame()
+
+        return [combined_detection]
+
+
+def builder_tracker_solver(cfg, track_pool):
+    return TrackSolver(track_pool,
+                       cfg.MODEL.TRACK_HEAD.TRACK_THRESH,
+                       cfg.MODEL.TRACK_HEAD.START_TRACK_THRESH,
+                       cfg.MODEL.TRACK_HEAD.RESUME_TRACK_THRESH)
\ No newline at end of file
diff --git a/siam-mot/siammot/modelling/track_head/track_utils.py b/siam-mot/siammot/modelling/track_head/track_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c0e1d9e3def9b61f21a8d78edafca1274e23c97
--- /dev/null
+++ b/siam-mot/siammot/modelling/track_head/track_utils.py
@@ -0,0 +1,272 @@
+import torch
+import numpy as np
+import torch.nn.functional as F
+
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+
+
+class TrackUtils(object):
+    """
+    A class that includes utility functions unique to track branch
+    """
+    def __init__(self, search_expansion=1.0, min_search_wh=128, pad_pixels=256):
+        """
+        :param search_expansion: expansion ratio (of the search region)
+        w.r.t the size of tracking targets
+        :param min_search_wh: minimal size (width and height) of the search region
+        :param pad_pixels: the padding pixels that are neccessary to keep the
+        feature map pf search region and that of template target in the same scale
+        """
+        self.search_expansion = search_expansion
+        self.min_search_wh = min_search_wh
+        self.pad_pixels = pad_pixels
+
+    @staticmethod
+    def swap_pairs(entity_list):
+        assert len(entity_list) % 2 == 0
+        # Take the targets of the other frame (in a tracking pair) as input during training, thus swap order
+        for xx in range(0, len(entity_list), 2):
+            entity_list[xx], entity_list[xx + 1] = entity_list[xx + 1], entity_list[xx]
+        return entity_list
+
+    @staticmethod
+    def shuffle_feature(f):
+        """
+        odd-even order swap of the feature tensor in the batch dimension
+        """
+
+        def shuffle_feature_tensor(x):
+            batch_size = x.shape[0]
+            assert batch_size % 2 == 0
+
+            # get channel swap order [1, 0, 3, 2, ...]
+            odd_idx = range(1, batch_size, 2)
+            even_idx = range(0, batch_size, 2)
+            idxs = np.arange(0, batch_size)
+            idxs[even_idx] = idxs[even_idx] + 1
+            idxs[odd_idx] = idxs[odd_idx] - 1
+            idxs = torch.tensor(idxs)
+
+            return x[idxs]
+
+        if isinstance(f, tuple):
+            shuffle_f = []
+            for i, _f in enumerate(f):
+                shuffle_f.append(shuffle_feature_tensor(_f))
+            shuffle_f = tuple(shuffle_f)
+        else:
+            shuffle_f = shuffle_feature_tensor(f)
+
+        return shuffle_f
+
+    def extend_bbox(self, in_box: [BoxList]):
+        """
+        Extend the bounding box to define the search region
+        :param in_box: a set of bounding boxes in previous frame
+        :param min_wh: the miniumun width/height of the search region
+        """
+        for i, _track in enumerate(in_box):
+            bbox_w = _track.bbox[:, 2] - _track.bbox[:, 0] + 1
+            bbox_h = _track.bbox[:, 3] - _track.bbox[:, 1] + 1
+            w_ext = bbox_w * (self.search_expansion / 2.)
+            h_ext = bbox_h * (self.search_expansion / 2.)
+
+            # todo: need to check the equation later
+            min_w_ext = (self.min_search_wh - bbox_w) / (self.search_expansion * 2.)
+            min_h_ext = (self.min_search_wh - bbox_h) / (self.search_expansion * 2.)
+
+            w_ext = torch.max(min_w_ext, w_ext)
+            h_ext = torch.max(min_h_ext, h_ext)
+            in_box[i].bbox[:, 0] -= w_ext
+            in_box[i].bbox[:, 1] -= h_ext
+            in_box[i].bbox[:, 2] += w_ext
+            in_box[i].bbox[:, 3] += h_ext
+            # in_box[i].clip_to_image()
+        return in_box
+
+    def pad_feature(self, f):
+        """
+        Pad the feature maps with 0
+        :param f: [torch.tensor] or torch.tensor
+        """
+
+        if isinstance(f, (list, tuple)):
+            pad_f = []
+            for i, _f in enumerate(f):
+                # todo fix this hack, should read from cfg file
+                pad_pixels = int(self.pad_pixels / ((2 ** i) * 4))
+                x = F.pad(_f, [pad_pixels, pad_pixels, pad_pixels, pad_pixels],
+                          mode='constant', value=0)
+                pad_f.append(x)
+            pad_f = tuple(pad_f)
+        else:
+            pad_f = F.pad(f, [self.pad_pixels, self.pad_pixels,
+                              self.pad_pixels, self.pad_pixels],
+                          mode='constant', value=0)
+
+        return pad_f
+
+    def update_boxes_in_pad_images(self, boxlists:[BoxList]):
+        """
+        Update the coordinates of bounding boxes in the padded image
+        """
+
+        pad_width = self.pad_pixels
+        pad_height = self.pad_pixels
+
+        pad_boxes = []
+        for _boxlist in boxlists:
+            im_width, im_height = _boxlist.size
+            new_width = int(im_width + pad_width*2)
+            new_height = int(im_height + pad_height*2)
+
+            assert (_boxlist.mode == 'xyxy')
+            xmin, ymin, xmax, ymax = _boxlist.bbox.split(1, dim=-1)
+            new_xmin = xmin + pad_width
+            new_ymin = ymin + pad_height
+            new_xmax = xmax + pad_width
+            new_ymax = ymax + pad_height
+            bbox = torch.cat((new_xmin, new_ymin, new_xmax, new_ymax), dim=-1)
+            bbox = BoxList(bbox, [new_width, new_height], mode='xyxy')
+            for _field in _boxlist.fields():
+                bbox.add_field(_field, _boxlist.get_field(_field))
+            pad_boxes.append(bbox)
+
+        return pad_boxes
+
+
+class TrackPool(object):
+    """
+    A class to manage the track id distribution (initiate/kill a track)
+    """
+    def __init__(self, active_ids=None, max_entangle_length=10, max_dormant_frames=1):
+        if active_ids is None:
+            self._active_ids = set()
+            # track ids that are killed up to previous frames
+            self._dormant_ids = {}
+            # track ids that are killed in current frame
+            self._kill_ids = set()
+            self._max_id = -1
+        self._embedding = None
+        self._cache = {}
+        self._frame_idx = 0
+        self._max_dormant_frames = max_dormant_frames
+        self._max_entangle_length = max_entangle_length
+
+    def suspend_track(self, track_id):
+        """
+        Suspend an active track, and add it to dormant track pools
+        """
+        if track_id not in self._active_ids:
+            raise ValueError
+
+        self._active_ids.remove(track_id)
+        self._dormant_ids[track_id] = self._frame_idx - 1
+
+    def expire_tracks(self):
+        """
+        Expire the suspended tracks after they are inactive
+        for a consecutive self._max_dormant_frames frames
+        """
+        for track_id, last_active in list(self._dormant_ids.items()):
+            if self._frame_idx - last_active >= self._max_dormant_frames:
+                self._dormant_ids.pop(track_id)
+                self._kill_ids.add(track_id)
+                self._cache.pop(track_id, None)
+
+    def increment_frame(self, value=1):
+        self._frame_idx += value
+
+    def update_cache(self, cache):
+        """
+        Update the latest position (bbox) / search region / template feature
+        for each track in the cache
+        """
+        template_features, sr, template_boxes = cache
+        sr = sr[0]
+        template_boxes = template_boxes[0]
+        for idx in range(len(template_boxes)):
+            if len(template_features) > 0:
+                assert len(template_features) == len(sr)
+                features = template_features[idx]
+            else:
+                features = template_features
+            search_region = sr[idx: idx+1]
+            box = template_boxes[idx: idx+1]
+            track_id = box.get_field("ids").item()
+            self._cache[track_id] = (features, search_region, box)
+
+    def resume_track(self, track_id):
+        """
+        Resume a dormant track
+        """
+        if track_id not in self._dormant_ids or \
+                track_id in self._active_ids:
+            raise ValueError
+
+        self._active_ids.add(track_id)
+        self._dormant_ids.pop(track_id)
+
+    def kill_track(self, track_id):
+        """
+        Kill a track
+        """
+        if track_id not in self._active_ids:
+            raise ValueError
+
+        self._active_ids.remove(track_id)
+        self._kill_ids.add(track_id)
+        self._cache.pop(track_id, None)
+
+    def start_track(self):
+        """
+        Return a new track id, when starting a new track
+        """
+        new_id = self._max_id + 1
+        self._max_id = new_id
+        self._active_ids.add(new_id)
+
+        return new_id
+
+    def get_active_ids(self):
+        return self._active_ids
+
+    def get_dormant_ids(self):
+        return set(self._dormant_ids.keys())
+
+    def get_cache(self):
+        return self._cache
+
+    def activate_tracks(self, track_id):
+        if track_id in self._active_ids or \
+           track_id not in self._dormant_ids:
+            raise ValueError
+
+        self._active_ids.add(track_id)
+        self._dormant_ids.pop(track_id)
+
+    def reset(self):
+        self._active_ids = set()
+        self._kill_ids = set()
+        self._dormant_ids = {}
+        self._embedding = None
+        self._cache = {}
+        self._max_id = -1
+        self._frame_idx = 0
+
+
+def build_track_utils(cfg):
+
+    search_expansion = cfg.MODEL.TRACK_HEAD.SEARCH_REGION - 1.
+    pad_pixels = cfg.MODEL.TRACK_HEAD.PAD_PIXELS
+    min_search_wh = cfg.MODEL.TRACK_HEAD.MINIMUM_SREACH_REGION
+
+    track_utils = TrackUtils(search_expansion=search_expansion,
+                             min_search_wh=min_search_wh,
+                             pad_pixels=pad_pixels)
+    track_pool = TrackPool(max_dormant_frames=cfg.MODEL.TRACK_HEAD.MAX_DORMANT_FRAMES)
+
+    return track_utils, track_pool
+
+
+
diff --git a/siam-mot/siammot/utils/boxlists_to_entities.py b/siam-mot/siammot/utils/boxlists_to_entities.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fdb469723297fe8af0c36c8615a804909d04386
--- /dev/null
+++ b/siam-mot/siammot/utils/boxlists_to_entities.py
@@ -0,0 +1,36 @@
+from maskrcnn_benchmark.structures.bounding_box import BoxList
+from gluoncv.torch.data.gluoncv_motion_dataset.dataset import AnnoEntity
+
+
+def boxlists_to_entities(boxlists, firstframe_idx, timestamps, class_table=None):
+    """
+    Convert a list of boxlist to entities
+    :return:
+    """
+
+    if isinstance(boxlists, BoxList):
+        boxlists = [boxlists]
+
+    # default class is person only
+    if class_table is None:
+        class_table = ["person"]
+
+    assert isinstance(boxlists, list), "The input has to be a list"
+
+    entities = []
+    for i, boxlist in enumerate(boxlists):
+        for j in range(len(boxlist)):
+            entity = AnnoEntity()
+            entity.bbox = boxlist.bbox[j].tolist()
+            entity.confidence = boxlist.get_field('scores')[j].item()
+            _label = boxlist.get_field('labels')[j].item()
+            entity.labels = {class_table[_label - 1]: entity.confidence}
+            # the default id is -1
+            entity.id = -1
+            if boxlist.has_field('ids'):
+                entity.id = boxlist.get_field('ids')[j].item()
+            entity.frame_num = firstframe_idx + i
+            entity.time = timestamps[i]
+            entities.append(entity)
+
+    return entities
diff --git a/siam-mot/siammot/utils/entity_utils.py b/siam-mot/siammot/utils/entity_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ab9d11db767f434b9c7fcd5b7497cf26c0c88a9
--- /dev/null
+++ b/siam-mot/siammot/utils/entity_utils.py
@@ -0,0 +1,46 @@
+import numpy as np
+from gluoncv.torch.data.gluoncv_motion_dataset.dataset import AnnoEntity
+
+
+def bbs_iou(entities_1: [AnnoEntity], entities_2: [AnnoEntity]):
+    """
+    Compute iou matrix between two lists of Entity
+    bbox in AnnoEntity is in the format of xywh
+
+    Different from boxlist_iou in terms of not adding TO_REMOVE to wh
+    """
+
+    if not isinstance(entities_1, list):
+        entities_1 = [entities_1]
+    if not isinstance(entities_2, list):
+        entities_2 = [entities_2]
+
+    if len(entities_1) == 0 or len(entities_2) == 0:
+        return np.zeros((len(entities_1), len(entities_2)))
+
+    box_xywh_1 = np.array([entity.bbox for entity in entities_1])
+    box_xywh_2 = np.array([entity.bbox for entity in entities_2])
+
+    # compute the area of union regions
+    area1 = box_xywh_1[:, 2] * box_xywh_1[:, 3]
+    area2 = box_xywh_2[:, 2] * box_xywh_2[:, 3]
+
+    # to xyxy
+    box_xyxy_1 = np.zeros_like(box_xywh_1)
+    box_xyxy_2 = np.zeros_like(box_xywh_2)
+    box_xyxy_1[:, :2] = box_xywh_1[:, 0:2]
+    box_xyxy_2[:, :2] = box_xywh_2[:, 0:2]
+    box_xyxy_1[:, 2:] = box_xywh_1[:, :2] + box_xywh_1[:, 2:]
+    box_xyxy_2[:, 2:] = box_xywh_2[:, :2] + box_xywh_2[:, 2:]
+
+    lt = np.maximum(box_xyxy_1[:, None, :2], box_xyxy_2[:, :2])  # [N,M,2]
+    rb = np.minimum(box_xyxy_1[:, None, 2:], box_xyxy_2[:, 2:])  # [N,M,2]
+
+    TO_REMOVE = 1
+
+    wh = (rb - lt).clip(min=0)  # [N,M,2]
+    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]
+
+    iou = inter / (area1[:, None] + area2 - inter)
+
+    return iou
\ No newline at end of file
diff --git a/siam-mot/siammot/utils/get_model_name.py b/siam-mot/siammot/utils/get_model_name.py
new file mode 100644
index 0000000000000000000000000000000000000000..2351ed95e64d32252e9203837e440e5f7b5ef6f8
--- /dev/null
+++ b/siam-mot/siammot/utils/get_model_name.py
@@ -0,0 +1,49 @@
+def get_model_name(cfg,
+                   model_suffix=None,
+                   is_train=True,
+                  ):
+    """
+    Automatically generate a model name that carries key information about configuration;
+    Those information includes the backbone, functionality (detection / tracking),
+    trained dataset and any manually attached experiment identifier
+
+    :param cfg:  experiment configuration file
+    :param model_suffix: manually attached experiment identifier
+    :param is_train: whether it is for training
+    :return:
+    """
+    backbone = cfg.MODEL.BACKBONE.CONV_BODY
+    branch_suffix = _get_branch_suffix(cfg)
+
+    assert is_train is True, "This function is called only during training."
+    dataset_list = cfg.DATASETS.TRAIN
+    dataset_suffix = _get_dataset_suffix(dataset_list)
+
+    output_dir = ""
+    output_dir += backbone
+    output_dir += branch_suffix
+    output_dir += dataset_suffix
+    if model_suffix is not None:
+        if len(model_suffix) > 0:
+            output_dir += ('_' + model_suffix)
+
+    return output_dir
+
+
+def _get_branch_suffix(cfg):
+    suffix = ""
+    if cfg.MODEL.BOX_ON:
+        suffix += '_box'
+    if cfg.MODEL.TRACK_ON:
+        suffix += '_track'
+    return suffix
+
+
+def _get_dataset_suffix(dataset_list):
+    suffix = ""
+
+    if not isinstance(dataset_list, (list, tuple)):
+        raise RuntimeError("dataset_list should be a list of strings, got {}".format(dataset_list))
+    for dataset_key in dataset_list:
+        suffix += ("_"+dataset_key)
+    return suffix
diff --git a/siam-mot/siammot/utils/registry.py b/siam-mot/siammot/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..11138ff5f5b94bba13ea72f552dcea371f5a1f09
--- /dev/null
+++ b/siam-mot/siammot/utils/registry.py
@@ -0,0 +1,4 @@
+from maskrcnn_benchmark.utils.registry import Registry
+
+SIAMESE_TRACKER = Registry()
+TRACKER_SAMPLER = Registry()
\ No newline at end of file
diff --git a/siam_mot_test.py b/siam_mot_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..346e58161074ca28b33aef71810ff9dbc58ffc5a
--- /dev/null
+++ b/siam_mot_test.py
@@ -0,0 +1,78 @@
+# This file is the entrypoint for your submission.
+# You can modify this file to include your code or directly call your functions/modules from here.
+import random
+import cv2
+from PIL import Image
+from evaluator.airborne_detection import AirbornePredictor
+
+import numpy as np
+from tqdm import tqdm
+
+import os
+from os import listdir
+from os.path import isfile, join
+
+from siam_mot_tracker import SiamMOTTracker
+
+class SiamMOTPredictor(AirbornePredictor):
+    """
+    PARTICIPANT_TODO: You can name your implementation as you like. `RandomPredictor` is just an example.
+    Below paths will be preloaded for you, you can read them as you like.
+    """
+    training_data_path = None
+    test_data_path = None
+    vocabulary_path = None
+
+    """
+    PARTICIPANT_TODO:
+    You can do any preprocessing required for your codebase here like loading up models into memory, etc.
+    """
+    def inference_setup(self):
+        current_path = os.getcwd()
+        config_file = os.path.join(current_path, 'siam-mot/configs/dla/DLA_34_FPN_AOT.yaml')
+        model_path = os.path.join(current_path, 'siam-mot/models/DLA-34-FPN_box_track_aot_d4.pth')
+        self.siammottracker = SiamMOTTracker(config_file, model_path)
+
+    def get_all_frame_images(self, flight_id):
+        frames = []
+        flight_folder = join(self.test_data_path, flight_id)
+        for frame in sorted(listdir(flight_folder)):
+            if isfile(join(flight_folder, frame)):
+                frames.append(frame)
+        return frames
+
+    """
+    PARTICIPANT_TODO:
+    During the evaluation all combinations for flight_id and flight_folder_path will be provided one by one.
+    """
+    def inference(self, flight_id):
+        self.siammottracker.tracker.reset_siammot_status()
+
+        for frame_image in tqdm(self.get_all_frame_images(flight_id)):
+            frame_image_path = self.get_frame_image_location(flight_id, frame_image)
+            frame = cv2.imread(frame_image_path)
+            
+            results = self.siammottracker.process(frame)
+
+            class_name = 'airborne'
+            for idx in range(len(results.bbox)):
+                confidence = results.get_field('scores')[idx]
+                if confidence < 0.3: # filter low confidence predictions
+                    continue
+
+                bbox_xywh = results.bbox.cpu().numpy()[idx]
+
+                # bbox needed is [x0, y0, x1, y1] (top, left, bottom, right)
+                bbox = [ float(bbox_xywh[0]), float(bbox_xywh[1]),
+                         float(bbox_xywh[0] + bbox_xywh[2]),
+                         float(bbox_xywh[1] + bbox_xywh[3])]
+
+                track_id = results.get_field('ids')[idx]
+
+                self.register_object_and_location(class_name, int(track_id), 
+                                                bbox, float(confidence), 
+                                                frame_image)
+
+if __name__ == "__main__":
+    submission = SiamMOTPredictor()
+    submission.run()