Skip to content

Commit 3c9911c

Browse files
committed
Merge branch 'main' of https://github.com/bigdata-ustc/EduCDM into GDIRT
� Conflicts: � CHANGE.txt � setup.py
2 parents 802750e + e3c563c commit 3c9911c

File tree

18 files changed

+657
-5
lines changed

18 files changed

+657
-5
lines changed

.github/ISSUE_TEMPLATE/feature_request.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@ labels: 'Feature request'
88

99
## Description
1010
(A clear and concise description of what the feature is.)
11-
- If the proposal is about a new dataset, provide description of what the dataset is and
12-
attach the basic data analysis with it.
13-
- If the proposal is about an API, provide mock examples if possible.
11+
- If the proposal is about an algorithm or a model, provide mock examples if possible. In addition, you may need to carefully follow the [guidance](https://github.com/bigdata-ustc/EduCDM/blob/main/CONTRIBUTE.md)
1412

1513
## References
1614
- list reference and related literature

CHANGE.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
v0.0.6:
2+
* add Item Response Theory with Expectation Maximization Optimization (EMIRT)
3+
14
v0.0.5:
2-
* fix potential ModuleNotFoundError
5+
* add Item Response Theory with Gradient Descent Optimization (GDIRT)
36

47
v0.0.4:
58
* add NeuralCDM (NCDM)

EduCDM/IRT/EM/IRT.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# coding: utf-8
2+
# 2021/5/2 @ liujiayu
3+
4+
import logging
5+
import numpy as np
6+
import pickle
7+
from tqdm import tqdm
8+
from scipy import stats
9+
from ..irt import irt3pl
10+
from EduCDM import CDM
11+
12+
13+
def init_parameters(prob_num, dim):
14+
alpha = stats.norm.rvs(loc=0.75, scale=0.01, size=(prob_num, dim))
15+
beta = stats.norm.rvs(size=(prob_num, dim))
16+
gamma = stats.uniform.rvs(size=prob_num)
17+
return alpha, beta, gamma
18+
19+
20+
def init_prior_prof_distribution(dim):
21+
prof = stats.uniform.rvs(loc=-4, scale=8, size=(100, dim)) # shape = (100,dim)
22+
dis = stats.multivariate_normal.pdf(prof, mean=np.zeros(dim), cov=np.identity(dim))
23+
norm_dis = dis / np.sum(dis) # shape = (100,)
24+
return prof, norm_dis
25+
26+
27+
def get_Likelihood(a, b, c, prof, R):
28+
stu_num, prob_num = R.shape[0], R.shape[1]
29+
prof_prob = irt3pl(np.sum(a * (np.expand_dims(prof, axis=1) - b), axis=-1), 1, 0, c) # shape = (100, prob_num)
30+
tmp1, tmp2 = np.zeros(shape=(prob_num, stu_num)), np.zeros(shape=(prob_num, stu_num))
31+
tmp1[np.where(R == 1)[1], np.where(R == 1)[0]] = 1
32+
tmp2[np.where(R == 0)[1], np.where(R == 0)[0]] = 1
33+
prob_stu = np.exp(np.dot(np.log(prof_prob + 1e-9), tmp1) + np.dot(np.log(1 - prof_prob + 1e-9), tmp2))
34+
return prof_prob, prob_stu
35+
36+
37+
def update_prior(prior_dis, prof_stu_like):
38+
dis_like = prof_stu_like * np.expand_dims(prior_dis, axis=1)
39+
norm_dis_like = dis_like / np.sum(dis_like, axis=0)
40+
update_prior_dis = np.sum(norm_dis_like, axis=1) / np.sum(norm_dis_like)
41+
return update_prior_dis, norm_dis_like
42+
43+
44+
def update_irt(a, b, c, D, prof, R, r_ek, s_ek, lr, epoch=10, epsilon=1e-3):
45+
for iteration in range(epoch):
46+
a_tmp, b_tmp, c_tmp = np.copy(a), np.copy(b), np.copy(c)
47+
prof_prob, _ = get_Likelihood(a, b, c, prof, R)
48+
common_term = (r_ek - s_ek * prof_prob) / prof_prob / (1 - c + 1e-9) # shape = (100, prob_num)
49+
a_1 = np.transpose(
50+
D * common_term * (prof_prob - c) * np.transpose(np.expand_dims(prof, axis=1) - b, (2, 0, 1)), (1, 2, 0))
51+
b_1 = D * common_term * (c - prof_prob)
52+
a_grad = np.sum(a_1, axis=0)
53+
b_grad = a * np.expand_dims(np.sum(b_1, axis=0), axis=1)
54+
c_grad = np.sum(common_term, axis=0)
55+
a = a + lr * a_grad
56+
b = b + lr * b_grad
57+
c = np.clip(c + lr * c_grad, 0, 1)
58+
change = max(np.max(np.abs(a - a_tmp)), np.max(np.abs(b - b_tmp)), np.max(np.abs(c - c_tmp)))
59+
if iteration > 5 and change < epsilon:
60+
break
61+
return a, b, c
62+
63+
64+
class IRT(CDM):
65+
def __init__(self, R, stu_num, prob_num, dim=1, skip_value=-1):
66+
super(IRT, self).__init__()
67+
self.R, self.skip_value = R, skip_value
68+
self.stu_num, self.prob_num, self.dim = stu_num, prob_num, dim
69+
self.a, self.b, self.c = init_parameters(prob_num, dim) # IRT parameters
70+
self.D = 1.702
71+
self.prof, self.prior_dis = init_prior_prof_distribution(dim)
72+
self.stu_prof = np.zeros(shape=(stu_num, dim))
73+
74+
def train(self, lr, epoch, epoch_m=10, epsilon=1e-3):
75+
a, b, c = np.copy(self.a), np.copy(self.b), np.copy(self.c)
76+
prior_dis = np.copy(self.prior_dis)
77+
for iteration in range(epoch):
78+
a_tmp, b_tmp, c_tmp, prior_dis_tmp = np.copy(a), np.copy(b), np.copy(c), np.copy(prior_dis)
79+
prof_prob_like, prof_stu_like = get_Likelihood(a, b, c, self.prof, self.R)
80+
prior_dis, norm_dis_like = update_prior(prior_dis, prof_stu_like)
81+
82+
r_1 = np.zeros(shape=(self.stu_num, self.prob_num))
83+
r_1[np.where(self.R == 1)[0], np.where(self.R == 1)[1]] = 1
84+
r_ek = np.dot(norm_dis_like, r_1) # shape = (100, prob_num)
85+
r_1[np.where(self.R != self.skip_value)[0], np.where(self.R != self.skip_value)[1]] = 1
86+
s_ek = np.dot(norm_dis_like, r_1) # shape = (100, prob_num)
87+
a, b, c = update_irt(a, b, c, self.D, self.prof, self.R, r_ek, s_ek, lr, epoch_m, epsilon)
88+
change = max(np.max(np.abs(a - a_tmp)), np.max(np.abs(b - b_tmp)), np.max(np.abs(c - c_tmp)),
89+
np.max(np.abs(prior_dis_tmp - prior_dis_tmp)))
90+
if iteration > 20 and change < epsilon:
91+
break
92+
self.a, self.b, self.c, self.prior_dis = a, b, c, prior_dis
93+
self.stu_prof = self.transform(self.R)
94+
95+
def eval(self, test_data) -> tuple:
96+
pred_score = irt3pl(np.sum(self.a * (np.expand_dims(self.stu_prof, axis=1) - self.b), axis=-1), 1, 0, self.c)
97+
test_rmse, test_mae = [], []
98+
for i in tqdm(test_data, "evaluating"):
99+
stu, test_id, true_score = i['user_id'], i['item_id'], i['score']
100+
test_rmse.append((pred_score[stu, test_id] - true_score) ** 2)
101+
test_mae.append(abs(pred_score[stu, test_id] - true_score))
102+
return np.sqrt(np.average(test_rmse)), np.average(test_mae)
103+
104+
def save(self, filepath):
105+
with open(filepath, 'wb') as file:
106+
pickle.dump({"a": self.a, "b": self.b, "c": self.c, "prof": self.stu_prof}, file)
107+
logging.info("save parameters to %s" % filepath)
108+
109+
def load(self, filepath):
110+
with open(filepath, 'rb') as file:
111+
self.a, self.b, self.c, self.stu_prof = pickle.load(file).values()
112+
logging.info("load parameters from %s" % filepath)
113+
114+
def inc_train(self, inc_train_data, lr=1e-3, epoch=10, epsilon=1e-3): # incremental training
115+
for i in inc_train_data:
116+
stu, test_id, true_score = i['user_id'], i['item_id'], i['score']
117+
self.R[stu, test_id] = true_score
118+
self.train(lr, epoch, epsilon=epsilon)
119+
120+
def transform(self, records, lr=1e-3, epoch=10, epsilon=1e-3): # MLE for evaluating students' state
121+
# can evaluate multiple students' states simultaneously, thus output shape = (stu_num, dim)
122+
# initialization stu_prof, shape = (stu_num, dim)
123+
if len(records.shape) == 1: # one student
124+
records = np.expand_dims(records, axis=0)
125+
_, prof_stu_like = get_Likelihood(self.a, self.b, self.c, self.prof, records)
126+
stu_prof = self.prof[np.argmax(prof_stu_like, axis=0)]
127+
128+
for iteration in range(epoch):
129+
prof_tmp = np.copy(stu_prof)
130+
ans_prob = irt3pl(np.sum(self.a * (np.expand_dims(stu_prof, axis=1) - self.b), axis=-1), 1, 0, self.c)
131+
ans_1 = self.D * (records - ans_prob) / ans_prob * (ans_prob - self.c) / (1 - self.c + 1e-9)
132+
ans_1[np.where(records == self.skip_value)[0], np.where(records == self.skip_value)[1]] = 0
133+
prof_grad = np.dot(ans_1, self.a)
134+
stu_prof = stu_prof - lr * prof_grad
135+
change = np.max(np.abs(stu_prof - prof_tmp))
136+
if iteration > 5 and change < epsilon:
137+
break
138+
return stu_prof # shape = (stu_num, dim)

EduCDM/IRT/EM/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# coding: utf-8
2+
# 2021/5/2 @ liujiayu
3+
4+
from .IRT import IRT

EduCDM/IRT/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44

55
from .GD import IRT as GDIRT
6+
from .EM import IRT as EMIRT

EduCDM/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .FuzzyCDF import FuzzyCDF
99
from .NCDM import NCDM
1010
from .IRT import GDIRT
11+
from .IRT import EMIRT

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,30 @@ More recent researches about CDMs:
3636
* [NCDM](EduCDM/NCDM) [[doc]](docs/NCDM.md) [[example]](examples/NCDM)
3737
* [FuzzyCDF](EduCDM/FuzzyCDF) [[doc]](docs/FuzzyCDF.md) [[example]](examples/FuzzyCDF)
3838
* [DINA](EduCDM/DINA) [[doc]](docs/DINA.md) [[example]](examples/DINA)
39+
* [IRT](EduCDM/IRT) [[doc]](docs/IRT.md) [[example]](examples/IRT)
40+
* Eexpectation Maximization ([EMIRT](EduCDM/IRT/EM)) [[example]](examples/IRT/EM)
41+
* Gradient Descent ([GDIRT](EduCDM/IRT/GD)) [[example]](examples/IRT/GD)
3942
* [MCD](EduCDM/MCD) [[doc]](docs/MCD.md) [[example]](examples/MCD)
43+
* [IRR](EduCDM/IRR)[[doc]](docs/IRR.md)[[example]](examples/IRR)
44+
* IRR-NCDM
45+
* IRR-DINA
46+
* IRR-IRT
4047

48+
## Installation
49+
50+
Git and install with `pip`:
51+
52+
```
53+
git clone https://github.com/bigdata-ustc/EduCDM.git
54+
cd path/to/code
55+
pip install .
56+
```
57+
58+
Or directly install from pypi:
59+
60+
```
61+
pip install EduCDM
62+
```
4163

4264

4365
## Contribute

docs/IRT.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Item response theory
2+
3+
If the reader wants to know the details of EMIRT, please refer to the paper: *[Estimation for Item Response Models using the EM Algorithm for Finite Mixtures](https://files.eric.ed.gov/fulltext/ED405356.pdf)*.
4+
```bibtex
5+
@article{woodruff1996estimation,
6+
title={Estimation of Item Response Models Using the EM Algorithm for Finite Mixtures.},
7+
author={Woodruff, David J and Hanson, Bradley A},
8+
year={1996},
9+
publisher={ERIC}
10+
}
11+
```

0 commit comments

Comments
 (0)