Tidyverse包中的循环和迭代-purrr

purrr
iteration
Author

Lee

Published

September 13, 2023

本章中,我们主要使用purrrdplyr两个包来实现循环和迭代。在R中,迭代通常与其他编程语言看起来截然不同,例如,如果您想要将数值向量x中的每个元素都加倍,在R中,您只需编写2 * x即可。而在大多数其他编程语言中,您需要使用某种形式的for循环来显式地将x的每个元素加倍。类似的情况在使用purrr的时候尤为明显。

1 修改多列

1.1 选择多列并调用单个函数

library(tidyverse)
df <- tibble(
  a = rnorm(10),
  b = rnorm(10),
  c = rnorm(10),
  d = rnorm(10)
)

df |> summarise(
  n = n(),
  a = median(a),
  b = median(b),
  c = median(c),
  d = median(d),
)
# A tibble: 1 × 5
      n      a       b      c     d
  <int>  <dbl>   <dbl>  <dbl> <dbl>
1    10 -0.682 -0.0659 -0.278 0.256

上述代码在计算df各列时,median()函数被重复执行了四次,意味着我们进行了多次的复制粘贴操作,这与“在代码中永远不要复制粘贴超过2次”的原则相违背。试想如果数据集中有10个以上的列,多次进行复制粘贴不仅繁琐,而且容易产生错误。幸好,在tidyverse中的acorss()函数就排上了用场。

across()是一个功能异常强大的函数,它有3个关键的参数:

  1. .cols,操作的目标列,可以是多列,where()which()everything()starts_with()等在此处同样适用。
  2. .fns,对目标列进行操作的函数,同样可以是多个函数,如果是多个函数,则需要以列表的形式传入。
  3. .names,控制输出列的名称,在与mutate()连用时非常有用。

那么上述代码可以使用across()进行改写。

df |> summarise(
  across(a:d, median)
)
# A tibble: 1 × 4
       a       b      c     d
   <dbl>   <dbl>  <dbl> <dbl>
1 -0.682 -0.0659 -0.278 0.256
Important

across()函数有两个使用注意点一定要牢记:

  1. across()不能单独使用,它必须在summarise()mutate()等函数中使用。
  2. across().fns传入函数时,函数后不需要加括号。如果函数中需要传入参数,则需要使用匿名函数。有关匿名函数,在下节中会有描述。

1.2 选择多列并调用多个函数

# 生成一个可操作的数据集
rnorm_na <- function(n, n_na, mean = 0, sd = 1) {
  sample(c(rnorm(n - n_na, mean = mean, sd = sd), rep(NA, n_na)))
} # 生成包含缺失值的正态分布数列,项数为n-n_na,其中有n-n_na项缺失值
df_miss <- tibble(
  a = rnorm_na(5, 1),
  b = rnorm_na(5, 1),
  c = rnorm_na(5, 2),
  d = rnorm(5)
)

# 计算df_miss各列的中数,计算时需要去掉缺失值
df_miss |>
  summarise(
    across(a:d, \(x) median(x, na.rm = TRUE))
  )
# A tibble: 1 × 4
       a     b     c      d
   <dbl> <dbl> <dbl>  <dbl>
1 -0.224 0.497 0.382 -0.252
Warning

上述代码的匿名函数写作了\(x) median(x, na.rm = TRUE),而我们更熟悉的写法是~median(.x, na.rm = TRUE)。Hadley之所以修改写法基于以下两个方面:

  1. .x的写法只适用于tidyverse函数内。
  2. .x有时指代比较抽象,不便理解。

基于此,Hadley建议使用\(x)的写法,例如~.x + 1建议写为\(x) x + 1

如果我们系统同时计算数据集中每列的中数并求和呢?这就需要调用两个函数。

df_miss |> summarise(
  across(
    a:d, # 一定要指定列
    list(
      median = \(x) median(x, na.rm = TRUE),
      sum = \(x) sum(x, na.rm = TRUE)
    )
  )
)
# A tibble: 1 × 8
  a_median a_sum b_median b_sum c_median c_sum d_median d_sum
     <dbl> <dbl>    <dbl> <dbl>    <dbl> <dbl>    <dbl> <dbl>
1   -0.224 -1.08    0.497  1.28    0.382 0.755   -0.252  1.17

