diff --git a/tools/test_robustness.py b/tools/test_robustness.py index e2632151777b32a83aedf198c3db4796b57e65ca..c0489f3ebaafac310079f522399a7f26bd22d24c 100644 --- a/tools/test_robustness.py +++ b/tools/test_robustness.py @@ -350,9 +350,9 @@ def main(): aggregated_results[corruptions[0]][0] continue + test_data_cfg = copy.deepcopy(cfg.data.test) # assign corruption and severity if corruption_severity > 0: - test_data_cfg = copy.deepcopy(cfg.data.test) corruption_trans = dict( type='Corrupt', corruption=corruption, @@ -368,7 +368,7 @@ def main(): # build the dataloader # TODO: support multiple images per gpu # (only minor changes are needed) - dataset = build_dataset(cfg.data.test) + dataset = build_dataset(test_data_cfg) data_loader = build_dataloader( dataset, imgs_per_gpu=1,