Cancel

微观基准测试

Language

·

October 10, 2024

前言

这里讨论的是微观基准测试(micro-benchmarking)的实现,关心的主要问题是:

在一次测试里到底要重复运行几次目标函数?

我们希望在尽可能少的运行次数下,得到一个相对稳定、准确的测量统计。

前车之辙

不妨来看一下几个语言内置的微观测试框架做了什么。

Python timeit

这个模块就是一个连同命令行接口和文档注解都不超过 400 代码的单文件。

从语法上看,就有将近 10 年的未被修改过1,而从 API 的设计理念上2,核心逻辑可能写在上个世纪。

它的命令行接口让用户指定作为时间观测次数的大轮次次数 repeat 和作为每次观测时运行的小轮次次数 number ,分别提供了 default_repeat = 5 和default_number = 1000_000 的默认值。

这种指定一个固定轮次的做法非常原始、根本就没有意义。

唯一值得一提的是 class Timer 里面提供的 autorange 方法,这个方法也仍然简单:

从 $1$ 开始,累积交替地以 $\times 2,\ \times 5$ 的次数运行测试目标,直到总运行时间超过 $0.2$ s 。

另外在测试开始运行的时候 opt-in 地禁止 GC。3

# In class Timer

def autorange(self, callback=None):
    """Return the number of loops and time taken so that total time >= 0.2.

	Calls the timeit method with increasing numbers from the sequence
    1, 2, 5, 10, 20, 50, ... until the time taken is at least 0.2
    second.  Returns (number, time_taken).

    If *callback* is given and is not None, it will be called after
    each trial with two arguments: ``callback(number, time_taken)``.
    """
    i = 1
    while True:
        for j in 1, 2, 5:
            number = i * j
            time_taken = self.timeit(number)
            if callback:
                callback(number, time_taken)
                if time_taken >= 0.2:
                    return (number, time_taken)
                i *= 10

Rust test

相比之下,Rust 内置的 test crate 实现的 micro-benchmarking 逻辑就体面得多。

首先确定初始的每轮运行次数 n4:

  1. 运行一次目标,得到测量时间 ns_single ns ,如果 $\lt 1$ 就补足为 $1$ ;5
  2. n = 1000_000 / ns_single 得到 $1$ ms 内的运行次数,如果 $\lt 1$ 就补足为 $1$ 。

然后逐步统计学取样,直到稳定地得到相对稳定的6观测统计或者超时7:

以 $50$ 为一轮的取样数,8

每一轮的迭代的运行次数仍然遵循交替 $\times 2,\ \times 5$ 的规律,

每次取样完后,通过 stats::winsorize(samples, 5.0) 去掉前 5% 和后 5%的取样结果9 ,

这样如果相邻两轮的观测统计 summ 和 summ5 满足以下条件:

  1. 前一轮 summ 的绝对偏差的中位数10 占这个中位数本身的百分比 $\lt 1\%$ ;
  2. 前一轮 summ 和这一轮 summ5 的两个中位数的差小于 summ5 的绝对偏差的中位数

则可以返回最新一轮 summ5 的统计结果,

否则检测总运行时间是否超时,超时也返回统计结果,

否则进行下一轮11。

