Skip to content

Commit

Permalink
Merge pull request campbio#52 from joshua-d-campbell/devel
Browse files Browse the repository at this point in the history
Incorporating updates to recursiveSplit
  • Loading branch information
zhewa authored Mar 19, 2019
2 parents bb0262e + f818100 commit 6597dd0
Show file tree
Hide file tree
Showing 18 changed files with 400 additions and 146 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ import(scales)
useDynLib(celda,"_colSumByGroup")
useDynLib(celda,"_colSumByGroupChange")
useDynLib(celda,"_colSumByGroup_numeric")
useDynLib(celda,"_perplexityG")
useDynLib(celda,"_rowSumByGroup")
useDynLib(celda,"_rowSumByGroupChange")
useDynLib(celda,"_rowSumByGroup_numeric")
2 changes: 1 addition & 1 deletion R/celdaGridSearch.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ selectBestModel = function(celda.list) {
if (!methods::is(celda.list, "celdaList")) stop("celda.list parameter was not of class celdaList.")

log_likelihood = NULL
group = setdiff(colnames(celda.list@run.params), c("index", "chain", "log_likelihood"))
group = setdiff(colnames(celda.list@run.params), c("index", "chain", "log_likelihood", "mean_perplexity"))
dt = data.table::as.data.table(celda.list@run.params)
new.run.params = as.data.frame(dt[,.SD[which.max(log_likelihood)], by=group])
new.run.params = new.run.params[,colnames(celda.list@run.params)]
Expand Down
12 changes: 9 additions & 3 deletions R/celda_C.R
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,18 @@ cC.calcLL = function(m.CP.by.S, n.G.by.CP, s, z, K, nS, nG, alpha, beta) {
#' @return Numeric. The log likelihood for the given cluster assignments
#' @seealso `celda_C()` for clustering cells
#' @examples
#' loglik = logLikelihood(celda.C.sim$counts, model="celda_C",
#' loglik = logLikelihood.celda_C(celda.C.sim$counts,
#' sample.label=celda.C.sim$sample.label,
#' z=celda.C.sim$z, K=celda.C.sim$K,
#' alpha=celda.C.sim$alpha, beta=celda.C.sim$beta)
#'
#' loglik = logLikelihood(celda.C.sim$counts, model="celda_C",
#' sample.label=celda.C.sim$sample.label,
#' z=celda.C.sim$z, K=celda.C.sim$K,
#' alpha=celda.C.sim$alpha, beta=celda.C.sim$beta)
#'
#' @export
logLikelihood.celda_C = function(counts, model, sample.label, z, K,
logLikelihood.celda_C = function(counts, sample.label, z, K,
alpha, beta) {
if (sum(z > K) > 0) stop("An entry in z contains a value greater than the provided K.")
sample.label = processSampleLabels(sample.label, ncol(counts))
Expand Down Expand Up @@ -533,7 +539,7 @@ setMethod("clusterProbability",
setMethod("perplexity",
signature(celda.mod = "celda_C"),
function(counts, celda.mod, new.counts=NULL) {
compareCountMatrix(counts, celda.mod)

if (!("celda_C" %in% class(celda.mod))) {
stop("The celda.mod provided was not of class celda_C.")
}
Expand Down
10 changes: 9 additions & 1 deletion R/celda_CG.R
Original file line number Diff line number Diff line change
Expand Up @@ -526,12 +526,20 @@ cCG.calcLL = function(K, L, m.CP.by.S, n.TS.by.CP, n.by.G, n.by.TS, nG.by.TS, nS
#' @return The log likelihood for the given cluster assignments
#' @seealso `celda_CG()` for clustering features and cells
#' @examples
#' loglik = logLikelihood(celda.CG.sim$counts, model="celda_CG",
#' loglik = logLikelihood.celda_CG(celda.CG.sim$counts,
#' sample.label=celda.CG.sim$sample.label,
#' z=celda.CG.sim$z, y=celda.CG.sim$y,
#' K=celda.CG.sim$K, L=celda.CG.sim$L,
#' alpha=celda.CG.sim$alpha, beta=celda.CG.sim$beta,
#' gamma=celda.CG.sim$gamma, delta=celda.CG.sim$delta)
#'
#' loglik = logLikelihood(celda.CG.sim$counts, model="celda_CG",
#' sample.label=celda.CG.sim$sample.label,
#' z=celda.CG.sim$z, y=celda.CG.sim$y,
#' K=celda.CG.sim$K, L=celda.CG.sim$L,
#' alpha=celda.CG.sim$alpha, beta=celda.CG.sim$beta,
#' gamma=celda.CG.sim$gamma, delta=celda.CG.sim$delta)
#'
#' @export
logLikelihood.celda_CG = function(counts, sample.label, z, y, K, L, alpha, beta, delta, gamma) {
if (sum(z > K) > 0) stop("An entry in z contains a value greater than the provided K.")
Expand Down
18 changes: 12 additions & 6 deletions R/celda_G.R
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,16 @@ cG.calcLL = function(n.TS.by.C, n.by.TS, n.by.G, nG.by.TS, nM, nG, L, beta, delt
#' @return The log-likelihood for the given cluster assignments
#' @seealso `celda_G()` for clustering features
#' @examples
#' loglik = logLikelihood(celda.G.sim$counts, model="celda_G",
#' loglik = logLikelihood.celda_G(celda.G.sim$counts,
#' y=celda.G.sim$y, L=celda.G.sim$L,
#' beta=celda.G.sim$beta, delta=celda.G.sim$delta,
#' gamma=celda.G.sim$gamma)
#'
#' loglik = logLikelihood(celda.G.sim$counts, model="celda_G",
#' y=celda.G.sim$y, L=celda.G.sim$L,
#' beta=celda.G.sim$beta, delta=celda.G.sim$delta,
#' gamma=celda.G.sim$gamma)
#'
#' @export
logLikelihood.celda_G = function(counts, y, L, beta, delta, gamma) {
if (sum(y > L) > 0) stop("An entry in y contains a value greater than the provided L.")
Expand Down Expand Up @@ -533,15 +539,15 @@ setMethod("perplexity",

factorized = factorizeMatrix(counts = counts, celda.mod = celda.mod,
type=c("posterior", "counts"))
phi <- factorized$posterior$module
psi <- factorized$posterior$cell
psi <- factorized$posterior$module
phi <- factorized$posterior$cell
eta <- factorized$posterior$gene.distribution
nG.by.TS = factorized$counts$gene.distribution

eta.prob = log(eta) * nG.by.TS
gene.by.cell.prob = log(phi %*% psi)
log.px = sum(gene.by.cell.prob * new.counts) # + sum(eta.prob)

# gene.by.cell.prob = log(psi %*% phi)
# log.px = sum(gene.by.cell.prob * new.counts) # + sum(eta.prob)
log.px = perplexityG_logPx(new.counts, phi, psi, celda.mod@clusters$y, celda.mod@params$L)# + sum(eta.prob)
perplexity = exp(-(log.px/sum(new.counts)))
return(perplexity)
})
Expand Down
9 changes: 7 additions & 2 deletions R/initialize_clusters.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ initialize.splitZ = function(counts, K, K.subcluster=NULL, alpha=1, beta=1, min.
z.ta = tabulate(overall.z, max(overall.z))
z.to.split = sample(which(z.ta > min.cell & z.ta > K.to.use))

if(length(z.to.split) == 0) break()

## Cycle through each splitable cluster and split it up into K.sublcusters
for(i in z.to.split) {

Expand Down Expand Up @@ -138,7 +140,8 @@ initialize.splitZ = function(counts, K, K.subcluster=NULL, alpha=1, beta=1, min.
m.CP.by.S = p$m.CP.by.S[-z.to.remove,,drop=FALSE]
overall.z = as.integer(as.factor(overall.z))
current.K = current.K - 1
}

}
return(overall.z)
}

Expand All @@ -165,7 +168,9 @@ initialize.splitY = function(counts, L, L.subcluster=NULL, temp.K=100, beta=1, d
## Determine which clusters are split-able
y.ta = tabulate(overall.y, max(overall.y))
y.to.split = sample(which(y.ta > min.feature & y.ta > L.subcluster))


if(length(y.to.split) == 0) break()

## Cycle through each splitable cluster and split it up into L.sublcusters
for(i in y.to.split) {

Expand Down
7 changes: 7 additions & 0 deletions R/matrixSums.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,10 @@ colSumByGroup.numeric <- function(x, group, K) {
res <- .Call("_colSumByGroup_numeric", x, group)
return(res)
}

#' @useDynLib celda _perplexityG
perplexityG_logPx <- function(x, phi, psi, group, L) {
group <- factor(group, levels=1:L)
res <- .Call("_perplexityG", x, phi, psi, group)
return(res)
}
44 changes: 28 additions & 16 deletions R/model_performance.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,20 @@ resamplePerplexity <- function(counts, celda.list, resample=5, seed=12345) {
if (!isTRUE(is.numeric(resample))) stop("Provided resample parameter was not numeric.")

setSeed(seed)
countsList = lapply(1:resample,
function(i){
resampleCountMatrix(counts)
})


perp.res = matrix(NA, nrow=length(celda.list@res.list), ncol=resample)
for(i in 1:length(celda.list@res.list)) {
for(j in 1:resample) {
perp.res[i,j] = perplexity(counts, celda.list@res.list[[i]], countsList[[j]])
for(j in 1:resample) {
new.counts = resampleCountMatrix(counts)
for(i in 1:length(celda.list@res.list)) {
perp.res[i,j] = perplexity(counts, celda.list@res.list[[i]], new.counts)
}
}
celda.list@perplexity = perp.res

## Add mean perplexity to run.params
perp.mean = apply(perp.res, 1, mean)
celda.list@run.params$mean_perplexity=perp.mean

return(celda.list)
}

Expand Down Expand Up @@ -86,14 +87,25 @@ plotGridSearchPerplexity.celda_CG = function(celda.list) {
l.means.by.k$K = as.factor(l.means.by.k$K)
l.means.by.k$L = as.factor(l.means.by.k$L)

plot = ggplot2::ggplot(df, ggplot2::aes_string(x="K", y="perplexity")) +
ggplot2::geom_jitter(height=0, width=0.1, ggplot2::aes_string(color="L")) +
ggplot2::scale_color_discrete(name="L") +
ggplot2::geom_path(data=l.means.by.k,
ggplot2::aes_string(x="K", y="mean_perplexity", group="L", color="L")) +
ggplot2::ylab("Perplexity") +
ggplot2::xlab("K") +
ggplot2::theme_bw()
if(nlevels(df$K) > 1) {
plot = ggplot2::ggplot(df, ggplot2::aes_string(x="K", y="perplexity")) +
ggplot2::geom_jitter(height=0, width=0.1, ggplot2::aes_string(color="L")) +
ggplot2::scale_color_discrete(name="L") +
ggplot2::geom_path(data=l.means.by.k,
ggplot2::aes_string(x="K", y="mean_perplexity", group="L", color="L")) +
ggplot2::ylab("Perplexity") +
ggplot2::xlab("K") +
ggplot2::theme_bw()
} else {
plot = ggplot2::ggplot(df, ggplot2::aes_string(x="L", y="perplexity")) +
ggplot2::geom_jitter(height=0, width=0.1, ggplot2::aes_string(color="K")) +
ggplot2::scale_color_discrete(name="K") +
ggplot2::geom_path(data=l.means.by.k,
ggplot2::aes_string(x="L", y="mean_perplexity", group="K", color="K")) +
ggplot2::ylab("Perplexity") +
ggplot2::xlab("L") +
ggplot2::theme_bw()
}

return(plot)
}
Expand Down
Loading

0 comments on commit 6597dd0

Please sign in to comment.