#' Assigning clonality status to every SNV
#' @description
#' Assigns clonality status to every SNV based on the variant allele frequency
#' distribution. The function uses maximum a posteriori assignment of single
#' variants to either "subclonal", "clonal", "early clonal" or "late clonal"
#' (if distinguishable). The likelihood is computed according to a binomial
#' distribution; prior probabilities are empirically determined based on the
#' relative SNV burden per clonal state.
#' @param nbObj combined SNV and CNV information as generated by
#' \code{\link{nbImport}}.
#' @param mrcaObj clonal SNV counts stratified by copy number as generated by
#' \code{\link{MRCA}}.
#' @param ID sample name.
#' @param purity tumor cell content.
#' @param driver.file optional, path to file with "chrom", "snv_start", "ref",
#' "alt", "gene" column containing known driver SNVs.
#' @param ref.build Reference genome. Default `hg19`. Can be `hg18`, `hg19` or
#' `hg38`.
#' @return a data.table with per-SNV clonality assignment
#' @examples
#' # Example using variants associated with specific SBS mutational signatures
#' # from vcf file
#' snvs <- system.file("extdata", "NBE15",
#'     "snvs_NBE15_somatic_snvs_conf_8_to_10.vcf",
#'     package = "LACHESIS"
#' )
#' s_data <- readVCF(vcf = snvs, vcf.source = "dkfz")
#' aceseq_cn <- system.file("extdata", "NBE15",
#'     "NBE15_comb_pro_extra2.51_1.txt",
#'     package = "LACHESIS"
#' )
#' c_data <- readCNV(aceseq_cn)
#' sig.filepath <- system.file("extdata",
#'     "NBE15_Decomposed_MutationType_Probabilities.txt",
#'     package = "LACHESIS"
#' )
#' nb <- nbImport(
#'     cnv = c_data, snv = s_data, purity = 1, ploidy = 2.51,
#'     sig.assign = TRUE, ID = "NBE15", sig.file = sig.filepath
#' )
#' cl_muts <- clonalMutationCounter(nb)
#' norm_muts <- normalizeCounts(cl_muts)
#' mrca <- MRCA(norm_muts)
#' estimateClonality(nbObj = nb, mrcaObj = mrca, ID = "NBE15", purity = 1)
#'
#' @export

estimateClonality <- function(nbObj = NULL, mrcaObj = NULL, ID = NULL,
                              purity = NULL, driver.file = NULL,
                              ref.build = "hg19") {
    A <- A_time <- B <- B_time <- Clonality <- Sample <- Signature <- TCN <-
        alt <- chrom <- cn_end <- cn_start <- known_driver_gene <- p_c <-
        p_ec <- p_lc <- p_sc <- ref <- snv_start <- t_alt_count <- t_depth <-
        t_vaf <- . <- NULL

    if (is.null(nbObj) || is.null(mrcaObj)) {
        stop(
            "Missing input. Please provide the output generated by nbImport and
         clonalMutationCounter."
        )
    }

    if (is.null(purity)) {
        stop("Please specify tumor purity.")
    }

    ref.build <- match.arg(
        arg = ref.build, choices = c("hg19", "hg18", "hg38"),
        several.ok = FALSE
    )

    snvClonality <- merge(nbObj, mrcaObj[, .(
        chrom, TCN, A, B, p_sc, p_lc, p_ec,
        p_c, A_time, B_time
    )],
    by = c("chrom", "TCN", "A", "B"), all.x = TRUE
    )

    snvClonality[, Clonality := {
        CN <- as.numeric(TCN)
        VAF <- as.numeric(t_vaf)
        depth <- as.numeric(t_depth)
        alt <- as.numeric(t_alt_count)

        if (is.na(CN) || is.na(VAF) || CN == 0 || depth == 0) {
            "n.d."
        } else {
            expectedVAFs <- .expectedClVAFAB(A, B, purity)

            expected_SC <- expectedVAFs[1] * 0.5
            expected_C <- expectedVAFs[1]
            expected_postcnv <- expectedVAFs[1]
            expected_precnv <- expectedVAFs[-1]


            lik_SC <- dbinom(alt, size = depth, prob = expected_SC)
            lik_C <- dbinom(alt, size = depth, prob = expected_C)
            lik_postcnv <- dbinom(alt, size = depth, prob = expected_postcnv)
            lik_precnv <- sum(dbinom(alt, size = depth, prob = expected_precnv))

            priors <- c(
                SC = ifelse(is.na(p_sc), 0.01, p_sc),
                Postcnv = ifelse(is.na(p_lc), 0.01, p_lc),
                Precnv = ifelse(is.na(p_ec), 0.01, p_ec),
                C = ifelse(is.na(p_c), 0.01, p_c)
            )
            priors <- priors / sum(priors)

            likelihoods <- c(
                SC = lik_SC, Postcnv = lik_postcnv,
                Precnv = lik_precnv, C = lik_C
            )
            posteriors <- priors * likelihoods
            post_class <- names(posteriors)[which.max(posteriors)]
        }
    }, by = seq_len(nrow(snvClonality))]

    snvClonality <- data.table(Sample = ID, snvClonality)

    if (is.null(driver.file)) {
        if (ref.build == "hg19") {
            driverMutations <- data.table::fread(
                system.file("extdata", "cancerhotspots_v2_GRCh37_adapted.tsv",
                    package = "LACHESIS"
                )
            )
        } else if (ref.build == "hg38") {
            driverMutations <- data.table::fread(
                system.file("extdata", "cancerhotspots_v2_GRCh38_adapted.tsv",
                    package = "LACHESIS"
                )
            )
        } else if (ref.build == "hg18") {
            driverMutations <- NULL
        }
    } else {
        driverMutations <- data.table::fread(driver.file)
    }

    if (!is.null(driverMutations)) {
        snvClonality <- merge(
            snvClonality,
            driverMutations,
            by = c("chrom", "snv_start", "ref", "alt"),
            all.x = TRUE
        )
    }

    if (!"Signature" %in% colnames(snvClonality)) {
        snvClonality[, Signature := NA_character_]
    }

    data.table::setnames(snvClonality, old = "gene", new = "known_driver_gene")
    snvClonality <- snvClonality[, .(
        chrom, snv_start, ref, alt, Sample, TCN, A, B,
        cn_start, cn_end, t_vaf, Signature, A_time,
        B_time, Clonality, known_driver_gene
    )]

    return(snvClonality)
}

