开发者
鲲鹏数学库FFT与 NumPy的兼容性适配及实践总结

鲲鹏数学库FFT与 NumPy的兼容性适配及实践总结

HPC

发表于 2026/05/12

0

1. 引言

NumPy 是 Python 科学计算生态的基石,其 numpy.fft 模块提供了高效的离散傅里叶变换接口。在 NumPy 1.22.4 版本中,默认的底层实现为 PocketFFT——一个轻量级、纯 C99 编写的 FFT 库,具有良好的跨平台可移植性。然而,在鲲鹏处理器平台上,PocketFFT 无法充分利用架构特有的 SIMD 指令集,导致大规模数据处理场景下未能发挥全部硬件潜能。

鲲鹏数学库(Kunpeng Math Library,简称 KML)是华为针对鲲鹏处理器深度优化的高性能数学函数库,其中的 KML_FFT 组件专门对一维及多维复/实 FFT 进行了 SVE/NEON 向量化实现,并融合了缓存友好算法与自适应线程调度,能够显著提升计算吞吐量。本文档旨在介绍如何将 NumPy(v1.22.4)的离散傅里叶变换(FFT)后端从默认的 PocketFFT 迁移至鲲鹏数学库 ​KML_FFT​,以充分发挥鲲鹏处理器的 SVE/NEON 指令集优势,提升大规模数据处理的计算效能。

2. Numpy FFT架构

所有面向用户的函数(fftifftrfftirffthfft 等)均定义在 numpy/fft/_pocketfft.py 中。该层负责处理参数标准化,包括:

  • 轴归一化​:将用户指定的 axis 转换为内部使用的正整数索引,并支持负数轴。
  • 长度调整​:若 n 大于输入数组在目标轴的长度,自动进行零填充(zero-padding);若 n 小于原长度,则执行截断。
  • 规范化​:根据 norm 参数计算前向/反向变换的缩放因子,统一在变换后施加。
  • 数据类型路由​:根据输入数组是否为单精度或复数,调度到不同的内部执行路径。

numpy.fft.fft 为例,其核心调用逻辑为:

@array_function_dispatch(_fft_dispatcher)
def fft(a, n=None, axis=-1, norm=None):
    """
    Compute the one-dimensional discrete Fourier Transform.

    This function computes the one-dimensional *n*-point discrete Fourier
    Transform (DFT) with the efficient Fast Fourier Transform (FFT)
    algorithm [CT].

    Parameters

    """
    a = asarray(a)
    if n is None:
        n = a.shape[axis]
    inv_norm = _get_forward_norm(n, norm) # 获取反向归一化因子
    output = _raw_fft(a, n, axis, False, True, inv_norm)
    return output

2.2. 核心调度层

_raw_fft 函数,同样位于_pocketfft.py中,它完成:

  • 调用 normalize_axis_index 确保 axis 合法。
  • 对输入数组进行切片或填充,使目标轴长度恰好为 n。
  • 将目标轴交换至最后一维(swapaxes),保证 C 层总是对最内维进行连续一维 FFT。
  • 根据数据类型调用对应的 C 扩展函数。
  • 单精度浮点/复数:pocketfft_internal 或 pocketfftf_internal(经函数指针 pfif.execute / pfi.execute)。
  • 双精度浮点/复数:同理。将结果轴交换回原始位置。

该层巧妙地屏蔽了多维变换的复杂性,使得 C 扩展模块仅需处理批量的一维连续序列。

# `inv_norm` is a float by which the result of the transform needs to be
# divided. This replaces the original, more intuitive 'fct` parameter to avoid
# divisions by zero (or alternatively additional checks) in the case of
# zero-length axes during its computation.
def _raw_fft(a, n, axis, is_real, is_forward, inv_norm):
    axis = normalize_axis_index(axis, a.ndim)
    if n is None:
        n = a.shape[axis]inv_n

    fct = 1/inv_norm

    if a.shape[axis] != n:
        s = list(a.shape)
        index = [slice(None)]*len(s)
        if s[axis] > n:
            index[axis] = slice(0, n)
            a = a[tuple(index)]
        else:
            index[axis] = slice(0, s[axis])
            s[axis] = n
            z = zeros(s, a.dtype.char)
            z[tuple(index)] = a
            a = z

    if a.dtype == float32 or a.dtype == complex64:
        if axis == a.ndim-1:
            r = pfif.execute(a, is_real, is_forward, fct)
        else:
            a = swapaxes(a, axis, -1)
            r = pfif.execute(a, is_real, is_forward, fct)
            r = swapaxes(r, axis, -1)
    else:
        if axis == a.ndim-1:
            r = pfi.execute(a, is_real, is_forward, fct)
        else:
            a = swapaxes(a, axis, -1)
            r = pfi.execute(a, is_real, is_forward, fct)
            r = swapaxes(r, axis, -1)

    return r

2.3. C扩展桥阶层

通过 Cython 生成的 _pocketfft_internal_pocketfftf_internal 扩展模块提供 Python 对象到 C 数组的转换。其中关键的入口函数为:

  • execute_complex(data, is_forward, fct):处理复数域的 FFT。
  • execute_real(data, is_forward, fct):处理实数域的 RFFT。

这些函数负责将传入的 NumPy 数组强制转换为连续的 C 复合类型数组(实部/虚部交错存储),并遍历其中的批量序列(nrepeats 为批量大小),对每个一维序列执行 FFT。适配的核心工作正是在这些底层执行函数中引入 KML_FFT 调用分支。

