Skip to content

[FEATURE] Add SKT model #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

[Jie Ouyang](https://github.com/0russwest0)

[Weizhe Huang](https://github.com/weizhehuang0827)

[Bihan Xu](https://github.com/xbh0720)

The starred is the corresponding author
94 changes: 94 additions & 0 deletions EduKTM/SKT/SKT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# coding: utf-8
# 2023/3/17 @ weizhehuang0827

import logging
import numpy as np
import torch
from tqdm import tqdm
from EduKTM import KTM
from .SKTNet import SKTNet
from EduKTM.utils import SLMLoss, tensor2list, pick
from sklearn.metrics import roc_auc_score, accuracy_score


class SKT(KTM):
def __init__(self, ku_num, graph_params, hidden_num, net_params: dict = None, loss_params=None):
super(SKT, self).__init__()
self.skt_model = SKTNet(
ku_num,
graph_params,
hidden_num,
**(net_params if net_params is not None else {})
)
self.loss_params = loss_params if loss_params is not None else {}

def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
loss_function = SLMLoss(**self.loss_params).to(device)
self.skt_model = self.skt_model.to(device)
trainer = torch.optim.Adam(self.skt_model.parameters(), lr)

for e in range(epoch):
losses = []
for (question, data, data_mask, label, pick_index, label_mask) in tqdm(train_data, "Epoch %s" % e):
# convert to device
question: torch.Tensor = question.to(device)
data: torch.Tensor = data.to(device)
data_mask: torch.Tensor = data_mask.to(device)
label: torch.Tensor = label.to(device)
pick_index: torch.Tensor = pick_index.to(device)
label_mask: torch.Tensor = label_mask.to(device)

# real training
predicted_response, _ = self.skt_model(
question, data, data_mask)

loss = loss_function(predicted_response,
pick_index, label, label_mask)

# back propagation
trainer.zero_grad()
loss.backward()
trainer.step()

losses.append(loss.mean().item())
print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses))))

if test_data is not None:
auc, accuracy = self.eval(test_data, device=device)
print("[Epoch %d] auc: %.6f, accuracy: %.6f" %
(e, auc, accuracy))

def eval(self, test_data, device="cpu") -> tuple:
self.skt_model.eval()
y_true = []
y_pred = []

for (question, data, data_mask, label, pick_index, label_mask) in tqdm(test_data, "evaluating"):
# convert to device
question: torch.Tensor = question.to(device)
data: torch.Tensor = data.to(device)
data_mask: torch.Tensor = data_mask.to(device)
label: torch.Tensor = label.to(device)
pick_index: torch.Tensor = pick_index.to(device)
label_mask: torch.Tensor = label_mask.to(device)

# real evaluating
output, _ = self.skt_model(question, data, data_mask)
output = output[:, :-1]
output = pick(output, pick_index.to(output.device))
pred = tensor2list(output)
label = tensor2list(label)
for i, length in enumerate(label_mask.cpu().tolist()):
length = int(length)
y_true.extend(label[i][:length])
y_pred.extend(pred[i][:length])
self.skt_model.train()
return roc_auc_score(y_true, y_pred), accuracy_score(y_true, np.array(y_pred) >= 0.5)

def save(self, filepath) -> ...:
torch.save(self.skt_model.state_dict(), filepath)
logging.info("save parameters to %s" % filepath)

def load(self, filepath):
self.skt_model.load_state_dict(torch.load(filepath))
logging.info("load parameters from %s" % filepath)
174 changes: 174 additions & 0 deletions EduKTM/SKT/SKTNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# coding: utf-8
# 2023/3/17 @ weizhehuang0827
__all__ = ["SKTNet"]


import torch
import torch.nn as nn
import torch.nn.functional as F
from EduKTM.utils import GRUCell, begin_states, get_states, expand_tensor, \
format_sequence, mask_sequence_variable_length
from .utils import Graph


class SKTNet(nn.Module):
def __init__(self, ku_num, graph_params=None,
alpha=0.5,
latent_dim=None, activation=None,
hidden_num=90, concept_dim=None,
# dropout=0.5, self_dropout=0.0,
dropout=0.0, self_dropout=0.5,
# dropout=0.0, self_dropout=0.0,
sync_dropout=0.0,
prop_dropout=0.0,
agg_dropout=0.0,
params=None):
super(SKTNet, self).__init__()
self.ku_num = int(ku_num)
self.hidden_num = self.ku_num if hidden_num is None else int(
hidden_num)
self.latent_dim = self.hidden_num if latent_dim is None else int(
latent_dim)
self.concept_dim = self.hidden_num if concept_dim is None else int(
concept_dim)
graph_params = graph_params if graph_params is not None else []
self.graph = Graph.from_file(ku_num, graph_params)
self.alpha = alpha

