diff --git a/aicrowd.json b/aicrowd.json index 531fadef45ec66bb59b7afa1883c6e08176d6ef9..864be108d9e800710998f4a2a6bc739365d79249 100755 --- a/aicrowd.json +++ b/aicrowd.json @@ -1,8 +1,8 @@ { "challenge_id": "esci-challenge-for-improving-product-search", - "task": "task_1_query-product_ranking", + "task": "task_2_multiclass_product_classification", "gpu": true, "authors": ["aicrowd-username"], - "description": "task_2_multiclass_product_classification, task_3_product_substitute_identification", + "description": "task_1_query-product_ranking, task_3_product_substitute_identification", "license": "MIT" } \ No newline at end of file diff --git a/utils/dataset.py b/utils/dataset.py index 9d8ee662ccb147025710bd3bdbb08336cccbe09c..bc40ed3de10368f39c4fa618ae4cf9e4c1db2b19 100755 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -11,10 +11,11 @@ import sentencepiece as spm TOKERIZER_SETTING = { - "deberta": {"cls":1}, - "distilbart": {'cls':0}, - 'cocolm':{'cls':0}, - 'mdeberta':{'cls':1}, + "deberta": {"cls":1, 'sep':2}, + "distilbart": {'cls':0, 'sep':2}, + 'cocolm':{'cls':0, 'sep':2}, + 'mdeberta':{'cls':1, 'sep':2}, + "bigbird":{'cls':65, 'sep':66}, } @@ -26,6 +27,7 @@ class BaseDataset(Dataset): filename = config_dict['product'] self.filename = filename self.cls_encode = TOKERIZER_SETTING[config_dict['type']]['cls'] + self.sep_encode = TOKERIZER_SETTING[config_dict['type']]['sep'] self.model_type = config_dict['type'] self.df = df self.key = config_dict['key'] @@ -70,26 +72,26 @@ class Task2Dataset(BaseDataset): product_title = list(self.database[self.key][row.product_id]["product_title"][:]) # type: ignore product_brand = list(self.database[self.key][row.product_id]["product_brand"][:]) # type: ignore product_color_name = list(self.database[self.key][row.product_id]["product_color_name"][:]) # type: ignore - product_bullet_point = list(self.database[self.key][row.product_id]["product_bullet_point"][:]) # type: ignore + product_bullet_point = list(self.database[self.key][row.product_id]["product_bullet_point"][:]) # type: ignore index = self.tokenizer.encode(" ".join(str(row["index"]))) # type: ignore if self.model_type == 'distilbart': index = index[1:-1] product_id = self.tokenizer.encode(" ".join(str(row["product_id"]))) # type: ignore if self.model_type == 'distilbart': product_id = product_id[1:-1] input_ids_pos = [1] - input_ids = [self.cls_encode] + [self.config_dict["encode"]["query"]] + query + [2] + input_ids = [self.cls_encode] + [self.config_dict["encode"]["query"]] + query + [self.sep_encode] input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["product_title"]] + product_title[1:-1] + [2] + input_ids += [self.config_dict["encode"]["product_title"]] + product_title[1:-1] + [self.sep_encode] input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["product_id"]] + product_id + [2] + input_ids += [self.config_dict["encode"]["product_id"]] + product_id + [self.sep_encode] input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["index"]] + index + [2] + input_ids += [self.config_dict["encode"]["index"]] + index + [self.sep_encode] input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["product_brand"]] + product_brand[1:-1] + [2] + input_ids += [self.config_dict["encode"]["product_brand"]] + product_brand[1:-1] + [self.sep_encode] input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["product_color_name"]] + product_color_name[1:-1] + [2] + input_ids += [self.config_dict["encode"]["product_color_name"]] + product_color_name[1:-1] + [self.sep_encode] input_ids_pos.append(len(input_ids)) input_ids += [self.config_dict["encode"]["product_bullet_point"]] + product_bullet_point[1:-1] - input_ids = input_ids[:self.max_length-1] + [2] + input_ids = input_ids[:self.max_length-1] + [self.sep_encode] for i in range(len(input_ids_pos)): if input_ids_pos[i] >= self.max_length: if input_ids_pos[-2] < self.max_length: @@ -111,6 +113,9 @@ class Task2Dataset(BaseDataset): "speical_token_pos": input_ids_pos, } + if self.model_type == 'bigbird': + feature['token_type_ids'] = torch.zeros_like(input_ids) + meta = { "product_id": row['product_id'], "example_id": row['example_id'], @@ -125,10 +130,6 @@ class Task2Dataset(BaseDataset): batch_first=True, padding_value=0, ).numpy() -# features["token_type_ids"] = pad_sequence( -# [x[0]["token_type_ids"] for x in batch], -# batch_first=True, -# ).numpy() features["attention_mask"] = pad_sequence( [x[0]["attention_mask"] for x in batch], batch_first=True, @@ -136,12 +137,77 @@ class Task2Dataset(BaseDataset): features["speical_token_pos"] = torch.cat( [x[0]["speical_token_pos"] for x in batch] ).numpy() + + if 'token_type_ids' in batch[0][0]: + features["token_type_ids"] = pad_sequence( + [x[0]["token_type_ids"] for x in batch], + batch_first=True, + ).numpy() + meta = {} meta["product_id"] = [x[1]["product_id"] for x in batch] meta["example_id"] = [x[1]["example_id"] for x in batch] return {"features": features, "meta": meta} +class Task2DatasetWithDescription(Task2Dataset): + def __getitem__(self, index) -> Tuple: + row = self.df.loc[index] + query = self.tokenizer.encode(self.clean(row["query"])) # type: ignore + if self.model_type == 'distilbart': query = query[1:-1] + product_title = list(self.database[self.key][row.product_id]["product_title"][:]) # type: ignore + product_brand = list(self.database[self.key][row.product_id]["product_brand"][:]) # type: ignore + product_color_name = list(self.database[self.key][row.product_id]["product_color_name"][:]) # type: ignore + product_bullet_point = list(self.database[self.key][row.product_id]["product_bullet_point"][:]) # type: ignore + product_description = list(self.database[self.key][row.product_id]["product_description"][:]) # type: ignore + index = self.tokenizer.encode(" ".join(str(row["index"]))) # type: ignore + if self.model_type == 'distilbart': index = index[1:-1] + product_id = self.tokenizer.encode(" ".join(str(row["product_id"]))) # type: ignore + if self.model_type == 'distilbart': product_id = product_id[1:-1] + input_ids_pos = [1] + input_ids = [self.cls_encode] + [self.config_dict["encode"]["query"]] + query + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["product_title"]] + product_title[1:-1] + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["product_id"]] + product_id + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["index"]] + index + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["product_brand"]] + product_brand[1:-1] + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["product_color_name"]] + product_color_name[1:-1] + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["product_bullet_point"]] + product_bullet_point[1:-1] + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["product_description"]] + product_description[1:-1] + input_ids = input_ids[:self.max_length-1] + [self.sep_encode] + for i in range(len(input_ids_pos)): + if input_ids_pos[i] >= self.max_length: + if input_ids_pos[-2] < self.max_length: + input_ids_pos[i] = input_ids_pos[-2] + elif input_ids_pos[1] < self.max_length: + input_ids_pos[i] = input_ids_pos[1] + else: + input_ids_pos[i] = self.max_length - 1 + input_ids = torch.tensor(input_ids, dtype=torch.long) + input_ids_pos = torch.tensor(input_ids_pos, dtype=torch.long)[None] + attention_mask = torch.ones_like(input_ids) + + feature = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "speical_token_pos": input_ids_pos, + } + if self.model_type == 'bigbird': + feature['token_type_ids'] = torch.zeros_like(input_ids) + meta = { + "product_id": row['product_id'], + "example_id": row['example_id'], + } + return feature, meta + + + class Task1Dataset(BaseDataset): def __getitem__(self, index) -> Tuple: row = self.df.loc[index] @@ -151,25 +217,26 @@ class Task1Dataset(BaseDataset): product_brand = list(self.database[self.key][row.product_id]["product_brand"][:]) # type: ignore product_color_name = list(self.database[self.key][row.product_id]["product_color_name"][:]) # type: ignore product_bullet_point = list(self.database[self.key][row.product_id]["product_bullet_point"][:]) # type: ignore + product_description = list(self.database[self.key][row.product_id]["product_description"][:]) # type: ignore index = self.tokenizer.encode(" ".join(str(row["index"]))) # type: ignore if self.model_type == 'distilbart': index = index[1:-1] product_id = self.tokenizer.encode(" ".join(str(row["product_id"]))) # type: ignore if self.model_type == 'distilbart': product_id = product_id[1:-1] input_ids_pos = [1] - input_ids = [self.cls_encode] + [self.config_dict["encode"]["query"]] + query + [2] + input_ids = [self.cls_encode] + [self.config_dict["encode"]["query"]] + query + [self.sep_encode] input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["product_title"]] + product_title[1:-1] + [2] + input_ids += [self.config_dict["encode"]["product_title"]] + product_title[1:-1] + [self.sep_encode] input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["product_id"]] + product_id + [2] + input_ids += [self.config_dict["encode"]["product_id"]] + product_id + [self.sep_encode] input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["index"]] + index + [2] + input_ids += [self.config_dict["encode"]["index"]] + index + [self.sep_encode] input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["product_brand"]] + product_brand[1:-1] + [2] + input_ids += [self.config_dict["encode"]["product_brand"]] + product_brand[1:-1] + [self.sep_encode] input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["product_color_name"]] + product_color_name[1:-1] + [2] + input_ids += [self.config_dict["encode"]["product_color_name"]] + product_color_name[1:-1] + [self.sep_encode] input_ids_pos.append(len(input_ids)) input_ids += [self.config_dict["encode"]["product_bullet_point"]] + product_bullet_point[1:-1] - input_ids = input_ids[:self.max_length-1] + [2] + input_ids = input_ids[:self.max_length-1] + [self.sep_encode] for i in range(len(input_ids_pos)): if input_ids_pos[i] >= self.max_length: if input_ids_pos[-2] < self.max_length: @@ -190,6 +257,8 @@ class Task1Dataset(BaseDataset): "attention_mask": attention_mask, "speical_token_pos": input_ids_pos, } + if self.model_type == 'bigbird': + feature['token_type_ids'] = torch.zeros_like(input_ids) meta = { "product_id": row['product_id'], @@ -212,7 +281,93 @@ class Task1Dataset(BaseDataset): features["speical_token_pos"] = torch.cat( [x[0]["speical_token_pos"] for x in batch] ).numpy() + if 'token_type_ids' in batch[0][0]: + features["token_type_ids"] = pad_sequence( + [x[0]["token_type_ids"] for x in batch], + batch_first=True, + ).numpy() meta = {} meta["product_id"] = [x[1]["product_id"] for x in batch] meta["query_id"] = [x[1]["query_id"] for x in batch] - return {"features": features, "meta": meta} \ No newline at end of file + return {"features": features, "meta": meta} + +class Task1DatasetWithDescription(Task1Dataset): + def __getitem__(self, index) -> Tuple: + row = self.df.loc[index] + query = self.tokenizer.encode(self.clean(row["query"])) # type: ignore + if self.model_type == 'distilbart': query = query[1:-1] + product_title = list(self.database[self.key][row.product_id]["product_title"][:]) # type: ignore + product_brand = list(self.database[self.key][row.product_id]["product_brand"][:]) # type: ignore + product_color_name = list(self.database[self.key][row.product_id]["product_color_name"][:]) # type: ignore + product_bullet_point = list(self.database[self.key][row.product_id]["product_bullet_point"][:]) # type: ignore + product_description = list(self.database[self.key][row.product_id]["product_description"][:]) # type: ignore + index = self.tokenizer.encode(" ".join(str(row["index"]))) # type: ignore + if self.model_type == 'distilbart': index = index[1:-1] + product_id = self.tokenizer.encode(" ".join(str(row["product_id"]))) # type: ignore + if self.model_type == 'distilbart': product_id = product_id[1:-1] + input_ids_pos = [1] + input_ids = [self.cls_encode] + [self.config_dict["encode"]["query"]] + query + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["product_title"]] + product_title[1:-1] + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["product_id"]] + product_id + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["index"]] + index + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["product_brand"]] + product_brand[1:-1] + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["product_color_name"]] + product_color_name[1:-1] + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["product_bullet_point"]] + product_bullet_point[1:-1] + [self.sep_encode] + input_ids_pos.append(len(input_ids)) + input_ids += [self.config_dict["encode"]["product_description"]] + product_description[1:-1] + input_ids = input_ids[:self.max_length-1] + [self.sep_encode] + for i in range(len(input_ids_pos)): + if input_ids_pos[i] >= self.max_length: + if input_ids_pos[-2] < self.max_length: + input_ids_pos[i] = input_ids_pos[-2] + elif input_ids_pos[1] < self.max_length: + input_ids_pos[i] = input_ids_pos[1] + else: + input_ids_pos[i] = self.max_length - 1 + + input_ids = torch.tensor(input_ids, dtype=torch.long) + input_ids_pos = torch.tensor(input_ids_pos, dtype=torch.long)[None] + #token_type_ids = torch.zeros_like(input_ids) + attention_mask = torch.ones_like(input_ids) + feature = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "speical_token_pos": input_ids_pos, + } + + meta = { + "product_id": row['product_id'], + "query_id": row['query_id'] + } + return feature, meta + + @staticmethod + def collate_fn(batch: List) -> dict: + features = {} + features["input_ids"] = pad_sequence( + [x[0]["input_ids"] for x in batch], + batch_first=True, + padding_value=0, + ).numpy() + features["attention_mask"] = pad_sequence( + [x[0]["attention_mask"] for x in batch], + batch_first=True, + ).numpy() + features["speical_token_pos"] = torch.cat( + [x[0]["speical_token_pos"] for x in batch] + ).numpy() + if 'token_type_ids' in batch[0][0]: + features["token_type_ids"] = pad_sequence( + [x[0]["token_type_ids"] for x in batch], + batch_first=True, + ).numpy() + meta = {} + meta["product_id"] = [x[1]["product_id"] for x in batch] + meta["query_id"] = [x[1]["query_id"] for x in batch] + return {"features": features, "meta": meta} diff --git a/utils/lgb_predict_task2.py b/utils/lgb_predict_task2.py index b133c90e0f138d21f198c95fc3f2dbc7a1ff9f74..ad66fcf95f62e9f14c459f39053072aec2867822 100755 --- a/utils/lgb_predict_task2.py +++ b/utils/lgb_predict_task2.py @@ -6,21 +6,21 @@ __MAP__ = ["irrelevant", "complement", "substitute", "exact"] LGB_CONFIG = { "us": { - "product_feat": "/models/us-product-feat.csv", - "model_file": "/models/lgb-us-task-2-only.txt", + "product_feat": "/models/us-product-feat-remove-intersection.csv", + "model_file": "/models/lgb-us-task-2.txt", "features": ['pred_0','pred_1','pred_2','pred_3', 'label_0','label_1','label_2','label_3', 'query_count','is_isbn','has_isbn','fold'] }, "jp": { - "product_feat": "/models/jp-product-feat.csv", + "product_feat": "/models/es-jp-product-feat-remove-intersection.csv", "model_file": "/models/lgb-es-jp-task-2.txt", "features": ['pred_0','pred_1','pred_2','pred_3', 'label_0','label_1','label_2','label_3', 'query_count','is_isbn','has_isbn','fold','locale'] }, "es": { - "product_feat": "/models/es-product-feat.csv", + "product_feat": "/models/es-jp-product-feat-remove-intersection.csv", "model_file": "/models/lgb-es-jp-task-2.txt", "features": ['pred_0','pred_1','pred_2','pred_3', 'label_0','label_1','label_2','label_3', @@ -28,11 +28,10 @@ LGB_CONFIG = { }, } - MAP = ["irrelevant", "complement", "substitute", "exact"] LOCALE_MAP = {'jp':0, 'es':1, 'us':2} COL_NAME = "esci_label" -WEIGHT_MAP = {0:1, 1:1, 2:1, 3:1} +WEIGHT_MAP = {0:1, 1:1, 2:1, 3:1, 4:1, 5:1} def lgb_predict_task2(df, locale): df = df.reset_index(drop=True) diff --git a/utils/lgb_predict_task3.py b/utils/lgb_predict_task3.py index 16d877b885e8866545c75e44b421f2a83423060c..26e79b3de8ce3f2ab899d9156d782cd21adf11ae 100755 --- a/utils/lgb_predict_task3.py +++ b/utils/lgb_predict_task3.py @@ -3,21 +3,21 @@ import pandas as pd LGB_CONFIG = { "us": { - "product_feat": "/models/us-product-feat.csv", - "model_file": "/models/lgb-us-task-3-only.txt", + "product_feat": "/models/us-product-feat-remove-intersection.csv", + "model_file": "/models/lgb-us-task-3.txt", "features": ['pred_0','pred_1','pred_2','pred_3', 'label_0','label_1','label_2','label_3', 'query_count','is_isbn','has_isbn','fold'] }, "jp": { - "product_feat": "/models/jp-product-feat.csv", + "product_feat": "/models/es-jp-product-feat-remove-intersection.csv", "model_file": "/models/lgb-es-jp-task-3.txt", "features": ['pred_0','pred_1','pred_2','pred_3', 'label_0','label_1','label_2','label_3', 'query_count','is_isbn','has_isbn','fold','locale'] }, "es": { - "product_feat": "/models/es-product-feat.csv", + "product_feat": "/models/es-jp-product-feat-remove-intersection.csv", "model_file": "/models/lgb-es-jp-task-3.txt", "features": ['pred_0','pred_1','pred_2','pred_3', 'label_0','label_1','label_2','label_3', @@ -29,7 +29,7 @@ LGB_CONFIG = { MAP = ["no_substitute", "substitute"] LOCALE_MAP = {'jp':0, 'es':1, 'us':2} COL_NAME = "substitute_label" -WEIGHT_MAP = {0:1, 1:1, 2:1, 3:1} +WEIGHT_MAP = {0:1, 1:1, 2:1, 3:1, 4:1, 5:1} def lgb_predict_task3(df, locale): df = df.reset_index(drop=True) diff --git a/utils/onnx_predict.py b/utils/onnx_predict.py index 4378284fbd513b32d8f3308d3e71d3d97ea99ff3..40ccda576922adb3c838ba2123319864e6ebd551 100755 --- a/utils/onnx_predict.py +++ b/utils/onnx_predict.py @@ -1,4 +1,4 @@ -from .dataset import Task2Dataset,Task1Dataset +from .dataset import Task2Dataset, Task1Dataset, Task2DatasetWithDescription, Task1DatasetWithDescription from torch.utils.data import DataLoader import onnxruntime as ort import pandas as pd @@ -22,7 +22,12 @@ def onnx_predict(sub_test, config): session = [] for model in config["model"]: session.append(ort.InferenceSession(model, providers=["CUDAExecutionProvider"])) - dataset = Task2Dataset(sub_test, config) + + if config['description']: + dataset = Task2DatasetWithDescription(sub_test, config) + else: + dataset = Task2Dataset(sub_test, config) + loader = DataLoader( dataset, batch_size=BATCH_SIZE, @@ -33,18 +38,19 @@ def onnx_predict(sub_test, config): collate_fn=dataset.collate_fn, persistent_workers=False, ) + all_output = defaultdict(list) all_example = [] for data in tqdm(loader): - inputs = { - "input_ids": data["features"]["input_ids"], - "attention_mask": data["features"]["attention_mask"], - "speical_token_pos": data["features"]["speical_token_pos"], - } + # inputs = { + # "input_ids": data["features"]["input_ids"], + # "attention_mask": data["features"]["attention_mask"], + # "speical_token_pos": data["features"]["speical_token_pos"], + # } for i, s in enumerate(session): all_output[i] += list( - s.run(output_names=["output"], input_feed=dict(inputs))[0] # type: ignore + s.run(output_names=["output"], input_feed=dict(data['features']))[0] # type: ignore ) all_example += data["meta"]["example_id"] @@ -70,7 +76,12 @@ def onnx_predict_task1(sub_test, config): session = [] for model in config["model"]: session.append(ort.InferenceSession(model, providers=["CUDAExecutionProvider"])) - dataset = Task1Dataset(sub_test, config) + + if config['description']: + dataset = Task1DatasetWithDescription(sub_test, config) + else: + dataset = Task1Dataset(sub_test, config) + loader = DataLoader( dataset, batch_size=BATCH_SIZE, diff --git a/utils/run_task1.py b/utils/run_task1.py index 1d3c43f625d054a1ce7a978d6c41eae4ee5a4d51..578c37ef4ae6b700d19947c2aebc137ea6a9262b 100755 --- a/utils/run_task1.py +++ b/utils/run_task1.py @@ -25,6 +25,7 @@ CONFIG = { "key": "us", "type": "deberta", "fold_offset": 0, + "description": False, }, { "clean": DeBertaCleanV2, @@ -43,6 +44,7 @@ CONFIG = { "key": "us", "type": "cocolm", "fold_offset": 2, + "description": False, }, { "clean": DeBertaCleanV2, @@ -61,6 +63,7 @@ CONFIG = { "key": "us", "type": "distilbart", "fold_offset": 4, + "description": False, }, ], "jp": [ diff --git a/utils/run_task2.py b/utils/run_task2.py index 57fe652a8fef762b0a62b1642aa5caebade49c0e..a1dc6abb045938cf18cc0df63f45c4aae11839a0 100755 --- a/utils/run_task2.py +++ b/utils/run_task2.py @@ -26,6 +26,7 @@ CONFIG = { "key": "us", "type": "deberta", "fold_offset": 0, + "description":False, }, { "clean": DeBertaCleanV2, @@ -39,11 +40,32 @@ CONFIG = { "product_brand": 10643, "product_color_name": 11890, }, - "model": ["/models/us-cocolm-new-seed-0-fp16.onnx", "/models/us-cocolm-new-seed-1-fp16.onnx"], + "model": ["/models/us-cocolm-des-0-fp16.onnx", "/models/us-cocolm-des-0-fp16.onnx"], "product": "/models/cocolm.h5", "key": "us", "type": "cocolm", "fold_offset": 2, + "description":True, + }, + { + "clean": DeBertaCleanV2, + "encode": { + "query": 12506, + "product_title": 3771, + "product_id": 4787, + "index": 6477, + "product_description": 6865, + "product_bullet_point": 10739, + "product_brand": 4609, + "product_color_name": 3225, + }, + "model": ["/models/us-bigbird-des-0-fp16.onnx", "/models/us-bigbird-des-1-fp16.onnx"], + "product": "/models/bigbird.h5", + "key": "us", + "type": "bigbird", + "tokenizer": "/models/bigbird.model", + "fold_offset": 4, + "description":True, }, ], "jp": [ @@ -65,6 +87,7 @@ CONFIG = { "key": "jp", "type": "mdeberta", "fold_offset": 0, + "description":False, }, ], "es": [ @@ -86,6 +109,7 @@ CONFIG = { "key": "es", "type": "mdeberta", "fold_offset": 0, + "description": False, }, ], } diff --git a/utils/run_task3.py b/utils/run_task3.py index b1a85a1f5408a8571bbc5dd204bc3615631b3379..52ac662ac32c99739c5429feee054bb5fe04c100 100755 --- a/utils/run_task3.py +++ b/utils/run_task3.py @@ -25,6 +25,7 @@ CONFIG = { "key": "us", "type": "deberta", "fold_offset": 0, + "description":False, }, { "clean": DeBertaCleanV2, @@ -38,11 +39,32 @@ CONFIG = { "product_brand": 10643, "product_color_name": 11890, }, - "model": ["/models/us-cocolm-new-seed-0-fp16.onnx", "/models/us-cocolm-new-seed-1-fp16.onnx"], + "model": ["/models/us-cocolm-des-0-fp16.onnx", "/models/us-cocolm-des-0-fp16.onnx"], "product": "/models/cocolm.h5", "key": "us", "type": "cocolm", "fold_offset": 2, + "description":True, + }, + { + "clean": DeBertaCleanV2, + "encode": { + "query": 12506, + "product_title": 3771, + "product_id": 4787, + "index": 6477, + "product_description": 6865, + "product_bullet_point": 10739, + "product_brand": 4609, + "product_color_name": 3225, + }, + "model": ["/models/us-bigbird-des-0-fp16.onnx", "/models/us-bigbird-des-1-fp16.onnx"], + "product": "/models/bigbird.h5", + "key": "us", + "type": "bigbird", + "tokenizer": "/models/bigbird.model", + "fold_offset": 4, + "description":True, }, ], "jp": [ @@ -64,6 +86,7 @@ CONFIG = { "key": "jp", "type": "mdeberta", "fold_offset": 0, + "description":False, }, ], "es": [ @@ -85,6 +108,7 @@ CONFIG = { "key": "es", "type": "mdeberta", "fold_offset": 0, + "description": False, }, ], }