本文介绍了蓄水池抽样算法[Reservoir Sampling]

问题

给定一个数据流,数据流长度$N$很大,且$N$直到处理完所有数据之前都不可知,请问如何在只遍历一遍数据$(O(N))$的情况下,能够随机选取出$m$个不重复的数据。

这个问题强调了3件事:

  • 数据流长度$N$很大且不可预知,所以无法一次性存入内存,这也直接导致不能直接取$N$内的$m$个随机数,然后按索引取出数据
  • 时间复杂度不能超过$O(N)$,所以不能先遍历一遍,然后分块存储数据,再随机选取
  • 随机选取$m$个数,每个数被选中的概率为$\frac{m}{N}$,这要求数据选取绝对随机

蓄水池抽样算法

int[] reservoir = new int[m];

for (int i = 0; i < reservoir.length; i++)
{
    // 这里使用已知长度的数组dataStream来表示未知长度的数据流
    // 并假设数据流长度大于蓄水池容量m
    reservoir[i] = dataStream[i];
}

for (int i = m; i < dataStream.length; i++)
{
    // 随机获得一个[0, i]内的随机整数
    int d = rand.nextInt(i + 1);
    // 如果随机整数落在[0, m-1]范围内,则替换蓄水池中的元素
    if (d < m)
    {
        reservoir[d] = dataStream[i];
    }
}

算法思路大致如下

  1. 如果接收的数据量小于$m$,则依次放入蓄水池
  2. 当接收到第$i$个数据时,如果$i >= m$,在$[0, i]$范围内取以随机数$d$,若$d$的落在$[0, m-1]$范围内,则用接收到的第$i$个数据替换蓄水池中的第$d$个数据
  3. 重复步骤2

当处理完所有的数据时,蓄水池中的每个数据都是以$\frac{m}{N}$的概率获得的。

下面用白话文推导验证该算法,假设数据开始编号为1。

第$i$个数据最后能够留在蓄水池中的概率=第$i$个数据进入过蓄水池的概率$\times$之后第i个数据不被替换的概率[第i+1到第N次处理数据都不会被替换]。

  1. 当$i<=$m时,数据直接放进蓄水池,所以第$i$个数据进入过蓄水池的概率为$1$
  2. 当$i>m$时,在$[1,i]$内选取随机数$d$,如果$d \leq m$,则使用第$i$个数据替换蓄水池中第$d$个数据,因此第$i$个数据进入过蓄水池的概率为$\frac{m}{i}$
  3. 当$i \leq m$时,程序从接收到第$m+1$个数据时开始执行替换操作,第$m+1$次处理会替换池中数据的为$\frac{m}{m+1}$,会替换掉第$i$个数据的概率为$\frac{1}{m}$,则第$m+1$次处理替换掉第$i$个数据的概率为$\frac{m}{m+1} \times \frac{1}{m} = \frac{1}{m+1}$,不被替换的概率为$1- \frac{1}{m+1}= \frac{m}{m+1}$。依次,第$m+2$次处理不替换掉第$i$个数据概率为$\frac{m+1}{m+2}$...第$N$次处理不替换掉第$i$个数据的概率为$\frac{N-1}{N}$。所以,之后第$i$个数据不被替换的概率=$\frac{m}{m+1} \times \frac{m+1}{m+2} \times \dots \times \frac{N-1}{N}=\frac{m}{N}$
  4. 当$i>m$时,程序从接收到第$i+1$个数据时开始有可能替换第$i$个数据。则参考上述第3点,之后第$i$个数据不被替换的概率=$\frac{i}{N}$
  5. 结合第1点和第3点可知,当$i \leq m$时,第$i$个接收到的数据最后留在蓄水池中的概率=$1 \times \frac{m}{N} = \frac{m}{N}$。结合第2点和第4点可知,当$i>m$时,第$i$个接收到的数据留在蓄水池中的概率=$\frac{m}{i} \times \frac{i}{N} = \frac{m}{N}$。综上可知,每个数据最后被选中留在蓄水池中的概率为$\frac{m}{N}$。

这个算法建立在统计学基础上,很巧妙地获得了$\frac{m}{N}$这个概率。

分布式蓄水池抽样

如果遇到超大的数据量,即使是$O(N)$的时间复杂度,蓄水池抽样程序完成抽样任务也将耗时很久。因此分布式的蓄水池抽样算法应运而生。运作原理如下:

  1. 假设有$K$台机器,将大数据集分成$K$个数据流,每台机器使用单机版蓄水池抽样处理一个数据流,抽样$m$个数据,并最后记录处理的数据量为$N_1, N_2, \dots, N_k, \dots, N_K$(假设$m < N_k$)。$N_1 + N_2 + \dots + N_K = N$。
  2. 取$[1, N]$一个随机数$d$,若$d < N_1$,则在第一台机器的蓄水池中等概率不放回地$\frac{1}{m}$选取一个数据,若$N_1 \leq d < (N_1+N_2)$,则在第二台机器的蓄水池中等概率不放回地选取一个数据;一次类推,重复$m$次,则最终从$N$大数据集中选出$m$个数据。

$\frac{m}{N}$的概率验证如下:

  1. 第$k$台机器中的蓄水池数据被选取的概率为$\frac{m}{N_k}$
  2. 从第$k$台机器的蓄水池中选取一个数据放进最终蓄水池的概率为$\frac{N_k}{N}$
  3. 第$k$台机器蓄水池的一个数据被选中的概率为$\frac{1}{m}$[不放回选取时等概率的]
  4. 重复$m$次选取,则每个数据被选中的概率为$m \times (\frac{m}{N_k} \times \frac{N_k}{N} \times \frac{1}{m} = \frac{m}{N})$

本文部分内容参考https://www.jianshu.com/p/7a9ea6ece2af