diff --git a/evaluator/model.py b/evaluator/model.py index 483b332e1d073473ff5173b14a76cd12c5ec3cad..74bf5ed9d32abeade5f33cf33a33cb049e71e4ca 100644 --- a/evaluator/model.py +++ b/evaluator/model.py @@ -29,7 +29,7 @@ class ZEWDPCModel(torch.nn.Module): def init_network(self): # Setup Base Model - EfficientNet_b4 - self.base_model = torchvision.models.efficientnet_b4( + self.base_model = torchvision.models.efficientnet_b0( pretrained=self.use_pretrained, ) diff --git a/local_evaluation.py b/local_evaluation.py index 4999c37341f520df7764bd8e895f48173b34ad7c..e3d85e110f5815bd171274837c8f0f10efa3f346 100644 --- a/local_evaluation.py +++ b/local_evaluation.py @@ -131,7 +131,7 @@ else: trainer = ZEWDPCTrainer(num_classes=6, use_pretrained=True) trainer.train( - aggregated_dataset, num_epochs=10, validation_percentage=0.1, batch_size=32 + aggregated_dataset, num_epochs=25, validation_percentage=0.1, batch_size=64 ) y_pred = trainer.predict(val_dataset)