# sketch_plate ----------------------------------------


#' Make an overview of plate layout, with colored groups
#'
#' @description A heatmap style ggplot figure with each well
#' labeled with a color for each group
#'
#' @param xfplate This the `raw_data` or the `rate_data` tibble
#' that is generated by the `revive_xfplate()` function
#' @param reorder_legend either `TRUE` or `FALSE`. When `TRUE` the
#' groups are ordered based on the number in the character string of
#' the group. It also adds a "__00" after each character string
#' to make the forcats::refactor(group, parse_number(group)) work.
#'
#' @return a ggplot object of a 96 well plate with the group
#' layout
#' @export
#' @importFrom ggplot2 ggplot aes geom_tile scale_fill_manual
#' scale_x_continuous scale_y_discrete %+replace% theme_bw theme
#' element_blank element_line element_text
#' @examples
#' system.file("extdata",
#'     "20191219_SciRep_PBMCs_donor_A.xlsx",
#'     package = "seahtrue"
#' ) |>
#'     revive_xfplate() |>
#'     purrr::pluck("raw_data", 1) |>
#'     sketch_plate(reorder_legend = TRUE)
sketch_plate <- function(xfplate, reorder_legend = FALSE) {
    theme_htmp <- function() {
        theme_bw(base_size = 15) %+replace%
            theme(
                panel.grid.minor.x = element_blank(),
                panel.grid.major.x = element_blank(),
                panel.grid.minor.y = element_blank(),
                panel.grid.major.y = element_blank(),
                panel.border = element_blank(),
                axis.ticks.x = element_line(),
                axis.ticks.y = element_line(),
                axis.title.x = element_blank(),
                axis.title.y = element_blank(),
                axis.text.x = element_text(
                    size = rel(1.3),
                    hjust = 0.5,
                    vjust = 0
                ),
                axis.text.y = element_text(
                    size = rel(1.3),
                    hjust = 0.5,
                    vjust = 0.5
                ),
                legend.title = element_blank(),
                legend.justification = c(0, 1)
            )
    }


    xfplate_sep <- xfplate %>%
        slice(1, .by = c(.data$well)) %>%
        tidyr::separate(.data$well,
            into = c("row", "column"),
            sep = 1,
            convert = TRUE
        )


    number_of_groups <- nlevels(as.factor(xfplate_sep$group))

    if (number_of_groups > 12) {
        cli::cli_alert_info(
            glue::glue("There are more than 12 grous, which likely
                 messes up the output plot."),
            wrap = TRUE
        )
    }

    groupColors3 <-
        grDevices::colorRampPalette(
            RColorBrewer::brewer.pal(8, "BrBG")
        )(number_of_groups)

    # relevel workaround
    if (reorder_legend) {
        xfplate_sep <- relevel_group_names(xfplate_sep)
    }

    h <- xfplate_sep %>%
        ggplot(aes(x = .data$column, y = .data$row)) +
        geom_tile(aes(fill = .data$group),
            color = "grey50", show.legend = TRUE
        ) +
        scale_fill_manual(values = groupColors3) +
        scale_x_continuous(
            limits = c(0.5, 13),
            breaks = c(seq_len(12)),
            position = "top"
        ) +
        scale_y_discrete(
            limits =
                rev(levels(
                    as.factor(xfplate_sep$row)
                ))
        ) +
        theme_htmp()

    return(h)
}


# sketch_rate -----------------------------------------

#' Generate a plot for the rate data
#'
#' @description The sketch_rate() function uses the rate_data
#' from the generated output from the revive_xfplate() function.
#' The injection info is annotated in the plot, using the information
#' form the injections provided in the original experiment. Several
#' options are available to plot either ECAR/OCR or normalize the
#' data with the values from the normalization cells in the .xlsx file.
#'
#' A number of validations are performed to check whether the data
#' can be plotted and whether the layout of the plot will not be
#' ruined...
#'
#' @param xf_rate The `rate_data` tibble as generated by `revive_plate`
#' @param param Either "OCR" or "ECAR"
#' @param normalize Either TRUE or FALSE
#' @param normalize_unit any string that will be pasted in the y-axis label
#' when normalize = TRUE
#' @param take_group_mean  Either TRUE or FALSE
#' @param reorder_legend  Either TRUE or FALSE. When `TRUE` the
#' groups are ordered based on the number in the character string of
#' the group. It also adds a "__00" after each character string
#' to make the forcats::refactor(group, parse_number(group)) work.
#'
#' @return a ggplot object
#' @export
#' @importFrom ggplot2 ggplot aes geom_line geom_ribbon geom_hline geom_vline 
#' annotate scale_x_continuous scale_y_discrete labs %+replace% 
#' theme_light theme rel scale_y_continuous
#' @importFrom rlang .data
#' @examples
#' system.file("extdata",
#'     "20191219_SciRep_PBMCs_donor_A.xlsx",
#'     package = "seahtrue"
#' ) |>
#'     revive_xfplate() |>
#'     purrr::pluck("rate_data", 1) |>
#'     sketch_rate(
#'         param = "OCR",
#'         reorder_legend = TRUE
#'     )
#'
#' system.file("extdata",
#'     "20191219_SciRep_PBMCs_donor_A.xlsx",
#'     package = "seahtrue"
#' ) |>
#'     revive_xfplate() |>
#'     purrr::pluck("rate_data", 1) |>
#'     sketch_rate(
#'         param = "OCR",
#'         take_group_mean = FALSE,
#'         reorder_legend = TRUE
#'     )
#'
#' system.file("extdata",
#'     "20191219_SciRep_PBMCs_donor_A.xlsx",
#'     package = "seahtrue"
#' ) |>
#'     revive_xfplate() |>
#'     purrr::pluck("rate_data", 1) |>
#'     sketch_rate(
#'         param = "ECAR",
#'         normalize = TRUE,
#'         take_group_mean = TRUE,
#'         reorder_legend = TRUE
#'     )
#'
#' system.file("extdata",
#'     "20191219_SciRep_PBMCs_donor_A.xlsx",
#'     package = "seahtrue"
#' ) |>
#'     revive_xfplate() |>
#'     purrr::pluck("rate_data", 1) |>
#'     sketch_rate(
#'         param = "ECAR",
#'         normalize = TRUE,
#'         take_group_mean = FALSE,
#'         reorder_legend = TRUE
#'     )
sketch_rate <- function(xf_rate,
                        param = "OCR",
                        normalize = FALSE,
                        normalize_unit = "10000 cells",
                        take_group_mean = TRUE,
                        reorder_legend = FALSE) {
    # validate data
    was_background_corrected <-
        xf_rate %>%
        attributes() %>%
        purrr::pluck("was_background_corrected")

    if (is.null(was_background_corrected)) {
        cli::cli_abort(
            glue::glue("The input data is likely not generated by the
                 seahtrue::revive_xflpate() function. Please check the
                 input data that was given to the argument xf_rate."),
            wrap = TRUE
        )
    }


    number_of_groups <-
        xf_rate %>%
        pull(.data$group) %>%
        unique() %>%
        length()

    largest_group_string_size <-
        xf_rate %>%
        pull(.data$group) %>%
        unique() %>%
        stringr::str_count() %>%
        max()

    if (!param %in% c("OCR", "ECAR")) {
        cli::cli_abort(
            glue::glue("The argument, {param}, is not correct. It
                 should either be OCR or ECAR (and type should be
                 character)"),
            wrap = TRUE
        )
    }

    if (largest_group_string_size > 40) {
        cli::cli_alert_info(
            glue::glue("At least one of the group names is huge, which likely
                 messes up the output plot."),
            wrap = TRUE
        )
    }

    if (number_of_groups > 12) {
        cli::cli_alert_info(
            glue::glue("There are more than 12 groups, which likely
                 messes up the output plot."),
            wrap = TRUE
        )
    }

    if (!validate::all_complete(xf_rate$cell_n)) {
        cli::cli_abort(
            glue::glue("There are NAs in the n_cell column"),
            wrap = TRUE
        )
    }

    if ("Background" %in% xf_rate$group &&
        was_background_corrected && normalize == TRUE) {
        xf_rate <- xf_rate %>%
            filter(.data$group != "Background")
    }

    if (any(xf_rate$cell_n == 0) && normalize) {
        wells_with_zero <- xf_rate %>%
            filter(.data$cell_n == 0) %>%
            pull(.data$well) %>%
            unique()

        if (length(wells_with_zero) > 15) {
            cli::cli_abort(
                glue::glue("There are more than 15 wells where the 
                     cell_n value is zero."),
                wrap = TRUE
            )
        }

        cli::cli_abort(
            glue::glue("The following wells have a cell_n value of
                 zero: {glue_collapse(wells_with_zero,  sep = ', ')}. 
                 In those cases the normalization cannot be performed and
                 the plot cannot be generated unfortunately. Please make sure
                 those wells have cell_n values or set normalize to FALSE."),
            wrap = TRUE
        )
    }

    # evaluate reorder_legend
    if (reorder_legend) {
        xf_rate <- xf_rate %>% relevel_group_names()
    }

    # evaluate param
    if (param == "OCR") {
        param_to_plot <- "OCR_wave_bc"
    }
    if (param == "ECAR") {
        param_to_plot <- "ECAR_wave_bc"
    }

    # evaluate normalize and unit
    if (normalize && param == "OCR") {
        xf_rate <- xf_rate %>%
            select(everything(), my_param = all_of(param_to_plot)) %>%
            mutate(my_param = .data$my_param / .data$cell_n)
        y_label <- paste0(param, " (pmol/min/", normalize_unit, ")")
    }
    if (!normalize && param == "OCR") {
        xf_rate <- xf_rate %>%
            select(everything(), my_param = all_of(param_to_plot))
        y_label <- paste0(param, " (pmol/min)")
    }
    if (normalize && param == "ECAR") {
        xf_rate <- xf_rate %>%
            select(everything(), my_param = all_of(param_to_plot)) %>%
            mutate(my_param = .data$my_param / .data$cell_n)
        y_label <- paste0(param, " (mpH/min/", normalize_unit, ")")
    }
    if (!normalize && param == "ECAR") {
        xf_rate <- xf_rate %>%
            select(everything(), my_param = all_of(param_to_plot))
        y_label <- paste0(param, " (mpH/min)")
    }

    # make plot
    if (!take_group_mean) {
        p <- xf_rate %>%
            plot_line_per_well( "my_param", y_label)
    }

    if (take_group_mean) {
        p <- xf_rate %>%
            plot_ribbon_per_meas_and_group( "my_param", y_label)
    }

    return(p)
}



#' Combine multiple revived xf plates into one plot for raw data
#'
#' @description In this plot the O2, pH, or its emission
#' value at the very first measurement point plotted for all
#' wells from all xfplates that are provided to the function.
#'
#'
#' @param my_df a tibble generated by glue_xfplates() with
#' for each row representing a single xf experiment
#' @param param either "O2_mmHg", "pH", "O2_em_corr" or
#' "pH_em_corr
#'
#' @return a ggplot object
#' @export
#' @importFrom ggplot2 geom_jitter scale_discrete_manual
#' facet_wrap
#'
#' @examples
#' suppressMessages(
#'     c(
#'         system.file("extdata",
#'             "20191219_SciRep_PBMCs_donor_A.xlsx",
#'             package = "seahtrue"
#'         ),
#'         system.file("extdata",
#'             "20191219_SciRep_PBMCs_donor_A.xlsx",
#'             package = "seahtrue"
#'         )
#'     ) |>
#'         glue_xfplates(arg_is_folder = FALSE) |>
#'         sketch_assimilate_raw(param = "O2_mmHg")
#' )
sketch_assimilate_raw <- function(my_df, param = "O2_mmHg") {
    theme_BF <- function() {
        theme_bw(base_size = 18) %+replace%
            theme( # panel.grid.minor.x = element_blank(),
                panel.grid.major.x = element_blank(),
                # panel.grid.minor.y = element_blank(),
                panel.grid.major.y = element_blank(),
                panel.border = element_blank(),
                axis.ticks.x = element_line(),
                axis.ticks.y = element_line(),
                axis.line.y = element_line(),
                axis.line.x = element_line(),
                plot.title = element_text(size = rel(0.9)),
                plot.subtitle = element_text(size = rel(0.85)),
                legend.text = element_text(size = rel(0.7)),
                legend.title = element_text(size = rel(0.7))
            )
    }

    if (!param %in% c(
        "O2_em_corr", "pH_em_corr",
        "O2_mmHg", "pH"
    )) {
        cli::cli_abort(
            glue::glue("The input parameter is not generated by the
                 seahtrue::revive_xflpate() function or is not 
                 O2_em_corr, pH_em_corr, O2_mmHg or pH. Please check the
                 input data that was given to the argument param"),
            wrap = TRUE
        )
    }


    if (param == "O2_em_corr") {
        targetEMS <- 12500
        x_title <- "emission (AU)"
    }

    if (param == "pH_em_corr") {
        targetEMS <- 30000
        x_title <- "emission (AU)"
    }

    if (param == "O2_mmHg") {
        targetEMS <- 151.7
        x_title <- "O2 (mmHg)"
    }

    if (param == "pH") {
        targetEMS <- 7.4
        x_title <- "pH"
    }


    my_df %>%
        dplyr::select(.data$plate_id, .data$raw_data) %>%
        tidyr::unnest(.data$raw_data) %>%
        dplyr::select(.data$plate_id, .data$measurement, .data$well, .data$group, 
                      emission = all_of(param)) %>%
        dplyr::slice(which.min(.data$measurement), .by = c(.data$plate_id, .data$well)) %>%
        dplyr::mutate(group_id = case_when(
          .data$group == "Background" ~ "background",
          .default = "sample"
        )) %>%
        ggplot(aes(x = .data$emission, y = .data$plate_id, group = .data$plate_id)) +
        ggridges::geom_density_ridges(aes(point_color = .data$group_id),
            jittered_points = TRUE,
            position = ggridges::position_points_jitter(width = 0.05, 
                                                        height = 0),
            point_shape = "|",
            point_size = 4, point_alpha = 0.8, alpha = 0.7
        ) +
        geom_vline(xintercept = targetEMS, linetype = "dashed", 
                   color = "#D16103") +
        scale_discrete_manual(
            aesthetics =
                "point_color",
            values = c("darkred", "darkblue")
        ) +
        labs(x = x_title) +
        theme_BF()
}

#' Combine multiple revived xf plates into one plot for
#' rate data
#' @description In this plot the OCR or ECAR is plotted per
#' group for each plate in a faceted gpgplot
#'
#' @param my_df a tibble generated by glue_xfplates() with
#' for each row representing a single xf experiment
#' @param param either "OCR" or "ECAR"
#' @param my_measurements the measurements that needs to be
#' in the plot. For example, c(3,6,7,12) for a typical
#' mito stress test.
#'
#' @return a ggplot object
#' @export
#' @importFrom ggplot2 scale_color_brewer
#' @importFrom stats quantile
#' @importFrom utils head tail
#' @examples
#' suppressMessages(
#'     c(
#'         system.file("extdata",
#'             "20191219_SciRep_PBMCs_donor_A.xlsx",
#'             package = "seahtrue"
#'         ),
#'         system.file("extdata",
#'             "20191219_SciRep_PBMCs_donor_A.xlsx",
#'             package = "seahtrue"
#'         )
#'     ) |>
#'         glue_xfplates(arg_is_folder = FALSE) |>
#'         sketch_assimilate_rate(
#'             param = "OCR",
#'             my_measurements = c(3, 4, 9, 12)
#'         )
#' )
sketch_assimilate_rate <- function(my_df,
                                   param = "OCR",
                                   my_measurements = c(3, 6, 7, 12)) {
    theme_BF <- function() {
        theme_bw(base_size = 18) %+replace%
            theme( # panel.grid.minor.x = element_blank(),
                panel.grid.major.x = element_blank(),
                # panel.grid.minor.y = element_blank(),
                panel.grid.major.y = element_blank(),
                panel.border = element_blank(),
                axis.ticks.x = element_line(),
                axis.ticks.y = element_line(),
                axis.line.y = element_line(),
                axis.line.x = element_line(),
                plot.title = element_text(size = rel(0.9)),
                plot.subtitle = element_text(size = rel(0.85)),
                legend.text = element_text(size = rel(0.7)),
                legend.title = element_text(size = rel(0.7))
            )
    }

    if (!param %in% c("OCR", "ECAR")) {
        cli::cli_abort(
            glue::glue("The input parameter is not OCR or ECAR. 
                 Please check the input data that was 
                 given to the argument param"),
            wrap = TRUE
        )
    }

    if (param == "OCR") {
        y_title <- "OCR (pmol/min)"
        param <- "OCR_wave_bc"
    }

    if (param == "ECAR") {
        y_title <- "ECAR (mpH/min)"
        param <- "ECAR_wave_bc"
    }


    my_df %>%
        dplyr::select(.data$plate_id, .data$rate_data) %>%
        tidyr::unnest(.data$rate_data) %>%
        dplyr::select(.data$plate_id, .data$measurement, .data$injection,
                      .data$interval, .data$well, .data$group, .data$cell_n,
                      .data$flagged_well,
            rate = all_of(param)
        ) %>%
        dplyr::filter(!.data$flagged_well) %>%
        dplyr::filter(.data$group != "Background") %>%
        dplyr::filter(.data$measurement %in% my_measurements) %>%
        dplyr::mutate(
            group_id =
                paste0(.data$plate_id, "_", .data$group)
        ) %>%
        relevel_group_names() %>%
        ggplot(aes(
            x = .data$group,
            y = .data$rate,
            color = as.factor(.data$interval)
        )) +
        scale_color_brewer(palette = "Set1") +
        geom_jitter() +
        labs(
            x = "",
            y = y_title,
            color = "interval"
        ) +
        theme_BF() +
        theme(
            axis.text.x =
                element_text(
                    size = rel(0.8),
                    angle = 50,
                    vjust = 0.9, hjust = 1
                )
        ) +
        facet_wrap(~plate_id)
}




# relevel function ------------------------------------

relevel_group_names <- function(my_df) {
    my_df <- my_df %>%
        mutate(group = dplyr::case_when(
            !stringr::str_detect(.data$group, ".*[0-9].*") ~
                paste0("00_", .data$group),
            .default = .data$group
        )) %>%
        mutate(
            group =
                forcats::fct_reorder(
                  .data$group,
                    readr::parse_number(.data$group)
                )
        )

    if ("00_Background" %in% my_df$group) {
        my_df <- my_df %>%
            mutate(
                group =
                    forcats::fct_recode(.data$group,
                        "Background" = "00_Background"
                    )
            )
    }

    return(my_df)
}

# ggplot functions ------------------------------------

plot_ribbon_per_meas_and_group <- function(df, var, y_title) {
    theme_ribbon <- function() {
        theme_light(base_size = 20) %+replace%
            theme(
                panel.grid.minor.x = element_blank(),
                panel.grid.major.x = element_blank(),
                panel.grid.minor.y = element_blank(),
                panel.grid.major.y = element_line(size = 0.5, 
                                                  linetype = "dashed"),
                panel.border = element_blank(),
                # axis.ticks.x = element_blank(),
                axis.ticks.y = element_blank(),
                axis.line = element_line(size = 0.6),
                plot.title = element_text(hjust = 0, color = "black", 
                                          size = rel(1)),
                legend.text = element_text(size = rel(0.6)),
                legend.title.align = 0,
                legend.title = element_text(size = rel(0.8))
            )
    }

    max_y <- df %>%
        select(everything(), param = all_of(var)) %>%
        pull(.data$param) %>%
        max()

    interval_end <- df %>%
        slice(1, .by = .data$measurement) %>%
        slice_tail(n = 1, by = .data$interval) %>%
        pull(.data$time_wave) %>%
        head(-1)

    interval_start <- df %>%
        slice(1, .by = .data$measurement) %>%
        slice_head(n = 1, by = .data$interval) %>%
        pull(.data$time_wave) %>%
        tail(-1)

    xintercept_inj <- (interval_start - interval_end) / 2 + interval_end

    label_inj <- df %>%
        pull(.data$injection) %>%
        unique() %>%
        tail(-1)

    quantiles_df <- df %>%
        select(everything(), param = all_of(var)) %>%
        summarise(
            q = list(quantile(.data$param)),
            .by = c(.data$time_wave, .data$group)
        ) %>%
        tidyr::unnest_wider(q)

    names(quantiles_df) <- make.names(names(quantiles_df))
    names(quantiles_df) <- gsub("\\.", "", names(quantiles_df))

    plot <- quantiles_df %>%
        mutate(
            whisker_up = .data$X100,
            whisker_down = .data$X0
        ) %>%
        ggplot(aes(x = .data$time_wave, group = .data$group)) +
        geom_ribbon(aes(ymin = .data$X50 - (.data$X50 - .data$X25), 
                        ymax = .data$X50 + (.data$X75 - .data$X50), 
                        fill = .data$group), alpha = 0.5) +
        geom_line(aes(y = .data$X50, color = .data$group), 
                  linewidth = 1.5, linetype = "solid") +
        geom_hline(yintercept = 0, linetype = "dashed", 
                   linewidth = 0.4, color = "#D16103") +
        geom_vline(
            xintercept = xintercept_inj,
            color = "grey40",
            linetype = "dashed",
        ) +
        annotate("text",
            x = xintercept_inj, y = max_y,
            label = label_inj, color = "grey40",
            hjust = 1, vjust = -0.4, size = 4, angle = 90
        ) +
        colorspace::scale_color_discrete_divergingx(
          palette = "Geyser", rev = TRUE) +
        colorspace::scale_fill_discrete_divergingx(
          palette = "Geyser", rev = TRUE) +
        labs(
            y = y_title,
            x = "time (minutes)"
        ) +
        scale_y_continuous(
            labels =
                scales::label_scientific()
        ) +
        theme_ribbon()

    return(plot)
}

plot_line_per_well <- function(df, var, y_title) {
    theme_ribbon <- function() {
        theme_light(base_size = 20) %+replace%
            theme(
                panel.grid.minor.x = element_blank(),
                panel.grid.major.x = element_blank(),
                panel.grid.minor.y = element_blank(),
                panel.grid.major.y = element_line(size = 0.5, 
                                                  linetype = "dashed"),
                panel.border = element_blank(),
                # axis.ticks.x = element_blank(),
                axis.ticks.y = element_blank(),
                axis.line = element_line(size = 0.6),
                plot.title = element_text(hjust = 0, 
                                          color = "black", size = rel(1)),
                legend.text = element_text(size = rel(0.6)),
                legend.title.align = 0,
                legend.title = element_text(size = rel(0.8))
            )
    }
    max_y <- df %>%
        select(everything(), param = all_of(var)) %>%
        pull(.data$param) %>%
        max()

    interval_end <- df %>%
        slice(1, .by = .data$measurement) %>%
        slice_tail(n = 1, by = .data$interval) %>%
        pull(.data$time_wave) %>%
        head(-1)

    interval_start <- df %>%
        slice(1, .by = .data$measurement) %>%
        slice_head(n = 1, by = .data$interval) %>%
        pull(.data$time_wave) %>%
        tail(-1)

    xintercept_inj <- (interval_start - interval_end) / 2 + interval_end

    label_inj <- df %>%
        pull(.data$injection) %>%
        unique() %>%
        tail(-1)

    plot <- df %>%
        select(everything(), param = all_of(var)) %>%
        ggplot(aes(
            x = .data$time_wave, y = .data$param,
            group = .data$well,
            color = .data$group
        )) +
        geom_line(linewidth = 0.5) +
        geom_vline(
            xintercept = xintercept_inj,
            color = "grey40",
            linetype = "dashed",
        ) +
        annotate("text",
            x = xintercept_inj, y = max_y,
            label = label_inj, color = "grey40",
            hjust = 1, vjust = -0.4, size = 4, angle = 90
        ) +
        colorspace::scale_color_discrete_divergingx(
          palette = "Geyser", rev = TRUE) +
        labs(
            y = y_title,
            x = "time (minutes)"
        ) +
        scale_y_continuous(
            labels =
                scales::label_scientific()
        ) +
        theme_ribbon()

    return(plot)
}
