-
Notifications
You must be signed in to change notification settings - Fork 727
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
285 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,10 +4,6 @@ | |
**/outputs | ||
**/data | ||
|
||
# Tests | ||
|
||
tests/ | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import pandas as pd | ||
import pytest | ||
|
||
from simpletransformers.classification import ( | ||
ClassificationModel, | ||
MultiLabelClassificationModel, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"model_type, model_name", | ||
[ | ||
("bert", "bert-base-uncased"), | ||
("xlnet", "xlnet-base-cased"), | ||
("xlm", "xlm-mlm-17-1280"), | ||
("roberta", "roberta-base"), | ||
("distilbert", "distilbert-base-uncased"), | ||
("albert", "albert-base-v1"), | ||
("camembert", "camembert-base"), | ||
("xlmroberta", "xlm-roberta-base"), | ||
], | ||
) | ||
def test_binary_classification(model_type, model_name): | ||
# Train and Evaluation data needs to be in a Pandas Dataframe of two columns. | ||
# The first column is the text with type str, and the second column is the | ||
# label with type int. | ||
train_data = [ | ||
["Example sentence belonging to class 1", 1], | ||
["Example sentence belonging to class 0", 0], | ||
] | ||
train_df = pd.DataFrame(train_data) | ||
|
||
eval_data = [ | ||
["Example eval sentence belonging to class 1", 1], | ||
["Example eval sentence belonging to class 0", 0], | ||
] | ||
eval_df = pd.DataFrame(eval_data) | ||
|
||
# Create a ClassificationModel | ||
model = ClassificationModel( | ||
model_type, | ||
model_name, | ||
use_cuda=False, | ||
args={"reprocess_input_data": True, "overwrite_output_dir": True}, | ||
) | ||
|
||
# Train the model | ||
model.train_model(train_df) | ||
|
||
# Evaluate the model | ||
result, model_outputs, wrong_predictions = model.eval_model(eval_df) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"model_type, model_name", | ||
[ | ||
("bert", "bert-base-uncased"), | ||
("xlnet", "xlnet-base-cased"), | ||
("xlm", "xlm-mlm-17-1280"), | ||
("roberta", "roberta-base"), | ||
("distilbert", "distilbert-base-uncased"), | ||
("albert", "albert-base-v1"), | ||
("camembert", "camembert-base"), | ||
("xlmroberta", "xlm-roberta-base"), | ||
], | ||
) | ||
def test_multiclass_classification(model_type, model_name): | ||
# Train and Evaluation data needs to be in a Pandas Dataframe containing at | ||
# least two columns. If the Dataframe has a header, it should contain a 'text' | ||
# and a 'labels' column. If no header is present, the Dataframe should | ||
# contain at least two columns, with the first column is the text with | ||
# type str, and the second column in the label with type int. | ||
train_data = [ | ||
["Example sentence belonging to class 1", 1], | ||
["Example sentence belonging to class 0", 0], | ||
["Example eval senntence belonging to class 2", 2], | ||
] | ||
train_df = pd.DataFrame(train_data) | ||
|
||
eval_data = [ | ||
["Example eval sentence belonging to class 1", 1], | ||
["Example eval sentence belonging to class 0", 0], | ||
["Example eval senntence belonging to class 2", 2], | ||
] | ||
eval_df = pd.DataFrame(eval_data) | ||
|
||
# Create a ClassificationModel | ||
model = ClassificationModel( | ||
model_type, | ||
model_name, | ||
num_labels=3, | ||
args={"reprocess_input_data": True, "overwrite_output_dir": True}, | ||
use_cuda=False, | ||
) | ||
|
||
# Train the model | ||
model.train_model(train_df) | ||
|
||
# Evaluate the model | ||
result, model_outputs, wrong_predictions = model.eval_model(eval_df) | ||
|
||
predictions, raw_outputs = model.predict(["Some arbitary sentence"]) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"model_type, model_name", | ||
[ | ||
("bert", "bert-base-uncased"), | ||
("xlnet", "xlnet-base-cased"), | ||
("xlm", "xlm-mlm-17-1280"), | ||
("roberta", "roberta-base"), | ||
("distilbert", "distilbert-base-uncased"), | ||
("albert", "albert-base-v1"), | ||
("xlmroberta", "xlm-roberta-base"), | ||
], | ||
) | ||
def test_multilabel_classification(model_type, model_name): | ||
# Train and Evaluation data needs to be in a Pandas Dataframe containing at | ||
# least two columns, a 'text' and a 'labels' column. The `labels` column | ||
# should contain multi-hot encoded lists. | ||
train_data = [ | ||
["Example sentence 1 for multilabel classification.", [1, 1, 1, 1, 0, 1]] | ||
] + [["This is another example sentence. ", [0, 1, 1, 0, 0, 0]]] | ||
train_df = pd.DataFrame(train_data, columns=["text", "labels"]) | ||
|
||
eval_data = [ | ||
["Example eval sentence for multilabel classification.", [1, 1, 1, 1, 0, 1]], | ||
["Example eval senntence belonging to class 2", [0, 1, 1, 0, 0, 0]], | ||
] | ||
eval_df = pd.DataFrame(eval_data) | ||
|
||
# Create a MultiLabelClassificationModel | ||
model = MultiLabelClassificationModel( | ||
model_type, | ||
model_name, | ||
num_labels=6, | ||
args={ | ||
"reprocess_input_data": True, | ||
"overwrite_output_dir": True, | ||
"num_train_epochs": 1, | ||
}, | ||
use_cuda=False, | ||
) | ||
|
||
# Train the model | ||
model.train_model(train_df) | ||
|
||
# Evaluate the model | ||
result, model_outputs, wrong_predictions = model.eval_model(eval_df) | ||
|
||
predictions, raw_outputs = model.predict( | ||
["This thing is entirely different from the other thing. "] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from simpletransformers.ner import NERModel | ||
import pandas as pd | ||
|
||
|
||
def test_named_entity_recognition(): | ||
# Creating train_df and eval_df for demonstration | ||
train_data = [ | ||
[0, "Simple", "B-MISC"], | ||
[0, "Transformers", "I-MISC"], | ||
[0, "started", "O"], | ||
[1, "with", "O"], | ||
[0, "text", "O"], | ||
[0, "classification", "B-MISC"], | ||
[1, "Simple", "B-MISC"], | ||
[1, "Transformers", "I-MISC"], | ||
[1, "can", "O"], | ||
[1, "now", "O"], | ||
[1, "perform", "O"], | ||
[1, "NER", "B-MISC"], | ||
] | ||
train_df = pd.DataFrame(train_data, columns=["sentence_id", "words", "labels"]) | ||
|
||
eval_data = [ | ||
[0, "Simple", "B-MISC"], | ||
[0, "Transformers", "I-MISC"], | ||
[0, "was", "O"], | ||
[1, "built", "O"], | ||
[1, "for", "O"], | ||
[0, "text", "O"], | ||
[0, "classification", "B-MISC"], | ||
[1, "Simple", "B-MISC"], | ||
[1, "Transformers", "I-MISC"], | ||
[1, "then", "O"], | ||
[1, "expanded", "O"], | ||
[1, "to", "O"], | ||
[1, "perform", "O"], | ||
[1, "NER", "B-MISC"], | ||
] | ||
eval_df = pd.DataFrame(eval_data, columns=["sentence_id", "words", "labels"]) | ||
|
||
# Create a NERModel | ||
model = NERModel( | ||
"bert", | ||
"bert-base-cased", | ||
args={"overwrite_output_dir": True, "reprocess_input_data": False}, | ||
use_cuda=False, | ||
) | ||
|
||
# Train the model | ||
model.train_model(train_df) | ||
|
||
# Evaluate the model | ||
result, model_outputs, predictions = model.eval_model(eval_df) | ||
|
||
# Predictions on arbitary text strings | ||
predictions, raw_outputs = model.predict(["Some arbitary sentence"]) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from simpletransformers.question_answering import QuestionAnsweringModel | ||
import json | ||
import os | ||
|
||
|
||
def test_question_answering(): | ||
# Create dummy data to use for training. | ||
train_data = [ | ||
{ | ||
"context": "This is the first context", | ||
"qas": [ | ||
{ | ||
"id": "00001", | ||
"is_impossible": False, | ||
"question": "Which context is this?", | ||
"answers": [{"text": "the first", "answer_start": 8}], | ||
} | ||
], | ||
}, | ||
{ | ||
"context": "Other legislation followed, including the Migratory Bird" | ||
" Conservation Act of 1929, a 1937 treaty prohibiting the hunting of" | ||
" right and gray whales, and the Bald Eagle Protection Act of 1940." | ||
" These later laws had a low cost to society—the species were" | ||
" relatively rare—and little opposition was raised", | ||
"qas": [ | ||
{ | ||
"id": "00002", | ||
"is_impossible": False, | ||
"question": "What was the cost to society?", | ||
"answers": [{"text": "low cost", "answer_start": 225}], | ||
}, | ||
{ | ||
"id": "00003", | ||
"is_impossible": False, | ||
"question": "What was the name of the 1937 treaty?", | ||
"answers": [ | ||
{"text": "Bald Eagle Protection Act", "answer_start": 167} | ||
], | ||
}, | ||
], | ||
}, | ||
] | ||
|
||
# Save as a JSON file | ||
os.makedirs("data", exist_ok=True) | ||
with open("data/train.json", "w") as f: | ||
json.dump(train_data, f) | ||
|
||
# Create the QuestionAnsweringModel | ||
model = QuestionAnsweringModel( | ||
"distilbert", | ||
"distilbert-base-uncased-distilled-squad", | ||
args={"reprocess_input_data": True, "overwrite_output_dir": True}, | ||
use_cuda=False | ||
) | ||
|
||
# Train the model | ||
model.train_model("data/train.json") | ||
|
||
# Evaluate the model. (Being lazy and evaluating on the train data itself) | ||
result, text = model.eval_model("data/train.json") | ||
|
||
# Making predictions using the model. | ||
to_predict = [ | ||
{ | ||
"context": "This is the context used for demonstrating predictions.", | ||
"qas": [{"question": "What is this context?", "id": "0"}], | ||
} | ||
] | ||
|
||
model.predict(to_predict) |