上述代码输出的结果列的名称是使用一个类似于 {.col}_{.fn}glue 规范命名的,其中 .col 是原始列的名称,.fn 是函数的名称。我们可以使用 .names 参数来提供您自己的 glue 规范

df_miss |> summarise(
  across(a:d, # 一定要指定列
    list(
      median = \(x) median(x, na.rm = TRUE),
      sum = \(x) sum(x, na.rm = TRUE)
    ),
    .names = "{.fn}-{.col}"
  )
)
# A tibble: 1 × 8
  `median-a` `sum-a` `median-b` `sum-b` `median-c` `sum-c` `median-d` `sum-d`
       <dbl>   <dbl>      <dbl>   <dbl>      <dbl>   <dbl>      <dbl>   <dbl>
1     -0.224   -1.08      0.497    1.28      0.382   0.755     -0.252    1.17

.names参数在与mutate()连用时非常有用,通常用来对比新生成的列与原有列的不同。

# 求数据集的绝对值
df_miss |>
  mutate(
    across(a:d, \(x) abs(x), .names = "{.col}_abs")
  )
# A tibble: 5 × 8
       a      b      c      d  a_abs  b_abs  c_abs d_abs
   <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl>  <dbl> <dbl>
1 -1.59  NA      0.382 -0.807  1.59  NA      0.382 0.807
2 NA     -0.376 NA     -0.252 NA      0.376 NA     0.252
3 -0.153  0.666 NA      1.09   0.153  0.666 NA     1.09 
4  0.952  0.507  0.888  2.70   0.952  0.507  0.888 2.70 
5 -0.294  0.486 -0.515 -1.56   0.294  0.486  0.515 1.56 

2 行操作

across()summarize()mutate()匹配度很好,但都是针对列的操作,有没有针对filter()的行的操作呢?

dplyrif_any()if_all()两个函数功能就是across()针对行操作的变种。我们具体来看:

# 选择包含缺失值的所有行
df_miss |>
  filter(
    if_any(a:d, is.na)
  )
# A tibble: 3 × 4
       a      b      c      d
   <dbl>  <dbl>  <dbl>  <dbl>
1 -1.59  NA      0.382 -0.807
2 NA     -0.376 NA     -0.252
3 -0.153  0.666 NA      1.09 
# 选择全是缺失值的所有行
df_miss |>
  filter(
    if_all(a:d, is.na)
  )
# A tibble: 0 × 4
# ℹ 4 variables: a <dbl>, b <dbl>, c <dbl>, d <dbl>

这部分更详细的说明可参看:https://bookdown.org/wangminjie/R4DS/tidyverse-beauty-of-across2.html#dplyr-1.0.4-if_any-and-if_all。

3 在自定义函数中使用across()

expand_dates <- function(df) {
  df |>
    mutate(across(where(is.Date), list(year = year, month = month, day = mday)))
}

上述函数的作用为:将所有日期列扩展为年、月和日列:

df_date <- tibble(
  name = c("Amy", "Bob"),
  date = ymd(c("2009-08-03", "2010-01-16"))
)
df_date
# A tibble: 2 × 2
  name  date      
  <chr> <date>    
1 Amy   2009-08-03
2 Bob   2010-01-16
df_date |>
  expand_dates()
# A tibble: 2 × 5
  name  date       date_year date_month date_day
  <chr> <date>         <dbl>      <dbl>    <int>
1 Amy   2009-08-03      2009          8        3
2 Bob   2010-01-16      2010          1       16
Important

across()还可以与非标准性评估(tidy eval)结合,实现给一个参数传入多个列的操作,只要注意将这些传入多个列的参数用两个大括号括起来。关于非标准性评估(tidy eval),可参看这里这里。我们看一个例子:

summarise_means <- function(df, summary_vars = where(is.numeric)) { # 参数summary_vars传入多个列
  df |>
    summarize(
      across({{ summary_vars }}, \(x) mean(x, na.rm = TRUE)), # summary_vars用两个大括号框起来
      n = n()
    )
}

diamonds |>
  group_by(cut) |>
  summarise_means()
# A tibble: 5 × 9
  cut       carat depth table price     x     y     z     n
  <ord>     <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
