Skip to content
Snippets Groups Projects
lgb_predict_task2.py 2.85 KiB
import lightgbm as lgb
import pandas as pd

__MAP__ = ["irrelevant", "complement", "substitute", "exact"]


LGB_CONFIG = {
    "us": {
        "product_feat": "/models/us-product-feat.csv",
        "model_file": "/models/lgb-us-task-2.txt",
        "features": ['pred_0','pred_1','pred_2','pred_3',
            'label_0','label_1','label_2','label_3',
            'query_count','is_isbn','has_isbn','fold']
    },
    "jp": {
        "product_feat": "/models/jp-product-feat.csv",
        "model_file": "/models/lgb-es-jp-task-2.txt",
        "features": ['pred_0','pred_1','pred_2','pred_3',
            'label_0','label_1','label_2','label_3',
            'query_count','is_isbn','has_isbn','fold','locale']
    },
    "es": {
        "product_feat": "/models/es-product-feat.csv",
        "model_file": "/models/lgb-es-jp-task-2.txt",
        "features": ['pred_0','pred_1','pred_2','pred_3',
            'label_0','label_1','label_2','label_3',
            'query_count','is_isbn','has_isbn','fold','locale']
    },
}


MAP = ["irrelevant", "complement", "substitute", "exact"]
LOCALE_MAP = {'jp':0, 'es':1, 'us':2}
COL_NAME = "esci_label"
WEIGHT_MAP = {0:1, 1:1, 2:1, 3:1}

def lgb_predict_task2(df, locale):
    df = df.reset_index(drop=True)
    model_file = LGB_CONFIG[locale]["model_file"]
    product_feat = pd.read_csv(LGB_CONFIG[locale]["product_feat"])
    for i in range(4):
        t = (
            product_feat[product_feat["label"] == i]
            .set_index("product_id")["example_id"]
            .to_dict()
        )
        df[f"label_{i}"] = df.product_id.apply(
            lambda x: t.get(x, 0)
        )
    temp = (
        df.groupby("query")["example_id"]
        .count()
        .reset_index()
        .rename({"example_id": "query_count"}, axis=1)
    )
    temp["query_count"] = temp["query_count"] // df["fold"].nunique()
    df = pd.merge(left=df, right=temp, on="query", how="left")  # query_count
    df["is_isbn"] = df["product_id"].apply(lambda x: int(x[0] != "B"))  # is_isbn
    temp = (
        (df.groupby("query").is_isbn.sum() > 0)
        .astype(int)
        .reset_index()
        .rename({"is_isbn": "has_isbn"}, axis=1)
    )
    df = pd.merge(left=df, right=temp, on="query")  # has_isbn
    df['locale'] = df['query_locale'].apply(lambda x: LOCALE_MAP[x])
    model = lgb.Booster(model_file=model_file)
    pred = model.predict(df[LGB_CONFIG[locale]["features"]])
    sub = pd.DataFrame()
    sub["example_id"] = df.example_id.values
    sub['fold'] = df.fold.values
    sub['weight'] = sub['fold'].apply(lambda x: WEIGHT_MAP[x])
    sub[list(range(len(MAP)))] = pred*sub['weight'].values.reshape(-1, 1) # type: ignore
    sub = sub.groupby("example_id").mean().reset_index()
    sub[COL_NAME] = sub[list(range(len(MAP)))].values.argmax(1)  # type: ignore
    sub[COL_NAME] = sub[COL_NAME].apply(lambda x: MAP[x])
    sub = sub[["example_id", COL_NAME]]
    return sub