Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
NoviceStone authored Oct 16, 2022
1 parent f3f0690 commit 30bd802
Show file tree
Hide file tree
Showing 18 changed files with 2,327 additions and 0 deletions.
53 changes: 53 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import argparse
from utils.train_util import add_flags_from_config

config_args = {
'data_config': {
'seed': (2022, 'manual random seed'),
'cuda': (0, 'which cuda device to use'),
'dataset': ('20ng', 'which dataset to use'),
'data-path': ('./data/20NG/20ng.pkl', 'path to load data'),
'batch-size': (200, 'number of examples in a mini-batch'),
},
'training_config': {
'epochs': (200, 'number of epochs to train'),
'lr': (0.01, 'initial learning rate'),
'lr-reduce-freq': (None, 'reduce lr every lr-reduce-freq or None to keep lr constant'),
'gamma': (0.1, 'gamma for lr scheduler'),
'dropout': (0, 'dropout probability'),
'momentum': (0.999, 'momentum in optimizer'),
'weight-decay': (1e-5, 'l2 regularization strength'),
'optimizer': ('Adam', 'which optimizer to use, can be any of [Adam, RiemannianAdam]'),
'grad-clip': (None, 'max norm for gradient clipping, or None for no gradient clipping'),
'log-freq': (1, 'how often to compute print train/val metrics (in epochs)'),
'eval-freq': (50, 'how often to compute val metrics (in epochs)'),
'save': (1, '1 to save model and logs and 0 otherwise'),
'save-dir': (None, 'path to save training logs and model weights (defaults to logs/task/date/run/)'),
'sweep-c': (0, ''),
'print-epoch': (True, ''),
},
'model_config': {
# topic model
'vocab-size': (2000, 'vocabulary size'), # 13368
'embed-size': (2, 'dimensionality of word and topic embeddings'), # (683, 366, 84, 11, 2) # (810, 408, 91, 13, 2)
'num-topics-list': ([185, 66, 11, 2], 'number of topics in each latent layer'), # (560, 325, 83, 12, 2)
'num-hiddens-list': ([300, 300, 300, 300], 'number of units in each hidden layer'),
'pretrained-embeddings': (False, 'whether to use pretrained embeddings to initialize words and topics'),
'manifold': ('PoincareBall', 'which manifold to use, can be any of [Euclidean, Hyperboloid, PoincareBall]'),
'c': (-1.0, 'hyperbolic radius, set to None for trainable curvature'),
'clip_r': (8.0, 'avoid the vanishing gradients problem'),
# hyperbolic gcn
'add-knowledge': (True, 'whether inject prior knowledge to topic modeling'),
'file-path': ('./data/20NG/20ng_wordnet_tree_4layers.pkl', 'path to load tree knowledge'),
'gcn-layers': (2, 'number of hidden layers in graph encoder'),
'bias': (1, 'whether to use bias (1) or not (0)'),
'use-att': (1, 'whether to use hyperbolic attention or not'),
'local-agg': (0, 'whether to local tangent space aggregation or not'),
'act': ('relu', 'which activation function to use (or None for no activation)'),
'double-precision': ('0', 'whether to use double precision'),
},
}

