DGX SparkへのFlashAttention 2導入体験記:モデル推論の高速化を目指して

インストール背景



MOSS-TTSのインストール後、推論速度に大きな不満はありませんでした。しかし、FlashAttention 2を導入すれば推論速度が向上し、GPUメモリ使用量も削減される可能性があるとの情報があり、実際にどの程度改善されるのか興味が湧きました。そこで、FlashAttention 2のインストールを進めることにしました。


FlashAttention 2とは何か、推論効率が向上する原理は?

FlashAttention 2は、Transformer系列におけるAttention演算をより効率的に処理するための実装であると理解しています。私の推測では、Attention計算プロセスにおけるメモリアクセスや中間テンソルの生成を削減したり、演算フローを最適化したりすることで、速度とメモリ効率が改善される仕組みだと考えました。ただし、この効果はモデル構造、入力長、dtype(torch.float16 / torch.bfloat16)、GPUアーキテクチャによって差が出ると判断しました。


DGX SparkにFlashAttention 2をインストールする



MOSS-TTSのREADMEを参照すると、以下の記述があります。

FlashAttention 2 is only available on supported GPUs and is typically used with torch.float16 or torch.bfloat16

torch.float16で利用可能であれば、DGX Sparkでも使用できると判断し、インストールを試みました。

flash-attn2-on-spark


1. 事前確認事項

  • DGX SparkはCUDA 13.0を使用しているため、--extra-index-url https://download.pytorch.org/whl/cu130から依存パッケージを検索してインストールしました。念のため、nvidia-smiでCUDAバージョンを確認しました。

  • インストール中にPyTorchを検索する際、既存のvenv環境にすでにインストールされているPyTorchをそのまま使用する必要がありました。venv環境でのビルド時に一時的な分離環境が生成されないよう、--no-build-isolationオプションを付加しました。

  • Spark環境ではwheelインストールに失敗し、インストールログに以下が出力されました。これはaarch64アーキテクチャが原因です。Sparkを使っているとよくあることなので、もはや驚きも苛立ちもなく、慣れたメッセージです。

Precompiled wheel not found. Building from source...

選択の余地なく、ソースビルドで進める必要があります。ソースビルドの過程で必要となるため、ninjaをvenvにインストールします。

bash pip install ninja

  • ホストシステムにはPython 3.12開発用ライブラリが必要です。お持ちでない方はインストールしてください。
sudo apt update
sudo apt install python3.12-dev -y

flash-attnはC++とCUDAコードをコンパイルしてPythonと連携させる過程を経ますが、この際にPythonの内部構造が定義されたPython.hファイルが必要となります。一般的なPython実行環境にはこのファイルは含まれていないため、開発者用パッケージを別途インストールする必要があります。


2. インストールコマンド

本稿の核心です。 上記の事項をすべて考慮し、次のコマンドでインストールしました。

TORCH_CUDA_ARCH_LIST="12.0" MAX_JOBS=1 pip install --no-build-isolation --extra-index-url https://download.pytorch.org/whl/cu130 -e ".[flash-attn]"

3. コマンド組み合わせの理由(試行錯誤による決定)

当初、マシンではgpt-oss-120bも稼働している状態でした。その状態でpip install ... -e ".[flash-attn]"を実行したところ、CPU使用量が急激に高騰し、システムが停止。ターミナルがフリーズしたため、物理スイッチで強制再起動する羽目になりました。その後は、リソースを使用するタスクをすべて停止し、インストールに専念しました。

何度かの試行錯誤の末、上記のコマンドでインストールを完了しました。総所要時間は1~2時間程度でした。正確には計測していませんが、1時間以上のインストールが続いたため席を外し、食事後に戻ってきたときには完了していました。

インストール中、メモリは継続的に約24GB程度を占有しているようでした。問題はCPUで、インストールに集中できるよう、他の作業は停止しておく方が安定していました。

オプションを付加した理由は以下の通りです。

  • TORCH_CUDA_ARCH_LIST="12.0": Blackwellアーキテクチャのみを明示的にターゲットとすることで、インストール時間の短縮を図る目的でした。
  • MAX_JOBS=1: 直前にシステムが停止した経験があったため、保守的に1に設定しました。その結果、インストールには60分以上を要しました。

インストール後の推論の向上効果

1. 速度面

速度については、正直なところ大きな体感はありませんでした。FlashAttention未導入時でも十分に高速だったため、数秒の短縮があったとしても、体感として「確実に速くなった」という感覚はあまりありませんでした。

  • 約7秒の生成結果に対して8~9秒ほどかかりました。
  • 約25秒の生成結果に対して32秒ほどかかりました。
  • 約16秒の生成結果は21秒ほどかかりました。

つまり、生成された結果物の長さに対して約1.3倍程度の推論時間がかかるという感触でした。

2. メモリ面

メモリについても、変化を感じ取るのは困難でした。推論中にnvidia-smiの数値を継続的に観察しましたが、メモリ使用量が目立って増減することはありませんでした。推論中の電力は36W前後で、温度は46度から53度程度に上昇するレベルでした。


まとめ

  • DGX Spark環境でのFlashAttention 2はwheelインストールに失敗したため、ソースビルドでインストールしました。
  • インストール自体は成功しましたが、ビルドに時間がかかり、CPU負荷が相当高かったです。
  • インストール後、速度とメモリの面で期待したほどの体感的な改善は大きくありませんでした。
  • 次回もしビルドをやり直す機会があれば、MAX_JOBS=4程度に上げても良いかもしれません。ビルド時間が理論上1/4程度に短縮される可能性があります。

関連記事