Ditto TalkingHead(DGX Spark / ARM64)TensorRT 移植工作记录

目的

为了在 DGX Spark(ARM64 / aarch64)平台上运行 ditto-talkinghead 项目, 我们将已有的 ONNX 模型转换为对应环境的 TensorRT 引擎并完成推理。 其中 warp_network.onnx 依赖自定义插件 GridSample3D, 因此 解决 TensorRT 自定义插件(.so)加载问题 是关键。

image of the porting process on DGX-Spark


为什么需要这项工作

Ditto 检查点中包含 libgrid_sample_3d_plugin.so,但该文件是 x86‑64 编译的二进制。

我的环境是:

  • DGX Spark
  • ARM64 (aarch64)
  • TensorRT 10.14.1
  • CUDA 13.1

因此 ARM TensorRT 无法加载 x86 插件 .so,导致在解析 warp_network.onnx 时失败,TRT 引擎无法生成。


症状(关键错误日志)

在 TRT 转换 warp_network 时出现如下错误:

  • Unable to load library: libgrid_sample_3d_plugin.so
  • Plugin not found ... GridSample3D
  • Fail parsing warp_network.onnx
  • 最终报 Network must have at least one output

根本原因是 GridSample3D 插件加载失败。


原因分析过程

1)插件文件是否存在

文件确实在:

  • ./checkpoints/ditto_onnx/libgrid_sample_3d_plugin.so

所以不是路径问题。


2)ldd 检查

运行 ldd ./checkpoints/ditto_onnx/libgrid_sample_3d_plugin.so 得到:

  • not a dynamic executable

这提示可能是 架构不匹配


3)使用 file / readelf 确认架构

关键证据:

  • file 输出:ELF 64-bit LSB shared object, x86-64
  • readelf -h 输出:Machine: Advanced Micro Devices X86-64

因此插件是 x86‑64 二进制,无法在 ARM64 上加载。


解决方案(核心)

结论

需要 在 ARM64 环境下重新编译 GridSample3D TensorRT 插件。


实施步骤

1)获取插件源码

ditto-talkinghead 仓库只提供二进制,没有源码。 我们在另一个仓库中找到了 grid-sample3d-trt-plugin 的源码:

  • grid_sample_3d_plugin.cpp
  • grid_sample_3d_plugin.h
  • grid_sample_3d.cu
  • grid_sample_3d.cuh

2)确认 TensorRT / CUDA 环境

  • TensorRT Python: 10.14.1.48
  • TensorRT 库: /usr/lib/aarch64-linux-gnu/libnvinfer.so
  • CUDA: 13.1
  • GPU capability: (12, 1) → CUDA 架构 121

3)修改 CMake 构建配置

原有 CMake 针对 x86 并硬编码了旧的 GPU 架构,存在以下问题:

  • 强制 compute_70 → 在当前 nvcc 中不受支持
  • cuda_fp16.h 包含路径错误
  • 缺少 TensorRT 库目录声明

关键修改

  • 移除 CUDA_ARCHITECTURES 硬编码
  • 明确指定 TensorRT include/lib 路径
  • 添加 CUDA include 路径(/usr/local/cuda/targets/sbsa-linux/include
  • 可选关闭测试子目录编译

4)成功编译 ARM64 插件

cd /workspace/grid-sample3d-trt-plugin
rm -rf build
mkdir build && cd build
cmake .. \
  -DCMAKE_BUILD_TYPE=Release \
  -DTensorRT_ROOT=/usr \
  -DTensorRT_INCLUDE_DIR=/usr/include/aarch64-linux-gnu \
  -DTensorRT_LIB_DIR=/usr/lib/aarch64-linux-gnu \
  -DCMAKE_CUDA_ARCHITECTURES=121

cmake --build . -j"$(nproc)"

编译完成后生成的文件:

  • build/libgrid_sample_3d_plugin.so

使用 file 再次确认是 ARM64:

file build/libgrid_sample_3d_plugin.so

预期输出类似:ELF 64-bit LSB shared object, ARM aarch64, ...


5)替换 Ditto 检查点中的 x86 插件

cp /workspace/grid-sample3d-trt-plugin/build/libgrid_sample_3d_plugin.so \
   /workspace/ditto-talkinghead/checkpoints/ditto_onnx/libgrid_sample_3d_plugin.so

6)验证 TensorRT 插件加载

在 Python 中直接使用 TRT 插件注册表加载,确认插件能够被成功注册并使用。


7)重新执行 ONNX → TensorRT 转换

再次运行 cvt_onnx_to_trt.pywarp_network.onnx 以及其他模型全部转换成功,推理也顺利完成。


当前状态

GridSample3D 自定义 TensorRT 插件在 ARM64 上正常工作 ✅ warp_network.onnx 成功解析 ✅ ONNX → TensorRT 引擎转换成功 ✅ Ditto TalkingHead 在 DGX Spark / ARM64 环境下推理成功

即完成了对 ditto-talkinghead 的 DGX Spark 移植。


故障排查笔记

1)检查点中的 .so 与平台绑定

检查点里的二进制并非跨平台通用,尤其在 ARM 环境经常出现 x86 二进制。遇到类似错误时,先确认机器架构是否为 aarch64。


2)TensorRT 版本升级可能需要重新编译插件

本次在 TRT 10.14 上编译成功,但若升级到新主版本,插件的 API/ABI 可能不兼容,需要重新构建。


3)注意 CUDA 架构硬编码

CMakeLists.txt 中若写死 70;80;86;89 等旧架构,在新 GPU 上会编译失败。 使用 torch.cuda.get_device_capability() 获取实际 capability, 然后把 major*10+minor 作为 CMAKE_CUDA_ARCHITECTURES(如 (12,1)121)。


环境概述

  • 平台:DGX Spark
  • 架构:aarch64 (ARM64)
  • CUDA:13.1
  • TensorRT:10.14.1
  • Python:3.12
  • GPU capability:12.1(对应 CMake CUDA arch = 121)

相关链接