Skip to content

Commit

Permalink
camera ready commit
Browse files Browse the repository at this point in the history
  • Loading branch information
yixinL7 committed Apr 11, 2021
1 parent bec1f53 commit 3ede43c
Show file tree
Hide file tree
Showing 8 changed files with 343 additions and 1,504 deletions.
99 changes: 96 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Refactoring-Summarization
Code for our paper:
"RefSum: Refactoring Neural Summarization"
"RefSum: Refactoring Neural Summarization", NAACL 2021.

<img src="https://github.com/yixinL7/Refactoring-Summarization/blob/main/intro-gap.png" width="500">

We present a model, Refactor, which can be used either as a base system or a meta system for text summarization.
## Outline
* ### [Install](https://github.com/yixinL7/Refactoring-Summarization#how-to-install)
* ### [Train your Refactor](https://github.com/yixinL7/Refactoring-Summarization#how-to-run)
Expand All @@ -24,7 +27,6 @@ Code for our paper:
- `model.py` -> Refactor model
- `data_utils.py` -> dataloader
- `utils.py` -> utility functions
- `preprocess.py` -> data preprocessing
- `demo.py` -> off-the-shelf refactoring


Expand Down Expand Up @@ -64,7 +66,71 @@ We use four datasets for our experiments.
- PubMed -> https://github.com/armancohan/long-summarization
- WikiHow -> https://github.com/mahnazkoupaee/WikiHow-Dataset

