diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..2dbeb9ded265bec755eb7405eef7bae68dd239be
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,38 @@
+FROM python:3.10.12-bullseye 
+
+ARG DEBIAN_FRONTEND=noninteractive 
+
+RUN apt -qq update && apt install -qq -y software-properties-common
+
+RUN  apt -qq -y install \
+  software-properties-common \
+  build-essential \
+  curl \
+  rsync \
+  git \
+  ffmpeg \
+  libsm6 \
+  libxext6
+
+ENV USER_NAME aicrowd
+ENV HOME_DIR /home/$USER_NAME
+ENV HOST_UID 1001
+ENV HOST_GID 1001
+
+RUN export uid=${HOST_UID} gid=${HOST_GID} && \
+    mkdir -p ${HOME_DIR} && \
+    echo "$USER_NAME:x:${uid}:${gid}:$USER_NAME,,,:$HOME_DIR:/bin/bash" >> /etc/passwd && \
+    echo "$USER_NAME:x:${uid}:" >> /etc/group
+
+RUN chown -R 1001:1001 ${HOME_DIR}
+
+COPY --chown=1001:1001 . ${HOME_DIR}
+WORKDIR ${HOME_DIR}
+
+RUN apt -qq update && apt -qq install -y `cat ${HOME_DIR}/apt.txt`
+
+RUN pip3 install -r ${HOME_DIR}/requirements.txt && rm -rf /root/.cache
+
+# Add your custom commands here
+
+USER ${USER_NAME}
diff --git a/README.md b/README.md
index 8e5b0eb63060cd49c31b4f50ed3c34827a8ee486..239ae214567bac96904999f429cd6bd40b7df16b 100644
--- a/README.md
+++ b/README.md
@@ -1,92 +1,470 @@
-# mosquitoalert-2023
+>Thanks a lot to the organizers, sponsors and AICrowd team for this quite interesting challenge!
 
-3rd place solution for AICrowd MosquitoAlert 2023 challenge
+<h1>3rd solution detailed</h1> 
 
-## Getting started
+My solution that reached **3rd place** on private LB is based on two stages detailed below:
 
-To make it easy for you to get started with GitLab, here's a list of recommended next steps.
 
-Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)!
+* **Stage1**: A mosquito bounding boxes detector model based on Yolo.
+* **Stage2**: An ensemble of two mosquito classifiers models. First model is based on ViT architecture and second on EfficientNetv2 architecture.
 
-## Add your files
+<h2>Inference</h2>
 
-- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files
-- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command:
 
-```
-cd existing_repo
-git remote add origin http://gitlab.aicrowd.com/MPWARE/mosquitoalert-2023.git
-git branch -M main
-git push -uf origin main
-```
+![Inference](images/inference.png "Inference")
 
-## Integrate with your tools
 
-- [ ] [Set up project integrations](http://gitlab.aicrowd.com/MPWARE/mosquitoalert-2023/-/settings/integrations)
+The main issue in this competition was the 2 seconds limit per image to run the full pipeline with CPU only. Such constraints limit the total number of models we could ensemble and TTA options. To speed up the inference I’ve decided to convert the YoloV8 model to OpenVino as the mAP metric on cross validation was not altered (and even a bit better). The 4 folds inference runtime was acceptable (around 400-500ms). The final bounding boxes are computed through a [weighted boxes fusion](https://github.com/ZFTurbo/Weighted-Boxes-Fusion) ensemble. Images matching to the bounding boxes become the input for both classifiers executed through pytorch without any runtime optimization (around 600-700ms). Finally the logits predictions of each are averaged before computing argmax to get the final prediction.
 
-## Collaborate with your team
 
-- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/)
-- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html)
-- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically)
-- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/)
-- [ ] [Automatically merge when pipeline succeeds](https://docs.gitlab.com/ee/user/project/merge_requests/merge_when_pipeline_succeeds.html)
+<table>
+  <tr>
+   <td>Metric
+   </td>
+   <td>Public LB
+   </td>
+   <td>Private LB
+   </td>
+  </tr>
+  <tr>
+   <td>mIoU
+   </td>
+   <td>0.830
+   </td>
+   <td>0.835
+   </td>
+  </tr>
+  <tr>
+   <td>F1
+   </td>
+   <td>0.904
+   </td>
+   <td>0.915
+   </td>
+  </tr>
+  <tr>
+   <td>F1 filtered
+   </td>
+   <td>0.819
+   </td>
+   <td>0.832
+   </td>
+  </tr>
+</table>
 
-## Test and Deploy
 
-Use the built-in continuous integration in GitLab.
+My runtime [environment](https://gitlab.aicrowd.com/MPWARE/mosquitoalert-2023-phase2-starter-kit/-/issues/234) for inference:
 
-- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/index.html)
-- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing(SAST)](https://docs.gitlab.com/ee/user/application_security/sast/)
-- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html)
-- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/)
-- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html)
 
-***
 
-# Editing this README
+* Python 3.10.12
+* Numpy 1.26.1
+* Pandas 2.1.1
+* Pytorch 2.1.0+cu121
+* Pytorch Lightning 2.0.9
+* Ultralytics 8.0.200
+* Timm 0.9.8
+* Albumentations 1.3.1
+* OPENVINO runtime 2023.1.0
 