.expectedClVAFAB <- function(A, B, purity) {
    unique(c(1, B, A) * purity / (purity * (A + B) + 2 * (1 - purity)))
}

#' Plotting assigned clonality status for every SNV by chromosome
#' @description
#' Visualizes results from  \code{\link{estimateClonality}}.
#' @param snvClonality output generated from \code{\link{estimateClonality}}.
#' @param nbObj output generated from \code{\link{nbImport}}.
#' @param sig.assign Logical. If TRUE, clonality status distribution will be
#' plotted for each SBS signature.
#' @param output.file optional, will save the mutational signatures stratified
#' by Clonality.
#' @param ... further arguments and parameters passed to other LACHESIS
#' functions.
#' @return graphs with clonality status of SNVs per chromosome and if specified,
#'  stratified by signature
#' @examples
#' # Example using variants associated with specific SBS mutational signatures
#' # from vcf file
#' snvs <- system.file("extdata", "NBE15",
#'     "snvs_NBE15_somatic_snvs_conf_8_to_10.vcf",
#'     package = "LACHESIS"
#' )
#' s_data <- readVCF(vcf = snvs, vcf.source = "dkfz")
#' aceseq_cn <- system.file("extdata", "NBE15",
#'     "NBE15_comb_pro_extra2.51_1.txt",
#'     package = "LACHESIS"
#' )
#' c_data <- readCNV(aceseq_cn)
#' sig.filepath <- system.file("extdata",
#'     "NBE15_Decomposed_MutationType_Probabilities.txt",
#'     package = "LACHESIS"
#' )
#' nb <- nbImport(
#'     cnv = c_data, snv = s_data, purity = 1, ploidy = 2.51,
#'     sig.assign = TRUE, ID = "NBE15", sig.file = sig.filepath
#' )
#' cl_muts <- clonalMutationCounter(nb)
#' norm_muts <- normalizeCounts(cl_muts)
#' mrca <- MRCA(norm_muts)
#' snvClonality <- estimateClonality(
#'     nbObj = nb, mrcaObj = mrca,
#'     ID = "NBE15", purity = 1
#' )
#' plotClonality(snvClonality, nbObj = nb, sig.assign = TRUE)
#'
#' @import ggplot2
#' @importFrom stats setNames
#' @export

