mediate_mnl <- function(formula.y,
                        formula.m1,
                        formula.m2 = NULL,
                        formula.m3 = NULL,
                        type.m1 = c("continuous", "binary"),
                        type.m2 = c("continuous", "binary"),
                        type.m3 = c("continuous", "binary"),
                        type.t = c("continous", "binary", "categorical"),
                        t.name,
                        t.shift = c("unit", "sd"),
                        data,
                        weights = NULL,
                        imp = FALSE,
                        S = 1000L,
                        seed = 19890213) {
  
  ## Seed
  set.seed(seed)
  
  ## Number of mediators
  MEDS <- sum(c(
    !is.null(formula.m1),!is.null(formula.m2),!is.null(formula.m3)
  ))
  
  ## Checks
  ## 1) Imputation
  if (imp) {
    if (!(is.list(data))) {
      stop("data must be a list of data frames") 
    } else if (!(all(sapply(data, is.data.frame)))) {
      stop("data must be a list of data frames")
    }
  } else {
    if (!(is.data.frame(data))) {
      stop("data must be a data frame")
    }
  }
  
  ## 2) Data inputs
  y.name <- all.vars(formula.y)[1]
  m1.name <- all.vars(formula.m1)[1]
  m2.name <- ifelse(MEDS >= 2, all.vars(formula.m2)[1], NULL)
  m3.name <- ifelse(MEDS >= 3, all.vars(formula.m3)[1], NULL)
  var.names <- c(y.name, t.name, m1.name, m2.name, m3.name)
  var.types <- c("categorical", type.t, type.m1, type.m2, type.m3)
  
  if (imp) {
    # imputed data
    for (n in seq_along(var.names)) {
      if (var.types[n] == "categorical") {
        if (!(all(sapply(data, function (d) {
          with(d, is.factor(get(eval(var.names[n]))))
        })))) {
          error (
            paste0(
              "Categorical variable is not a factor. ",
              "Problem found for: ",
              var.names[n]
            )
          )
        } else {
          if (!(all(sapply(data, function (d) {
            with(d, nlevels(get(eval(
              var.names[n]
            ))))
          } > 2)))) {
            stop(
              paste0(
                "Categorical variable needs >= 3 categories. ",
                "Problem found for: ",
                var.names[n]
              )
            )
          }
        }
      } else if (var.types[n] == "continuous") {
        if (all(sapply(data, function (d) {
          with(d, is.numeric(get(eval(var.names[n]))))
        }))) {
          if (any(sapply(data, function (d) {
            with(d, length(unique(get(
              eval(var.names[n])
            ))))
          } < 5))) {
            warning(
              paste0(
                "Some variables declared as continuous ",
                "have fewer than 5 unique values. ",
                "Problem found for: ",
                var.names[n]
              )
            )
          }
        } else {
          stop(
            paste0(
              "Continuous variables must be numeric. ",
              "Problem found for: ",
              var.names[n]
            )
          )
        }
      } else if (var.types[n] == "binary") {
        if (all(sapply(data, function (d) {
          with(d, is.factor(get(eval(var.names[n]))))
        }))) {
          if (any(sapply(data, function (d) {
            with(d, nlevels(get(eval(
              var.names[n]
            ))))
          } != 2))) {
            stop(
              paste0(
                "Binary factor variables must have exactly two levels. ",
                "Problem found for: ",
                var.names[n]
              )
            )
          }
        } else {
          if (!(all(sapply(data, function (d) {
            with(d, unique(get(eval(
              var.names[n]
            ))))
          } %in% c(0, 1))))) {
            stop(
              paste0(
                "Binary numerical variables must be coded 0 and 1. ",
                "Problem found for: ",
                var.names[n]
              )
            )
          }
        }
      }
    }
  } else {
    # single data frame
    for (n in seq_along(var.names)) {
      if (var.types[n] == "categorical") {
        if (!(is.factor(data[, var.names[n]]))) {
          error (
            paste0(
              "Categorical variable is not a factor. ",
              "Problem found for: ",
              var.names[n]
            )
          )
        } else {
          if (!(nlevels(data[, var.names[n]]) > 2)) {
            stop(
              paste0(
                "Categorical variable needs >= 3 categories. ",
                "Problem found for: ",
                var.names[n]
              )
            )
          }
        }
      } else if (var.types[n] == "continuous") {
        if (is.numeric(data[, var.names[n]])) {
          if (length(unique(data[, var.names[n]])) < 5) {
            warning(
              paste0(
                "Some variables declared as continuous ",
                "have fewer than 5 unique values. ",
                "Problem found for: ",
                var.names[n]
              )
            )
          }
        } else {
          stop(
            paste0(
              "Continuous variables must be numeric. ",
              "Problem found for: ",
              var.names[n]
            )
          )
        }
      } else if (var.types[n] == "binary") {
        if (!(is.factor(data[, var.names[n]]))) {
          if (nlevels(data[, var.names[n]]) != 2) {
            stop(
              paste0(
                "Binary factor variables must have exactly two levels. ",
                "Problem found for: ",
                var.names[n]
              )
            )
          }
        } else {
          if (unique(data[, var.names[n]]) %in% c(0, 1)) {
            stop(
              paste0(
                "Binary numerical variables must be coded 0 and 1. ",
                "Problem found for: ",
                var.names[n]
              )
            )
          }
        }
      }
    }
  }
    
  ## Model Estimation
  print("1 Estimating models")
  if (imp) { # data is a list of imputed data sets
    M <- length(data)
    b.y <- v.y <- b.m1 <- v.m1 <- list()
    if (!is.null(formula.m2))
      b.m2 <- v.m2 <- list()
    if (!is.null(formula.m3))
      b.m3 <- v.m3 <- list()
    for (i in seq_len(M)) {
      df <- data[[i]]
      w <- if (is.null(weights)) {
        rep(1L, nrow(df))
      } else {
        df[, weights]
      }
      df$weights <- w
      
      ## Outcome
      mod.y <- multinom(
        formula.y,
        Hess = TRUE,
        data = df,
        weights = weights,
        maxit = 500L
      )
      v.y[[i]] <- vcov(mod.y)
      b.y[[i]] <- coef(mod.y)
      
      ## Mediator 1
      if (type.m1 == "continuous") {
        mod.m1 <- lm(formula.m1,
                     data = df,
                     weights = weights)
      } else if (type.m1 == "binary") {
        mod.m1 <- glm(
          formula.m1,
          data = df,
          weights = weights,
          family = binomial(link = 'logit')
        )
      }
      v.m1[[i]] <- vcov(mod.m1)
      b.m1[[i]] <- coef(mod.m1)
      
      ## Mediator 2
      if (!is.null(formula.m2)) {
        if (type.m2 == "continuous") {
          mod.m2 <- lm(formula.m2,
                       data = df,
                       weights = weights)
        } else if (type.m2 == "binary") {
          mod.m2 <- glm(
            formula.m2,
            data = df,
            weights = weights,
            family = binomial(link = 'logit')
          )
        }
        v.m2[[i]] <- vcov(mod.m2)
        b.m2[[i]] <- coef(mod.m2)
      }
      
      ## Mediator 3
      if (!is.null(formula.m3)) {
        if (type.m3 == "continuous") {
          mod.m3 <- lm(formula.m3,
                       data = df,
                       weights = weights)
        } else if (type.m3 == "binary") {
          mod.m3 <- glm(
            formula.m3,
            data = df,
            weights = weights,
            family = binomial(link = 'logit')
          )
        }
        v.m3[[i]] <- vcov(mod.m3)
        b.m3[[i]] <- coef(mod.m3)
      }
    }
    rm(df)
  } else { # data is a data frame
    M <- 1L
    w <- if (is.null(weights)) {
      rep(1L, nrow(data))
    } else {
      data[, weights]
    }
    data$weights <- w
    
    ## Outcome
    mod.y <- multinom(
      formula.y,
      Hess = TRUE,
      data = data,
      weights = weights,
      maxit = 500L
    )
    v.y <- vcov(mod.y)
    b.y <- coef(mod.y)
    
    ## Mediator 1
    if (type.m1 == "continuous") {
      mod.m1 <- lm(formula.m1,
                   data = data,
                   weights = weights)
    } else if (type.m1 == "binary") {
      mod.m1 <- glm(
        formula.m1,
        data = data,
        weights = weights,
        family = binomial(link = 'logit')
      )
    }
    v.m1 <- vcov(mod.m1)
    b.m1 <- coef(mod.m1)
    
    ## Mediator 2
    if (!is.null(formula.m2)) {
      if (type.m2 == "continuous") {
        mod.m2 <- lm(formula.m2,
                     data = data,
                     weights = weights)
      } else if (type.m2 == "binary") {
        mod.m2 <- glm(
          formula.m2,
          data = data,
          weights = weights,
          family = binomial(link = 'logit')
        )
      }
      v.m2 <- vcov(mod.m2)
      b.m2 <- coef(mod.m2)
    }
    
    ## Mediator 3
    if (!is.null(formula.m3)) {
      if (type.m3 == "continuous") {
        mod.m3 <- lm(formula.m3,
                     data = data,
                     weights = weights)
      } else if (type.m3 == "binary") {
        mod.m3 <- glm(
          formula.m3,
          data = data,
          weights = weights,
          family = binomial(link = 'logit')
        )
      }
      v.m3 <- vcov(mod.m3)
      b.m3 <- coef(mod.m3)
    }
  }
  
  ## Extract model information
  call.y <- as.formula(paste(as.character(mod.y$terms)[c(1, 3)], sep = " "))
  call.m1 <- as.formula(paste(as.character(mod.m1$terms)[c(1, 3)], sep = " "))
  if (!is.null(formula.m2)) {
    call.m2 <- 
      as.formula(paste(as.character(mod.m2$terms)[c(1, 3)], sep = " "))
  }
  if (!is.null(formula.m3)) {
    call.m3 <- 
      as.formula(paste(as.character(mod.m3$terms)[c(1, 3)], sep = " "))
  }
  j.names <- mod.y$lab
  k.names <- mod.y$coefnames
  weights <- mod.y$weights
  
  ## Treatment levels of categorical 
  if (type.t == "categorical") {
    if (imp) {
      t.levels <- with(data[[1]], levels(get(eval(t.name))))
    } else {
      t.levels <- with(data, levels(get(eval(t.name))))
    }
  }
  
  ## Constants
  J <- nrow(coef(mod.y)) + 1L
  K <- ncol(coef(mod.y))
  
  ## Parameter Simulation
  print("2 Simulating Model Parameters")
  if (imp) { # data is a list of imputed data frames
    b.sim <- m1.sim <- NULL
    if (!is.null(formula.m2))
      m2.sim <- NULL
    if (!is.null(formula.m3))
      m3.sim <- NULL
    for (i in seq_len(M)) {
      ## Outcome
      coef.vec <- as.vector(t(b.y[[i]]))
      b.sim.tmp <- mvrnorm(S, coef.vec, v.y[[i]])
      for (k in seq_len(K))
        b.sim.tmp <- cbind(0, b.sim.tmp)
      b.sim <- rbind(b.sim, b.sim.tmp)
      rm(b.sim.tmp)
      
      ## Mediator 1
      m1.sim <- rbind(m1.sim, mvrnorm(S, b.m1[[i]], v.m1[[i]]))
      
      ## Mediator 2
      if (!is.null(formula.m2)) {
        m2.sim <- rbind(m2.sim, mvrnorm(S, b.m2[[i]], v.m2[[i]]))
      }
      
      ## Mediator 3
      if (!is.null(formula.m3)) {
        m3.sim <- rbind(m3.sim, mvrnorm(S, b.m3[[i]], v.m3[[i]]))
      }
    }
  } else { # data is a data frame
    ## Outcome
    coef.vec <- as.vector(t(b.y))
    b.sim <- mvrnorm(S, coef.vec, v.y)
    for (k in seq_len(K))
      b.sim <- cbind(0, b.sim)
    
    ## Mediator 1
    m1.sim <- mvrnorm(S, b.m1, v.m1)
    
    ## Mediator 2
    if (!is.null(formula.m2)) {
      m2.sim <- mvrnorm(S, b.m2, v.m2)
    }
    
    ## Mediator 3
    if (!is.null(formula.m3)) {
      m3.sim <- mvrnorm(S, b.m3, v.m3)
    }
  }
  
  ## Simulations
  b <- array(NA, c(J, M * S, K))
  for (j in seq_len(J)) {
    from <- (j - 1L) * K + 1L
    to <- j * K
    b[j, ,] <- b.sim[, from:to]
  }
  rm(b.sim)
  
  ## Data matrices
  if (imp) { # data is a list of imputed data frames
    X  <- lapply(data, function(d)
      model.matrix(call.y, data = d))
    X <- simplify2array(X)
    Z1 <- lapply(data, function(d)
      model.matrix(call.m1, data = d))
    Z1 <- simplify2array(Z1)
    if (!is.null(formula.m2)) {
      Z2 <- lapply(data, function(d)
        model.matrix(call.m2, data = d))
      Z2 <- simplify2array(Z2)
    }
    if (!is.null(formula.m3)) {
      Z3 <- lapply(data, function(d)
        model.matrix(call.m3, data = d))
      Z3 <- simplify2array(Z3)
    }
  } else { # data is a data frame
    X <- model.matrix(call.y, data = data)
    Z1 <- model.matrix(call.m1, data = data)
    if (!is.null(formula.m2))
      Z2 <- model.matrix(call.m2, data = data)
    if (!is.null(formula.m3))
      Z3 <- model.matrix(call.m3, data = data)
  }
  N <- nrow(X)
  
  ## Find X/Z*-variables of interest
  if (M > 1) { # imputed
    X.tmp <- X[, , 1]
    Z1.tmp <- Z1[, , 1]
    if (!is.null(formula.m2))
      Z2.tmp <- Z2[, , 1]
    if (!is.null(formula.m3))
      Z3.tmp <- Z3[, , 1]
  } else { # single data frame
    X.tmp <- X
    Z1.tmp <- Z1
    if (!is.null(formula.m2))
      Z2.tmp <- Z2
    if (!is.null(formula.m3))
      Z3.tmp <- Z3
  }
  t.index <- grepl(t.name, colnames(X.tmp))
  t.which <- which(t.index)
  t.which.m1 <- which(grepl(m1.name, colnames(X.tmp)))
  if (!is.null(formula.m2)) {
    t.which.m2 <- which(grepl(m2.name, colnames(X.tmp)))
  }
  if (!is.null(formula.m3)) {
    t.which.m3 <- which(grepl(m3.name, colnames(X.tmp)))
  }
  z1.which <- which(grepl(t.name, colnames(Z1.tmp)))
  if (!is.null(formula.m2)) {
    z2.which <- which(grepl(t.name, colnames(Z2.tmp)))
  }
  if (!is.null(formula.m3)) {
    z3.which <- which(grepl(t.name, colnames(Z3.tmp)))
  }
  if (type.t == "binary") {
    n.cat <- 1L
  } else if (type.t == "categorical") {
    n.cat <- sum(t.index) + 1L
  }
  rm(X.tmp, Z1.tmp)
  if (!is.null(formula.m2))
    rm(Z2.tmp)
  if (!is.null(formula.m3))
    rm(Z3.tmp)
  
  ## Average X/Z* over imputations
  if (M > 1) {
    X <- apply(X, 1:2, mean)
    Z1 <- apply(Z1, 1:2, mean)
    if (!is.null(formula.m2))
      Z2 <- apply(Z2, 1:2, mean)
    if (!is.null(formula.m3))
      Z3 <- apply(Z3, 1:2, mean)
  }
  
  ## Predictions
  print("3 Calculating requested quantities")
  
  if (type.t == "continuous") {
    ## Define Shifts
    if (t.shift == "unit") {
      shift <- 1L
    } else if (t.shift == "sd") {
      shift <- sd(X[, t.which])
    }
  } else {
    shift <- 1L
  }

  if (type.t %in% c("binary", "continuous")) {
    ## Counterfactual predictions of the mediator
    m1.hat <- array(NA, c(M * S, N, 3L))
    if (!is.null(formula.m2))
      m2.hat <- array(NA, c(M * S, N, 3L))
    if (!is.null(formula.m3))
      m3.hat <- array(NA, c(M * S, N, 3L))
    for (c in 1:3) {
      ## Data
      Z1.tmp <- Z1
      if (type.t == "continuous") {
        if (c < 3)
          Z1.tmp[, t.which] <- Z1.tmp[, z1.which] + (c - 1.5) * shift
        if (!is.null(formula.m2)) {
          Z2.tmp <- Z2
          if (c < 3)
            Z2.tmp[, t.which] <- Z2.tmp[, z2.which] + (c - 1.5) * shift
        }
        if (!is.null(formula.m3)) {
          Z3.tmp <- Z3
          if (c < 3)
            Z3.tmp[, t.which] <- Z3.tmp[, z3.which] + (c - 1.5) * shift
        }
      } else {
        if (c < 3)
          Z1.tmp[, t.which] <- c - 1
        if (!is.null(formula.m2)) {
          Z2.tmp <- Z2
          if (c < 3)
            Z2.tmp[, t.which] <- c - 1
        }
        if (!is.null(formula.m3)) {
          Z3.tmp <- Z3
          if (c < 3)
            Z3.tmp[, t.which] <- c - 1
        }
      }
      
      ## Predictions
      m1.hat[, , c] <- m1.sim %*% t(Z1.tmp)
      if (type.m1 == "binary")
        m1.hat[, , c] <- invlogit(m1.hat[, , c])
      if (c == 2) {
        m1.ame <- quantile(apply(m1.hat[, , 2] - m1.hat[, , 1],
                                 1, weighted.mean, w),
                           c(.5, .025, .975))
      }
      if (!is.null(formula.m2)) {
        m2.hat[, , c] <- m2.sim %*% t(Z2.tmp)
        if (type.m2 == "binary")
          m2.hat[, , c] <- invlogit(m2.hat[, , c])
        if (c == 2) {
          m2.ame <- quantile(apply(m2.hat[, , 2] - m2.hat[, , 1],
                                   1, weighted.mean, w),
                             c(.5, .025, .975))
        }
      } else {
        m2.ame <- NULL
      }
      if (!is.null(formula.m3)) {
        m3.hat[, , c] <- m3.sim %*% t(Z3.tmp)
        if (type.m3 == "binary")
          m3.hat[, , c] <- invlogit(m3.hat[, , c])
        if (c == 2) {
          m3.ame <- quantile(apply(m3.hat[, , 2] - m3.hat[, , 1],
                                   1, weighted.mean, w),
                             c(.5, .025, .975))
        }
      } else {
        m3.ame <- NULL
      }
    }
    
    ## Direct effects
    de.sim <- array(NA, c(J, M * S, N, 2L))
    exp.xb.sim <-  array(NA, c(J, N, 2L))
    for (c in 1:2) {
      X.tmp <- X
      if (type.t == "continuous") {
        X.tmp[, t.which] <- X.tmp[, t.which] + (c - 1.5) * shift
      } else {
        X.tmp[, t.which] <- c - 1
      }
      for (s in seq_len(M * S)) {
        X.tmp[, t.which.m1] <- m1.hat[s, , 3]
        if (!is.null(formula.m2))
          X.tmp[, t.which.m2] <- m2.hat[s, , 3]
        if (!is.null(formula.m3))
          X.tmp[, t.which.m3] <- m3.hat[s, , 3]
        for (j in seq_len(J)) {
          exp.xb.sim[j, , c] <- exp(b[j, s,] %*% t(X.tmp))
        }
        for (j in seq_len(J)) {
          denom <- apply(exp.xb.sim[, , c], 2, sum)
          de.sim[j, s, , c] <- exp.xb.sim[j, , c] / denom
        }
      }
    }
    de.sim <- de.sim[, , , 2] - de.sim[, , , 1] / abs(shift)
    de.sim <- apply(de.sim, c(1, 2), weighted.mean, w)
    de <- apply(de.sim, 1, quantile, c(.5, .025, .975))
    
    ## Indirect effects (Mediator 1)
    me1.sim <- array(NA, c(J, M * S, N, 2L))
    exp.xb.sim <-  array(NA, c(J, N, 2L))
    for (c in 1:2) {
      X.tmp <- X
      for (s in seq_len(M * S)) {
        X.tmp[, t.which.m1] <- m1.hat[s, , c]
        if (!is.null(formula.m2))
          X.tmp[, t.which.m2] <- m2.hat[s, , 3]
        if (!is.null(formula.m3))
          X.tmp[, t.which.m3] <- m3.hat[s, , 3]
        for (j in seq_len(J)) {
          exp.xb.sim[j, , c] <- exp(b[j, s,] %*% t(X.tmp))
        }
        for (j in seq_len(J)) {
          denom <- apply(exp.xb.sim[, , c], 2, sum)
          me1.sim[j, s, , c] <- exp.xb.sim[j, , c] / denom
        }
      }
    }
    me1.sim <- me1.sim[, , , 2] - me1.sim[, , , 1] / abs(shift)
    me1.sim <- apply(me1.sim, c(1, 2), weighted.mean, w)
    me1 <- apply(me1.sim, 1, quantile, c(.5, .025, .975))
    
    ## Indirect effects (Mediator 2)
    if (!is.null(formula.m2)) {
      me2.sim <- array(NA, c(J, M * S, N, 2L))
      exp.xb.sim <-  array(NA, c(J, N, 2L))
      for (c in 1:2) {
        X.tmp <- X
        for (s in seq_len(M * S)) {
          X.tmp[, t.which.m1] <- m1.hat[s, , 3]
          X.tmp[, t.which.m2] <- m2.hat[s, , c]
          if (!is.null(formula.m3))
            X.tmp[, t.which.m3] <- m3.hat[s, , 3]
          for (j in seq_len(J)) {
            exp.xb.sim[j, , c] <- exp(b[j, s,] %*% t(X.tmp))
          }
          for (j in seq_len(J)) {
            denom <- apply(exp.xb.sim[, , c], 2, sum)
            me2.sim[j, s, , c] <- exp.xb.sim[j, , c] / denom
          }
        }
      }
      me2.sim <- me2.sim[, , , 2] - me2.sim[, , , 1] / abs(shift)
      me2.sim <- apply(me2.sim, c(1, 2), weighted.mean, w)
      me2 <- apply(me2.sim, 1, quantile, c(.5, .025, .975))
    } else {
      me2 <- NULL
    }
    
    ## Indirect effects (Mediator 3)
    if (!is.null(formula.m3)) {
      me3.sim <- array(NA, c(J, M * S, N, 2L))
      exp.xb.sim <-  array(NA, c(J, N, 2L))
      for (c in 1:2) {
        X.tmp <- X
        for (s in seq_len(M * S)) {
          X.tmp[, t.which.m1] <- m1.hat[s, , 3]
          if (!is.null(formula.m2))
            X.tmp[, t.which.m2] <- m2.hat[s, , 3]
          X.tmp[, t.which.m3] <- m3.hat[s, , c]
          for (j in seq_len(J)) {
            exp.xb.sim[j, , c] <- exp(b[j, s,] %*% t(X.tmp))
          }
          for (j in seq_len(J)) {
            denom <- apply(exp.xb.sim[, , c], 2, sum)
            me3.sim[j, s, , c] <- exp.xb.sim[j, , c] / denom
          }
        }
      }
      me3.sim <- me3.sim[, , , 2] - me3.sim[, , , 1] / abs(shift)
      me3.sim <- apply(me3.sim, c(1, 2), weighted.mean, w)
      me3 <- apply(me3.sim, 1, quantile, c(.5, .025, .975))
    } else {
      me3 <- NULL
    }
    
    ## Total effects
    te.sim <- de.sim + me1.sim
    if (!is.null(formula.m2))
      te.sim <- te.sim + me2.sim
    if (!is.null(formula.m3))
      te.sim <- te.sim + me3.sim
    te <- apply(te.sim, 1, quantile, c(.5, .025, .975))
    
    ## Labeling
    dimnames(de) <- dimnames(te) <- dimnames(me1) <-
      list(c("Est.", "2.5%", "97.5%"), j.names)
    if (!is.null(formula.m2))
      dimnames(me2) <- list(c("Est.", "2.5%", "97.5%"), j.names)
    if (!is.null(formula.m3))
      dimnames(me3) <- list(c("Est.", "2.5%", "97.5%"), j.names)
    
  } else if (type.t == "categorical") {
    ## All counterfactual scenarios
    d.perm <- combn(n.cat, 2L)
    d.levels.1 <- t.levels[d.perm[1,]]
    d.levels.2 <- t.levels[d.perm[2,]]
    d.levels <- paste(d.levels.1, d.levels.2, sep = " vs ")
    d.perm <- d.perm - 1L
    D <- ncol(d.perm)
    
    ## Counterfactual predictions of the mediator
    m1.hat <- array(NA, c(D, M * S, N, 3L))
    m1.ame <- matrix(NA, D, 3L)
    if (!is.null(formula.m2)) {
      m2.hat <- array(NA, c(D, M * S, N, 3L))
      m2.ame <- matrix(NA, D, 3L)
    }
    if (!is.null(formula.m3)) {
      m3.hat <- array(NA, c(D,  M * S, N, 3L))
      m3.ame <- matrix(NA, D, 3L)
    }
    for (d in seq_len(D)) {
      for (c in 1:2) {
        ## Data
        Z1.tmp <- Z1
        Z1.tmp[, z1.which] <- 0
        Z1.tmp[, z1.which[d.perm[c, d]]] <- 1
        if (!is.null(formula.m2)) {
          Z2.tmp <- Z2
          Z2.tmp[, z2.which] <- 0
          Z2.tmp[, z2.which[d.perm[c, d]]] <- 1
        }
        if (!is.null(formula.m3)) {
          Z3.tmp <- Z3
          Z3.tmp[, z3.which] <- 0
          Z3.tmp[, z3.which[d.perm[c, d]]] <- 1
        }
        
        ## Predictions
        m1.hat[d, , , c] <- m1.sim %*% t(Z1.tmp)
        if (type.m1 == "binary")
          m1.hat[d, , , c] <- invlogit(m1.hat[d, , , c])
        if (c == 2) {
          m1.ame[d,] <- quantile(apply(m1.hat[d, , , 1] - m1.hat[d, , , 2],
                                       1, weighted.mean, w),
                                 c(.5, .025, .975))
          row.names(m1.ame) <- d.levels
          colnames(m1.ame) <- c("Est.", "2.5%", "97.5%")
          m1.hat[d, , , 3L] <- m1.sim %*% t(Z1)
        }
        if (!is.null(formula.m2)) {
          m2.hat[d, , , c] <- m2.sim %*% t(Z2.tmp)
          if (type.m2 == "binary")
            m2.hat[d, , , c] <- invlogit(m2.hat[d, , , c])
          if (c == 2) {
            m2.ame[d,] <-
              quantile(apply(m2.hat[d, , , 1] - m2.hat[d, , , 2],
                             1, weighted.mean, w),
                       c(.5, .025, .975))
            row.names(m2.ame) <- d.levels
            colnames(m2.ame) <- c("Est.", "2.5%", "97.5%")
            m2.hat[d, , , 3L] <- m2.sim %*% t(Z2)
          }
        } else {
          m2.ame <- NULL
        }
        if (!is.null(formula.m3)) {
          m3.hat[d, , , c] <- m3.sim %*% t(Z3.tmp)
          if (type.m3 == "binary")
            m3.hat[d, , , c] <- invlogit(m3.hat[d, , , c])
          if (c == 2) {
            m3.ame[d,] <-
              quantile(apply(m3.hat[d, , , 1] - m3.hat[d, , , 2],
                             1, weighted.mean, w),
                       c(.5, .025, .975))
            row.names(m3.ame) <- d.levels
            colnames(m3.ame) <- c("Est.", "2.5%", "97.5%")
            m3.hat[d, , , 3L] <- m3.sim %*% t(Z3)
          }
        } else {
          m3.ame <- NULL
        }
      }
    }
    
    ## Direct effects
    de.sim <- array(NA, c(D, J, M * S, N, 2L))
    exp.xb.sim <-  array(NA, c(D, J, N, 2L))
    
    for (d in seq_len(D)) {
      for (c in 1:2) {
        X.tmp <- X
        X.tmp[, t.which] <- 0
        X.tmp[, t.which[d.perm[c, d]]] <- 1
        for (s in seq_len(M * S)) {
          X.tmp[, t.which.m1] <- m1.hat[d, s, , 3]
          if (!is.null(formula.m2))
            X.tmp[, t.which.m2] <- m2.hat[d, s, , 3]
          if (!is.null(formula.m3))
            X.tmp[, t.which.m3] <- m3.hat[d, s, , 3]
          for (j in seq_len(J)) {
            exp.xb.sim[d, j, , c] <- exp(b[j, s, ] %*% t(X.tmp))
          }
          for (j in seq_len(J)) {
            denom <- apply(exp.xb.sim[d, , , c], 2, sum)
            de.sim[d, j, s, , c] <- exp.xb.sim[d, j, , c] / denom
          }
        }
      }
    }
    de.sim <- de.sim[, , , , 1] - de.sim[, , , , 2]
    de.sim <- apply(de.sim, 1:3, weighted.mean, w)
    de <- apply(de.sim, 1:2, quantile, c(.5, .025, .975))
    dimnames(de)[2:3] <- list(d.levels, j.names)
    
    ## Indirect effects (Mediator 1)
    me1.sim <- array(NA, c(D, J, M * S, N, 2L))
    exp.xb.sim <-  array(NA, c(D, J, N, 2L))
    for (d in seq_len(D)) {
      for (c in 1:2) {
        X.tmp <- X
        for (s in seq_len(M * S)) {
          X.tmp[, t.which.m1] <- m1.hat[d, s, , c]
          if (!is.null(formula.m2))
            X.tmp[, t.which.m2] <- m2.hat[d, s, , 3]
          if (!is.null(formula.m3))
            X.tmp[, t.which.m3] <- m3.hat[d, s, , 3]
          for (j in seq_len(J)) {
            exp.xb.sim[d, j, , c] <- exp(b[j, s, ] %*% t(X.tmp))
          }
          for (j in seq_len(J)) {
            denom <- apply(exp.xb.sim[d, , , c], 2, sum)
            me1.sim[d, j, s, , c] <- exp.xb.sim[d, j, , c] / denom
          }
        }
      }
    }
    me1.sim <- me1.sim[, , , , 1] - me1.sim[, , , , 2]
    me1.sim <- apply(me1.sim, 1:3, weighted.mean, w)
    me1 <- apply(me1.sim, 1:2, quantile, c(.5, .025, .975))
    dimnames(me1)[2:3] <- list(d.levels, j.names)
    
    ## Indirect effects (Mediator 2)
    if (!is.null(formula.m2)) {
      me2.sim <- array(NA, c(D, J, M * S, N, 2L))
      exp.xb.sim <-  array(NA, c(D, J, N, 2L))
      for (d in seq_len(D)) {
        for (c in 1:2) {
          X.tmp <- X
          for (s in seq_len(M * S)) {
            X.tmp[, t.which.m1] <- m1.hat[d, s, , 3]
            if (!is.null(formula.m2))
              X.tmp[, t.which.m2] <- m2.hat[d, s, , c]
            if (!is.null(formula.m3))
              X.tmp[, t.which.m3] <- m3.hat[d, s, , 3]
            for (j in seq_len(J)) {
              exp.xb.sim[d, j, , c] <- exp(b[j, s, ] %*% t(X.tmp))
            }
            for (j in seq_len(J)) {
              denom <- apply(exp.xb.sim[d, , , c], 2, sum)
              me2.sim[d, j, s, , c] <- exp.xb.sim[d, j, , c] / denom
            }
          }
        }
      }
      me2.sim <- me2.sim[, , , , 1] - me2.sim[, , , , 2]
      me2.sim <- apply(me2.sim, 1:3, weighted.mean, w)
      me2 <- apply(me2.sim, 1:2, quantile, c(.5, .025, .975))
      dimnames(me2)[2:3] <- list(d.levels, j.names)
    } else {
      me2 <- NULL
    }
    
    ## Indirect effects (Mediator 3)
    if (!is.null(formula.m3)) {
      me3.sim <- array(NA, c(D, J, M * S, N, 2L))
      exp.xb.sim <-  array(NA, c(D, J, N, 2L))
      for (d in seq_len(D)) {
        for (c in 1:2) {
          X.tmp <- X
          for (s in seq_len(M * S)) {
            X.tmp[, t.which.m1] <- m1.hat[d, s, , 3]
            if (!is.null(formula.m2))
              X.tmp[, t.which.m2] <- m2.hat[d, s, , 3]
            if (!is.null(formula.m3))
              X.tmp[, t.which.m3] <- m3.hat[d, s, , c]
            for (j in seq_len(J)) {
              exp.xb.sim[d, j, , c] <- exp(b[j, s, ] %*% t(X.tmp))
            }
            for (j in seq_len(J)) {
              denom <- apply(exp.xb.sim[d, , , c], 2, sum)
              me3.sim[d, j, s, , c] <- exp.xb.sim[d, j, , c] / denom
            }
          }
        }
      }
      me3.sim <- me3.sim[, , , , 1] - me3.sim[, , , , 2]
      me3.sim <- apply(me3.sim, 1:3, weighted.mean, w)
      me3 <- apply(me3.sim, 1:2, quantile, c(.5, .025, .975))
      dimnames(me3)[2:3] <- list(d.levels, j.names)
    } else {
      me3 <- NULL
    }
    
    ## Total effects
    te.sim <- de.sim + me1.sim
    if (!is.null(formula.m2))
      te.sim <- te.sim + me2.sim
    if (!is.null(formula.m3))
      te.sim <- te.sim + me3.sim
    te <- apply(te.sim, 1:2, quantile, c(.5, .025, .975))
    dimnames(te)[2:3] <- list(d.levels, j.names)
  }
  
  ## Value
  print("4 Returning Output")
  out <- list(
    te = te,
    de = de,
    me1 = me1,
    me2 = me2,
    me3 = me3,
    m1.ame = m1.ame,
    m2.ame = m2.ame,
    m3.ame = m3.ame
  )
  return(out)
}