Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions permutect/architecture/artifact_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,29 @@ def __init__(self, num_components: int):
n and k are 1D tensors, the only dimension being batch.
'''
def forward(self, types_one_hot_bv, depths_b, alt_counts_b):
device = next(self.parameters()).device
depths_b = depths_b.to(device)
types_one_hot_bv = types_one_hot_bv.to(device)
alt_counts_b = alt_counts_b.to(device)
alt_counts_bk = torch.unsqueeze(alt_counts_b, dim=1).expand(-1, self.K - 1)
depths_bk = torch.unsqueeze(depths_b, dim=1).expand(-1, self.K - 1)
depths_bvk = depths_bk[:, None, :]

self.alpha0_pre_exp_vk
eta_vk = torch.exp(self.eta_pre_exp_vk)
delta_vk = torch.exp(self.delta_pre_exp_vk)
alpha0_pre_exp_bvk, eta_bvk, delta_bvk = self.alpha0_pre_exp_vk[None, :, :], eta_vk[None, :, :], delta_vk[None, :, :]
self.alpha0_pre_exp_vk.to(device)
eta_vk = torch.exp(self.eta_pre_exp_vk).to(device)
delta_vk = torch.exp(self.delta_pre_exp_vk).to(device)
alpha0_pre_exp_bvk, eta_bvk, delta_bvk = self.alpha0_pre_exp_vk[None, :, :].to(device), eta_vk[None, :, :].to(device), delta_vk[None, :, :].to(device)
weights0_pre_softmax_bvk, gamma_bvk, kappa_bvk = self.weights0_pre_softmax_vk[None, :, :], self.gamma_vk[None, :, :], self.kappa_vk[None, :, :]

alpha_bvk = torch.exp(alpha0_pre_exp_bvk - eta_bvk * torch.sigmoid(depths_bvk * delta_bvk))
alpha_bvk = torch.exp(alpha0_pre_exp_bvk - eta_bvk * torch.sigmoid(depths_bvk * delta_bvk).to(device))

types_one_hot_bvk = torch.unsqueeze(types_one_hot_bv, dim=-1) # gives it broadcastable length-1 component dimension
alpha_bk = torch.sum(types_one_hot_bvk * alpha_bvk, dim=1) # due to one-hotness only one v contributes to the sum
beta_bk = self.beta * torch.ones_like(alpha_bk)

beta_binomial_likelihoods_bk = beta_binomial(depths_bk, alt_counts_bk, alpha_bk, beta_bk)

weights_pre_softmax_bvk = weights0_pre_softmax_bvk + gamma_bvk * torch.sigmoid(depths_bvk * kappa_bvk)
weights_pre_softmax_bvk = weights0_pre_softmax_bvk + gamma_bvk * torch.sigmoid(depths_bvk * kappa_bvk).to(device)

log_weights_bvk = torch.log_softmax(weights_pre_softmax_bvk, dim=-1) # softmax over component dimension
log_weights_bk = torch.sum(types_one_hot_bvk * log_weights_bvk, dim=1) # same idea as above
Expand All @@ -99,24 +103,26 @@ def fit(self, num_epochs, types_one_hot_2d, depths_1d_tensor, alt_counts_1d_tens
get raw data for a spectrum plot of probability density vs allele fraction for a particular variant type
'''
def spectrum_density_vs_fraction(self, variant_type: Variation, depth: int):
fractions_f = torch.arange(0.01, 0.99, 0.001) # 1D tensor
device = self.weights0_pre_softmax_vk.device # Get device from model parameter

fractions_f = torch.arange(0.01, 0.99, 0.001, device=device) # 1D tensor

weights0_pre_softmax_k = self.weights0_pre_softmax_vk[variant_type]
gamma_k = self.gamma_vk[variant_type]
kappa_k = self.kappa_vk[variant_type]
alpha0_pre_exp_k = self.alpha0_pre_exp_vk[variant_type]
eta_k = torch.exp(self.eta_pre_exp_vk[variant_type])
delta_k = torch.exp(self.delta_pre_exp_vk[variant_type])
weights0_pre_softmax_k = self.weights0_pre_softmax_vk[variant_type].to(device)
gamma_k = self.gamma_vk[variant_type].to(device)
kappa_k = self.kappa_vk[variant_type].to(device)
alpha0_pre_exp_k = self.alpha0_pre_exp_vk[variant_type].to(device)
eta_k = torch.exp(self.eta_pre_exp_vk[variant_type].to(device))
delta_k = torch.exp(self.delta_pre_exp_vk[variant_type].to(device))

alpha_k = torch.exp(alpha0_pre_exp_k - eta_k * torch.sigmoid(depth * delta_k))
weights_pre_softmax_k = weights0_pre_softmax_k + gamma_k * torch.sigmoid(depth * kappa_k)

log_weights_k = torch.log_softmax(weights_pre_softmax_k, dim=0)
beta_k = self.beta * torch.ones_like(alpha_k)
beta_k = self.beta * torch.ones_like(alpha_k, device=device)

dist = torch.distributions.beta.Beta(alpha_k, beta_k)
dist = torch.distributions.beta.Beta(alpha_k.to(device), beta_k.to(device))

log_densities_fk = dist.log_prob(fractions_f.unsqueeze(dim=-1))
log_densities_fk = dist.log_prob(fractions_f.unsqueeze(dim=-1).to(device))
log_weights_fk = log_weights_k.unsqueeze(dim=0)