pub fn iter<T, F>(inner: &mut F) -> stats::Summary
where
    F: FnMut() -> T,
{
    // Initial bench run to get ballpark figure.
    let ns_single = ns_iter_inner(inner, 1);

    // Try to estimate iter count for 1ms falling back to 1m
    // iterations if first run took < 1ns.
    let ns_target_total = 1_000_000; // 1ms
    let mut n: u64 = ns_target_total / cmp::max(1, ns_single);

    // if the first run took more than 1ms we don't want to just
    // be left doing 0 iterations on every loop. The unfortunate
    // side effect of not being able to do as many runs is
    // automatically handled by the statistical analysis below
    // (i.e., larger error bars).
    n = cmp::max(1, n);

    let mut total_run = Duration::new(0, 0);
    let samples: &mut [f64] = &mut [0.0_f64; 50];
    loop {
        let loop_start = Instant::now();

        for p in &mut *samples {
            *p = ns_iter_inner(inner, n) as f64 / n as f64;
        }

        stats::winsorize(samples, 5.0);
        let summ = stats::Summary::new(samples);

        for p in &mut *samples {
            let ns = ns_iter_inner(inner, 5 * n);
            *p = ns as f64 / (5 * n) as f64;
        }

        stats::winsorize(samples, 5.0);
        let summ5 = stats::Summary::new(samples);

        let loop_run = loop_start.elapsed();

        // If we've run for 100ms and seem to have converged to a
        // stable median.
        if loop_run > Duration::from_millis(100)
            && summ.median_abs_dev_pct < 1.0
            && summ.median - summ5.median < summ5.median_abs_dev
        {
            return summ5;
        }

        total_run += loop_run;
        // Longest we ever run for is 3s.
        if total_run > Duration::from_secs(3) {
            return summ5;
        }

        // If we overflow here just return the results so far. We check a
        // multiplier of 10 because we're about to multiply by 2 and the
        // next iteration of the loop will also multiply by 5 (to calculate
        // the summ5 result)
        n = match n.checked_mul(10) {
            Some(_) => n * 2,
            None => {
                return summ5;
            }
        };
    }
}

重新发明 [Python]

发明动机

一谈起为 Python Code 做基准测试,有人洋洋洒洒能列出很多种方法,什么使用 time 模块,使用 timeit 模块,使用 profile 或 cProfile ,甚至使用 Unix-like 系统上的 time 命令,但是真正能用的还是 timeit 。

可正如前面所讲,“你什么 timeit 都在做基准测试,它能做吗,做不了,没那个能力知道吧”。

又找不到一个主流的、大小合适的、做微观基准测试的 Python 第三方库,因此有必要为 Python 写一个可用的微观基准测试框架,这个动机也是本文的起源。

实操

直接参照 Rust test 里的实现做一个 Python 的版本12。

好消息是 Python 内置的 statistics 模块把 Rust test 里面花两百行代码实现的统计方法13都实现了,这为我们省去了很多功夫。

这里只列举核心的方法,完整框架参考项目所在仓库 。

数据结构

首先定义一个测试方法应该返回的基本数据结构 CaseStats :

class CaseStats(NamedTuple):
    """case -> module -> benchmark -(archiving)> records"""

    fname: str
    ave: int
    iters: int
    # diviation in nanos
    dev: int
    raw_samples: list[Number]

一个被测试的目标函数函数应该是 type BenchCase = Callable[[], None] ,

这样测试方法的类型就是 Callable[[BenchCase], CaseStats] 。

辅助方法

首先是需要的一系列统计方法:

from numbers import Number
from statistics import quantiles, mean, median, StatisticsError


class SampleStats:
    def __init__(self, samples: list[Number], n: int):
        self.samples = samples
        self.n = n

    @property
    @cache
    def max(self) -> Number:
        return max(self.samples)

    @property
    @cache
    def min(self) -> Number:
        return min(self.samples)

    @property
    @cache
    def median(self) -> Number:
        return median(self.samples)

    @property
    @cache
    def median_abs_dev(self) -> Number:
        return median_abs_dev(self.samples)

    @property
    @cache
    def median_abs_dev_pct(self) -> Number:
        """ 0.01 stand for 1% """

        return self.median_abs_dev / self.median

def winsoring(samples: list[Number], n: int, e: int):
    """ [Winsoring](https://en.wikipedia.org/wiki/Winsorizing) high e/n and low e/n

    in-place
    """

    if not 2 * e < n:
        raise StatisticsError('must satisfy 2 * e < n')

    # percentiles
    pcts = quantiles(samples, n=n)

    lo = pcts[e - 1]
    hi = pcts[n - e - 1]

    for i in range(len(samples)):
        if samples[i] > hi:
            samples[i] = hi

        elif samples[i] < lo:
            samples[i] = lo


def median_abs_dev(samples: list[Number]) -> Number:
    x0 = median(samples)

    return median(list(map(lambda x: abs(x - x0), samples)))