You can find the processed data for all of our experiments here [TODO: ADD LINK]. After downloading, you should put the data in `./data` directory.
You can find the processed data for all of our experiments [here](https://drive.google.com/drive/folders/1QvlxYVyEN1tGzzzNrfAcNIui56qdhezL?usp=sharing). After downloading, you should put the data in `./data` directory.

<table>
<thead>
<tr>
<th>Dataset</th>
<th>Experiment</th>
<th>Link</th>
</tr>
</thead>
<tbody>
<tr>
<td rowspan="6">CNNDM</td>
<td>Pre-train</td>
<td><a href="https://drive.google.com/file/d/1kcwR0PswyBXWGrNJBcg7Et65keSSsXoc/view?usp=sharing" target="_blank" rel="noopener noreferrer">Download</a></td>
</tr>
<tr>
<td>BART Reranking</td>
<td><a href="https://drive.google.com/file/d/1GfwqDpFBPV3jOaCUtGRt8FRlUak9YzyV/view?usp=sharing" target="_blank" rel="noopener noreferrer">Download</a></td>
</tr>
<tr>
<td>GSum Reranking</td>
<td><a href="https://drive.google.com/file/d/1hue7r7tU-9o1pnNuHC6wDV4bCFpwtK95/view?usp=sharing" target="_blank" rel="noopener noreferrer">Download</a></td>
</tr>
<tr>
<td>Two-system Combination (System-level)</td>
<td><a href="https://drive.google.com/file/d/1WIf9WvKX90fHxVCR5ywb0Kd5mZJgu9cz/view?usp=sharing" target="_blank" rel="noopener noreferrer">Download</a></td>
</tr>
<tr>
<td>Two-system Combination (Sentence-level)</td>
<td><a href="https://drive.google.com/file/d/1z0EFkOtTXriarv7tR3KY3D_Sssx4yHEQ/view?usp=sharing" target="_blank" rel="noopener noreferrer">Download</a></td>
</tr>
<tr>
<td>Three-system Combination (System-level)</td>
<td><a href="https://drive.google.com/file/d/1sklrdsA_UxNAYeK1helUJ_ZdhdcltRZz/view?usp=sharing" target="_blank" rel="noopener noreferrer">Download</a></td>
</tr>
<tr>
<td rowspan="2">XSum</td>
<td>Pre-train</td>
<td><a href="https://drive.google.com/file/d/1fSPJDmkBakYcfOhAF_UlLCbThR6h1O74/view?usp=sharing" target="_blank" rel="noopener noreferrer">Download</a></td>
</tr>
<tr>
<td>PEGASUS Reranking</td>
<td><a href="https://drive.google.com/file/d/1ZqdooQ4YwwRg4qab3lEUu-Wr7NV11gKe/view?usp=sharing" target="_blank" rel="noopener noreferrer">Download</a></td>
</tr>
<tr>
<td rowspan="2">PubMed</td>
<td>Pre-train</td>
<td><a href="https://drive.google.com/file/d/1l_LmeNPRTv_L9GPctFYNZVp5gp0t7DDG/view?usp=sharing" target="_blank" rel="noopener noreferrer">Download</a></td>
</tr>
<tr>
<td>BART Reranking</td>
<td><a href="https://drive.google.com/file/d/1lW3VefPnPs664qy5o4Qub9IpIH2YfWHt/view?usp=sharing" target="_blank" rel="noopener noreferrer">Download</a></td>
</tr>
<tr>
<td rowspan="2">WikiHow</td>
<td>Pre-train</td>
<td><a href="https://drive.google.com/file/d/1p2Us8qvKqwgQcE6ZIUR5-umMtBxGJ2ef/view?usp=sharing" target="_blank" rel="noopener noreferrer">Download</a></td>
</tr>
<tr>
<td>BART Reranking</td>
<td><a href="https://drive.google.com/file/d/1HELUaZm4FpOXZ1hF5n4nqtsDyNHUygZL/view?usp=sharing" target="_blank" rel="noopener noreferrer">Download</a></td>
</tr>
</tbody>
</table>

## 5. Results

Expand All @@ -90,7 +156,34 @@ You can find the processed data for all of our experiments here [TODO: ADD LINK]
| Summary-Level Combination | 45.04 | 21.61 | 41.72 |
| Sentence-Level Combination | 44.93 | 21.48 | 41.42 |

#### System-Combination (BART, pre-trained Refactor and GSum)
| | ROUGE-1 | ROUGE-2 | ROUGE-L |
|----------------------------|---------|---------|---------|
| BART | 44.26 | 21.12 | 41.16 |
| pre-trained Refactor | 44.13 | 20.51 | 40.29 |
| GSum | 45.93 | 22.30 | 42.68 |
| Summary-Level Combination | 46.12 | 22.46 | 42.92 |

### XSum
#### Reranking PEGASUS
| | ROUGE-1 | ROUGE-2 | ROUGE-L |
|----------|---------|---------|---------|
| PEGASUS | 47.12 | 24.46 | 39.04 |
| Refactor | 47.45 | 24.55 | 39.41 |

### PubMed
#### Reranking BART
| | ROUGE-1 | ROUGE-2 | ROUGE-L |
|----------|---------|---------|---------|
| BART | 43.42 | 15.32 | 39.21 |
| Refactor | 43.72 | 15.41 | 39.51 |

### WikiHow
#### Reranking BART
| | ROUGE-1 | ROUGE-2 | ROUGE-L |
|----------|---------|---------|---------|
| BART | 41.98 | 18.09 | 40.53 |
| Refactor | 42.12 | 18.13 | 40.66 |



Expand Down
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def padding(sents):

def scoring(data_pt, model_pt, result_pt):
model = Refactor('bert-base-uncased', num_layers=2).to(device)
model.load_state_dict(torch.load(model_pt), map_location=device)
model.load_state_dict(torch.load(model_pt, map_location=device))
model = model.eval()
rouge1, rouge2, rougeLsum = 0, 0, 0
num = 0
Expand Down
Binary file added intro-gap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 12 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
logging.getLogger("transformers.tokenization_utils_fast").setLevel(logging.ERROR)


def base_setting(args):
args.batch_size = getattr(args, 'batch_size', 1)
args.num_layers = getattr(args, 'num_layers', 2) # transformer layers
Expand All @@ -44,11 +45,12 @@ def base_setting(args):
args.pretrained = getattr(args, "pretrained", None)
args.max_lr = getattr(args, "max_lr", 2e-3)
args.scale = getattr(args, "scale", 1)
args.datatype = getattr(args, "datatype", "two")
args.datatype = getattr(args, "datatype", "pre")
args.dataset = getattr(args, "dataset", "CNNDM")
args.use_ids = getattr(args, "use_ids", False)
args.use_ids = getattr(args, "use_ids", True) # set true for pretraining
args.max_len = getattr(args, "max_len", 512)
args.max_num = getattr(args, "max_num", 4)
args.max_num = getattr(args, "max_num", 4) # max number of candidates


def evaluation(args):
# load data
Expand Down Expand Up @@ -175,14 +177,13 @@ def run(rank, args):
is_mp = len(args.gpuid) > 1
world_size = len(args.gpuid)
if is_master:
id = random.randint(0, 100000)
recorder = Recorder(id, args.log)
recorder = Recorder(args.log)
tok = BertTokenizer.from_pretrained(args.model_type)
if args.use_ids:
collate_fn = partial(collate_mp_ids, pad_token_id=tok.pad_token_id, is_test=False)
collate_fn_val = partial(collate_mp_ids, pad_token_id=tok.pad_token_id, is_test=True)
train_set = RefactoringIDsDataset(f"./{args.dataset}/{args.datatype}/train", args.model_type, maxlen=args.max_len, max_num=args.max_num)
val_set = RefactoringDataset(f"./{args.dataset}/{args.datatype}/val", args.model_type, is_test=True, maxlen=512, is_sorted=False)
val_set = RefactoringIDsDataset(f"./{args.dataset}/{args.datatype}/val", args.model_type, is_test=True, maxlen=512, is_sorted=False)
else:
collate_fn = partial(collate_mp, pad_token_id=tok.pad_token_id, is_test=False)
collate_fn_val = partial(collate_mp, pad_token_id=tok.pad_token_id, is_test=True)
Expand All @@ -203,15 +204,15 @@ def run(rank, args):
model = Refactor(model_path, num_layers=args.num_layers)

if args.model_pt is not None:
model.load_state_dict(torch.load(args.model_pt), map_location=f'cuda:{gpuid}')
model.load_state_dict(torch.load(args.model_pt, map_location=f'cuda:{gpuid}'))
if args.cuda:
if len(args.gpuid) == 1:
model = model.cuda()
else:
dist.init_process_group("nccl", rank=rank, world_size=world_size)
model = nn.parallel.DistributedDataParallel(model.to(gpuid), [gpuid], find_unused_parameters=True)
model.train()
init_lr = args.sim_lr / args.warmup_steps
init_lr = args.max_lr / args.warmup_steps
optimizer = optim.Adam(model.parameters(), lr=init_lr)
if is_master:
recorder.write_config(args, [model], __file__)
Expand Down Expand Up @@ -262,15 +263,16 @@ def run(rank, args):
else:
recorder.save(model, "model.bin")
recorder.save(optimizer, "optimizer.bin")
recorder.print("best - epoch: %d, batch: %d"%(epoch + 1, i / args.accumulate_step))
recorder.print("best - epoch: %d, batch: %d"%(epoch + 1, i / args.accumulate_step + 1))
if is_master:
if is_mp:
recorder.save(model.module, "model_cur.bin")
else:
recorder.save(model, "model_cur.bin")
recorder.save(optimizer, "optimizer_cur.bin")
recorder.print("val score: %.6f"%(1 - loss))



def main(args):
# set env
if len(args.gpuid) > 1:
Expand Down
Loading

0 comments on commit 3ede43c

Please sign in to comment.