Skip to content

Commit

Permalink
Add Code files
Browse files Browse the repository at this point in the history
  • Loading branch information
BongjinKim authored May 23, 2021
1 parent 19d1f95 commit b9231c2
Show file tree
Hide file tree
Showing 18 changed files with 7,865 additions and 0 deletions.
2,493 changes: 2,493 additions & 0 deletions BongjinKim/EDA.ipynb

Large diffs are not rendered by default.

64 changes: 64 additions & 0 deletions BongjinKim/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
## Baseline model of SUMBT
Ontology-based DST model인 [SUMBT](https://arxiv.org/abs/1907.07421)의 한국어 구현체입니다. <br>

- 기존의 GloVe, Char Embedding 대신 `dsksd/bert-ko-small-minimal``token_embeddings`을pretrained Subword Embedding으로 사용합니다.
- 메모리를 아끼기 위해 Token Embedding (768) => Hidden Dimension (300)으로의 Projection layer가 들어 있습니다.
- 빠른 학습을 위해 `Parallel Decoding`이 구현되어 있습니다.


### 1. 필요한 라이브러리 설치

`pip install -r requirements.txt`

### 2. 모델 학습

`python train_sumbt.py` <br>
학습된 모델은 epoch 별로 `--model_dir {MODEL_DIR}` 으로 저장됩니다.<br>
추론에 필요한 부가 정보인 configuration들도 같은 경로에 저장됩니다.<br>
Best Checkpoint Path가 학습 마지막에 표기됩니다.<br>

### 3. 추론하기

`python inference_sumbt.py`

### 4. 제출하기

3번 스텝 `inference_sumbt.py`에서 `--output_dir {OUTPUT_DIR}`에 저장된 `predictions.json`을 제출합니다.

---
## Baseline model of TRADE


Open-vocab based DST model인 [TRADE](https://arxiv.org/abs/1905.08743)의 한국어 구현체입니다. (5강, 6강 내용 참고) <br>

- 기존의 GloVe, Char Embedding 대신 `monologg/koelectra-base-v3-discriminator``token_embeddings`을pretrained Subword Embedding으로 사용합니다.
- 메모리를 아끼기 위해 Token Embedding (768) => Hidden Dimension (400)으로의 Projection layer가 들어 있습니다.
- 빠른 학습을 위해 `Parallel Decoding`이 구현되어 있습니다.
- encoder가 Bert인 모델도 구현되어 있습니다.


### 1. 필요한 라이브러리 설치

`pip install -r requirements.txt`

### 2. 모델 학습

`python train_trade.py` <br>
학습된 모델은 epoch 별로 `SM_MODEL_DIR/model-{epoch}.bin` 으로 저장됩니다.<br>
추론에 필요한 부가 정보인 configuration들도 같은 경로에 저장됩니다.<br>
Best Checkpoint Path가 학습 마지막에 표기됩니다.<br>

`python train_tradeBert.py` <br>
Utterance encoder로 Bert 모듈을 사용합니다.<br>

### 3. 추론하기

`python inference.py`

### 4. 제출하기

3번 스텝 `inference.py`에서 `--output_dir {OUTPUT_DIR}`에 저장된 `predictions.json`을 제출합니다.

---
## Baseline model of SOM-DST
Open-vocab based DST model인 [SOM-DST](https://arxiv.org/abs/1911.03906)의 한국어 구현체입니다.
626 changes: 626 additions & 0 deletions BongjinKim/TRADE_exercise.ipynb

Large diffs are not rendered by default.

278 changes: 278 additions & 0 deletions BongjinKim/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
import dataclasses
import json
import random
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass
from typing import List, Optional, Union

import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm


@dataclass
class OntologyDSTFeature:
guid: str
input_ids: List[int]
segment_ids: List[int]
num_turn: int
target_ids: Optional[List[int]]


@dataclass
class OpenVocabDSTFeature:
guid: str
input_id: List[int]
segment_id: List[int]
gating_id: List[int]
target_ids: Optional[Union[List[int], List[List[int]]]]
slot_positions: [List[int]] = None
domain_id: int = None


class WOSDataset(Dataset):
def __init__(self, features):
self.features = features
self.length = len(self.features)

def __len__(self):
return self.length

def __getitem__(self, idx):
return self.features[idx]


def load_dataset(dataset_path, dev_split=0.1):
data = json.load(open(dataset_path, 'rt', encoding='UTF8'))
num_data = len(data)
num_dev = int(num_data * dev_split)
if not num_dev:
return data, [] # no dev dataset
#print(num_data, num_dev)
dom_mapper = defaultdict(list)
for d in data:
dom_mapper[len(d["domains"])].append(d["dialogue_idx"])

num_per_domain_trainsition = num_dev // 3
dev_idx = []
#print()
for v in dom_mapper.values():
if len(v) <= num_per_domain_trainsition:
dev_idx.extend(v)
else:
idx = random.sample(v, num_per_domain_trainsition)
dev_idx.extend(idx)

train_data, dev_data = [], []
for d in data:
if d["dialogue_idx"] in dev_idx:
dev_data.append(d)
else:
train_data.append(d)

dev_labels = {}
for dialogue in dev_data:
d_idx = 0
guid = dialogue["dialogue_idx"]
for idx, turn in enumerate(dialogue["dialogue"]):
if turn["role"] != "user":
continue

state = turn.pop("state")

guid_t = f"{guid}-{d_idx}"
d_idx += 1

dev_labels[guid_t] = state

return train_data, dev_data, dev_labels


def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
torch.cuda.manual_seed_all(seed)


def split_slot(dom_slot_value, get_domain_slot=False):
try:
dom, slot, value = dom_slot_value.split("-")
except ValueError:
tempo = dom_slot_value.split("-")
if len(tempo) < 2:
return dom_slot_value, dom_slot_value, dom_slot_value
dom, slot = tempo[0], tempo[1]
value = dom_slot_value.replace(f"{dom}-{slot}-", "").strip()

if get_domain_slot:
return f"{dom}-{slot}", value
return dom, slot, value


def build_slot_meta(data):
slot_meta = []
for dialog in data:
for turn in dialog["dialogue"]:
if not turn.get("state"):
continue

for dom_slot_value in turn["state"]:
domain_slot, _ = split_slot(dom_slot_value, get_domain_slot=True)
if domain_slot not in slot_meta:
slot_meta.append(domain_slot)
return sorted(slot_meta)


def convert_state_dict(state):
dic = {}
for slot in state:
s, v = split_slot(slot, get_domain_slot=True)
dic[s] = v
return dic


@dataclass
class DSTInputExample:
guid: str
context_turns: List[str]
current_turn: List[str]
label: Optional[List[str]] = None
domains: List[str] = None

def to_dict(self):
return dataclasses.asdict(self)

def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2) + "\n"


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""

# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()

def tokenize_ontology(ontology, tokenizer, max_seq_length=12):
slot_types = []
slot_values = []
for k, v in ontology.items():
tokens = tokenizer.encode(k)
if len(tokens) < max_seq_length:
gap = max_seq_length - len(tokens)
tokens.extend([tokenizer.pad_token_id] * gap)
slot_types.append(tokens)
slot_value = []
for vv in v:
tokens = tokenizer.encode(vv)
if len(tokens) < max_seq_length:
gap = max_seq_length - len(tokens)
tokens.extend([tokenizer.pad_token_id] * gap)
slot_value.append(tokens)
slot_values.append(torch.LongTensor(slot_value))
return torch.LongTensor(slot_types), slot_values

def get_examples_from_dialogue(dialogue, user_first=False):
guid = dialogue["dialogue_idx"]
examples = []
history = []
d_idx = 0
domains = dialogue["domains"]
for idx, turn in enumerate(dialogue["dialogue"]):
if turn["role"] != "user":
continue

if idx:
sys_utter = dialogue["dialogue"][idx - 1]["text"]
else:
sys_utter = ""

user_utter = turn["text"]
state = turn.get("state")
context = deepcopy(history)
if user_first:
current_turn = [user_utter, sys_utter]
else:
current_turn = [sys_utter, user_utter]
examples.append(
DSTInputExample(
guid=f"{guid}-{d_idx}",
context_turns=context,
current_turn=current_turn,
label=state,
domains=domains,
)
)
history.append(sys_utter)
history.append(user_utter)
d_idx += 1
return examples


def get_examples_from_dialogues(data, user_first=False, dialogue_level=False):
examples = []
for d in tqdm(data):
example = get_examples_from_dialogue(d, user_first=user_first)
if dialogue_level:
examples.append(example)
else:
examples.extend(example)
return examples


class DSTPreprocessor:
def __init__(self, slot_meta, src_tokenizer, trg_tokenizer=None, ontology=None):
self.slot_meta = slot_meta
self.src_tokenizer = src_tokenizer
self.trg_tokenizer = trg_tokenizer if trg_tokenizer else src_tokenizer
self.ontology = ontology

def pad_ids(self, arrays, pad_idx, max_length=-1):
if max_length < 0:
max_length = max(list(map(len, arrays)))

arrays = [
array + [pad_idx] * (max_length - min(len(array), 512)) for array in arrays
]
return arrays

def pad_id_of_matrix(self, arrays, padding, max_length=-1, left=False):
if max_length < 0:
max_length = max([array.size(-1) for array in arrays])

new_arrays = []
for i, array in enumerate(arrays):
n, l = array.size()
pad = torch.zeros(n, (max_length - l))
pad[
:,
:,
] = padding
pad = pad.long()
m = torch.cat([array, pad], -1)
new_arrays.append(m.unsqueeze(0))

return torch.cat(new_arrays, 0)

def _convert_example_to_feature(self):
raise NotImplementedError

def convert_examples_to_features(self):
raise NotImplementedError

def recover_state(self):
raise NotImplementedError
Loading

0 comments on commit b9231c2

Please sign in to comment.