Mi Experiencia Instalando FlashAttention 2 para Inferencia de Modelos en DGX Spark
Contexto de la Instalación
Después de instalar MOSS-TTS, no experimenté mayores inconvenientes con la velocidad de inferencia. Sin embargo, al enterarme de que FlashAttention 2 podría acelerar la inferencia y reducir el uso de memoria de la GPU, sentí curiosidad por ver cuánto mejoraría realmente. Por ello, decidí proceder con su instalación.
¿Qué es FlashAttention 2 y cuál es el principio detrás de su eficiencia en inferencia?
Entiendo que FlashAttention 2 es una implementación diseñada para procesar las operaciones de atención de manera más eficiente en arquitecturas Transformer. Mi hipótesis es que mejora la velocidad y la eficiencia de la memoria al reducir el acceso a la memoria y la creación de tensores intermedios durante el cálculo de la atención, o al optimizar el flujo de las operaciones. Sin embargo, consideré que este efecto podría variar según la estructura del modelo, la longitud de la entrada, el tipo de dato (torch.float16 / torch.bfloat16) y la arquitectura de la GPU.
Instalando FlashAttention 2 en DGX Spark
El README de MOSS-TTS incluye la siguiente nota:
FlashAttention 2 is only available on supported GPUs and is typically used with torch.float16 or torch.bfloat16
Asumiendo que si se podía usar con torch.float16, también sería compatible con DGX Spark, decidí intentar la instalación.

1. Verificaciones previas
-
Dado que DGX Spark utiliza CUDA 13.0, busqué e instalé los paquetes de dependencia desde
--extra-index-url https://download.pytorch.org/whl/cu130. Aun así, verifiqué la versión de CUDA connvidia-smi. -
Durante la instalación, al buscar PyTorch, tuve que usar la versión ya instalada en mi entorno
venv. Para evitar la creación de un entorno de aislamiento temporal durante la construcción envenv, incluí la opción--no-build-isolation. -
La instalación del wheel falló en el entorno Spark, y el registro de instalación mostró lo siguiente. Esto se debe a la arquitectura aarch64. Es algo que me sucede a menudo al usar Spark, así que ya no me sorprende ni me molesta. Es un mensaje familiar.
Precompiled wheel not found. Building from source...
No tuve más remedio que proceder con la compilación desde el código fuente. Como es necesario para el proceso de compilación, instalé ninja en el venv.
bash
pip install ninja
- Si no lo tienen, necesitarán las bibliotecas de desarrollo de Python 3.12 en el sistema host. Instálenlas:
sudo apt update
sudo apt install python3.12-dev -y
flash-attn compila código C++ y CUDA para vincularlo con Python, y para ello se necesita el archivo Python.h, donde se define la estructura interna de Python. Este archivo no suele estar incluido en los entornos de ejecución de Python estándar, por lo que es necesario instalar un paquete para desarrolladores por separado.
2. Comando de instalación
Este es el punto clave de esta publicación. Teniendo en cuenta todos los aspectos anteriores, la instalación se realizó con el siguiente comando:
TORCH_CUDA_ARCH_LIST="12.0" MAX_JOBS=1 pip install --no-build-isolation --extra-index-url https://download.pytorch.org/whl/cu130 -e ".[flash-attn]"
3. Razones de la combinación de comandos (decisión por ensayo y error)
Al principio, tenía gpt-oss-120b también en ejecución en la máquina. Al ejecutar pip install ... -e ".[flash-attn]" en esas condiciones, el uso de la CPU se disparó abruptamente, el sistema se bloqueó y la terminal dejó de responder, lo que me obligó a reiniciar forzadamente con el interruptor físico. Después de eso, detuve todas las tareas que consumían recursos y me concentré únicamente en la instalación.
Tras varios intentos y errores, logré completar la instalación con el comando anterior. El tiempo total estimado fue de 1 a 2 horas. Aunque no lo medí con precisión, la instalación duró más de una hora, así que me ausenté y, al regresar después de comer, ya había terminado.
Durante la instalación, la memoria se mantuvo ocupada en unos 24GB aproximadamente. El problema principal fue la CPU, y fue más estable dejar de lado otras tareas para que la instalación pudiera concentrarse en ella.
Las razones para incluir las opciones fueron las siguientes:
TORCH_CUDA_ARCH_LIST="12.0": El objetivo era acortar el tiempo de instalación al apuntar explícitamente solo a la arquitectura Blackwell.MAX_JOBS=1: Debido a la experiencia previa de bloqueo del sistema, lo configuré de forma conservadora a 1. Como resultado, la instalación tardó más de 60 minutos.
Efectos de mejora de la inferencia tras la instalación
1. En cuanto a la velocidad
Sinceramente, la velocidad no se sintió significativamente diferente. Incluso sin FlashAttention, ya era bastante rápido, así que, aunque hubiera una reducción de unos pocos segundos, la sensación de que “definitivamente era más rápido” no fue muy marcada.
- Generar un resultado de unos 7 segundos tomó entre 8 y 9 segundos.
- Generar un resultado de 25 segundos tomó unos 32 segundos.
- Un resultado de 16 segundos tomó 21 segundos.
Es decir, la inferencia tomó aproximadamente 1.3 veces la duración del resultado generado.
2. En cuanto a la memoria
Respecto a la memoria, tampoco noté grandes cambios. Aunque monitoreé constantemente los valores de nvidia-smi durante la inferencia, el uso de memoria no aumentó ni disminuyó de manera perceptible. Durante la inferencia, el consumo de energía rondaba los 36W, y la temperatura subía de 46 a 53 grados Celsius.
Resumen
- En el entorno DGX Spark, la instalación del wheel de FlashAttention 2 falló, por lo que tuve que instalarlo desde el código fuente.
- La instalación fue exitosa, pero el tiempo de compilación fue largo y la carga de la CPU, considerable.
- Después de la instalación, la mejora percibida en términos de velocidad y memoria no fue tan significativa como esperaba.
- Si tuviera que volver a compilarlo, probablemente aumentaría
MAX_JOBSa 4. En teoría, esto debería reducir el tiempo de compilación a una cuarta parte.
Artículos relacionados
No hay comentarios.