-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f3f0690
commit 30bd802
Showing
18 changed files
with
2,327 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import argparse | ||
from utils.train_util import add_flags_from_config | ||
|
||
config_args = { | ||
'data_config': { | ||
'seed': (2022, 'manual random seed'), | ||
'cuda': (0, 'which cuda device to use'), | ||
'dataset': ('20ng', 'which dataset to use'), | ||
'data-path': ('./data/20NG/20ng.pkl', 'path to load data'), | ||
'batch-size': (200, 'number of examples in a mini-batch'), | ||
}, | ||
'training_config': { | ||
'epochs': (200, 'number of epochs to train'), | ||
'lr': (0.01, 'initial learning rate'), | ||
'lr-reduce-freq': (None, 'reduce lr every lr-reduce-freq or None to keep lr constant'), | ||
'gamma': (0.1, 'gamma for lr scheduler'), | ||
'dropout': (0, 'dropout probability'), | ||
'momentum': (0.999, 'momentum in optimizer'), | ||
'weight-decay': (1e-5, 'l2 regularization strength'), | ||
'optimizer': ('Adam', 'which optimizer to use, can be any of [Adam, RiemannianAdam]'), | ||
'grad-clip': (None, 'max norm for gradient clipping, or None for no gradient clipping'), | ||
'log-freq': (1, 'how often to compute print train/val metrics (in epochs)'), | ||
'eval-freq': (50, 'how often to compute val metrics (in epochs)'), | ||
'save': (1, '1 to save model and logs and 0 otherwise'), | ||
'save-dir': (None, 'path to save training logs and model weights (defaults to logs/task/date/run/)'), | ||
'sweep-c': (0, ''), | ||
'print-epoch': (True, ''), | ||
}, | ||
'model_config': { | ||
# topic model | ||
'vocab-size': (2000, 'vocabulary size'), # 13368 | ||
'embed-size': (2, 'dimensionality of word and topic embeddings'), # (683, 366, 84, 11, 2) # (810, 408, 91, 13, 2) | ||
'num-topics-list': ([185, 66, 11, 2], 'number of topics in each latent layer'), # (560, 325, 83, 12, 2) | ||
'num-hiddens-list': ([300, 300, 300, 300], 'number of units in each hidden layer'), | ||
'pretrained-embeddings': (False, 'whether to use pretrained embeddings to initialize words and topics'), | ||
'manifold': ('PoincareBall', 'which manifold to use, can be any of [Euclidean, Hyperboloid, PoincareBall]'), | ||
'c': (-1.0, 'hyperbolic radius, set to None for trainable curvature'), | ||
'clip_r': (8.0, 'avoid the vanishing gradients problem'), | ||
# hyperbolic gcn | ||
'add-knowledge': (True, 'whether inject prior knowledge to topic modeling'), | ||
'file-path': ('./data/20NG/20ng_wordnet_tree_4layers.pkl', 'path to load tree knowledge'), | ||
'gcn-layers': (2, 'number of hidden layers in graph encoder'), | ||
'bias': (1, 'whether to use bias (1) or not (0)'), | ||
'use-att': (1, 'whether to use hyperbolic attention or not'), | ||
'local-agg': (0, 'whether to local tangent space aggregation or not'), | ||
'act': ('relu', 'which activation function to use (or None for no activation)'), | ||
'double-precision': ('0', 'whether to use double precision'), | ||
}, | ||
} | ||
|
||
parser = argparse.ArgumentParser() | ||
for _, config_dict in config_args.items(): | ||
parser = add_flags_from_config(parser, config_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .euclidean import Euclidean | ||
from .hyperboloid import Hyperboloid | ||
from .poincare import PoincareBall |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
"""Base manifold.""" | ||
|
||
import torch | ||
|
||
|
||
class Manifold(object): | ||
"""Abstract class to define basic operations on a manifold. | ||
Attributes: | ||
clip (function): Clips tensor values to a specified min and max. | ||
dtype: The type of the variables. | ||
eps (float): A small constant value. | ||
max_norm (float): The maximum value for number clipping. | ||
min_norm (float): The minimum value for number clipping. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
"""Initialize a manifold. | ||
""" | ||
super(Manifold, self).__init__() | ||
|
||
self.min_norm = 1e-15 | ||
self.max_norm = 1e15 | ||
self.eps = 1e-5 | ||
|
||
self.dtype = kwargs["dtype"] if "dtype" in kwargs else torch.float32 | ||
self.clip = lambda x: torch.clamp(x, min=self.min_norm, max=self.max_norm) | ||
|
||
def proj(self, x, c): | ||
"""A projection function that prevents x from leaving the manifold. | ||
Args: | ||
x (tensor): A point should be on the manifold, but it may not meet the manifold constraints. | ||
c (tensor): The manifold curvature. | ||
Returns: | ||
tensor: A projected point, meeting the manifold constraints. | ||
""" | ||
raise NotImplementedError | ||
|
||
def proj_tan(self, v, x, c): | ||
"""A projection function that prevents v from leaving the tangent space of point x. | ||
Args: | ||
v (tensor): A point should be on the tangent space, but it may not meet the manifold constraints. | ||
x (tensor): A point on the manifold. | ||
c (tensor): The manifold curvature. | ||
Returns: | ||
tensor: A projected point, meeting the tangent space constraints. | ||
""" | ||
raise NotImplementedError | ||
|
||
def proj_tan0(self, v, c): | ||
"""A projection function that prevents v from leaving the tangent space of origin point. | ||
Args: | ||
v (tensor): A point should be on the tangent space, but it may not meet the manifold constraints. | ||
c (tensor): The manifold curvature. | ||
Returns: | ||
tensor: A projected point, meeting the tangent space constraints. | ||
""" | ||
raise NotImplementedError | ||
|
||
def expmap(self, v, x, c): | ||
"""Map a point v in the tangent space of point x to the manifold. | ||
Args: | ||
v (tensor): A point in the tangent space of point x. | ||
x (tensor): A point on the manifold. | ||
c (tensor): The manifold curvature. | ||
Returns: | ||
tensor: The result of mapping tangent point v to the manifold. | ||
""" | ||
raise NotImplementedError | ||
|
||
def expmap0(self, v, c): | ||
"""Map a point v in the tangent space of origin point to the manifold. | ||
Args: | ||
v (tensor): A point in the tangent space of origin point. | ||
c (tensor): The manifold curvature. | ||
Returns: | ||
tensor: The result of mapping tangent point v to the manifold. | ||
""" | ||
raise NotImplementedError | ||
|
||
def logmap(self, y, x, c): | ||
"""Map a point y on the manifold to the tangent space of x. | ||
Args: | ||
y (tensor): A point on the manifold. | ||
x (tensor): A point on the manifold. | ||
c (tensor): The manifold curvature. | ||
Returns: | ||
tensor: The result of mapping y to the tangent space of x. | ||
""" | ||
raise NotImplementedError | ||
|
||
def logmap0(self, y, c): | ||
"""Map a point y on the manifold to the tangent space of origin point. | ||
Args: | ||
y (tensor): A point on the manifold. | ||
c (tensor): The manifold curvature. | ||
Returns: | ||
tensor: The result of mapping y to the tangent space of origin point. | ||
""" | ||
raise NotImplementedError | ||
|
||
def ptransp(self, v, x, y, c): | ||
"""Parallel transport function, used to move point v in the tangent space of x to the tangent space of y. | ||
Args: | ||
v (tensor): A point in the tangent space of x. | ||
x (tensor): A point on the manifold. | ||
y (tensor): A point on the manifold. | ||
c (tensor): The manifold curvature. | ||
Returns: | ||
tensor: The result of transporting v from the tangent space at x to the tangent space at y. | ||
""" | ||
raise NotImplementedError | ||
|
||
def ptransp0(self, v, x, c): | ||
"""Parallel transport function, used to move point v in the tangent space of origin point to the tangent space of y. | ||
Args: | ||
v (tensor): A point in the tangent space of origin point. | ||
x (tensor): A point on the manifold. | ||
c (tensor): The manifold curvature. | ||
Returns: | ||
tensor: The result of transporting v from the tangent space at origin point to the tangent space at y. | ||
""" | ||
raise NotImplementedError | ||
|
||
def dist(self, x, y, c): | ||
"""Calculate the squared geodesic/distance between x and y. | ||
Args: | ||
x (tensor): A point on the manifold. | ||
y (tensor): A point on the manifold. | ||
c (tensor): The manifold curvature. | ||
Returns: | ||
tensor: the geodesic/distance between x and y. | ||
""" | ||
raise NotImplementedError | ||
|
||
def egrad2rgrad(self, grad, x, c): | ||
"""Computes Riemannian gradient from the Euclidean gradient, typically used in Riemannian optimizers. | ||
Args: | ||
grad (tensor): Euclidean gradient at x. | ||
x (tensor): A point on the manifold. | ||
c (tensor): The manifold curvature. | ||
Returns: | ||
tensor: Riemannian gradient at x. | ||
""" | ||
raise NotImplementedError | ||
|
||
def inner(self, v1, v2, x, c, keep_shape): | ||
"""Computes the inner product of a pair of tangent vectors v1 and v2 at x. | ||
Args: | ||
v1 (tensor): A tangent point at x. | ||
v2 (tensor): A tangent point at x. | ||
x (tensor): A point on the manifold. | ||
c (tensor): The manifold curvature. | ||
keep_shape (bool, optional): Whether the output tensor keeps shape or not. | ||
Returns: | ||
tensor: The inner product of v1 and v2 at x. | ||
""" | ||
raise NotImplementedError | ||
|
||
def retraction(self, v, x, c): | ||
"""Retraction is a continuous map function from tangent space to the manifold, typically used in Riemannian optimizers. | ||
The exp map is one of retraction functions. | ||
Args: | ||
v (tensor): A tangent point at x. | ||
x (tensor): A point on the manifold. | ||
c (tensor): The manifold curvature. | ||
Returns: | ||
tensor: The result of mapping tangent point v at x to the manifold. | ||
""" | ||
return self.proj(self.expmap(v, x, c), c) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""Euclidean manifold.""" | ||
|
||
import torch | ||
from manifolds.base import Manifold | ||
|
||
|
||
class Euclidean(Manifold): | ||
"""Euclidean Manifold class. Usually we refer it as R^n. | ||
Attributes: | ||
name (str): The manifold name, and its value is "Euclidean". | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
"""Initialize an Euclidean manifold. | ||
Args: | ||
**kwargs: Description | ||
""" | ||
super(Euclidean, self).__init__(**kwargs) | ||
self.name = 'Euclidean' | ||
|
||
def proj(self, x, c): | ||
return x | ||
|
||
def proj_tan(self, v, x, c): | ||
return v | ||
|
||
def proj_tan0(self, v, c): | ||
return v | ||
|
||
def expmap(self, v, x, c): | ||
return v + x | ||
|
||
def expmap0(self, v, c): | ||
return v | ||
|
||
def logmap(self, y, x, c): | ||
return y - x | ||
|
||
def logmap0(self, y, c): | ||
return y | ||
|
||
def ptransp(self, v, x, y, c): | ||
return torch.ones_like(x) * v | ||
|
||
def ptransp0(self, v, x, c): | ||
return torch.ones_like(x) * v | ||
|
||
def dist(self, x, y, c): | ||
sqdis = torch.sum((x - y).pow(2), dim=-1) | ||
return sqdis.sqrt() | ||
|
||
def egrad2rgrad(self, grad, x, c): | ||
return grad | ||
|
||
def inner(self, v1, v2, x, c, keep_shape=False): | ||
if keep_shape: | ||
# In order to keep the same computation logic in Ada* Optimizer | ||
return v1 * v2 | ||
else: | ||
return torch.sum(v1 * v2, dim=-1, keepdim=False) |
Oops, something went wrong.