Installation de FlashAttention 2 sur DGX Spark pour l'inférence de modèles : Mon retour d'expérience
Contexte de l'installation
Après l'installation de MOSS-TTS, la vitesse d'inférence ne posait pas de problème majeur. Cependant, j'avais entendu dire que FlashAttention 2 pouvait accélérer l'inférence et réduire la consommation de mémoire GPU. Curieux de savoir l'ampleur réelle de cette amélioration, j'ai décidé de procéder à son installation.
Qu'est-ce que FlashAttention 2 et comment améliore-t-il l'efficacité d'inférence ?
Je perçois FlashAttention 2 comme une implémentation conçue pour traiter les opérations d'Attention de manière plus efficace au sein des architectures Transformer. Mon hypothèse est qu'il améliore la vitesse et l'efficacité de la mémoire en réduisant les accès mémoire et la création de tenseurs intermédiaires pendant le calcul de l'Attention, ou en optimisant le flux des opérations. Il est important de noter que cet effet peut varier considérablement en fonction de l'architecture du modèle, de la longueur de l'entrée, du dtype (torch.float16 / torch.bfloat16) et de l'architecture GPU.
Installation de FlashAttention 2 sur DGX Spark
Le fichier README de MOSS-TTS contient la mention suivante :
FlashAttention 2 is only available on supported GPUs and is typically used with torch.float16 or torch.bfloat16
Étant donné qu'il est utilisable avec torch.float16, j'ai supposé qu'il pourrait fonctionner sur DGX Spark et j'ai tenté l'installation.

1. Vérifications préalables
-
Le DGX Spark utilisant CUDA 13.0, j'ai recherché et installé les paquets de dépendances via
--extra-index-url https://download.pytorch.org/whl/cu130. Pour m'en assurer, j'ai vérifié la version de CUDA avecnvidia-smi. -
Lors de l'installation, il a été nécessaire d'utiliser la version de PyTorch déjà présente dans l'environnement venv existant. J'ai ajouté l'option
--no-build-isolationpour éviter la création d'un environnement d'isolation temporaire lors de la compilation dans le venv. -
L'installation via wheel a échoué dans l'environnement Spark, et le journal d'installation affichait le message suivant. C'est dû à l'architecture aarch64. C'est un scénario fréquent avec Spark, donc cela ne me surprend plus, ni ne m'irrite. C'est un message familier.
Precompiled wheel not found. Building from source...
Il n'y avait pas d'autre choix que de procéder à une compilation à partir des sources. Pour cela, j'ai installé ninja dans le venv, car il est nécessaire pour le processus de compilation.
bash
pip install ninja
- Les bibliothèques de développement Python 3.12 sont requises sur le système hôte ; si vous ne les avez pas, installez-les.
sudo apt update
sudo apt install python3.12-dev -y
Flash-attn compile du code C++ et CUDA pour le lier à Python, ce qui nécessite le fichier Python.h qui définit la structure interne de Python. Cet en-tête n'est généralement pas inclus dans un environnement d'exécution Python standard, d'où la nécessité d'installer un paquet de développement séparé.
2. Commande d'installation
C'est le point central de cet article. En tenant compte de toutes les considérations précédentes, j'ai procédé à l'installation avec la commande suivante :
TORCH_CUDA_ARCH_LIST="12.0" MAX_JOBS=1 pip install --no-build-isolation --extra-index-url https://download.pytorch.org/whl/cu130 -e ".[flash-attn]"
3. Justification des options de commande (décision par essais et erreurs)
Initialement, le modèle gpt-oss-120b tournait également sur la machine. En exécutant pip install ... -e ".[flash-attn]" dans cet état, l'utilisation du CPU a grimpé en flèche, le système s'est figé et le terminal est devenu inutilisable, m'obligeant à un redémarrage forcé via l'interrupteur physique. Par la suite, j'ai arrêté toutes les tâches gourmandes en ressources pour me concentrer uniquement sur l'installation.
Après plusieurs tentatives et ajustements, j'ai finalement réussi l'installation avec la commande mentionnée ci-dessus. Le temps total d'installation a été d'environ 1 à 2 heures. Bien que je n'aie pas mesuré précisément, l'installation a duré plus d'une heure ; je suis parti, et à mon retour après le repas, elle était terminée.
Pendant l'installation, la consommation de mémoire s'est maintenue autour de 24 Go. Le problème principal était le CPU, et il était plus stable de suspendre les autres tâches pour se concentrer sur l'installation.
Les options ont été ajoutées pour les raisons suivantes :
TORCH_CUDA_ARCH_LIST="12.0": L'objectif était de cibler explicitement l'architecture Blackwell afin de réduire le temps d'installation.MAX_JOBS=1: Suite à l'expérience précédente où le système s'était bloqué, j'ai opté pour une approche conservatrice en réglant cette valeur à 1. En conséquence, l'installation a pris plus de 60 minutes.
Effets de l'amélioration de l'inférence après installation
1. En termes de vitesse
Honnêtement, la vitesse n'a pas été significativement améliorée. Le système était déjà assez rapide sans FlashAttention, donc même avec un gain de quelques secondes, la sensation d'une accélération "nette" n'était pas flagrante.
- La génération d'un résultat d'environ 7 secondes prenait 8 à 9 secondes.
- La génération d'un résultat de 25 secondes prenait environ 32 secondes.
- Un résultat de 16 secondes prenait 21 secondes.
En d'autres termes, le temps d'inférence semblait être environ 1,3 fois la durée du résultat généré.
2. En termes de mémoire
En ce qui concerne la mémoire, il a également été difficile de percevoir un changement. J'ai surveillé en permanence les valeurs de nvidia-smi pendant l'inférence, mais l'utilisation de la mémoire n'a pas augmenté ni diminué de manière notable. Pendant l'inférence, la consommation électrique était d'environ 36W, et la température augmentait de 46°C à environ 53°C.
Conclusion
- Dans l'environnement DGX Spark, l'installation de FlashAttention 2 via wheel ayant échoué, j'ai procédé à une compilation à partir des sources.
- L'installation a été un succès, mais le temps de compilation a été long et la charge CPU considérable.
- Après l'installation, les améliorations perceptibles en termes de vitesse et de mémoire n'ont pas été aussi significatives que prévu.
- Si je devais refaire la compilation, j'augmenterais probablement
MAX_JOBSà 4. Le temps de compilation devrait théoriquement être réduit d'environ un quart.
Articles connexes