diff --git a/models/lgb-es-jp.txt b/models/lgb-es-jp-task-2.txt similarity index 100% rename from models/lgb-es-jp.txt rename to models/lgb-es-jp-task-2.txt diff --git a/models/lgb-us.txt b/models/lgb-us-task-2.txt similarity index 100% rename from models/lgb-us.txt rename to models/lgb-us-task-2.txt diff --git a/utils/__init__.py b/utils/__init__.py deleted file mode 100755 index be3f3b861a757a4f91df98f0cfe461562e9d5424..0000000000000000000000000000000000000000 --- a/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .run import Task2Predictor \ No newline at end of file diff --git a/utils/base_predictor.py b/utils/base_predictor.py deleted file mode 100755 index cf58d5a654105f90af3ef08dac811dabc469be1c..0000000000000000000000000000000000000000 --- a/utils/base_predictor.py +++ /dev/null @@ -1,31 +0,0 @@ -import pathlib -import typing - -""" -NOTE: Any changes to this file will be overwritten by the evaluators. -""" - -PathType = typing.Union[str, pathlib.Path] - - -class BasePredictor: - def prediction_setup(self): - """To be implemented by the participants. - - Participants can add the steps needed to initialize their models, - and or any other setup related things here. - """ - raise NotImplementedError - - def predict(self, test_set_path: PathType, predictions_output_path: PathType, register_progress): - """To be implemented by the participants. - - Participants need to consume the test set (the path of which is passed) - and write the final predictions as a CSV file to `predictions_output_path`. - - Args: - test_set_path: Path to the Test Set for the specific task. - predictions_output_path: Output Path to write the predictions as a CSV file. - register_progress: A helper callable to register progress. Accepts a value [0, 1]. - """ - raise NotImplementedError diff --git a/utils/clean.py b/utils/clean.py deleted file mode 100755 index 665de573a10a42c789eacedcc06c48dc8d6569a0..0000000000000000000000000000000000000000 --- a/utils/clean.py +++ /dev/null @@ -1,640 +0,0 @@ -import re -import emoji -from collections import Counter -import unicodedata - -class BaseClean: - clean_fns = [ - "to_lower", - "to_symbol", - "remove_emoji", - "clean_contractions", - "common_us_word", - "query_clean_v1", - "remove_control_char", - "remove_duplicate", - "remove_ending_underscore", - "remove_starting_underscore", - "clean_multiple_form", - "leet_clean", - ] - - def __init__(self, clean_fns=None): - if clean_fns: - self.clean_fns = clean_fns - - def __call__(self, input_texts): - - if type(input_texts) == list: - for fn in self.clean_fns: - fn = eval(fn) - input_texts = fn(input_texts) - - elif type(input_texts) == str: - input_texts = [input_texts] - input_texts = self(input_texts) - input_texts = input_texts[0] - - return input_texts - -class DeBertaCleanV2(BaseClean): - clean_fns = [ - "to_lower", - "to_symbol", - "remove_emoji", - "clean_contractions", - "common_us_word", - "query_clean_v1", - "remove_control_char", - "remove_duplicate", - "remove_ending_underscore", - "remove_starting_underscore", - "clean_multiple_form", - "leet_clean", - ] - -class ESclean(BaseClean): - clean_fns = [ - "to_lower", - "to_symbol", - "remove_emoji", - "common_es_word", - "query_clean_v1", - "remove_control_char", - "remove_duplicate", - "remove_ending_underscore", - "remove_starting_underscore", - "clean_multiple_form", - "leet_clean", - ] - - -def common_es_word(data): - if type(data) == list: - return [common_us_word(d) for d in data] - else: - text = data - text = re.sub("''", '"', text) - text = re.sub("â€|“", '"', text) - text = re.sub("‘|′", "'", text) - exps = re.findall("[0-9] {0,1}'", text) - for exp in exps: - text = text.replace(exp, exp[0] + "pie") - - exps = re.findall('[0-9] {0,1}"', text) - for exp in exps: - text = text.replace(exp, exp.replace('"', "pulgada")) - - return text - - -class JSClean(BaseClean): - clean_fns = [ - "to_lower", - "to_symbol", - "remove_emoji", - "query_clean_v1", - "remove_control_char", - "remove_duplicate", - "remove_ending_underscore", - "remove_starting_underscore", - "clean_multiple_form", - "leet_clean", - ] - - -def to_symbol(data): - - if type(data) == list: - return [to_symbol(d) for d in data] - else: - text = data - text = re.sub(""", '"', text) - text = re.sub("'", "'", text) - return text - -def to_lower(data): - if verbose: - print("#" * 10, "Step - Lowering everything:") - data = list(map(lambda x: x.lower(), data)) - return data - -def common_us_word(data): - if type(data) == list: - return [common_us_word(d) for d in data] - else: - text = data - text = re.sub("''", '"', text) - text = re.sub("a/c", "ac", text) - text = re.sub("0z", "oz", text) - text = re.sub("â€|“", '"', text) - text = re.sub("‘|′", "'", text) - exps = re.findall("[0-9] {0,1}'", text) - - for exp in exps: - text = text.replace(exp, exp[0] + "feet") - exps = re.findall('[0-9] {0,1}"', text) - - for exp in exps: - text = text.replace(exp, exp.replace('"', "inch")) - - text = re.sub("men'{0,1} {0,1}s|mens' s", "men", text) - - return text - - -def remove_emoji(data): - if type(data) == list: - return [remove_emoji(d) for d in data] - elif type(data) == str: - return emoji.get_emoji_regexp().sub("", data) - else: - raise - - -# TODO check spell for some words -def query_clean_v1(data): - - if type(data) == list: - return [query_clean_v1(d) for d in data] - - elif type(data) == str: - text = data - product_ids = re.findall("b0[0-9a-z]{8}", text) - if product_ids: - for i, exp in enumerate(product_ids): - text = text.replace(exp, f"placehold{chr(97+i)}") - - exps = re.findall("[a-zA-Z]'s|s'", text) - for exp in exps: - text = text.replace(exp, exp[0]) - - text = re.sub("\(|\)|\*|---|\+|'|,|\[|\]| -|- |\. |/ |:", " ", text) # ignore - text = text.strip() - - exps = re.findall("[a-zA-Z]\.", text) - for exp in exps: - text = text.replace(exp, exp[0]) - - # ! -> l for words - exps = re.findall("![a-zA-Z]{2}", text) - for exp in exps: - text = text.replace(exp, exp.replace("!", "l")) - - # a/b -> a b - exps = re.findall("[a-zA-Z]/[a-zA-Z]", text) - for exp in exps: - text = text.replace(exp, exp.replace("/", " ")) - - # remove " - text = re.sub('"', " ", text) - - # remove " - text = re.sub("'", " ", text) - - # # + [sep] + [num] -> # + [num] - exps = re.findall("# {1}[0-9]", text) - for exp in exps: - text = text.replace(exp, exp.replace(" ", "")) - - # remove # without - exps = re.findall("#[a-zA-Z]", text) - for exp in exps: - text = text.replace(exp, exp.replace("#", "")) - - if product_ids: - for i, exp in enumerate(product_ids): - text = text.replace(f"placehold{chr(97+i)}", exp) - - text = text.strip() - - return text - -def clean_contractions(data): - - helper_contractions = { - "aren't": "are not", - "Aren't": "Are not", - "AREN'T": "ARE NOT", - "C'est": "C'est", - "C'mon": "C'mon", - "c'mon": "c'mon", - "can't": "cannot", - "Can't": "Cannot", - "CAN'T": "CANNOT", - "con't": "continued", - "cont'd": "continued", - "could've": "could have", - "couldn't": "could not", - "Couldn't": "Could not", - "didn't": "did not", - "Didn't": "Did not", - "DIDN'T": "DID NOT", - "don't": "do not", - "Don't": "Do not", - "DON'T": "DO NOT", - "doesn't": "does not", - "Doesn't": "Does not", - "else's": "else", - "gov's": "government", - "Gov's": "government", - "gov't": "government", - "Gov't": "government", - "govt's": "government", - "gov'ts": "governments", - "hadn't": "had not", - "hasn't": "has not", - "Hasn't": "Has not", - "haven't": "have not", - "Haven't": "Have not", - "he's": "he is", - "He's": "He is", - "he'll": "he will", - "He'll": "He will", - "he'd": "he would", - "He'd": "He would", - "Here's": "Here is", - "here's": "here is", - "I'm": "I am", - "i'm": "i am", - "I'M": "I am", - "I've": "I have", - "i've": "i have", - "I'll": "I will", - "i'll": "i will", - "I'd": "I would", - "i'd": "i would", - "ain't": "is not", - "isn't": "is not", - "Isn't": "Is not", - "ISN'T": "IS NOT", - "it's": "it is", - "It's": "It is", - "IT'S": "IT IS", - "I's": "It is", - "i's": "it is", - "it'll": "it will", - "It'll": "It will", - "it'd": "it would", - "It'd": "It would", - "Let's": "Let's", - "let's": "let us", - "ma'am": "madam", - "Ma'am": "Madam", - "she's": "she is", - "She's": "She is", - "she'll": "she will", - "She'll": "She will", - "she'd": "she would", - "She'd": "She would", - "shouldn't": "should not", - "that's": "that is", - "That's": "That is", - "THAT'S": "THAT IS", - "THAT's": "THAT IS", - "that'll": "that will", - "That'll": "That will", - "there's": "there is", - "There's": "There is", - "there'll": "there will", - "There'll": "There will", - "there'd": "there would", - "they're": "they are", - "They're": "They are", - "they've": "they have", - "They've": "They Have", - "they'll": "they will", - "They'll": "They will", - "they'd": "they would", - "They'd": "They would", - "wasn't": "was not", - "we're": "we are", - "We're": "We are", - "we've": "we have", - "We've": "We have", - "we'll": "we will", - "We'll": "We will", - "we'd": "we would", - "We'd": "We would", - "What'll": "What will", - "weren't": "were not", - "Weren't": "Were not", - "what's": "what is", - "What's": "What is", - "When's": "When is", - "Where's": "Where is", - "where's": "where is", - "Where'd": "Where would", - "who're": "who are", - "who've": "who have", - "who's": "who is", - "Who's": "Who is", - "who'll": "who will", - "who'd": "Who would", - "Who'd": "Who would", - "won't": "will not", - "Won't": "will not", - "WON'T": "WILL NOT", - "would've": "would have", - "wouldn't": "would not", - "Wouldn't": "Would not", - "would't": "would not", - "Would't": "Would not", - "y'all": "you all", - "Y'all": "You all", - "you're": "you are", - "You're": "You are", - "YOU'RE": "YOU ARE", - "you've": "you have", - "You've": "You have", - "y'know": "you know", - "Y'know": "You know", - "ya'll": "you will", - "you'll": "you will", - "You'll": "You will", - "you'd": "you would", - "You'd": "You would", - "Y'got": "You got", - "cause": "because", - "had'nt": "had not", - "Had'nt": "Had not", - "how'd": "how did", - "how'd'y": "how do you", - "how'll": "how will", - "how's": "how is", - "I'd've": "I would have", - "I'll've": "I will have", - "i'd've": "i would have", - "i'll've": "i will have", - "it'd've": "it would have", - "it'll've": "it will have", - "mayn't": "may not", - "might've": "might have", - "mightn't": "might not", - "mightn't've": "might not have", - "must've": "must have", - "mustn't": "must not", - "mustn't've": "must not have", - "needn't": "need not", - "needn't've": "need not have", - "o'clock": "of the clock", - "oughtn't": "ought not", - "oughtn't've": "ought not have", - "shan't": "shall not", - "sha'n't": "shall not", - "shan't've": "shall not have", - "she'd've": "she would have", - "she'll've": "she will have", - "should've": "should have", - "shouldn't've": "should not have", - "so've": "so have", - "so's": "so as", - "this's": "this is", - "that'd": "that would", - "that'd've": "that would have", - "there'd've": "there would have", - "they'd've": "they would have", - "they'll've": "they will have", - "to've": "to have", - "we'd've": "we would have", - "we'll've": "we will have", - "what'll": "what will", - "what'll've": "what will have", - "what're": "what are", - "what've": "what have", - "when's": "when is", - "when've": "when have", - "where'd": "where did", - "where've": "where have", - "who'll've": "who will have", - "why's": "why is", - "why've": "why have", - "will've": "will have", - "won't've": "will not have", - "wouldn't've": "would not have", - "y'all'd": "you all would", - "y'all'd've": "you all would have", - "y'all're": "you all are", - "y'all've": "you all have", - "you'd've": "you would have", - "you'll've": "you will have", - } - if verbose: - print("#" * 10, "Step - Contractions:") - # Apply spellchecker for contractions - # Local (only unknown words) - local_vocab = {} - temp_vocab = _check_vocab(data, local_vocab, response="unknown_list") - temp_vocab = [k for k in temp_vocab if (_check_replace(k)) and ("'" in k)] - temp_dict = {} - for word in temp_vocab: - if word in helper_contractions: - temp_dict[word] = helper_contractions[word] - data = list( - map( - lambda x: " ".join([_make_dict_cleaning(i, temp_dict) for i in x.split()]), - data, - ) - ) - if verbose: - _print_dict(temp_dict) - return data - -def remove_control_char(data): - if verbose: - print("#" * 10, "Step - Control Chars:") - global_chars_list = list(set([c for line in data for c in line])) - chars_dict = {c: "" for c in global_chars_list if unicodedata.category(c)[0] == "C"} - data = list( - map( - lambda x: " ".join([_make_cleaning(i, chars_dict) for i in x.split()]), data - ) - ) - - return data - -verbose = False -global_lower = True -WPLACEHOLDER = "[PLS]" - -def _check_replace(w): - return not bool(re.search(WPLACEHOLDER, w)) - - -def _make_cleaning(s, c_dict): - if _check_replace(s): - s = s.translate(c_dict) - return s - - -def _check_vocab(c_list, vocabulary, response="default"): - try: - words = set([w for line in c_list for w in line.split()]) - # print('Total Words :',len(words)) - u_list = words.difference(set(vocabulary)) - k_list = words.difference(u_list) - - if response == "default": - print("Unknown words:", len(u_list), "| Known words:", len(k_list)) - elif response == "unknown_list": - return list(u_list) - elif response == "known_list": - return list(k_list) - except: - return [] - - -def _make_dict_cleaning(s, w_dict): - if _check_replace(s): - s = w_dict.get(s, s) - return s - - -def _print_dict(temp_dict, n_items=10): - run = 0 - for k, v in temp_dict.items(): - print(k, "---", v) - run += 1 - if run == n_items: - break - #################### Main Function ################# - -def remove_duplicate(data): - # Duplicated dots, question marks and exclamations - # Locallocal_vocab - if verbose: - print("#" * 10, "Step - Duplicated Chars:") - local_vocab = {} - temp_vocab = _check_vocab(data, local_vocab, response="unknown_list") - temp_vocab = [k for k in temp_vocab if _check_replace(k)] - temp_dict = {} - - for word in temp_vocab: - new_word = word - if ( - (Counter(word)["."] > 1) - or (Counter(word)["!"] > 1) - or (Counter(word)["?"] > 1) - or (Counter(word)[","] > 1) - ): - if Counter(word)["."] > 1: - new_word = re.sub("\.\.+", " . . . ", new_word) - if Counter(word)["!"] > 1: - new_word = re.sub("\!\!+", " ! ! ! ", new_word) - if Counter(word)["?"] > 1: - new_word = re.sub("\?\?+", " ? ? ? ", new_word) - if Counter(word)[","] > 1: - new_word = re.sub("\,\,+", " , , , ", new_word) - temp_dict[word] = new_word - - temp_dict = {k: v for k, v in temp_dict.items() if k != v} - data = list( - map( - lambda x: " ".join([_make_dict_cleaning(i, temp_dict) for i in x.split()]), - data, - ) - ) - return data - - -def remove_ending_underscore(data): - if verbose: - print("#" * 10, "Step - Remove ending underscore:") - local_vocab = {} - temp_vocab = _check_vocab(data, local_vocab, response="unknown_list") - temp_vocab = [k for k in temp_vocab if (_check_replace(k)) and ("_" in k)] - temp_dict = {} - for word in temp_vocab: - new_word = word - if word[len(word) - 1] == "_": - for i in range(len(word), 0, -1): - if word[i - 1] != "_": - new_word = word[:i] - temp_dict[word] = new_word - break - data = list( - map( - lambda x: " ".join([_make_dict_cleaning(i, temp_dict) for i in x.split()]), - data, - ) - ) - if verbose: - _print_dict(temp_dict) - return data - - -def remove_starting_underscore(data): - if verbose: - print("#" * 10, "Step - Remove starting underscore:") - local_vocab = {} - temp_vocab = _check_vocab(data, local_vocab, response="unknown_list") - temp_vocab = [k for k in temp_vocab if (_check_replace(k)) and ("_" in k)] - temp_dict = {} - for word in temp_vocab: - new_word = word - if word[len(word) - 1] == "_": - for i in range(len(word)): - if word[i] != "_": - new_word = word[:i] - temp_dict[word] = new_word - break - data = list( - map( - lambda x: " ".join([_make_dict_cleaning(i, temp_dict) for i in x.split()]), - data, - ) - ) - if verbose: - _print_dict(temp_dict) - return data - -def clean_multiple_form(data): - if verbose: - print("#" * 10, "Step - Multiple form:") - local_vocab = {} - temp_vocab = _check_vocab(data, local_vocab, response="unknown_list") - temp_vocab = [k for k in temp_vocab if (k[-1:] == "s") and (len(k) > 4)] - temp_dict = {k: k[:-1] for k in temp_vocab if (k[:-1] in local_vocab)} - data = list( - map( - lambda x: " ".join([_make_dict_cleaning(i, temp_dict) for i in x.split()]), - data, - ) - ) - if verbose: - _print_dict(temp_dict) - return data - -def leet_clean(data): - def __convert_leet(word): - # basic conversion - word = re.sub("0", "o", word) - word = re.sub("1", "i", word) - word = re.sub("3", "e", word) - word = re.sub("\$", "s", word) - word = re.sub("\@", "a", word) - return word - - if verbose: - print("#" * 10, "Step - L33T (with vocab check):") - local_vocab = {} - temp_vocab = _check_vocab(data, local_vocab, response="unknown_list") - temp_vocab = [k for k in temp_vocab if _check_replace(k)] - - temp_dict = {} - for word in temp_vocab: - new_word = __convert_leet(word) - - if new_word != word: - if (len(word) > 2) and (new_word in local_vocab): - temp_dict[word] = new_word - - data = list( - map( - lambda x: " ".join([_make_dict_cleaning(i, temp_dict) for i in x.split()]), - data, - ) - ) - if verbose: - _print_dict(temp_dict) - return data \ No newline at end of file diff --git a/utils/cocolm_tokenizer.py b/utils/cocolm_tokenizer.py deleted file mode 100755 index 1b49b2d054b5f9a887fc1f9a5b520e8ac73e89cb..0000000000000000000000000000000000000000 --- a/utils/cocolm_tokenizer.py +++ /dev/null @@ -1,517 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. -# The script is largely adapted from the huggingface transformers library - -import os -import logging -from collections import Counter -import torch -import re -import os -import unicodedata - -from transformers.tokenization_utils import PreTrainedTokenizer -from typing import Union -import swifter -import pandas as pd - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class Dictionary: - """A mapping from symbols to consecutive integers""" - - def __init__( - self, - *, # begin keyword-only arguments - bos="<s>", - pad="<pad>", - eos="</s>", - unk="<unk>", - extra_special_symbols=None, - ): - self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos - self.symbols = [] - self.count = [] - self.indices = {} - self.alias_mapper = { - "<s>": "[CLS]", - "<pad>": "[PAD]", - "</s>": "[SEP]", - "<unk>": "[UNK]", - "<mask>": "[MASK]", - "[CLS]": "[CLS]", - "[PAD]": "[PAD]", - "[SEP]": "[SEP]", - "[UNK]": "[UNK]", - "[MASK]": "[MASK]", - } - self.bos_index = self.add_symbol(bos) - self.pad_index = self.add_symbol(pad) - self.eos_index = self.add_symbol(eos) - self.unk_index = self.add_symbol(unk) - if extra_special_symbols: - for s in extra_special_symbols: - self.add_symbol(s) - self.nspecial = len(self.symbols) - - def __eq__(self, other): - return self.indices == other.indices - - def __getitem__(self, idx): - if idx < len(self.symbols): - return self.symbols[idx] - return self.unk_word - - def __len__(self): - """Returns the number of symbols in the dictionary""" - return len(self.symbols) - - def __contains__(self, sym): - return sym in self.indices - - def index(self, sym): - """Returns the index of the specified symbol""" - assert isinstance(sym, str) - if sym in self.alias_mapper: - sym = self.alias_mapper[sym] - if sym in self.indices: - return self.indices[sym] - return self.unk_index - - def unk_string(self, escape=False): - """Return unknown string, optionally escaped as: <<unk>>""" - if escape: - return "<{}>".format(self.unk_word) - else: - return self.unk_word - - def add_symbol(self, word, n=1, overwrite=False): - """Adds a word to the dictionary""" - if word in self.alias_mapper: - word = self.alias_mapper[word] - if word in self.indices and not overwrite: - idx = self.indices[word] - self.count[idx] = self.count[idx] + n - return idx - else: - idx = len(self.symbols) - self.indices[word] = idx - self.symbols.append(word) - self.count.append(n) - return idx - - def update(self, new_dict, word): - """Updates counts from new dictionary.""" - if word in self.alias_mapper: - word = self.alias_mapper[word] - for word in new_dict.symbols: - idx2 = new_dict.indices[word] - if word in self.indices: - idx = self.indices[word] - self.count[idx] = self.count[idx] + new_dict.count[idx2] - else: - idx = len(self.symbols) - self.indices[word] = idx - self.symbols.append(word) - self.count.append(new_dict.count[idx2]) - - def pad_to_multiple_(self, padding_factor): - """Pad Dictionary size to be a multiple of *padding_factor*.""" - if padding_factor > 1: - i = 0 - while len(self) % padding_factor != 0: - symbol = "madeupword{:04d}".format(i) - self.add_symbol(symbol, n=0) - i += 1 - - def bos(self): - """Helper to get index of beginning-of-sentence symbol""" - return self.bos_index - - def pad(self): - """Helper to get index of pad symbol""" - return self.pad_index - - def eos(self): - """Helper to get index of end-of-sentence symbol""" - return self.eos_index - - def unk(self): - """Helper to get index of unk symbol""" - return self.unk_index - - @classmethod - def load(cls, f): - """Loads the dictionary from a text file with the format: - - ``` - <symbol0> <count0> - <symbol1> <count1> - ... - ``` - """ - d = cls() - d.add_from_file(f) - return d - - def add_from_file(self, f): - """ - Loads a pre-existing dictionary from a text file and adds its symbols - to this instance. - """ - if isinstance(f, str): - try: - # with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd: - with open(f, "r", encoding="utf-8") as fd: - self.add_from_file(fd) - except FileNotFoundError as fnfe: - raise fnfe - except UnicodeError: - raise Exception( - "Incorrect encoding detected in {}, please " - "rebuild the dataset".format(f) - ) - return - - lines = f.readlines() - indices_start_line = self._load_meta(lines) - - for line_idx, line in enumerate(lines[indices_start_line:]): - try: - splits = line.rstrip().rsplit(" ", 1) - line = splits[0] - field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx) - if field == "#fairseq:overwrite": - overwrite = True - line, field = line.rsplit(" ", 1) - else: - overwrite = False - count = int(field) - word = line - if word in self and not overwrite: - logger.info( - "Duplicate word found when loading Dictionary: '{}', index is {}.".format( - word, self.indices[word] - ) - ) - else: - self.add_symbol(word, n=count, overwrite=overwrite) - except ValueError: - raise ValueError( - "Incorrect dictionary format, expected '<token> <cnt> [flags]'" - ) - - def _get_meta(self): - return [], [] - - def _load_meta(self, lines): - return 0 - - def save(self, f): - """Stores dictionary into a text file""" - ex_keys, ex_vals = self._get_meta() - self._save( - f, - zip( - ex_keys + self.symbols[self.nspecial :], - ex_vals + self.count[self.nspecial :], - ), - ) - - def dummy_sentence(self, length): - t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() - t[-1] = self.eos() - return t - - -def _is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if ( - (cp >= 33 and cp <= 47) - or (cp >= 58 and cp <= 64) - or (cp >= 91 and cp <= 96) - or (cp >= 123 and cp <= 126) - ): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False - - -def _truncate_seq_pair(tokens_a, tokens_b, max_length): - """Truncates a sequence pair in place to the maximum length.""" - - # This is a simple heuristic which will always truncate the longer sequence - # one token at a time. This makes more sense than truncating an equal percent - # of tokens from each, since if one sequence is very short then each token - # that's truncated likely contains more information than a longer sequence. - while True: - total_length = len(tokens_a) + len(tokens_b) - if total_length <= max_length: - break - if len(tokens_a) > len(tokens_b): - tokens_a.pop() - else: - tokens_b.pop() - - -class SentencepiecePreTokenizer(object): - def __init__(self): - self.transl_table = dict( - [(ord(x), ord(y)) for x, y in zip("‘’´“â€â€”–-", "'''\"\"---")] - ) - - def handle_single_quote(self, tokens): - line = " ".join(tokens) - line = re.sub(r"' ([smdSMDtT])\b", r"'\1", line) - line = re.sub(r"' ll\b", "'ll", line) - line = re.sub(r"' re\b", "'re", line) - line = re.sub(r"' ve\b", "'ve", line) - line = re.sub(r"' LL\b", "'LL ", line) - line = re.sub(r"' RE\b", "'RE ", line) - line = re.sub(r"' VE\b", "'VE ", line) - return line.split() - - def split_on_cont_punc(self, tokens): - new_tokens = [] - for token in tokens: - if len(token) > 1: - last_j = 0 - pre_is_punc = _is_punctuation(token[0]) - for j, ch in enumerate(token): - is_punc = _is_punctuation(ch) - if is_punc != pre_is_punc: - new_tokens.append(token[last_j:j]) - last_j = j - pre_is_punc = is_punc - if last_j < len(token): - new_tokens.append(token[last_j:]) - else: - new_tokens.append(token) - return new_tokens - - def split_pre_and_post_punc(self, tokens): - def pre_punc(token): - last_j = 0 - for j in range(1, len(token)): - if not _is_punctuation(token[j]): - last_j = j - break - return token[:last_j], token[last_j:] - - def post_punc(token): - last_j = len(token) - for j in range(len(token) - 2, -1, -1): - is_punc = _is_punctuation(token[j]) - if not _is_punctuation(token[j]): - last_j = j + 1 - break - return token[:last_j], token[last_j:] - - new_tokens = [] - for token in tokens: - if len(token) > 1 and _is_punctuation(token[0]): - a, b = pre_punc(token) - if a: - new_tokens.append(a) - if b: - if _is_punctuation(b[-1]): - c, d = post_punc(b) - if c: - new_tokens.append(c) - if d: - new_tokens.append(d) - else: - new_tokens.append(b) - elif len(token) > 1 and _is_punctuation(token[-1]): - a, b = post_punc(token) - if a: - new_tokens.append(a) - if b: - new_tokens.append(b) - else: - new_tokens.append(token) - return new_tokens - - def tokenize(self, line): - line = line.strip() - line = line.replace("``", '"').replace("''", '"') - line = line.translate(self.transl_table) - tokens = line.split() - tokens = self.split_pre_and_post_punc(tokens) - tokens = self.handle_single_quote(tokens) - return tokens - - -class COCOLMTokenizer(PreTrainedTokenizer): - vocab_files_names = {"vocab_file": "sp.model", "dict_file": "dict.txt"} - pretrained_vocab_files_map = { - "vocab_file": { - "cocolm-cased": "/models/cocolm-sp.model", - }, - "dict_file": {"cocolm-cased": "/models/cocolm-dict.txt"}, - } - max_model_input_sizes = {"cocolm-cased": 512} - - def __init__( - self, - vocab_file="/models/cocolm-sp.model", - dict_file="/models/cocolm-dict.txt", - **kwargs, - ): - super(COCOLMTokenizer, self).__init__(**kwargs) - if not os.path.exists(vocab_file): - raise EnvironmentError("file {} not found".format(vocab_file)) - try: - import sentencepiece as spm - - self.sp = spm.SentencePieceProcessor() - self.sp.Load(vocab_file) - self.pre_tokenizer = SentencepiecePreTokenizer() - self.dictionary = Dictionary.load(dict_file) - except ImportError: - raise ImportError( - "Please install sentencepiece with: pip install sentencepiece" - ) - self.dictionary.add_symbol("<mask>") - - @property - def cls_token(self): - return self.dictionary.alias_mapper[self.dictionary.bos_word] - - @property - def sep_token(self): - return self.dictionary.alias_mapper[self.dictionary.eos_word] - - @property - def pad_token(self): - return self.dictionary.alias_mapper[self.dictionary.pad_word] - - @property - def unk_token(self): - return self.dictionary.alias_mapper[self.dictionary.unk_word] - - @property - def cls_token_id(self): - return self.dictionary.bos_index - - @property - def sep_token_id(self): - return self.dictionary.eos_index - - @property - def pad_token_id(self): - return self.dictionary.pad_index - - @property - def mask_token_id(self): - return self.dictionary.index("<mask>") - - @property - def unk_token_id(self): - return self.dictionary.unk_index - - def encode_plus(self, text_a, text_b=None, add_special_tokens=True, max_length=512): - tokens_a = self.tokenize(text_a) - if text_b is not None: - tokens_b = self.tokenize(text_b) - _truncate_seq_pair(tokens_a, tokens_b, max_length - 4) - else: - if len(tokens_a) > max_length - 2: - tokens_a = tokens_a[: max_length - 2] - - if add_special_tokens: - tokens = [self.dictionary.bos_word] + tokens_a + [self.dictionary.eos_word] - if text_b is not None: - tokens += ( - [self.dictionary.eos_word] + tokens_b + [self.dictionary.eos_word] - ) - else: - tokens = tokens_a + tokens_b - - ids = self.convert_tokens_to_ids(tokens) - return {"input_ids": ids} - - def encode(self, x: str, add_special_tokens=False) -> str: - tokens = self.tokenize(x) - if add_special_tokens: - tokens = [self.dictionary.bos_word] + tokens + [self.dictionary.eos_word] - return self.convert_tokens_to_ids(tokens) - - def decode(self, ids: list) -> str: - x = "".join([self._convert_id_to_token(token_id) for token_id in ids]) - return x.replace(" ", "").replace("\u2581", " ").strip() - - def skip_space(self, tokens): - new_tokens = [] - for i, token in enumerate(tokens): - skip = False - # skip single space, to reduce total length - if token == "\u2581": - if i == len(tokens) - 1 or _is_punctuation(tokens[i + 1][0]): - skip = True - if not skip: - new_tokens.append(token) - return new_tokens - - def tokenize(self, x): - x = " ".join(self.pre_tokenizer.tokenize(x)) - tokens = self.sp.EncodeAsPieces(x) - tokens = self.skip_space(tokens) - return tokens - - def convert_tokens_to_ids(self, tokens: list): - ret = [] - if isinstance(tokens, str): - return self.dictionary.index(tokens) - for token in tokens: - ret.append(self.dictionary.index(token)) - return ret - - def _convert_id_to_token(self, index): - """Converts a token (str) in an id using the vocab.""" - token = self.dictionary[index] - return token - - def convert_tokens_to_string(self, tokens: list): - x = " ".join(tokens) - return x.replace(" ", "").replace("\u2581", " ").strip() - - def is_beginning_of_word(self, x: str) -> bool: - if x in ["<unk>", "<s>", "</s>", "<pad>", "[CLS]", "[PAD]", "[SEP]", "[UNK]"]: - # special elements are always considered beginnings - # HACK: this logic is already present in fairseq/tasks/masked_lm.py - # but these special tokens are also contained in the sentencepiece - # vocabulary which causes duplicate special tokens. This hack makes - # sure that they are all taken into account. - return True - return x.startswith("\u2581") - - def __call__(self, token: Union[str, list], return_dict=True, **kwargs) -> Union[str, dict]: # type: ignore - - if type(token) == str: - input_ids = self.encode(token, add_special_tokens=True) # type: ignore - if return_dict: - return {"input_ids": input_ids} - else: - return input_ids - - elif type(token) == list: - input_ids = pd.DataFrame(token) - input_ids = ( - input_ids[0] - .swifter.allow_dask_on_strings(enable=True) - .apply(lambda x: self(x, return_dict=False)) - ) - input_ids = input_ids.to_list() - if return_dict: - return {"input_ids": input_ids} - else: - return input_ids diff --git a/utils/dataset.py b/utils/dataset.py deleted file mode 100755 index d2043c77f90278e34b5dd3cc577d558ac447c3dd..0000000000000000000000000000000000000000 --- a/utils/dataset.py +++ /dev/null @@ -1,143 +0,0 @@ -import atexit -from typing import List, Tuple -import h5py -import torch -from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import Dataset, get_worker_info -import pandas as pd -import sentencepiece as spm -#from .clean import DeBertaCleanV2, ESclean, JSClean - - - -TOKERIZER_SETTING = { - "deberta": {"cls":1}, - "distilbart": {'cls':0}, - 'cocolm':{'cls':0}, - 'mdeberta':{'cls':1}, -} - - -class BaseDataset(Dataset): - def __init__( - self, df: pd.DataFrame, config_dict: dict = {} - ) -> None: - self.max_length = 512 - filename = config_dict['product'] - self.filename = filename - self.cls_encode = TOKERIZER_SETTING[config_dict['type']]['cls'] - self.model_type = config_dict['type'] - self.df = df - self.key = config_dict['key'] - self.config_dict = config_dict - self.clean = config_dict['clean'] - self.clean = self.clean() - - def __len__(self) -> int: - return len(self.df) - - def cleanup(self) -> None: - self.database.close() - - def h5py_worker_init(self) -> None: - self.database = h5py.File(self.filename, "r", libver="latest", swmr=True) - if self.model_type == 'cocolm': - from .cocolm_tokenizer import COCOLMTokenizer # type: ignore - self.tokenizer = COCOLMTokenizer(vocab_file="/models/cocolm-sp.model", dict_file="/models/cocolm-dict.txt",) - elif self.model_type == 'distilbart': - from transformers import BartTokenizerFast # type: ignore - self.tokenizer = BartTokenizerFast(vocab_file='/models/distilbart-vocab.json', merges_file='/models/distilbart-merges.txt') - self.tokenizer.model_max_length = 1024 - else: - self.tokenizer = spm.SentencePieceProcessor(model_file=self.config_dict["tokenizer"]) # type: ignore - atexit.register(self.cleanup) - - @staticmethod - def worker_init_fn(worker_id): - worker_info = get_worker_info() - dataset = worker_info.dataset # type: ignore - dataset.h5py_worker_init() - - @staticmethod - def collate_fn(batch): - ... - -class Task2Dataset(BaseDataset): - def __getitem__(self, index) -> Tuple: - row = self.df.loc[index] - query = self.tokenizer.encode(self.clean(row["query"])) # type: ignore - if self.model_type == 'distilbart': query = query[1:-1] - product_title = list(self.database[self.key][row.product_id]["product_title"][:]) # type: ignore - product_brand = list(self.database[self.key][row.product_id]["product_brand"][:]) # type: ignore - product_color_name = list(self.database[self.key][row.product_id]["product_color_name"][:]) # type: ignore - product_bullet_point = list(self.database[self.key][row.product_id]["product_bullet_point"][:]) # type: ignore - index = self.tokenizer.encode(" ".join(str(row["index"]))) # type: ignore - if self.model_type == 'distilbart': index = index[1:-1] - product_id = self.tokenizer.encode(" ".join(str(row["product_id"]))) # type: ignore - if self.model_type == 'distilbart': product_id = product_id[1:-1] - input_ids_pos = [1] - input_ids = [self.cls_encode] + [self.config_dict["encode"]["query"]] + query + [2] - input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["product_title"]] + product_title[1:-1] + [2] - input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["product_id"]] + product_id + [2] - input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["index"]] + index + [2] - input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["product_brand"]] + product_brand[1:-1] + [2] - input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["product_color_name"]] + product_color_name[1:-1] + [2] - input_ids_pos.append(len(input_ids)) - input_ids += [self.config_dict["encode"]["product_bullet_point"]] + product_bullet_point[1:-1] - input_ids = input_ids[:self.max_length-1] + [2] - for i in range(len(input_ids_pos)): - if input_ids_pos[i] >= self.max_length: - if input_ids_pos[-2] < self.max_length: - input_ids_pos[i] = input_ids_pos[-2] - elif input_ids_pos[1] < self.max_length: - input_ids_pos[i] = input_ids_pos[1] - else: - input_ids_pos[i] = self.max_length - 1 - - input_ids = torch.tensor(input_ids, dtype=torch.long) - input_ids_pos = torch.tensor(input_ids_pos, dtype=torch.long)[None] - #token_type_ids = torch.zeros_like(input_ids) - attention_mask = torch.ones_like(input_ids) - - feature = { - "input_ids": input_ids, - #"token_type_ids": token_type_ids, - "attention_mask": attention_mask, - "speical_token_pos": input_ids_pos, - } - - meta = { - "product_id": row['product_id'], - "example_id": row['example_id'], - } - return feature, meta - - @staticmethod - def collate_fn(batch: List) -> dict: - features = {} - features["input_ids"] = pad_sequence( - [x[0]["input_ids"] for x in batch], - batch_first=True, - padding_value=0, - ).numpy() -# features["token_type_ids"] = pad_sequence( -# [x[0]["token_type_ids"] for x in batch], -# batch_first=True, -# ).numpy() - features["attention_mask"] = pad_sequence( - [x[0]["attention_mask"] for x in batch], - batch_first=True, - ).numpy() - features["speical_token_pos"] = torch.cat( - [x[0]["speical_token_pos"] for x in batch] - ).numpy() - meta = {} - meta["product_id"] = [x[1]["product_id"] for x in batch] - meta["example_id"] = [x[1]["example_id"] for x in batch] - return {"features": features, "meta": meta} - diff --git a/utils/lgb_predict.py b/utils/lgb_predict.py deleted file mode 100755 index 010cab89c233691611b40dbcb9f4584dd3ae807d..0000000000000000000000000000000000000000 --- a/utils/lgb_predict.py +++ /dev/null @@ -1,78 +0,0 @@ -import lightgbm as lgb -import pandas as pd - -__MAP__ = ["irrelevant", "complement", "substitute", "exact"] - - -LGB_CONFIG = { - "us": { - "product_feat": "/models/us-product-feat.csv", - "model_file": "/models/lgb-us.txt", - "features": ['pred_0','pred_1','pred_2','pred_3', - 'label_0','label_1','label_2','label_3', - 'query_count','is_isbn','has_isbn','fold'] - }, - "jp": { - "product_feat": "/models/jp-product-feat.csv", - "model_file": "/models/lgb-es-jp.txt", - "features": ['pred_0','pred_1','pred_2','pred_3', - 'label_0','label_1','label_2','label_3', - 'query_count','is_isbn','has_isbn','fold','locale'] - }, - "es": { - "product_feat": "/models/es-product-feat.csv", - "model_file": "/models/lgb-es-jp.txt", - "features": ['pred_0','pred_1','pred_2','pred_3', - 'label_0','label_1','label_2','label_3', - 'query_count','is_isbn','has_isbn','fold','locale'] - }, -} - - -MAP = ["irrelevant", "complement", "substitute", "exact"] -LOCALE_MAP = {'jp':0, 'es':1, 'us':2} -COL_NAME = "esci_label" -WEIGHT_MAP = {0:0.4, 1:0.4, 2:0.4, 3:0.4, 4:0.25, 5:0.2} - -def lgb_predict(df, locale): - df = df.reset_index(drop=True) - model_file = LGB_CONFIG[locale]["model_file"] - product_feat = pd.read_csv(LGB_CONFIG[locale]["product_feat"]) - for i in range(4): - t = ( - product_feat[product_feat["label"] == i] - .set_index("product_id")["example_id"] - .to_dict() - ) - df[f"label_{i}"] = df.product_id.apply( - lambda x: t.get(x, 0) - ) # label_0 : label_3 - temp = ( - df.groupby("query")["example_id"] - .count() - .reset_index() - .rename({"example_id": "query_count"}, axis=1) - ) - temp["query_count"] = temp["query_count"] // df["fold"].nunique() - df = pd.merge(left=df, right=temp, on="query", how="left") # query_count - df["is_isbn"] = df["product_id"].apply(lambda x: int(x[0] != "B")) # is_isbn - temp = ( - (df.groupby("query").is_isbn.sum() > 0) - .astype(int) - .reset_index() - .rename({"is_isbn": "has_isbn"}, axis=1) - ) - df = pd.merge(left=df, right=temp, on="query") # has_isbn - df['locale'] = df['query_locale'].apply(lambda x: LOCALE_MAP[x]) - model = lgb.Booster(model_file=model_file) - pred = model.predict(df[LGB_CONFIG[locale]["features"]]) - sub = pd.DataFrame() - sub["example_id"] = df.example_id.values - sub['fold'] = df.fold.values - sub['weight'] = sub['fold'].apply(lambda x: WEIGHT_MAP[x]) - sub[list(range(len(MAP)))] = pred*sub['weight'].values.reshape(-1, 1) # type: ignore - sub = sub.groupby("example_id").mean().reset_index() - sub[COL_NAME] = sub[list(range(len(MAP)))].values.argmax(1) # type: ignore - sub[COL_NAME] = sub[COL_NAME].apply(lambda x: MAP[x]) - sub = sub[["example_id", COL_NAME]] - return sub diff --git a/utils/lgb_predict_task3.py b/utils/lgb_predict_task3.py deleted file mode 100644 index 9ab883e7c7b8acdba5fa6729fa26e776c5aac9e2..0000000000000000000000000000000000000000 --- a/utils/lgb_predict_task3.py +++ /dev/null @@ -1,72 +0,0 @@ -import lightgbm as lgb -import pandas as pd - -LGB_CONFIG = { - "us": { - "product_feat": "/models/us-product-feat.csv", - "model_file": "/models/lgb-us-task-3.txt", - "features": ['pred_0','pred_1','pred_2','pred_3', - 'label_0','label_1','label_2','label_3', - 'query_count','is_isbn','has_isbn','fold'] - }, - "jp": { - "product_feat": "/models/jp-product-feat.csv", - "model_file": "/models/lgb-es-jp-task-3.txt", - "features": ['pred_0','pred_1','pred_2','pred_3', - 'label_0','label_1','label_2','label_3', - 'query_count','is_isbn','has_isbn','fold','locale'] - }, - "es": { - "product_feat": "/models/es-product-feat.csv", - "model_file": "/models/lgb-es-jp-task-3.txt", - "features": ['pred_0','pred_1','pred_2','pred_3', - 'label_0','label_1','label_2','label_3', - 'query_count','is_isbn','has_isbn','fold','locale'] - }, -} - - -MAP = ["no_substitute", "substitute"] -LOCALE_MAP = {'jp':0, 'es':1, 'us':2} -COL_NAME = "substitute_label" - -def lgb_predict_task3(df, locale): - df = df.reset_index(drop=True) - model_file = LGB_CONFIG[locale]["model_file"] - product_feat = pd.read_csv(LGB_CONFIG[locale]["product_feat"]) - for i in range(4): - t = ( - product_feat[product_feat["label"] == i] - .set_index("product_id")["example_id"] - .to_dict() - ) - df[f"label_{i}"] = df.product_id.apply( - lambda x: t.get(x, 0) - ) # label_0 : label_3 - temp = ( - df.groupby("query")["example_id"] - .count() - .reset_index() - .rename({"example_id": "query_count"}, axis=1) - ) - temp["query_count"] = temp["query_count"] // df["fold"].nunique() - df = pd.merge(left=df, right=temp, on="query", how="left") # query_count - df["is_isbn"] = df["product_id"].apply(lambda x: int(x[0] != "B")) # is_isbn - temp = ( - (df.groupby("query").is_isbn.sum() > 0) - .astype(int) - .reset_index() - .rename({"is_isbn": "has_isbn"}, axis=1) - ) - df = pd.merge(left=df, right=temp, on="query") # has_isbn - df['locale'] = df['query_locale'].apply(lambda x: LOCALE_MAP[x]) - model = lgb.Booster(model_file=model_file) - pred = model.predict(df[LGB_CONFIG[locale]["features"]]) - sub = pd.DataFrame() - sub["example_id"] = df.example_id.values - sub[list(range(len(MAP)))] = pred - sub = sub.groupby("example_id").mean().reset_index() - sub[COL_NAME] = sub[list(range(len(MAP)))].values.argmax(1) # type: ignore - sub[COL_NAME] = sub[COL_NAME].apply(lambda x: MAP[x]) - sub = sub[["example_id", COL_NAME]] - return sub diff --git a/utils/onnx_predict.py b/utils/onnx_predict.py deleted file mode 100755 index 6d2ca1aac5393b815c2c997af208862494364f21..0000000000000000000000000000000000000000 --- a/utils/onnx_predict.py +++ /dev/null @@ -1,59 +0,0 @@ -from .dataset import Task2Dataset -from torch.utils.data import DataLoader -import onnxruntime as ort -import pandas as pd -from scipy.special import softmax -from tqdm import tqdm -from collections import defaultdict - -BATCH_SIZE = 4 -NUM_WORKERS = 8 - - -def onnx_predict(sub_test, config): - sub_test = sub_test.sort_values("length", ascending=False) - - if type(config["model"]) == str: - session = [ort.InferenceSession(config["model"], providers=["CUDAExecutionProvider"])] - else: - session = [] - for model in config["model"]: - session.append(ort.InferenceSession(model, providers=["CUDAExecutionProvider"])) - dataset = Task2Dataset(sub_test, config) - loader = DataLoader( - dataset, - batch_size=BATCH_SIZE, - num_workers=NUM_WORKERS, - drop_last=False, - pin_memory=True, - worker_init_fn=dataset.worker_init_fn, - collate_fn=dataset.collate_fn, - persistent_workers=False, - ) - all_output = defaultdict(list) - all_example = [] - - for data in tqdm(loader): - inputs = { - "input_ids": data["features"]["input_ids"], - "attention_mask": data["features"]["attention_mask"], - "speical_token_pos": data["features"]["speical_token_pos"], - } - for i, s in enumerate(session): - all_output[i] += list( - s.run(output_names=["output"], input_feed=dict(inputs))[0] # type: ignore - ) - all_example += data["meta"]["example_id"] - - all_df = [] - for i in range(len(session)): - df = pd.DataFrame() - df["example_id"] = all_example - df['fold'] = i - df[[0, 1, 2, 3]] = all_output[i] - all_df.append(df.copy()) - add_df = pd.concat(all_df) - add_df['fold'] = add_df['fold'] + config['fold_offset'] - del session - del loader - return add_df diff --git a/utils/run.py b/utils/run.py deleted file mode 100755 index 2bac4919649e4b71ed35aa7b9fd20268380b91b1..0000000000000000000000000000000000000000 --- a/utils/run.py +++ /dev/null @@ -1,219 +0,0 @@ -import pandas as pd -from .base_predictor import BasePredictor, PathType -from .clean import DeBertaCleanV2, ESclean, JSClean # type: ignore -from .onnx_predict import onnx_predict -from .lgb_predict import lgb_predict -from .lgb_predict_task3 import lgb_predict_task3 -from scipy.special import softmax - -CONFIG = { - "us": [ - { - "clean": DeBertaCleanV2, - "encode": { - "query": 7236, - "product_title": 1650, - "product_id": 77340, - "index": 3884, - "product_description": 3175, - "product_bullet_point": 5310, - "product_brand": 1224, - "product_color_name": 1163, - }, - "tokenizer": "/models/spm-us.model", - "model": ["/models/us-kd-v2-2-fold-0-fp16.onnx", "/models/us-kd-v2-2-fold-1-fp16.onnx"], - "product": "/models/product.h5", - "key": "us", - "type": "deberta", - "fold_offset": 0, - }, - { - "clean": DeBertaCleanV2, - "encode": { - "query": 14890, - "product_title": 10253, - "product_id": 5763, - "index": 6554, - "product_description": 29172, - "product_bullet_point": 32261, - "product_brand": 10643, - "product_color_name": 11890, - }, - "model": ["/models/us-cocolm-kd-2-fold-0-fp16.onnx", "/models/us-cocolm-kd-2-fold-1-fp16.onnx"], - "product": "/models/cocolm.h5", - "key": "us", - "type": "cocolm", - "fold_offset": 2, - }, - { - "clean": DeBertaCleanV2, - "encode": { - "query": 48360, - "product_title": 14691, - "product_id": 2688, - "index": 45673, - "product_description": 41602, - "product_bullet_point": 47977, - "product_brand": 42452, - "product_color_name": 44287, - }, - "model": ["/models/us-distilbart-two-fold-0-fp16.onnx"], - #"model": ["/models/us-distilbart-two-fold-0-fp16.onnx" ,"/models/us-distilbart-two-fold-1-fp16.onnx"], - "product": "/models/distilbart.h5", - "key": "us", - "type": "distilbart", - "fold_offset": 4, - }, - ], - "jp": [ - { - "clean": JSClean, - "encode": { - "query": 25245, - "product_title": 81171, - "product_id": 2531, - "index": 43696, - "product_description": 36173, - "product_bullet_point": 115653, - "product_brand": 51403, - "product_color_name": 2551, - }, - "tokenizer": "/models/spm-jp-es.model", - "model": ['/models/us-es-jp-mdeberta-0-fp16.onnx', '/models/us-es-jp-mdeberta-1-fp16.onnx'], - "product": "/models/product.h5", - "key": "jp", - "type": "mdeberta", - "fold_offset": 0, - }, - ], - "es": [ - { - "clean": ESclean, - "encode": { - "query": 22446, - "product_title": 10332, - "product_id": 2531, - "index": 39269, - "product_description": 80482, - "product_bullet_point": 22504, - "product_brand": 5504, - "product_color_name": 6776, - }, - "tokenizer": "/models/spm-jp-es.model", - "model": ['/models/us-es-jp-mdeberta-0-fp16.onnx', '/models/us-es-jp-mdeberta-1-fp16.onnx'], - "product": "/models/product.h5", - "key": "es", - "type": "mdeberta", - "fold_offset": 0, - }, - ], -} - -class Task2Predictor(BasePredictor): - def prediction_setup(self): - """To be implemented by the participants. - - Participants can add the steps needed to initialize their models, - and/or any other setup related things here. - """ - pass - - def predict( - self, - test_set_path: PathType, - product_catalogue_path: PathType, - predictions_output_path: PathType, - register_progress=lambda x: print("Progress : ", x), - **kwargs, - ): - test = pd.read_csv(test_set_path) - progress = 0.1 - register_progress(progress) - test = test.fillna("") - product = pd.read_csv("/models/product.csv") - all_output = [] - for locale in ["us", "jp", "es"]: - sub_test = test[test.query_locale == locale].reset_index(drop=True) - sub_product = product[ - (product.product_locale == locale) - & (product.product_id.isin(test.product_id)) - ].reset_index(drop=True) - sub_test = pd.merge( - left=sub_test, - right=sub_product[["product_id", "index", "length"]], - on="product_id", - ) - onnx_pred = [] - for i in range(len(CONFIG[locale])): # type: ignore - config = CONFIG[locale][i] - pred = onnx_predict(sub_test, config) - pred[list(range(4))] = softmax(pred[list(range(4))].values, 1) # type: ignore - pred.columns = ["example_id", "fold"] + ["pred_0",'pred_1','pred_2','pred_3'] - onnx_pred.append(pred.copy()) - progress += 0.1 - register_progress(progress) - - df = pd.concat(onnx_pred) - df = pd.merge(left=df, right=sub_test, on="example_id", how="left") - all_output.append(lgb_predict(df, locale).copy()) - progress += 0.1 - register_progress(progress) - predictions_df = pd.concat(all_output).reset_index(drop=True) - print("Writing Task-2 Predictions to : ", predictions_output_path) - predictions_df.to_csv(predictions_output_path, index=False, header=True) - register_progress(1) - - -class Task3Predictor(BasePredictor): - def prediction_setup(self): - """To be implemented by the participants. - - Participants can add the steps needed to initialize their models, - and/or any other setup related things here. - """ - pass - - def predict( - self, - test_set_path: PathType, - product_catalogue_path: PathType, - predictions_output_path: PathType, - register_progress=lambda x: print("Progress : ", x), - **kwargs, - ): - test = pd.read_csv(test_set_path) - progress = 0.1 - register_progress(progress) - test = test.fillna("") - product = pd.read_csv("/models/product.csv") - all_output = [] - for locale in ["us", "jp", "es"]: - sub_test = test[test.query_locale == locale].reset_index(drop=True) - sub_product = product[ - (product.product_locale == locale) - & (product.product_id.isin(test.product_id)) - ].reset_index(drop=True) - sub_test = pd.merge( - left=sub_test, - right=sub_product[["product_id", "index", "length"]], - on="product_id", - ) - onnx_pred = [] - for i in range(len(CONFIG[locale])): # type: ignore - config = CONFIG[locale][i] - pred = onnx_predict(sub_test, config) - pred[list(range(4))] = softmax(pred[list(range(4))].values, 1) # type: ignore - pred.columns = ["example_id", "fold"] + ["pred_0",'pred_1','pred_2','pred_3'] - onnx_pred.append(pred.copy()) - progress += 0.1 - register_progress(progress) - - df = pd.concat(onnx_pred) - df = pd.merge(left=df, right=sub_test, on="example_id", how="left") - all_output.append(lgb_predict_task3(df, locale).copy()) - progress += 0.1 - register_progress(progress) - predictions_df = pd.concat(all_output).reset_index(drop=True) - print("Writing Task-3 Predictions to : ", predictions_output_path) - predictions_df.to_csv(predictions_output_path, index=False, header=True) - register_progress(1)