3. KML_FFT接入

3.1. 设计原则

  1. 接口透明​:不改变任何 Python 级别的 API 签名、参数语义和返回结果格式。
  2. 条件式替换​:通过编译宏 HAVE_HUAWEI_KML 控制后端选择,同一份源码可在标准 x86 环境与鲲鹏环境下分别使用 PocketFFT 或 KML_FFT,避免代码分叉。
  3. 性能优先​:对大批量变换优先使用 KML_FFT,充分利用其向量化和并行能力;对小批量变换仍保留 PocketFFT,避免计划(plan)创建开销。
  4. 线程安全与全局状态​:KML_FFT 内部管理线程池,需确保其与 Python GIL(全局解释器锁)的正确协调。

3.2 核心适配代码剖析

当前修改主要集中在 numpy/fft/_pocketfft.cexecute_complex 函数。以下为完整适配代码:

static PyObject *
execute_complex(PyObject *a1, int is_forward, double fct)
{
    // 将输入强制转换为 C 连续的复数双精度数组
    PyArrayObject *data = (PyArrayObject *)PyArray_FromAny(a1,
            PyArray_DescrFromType(NPY_CDOUBLE), 1, 0,
            NPY_ARRAY_ENSURECOPY | NPY_ARRAY_DEFAULT |
            NPY_ARRAY_ENSUREARRAY | NPY_ARRAY_FORCECAST,
            NULL);
    if (!data) return NULL;
    int npts = PyArray_DIM(data, PyArray_NDIM(data) - 1);
    int nrepeats = PyArray_SIZE(data)/npts;
    double *dptr = (double *)PyArray_DATA(data);
    int fail=0;
    // 小批量场景:沿用 PocketFFT,避免 KML plan 开销    
    if(nrepeats < 256){
      cfft_plan plan=NULL;
      Py_BEGIN_ALLOW_THREADS;
      plan = make_cfft_plan(npts);
      if (!plan) fail=1;
      if (!fail)
        for (int i = 0; i < nrepeats; i++) {
            int res = is_forward ?
              cfft_forward(plan, dptr, fct) : cfft_backward(plan, dptr, fct);
            if (res!=0) { fail=1; break; }
            dptr += npts*2;
        }
      if (plan) destroy_cfft_plan(plan);
      Py_END_ALLOW_THREADS;
    } else { // 大批量场景:启用 KML_FFT 加速路径 

#ifndef HAVE_HUAWEI_KML
    // 创建 KML_FFT plan        
    cfft_plan plan=NULL;
    Py_BEGIN_ALLOW_THREADS;
    plan = make_cfft_plan(npts);

    if (!plan) fail=1;
    if (!fail)
      for (int i = 0; i < nrepeats; i++) {
          int res = is_forward ?
            cfft_forward(plan, dptr, fct) : cfft_backward(plan, dptr, fct);
          if (res!=0) { fail=1; break; }
          dptr += npts*2;
      }
    if (plan) destroy_cfft_plan(plan);
    Py_END_ALLOW_THREADS;

#else

    kml_fft_plan plan; 
    plan = is_forward?
        kml_fft_plan_dft_1d(npts, dptr , dptr, KML_FFT_FORWARD, KML_FFT_ESTIMATE) :
        kml_fft_plan_dft_1d(npts, dptr , dptr, KML_FFT_BACKWARD, KML_FFT_ESTIMATE);

    for (int i = 0; i < nrepeats; i++) {
      kml_fft_execute_dft(plan, dptr, dptr); 
      // ifft normalization
      if(is_forward == 0){
        int i;
        for (i = 0; i < npts*2 - 3; i+=4) {
            dptr[i] /= npts;
            dptr[i+1] /= npts;
            dptr[i+2] /= npts;
            dptr[i+3] /= npts;
        }
        for(;i < npts*2; ++i) {
            dptr[i] /= npts;
        }
      }      
      dptr += npts*2;
    }
    kml_fft_destroy_plan(plan); 
#endif // HAVE_HUAWEI_KML
    }
    if (fail) {
        Py_XDECREF(data);
        return PyErr_NoMemory();
    }
    return (PyObject *)data;
}

代码说明:

  1. 内存布局:代码使用了 NPY_ARRAY_ENSURECOPY 和 PyArray_DATA。 KML 要求输入数据在内存中是连续的。通过 PyArray_FromAny 强制转换为 NPY_CDOUBLE(复数双精度),确保了数据符合 kml_fft_plan_dft_1d 的接口要求。在 kml_fft_execute_dft(plan, dptr, dptr) 中,输入和输出指针相同,节省了内存带宽。
  2. 逆变换归一化 (Normalization):FFT 算法在执行逆变换(IFFT)时,标准定义通常需要除以信号长度 N。代码中通过手动循环对 dptr 进行缩放:dptr[i] /= npts。采用了 循环展开一次处理 4 个元素,有助于编译器更好地进行指令级并行优化。
  3. 条件编译:使用 #ifndef HAVE_HUAWEI_KML 保证了代码的兼容性。如果编译环境中没有安装或未指定使用 KML,代码会自动回退到通用的 cfft 实现,不会导致编译失败。

NumPy FFT 后端迁移至 KML_FFT的代码分析说明至此结束,具体的构建使用方法请参见鲲鹏社区HPCKit迁移指南:Numpy使用KML

本页内容