1 Fair      1.05   64.0  59.1 4359.  6.25  6.18  3.98  1610
2 Good      0.849  62.4  58.7 3929.  5.84  5.85  3.64  4906
3 Very Good 0.806  61.8  58.0 3982.  5.74  5.77  3.56 12082
4 Premium   0.892  61.3  58.7 4584.  5.97  5.94  3.65 13791
5 Ideal     0.703  61.7  56.0 3458.  5.51  5.52  3.40 21551

4 与pivot_longer()连用

pivot_longer()across()间有许多有趣的联系,它们可以实现相同的操作并得到一致的结果。我们看一个例子,例如我们需要计算df数据集每列的中数和众数。

  • 使用across():我们可以看到结果是一个宽表格,每列的名称的形式为”列名_函数名”。
df |>
  summarise(
    across(a:d, list(median = median, mean = mean))
  )
# A tibble: 1 × 8
  a_median a_mean b_median  b_mean c_median c_mean d_median d_mean
     <dbl>  <dbl>    <dbl>   <dbl>    <dbl>  <dbl>    <dbl>  <dbl>
1   -0.682 -0.633  -0.0659 0.00853   -0.278 -0.340    0.256  0.245
long <- df |>
  pivot_longer(a:d) |>
  group_by(name) |>
  summarise(
    median = median(value),
    mean = mean(value)
  )
long
# A tibble: 4 × 3
  name   median     mean
  <chr>   <dbl>    <dbl>
1 a     -0.682  -0.633  
2 b     -0.0659  0.00853
3 c     -0.278  -0.340  
4 d      0.256   0.245  
long |>
  pivot_wider(
    names_from = name,
    values_from = c(median, mean),
    names_vary = "slowest",
    names_glue = "{name}_{.value}"
  )
# A tibble: 1 × 8
  a_median a_mean b_median  b_mean c_median c_mean d_median d_mean
     <dbl>  <dbl>    <dbl>   <dbl>    <dbl>  <dbl>    <dbl>  <dbl>
1   -0.682 -0.633  -0.0659 0.00853   -0.278 -0.340    0.256  0.245

这个操作在我们无法使用across()的时候非常有用:即当我们需要同时计算一组多列的时候。例如,数据框同时包含了数值列和权重列,我们需要计算加权平均值时(我理解其实就是数据不整洁的时候):

df_paired <- tibble(
  a_val = rnorm(10),
  a_wts = runif(10),
  b_val = rnorm(10),
  b_wts = runif(10),
  c_val = rnorm(10),
  c_wts = runif(10),
  d_val = rnorm(10),
  d_wts = runif(10)
)
df_paired
# A tibble: 10 × 8
     a_val  a_wts  b_val  b_wts   c_val  c_wts  d_val d_wts
     <dbl>  <dbl>  <dbl>  <dbl>   <dbl>  <dbl>  <dbl> <dbl>
 1  0.0489 0.212  -0.110 0.555   0.0358 0.982   2.39  0.299
 2 -1.43   0.364  -0.435 0.0414  0.0912 0.130  -0.427 0.179
 3 -0.935  0.142  -0.214 0.0101 -0.320  0.758  -0.477 0.849
 4  1.14   0.0530 -1.28  0.305   0.399  0.577   1.05  0.374
 5 -1.12   0.799  -0.287 0.853  -0.289  0.195  -1.42  0.523
 6 -0.117  0.412   0.426 0.855  -0.0595 0.364   0.762 0.116
 7 -0.167  0.853  -0.463 0.152   0.206  0.645  -0.357 0.847
 8  2.52   0.0483 -0.339 0.0982 -0.374  0.0437 -0.644 0.652
 9  1.20   0.602  -0.679 0.882  -1.43   0.280  -1.43  0.613
10  0.189  0.929   0.684 0.517  -0.494  0.984  -0.166 0.290

df_paired数据框是不整洁的,我们无法使用across()函数,此时我们需要进行转换:

df_long <- df_paired |>
  pivot_longer(
    everything(),
    names_to = c("group", ".value"),
    names_sep = "_"
  )
df_long
# A tibble: 40 × 3
   group     val    wts
   <chr>   <dbl>  <dbl>
 1 a      0.0489 0.212 
 2 b     -0.110  0.555 
 3 c      0.0358 0.982 
 4 d      2.39   0.299 
 5 a     -1.43   0.364 
 6 b     -0.435  0.0414
 7 c      0.0912 0.130 
 8 d     -0.427  0.179 
 9 a     -0.935  0.142 
