一些自定义R语言函数
数学与统计学
摘要
列出一些常用的自定义函数。
简介
这里记录一些自己编写的R语言函数,以供日常使用参考。
代码部分
1号函数
library(tidyverse)
<- function(data, predictor, outcome, reorder = F) {
plot_cat_bin <- data %>%
data transmute(predictor = !!enquo(predictor),
outcome = !!enquo(outcome))
stopifnot(is.factor(data$predictor) | is.logical(data$predictor))
stopifnot(all(data$outcome %in% c(0, 1)) | is.logical(data$outcome))
<- drop_na(data, outcome)
data <- mean(pull(data, outcome))
prob_mean
<- data %>%
data mutate(predictor = forcats::fct_explicit_na(factor(predictor), "<NA>")) %>%
group_by(predictor) %>%
group_modify(function(.x, .y) {
<- prop.test(sum(pull(.x, outcome)), nrow(.x))
results tibble(prob = unname(results$estimate),
lower = results$conf.int[1],
upper = results$conf.int[2])
keep = T) %>%
}, mutate(prob_mean = !!prob_mean) %>%
ungroup()
if (reorder) {
<- data %>%
data mutate(predictor = forcats::fct_reorder(predictor, prob))
}
%>%
data ggplot(aes(predictor, prob)) +
geom_point() +
geom_errorbar(aes(ymin = lower, ymax = upper)) +
geom_hline(yintercept = prob_mean, linetype = "dashed") +
labs(x = as_label(enexpr(predictor)))
}
一些笔记:
计算率的置信区间时,最常见的计算算式是\(\hat{p} \pm z \sqrt{\frac{\hat{p}(1 - \hat{p})}{n}}\);由该算式算出来的区间叫Wald区间。当\(n\)很小,或\(\hat{p}\)接近0或1的时候,Wald区间值非常不稳定,而且有可能不处于\([0, 1]\)区间内。
因为上述原因,R语言的prop.test()
在计算率的置信区间时,使用的是适用面更广的Wilson score区间,并且在必要时还会增加连续性校正;该方法的计算算式可以参阅Newcombe RG (1998). “Two-Sided Confidence Intervals for the Single Proportion: Comparison of Seven Methods.” 一文。prop.test()
所使用的方法对应该文中的方法3(不含连续性校正)和方法4(含连续性校正)。
因此,举个例子,prop.test(12, 123, conf.level = 0.99, correct = FALSE)
计算出来的置信区间\([0.0479415, 0.1883752]\)也可以通过以下方式计算出来:
<- 12/123
p <- 123
n <- qnorm(0.995)
z 2*n*p + z^2 + c(-1, 1)*z*sqrt(z^2 + 4*n*p*(1 - p))) /(2*(n + z^2)) (
## [1] 0.0479415 0.1883752
2号函数
library(tidyverse)
<- function(data, predictor, outcome, augment = F) {
plot_num_bin <- data %>%
data transmute(predictor = !!enquo(predictor),
outcome = !!enquo(outcome))
stopifnot(is.numeric(data$predictor))
stopifnot(all(data$outcome %in% c(0, 1)) | is.logical(data$outcome))
<- drop_na(data, outcome) %>%
data mutate(outcome = as.double(outcome))
<- mean(pull(data, outcome))
prob_mean
<- mgcv::gam(outcome ~ s(predictor), data = data, family = binomial)
gam_fit
if(augment) {
<- tibble(predictor = runif(1000, min(data$predictor), max(data$predictor)))
data else {
} <- distinct(data, predictor)
data
}
<- data %>%
data mutate(link = predict(gam_fit, data, type = "link"),
se = predict(gam_fit, data, type = "link", se.fit = T)$se.fit,
upper = link + qnorm(0.975) * se,
lower = link - qnorm(0.975) * se,
lower = binomial()$linkinv(lower),
upper = binomial()$linkinv(upper),
prob = binomial()$linkinv(link),
prob_mean = !!prob_mean)
%>%
data ggplot(aes(x = predictor)) +
geom_line(aes(y = prob)) +
geom_ribbon(aes(ymin = lower, ymax = upper), fill = "darkgrey", alpha = 0.5) +
geom_hline(yintercept = prob_mean, linetype = "dashed") +
labs(x = as_label(enexpr(predictor)))
}
笔记:待补充。
3号函数
此函数是对stats::nls()
的封装,用来拟合非线性函数曲线。
<- function(data, formula, params, ...,
fit_nls max_try = 100, seed = 1) {
# params用来设置曲线中各个参数的初始值及范围;其必须满足以下要求:
# params必须是一个列表
stopifnot(is.list(params))
# params列表中的所有元素必须已被命名
stopifnot(!is.null(names(params)) & all(nchar(names(params)) > 0))
# params列表中的各个元素必须是类型为浮点数或整数、且长度大于0的向量
stopifnot(all(vapply(params, function(x) {
!is.null(x) & is.atomic(x) & typeof(x) %in% c("double", "integer")
logical(1))))
}, # params列表中的各个向量至少包含一个非缺失且非无穷大值
stopifnot(all(vapply(params, function(x) {
!all(is.na(x) | is.infinite(x))
logical(1))))
},
<- params[sort(names(params))]
params
<- as.data.frame(
params_range do.call(rbind,
Map(function(p, n) {
<- range(p, na.rm = T)
prange data.frame(name = n, lower = prange[[1]], upper = prange[[2]])
p = params, n = names(params)))
},
)
<- expand.grid(lapply(params, function(x) {x[!is.na(x) & !is.infinite(x) & !duplicated(x)]}))
params_init
<- min(max_try, nrow(params_init))
max_try
set.seed(seed)
<- params_init[sample.int(nrow(params_init), size = max_try), ]
params_init
for (i in seq_len(nrow(params_init))) {
<- tryCatch(nls(formula = formula, data = data,
nls_fit start = unlist(params_init[i, ]),
algorithm = "port",
lower = params_range$lower,
upper = params_range$upper,
...),error = function(e) {e})
if (is.element("nls", class(nls_fit))) {
return(nls_fit)
}
}
stop(nls_fit)
}