-When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!).  Thank you to [makeareadme.com](https://www.makeareadme.com/) for this template.
+The inference script is available [here](https://gitlab.aicrowd.com/MPWARE/mosquitoalert-2023-phase2-starter-kit/-/blob/submission-v0.235/my_models/yolo/combo.py).
 
-## Suggestions for a good README
-Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information.
+<h2>Data</h2>
 
-## Name
-Choose a self-explaining name for your project.
+I’ve used the initial train dataset (10k images) only until last week of the challenge when the external datasets thread was opened. Then, I’ve downloaded/added the provided [external datasets](https://discourse.aicrowd.com/t/external-datasets-used-by-participants/9217) to my training and I got a nice boost on both CV and LB scores. New data helped to provide more examples for rare classes. The initial dataset was severely imbalanced for a few classes and even using a balanced sampler strategy was not helping. Only new samples helped to improve the training. External dataset was around 39k images with no bounding boxes so I’ve generated pseudo label bounding boxes based on my initial trained yolo model. I’ve noticed a lot of noisy images (e.g. labels available but no mosquito) so I’ve decided to keep only pseudo labels with high yolo confidence leading to select 23k images at the end. During the initial training I’ve noticed many similar backgrounds around the mosquitos (blood, finger, …) and to avoid the model overfitting on background context I’ve created an additional background-only dataset (1k images) and a 7th class (no mosquito) for my classifiers.
 
-## Description
-Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors.
+The 10k images (BB from Yolo OOF) + 23k images (Pseudo BB with high confidence) are available [here](https://mosquitoalert.s3.amazonaws.com/stage2-images_and_external_boxes_yolov8n_v3_768_seed_42_1.4_oof.zip). The 1k background images are available [here](https://mosquitoalert.s3.amazonaws.com/background.zip).
+
+<h2>Training</h2>
+
+As said previously, due to runtime limits, we had no real room for multiple models nor room for large models. Large image size was not an option either. So I needed to find relatively small and fast models able to learn something with limited image size.
+
+Training pipeline overview:
+
+
+![Training](images/training.png "Training")
+
+
+**Stage1**: Bounding boxes detector
+
+I’ve selected a YoloV8 nano with a single class. Such a task is quite easy and does not need a large Yolo model. Metrics between YoloV8 nano and small were comparable but nano is quite faster on inference.
+
+Training procedure and hyper parameters:
+
+Cross validation based on 4 folds stratified by species and mosquito surface in original image. Training limited to initial images only to fit best with BB labels provided.
+
+
+
+* 768x768 image size
+* Initial 10k images (no external dataset)
+* Around 1k additional background images
+* Baseline augmentations + degrees=10.0, shear=2.0, flipud=0.5, mixup=0.25, closing mosaic for last 10
+* 128 epochs, batch size = 16
+* SGD optimizer, Linear scheduler from lr=0.01 to lr=0.0001
+
+Metrics: 
+
+
+
+* AP@0.5: 0.988/0.991/0.987/0.991
+* AP@[0.5-0.95]: 0.768/0.769/0.766/0.774
+* mIoU (OpenVino converted): 0.851/0.853/0.852/0.854
+
+The 4 folds models are ensembled in inference with a weighted Boxes Fusion to select final boxes. As the final metric was “filtered F1”, it was important to lower the number of filtered IoU. This ensemble achieved mIoU=0.83 with only 225 removed images on public LB. I’ve generated the pseudo bounding boxes for the external dataset the same way to keep consistency.
+
+**Stage2**: Classifiers
+
+I’ve experimented with a few backbones from [Timm](https://github.com/huggingface/pytorch-image-models) such as EfficientNet, MaxViT, TinyViT, ResNet, MobileNet, DenseNet with different image sizes from 224 to 512. Best cross validation F1 was for TinyViT 384 and EfficientNetV2s 512.
+
+Training procedure and hyper-parameters:
+
+Cross validation is based on 4 same folds as for Yolo to keep end-to-end consistency.
+
+<table style="text-align: center;">
+  <tr>
+   <td>
+   </td>
+   <td style="background-color: #EEEEEE;">Tiny ViT 
+   </td>
+   <td style="background-color: #EEEEEE;">EfficientNet V2 small 
+   </td>
+  </tr>
+  <tr>
+   <td>Parameters
+   </td>
+   <td>22M
+   </td>
+   <td>20M
+   </td>
+  </tr>
+  <tr>
+   <td>Image size (BB resized)
+   </td>
+   <td>384x384
+   </td>
+   <td>512x512
+   </td>
+  </tr>
+  <tr>
+   <td>Total train images
+   </td>
+   <td>28k (10k initial + 17k external + 1k background)
+   </td>
+   <td>34k (10k initial + 23k external + 1k background)
+   </td>
+  </tr>
+  <tr>
+   <td>Normalization
+   </td>
+   <td>ImageNet
+<p>
+Mean: [0.485, 0.456, 0.406]
+<p>
+Std:[0.229, 0.224, 0.225]
+   </td>
+   <td>ImageNet
+<p>
+Mean: [0.485, 0.456, 0.406]
+<p>
+Std: [0.229, 0.224, 0.225]
+   </td>
+  </tr>
+  <tr>
+   <td>Augmentations
+   </td>
+   <td>Hard <p>
+(H/V, Rot90, Scale, Shift, Rotate, Colors, Blur/Noise, MixUp, CutMix, …)
+   </td>
+   <td>Hard <p>
+(H/V, Rot90, Scale, Shift, Rotate, Colors, Blur/Noise, MixUp, CutMix, …)
+   </td>
+  </tr>
+  <tr>
+   <td>Loss
+   </td>
+   <td>CrossEntropyLoss <p>
+with label smoothing=0.1
+   </td>
+   <td>CrossEntropyLoss
+<p>
+with label smoothing=0.1
+   </td>
+  </tr>
+  <tr>
+   <td>Optimizer
+   </td>
+   <td>AdamW
+   </td>
+   <td>AdamW
+   </td>
+  </tr>
+  <tr>
+   <td>Scheduler
+   </td>
+   <td>CosineAnnealingLR
+<p>
+From 1e-4 to 0.0
+   </td>
+   <td>CosineAnnealingLR
+<p>
+From 1e-3 to 0.0
+   </td>
+  </tr>
+  <tr>
+   <td>Batch size
+   </td>
+   <td>32
+   </td>
+   <td>32
+   </td>
+  </tr>
+  <tr>
+   <td>Epochs
+   </td>
+   <td>96
+   </td>
+   <td>96
+   </td>
+  </tr>
+  <tr>
+   <td>Weights averaging
+   </td>
+   <td>EMA
+   </td>
+   <td>EMA
+   </td>
+  </tr>
+  <tr>
+   <td>Mixed precision
+   </td>
+   <td>16-mixed
+   </td>
+   <td>16-mixed
+   </td>
+  </tr>
+  <tr>
+   <td>CV4 scores
+   </td>
+   <td>Precision: 0.877
+<p>
+Recall: 0.865
+<p>
+F1: <strong>0.870</strong>
+<p>
+(0.844/0.886/0.862/0.891)
+   </td>
+   <td>Precision: 0.867
+<p>
+Recall: 0.854
+<p>
+F1: <strong>0.860</strong> <p>
+(0.862/0.847/0.855/0.885)
+   </td>
+  </tr>
+  <tr>
+   <td>Public/Private LB
+   </td>
+   <td>
+   </td>
+   <td>
+   </td>
+  </tr>
+  <tr>
+   <td>Best single fold <p>
+(with HFlip TTA)
+   </td>
+   <td>F1:0.886/0.861 <p>
+F1 filtered: 0.822/0.798
+   </td>
+   <td>F1:0.877/0.868
+<p>
+F1 filtered: 0.803/0.800
+   </td>
+  </tr>
+  <tr>
+   <td>Full fit <p>
+(with HFlip TTA) 
+   </td>
+   <td>F1: 0.865/0.858 <p>
+F1 filtered: 0.816/0.806
+   </td>
+   <td>F1: 0.878/0.889 <p>
+F1 filtered: 0.809/0.810
+   </td>
+  </tr>
+  <tr>
+   <td>
+   </td>
+   <td colspan="2" style="background-color: #EEEEEE;">Ensemble
+   </td>
+  </tr>
+  <tr>
+   <td>Best single fold ensemble <p>
+(no TTA) 
+   </td>
+   <td colspan="2" >F1 0.892/0.891
+<p>
+F1 filtered: 0.816/0.814
+   </td>
+  </tr>
+  <tr>
+   <td>Full fit ensemble <p>
+(no TTA)
+   </td>
+   <td colspan="2" >F1: <strong>0.904</strong>/0.915
+<p>
+F1 filtered: 0.819/0.832
+   </td>
+  </tr>
+</table>
+
+The training procedure has been improved over time and the main boosts come from mixup augmentations, Exponential Moving Average (EMA) and label smoothing. Some labels are noisy and label smoothing helps to reduce overconfidence.
+
+The confusion matrix:
+
+
+<table style="text-align: center;">
+  <tr>
+   <td>
+
+
+<img src="images/tinyvit_cm.png" width="" alt="tinyvit_cm" title="tinyvit_cm">
+
+<p>
+Tiny ViT
+   </td>
+   <td>
+
+
+<img src="images/effnetv2s_cm.png" width="" alt="effnetv2s_cm" title="effnetv2s_cm">
+
+<p>
+EfficientNetv2s
+   </td>
+  </tr>
+</table>
+
+
+Here after the training chart for Tiny ViT and EfficientNetV2s folds.
+
+
+
+![alt_text](images/training_chart.png "image_tooltip")
+
+
+We can notice that the F1 metric improves and is quite stable up to the end of training.
+
+Inference was over the 2 seconds time limit with 4 folds so I’ve decided to switch to a full fit training. Full fit leads to a single model that relies on all data, weights are from the last checkpoint instead of the best checkpoint. It’s a common trick used by grandmasters to benefit from all data and limit overfit. It can be applied here using exactly the same hyper-parameters as the CV training procedure.
+
+**More feedback**
+
+Hereafter some feedback about other approaches I’ve experimented and that failed:
+
+What did not work (or not improve):
+
+
+
+* Increasing image size to 1024x1024 does not bring more information to the models
+* Forcing the aspect ratio to 1.0 on resize did not help
+* Single stage 6 classes Yolo model (both regular and RTDETR)
+* Single stage 6 classes EfficientDet model
+* SWA (last 5 epochs or best 5)
+* Model pruning (to speed up inference)
+* Post Processing: Add/remove margin to predicted BB
+* Training based on mosquito bodies only. I’ve created a mosquito body extractor but my conclusion is that relying on head, thorax, and dorsal is not enough. We have predictive power in mosquito legs and wings. It might be obvious for the experts but from past experiences on other species it was worth trying.
+
+
+
+<h1>Training procedure</h1>
+
+List of files and meta-data needed to train are under resources folder:
+
+
+<table>
+  <tr>
+   <td style="background-color: #EEEEEE;">File
+   </td>
+   <td style="background-color: #EEEEEE;">Description
+   </td>
+  </tr>
+  <tr>
+   <td>files_10k_bb_4folds.csv
+   </td>
+   <td>Files from <a href="https://www.aicrowd.com/challenges/mosquitoalert-challenge-2023/dataset_files">initial dataset</a> with bounding boxes recomputed from YoloV8n OOF. CV4 split provided in sgkf_fold_s42 column.
+   </td>
+  </tr>
+  <tr>
+   <td>background_1k.csv
+   </td>
+   <td>Background files cropped (no mosquito) from initial dataset.
+   </td>
+  </tr>
+  <tr>
+   <td>files_external_17k_inaturalist_kaggle_bb_4folds.csv
+   </td>
+   <td>INaturalist <a href="https://drive.google.com/file/d/1xrz2qMmzd2ut12g_EXkOSxzXGUSRSh8s/view?usp=drive_link">files</a> and Kaggle <a href="https://drive.google.com/file/d/1aXVaowHDaoDRK4PeqJN25lFcujQMJHUx/view?usp=drive_link">files </a>coming from external datasets declared on <a href="https://discourse.aicrowd.com/t/external-datasets-used-by-participants/9217">AICrowd thread</a> (CC-BY-NC and CC BY-SA licenses). They are filtered on high Yolov8N confidence. CV4 split provided in sgkf_fold_s42 column.
+   </td>
+  </tr>
+  <tr>
+   <td>files_external_6k_inaturalist_s3_bb_4folds.csv
+   </td>
+   <td>INaturalist <a href="https://discourse.aicrowd.com/t/external-datasets-used-by-participants/9217">files</a> (<a href="https://drive.google.com/file/d/1WxJByUzYtquscUl0XEsUrKs7Sqj5h6wf/view?usp=sharing">inat_images</a>, CC-BY-NC and CC BY-SA licenses) filtered on high Yolov8N confidence. Link to original image in “url” column. CV4 split provided in sgkf_fold_s42 column.
+   </td>
+  </tr>
+  <tr>
+   <td>effnetv2s.json
+   </td>
+   <td>EfficientNet model training hyper-parameters
+   </td>
+  </tr>
+  <tr>
+   <td>tinivit.json
+   </td>
+   <td>Tiny ViT training hyper-parameters
+   </td>
+  </tr>
+  <tr>
+   <td>yolov2n.json
+   </td>
+   <td>YoloV8 nano training hyper-parameters.
+   </td>
+  </tr>
+</table>
 
-## Badges
-On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge.
 
-## Visuals
-Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method.
+Initial images are available on AICrowd [here](https://www.aicrowd.com/challenges/mosquitoalert-challenge-2023/dataset_files) (9.6GB)
 
-## Installation
-Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection.
+Background image already cropped are available [here](https://mosquitoalert.s3.amazonaws.com/background.zip) (175MB)
 
-## Usage
-Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README.
+Stage2 images already downloaded and cropped are available [here](https://mosquitoalert.s3.amazonaws.com/stage2-images_and_external_boxes_yolov8n_v3_768_seed_42_1.4_oof.zip) (14GB).
 
-## Support
-Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc.
+**1/ Create the following folder structure:**
 
-## Roadmap
-If you have ideas for releases in the future, it is a good idea to list them in the README.
 
-## Contributing
-State if you are open to contributions and what your requirements are for accepting them.
+    data/
+       images/ (copy initial images here)
+           train_00000.jpeg
+           …
+           train_10356.jpeg
 
-For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self.
+           background/ (copy  background images here)
+                 background_00013.jpeg
+                 …
+                 background_10355.jpeg
 
-You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser.
-
-## Authors and acknowledgment
-Show your appreciation to those who have contributed to the project.
-
-## License
-For open source projects, say how it is licensed.
-
-## Project status
-If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers.
+       stage2_images_bb/ (copy cropped images here, both initial and external)
+           0.5038_86891872.png
+           …
+           Vivo-Y21_Ae-albopictus_s10_l2_t4_na.png.png
+
+**2/ Run classifiers training scripts:**
+
+You can remove --wandb_project if you don’t have a [wandb](https://wandb.ai/) account --full_train option if you want to train CV models to monitor F1/Precision/Recall score.
+
+<span style="text-decoration:underline;">TinyVit</span>: Around 8h per fold on 1 GPU with 24GB VRAM.
+
+python train_classifier.py --config ./resources/tinyvit.json --images_bb_path ./data/stage2_images_bb --background_path ./data/images/background --initial_cv ./resources/files_10k_bb_4folds.csv --background ./resources/background_1k.csv --external_cv ./resources/files_external_17k_inaturalist_kaggle_bb_4folds.csv --wandb_project MosquitoClassifier --full_train
+
+<span style="text-decoration:underline;">EfficientNetV2s</span>: Around 8h per fold on 1 GPU with 24GB VRAM.
+
+python train_classifier.py --config ./resources/effnetv2s.json --images_bb_path ./data/stage2_images_bb --background_path ./data/images/background --initial_cv ./resources/files_10k_bb_4folds.csv --background ./resources/background_1k.csv --external_cv ./resources/files_external_17k_inaturalist_kaggle_bb_4folds.csv,./resources/files_external_6k_inaturalist_s3_bb_4folds.csv --wandb_project MosquitoClassifier --full_train
diff --git a/apt.txt b/apt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2d50363365906af2d82f13e778e25af7dbd99d22
--- /dev/null
+++ b/apt.txt
@@ -0,0 +1,4 @@
+git
+ffmpeg
+libsm6
+libxext6
diff --git a/images/effnetv2s_cm.png b/images/effnetv2s_cm.png
new file mode 100644
index 0000000000000000000000000000000000000000..0a967293bf719ef467af454373f2696732327f20
--- /dev/null
+++ b/images/effnetv2s_cm.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:12b02a42449464f44b8f19ad2ecc251cd135f212ac0fe3f1987b8fabf567153e
+size 52238
diff --git a/images/inference.png b/images/inference.png
new file mode 100644
index 0000000000000000000000000000000000000000..5541e5fd142320135c9ac58d4a7fb285865e779c
--- /dev/null
+++ b/images/inference.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7b1f690cf7c13266729b0de510be225cde7dec047adf6bd527c8d668fee4fe97
+size 50288
diff --git a/images/tinyvit_cm.png b/images/tinyvit_cm.png
new file mode 100644
index 0000000000000000000000000000000000000000..aa5924265b3943ca1156ff84dddb76fd8ada2d0d
--- /dev/null
+++ b/images/tinyvit_cm.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6c9984f0caabbfe7ca01428f3afd67ba4fcbeddfe902ea9490f072c9649a9769
+size 52481
diff --git a/images/training.png b/images/training.png
new file mode 100644
index 0000000000000000000000000000000000000000..221b87f9284122d1987799d8920c8fba09bd01eb
--- /dev/null
+++ b/images/training.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e60422ae7b42fe20115f13fe7afbb544d9405c44dc4c01c6bd3cf45a39225dea
+size 103023
diff --git a/images/training_chart.png b/images/training_chart.png
new file mode 100644
index 0000000000000000000000000000000000000000..1f48cdeadadec7884395ae673795ab1c72839330
--- /dev/null
+++ b/images/training_chart.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:afaf1fa9990fdb2e42be73bcf343f4818e10511108a6880d51bc03021197f50d
+size 377994
diff --git a/mqt/__init__.py b/mqt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mqt/models/__init__.py b/mqt/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mqt/models/classifier.py b/mqt/models/classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ca9243deb69db82762d07bdf0a1a1fba5a9d2dc
--- /dev/null
+++ b/mqt/models/classifier.py
@@ -0,0 +1,144 @@
+import numpy as np
+import torch
+import timm
+from torchmetrics import F1Score, Precision, Recall
+import pytorch_lightning as L
+from torchmetrics.functional.classification import multiclass_f1_score, multiclass_recall, multiclass_precision
+from torch.nn import functional as F
+
+from mqt.training.mix import cutmix_data, mixup_data, mixup_cross_entropy
+
+
+def hflip(data):
+    w = data.shape[-1]
+    return data[..., torch.arange(w - 1, -1, -1, device=data.device)]
+
+
+# (*, C, H, W)
+def vflip(data):
+    h = data.shape[-2]
+    return data[..., torch.arange(h - 1, -1, -1, device=data.device), :]
+
+
+class MosquitoModel(L.LightningModule):
+    def __init__(self, config):
+        super().__init__()
+
+        self.config = config
+        self.tta = False
+
+        self.backbone = timm.create_model(self.config.backbone, pretrained=config.pretrained,
+                                          num_classes=config.num_classes, global_pool=config.global_pool)
+        self.head = None
+
+        self.valid_f1 = F1Score(task="multiclass", num_classes=config.num_classes, average="macro")
+        self.valid_precision = Precision(task="multiclass", num_classes=config.num_classes, average="macro")
+        self.valid_recall = Recall(task="multiclass", num_classes=config.num_classes, average="macro")
+
+        # self.save_hyperparameters()
+
+    def forward(self, x):
+        batch_size, channels, width, height = x.size()
+        # Features
+        x = self.backbone(x)
+        # Classifier
+        x = self.head(x) if self.head is not None else x
+        # return logits
+        return x
+
+    def _get_preds_loss_metrics(self, batch, is_valid=False):
+        '''convenience function since train/valid/test steps are similar'''
+        x, y = batch["image"], batch["label"]  # (BS, C, H, W), (BS, C)
+
+        mixup_batch = False
+        # CutMix + MixUp
+        if (self.config.cutmix_prob is not None) and (self.config.mixup_prob is not None) and (is_valid == False):
+            if np.random.random() > (self.config.mixup_prob + self.config.cutmix_prob) * 0.5:
+                if np.random.random() > 0.50:  # 0.20: # 0.5
+                    x, y_a, y_b, lam = cutmix_data(x, y, alpha=self.config.cutmix_alpha)
+                else:
+                    x, y_a, y_b, lam = mixup_data(x, y, alpha=self.config.mixup_alpha)
+                mixup_batch = True
+        else:
+            # MixUp
+            if (self.config.mixup_prob is not None) and (is_valid == False):
+                if np.random.random() > self.config.mixup_prob:
+                    # # y_a is original target, y_b is the permuted one, lam is the factor
+                    x, y_a, y_b, lam = mixup_data(x, y, alpha=self.config.mixup_alpha)
+                    mixup_batch = True
+            # CutMix
+            if (self.config.cutmix_prob is not None) and (is_valid == False):
+                if np.random.random() > self.config.cutmix_prob:
+                    # # y_a is original target, y_b is the permuted one, lam is the factor
+                    x, y_a, y_b, lam = cutmix_data(x, y, alpha=self.config.cutmix_alpha)
+                    mixup_batch = True
+
+        if mixup_batch == True:
+            logits = self(x)  # (BS, NC)
+            preds = torch.argmax(logits, dim=1)  # (BS)
+            loss = mixup_cross_entropy(logits, y_a, y_b, lam, label_smoothing=self.config.label_smoothing)
+            # loss = mixup_focal_loss(logits, y_a, y_b, lam, label_smoothing=self.config.label_smoothing)
+            y_true_mixed = (lam * y_a + (1 - lam) * y_b)  # (BS, C) Two class enabled with prob
+            y_true = torch.argmax(y_true_mixed, dim=1)  # (BS) # the highest one for training
+        else:
+            logits = self(x)  # (BS, NC)
+            if (is_valid == True) and (self.config.num_classes > 6):  # (BACKGROUND_HOME is not None):
+                logits = logits[:, 0:6]
+                y = y[:, 0:6]
+            preds = torch.argmax(logits, dim=1)  # (BS, C)
+            loss = F.cross_entropy(logits, y, label_smoothing=self.config.label_smoothing)
+            # loss = focal_loss(logits, y, label_smoothing=self.config.label_smoothing)
+            y_true = torch.argmax(y, dim=1)  # (BS)
+
+        f1 = multiclass_f1_score(preds, y_true, num_classes=self.config.num_classes, average="macro")
+        recall = multiclass_recall(preds, y_true, num_classes=self.config.num_classes, average="macro")
+        precision = multiclass_precision(preds, y_true, num_classes=self.config.num_classes, average="macro")
+        # Accumulate predictions/ground truth
+        if is_valid:
+            self.valid_f1.update(preds, y_true)
+            self.valid_precision.update(preds, y_true)
+            self.valid_recall.update(preds, y_true)
+        return preds, loss, f1, recall, precision
+
+    def training_step(self, batch, batch_idx):
+        # training_step defines the train loop. It is independent of forward
+        _, loss, f1, recall, precision = self._get_preds_loss_metrics(batch)
+        # Log loss and metric
+        self.log('train_step_loss', loss, batch_size=self.config.batch_size)
+        self.log('train_step_f1', f1, batch_size=self.config.batch_size)
+        self.log('train_step_recall', recall, batch_size=self.config.batch_size)
+        self.log('train_step_precision', precision, batch_size=self.config.batch_size)
+        return loss
+
+    def validation_step(self, batch, batch_idx):
+        preds, loss, f1, recall, precision = self._get_preds_loss_metrics(batch, is_valid=True)
+        # Log loss and metric
+        self.log('val_step_loss', loss, batch_size=self.config.val_batch_size)
+        self.log('val_step_f1', f1, batch_size=self.config.val_batch_size)
+        self.log('val_step_recall', recall, batch_size=self.config.val_batch_size)
+        self.log('val_step_precision', precision, batch_size=self.config.val_batch_size)
+        return preds
+
+    def on_validation_epoch_end(self):
+        self.log('val_f1', self.valid_f1.compute())
+        self.log('val_precision', self.valid_precision.compute())
+        self.log('val_recall', self.valid_recall.compute())
+        self.valid_f1.reset()
+        self.valid_precision.reset()
+        self.valid_recall.reset()
+
+    def predict_step(self, batch, batch_idx):
+        x = batch["image"]
+        logits = self(x)
+        logits = logits[:, 0:6]
+        if self.tta == True:
+            logits_tta = self(hflip(x))
+            logits_tta = logits_tta[:, 0:6]
+            logits = torch.mean(torch.stack([logits, logits_tta]), dim=0)
+        preds = torch.argmax(logits, dim=1)
+        return preds
+
+    def configure_optimizers(self):
+        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.lr0)
+        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.config.epochs, eta_min=self.config.lrf)
+        return [optimizer], [scheduler]
diff --git a/mqt/training/__init__.py b/mqt/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mqt/training/mix.py b/mqt/training/mix.py
new file mode 100644
index 0000000000000000000000000000000000000000..04d5111ec6fd518dc856f891c4a468c48e37754c
--- /dev/null
+++ b/mqt/training/mix.py
@@ -0,0 +1,62 @@
+import torch
+import numpy as np
+from torch.nn import functional as F
+
+
+def mixup_data(x, y, alpha=1.0, use_cuda=True):
+    if alpha > 0:
+        lam = np.random.beta(alpha, alpha)
+    else:
+        lam = 1
+
+    batch_size = x.size()[0]
+    if use_cuda:
+        index = torch.randperm(batch_size).cuda()
+    else:
+        index = torch.randperm(batch_size)
+        print("index", index)
+
+    mixed_x = lam * x + (1 - lam) * x[index, :]
+    y_a, y_b = y, y[index]
+
+    return mixed_x, y_a, y_b, lam
+
+
+def mixup_cross_entropy(pred, y_a, y_b, lam, label_smoothing=0.0):
+    return lam * F.cross_entropy(pred, y_a, label_smoothing=label_smoothing) + (1 - lam) * F.cross_entropy(pred, y_b, label_smoothing=label_smoothing)
+
+
+def rand_bbox(size, lam):
+    W = size[2]
+    H = size[3]
+    cut_rat = np.sqrt(1.0 - lam)
+    cut_w = np.int32(W * cut_rat)
+    cut_h = np.int32(H * cut_rat)
+
+    # uniform
+    cx = np.random.randint(W)
+    cy = np.random.randint(H)
+
+    bbx1 = np.clip(cx - cut_w // 2, 0, W)
+    bby1 = np.clip(cy - cut_h // 2, 0, H)
+    bbx2 = np.clip(cx + cut_w // 2, 0, W)
+    bby2 = np.clip(cy + cut_h // 2, 0, H)
+
+    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
+    return bbx1, bby1, bbx2, bby2, lam
+
+
+def cutmix_data(x, y, alpha=1.0, device="cuda"):
+    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
+
+    index = torch.randperm(x.size()[0]).to(device)
+
+    bbx1, bby1, bbx2, bby2, lam = rand_bbox(x.size(), lam)
+
+    mixed_x = x.clone()
+
+    mixed_x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
+    y_a, y_b = y, y[index]
+
+    return mixed_x, y_a, y_b, lam
+
diff --git a/mqt/training/train_bb_detector.py b/mqt/training/train_bb_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a3872e39385443a0ccdb42d53c82ed882e41c8
--- /dev/null
+++ b/mqt/training/train_bb_detector.py
@@ -0,0 +1,353 @@
+import numpy as np
+import pandas as pd
+import glob
+import sys
+
+from my_models.utils.torch import SEEDS, save_config
+
+pd.set_option('display.max_rows', 100)
+pd.set_option('display.max_columns', 100)
+pd.set_option('display.max_colwidth', None)
+import ultralytics
+
+import os
+from tqdm.auto import tqdm
+import torch
+import cv2
+import wandb
+from PIL import Image
+
+from ultralytics import YOLO
+from ultralytics import RTDETR
+
+print("Python", sys.version)
+print("Numpy", np.__version__)
+print("Pandas", pd.__version__)
+print("Pytorch", torch.__version__)
+print("Ultralytics", ultralytics.__version__)
+print("WandB", wandb.__version__)
+
+NOISY = [
+    # Hard cases
+    "train_10156",  # Bad label, bbox out of image
+    "train_10160",  # Bad label, bbox out of image
+    "train_07228",  # Broken label width=1
+
+    # False positives
+    "train_05448",  # Bad label, wood instead of mosquito
+    "train_07395",  # Bad label, wood instead of mosquito
+    "train_06988",  # Bad label, skin instead of mosquito
+    "train_07454",  # Bad label, skin instead of mosquito
+    "train_10213",  # Bad label, skin instead of mosquito
+    "train_09453",  # Bad label, skin instead of mosquito
+    "train_08742",  # Bad label, skin instead of mosquito
+    "train_07478",  # Bad label, bottle instead of mosquito
+    "train_08904",  # Bad label, bottle instead of mosquito
+    "train_08090",  # Bad label, keyboard instead of mosquito
+    "train_08107",  # Bad label, belt instead of mosquito
+    "train_08201",  # Bad label, ground instead of mosquito
+    "train_09330",  # Bad label, ground instead of mosquito
+    "train_08553",  # Bad label, finger instead of mosquito
+    "train_06737",  # Bad label, finger instead of mosquito
+    "train_09065",  # Bad label, shape instead of mosquito
+    "train_10209",  # Bad label, feet instead of mosquito
+
+    # Discovered later, used in 1.7+
+    "train_07512",  # Bad label, wall instead of mosquito
+]
+
+SINGLE_CLASS = True
+FOLDS = 4
+VERSION = "v3"
+IMAGE_SIZE = 768
+DEVICE = "cpu"
+
+
+def generate_yolo_labels(train_pd, data_home):
+
+    # Generate YOLO labels
+    for idx, row in train_pd.iterrows():
+
+        xmin_, ymin_, xmax_, ymax_ = row["bbx_xtl"], row["bbx_ytl"], row["bbx_xbr"], row["bbx_ybr"]
+        w = row["img_w"]
+        h = row["img_h"]
+        label = row["label"]
+        if SINGLE_CLASS:
+            label = 0
+
+        if xmin_ < 0:
+            xmin_ = 0
+        if xmin_ > w:
+            print("w", row["uid"])
+            xmin_ = 0  # w
+        if ymin_ < 0:
+            ymin_ = 0
+        if ymin_ > h:
+            print("h", row["uid"])
+            ymin_ = 0  # h
+
+        if xmax_ < 0:
+            xmax_ = 0
+        if xmax_ > w:
+            print("w", row["uid"])
+            xmax_ = w
+        if ymax_ < 0:
+            ymax_ = 0
+        if ymax_ > h:
+            print("h", row["uid"])
+            ymax_ = h
+
+        filename = row["img_fName"].replace(".jpeg", ".txt")
+
+        bbw = xmax_ - xmin_
+        assert bbw > 0, "Bad width"
+        bbh = ymax_ - ymin_
+        assert bbh > 0, "Bad height"
+
+        # class x_center y_center width height
+        bbw = bbw / w
+        bbh = bbh / h
+        bbx = ((xmax_ + xmin_) / 2.0) / w
+        bby = ((ymax_ + ymin_) / 2.0) / h
+
+        with open(os.path.join(data_home, "labels", filename), "w") as f:
+            f.write(str(label) + " " + str(bbx) + " " + str(bby) + " " + str(bbw) + " " + str(bbh))
+
+
+def generate_yolo_list(train_home, train_pd, seed=SEEDS[0], exclude_noisy=NOISY, folds=FOLDS, single_class=SINGLE_CLASS):
+
+    # Generate YOLO lists
+    for fold_ in range(folds):
+        x_valid = train_pd[train_pd["sgkf_fold_s%d" % seed] == fold_]
+        x_train = train_pd[train_pd["sgkf_fold_s%d" % seed] != fold_]
+
+        x_valid_cleaned = x_valid[~(x_valid["uid"].isin(exclude_noisy))]
+        x_train_cleaned = x_train[~(x_train["uid"].isin(exclude_noisy))]
+        print("Hard noise removed", x_valid.shape, x_valid_cleaned.shape, x_train.shape, x_train_cleaned.shape)
+
+        x_valid_files = [os.path.join(".", c) for c in x_valid_cleaned["img_fName"].values]
+        x_train_files = [os.path.join(".", c) for c in x_train_cleaned["img_fName"].values]
+
+        # Add some background images
+        bgs = glob.glob(os.path.join(train_home, "background") + "/*.jpeg")
+        bgs = [c.replace(train_home, ".") for c in bgs]
+        x_train_files.extend(bgs)
+
+        with open(os.path.join(train_home, "train_%s_%d_%d.txt" % (VERSION, seed, fold_)), "w") as f:
+            for c in x_train_files:
+                f.write(c + "\n")
+
+        with open(os.path.join(train_home, "valid_%s_%d_%d.txt" % (VERSION, seed, fold_)), "w") as f:
+            for c in x_valid_files:
+                f.write(c + "\n")
+
+        # https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml
+        if single_class:
+            yaml = f'''
+    # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+    path: data  # dataset root dir
+    train: images/train_{VERSION}_{seed}_{fold_}.txt
+    val: images/valid_{VERSION}_{seed}_{fold_}.txt
+    test:  # test images (optional)
+
+    # Classes
+    names:
+      0: mosquito
+
+    '''
+        else:
+            yaml = f'''
+    # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
+    path: data  # dataset root dir
+    train: images/train_{VERSION}_{seed}_{fold_}.txt
+    val: images/valid_{VERSION}_{seed}_{fold_}.txt
+    test:  # test images (optional)
+
+    # Classes
+    names:
+      0: aegypti
+      1: albopictus
+      2: anopheles
+      3: culex
+      4: culiseta
+      5: japonicus/koreicus
+
+    '''
+
+        with open("mqt_%s_%d_%d.yaml" % (VERSION, seed, fold_), "w") as f:
+            f.write(yaml)
+        # break
+
+
+def get_model(architecture, weights):
+    if architecture == "YOLO":
+        model = YOLO(weights, task='detect')
+    elif architecture == "RTDETR":
+        model = RTDETR(weights)
+    else:
+        raise Exception("Architecture not found", architecture)
+    return model
+
+
+def train_yolo(config, architecture="YOLO", model_backbone="yolov8n", image_size=IMAGE_SIZE,
+               models_folder="mosquito_yolo_models", seed=SEEDS[0]):
+
+    resume_fold = 0
+    for fold_ in range(FOLDS):
+        torch.cuda.empty_cache()
+        if fold_ < resume_fold:
+            continue
+
+        model_name = "%s_%s_%d_seed_%d_fold%d_%s" % (model_backbone, VERSION, image_size, seed, fold_, "1.4")
+        os.makedirs(os.path.join("./" + models_folder, model_name), exist_ok=True)
+        d = save_config(config, os.path.join("./" + models_folder, model_name, "config.json"))
+
+        resume = False
+        weights_file = os.path.join("./" + models_folder, model_name, "weights", "last.pt")
+        if os.path.exists(weights_file):
+            print("Weights available, trying to resume from", weights_file)
+            model = get_model(architecture, weights_file)
+            resume = True
+        else:
+            # New pretrained model
+            model = get_model(architecture, "%s.pt" % model_backbone)
+
+        results = model.train(data="mqt_%s_%d_%d.yaml" % (VERSION, seed, fold_), project=models_folder, name=model_name,
+                              exist_ok=True, resume=resume, **d)
+
+        wandb.finish()
+
+
+def export_to_openvino(architecture="YOLO", model_backbone="yolov8n", image_size=IMAGE_SIZE, folds=FOLDS,
+                       models_folder="mosquito_yolo_models", seed=SEEDS[0]):
+    for fold_ in range(folds):
+        model_name = "%s_%s_%d_seed_%d_fold%d_%s" % (model_backbone, VERSION, IMAGE_SIZE, seed, fold_, "1.4")
+        model = get_model(architecture, "%s/%s/weights/best.pt" % (models_folder, model_name))
+        model.export(format='openvino', imgsz=image_size, half=False)
+        model = get_model(architecture, "%s/%s/weights/last.pt" % (models_folder, model_name))
+        model.export(format='openvino', imgsz=image_size, half=False)
+
+
+def predict_image(model, img, imgsz=IMAGE_SIZE, uid="no-uid"):
+    # Inference
+    outputs = model.predict(source=img, imgsz=imgsz, max_det=1, conf=0.00001, iou=0.7, augment=False,
+                            device=DEVICE, verbose=False)
+    # Extract BB
+    best_box = None
+    best_score = None
+    best_label = None
+    for r in outputs:
+        boxes = r.boxes.cpu().numpy()
+        for bbox in boxes:
+            box = bbox.xyxy[0]  # get box coordinates in (top, left, bottom, right) format
+            score = bbox.conf[0]
+            label = bbox.cls[0]
+            # print(box, score, label)
+            best_box = box if best_box is None else best_box
+            best_score = score if best_score is None else best_score
+            best_label = label if best_label is None else best_label
+            if score > best_score:
+                best_score = score
+                best_box = box
+                best_label = label
+    xmin_, ymin_, xmax_, ymax_ = None, None, None, None
+    if best_box is not None:
+        xmin_, ymin_, xmax_, ymax_ = best_box[0], best_box[1], best_box[2], best_box[3]
+    return (uid, xmin_, ymin_, xmax_, ymax_, best_score, best_label)
+
+
+def compute_bb_oof(train_pd, train_home, architecture="YOLO", model_backbone="yolov8n", image_size=IMAGE_SIZE, folds=FOLDS,
+                   models_folder="mosquito_yolo_models", seed=SEEDS[0], openvino=True):
+    results_ = []
+    for fold_ in range(folds):
+        # fold_ = 0
+        model_name = "%s_%s_%d_seed_%d_fold%d_%s" % (model_backbone, VERSION, image_size, seed, fold_, "1.4")
+        if openvino:
+            filename = "%s/%s/weights/%s_openvino_model/" % (models_folder, model_name, "best")
+        else:
+            filename = "%s/%s/weights/%s.pt" % (models_folder, model_name, "best")
+
+        print("Loading", filename)
+        model = get_model(architecture, filename)
+
+        x_valid = train_pd[train_pd["sgkf_fold_s%d" % seed] == fold_]
+
+        for idx, row in tqdm(x_valid.iterrows(), total=len(x_valid)):
+            uid = row["uid"]
+            filename = row["img_fName"]
+            img = Image.open(os.path.join(train_home, filename))
+
+            # BGR format expected for np.array (like CV2)
+            img = np.array(img)
+            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+
+            uid, xmin_, ymin_, xmax_, ymax_, best_score, best_label = predict_image(model, img, uid=uid)
+            results_.append((uid, xmin_, ymin_, xmax_, ymax_, best_score, best_label))
+
+    results_pd = pd.DataFrame(results_, columns=["uid", "bbx_xtl", "bbx_ytl", "bbx_xbr", "bbx_ybr", "score", "label"])
+    MODEL = "oof_ov"
+    results_pd.rename(
+        columns={"bbx_xtl": "%s_bbx_xtl" % MODEL, "bbx_ytl": "%s_bbx_ytl" % MODEL, "bbx_xbr": "%s_bbx_xbr" % MODEL,
+                 "bbx_ybr": "%s_bbx_ybr" % MODEL, "score": "%s_score" % MODEL, "label": "%s_label" % MODEL},
+        inplace=True)
+
+    return results_pd
+
+
+def get_iou(bb1, bb2):
+    """
+    Calculate the Intersection over Union (IoU) of two bounding boxes.
+
+    Parameters
+    ----------
+    bb1 : dict
+        Keys: {'x1', 'x2', 'y1', 'y2'}
+        The (x1, y1) position is at the top left corner,
+        the (x2, y2) position is at the bottom right corner
+    bb2 : dict
+        Keys: {'x1', 'x2', 'y1', 'y2'}
+        The (x, y) position is at the top left corner,
+        the (x2, y2) position is at the bottom right corner
+
+    Returns
+    -------
+    float
+        in [0, 1]
+    """
+    assert bb1["x1"] < bb1["x2"]
+    assert bb1["y1"] < bb1["y2"]
+    assert bb2["x1"] < bb2["x2"]
+    assert bb2["y1"] < bb2["y2"]
+
+    # determine the coordinates of the intersection rectangle
+    x_left = max(bb1["x1"], bb2["x1"])
+    y_top = max(bb1["y1"], bb2["y1"])
+    x_right = min(bb1["x2"], bb2["x2"])
+    y_bottom = min(bb1["y2"], bb2["y2"])
+
+    if x_right < x_left or y_bottom < y_top:
+        return 0.0
+
+    # The intersection of two axis-aligned bounding boxes is always an
+    # axis-aligned bounding box
+    intersection_area = (x_right - x_left) * (y_bottom - y_top)
+
+    # compute the area of both AABBs
+    bb1_area = (bb1["x2"] - bb1["x1"]) * (bb1["y2"] - bb1["y1"])
+    bb2_area = (bb2["x2"] - bb2["x1"]) * (bb2["y2"] - bb2["y1"])
+
+    # compute the intersection over union by taking the intersection
+    # area and dividing it by the sum of prediction + ground-truth
+    # areas - the interesection area
+    iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
+    assert iou >= 0.0
+    assert iou <= 1.0
+    return iou
+
+
+def compute_iou(x):
+    MODEL = "oof_ov"
+    bb1 = {'x1':x["%s_bbx_xtl"%MODEL], 'x2':x["%s_bbx_xbr"%MODEL], 'y1':x["%s_bbx_ytl"%MODEL], 'y2':x["%s_bbx_ybr"%MODEL]}
+    bb2 = {'x1':x["bbx_xtl"], 'x2':x["bbx_xbr"], 'y1':x["bbx_ytl"], 'y2':x["bbx_ybr"]}
+    return get_iou(bb1, bb2)
+
diff --git a/mqt/training/train_classifier.py b/mqt/training/train_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..31fb3247432144e013d91fe76888a90d71c392b0
--- /dev/null
+++ b/mqt/training/train_classifier.py
@@ -0,0 +1,381 @@
+import torch
+import os
+import random
+import numpy as np
+import pandas as pd
+from PIL import Image
+import cv2
+import albumentations as A
+from albumentations.pytorch import ToTensorV2
+import torch.utils.data as data
+import wandb
+
+import pytorch_lightning as L
+from pytorch_lightning.loggers import WandbLogger
+from pytorch_lightning.loggers import CSVLogger
+from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning.callbacks import LearningRateMonitor
+from pytorch_lightning.callbacks import StochasticWeightAveraging
+from pytorch_lightning.callbacks import ModelPruning
+from pytorch_lightning import seed_everything
+
+# Apache2 license
+from nemo.collections.common.callbacks.ema import EMA
+from nemo.utils.callbacks.nemo_model_checkpoint import NeMoModelCheckpoint
+
+from mqt.models.classifier import MosquitoModel
+from my_models.utils.torch import save_config, Config
+
+
+def resize(new_size, conf, p=1.0):
+    if conf.ar is None:
+        return A.Compose([
+            A.Resize(new_size, new_size, interpolation=cv2.INTER_LINEAR, p=1.0, always_apply=True),
+        ], p=p)
+    elif conf.ar == 1.0:
+        return A.Compose([
+            A.LongestMaxSize(max_size=new_size, interpolation=cv2.INTER_LINEAR, p=1.0, always_apply=True),
+            A.PadIfNeeded(min_height=new_size, min_width=new_size, border_mode=cv2.BORDER_CONSTANT,
+                          value=(114, 114, 114), p=1.0, always_apply=True),
+        ], p=p)
+    elif conf.ar == 0.0:
+        return A.Compose([
+            A.PadIfNeeded(min_height=new_size, min_width=new_size, border_mode=cv2.BORDER_CONSTANT,
+                          value=(114, 114, 114), p=1.0, always_apply=True),
+            A.LongestMaxSize(max_size=new_size, interpolation=cv2.INTER_LINEAR, p=1.0, always_apply=True),
+            A.PadIfNeeded(min_height=new_size, min_width=new_size, border_mode=cv2.BORDER_CONSTANT,
+                          value=(114, 114, 114), p=1.0, always_apply=True),
+        ], p=p)
+
+
+def random_crop(new_size, p=1.0):
+    return A.Compose([
+        A.NoOp(p=1.0, always_apply=True),
+        # A.PadIfNeeded(min_height=new_size, min_width=new_size, border_mode=cv2.BORDER_CONSTANT, value=(114, 114, 114), p=1.0, always_apply=True),
+        # A.RandomCrop(new_size, new_size, p=1.0, always_apply=True),
+    ], p=p)
+
+
+def normalize(mean, std, max_pixel, p=1.0):
+    return A.Compose([
+
+        A.Normalize(mean=mean, std=std, max_pixel_value=max_pixel, p=1.0, always_apply=True),
+        ToTensorV2(p=1.0, always_apply=True)
+
+    ], p=p)
+
+
+class Rotate90_270(A.RandomRotate90):
+    def get_params(self):
+        return {"factor": random.choice([1, 3])}
+
+
+class Dataset(torch.utils.data.Dataset):
+
+    def __init__(self, data_path, df, conf, subset='train', preprocess=None, augment=None, prepare=None):
+        self.data_path = data_path
+        self.df = df
+        self.conf = conf
+        self.subset = subset
+        self.preprocess = preprocess
+        self.augment = augment
+        self.prepare = prepare
+
+        self.generate_background = random_crop(self.conf.imgsz * 2, p=1.0)
+
+    def read_image_(self, record_id, bg_path):
+        if "background_" in record_id:
+            img = self.generate_background(image=np.array(Image.open(bg_path)))["image"]
+        else:
+            img = np.array(Image.open(os.path.join(self.data_path, record_id + ".png")))
+        return img
+
+    def read_image(self, row):
+        record_id = row["uid"]
+        return self.read_image_(record_id, bg_path=row["img_fName"])
+
+    def get_data(self, row, idx):
+        img = self.read_image(row)
+
+        sample = {
+            'image': img,
+            'weight': 1,
+            "uid": row["uid"],
+        }
+
+        # Optional preprocessing on RGB image (float)
+        if self.preprocess:
+            tmp = self.preprocess(image=sample['image'])
+            sample['image'] = tmp["image"]  # Apply on full image
+
+        # Optional augmentation on RGB image (float)
+        if self.augment:
+            tmp = self.augment(image=sample['image'])
+            sample['image'] = tmp["image"]  # Apply on full image
+
+        # Mandatory to feed model (normalization, convert to CHW)
+        if self.prepare:
+            tmp = self.prepare(image=sample['image'])
+            sample['image'] = tmp["image"]  # Apply on full image
+
+        if self.subset != "test":
+            class_ = int(row["label"])
+            label = np.zeros(self.conf.num_classes, dtype=np.float32)
+            label[class_] = 1.
+            sample["label"] = label
+
+        return sample
+
+    def __len__(self):
+        return len(self.df)
+
+    def __getitem__(self, idx):
+        if torch.is_tensor(idx):
+            idx = idx.tolist()
+        row = self.df.iloc[idx]
+        sample = self.get_data(row, idx)
+        return sample
+
+
+def hardm_augmentation_train(conf, p=1.0):
+    return A.Compose([
+
+        # Flips/Rotate
+        A.OneOf([
+            A.HorizontalFlip(p=0.25),
+            A.VerticalFlip(p=0.25),
+            Rotate90_270(p=0.50)],
+            p=0.75),
+
+        # Random crop resize
+        A.RandomResizedCrop(conf.image_size, conf.image_size, scale=(0.75, 1.25), ratio=(3. / 4., 4. / 3.),
+                            interpolation=cv2.INTER_LINEAR, p=0.5),
+
+        # ShiftScaleRotate
+        A.ShiftScaleRotate(shift_limit=0.10, scale_limit=0.2, rotate_limit=30, p=0.75),
+
+        # Colors/Channels
+        A.OneOf([
+            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.2, p=0.3),
+            A.RandomGamma(gamma_limit=(80, 120), p=0.2),
+            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.3),
+            A.CLAHE(p=0.25),
+        ], p=0.5),
+
+        # Misc
+        A.OneOf([
+            A.OpticalDistortion(distort_limit=0.2, shift_limit=0.05, p=0.25),
+            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.25),
+            A.JpegCompression(quality_lower=8, quality_upper=70, p=0.25),
+            A.ToGray(p=0.025),
+        ], p=0.5),
+
+        # Blur/Noise
+        A.OneOf([
+            A.GaussianBlur(blur_limit=(3, 5), p=0.25),
+            A.MotionBlur(blur_limit=(3, 5), p=0.25),
+            A.GaussNoise(var_limit=(5.0, 30.0), p=0.25),
+            A.CoarseDropout(max_holes=1, max_height=72, max_width=72, p=0.25),
+        ], p=0.5),
+
+    ], p=p)
+
+
+# Noisy bounding boxes from mosquito detector OOF
+BAD_YOLO8N_768_OOF = [
+    "train_09648",
+    "train_09387",
+    "train_08941",
+    "train_01447",
+]
+
+
+def train_cls(train_boxes_home, config, train_pd, train_background_pd=None, external_pd=None,
+              wandb_project=None, models_home="./mosquito_models", full_train=False,
+              exclude_noisy=BAD_YOLO8N_768_OOF):
+
+    seed_everything(config.seed, workers=True)
+
+    preprocess_image = resize(config.imgsz, config, p=1.0) if config.imgsz is not None else None
+    image_augmentation_train = hardm_augmentation_train(config, p=1.0)
+    prepare_feed = normalize(config.IMG_MEAN, config.IMG_STD, config.max_pixel, p=1.0)
+
+    if full_train:
+        fold_ = 99
+        x_train = train_pd.copy()
+        x_train_cleaned = x_train[~(x_train["uid"].isin(exclude_noisy))]
+        print("Hard noise removed", x_train.shape, x_train_cleaned.shape)
+
+        if train_background_pd is not None:
+            x_train_cleaned = pd.concat([x_train_cleaned, train_background_pd], ignore_index=True)
+            print("Background class added to x_train_cleaned", x_train_cleaned.shape)
+
+        if external_pd is not None:
+            x_train_cleaned = pd.concat([x_train_cleaned, external_pd], ignore_index=True)
+            print("External added to x_train_cleaned", x_train_cleaned.shape)
+
+        train_dataset = Dataset(train_boxes_home, x_train_cleaned, config, subset="train", preprocess=preprocess_image,
+                                augment=image_augmentation_train, prepare=prepare_feed)
+        train_sampler = None # get_sampler(config.sampler, ds=train_dataset)
+        # print("Train sampler", train_sampler)
+        train_batch_sampler = None  # get_batch_sampler(config, ds=train_dataset)
+        # print("Train batch sampler", train_batch_sampler)
+
+        train_dataloader = data.DataLoader(train_dataset,
+                                           batch_size=config.batch_size if train_batch_sampler is None else 1,
+                                           sampler=train_sampler, batch_sampler=train_batch_sampler, drop_last=False,
+                                           num_workers=config.num_workers, shuffle=True if (
+                        (train_sampler is None) and (train_batch_sampler is None)) else False, pin_memory=True)
+
+        model = MosquitoModel(config)
+
+        model_path = "classifier_%s_%s_%s/seed%s/%s/%s" % (
+        config.imgsz, config.backbone, config.version, config.seed, "fold%d" % fold_, "stage1")
+        default_root_dir = os.path.join(models_home, model_path)
+        os.makedirs(default_root_dir, exist_ok=True)
+        d = save_config(config, os.path.join(default_root_dir, "config.json"))
+
+        logger_wandb = None
+        if wandb_project is not None:
+            logger_wandb = WandbLogger(project=wandb_project, name=model_path.replace("/", "_"))
+        logger_csv = CSVLogger("./logs", name=model_path.replace("/", "_"))
+
+        checkpoint_callback = ModelCheckpoint(dirpath=default_root_dir, monitor="train_step_f1", mode='max',
+                                              save_weights_only=True, filename='best_{epoch}-{train_step_f1:.3f}',
+                                              save_top_k=config.save_top_k, save_last=True)
+        lr_monitor = LearningRateMonitor(logging_interval='epoch')
+        callbacks = [lr_monitor]
+
+        if config.ema is not None:
+            print("EMA enabled:", config.ema)
+            ema = EMA(config.ema)
+            callbacks.extend([ema])
+            checkpoint_callback = NeMoModelCheckpoint(dirpath=default_root_dir, monitor="train_step_f1", mode='max',
+                                                      filename='best_{epoch}-{train_step_f1:.4f}',
+                                                      save_top_k=config.save_top_k, save_last=True, save_weights_only=True,
+                                                      save_nemo_on_train_end=False)
+
+        callbacks.extend([checkpoint_callback])
+
+        trainer = L.Trainer(
+            default_root_dir=default_root_dir,
+            max_epochs=config.epochs,
+            accelerator=config.device,
+            accumulate_grad_batches=config.accumulate_grad_batches,
+            gradient_clip_val=config.gradient_clip_val,
+            deterministic=config.deterministic,
+            precision=config.precision,
+            logger=logger_wandb if logger_wandb is not None else logger_csv,
+            callbacks=callbacks,
+            check_val_every_n_epoch=1,
+            enable_progress_bar=True,
+            limit_val_batches=0,
+            num_sanity_val_steps=0,
+        )
+
+        trainer.fit(model, train_dataloader, None)
+
+        if logger_wandb is not None:
+            wandb.finish()
+
+    else:
+        resume_fold = 0
+        for fold_ in range(config.folds):
+            if fold_ < resume_fold:
+                continue
+
+            x_valid = train_pd[train_pd["sgkf_fold_s%d" % config.folds_seed] == fold_]
+            x_train = train_pd[train_pd["sgkf_fold_s%d" % config.folds_seed] != fold_]
+            x_valid_cleaned = x_valid[~(x_valid["uid"].isin(exclude_noisy))]
+            x_train_cleaned = x_train[~(x_train["uid"].isin(exclude_noisy))]
+            print("Hard noise removed", x_valid.shape, x_valid_cleaned.shape, x_train.shape, x_train_cleaned.shape)
+
+            if train_background_pd is not None:
+                x_train_cleaned = pd.concat([x_train_cleaned, train_background_pd], ignore_index=True)
+                print("Background class added to x_train_cleaned", x_valid_cleaned.shape, x_train_cleaned.shape)
+
+            if external_pd is not None:
+                x_train_ext = external_pd[external_pd["sgkf_fold_s%d" % config.folds_seed] != fold_]
+                x_train_cleaned = pd.concat([x_train_cleaned, x_train_ext], ignore_index=True)
+                print("External added to x_train_cleaned", x_train_cleaned.shape)
+
+            train_dataset = Dataset(train_boxes_home, x_train_cleaned, config, subset="train", preprocess=preprocess_image,
+                                    augment=image_augmentation_train, prepare=prepare_feed)
+            valid_dataset = Dataset(train_boxes_home, x_valid_cleaned, config, subset="valid", preprocess=preprocess_image,
+                                    augment=None, prepare=prepare_feed)
+
+            train_sampler = None  # get_sampler(config.sampler, ds=train_dataset)
+            # print("Train sampler", train_sampler)
+
+            train_batch_sampler = None # get_batch_sampler(config, ds=train_dataset)
+            # print("Train batch sampler", train_batch_sampler)
+
+            train_dataloader = data.DataLoader(train_dataset,
+                                               batch_size=config.batch_size if train_batch_sampler is None else 1,
+                                               sampler=train_sampler, batch_sampler=train_batch_sampler, drop_last=False,
+                                               num_workers=config.num_workers, shuffle=True if (
+                            (train_sampler is None) and (train_batch_sampler is None)) else False, pin_memory=True)
+            valid_dataloader = data.DataLoader(valid_dataset, batch_size=config.val_batch_size, drop_last=False,
+                                               num_workers=config.num_workers, pin_memory=True)
+
+            model = MosquitoModel(config)
+
+            model_path = "classifier_%s_%s_%s/seed%s/%s/%s" % (
+            config.imgsz, config.backbone, config.version, config.seed, "fold%d" % fold_, "stage1")
+            default_root_dir = os.path.join(models_home, model_path)
+            os.makedirs(default_root_dir, exist_ok=True)
+            d = save_config(config, os.path.join(default_root_dir, "config.json"))
+
+            logger_wandb = None
+            if wandb_project is not None:
+                logger_wandb = WandbLogger(project=wandb_project, name=model_path.replace("/", "_"))
+            logger_csv = CSVLogger("./logs", name=model_path.replace("/", "_"))
+
+            lr_monitor = LearningRateMonitor(logging_interval='epoch')
+            checkpoint_callback = ModelCheckpoint(dirpath=default_root_dir, monitor="val_f1", mode='max',
+                                                  filename='best_{epoch}-{val_f1:.4f}', save_top_k=config.save_top_k,
+                                                  save_last=True, save_weights_only=True)
+
+            callbacks = [lr_monitor]
+
+            if config.ema is not None:
+                print("EMA enabled:", config.ema)
+                ema = EMA(config.ema)
+                callbacks.extend([ema])
+                checkpoint_callback = NeMoModelCheckpoint(dirpath=default_root_dir, monitor="val_f1", mode='max',
+                                                          filename='best_{epoch}-{val_f1:.4f}',
+                                                          save_top_k=config.save_top_k, save_last=True,
+                                                          save_weights_only=True, save_nemo_on_train_end=False)
+
+            callbacks.extend([checkpoint_callback])
+
+            if config.swa_lrs is not None:
+                print("SWA enabled")
+                swa = StochasticWeightAveraging(swa_lrs=config.swa_lrs)
+                callbacks.extend([swa])
+            if config.pruning is not None:
+                print("Pruning enabled")
+                pruning = ModelPruning("l1_unstructured", amount=config.pruning)
+                callbacks.extend([pruning])
+
+            print()
+            trainer = L.Trainer(
+                default_root_dir=default_root_dir,
+                max_epochs=config.epochs,
+                accelerator=config.device,
+                accumulate_grad_batches=config.accumulate_grad_batches,
+                gradient_clip_val=config.gradient_clip_val,
+                deterministic=config.deterministic,
+                precision=config.precision,
+                logger=logger_wandb if logger_wandb is not None else logger_csv,
+                callbacks=callbacks,
+                check_val_every_n_epoch=1,
+                enable_progress_bar=True,
+            )
+
+            trainer.fit(model, train_dataloader, valid_dataloader)
+
+            if logger_wandb is not None:
+                wandb.finish()
+
+            # break
diff --git a/nemo/__init__.py b/nemo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9fedbd1cc691ff5b16a78420965131787f37dd
--- /dev/null
+++ b/nemo/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nemo.package_info import (
+    __contact_emails__,
+    __contact_names__,
+    __description__,
+    __download_url__,
+    __homepage__,
+    __keywords__,
+    __license__,
+    __package_name__,
+    __repository_url__,
+    __shortversion__,
+    __version__,
+)
diff --git a/nemo/collections/__init__.py b/nemo/collections/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nemo/collections/common/__init__.py b/nemo/collections/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nemo/collections/common/callbacks/__init__.py b/nemo/collections/common/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cf495d946960af72276cde51d1f546385356a1d
--- /dev/null
+++ b/nemo/collections/common/callbacks/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nemo.collections.common.callbacks.callbacks import LogEpochTimeCallback
+from nemo.collections.common.callbacks.ema import EMA
diff --git a/nemo/collections/common/callbacks/callbacks.py b/nemo/collections/common/callbacks/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a6c011c38dfaf292108d9ca903b2f133b991f92
--- /dev/null
+++ b/nemo/collections/common/callbacks/callbacks.py
@@ -0,0 +1,96 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import time
+
+from pytorch_lightning.callbacks import Callback
+from pytorch_lightning.utilities import rank_zero_only
+
+# from sacrebleu import corpus_bleu
+
+
+class LogEpochTimeCallback(Callback):
+    """Simple callback that logs how long each epoch takes, in seconds, to a pytorch lightning log
+    """
+
+    @rank_zero_only
+    def on_train_epoch_start(self, trainer, pl_module):
+        self.epoch_start = time.time()
+
+    @rank_zero_only
+    def on_train_epoch_end(self, trainer, pl_module):
+        curr_time = time.time()
+        duration = curr_time - self.epoch_start
+        trainer.logger.log_metrics({"epoch_time": duration}, step=trainer.global_step)
+
+
+# class MachineTranslationLogEvalCallback(Callback):
+#     def _on_eval_end(self, trainer, pl_module, mode):
+#         counts = np.array(self._non_pad_tokens)
+#         eval_loss = np.sum(np.array(self._losses) * counts) / np.sum(counts)
+#         sacre_bleu = corpus_bleu(self._translations, [self._ground_truths], tokenize="13a")
+#         print(f"{mode} results for process with global rank {pl_module.global_rank}".upper())
+#         for i in range(pl_module.num_examples[mode]):
+#             print('\u0332'.join(f"EXAMPLE {i}:"))  # Underline output
+#             sent_id = np.random.randint(len(self._translations))
+#             print(f"Ground truth: {self._ground_truths[sent_id]}\n")
+#             print(f"Translation: {self._translations[sent_id]}\n")
+#             print()
+#         print("-" * 50)
+#         print(f"loss: {eval_loss:.3f}")
+#         print(f"SacreBLEU: {sacre_bleu}")
+#         print("-" * 50)
+
+#     @rank_zero_only
+#     def on_test_end(self, trainer, pl_module):
+#         self._on_eval_end(trainer, pl_module, "test")
+
+#     @rank_zero_only
+#     def on_validation_end(self, trainer, pl_module):
+#         self._on_eval_end(trainer, pl_module, "val")
+
+#     @rank_zero_only
+#     def on_sanity_check_end(self, trainer, pl_module):
+#         self._on_eval_end(trainer, pl_module, "val")
+
+#     def _on_eval_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx, mode):
+#         self._translations.extend(outputs['translations'])
+#         self._ground_truths.extend(outputs['ground_truths'])
+#         self._non_pad_tokens.append(outputs['num_non_pad_tokens'])
+#         self._losses.append(outputs[f'{mode}_loss'])
+
+#     @rank_zero_only
+#     def on_test_batch_end(self, trainer, pl_module, batch, outputs, batch_idx, dataloader_idx):
+#         self._on_eval_batch_end(trainer, pl_module, batch, outputs, batch_idx, dataloader_idx, 'test')
+
+#     @rank_zero_only
+#     def on_validation_batch_end(self, trainer, pl_module, batch, outputs, batch_idx, dataloader_idx):
+#         self._on_eval_batch_end(trainer, pl_module, batch, outputs, batch_idx, dataloader_idx, 'val')
+
+#     def _on_eval_start(self, trainer, pl_module):
+#         self._translations = []
+#         self._ground_truths = []
+#         self._losses = []
+#         self._non_pad_tokens = []
+
+#     @rank_zero_only
+#     def on_test_start(self, trainer, pl_module):
+#         self._on_eval_start(trainer, pl_module)
+
+#     @rank_zero_only
+#     def on_validation_start(self, trainer, pl_module):
+#         self._on_eval_start(trainer, pl_module)
+
+#     @rank_zero_only
+#     def on_sanity_check_start(self, trainer, pl_module):
+#         self._on_eval_start(trainer, pl_module)
diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec53b61e17c0d8de4d90e6afe14dd2ee6535a8b9
--- /dev/null
+++ b/nemo/collections/common/callbacks/ema.py
@@ -0,0 +1,347 @@
+# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import contextlib
+import copy
+import os
+import threading
+from typing import Any, Dict, Iterable
+
+import pytorch_lightning as pl
+import torch
+from pytorch_lightning import Callback
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.utilities.rank_zero import rank_zero_info
+
+
+class EMA(Callback):
+    """
+    Implements Exponential Moving Averaging (EMA).
+
+    When training a model, this callback will maintain moving averages of the trained parameters.
+    When evaluating, we use the moving averages copy of the trained parameters.
+    When saving, we save an additional set of parameters with the prefix `ema`.
+
+    Args:
+        decay: The exponential decay used when calculating the moving average. Has to be between 0-1.
+        validate_original_weights: Validate the original weights, as apposed to the EMA weights.
+        every_n_steps: Apply EMA every N steps.
+        cpu_offload: Offload weights to CPU.
+    """
+
+    def __init__(
+        self, decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False,
+    ):
+        if not (0 <= decay <= 1):
+            raise MisconfigurationException("EMA decay value must be between 0 and 1")
+        self.decay = decay
+        self.validate_original_weights = validate_original_weights
+        self.every_n_steps = every_n_steps
+        self.cpu_offload = cpu_offload
+
+    def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+        device = pl_module.device if not self.cpu_offload else torch.device('cpu')
+        trainer.optimizers = [
+            EMAOptimizer(
+                optim,
+                device=device,
+                decay=self.decay,
+                every_n_steps=self.every_n_steps,
+                current_step=trainer.global_step,
+            )
+            for optim in trainer.optimizers
+            if not isinstance(optim, EMAOptimizer)
+        ]
+
+    def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+        if self._should_validate_ema_weights(trainer):
+            self.swap_model_weights(trainer)
+
+    def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+        if self._should_validate_ema_weights(trainer):
+            self.swap_model_weights(trainer)
+
+    def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+        if self._should_validate_ema_weights(trainer):
+            self.swap_model_weights(trainer)
+
+    def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+        if self._should_validate_ema_weights(trainer):
+            self.swap_model_weights(trainer)
+
+    def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool:
+        return not self.validate_original_weights and self._ema_initialized(trainer)
+
+    def _ema_initialized(self, trainer: "pl.Trainer") -> bool:
+        return any(isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers)
+
+    def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False):
+        for optimizer in trainer.optimizers:
+            assert isinstance(optimizer, EMAOptimizer)
+            optimizer.switch_main_parameter_weights(saving_ema_model)
+
+    @contextlib.contextmanager
+    def save_ema_model(self, trainer: "pl.Trainer"):
+        """
+        Saves an EMA copy of the model + EMA optimizer states for resume.
+        """
+        self.swap_model_weights(trainer, saving_ema_model=True)
+        try:
+            yield
+        finally:
+            self.swap_model_weights(trainer, saving_ema_model=False)
+
+    @contextlib.contextmanager
+    def save_original_optimizer_state(self, trainer: "pl.Trainer"):
+        for optimizer in trainer.optimizers:
+            assert isinstance(optimizer, EMAOptimizer)
+            optimizer.save_original_optimizer_state = True
+        try:
+            yield
+        finally:
+            for optimizer in trainer.optimizers:
+                optimizer.save_original_optimizer_state = False
+
+    def on_load_checkpoint(
+        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
+    ) -> None:
+        checkpoint_callback = trainer.checkpoint_callback
+
+        # use the connector as NeMo calls the connector directly in the exp_manager when restoring.
+        connector = trainer._checkpoint_connector
+        # Replace connector._ckpt_path with below to avoid calling into lightning's protected API
+        ckpt_path = trainer.ckpt_path
+
+        if ckpt_path and checkpoint_callback is not None and 'NeMo' in type(checkpoint_callback).__name__:
+            ext = checkpoint_callback.FILE_EXTENSION
+            if ckpt_path.endswith(f'-EMA{ext}'):
+                rank_zero_info(
+                    "loading EMA based weights. "
+                    "The callback will treat the loaded EMA weights as the main weights"
+                    " and create a new EMA copy when training."
+                )
+                return
+            ema_path = ckpt_path.replace(ext, f'-EMA{ext}')
+            if os.path.exists(ema_path):
+                ema_state_dict = torch.load(ema_path, map_location=torch.device('cpu'))
+
+                checkpoint['optimizer_states'] = ema_state_dict['optimizer_states']
+                del ema_state_dict
+                rank_zero_info("EMA state has been restored.")
+            else:
+                raise MisconfigurationException(
+                    "Unable to find the associated EMA weights when re-loading, "
+                    f"training will start with new EMA weights. Expected them to be at: {ema_path}",
+                )
+
+
+@torch.no_grad()
+def ema_update(ema_model_tuple, current_model_tuple, decay):
+    torch._foreach_mul_(ema_model_tuple, decay)
+    torch._foreach_add_(
+        ema_model_tuple, current_model_tuple, alpha=(1.0 - decay),
+    )
+
+
+def run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None):
+    if pre_sync_stream is not None:
+        pre_sync_stream.synchronize()
+
+    ema_update(ema_model_tuple, current_model_tuple, decay)
+
+
+class EMAOptimizer(torch.optim.Optimizer):
+    r"""
+    EMAOptimizer is a wrapper for torch.optim.Optimizer that computes
+    Exponential Moving Average of parameters registered in the optimizer.
+
+    EMA parameters are automatically updated after every step of the optimizer
+    with the following formula:
+
+        ema_weight = decay * ema_weight + (1 - decay) * training_weight
+
+    To access EMA parameters, use ``swap_ema_weights()`` context manager to
+    perform a temporary in-place swap of regular parameters with EMA
+    parameters.
+
+    Notes:
+        - EMAOptimizer is not compatible with APEX AMP O2.
+
+    Args:
+        optimizer (torch.optim.Optimizer): optimizer to wrap
+        device (torch.device): device for EMA parameters
+        decay (float): decay factor
+
+    Returns:
+        returns an instance of torch.optim.Optimizer that computes EMA of
+        parameters
+
+    Example:
+        model = Model().to(device)
+        opt = torch.optim.Adam(model.parameters())
+
+        opt = EMAOptimizer(opt, device, 0.9999)
+
+        for epoch in range(epochs):
+            training_loop(model, opt)
+
+            regular_eval_accuracy = evaluate(model)
+
+            with opt.swap_ema_weights():
+                ema_eval_accuracy = evaluate(model)
+    """
+
+    def __init__(
+        self,
+        optimizer: torch.optim.Optimizer,
+        device: torch.device,
+        decay: float = 0.9999,
+        every_n_steps: int = 1,
+        current_step: int = 0,
+    ):
+        self.optimizer = optimizer
+        self.decay = decay
+        self.device = device
+        self.current_step = current_step
+        self.every_n_steps = every_n_steps
+        self.save_original_optimizer_state = False
+
+        self.first_iteration = True
+        self.rebuild_ema_params = True
+        self.stream = None
+        self.thread = None
+
+        self.ema_params = ()
+        self.in_saving_ema_model_context = False
+
+    def all_parameters(self) -> Iterable[torch.Tensor]:
+        return (param for group in self.param_groups for param in group['params'])
+
+    def step(self, closure=None, **kwargs):
+        self.join()
+
+        if self.first_iteration:
+            if any(p.is_cuda for p in self.all_parameters()):
+                self.stream = torch.cuda.Stream()
+
+            self.first_iteration = False
+
+        if self.rebuild_ema_params:
+            opt_params = list(self.all_parameters())
+
+            self.ema_params += tuple(
+                copy.deepcopy(param.data.detach()).to(self.device) for param in opt_params[len(self.ema_params) :]
+            )
+            self.rebuild_ema_params = False
+
+        loss = self.optimizer.step(closure)
+
+        if self._should_update_at_step():
+            self.update()
+        self.current_step += 1
+        return loss
+
+    def _should_update_at_step(self) -> bool:
+        return self.current_step % self.every_n_steps == 0
+
+    @torch.no_grad()
+    def update(self):
+        if self.stream is not None:
+            self.stream.wait_stream(torch.cuda.current_stream())
+
+        with torch.cuda.stream(self.stream):
+            current_model_state = tuple(
+                param.data.to(self.device, non_blocking=True) for param in self.all_parameters()
+            )
+
+            if self.device.type == 'cuda':
+                ema_update(self.ema_params, current_model_state, self.decay)
+
+        if self.device.type == 'cpu':
+            self.thread = threading.Thread(
+                target=run_ema_update_cpu, args=(self.ema_params, current_model_state, self.decay, self.stream,),
+            )
+            self.thread.start()
+
+    def swap_tensors(self, tensor1, tensor2):
+        tmp = torch.empty_like(tensor1)
+        tmp.copy_(tensor1)
+        tensor1.copy_(tensor2)
+        tensor2.copy_(tmp)
+
+    def switch_main_parameter_weights(self, saving_ema_model: bool = False):
+        self.join()
+        self.in_saving_ema_model_context = saving_ema_model
+        for param, ema_param in zip(self.all_parameters(), self.ema_params):
+            self.swap_tensors(param.data, ema_param)
+
+    @contextlib.contextmanager
+    def swap_ema_weights(self, enabled: bool = True):
+        r"""
+        A context manager to in-place swap regular parameters with EMA
+        parameters.
+        It swaps back to the original regular parameters on context manager
+        exit.
+
+        Args:
+            enabled (bool): whether the swap should be performed
+        """
+
+        if enabled:
+            self.switch_main_parameter_weights()
+        try:
+            yield
+        finally:
+            if enabled:
+                self.switch_main_parameter_weights()
+
+    def __getattr__(self, name):
+        return getattr(self.optimizer, name)
+
+    def join(self):
+        if self.stream is not None:
+            self.stream.synchronize()
+
+        if self.thread is not None:
+            self.thread.join()
+
+    def state_dict(self):
+        self.join()
+
+        if self.save_original_optimizer_state:
+            return self.optimizer.state_dict()
+
+        # if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights
+        ema_params = self.ema_params if not self.in_saving_ema_model_context else list(self.all_parameters())
+        state_dict = {
+            'opt': self.optimizer.state_dict(),
+            'ema': ema_params,
+            'current_step': self.current_step,
+            'decay': self.decay,
+            'every_n_steps': self.every_n_steps,
+        }
+        return state_dict
+
+    def load_state_dict(self, state_dict):
+        self.join()
+
+        self.optimizer.load_state_dict(state_dict['opt'])
+        self.ema_params = tuple(param.to(self.device) for param in copy.deepcopy(state_dict['ema']))
+        self.current_step = state_dict['current_step']
+        self.decay = state_dict['decay']
+        self.every_n_steps = state_dict['every_n_steps']
+        self.rebuild_ema_params = False
+
+    def add_param_group(self, param_group):
+        self.optimizer.add_param_group(param_group)
+        self.rebuild_ema_params = True
diff --git a/nemo/constants.py b/nemo/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2afbf882277abcb8e183071fe14bdfffe07335e
--- /dev/null
+++ b/nemo/constants.py
@@ -0,0 +1,21 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+NEMO_ENV_VARNAME_ENABLE_COLORING = "NEMO_ENABLE_COLORING"
+NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR = "NEMO_REDIRECT_LOGS_TO_STDERR"
+NEMO_ENV_VARNAME_TESTING = "NEMO_TESTING"  # Set to True to enable nemo.util.logging's debug mode
+NEMO_ENV_VARNAME_VERSION = "NEMO_EXPM_VERSION"  # Used for nemo.utils.exp_manager versioning
+NEMO_ENV_CACHE_DIR = "NEMO_CACHE_DIR"  # Used to change default nemo cache directory
+NEMO_ENV_DATA_STORE_CACHE_DIR = "NEMO_DATA_STORE_CACHE_DIR"  # Used to change default nemo data store cache directory
+NEMO_ENV_DATA_STORE_CACHE_SHARED = "NEMO_DATA_STORE_CACHE_SHARED"  # Shared among nodes (1) or not shared (0)
diff --git a/nemo/package_info.py b/nemo/package_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..981bb1f6b090f992f090124a6f932d9b1b2359dd
--- /dev/null
+++ b/nemo/package_info.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+MAJOR = 1
+MINOR = 21
+PATCH = 0
+PRE_RELEASE = 'rc0'
+
+# Use the following formatting: (major, minor, patch, pre-release)
+VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)
+
+__shortversion__ = '.'.join(map(str, VERSION[:3]))
+__version__ = '.'.join(map(str, VERSION[:3])) + ''.join(VERSION[3:])
+
+__package_name__ = 'nemo_toolkit'
+__contact_names__ = 'NVIDIA'
+__contact_emails__ = 'nemo-toolkit@nvidia.com'
+__homepage__ = 'https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/'
+__repository_url__ = 'https://github.com/nvidia/nemo'
+__download_url__ = 'https://github.com/NVIDIA/NeMo/releases'
+__description__ = 'NeMo - a toolkit for Conversational AI'
+__license__ = 'Apache2'
+__keywords__ = 'deep learning, machine learning, gpu, NLP, NeMo, nvidia, pytorch, torch, tts, speech, language'
diff --git a/nemo/utils/__init__.py b/nemo/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c424f72e411a4c4739e0d76c4b41d1e06cb89c9
--- /dev/null
+++ b/nemo/utils/__init__.py
@@ -0,0 +1,34 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nemo.utils.app_state import AppState
+from nemo.utils.cast_utils import (
+    CastToFloat,
+    CastToFloatAll,
+    avoid_bfloat16_autocast_context,
+    avoid_float16_autocast_context,
+    cast_all,
+    cast_tensor,
+)
+from nemo.utils.nemo_logging import Logger as _Logger
+from nemo.utils.nemo_logging import LogMode as logging_mode
+
+logging = _Logger()
+try:
+    from nemo.utils.lightning_logger_patch import add_memory_handlers_to_pl_logger
+
+    add_memory_handlers_to_pl_logger()
+except ModuleNotFoundError:
+    pass
diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..d06e1ac32e3698c230aec75105c516537d53cf04
--- /dev/null
+++ b/nemo/utils/app_state.py
@@ -0,0 +1,548 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from threading import Lock
+from typing import Dict, Optional
+
+from nemo.utils.metaclasses import Singleton
+
+
+@dataclass()
+class ModelMetadataRegistry:
+    guid: str
+    gidx: int
+    restoration_path: Optional[str] = None
+
+
+class AppState(metaclass=Singleton):
+    def __init__(self):
+        # method call lock
+        self.__lock = Lock()
+
+        # TODO: should we store global config in hydra_runner?
+        self._app_cfg = None
+
+        # World info
+        self._device_id = None
+        self._local_rank = None
+        self._global_rank = None
+        self._tensor_model_parallel_rank = None
+        self._pipeline_model_parallel_rank = None
+        self._data_parallel_rank = None
+
+        self._world_size = None
+        self._model_parallel_size = None
+        self._tensor_model_parallel_size = None
+        self._tensor_model_parallel_group = None
+        self._pipeline_model_parallel_size = None
+        self._virtual_pipeline_model_parallel_size = None
+        self._pipeline_model_parallel_group = None
+        self._pipeline_model_parallel_split_rank = None
+        self._is_megatron_initialized = False
+        self._data_parallel_size = None
+        self._data_parallel_group = None
+        self._megatron_checkpoint_version = None
+        self._use_fp8 = False
+        self._init_mpi_proc_gruop = False
+
+        self._random_seed = None
+
+        # Logging info
+        self._log_dir = None
+        self._exp_dir = None
+        self._name = None
+        self._checkpoint_name = None
+        self._version = None
+        self._create_checkpoint_callback = None
+        self._checkpoint_callback_params = None
+
+        # Save and Restore (.nemo)
+        self._tmpdir_name = None
+        self._is_model_being_restored = False
+        self._nemo_file_folder = None
+        self._model_restore_path = None
+        self._all_model_restore_paths = []
+        self._model_guid_map = {}  # type: Dict[str, ModelMetadataRegistry]
+
+    @property
+    def device_id(self):
+        """ Property returns the device_id
+            Returns:
+                device_id
+        """
+        return self._device_id
+
+    @device_id.setter
+    def device_id(self, id):
+        """ Property sets the device_id.
+            Args:
+                size (int): The device id. 
+        """
+        self._device_id = id
+
+    @property
+    def world_size(self):
+        """ Property returns the total number of GPUs.
+            Returns:
+                Total number of GPUs.
+        """
+        return self._world_size
+
+    @world_size.setter
+    def world_size(self, size):
+        """ Property sets the total number of GPUs.
+            Args:
+                size (int):  Total number of GPUs.
+        """
+        self._world_size = size
+
+    @property
+    def model_parallel_size(self):
+        """ Property returns the number of GPUs in each model parallel group.
+            Returns:
+                Number of GPUs in each model parallel group.
+        """
+        return self._model_parallel_size
+
+    @model_parallel_size.setter
+    def model_parallel_size(self, size):
+        """ Property sets the number of GPUs in each model parallel group.
+            Args:
+                size (int):  Number of GPUs in each model parallel group.
+        """
+        self._model_parallel_size = size
+
+    @property
+    def tensor_model_parallel_size(self):
+        """ Property returns the number of GPUs in each model parallel group.
+            Returns:
+                Number of GPUs in each model parallel group.
+        """
+        return self._tensor_model_parallel_size
+
+    @tensor_model_parallel_size.setter
+    def tensor_model_parallel_size(self, size):
+        """ Property sets the number of GPUs in each model parallel group.
+            Args:
+                size (int):  Number of GPUs in each model parallel group.
+        """
+        self._tensor_model_parallel_size = size
+
+    @property
+    def pipeline_model_parallel_size(self):
+        """ Property returns the number of GPUs in each model parallel group.
+            Returns:
+                Number of GPUs in each model parallel group.
+        """
+        return self._pipeline_model_parallel_size
+
+    @pipeline_model_parallel_size.setter
+    def pipeline_model_parallel_size(self, size):
+        """ Property sets the number of GPUs in each model parallel group.
+            Args:
+                size (int):  Number of GPUs in each model parallel group.
+        """
+        self._pipeline_model_parallel_size = size
+
+    @property
+    def virtual_pipeline_model_parallel_size(self):
+        """ Property returns the number of GPUs in each model parallel group.
+            Returns:
+                Number of GPUs in each model parallel group.
+        """
+        return self._virtual_pipeline_model_parallel_size
+
+    @virtual_pipeline_model_parallel_size.setter
+    def virtual_pipeline_model_parallel_size(self, size):
+        """ Property sets the size of the virtual pipeline parallel model.
+            Args:
+                size (int):  Number of modules in each pipeline parallel model.
+        """
+        self._virtual_pipeline_model_parallel_size = size
+
+    @property
+    def data_parallel_size(self):
+        """ Property returns the number of GPUs in each data parallel group.
+            Returns:
+                Number of GPUs in each data parallel group.
+        """
+        return self._data_parallel_size
+
+    @data_parallel_size.setter
+    def data_parallel_size(self, size):
+        """ Property sets the number of GPUs in each data parallel group.
+            Args:
+                size (int):  Number of GPUs in each data parallel group.
+        """
+        self._data_parallel_size = size
+
+    @property
+    def local_rank(self):
+        """ Property returns the local rank.
+            Returns:
+                Local rank.
+        """
+        return self._local_rank
+
+    @local_rank.setter
+    def local_rank(self, rank):
+        """ Property sets the local rank.
+            Args:
+                rank (int):  Local rank.
+        """
+        self._local_rank = rank
+
+    @property
+    def global_rank(self):
+        """ Property returns the global rank.
+            Returns:
+                Global rank.
+        """
+        return self._global_rank
+
+    @global_rank.setter
+    def global_rank(self, rank):
+        """ Property sets the global rank.
+            Args:
+                rank (int):  Global rank.
+        """
+        self._global_rank = rank
+
+    @property
+    def tensor_model_parallel_rank(self):
+        """ Property returns the tensor model parallel rank.
+            Returns:
+                Tensor model parallel rank.
+        """
+        return self._tensor_model_parallel_rank
+
+    @tensor_model_parallel_rank.setter
+    def tensor_model_parallel_rank(self, rank):
+        """ Property sets the tensor model parallel rank.
+            Args:
+                rank (int):  Tensor model parallel rank.
+        """
+        self._tensor_model_parallel_rank = rank
+
+    @property
+    def tensor_model_parallel_group(self):
+        """ Property returns the tensor model parallel group.
+            Returns:
+                Tensor model parallel group.
+        """
+        return self._tensor_model_parallel_group
+
+    @tensor_model_parallel_group.setter
+    def tensor_model_parallel_group(self, group):
+        """ Property sets the tensor model parallel group.
+            Args:
+                group:  Tensor model parallel group.
+        """
+        self._tensor_model_parallel_group = group
+
+    @property
+    def pipeline_model_parallel_rank(self):
+        """ Property returns the pipeline model parallel rank.
+            Returns:
+                Pipeline model parallel rank.
+        """
+        return self._pipeline_model_parallel_rank
+
+    @pipeline_model_parallel_rank.setter
+    def pipeline_model_parallel_rank(self, rank):
+        """ Property sets the pipeline model parallel rank.
+            Args:
+                rank (int):  Pipeline model parallel rank.
+        """
+        self._pipeline_model_parallel_rank = rank
+
+    @property
+    def virtual_pipeline_model_parallel_rank(self):
+        """ Property returns the virtual pipeline parallel rank.
+            Returns:
+                Model parallel rank.
+        """
+        return self._virtual_pipeline_model_parallel_rank
+
+    @virtual_pipeline_model_parallel_rank.setter
+    def virtual_pipeline_model_parallel_rank(self, rank):
+        """ Property sets the virtual pipeline parallel rank.
+            Args:
+                rank (int):  Virtual pipeline parallel rank.
+        """
+        self._virtual_pipeline_model_parallel_rank = rank
+
+    @property
+    def pipeline_model_parallel_split_rank(self):
+        """ Property returns the rank at which Encoder and Decoder are split into different pipelines for Megatrron Encoder-Decoder models.
+            Returns:
+                Pipeline model parallel split rank.
+        """
+        return self._pipeline_model_parallel_split_rank
+
+    @pipeline_model_parallel_split_rank.setter
+    def pipeline_model_parallel_split_rank(self, rank):
+        """ Property sets the rank at which Encoder and Decoder are split into different pipelines for Megatrron Encoder-Decoder models.
+            Args:
+                rank (int): Model parallel split rank.
+        """
+        self._pipeline_model_parallel_split_rank = rank
+
+    @property
+    def pipeline_model_parallel_group(self):
+        """ Property returns the pipeline model parallel group.
+            Returns:
+                Pipeline model parallel group.
+        """
+        return self._pipeline_model_parallel_group
+
+    @pipeline_model_parallel_group.setter
+    def pipeline_model_parallel_group(self, group):
+        """ Property sets the pipeline model parallel group.
+            Args:
+                group:  Pipeline model parallel group.
+        """
+        self._pipeline_model_parallel_group = group
+
+    @property
+    def data_parallel_rank(self):
+        """ Property returns the data parallel rank.
+            Returns:
+                Data parallel rank.
+        """
+        return self._data_parallel_rank
+
+    @data_parallel_rank.setter
+    def data_parallel_rank(self, rank):
+        """ Property sets the data parallel rank.
+            Args:
+                rank (int):  Data parallel rank.
+        """
+        self._data_parallel_rank = rank
+
+    @property
+    def data_parallel_group(self):
+        """ Property returns the data parallel group.
+            Returns:
+                Data parallel group.
+        """
+        return self._data_parallel_group
+
+    @data_parallel_group.setter
+    def data_parallel_group(self, group):
+        """ Property sets the data parallel group.
+            Args:
+                group:  Data parallel group.
+        """
+        self._data_parallel_group = group
+
+    @property
+    def use_fp8(self):
+        """ Property returns the use of fp8 precision.
+            Returns:
+                Use of FP8.
+        """
+        return self._use_fp8
+
+    @use_fp8.setter
+    def use_fp8(self, use_fp8):
+        """ Property sets the use of fp8 precision.
+            Args:
+                use_fp8:  Use of FP8.
+        """
+        self._use_fp8 = use_fp8
+
+    @property
+    def init_mpi_proc_group(self):
+        """ Property sets the initialization of mpi process group.
+            Returns:
+                Initialize mpi process group.
+        """
+        return self._init_mpi_proc_group
+
+    @init_mpi_proc_group.setter
+    def init_mpi_proc_group(self, init_mpi_proc_group):
+        """ Property sets the initialization of mpi process group.
+            Args:
+                init_mpi_proc_group:  Initialize mpi process group.
+        """
+        self._init_mpi_proc_group = init_mpi_proc_group
+
+    @property
+    def random_seed(self):
+        """ Property returns the random seed.
+            Returns:
+                Random seed.
+        """
+        return self._random_seed
+
+    @random_seed.setter
+    def random_seed(self, seed):
+        """ Property sets the random seed.
+            Args:
+                seed (int):  Random seed.
+        """
+        self._random_seed = seed
+
+    @property
+    def log_dir(self):
+        """Returns the log_dir set by exp_manager.
+        """
+        return self._log_dir
+
+    @log_dir.setter
+    def log_dir(self, dir):
+        """Sets the log_dir property.
+
+        Args:
+            dir (str): Log_dir set by exp_manager.
+        """
+        self._log_dir = dir
+
+    @property
+    def exp_dir(self):
+        """Returns the exp_dir set by exp_manager.
+        """
+        return self._exp_dir
+
+    @exp_dir.setter
+    def exp_dir(self, dir):
+        """Sets the log_dir property.
+
+        Args:
+            dir (str): Log_dir set by exp_manager.
+        """
+        self._exp_dir = dir
+
+    @property
+    def name(self):
+        """Returns the name set by exp_manager.
+        """
+        return self._name
+
+    @name.setter
+    def name(self, name):
+        """Sets the name property.
+
+        Args:
+            dir (str): name set by exp_manager.
+        """
+        self._name = name
+
+    @property
+    def checkpoint_name(self):
+        """Returns the name set by exp_manager.
+        """
+        return self._checkpoint_name
+
+    @checkpoint_name.setter
+    def checkpoint_name(self, name):
+        """Sets the name property.
+
+        Args:
+            dir (str): name set by exp_manager.
+        """
+        self._checkpoint_name = name
+
+    @property
+    def version(self):
+        """Returns the version set by exp_manager.
+        """
+        return self._version
+
+    @version.setter
+    def version(self, version):
+        """Sets the version property.
+
+        Args:
+            dir (str): version set by exp_manager.
+        """
+        self._version = version
+
+    @property
+    def create_checkpoint_callback(self):
+        """Returns the create_checkpoint_callback set by exp_manager.
+        """
+        return self._create_checkpoint_callback
+
+    @create_checkpoint_callback.setter
+    def create_checkpoint_callback(self, create_checkpoint_callback):
+        """Sets the create_checkpoint_callback property.
+
+        Args:
+            dir (bool): create_checkpoint_callback set by exp_manager.
+        """
+        self._create_checkpoint_callback = create_checkpoint_callback
+
+    @property
+    def checkpoint_callback_params(self):
+        """Returns the version set by exp_manager.
+        """
+        return self._checkpoint_callback_params
+
+    @checkpoint_callback_params.setter
+    def checkpoint_callback_params(self, params):
+        """Sets the name property.
+
+        Args:
+            params (dict): checkpoint_callback_params set by exp_manager.
+        """
+        self._checkpoint_callback_params = params
+
+    @property
+    def model_restore_path(self):
+        restore_path = self._all_model_restore_paths[-1] if len(self._all_model_restore_paths) > 0 else None
+        return restore_path
+
+    @model_restore_path.setter
+    def model_restore_path(self, path):
+        with self.__lock:
+            self._model_restore_path = path
+            self._all_model_restore_paths.append(path)
+
+    def register_model_guid(self, guid: str, restoration_path: Optional[str] = None):
+        # Maps a guid to its restore path (None or last absolute path)
+        with self.__lock:
+            if guid in self._model_guid_map:
+                idx = self._model_guid_map[guid].gidx
+            else:
+                idx = len(self._model_guid_map)
+            self._model_guid_map[guid] = ModelMetadataRegistry(guid, idx, restoration_path=restoration_path)
+
+    def reset_model_guid_registry(self):
+        # Reset the guid mapping
+        with self.__lock:
+            self._model_guid_map.clear()
+
+    def get_model_metadata_from_guid(self, guid) -> ModelMetadataRegistry:
+        # Returns the global model idx and restoration path
+        metadata = self._model_guid_map[guid]
+        return metadata
+
+    @property
+    def is_model_being_restored(self) -> bool:
+        return self._is_model_being_restored
+
+    @is_model_being_restored.setter
+    def is_model_being_restored(self, is_restored: bool):
+        self._is_model_being_restored = is_restored
+
+    @property
+    def nemo_file_folder(self) -> str:
+        return self._nemo_file_folder
+
+    @nemo_file_folder.setter
+    def nemo_file_folder(self, path: str):
+        self._nemo_file_folder = path
diff --git a/nemo/utils/callbacks/__init__.py b/nemo/utils/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6623657a2dc2544dffa99cb9606ea79d88992337
--- /dev/null
+++ b/nemo/utils/callbacks/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nemo.utils.callbacks.nemo_model_checkpoint import NeMoModelCheckpoint
+from nemo.utils.callbacks.preemption import PreemptionCallback
diff --git a/nemo/utils/callbacks/nemo_model_checkpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4759ecf5949d1053e553c396b293955a3c39c4d
--- /dev/null
+++ b/nemo/utils/callbacks/nemo_model_checkpoint.py
@@ -0,0 +1,296 @@
+# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+import shutil
+from copy import deepcopy
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+
+import pytorch_lightning
+import torch
+from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning.utilities import rank_zero_info
+
+from nemo.collections.common.callbacks import EMA
+from nemo.utils import logging
+from nemo.utils.app_state import AppState
+from nemo.utils.get_rank import is_global_rank_zero
+from nemo.utils.model_utils import ckpt_to_dir, inject_model_parallel_rank, uninject_model_parallel_rank
+
+
+class NeMoModelCheckpoint(ModelCheckpoint):
+    """ Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end.
+    Extends Lightning's on_save_checkpoint func to save the .nemo file. Saves the .nemo file based 
+    on the best checkpoint saved (according to the monitor value).
+    Also contains func to save the EMA copy of the model.
+    """
+
+    def __init__(
+        self,
+        always_save_nemo: bool = False,
+        save_nemo_on_train_end: bool = True,
+        save_best_model: bool = False,
+        postfix: str = ".nemo",
+        n_resume: bool = False,
+        model_parallel_size: int = None,
+        **kwargs,
+    ):
+        # Parse and store "extended" parameters: save_best model and postfix.
+        self.always_save_nemo = always_save_nemo
+        self.save_nemo_on_train_end = save_nemo_on_train_end
+        self.save_best_model = save_best_model
+        if self.save_best_model and not self.save_nemo_on_train_end:
+            logging.warning(
+                (
+                    "Found save_best_model is True and save_nemo_on_train_end is False. "
+                    "Set save_nemo_on_train_end to True to automatically save the best model."
+                )
+            )
+        self.postfix = postfix
+        self.previous_best_path = ""
+        self.model_parallel_size = model_parallel_size
+
+        # `prefix` is deprecated
+        if 'prefix' in kwargs:
+            self.prefix = kwargs.pop('prefix')
+        else:
+            self.prefix = ""
+
+        # Call the parent class constructor with the remaining kwargs.
+        super().__init__(**kwargs)
+
+        if self.save_top_k != -1 and n_resume:
+            logging.debug("Checking previous runs")
+            self.nemo_topk_check_previous_run()
+
+    def nemo_topk_check_previous_run(self):
+        try:
+            self.best_k_models
+            self.kth_best_model_path
+            self.best_model_score
+            self.best_model_path
+        except AttributeError:
+            raise AttributeError("Lightning's ModelCheckpoint was updated. NeMoModelCheckpoint will need an update.")
+        self.best_k_models = {}
+        self.kth_best_model_path = ""
+        self.best_model_score = None
+        self.best_model_path = ""
+
+        checkpoints = list(path for path in self._saved_checkpoint_paths if not self._is_ema_filepath(path))
+        for checkpoint in checkpoints:
+            if 'mp_rank' in str(checkpoint) or 'tp_rank' in str(checkpoint):
+                checkpoint = uninject_model_parallel_rank(checkpoint)
+            checkpoint = str(checkpoint)
+            # second case is for distributed checkpoints, since they are a directory there's no extension
+            if checkpoint[-10:] == '-last.ckpt' or checkpoint[-5:] == '-last':
+                continue
+            index = checkpoint.find(self.monitor) + len(self.monitor) + 1  # Find monitor in str + 1 for '='
+            if index != len(self.monitor):
+                match = re.search('[A-z]', checkpoint[index:])
+                if match:
+                    value = checkpoint[index : index + match.start() - 1]  # -1 due to separator hypen
+                    self.best_k_models[checkpoint] = float(value)
+        if len(self.best_k_models) < 1:
+            return  # No saved checkpoints yet
+
+        _reverse = False if self.mode == "min" else True
+
+        best_k_models = sorted(self.best_k_models, key=self.best_k_models.get, reverse=_reverse)
+
+        # This section should be ok as rank zero will delete all excess checkpoints, since all other ranks are
+        # instantiated after rank zero. models_to_delete should be 0 for all other ranks.
+        if self.model_parallel_size is not None:
+            # check for distributed checkpoint
+            if checkpoints[0].is_dir():
+                models_to_delete = len(best_k_models) - self.save_top_k
+            else:
+                models_to_delete = len(best_k_models) - self.model_parallel_size * self.save_top_k
+        else:
+            models_to_delete = len(best_k_models) - self.save_top_k
+
+        models_to_delete = max(0, models_to_delete)
+        logging.debug(f'Number of models to delete: {models_to_delete}')
+
+        # If EMA enabled, delete the additional EMA weights
+        ema_enabled = self._has_ema_ckpts(self._saved_checkpoint_paths)
+
+        for _ in range(models_to_delete):
+            model = best_k_models.pop(-1)
+            self.best_k_models.pop(model)
+            self._del_model_without_trainer(model)
+            if ema_enabled and self._fs.exists(self._ema_format_filepath(model)):
+                self._del_model_without_trainer(self._ema_format_filepath(model))
+            logging.debug(f"Removed checkpoint: {model}")
+
+        self.kth_best_model_path = best_k_models[-1]
+        self.best_model_path = best_k_models[0]
+        self.best_model_score = self.best_k_models[self.best_model_path]
+
+    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
+        output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
+        if not self.always_save_nemo:
+            return output
+        # Load the best model and then re-save it
+        app_state = AppState()
+        if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
+            logging.warning(f'always_save_nemo will slow down training for model_parallel > 1.')
+        # since we are creating tarfile artifacts we need to update .nemo path
+        app_state.model_restore_path = os.path.abspath(
+            os.path.expanduser(os.path.join(self.dirpath, self.prefix + self.postfix))
+        )
+        if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
+            maybe_injected_best_model_path = inject_model_parallel_rank(self.best_model_path)
+        else:
+            maybe_injected_best_model_path = self.best_model_path
+
+        if self.save_best_model:
+            if not os.path.exists(maybe_injected_best_model_path):
+                return
+
+            if self.best_model_path == self.previous_best_path:
+                return output
+
+            self.previous_model_path = self.best_model_path
+            old_state_dict = deepcopy(pl_module.state_dict())
+            checkpoint = torch.load(maybe_injected_best_model_path, map_location='cpu')
+            if 'state_dict' in checkpoint:
+                checkpoint = checkpoint['state_dict']
+            # get a new instanace of the model
+            pl_module.load_state_dict(checkpoint, strict=True)
+            if torch.distributed.is_initialized():
+                torch.distributed.barrier()
+            pl_module.save_to(save_path=app_state.model_restore_path)
+            logging.info(f"New best .nemo model saved to: {app_state.model_restore_path}")
+            pl_module.load_state_dict(old_state_dict, strict=True)
+        else:
+            if torch.distributed.is_initialized():
+                torch.distributed.barrier()
+            pl_module.save_to(save_path=app_state.model_restore_path)
+            logging.info(f"New .nemo model saved to: {app_state.model_restore_path}")
+        return output
+
+    def on_train_end(self, trainer, pl_module):
+        if trainer.fast_dev_run:
+            return None
+
+        # check if we need to save a last checkpoint manually as validation isn't always run based on the interval
+        if self.save_last and trainer.val_check_interval != 0:
+            should_save_last_checkpoint = False
+            if isinstance(trainer.val_check_interval, float) and trainer.val_check_interval % trainer.global_step != 0:
+                should_save_last_checkpoint = True
+            if isinstance(trainer.val_check_interval, int) and trainer.global_step % trainer.val_check_interval != 0:
+                should_save_last_checkpoint = True
+            if should_save_last_checkpoint:
+                monitor_candidates = self._monitor_candidates(trainer)
+                super()._save_last_checkpoint(trainer, monitor_candidates)
+        # Call parent on_train_end() to save the -last checkpoint
+        super().on_train_end(trainer, pl_module)
+
+        # Load the best model and then re-save it
+        if self.save_best_model:
+            # wait for all processes
+            trainer.strategy.barrier("SaveBestCheckpointConnector.resume_end")
+            if self.best_model_path == "":
+                logging.warning(
+                    f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints "
+                    "were found. Saving latest model instead."
+                )
+            else:
+                self.best_model_path = trainer.strategy.broadcast(self.best_model_path)
+                trainer._checkpoint_connector.restore(self.best_model_path)
+
+        if self.save_nemo_on_train_end:
+            pl_module.save_to(save_path=os.path.join(self.dirpath, self.prefix + self.postfix))
+
+    def _del_model_without_trainer(self, filepath: str) -> None:
+
+        filepath = Path(filepath)
+
+        # check if filepath is a distributed a checkpoint
+        if ckpt_to_dir(filepath).is_dir():
+            if is_global_rank_zero():
+                try:
+                    dist_ckpt = ckpt_to_dir(filepath)
+                    shutil.rmtree(dist_ckpt)
+                    logging.info(f"Removed distributed checkpoint: {dist_ckpt}")
+                except:
+                    logging.info(f"Tried to remove distributed checkpoint: {dist_ckpt} but failed.")
+
+        else:
+            app_state = AppState()
+
+            # legacy model parallel checkpoint
+            if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
+                # filepath needs to be updated to include mp_rank
+                filepath = inject_model_parallel_rank(filepath)
+
+            # each model parallel rank needs to remove its model
+            if is_global_rank_zero() or (
+                app_state.model_parallel_size is not None and app_state.data_parallel_rank == 0
+            ):
+                try:
+                    self._fs.rm(filepath)
+                    logging.info(f"Removed checkpoint: {filepath}")
+                except:
+                    logging.info(f"Tried to remove checkpoint: {filepath} but failed.")
+
+    def _ema_callback(self, trainer: 'pytorch_lightning.Trainer') -> Optional[EMA]:
+        ema_callback = None
+        for callback in trainer.callbacks:
+            if isinstance(callback, EMA):
+                ema_callback = callback
+        return ema_callback
+
+    def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None:
+        ema_callback = self._ema_callback(trainer)
+        if ema_callback is not None:
+            with ema_callback.save_original_optimizer_state(trainer):
+                super()._save_checkpoint(trainer, filepath)
+
+            # save EMA copy of the model as well.
+            with ema_callback.save_ema_model(trainer):
+                filepath = self._ema_format_filepath(filepath)
+                if self.verbose:
+                    rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
+                super()._save_checkpoint(trainer, filepath)
+        else:
+            super()._save_checkpoint(trainer, filepath)
+
+    def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str) -> None:
+        super()._remove_checkpoint(trainer, filepath)
+        ema_callback = self._ema_callback(trainer)
+        if ema_callback is not None:
+            # remove EMA copy of the state dict as well.
+            filepath = self._ema_format_filepath(filepath)
+            super()._remove_checkpoint(trainer, filepath)
+
+    def _ema_format_filepath(self, filepath: str) -> str:
+        return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}')
+
+    def _has_ema_ckpts(self, checkpoints: Iterable[Path]) -> bool:
+        return any(self._is_ema_filepath(checkpoint_path) for checkpoint_path in checkpoints)
+
+    def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool:
+        return str(filepath).endswith(f'-EMA{self.FILE_EXTENSION}')
+
+    @property
+    def _saved_checkpoint_paths(self) -> Iterable[Path]:
+        # distributed checkpoints are directories so we check for them here
+        dist_checkpoints = [d for d in list(Path(self.dirpath).glob("*")) if d.is_dir()]
+        if dist_checkpoints:
+            return dist_checkpoints
+        else:
+            return Path(self.dirpath).rglob("*.ckpt")
diff --git a/nemo/utils/callbacks/preemption.py b/nemo/utils/callbacks/preemption.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9b5f95022f3192257e5dba25c067fa6e2a73b8a
--- /dev/null
+++ b/nemo/utils/callbacks/preemption.py
@@ -0,0 +1,105 @@
+# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import signal
+import sys
+
+import torch
+from pytorch_lightning.callbacks import Callback
+
+from nemo.utils import logging
+
+
+class PreemptionCallback(Callback):
+    """
+    PreemptionCallback class creates a callback that checks for preemption during training at the end of every step.
+    Upon preemption the callback provides a function to gracefully exit the training immediately and also saves the current state in a checkpoint as *last.ckpt. 
+    (to be able to start from the same step without wasting any compute while resuming the next time).
+
+    PreemptionCallback is always enabled by default via the arg create_preemption_callback under ExpManagerConfig. To disable please pass
+    create_preemption_callback: False in your config file.
+    """
+
+    def __init__(self, checkpoint_callback, sig=None):
+        self.sig = sig
+        if self.sig is None:
+            self.sig = signal.SIGTERM
+        self.checkpoint_callback = checkpoint_callback
+        self.preemption_enabled = False
+
+    @property
+    def interrupted(self):
+        interrupted = torch.tensor(self._interrupted, device=torch.cuda.current_device(), dtype=torch.int32)
+        torch.distributed.broadcast(interrupted, 0)
+        interrupted = bool(interrupted.item())
+        return interrupted
+
+    def on_train_start(self, trainer, pl_module):
+        """
+        Defines custom handlers at the beginning of training to be executed when the 
+        preemption signal is received.
+        """
+
+        # Check if torch distributed is initialised, as its needed for broadcasting the preemption signal to all the ranks
+        if not (torch.distributed.is_available() and torch.distributed.is_initialized()):
+            logging.info("Preemption requires torch distributed to be initialized, disabling preemption")
+        else:
+            self.preemption_enabled = True
+            # Bool var that's initialized to false and made True upon receving the preemption signal
+            self._interrupted = False
+            self.released = False
+            self.original_handler = signal.getsignal(self.sig)
+
+            # Master handler executed only by rank 0 when the preemption siganal is received, to avoid deadlock conditions
+            def master_handler(signum, frame):
+                self.release()
+                self._interrupted = True
+
+            # Handler executed by the non zero ranks
+            def ignoring_handler(signum, frame):
+                self.release()
+
+            self.private_rank = torch.distributed.get_rank()
+            if self.private_rank == 0:
+                signal.signal(self.sig, master_handler)
+            else:
+                signal.signal(self.sig, ignoring_handler)
+
+        return self
+
+    def on_train_end(self, trainer, pl_module):
+        if self.preemption_enabled:
+            self.release()
+
+    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int):
+        if self.preemption_enabled:
+            # check if the job was preempted at the end of every training step/iteration
+            # NOTE: "self.interrupted" is a property which triggers a
+            # distributed broadcast of "_interrupted" flag from rank 0 to all other
+            # ranks, to avoid performance overheads it's best to store the result in
+            # a regular local variable
+            interrupted = self.interrupted
+            if interrupted:
+                logging.info("Received SIGTERM, saving checkpoint and exiting")
+                monitor_candidates = self.checkpoint_callback._monitor_candidates(trainer)
+                self.checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates)
+                sys.exit(0)
+
+    def release(self):
+        if self.released:
+            return False
+
+        signal.signal(self.sig, self.original_handler)
+        self.released = True
+        return True
diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..21e977ec494d858bb8b489f7c77e4770c54d45f5
--- /dev/null
+++ b/nemo/utils/cast_utils.py
@@ -0,0 +1,93 @@
+# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from contextlib import nullcontext
+
+import torch
+
+
+def avoid_bfloat16_autocast_context():
+    """
+    If the current autocast context is bfloat16,
+    cast it to float32
+    """
+
+    if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.bfloat16:
+        return torch.cuda.amp.autocast(dtype=torch.float32)
+    else:
+        return nullcontext()
+
+
+def avoid_float16_autocast_context():
+    """
+    If the current autocast context is float16, cast it to bfloat16
+    if available (unless we're in jit) or float32
+    """
+
+    if torch.is_autocast_enabled() and torch.get_autocast_gpu_dtype() == torch.float16:
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            return torch.cuda.amp.autocast(dtype=torch.float32)
+
+        if torch.cuda.is_bf16_supported():
+            return torch.cuda.amp.autocast(dtype=torch.bfloat16)
+        else:
+            return torch.cuda.amp.autocast(dtype=torch.float32)
+    else:
+        return nullcontext()
+
+
+def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32):
+    return x.to(dtype=to_dtype) if x.dtype == from_dtype else x
+
+
+def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32):
+    if isinstance(x, torch.Tensor):
+        return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype)
+    else:
+        if isinstance(x, dict):
+            new_dict = {}
+            for k in x.keys():
+                new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)
+            return new_dict
+        elif isinstance(x, tuple):
+            return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)
+
+
+class CastToFloat(torch.nn.Module):
+    def __init__(self, mod):
+        super(CastToFloat, self).__init__()
+        self.mod = mod
+
+    def forward(self, x):
+        if torch.is_autocast_enabled() and x.dtype != torch.float32:
+            with torch.cuda.amp.autocast(enabled=False):
+                ret = self.mod.forward(x.to(torch.float32)).to(x.dtype)
+        else:
+            ret = self.mod.forward(x)
+        return ret
+
+
+class CastToFloatAll(torch.nn.Module):
+    def __init__(self, mod):
+        super(CastToFloatAll, self).__init__()
+        self.mod = mod
+
+    def forward(self, *args):
+        if torch.is_autocast_enabled():
+            from_dtype = args[0].dtype
+            with torch.cuda.amp.autocast(enabled=False):
+                ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
+                return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
+        else:
+            return self.mod.forward(*args)
diff --git a/nemo/utils/data_utils.py b/nemo/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6479a65f11287e6494c298dac7bb4dbea5b931fc
--- /dev/null
+++ b/nemo/utils/data_utils.py
@@ -0,0 +1,318 @@
+# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import pathlib
+import shutil
+import subprocess
+from typing import Tuple
+
+from nemo import __version__ as NEMO_VERSION
+from nemo import constants
+from nemo.utils import logging
+
+
+def resolve_cache_dir() -> pathlib.Path:
+    """
+    Utility method to resolve a cache directory for NeMo that can be overriden by an environment variable.
+
+    Example:
+        NEMO_CACHE_DIR="~/nemo_cache_dir/" python nemo_example_script.py
+
+    Returns:
+        A Path object, resolved to the absolute path of the cache directory. If no override is provided,
+        uses an inbuilt default which adapts to nemo versions strings.
+    """
+    override_dir = os.environ.get(constants.NEMO_ENV_CACHE_DIR, "")
+    if override_dir == "":
+        path = pathlib.Path.joinpath(pathlib.Path.home(), f'.cache/torch/NeMo/NeMo_{NEMO_VERSION}')
+    else:
+        path = pathlib.Path(override_dir).resolve()
+    return path
+
+
+def is_datastore_path(path) -> bool:
+    """Check if a path is from a data object store.
+    Currently, only AIStore is supported.
+    """
+    return path.startswith('ais://')
+
+
+def is_tarred_path(path) -> bool:
+    """Check if a path is for a tarred file.
+    """
+    return path.endswith('.tar')
+
+
+def is_datastore_cache_shared() -> bool:
+    """Check if store cache is shared.
+    """
+    # Assume cache is shared by default, e.g., as in resolve_cache_dir (~/.cache)
+    cache_shared = int(os.environ.get(constants.NEMO_ENV_DATA_STORE_CACHE_SHARED, 1))
+
+    if cache_shared == 0:
+        return False
+    elif cache_shared == 1:
+        return True
+    else:
+        raise ValueError(f'Unexpected value of env {constants.NEMO_ENV_DATA_STORE_CACHE_SHARED}')
+
+
+def ais_cache_base() -> str:
+    """Return path to local cache for AIS.
+    """
+    override_dir = os.environ.get(constants.NEMO_ENV_DATA_STORE_CACHE_DIR, "")
+    if override_dir == "":
+        cache_dir = resolve_cache_dir().as_posix()
+    else:
+        cache_dir = pathlib.Path(override_dir).resolve().as_posix()
+
+    if cache_dir.endswith(NEMO_VERSION):
+        # Prevent re-caching dataset after upgrading NeMo
+        cache_dir = os.path.dirname(cache_dir)
+    return os.path.join(cache_dir, 'ais')
+
+
+def ais_endpoint() -> str:
+    """Get configured AIS endpoint.
+    """
+    return os.getenv('AIS_ENDPOINT')
+
+
+def bucket_and_object_from_uri(uri: str) -> Tuple[str, str]:
+    """Parse a path to determine bucket and object path.
+
+    Args:
+        uri: Full path to an object on an object store
+
+    Returns:
+        Tuple of strings (bucket_name, object_path)
+    """
+    if not is_datastore_path(uri):
+        raise ValueError(f'Provided URI is not a valid store path: {uri}')
+    uri_parts = pathlib.PurePath(uri).parts
+    bucket = uri_parts[1]
+    object_path = pathlib.PurePath(*uri_parts[2:])
+
+    return str(bucket), str(object_path)
+
+
+def ais_endpoint_to_dir(endpoint: str) -> str:
+    """Convert AIS endpoint to a valid dir name.
+    Used to build cache location.
+
+    Args:
+        endpoint: AIStore endpoint in format https://host:port
+    
+    Returns:
+        Directory formed as `host/port`.
+    """
+    if not endpoint.startswith('http://'):
+        raise ValueError(f'Unexpected format for ais endpoint: {endpoint}')
+
+    endpoint = endpoint.replace('http://', '')
+    host, port = endpoint.split(':')
+    return os.path.join(host, port)
+
+
+def ais_binary() -> str:
+    """Return location of `ais` binary.
+    """
+    path = shutil.which('ais')
+
+    if path is not None:
+        logging.debug('Found AIS binary at %s', path)
+        return path
+
+    logging.warning('AIS binary not found with `which ais`.')
+
+    # Double-check if it exists at the default path
+    default_path = '/usr/local/bin/ais'
+    if os.path.isfile(default_path):
+        logging.info('ais available at the default path: %s', default_path)
+        return default_path
+    else:
+        raise RuntimeError(f'AIS binary not found.')
+
+
+def datastore_path_to_local_path(store_path: str) -> str:
+    """Convert a data store path to a path in a local cache.
+
+    Args:
+        store_path: a path to an object on an object store
+
+    Returns:
+        Path to the same object in local cache.
+    """
+    if store_path.startswith('ais://'):
+        endpoint = ais_endpoint()
+        if endpoint is None:
+            raise RuntimeError(f'AIS endpoint not set, cannot resolve {store_path}')
+
+        local_ais_cache = os.path.join(ais_cache_base(), ais_endpoint_to_dir(endpoint))
+        store_bucket, store_object = bucket_and_object_from_uri(store_path)
+        local_path = os.path.join(local_ais_cache, store_bucket, store_object)
+    else:
+        raise ValueError(f'Unexpected store path format: {store_path}')
+
+    return local_path
+
+
+def get_datastore_object(path: str, force: bool = False, num_retries: int = 5) -> str:
+    """Download an object from a store path and return the local path.
+    If the input `path` is a local path, then nothing will be done, and
+    the original path will be returned.
+
+    Args:
+        path: path to an object
+        force: force download, even if a local file exists
+        num_retries: number of retries if the get command fails
+    
+    Returns:
+        Local path of the object.
+    """
+    if path.startswith('ais://'):
+        endpoint = ais_endpoint()
+        if endpoint is None:
+            raise RuntimeError(f'AIS endpoint not set, cannot resolve {path}')
+
+        local_path = datastore_path_to_local_path(store_path=path)
+
+        if not os.path.isfile(local_path) or force:
+            # Either we don't have the file in cache or we force download it
+            # Enhancement: if local file is present, check some tag and compare against remote
+            local_dir = os.path.dirname(local_path)
+            if not os.path.isdir(local_dir):
+                os.makedirs(local_dir, exist_ok=True)
+
+            cmd = [ais_binary(), 'get', path, local_path]
+
+            # for now info, later debug
+            logging.debug('Downloading from AIS')
+            logging.debug('\tendpoint    %s', endpoint)
+            logging.debug('\tpath:       %s', path)
+            logging.debug('\tlocal path: %s', local_path)
+            logging.debug('\tcmd:        %s', subprocess.list2cmdline(cmd))
+
+            done = False
+            for n in range(num_retries):
+                if not done:
+                    try:
+                        # Use stdout=subprocess.DEVNULL to prevent showing AIS command on each line
+                        subprocess.check_call(cmd, stdout=subprocess.DEVNULL)
+                        done = True
+                    except subprocess.CalledProcessError as err:
+                        logging.warning('Attempt %d of %d failed with: %s', n + 1, num_retries, str(err))
+
+            if not done:
+                raise RuntimeError('Download failed: %s', subprocess.list2cmdline(cmd))
+
+        return local_path
+
+    else:
+        # Assume the file is local
+        return path
+
+
+class DataStoreObject:
+    """A simple class for handling objects in a data store.
+    Currently, this class supports objects on AIStore.
+
+    Args:
+        store_path: path to a store object
+        local_path: path to a local object, may be used to upload local object to store
+        get: get the object from a store
+    """
+
+    def __init__(self, store_path: str, local_path: str = None, get: bool = False):
+        if local_path is not None:
+            raise NotImplementedError('Specifying a local path is currently not supported.')
+
+        self._store_path = store_path
+        self._local_path = local_path
+
+        if get:
+            self.get()
+
+    @property
+    def store_path(self) -> str:
+        """Return store path of the object.
+        """
+        return self._store_path
+
+    @property
+    def local_path(self) -> str:
+        """Return local path of the object.
+        """
+        return self._local_path
+
+    def get(self, force: bool = False) -> str:
+        """Get an object from the store to local cache and return the local path.
+
+        Args:
+            force: force download, even if a local file exists
+
+        Returns:
+            Path to a local object.
+        """
+        if not self.local_path:
+            # Assume the object needs to be downloaded
+            self._local_path = get_datastore_object(self.store_path, force=force)
+        return self.local_path
+
+    def put(self, force: bool = False) -> str:
+        """Push to remote and return the store path
+
+        Args:
+            force: force download, even if a local file exists
+
+        Returns:
+            Path to a (remote) object object on the object store.
+        """
+        raise NotImplementedError()
+
+    def __str__(self):
+        """Return a human-readable description of the object.
+        """
+        description = f'{type(self)}: store_path={self.store_path}, local_path={self.local_path}'
+        return description
+
+
+def datastore_path_to_webdataset_url(store_path: str):
+    """Convert store_path to a WebDataset URL.
+
+    Args:
+        store_path: path to buckets on store
+
+    Returns:
+        URL which can be directly used with WebDataset.
+    """
+    if store_path.startswith('ais://'):
+        url = f'pipe:ais get {store_path} - || true'
+    else:
+        raise ValueError(f'Unknown store path format: {store_path}')
+
+    return url
+
+
+def datastore_object_get(store_object: DataStoreObject) -> bool:
+    """A convenience wrapper for multiprocessing.imap.
+
+    Args:
+        store_object: An instance of DataStoreObject
+
+    Returns:
+        True if get() returned a path.
+    """
+    return store_object.get() is not None
diff --git a/nemo/utils/env_var_parsing.py b/nemo/utils/env_var_parsing.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd37f2b18e9e782f8899a6fce6e49fb5b701a411
--- /dev/null
+++ b/nemo/utils/env_var_parsing.py
@@ -0,0 +1,207 @@
+# The MIT Licence (MIT)
+#
+# Copyright (c) 2016 YunoJuno Ltd
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+#
+# Vendored dependency from : https://github.com/yunojuno/python-env-utils/blob/master/env_utils/utils.py
+#
+#
+# Modified by NVIDIA
+#
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import decimal
+import json
+import os
+
+from dateutil import parser
+
+__all__ = [
+    "get_env",
+    "get_envbool",
+    "get_envint",
+    "get_envfloat",
+    "get_envdecimal",
+    "get_envdate",
+    "get_envdatetime",
+    "get_envlist",
+    "get_envdict",
+    "CoercionError",
+    "RequiredSettingMissingError",
+]
+
+
+class CoercionError(Exception):
+    """Custom error raised when a value cannot be coerced."""
+
+    def __init__(self, key, value, func):
+        msg = "Unable to coerce '{}={}' using {}.".format(key, value, func.__name__)
+        super(CoercionError, self).__init__(msg)
+
+
+class RequiredSettingMissingError(Exception):
+    """Custom error raised when a required env var is missing."""
+
+    def __init__(self, key):
+        msg = "Required env var '{}' is missing.".format(key)
+        super(RequiredSettingMissingError, self).__init__(msg)
+
+
+def _get_env(key, default=None, coerce=lambda x: x, required=False):
+    """
+    Return env var coerced into a type other than string.
+    This function extends the standard os.getenv function to enable
+    the coercion of values into data types other than string (all env
+    vars are strings by default).
+    Args:
+        key: string, the name of the env var to look up
+    Kwargs:
+        default: the default value to return if the env var does not exist. NB the
+            default value is **not** coerced, and is assumed to be of the correct type.
+        coerce: a function that is used to coerce the value returned into
+            another type
+        required: bool, if True, then a RequiredSettingMissingError error is raised
+            if the env var does not exist.
+    Returns the env var, passed through the coerce function
+    """
+    try:
+        value = os.environ[key]
+    except KeyError:
+        if required is True:
+            raise RequiredSettingMissingError(key)
+        else:
+            return default
+
+    try:
+        return coerce(value)
+    except Exception:
+        raise CoercionError(key, value, coerce)
+
+
+# standard type coercion functions
+def _bool(value):
+    if isinstance(value, bool):
+        return value
+
+    return not (value is None or value.lower() in ("false", "0", "no", "n", "f", "none"))
+
+
+def _int(value):
+    return int(value)
+
+
+def _float(value):
+    return float(value)
+
+
+def _decimal(value):
+    return decimal.Decimal(value)
+
+
+def _dict(value):
+    return json.loads(value)
+
+
+def _datetime(value):
+    return parser.parse(value)
+
+
+def _date(value):
+    return parser.parse(value).date()
+
+
+def get_env(key, *default, **kwargs):
+    """
+    Return env var.
+    This is the parent function of all other get_foo functions,
+    and is responsible for unpacking args/kwargs into the values
+    that _get_env expects (it is the root function that actually
+    interacts with environ).
+    Args:
+        key: string, the env var name to look up.
+        default: (optional) the value to use if the env var does not
+            exist. If this value is not supplied, then the env var is
+            considered to be required, and a RequiredSettingMissingError
+            error will be raised if it does not exist.
+    Kwargs:
+        coerce: a func that may be supplied to coerce the value into
+            something else. This is used by the default get_foo functions
+            to cast strings to builtin types, but could be a function that
+            returns a custom class.
+    Returns the env var, coerced if required, and a default if supplied.
+    """
+    assert len(default) in (0, 1), "Too many args supplied."
+    func = kwargs.get('coerce', lambda x: x)
+    required = len(default) == 0
+    default = default[0] if not required else None
+    return _get_env(key, default=default, coerce=func, required=required)
+
+
+def get_envbool(key, *default):
+    """Return env var cast as boolean."""
+    return get_env(key, *default, coerce=_bool)
+
+
+def get_envint(key, *default):
+    """Return env var cast as integer."""
+    return get_env(key, *default, coerce=_int)
+
+
+def get_envfloat(key, *default):
+    """Return env var cast as float."""
+    return get_env(key, *default, coerce=_float)
+
+
+def get_envdecimal(key, *default):
+    """Return env var cast as Decimal."""
+    return get_env(key, *default, coerce=_decimal)
+
+
+def get_envdate(key, *default):
+    """Return env var as a date."""
+    return get_env(key, *default, coerce=_date)
+
+
+def get_envdatetime(key, *default):
+    """Return env var as a datetime."""
+    return get_env(key, *default, coerce=_datetime)
+
+
+def get_envlist(key, *default, **kwargs):
+    """Return env var as a list."""
+    separator = kwargs.get('separator', ' ')
+    return get_env(key, *default, coerce=lambda x: x.split(separator))
+
+
+def get_envdict(key, *default):
+    """Return env var as a dict."""
+    return get_env(key, *default, coerce=_dict)
diff --git a/nemo/utils/formatters/__init__.py b/nemo/utils/formatters/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e3250071955216f6abc505e6181fb59931baa8d
--- /dev/null
+++ b/nemo/utils/formatters/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/nemo/utils/formatters/base.py b/nemo/utils/formatters/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4b6fef881e504bf3fd5363da422fcce8d2a36b2
--- /dev/null
+++ b/nemo/utils/formatters/base.py
@@ -0,0 +1,135 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+
+from nemo.utils.formatters.colors import Fore as ForegroundColors
+from nemo.utils.formatters.utils import check_color_support, to_unicode
+
+__all__ = ["BaseNeMoFormatter"]
+
+
+class BaseFormatter(logging.Formatter):
+    """
+    Log formatter used in Tornado. Key features of this formatter are:
+    * Color support when logging to a terminal that supports it.
+    * Timestamps on every log line.
+    * Robust against str/bytes encoding problems.
+    """
+
+    DEFAULT_FORMAT = "%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s"
+
+    DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
+
+    DEFAULT_COLORS = {
+        logging.DEBUG: ForegroundColors.CYAN,
+        logging.INFO: ForegroundColors.GREEN,
+        logging.WARNING: ForegroundColors.YELLOW,
+        logging.ERROR: ForegroundColors.MAGENTA,
+        logging.CRITICAL: ForegroundColors.RED,
+    }
+
+    def __init__(self, color=True, fmt=None, datefmt=None, colors=None):
+        r"""
+        :arg bool color: Enables color support.
+        :arg string fmt: Log message format.
+          It will be applied to the attributes dict of log records. The
+          text between ``%(color)s`` and ``%(end_color)s`` will be colored
+          depending on the level if color support is on.
+        :arg dict colors: color mappings from logging level to terminal color
+          code
+        :arg string datefmt: Datetime format.
+          Used for formatting ``(asctime)`` placeholder in ``prefix_fmt``.
+        .. versionchanged:: 3.2
+           Added ``fmt`` and ``datefmt`` arguments.
+        """
+
+        if fmt is None:
+            fmt = self.DEFAULT_FORMAT
+
+        if datefmt is None:
+            datefmt = self.DEFAULT_DATE_FORMAT
+
+        if colors is None:
+            colors = self.DEFAULT_COLORS
+
+        logging.Formatter.__init__(self, datefmt=datefmt)
+
+        self._fmt = fmt
+        self._colors = {}
+        self._normal = ""
+
+        if color and check_color_support():
+            self._colors = colors
+            self._normal = ForegroundColors.RESET
+
+    def format(self, record):
+        try:
+            message = record.getMessage()
+            assert isinstance(message, str)  # guaranteed by logging
+            # Encoding notes:  The logging module prefers to work with character
+            # strings, but only enforces that log messages are instances of
+            # basestring.  In python 2, non-ascii bytestrings will make
+            # their way through the logging framework until they blow up with
+            # an unhelpful decoding error (with this formatter it happens
+            # when we attach the prefix, but there are other opportunities for
+            # exceptions further along in the framework).
+            #
+            # If a byte string makes it this far, convert it to unicode to
+            # ensure it will make it out to the logs.  Use repr() as a fallback
+            # to ensure that all byte strings can be converted successfully,
+            # but don't do it by default so we don't add extra quotes to ascii
+            # bytestrings.  This is a bit of a hacky place to do this, but
+            # it's worth it since the encoding errors that would otherwise
+            # result are so useless (and tornado is fond of using utf8-encoded
+            # byte strings wherever possible).
+            record.message = to_unicode(message)
+
+        except Exception as e:
+            record.message = "Bad message (%r): %r" % (e, record.__dict__)
+
+        record.asctime = self.formatTime(record, self.datefmt)
+
+        if record.levelno in self._colors:
+            record.color = self._colors[record.levelno]
+            record.end_color = self._normal
+        else:
+            record.color = record.end_color = ""
+
+        formatted = self._fmt % record.__dict__
+
+        if record.exc_info:
+            if not record.exc_text:
+                record.exc_text = self.formatException(record.exc_info)
+
+        if record.exc_text:
+            # exc_text contains multiple lines.  We need to _safe_unicode
+            # each line separately so that non-utf8 bytes don't cause
+            # all the newlines to turn into '\n'.
+            lines = [formatted.rstrip()]
+            lines.extend(to_unicode(ln) for ln in record.exc_text.split("\n"))
+
+            formatted = "\n".join(lines)
+        return formatted.replace("\n", "\n    ")
+
+
+class BaseNeMoFormatter(BaseFormatter):
+    DEFAULT_FORMAT = "%(color)s[NeMo %(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s"
+
+
+class DebugNeMoFormatter(BaseFormatter):
+    DEFAULT_FORMAT = (
+        "%(color)s[NeMo %(levelname)1.1s %(asctime)s %(module)s:%(lineno)d rank:%(rank)s]%(end_color)s %(message)s"
+    )
diff --git a/nemo/utils/formatters/colors.py b/nemo/utils/formatters/colors.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8eb616b3316479e3d094516155cd6aaae8981e0
--- /dev/null
+++ b/nemo/utils/formatters/colors.py
@@ -0,0 +1,121 @@
+# Source: https://github.com/tartley/colorama/blob/master/colorama/ansi.py
+# Copyright: Jonathan Hartley 2013. BSD 3-Clause license.
+#
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+CSI = "\033["
+OSC = "\033]"
+BEL = "\007"
+
+
+def code_to_chars(code):
+    return CSI + str(code) + "m"
+
+
+def set_title(title):
+    return OSC + "2;" + title + BEL
+
+
+def clear_screen(mode=2):
+    return CSI + str(mode) + "J"
+
+
+def clear_line(mode=2):
+    return CSI + str(mode) + "K"
+
+
+class AnsiCodes(object):
+    def __init__(self):
+        # the subclasses declare class attributes which are numbers.
+        # Upon instantiation we define instance attributes, which are the same
+        # as the class attributes but wrapped with the ANSI escape sequence
+        for name in dir(self):
+            if not name.startswith("_"):
+                value = getattr(self, name)
+                setattr(self, name, code_to_chars(value))
+
+
+class AnsiCursor(object):
+    def UP(self, n=1):
+        return CSI + str(n) + "A"
+
+    def DOWN(self, n=1):
+        return CSI + str(n) + "B"
+
+    def FORWARD(self, n=1):
+        return CSI + str(n) + "C"
+
+    def BACK(self, n=1):
+        return CSI + str(n) + "D"
+
+    def POS(self, x=1, y=1):
+        return CSI + str(y) + ";" + str(x) + "H"
+
+
+class AnsiFore(AnsiCodes):
+    BLACK = 30
+    RED = 31
+    GREEN = 32
+    YELLOW = 33
+    BLUE = 34
+    MAGENTA = 35
+    CYAN = 36
+    WHITE = 37
+    RESET = 39
+
+    # These are fairly well supported, but not part of the standard.
+    LIGHTBLACK_EX = 90
+    LIGHTRED_EX = 91
+    LIGHTGREEN_EX = 92
+    LIGHTYELLOW_EX = 93
+    LIGHTBLUE_EX = 94
+    LIGHTMAGENTA_EX = 95
+    LIGHTCYAN_EX = 96
+    LIGHTWHITE_EX = 97
+
+
+class AnsiBack(AnsiCodes):
+    BLACK = 40
+    RED = 41
+    GREEN = 42
+    YELLOW = 43
+    BLUE = 44
+    MAGENTA = 45
+    CYAN = 46
+    WHITE = 47
+    RESET = 49
+
+    # These are fairly well supported, but not part of the standard.
+    LIGHTBLACK_EX = 100
+    LIGHTRED_EX = 101
+    LIGHTGREEN_EX = 102
+    LIGHTYELLOW_EX = 103
+    LIGHTBLUE_EX = 104
+    LIGHTMAGENTA_EX = 105
+    LIGHTCYAN_EX = 106
+    LIGHTWHITE_EX = 107
+
+
+class AnsiStyle(AnsiCodes):
+    BRIGHT = 1
+    DIM = 2
+    NORMAL = 22
+    RESET_ALL = 0
+
+
+Fore = AnsiFore()
+Back = AnsiBack()
+Style = AnsiStyle()
+Cursor = AnsiCursor()
diff --git a/nemo/utils/formatters/utils.py b/nemo/utils/formatters/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..570f75a7406651722899c7f538f856696df71ce6
--- /dev/null
+++ b/nemo/utils/formatters/utils.py
@@ -0,0 +1,46 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import sys
+
+from nemo.constants import NEMO_ENV_VARNAME_ENABLE_COLORING
+from nemo.utils.env_var_parsing import get_envbool
+
+__all__ = ["check_color_support", "to_unicode"]
+
+
+def check_color_support():
+    # Colors can be forced with an env variable
+    if not sys.platform.lower().startswith("win") and get_envbool(NEMO_ENV_VARNAME_ENABLE_COLORING, False):
+        return True
+
+
+def to_unicode(value):
+    """
+    Converts a string argument to a unicode string.
+    If the argument is already a unicode string or None, it is returned
+    unchanged.  Otherwise it must be a byte string and is decoded as utf8.
+    """
+    try:
+        if isinstance(value, (str, type(None))):
+            return value
+
+        if not isinstance(value, bytes):
+            raise TypeError("Expected bytes, unicode, or None; got %r" % type(value))
+
+        return value.decode("utf-8")
+
+    except UnicodeDecodeError:
+        return repr(value)
diff --git a/nemo/utils/get_rank.py b/nemo/utils/get_rank.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b36eed6246b95754bdab3391f2a523fcd6e159a
--- /dev/null
+++ b/nemo/utils/get_rank.py
@@ -0,0 +1,50 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from nemo.utils.env_var_parsing import get_envint
+
+
+def is_global_rank_zero():
+    """ Helper function to determine if the current process is global_rank 0 (the main process)
+    """
+    # Try to get the pytorch RANK env var
+    # RANK is set by torch.distributed.launch
+    rank = get_envint("RANK", None)
+    if rank is not None:
+        return rank == 0
+
+    # Try to get the SLURM global rank env var
+    # SLURM_PROCID is set by SLURM
+    slurm_rank = get_envint("SLURM_PROCID", None)
+    if slurm_rank is not None:
+        return slurm_rank == 0
+
+    # if neither pytorch and SLURM env vars are set
+    # check NODE_RANK/GROUP_RANK and LOCAL_RANK env vars
+    # asume global_rank is zero if undefined
+    node_rank = get_envint("NODE_RANK", get_envint("GROUP_RANK", 0))
+    local_rank = get_envint("LOCAL_RANK", 0)
+    return node_rank == 0 and local_rank == 0
+
+
+def get_rank():
+    """ Helper function that returns torch.distributed.get_rank() if DDP has been initialized otherwise it returns 0.
+    """
+
+    if is_global_rank_zero():
+        return 0
+    else:
+        return torch.distributed.get_rank()
diff --git a/nemo/utils/lightning_logger_patch.py b/nemo/utils/lightning_logger_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b21ce3b1ae5132a20296c44e090e3e4dc40fd26
--- /dev/null
+++ b/nemo/utils/lightning_logger_patch.py
@@ -0,0 +1,58 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging as _logging
+from logging.handlers import MemoryHandler
+
+import pytorch_lightning as pl
+
+HANDLERS = {}
+PATCHED = False
+
+
+def add_memory_handlers_to_pl_logger():
+    """
+    Adds two MemoryHandlers to pytorch_lightning's logger. These two handlers are essentially message buffers. This
+    function is called in nemo.utils.__init__.py. These handlers are used in add_filehandlers_to_pl_logger to flush
+    buffered messages to files.
+    """
+    if not HANDLERS:
+        HANDLERS["memory_err"] = MemoryHandler(-1)
+        HANDLERS["memory_err"].addFilter(lambda record: record.levelno > _logging.INFO)
+        HANDLERS["memory_all"] = MemoryHandler(-1)
+        pl._logger.addHandler(HANDLERS["memory_err"])
+        pl._logger.addHandler(HANDLERS["memory_all"])
+
+
+def add_filehandlers_to_pl_logger(all_log_file, err_log_file):
+    """
+    Adds two filehandlers to pytorch_lightning's logger. Called in nemo.utils.exp_manager(). The first filehandler
+    logs all messages to all_log_file while the second filehandler logs all WARNING and higher messages to err_log_file.
+    If "memory_err" and "memory_all" exist in HANDLERS, then those buffers are flushed to err_log_file and all_log_file
+    respectively, and then closed.
+    """
+    HANDLERS["file"] = _logging.FileHandler(all_log_file)
+    pl._logger.addHandler(HANDLERS["file"])
+    HANDLERS["file_err"] = _logging.FileHandler(err_log_file)
+    HANDLERS["file_err"].addFilter(lambda record: record.levelno > _logging.INFO)
+    pl._logger.addHandler(HANDLERS["file_err"])
+
+    if HANDLERS.get("memory_all", None):
+        HANDLERS["memory_all"].setTarget(HANDLERS["file"])
+        HANDLERS["memory_all"].close()
+        del HANDLERS["memory_all"]
+    if HANDLERS.get("memory_err", None):
+        HANDLERS["memory_err"].setTarget(HANDLERS["file_err"])
+        HANDLERS["memory_err"].close()
+        del HANDLERS["memory_err"]
diff --git a/nemo/utils/metaclasses.py b/nemo/utils/metaclasses.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fad7cb1501306426c1ab30c0367ddd49583ab8f
--- /dev/null
+++ b/nemo/utils/metaclasses.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import threading
+
+
+class Singleton(type):
+    """ Implementation of a generic, tread-safe singleton meta-class.
+        Can be used as meta-class, i.e. will create 
+    """
+
+    # List of instances - one per class.
+    __instances = {}
+    # Lock used for accessing the instance.
+    __lock = threading.Lock()
+
+    def __call__(cls, *args, **kwargs):
+        """ Returns singleton instance. A thread safe implementation. """
+        if cls not in cls.__instances:
+            # Enter critical section.
+            with cls.__lock:
+                # Check once again.
+                if cls not in cls.__instances:
+                    # Create a new object instance - one per class.
+                    cls.__instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
+        # Return the instance.
+        return cls.__instances[cls]
diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a13c95631e74eaa1ccd68a90f9771f4b955cf82
--- /dev/null
+++ b/nemo/utils/model_utils.py
@@ -0,0 +1,633 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import importlib
+import os
+from dataclasses import dataclass, is_dataclass
+from enum import Enum
+from functools import lru_cache
+from pathlib import Path
+from typing import List, Optional, Tuple, Union
+
+import wrapt
+
+from nemo.utils import AppState, logging
+from nemo.utils.data_utils import resolve_cache_dir  # imported for compatibility: model_utils.resolve_cache_dir()
+from nemo.utils.data_utils import is_datastore_path
+
+# TODO @blisc: Perhaps refactor instead of import guarding
+
+_HAS_HYDRA = True
+
+try:
+    from omegaconf import DictConfig, ListConfig, OmegaConf
+    from omegaconf import errors as omegaconf_errors
+    from packaging import version
+except ModuleNotFoundError:
+    _HAS_HYDRA = False
+
+
+_VAL_TEST_FASTPATH_KEY = 'ds_item'
+
+
+class ArtifactPathType(Enum):
+    """
+    ArtifactPathType refers to the type of the path that the artifact is located at.
+
+    LOCAL_PATH: A user local filepath that exists on the file system.
+    TAR_PATH: A (generally flattened) filepath that exists inside of an archive (that may have its own full path).
+    """
+
+    LOCAL_PATH = 0
+    TAR_PATH = 1
+
+
+@dataclass(init=False)
+class ArtifactItem:
+    path: str
+    path_type: ArtifactPathType
+    hashed_path: Optional[str] = None
+
+
+def resolve_dataset_name_from_cfg(cfg: 'DictConfig') -> Optional[str]:
+    """
+    Parses items of the provided sub-config to find the first potential key that
+    resolves to an existing file or directory.
+
+    # Fast-path Resolution
+    In order to handle cases where we need to resolve items that are not paths, a fastpath
+    key can be provided as defined in the global `_VAL_TEST_FASTPATH_KEY`.
+
+    This key can be used in two ways :
+
+    ## _VAL_TEST_FASTPATH_KEY points to another key in the config
+
+    If this _VAL_TEST_FASTPATH_KEY points to another key in this config itself,
+    then we assume we want to loop through the values of that key.
+
+    This allows for any key in the config to become a fastpath key.
+
+    Example:
+    validation_ds:
+        splits: "val"
+        ...
+        <_VAL_TEST_FASTPATH_KEY>: "splits"  <-- this points to the key name "splits"
+
+    Then we can write the following when overriding in hydra:
+    ```python
+    python train_file.py ... \
+        model.validation_ds.splits=[val1, val2, dev1, dev2] ...
+    ```
+
+    ## _VAL_TEST_FASTPATH_KEY itself acts as the resolved key
+
+    If this _VAL_TEST_FASTPATH_KEY does not point to another key in the config, then
+    it is assumed that the items of this key itself are used for resolution.
+
+    Example:
+    validation_ds:
+        ...
+        <_VAL_TEST_FASTPATH_KEY>: "val"  <-- this points to the key name "splits"
+
+    Then we can write the following when overriding in hydra:
+    ```python
+    python train_file.py ... \
+        model.validation_ds.<_VAL_TEST_FASTPATH_KEY>=[val1, val2, dev1, dev2] ...
+    ```
+
+    # IMPORTANT NOTE:
+    It <can> potentially mismatch if there exist more than 2 valid paths, and the
+    first path does *not* resolve the the path of the data file (but does resolve to
+    some other valid path).
+
+    To avoid this side-effect, place the data path as the first item on the config file.
+
+    Args:
+        cfg: DictConfig (Sub-config) that should be parsed.
+
+    Returns:
+        A str representing the `key` of the config which hosts the filepath(s),
+        or None in case path could not be resolved.
+    """
+    if _VAL_TEST_FASTPATH_KEY in cfg and cfg[_VAL_TEST_FASTPATH_KEY] is not None:
+        fastpath_key = cfg[_VAL_TEST_FASTPATH_KEY]
+
+        if isinstance(fastpath_key, str) and fastpath_key in cfg:
+            return cfg[fastpath_key]
+        else:
+            return _VAL_TEST_FASTPATH_KEY
+
+    for key, value in cfg.items():
+        if type(value) in [list, tuple, ListConfig]:
+            # Count the number of valid paths in the list
+            values_are_paths = 0
+            for val_i in value:
+                val_i = str(val_i)
+                if os.path.exists(val_i) or os.path.isdir(val_i) or is_datastore_path(val_i):
+                    values_are_paths += 1
+                else:
+                    # reset counter and break inner loop
+                    break
+
+            if values_are_paths == len(value):
+                return key
+
+        else:
+            if os.path.exists(str(value)) or os.path.isdir(str(value)) or is_datastore_path(str(value)):
+                return key
+
+    return None
+
+
+def parse_dataset_as_name(name: str) -> str:
+    """
+    Constructs a valid prefix-name from a provided file path.
+
+    Args:
+        name: str path to some valid data/manifest file or a python object that
+            will be used as a name for the data loader (via str() cast).
+
+    Returns:
+        str prefix used to identify uniquely this data/manifest file.
+    """
+    if os.path.exists(str(name)) or os.path.isdir(str(name)) or is_datastore_path(str(name)):
+        name = Path(name).stem
+    else:
+        name = str(name)
+
+    # cleanup name
+    name = name.replace('-', '_')
+
+    if 'manifest' in name:
+        name = name.replace('manifest', '')
+
+    if 'dataset' in name:
+        name = name.replace('dataset', '')
+
+    # Test if the manifes/dataset name was simply `manifest.yaml` or `dataset.yaml`: Invalid names.
+    if name == '':
+        raise ValueError(
+            "Provided dataset / manifest filename was `manifest.json` or `dataset.json`.\n"
+            "Such a name is invalid, since multiple datasets/manifests can share the same name,\n"
+            "thereby overriding their results during logging. Please pick a more discriptive filename \n"
+            "for the provided dataset / manifest file."
+        )
+
+    if '_' != name[-1]:
+        name = name + '_'
+
+    return name
+
+
+def unique_names_check(name_list: Optional[List[str]]):
+    """
+    Performs a uniqueness check on the name list resolved, so that it can warn users
+    about non-unique keys.
+
+    Args:
+        name_list: List of strings resolved for data loaders.
+    """
+    if name_list is None:
+        return
+
+    # Name uniqueness checks
+    names = set()
+    for name in name_list:
+        if name in names:
+            logging.warning(
+                "Name resolution has found more than one data loader having the same name !\n"
+                "In such cases, logs will nor be properly generated. "
+                "Please rename the item to have unique names.\n"
+                f"Resolved name : {name}"
+            )
+        else:
+            names.add(name)  # we need just hash key check, value is just a placeholder
+
+
+def resolve_validation_dataloaders(model: 'ModelPT'):
+    """
+    Helper method that operates on the ModelPT class to automatically support
+    multiple dataloaders for the validation set.
+
+    It does so by first resolving the path to one/more data files via `resolve_dataset_name_from_cfg()`.
+    If this resolution fails, it assumes the data loader is prepared to manually support / not support
+    multiple data loaders and simply calls the appropriate setup method.
+
+    If resolution succeeds:
+        Checks if provided path is to a single file or a list of files.
+        If a single file is provided, simply tags that file as such and loads it via the setup method.
+        If multiple files are provided:
+            Inject a new manifest path at index "i" into the resolved key.
+            Calls the appropriate setup method to set the data loader.
+            Collects the initialized data loader in a list and preserves it.
+            Once all data loaders are processed, assigns the list of loaded loaders to the ModelPT.
+            Finally assigns a list of unique names resolved from the file paths to the ModelPT.
+
+    Args:
+        model: ModelPT subclass, which requires >=1 Validation Dataloaders to be setup.
+    """
+    if not _HAS_HYDRA:
+        logging.error("This function requires Hydra/Omegaconf and it was not installed.")
+        exit(1)
+    cfg = copy.deepcopy(model._cfg)
+    dataloaders = []
+
+    # process val_loss_idx
+    if 'val_dl_idx' in cfg.validation_ds:
+        cfg = OmegaConf.to_container(cfg)
+        val_dl_idx = cfg['validation_ds'].pop('val_dl_idx')
+        cfg = OmegaConf.create(cfg)
+    else:
+        val_dl_idx = 0
+
+    # Set val_loss_idx
+    model._val_dl_idx = val_dl_idx
+
+    ds_key = resolve_dataset_name_from_cfg(cfg.validation_ds)
+
+    if ds_key is None or val_dl_idx < 0:
+        logging.debug(
+            "Could not resolve file path from provided config - {}. "
+            "Disabling support for multi-dataloaders.".format(cfg.validation_ds)
+        )
+
+        model.setup_validation_data(cfg.validation_ds)
+        return
+
+    ds_values = cfg.validation_ds[ds_key]
+
+    if isinstance(ds_values, (list, tuple, ListConfig)):
+
+        for ds_value in ds_values:
+            if isinstance(ds_value, (dict, DictConfig)):
+                # this is a nested dataset
+                cfg.validation_ds = ds_value
+            else:
+                cfg.validation_ds[ds_key] = ds_value
+
+            model.setup_validation_data(cfg.validation_ds)
+            dataloaders.append(model._validation_dl)
+
+        model._validation_dl = dataloaders
+        if len(ds_values) > 0 and isinstance(ds_values[0], (dict, DictConfig)):
+            # using the name of each of the nested dataset
+            model._validation_names = [ds.name for ds in ds_values]
+        else:
+            model._validation_names = [parse_dataset_as_name(ds) for ds in ds_values]
+        unique_names_check(name_list=model._validation_names)
+        return
+
+    else:
+        model.setup_validation_data(cfg.validation_ds)
+        model._validation_names = [parse_dataset_as_name(ds_values)]
+        unique_names_check(name_list=model._validation_names)
+
+
+def resolve_test_dataloaders(model: 'ModelPT'):
+    """
+    Helper method that operates on the ModelPT class to automatically support
+    multiple dataloaders for the test set.
+
+    It does so by first resolving the path to one/more data files via `resolve_dataset_name_from_cfg()`.
+    If this resolution fails, it assumes the data loader is prepared to manually support / not support
+    multiple data loaders and simply calls the appropriate setup method.
+
+    If resolution succeeds:
+        Checks if provided path is to a single file or a list of files.
+        If a single file is provided, simply tags that file as such and loads it via the setup method.
+        If multiple files are provided:
+            Inject a new manifest path at index "i" into the resolved key.
+            Calls the appropriate setup method to set the data loader.
+            Collects the initialized data loader in a list and preserves it.
+            Once all data loaders are processed, assigns the list of loaded loaders to the ModelPT.
+            Finally assigns a list of unique names resolved from the file paths to the ModelPT.
+
+    Args:
+        model: ModelPT subclass, which requires >=1 Test Dataloaders to be setup.
+    """
+    if not _HAS_HYDRA:
+        logging.error("This function requires Hydra/Omegaconf and it was not installed.")
+        exit(1)
+    cfg = copy.deepcopy(model._cfg)
+    dataloaders = []
+
+    # process test_loss_idx
+    if 'test_dl_idx' in cfg.test_ds:
+        cfg = OmegaConf.to_container(cfg)
+        test_dl_idx = cfg['test_ds'].pop('test_dl_idx')
+        cfg = OmegaConf.create(cfg)
+    else:
+        test_dl_idx = 0
+
+    # Set val_loss_idx
+    model._test_dl_idx = test_dl_idx
+
+    ds_key = resolve_dataset_name_from_cfg(cfg.test_ds)
+
+    if ds_key is None:
+        logging.debug(
+            "Could not resolve file path from provided config - {}. "
+            "Disabling support for multi-dataloaders.".format(cfg.test_ds)
+        )
+
+        model.setup_test_data(cfg.test_ds)
+        return
+
+    ds_values = cfg.test_ds[ds_key]
+
+    if isinstance(ds_values, (list, tuple, ListConfig)):
+
+        for ds_value in ds_values:
+            if isinstance(ds_value, (dict, DictConfig)):
+                # this is a nested dataset
+                cfg.test_ds = ds_value
+            else:
+                cfg.test_ds[ds_key] = ds_value
+
+            model.setup_test_data(cfg.test_ds)
+            dataloaders.append(model._test_dl)
+
+        model._test_dl = dataloaders
+        if len(ds_values) > 0 and isinstance(ds_values[0], (dict, DictConfig)):
+            # using the name of each of the nested dataset
+            model._test_names = [ds.name for ds in ds_values]
+        else:
+            model._test_names = [parse_dataset_as_name(ds) for ds in ds_values]
+
+        unique_names_check(name_list=model._test_names)
+        return
+
+    else:
+        model.setup_test_data(cfg.test_ds)
+        model._test_names = [parse_dataset_as_name(ds_values)]
+
+        unique_names_check(name_list=model._test_names)
+
+
+@wrapt.decorator
+def wrap_training_step(wrapped, instance: 'pl.LightningModule', args, kwargs):
+    output_dict = wrapped(*args, **kwargs)
+
+    if isinstance(output_dict, dict) and output_dict is not None and 'log' in output_dict:
+        log_dict = output_dict.pop('log')
+        instance.log_dict(log_dict, on_step=True)
+
+    return output_dict
+
+
+def convert_model_config_to_dict_config(cfg: Union['DictConfig', 'NemoConfig']) -> 'DictConfig':
+    """
+    Converts its input into a standard DictConfig.
+    Possible input values are:
+    -   DictConfig
+    -   A dataclass which is a subclass of NemoConfig
+
+    Args:
+        cfg: A dict-like object.
+
+    Returns:
+        The equivalent DictConfig
+    """
+    if not _HAS_HYDRA:
+        logging.error("This function requires Hydra/Omegaconf and it was not installed.")
+        exit(1)
+    if not isinstance(cfg, (OmegaConf, DictConfig)) and is_dataclass(cfg):
+        cfg = OmegaConf.structured(cfg)
+
+    if not isinstance(cfg, DictConfig):
+        raise ValueError(f"cfg constructor argument must be of type DictConfig/dict but got {type(cfg)} instead.")
+
+    config = OmegaConf.to_container(cfg, resolve=True)
+    config = OmegaConf.create(config)
+    return config
+
+
+def _convert_config(cfg: 'OmegaConf'):
+    """ Recursive function convertint the configuration from old hydra format to the new one. """
+    if not _HAS_HYDRA:
+        logging.error("This function requires Hydra/Omegaconf and it was not installed.")
+        exit(1)
+
+    # Get rid of cls -> _target_.
+    if 'cls' in cfg and '_target_' not in cfg:
+        cfg._target_ = cfg.pop('cls')
+
+    # Get rid of params.
+    if 'params' in cfg:
+        params = cfg.pop('params')
+        for param_key, param_val in params.items():
+            cfg[param_key] = param_val
+
+    # Recursion.
+    try:
+        for _, sub_cfg in cfg.items():
+            if isinstance(sub_cfg, DictConfig):
+                _convert_config(sub_cfg)
+    except omegaconf_errors.OmegaConfBaseException as e:
+        logging.warning(f"Skipped conversion for config/subconfig:\n{cfg}\n Reason: {e}.")
+
+
+def maybe_update_config_version(cfg: 'DictConfig'):
+    """
+    Recursively convert Hydra 0.x configs to Hydra 1.x configs.
+
+    Changes include:
+    -   `cls` -> `_target_`.
+    -   `params` -> drop params and shift all arguments to parent.
+    -   `target` -> `_target_` cannot be performed due to ModelPT injecting `target` inside class.
+
+    Args:
+        cfg: Any Hydra compatible DictConfig
+
+    Returns:
+        An updated DictConfig that conforms to Hydra 1.x format.
+    """
+    if not _HAS_HYDRA:
+        logging.error("This function requires Hydra/Omegaconf and it was not installed.")
+        exit(1)
+    if cfg is not None and not isinstance(cfg, DictConfig):
+        try:
+            temp_cfg = OmegaConf.create(cfg)
+            cfg = temp_cfg
+        except omegaconf_errors.OmegaConfBaseException:
+            # Cannot be cast to DictConfig, skip updating.
+            return cfg
+
+    # Make a copy of model config.
+    cfg = copy.deepcopy(cfg)
+    OmegaConf.set_struct(cfg, False)
+
+    # Convert config.
+    _convert_config(cfg)
+
+    # Update model config.
+    OmegaConf.set_struct(cfg, True)
+
+    return cfg
+
+
+@lru_cache(maxsize=1024)
+def import_class_by_path(path: str):
+    """
+    Recursive import of class by path string.
+    """
+    paths = path.split('.')
+    path = ".".join(paths[:-1])
+    class_name = paths[-1]
+    mod = __import__(path, fromlist=[class_name])
+    mod = getattr(mod, class_name)
+    return mod
+
+
+def resolve_subclass_pretrained_model_info(base_class) -> List['PretrainedModelInfo']:
+    """
+    Recursively traverses the inheritance graph of subclasses to extract all pretrained model info.
+    First constructs a set of unique pretrained model info by performing DFS over the inheritance graph.
+    All model info belonging to the same class is added together.
+
+    Args:
+        base_class: The root class, whose subclass graph will be traversed.
+
+    Returns:
+        A list of unique pretrained model infos belonging to all of the inherited subclasses of
+        this baseclass.
+    """
+    list_of_models = set()
+
+    def recursive_subclass_walk(cls):
+        for subclass in cls.__subclasses__():
+            # step into its immediate subclass
+            recursive_subclass_walk(subclass)
+
+            subclass_models = subclass.list_available_models()
+
+            if subclass_models is not None and len(subclass_models) > 0:
+                # Inject subclass info into pretrained model info
+                # if not already overriden by subclass
+                for model_info in subclass_models:
+                    # If subclass manually injects class_, dont override.
+                    if model_info.class_ is None:
+                        model_info.class_ = subclass
+
+                for model_info in subclass_models:
+                    list_of_models.add(model_info)
+
+    recursive_subclass_walk(base_class)
+
+    list_of_models = list(sorted(list_of_models))
+    return list_of_models
+
+
+def check_lib_version(lib_name: str, checked_version: str, operator) -> Tuple[Optional[bool], str]:
+    """
+    Checks if a library is installed, and if it is, checks the operator(lib.__version__, checked_version) as a result.
+    This bool result along with a string analysis of result is returned.
+
+    If the library is not installed at all, then returns None instead, along with a string explaining
+    that the library is not installed
+
+    Args:
+        lib_name: lower case str name of the library that must be imported.
+        checked_version: semver string that is compared against lib.__version__.
+        operator: binary callable function func(a, b) -> bool; that compares lib.__version__ against version in
+            some manner. Must return a boolean.
+
+    Returns:
+        A tuple of results:
+        -   Bool or None. Bool if the library could be imported, and the result of
+            operator(lib.__version__, checked_version) or False if __version__ is not implemented in lib.
+            None is passed if the library is not installed at all.
+        -   A string analysis of the check.
+    """
+    try:
+        if '.' in lib_name:
+            mod = import_class_by_path(lib_name)
+        else:
+            mod = importlib.import_module(lib_name)
+
+        if hasattr(mod, '__version__'):
+            lib_ver = version.Version(mod.__version__)
+            match_ver = version.Version(checked_version)
+
+            if operator(lib_ver, match_ver):
+                msg = f"Lib {lib_name} version is satisfied !"
+                return True, msg
+            else:
+                msg = (
+                    f"Lib {lib_name} version ({lib_ver}) is not {operator.__name__} than required version {checked_version}.\n"
+                    f"Please upgrade the lib using either pip or conda to the latest version."
+                )
+                return False, msg
+        else:
+            msg = (
+                f"Lib {lib_name} does not implement __version__ in its init file. "
+                f"Could not check version compatibility."
+            )
+            return False, msg
+    except (AttributeError, ImportError, ModuleNotFoundError):
+        pass
+
+    msg = f"Lib {lib_name} has not been installed. Please use pip or conda to install this package."
+    return None, msg
+
+
+def uninject_model_parallel_rank(filepath):
+    filepath = str(filepath)
+    if 'mp_rank' in filepath or 'tp_rank' in filepath:
+        dirname = os.path.dirname(os.path.dirname(filepath))
+        basename = os.path.basename(filepath)
+        filepath = os.path.join(dirname, basename)
+        return filepath
+    else:
+        return filepath
+
+
+def inject_model_parallel_rank(filepath):
+    """
+    Injects tensor/pipeline model parallel ranks into the filepath.
+    Does nothing if not using model parallelism.
+    """
+    # first make sure filepath does not have rank
+    filepath = uninject_model_parallel_rank(filepath)
+
+    app_state = AppState()
+    if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
+        # filepath needs to be updated to include mp_rank
+        dirname = os.path.dirname(filepath)
+        basename = os.path.basename(filepath)
+        if app_state.pipeline_model_parallel_size is None or app_state.pipeline_model_parallel_size == 1:
+            filepath = f'{dirname}/mp_rank_{app_state.tensor_model_parallel_rank:02d}/{basename}'
+        else:
+            filepath = f'{dirname}/tp_rank_{app_state.tensor_model_parallel_rank:02d}_pp_rank_{app_state.pipeline_model_parallel_rank:03d}/{basename}'
+        return filepath
+    else:
+        return filepath
+
+
+def ckpt_to_dir(filepath: Union[str, Path]) -> Path:
+    """ PTL considers checkpoints as .ckpt files.
+        This method removes the extension and returns a path
+        to be used as a directory for distributed checkpoints
+    """
+
+    filepath = Path(filepath)
+
+    # adding this assert because we will later remove directories based on the return value of this method
+    assert filepath.suffix == ".ckpt", f'filepath: {filepath} must have .ckpt extension'
+
+    # create a new path whose name is the original filepath without the .ckpt extension
+    checkpoint_dir = filepath.with_name(filepath.stem)
+
+    return checkpoint_dir
diff --git a/nemo/utils/nemo_logging.py b/nemo/utils/nemo_logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..95e17e5c5f6c43716c7f40d5cbf7b07c9efb420f
--- /dev/null
+++ b/nemo/utils/nemo_logging.py
@@ -0,0 +1,421 @@
+# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import enum
+import logging as _logging
+import sys
+import threading
+import warnings
+from contextlib import contextmanager
+from logging.handlers import MemoryHandler
+
+from nemo.constants import NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR, NEMO_ENV_VARNAME_TESTING
+from nemo.utils.env_var_parsing import get_envbool
+from nemo.utils.formatters.base import BaseNeMoFormatter, DebugNeMoFormatter
+from nemo.utils.get_rank import is_global_rank_zero
+from nemo.utils.metaclasses import Singleton
+
+__all__ = ["Logger", "LogMode"]
+
+
+class LogMode(enum.IntEnum):
+    EACH = 0  # Log the message each time
+    ONCE = 1  # Log the message only once. The same message will not be logged again.
+
+
+class Logger(metaclass=Singleton):
+
+    # Level 0
+    NOTSET = _logging.NOTSET
+
+    # Level 10
+    DEBUG = _logging.DEBUG
+
+    # Level 20
+    INFO = _logging.INFO
+
+    # Level 30
+    WARNING = _logging.WARNING
+
+    # Level 40
+    ERROR = _logging.ERROR
+
+    # Level 50
+    CRITICAL = _logging.CRITICAL
+
+    _level_names = {
+        0: "NOTSET",
+        10: "DEBUG",
+        20: "INFO",
+        30: "WARNING",
+        40: "ERROR",
+        50: "CRITICAL",
+    }
+
+    def __init__(self, capture_warnings=True):
+
+        self._logger = None
+        # Multi-GPU runs run in separate processes, thread locks shouldn't be needed
+        self._logger_lock = threading.Lock()
+        self._handlers = dict()
+        self.old_warnings_showwarning = None
+        self._define_logger(capture_warnings)
+        self.once_logged = set()
+        self.rank = 0 if is_global_rank_zero() else "UNK"
+
+    def _define_logger(self, capture_warnings=True):
+        """ Creates the logger if not already created. Called in init"""
+
+        # Use double-checked locking to avoid taking lock unnecessarily.
+        if self._logger is not None:
+            return self._logger
+
+        with self._logger_lock:
+            try:
+                self._logger = _logging.getLogger("nemo_logger")
+                # By default, silence all loggers except the logger for rank 0
+                self.remove_stream_handlers()
+                # If NEMO_TESTING is set, add a streamhandler to all ranks
+                if get_envbool(NEMO_ENV_VARNAME_TESTING, False):
+                    old_factory = _logging.getLogRecordFactory()
+
+                    def record_factory(*args, **kwargs):
+                        record = old_factory(*args, **kwargs)
+                        record.rank = self.rank
+                        return record
+
+                    _logging.setLogRecordFactory(record_factory)
+                    self.add_stream_handlers(formatter=DebugNeMoFormatter)
+                elif is_global_rank_zero():
+                    self.add_stream_handlers()
+
+                # Add memoryhandlers, essentially buffers. They are used to save messages that we will flush to file
+                # once the appropriate file handlers are added.
+                if is_global_rank_zero():
+                    # Add a memoryhandler for error messages. Only logged on rank 0
+                    self._handlers["memory_err"] = MemoryHandler(-1)
+                    self._handlers["memory_err"].addFilter(lambda record: record.levelno > _logging.INFO)
+                    formatter = BaseNeMoFormatter
+                    self._handlers["memory_err"].setFormatter(formatter())
+                    self._logger.addHandler(self._handlers["memory_err"])
+                # Add a memoryhandler for all messages on all ranks
+                self._handlers["memory_all"] = MemoryHandler(-1)
+                formatter = BaseNeMoFormatter
+                self._handlers["memory_all"].setFormatter(formatter())
+                self._logger.addHandler(self._handlers["memory_all"])
+
+            finally:
+                level = Logger.INFO
+                if get_envbool(NEMO_ENV_VARNAME_TESTING, False):
+                    level = Logger.DEBUG
+                self.set_verbosity(verbosity_level=level)
+                self.captureWarnings(capture_warnings)
+
+        self._logger.propagate = False
+
+    def remove_stream_handlers(self):
+        """ Removes StreamHandler that log to stdout and stderr from the logger."""
+        if self._logger is None:
+            raise RuntimeError("Impossible to set handlers if the Logger is not predefined")
+
+        # ======== Remove Handler if already existing ========
+
+        try:
+            self._logger.removeHandler(self._handlers["stream_stdout"])
+            del self._handlers["stream_stdout"]
+        except KeyError:
+            pass
+
+        try:
+            self._logger.removeHandler(self._handlers["stream_stderr"])
+            del self._handlers["stream_stderr"]
+        except KeyError:
+            pass
+
+    def add_stream_handlers(self, formatter=BaseNeMoFormatter):
+        """Add StreamHandler that log to stdout and stderr to the logger. INFO and lower logs are streamed to stdout
+        while WARNING and higher are streamed to stderr. If the NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR environment
+        variable is set, all logs are sent to stderr instead.
+        """
+        if self._logger is None:
+            raise RuntimeError("Impossible to set handlers if the Logger is not predefined")
+
+        # Add the output handler.
+        if get_envbool(NEMO_ENV_VARNAME_REDIRECT_LOGS_TO_STDERR, False):
+            self._handlers["stream_stdout"] = _logging.StreamHandler(sys.stderr)
+
+        else:
+            self._handlers["stream_stdout"] = _logging.StreamHandler(sys.stdout)
+            self._handlers["stream_stdout"].addFilter(lambda record: record.levelno <= _logging.INFO)
+
+            self._handlers["stream_stderr"] = _logging.StreamHandler(sys.stderr)
+            self._handlers["stream_stderr"].addFilter(lambda record: record.levelno > _logging.INFO)
+
+        self._handlers["stream_stdout"].setFormatter(formatter())
+        self._logger.addHandler(self._handlers["stream_stdout"])
+
+        try:
+            self._handlers["stream_stderr"].setFormatter(formatter())
+            self._logger.addHandler(self._handlers["stream_stderr"])
+        except KeyError:
+            pass
+
+    def reset_stream_handler(self, formatter=BaseNeMoFormatter):
+        """Removes then adds stream handlers."""
+        self.remove_stream_handlers()
+        self.add_stream_handlers(formatter=formatter)
+
+    def add_file_handler(self, log_file):
+        """Add a FileHandler to logger that logs all messages to a file. If the logger had a MemoryHandler at
+        self._handlers["memory_all"], those buffered messages are flushed to the new file, and the MemoryHandler is
+        closed."""
+        if self._logger is None:
+            raise RuntimeError("Impossible to set handlers if the Logger is not predefined")
+
+        self._handlers["file"] = _logging.FileHandler(log_file)
+        formatter = BaseNeMoFormatter
+        self._handlers["file"].setFormatter(formatter())
+        self._logger.addHandler(self._handlers["file"])
+
+        if self._handlers.get("memory_all", None):
+            self._handlers["memory_all"].setTarget(self._handlers["file"])
+            self._handlers["memory_all"].close()  # flush and remove
+            del self._handlers["memory_all"]
+
+    def add_err_file_handler(self, log_file):
+        """Add a FileHandler to logger that logs all WARNING and higher messages to a file. If the logger had a
+        MemoryHandler at self._handlers["memory_err"], those buffered messages are flushed to the new file, and the
+        MemoryHandler is closed."""
+        if self._logger is None:
+            raise RuntimeError("Impossible to set handlers if the Logger is not predefined")
+
+        self._handlers["file_err"] = _logging.FileHandler(log_file)
+        self._handlers["file_err"].addFilter(lambda record: record.levelno > _logging.INFO)
+
+        formatter = BaseNeMoFormatter
+        self._handlers["file_err"].setFormatter(formatter())
+        self._logger.addHandler(self._handlers["file_err"])
+
+        if self._handlers.get("memory_err", None):
+            self._handlers["memory_err"].setTarget(self._handlers["file_err"])
+            self._handlers["memory_err"].close()  # flush and remove
+            del self._handlers["memory_err"]
+
+    def getEffectiveLevel(self):
+        """Return how much logging output will be produced."""
+        if self._logger is not None:
+            return self._logger.getEffectiveLevel()
+
+    def get_verbosity(self):
+        """See getEffectiveLevel"""
+        return self.getEffectiveLevel()
+
+    def setLevel(self, verbosity_level):
+        """Sets the threshold for what messages will be logged."""
+        if self._logger is not None:
+            self._logger.setLevel(verbosity_level)
+
+            for handler in self._logger.handlers:
+                handler.setLevel(verbosity_level)
+
+    def set_verbosity(self, verbosity_level):
+        """See setLevel"""
+        self.setLevel(verbosity_level)
+
+    @contextmanager
+    def patch_stderr_handler(self, stream):
+        """ Sends messages that should log to stderr to stream instead. Useful for unittests """
+        if self._logger is not None:
+            try:
+                old_stream = self._handlers["stream_stderr"].stream
+                if old_stream is None:
+                    raise ValueError
+
+                # Port backwards set_stream() from python 3.7
+                self._handlers["stream_stderr"].acquire()
+                try:
+                    self._handlers["stream_stderr"].flush()
+                    self._handlers["stream_stderr"].stream = stream
+                finally:
+                    self._handlers["stream_stderr"].release()
+
+                yield stream
+            except (KeyError, ValueError):
+                raise RuntimeError("Impossible to patch logging handlers if handler does not exist")
+            finally:
+                # Port backwards set_stream() from python 3.7
+                self._handlers["stream_stderr"].acquire()
+                try:
+                    self._handlers["stream_stderr"].flush()
+                    self._handlers["stream_stderr"].stream = old_stream
+                finally:
+                    self._handlers["stream_stderr"].release()
+
+        else:
+            raise RuntimeError("Impossible to patch logging handlers if handler does not exist")
+
+    @contextmanager
+    def patch_stdout_handler(self, stream):
+        """ Sends messages that should log to stdout to stream instead. Useful for unittests """
+        if self._logger is not None:
+            try:
+                old_stream = self._handlers["stream_stdout"].stream
+                if old_stream is None:
+                    raise ValueError
+
+                # Port backwards set_stream() from python 3.7
+                self._handlers["stream_stdout"].acquire()
+                try:
+                    self._handlers["stream_stdout"].flush()
+                    self._handlers["stream_stdout"].stream = stream
+                finally:
+                    self._handlers["stream_stdout"].release()
+
+                yield stream
+            except (KeyError, ValueError):
+                raise RuntimeError("Impossible to patch logging handlers if handler does not exist")
+            finally:
+                # Port backwards set_stream() from python 3.7
+                self._handlers["stream_stdout"].acquire()
+                try:
+                    self._handlers["stream_stdout"].flush()
+                    self._handlers["stream_stdout"].stream = old_stream
+                finally:
+                    self._handlers["stream_stdout"].release()
+
+        else:
+            raise RuntimeError("Impossible to patch logging handlers if handler does not exist")
+
+    @contextmanager
+    def temp_verbosity(self, verbosity_level):
+        """Sets the a temporary threshold for what messages will be logged."""
+
+        if self._logger is not None:
+
+            old_verbosity = self.get_verbosity()
+
+            try:
+                self.set_verbosity(verbosity_level)
+                yield
+
+            finally:
+                self.set_verbosity(old_verbosity)
+
+        else:
+            try:
+                yield
+
+            finally:
+                pass
+
+    def captureWarnings(self, capture):
+        """
+        If capture is true, redirect all warnings to the logging package.
+        If capture is False, ensure that warnings are not redirected to logging
+        but to their original destinations.
+        """
+
+        if self._logger is not None:
+
+            if capture and self.old_warnings_showwarning is None:
+                # Backup Method
+                self.old_warnings_showwarning = warnings.showwarning
+                warnings.showwarning = self._showwarning
+
+            elif not capture and self.old_warnings_showwarning is not None:
+                # Restore Method
+                warnings.showwarning = self.old_warnings_showwarning
+                self.old_warnings_showwarning = None
+
+    def _showwarning(self, message, category, filename, lineno, file=None, line=None):
+        """
+        Implementation of showwarnings which redirects to logging.
+        It will call warnings.formatwarning and will log the resulting string
+        with level logging.WARNING.
+        """
+        s = warnings.formatwarning(message, category, filename, lineno, line)
+        self.warning("%s", s)
+
+    def _logged_once(self, msg, mode):
+        PREFIX_LEN = 12
+        if mode == LogMode.ONCE:
+            if msg[PREFIX_LEN:] in self.once_logged:
+                return True
+            self.once_logged.add(msg[PREFIX_LEN:])
+        return False
+
+    def debug(self, msg, *args, mode=LogMode.EACH, **kwargs):
+        """
+        Log 'msg % args' with severity 'DEBUG'.
+
+        To pass exception information, use the keyword argument exc_info with
+        a true value, e.g.
+
+        logger.debug("Houston, we have a %s", "thorny problem", exc_info=1)
+        """
+        if self._logger is not None and self._logger.isEnabledFor(Logger.DEBUG) and not self._logged_once(msg, mode):
+            self._logger._log(Logger.DEBUG, msg, args, **kwargs)
+
+    def info(self, msg, *args, mode=LogMode.EACH, **kwargs):
+        """
+        Log 'msg % args' with severity 'INFO'.
+
+        To pass exception information, use the keyword argument exc_info with
+        a true value, e.g.
+
+        logger.info("Houston, we have a %s", "interesting problem", exc_info=1)
+        """
+        if self._logger is not None and self._logger.isEnabledFor(Logger.INFO) and not self._logged_once(msg, mode):
+            self._logger._log(Logger.INFO, msg, args, **kwargs)
+
+    def warning(self, msg, *args, mode=LogMode.EACH, **kwargs):
+        """
+        Log 'msg % args' with severity 'WARNING'.
+
+        To pass exception information, use the keyword argument exc_info with
+        a true value, e.g.
+
+        logger.warning("Houston, we have a %s", "bit of a problem", exc_info=1)
+        """
+        if self._logger is not None and self._logger.isEnabledFor(Logger.WARNING) and not self._logged_once(msg, mode):
+            self._logger._log(Logger.WARNING, msg, args, **kwargs)
+
+    def error(self, msg, *args, mode=LogMode.EACH, **kwargs):
+        """
+        Log 'msg % args' with severity 'ERROR'.
+
+        To pass exception information, use the keyword argument exc_info with
+        a true value, e.g.
+
+        logger.error("Houston, we have a %s", "major problem", exc_info=1)
+        """
+        if self._logger is not None and self._logger.isEnabledFor(Logger.ERROR) and not self._logged_once(msg, mode):
+            self._logger._log(Logger.ERROR, msg, args, **kwargs)
+
+    def critical(self, msg, *args, mode=LogMode.EACH, **kwargs):
+        """
+        Log 'msg % args' with severity 'CRITICAL'.
+
+        To pass exception information, use the keyword argument exc_info with
+        a true value, e.g.
+
+        logger.critical("Houston, we have a %s", "major disaster", exc_info=1)
+        """
+        if (
+            self._logger is not None
+            and self._logger.isEnabledFor(Logger.CRITICAL)
+            and not self._logged_once(msg, mode)
+        ):
+            self._logger._log(Logger.CRITICAL, msg, args, **kwargs)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8037bbf24abe4e2d7367528fba2b548fc09d6f34
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,19 @@
+numpy # Don't remove - Needed for reading images
+Pillow # Don't remove - Needed for reading images
+scikit-learn # Used in local evaluation
+
+pandas
+torch
+torchvision
+PyYAML
+ultralytics
+gitpython
+ensemble-boxes
+onnxruntime
+openvino-dev>=2023.0
+
+timm
+albumentations
+wandb
+lightning
+torchmetrics
diff --git a/resources/background_1k.csv b/resources/background_1k.csv
new file mode 100644
index 0000000000000000000000000000000000000000..91deee14fc5b703ec16fcb15c878738fb12a0857
--- /dev/null
+++ b/resources/background_1k.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9100495c4e76065fc71b188affc346dfd9bc6eba9ae1d5773879b29006edc3fb
+size 63446
diff --git a/resources/effnetv2s.json b/resources/effnetv2s.json
new file mode 100644
index 0000000000000000000000000000000000000000..3b9c1c69d35675bad694b142f7a3e2260260ce29
--- /dev/null
+++ b/resources/effnetv2s.json
@@ -0,0 +1 @@
+{"version": "4.1.0", "seed": 42, "folds": 4, "folds_seed": 42, "imgsz": 512, "ar": null, "center_crop": null, "crop_ratio": null, "image_size": 512, "backbone": "tf_efficientnetv2_s", "global_pool": "avg", "num_classes": 7, "pretrained": true, "max_pixel": 255.0, "IMG_MEAN": [0.485, 0.456, 0.406], "IMG_STD": [0.229, 0.224, 0.225], "epochs": 96, "batch_size": 32, "val_batch_size": 32, "accumulate_grad_batches": 1, "gradient_clip_val": null, "cutmix_prob": 0.5, "cutmix_alpha": 1.0, "mixup_prob": 0.5, "mixup_alpha": 0.2, "optimizer": "AdamW", "lr0": 0.001, "lrf": 0.0, "scheduler": "cos_lr", "dropout": 0.0, "swa_lrs": null, "ema": 0.999, "save_top_k": 5, "label_smoothing": 0.1, "sampler": null, "batch_sampler": null, "batch_sampler_alpha": 0.25, "precision": "16-mixed", "device": "gpu", "deterministic": true, "num_workers": 8, "pruning": null}
\ No newline at end of file
diff --git a/resources/files_10k_bb_4folds.csv b/resources/files_10k_bb_4folds.csv
new file mode 100644
index 0000000000000000000000000000000000000000..fb50c46c8738cd12b4a90f121d222c0ff237cc2a
--- /dev/null
+++ b/resources/files_10k_bb_4folds.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e2d4530fc58a5c41ab37249e8205f712e814ccf22b051b07534154613d8b6f1f
+size 1846433
diff --git a/resources/files_external_17k_inaturalist_kaggle_bb_4folds.csv b/resources/files_external_17k_inaturalist_kaggle_bb_4folds.csv
new file mode 100644
index 0000000000000000000000000000000000000000..f61a8fb1ec9fea95c0f8fb2c0ff6f729c5e48b87
--- /dev/null
+++ b/resources/files_external_17k_inaturalist_kaggle_bb_4folds.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8e468e669863d54b77c850db1986066df86510052804aef9525b1c1a0b48a450
+size 3920445
diff --git a/resources/files_external_6k_inaturalist_s3_bb_4folds.csv b/resources/files_external_6k_inaturalist_s3_bb_4folds.csv
new file mode 100644
index 0000000000000000000000000000000000000000..77e9aa14d2418e8b33ca7dfe322ceffc3bef6453
--- /dev/null
+++ b/resources/files_external_6k_inaturalist_s3_bb_4folds.csv
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8f39c5cb814644c42d5341ebaccc2b62e02741d641482b6419a14d72822719bc
+size 1291961
diff --git a/resources/tinyvit.json b/resources/tinyvit.json
new file mode 100644
index 0000000000000000000000000000000000000000..f7e380299db8d03a10ca77df9ab3c63f7b889648
--- /dev/null
+++ b/resources/tinyvit.json
@@ -0,0 +1 @@
+{"version": "4.0.0", "seed": 42, "folds": 4, "folds_seed": 42, "imgsz": 384, "ar": null, "center_crop": null, "crop_ratio": null, "image_size": 384, "backbone": "tiny_vit_21m_384", "global_pool": "avg", "num_classes": 7, "pretrained": true, "max_pixel": 255.0, "IMG_MEAN": [0.485, 0.456, 0.406], "IMG_STD": [0.229, 0.224, 0.225], "epochs": 96, "batch_size": 32, "val_batch_size": 32, "accumulate_grad_batches": 1, "gradient_clip_val": null, "cutmix_prob": 0.5, "cutmix_alpha": 1.0, "mixup_prob": 0.5, "mixup_alpha": 0.2, "optimizer": "AdamW", "lr0": 0.0001, "lrf": 0.0, "scheduler": "cos_lr", "dropout": 0.0, "swa_lrs": null, "ema": 0.999, "save_top_k": 5, "label_smoothing": 0.1, "sampler": null, "batch_sampler": null, "batch_sampler_alpha": 0.25, "precision": "16-mixed", "device": "gpu", "deterministic": true, "num_workers": 8, "pruning": null}
\ No newline at end of file
diff --git a/resources/yolov8n.json b/resources/yolov8n.json
new file mode 100644
index 0000000000000000000000000000000000000000..33b677ee9f3d5ade906f97802eeaf9ae0bbf3e8b
--- /dev/null
+++ b/resources/yolov8n.json
@@ -0,0 +1 @@
+{"seed": 42, "imgsz": 768, "patience": 128, "epochs": 128, "batch": 16, "val": true, "pretrained": true, "single_cls": true, "optimizer": "auto", "degrees": 10.0, "shear": 2.0, "flipud": 0.5, "mixup": 0.25, "dropout": 0.0, "box": 7.5, "cls": 0.5, "dfl": 1.5, "label_smoothing": 0.0, "amp": true, "device": 0}
\ No newline at end of file
diff --git a/train_bb_detector.py b/train_bb_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..bad733c4cc2f4c010491ae8a1e62009f73300804
--- /dev/null
+++ b/train_bb_detector.py
@@ -0,0 +1,30 @@
+import json
+import pandas as pd
+import os
+import argparse
+
+from mqt.training.train_bb_detector import train_yolo
+from my_models.utils.torch import Config
+
+
+def check_path(path):
+    if os.path.exists(path):
+        return True
+    else:
+        raise Exception("Path not found: %s" % path)
+
+
+if __name__ == '__main__':
+
+    # python train_bb_detector.py --config ./resources/yolov8n.json --images_path ./data/images --initial_cv ./resources/files_10k_bb_4folds.csv
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--config', type=str, default='./resources/yolov8n.json', help='Path to yolo configuration file')
+    parser.add_argument('--images_path', type=str, default="./data/images", help='Path to images folder')
+    parser.add_argument('--initial_cv', type=str, default="./resources/files_10k_bb_4folds.csv", help='Path to initial files with CV split')
+    parser.add_argument('--models_home', type=str, default="./mosquito_yolo_models", help='Yolo model home')
+    parser.add_argument('--image_size', type=int, default=768, help='Image size')
+    parser.add_argument('--epochs', type=int, default=128, help='Epochs')
+
+    args = parser.parse_args()
+
diff --git a/train_classifier.py b/train_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..093a513d4f4e44ccd8f3b2d37069e2b1b7e8ac10
--- /dev/null
+++ b/train_classifier.py
@@ -0,0 +1,80 @@
+import json
+import pandas as pd
+import os
+import argparse
+
+from mqt.training.train_classifier import train_cls
+from my_models.utils.torch import Config
+
+
+def check_path(path):
+    if os.path.exists(path):
+        return True
+    else:
+        raise Exception("Path not found: %s" % path)
+
+
+if __name__ == '__main__':
+
+    # python train_classifier.py --config ./resources/effnetv2s.json --images_bb_path ./data/stage2_images_bb --background_path ./data/images/background --initial_cv ./resources/files_10k_bb_4folds.csv --background ./resources/background_1k.csv --external_cv ./resources/files_external_17k_inaturalist_kaggle_bb_4folds.csv,./resources/files_external_6k_inaturalist_s3_bb_4folds.csv --wandb_project MosquitoClassifier --full_train
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--config', type=str, default='./resources/tinyvit.json', help='Path to classifier configuration file')
+    parser.add_argument('--images_bb_path', type=str, default="./data/stage2_images_bb", help='Path to bounding boxes images folder')
+    parser.add_argument('--background_path', type=str, default="./data/images/background", help='Path to background images folder')
+    parser.add_argument('--initial_cv', type=str, default="./resources/files_10k_bb_4folds.csv", help='Path to initial files with CV split')
+    parser.add_argument('--background', type=str, default="./resources/background_1k.csv", help='Path to background files')
+    parser.add_argument('--external_cv', type=str, default="./resources/files_external_17k_inaturalist_kaggle_bb_4folds.csv", help='Path to external files with CV split')
+    parser.add_argument('--full_train', default=False, action="store_true", help='Enable full train')
+    parser.add_argument('--wandb_project', type=str, default=None, help='Optional Wandb project')
+    parser.add_argument('--models_home', type=str, default="./mosquito_models", help='Model home')
+    parser.add_argument('--epochs', type=int, default=96, help='Epochs')
+    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
+    parser.add_argument('--num_workers', type=int, default=8, help='CPU workers')
+
+    args = parser.parse_args()
+
+    train_file = args.initial_cv
+    external_files = args.external_cv
+    background_file = args.background
+    train_boxes_home = args.images_bb_path
+    check_path(train_boxes_home)
+    background_home = args.background_path
+    check_path(background_home)
+    config_path = args.config
+    full_train = args.full_train
+    wandb_project = args.wandb_project
+
+    if wandb_project is not None:
+        print("Wandb project:", wandb_project)
+
+    check_path(train_file)
+    train_pd = pd.read_csv(train_file)
+    print("Initial data loaded:", train_pd.shape)
+
+    check_path(background_file)
+    train_background_pd = pd.read_csv(background_file)
+    print("Background data loaded:", train_background_pd.shape)
+
+    external_pd = None
+    external_files = external_files.split(",")
+    for external_file in external_files:
+        check_path(external_file)
+        external_pd_ = pd.read_csv(external_file)
+        external_pd = pd.concat([external_pd, external_pd_]) if external_pd is not None else external_pd_
+    print("External data loaded:", external_pd.shape)
+
+    check_path(config_path)
+    config = Config(json.load(open(config_path, "r")))
+    print("Config loaded:", config.__dict__)
+
+    config.epochs = args.epochs
+    config.batch_size = args.batch_size
+    config.val_batch_size = config.batch_size
+    config.num_workers = args.num_workers
+
+    print("Full train:", full_train)
+    print()
+
+    train_cls(train_boxes_home, config, train_pd, train_background_pd=train_background_pd, external_pd=external_pd,
+                  wandb_project=wandb_project, models_home=args.models_home, full_train=full_train)