log_weighted_densities_fk = log_weights_fk + log_densities_fk
Expand Down
6 changes: 6 additions & 0 deletions permutect/architecture/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def __init__(self, layer_sizes: List[int], batch_normalize: bool = False, dropou
self._output_dim = input_dim
self._model = nn.Sequential(*layers)

def forward(self, x: Tensor) -> Tensor:
# Ensure input is on same device as model
device = next(self.parameters()).device
x = x.to(device)
return self._model.forward(x)

def input_dimension(self) -> int:
return self._input_dim

Expand Down
17 changes: 11 additions & 6 deletions permutect/architecture/normal_seq_error_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,31 @@ def __init__(self, num_samples: int, max_mean: float):
self.mean_pre_sigmoid = torch.nn.Parameter(torch.tensor(0.0))

def forward(self, alt_counts_1d: torch.Tensor, ref_counts_1d: torch.Tensor):
device = alt_counts_1d.device

batch_size = len(alt_counts_1d)
fractions_2d = self.get_fractions(batch_size, self.num_samples)
fractions_2d = self.get_fractions(batch_size, self.num_samples).to(device)

log_likelihoods_2d = torch.reshape(alt_counts_1d, (batch_size, 1)) * torch.log(fractions_2d) \
+ torch.reshape(ref_counts_1d, (batch_size, 1)) * torch.log(1 - fractions_2d)
log_likelihoods_2d = torch.reshape(alt_counts_1d, (batch_size, 1)) * torch.log(fractions_2d).to(device) \
+ torch.reshape(ref_counts_1d.to(device), (batch_size, 1)) * torch.log(1 - fractions_2d).to(device)

# average over sample dimension
log_likelihoods_1d = torch.logsumexp(log_likelihoods_2d, dim=1) - math.log(self.num_samples)
log_likelihoods_1d = torch.logsumexp(log_likelihoods_2d, dim=1).to(device) - math.log(self.num_samples)

combinatorial_term = torch.lgamma(alt_counts_1d + ref_counts_1d + 1) - torch.lgamma(alt_counts_1d + 1) - torch.lgamma(ref_counts_1d + 1)
combinatorial_term = torch.lgamma(alt_counts_1d + ref_counts_1d.to(device) + 1).to(device) - torch.lgamma(alt_counts_1d + 1).to(device) - torch.lgamma(ref_counts_1d.to(device) + 1).to(device)

return combinatorial_term + log_likelihoods_1d

def get_mean(self):
return torch.sigmoid(self.mean_pre_sigmoid) * self.max_mean

def get_fractions(self, batch_size, num_samples):
# Get the device from the model parameter
device = self.mean_pre_sigmoid.device

actual_mean = torch.sigmoid(self.mean_pre_sigmoid) * self.max_mean
actual_sigma = SQRT_PI_OVER_2 * actual_mean
normal_samples = torch.randn(batch_size, num_samples)
normal_samples = torch.randn(batch_size, num_samples, device=device)
half_normal_samples = torch.abs(normal_samples)
fractions_2d_unbounded = actual_sigma * half_normal_samples
# apply tanh to constrain fractions to [0, 1), and then to [EPSILON, 1 - EPSILON] for numerical stability
Expand Down
16 changes: 10 additions & 6 deletions permutect/architecture/overdispersed_binomial_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,19 @@ def forward(self, x, n, k):
assert n.size() == k.size()
assert len(x) == len(n)
assert x.size()[1] == self.input_size
device = x.device

log_weights = log_softmax(self.weights_pre_softmax(x), dim=1)
log_weights = log_softmax(self.weights_pre_softmax(x), dim=1).to(device)

# we make them 2D, with 1st dim batch, to match alpha and beta. A single column is OK because the single value of
# n/k are broadcast over all mixture components
n_2d = unsqueeze(n, dim=1)
k_2d = unsqueeze(k, dim=1)
n_2d = unsqueeze(n, dim=1).to(device)
k_2d = unsqueeze(k, dim=1).to(device)

# 2D tensors -- 1st dim batch, 2nd dim mixture component
means = self.max_mean * torch.sigmoid(self.mean_pre_sigmoid(x))
concentrations = self.get_concentration(x)
means = means.to(device)
concentrations = self.get_concentration(x).to(device)

if self.mode == 'beta':
alphas = means * concentrations
Expand Down Expand Up @@ -212,8 +214,10 @@ def fit(self, num_epochs, inputs_2d_tensor, depths_1d_tensor, alt_counts_1d_tens
here x is a 1D tensor, a single datum/row of the 2D tensors as above
'''
def spectrum_density_vs_fraction(self, variant_type: Variation, depth: int):
fractions = torch.arange(0.01, 0.99, 0.001) # 1D tensor
x = torch.from_numpy(variant_type.one_hot_tensor()).float()
device = next(self.weights_pre_softmax.parameters()).device # Get device from model parameter

fractions = torch.arange(0.01, 0.99, 0.001, device=device) # 1D tensor
x = torch.from_numpy(variant_type.one_hot_tensor()).float().to(device)

unsqueezed = x.unsqueeze(dim=0) # this and the three following tensors are 2D tensors with one row
log_weights = log_softmax(self.weights_pre_softmax(unsqueezed).detach(), dim=1)
Expand Down
Loading