10 b     -0.214  0.0101
# ℹ 30 more rows
df_long |>
  group_by(group) |>
  summarise(mean = weighted.mean(val, wts)) |>
  pivot_wider(
    names_from = group,
    values_from = mean,
    names_glue = "{group}_{.value}"
  )
# A tibble: 1 × 4
  a_mean b_mean c_mean d_mean
   <dbl>  <dbl>  <dbl>  <dbl>
1 -0.147 -0.164 -0.164 -0.352

5 读取多个文件

这是一个重要的功能。实际应用中,我们经常会遇到需要读取多个表格文件的情况。如果一个一个读取,需要重复写多行代码(多次复制粘贴操作),与代码书写原则相悖(不要重复书写相同的代码两次以上)。这时我们需要用到purrr中的map()系列函数。

读取多个文件有如下几个基本步骤:

5.1 列出目录中需要读取的文档-list.files()

list.files()函数有是三个主要参数:

  • 第一个参数path,为文件的路径。
  • 第二个参数pattern,为一个正则表达式,用来选择需要读取的文档名。通常我们用到的"[.]xlsx$"[.]csv$,用于查找具有特定扩展名的所有文件。
  • 第三个参数full.names,为一个逻辑值,决定了在输出的结果中是否包含完整的目录名(通常我们设置为TRUE)。
paths <- list.files(
  path = "D:/Myblog/datas/gapminder",
  pattern = "[.]xlsx$", 
  full.names = TRUE
)
paths
 [1] "D:/Myblog/datas/gapminder/1952.xlsx" "D:/Myblog/datas/gapminder/1957.xlsx"
 [3] "D:/Myblog/datas/gapminder/1962.xlsx" "D:/Myblog/datas/gapminder/1967.xlsx"
 [5] "D:/Myblog/datas/gapminder/1972.xlsx" "D:/Myblog/datas/gapminder/1977.xlsx"
 [7] "D:/Myblog/datas/gapminder/1982.xlsx" "D:/Myblog/datas/gapminder/1987.xlsx"
 [9] "D:/Myblog/datas/gapminder/1992.xlsx" "D:/Myblog/datas/gapminder/1997.xlsx"
[11] "D:/Myblog/datas/gapminder/2002.xlsx" "D:/Myblog/datas/gapminder/2007.xlsx"

5.2 map()和list_rbind()

  • 对于多个表格的读取,首先将其保存至一个列表中,然后使用list_rbind()将其连接起来是比较简便且符合逻辑的做法。这两个函数都来自purrr
files <- map(paths, readxl::read_excel)  # map()的作用是对path同时执行后面函数的操作
length(files)  # 查看列表中共有多少个数据表
[1] 12
files[[1]]  # 可以通过index访问每一个数据表
# A tibble: 142 × 5
   country     continent lifeExp      pop gdpPercap
   <chr>       <chr>       <dbl>    <dbl>     <dbl>
 1 Afghanistan Asia         28.8  8425333      779.
 2 Albania     Europe       55.2  1282697     1601.
 3 Algeria     Africa       43.1  9279525     2449.
 4 Angola      Africa       30.0  4232095     3521.
 5 Argentina   Americas     62.5 17876956     5911.
 6 Australia   Oceania      69.1  8691212    10040.
 7 Austria     Europe       66.8  6927772     6137.
 8 Bahrain     Asia         50.9   120447     9867.
 9 Bangladesh  Asia         37.5 46886859      684.
10 Belgium     Europe       68    8730405     8343.
# ℹ 132 more rows
list_rbind(files)  # 根据行将列表中所有数据表连接
# A tibble: 1,704 × 5
   country     continent lifeExp      pop gdpPercap
   <chr>       <chr>       <dbl>    <dbl>     <dbl>
 1 Afghanistan Asia         28.8  8425333      779.
 2 Albania     Europe       55.2  1282697     1601.
 3 Algeria     Africa       43.1  9279525     2449.
 4 Angola      Africa       30.0  4232095     3521.
 5 Argentina   Americas     62.5 17876956     5911.
 6 Australia   Oceania      69.1  8691212    10040.
 7 Austria     Europe       66.8  6927772     6137.
 8 Bahrain     Asia         50.9   120447     9867.
 9 Bangladesh  Asia         37.5 46886859      684.
