DGX Spark 模型推論:FlashAttention 2 安裝實錄與效能評測

安裝背景



在安裝 MOSS-TTS 後,模型的推論速度並未造成顯著困擾。然而,由於有資料指出安裝 FlashAttention 2 能提升推論速度並減少 GPU 記憶體用量,我對實際的改善效果感到好奇。因此,我著手進行了 FlashAttention 2 的安裝。


FlashAttention 2 是什麼?其推論效率提升的原理為何?

據我了解,FlashAttention 2 是一個用於 Transformer 系列模型中,更有效率處理 Attention 運算的實作方案。我推測,它可能是透過減少 Attention 計算過程中的記憶體存取和中間張量生成,或是最佳化運算流程來提升速度和記憶體效率。不過,我認為這種效果會因模型架構、輸入長度、資料類型 (dtype,如 torch.float16 / torch.bfloat16) 以及 GPU 架構而有所差異。


在 DGX Spark 上安裝 FlashAttention 2



參考 MOSS-TTS 的 README 文件,其中有以下說明:

FlashAttention 2 is only available on supported GPUs and is typically used with torch.float16 or torch.bfloat16

既然可以在 torch.float16 環境下使用,我判斷 DGX Spark 也應該能夠支援,於是便嘗試進行安裝。

flash-attn2-on-spark


1. 前置確認事項

  • DGX Spark 使用 CUDA 13.0,因此我從 --extra-index-url https://download.pytorch.org/whl/cu130 尋找並安裝了相依套件。為了再次確認,我還是用 nvidia-smi 檢查了 CUDA 版本。

  • 在安裝過程中尋找 PyTorch 時,我必須沿用現有 venv 環境中已安裝的 PyTorch。為了避免在 venv 環境中建置時產生臨時隔離環境,我加入了 --no-build-isolation 選項。

  • 在 Spark 環境中,wheel 安裝失敗,安裝日誌中顯示了以下訊息。這是因為 aarch64 架構的關係。在使用 Spark 時,這種情況屢見不鮮,我已經見怪不怪了,甚至不再感到惱火,這訊息對我來說很熟悉。

Precompiled wheel not found. Building from source...

別無選擇,只能進行原始碼建置。由於原始碼建置過程需要 ninja,因此將其安裝到 venv 中。

bash pip install ninja

  • 主機系統需要 Python 3.12 開發函式庫,如果尚未安裝,請進行安裝。
sudo apt update
sudo apt install python3.12-dev -y

flash-attn 會編譯 C++ 和 CUDA 程式碼,並將其連結到 Python。此過程中需要定義 Python 內部結構的 Python.h 文件。由於一般的 Python 執行環境不包含此文件,因此需要單獨安裝開發者套件。


2. 安裝指令

這是本文的重點。 綜合考量以上所有事項後,我使用以下指令進行安裝:

TORCH_CUDA_ARCH_LIST="12.0" MAX_JOBS=1 pip install --no-build-isolation --extra-index-url https://download.pytorch.org/whl/cu130 -e ".[flash-attn]"

3. 指令組合的原因 (透過試錯決定的方案)

最初,機器上同時運行著 gpt-oss-120b。在這種情況下執行 pip install ... -e ".[flash-attn]" 後,CPU 使用率急劇飆升,導致系統停擺,終端機也失去響應,最終我只能透過實體開關強制重啟。

此後,我關閉了所有耗用資源的任務,專注於安裝。

經過多次嘗試後,我最終使用上述指令完成了安裝。總共耗時約 1 到 2 小時。雖然沒有精確測量,但安裝過程持續了一個多小時,我離開了一段時間,用餐回來時安裝已經完成。

安裝期間,記憶體似乎持續佔用約 24GB。真正的問題出在 CPU,為了確保安裝過程穩定,最好還是關閉其他正在運行的任務。

加入這些選項的原因如下:

  • TORCH_CUDA_ARCH_LIST="12.0":目的是明確指定 Blackwell 架構,以縮短安裝時間。
  • MAX_JOBS=1:由於之前系統曾當機,我保守地將其設定為 1。結果安裝時間超過 60 分鐘。

安裝後推論效果的提升

1. 速度方面

老實說,速度方面並沒有明顯的體感提升。因為即使沒有安裝 flash-attn,原本的速度也已經足夠快了,所以即使有幾秒鐘的縮短,也難以讓人感覺到「確實變快了」。

  • 生成約 7 秒的結果,大約需要 8~9 秒。
  • 生成約 25 秒的結果,大約需要 32 秒。
  • 生成約 16 秒的結果,大約需要 21 秒。

換句話說,推論時間大約是生成結果長度的 1.3 倍左右。

2. 記憶體方面

記憶體方面也沒有感受到變化。在推論過程中,我持續觀察 nvidia-smi 的數值,但記憶體用量並沒有顯著增加或減少。推論期間的功耗約在 36W 左右,溫度則從 46 度上升到 53 度左右。


總結

  • 在 DGX Spark 環境中,FlashAttention 2 的 wheel 安裝失敗,最終是透過原始碼建置完成安裝。
  • 雖然安裝本身成功了,但建置時間漫長,且 CPU 負載相當高。
  • 安裝後,速度和記憶體方面並未達到預期中顯著的體感改善。
  • 如果下次再次進行建置,我可能會將 MAX_JOBS 設定提高到 4 左右。理論上,建置時間應該可以縮短到原來的四分之一。

相關文章