-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
19d1f95
commit b9231c2
Showing
18 changed files
with
7,865 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
## Baseline model of SUMBT | ||
Ontology-based DST model인 [SUMBT](https://arxiv.org/abs/1907.07421)의 한국어 구현체입니다. <br> | ||
|
||
- 기존의 GloVe, Char Embedding 대신 `dsksd/bert-ko-small-minimal`의 `token_embeddings`을pretrained Subword Embedding으로 사용합니다. | ||
- 메모리를 아끼기 위해 Token Embedding (768) => Hidden Dimension (300)으로의 Projection layer가 들어 있습니다. | ||
- 빠른 학습을 위해 `Parallel Decoding`이 구현되어 있습니다. | ||
|
||
|
||
### 1. 필요한 라이브러리 설치 | ||
|
||
`pip install -r requirements.txt` | ||
|
||
### 2. 모델 학습 | ||
|
||
`python train_sumbt.py` <br> | ||
학습된 모델은 epoch 별로 `--model_dir {MODEL_DIR}` 으로 저장됩니다.<br> | ||
추론에 필요한 부가 정보인 configuration들도 같은 경로에 저장됩니다.<br> | ||
Best Checkpoint Path가 학습 마지막에 표기됩니다.<br> | ||
|
||
### 3. 추론하기 | ||
|
||
`python inference_sumbt.py` | ||
|
||
### 4. 제출하기 | ||
|
||
3번 스텝 `inference_sumbt.py`에서 `--output_dir {OUTPUT_DIR}`에 저장된 `predictions.json`을 제출합니다. | ||
|
||
--- | ||
## Baseline model of TRADE | ||
|
||
|
||
Open-vocab based DST model인 [TRADE](https://arxiv.org/abs/1905.08743)의 한국어 구현체입니다. (5강, 6강 내용 참고) <br> | ||
|
||
- 기존의 GloVe, Char Embedding 대신 `monologg/koelectra-base-v3-discriminator`의 `token_embeddings`을pretrained Subword Embedding으로 사용합니다. | ||
- 메모리를 아끼기 위해 Token Embedding (768) => Hidden Dimension (400)으로의 Projection layer가 들어 있습니다. | ||
- 빠른 학습을 위해 `Parallel Decoding`이 구현되어 있습니다. | ||
- encoder가 Bert인 모델도 구현되어 있습니다. | ||
|
||
|
||
### 1. 필요한 라이브러리 설치 | ||
|
||
`pip install -r requirements.txt` | ||
|
||
### 2. 모델 학습 | ||
|
||
`python train_trade.py` <br> | ||
학습된 모델은 epoch 별로 `SM_MODEL_DIR/model-{epoch}.bin` 으로 저장됩니다.<br> | ||
추론에 필요한 부가 정보인 configuration들도 같은 경로에 저장됩니다.<br> | ||
Best Checkpoint Path가 학습 마지막에 표기됩니다.<br> | ||
|
||
`python train_tradeBert.py` <br> | ||
Utterance encoder로 Bert 모듈을 사용합니다.<br> | ||
|
||
### 3. 추론하기 | ||
|
||
`python inference.py` | ||
|
||
### 4. 제출하기 | ||
|
||
3번 스텝 `inference.py`에서 `--output_dir {OUTPUT_DIR}`에 저장된 `predictions.json`을 제출합니다. | ||
|
||
--- | ||
## Baseline model of SOM-DST | ||
Open-vocab based DST model인 [SOM-DST](https://arxiv.org/abs/1911.03906)의 한국어 구현체입니다. |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,278 @@ | ||
import dataclasses | ||
import json | ||
import random | ||
from collections import defaultdict | ||
from copy import deepcopy | ||
from dataclasses import dataclass | ||
from typing import List, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
from torch.utils.data import Dataset | ||
from tqdm import tqdm | ||
|
||
|
||
@dataclass | ||
class OntologyDSTFeature: | ||
guid: str | ||
input_ids: List[int] | ||
segment_ids: List[int] | ||
num_turn: int | ||
target_ids: Optional[List[int]] | ||
|
||
|
||
@dataclass | ||
class OpenVocabDSTFeature: | ||
guid: str | ||
input_id: List[int] | ||
segment_id: List[int] | ||
gating_id: List[int] | ||
target_ids: Optional[Union[List[int], List[List[int]]]] | ||
slot_positions: [List[int]] = None | ||
domain_id: int = None | ||
|
||
|
||
class WOSDataset(Dataset): | ||
def __init__(self, features): | ||
self.features = features | ||
self.length = len(self.features) | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
def __getitem__(self, idx): | ||
return self.features[idx] | ||
|
||
|
||
def load_dataset(dataset_path, dev_split=0.1): | ||
data = json.load(open(dataset_path, 'rt', encoding='UTF8')) | ||
num_data = len(data) | ||
num_dev = int(num_data * dev_split) | ||
if not num_dev: | ||
return data, [] # no dev dataset | ||
#print(num_data, num_dev) | ||
dom_mapper = defaultdict(list) | ||
for d in data: | ||
dom_mapper[len(d["domains"])].append(d["dialogue_idx"]) | ||
|
||
num_per_domain_trainsition = num_dev // 3 | ||
dev_idx = [] | ||
#print() | ||
for v in dom_mapper.values(): | ||
if len(v) <= num_per_domain_trainsition: | ||
dev_idx.extend(v) | ||
else: | ||
idx = random.sample(v, num_per_domain_trainsition) | ||
dev_idx.extend(idx) | ||
|
||
train_data, dev_data = [], [] | ||
for d in data: | ||
if d["dialogue_idx"] in dev_idx: | ||
dev_data.append(d) | ||
else: | ||
train_data.append(d) | ||
|
||
dev_labels = {} | ||
for dialogue in dev_data: | ||
d_idx = 0 | ||
guid = dialogue["dialogue_idx"] | ||
for idx, turn in enumerate(dialogue["dialogue"]): | ||
if turn["role"] != "user": | ||
continue | ||
|
||
state = turn.pop("state") | ||
|
||
guid_t = f"{guid}-{d_idx}" | ||
d_idx += 1 | ||
|
||
dev_labels[guid_t] = state | ||
|
||
return train_data, dev_data, dev_labels | ||
|
||
|
||
def set_seed(seed): | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
if torch.cuda.device_count() > 0: | ||
torch.cuda.manual_seed_all(seed) | ||
|
||
|
||
def split_slot(dom_slot_value, get_domain_slot=False): | ||
try: | ||
dom, slot, value = dom_slot_value.split("-") | ||
except ValueError: | ||
tempo = dom_slot_value.split("-") | ||
if len(tempo) < 2: | ||
return dom_slot_value, dom_slot_value, dom_slot_value | ||
dom, slot = tempo[0], tempo[1] | ||
value = dom_slot_value.replace(f"{dom}-{slot}-", "").strip() | ||
|
||
if get_domain_slot: | ||
return f"{dom}-{slot}", value | ||
return dom, slot, value | ||
|
||
|
||
def build_slot_meta(data): | ||
slot_meta = [] | ||
for dialog in data: | ||
for turn in dialog["dialogue"]: | ||
if not turn.get("state"): | ||
continue | ||
|
||
for dom_slot_value in turn["state"]: | ||
domain_slot, _ = split_slot(dom_slot_value, get_domain_slot=True) | ||
if domain_slot not in slot_meta: | ||
slot_meta.append(domain_slot) | ||
return sorted(slot_meta) | ||
|
||
|
||
def convert_state_dict(state): | ||
dic = {} | ||
for slot in state: | ||
s, v = split_slot(slot, get_domain_slot=True) | ||
dic[s] = v | ||
return dic | ||
|
||
|
||
@dataclass | ||
class DSTInputExample: | ||
guid: str | ||
context_turns: List[str] | ||
current_turn: List[str] | ||
label: Optional[List[str]] = None | ||
domains: List[str] = None | ||
|
||
def to_dict(self): | ||
return dataclasses.asdict(self) | ||
|
||
def to_json_string(self): | ||
"""Serializes this instance to a JSON string.""" | ||
return json.dumps(self.to_dict(), indent=2) + "\n" | ||
|
||
|
||
def _truncate_seq_pair(tokens_a, tokens_b, max_length): | ||
"""Truncates a sequence pair in place to the maximum length.""" | ||
|
||
# This is a simple heuristic which will always truncate the longer sequence | ||
# one token at a time. This makes more sense than truncating an equal percent | ||
# of tokens from each, since if one sequence is very short then each token | ||
# that's truncated likely contains more information than a longer sequence. | ||
while True: | ||
total_length = len(tokens_a) + len(tokens_b) | ||
if total_length <= max_length: | ||
break | ||
if len(tokens_a) > len(tokens_b): | ||
tokens_a.pop() | ||
else: | ||
tokens_b.pop() | ||
|
||
def tokenize_ontology(ontology, tokenizer, max_seq_length=12): | ||
slot_types = [] | ||
slot_values = [] | ||
for k, v in ontology.items(): | ||
tokens = tokenizer.encode(k) | ||
if len(tokens) < max_seq_length: | ||
gap = max_seq_length - len(tokens) | ||
tokens.extend([tokenizer.pad_token_id] * gap) | ||
slot_types.append(tokens) | ||
slot_value = [] | ||
for vv in v: | ||
tokens = tokenizer.encode(vv) | ||
if len(tokens) < max_seq_length: | ||
gap = max_seq_length - len(tokens) | ||
tokens.extend([tokenizer.pad_token_id] * gap) | ||
slot_value.append(tokens) | ||
slot_values.append(torch.LongTensor(slot_value)) | ||
return torch.LongTensor(slot_types), slot_values | ||
|
||
def get_examples_from_dialogue(dialogue, user_first=False): | ||
guid = dialogue["dialogue_idx"] | ||
examples = [] | ||
history = [] | ||
d_idx = 0 | ||
domains = dialogue["domains"] | ||
for idx, turn in enumerate(dialogue["dialogue"]): | ||
if turn["role"] != "user": | ||
continue | ||
|
||
if idx: | ||
sys_utter = dialogue["dialogue"][idx - 1]["text"] | ||
else: | ||
sys_utter = "" | ||
|
||
user_utter = turn["text"] | ||
state = turn.get("state") | ||
context = deepcopy(history) | ||
if user_first: | ||
current_turn = [user_utter, sys_utter] | ||
else: | ||
current_turn = [sys_utter, user_utter] | ||
examples.append( | ||
DSTInputExample( | ||
guid=f"{guid}-{d_idx}", | ||
context_turns=context, | ||
current_turn=current_turn, | ||
label=state, | ||
domains=domains, | ||
) | ||
) | ||
history.append(sys_utter) | ||
history.append(user_utter) | ||
d_idx += 1 | ||
return examples | ||
|
||
|
||
def get_examples_from_dialogues(data, user_first=False, dialogue_level=False): | ||
examples = [] | ||
for d in tqdm(data): | ||
example = get_examples_from_dialogue(d, user_first=user_first) | ||
if dialogue_level: | ||
examples.append(example) | ||
else: | ||
examples.extend(example) | ||
return examples | ||
|
||
|
||
class DSTPreprocessor: | ||
def __init__(self, slot_meta, src_tokenizer, trg_tokenizer=None, ontology=None): | ||
self.slot_meta = slot_meta | ||
self.src_tokenizer = src_tokenizer | ||
self.trg_tokenizer = trg_tokenizer if trg_tokenizer else src_tokenizer | ||
self.ontology = ontology | ||
|
||
def pad_ids(self, arrays, pad_idx, max_length=-1): | ||
if max_length < 0: | ||
max_length = max(list(map(len, arrays))) | ||
|
||
arrays = [ | ||
array + [pad_idx] * (max_length - min(len(array), 512)) for array in arrays | ||
] | ||
return arrays | ||
|
||
def pad_id_of_matrix(self, arrays, padding, max_length=-1, left=False): | ||
if max_length < 0: | ||
max_length = max([array.size(-1) for array in arrays]) | ||
|
||
new_arrays = [] | ||
for i, array in enumerate(arrays): | ||
n, l = array.size() | ||
pad = torch.zeros(n, (max_length - l)) | ||
pad[ | ||
:, | ||
:, | ||
] = padding | ||
pad = pad.long() | ||
m = torch.cat([array, pad], -1) | ||
new_arrays.append(m.unsqueeze(0)) | ||
|
||
return torch.cat(new_arrays, 0) | ||
|
||
def _convert_example_to_feature(self): | ||
raise NotImplementedError | ||
|
||
def convert_examples_to_features(self): | ||
raise NotImplementedError | ||
|
||
def recover_state(self): | ||
raise NotImplementedError |
Oops, something went wrong.