sync_activation = nn.ReLU() if activation is None else activation
prop_activation = nn.ReLU() if activation is None else activation
agg_activation = nn.ReLU() if activation is None else activation

self.rnn = GRUCell(self.hidden_num)
self.response_embedding = nn.Embedding(
2 * self.ku_num, self.latent_dim)
self.concept_embedding = nn.Embedding(self.ku_num, self.concept_dim)
self.f_self = GRUCell(self.hidden_num)
self.self_dropout = nn.Dropout(self_dropout)
self.f_prop = nn.Sequential(
nn.Linear(self.hidden_num * 2, self.hidden_num),
prop_activation,
nn.Dropout(prop_dropout),
)
self.f_sync = nn.Sequential(
nn.Linear(self.hidden_num * 3, self.hidden_num),
sync_activation,
nn.Dropout(sync_dropout),
)
self.f_agg = nn.Sequential(
nn.Linear(self.hidden_num, self.hidden_num),
agg_activation,
nn.Dropout(agg_dropout),
)
self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(self.hidden_num, 1)
self.sigmoid = nn.Sigmoid()

def neighbors(self, x, ordinal=True):
return self.graph.neighbors(x, ordinal)

def successors(self, x, ordinal=True):
return self.graph.successors(x, ordinal)

def forward(self, questions, answers, valid_length=None, states=None, layout='NTC', compressed_out=True,
*args, **kwargs):
length = questions.shape[1]
device = questions.device
inputs, axis, batch_size = format_sequence(
length, questions, layout, False)
answers, _, _ = format_sequence(length, answers, layout, False)
states = begin_states([(batch_size, self.ku_num, self.hidden_num)])[0]
states = states.to(device)
outputs = []
for i in range(length):
inputs_i = inputs[i].reshape([batch_size, ])
answer_i = answers[i].reshape([batch_size, ])

# concept embedding
concept_embeddings = self.concept_embedding.weight.data
concept_embeddings = expand_tensor(
concept_embeddings, 0, batch_size)
# concept_embeddings = (_self_mask + _successors_mask + _neighbors_mask) * concept_embeddings

# self - influence
_self_state = get_states(inputs_i, states)
# fc
# _next_self_state = self.f_self(mx.nd.concat(_self_state, self.response_embedding(answers[i]), dim=-1))
# gru
_next_self_state, _ = self.f_self(
self.response_embedding(answer_i), [_self_state])
# _next_self_state = self.f_self(mx.nd.concat(_self_hidden_states, _self_state))
# _next_self_state, _ = self.f_self(_self_hidden_states, [_self_state])
_next_self_state = self.self_dropout(_next_self_state)

# get self mask
_self_mask = torch.unsqueeze(F.one_hot(inputs_i, self.ku_num), -1)
_self_mask = torch.broadcast_to(
_self_mask, (-1, -1, self.hidden_num))

# find neighbors
_neighbors = self.neighbors(inputs_i)
_neighbors_mask = torch.unsqueeze(
torch.tensor(_neighbors, device=device), -1)
_neighbors_mask = torch.broadcast_to(
_neighbors_mask, (-1, -1, self.hidden_num))

# synchronization
_broadcast_next_self_states = torch.unsqueeze(_next_self_state, 1)
_broadcast_next_self_states = torch.broadcast_to(
_broadcast_next_self_states, (-1, self.ku_num, -1))
# _sync_diff = mx.nd.concat(states, _broadcast_next_self_states, concept_embeddings, dim=-1)
_sync_diff = torch.concat(
(states, _broadcast_next_self_states, concept_embeddings), dim=-1)
_sync_inf = _neighbors_mask * self.f_sync(_sync_diff)

# reflection on current vertex
_reflec_inf = torch.sum(_sync_inf, dim=1)
_reflec_inf = torch.broadcast_to(
torch.unsqueeze(_reflec_inf, 1), (-1, self.ku_num, -1))
_sync_inf = _sync_inf + _self_mask * _reflec_inf

