forked from autogluon/autogluon
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Kaggle normalization example (autogluon#2129)
* data preprocess with normalization * add training script * add kaggle submission script * add README * add examples of normalization to README
- Loading branch information
Showing
4 changed files
with
387 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# Using MultiModalPredictor with Text Normalization | ||
|
||
Take [Feedback Prize - Predicting Effective Arguments](https://www.kaggle.com/competitions/feedback-prize-effectiveness) as an example to show how text normalization can be helpful. | ||
|
||
## 1. Preprocess data with text normalization | ||
|
||
Text normalization is the task of mapping non-canonical language, typical of speech transcription and computer-mediated communication, to a standardized writing. | ||
It is an up-stream task necessary to enable the subsequent direct employment of standard natural language processing tools and indispensable for languages such as Swiss German, | ||
with strong regional variation and no written standard. | ||
Even though the competition dataset is composed of English only, we found that applying text normalization can reduce `log loss`, the metrics that this competition is evaluating on. | ||
|
||
### 1.1 Define error handlers for codecs | ||
def replace_encoding_with_utf8(error: UnicodeError) -> Tuple[bytes, int]: | ||
return error.object[error.start : error.end].encode("utf-8"), error.end | ||
|
||
def replace_decoding_with_cp1252(error: UnicodeError) -> Tuple[str, int]: | ||
return error.object[error.start : error.end].decode("cp1252"), error.end | ||
|
||
|
||
### 1.2 Register error handlers for codecs | ||
codecs.register_error("replace_encoding_with_utf8", replace_encoding_with_utf8) | ||
codecs.register_error("replace_decoding_with_cp1252", replace_decoding_with_cp1252) | ||
|
||
|
||
### 1.3 Applying a series of decoding and encoding for normalization | ||
def resolve_encodings_and_normalize(text: str) -> str: | ||
text = ( | ||
text.encode("raw_unicode_escape") | ||
.decode("utf-8", errors="replace_decoding_with_cp1252") | ||
.encode("cp1252", errors="replace_encoding_with_utf8") | ||
.decode("utf-8", errors="replace_decoding_with_cp1252") | ||
) | ||
text = unidecode(text) | ||
return text | ||
|
||
### 1.4 Applying normalization to feature columns | ||
def read_and_process_data_with_norm(path: str, file: str, is_train: bool) -> pd.DataFrame: | ||
df = pd.read_csv(os.path.join(path, file)) | ||
df["discourse_text"] = df["discourse_text"].apply(resolve_encodings_and_normalize) | ||
return df | ||
|
||
### 1.5 A few examples of normalized texts | ||
|
||
# Example-1 pre-normalization | ||
'The same technology can make computer-animated faces more expressive\x97for video games or video surgery. \x93Most human communication is nonverbal, including emotional communication,\x94 notes Dr. Huang. \x93So computers need to understand that, too.\x94Eckman has classified six basic emotions\x97happiness, surprise, anger, disgust, fear, and sadness\x97and then associated each with characteristic movements of the facial muscles. For example, your frontalis pars lateralis muscle (above your eyes) raises your eyebrows when you\x92re surprised; your orbicularis oris (around your mouth) tightens your lips to show anger. ' | ||
|
||
# Example-1 post-normalization | ||
'The same technology can make computer-animated faces more expressive--for video games or video surgery. "Most human communication is nonverbal, including emotional communication," notes Dr. Huang. "So computers need to understand that, too."Eckman has classified six basic emotions--happiness, surprise, anger, disgust, fear, and sadness--and then associated each with characteristic movements of the facial muscles. For example, your frontalis pars lateralis muscle (above your eyes) raises your eyebrows when you\'re surprised; your orbicularis oris (around your mouth) tightens your lips to show anger. ' | ||
|
||
# Example-2 pre-normalization | ||
'"Congestion was down\xa0 60\xa0 percent\xa0 in\xa0 the\xa0 capital\xa0 of\xa0 france\xa0 after\xa0 fivedays\xa0 of\xa0 intensifying\xa0 smog."smog\xa0 by\xa0 meaning\xa0 pollution\xa0 went\xa0 down\xa0 just\xa0 60\xa0 percent\xa0 in\xa0 five\xa0 days. Thats\xa0 a\xa0 great\xa0 adavantage\xa0 just\xa0 by\xa0 limting\xa0 car\xa0 usuage. In\xa0 source\xa0 number 1\xa0 explains that " Passenger\xa0 cars\xa0 are\xa0 responsible\xa0 for\xa0 12\xa0 percent\xa0 of\xa0 greenouse\xa0 gas emissions\xa0 in\xa0 Europe.. and up\xa0 to\xa0 50\xa0 perecnt\xa0 in\xa0 some\xa0 car-intensive\xa0 areas\xa0 in\xa0 the\xa0 Untied States." We\xa0 as\xa0 a\xa0 country\xa0 shoud\xa0 lower\xa0 that\xa0 and\xa0 the\xa0 best\xa0 way\xa0 is\xa0 to\xa0 limting\xa0 car usuage . Limting car usage is one of te some advantges to lowering pollution ( greenhouse gas, smog). ' | ||
|
||
# Example-2 post-normalization | ||
'"Congestion was down 60 percent in the capital of france after fivedays of intensifying smog."smog by meaning pollution went down just 60 percent in five days. Thats a great adavantage just by limting car usuage. In source number 1 explains that " Passenger cars are responsible for 12 percent of greenouse gas emissions in Europe.. and up to 50 perecnt in some car-intensive areas in the Untied States." We as a country shoud lower that and the best way is to limting car usuage . Limting car usage is one of te some advantges to lowering pollution ( greenhouse gas, smog). ' | ||
|
||
For details, please refer to | ||
[`kaggle_feedback_prize_preprocess.py`](./kaggle_feedback_prize_preprocess.py). | ||
|
||
## 2. MultiModalPredictor for Training | ||
|
||
MultiModalPredictor can automatically build deep learning models with multimodal datasets. | ||
The tabular data we have for this Kaggle competition is a perfect example that showcases how easily we could build models with just a few lines of code using MultiModalPredictor. | ||
For details, please refer to [`kaggle_feedback_prize_train.py`](./kaggle_feedback_prize_train.py). | ||
|
||
### 2.1 Build the MultiModalPredictor | ||
|
||
You can build the predictor as following. | ||
|
||
predictor = MultiModalPredictor( | ||
label="discourse_effectiveness", | ||
problem_type="multiclass", | ||
eval_metric="log_loss", | ||
path=save_path, | ||
verbosity=3, | ||
) | ||
|
||
- `label` indicates the target value in training data. | ||
- `problem_type` indicates the type of the problem. It can be "multiclass", "binary" or "regression". | ||
- `eval_metric` indicates the evaluation metrics of the model which is always the evaluation of the competition. | ||
- `path` indicates the path to save MultiModalPredictor models. | ||
- `verbosity` controls how much information is printed. | ||
|
||
### 2.2 Train the MultiModalPredictor | ||
|
||
Then, you can train the MultiModalPredictor with `.fit()`. | ||
|
||
predictor.fit( | ||
train_data=train_df, | ||
tuning_data=val_df, | ||
presets="best_quality", | ||
hyperparameters={ | ||
"model.hf_text.checkpoint_name": "microsoft/deberta-v3-large", | ||
"optimization.learning_rate": 5e-5, | ||
"optimization.max_epochs": 7, | ||
}, | ||
) | ||
|
||
- `train_data` is the data used for training. | ||
- `tuning_data` is the data for validation. If it is empty, the tuning data will be split from training data automatically. | ||
- `presets` sets a various number of parameters depending on the quality of models one prefers. For details, please refer to [presets section](https://auto.gluon.ai/stable/tutorials/tabular_prediction/tabular-quickstart.html#presets) | ||
- `hyperparameters` is a Dict which will override the default configs in the training. The configs contain five different types. | ||
-- `model` contains the parameters which control the models used in the predictor. You can select the model you need and adjust the details. Default is selecting the models determined by the dataset automatically. | ||
--`optimization` contains the configs in the optimization process, including but not limited to max training epochs, learning rate and warm-up. | ||
|
||
### 2.3 Save Standalone Model | ||
Models should be saved for offline deployment for Kaggle competitions, and uploaded to Kaggle as `datasets` after training is done. You can specify the MultiModalPredictor to save a “standalone” model that can be loaded without internet access. | ||
|
||
predictor.save(path=save_standalone_path, standalone=True) | ||
|
||
## 2. Kaggle Kernel-only Competition with AutoGluon | ||
|
||
In a Kaggle competition, especially a code competition, users cannot obtain AutoGluon resources through the network. | ||
To solve the problem, there are two key points: | ||
|
||
- Loading AutoGluon and its related libraries through datasets. | ||
- Using standalone models to avoid model downloading. | ||
|
||
The AutoGluon and its dependencies are currently packaged in a zip file and available for downloading as data in [Kaggle notebook](https://www.kaggle.com/code/linuxdex/get-autogluon-standalone/data). | ||
You can download `autogluon_standalone.zip`, unzip it, and upload this folder as a [Kaggle Dataset](https://www.kaggle.com/datasets). | ||
|
||
Use the following code to install AutoGluon without network in a kaggle notebook. | ||
|
||
import sys | ||
sys.path.append("../input/autogluon-standalone/antlr4-python3-runtime-4.8/antlr4-python3-runtime-4.8/src/") | ||
!pip install --no-deps --no-index --quiet ../input/autogluon-standalone/autogluon_standalone/*.whl | ||
|
||
Using the saved standalone model can avoid downloading models in submission. You can refer to [Section 2.3](#23-save-standalone-model) to save the standalone model. | ||
|
||
## 3. Prediction in Kaggle Competitions | ||
|
||
Next, let's upload the predictor to Kaggle and use it to generate probabilities on the test set. You can upload the MultiModalPredictor standalone models as datasets to Kaggle directly on a notebook, | ||
or via [Kaggle API](https://www.kaggle.com/docs/api#interacting-with-datasets). | ||
Make sure that models are present under the `Input` section in your notebook. | ||
|
||
You can then load the MultiModalPredictor using the following code. | ||
|
||
pretrained_model = MultiModalPredictor.load(path=save_standalone_path) | ||
|
||
You can upload the [preprocessing script](./kaggle_feedback_prize_preprocess.py) to Kaggle following the [instructions](https://www.kaggle.com/product-feedback/91185) or paste them directly into a notebook code block. | ||
|
||
Preprocess test data with text normalization. | ||
|
||
test_df = kaggle_feedback_prize_preprocess.read_and_process_data_with_norm(data_path, "test.csv", is_train=False) | ||
|
||
With the `.predict_proba()`, you can get the probabilities of all classes. | ||
|
||
test_pred = pretrained_model.predict_proba(test_df) | ||
|
||
For detailed codes, please refer to [`kaggle_feedback_prize_submit.py`](./kaggle_feedback_prize_submit.py). | ||
|
||
## 4. Benchmarking model performance with text normalization | ||
|
||
We have benchmarked text normalization effect on different models and hyperparameters. | ||
For model evaluation, we fixed 20% of stratified samples from the training data, and we also submitted a couple of large models to Kaggle competition for leadboard scores. | ||
|
||
|
||
| model | lr | lr_decay | cv_k | normalized_text | local_log_loss | kaggle_private | kaggle_public | ||
| --- | --- | --- | --- |--- |--- |--- |--- | ||
| microsoft/deberta-v3-base | 5e-5 | 0.9 | 3 | Y | 0.5692 | ||
| microsoft/deberta-v3-base | 5e-5 | 0.9 | 3 | N | 0.5835 | ||
| microsoft/deberta-v3-base | 5e-5 | 0.9 | 5 | Y | 0.5694 | ||
| microsoft/deberta-v3-base | 5e-5 | 0.9 | 5 | N | 0.5750 | ||
| microsoft/deberta-v3-large | 5e-5 | 0.9 | 3 | Y | 0.5848 | ||
| microsoft/deberta-v3-large | 5e-5 | 0.9 | 3 | N | 0.5779 | ||
| microsoft/deberta-v3-large | 5e-5 | 0.9 | 5 | Y | 0.5552 | 0.621 | 0.6267 | ||
| microsoft/deberta-v3-large | 5e-5 | 0.9 | 5 | N | 0.5703 | 0.6228 | 0.6296 | ||
| roberta-base | 5e-5 | 0.9 | 3 | Y | 0.5969 | ||
| roberta-base | 5e-5 | 0.9 | 3 | N | 0.5944 | ||
| roberta-base | 5e-5 | 0.9 | 5 | Y | 0.5741 | ||
| roberta-base | 5e-5 | 0.9 | 5 | N | 0.5781 | ||
| roberta-large | 5e-5 | 0.9 | 3 | Y | 0.5739 | ||
| roberta-large | 5e-5 | 0.9 | 3 | N | 0.5850 | ||
| roberta-large | 5e-5 | 0.9 | 5 | Y | 0.5635 | 0.6419 | 0.6399 | ||
| roberta-large | 5e-5 | 0.9 | 5 | N | 0.5657 | 0.6439 | 0.6404 | ||
|
||
The results of the benchmark are shown in the table above. It is evident that text normalization is effective in majority of the cases. |
52 changes: 52 additions & 0 deletions
52
examples/automm/kaggle_feedback_prize/kaggle_feedback_prize_preprocess.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import codecs | ||
import os | ||
from typing import Tuple | ||
|
||
import pandas as pd | ||
from text_unidecode import unidecode | ||
|
||
|
||
def get_essay(essay_id: str, input_dir: str, is_train: bool = True) -> str: | ||
parent_path = input_dir + "train" if is_train else input_dir + "test" | ||
essay_path = os.path.join(parent_path, f"{essay_id}.txt") | ||
essay_text = open(essay_path, "r").read() | ||
return essay_text | ||
|
||
|
||
def replace_encoding_with_utf8(error: UnicodeError) -> Tuple[bytes, int]: | ||
return error.object[error.start : error.end].encode("utf-8"), error.end | ||
|
||
|
||
def replace_decoding_with_cp1252(error: UnicodeError) -> Tuple[str, int]: | ||
return error.object[error.start : error.end].decode("cp1252"), error.end | ||
|
||
|
||
# Register the encoding and decoding error handlers for `utf-8` and `cp1252`. | ||
codecs.register_error("replace_encoding_with_utf8", replace_encoding_with_utf8) | ||
codecs.register_error("replace_decoding_with_cp1252", replace_decoding_with_cp1252) | ||
|
||
|
||
def resolve_encodings_and_normalize(text: str) -> str: | ||
"""Resolve the encoding problems and normalize the abnormal characters.""" | ||
text = ( | ||
text.encode("raw_unicode_escape") | ||
.decode("utf-8", errors="replace_decoding_with_cp1252") | ||
.encode("cp1252", errors="replace_encoding_with_utf8") | ||
.decode("utf-8", errors="replace_decoding_with_cp1252") | ||
) | ||
text = unidecode(text) | ||
return text | ||
|
||
|
||
def read_and_process_data(path: str, file: str, is_train: bool) -> pd.DataFrame: | ||
df = pd.read_csv(os.path.join(path, file)) | ||
df["essay_text"] = df["essay_id"].apply(lambda x: get_essay(x, path, is_train=is_train)) | ||
return df | ||
|
||
|
||
def read_and_process_data_with_norm(path: str, file: str, is_train: bool) -> pd.DataFrame: | ||
df = pd.read_csv(os.path.join(path, file)) | ||
df["essay_text"] = df["essay_id"].apply(lambda x: get_essay(x, path, is_train=is_train)) | ||
df["discourse_text"] = df["discourse_text"].apply(resolve_encodings_and_normalize) | ||
df["essay_text"] = df["essay_text"].apply(resolve_encodings_and_normalize) | ||
return df |
57 changes: 57 additions & 0 deletions
57
examples/automm/kaggle_feedback_prize/kaggle_feedback_prize_submit.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import gc | ||
import sys | ||
import warnings | ||
|
||
import kaggle_feedback_prize_preprocess | ||
import pandas as pd | ||
import torch | ||
|
||
from autogluon.multimodal import MultiModalPredictor | ||
|
||
warnings.filterwarnings("ignore") | ||
sys.path.append("../input/autogluon-standalone/antlr4-python3-runtime-4.8/antlr4-python3-runtime-4.8/src/") | ||
!pip install - -no - deps - -no - index - -quiet .. / input / autogluon - standalone / *.whl | ||
|
||
|
||
data_path = "../input/feedback-prize-effectiveness/" | ||
|
||
config_1 = { | ||
"save_path": "../input/feedback_microsoft-deberta-v3-large/microsoft-deberta-v3-large-cv5-lr-5e-05-mepoch-7", | ||
"per_gpu_batch_size_evaluation": 2, | ||
"N_fold": 5, | ||
} | ||
config_2 = { | ||
"save_path": "../input/roberta-large/roberta-large-cv5-lr-5e-05-mepoch-7", | ||
"per_gpu_batch_size_evaluation": 2, | ||
"N_fold": 5, | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
test_df = kaggle_feedback_prize_preprocess.read_and_process_data_with_norm(data_path, "test.csv", is_train=False) | ||
|
||
configs = [config_1, config_2] | ||
weights = [0.6, 0.4] | ||
|
||
all_proba = [] | ||
for config in configs: | ||
print(config) | ||
model_proba = [] | ||
for fold in range(config["N_fold"]): | ||
pretrained_model = MultiModalPredictor.load(path=config["save_path"] + f"_{fold}") | ||
pretrained_model._config.env.per_gpu_batch_size_evaluation = config["per_gpu_batch_size_evaluation"] | ||
test_proba = pretrained_model.predict_proba(test_df) | ||
model_proba.append(test_proba) | ||
|
||
# free up CPU memory | ||
del pretrained_model | ||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
proba_concat = pd.concat(model_proba) | ||
mean_proba = proba_concat.groupby(level=0).mean() | ||
all_proba.append(mean_proba) | ||
|
||
result = sum([all_proba[i] * weights[i] for i in range(len(configs))]) | ||
result.to_csv("submission.csv", index=False) | ||
print(result) |
102 changes: 102 additions & 0 deletions
102
examples/automm/kaggle_feedback_prize/kaggle_feedback_prize_train.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import argparse | ||
import random | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch as th | ||
from kaggle_feedback_prize_preprocess import read_and_process_data | ||
from sklearn.model_selection import StratifiedKFold | ||
|
||
from autogluon.multimodal import MultiModalPredictor | ||
|
||
|
||
def get_args() -> argparse.ArgumentParser: | ||
parser = argparse.ArgumentParser( | ||
description="The example for Kaggle competition Feedback Prize - Predicting Effective Arguments." | ||
) | ||
parser.add_argument("--data_path", type=str, help="The path of the competiton dataset.", default="./data/") | ||
parser.add_argument("--model_path", type=str, help="The path of the model artifacts.", default="./model/") | ||
parser.add_argument( | ||
"--label_name", type=str, help="The column name of predictive label.", default="discourse_effectiveness" | ||
) | ||
parser.add_argument("--problem_type", type=str, help="The problem type.", default="multiclass") | ||
parser.add_argument("--eval_metric", type=str, help="The evaluation metric.", default="log_loss") | ||
parser.add_argument("--learning_rate", type=float, help="The learning rate in the training.", default=5e-5) | ||
parser.add_argument("--max_epochs", type=int, help="The max training epochs in the training.", default=7) | ||
parser.add_argument( | ||
"--text_backbone", type=str, help="Pretrained backbone for finetuning.", default="microsoft/deberta-v3-large" | ||
) | ||
parser.add_argument("--folds", type=int, help="The folds of the training.", default=5) | ||
parser.add_argument("--seed", type=int, help="The random seed.", default=42) | ||
args = parser.parse_args() | ||
|
||
backbone_model = args.text_backbone.replace("/", "-") | ||
args.save_path = args.model_path + "feedback-{}/{}-cv{}-lr-{}-mepoch-{}".format( | ||
backbone_model, | ||
backbone_model, | ||
args.folds, | ||
args.learning_rate, | ||
args.max_epochs, | ||
) | ||
return args | ||
|
||
|
||
def get_hparams(args: argparse.ArgumentParser) -> dict: | ||
hparams = { | ||
"model.hf_text.checkpoint_name": args.text_backbone, | ||
"optimization.learning_rate": args.learning_rate, | ||
"optimization.max_epochs": args.max_epochs, | ||
} | ||
|
||
return hparams | ||
|
||
|
||
def set_seed(seed: int) -> None: | ||
th.manual_seed(seed) | ||
np.random.seed(seed) | ||
random.seed(seed) | ||
|
||
|
||
def train( | ||
train_df: pd.DataFrame, val_df: pd.DataFrame, args: argparse.ArgumentParser, path: str | ||
) -> MultiModalPredictor: | ||
hparams = get_hparams(args) | ||
|
||
predictor = MultiModalPredictor( | ||
label=args.label_name, | ||
problem_type=args.problem_type, | ||
eval_metric=args.eval_metric, | ||
path=path, | ||
verbosity=3, | ||
).fit( | ||
train_data=train_df, | ||
tuning_data=val_df, | ||
presets="best_quality", | ||
hyperparameters=hparams, | ||
) | ||
|
||
return predictor | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_args() | ||
set_seed(args.seed) | ||
|
||
train_df = read_and_process_data(args.data_path, "train.csv", is_train=True) | ||
|
||
y_train = train_df[args.label_name] | ||
X_train = train_df.drop(args.label_name, axis=1) | ||
|
||
# K fold cross validation | ||
skf = StratifiedKFold(n_splits=args.folds, shuffle=True) | ||
losses = [] | ||
for i, (train_idx, val_idx) in enumerate(skf.split(X_train, y_train)): | ||
X_t, X_v = X_train.iloc[train_idx], X_train.iloc[val_idx] | ||
y_t, y_v = y_train.iloc[train_idx], y_train.iloc[val_idx] | ||
|
||
train_df = pd.concat([X_t, y_t], axis=1) | ||
val_df = pd.concat([X_v, y_v], axis=1) | ||
path = args.save_path + f"_{i}" | ||
|
||
predictor = train(train_df, val_df, args, path) | ||
predictor.save(path, standalone=True) |