Installationsbericht zu FlashAttention 2 für Modellinferenz auf DGX Spark

Hintergrund der Installation



Nach der Installation von MOSS-TTS gab es keine größeren Probleme mit der Inferenzgeschwindigkeit. Allerdings hörte ich, dass FlashAttention 2 die Inferenz beschleunigen und den GPU-Speicherverbrauch senken könnte. Ich war neugierig, wie groß die tatsächliche Verbesserung sein würde, und beschloss daher, FlashAttention 2 zu installieren.


Was ist FlashAttention 2 und wie erhöht es die Inferenz-Effizienz?

FlashAttention 2 verstehe ich als eine Implementierung, die darauf abzielt, Attention-Operationen in Transformer-Architekturen effizienter zu verarbeiten. Meine Vermutung ist, dass die Geschwindigkeits- und Speichereffizienz durch die Reduzierung von Speicherzugriffen und der Erzeugung von Zwischen-Tensoren während der Attention-Berechnung oder durch die Optimierung des Rechenflusses verbessert wird. Ich ging jedoch davon aus, dass dieser Effekt je nach Modellstruktur, Eingabelänge, dtype (torch.float16 / torch.bfloat16) und GPU-Architektur variieren kann.


FlashAttention 2 auf DGX Spark installieren



Die README von MOSS-TTS enthält den folgenden Hinweis:

FlashAttention 2 is only available on supported GPUs and is typically used with torch.float16 or torch.bfloat16

Da es mit torch.float16 verwendet werden kann, ging ich davon aus, dass es auch auf DGX Spark funktionieren sollte, und versuchte die Installation.

flash-attn2-on-spark


1. Vorabprüfungen

  • DGX Spark verwendet CUDA 13.0, daher habe ich die Abhängigkeitspakete unter --extra-index-url https://download.pytorch.org/whl/cu130 gesucht und installiert. Zur Überprüfung habe ich die CUDA-Version zusätzlich mit nvidia-smi verifiziert.

  • Während der Installation musste ich das bereits im bestehenden venv installierte PyTorch weiterverwenden. Um zu verhindern, dass beim Bauen in der venv-Umgebung eine temporäre isolierte Umgebung erstellt wird, habe ich die Option --no-build-isolation hinzugefügt.

  • Die Wheel-Installation schlug in der Spark-Umgebung fehl, und im Installationsprotokoll wurde Folgendes ausgegeben. Dies liegt an der aarch64-Architektur. Da ich dies bei der Verwendung von Spark häufig erlebe, überrascht es mich nicht mehr. Es ärgert mich auch nicht mehr. Es ist eine vertraute Meldung.

Precompiled wheel not found. Building from source...

Es gab keine andere Wahl, als mit einem Source-Build fortzufahren. Da ninja für den Source-Build benötigt wird, installiere ich es in der venv.

bash pip install ninja

  • Auf dem Hostsystem werden Python 3.12 Entwicklungsbibliotheken benötigt. Falls Sie diese nicht haben, installieren Sie sie bitte.
sudo apt update
sudo apt install python3.12-dev -y

Flash-attn kompiliert C++- und CUDA-Code und verknüpft diesen mit Python. Dabei wird die Datei Python.h benötigt, die die interne Struktur von Python definiert. Da diese Datei in einer normalen Python-Laufzeitumgebung nicht enthalten ist, muss das Entwicklerpaket separat installiert werden.


2. Installationsbefehl

Dies ist der Kern dieses Beitrags. Unter Berücksichtigung aller oben genannten Punkte habe ich die Installation mit dem folgenden Befehl durchgeführt:

TORCH_CUDA_ARCH_LIST="12.0" MAX_JOBS=1 pip install --no-build-isolation --extra-index-url https://download.pytorch.org/whl/cu130 -e ".[flash-attn]"

3. Begründung der Befehlskombination (Entscheidung durch Trial-and-Error)

Zuerst lief auf der Maschine auch gpt-oss-120b. Als ich in diesem Zustand pip install ... -e ".[flash-attn]" ausführte, stieg die CPU-Auslastung rapide an, das System fror ein und das Terminal reagierte nicht mehr. Ich musste einen erzwungenen Neustart über den physischen Schalter durchführen. Danach habe ich alle ressourcenintensiven Aufgaben beendet und mich ausschließlich auf die Installation konzentriert.

Nach mehreren Versuchen konnte ich die Installation mit dem oben genannten Befehl abschließen. Die gesamte Installationszeit betrug etwa 1 bis 2 Stunden. Ich habe es nicht genau gemessen, aber die Installation lief über eine Stunde, während ich nicht am Platz war, und war abgeschlossen, als ich nach dem Essen zurückkam.

Während der Installation schien der Speicher durchgehend etwa 24 GB zu belegen. Das Problem war die CPU, und es war stabiler, andere Aufgaben zu beenden, um sich auf die Installation konzentrieren zu können.

Die Optionen wurden aus folgenden Gründen hinzugefügt:

  • TORCH_CUDA_ARCH_LIST="12.0": Ziel war es, nur die Blackwell-Architektur explizit anzusprechen, um die Installationszeit zu verkürzen.
  • MAX_JOBS=1: Da das System zuvor abgestürzt war, habe ich diesen Wert konservativ auf 1 gesetzt. Dies führte zu einer Installationszeit von über 60 Minuten.

Verbesserungen der Inferenz nach der Installation

1. Geschwindigkeitsaspekt

Die Geschwindigkeit war ehrlich gesagt nicht spürbar verbessert. Da die Inferenz auch ohne Flash-attn bereits schnell genug war, war das Gefühl einer „deutlichen Beschleunigung“ nicht groß, selbst wenn eine Verkürzung um einige Sekunden stattfand.

  • Die Generierung eines 7-sekündigen Ergebnisses dauerte etwa 8-9 Sekunden.
  • Die Generierung eines 25-sekündigen Ergebnisses dauerte etwa 32 Sekunden.
  • Ein 16-sekündiges Ergebnis dauerte 21 Sekunden.

Das bedeutet, dass die Inferenzzeit ungefähr das 1,3-fache der Länge des generierten Ergebnisses betrug.

2. Speicherverbrauch

Auch beim Speicher war kaum eine Veränderung festzustellen. Ich habe die nvidia-smi-Werte während der Inferenz kontinuierlich beobachtet, aber die Speichernutzung hat sich nicht merklich erhöht oder verringert. Der Stromverbrauch lag während der Inferenz bei etwa 36W, und die Temperatur stieg von 46 Grad auf etwa 53 Grad.


Zusammenfassung

  • Die Wheel-Installation von FlashAttention 2 schlug in der DGX Spark-Umgebung fehl, daher erfolgte die Installation über einen Source-Build.
  • Die Installation selbst war erfolgreich, aber der Build dauerte lange und die CPU-Auslastung war beträchtlich.
  • Nach der Installation gab es in Bezug auf Geschwindigkeit und Speicher nicht die erwarteten spürbaren Verbesserungen.
  • Wenn ich das nächste Mal einen Build durchführe, könnte ich MAX_JOBS auf etwa 4 erhöhen. Die Build-Zeit würde sich dann theoretisch auf ein Viertel reduzieren.

Verwandter Beitrag