Skip to content

Commit

Permalink
finish train
Browse files Browse the repository at this point in the history
  • Loading branch information
1710763616 committed Oct 18, 2022
1 parent 1f15022 commit 44b80e5
Show file tree
Hide file tree
Showing 9 changed files with 417 additions and 42 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ and each would result in a tiny table (with polarity False as an example):
We collect training data based on two public sentence-pair datasets, MultiNLI [(Williams et al., 2018)](https://doi.org/10.18653/v1/n18-1101) and STS-B [(Cer et al., 2017)](http://arxiv.org/abs/1708.00055), in which each sample is comprised of a premise and a hypothesis. We perform counterfactual data augmentation (CDA) ([Zhao et al., 2018b)](https://arxiv.org/abs/1804.06876) on the sentences in MultiNLI and STS-B to construct a training set. Datasets you can download from the above link include `train.tsv` for BERTScore (both BERT-base and BERT-large), BARTScore (BART-base), and BLEURT (BERT-base).

### Train
The following example add and train a debias adapter in the BERT-large of BERTScore. A single 24GB GPU (RTX 3090) is used for the example so we recommend you to use similar or better equipments. Please note that you should download the corresponding [dataset](https://drive.google.com/drive/folders/1rqPw_h6_0CxgL4LY2LhBPMR2RODnMnv6?usp=sharing) described above first.
The following example adds and trains a debias adapter in the BERT-large of BERTScore. A single 24GB GPU (RTX 3090) is used for the example so we recommend you to use similar or better equipments. Please note that you should download the corresponding [dataset](https://drive.google.com/drive/folders/1rqPw_h6_0CxgL4LY2LhBPMR2RODnMnv6?usp=sharing) described above first.

```bash
cd Metric-Fairness/mitigating_bias/train/BERTScore
Expand All @@ -145,7 +145,7 @@ python train_BERTScore.py
--data_path ${INPUT_PATH}
```

When training finished, a debias adapter will be saved in `./adapter/`, and you can check more training details in `./logs` see [fitlog](https://fitlog.readthedocs.io/zh/latest/)
When training finished, a debias adapter will be saved in `./adapter/`, and you can check more training details in `./logs` . See [fitlog](https://fitlog.readthedocs.io/zh/latest/)

### Test

Expand Down
37 changes: 37 additions & 0 deletions mitigating_bias/train/BARTScore/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from transformers import AutoTokenizer
from fastNLP import DataSet, Instance
from fastNLP.io import Loader, DataBundle


class DataLoader(Loader):
def __init__(self, max_seq_len=150):
super().__init__()
self.max_seq_len=max_seq_len

def _load(self, path: str) -> DataSet:
print('Loading {}...'.format(path))
total_sampls, debias_samples, distillation_samples = 0, 0, 0
ds = DataSet()
with open(path, 'r') as fin:
lines = fin.readlines()
for l in lines:
items = l.split('\t')
refs = ' '.join(items[0].strip().split(' ')[:self.max_seq_len])
hyps = ' '.join(items[1].strip().split(' ')[:self.max_seq_len])
sample = {
'refs': refs,
'hyps': hyps,
'labels': float(items[2]),
'type': items[3],
}
ds.append(Instance(**sample))
# statistics
total_sampls += 1
if sample['type'] == 'debias':
debias_samples += 1
else:
distillation_samples += 1

ds.set_input("refs", "hyps", "labels")
ds.set_target("labels")
return ds
6 changes: 6 additions & 0 deletions mitigating_bias/train/BARTScore/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
adapter_transformers==3.1.0
fastNLP==1.0.0
fitlog==0.9.13
numpy==1.20.3
torch==1.12.1+cu116
transformers==4.23.1
180 changes: 180 additions & 0 deletions mitigating_bias/train/BARTScore/train_BARTScore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import argparse
import random
from math import fabs

import fitlog
import numpy as np
import torch
import torch.nn as nn
from dataloader import DataLoader
from fastNLP import (AccuracyMetric, ClassifyFPreRecMetric, DataSet,
FitlogCallback, GradientClipCallback, Instance,
LossInForward, RandomSampler, Tester, Trainer,
WarmupCallback, cache_results)
from transformers import AdamW, BartForConditionalGeneration, BartTokenizer


def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)


def parse_args():
parser = argparse.ArgumentParser()

parser.add_argument(
"--model_type", default='facebook/bart-base', type=str, required=False)
parser.add_argument(
"--adapter_name", default='debiased-bartscore', type=str, required=False)
parser.add_argument("--lr", default=1e-3, type=float, required=False)
parser.add_argument("--warmup", default=0.0, type=float, required=False)
parser.add_argument("--batch_size", default=32, type=int, required=False)
parser.add_argument("--n_epochs", default=4, type=int, required=False)
parser.add_argument("--seed", default=42, type=int, required=False)
parser.add_argument("--device", default='cuda:0', type=str, required=False)
parser.add_argument("--logging_steps", default=100,
type=int, required=False)
parser.add_argument("--bart_batch_size", default=8,
type=int, required=False)
parser.add_argument("--max_length", default=1024, type=int, required=False)
parser.add_argument(
"--data_path", default='train.tsv', type=str, required=False)
return parser.parse_args()


class BARTScore(torch.nn.Module):
def __init__(self, args):
super(BARTScore, self).__init__()

self.tokenizer = BartTokenizer.from_pretrained(args.model_type)
self.model = BartForConditionalGeneration.from_pretrained(
args.model_type)
# print(self.model)
# print(type(self.model))
self.model.add_adapter(args.adapter_name)
# add adapter and freeze other parameters
self.model.train_adapter(args.adapter_name)
self.model.to(args.device)

self.loss_fct = nn.NLLLoss(
reduction='none', ignore_index=self.model.config.pad_token_id)
self.lsm = nn.LogSoftmax(dim=1)
self.batch_size = args.bart_batch_size
self.device = args.device
self.max_length = args.max_length

def save_adapter(self, adapter_name):
self.model.save_adapter('./adapter', adapter_name)

def get_bart_score(self, src, tgt):
for i in range(0, len(src), self.batch_size):
src_list = src[i: i + self.batch_size]
tgt_list = tgt[i: i + self.batch_size]

encoded_src = self.tokenizer(
src_list,
max_length=self.max_length,
truncation=True,
padding=True,
return_tensors='pt'
)
encoded_tgt = self.tokenizer(
tgt_list,
max_length=self.max_length,
truncation=True,
padding=True,
return_tensors='pt'
)
src_tokens = encoded_src['input_ids'].to(self.device)
src_mask = encoded_src['attention_mask'].to(self.device)
tgt_tokens = encoded_tgt['input_ids'].to(self.device)
tgt_mask = encoded_tgt['attention_mask'].to(self.device)
tgt_len = tgt_mask.sum(dim=1).to(self.device)
output = self.model(
input_ids=src_tokens,
attention_mask=src_mask,
labels=tgt_tokens
)
logits = output.logits.view(-1, self.model.config.vocab_size)
loss = self.loss_fct(self.lsm(logits), tgt_tokens.view(-1))
loss = loss.view(tgt_tokens.shape[0], -1)
loss = loss.sum(dim=1) / tgt_len
if i == 0:
score = -loss
else:
score = torch.cat((score, -loss), 0)
return score

def forward(self, refs, hyps, labels):
refs = refs.tolist()
hyps = hyps.tolist()
r = self.get_bart_score(hyps, refs)
p = self.get_bart_score(refs, hyps)
f = (r + p) / 2
loss_func = torch.nn.MSELoss()
loss = loss_func(f, labels)
return {
'p': p,
'r': r,
'f': f,
'loss': loss,
}


if __name__ == '__main__':
args = parse_args()
set_seed(args)

# static hyperparams
args.all_layers = False
args.lang = 'en'
args.verbose = False
args.adapter_name = args.model_type + args.adapter_name

log_dir = './logs'
fitlog.set_log_dir(log_dir)
# fitlog.commit(__file__)
fitlog.add_hyper(args)
fitlog.add_hyper_in_file(__file__)

model = BARTScore(args)

@cache_results('cached_data.bin', _refresh=False)
def get_data(path):
paths = {
'train': path,
}
data_bundle = DataLoader().load(paths)
return data_bundle

# load dataset
data_bundle = get_data(path=args.data_path)
train_data = data_bundle.get_dataset('train')
print('# samples: {}'.format(len(train_data)))
print('Example:')
print(train_data[0])

parameters = []
# print('Trainable params:')
for name, param in model.named_parameters():
if param.requires_grad:
parameters.append(param)
# print('{}: {}'.format(name, param.shape))
optimizer = AdamW(parameters, lr=args.lr)

callbacks = []
callbacks.append(GradientClipCallback(clip_value=1, clip_type='norm'))
callbacks.append(FitlogCallback(log_loss_every=args.logging_steps))
if args.warmup > 0:
callbacks.append(WarmupCallback(warmup=args.warmup, schedule='linear'))
trainer = Trainer(train_data=train_data, model=model, loss=LossInForward(), optimizer=optimizer,
batch_size=args.batch_size, sampler=RandomSampler(), drop_last=False, update_every=1,
num_workers=4, n_epochs=args.n_epochs, print_every=50, dev_data=None, metrics=None,
validate_every=args.logging_steps, save_path=None, use_tqdm=False, device=args.device,
callbacks=callbacks, dev_batch_size=None, metric_key=None)
trainer.train(load_best_model=False)
model.save_adapter(args.adapter_name)
fitlog.finish()
40 changes: 0 additions & 40 deletions mitigating_bias/train/BERTScore/adapter/adapter_config.json

This file was deleted.

Binary file not shown.
37 changes: 37 additions & 0 deletions mitigating_bias/train/BLEURT/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from transformers import AutoTokenizer
from fastNLP import DataSet, Instance
from fastNLP.io import Loader, DataBundle
import os

class DataLoader(Loader):
def __init__(self, max_seq_len=150):
super().__init__()
self.max_seq_len=max_seq_len

def _load(self, path: str) -> DataSet:
print('Loading {}...'.format(path))
total_sampls, debias_samples, distillation_samples = 0, 0, 0
ds = DataSet()
with open(path, 'r') as fin:
lines = fin.readlines()
for l in lines:
items = l.split('\t')
refs = ' '.join(items[0].strip().split(' ')[:self.max_seq_len])
hyps = ' '.join(items[1].strip().split(' ')[:self.max_seq_len])
sample = {
'refs': refs,
'hyps': hyps,
'labels': float(items[2]),
'type': items[3],
}
ds.append(Instance(**sample))
# statistics
total_sampls += 1
if sample['type'] == 'debias':
debias_samples += 1
else:
distillation_samples += 1

ds.set_input("refs", "hyps", "labels")
ds.set_target("labels")
return ds
6 changes: 6 additions & 0 deletions mitigating_bias/train/BLEURT/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
adapter_transformers==3.1.0
fastNLP==1.0.0
fitlog==0.9.13
numpy==1.20.3
torch==1.12.1+cu116
transformers==4.23.1
Loading

0 comments on commit 44b80e5

Please sign in to comment.