10 Belgium     Europe       68    8730405     8343.
# ℹ 1,694 more rows
  • 如果我们需要在函数中传入参数,则需使用之前提到的匿名函数。
paths |> 
  # 读取列表中每个数据表的一行数据
  map(\(path) readxl::read_excel(path, n_max = 1)) |> 
  list_rbind()
# A tibble: 12 × 5
   country     continent lifeExp      pop gdpPercap
   <chr>       <chr>       <dbl>    <dbl>     <dbl>
 1 Afghanistan Asia         28.8  8425333      779.
 2 Afghanistan Asia         30.3  9240934      821.
 3 Afghanistan Asia         32.0 10267083      853.
 4 Afghanistan Asia         34.0 11537966      836.
 5 Afghanistan Asia         36.1 13079460      740.
 6 Afghanistan Asia         38.4 14880372      786.
 7 Afghanistan Asia         39.9 12881816      978.
 8 Afghanistan Asia         40.8 13867957      852.
 9 Afghanistan Asia         41.7 16317921      649.
10 Afghanistan Asia         41.8 22227415      635.
11 Afghanistan Asia         42.1 25268405      727.
12 Afghanistan Asia         43.8 31889923      975.

仔细观察,我们发现通过以上代码输出的结果还缺少一个重要的信息,即这些数据是那个年份的(也就是从哪个源文件得来的)。

这个信息非常重要,我们在下一节梳理读取多个文件的完整步骤,并解决以上问题。

5.3 数据的路径来源

  1. 得到数据来源的向量。set_names()
files <- paths |> 
  set_names(basename)
files
                            1952.xlsx                             1957.xlsx 
"D:/Myblog/datas/gapminder/1952.xlsx" "D:/Myblog/datas/gapminder/1957.xlsx" 
                            1962.xlsx                             1967.xlsx 
"D:/Myblog/datas/gapminder/1962.xlsx" "D:/Myblog/datas/gapminder/1967.xlsx" 
                            1972.xlsx                             1977.xlsx 
"D:/Myblog/datas/gapminder/1972.xlsx" "D:/Myblog/datas/gapminder/1977.xlsx" 
                            1982.xlsx                             1987.xlsx 
"D:/Myblog/datas/gapminder/1982.xlsx" "D:/Myblog/datas/gapminder/1987.xlsx" 
                            1992.xlsx                             1997.xlsx 
"D:/Myblog/datas/gapminder/1992.xlsx" "D:/Myblog/datas/gapminder/1997.xlsx" 
                            2002.xlsx                             2007.xlsx 
"D:/Myblog/datas/gapminder/2002.xlsx" "D:/Myblog/datas/gapminder/2007.xlsx" 
  1. 读取所有数据
files <- files |> 
  map(readxl::read_excel)
files[[1]]
# A tibble: 142 × 5
   country     continent lifeExp      pop gdpPercap
   <chr>       <chr>       <dbl>    <dbl>     <dbl>
 1 Afghanistan Asia         28.8  8425333      779.
 2 Albania     Europe       55.2  1282697     1601.
 3 Algeria     Africa       43.1  9279525     2449.
 4 Angola      Africa       30.0  4232095     3521.
 5 Argentina   Americas     62.5 17876956     5911.
 6 Australia   Oceania      69.1  8691212    10040.
 7 Austria     Europe       66.8  6927772     6137.
 8 Bahrain     Asia         50.9   120447     9867.
 9 Bangladesh  Asia         37.5 46886859      684.
10 Belgium     Europe       68    8730405     8343.
# ℹ 132 more rows
  1. 将列表中数据连接起来
paths |> 
  set_names(basename) |> 
  map(readxl::read_excel) |> 
  list_rbind(names_to = "year") |>  # 将文件名设为新列并指定列名
  mutate(year = parse_number(year))