一个便于观测时间的方法,利用了 Python 上文管理器的 with 语法:

class Watch:
    def __init__(self, fmt=".2f") -> None:
        """
        :fmt: format float number of (s/ms/us)
        """
        self.fmt = fmt

    def __enter__(self):
        self.start = time.perf_counter_ns()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end = time.perf_counter_ns()
        self._nanos = self.end - self.start

    @property
    def nanos(self) -> int:
        return self._nanos

    @property
    def micros(self) -> float:
        return self._nanos / 1000

    @property
    def millis(self) -> float:
        return self._nanos / 1000_000

    @property
    def secs(self) -> float:
        return self._nanos / 1000_000_000

然后我们可以写性能统计取样的方法:

SAMPLE_SIZE: int = 50
WINSORING_PERCENTILE: dict[str, int] = {'n': 100, 'e': 5}

def sampling(f: BenchCase, n: int) -> SampleStats:
    samples = [0] * SAMPLE_SIZE

    for i in range(SAMPLE_SIZE):
        with Watch() as w:
            for _ in range(n):
                f()

        samples[i] = w.nanos // n

    winsoring(samples, **WINSORING_PERCENTILE)

    return SampleStats(samples, n)

主过程

def run_a_benchmark(f: BenchCase) -> SampleStats:
    with Watch() as w:
        f()

    tot = w.nanos

    if w.nanos == 0:
        raise RuntimeError("It's just impossible for CPython")

    n = int(max(1000_000 / w.nanos, 1))

    while True:
        with Watch() as w:
            summ = sampling(f, n=n)
            summ5 = sampling(f, n=5 * n)

        if (w.millis > 200
            and summ.median_abs_dev_pct < 1.0
            and summ.median - summ5.median < summ5.median_abs_dev
           ):

            return summ5

        tot += w.nanos

        if tot > 3_000_000_000:  # 3 seconds
            return summ5

        n *= 10  # x2 x5

注解

  1. PEP484 Was accepted in Python 3.5 which was released in September 2015 ↩

  2. 居然是传一个代码片段的字符串,和一个上下文环境的字典,非常地不正规,完全是 Python 2 那个时代的流行时尚 ↩

  3. 禁止 GC 可以排除由于 GC 带来的观测指标的波动,但长时间不 GC 会降低本身就很有限的性能。基本上这种操作仍然没有太大意义,外部的干扰因素有很多,应该通过统计学办法来排除干扰。 ↩

  4. 相当于 timeit 里的 number ↩

  5. 这是 CPython 永远不用担心的问题。 ↩

  6. 这不是中文语法错误,“稳定地”指相邻迭代的两轮观测统计的“平均值”相近,“稳定的”指一轮观测统计内部的“方差”小。 ↩

  7. 3 s ↩

  8. 相当于 timeit 里 repeat=50 ↩

  9. 通过分别把 $\lt 5\%$ 和 $\gt 95\%$ 的样本值修改为 $=5\%$ 和 $= 95\%$ 的样本值 ↩

  10. 绝对偏差的中位数,Median absolute deviation ,$\texttt{MAD} = \texttt{median}(\vert X_i-\texttt{median}(X) \vert)$,方法分类上属于 Average absolute deviation ,是一种被认为比样本方差和标准差更有鲁棒性的估算标量。 ↩

  11. 这里还细心检查了 loop 代码的下一次执行,也就是 n * 2 * 5 后是否会溢出,如果会溢出也提前返回。 ↩

  12. 假设环境是 Python 3.12 ↩

  13. Rust 的问题,或者说这是个组织结构问题,在于我们知道它的源代码结构非常地迭床架屋,里面充斥着形形色色重新发明的轮子和仅限内部使用的范用工具,从 DRY 的角度看这是双重的折磨,一方面轮子不应该重新发明,而应该使用一个范用库;另一方面,发明的范用工具,应该开放为范用库。而这二百多行代码就属于前者–重新造的轮子。 ↩

© minghu6

·

theme

Simplex theme logo

by golas