Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
LogSemiring,
CheckpointSemiring,
CheckpointShardSemiring,
GumbelCRFSemiring,
KMaxSemiring,
SparseMaxSemiring,
MaxSemiring,
Expand Down Expand Up @@ -511,3 +512,16 @@ def test_lc_custom():
# s2 = struct.sum(vals)
# assert torch.isclose(s, s2).all()
# assert torch.isclose(marginals, marginals2).all()


@given(data())
def test_gumbel(data):
model = data.draw(sampled_from([LinearChain, SemiMarkov, DepTree]))
semiring = GumbelCRFSemiring(1.0)
test = test_lookup[model]()
struct = model(semiring)
vals, (batch, N) = test._rand()
vals.requires_grad_(True)
alpha = struct.marginals(vals)
print(alpha[0])
print(torch.autograd.grad(alpha, vals, alpha.detach())[0][0])
8 changes: 8 additions & 0 deletions torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
MultiSampledSemiring,
KMaxSemiring,
StdSemiring,
GumbelCRFSemiring,
)


Expand Down Expand Up @@ -183,6 +184,13 @@ def count(self):
ones[self.log_potentials.eq(-float("inf"))] = 0
return self._struct(StdSemiring).sum(ones, self.lengths)

def gumbel_crf(self, temperature=1.0):
with torch.enable_grad():
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(
self.log_potentials, self.lengths
)
return st_gumbel

# @constraints.dependent_property
# def support(self):
# pass
Expand Down
9 changes: 8 additions & 1 deletion torch_struct/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def backward(ctx, grad_v):

return DPManual.apply(edge)

def marginals(self, edge, lengths=None, _autograd=True, _raw=False):
def marginals(self, edge, lengths=None, _autograd=True, _raw=False, _combine=False):
"""
Compute the marginals of a structured model.

Expand Down Expand Up @@ -135,6 +135,13 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False):
)
all_m.append(self.semiring.unconvert(self._arrange_marginals(marg)))
return torch.stack(all_m, dim=0)
elif _combine:
obj = v.sum(dim=0).sum(dim=0)
marg = torch.autograd.grad(
obj, edges, create_graph=True, only_inputs=True, allow_unused=False
)
a_m = self._arrange_marginals(marg)
return a_m
else:
obj = self.semiring.unconvert(v).sum(dim=0)
marg = torch.autograd.grad(
Expand Down
100 changes: 100 additions & 0 deletions torch_struct/semirings/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,106 @@ def sum(xs, dim=-1):
return _SampledLogSumExp.apply(xs, dim)


def GumbelMaxSemiring(temp):
class _GumbelMaxLogSumExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim):
ctx.save_for_backward(input, torch.tensor(dim))
return torch.logsumexp(input, dim=dim)

@staticmethod
def backward(ctx, grad_output):
logits, dim = ctx.saved_tensors
grad_input = None
if ctx.needs_input_grad[0]:

def sample(ls):
pre_shape = ls.shape
update = (
ls + torch.distributions.Gumbel(0, 1).sample((ls.shape[-1],))
) / temp
out = torch.nn.functional.one_hot(update.max(-1)[1], pre_shape[-1])
return out

if dim == -1:
s = sample(logits)
else:
dim = dim if dim >= 0 else logits.dim() + dim
perm = [i for i in range(logits.dim()) if i != dim] + [dim]
rev_perm = [
a for a, b in sorted(enumerate(perm), key=lambda a: a[1])
]
s = sample(logits.permute(perm)).permute(rev_perm)

grad_input = grad_output.unsqueeze(dim).mul(s)
return grad_input, None

class _GumbelMaxSemiring(_BaseLog):
@staticmethod
def sum(xs, dim=-1):
return _GumbelMaxLogSumExp.apply(xs, dim)

return _GumbelMaxSemiring


def GumbelCRFSemiring(temp):
class ST(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, dim):
out = torch.nn.functional.one_hot(logits.max(-1)[1], dim)
out = out.type_as(logits)
ctx.save_for_backward(logits, out)
return out

@staticmethod
def backward(ctx, grad_output):
logits, out = ctx.saved_tensors
with torch.enable_grad():
ret = torch.autograd.grad(
logits.softmax(-1), logits, out * grad_output
)[0]
return ret, None

class _GumbelCRFLogSumExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim):
ctx.save_for_backward(input, torch.tensor(dim))
return torch.logsumexp(input, dim=dim)

@staticmethod
def backward(ctx, grad_output):
logits, dim = ctx.saved_tensors
grad_input = None
if ctx.needs_input_grad[0]:

def sample(ls):
update = (
ls + torch.distributions.Gumbel(0, 1).sample((ls.shape[-1],))
) / temp
out = ST.apply(update, ls.shape[-1])
return out

if dim == -1:
s = sample(logits)
else:
dim = dim if dim >= 0 else logits.dim() + dim
perm = [i for i in range(logits.dim()) if i != dim] + [dim]
rev_perm = [
a for a, b in sorted(enumerate(perm), key=lambda a: a[1])
]
s = sample(logits.permute(perm)).permute(rev_perm)

grad_input = grad_output.unsqueeze(dim).mul(s)
return grad_input, None

class _GumbelCRFSemiring(_BaseLog):
@staticmethod
def sum(xs, dim=-1):
return _GumbelCRFLogSumExp.apply(xs, dim)

return _GumbelCRFSemiring


bits = torch.tensor([pow(2, i) for i in range(1, 18)])


Expand Down
7 changes: 5 additions & 2 deletions torch_struct/semirings/semirings.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,12 @@ class KLDivergenceSemiring(Semiring):

Based on descriptions in:

* Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter`
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
* Parameter estimation for probabilistic finite-state
transducers :cite:`eisner2002parameter`
* First-and second-order expectation semirings with applications to
minimumrisk training on translation forests :cite:`li2009first`
* Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf`

"""

zero = 0
Expand Down