RK3399PRO NPU Manual 提供了入门指导,里面的 RKNN toolkit 包含了样例和转换工具。我的 unet 版本 是参考 github 头像匹配的 pytorch 版本。
一、出师不利
样例提供了 torchvision 上 resnet_18 的
python 版本,能够成功转成 rknn 格式,代码如下: import numpy as np import cv2 from rknn.api import RKNN import torchvision.models as models import torch if __name__ == '__main__': net = models.resnet_18(pretrained=True) net.eval() trace_model = torch.jit.trace(net, torch.Tensor(1, 3, 224, 224)) trace_model.save('./resnet_18.pt') model = './resnet_18.pt' input_size_list = [[3,224,224]] # Create RKNN object rknn = RKNN() # pre-process config print('--> config model') rknn.config(channel_mean_value='123.675 116.28 103.53 58.395', reorder_channel='0 1 2') print('done') # Load pytorch model
print('--> Loading model') ret = rknn.load_pytorch(model=model, input_size_list=input_size_list) if ret != 0: print('Load pytorch model failed!') exit(ret) print('done') # Build model print('--> Building model') ret = rknn.build(do_quantization=False, dataset='./dataset.txt') if ret != 0: print('Build pytorch failed!') exit(ret) print('done') # Export rknn model print('--> Export RKNN model') ret = rknn.export_rknn('./resnet_18.rknn') if ret != 0: print('Export resnet_18.rknn failed!') exit(ret) print('done')
这里有个关键函数,torch.jit.trace,JIT 表示 Just In Time Compilation,即时编译。它是 Python 和 C++的 桥梁,我们可以使用 python 训练模型,然后通过 JIT 将模型转为与语言无关的静态图,供 C++调用,能非 常方便得部署到
树莓派、IOS、
Android 等设备。
静态图大概长这样:
我一开始先想部署自己训练的 VGG 模型,网络尽量接近 torchvision_models 的样例,定义如下: import torch.nn as nn import torch class VGG(nn.Module): def __init__(self): super(VGG, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, 1, 1,bias=False)
self.relu = nn.ReLU(inplace=True)
self.conv2_1 = nn.Conv2d(32, 64, 3, 1, 1,bias=False)
self.max_pool2d = nn.MaxPool2d(
kernel_size=2,stride=2)
self.conv2_2 = nn.Conv2d(64, 64, 3, 1, 1,bias=False)
self.conv3_1 = nn.Conv2d(64, 128, 3, 1, 1,bias=False)
self.conv3_2 = nn.Conv2d(128, 128, 3, 1, 1,bias=False)
self.conv3_3 = nn.Conv2d(128, 128, 3, 1, 1,bias=False)
self.fc1 = nn.Linear(2048, 128,bias=False)
self.fc2 = nn.Linear(128, 10,bias=False) def forward(self, x):
......
更多详细内容请下载附件查看