コンテンツにスキップ

PyTorch の推論を高速化する

PyTorch の推論を高速化する方法について調べてみた。

torch モデルに対しての処理

  • torch.autocast()の使用
  • 畳み込みモデルに Channels-Last メモリ形式を使用
  • 畳み込みモデルに cuDNN ベンチマークを使用
  • torch.inference_mode()の使用
  • PyTorch2 系のtorch.compile()の使用

これらを組み合わせると 3~4 倍程度の高速化が期待できる。link

compile 時に backend を指定でき、hidetを使っただとレイテンシが 1/5 になっている。

torch モデルの変換

torch で生成したモデルを TensorRT や ONNX、OpenVINO、Deepspeed に変換することで高速化する方法もある。変換する前と後のモデルで推論のパフォーマンスを計測した記事がある。

参考