扫二维码与项目经理沟通
我们在微信上24小时期待你的声音
解答本文疑问/技术咨询/运营咨询/技术建议/互联网交流
这篇文章主要为大家展示了“Pytorch转ONNX中tracing机制有什么用”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“Pytorch转ONNX中tracing机制有什么用”这篇文章吧。
为调兵山等地区用户提供了全套网页设计制作服务,及调兵山网站建设行业解决方案。主营业务为成都网站设计、网站制作、调兵山网站设计,以传统方式定制建设网站,并提供域名空间备案等一条龙服务,秉承以专业、用心的态度为用户提供真诚的服务。我们深信只要达到每一位用户的要求,就会得到认可,从而选择与我们长期合作。这样,我们也可以走得更远!
(1)tracing的机制
上文提到过,Pytorch转ONNX的方式是基于tracing(追踪),通俗来说,就是ONNX的相关代码在一旁看着Pytorch跑一遍,运行了什么内容就把什么记录下来。但是在这里并不是所有Python的运行内容都会被记录。举个例子,下面的代码中,
c = torch.matmul(a, b)
print("Blabla")
e = torch.matmul(c, d)
其中只有第1,3行相关的内容会被记录,因为只有他们是和Pytorch相关的,而第二行只是普通的python语句。
具体来说,只有ATen操作会被记录下来。ATen可以被理解为一个Pytorch的基本操作库,一切的Pytorch函数都是基于这些零部件构造出来的(比如ATen就是加减乘除,所有Pytorch的其他操作,比如平方,算sigmoid,都可以根据加减乘除构造出来)
*之前说的ONNX无法记录if语句的问题也是因为if并不是Aten中的操作
虽然ONNX可以记录所有Pytorch的执行(即记录所有ATen操作),但是在输出的时候会做一个剪枝,把没用的操作剪掉
举个例子,下面的程序,显而易见第一句话是没有用的。
t1 = torch.matmul(a, b)
t2 = torch.matmul(c, d)
return t2
ONNX会在得到全部的操作以及他们之间的输入输出关系后(以DAG作为表示),根据DAG的输出往前推,做遍历,所有可以被遍历到的节点被保留,其他节点直接扔掉。
在MMDetection(https://github.com/open-mmlab/mmdetection)中,在NMS(non-Maximumnon maximum suppression)中有如下代码:
if bboxes.numel() == 0:
bboxes = multibboxes.newzeros((0, 5))
labels = multibboxes.newzeros((0, ), dtype=torch.long)
if torch.onnx.isinonnxexport():
raise RuntimeError('[ONNX Error] Can not record NMS '
'as it has not been executed this time')
return bboxes, labels
dets, keep = batchednms(bboxes, scores, labels, nmscfg)
代码逻辑很简单,如果之前的网络根本没有输出任何合法的bbox(第一行的分支判断),那么显然nms的结果就是一堆0,所以没必要运行nms直接返回0就可以。
如果我们想将这段代码转换到ONNX,之前我们提到过ONNX不能处理分支逻辑,因此只能选择一条路去走,记录那条路转换得到的模型。很显然,正常情况下我们自然期待会有较多的bbox,并且将这些bbox作为参数调用nms。
所以如果我们发现模型执行的路径触发了if分支,我们必须要进行一个判断,看看是不是在转ONNX,如果是的话我们就需要直接报错,因为显然转出来的ONNX不是我们想要的。
假设什么都不做,在这种情况下我们转出来的模型是什么样呢?思考一下不难发现,假设函数的返回值就是网络的最终输出,那么我们只会得到一个2个节点的DAG,即第2,3行的两个操作。之前说过ONNX拿到所有的DAG之后会做剪枝,在这里ONNX拿到返回值(bboxes, labels)做回溯,发现最头上就是第2,3行的两个操作,就直接停掉了。所有其他的操作,比如backbone,rpn,fpn,都会被扔掉。
因此,在进行MMDet模型的转换的时候,必须用真实的数据和训练好的参数来做转换,否则基本不会得到有效的bbox,于是就会触发第6行的error
(2)利用tracing机制做优化
在MMSeg中有一个很巧妙的利用tracing机制做优化的例子。
在slide inference时,我们需要计算一个count mat矩阵,这个矩阵在h, w以及对应的stride都固定的情况下会是一个常量。
不过在训练时,往往这些都是我们要调的参数,所有MMSeg没有选择把这些常数保存下来,而是每次都算一遍
countmat = img.newzeros((batchsize, 1, himg, wimg))
for hidx in range(hgrids):
for widx in range(wgrids):
y1 = hidx * hstride
x1 = widx * wstride
y2 = min(y1 + hcrop, himg)
x2 = min(x1 + wcrop, wimg)
y1 = max(y2 - hcrop, 0)
x1 = max(x2 - wcrop, 0)
cropimg = img[:, :, y1:y2, x1:x2]
cropseglogit = self.encodedecode(cropimg, imgmeta)
preds += F.pad(cropseglogit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
countmat[:, :, y1:y2, x1:x2] += 1
assert (countmat == 0).sum() == 0
if torch.onnx.isinonnxexport():
# cast countmat to constant while exporting to ONNX
countmat = torch.fromnumpy(
countmat.cpu().detach().numpy()).to(device=img.device)
不过在部署时,这些参数往往是固定的,因此我们没必要把它算一遍。因此在倒数第4行的if分支里,我们做了一件看似很没用的事
countmat = torch.fromnumpy(countmat.cpu().detach().numpy()).to(device=img.device)
即我们把算出来的countmat从tensor转换成numpy,再转回tensor。
其实我们的目的是切断tracing。
之前提到过,ONNX只能记录ATen相关的操作,但是很显然,tensor和numpy的互转肯定不是ATen操作。因此在回溯的时候,当访问到count mat,ONNX并不能发现它是被谁运算出来的,所以countmat就会被看作一个常数被保存下来,之前计算countmat的部分都会被扔掉
以上是“Pytorch转ONNX中tracing机制有什么用”这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注创新互联行业资讯频道!
我们在微信上24小时期待你的声音
解答本文疑问/技术咨询/运营咨询/技术建议/互联网交流