diff --git a/environment.yaml b/environment.yaml deleted file mode 100644 index d9c6b475d424f4320098c9d5f66838d76e85f40f..0000000000000000000000000000000000000000 --- a/environment.yaml +++ /dev/null @@ -1,103 +0,0 @@ -name: zew -channels: - - pytorch - - defaults -dependencies: - - blas=1.0 - - bzip2=1.0.8 - - ca-certificates=2021.10.26 - - certifi=2021.10.8 - - ffmpeg=4.3 - - freetype=2.11.0 - - gettext=0.21.0 - - giflib=5.2.1 - - gmp=6.2.1 - - gnutls=3.6.15 - - icu=58.2 - - intel-openmp=2021.4.0 - - jpeg=9d - - lame=3.100 - - lcms2=2.12 - - libcxx=12.0.0 - - libffi=3.3 - - libiconv=1.16 - - libidn2=2.3.2 - - libpng=1.6.37 - - libtasn1=4.16.0 - - libtiff=4.2.0 - - libunistring=0.9.10 - - libuv=1.40.0 - - libwebp=1.2.0 - - libwebp-base=1.2.0 - - libxml2=2.9.12 - - llvm-openmp=12.0.0 - - lz4-c=1.9.3 - - mkl=2021.4.0 - - mkl-service=2.4.0 - - mkl_fft=1.3.1 - - mkl_random=1.2.2 - - ncurses=6.3 - - nettle=3.7.3 - - numpy=1.21.2 - - numpy-base=1.21.2 - - olefile=0.46 - - openh264=2.1.1 - - openssl=1.1.1m - - pillow=8.4.0 - - pip=21.2.4 - - python=3.8.12 - - pytorch=1.10.2 - - readline=8.1.2 - - setuptools=58.0.4 - - sqlite=3.37.0 - - tk=8.6.11 - - torchaudio=0.10.2 - - torchvision=0.11.3 - - typing_extensions=3.10.0.2 - - wheel=0.37.1 - - xz=5.2.5 - - zlib=1.2.11 - - zstd=1.4.9 - - pip: - - aicrowd-gym==0.0.3 - - aicrowd-gym-internal==0.0.4 - - arrow==1.2.2 - - black==22.1.0 - - click==8.0.3 - - cloudpickle==2.0.0 - - gym==0.21.0 - - imageio==2.14.1 - - jinja2==3.0.3 - - jinja2-time==0.2.0 - - joblib==1.1.0 - - loguru==0.5.3 - - markupsafe==2.0.1 - - minio==7.1.3 - - msgpack==1.0.2 - - msgpack-numpy==0.4.7.1 - - mypy-extensions==0.4.3 - - networkx==2.6.3 - - packaging==21.3 - - pandas==1.4.0 - - pathspec==0.9.0 - - platformdirs==2.4.1 - - pydantic==1.9.0 - - pyparsing==3.0.7 - - pyro5==5.13.1 - - python-dateutil==2.8.2 - - pytz==2021.3 - - pywavelets==1.2.0 - - pyzmq==22.0.3 - - scikit-image==0.19.1 - - scikit-learn==1.0.2 - - scipy==1.7.3 - - serpent==1.40 - - six==1.15.0 - - threadpoolctl==3.1.0 - - tifffile==2021.11.2 - - timeout-decorator==0.5.0 - - tomli==2.0.0 - - tqdm==4.60.0 - - urllib3==1.26.8 - - zmq==0.0.0 -prefix: /Users/mohanty/miniconda3/envs/zew diff --git a/evaluator/__init__.py b/evaluator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dataset.py b/evaluator/dataset.py similarity index 100% rename from dataset.py rename to evaluator/dataset.py diff --git a/evaluation_metrics.py b/evaluator/evaluation_metrics.py similarity index 99% rename from evaluation_metrics.py rename to evaluator/evaluation_metrics.py index 21de4043fd2cab088c756f00208b741206dbcd55..b8b782966383b5a271a077a643bd427e3fe219f0 100644 --- a/evaluation_metrics.py +++ b/evaluator/evaluation_metrics.py @@ -4,6 +4,7 @@ import numpy as np from sklearn.metrics import accuracy_score from sklearn.metrics import hamming_loss + def exact_match_ratio(y_true, y_pred): if type(y_pred) == torch.Tensor: y_pred = y_pred.numpy() diff --git a/exceptions.py b/evaluator/exceptions.py similarity index 100% rename from exceptions.py rename to evaluator/exceptions.py diff --git a/main.py b/main.py index eccc595f4e7b23fcfefb55e1f6e153024bb06d53..4ee17913045fb56327aaf72a81b82a304928f181 100644 --- a/main.py +++ b/main.py @@ -5,9 +5,8 @@ ## import numpy as np -from tqdm.auto import tqdm -from dataset import ZEWDPCBaseDataset, ZEWDPCProtectedDataset +from evaluator.dataset import ZEWDPCBaseDataset, ZEWDPCProtectedDataset from run import ZEWDPCBaseRun #################################################################################### @@ -88,7 +87,7 @@ assert predictions.shape == (len(val_dataset), 4) ## ## Phase 4 : Evaluation Phase #################################################################################### -from evaluation_metrics import accuracy_score, hamming_loss, exact_match_ratio +from evaluator.evaluation_metrics import accuracy_score, hamming_loss, exact_match_ratio y_true = val_dataset_gt._get_all_labels() y_pred = predictions diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..afa11bb5ebb52b198e132a2426b39282cd84f2f5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +click==8.0.3 +imageio==2.14.1 +jinja2==3.0.3 +timeout-decorator==0.5.0 +tqdm==4.60.0 +pandas +scikit-image +scikit-learn +scipy +torch +torchvision +torchaudio diff --git a/run.py b/run.py index 30515c480ca4e2d1a6f21b11ec66ad88fe66be03..e5f59c0b4e91938d0f3e2aefec20c0c1c5b9f58f 100644 --- a/run.py +++ b/run.py @@ -3,7 +3,7 @@ import numpy as np from tqdm.auto import tqdm -from dataset import ZEWDPCBaseDataset, ZEWDPCProtectedDataset +from evaluator.dataset import ZEWDPCBaseDataset, ZEWDPCProtectedDataset class ZEWDPCBaseRun: @@ -239,7 +239,7 @@ if __name__ == "__main__": ## ## Phase 4 : Evaluation Phase #################################################################################### - from evaluation_metrics import accuracy_score, hamming_loss, exact_match_ratio + from evaluator.evaluation_metrics import accuracy_score, hamming_loss, exact_match_ratio y_true = val_dataset_gt._get_all_labels() y_pred = predictions