Skip to content

TensorRT化

train_model修正箇所

RESNETの読み込み。

1
2
3
4
device = torch.device('cuda')
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 2)
model = model.cuda().eval().half()

すでに学習済みのモデルに対して、学習が終わった後でTensorRTに変換する。

新規で追加

1
2
3
from troch2trt import troch2trt 
data = torch.randn((1, 3, 224, 224)).cuda().half()
model_trt = torch2trt(model, [data], fp16_mode=True)

変換が正しく行われたか、変換前と変換後の比較をおこなう。

1
2
3
4
5
6
output_trt = model_trt(data)
output = model(data)

print(output.flatten()[0:10])
print(output_trt.flatten()[0:10])
print('max error: %f' % float(torch.max(torch.abs(output - output_trt))))

tensor([ 0.6694, 2.5645, 2.9473, 2.6445, 3.9434, 3.5391, 3.5820, 0.3401, -1.0713, -0.3494], device='cuda:0', dtype=torch.float16, grad_fn=) tensor([ 0.6763, 2.5742, 2.9551, 2.6562, 3.9531, 3.5449, 3.5840, 0.3359, -1.0742, -0.3525], device='cuda:0', dtype=torch.float16) max error: 0.020508

最後に、TensorRTのモデルとして保存する。

1
torch.save(model_trt.state_dict(), 'road_following_model_xy_trt.pth')

live_demo 修正箇所

TensorRTでモデルを読み込む。

1
2
3
4
5
6
7
8
9
import torchvision
import torch
from torch2trt import TRTModule 

model_trt = TRTModule()
model_trt.load_state_dict(torch.load('road_following_model_xy_trt.pth'))
device = torch.device('cuda')
model = model_trt.to(device)
model = model_trt.eval()