Skip to content

Commit b1f8910

Browse files
committed
Implemented gen_special_linear
1 parent aa977e9 commit b1f8910

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

Util/Util.py

+26
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,32 @@ def gen_noisy_poly(size=10000, p=3, n_dim=100, n_valid=5, noise_scale=0.5, test_
252252
return (x_train_noise, y_train), (x_test, y_test)
253253
return (x_train_noise, DataUtil.get_one_hot(y_train, 2)), (x_test, DataUtil.get_one_hot(y_test, 2))
254254

255+
@staticmethod
256+
def gen_special_linear(size=10000, n_dim=10, n_redundant=3, n_categorical=3,
257+
cv_ratio=0.15, test_ratio=0.15, one_hot=True):
258+
x_train = np.random.randn(size, n_dim)
259+
x_train_redundant = np.ones([size, n_redundant]) * np.random.randint(0, 3, n_redundant)
260+
x_train_categorical = np.random.randint(3, 8, [size, n_categorical])
261+
x_train_stacked = np.hstack([x_train, x_train_redundant, x_train_categorical])
262+
n_test = int(size * test_ratio)
263+
x_test = np.random.randn(n_test, n_dim)
264+
x_test_redundant = np.ones([n_test, n_redundant]) * np.random.randint(3, 6, n_redundant)
265+
x_test_categorical = np.random.randint(0, 5, [n_test, n_categorical])
266+
x_test_stacked = np.hstack([x_test, x_test_redundant, x_test_categorical])
267+
w = np.random.randn(n_dim, 1)
268+
y_train = (x_train.dot(w) > 0).astype(np.int8).ravel()
269+
y_test = (x_test.dot(w) > 0).astype(np.int8).ravel()
270+
n_cv = int(size * cv_ratio)
271+
x_train_stacked, x_cv_stacked = x_train_stacked[:-n_cv], x_train_stacked[-n_cv:]
272+
y_train, y_cv = y_train[:-n_cv], y_train[-n_cv:]
273+
if not one_hot:
274+
return (x_train_stacked, y_train), (x_cv_stacked, y_cv), (x_test_stacked, y_test)
275+
return (
276+
(x_train_stacked, DataUtil.get_one_hot(y_train, 2)),
277+
(x_cv_stacked, DataUtil.get_one_hot(y_cv, 2)),
278+
(x_test_stacked, DataUtil.get_one_hot(y_test, 2))
279+
)
280+
255281
@staticmethod
256282
def quantize_data(x, y, wc=None, continuous_rate=0.1, separate=False):
257283
if isinstance(x, list):

0 commit comments

Comments
 (0)