Installatie-ervaring met FlashAttention 2 voor modelinferentie op DGX Spark

Achtergrond van de installatie

Na de installatie van MOSS-TTS ondervonden we geen grote problemen met de inferentiesnelheid. Echter, aangezien FlashAttention 2 de inferentiesnelheid zou kunnen verhogen en het GPU-geheugengebruik zou kunnen verminderen, waren we benieuwd naar de daadwerkelijke verbetering. Daarom zijn we overgegaan tot de installatie van FlashAttention 2.


Wat is FlashAttention 2 en hoe verhoogt het de inferentie-efficiëntie?

FlashAttention 2 wordt gezien als een implementatie die Attention-bewerkingen in Transformer-architecturen efficiënter verwerkt. Mijn vermoeden is dat de snelheid en geheugenefficiëntie verbeteren door geheugentoegang en de aanmaak van tussentijdse tensors tijdens de Attention-berekening te verminderen, of door de operationele stroom te optimaliseren. Ik realiseerde me echter dat dit effect kan variëren afhankelijk van de modelstructuur, invoerlengte, dtype (torch.float16 / torch.bfloat16) en GPU-architectuur.


FlashAttention 2 installeren op DGX Spark

De README van MOSS-TTS vermeldt het volgende:

FlashAttention 2 is only available on supported GPUs and is typically used with torch.float16 or torch.bfloat16

Als het beschikbaar is voor torch.float16, ging ik ervan uit dat het ook op DGX Spark gebruikt kon worden en heb ik geprobeerd het te installeren.

flash-attn2-on-spark


1. Voorafgaande controles

  • Aangezien DGX Spark CUDA 13.0 gebruikt, heb ik de afhankelijkheidspakketten geïnstalleerd via --extra-index-url https://download.pytorch.org/whl/cu130. Voor de zekerheid heb ik de CUDA-versie gecontroleerd met nvidia-smi.

  • Tijdens de installatie moest ik de bestaande PyTorch-installatie in de venv-omgeving gebruiken. Om te voorkomen dat er tijdens het bouwen in de venv-omgeving een tijdelijke geïsoleerde omgeving werd aangemaakt, heb ik de optie --no-build-isolation toegevoegd.

  • De installatie van het wheel-bestand mislukte in de Spark-omgeving, en de installatielog toonde het volgende. Dit komt door de aarch64-architectuur. Het is iets wat ik vaak tegenkom bij het gebruik van Spark, dus het verbaast me niet meer. Ik erger me er ook niet meer aan. Het is een bekende boodschap.

Precompiled wheel not found. Building from source...

Er was geen andere keuze dan door te gaan met een broncode-build. Omdat dit nodig is tijdens het broncode-bouwproces, installeer ik ninja in de venv.

bash pip install ninja

  • De Python 3.12 ontwikkelbibliotheken zijn vereist op het hostsyteem; installeer deze als u ze nog niet heeft.
sudo apt update
sudo apt install python3.12-dev -y

Flash-attn compileert C++- en CUDA-code en koppelt deze aan Python. Hiervoor is het Python.h-bestand nodig, dat de interne structuur van Python definieert. Dit bestand is meestal niet inbegrepen in een standaard Python-runtimeomgeving, dus een apart ontwikkelaarspakket moet worden geïnstalleerd.


2. Installatiecommando

Dit is de kern van deze post. Met alle bovenstaande overwegingen heb ik de installatie uitgevoerd met het volgende commando:

TORCH_CUDA_ARCH_LIST="12.0" MAX_JOBS=1 pip install --no-build-isolation --extra-index-url https://download.pytorch.org/whl/cu130 -e ".[flash-attn]"

3. Reden voor de commandocombinatie (besluit genomen door vallen en opstaan)

Aanvankelijk draaide gpt-oss-120b ook op de machine. Toen ik pip install ... -e ".[flash-attn]" uitvoerde, schoot het CPU-gebruik omhoog, het systeem liep vast en de terminal reageerde niet meer, waardoor ik een geforceerde herstart via de fysieke schakelaar moest uitvoeren. Daarna heb ik alle resource-intensieve taken gestopt en me volledig gericht op de installatie.

Na verschillende pogingen en fouten heb ik de installatie met het bovenstaande commando voltooid. De totale tijd bedroeg ongeveer 1 tot 2 uur. Hoewel niet precies gemeten, was de installatie langer dan een uur bezig toen ik wegging, en was deze voltooid toen ik na het eten terugkwam.

Tijdens de installatie leek het geheugen constant ongeveer 24 GB te verbruiken. Het probleem lag bij de CPU, en het was stabieler om andere taken te stoppen zodat de installatie ongehinderd kon doorgaan.

De redenen voor het gebruik van de opties zijn als volgt:

  • TORCH_CUDA_ARCH_LIST="12.0": Het doel was om de installatietijd te verkorten door expliciet de Blackwell-architectuur te targeten.
  • MAX_JOBS=1: Aangezien het systeem eerder was vastgelopen, heb ik dit conservatief op 1 ingesteld. Als gevolg hiervan duurde de installatie meer dan 60 minuten.

Verbetering van inferentie na installatie

1. Snelheid

Eerlijk gezegd merkte ik geen significante snelheidsverbetering. Zelfs zonder FlashAttention was het al snel genoeg, dus zelfs als er een paar seconden tijdwinst was, voelde het niet als een 'duidelijke versnelling'.

  • Het genereren van een resultaat van ongeveer 7 seconden duurde 8-9 seconden.
  • Het genereren van een resultaat van 25 seconden duurde ongeveer 32 seconden.
  • Een resultaat van 16 seconden duurde 21 seconden.

Dit betekent dat de inferentietijd ongeveer 1,3 keer de lengte van het gegenereerde resultaat bedroeg.

2. Geheugen

Ook op het gebied van geheugen was er weinig verandering merkbaar. Hoewel ik de nvidia-smi-waarden tijdens de inferentie continu in de gaten hield, nam het geheugengebruik niet significant toe of af. Tijdens de inferentie lag het stroomverbruik rond de 36W en steeg de temperatuur van 46 naar ongeveer 53 graden Celsius.


Samenvatting

  • In de DGX Spark-omgeving mislukte de wheel-installatie van FlashAttention 2, dus moest het vanuit de bron worden gebouwd.
  • De installatie zelf was succesvol, maar het bouwproces duurde lang en er was een aanzienlijke CPU-belasting.
  • Na de installatie was de waargenomen verbetering in snelheid en geheugen niet zo groot als verwacht.
  • Als ik de volgende keer opnieuw moet bouwen, zou ik MAX_JOBS=4 kunnen overwegen. Dit zou de bouwtijd theoretisch met een kwart kunnen verkorten.

Gerelateerde post