# A tibble: 1,704 × 6
    year country     continent lifeExp      pop gdpPercap
   <dbl> <chr>       <chr>       <dbl>    <dbl>     <dbl>
 1  1952 Afghanistan Asia         28.8  8425333      779.
 2  1952 Albania     Europe       55.2  1282697     1601.
 3  1952 Algeria     Africa       43.1  9279525     2449.
 4  1952 Angola      Africa       30.0  4232095     3521.
 5  1952 Argentina   Americas     62.5 17876956     5911.
 6  1952 Australia   Oceania      69.1  8691212    10040.
 7  1952 Austria     Europe       66.8  6927772     6137.
 8  1952 Bahrain     Asia         50.9   120447     9867.
 9  1952 Bangladesh  Asia         37.5 46886859      684.
10  1952 Belgium     Europe       68    8730405     8343.
# ℹ 1,694 more rows
  1. 在更复杂的情况下,目录名称中可能存储有其他变量,或者文件名可能包含多个数据位。在这种情况下,可以使用 set_names()(不带任何参数)记录完整路径,然后使用 tidyr::separate_wider_delim() 等函数将它们转换为有用的列。这些函数可帮助我们将文件名或目录名称中的不同数据分隔为单独的列,以便更好地进行数据分析。
paths |> 
  set_names() |> 
  map(readxl::read_excel) |> 
  list_rbind(names_to = "year") |> 
  separate_wider_delim(year, delim = "/", names = c(NA, NA, NA, "dir", "file")) |>
  separate_wider_delim(file, delim = ".", names = c("file", "ext"))
# A tibble: 1,704 × 8
   dir       file  ext   country     continent lifeExp      pop gdpPercap
   <chr>     <chr> <chr> <chr>       <chr>       <dbl>    <dbl>     <dbl>
 1 gapminder 1952  xlsx  Afghanistan Asia         28.8  8425333      779.
 2 gapminder 1952  xlsx  Albania     Europe       55.2  1282697     1601.
 3 gapminder 1952  xlsx  Algeria     Africa       43.1  9279525     2449.
 4 gapminder 1952  xlsx  Angola      Africa       30.0  4232095     3521.
 5 gapminder 1952  xlsx  Argentina   Americas     62.5 17876956     5911.
 6 gapminder 1952  xlsx  Australia   Oceania      69.1  8691212    10040.
 7 gapminder 1952  xlsx  Austria     Europe       66.8  6927772     6137.
 8 gapminder 1952  xlsx  Bahrain     Asia         50.9   120447     9867.
 9 gapminder 1952  xlsx  Bangladesh  Asia         37.5 46886859      684.
10 gapminder 1952  xlsx  Belgium     Europe       68    8730405     8343.
# ℹ 1,694 more rows
  1. 保存数据:有时我们需要将处理完成后的数据保存在excel中。
gapminder <- paths |> 
  set_names(basename) |> 
  map(readxl::read_excel) |> 
  list_rbind(names_to = "year") |> 
  mutate(year = parse_number(year))

gapminder
# A tibble: 1,704 × 6
    year country     continent lifeExp      pop gdpPercap
   <dbl> <chr>       <chr>       <dbl>    <dbl>     <dbl>
 1  1952 Afghanistan Asia         28.8  8425333      779.
 2  1952 Albania     Europe       55.2  1282697     1601.
 3  1952 Algeria     Africa       43.1  9279525     2449.
 4  1952 Angola      Africa       30.0  4232095     3521.
 5  1952 Argentina   Americas     62.5 17876956     5911.
 6  1952 Australia   Oceania      69.1  8691212    10040.
 7  1952 Austria     Europe       66.8  6927772     6137.
 8  1952 Bahrain     Asia         50.9   120447     9867.
 9  1952 Bangladesh  Asia         37.5 46886859      684.
10  1952 Belgium     Europe       68    8730405     8343.
# ℹ 1,694 more rows
write_csv(gapminder, "gapminer.csv")

5.4 多个简单迭代连用

前文讨论的例子中,我们直接从磁盘加载了数据,并且很幸运地得到了一个整洁的数据集。

然而在大多数情况下,我们需要进行一些额外的数据整理工作。面对这种情况,我们通常有两种基本的选择:

  1. 使用一个复杂的函数进行一轮迭代;
  2. 使用简单的函数进行多轮迭代。

根据我们的经验,大多数人首先选择使用一个复杂的迭代,但通常通过使用多个简单的迭代来获得更好的结果。