Ditto TalkingHead (DGX Spark / ARM64) の TensorRT 移植作業記録

目的



ditto-talkinghead プロジェクトを DGX Spark (ARM64 / aarch64) 環境で動作させるため、既存の ONNX モデルを現在の環境で推論可能なTensorRTエンジンに変換する作業を行いました。

特に warp_network.onnxGridSample3D カスタムプラグインに依存しているため、単純変換ではなく TensorRT カスタムプラグイン (.so) のロード問題の解決 が鍵となった。

image of the porting process on DGX-Spark


なぜこの作業が必要だったか

Ditto のチェックポイントには libgrid_sample_3d_plugin.so が含まれていたが、これは x86‑64 用にビルドされたバイナリ だった。

私の環境は:

  • DGX Spark
  • ARM64 (aarch64)
  • TensorRT 10.14.1
  • CUDA 13.1

つまり、x86用のプラグインである.soファイルをARM版TensorRTでロードできず、warp_network.onnxのパース時に失敗し、TRTエンジンの生成に至りませんでした。


症状(失敗ログの要点)



TRT 変換中に warp_network で次のエラーが発生しました:

  • Unable to load library: libgrid_sample_3d_plugin.so
  • Plugin not found ... GridSample3D
  • Fail parsing warp_network.onnx
  • 最終的に Network must have at least one output

根本原因は GridSample3D プラグインのロード失敗でした。


原因分析の流れ

1) プラグインファイルの有無確認

ファイルは実際に存在しました。

  • ./checkpoints/ditto_onnx/libgrid_sample_3d_plugin.so

したがって単なるパス問題ではありませんでした。


2) ldd の結果

ldd ./checkpoints/ditto_onnx/libgrid_sample_3d_plugin.so の出力:

  • not a dynamic executable

当初は不自然に思えましたが、追加の解析により アーキテクチャ不一致の可能性 が示唆されました。


3) file / readelf でアーキテクチャを確定

決定的証拠:

  • file の結果: ELF 64-bit LSB shared object, x86-64
  • readelf -h の結果: Machine: Advanced Micro Devices X86-64

つまり、Ditto が提供した .sox86‑64 用バイナリ で、現在の ARM64 環境ではロードできませんでした。


解決策(要点)

結論

GridSample3D TensorRT プラグインを ARM64 環境で再ビルド する必要がありました。


解決手順

1) プラグインソースの入手

ditto-talkinghead リポジトリにはプラグインソースがなく、バイナリのみが提供されていました。

別リポジリから GridSample3D TensorRT プラグインのソースを取得しました:

  • grid-sample3d-trt-plugin

ソースファイル:

  • grid_sample_3d_plugin.cpp
  • grid_sample_3d_plugin.h
  • grid_sample_3d.cu
  • grid_sample_3d.cuh

2) TensorRT / CUDA 環境の確認

確認結果は次のとおりです:

  • TensorRT Python: 10.14.1.48
  • TensorRT ライブラリ: /usr/lib/aarch64-linux-gnu/libnvinfer.so
  • CUDA: 13.1
  • GPU Capability: (12, 1) → CUDA arch 121

3) CMake ビルド設定の修正

既存のCMake設定はx86向けの設定が中心で、古いGPUアーキテクチャがハードコードされていたため、以下の問題がありました。

  • compute_70 が強制 → 現在の nvcc では未対応
  • cuda_fp16.h のインクルードパス問題
  • TensorRT lib ディレクトリ指定不足

修正ポイント

  • CUDA_ARCHITECTURES のハードコードを除去
  • TensorRT の include / lib パスを明示
  • CUDA include パスを指定 (/usr/local/cuda/targets/sbsa-linux/include)
  • テストサブディレクトリのビルドを無効化(任意)

4) ARM64 用プラグインのビルド成功

cd /workspace/grid-sample3d-trt-plugin
rm -rf build
mkdir build && cd build
cmake .. \
  -DCMAKE_BUILD_TYPE=Release \
  -DTensorRT_ROOT=/usr \
  -DTensorRT_INCLUDE_DIR=/usr/include/aarch64-linux-gnu \
  -DTensorRT_LIB_DIR=/usr/lib/aarch64-linux-gnu \
  -DCMAKE_CUDA_ARCHITECTURES=121

cmake --build . -j"$(nproc)"

ビルド後に生成されたファイル:

  • build/libgrid_sample_3d_plugin.so

このファイルが ARM64 かどうかを確認:

file build/libgrid_sample_3d_plugin.so

期待される出力例:

  • ELF 64-bit LSB shared object, ARM aarch64, …

5) Ditto チェックポイントの x86 プラグインと置換

cp /workspace/grid-sample3d-trt-plugin/build/libgrid_sample_3d_plugin.so \
   /workspace/ditto-talkinghead/checkpoints/ditto_onnx/libgrid_sample_3d_plugin.so

6) TensorRT プラグインロードの確認

Python から TRT プラグインレジストリに直接ロードし、正常に動作することをテストしました。


7) ONNX → TensorRT 変換の再実行

cvt_onnx_to_trt.py を実行すると、warp_network.onnx を含むすべての変換が成功しました。

最終的に 推論も成功 しました。


現在の最終状態

GridSample3D カスタム TensorRT プラグインが ARM64 で正常動作 ✅ warp_network.onnx のパースに成功 ✅ ONNX → TensorRT エンジン変換に成功 ✅ Ditto TalkingHead の推論が DGX Spark / ARM64 環境で成功

この作業により ditto-talkinghead を DGX Spark 環境向けに完全に移植できました。


トラブルシューティングメモ

1) チェックポイントに含まれる .so はプラットフォーム依存

チェックポイントに含まれているからといって、必ずしも自身の環境でそのまま動作するとは限りません。特にARM環境では x86 用バイナリ が混在していることが多いため、DGX‑SparkやmacOSユーザーは頻繁に遭遇する問題です。初めてこのエラーに直面した場合は、大いに戸惑うかもしれません。自身の環境がaarch64 (ARM)である可能性を、まず疑ってみるべきでしょう。


2) TensorRT バージョンアップ時はプラグインの再ビルドが必要になることがある

今回はTRT 10.14でビルドに成功しましたが、メジャーバージョンが変わると、プラグインAPI/ABIの変更により再ビルドが必要になる場合があります。


3) CUDA アーキテクチャのハードコーディングに注意

CMakeLists.txt70;80;86;89 などが書かれていると最新 GPU で失敗する可能性があります。現在の GPU Capability を確認し、major*10 + minor の形式で CMAKE_CUDA_ARCHITECTURES を指定してください。

python - <<'PY'
import torch
print(torch.cuda.get_device_capability())
PY

参考(自環境サマリ)

  • Platform: DGX Spark
  • Arch: aarch64 (ARM64)
  • CUDA: 13.1
  • TensorRT: 10.14.1
  • Python: 3.12
  • GPU capability: 12.1 (CMake CUDA arch = 121)

関連記事