DGX Spark 上 FlashAttention 2 模型推理安装实录

安装背景



在安装 MOSS-TTS 后,其推理速度已足以满足日常需求。然而,听说安装 FlashAttention 2 可以显著提升推理速度并减少 GPU 内存占用,这让我非常好奇实际效果究竟如何。因此,我决定着手安装 FlashAttention 2。


FlashAttention 2 是什么?它如何提升推理效率?

据我理解,FlashAttention 2 是 Transformer 架构中用于更高效处理 Attention 运算的一种实现。我推测,它通过减少 Attention 计算过程中的内存访问和中间张量生成,或优化运算流程,从而提升速度和内存效率。不过,我判断这种效果会因模型结构、输入长度、数据类型 (torch.float16 / torch.bfloat16) 和 GPU 架构的不同而有所差异。


在 DGX Spark 上安装 FlashAttention 2



参考 MOSS-TTS 的 README 文件,其中有这样一段话:

FlashAttention 2 is only available on supported GPUs and is typically used with torch.float16 or torch.bfloat16

既然可以在 torch.float16 环境中使用,我判断 DGX Spark 应该也适用,于是尝试进行安装。

flash-attn2-on-spark


1. 前期准备

  • DGX Spark 使用 CUDA 13.0,因此我从 --extra-index-url https://download.pytorch.org/whl/cu130 查找并安装了依赖包。为了确认,我还通过 nvidia-smi 检查了 CUDA 版本。

  • 安装过程中,由于 PyTorch 已经在现有的 venv 环境中安装,我需要继续使用它。在 venv 环境中进行构建时,为了避免创建临时的隔离环境,我加入了 --no-build-isolation 选项。

  • 在 Spark 环境中,wheel 包安装失败,安装日志中输出以下信息。这通常是由于 aarch64 架构造成的。在使用 Spark 时,这种情况屡见不鲜,我已经见怪不怪了,甚至不再感到恼火,这已是熟悉的提示。

    Precompiled wheel not found. Building from source...

    别无选择,只能通过源码构建。源码构建过程中需要 ninja,因此将其安装到 venv 中:

    bash pip install ninja

  • 主机系统需要 Python 3.12 的开发库,如果您的系统尚未安装,请执行以下命令:

    bash sudo apt update sudo apt install python3.12-dev -y

    flash-attn 在编译 C++ 和 CUDA 代码并将其与 Python 连接时,需要 Python.h 文件,该文件定义了 Python 的内部结构。通常的 Python 运行环境中不包含此文件,因此需要单独安装开发者包。


2. 安装命令

这是本文的核心部分。 综合考虑上述所有事项,我使用了以下命令进行安装:

TORCH_CUDA_ARCH_LIST="12.0" MAX_JOBS=1 pip install --no-build-isolation --extra-index-url https://download.pytorch.org/whl/cu130 -e ".[flash-attn]"

3. 命令组合原因 (基于试错的决定)

最初,机器上还在运行 gpt-oss-120b。在这种情况下执行 pip install ... -e ".[flash-attn]" 后,CPU 使用率急剧飙升,系统随之卡死,终端也失去了响应,最终我不得不通过物理开关强制重启。此后,我暂停了所有占用资源的任务,专注于安装。

经过多次尝试,最终我通过上述命令成功完成了安装。整个过程大约耗时 1 到 2 小时。虽然没有精确测量,但安装持续了一个多小时,我离开去用餐,回来时发现安装已经完成。

安装期间,内存占用持续在 24GB 左右。主要问题在于 CPU,为了确保安装的稳定性,最好暂停其他任务。

加入这些选项的原因如下:

  • TORCH_CUDA_ARCH_LIST="12.0": 明确指定以 Blackwell 架构为目标,旨在缩短安装时间。
  • MAX_JOBS=1: 鉴于之前系统卡死的经验,我保守地将其设置为 1。结果导致安装时间超过 60 分钟。

安装后的推理效果提升

1. 速度方面

老实说,速度方面并没有显著的体感提升。在未安装 flash-attn 时,速度已经足够快,即使缩短了几秒钟,也未能给我“确实变快了”的强烈感受。

  • 生成约 7 秒的结果,耗时约 8-9 秒。
  • 生成约 25 秒的结果,耗时约 32 秒。
  • 生成约 16 秒的结果,耗时 21 秒。

也就是说,推理时间大约是生成结果长度的 1.3 倍左右。

2. 内存方面

内存方面也难以察觉到变化。在推理过程中,我持续观察 nvidia-smi 的数值,但内存使用量并没有明显增加或减少。推理期间,功耗维持在 36W 左右,温度从 46 度上升到 53 度左右。


总结

  • 在 DGX Spark 环境中,FlashAttention 2 的 wheel 包安装失败,最终通过源码构建完成。
  • 安装虽然成功,但构建时间较长,且 CPU 负载相当高。
  • 安装后,速度和内存方面的实际体验提升未达到预期。
  • 如果下次再次构建,我会考虑将 MAX_JOBS 提升到 4 左右。理论上,构建时间有望缩短至原先的四分之一。

相关文章