Here I implement the algorithm described in a previous note and compare results with FLASH.
Click “Code” to view the implementation.
# INITIALIZATION FUNCTIONS ------------------------------------------
fl_to_altfl <- function(data, fl, k) {
altfl <- list()
altfl$tau <- fl$tau
altfl$Rk <- flashr:::flash_get_Rk(data, fl, k)
altfl$R2k <- flashr:::flash_get_R2k(data, fl, k)
altfl$al <- fl$gl[[k]]$a
altfl$pi0l <- fl$gl[[k]]$pi0
altfl$af <- fl$gf[[k]]$a
altfl$pi0f <- fl$gf[[k]]$pi0
s2 = 1/(fl$EF2[, k] %*% t(fl$tau))
s = sqrt(s2)
Rk = flashr:::flash_get_Rk(data, fl, k)
x = fl$EF[, k] %*% t(Rk * fl$tau) * s2
w = 1 - fl$gl[[k]]$pi0
a = fl$gl[[k]]$a
altfl$wl <- ebnm:::wpost_normal(x, s, w, a)
altfl$mul <- ebnm:::pmean_cond_normal(x, s, a)
altfl$s2l <- ebnm:::pvar_cond_normal(s, a)
s2 = 1/(fl$EL2[, k] %*% fl$tau)
s = sqrt(s2)
Rk = flashr:::flash_get_Rk(data, fl, k)
x = fl$EL[, k] %*% (Rk * fl$tau) * s2
w = 1 - fl$gf[[k]]$pi0
a = fl$gf[[k]]$a
altfl$wf <- ebnm:::wpost_normal(x, s, w, a)
altfl$muf <- ebnm:::pmean_cond_normal(x, s, a)
altfl$s2f <- ebnm:::pvar_cond_normal(s, a)
altfl$KL <- sum(unlist(fl$KL_l)[-k] + unlist(fl$KL_f)[-k])
altfl_to_fl <- function(altfl, fl, k) {
fl$EL[, k] <- compute_EX(altfl$wl, altfl$mul)
fl$EL2[, k] <- compute_EX2(altfl$wl, altfl$mul, altfl$s2l)
fl$EF[, k] <- compute_EX(altfl$wf, altfl$muf)
fl$EF2[, k] <- compute_EX2(altfl$wf, altfl$muf, altfl$s2f)
fl$gl[[k]] <- list(pi0 = altfl$pi0l, a = altfl$al)
fl$gf[[k]] <- list(pi0 = altfl$pi0f, a = altfl$af)
fl$ebnm_fn_l <- fl$ebnm_fn_f <- "alt"
fl$ebnm_param_l <- fl$ebnm_param_f <- list()
fl$tau <- altfl$tau
# OBJECTIVE FUNCTION ------------------------------------------------
compute_obj <- function(altfl) {
with(altfl, {
EL <- compute_EX(wl, mul)
EL2 <- compute_EX2(wl, mul, s2l)
EF <- compute_EX(wf, muf)
EF2 <- compute_EX2(wf, muf, s2f)
obj <- rep(0, 4)
obj[1] <- sum(0.5 * log(tau / (2 * pi)))
obj[2] <- sum(-0.5 * (tau * (R2k - 2 * Rk * outer(EL, EF) + outer(EL2, EF2))))
obj[3] <- compute_KL(wl, mul, s2l, pi0l, al)
obj[4] <- compute_KL(wf, muf, s2f, pi0f, af)
return(sum(obj) + KL)
compute_KL <- function(w, mu, s2, pi0, a) {
obj <- rep(0, 3)
tmp <- (1 - w) * (log(pi0) - log(1 - w))
obj[1] <- sum(tmp[!is.nan(tmp)])
tmp <- w * (log(1 - pi0) - log(w))
obj[2] <- sum(tmp[!is.nan(tmp)])
obj[3] <- sum(0.5 * w * (log(a) + log(s2) + 1 - a * (mu^2 + s2)))
compute_EX <- function(w, mu) {
return(as.vector(w * mu))
compute_EX2 <- function(w, mu, sigma2) {
return(as.vector(w * (mu^2 + sigma2)))
# UPDATE FUNCTIONS --------------------------------------------------
update_a <- function(w, EX2) {
return(sum(w) / sum(EX2))
update_pi0 <- function(w) {
return(sum(1 - w) / length(w))
update_mul <- function(a, tau, Rk, EF, EF2) {
n <- nrow(tau)
p <- ncol(tau)
numer <- rowSums(tau * Rk * matrix(EF, nrow=n, ncol=p, byrow=TRUE))
denom <- a + rowSums(tau * matrix(EF2, nrow=n, ncol=p, byrow=TRUE))
return(numer / denom)
update_muf <- function(a, tau, Rk, EL, EL2) {
n <- nrow(tau)
p <- ncol(tau)
numer <- colSums(tau * Rk * matrix(EL, nrow=n, ncol=p, byrow=FALSE))
denom <- a + colSums(tau * matrix(EL2, nrow=n, ncol=p, byrow=FALSE))
return(numer / denom)
update_s2l <- function(a, tau, EF2) {
n <- nrow(tau)
p <- ncol(tau)
return(1 / (a + rowSums(tau * matrix(EF2, nrow=n, ncol=p, byrow=TRUE))))
update_s2f <- function(a, tau, EL2) {
n <- nrow(tau)
p <- ncol(tau)
return(1 / (a + colSums(tau * matrix(EL2, nrow=n, ncol=p, byrow=FALSE))))
update_wl <- function(a, pi0, mu, sigma2, tau, Rk, EF, EF2) {
C1 <- log(1 - pi0) - log(pi0)
C2 <- 0.5 * (log(a) + log(sigma2) - a * (mu^2 + sigma2) + 1)
C3 <- rowSums(tau * (Rk * outer(mu, EF) - 0.5 * outer(mu^2 + sigma2, EF2)))
C <- C1 + C2 + C3
return(1 / (1 + exp(-C)))
update_wf <- function(a, pi0, mu, sigma2, tau, Rk, EL, EL2) {
C1 <- log(1 - pi0) - log(pi0)
C2 <- 0.5 * (log(a) + log(sigma2) - a * (mu^2 + sigma2) + 1)
C3 <- colSums(tau * (Rk * outer(EL, mu) - 0.5 * outer(EL2, mu^2 + sigma2)))
C <- C1 + C2 + C3
return(1 / (1 + exp(-C)))
# ALGORITHM ---------------------------------------------------------
update_tau <- function(altfl) {
within(altfl, {
EL <- compute_EX(wl, mul)
EL2 <- compute_EX2(wl, mul, s2l)
EF <- compute_EX(wf, muf)
EF2 <- compute_EX2(wf, muf, s2f)
R2 <- R2k - 2 * Rk * outer(EL, EF) + outer(EL2, EF2)
tau <- matrix(1 / colMeans(R2), nrow=nrow(tau), ncol=ncol(tau),
update_loadings_post <- function(altfl) {
within(altfl, {
EF <- compute_EX(wf, muf)
EF2 <- compute_EX2(wf, muf, s2f)
mul <- update_mul(al, tau, Rk, EF, EF2)
s2l <- update_s2l(al, tau, EF2)
wl <- update_wl(al, pi0l, mul, s2l, tau, Rk, EF, EF2)
update_loadings_prior <- function(altfl) {
within(altfl, {
EL2 <- compute_EX2(wl, mul, s2l)
al <- update_a(wl, EL2)
pi0l <- update_pi0(wl)
update_factor_post <- function(altfl) {
within(altfl, {
EL <- compute_EX(wl, mul)
EL2 <- compute_EX2(wl, mul, s2l)
muf <- update_muf(af, tau, Rk, EL, EL2)
s2f <- update_s2f(af, tau, EL2)
wf <- update_wf(af, pi0f, muf, s2f, tau, Rk, EL, EL2)
update_factor_prior <- function(altfl) {
within(altfl, {
EF2 <- compute_EX2(wf, muf, s2f)
af <- update_a(wf, EF2)
pi0f <- update_pi0(wf)
do_one_update <- function(altfl) {
obj <- rep(0, 5)
altfl <- update_tau(altfl)
obj[1] <- compute_obj(altfl)
altfl <- update_loadings_post(altfl)
obj[2] <- compute_obj(altfl)
altfl <- update_loadings_prior(altfl)
obj[3] <- compute_obj(altfl)
altfl <- update_factor_post(altfl)
obj[4] <- compute_obj(altfl)
altfl <- update_factor_prior(altfl)
obj[5] <- compute_obj(altfl)
return(list(altfl = altfl, obj = obj))
optimize_alt_fl <- function(altfl, tol = .01, verbose = FALSE) {
obj <- compute_obj(altfl)
diff <- Inf
while (diff > tol) {
tmp <- do_one_update(altfl)
new_obj <- tmp$obj[length(tmp$obj)]
diff <- new_obj - obj
obj <- new_obj
if (verbose) {
message(paste("Objective:", obj))
altfl <- tmp$altfl
Using the same dataset as in previous investigations, I fit a FLASH object with four factors (recall that it’s the fourth factor that was causing problems during loadings updates):
# devtools::install_github("stephenslab/flashr")
Loading flashr
# devtools::install_github("stephenslab/flashr")
Loading ebnm
fl <- flash_add_greedy(data, Kmax=4, verbose=FALSE)
fitting factor/loading 1
fitting factor/loading 2
fitting factor/loading 3
fitting factor/loading 4
The objective as computed by FLASH is:
flash_get_objective(data, fl)
[1] -1297135
I now convert the fourth factor to an “altfl” object. The objective as computed by the alternate method is:
altfl <- fl_to_altfl(data, fl, 4)
[1] -1297135
So the objective functions agree. This is a good thing.
Next, I attempt to optimize the altfl object:
altfl <- optimize_alt_fl(altfl, verbose=TRUE)
Objective: -1297135.42237706
Nothing happens. This is also a good thing!
That I get the same values in both cases confirms that a) the expressions in the previous note are correct; b) the KL_l
and KL_f
values obtained using FLASH are in fact reliable.
The algorithm implemented here is, however, very prone to getting stuck in local maxima, and does not appear to be a suitable substitute for the existing FLASH algorithm.
