コンテンツにスキップ

TensorRTへの変換

YOUTUBE

optimize_model

interactive_regressionで保存した学習済みモデルをTensorRTのフォーマットに変換します。

2つめのcellが、学習済みモデルの読み込みになるので、保存した名前に合わせて書き換えます。

1
model.load_state_dict(torch.load('model_08-06.pth'))

例えばmodel_11-13.pthという名前で保存した場合は、下記のように書き換えます。

1
model.load_state_dict(torch.load('model_11-13.pth'))

4つめのcellが、TensorRTフォーマットで出力するTensorRTフォーマットの学習済みモデルになっていますので、必要に応じて書き直します。

1
torch.save(model_trt.state_dict(), 'model_08-06_trt.pth')