From 11bc4a6584e86e2c4609c0570b9deb1988a79410 Mon Sep 17 00:00:00 2001 From: wufanyou <fanyou.wu@outlook.com> Date: Tue, 12 Jul 2022 18:37:57 +0800 Subject: [PATCH] update --- aicrowd.json | 4 ++-- utils/dataset.py | 3 +++ utils/lgb_predict_task1.py | 8 ++++---- utils/onnx_predict.py | 12 +----------- utils/run_task1.py | 28 +++++++++++++++++++++++++--- 5 files changed, 35 insertions(+), 20 deletions(-) diff --git a/aicrowd.json b/aicrowd.json index 19a8ab9..3fe7d2e 100755 --- a/aicrowd.json +++ b/aicrowd.json @@ -1,8 +1,8 @@ { "challenge_id": "esci-challenge-for-improving-product-search", - "task": "task_3_product_substitute_identification", + "task": "task_1_query-product_ranking", "gpu": true, "authors": ["aicrowd-username"], - "description": "task_1_query-product_ranking, task_2_multiclass_product_classification", + "description": "task_3_product_substitute_identification, task_2_multiclass_product_classification", "license": "MIT" } \ No newline at end of file diff --git a/utils/dataset.py b/utils/dataset.py index bc40ed3..e7f68aa 100755 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -340,6 +340,9 @@ class Task1DatasetWithDescription(Task1Dataset): "attention_mask": attention_mask, "speical_token_pos": input_ids_pos, } + + if self.model_type == 'bigbird': + feature['token_type_ids'] = torch.zeros_like(input_ids) meta = { "product_id": row['product_id'], diff --git a/utils/lgb_predict_task1.py b/utils/lgb_predict_task1.py index bd0a340..9c3f556 100755 --- a/utils/lgb_predict_task1.py +++ b/utils/lgb_predict_task1.py @@ -6,21 +6,21 @@ __MAP__ = ["irrelevant", "complement", "substitute", "exact"] LGB_CONFIG = { "us": { - "product_feat": "/models/us-product-feat.csv", + "product_feat": "/models/us-product-feat-task-1-only.csv", "model_file": "/models/lgb-us-task-1.txt", "features": ['pred_0','pred_1','pred_2','pred_3', 'label_0','label_1','label_2','label_3', 'query_count','is_isbn','has_isbn','fold'] }, "jp": { - "product_feat": "/models/jp-product-feat.csv", + "product_feat": "/models/es-jp-product-feat-task-1-only.csv", "model_file": "/models/lgb-es-jp-task-1.txt", "features": ['pred_0','pred_1','pred_2','pred_3', 'label_0','label_1','label_2','label_3', 'query_count','is_isbn','has_isbn','fold','locale'] }, "es": { - "product_feat": "/models/es-product-feat.csv", + "product_feat": "/models/es-jp-product-feat-task-1-only.csv", "model_file": "/models/lgb-es-jp-task-1.txt", "features": ['pred_0','pred_1','pred_2','pred_3', 'label_0','label_1','label_2','label_3', @@ -30,7 +30,7 @@ LGB_CONFIG = { LOCALE_MAP = {'jp':0, 'es':1, 'us':2} COL_NAME = "product_id" -WEIGHT_MAP = {0:0.5, 1:0.5, 2:0.5, 3:0.5, 4:0.25, 5:0.25} +WEIGHT_MAP = {0:0.5, 1:0.5, 2:0.5, 3:0.5, 4:0.25, 5:0.25, 6:0.5, 7:0.5} # Need modification def lgb_predict_task1(df, locale): df = df.reset_index(drop=True) diff --git a/utils/onnx_predict.py b/utils/onnx_predict.py index 40ccda5..af4b8b3 100755 --- a/utils/onnx_predict.py +++ b/utils/onnx_predict.py @@ -43,11 +43,6 @@ def onnx_predict(sub_test, config): all_example = [] for data in tqdm(loader): - # inputs = { - # "input_ids": data["features"]["input_ids"], - # "attention_mask": data["features"]["attention_mask"], - # "speical_token_pos": data["features"]["speical_token_pos"], - # } for i, s in enumerate(session): all_output[i] += list( s.run(output_names=["output"], input_feed=dict(data['features']))[0] # type: ignore @@ -97,14 +92,9 @@ def onnx_predict_task1(sub_test, config): all_product = [] for data in tqdm(loader): - inputs = { - "input_ids": data["features"]["input_ids"], - "attention_mask": data["features"]["attention_mask"], - "speical_token_pos": data["features"]["speical_token_pos"], - } for i, s in enumerate(session): all_output[i] += list( - s.run(output_names=["output"], input_feed=dict(inputs))[0] # type: ignore + s.run(output_names=["output"], input_feed=dict(data['features']))[0] # type: ignore ) all_query += data["meta"]["query_id"] all_product += data['meta']['product_id'] diff --git a/utils/run_task1.py b/utils/run_task1.py index 578c37e..09f9bf8 100755 --- a/utils/run_task1.py +++ b/utils/run_task1.py @@ -24,7 +24,7 @@ CONFIG = { "product": "/models/product.h5", "key": "us", "type": "deberta", - "fold_offset": 0, + "fold_offset": 0, # fold 0, 1 "description": False, }, { @@ -43,7 +43,7 @@ CONFIG = { "product": "/models/cocolm.h5", "key": "us", "type": "cocolm", - "fold_offset": 2, + "fold_offset": 2, # fold 2, 3 "description": False, }, { @@ -62,9 +62,29 @@ CONFIG = { "product": "/models/distilbart.h5", "key": "us", "type": "distilbart", - "fold_offset": 4, + "fold_offset": 4, # fold 4, 5 "description": False, }, + { + "clean": DeBertaCleanV2, + "encode": { + "query": 12506, + "product_title": 3771, + "product_id": 4787, + "index": 6477, + "product_description": 6865, + "product_bullet_point": 10739, + "product_brand": 4609, + "product_color_name": 3225, + }, + "model": ["/models/us-bigbird-des-0-fp16.onnx", "/models/us-bigbird-des-1-fp16.onnx"], + "product": "/models/bigbird.h5", + "key": "us", + "type": "bigbird", + "tokenizer": "/models/bigbird.model", + "fold_offset": 6, # 6, 7 + "description":True, + }, ], "jp": [ { @@ -85,6 +105,7 @@ CONFIG = { "key": "jp", "type": "mdeberta", "fold_offset": 0, + "description":False, }, ], "es": [ @@ -106,6 +127,7 @@ CONFIG = { "key": "es", "type": "mdeberta", "fold_offset": 0, + "description":False, }, ], } -- GitLab