plotClonality <- function(snvClonality = snvClonality, nbObj = NULL,
                          sig.assign = FALSE, output.file = NULL, ...) {
    A <- B <- Clonality <- Signature <- chrom <- chrom_AB <- . <- NULL

    if (is.null(snvClonality)) {
        stop(
            "Missing input. Please provide output generated by estimateClonality."
        )
    }
    if (is.null(nbObj)) {
        stop(
            "Missing input. Please provide output generated by nbImport."
        )
    }

    if (!is.null(output.file)) {
        pdf(file = output.file, width = 8, height = 6)
    }

    snvClonality[, chrom_AB := ifelse(
        !is.na(A) & !is.na(B),
        paste0("chr", chrom, " (", A, ":", B, ")"),
        paste0("chr", chrom)
    )]

    chrom_levels <- unique(snvClonality$chrom_AB[order(as.numeric(sub(
        "chr", "",
        snvClonality$chrom
    )))])
    snvClonality <- snvClonality[Clonality %in%
        c("Precnv", "Postcnv", "SC", "C") &
        !is.na(chrom_AB)]
    snvClonality[, chrom_AB := factor(chrom_AB, levels = chrom_levels)]
    chrom_split <- split(chrom_levels, ceiling(seq_along(chrom_levels) /
        (length(chrom_levels) / 2)))

    p1_top <- ggplot(
        snvClonality[chrom_AB %in% chrom_split[[1]]],
        aes(
            x = chrom_AB,
            fill = factor(Clonality, levels = c("Precnv", "Postcnv", "C", "SC"))
        )
    ) +
        geom_bar(position = position_dodge(width = 0.8), width = 0.7) +
        labs(
            title = "SNV Clonality per Chromosomal Segment",
            x = "Chromosome",
            y = "Number of SNVs",
            fill = "Clonality"
        ) +
        scale_fill_manual(
            values = c(
                "Precnv" = "#66c2a5",
                "Postcnv" = "#fc8d62",
                "C" = "#8da0cb",
                "SC" = "#e78ac3"
            ),
            labels = c(
                "Precnv" = "Clonal\n- Pre-CNV",
                "Postcnv" = "Clonal\n- Post-CNV",
                "C" = "Clonal\n-NOS",
                "SC" = "Subclonal"
            )
        ) +
        theme_classic() +
        theme(
            axis.text.x = element_text(angle = 45, hjust = 1),
            legend.key.height = unit(1.5, "lines")
        )

    p1_bottom <- ggplot(
        snvClonality[chrom_AB %in% chrom_split[[2]]],
        aes(
            x = chrom_AB,
            fill = factor(Clonality, levels = c("Precnv", "Postcnv", "C", "SC"))
        )
    ) +
        geom_bar(position = position_dodge(width = 0.8), width = 0.7) +
        labs(
            title = NULL,
            x = "Chromosome",
            y = "Number of SNVs",
            fill = "Clonality"
        ) +
        scale_fill_manual(
            values = c(
                "Precnv" = "#66c2a5",
                "Postcnv" = "#fc8d62",
                "C" = "#8da0cb",
                "SC" = "#e78ac3"
            ),
            labels = c(
                "Precnv" = "Clonal\n- Pre-CNV",
                "Postcnv" = "Clonal\n- Post-CNV",
                "C" = "Clonal\n-NOS",
                "SC" = "Subclonal"
            )
        ) +
        theme_classic() +
        theme(
            axis.text.x = element_text(angle = 45, hjust = 1),
            legend.key.height = unit(1.5, "lines")
        )

    gridExtra::grid.arrange(p1_top, p1_bottom, ncol = 1)

    if (sig.assign == TRUE) {
        sig.colors <- attr(nbObj, "sig.colors")
        chrom_split <- split(chrom_levels, ceiling(seq_along(chrom_levels) / 4))

        for (chrom_page in chrom_split) {
            p2 <- ggplot(
                snvClonality[chrom_AB %in% chrom_page],
                aes(
                    x = factor(Clonality,
                        levels = c("Precnv", "Postcnv", "C", "SC")
                    ),
                    fill = Signature
                )
            ) +
                geom_bar(position = "stack", width = 0.7) +
                facet_wrap(~chrom_AB, scales = "free_y", ncol = 2, nrow = 2) +
                labs(
                    title = "SNV Timing Stratified by Mutational Signature",
                    x = "Clonality",
                    y = "Number of SNVs",
                    fill = "Signature"
                ) +
                scale_x_discrete(labels = c(
                    "Precnv" = "Clonal\n- Pre-CNV",
                    "Postcnv" = "Clonal\n- Post-CNV",
                    "C" = "Clonal\n-NOS",
                    "SC" = "Subclonal"
                )) +
                scale_fill_manual(values = sig.colors) +
                theme_classic() +
                theme(
                    axis.text.x = element_text(angle = 45, hjust = 1),
                    strip.text = element_text(face = "bold"),
                    strip.background = element_blank()
                )

            print(p2)
        }
    }

    if (!is.null(output.file)) {
        dev.off()
    }
}