parser = argparse.ArgumentParser()
for _, config_dict in config_args.items():
parser = add_flags_from_config(parser, config_dict)
3 changes: 3 additions & 0 deletions manifolds/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .euclidean import Euclidean
from .hyperboloid import Hyperboloid
from .poincare import PoincareBall
170 changes: 170 additions & 0 deletions manifolds/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Base manifold."""

import torch


class Manifold(object):
"""Abstract class to define basic operations on a manifold.
Attributes:
clip (function): Clips tensor values to a specified min and max.
dtype: The type of the variables.
eps (float): A small constant value.
max_norm (float): The maximum value for number clipping.
min_norm (float): The minimum value for number clipping.
"""

def __init__(self, **kwargs):
"""Initialize a manifold.
"""
super(Manifold, self).__init__()

self.min_norm = 1e-15
self.max_norm = 1e15
self.eps = 1e-5

self.dtype = kwargs["dtype"] if "dtype" in kwargs else torch.float32
self.clip = lambda x: torch.clamp(x, min=self.min_norm, max=self.max_norm)

def proj(self, x, c):
"""A projection function that prevents x from leaving the manifold.
Args:
x (tensor): A point should be on the manifold, but it may not meet the manifold constraints.
c (tensor): The manifold curvature.
Returns:
tensor: A projected point, meeting the manifold constraints.
"""
raise NotImplementedError

def proj_tan(self, v, x, c):
"""A projection function that prevents v from leaving the tangent space of point x.
Args:
v (tensor): A point should be on the tangent space, but it may not meet the manifold constraints.
x (tensor): A point on the manifold.
c (tensor): The manifold curvature.
Returns:
tensor: A projected point, meeting the tangent space constraints.
"""
raise NotImplementedError

def proj_tan0(self, v, c):
"""A projection function that prevents v from leaving the tangent space of origin point.
Args:
v (tensor): A point should be on the tangent space, but it may not meet the manifold constraints.
c (tensor): The manifold curvature.
Returns:
tensor: A projected point, meeting the tangent space constraints.
"""
raise NotImplementedError

def expmap(self, v, x, c):
"""Map a point v in the tangent space of point x to the manifold.
Args:
v (tensor): A point in the tangent space of point x.
x (tensor): A point on the manifold.
c (tensor): The manifold curvature.
Returns:
tensor: The result of mapping tangent point v to the manifold.
"""
raise NotImplementedError

def expmap0(self, v, c):
"""Map a point v in the tangent space of origin point to the manifold.
Args:
v (tensor): A point in the tangent space of origin point.
c (tensor): The manifold curvature.
Returns:
tensor: The result of mapping tangent point v to the manifold.
"""
raise NotImplementedError

def logmap(self, y, x, c):
"""Map a point y on the manifold to the tangent space of x.
Args:
y (tensor): A point on the manifold.
x (tensor): A point on the manifold.
c (tensor): The manifold curvature.
Returns:
tensor: The result of mapping y to the tangent space of x.
"""
raise NotImplementedError

def logmap0(self, y, c):
"""Map a point y on the manifold to the tangent space of origin point.
Args:
y (tensor): A point on the manifold.
c (tensor): The manifold curvature.
Returns:
tensor: The result of mapping y to the tangent space of origin point.
"""
raise NotImplementedError

def ptransp(self, v, x, y, c):
"""Parallel transport function, used to move point v in the tangent space of x to the tangent space of y.
Args:
v (tensor): A point in the tangent space of x.
x (tensor): A point on the manifold.
y (tensor): A point on the manifold.
c (tensor): The manifold curvature.
Returns:
tensor: The result of transporting v from the tangent space at x to the tangent space at y.
"""
raise NotImplementedError

def ptransp0(self, v, x, c):
"""Parallel transport function, used to move point v in the tangent space of origin point to the tangent space of y.
Args:
v (tensor): A point in the tangent space of origin point.
x (tensor): A point on the manifold.
c (tensor): The manifold curvature.
Returns:
tensor: The result of transporting v from the tangent space at origin point to the tangent space at y.
"""
raise NotImplementedError

def dist(self, x, y, c):
"""Calculate the squared geodesic/distance between x and y.
Args:
x (tensor): A point on the manifold.
y (tensor): A point on the manifold.
c (tensor): The manifold curvature.
Returns:
tensor: the geodesic/distance between x and y.
"""
raise NotImplementedError

def egrad2rgrad(self, grad, x, c):
"""Computes Riemannian gradient from the Euclidean gradient, typically used in Riemannian optimizers.
Args:
grad (tensor): Euclidean gradient at x.
x (tensor): A point on the manifold.
c (tensor): The manifold curvature.
Returns:
tensor: Riemannian gradient at x.
"""
raise NotImplementedError

def inner(self, v1, v2, x, c, keep_shape):
"""Computes the inner product of a pair of tangent vectors v1 and v2 at x.
Args:
v1 (tensor): A tangent point at x.
v2 (tensor): A tangent point at x.
x (tensor): A point on the manifold.
c (tensor): The manifold curvature.
keep_shape (bool, optional): Whether the output tensor keeps shape or not.
Returns:
tensor: The inner product of v1 and v2 at x.
"""
raise NotImplementedError

def retraction(self, v, x, c):
"""Retraction is a continuous map function from tangent space to the manifold, typically used in Riemannian optimizers.
The exp map is one of retraction functions.
Args:
v (tensor): A tangent point at x.
x (tensor): A point on the manifold.
c (tensor): The manifold curvature.
Returns:
tensor: The result of mapping tangent point v at x to the manifold.
"""
return self.proj(self.expmap(v, x, c), c)
61 changes: 61 additions & 0 deletions manifolds/euclidean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Euclidean manifold."""

import torch
from manifolds.base import Manifold


class Euclidean(Manifold):
"""Euclidean Manifold class. Usually we refer it as R^n.
Attributes:
name (str): The manifold name, and its value is "Euclidean".
"""

def __init__(self, **kwargs):
"""Initialize an Euclidean manifold.
Args:
**kwargs: Description
"""
super(Euclidean, self).__init__(**kwargs)
self.name = 'Euclidean'

def proj(self, x, c):
return x

def proj_tan(self, v, x, c):
return v

def proj_tan0(self, v, c):
return v

def expmap(self, v, x, c):
return v + x

def expmap0(self, v, c):
return v

def logmap(self, y, x, c):
return y - x

def logmap0(self, y, c):
return y

def ptransp(self, v, x, y, c):
return torch.ones_like(x) * v

def ptransp0(self, v, x, c):
return torch.ones_like(x) * v

def dist(self, x, y, c):
sqdis = torch.sum((x - y).pow(2), dim=-1)
return sqdis.sqrt()

def egrad2rgrad(self, grad, x, c):
return grad

def inner(self, v1, v2, x, c, keep_shape=False):
if keep_shape:
# In order to keep the same computation logic in Ada* Optimizer
return v1 * v2
else:
return torch.sum(v1 * v2, dim=-1, keepdim=False)
Loading

0 comments on commit 30bd802

Please sign in to comment.