|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | + |
| 6 | +class CDBinEncoder(): |
| 7 | + def __init__(self, g, r): # g is the original input dimension, and r is the target dimension |
| 8 | + super(object, self).__init__() |
| 9 | + |
| 10 | + self.fix_seed(37) |
| 11 | + |
| 12 | + print('initia parameters... ...') |
| 13 | + self.g = g |
| 14 | + self.r = r |
| 15 | + |
| 16 | + self.V = torch.from_numpy(self.generate_V(g, g * 5)).float().cuda() |
| 17 | + self.normed_V = (self.V / torch.norm(self.V, dim=0).unsqueeze(0)).cuda() |
| 18 | + |
| 19 | + self.P = self.generate_P_svd(self.V, r).float().cuda() |
| 20 | + |
| 21 | + self.V_p = (self.P @ self.V * np.sqrt(r)).float().cuda() |
| 22 | + self.inverse_V_p = torch.pinverse(self.V_p).float().cuda() |
| 23 | + |
| 24 | + def fix_seed(self, seed): |
| 25 | + np.random.seed(seed) |
| 26 | + torch.manual_seed(seed) |
| 27 | + torch.cuda.manual_seed(seed) |
| 28 | + |
| 29 | + def generate_V(self, num_rows, num_cols): |
| 30 | + limit = np.sqrt(2. / (num_rows + num_cols)) |
| 31 | + random_matrix = np.random.normal(loc=0.0, scale=limit, size=(num_rows, num_cols)) |
| 32 | + |
| 33 | + emb_mean = np.mean(random_matrix, axis=0)[None, :] |
| 34 | + random_matrix -= emb_mean |
| 35 | + |
| 36 | + return random_matrix |
| 37 | + |
| 38 | + def generate_P_svd(self, V, r): |
| 39 | + u, sigma, v = torch.svd(V) |
| 40 | + return u[:r, :] |
| 41 | + |
| 42 | + def generate_P(self, g, r): |
| 43 | + limit = np.sqrt(6. / (g + r)) |
| 44 | + random_matrix = np.random.uniform(low=-limit, high=limit, size=(g, r)) |
| 45 | + |
| 46 | + u, sigma, v = np.linalg.svd(random_matrix) |
| 47 | + |
| 48 | + return u[:r, :] |
| 49 | + |
| 50 | + def dcd(self, S, U, V): |
| 51 | + L = U.shape[0] |
| 52 | + Q = (V @ S.t()).cuda() |
| 53 | + |
| 54 | + while True: |
| 55 | + is_update = False |
| 56 | + for i in range(L): |
| 57 | + U_b_prime = torch.cat((U[:i, :], U[i + 1:, :])) |
| 58 | + |
| 59 | + v_p = V[i, :] |
| 60 | + V_p_prime = torch.cat((V[:i, :], V[i + 1:, :])) |
| 61 | + |
| 62 | + q = Q[i, :] |
| 63 | + |
| 64 | + bracket_result = (q - U_b_prime.t() @ V_p_prime @ v_p).cuda() |
| 65 | + |
| 66 | + new_u = bracket_result.sign().cuda() |
| 67 | + new_u[torch.eq(new_u, 0.)] = 1. |
| 68 | + |
| 69 | + if torch.all(torch.eq(new_u, U[i, :])): |
| 70 | + continue |
| 71 | + U[i, :] = new_u |
| 72 | + is_update = True |
| 73 | + |
| 74 | + if not is_update: break |
| 75 | + |
| 76 | + return U.t().cpu().numpy() |
| 77 | + |
| 78 | + def encode(self, X): |
| 79 | + X = torch.from_numpy(X).cuda() |
| 80 | + |
| 81 | + normed_X = (X / torch.norm(X, dim=1).unsqueeze(1)).cuda() |
| 82 | + |
| 83 | + S = (normed_X @ self.normed_V * self.r).cuda() |
| 84 | + |
| 85 | + X_small_code = (S @ self.inverse_V_p).cuda() |
| 86 | + |
| 87 | + return self.dcd(S, X_small_code.t(), self.V_p) |
0 commit comments