一些自定义R语言函数

简介

这里记录一些自己编写的R语言函数,以供日常使用参考。

代码部分

1号函数

library(tidyverse)

plot_cat_bin <- function(data, predictor, outcome, reorder = F) {
  data <- data %>%
    transmute(predictor = !!enquo(predictor),
              outcome = !!enquo(outcome))
  
  stopifnot(is.factor(data$predictor) | is.logical(data$predictor))
  stopifnot(all(data$outcome %in% c(0, 1)) | is.logical(data$outcome))
  
  data <- drop_na(data, outcome)
  prob_mean <- mean(pull(data, outcome))
  
  data <- data %>%
    mutate(predictor = forcats::fct_explicit_na(factor(predictor), "<NA>")) %>%
    group_by(predictor) %>%
    group_modify(function(.x, .y) {
      results <- prop.test(sum(pull(.x, outcome)), nrow(.x))
      tibble(prob = unname(results$estimate),
             lower = results$conf.int[1],
             upper = results$conf.int[2])
    }, keep = T) %>%
    mutate(prob_mean = !!prob_mean) %>%
    ungroup()
  
  if (reorder) {
    data <- data %>%
      mutate(predictor = forcats::fct_reorder(predictor, prob))
  }
  
  data %>%
    ggplot(aes(predictor, prob)) +
    geom_point() +
    geom_errorbar(aes(ymin = lower, ymax = upper)) +
    geom_hline(yintercept = prob_mean, linetype = "dashed") +
    labs(x = as_label(enexpr(predictor)))
}

一些笔记:

计算率的置信区间时,最常见的计算算式是\(\hat{p} \pm z \sqrt{\frac{\hat{p}(1 - \hat{p})}{n}}\);由该算式算出来的区间叫Wald区间。当\(n\)很小,或\(\hat{p}\)接近0或1的时候,Wald区间值非常不稳定,而且有可能不处于\([0, 1]\)区间内。

因为上述原因,R语言的prop.test()在计算率的置信区间时,使用的是适用面更广的Wilson score区间,并且在必要时还会增加连续性校正;该方法的计算算式可以参阅Newcombe RG (1998). “Two-Sided Confidence Intervals for the Single Proportion: Comparison of Seven Methods.” 一文。prop.test()所使用的方法对应该文中的方法3(不含连续性校正)和方法4(含连续性校正)。

因此,举个例子,prop.test(12, 123, conf.level = 0.99, correct = FALSE)计算出来的置信区间\([0.0479415, 0.1883752]\)也可以通过以下方式计算出来:

p <- 12/123
n <- 123
z <- qnorm(0.995)
(2*n*p + z^2 + c(-1, 1)*z*sqrt(z^2 + 4*n*p*(1 - p))) /(2*(n + z^2))
## [1] 0.0479415 0.1883752

2号函数

library(tidyverse)

plot_num_bin <- function(data, predictor, outcome, augment = F) {
  data <- data %>%
    transmute(predictor = !!enquo(predictor),
              outcome = !!enquo(outcome))
  
  stopifnot(is.numeric(data$predictor))
  stopifnot(all(data$outcome %in% c(0, 1)) | is.logical(data$outcome))
  
  data <- drop_na(data, outcome) %>%
    mutate(outcome = as.double(outcome))
  prob_mean <- mean(pull(data, outcome))
  
  gam_fit <- mgcv::gam(outcome ~ s(predictor), data = data, family = binomial)
  
  if(augment) {
    data <- tibble(predictor = runif(1000, min(data$predictor), max(data$predictor)))
  } else {
    data <- distinct(data, predictor)
  }
  
  data <- data %>%
    mutate(link = predict(gam_fit, data, type = "link"),
           se = predict(gam_fit, data, type = "link", se.fit = T)$se.fit,
           upper = link + qnorm(0.975) * se,
           lower = link - qnorm(0.975) * se,
           lower = binomial()$linkinv(lower),
           upper = binomial()$linkinv(upper),
           prob = binomial()$linkinv(link),
           prob_mean = !!prob_mean)
  
  data %>%
    ggplot(aes(x = predictor)) +
    geom_line(aes(y = prob)) +
    geom_ribbon(aes(ymin = lower, ymax = upper), fill = "darkgrey", alpha = 0.5) +
    geom_hline(yintercept = prob_mean, linetype = "dashed") +
    labs(x = as_label(enexpr(predictor)))
}

笔记:待补充。

3号函数

此函数是对stats::nls()的封装,用来拟合非线性函数曲线。

fit_nls <- function(data, formula, params, ...,
                    max_try = 100, seed = 1) {
  # params用来设置曲线中各个参数的初始值及范围;其必须满足以下要求:
  # params必须是一个列表
  stopifnot(is.list(params))
  # params列表中的所有元素必须已被命名
  stopifnot(!is.null(names(params)) & all(nchar(names(params)) > 0))
  # params列表中的各个元素必须是类型为浮点数或整数、且长度大于0的向量
  stopifnot(all(vapply(params, function(x) {
    !is.null(x) & is.atomic(x) & typeof(x) %in% c("double", "integer")
  }, logical(1))))
  # params列表中的各个向量至少包含一个非缺失且非无穷大值
  stopifnot(all(vapply(params, function(x) {
    !all(is.na(x) | is.infinite(x))
  }, logical(1))))
  
  params <- params[sort(names(params))]
  
  params_range <- as.data.frame(
    do.call(rbind,
            Map(function(p, n) {
              prange <- range(p, na.rm = T)
              data.frame(name = n, lower = prange[[1]], upper = prange[[2]])
            }, p = params, n = names(params)))
  )
  
  params_init <- expand.grid(lapply(params, function(x) {x[!is.na(x) & !is.infinite(x) & !duplicated(x)]}))
  
  max_try <- min(max_try, nrow(params_init))
  
  set.seed(seed)
  params_init <- params_init[sample.int(nrow(params_init), size = max_try), ]
  
  for (i in seq_len(nrow(params_init))) {
    nls_fit <- tryCatch(nls(formula = formula, data = data,
                            start = unlist(params_init[i, ]),
                            algorithm = "port",
                            lower = params_range$lower,
                            upper = params_range$upper,
                            ...),
                        error = function(e) {e})

    if (is.element("nls", class(nls_fit))) {
      return(nls_fit)
    }
  }

  stop(nls_fit)
}

相关

下一页
上一页
comments powered by Disqus