# find successors
_successors = self.successors(inputs_i)
_successors_mask = torch.unsqueeze(
torch.tensor(_successors, device=device), -1)
_successors_mask = torch.broadcast_to(
_successors_mask, (-1, -1, self.hidden_num))

# propagation
_prop_diff = torch.concat(
(_next_self_state - _self_state, self.concept_embedding(inputs_i)), dim=-1)
# _prop_diff = _next_self_state - _self_state

# 1
_prop_inf = self.f_prop(_prop_diff)
_prop_inf = _successors_mask * \
torch.broadcast_to(torch.unsqueeze(
_prop_inf, axis=1), (-1, self.ku_num, -1))
# 2
# _broadcast_diff = mx.nd.broadcast_to(mx.nd.expand_dims(_prop_diff, axis=1), (0, self.ku_num, 0))
# _pro_inf = _successors_mask * self.f_prop(
# mx.nd.concat(_broadcast_diff, concept_embeddings, dim=-1)
# )
# _pro_inf = _successors_mask * self.f_prop(
# _broadcast_diff
# )

# aggregate
_inf = self.f_agg(self.alpha * _sync_inf + (1 - self.alpha) * _prop_inf)
next_states, _ = self.rnn(_inf, [states])
# next_states, _ = self.rnn(torch.concat((_inf, concept_embeddings), dim=-1), [states])
# states = (1 - _self_mask) * next_states + _self_mask * _broadcast_next_self_states
states = next_states
output = self.sigmoid(torch.squeeze(
self.out(self.dropout(states)), axis=-1))
outputs.append(output)
# if valid_length is not None and not compressed_out:
# all_states.append([states])

if valid_length is not None:
if compressed_out:
states = None
outputs = mask_sequence_variable_length(torch, outputs, valid_length)

return outputs, states
2 changes: 2 additions & 0 deletions EduKTM/SKT/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .SKT import SKT
from .etl import etl
81 changes: 81 additions & 0 deletions EduKTM/SKT/etl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# coding: utf-8
# 2023/3/17 @ weizhehuang0827


import torch
import json
from tqdm import tqdm
from EduKTM.utils.torch import PadSequence, FixedBucketSampler


def extract(data_src, max_step=200): # pragma: no cover
responses = []
step = max_step
with open(data_src) as f:
for line in tqdm(f, "reading data from %s" % data_src):
data = json.loads(line)
if step is not None:
for i in range(0, len(data), step):
if len(data[i: i + step]) < 2:
continue
responses.append(data[i: i + step])
else:
responses.append(data)

return responses


def transform(raw_data, batch_size, num_buckets=100):
# 定义数据转换接口
# raw_data --> batch_data

responses = raw_data

batch_idxes = FixedBucketSampler(
[len(rs) for rs in responses], batch_size, num_buckets=num_buckets)
batch = []

def index(r):
correct = 0 if r[1] <= 0 else 1
return r[0] * 2 + correct

for batch_idx in tqdm(batch_idxes, "batchify"):
batch_qs = []
batch_rs = []
batch_pick_index = []
batch_labels = []
for idx in batch_idx:
batch_qs.append([r[0] for r in responses[idx]])
batch_rs.append([index(r) for r in responses[idx]])
if len(responses[idx]) <= 1: # pragma: no cover
pick_index, labels = [], []
else:
pick_index, labels = zip(
*[(r[0], 0 if r[1] <= 0 else 1) for r in responses[idx][1:]])
batch_pick_index.append(list(pick_index))
batch_labels.append(list(labels))

max_len = max([len(rs) for rs in batch_rs])
padder = PadSequence(max_len, pad_val=0)
batch_qs = [padder(qs) for qs in batch_qs]
batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs])

max_len = max([len(rs) for rs in batch_labels])
padder = PadSequence(max_len, pad_val=0)
batch_labels, label_mask = zip(
*[(padder(labels), len(labels)) for labels in batch_labels])
batch_pick_index = [padder(pick_index)
for pick_index in batch_pick_index]
# Load
batch.append(
[torch.tensor(batch_qs), torch.tensor(batch_rs), torch.tensor(data_mask), torch.tensor(batch_labels),
torch.tensor(batch_pick_index),
torch.tensor(label_mask)])

return batch


def etl(data_src, cfg=None, batch_size=None, **kwargs): # pragma: no cover
batch_size = batch_size if batch_size is not None else cfg.batch_size
raw_data = extract(data_src)
return transform(raw_data, batch_size, **kwargs)
Loading