diff --git a/config.py b/config.py new file mode 100644 index 0000000..170ef5f --- /dev/null +++ b/config.py @@ -0,0 +1,53 @@ +import argparse +from utils.train_util import add_flags_from_config + +config_args = { + 'data_config': { + 'seed': (2022, 'manual random seed'), + 'cuda': (0, 'which cuda device to use'), + 'dataset': ('20ng', 'which dataset to use'), + 'data-path': ('./data/20NG/20ng.pkl', 'path to load data'), + 'batch-size': (200, 'number of examples in a mini-batch'), + }, + 'training_config': { + 'epochs': (200, 'number of epochs to train'), + 'lr': (0.01, 'initial learning rate'), + 'lr-reduce-freq': (None, 'reduce lr every lr-reduce-freq or None to keep lr constant'), + 'gamma': (0.1, 'gamma for lr scheduler'), + 'dropout': (0, 'dropout probability'), + 'momentum': (0.999, 'momentum in optimizer'), + 'weight-decay': (1e-5, 'l2 regularization strength'), + 'optimizer': ('Adam', 'which optimizer to use, can be any of [Adam, RiemannianAdam]'), + 'grad-clip': (None, 'max norm for gradient clipping, or None for no gradient clipping'), + 'log-freq': (1, 'how often to compute print train/val metrics (in epochs)'), + 'eval-freq': (50, 'how often to compute val metrics (in epochs)'), + 'save': (1, '1 to save model and logs and 0 otherwise'), + 'save-dir': (None, 'path to save training logs and model weights (defaults to logs/task/date/run/)'), + 'sweep-c': (0, ''), + 'print-epoch': (True, ''), + }, + 'model_config': { + # topic model + 'vocab-size': (2000, 'vocabulary size'), # 13368 + 'embed-size': (2, 'dimensionality of word and topic embeddings'), # (683, 366, 84, 11, 2) # (810, 408, 91, 13, 2) + 'num-topics-list': ([185, 66, 11, 2], 'number of topics in each latent layer'), # (560, 325, 83, 12, 2) + 'num-hiddens-list': ([300, 300, 300, 300], 'number of units in each hidden layer'), + 'pretrained-embeddings': (False, 'whether to use pretrained embeddings to initialize words and topics'), + 'manifold': ('PoincareBall', 'which manifold to use, can be any of [Euclidean, Hyperboloid, PoincareBall]'), + 'c': (-1.0, 'hyperbolic radius, set to None for trainable curvature'), + 'clip_r': (8.0, 'avoid the vanishing gradients problem'), + # hyperbolic gcn + 'add-knowledge': (True, 'whether inject prior knowledge to topic modeling'), + 'file-path': ('./data/20NG/20ng_wordnet_tree_4layers.pkl', 'path to load tree knowledge'), + 'gcn-layers': (2, 'number of hidden layers in graph encoder'), + 'bias': (1, 'whether to use bias (1) or not (0)'), + 'use-att': (1, 'whether to use hyperbolic attention or not'), + 'local-agg': (0, 'whether to local tangent space aggregation or not'), + 'act': ('relu', 'which activation function to use (or None for no activation)'), + 'double-precision': ('0', 'whether to use double precision'), + }, +} + +parser = argparse.ArgumentParser() +for _, config_dict in config_args.items(): + parser = add_flags_from_config(parser, config_dict) diff --git a/manifolds/__init__.py b/manifolds/__init__.py new file mode 100644 index 0000000..fa7e785 --- /dev/null +++ b/manifolds/__init__.py @@ -0,0 +1,3 @@ +from .euclidean import Euclidean +from .hyperboloid import Hyperboloid +from .poincare import PoincareBall diff --git a/manifolds/base.py b/manifolds/base.py new file mode 100644 index 0000000..292904f --- /dev/null +++ b/manifolds/base.py @@ -0,0 +1,170 @@ +"""Base manifold.""" + +import torch + + +class Manifold(object): + """Abstract class to define basic operations on a manifold. + + Attributes: + clip (function): Clips tensor values to a specified min and max. + dtype: The type of the variables. + eps (float): A small constant value. + max_norm (float): The maximum value for number clipping. + min_norm (float): The minimum value for number clipping. + """ + + def __init__(self, **kwargs): + """Initialize a manifold. + """ + super(Manifold, self).__init__() + + self.min_norm = 1e-15 + self.max_norm = 1e15 + self.eps = 1e-5 + + self.dtype = kwargs["dtype"] if "dtype" in kwargs else torch.float32 + self.clip = lambda x: torch.clamp(x, min=self.min_norm, max=self.max_norm) + + def proj(self, x, c): + """A projection function that prevents x from leaving the manifold. + Args: + x (tensor): A point should be on the manifold, but it may not meet the manifold constraints. + c (tensor): The manifold curvature. + Returns: + tensor: A projected point, meeting the manifold constraints. + """ + raise NotImplementedError + + def proj_tan(self, v, x, c): + """A projection function that prevents v from leaving the tangent space of point x. + Args: + v (tensor): A point should be on the tangent space, but it may not meet the manifold constraints. + x (tensor): A point on the manifold. + c (tensor): The manifold curvature. + Returns: + tensor: A projected point, meeting the tangent space constraints. + """ + raise NotImplementedError + + def proj_tan0(self, v, c): + """A projection function that prevents v from leaving the tangent space of origin point. + Args: + v (tensor): A point should be on the tangent space, but it may not meet the manifold constraints. + c (tensor): The manifold curvature. + Returns: + tensor: A projected point, meeting the tangent space constraints. + """ + raise NotImplementedError + + def expmap(self, v, x, c): + """Map a point v in the tangent space of point x to the manifold. + Args: + v (tensor): A point in the tangent space of point x. + x (tensor): A point on the manifold. + c (tensor): The manifold curvature. + Returns: + tensor: The result of mapping tangent point v to the manifold. + """ + raise NotImplementedError + + def expmap0(self, v, c): + """Map a point v in the tangent space of origin point to the manifold. + Args: + v (tensor): A point in the tangent space of origin point. + c (tensor): The manifold curvature. + Returns: + tensor: The result of mapping tangent point v to the manifold. + """ + raise NotImplementedError + + def logmap(self, y, x, c): + """Map a point y on the manifold to the tangent space of x. + Args: + y (tensor): A point on the manifold. + x (tensor): A point on the manifold. + c (tensor): The manifold curvature. + Returns: + tensor: The result of mapping y to the tangent space of x. + """ + raise NotImplementedError + + def logmap0(self, y, c): + """Map a point y on the manifold to the tangent space of origin point. + Args: + y (tensor): A point on the manifold. + c (tensor): The manifold curvature. + Returns: + tensor: The result of mapping y to the tangent space of origin point. + """ + raise NotImplementedError + + def ptransp(self, v, x, y, c): + """Parallel transport function, used to move point v in the tangent space of x to the tangent space of y. + Args: + v (tensor): A point in the tangent space of x. + x (tensor): A point on the manifold. + y (tensor): A point on the manifold. + c (tensor): The manifold curvature. + Returns: + tensor: The result of transporting v from the tangent space at x to the tangent space at y. + """ + raise NotImplementedError + + def ptransp0(self, v, x, c): + """Parallel transport function, used to move point v in the tangent space of origin point to the tangent space of y. + Args: + v (tensor): A point in the tangent space of origin point. + x (tensor): A point on the manifold. + c (tensor): The manifold curvature. + Returns: + tensor: The result of transporting v from the tangent space at origin point to the tangent space at y. + """ + raise NotImplementedError + + def dist(self, x, y, c): + """Calculate the squared geodesic/distance between x and y. + Args: + x (tensor): A point on the manifold. + y (tensor): A point on the manifold. + c (tensor): The manifold curvature. + Returns: + tensor: the geodesic/distance between x and y. + """ + raise NotImplementedError + + def egrad2rgrad(self, grad, x, c): + """Computes Riemannian gradient from the Euclidean gradient, typically used in Riemannian optimizers. + Args: + grad (tensor): Euclidean gradient at x. + x (tensor): A point on the manifold. + c (tensor): The manifold curvature. + Returns: + tensor: Riemannian gradient at x. + """ + raise NotImplementedError + + def inner(self, v1, v2, x, c, keep_shape): + """Computes the inner product of a pair of tangent vectors v1 and v2 at x. + Args: + v1 (tensor): A tangent point at x. + v2 (tensor): A tangent point at x. + x (tensor): A point on the manifold. + c (tensor): The manifold curvature. + keep_shape (bool, optional): Whether the output tensor keeps shape or not. + Returns: + tensor: The inner product of v1 and v2 at x. + """ + raise NotImplementedError + + def retraction(self, v, x, c): + """Retraction is a continuous map function from tangent space to the manifold, typically used in Riemannian optimizers. + The exp map is one of retraction functions. + Args: + v (tensor): A tangent point at x. + x (tensor): A point on the manifold. + c (tensor): The manifold curvature. + Returns: + tensor: The result of mapping tangent point v at x to the manifold. + """ + return self.proj(self.expmap(v, x, c), c) diff --git a/manifolds/euclidean.py b/manifolds/euclidean.py new file mode 100644 index 0000000..13e5fb0 --- /dev/null +++ b/manifolds/euclidean.py @@ -0,0 +1,61 @@ +"""Euclidean manifold.""" + +import torch +from manifolds.base import Manifold + + +class Euclidean(Manifold): + """Euclidean Manifold class. Usually we refer it as R^n. + + Attributes: + name (str): The manifold name, and its value is "Euclidean". + """ + + def __init__(self, **kwargs): + """Initialize an Euclidean manifold. + Args: + **kwargs: Description + """ + super(Euclidean, self).__init__(**kwargs) + self.name = 'Euclidean' + + def proj(self, x, c): + return x + + def proj_tan(self, v, x, c): + return v + + def proj_tan0(self, v, c): + return v + + def expmap(self, v, x, c): + return v + x + + def expmap0(self, v, c): + return v + + def logmap(self, y, x, c): + return y - x + + def logmap0(self, y, c): + return y + + def ptransp(self, v, x, y, c): + return torch.ones_like(x) * v + + def ptransp0(self, v, x, c): + return torch.ones_like(x) * v + + def dist(self, x, y, c): + sqdis = torch.sum((x - y).pow(2), dim=-1) + return sqdis.sqrt() + + def egrad2rgrad(self, grad, x, c): + return grad + + def inner(self, v1, v2, x, c, keep_shape=False): + if keep_shape: + # In order to keep the same computation logic in Ada* Optimizer + return v1 * v2 + else: + return torch.sum(v1 * v2, dim=-1, keepdim=False) diff --git a/manifolds/poincare.py b/manifolds/poincare.py new file mode 100644 index 0000000..3d2df8d --- /dev/null +++ b/manifolds/poincare.py @@ -0,0 +1,129 @@ +"""Poincare ball manifold.""" + +import torch +from manifolds.base import Manifold +from utils.math_util import TanC, ArTanC + + +class PoincareBall(Manifold): + """PoicareBall Manifold class. + + We use the following convention: x0^2 + x1^2 + ... + xd^2 < 1 / c. (c < 0) + + So that the Poincare ball radius will be 1 / sqrt(-c). + Notice that the more close c is to 0, the more flat space will be. + """ + + def __init__(self, ): + super(PoincareBall, self).__init__() + self.name = 'PoincareBall' + self.truncate_c = lambda x: torch.clamp(x, min=-1e5, max=-1e-5) + + def proj(self, x, c): + c = self.truncate_c(c) + x_norm = self.clip(x.norm(dim=-1, keepdim=True, p=2)) + max_norm = (1 - self.eps) / c.abs().sqrt() + cond = x_norm > max_norm + projected = x / x_norm * max_norm + return torch.where(cond, projected, x) + + def proj_tan(self, v, x, c): + return v + + def proj_tan0(self, v, c): + return v + + def expmap(self, v, x, c): + c = self.truncate_c(c) + v_norm = self.clip(v.norm(p=2, dim=-1, keepdim=True)) + second_term = TanC(self._lambda_x(x, c) * v_norm / 2.0, c) * v / v_norm + gamma = self._mobius_add(x, second_term, c) + return gamma + + def expmap0(self, v, c): + c = self.truncate_c(c) + v_norm = self.clip(v.norm(p=2, dim=-1, keepdim=True)) + gamma = TanC(v_norm, c) * v / v_norm + return gamma + + def logmap(self, y, x, c): + c = self.truncate_c(c) + sub = self._mobius_add(-x, y, c) + sub_norm = self.clip(sub.norm(p=2, dim=-1, keepdim=True)) + lam = self._lambda_x(x, c) + return 2.0 / lam * ArTanC(sub_norm, c) * sub / sub_norm + + def logmap0(self, y, c): + c = self.truncate_c(c) + y_norm = self.clip(y.norm(p=2, axis=-1, keepdim=True)) + return ArTanC(y_norm, c) * y / y_norm + + def ptransp(self, v, x, y, c): + c = self.truncate_c(c) + lambda_x = self._lambda_x(x, c) + lambda_y = self._lambda_x(y, c) + return self._gyration(y, -x, v, c) * lambda_x / lambda_y + + def ptransp0(self, v, x, c): + c = self.truncate_c(c) + lambda_x = self._lambda_x(x, c) + return torch.tensor(2.0, dtype=self.dtype) * v / lambda_x + + def dist(self, x, y, c): + c = self.truncate_c(c) + return 2.0 * ArTanC(self._mobius_add(-x, y, c).norm(p=2, dim=-1), c) + + def egrad2rgrad(self, grad, x, c): + c = self.truncate_c(c) + metric = torch.square(self._lambda_x(x, c)) + return grad / metric + + def inner(self, v1, v2, x, c, keep_shape=False): + c = self.truncate_c(c) + metric = torch.square(self._lambda_x(x, c)) + product = v1 * metric * v2 + res = product.sum(dim=-1, keepdim=True) + if keep_shape: + # return tf.broadcast_to(res, x.shape) + last_dim = x.shape.as_list()[-1] + return torch.cat([res for _ in range(last_dim)], dim=-1) + return torch.squeeze(res, dim=-1) + + def retraction(self, v, x, c): + c = self.truncate_c(c) + new_v = self.expmap(v, x, c) + return self.proj(new_v, c) + + def _mobius_add(self, x, y, c): + c = self.truncate_c(c) + x2 = x.pow(2).sum(dim=-1, keepdim=True) + y2 = y.pow(2).sum(dim=-1, keepdim=True) + xy = (x * y).sum(dim=-1, keepdim=True) + num = (1 - 2 * c * xy - c * y2) * x + (1 + c * x2) * y + denom = 1 - 2 * c * xy + (c ** 2) * x2 * y2 + return num / self.clip(denom) + + def _mobius_mul(self, x, a, c): + c = self.truncate_c(c) + x_norm = self.clip(x.norm(p=2, dim=-1, keepdim=True)) + scale = TanC(a * ArTanC(x_norm, c), c) / x_norm + return scale * x + + def _mobius_matvec(self, x, a, c): + c = self.truncate_c(c) + x_norm = self.clip(x.norm(p=2, dim=-1, keepdim=True)) + mx = torch.matmul(x, a) + mx_norm = self.clip(mx.norm(p=2, dim=-1, keepdim=True)) + res = TanC(mx_norm / x_norm * ArTanC(x_norm, c), c) * mx / mx_norm + return res + + def _lambda_x(self, x, c): + c = self.truncate_c(c) + x_sqnorm = x.pow(2).sum(dim=-1, keepdim=True) + return self.clip(2.0 / (1.0 + c * x_sqnorm)) + + def _gyration(self, x, y, v, c): + xy = self._mobius_add(x, y, c) + yv = self._mobius_add(y, v, c) + xyv = self._mobius_add(x, yv, c) + return self._mobius_add(-xy, xyv, c) diff --git a/models/block.py b/models/block.py new file mode 100644 index 0000000..cd51c89 --- /dev/null +++ b/models/block.py @@ -0,0 +1,41 @@ +import torch.nn as nn + + +def _get_activation_fn(activation): + if activation == "relu": + return nn.ReLU() + elif activation == "softplus": + return nn.Softplus() + elif activation == "tanh": + return nn.Tanh() + else: + raise RuntimeError("activation should be relu/tanh/softplus, not {}".format(activation)) + + +class ResBlock(nn.Module): + """Simple MLP block with residual connection. + + Args: + in_features: the feature dimension of each output sample. + out_features: the feature dimension of each output sample. + activation: the activation function of intermediate layer, relu or gelu. + """ + + def __init__(self, in_features, out_features, activation="relu"): + super(ResBlock, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.fc1 = nn.Linear(in_features, out_features) + self.fc2 = nn.Linear(out_features, out_features) + + self.bn = nn.BatchNorm1d(out_features) + self.activation = _get_activation_fn(activation) + + def forward(self, x): + if self.in_features == self.out_features: + out = self.fc2(self.activation(self.fc1(x))) + return self.activation(self.bn(x + out)) + else: + x = self.fc1(x) + out = self.fc2(self.activation(x)) + return self.activation(self.bn(x + out)) diff --git a/models/etm.py b/models/etm.py new file mode 100644 index 0000000..de7c559 --- /dev/null +++ b/models/etm.py @@ -0,0 +1,169 @@ +import torch +import torch.nn as nn +from models.block import ResBlock, _get_activation_fn + + +class ETM(nn.Module): + """Simple implementation of the <> + + Args + args: the set of arguments used to characterize the hierarchical neural topic model. + device: the physical hardware that the model is trained on. + pretrained_embeddings: if not None, initialize each word embedding in the vocabulary with pretrained Glove embeddings. + """ + + def __init__(self, args, device, word_embeddings): + super(ETM, self).__init__() + self.device = device + + # hyper-parameters + self.embed_size = args.embed_size + self.vocab_size = args.vocab_size + self.num_topics = args.num_topics_list[0] + self.num_hiddens = args.num_hiddens_list[0] + + # learnable word embeddings + if word_embeddings is not None: + self.rho = nn.Parameter(torch.from_numpy(word_embeddings).float()) + else: + self.rho = nn.Parameter( + torch.empty(args.vocab_size, args.embed_size).normal_(std=0.02)) + + # topic embeddings for different latent layers + self.alpha = nn.Parameter( + torch.empty(self.num_topics, args.embed_size).normal_(std=0.02)) + + # deterministic mapping to obtain hidden features + self.h_encoder = ResBlock(self.vocab_size, self.num_hiddens, args.act) + + # variational encoder to obtain posterior parameters + self.q_theta = nn.Linear(self.num_hiddens, 2 * self.num_topics) + + def reparameterize(self, mu, logvar): + """Returns a sample from a Gaussian distribution via reparameterization. + """ + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps.mul_(std).add_(mu) + + def kl_normal_normal(self, mu_theta, logsigma_theta): + """Returns the Kullback-Leibler divergence between a normal distribution and a standard normal distribution. + """ + kl_div = -0.5 * torch.sum( + 1 + logsigma_theta - mu_theta.pow(2) - logsigma_theta.exp(), dim=-1 + ) + return kl_div.mean() + + def get_phi(self): + """Derives the topic-word matrix by computing the inner product. + """ + dist = torch.mm(self.rho, self.alpha.transpose(0, 1)) + phi = torch.softmax(dist, dim=0) + return phi + + def forward(self, x, is_training=True): + """Forward pass: compute the kl loss and data likelihood. + """ + denorm = torch.where(x.sum(dim=1, keepdims=True) > 0, x.sum(dim=1, keepdims=True), torch.tensor([1.]).cuda()) + norm_x = x / denorm + + hidden_feats = self.h_encoder(norm_x) + mu, logvar = torch.chunk(self.q_theta(hidden_feats), 2, dim=1) + kl_loss = self.kl_normal_normal(mu, logvar) + if is_training: + theta = torch.softmax(self.reparameterize(mu, logvar), dim=1) + else: + theta = torch.softmax(mu, dim=1) + + # ================================================================================= + phi = self.get_phi() + logit = torch.mm(theta, phi.t()) + almost_zeros = torch.full_like(logit, 1e-6) + results_without_zeros = logit.add(almost_zeros) + predictions = torch.log(results_without_zeros) + recon_loss = -(predictions * x).sum(1).mean() + + nelbo = recon_loss + kl_loss + return nelbo, recon_loss, kl_loss, theta + + +class VanillaETM(nn.Module): + + def __init__(self, args, device, word_embeddings): + super(VanillaETM, self).__init__() + self.device = device + + # hyper-parameters + self.embed_size = args.embed_size + self.vocab_size = args.vocab_size + self.num_topics = args.num_topics_list[0] + self.num_hiddens = args.num_hiddens_list[0] + + # learnable word embeddings + if word_embeddings is not None: + self.rho = nn.Parameter(torch.from_numpy(word_embeddings).float()) + else: + self.rho = nn.Linear(args.embed_size, args.vocab_size, bias=False) + + # topic embeddings for different latent layers + self.alpha = nn.Linear(args.embed_size, self.num_topics, bias=False) + + # deterministic mapping to obtain hierarchical features + self.activation = _get_activation_fn(args.act) + self.h_encoder = nn.Sequential( + nn.Linear(self.vocab_size, self.num_hiddens), + self.activation, + nn.Linear(self.num_hiddens, self.num_hiddens), + self.activation, + nn.Dropout(args.dropout) + ) + + # variational encoder to obtain posterior parameters + self.q_theta = nn.Linear(self.num_hiddens, 2 * self.num_topics) + + def reparameterize(self, mu, logvar): + """Returns a sample from a Gaussian distribution via reparameterization. + """ + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps.mul_(std).add_(mu) + + def kl_normal_normal(self, mu_theta, logsigma_theta): + """Returns the Kullback-Leibler divergence between a normal distribution and a standard normal distribution. + """ + kl_div = -0.5 * torch.sum( + 1 + logsigma_theta - mu_theta.pow(2) - logsigma_theta.exp(), dim=-1 + ) + return kl_div.mean() + + def get_phi(self): + """Derives the topic-word matrix by computing the inner product. + """ + dist = self.alpha(self.rho.weight) + phi = torch.softmax(dist, dim=0) + return phi + + def forward(self, x, is_training=True): + """Forward pass: compute the kl loss and data likelihood. + """ + denorm = torch.where(x.sum(dim=1, keepdims=True) > 0, x.sum(dim=1, keepdims=True), torch.tensor([1.]).cuda()) + norm_x = x / denorm + + hidden_feats = self.h_encoder(norm_x) + mu, logvar = torch.chunk(self.q_theta(hidden_feats), 2, dim=1) + kl_loss = self.kl_normal_normal(mu, logvar) + if is_training: + theta = torch.softmax(self.reparameterize(mu, logvar), dim=1) + else: + theta = torch.softmax(mu, dim=1) + + # ================================================================================= + phi = self.get_phi() + logit = torch.mm(theta, phi.t()) + almost_zeros = torch.full_like(logit, 1e-6) + results_without_zeros = logit.add(almost_zeros) + predictions = torch.log(results_without_zeros) + recon_loss = -(predictions * x).sum(1).mean() + + nelbo = recon_loss + kl_loss + return nelbo, recon_loss, kl_loss, theta diff --git a/models/hyperminer.py b/models/hyperminer.py new file mode 100644 index 0000000..7f9fe53 --- /dev/null +++ b/models/hyperminer.py @@ -0,0 +1,295 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import manifolds +from models.etm import ETM +from models.sawetm import SawETM + + +class HyperETM(ETM): + """A variant of ETM that embeds words and topic into hyperbolic space to measure their distance + """ + + def __init__(self, args, device, word_embeddings): + super(HyperETM, self).__init__(args, device, word_embeddings) + self.manifold = getattr(manifolds, args.manifold)() + if args.c is not None: + self.curvature = torch.tensor([args.c]) + self.curvature = self.curvature.to(device) + else: + self.curvature = nn.Parameter(torch.Tensor([-1.])) + + # the effective radius used to clip Euclidean features + self.clip_r = args.clip_r + + def feat_clip(self, x): + """Use feature clipping technique to avoid the gradient vanishing problem""" + x_norm = x.norm(p=2, dim=-1, keepdim=True) + cond = x_norm > self.clip_r + projected = x / x_norm * self.clip_r + return torch.where(cond, projected, x) + + def get_phi(self): + """Derives the topic-word matrix by the distance in hyperbolic space. + """ + hyp_rho = self.manifold.proj( + # self.manifold.expmap0(self.feat_clip(self.rho), self.curvature), + self.manifold.expmap0(self.rho, self.curvature), + self.curvature + ) + hyp_alpha = self.manifold.proj( + # self.manifold.expmap0(self.feat_clip(self.alpha), self.curvature), + self.manifold.expmap0(self.alpha, self.curvature), + self.curvature + ) + return torch.softmax(-self.manifold.dist( + hyp_rho.unsqueeze(1), hyp_alpha.unsqueeze(0), self.curvature + ), dim=0) + + def forward(self, x, is_training=True): + """Forward pass: compute the kl loss and data likelihood. + """ + denorm = torch.where(x.sum(dim=1, keepdims=True) > 0, x.sum(dim=1, keepdims=True), torch.tensor([1.]).cuda()) + norm_x = x / denorm + + hidden_feats = self.h_encoder(norm_x) + mu, logvar = torch.chunk(self.q_theta(hidden_feats), 2, dim=1) + kl_loss = self.kl_normal_normal(mu, logvar) + if is_training: + theta = torch.softmax(self.reparameterize(mu, logvar), dim=1) + else: + theta = torch.softmax(mu, dim=1) + + # ================================================================================= + phi = self.get_phi() + logit = torch.mm(theta, phi.t()) + almost_zeros = torch.full_like(logit, 1e-6) + results_without_zeros = logit.add(almost_zeros) + predictions = torch.log(results_without_zeros) + recon_loss = -(predictions * x).sum(1).mean() + + nelbo = recon_loss + 0.5 * kl_loss + return nelbo, recon_loss, kl_loss, theta + + +class HyperMiner(SawETM): + """A variant of SawETM that embeds words and topic into hyperbolic space to measure their distance + """ + + def __init__(self, args, device, word_embeddings): + super(HyperMiner, self).__init__(args, device, word_embeddings) + self.manifold = getattr(manifolds, args.manifold)() + if args.c is not None: + self.curvature = torch.tensor([args.c]) + self.curvature = self.curvature.to(device) + else: + self.curvature = nn.Parameter(torch.Tensor([-1.])) + + self.clip_r = args.clip_r + + def feat_clip(self, x): + x_norm = x.norm(p=2, dim=-1, keepdim=True) + cond = x_norm > self.clip_r + projected = x / x_norm * self.clip_r + return torch.where(cond, projected, x) + + def get_phi(self): + """Returns the factor loading matrix by utilizing sawtooth connection. + """ + phis = [] + for n in range(self.num_layers): + if n == 0: + hyp_rho = self.manifold.proj( + self.manifold.expmap0(self.rho, self.curvature), self.curvature) + hyp_alpha = self.manifold.proj( + self.manifold.expmap0(self.alpha[n], self.curvature), self.curvature) + phi = torch.softmax(-self.manifold.dist( + hyp_rho.unsqueeze(1), hyp_alpha.unsqueeze(0), self.curvature), dim=0) + else: + hyp_alpha1 = self.manifold.proj( + self.manifold.expmap0(self.alpha[n - 1], self.curvature), self.curvature) + hyp_alpha2 = self.manifold.proj( + self.manifold.expmap0(self.alpha[n], self.curvature), self.curvature) + phi = torch.softmax(-self.manifold.dist( + hyp_alpha1.unsqueeze(1).detach(), hyp_alpha2.unsqueeze(0), self.curvature), dim=0) + phis.append(phi) + return phis + + def forward(self, x, is_training=True): + """Forward pass: compute the kl loss and data likelihood. + """ + hidden_feats = [] + for n in range(self.num_layers): + if n == 0: + hidden_feats.append(self.h_encoder[n](x)) + else: + hidden_feats.append(self.h_encoder[n](hidden_feats[-1])) + + # ================================================================================= + phis = self.get_phi() + + ks = [] + lambs = [] + thetas = [] + phi_by_theta_list = [] + for n in range(self.num_layers - 1, -1, -1): + if n == self.num_layers - 1: + joint_feat = hidden_feats[n] + else: + joint_feat = torch.cat((hidden_feats[n], phi_by_theta_list[0]), dim=1) + + k, lamb = torch.chunk(F.softplus(self.q_theta[n](joint_feat)), 2, dim=1) + k = torch.clamp(k, self.wei_shape_min.item(), self.wei_shape_max.item()) + lamb = torch.clamp(lamb, self.real_min.item()) + + if is_training: + lamb = lamb / torch.exp(torch.lgamma(1 + 1 / k)) + theta = self.reparameterize(k, lamb, sample_num=3) if n == 0 else self.reparameterize(k, lamb) + else: + theta = torch.min(lamb, self.theta_max) + + phi_by_theta = torch.mm(theta, phis[n].t()) + phi_by_theta_list.insert(0, phi_by_theta) + thetas.insert(0, theta) + lambs.insert(0, lamb) + ks.insert(0, k) + + # ================================================================================= + nll = self.get_nll(x, phi_by_theta_list[0]) + + kl_loss = [] + for n in range(self.num_layers): + if n == self.num_layers - 1: + kl_loss.append(self.kl_weibull_gamma( + ks[n], lambs[n], self.gam_prior, self.gam_prior)) + else: + kl_loss.append(self.kl_weibull_gamma( + ks[n], lambs[n], phi_by_theta_list[n + 1], self.gam_prior)) + + nelbo = nll + sum(kl_loss) + return nelbo, nll, sum(kl_loss), thetas + + +class HyperMinerKG(HyperMiner): + """An improved version of HyperMiner that injects external knowledge to guide the learning of topic taxonomy + """ + + def __init__(self, args, device, word_embeddings, adjacent_mat): + super(HyperMinerKG, self).__init__(args, device, word_embeddings) + self.manifold = getattr(manifolds, args.manifold)() + if args.c is not None: + self.curvature = torch.tensor([args.c]) + self.curvature = self.curvature.to(device) + else: + self.curvature = nn.Parameter(torch.Tensor([-1.])) + + self.clip_r = args.clip_r + self.adj = adjacent_mat.to(device) + self.split_sections = self.num_topics_list[:: -1] + [self.vocab_size] + self.inp_embeddings = nn.Parameter( + torch.empty(sum(self.split_sections), args.embed_size).normal_(std=0.02)) + + del self.rho, self.alpha + self.temp = 1.0 + + def get_phi(self): + """Returns the factor loading matrix according to hyperbolic distance. + """ + hyp_embeddings = self.manifold.proj( + self.manifold.expmap0(self.inp_embeddings, self.curvature), self.curvature) + hyp_embeddings[0] = torch.tensor([0.35, 0.35]).to(self.device) + hyp_embeddings[1] = torch.tensor([-0.35, -0.35]).to(self.device) + + N = sum(self.num_topics_list) + dist_mat = self.manifold.dist( + hyp_embeddings[: N].unsqueeze(1), + hyp_embeddings.unsqueeze(0), + self.curvature + ) + + phis = [] + for n in range(self.num_layers): + x_start = sum(self.split_sections[: self.num_layers - n - 1]) + x_end = sum(self.split_sections[: self.num_layers - n]) + y_start = sum(self.split_sections[: self.num_layers - n]) + y_end = sum(self.split_sections[: self.num_layers - n + 1]) + phi = torch.softmax( + -dist_mat[x_start: x_end, y_start: y_end].t(), dim=0) + phis.append(phi) + return phis, dist_mat + + def contrastive_loss(self, dist_mat, K=256): + """Hyperbolic contrastive loss to maintain the semantic structure information. + """ + N = sum(self.num_topics_list) + adj_dense = self.adj.to_dense() + neg_adj = torch.ones_like(adj_dense) - adj_dense + pos_loss = torch.exp(-(adj_dense[: N] * dist_mat).max(1)[0] / self.temp) + # pos_loss = torch.exp(-(adj_dense[: N] * dist_mat) / self.temp).sum(1) + + # 1. sampling K negatives from the set of non-first-order neighbors + neg_dist = (neg_adj[: N] * dist_mat) + neg_dist = torch.where(neg_dist > 1e-6, neg_dist, torch.tensor(1000, dtype=torch.float32).to(self.device)) + neg_loss = torch.exp(-neg_dist.topk(K, dim=1, largest=False)[0] / self.temp).sum(1) + + # 2. consider all non-first-order neighbors as negatives + # neg_loss = (neg_adj[: N] * torch.exp(-dist_mat / self.temp)).sum(1) + + nce_loss = torch.log(pos_loss + neg_loss + self.real_min) - torch.log(pos_loss + self.real_min) + return nce_loss.mean() + + def forward(self, x, is_training=True): + """Forward pass: compute the kl loss and data likelihood. + """ + hidden_feats = [] + for n in range(self.num_layers): + if n == 0: + hidden_feats.append(self.h_encoder[n](x)) + else: + hidden_feats.append(self.h_encoder[n](hidden_feats[-1])) + + # ================================================================================= + phis, dist_mat = self.get_phi() + contrast_loss = self.contrastive_loss(dist_mat) + + ks = [] + lambs = [] + thetas = [] + phi_by_theta_list = [] + for n in range(self.num_layers-1, -1, -1): + if n == self.num_layers - 1: + joint_feat = hidden_feats[n] + else: + joint_feat = torch.cat((hidden_feats[n], phi_by_theta_list[0]), dim=1) + + k, lamb = torch.chunk(F.softplus(self.q_theta[n](joint_feat)), 2, dim=1) + k = torch.clamp(k, self.wei_shape_min.item(), self.wei_shape_max.item()) + lamb = torch.clamp(lamb, self.real_min.item()) + + if is_training: + lamb = lamb / torch.exp(torch.lgamma(1 + 1 / k)) + theta = self.reparameterize(k, lamb, sample_num=5) if n == 0 else self.reparameterize(k, lamb) + else: + theta = torch.min(lamb, self.theta_max) + + phi_by_theta = torch.mm(theta, phis[n].t()) + phi_by_theta_list.insert(0, phi_by_theta) + thetas.insert(0, theta) + lambs.insert(0, lamb) + ks.insert(0, k) + + # ================================================================================= + nll = self.get_nll(x, phi_by_theta_list[0]) + + kl_loss = [] + for n in range(self.num_layers): + if n == self.num_layers - 1: + kl_loss.append(self.kl_weibull_gamma( + ks[n], lambs[n], self.gam_prior, self.gam_prior)) + else: + kl_loss.append(self.kl_weibull_gamma( + ks[n], lambs[n], phi_by_theta_list[n+1], self.gam_prior)) + + nelbo = nll + 0.2 * sum(kl_loss) + 5 * contrast_loss + return nelbo, nll, contrast_loss, thetas diff --git a/models/sawetm.py b/models/sawetm.py new file mode 100644 index 0000000..b509d17 --- /dev/null +++ b/models/sawetm.py @@ -0,0 +1,162 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.block import ResBlock + + +class SawETM(nn.Module): + """Simple implementation of the <> + + Args + args: the set of arguments used to characterize the hierarchical neural topic model. + device: the physical hardware that the model is trained on. + pretrained_embeddings: if not None, initialize each word embedding in the vocabulary with pretrained Glove embeddings. + """ + + def __init__(self, args, device, word_embeddings): + super(SawETM, self).__init__() + # constants + self.device = device + self.gam_prior = torch.tensor(1.0, dtype=torch.float, device=device) + self.real_min = torch.tensor(1e-30, dtype=torch.float, device=device) + self.theta_max = torch.tensor(1000.0, dtype=torch.float, device=device) + self.wei_shape_min = torch.tensor(1e-1, dtype=torch.float, device=device) + self.wei_shape_max = torch.tensor(100.0, dtype=torch.float, device=device) + + # hyper-parameters + self.embed_size = args.embed_size + self.vocab_size = args.vocab_size + self.num_topics_list = args.num_topics_list + self.num_hiddens_list = args.num_hiddens_list + assert len(args.num_topics_list) == len(args.num_hiddens_list) + self.num_layers = len(args.num_topics_list) + + # learnable word embeddings + if word_embeddings is not None: + self.rho = nn.Parameter(torch.from_numpy(word_embeddings).float()) + else: + self.rho = nn.Parameter( + torch.empty(args.vocab_size, args.embed_size).normal_(std=0.02)) + + # topic embeddings for different latent layers + self.alpha = nn.ParameterList([]) + for n in range(self.num_layers): + self.alpha.append(nn.Parameter( + torch.empty(args.num_topics_list[n], args.embed_size).normal_(std=0.02))) + + # deterministic mapping to obtain hierarchical features + self.h_encoder = nn.ModuleList([]) + for n in range(self.num_layers): + if n == 0: + self.h_encoder.append( + ResBlock(args.vocab_size, args.num_hiddens_list[n])) + else: + self.h_encoder.append( + ResBlock(args.num_hiddens_list[n - 1], args.num_hiddens_list[n])) + + # variational encoder to obtain posterior parameters + self.q_theta = nn.ModuleList([]) + for n in range(self.num_layers): + if n == self.num_layers - 1: + self.q_theta.append( + nn.Linear(args.num_hiddens_list[n], 2 * args.num_topics_list[n])) + else: + self.q_theta.append(nn.Linear( + args.num_hiddens_list[n] + args.num_topics_list[n], 2 * args.num_topics_list[n])) + + def log_max(self, x): + return torch.log(torch.max(x, self.real_min)) + + def reparameterize(self, shape, scale, sample_num=50): + """Returns a sample from a Weibull distribution via reparameterization. + """ + shape = shape.unsqueeze(0).repeat(sample_num, 1, 1) + scale = scale.unsqueeze(0).repeat(sample_num, 1, 1) + eps = torch.rand_like(shape, dtype=torch.float, device=self.device) + samples = scale * torch.pow(- self.log_max(1 - eps), 1 / shape) + return torch.clamp(samples.mean(0), self.real_min.item(), self.theta_max.item()) + + def kl_weibull_gamma(self, wei_shape, wei_scale, gam_shape, gam_scale): + """Returns the Kullback-Leibler divergence between a Weibull distribution and a Gamma distribution. + """ + euler_mascheroni_c = torch.tensor(0.5772, dtype=torch.float, device=self.device) + t1 = torch.log(wei_shape) + torch.lgamma(gam_shape) + t2 = - gam_shape * torch.log(wei_scale * gam_scale) + t3 = euler_mascheroni_c * (gam_shape / wei_shape - 1) - 1 + t4 = gam_scale * wei_scale * torch.exp(torch.lgamma(1 + 1 / wei_shape)) + return (t1 + t2 + t3 + t4).sum(1).mean() + + def get_nll(self, x, x_reconstruct): + """Returns the negative Poisson likelihood of observational count data. + """ + log_likelihood = self.log_max(x_reconstruct) * x - torch.lgamma(1.0 + x) - x_reconstruct + neg_log_likelihood = - torch.sum(log_likelihood, dim=1, keepdim=False).mean() + return neg_log_likelihood + + def get_phi(self): + """Returns the factor loading matrix by utilizing sawtooth connection. + """ + phis = [] + for n in range(self.num_layers): + if n == 0: + phi = torch.softmax(torch.mm( + self.rho, self.alpha[n].transpose(0, 1)), dim=0) + else: + phi = torch.softmax(torch.mm( + self.alpha[n - 1].detach(), self.alpha[n].transpose(0, 1)), dim=0) + phis.append(phi) + return phis + + def forward(self, x, is_training=True): + """Forward pass: compute the kl loss and data likelihood. + """ + hidden_feats = [] + for n in range(self.num_layers): + if n == 0: + hidden_feats.append(self.h_encoder[n](x)) + else: + hidden_feats.append(self.h_encoder[n](hidden_feats[-1])) + + # ================================================================================= + phis = self.get_phi() + + ks = [] + lambs = [] + thetas = [] + phi_by_theta_list = [] + for n in range(self.num_layers - 1, -1, -1): + if n == self.num_layers - 1: + joint_feat = hidden_feats[n] + else: + joint_feat = torch.cat((hidden_feats[n], phi_by_theta_list[0]), dim=1) + + k, lamb = torch.chunk(F.softplus(self.q_theta[n](joint_feat)), 2, dim=1) + k = torch.clamp(k, self.wei_shape_min.item(), self.wei_shape_max.item()) + lamb = torch.clamp(lamb, self.real_min.item()) + + if is_training: + lamb = lamb / torch.exp(torch.lgamma(1 + 1 / k)) + theta = self.reparameterize(k, lamb, sample_num=3) if n == 0 else self.reparameterize(k, lamb) + else: + theta = torch.min(lamb, self.theta_max) + + phi_by_theta = torch.mm(theta, phis[n].t()) + phi_by_theta_list.insert(0, phi_by_theta) + thetas.insert(0, theta) + lambs.insert(0, lamb) + ks.insert(0, k) + + # ================================================================================= + nll = self.get_nll(x, phi_by_theta_list[0]) + + kl_loss = [] + for n in range(self.num_layers): + if n == self.num_layers - 1: + kl_loss.append(self.kl_weibull_gamma( + ks[n], lambs[n], self.gam_prior, self.gam_prior)) + else: + kl_loss.append(self.kl_weibull_gamma( + ks[n], lambs[n], phi_by_theta_list[n + 1], self.gam_prior)) + + nelbo = nll + sum(kl_loss) + return nelbo, nll, sum(kl_loss), thetas diff --git a/train_clustering.py b/train_clustering.py new file mode 100644 index 0000000..b23f471 --- /dev/null +++ b/train_clustering.py @@ -0,0 +1,207 @@ +from __future__ import division +from __future__ import print_function + +import datetime +import logging +import numpy as np +import os +import pickle +import time + +import torch +from config import parser +from models.etm import ETM +from models.sawetm import SawETM +from models.hyperminer import HyperETM, HyperMiner, HyperMinerKG +from utils.data_util import get_data_loader +from utils.train_util import get_dir_name, convert_to_coo_adj, load_glove_embeddings, visualize_topics +from utils.eval_util import text_clustering + + +def main(args): + global save_dir + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if int(args.cuda) >= 0 and torch.cuda.is_available(): + torch.cuda.manual_seed(args.seed) + device = torch.device('cuda:' + str(args.cuda)) + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) + else: + device = torch.device('cpu') + + if int(args.double_precision): + torch.set_default_dtype(torch.float64) + + logging.getLogger().setLevel(logging.INFO) + if args.save: + if not args.save_dir: + dt = datetime.datetime.now() + date = f"{dt.year}_{dt.month}_{dt.day}" + models_dir = os.path.join('logs', args.dataset, date) + save_dir = get_dir_name(models_dir) + else: + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + logging.basicConfig(level=logging.INFO, + handlers=[ + logging.FileHandler(os.path.join(save_dir, 'log.txt')), + logging.StreamHandler() + ]) + logging.info('Clustering experiment') + logging.info(f'Using {device}') + logging.info(f'Using seed {args.seed}') + + train_loader, vocab = get_data_loader(args.dataset, args.data_path, 'train', args.batch_size) + test_loader, _ = get_data_loader(args.dataset, args.data_path, 'test', args.batch_size, shuffle=False, drop_last=False) + args.vocab_size = len(vocab) + logging.info(f'Using dataset {args.dataset}') + logging.info(f'{len(vocab)} words as vocabulary') + logging.info(f'{len(train_loader.dataset)} training docs') + logging.info(f'{len(test_loader.dataset)} test docs') + + if args.pretrained_embeddings: + logging.info('Using pretrained glove embeddings') + initial_embeddings = load_glove_embeddings(args.embed_size, vocab) + else: + initial_embeddings = None + + if args.add_knowledge: + with open(args.file_path, 'rb') as f: + adj, num_topics_list, concept_names = pickle.load(f) + num_layers = len(num_topics_list) + args.num_topics_list = num_topics_list + args.num_hiddens_list = args.num_hiddens_list[: num_layers] + + sparse_adj = convert_to_coo_adj(adj) + model = HyperMinerKG(args, device, initial_embeddings, sparse_adj) + else: + if args.manifold == 'Euclidean': + model = SawETM(args, device, initial_embeddings) + # model = ETM(args, device, initial_embeddings) + else: + model = HyperMiner(args, device, initial_embeddings) + # model = HyperETM(args, device, initial_embeddings) + model = model.to(device) + logging.info(str(model)) + total_params = sum([np.prod(p.size()) for p in model.parameters()]) + logging.info(f"Total number of parameters:{total_params}") + + optimizer = torch.optim.Adam( + params=model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay + ) + logging.info(f'Using {args.optimizer} optimizer') + logging.info(f'Initial learning rate {args.lr}') + + if not args.lr_reduce_freq: + args.lr_reduce_freq = args.epochs + else: + logging.info(f'Decay rate {args.gamma}') + logging.info(f'Step size {args.lr_reduce_freq}') + lr_scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, + step_size=int(args.lr_reduce_freq), + gamma=float(args.gamma) + ) + + best_purity_epoch = 0 + best_purity = 0 + best_nmi_epoch = 0 + best_nmi = 0 + + # Train model + t_total = time.time() + for epoch in range(args.epochs): + model.train() + t = time.time() + + lowerbound = [] + likelihood = [] + kl_div = [] + for idx, (batch_data, _) in enumerate(train_loader): + batch_data = batch_data.float().to(device) + nelbo, nll, kl_loss, _ = model(batch_data) + lowerbound.append(-nelbo.item()) + likelihood.append(-nll.item()) + kl_div.append(kl_loss.item()) + + flag = 0 + for p in model.parameters(): + flag += torch.sum(torch.isnan(p)) + if flag == 0: + optimizer.zero_grad() + nelbo.backward() + if args.grad_clip is not None: + max_norm = float(args.grad_clip) + all_params = list(model.parameters()) + for param in all_params: + torch.nn.utils.clip_grad_norm_(param, max_norm) + optimizer.step() + + if (idx + 1) % 10 == 0: + print('Epoch: [{}/{}]\t elbo: {}\t likelihood: {}\t kl: {}'.format( + idx + 1, epoch + 1, np.mean(lowerbound), np.mean(likelihood), np.mean(kl_div))) + + if (epoch + 1) % args.log_freq == 0: + logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1), + '\tlr: {}'.format(lr_scheduler.get_last_lr()[0]), + '\telbo: {:.8f}'.format(np.mean(lowerbound)), + '\tlikelihood: {:.8f}'.format(np.mean(likelihood)), + '\ttime: {:.4f}s'.format(time.time() - t) + ])) + + if (epoch + 1) % args.eval_freq == 0: + model.eval() + test_feats = [] + test_labels = [] + for idx, (batch_data, batch_labels) in enumerate(test_loader): + batch_data = batch_data.float().to(device) + with torch.no_grad(): + _, _, _, thetas = model(batch_data, is_training=False) + # test_feats.append(thetas.cpu().numpy()) + test_feats.append(thetas[0].cpu().numpy()) + test_labels.append(batch_labels.numpy()) + test_feats = np.concatenate(test_feats, axis=0) + test_labels = np.concatenate(test_labels) + + purity, nmi = text_clustering(test_feats, test_labels) + logging.info("Epoch: {:04d}\t Purity: {:.6f}\t NMI: {:.6f}".format(epoch + 1, purity, nmi)) + + if purity > best_purity: + best_purity = purity + best_purity_epoch = epoch + 1 + torch.save( + model.state_dict(), + os.path.join(save_dir, 'ckpt_best_purity.pth') + ) + if nmi > best_nmi: + best_nmi = nmi + best_nmi_epoch = epoch + 1 + torch.save( + model.state_dict(), + os.path.join(save_dir, 'ckpt_best_nmi.pth') + ) + + lr_scheduler.step() + + logging.info("Optimization Finished!") + # logging.info("Best epoch: {}".format(best_epoch)) + logging.info("Best clustering purity: {:.6f} at epoch {}".format(best_purity, best_purity_epoch)) + logging.info("Best clustering nmi: {:.6f} at epoch {}".format(best_nmi, best_nmi_epoch)) + logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total)) + + # model.load_state_dict(torch.load( + # os.path.join(save_dir, 'ckpt_best.pth'), + # map_location=device + # )) + # model.eval() + # with torch.no_grad(): + # phis = model.get_phi() + # visualize_topics(phis, save_dir, vocab) + + +if __name__ == '__main__': + args = parser.parse_args() + main(args) diff --git a/train_quality.py b/train_quality.py new file mode 100644 index 0000000..5ca1d8a --- /dev/null +++ b/train_quality.py @@ -0,0 +1,224 @@ +from __future__ import division +from __future__ import print_function + +import datetime +# import json +import logging +import numpy as np +import os +import pickle +import time + +import torch +from config import parser +from models.etm import ETM +from models.sawetm import SawETM +from models.hyperminer import HyperETM, HyperMiner, HyperMinerKG +from utils.data_util import get_data_loader +from utils.train_util import get_dir_name, convert_to_coo_adj, load_glove_embeddings, visualize_topics +from utils.eval_util import topic_diversity, topic_coherence + + +def main(args): + global save_dir + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if int(args.cuda) >= 0 and torch.cuda.is_available(): + torch.cuda.manual_seed(args.seed) + device = torch.device('cuda:' + str(args.cuda)) + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) + else: + device = torch.device('cpu') + + if int(args.double_precision): + torch.set_default_dtype(torch.float64) + + logging.getLogger().setLevel(logging.INFO) + if args.save: + if not args.save_dir: + dt = datetime.datetime.now() + date = f"{dt.year}_{dt.month}_{dt.day}" + models_dir = os.path.join('logs', args.dataset, date) + save_dir = get_dir_name(models_dir) + else: + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + logging.basicConfig(level=logging.INFO, + handlers=[ + logging.FileHandler(os.path.join(save_dir, 'log.txt')), + logging.StreamHandler() + ]) + logging.info('Topic quality experiment') + logging.info(f'Using {device}') + logging.info(f'Using seed {args.seed}') + + train_loader, vocab = get_data_loader(args.dataset, args.data_path, 'train', args.batch_size) + test_loader, _ = get_data_loader(args.dataset, args.data_path, 'test', args.batch_size, shuffle=False, drop_last=False) + args.vocab_size = len(vocab) + logging.info(f'Using dataset {args.dataset}') + logging.info(f'{len(vocab)} words as vocabulary') + logging.info(f'{len(train_loader.dataset)} training docs') + logging.info(f'{len(test_loader.dataset)} test docs') + + if args.pretrained_embeddings: + logging.info('Using pretrained glove embeddings') + initial_embeddings = load_glove_embeddings(args.embed_size, vocab) + else: + initial_embeddings = None + + if args.add_knowledge: + with open(args.file_path, 'rb') as f: + adj, num_topics_list, concept_names = pickle.load(f) + num_layers = len(num_topics_list) + args.num_topics_list = num_topics_list + args.num_hiddens_list = args.num_hiddens_list[: num_layers] + + sparse_adj = convert_to_coo_adj(adj) + model = HyperMinerKG(args, device, initial_embeddings, sparse_adj) + else: + if args.manifold == 'Euclidean': + model = SawETM(args, device, initial_embeddings) + # model = ETM(args, device, initial_embeddings) + else: + model = HyperMiner(args, device, initial_embeddings) + # model = HyperETM(args, device, initial_embeddings) + model = model.to(device) + logging.info(str(model)) + total_params = sum([np.prod(p.size()) for p in model.parameters()]) + logging.info(f"Total number of parameters:{total_params}") + + optimizer = torch.optim.Adam( + params=model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay + ) + logging.info(f'Using {args.optimizer} optimizer') + logging.info(f'Initial learning rate {args.lr}') + + if not args.lr_reduce_freq: + args.lr_reduce_freq = args.epochs + else: + logging.info(f'Decay rate {args.gamma}') + logging.info(f'Step size {args.lr_reduce_freq}') + lr_scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, + step_size=int(args.lr_reduce_freq), + gamma=float(args.gamma) + ) + + best_diversity = [0, 0] + best_coherence = [0, 0] + best_diversity_epoch = 0 + best_coherence_epoch = 0 + + # Train model + t_total = time.time() + for epoch in range(args.epochs): + model.train() + t = time.time() + + lowerbound = [] + likelihood = [] + kl_div = [] + for idx, (batch_data, _) in enumerate(train_loader): + batch_data = batch_data.float().to(device) + nelbo, nll, kl_loss, _ = model(batch_data) + lowerbound.append(-nelbo.item()) + likelihood.append(-nll.item()) + kl_div.append(kl_loss.item()) + + flag = 0 + for p in model.parameters(): + flag += torch.sum(torch.isnan(p)) + if flag == 0: + optimizer.zero_grad() + nelbo.backward() + if args.grad_clip is not None: + max_norm = float(args.grad_clip) + all_params = list(model.parameters()) + for param in all_params: + torch.nn.utils.clip_grad_norm_(param, max_norm) + optimizer.step() + + if (idx + 1) % 10 == 0: + print('Epoch: [{}/{}]\t elbo: {}\t likelihood: {}\t kl: {}'.format( + idx + 1, epoch + 1, np.mean(lowerbound), np.mean(likelihood), np.mean(kl_div))) + + if (epoch + 1) % args.log_freq == 0: + logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1), + '\tlr: {}'.format(lr_scheduler.get_last_lr()[0]), + '\telbo: {:.8f}'.format(np.mean(lowerbound)), + '\tlikelihood: {:.8f}'.format(np.mean(likelihood)), + '\ttime: {:.4f}s'.format(time.time() - t) + ])) + + if (epoch + 1) % args.eval_freq == 0: + torch.save( + model.state_dict(), + os.path.join(save_dir, 'ckpt_epoch{}.pth'.format(epoch + 1)) + ) + # model.eval() + # with torch.no_grad(): + # # phis = model.get_phi() + # phis, _ = model.get_phi() + # # phis = [phis.cpu().numpy()] + # phis = [phi.cpu().numpy() for phi in phis] + # corpus = train_loader.dataset.data.toarray() + # + # factorial_phi = 1 + # td_all_layers = [] + # tc_all_layers = [] + # for layer_id, phi in enumerate(phis): + # factorial_phi = np.dot(factorial_phi, phi) + # cur_td = topic_diversity(factorial_phi.T, top_k=25) + # cur_tc = topic_coherence(corpus, None, factorial_phi.T, top_k=10) + # td_all_layers.append(cur_td) + # tc_all_layers.append(cur_tc) + # print('Layer {}, \tTD: {:.6f}, \tTC: {:.6f}'.format(layer_id, cur_td, cur_tc)) + # + # logging.info("Epoch: {:04d}\t Diversity: {}".format(epoch + 1, td_all_layers)) + # logging.info("Epoch: {:04d}\t Coherence: {}".format(epoch + 1, tc_all_layers)) + # + # if np.mean(td_all_layers) > np.mean(best_diversity): + # best_diversity = td_all_layers + # best_diversity_epoch = epoch + 1 + # torch.save( + # model.state_dict(), + # os.path.join(save_dir, 'ckpt_best_diversity.pth') + # ) + # if np.mean(tc_all_layers) > np.mean(best_coherence): + # best_coherence = tc_all_layers + # best_coherence_epoch = epoch + 1 + # torch.save( + # model.state_dict(), + # os.path.join(save_dir, 'ckpt_best_coherence.pth') + # ) + + lr_scheduler.step() + + torch.save( + model.state_dict(), + os.path.join(save_dir, 'ckpt_last.pth') + ) + + logging.info("Optimization Finished!") + logging.info("Best diversity: {} at epoch {}".format(best_diversity, best_diversity_epoch)) + logging.info("Best coherence: {} at epoch {}".format(best_coherence, best_coherence_epoch)) + logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total)) + + # save_dir = './logs/20ng/2022_10_3/21' + # model.load_state_dict(torch.load( + # os.path.join(save_dir, 'ckpt_best_diversity.pth'), + # map_location=device + # )) + model.eval() + with torch.no_grad(): + # phis = model.get_phi() + phis, _ = model.get_phi() + visualize_topics(phis, save_dir, vocab, concepts=concept_names) + + +if __name__ == '__main__': + args = parser.parse_args() + main(args) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/build_graph.py b/utils/build_graph.py new file mode 100644 index 0000000..c16f4ee --- /dev/null +++ b/utils/build_graph.py @@ -0,0 +1,93 @@ +import numpy as np +import pickle + +from nltk.corpus import wordnet as wn +from time import time +from utils.data_util import TextDataset + + +datafile = '../data/20NG/20news_groups.pkl' +dataname = '20ng' +mode = 'train' + +print('==> Loading dataset...') +t0 = time() +entire_dataset = TextDataset(dataname, datafile, mode) +vocabulary = entire_dataset.vocab +vocab_size = len(vocabulary) +del entire_dataset +print("Done in %0.3fs." % (time() - t0)) + + +print('\n==> Extracting a subgraph from WordNet...') +t0 = time() +max_depth = 5 +graph_from_wordnet = dict() +for word in vocabulary: + try: + leaf_node = wn.synset(word + '.n.01') + path_to_root = leaf_node.hypernym_paths()[0] + + if len(path_to_root) > max_depth: + end_idx = max_depth + else: + end_idx = -1 + + for layer_id, concept in enumerate(path_to_root[: end_idx]): + layer_name = 'layer_' + str(layer_id) + if layer_name not in graph_from_wordnet.keys(): + graph_from_wordnet[layer_name] = [] + if concept not in graph_from_wordnet[layer_name]: + graph_from_wordnet[layer_name].append(concept) + except: + pass +print("Done in %0.3fs." % (time() - t0)) + + +num_topics_list = [] +for (k, v) in graph_from_wordnet.items(): + num_topics_list.append(len(v)) + + +print('\n==> Generating the adjacency matrix...') +t0 = time() +total_nodes = sum(num_topics_list) + vocab_size +adj_mat = np.eye(total_nodes).astype('float32') +all_concepts = list(graph_from_wordnet.values())[0] +for concept_group in list(graph_from_wordnet.values())[1:]: + all_concepts = all_concepts + concept_group +for m, concept in enumerate(all_concepts): + if concept.hypernyms(): + n = all_concepts.index(concept.hypernyms()[0]) + adj_mat[m, n] = 1.0 + adj_mat[n, m] = 1.0 +for word in vocabulary: + m = m + 1 + try: + leaf_node = wn.synset(word + '.n.01') + path_to_root = leaf_node.hypernym_paths()[0] + if len(path_to_root) > max_depth: + parent_node = path_to_root[max_depth - 1] + else: + parent_node = path_to_root[-2] + + n = all_concepts.index(parent_node) + adj_mat[m, n] = 1.0 + adj_mat[n, m] = 1.0 + except: + pass +print("Done in %0.3fs." % (time() - t0)) + + +taxonomy = dict() +for (k, v) in graph_from_wordnet.items(): + new_k = k.split('_')[-1] + new_v = [concept.name() for concept in v] + taxonomy[new_k] = new_v + + +print('\n==> Saving knowledge graph to .pkl file...') +t0 = time() +with open('../data/20NG/{}_wordnet_tree_3layer.pkl'.format(dataname), 'wb') as f: + pickle.dump([adj_mat[3:, 3:], num_topics_list[::-1][:-2], taxonomy], f) +print("Done in %0.3fs." % (time() - t0)) diff --git a/utils/data_util.py b/utils/data_util.py new file mode 100644 index 0000000..8b06ba5 --- /dev/null +++ b/utils/data_util.py @@ -0,0 +1,130 @@ +"""Data utils functions for pre-processing and data loading.""" + +import numpy as np +import pickle +import random +import torch.utils.data + + +class TextDataset(torch.utils.data.Dataset): + def __init__(self, name, path, mode='train'): + super(TextDataset, self).__init__() + with open(path, 'rb') as f: + data = pickle.load(f) + + if name in ['20ng', 'tmn', 'webs']: + # vocab = data['vocab'] + # train_bows = data['train_data'] + # train_labels = data['train_labels'] + # test_bows = data['test_data'] + # test_labels = data['test_labels'] + train_id = data['train_id'] + test_id = data['test_id'] + label = np.squeeze(np.array(data['label'])) + + train_bows = data['data_2000'][train_id] + train_labels = label[train_id] + test_bows = data['data_2000'][test_id] + test_labels = label[test_id] + vocab = data['voc2000'] + elif name == 'wiki': + vocab = data['vocab'] + train_bows = data['data'] + train_labels = None + test_bows = None + test_labels = None + elif name == 'rcv': + vocab = data['rcv2_voc'] + train_bows = data['rcv2_bow'] + train_labels = None + test_bows = None + test_labels = None + else: + raise NotImplementedError(f'unknown dataset: {name}') + + if mode == 'train': + self.data = train_bows + self.labels = train_labels + elif mode == 'test': + self.data = test_bows + self.labels = test_labels + else: + raise ValueError("argument 'mode' must be either train or test") + self.vocab = vocab + + if self.labels is not None: + assert self.data.shape[0] == len(self.labels) + + def __getitem__(self, index): + if self.labels is not None: + return self.data[index].toarray().squeeze(), self.labels[index] + else: + return self.data[index].toarray().squeeze(), 0 + + def __len__(self): + return self.data.shape[0] + + +class PPLDataset(torch.utils.data.Dataset): + def __init__(self, name, data_path='./data/20NG/20news_groups.pkl'): + super(PPLDataset, self).__init__() + with open(data_path, 'rb') as f: + data = pickle.load(f) + + if name in ['20ng', 'tmn']: + bows_matrix = np.concatenate( + data['train_data'].toarray(), data['test_data'].toarray(), axis=0 + ) + else: + bows_matrix = data['data'].toarray() + + context_ratio = 0.7 + context_bows = np.zeros_like(bows_matrix) + mask_bows = np.zeros_like(bows_matrix) + for doc_id, bow in enumerate(bows_matrix): + indices = np.nonzero(bow)[0] + indices_rep = [] + for idx in indices: + indices_rep += bow[idx] * [idx] + + random.seed(2022) + random.shuffle(indices_rep) + temp1 = indices_rep[:int(len(indices_rep) * context_ratio)] + temp2 = indices_rep[int(len(indices_rep) * context_ratio):] + for word_id in temp1: + context_bows[doc_id][word_id] += 1 + for word_id in temp2: + mask_bows[doc_id][word_id] += 1 + + self.context_data = context_bows + self.mask_data = mask_bows + self.vocab = data['vocab'] + + def __getitem__(self, index): + return torch.from_numpy(self.context_data[index].squeeze()).float(), \ + torch.from_numpy(self.mask_data[index].squeeze()).float() + + def __len__(self): + return self.context_data.shape[0] + + +def get_data_loader(data_name, data_path, mode='train', batch_size=200, shuffle=True, drop_last=True, num_workers=4): + dataset = TextDataset(name=data_name, path=data_path, mode=mode) + return torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + num_workers=num_workers + ), dataset.vocab + + +def get_ppl_dataloader(data_name, data_path, batch_size=200, shuffle=True, drop_last=True, num_workers=4): + dataset = PPLDataset(data_name, data_path) + return torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + num_workers=num_workers + ), dataset.vocab diff --git a/utils/eval_util.py b/utils/eval_util.py new file mode 100644 index 0000000..b006bda --- /dev/null +++ b/utils/eval_util.py @@ -0,0 +1,147 @@ +import numpy as np +from sklearn.cluster import KMeans +from sklearn.svm import LinearSVC, SVC +from sklearn.linear_model import LogisticRegression +from sklearn.metrics.cluster import normalized_mutual_info_score +from sklearn.metrics import average_precision_score, accuracy_score, f1_score + + +def topic_diversity(topic_matrix, top_k=25): + """ Topic Diversity (TD) measures how diverse the discovered topics are. + + We define topic diversity to be the percentage of unique words in the top 25 words (Dieng et al., 2020) + of the selected topics. TD close to 0 indicates redundant topics, TD close to 1 indicates more varied topics. + + Args: + topic_matrix: + top_k: + """ + num_topics = topic_matrix.shape[0] + top_words_idx = np.zeros((num_topics, top_k)) + for k in range(num_topics): + idx = np.argsort(topic_matrix[k, :])[::-1][:top_k] + top_words_idx[k, :] = idx + num_unique = len(np.unique(top_words_idx)) + num_total = num_topics * top_k + td = num_unique / num_total + # print('Topic diversity is: {}'.format(td)) + return td + + +def compute_npmi(corpus, word_i, word_j): + """ Pointwise Mutual Information (PMI) measures the association of a pair of outcomes x and y. + + PMI is defined as log[p(x, y)/p(x)p(y)], which can be further normalized between [-1, +1], resulting in + -1 (in the limit) for never occurring together, 0 for independence, and +1 for complete co-occurrence. + The Normalized PMI is computed by PMI(x, y) / [-log(x, y)]. + """ + num_docs = len(corpus) + num_docs_appear_wi = 0 + num_docs_appear_wj = 0 + num_docs_appear_both = 0 + for n in range(num_docs): + doc = corpus[n] + # doc = corpus[n].squeeze(0) + # doc = [doc.squeeze()] if len(doc) == 1 else doc.squeeze() + + # if word_i in doc: + if doc[word_i] > 0: + num_docs_appear_wi += 1 + # if word_j in doc: + if doc[word_j] > 0: + num_docs_appear_wj += 1 + # if [word_i, word_j] in doc: + if doc[word_i] > 0 and doc[word_j] > 0: + num_docs_appear_both += 1 + + if num_docs_appear_both == 0: + return -1 + else: + pmi = np.log(num_docs) + np.log(num_docs_appear_both) - \ + np.log(num_docs_appear_wi) - np.log(num_docs_appear_wj) + return pmi / (np.log(num_docs) - np.log(num_docs_appear_both)) + + +def topic_coherence(corpus, vocab, topic_matrix, top_k=10): + """ Topic Coherence measures the semantic coherence of top words in the discovered topics. + + We apply the widely-used Normalized Pointwise Mutual Information (NPMI) (Aletras & Stevenson, 2013; Lau et al., 2014) + computed over the top 10 words of each topic, by the Palmetto package (Röder et al., 2015). + + Args: + corpus: + vocab: + topic_matrix: + top_k: + """ + num_docs = len(corpus) + # print('Number of documents: ', num_docs) + + tc_list = [] + num_topics = topic_matrix.shape[0] + print('Number of topics: ', num_topics) + for k in range(num_topics): + # print('Topic Index: {}/{}'.format(k, num_topics)) + top_words_idx = np.argsort(topic_matrix[k, :])[::-1][:top_k] + # top_words = [vocab[idx] for idx in list(top_words_idx)] + + pairs_count = 0 + tc_k = 0 + # for i, word in enumerate(top_words): + for i, word in enumerate(top_words_idx): + for j in range(i + 1, top_k): + # tc_list.append(compute_npmi(corpus, word, top_words[j])) + # tc_list.append(compute_npmi(corpus, word, top_words_idx[j])) + tc_k += compute_npmi(corpus, word, top_words_idx[j]) + pairs_count += 1 + tc_list.append(tc_k / pairs_count) + + # tc = sum(tc_list) / (num_topics * pairs_count) + # print('Topic coherence is: {}'.format(tc)) + tc_list.sort(reverse=True) + half_num = int(num_topics/2) + return sum(tc_list[: half_num]) / half_num + + +def text_classification(train_data, train_labels, test_data, test_labels, algorithm='LR'): + if algorithm == 'LR': + clf = LogisticRegression(random_state=0, solver='liblinear', multi_class='ovr') + elif algorithm == 'SVM': + clf = SVC(random_state=0) + else: + raise NotImplementedError + + clf.fit(train_data, train_labels) + test_acc = accuracy_score(clf.predict(test_data), test_labels) + # print('Accuracy on the test set: {}'.format(test_acc)) + return test_acc + + +def text_clustering(data, labels_true, num_clusters=30): + # standardization + mu = np.mean(data, axis=-1, keepdims=True) + sigma = np.std(data, axis=-1, keepdims=True) + sigma = np.where(sigma > 0, sigma, 1) + data = (data - mu) / sigma + + # clustering based on Euclidean distance + estimator = KMeans(n_clusters=num_clusters, random_state=0) + estimator.fit(data) + labels_pred = estimator.labels_ + + purity_score = purity(labels_true, labels_pred) + nmi_score = normalized_mutual_info_score(labels_true, labels_pred) + # print('Clustering purity: {}'.format(purity_score)) + # print('Clustering nmi: {}'.format(nmi_score)) + return purity_score, nmi_score + + +def purity(labels_true, labels_pred): + clusters = np.unique(labels_pred) + counts = [] + for c in clusters: + indices = np.where(labels_pred == c)[0] + max_votes = np.bincount(labels_true[indices]).max() + counts.append(max_votes) + return sum(counts) / labels_true.shape[0] + diff --git a/utils/hyperbolicity.py b/utils/hyperbolicity.py new file mode 100644 index 0000000..5aebdd3 --- /dev/null +++ b/utils/hyperbolicity.py @@ -0,0 +1,46 @@ +import networkx as nx +import numpy as np +import pickle +import time +from tqdm import tqdm + + +def hyperbolicity_sample(G, num_samples=50000): + curr_time = time.time() + hyps = [] + for _ in tqdm(range(num_samples)): + curr_time = time.time() + node_tuple = np.random.choice(G.nodes(), 4, replace=False) + s = [] + try: + d01 = nx.shortest_path_length(G, source=node_tuple[0], target=node_tuple[1], weight=None) + d23 = nx.shortest_path_length(G, source=node_tuple[2], target=node_tuple[3], weight=None) + d02 = nx.shortest_path_length(G, source=node_tuple[0], target=node_tuple[2], weight=None) + d13 = nx.shortest_path_length(G, source=node_tuple[1], target=node_tuple[3], weight=None) + d03 = nx.shortest_path_length(G, source=node_tuple[0], target=node_tuple[3], weight=None) + d12 = nx.shortest_path_length(G, source=node_tuple[1], target=node_tuple[2], weight=None) + s.append(d01 + d23) + s.append(d02 + d13) + s.append(d03 + d12) + s.sort() + hyps.append((s[-1] - s[-2]) / 2) + except Exception as e: + continue + print('Time for hyp:', time.time() - curr_time) + return max(hyps) + + +if __name__ == '__main__': + kg_path = '../data/20news/20ng_wordnet_tree.pkl' + + print('==> Loading graph...') + with open(kg_path, 'rb') as f: + adj, n_nodes_per_layer, concepts = pickle.load(f) + adj = adj - np.eye(adj.shape[0]) + + graph = nx.from_numpy_array(adj) + print(f'Number of nodes: {graph.number_of_nodes()}') + print(f'Number of edges: {graph.number_of_edges()}') + print('\n==> Computing hyperbolicity...') + hyp = hyperbolicity_sample(graph) + print('Hyp:', hyp) diff --git a/utils/math_util.py b/utils/math_util.py new file mode 100644 index 0000000..23509e6 --- /dev/null +++ b/utils/math_util.py @@ -0,0 +1,221 @@ +"""Math utils functions.""" + +import torch +from torch import tan, atan, cos, acos, sin, asin +from torch import tanh, atanh, cosh, acosh, sinh, asinh + + +def Tan(x): + """Computes tangent of x element-wise. + + Args: + x (tensor): A tensor. + + Returns: + A tensor. Has the same type as x. + """ + return tan(x) + + +def Tanh(x): + """Computes hyperbolic tangent of x element-wise. + + Args: + x (tensor): A tensor. + + Returns: + A tensor: Has the same type as x. + """ + return tanh(torch.clamp(x, min=-15, max=15)) + + +def TanC(x, c): + """A unified tangent and inverse tangent function for different signs of curvatures. + + This function is used in k-Stereographic model, a unification of constant curvature manifolds. + Please refer to https://arxiv.org/abs/2007.07698 for more details. + + First-order expansion is used in order to calculate gradients correctly when c is zero. + + Args: + x (tensor): A tensor. + c (tensor): Manifold curvature. + + Returns: + A tensor: Has the same type of x. + """ + return 1 / c.abs().sqrt() * Tanh(x * c.abs().sqrt()) # c < 0 + + +def ArTan(x): + """Computes inverse tangent of x element-wise. + + Args: + x (tensor): A tensor. + + Returns: + A tensor: Has the same type as x. + """ + return atan(torch.clamp(x, min=-15, max=15)) + + +def ArTanh(x): + """Computes inverse hyperbolic tangent of x element-wise. + + Args: + x (tensor): A tensor. + + Returns: + A tensor: Has the same type as x. + """ + return atanh(torch.clamp(x, min=-1 + 1e-7, max=1 - 1e-7)) + + +def ArTanC(x, c): + """A unified hyperbolic tangent and inverse hyperbolic tangent function for different signs of curvatures. + + This function is used in k-Stereographic model, a unification of constant curvature manifolds. + Please refer to https://arxiv.org/abs/2007.07698 for more details. + + First-order expansion is used in order to calculate gradients correctly when c is zero. + + Args: + x (tensor): A tensor. + c (tensor): Manifold curvature. + + Returns: + A tensor: Has the same type of x. + """ + return 1 / c.abs().sqrt() * ArTanh(x * c.abs().sqrt()) # c < 0 + + +def Cos(x): + """Computes cosine of x element-wise. + + Args: + x (tensor): A tensor. + + Returns: + A tensor: Has the same type of x. + """ + return cos(x) + + +def Cosh(x): + """Computes hyperbolic cosine of x element-wise. + + Args: + x (tensor): A tensor. + + Returns: + A tensor: Has the same type of x. + """ + return cosh(torch.clamp(x, min=-15, max=15)) + + +def ArCos(x): + """Computes inverse cosine of x element-wise. + + Args: + x (tensor): A tensor. + + Returns: + A tensor: Has the same type of x. + """ + return acos(torch.clamp(x, min=-1 + 1e-7, max=1 - 1e-7)) + + +def ArCosh(x): + """Computes inverse hyperbolic cosine of x element-wise. + + Args: + x (tensor): A tensor. + + Returns: + A tensor: Has the same type of x. + """ + return acosh(torch.clamp(x, min=1 + 1e-7, max=1e15)) + + +def Sin(x): + """Computes sine of x element-wise. + + Args: + x (tensor): A tensor. + + Returns: + A tensor: Has the same type of x. + """ + return sin(x) + + +def Sinh(x): + """Computes hyperbolic sine of x element-wise. + + Args: + x (tensor): A tensor. + + Returns: + A tensor: Has the same type of x. + """ + return sinh(torch.clamp(x, min=-15, max=15)) + + +def SinC(x, c): + """A unified sine and inverse sine function for different signs of curvatures. + + This function is used in k-Stereographic model, a unification of constant curvature manifolds. + Please refer to https://arxiv.org/abs/2007.07698 for more details. + + First-order expansion is used in order to calculate gradients correctly when c is zero. + + Args: + x (tensor): A tensor. + c (tensor): Manifold curvature. + + Returns: + A tensor: Has the same type of x. + """ + return 1 / c.abs().sqrt() * Sinh(x * c.abs().sqrt()) # c < 0 + + +def ArSin(x): + """Computes inverse sine of x element-wise. + + Args: + x (tensor): A tensor. + + Returns: + A tensor: Has the same type of x. + """ + return asin(torch.clamp(x, min=-1 + 1e-7, max=1 - 1e-7)) + + +def ArSinh(x): + """Computes inverse hyperbolic sine of x element-wise. + + Args: + x (tensor): A tensor. + + Returns: + A tensor: Has the same type of x. + """ + return asinh(torch.clamp(x, min=-15, max=15)) + + +def ArSinC(x, c): + """A unified hyperbolic sine and inverse hyperbolic sine function for different signs of curvatures. + + This function is used in k-Stereographic model, a unification of constant curvature manifolds. + Please refer to https://arxiv.org/abs/2007.07698 for more details. + + First-order expansion is used in order to calculate gradients correctly when c is zero. + + Args: + x (tensor): A tensor. + c (tensor): Manifold curvature. + + Returns: + A tensor: Has the same type of x. + """ + return 1 / c.abs().sqrt() * ArSinh(x * c.abs().sqrt()) # c < 0 diff --git a/utils/train_util.py b/utils/train_util.py new file mode 100644 index 0000000..5df3738 --- /dev/null +++ b/utils/train_util.py @@ -0,0 +1,176 @@ +import argparse +import numpy as np +import os +from time import time +import scipy.sparse as sp +# import torch +# import torch.nn.functional as F +# import torch.nn.modules.loss +import torch + +from .eval_util import topic_diversity + + +def format_metrics(metrics, split): + """Format metric in metric dict for logging.""" + return " ".join([ + "{}_{}: {:.4f}".format(split, metric_name, metric_val) for metric_name, metric_val in metrics.items() + ]) + + +def get_dir_name(models_dir): + """Gets a directory to save the model. + + If the directory already exists, then append a new integer to the end of + it. This method is useful so that we don't overwrite existing models + when launching new jobs. + + Args: + models_dir: The directory where all the models are. + + Returns: + The name of a new directory to save the training logs and model weights. + """ + if not os.path.exists(models_dir): + save_dir = os.path.join(models_dir, '0') + os.makedirs(save_dir) + else: + existing_dirs = np.array( + [ + d + for d in os.listdir(models_dir) + if os.path.isdir(os.path.join(models_dir, d)) + ] + ).astype(np.int) + if len(existing_dirs) > 0: + dir_id = str(existing_dirs.max() + 1) + else: + dir_id = "1" + save_dir = os.path.join(models_dir, dir_id) + os.makedirs(save_dir) + return save_dir + + +def add_flags_from_config(parser, config_dict): + """Adds a flag (and default value) to an ArgumentParser for each parameter in a config. + """ + + def OrNone(default): + def func(x): + # Convert "none" to proper None object + if x.lower() == "none": + return None + # If default is None (and x is not None), return x without conversion as str + elif default is None: + return str(x) + # Otherwise, default has non-None type; convert x to that type + else: + return type(default)(x) + + return func + + for param in config_dict: + default, description = config_dict[param] + try: + if isinstance(default, dict): + parser = add_flags_from_config(parser, default) + elif isinstance(default, list): + if len(default) > 0: + # pass a list as argument + parser.add_argument( + f"--{param}", + action="append", + type=type(default[0]), + default=default, + help=description + ) + else: + pass + parser.add_argument(f"--{param}", action="append", default=default, help=description) + else: + pass + parser.add_argument(f"--{param}", type=OrNone(default), default=default, help=description) + except argparse.ArgumentError: + print( + f"Could not add flag for param {param} because it was already present." + ) + return parser + + +def convert_to_coo_adj(dense_adj): + """convert the dense adjacent matrix (numpy array) to sparse_coo matrix (torch Tensor). + """ + dense_adj = dense_adj - np.eye(dense_adj.shape[0]) + coo_mat = sp.coo_matrix(dense_adj) + edge_weights = coo_mat.data + edge_indices = np.vstack((coo_mat.row, coo_mat.col)) + return torch.sparse_coo_tensor( + indices=torch.from_numpy(edge_indices).long(), + values=torch.from_numpy(edge_weights).float(), + size=coo_mat.shape + ) + + +def load_glove_embeddings(embed_size, vocab): + """Initial word embeddings with pretrained glove embeddings if necessary. + """ + glove_path = './data/glove/glove.6B.{}d.txt'.format(embed_size) + + print('\n==> Loading pretrained glove embeddings...') + t0 = time() + embeddings_dict = dict() + with open(glove_path, 'r', encoding='utf-8') as f: + for line in f.readlines(): + values = line.split() + word = values[0] + embedding = np.asarray(values[1:], dtype=np.float32) + embeddings_dict[word] = embedding + print("Done in %0.3fs." % (time() - t0)) + + print('==> Initialize word embeddings with glove embeddings...') + t0 = time() + vocab_embeddings = list() + for word in vocab: + try: + vocab_embeddings.append(embeddings_dict[word]) + except: + vocab_embeddings.append(0.02 * np.random.randn(embed_size)) + print("Done in %0.3fs." % (time() - t0)) + + return np.array(vocab_embeddings, dtype=np.float32) + + +def get_top_n(phi_column, vocab, top_n=25): + top_n_words = '' + indices = np.argsort(-phi_column) + for n in range(top_n): + top_n_words += vocab[indices[n]] + top_n_words += ' ' + return top_n_words + + +def visualize_topics(phis, save_dir, vocab, top_n=25, concepts=None): + if isinstance(phis, list): + phis = [phi.cpu().numpy() for phi in phis] + else: + phis = [phis.cpu().numpy()] + + num_layers = len(phis) + factorial_phi = 1 + for layer_id, phi in enumerate(phis): + factorial_phi = np.dot(factorial_phi, phi) + cur_td = topic_diversity(factorial_phi.T, top_n) + + num_topics = factorial_phi.shape[1] + path = os.path.join(save_dir, 'phi_' + str(layer_id) + '.txt') + f = open(path, 'w') + for k in range(num_topics): + top_n_words = get_top_n( + factorial_phi[:, k], vocab, top_n) + if concepts is not None: + f.write('({})'.format(concepts[str(num_layers - layer_id)][k])) + f.write(top_n_words) + f.write('\n') + f.write('Topic diversity:{}'.format(cur_td)) + f.write('\n') + f.close()