Tensorrt实现solov2加速

solo系列网络是由Xinlong Wang提出的单阶段实例分割网络。其搭建在mmdetection库中。solov2主干网络如下图所示:

faststyle1

solov2实例分割网络

Tensorrt实现solov2加速步骤如下所示:

1、修改solo中tensorrt或onnx不支持的层。solo原生代码中采用的group normalization层在onnx和tensorrt中支持性不高,会导致模型转换错误,因此将模型中所有group normalization层替换为batch normalization层。因为具体替换了模型的结构,因此还需对模型进行重新训练。

faststyle1

onnx生成的gn层

faststyle1

未指定缩放size情况下onnx生成的upsample层

2、进行pth模型到onnx模型的转换。pytorch1.3对应的onnx版本为1.6,其upsample层需具体转换尺寸,不然会在转tensorrt的过程中报错。此外,solo采用了较为特殊的coordconv,onnx不支持其采用的torch.linsapce操作,因此我们将coord中指明方向的两层保存为具体参数直接读取使用,具体代码为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch

y_list = [100,100,50,25,25]
x_list = [168,168,84,42,42]

i = 1
for x_size, y_size in zip(x_list,y_list):
x_range = torch.linspace(-1,1,x_size)
y_range = torch.linspace(-1,1,y_size)
y,x = torch.meshgrid(y_range,x_range)
y = y.expand([1,1,-1,-1])
x = x.expand([1,1,-1,-1])
coord_feat = torch.cat([x,y],1)
torch.save(coord_feat,"coord{}.pth".format(i))
i+=1

转换onnx的具体代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import argparse
import mmcv
import torch
import torch.nn.functional as F
from mmcv.runner import load_checkpoint
from mmdet.models import build_detector
import cv2
import torch.nn as nn
from types import MethodType
import torch.onnx as onnx
from torchvision.transforms import Normalize


class Norm(nn.Module):
def __init__(self):
super(Norm,self).__init__()
self.mean = [0.485,0.456,0.406]
self.std = [0.229,0.224,0.225]
self.normal = Normalize(self.mean,self.std)

def forward(self,x):
x = x.squeeze(0)
x = x/255.
return self.normal(x).unsqueeze(0)


def points_nms(heat,kernel=2):
hmax = F.max_pool2d(
heat,(kernel,kernel),stride=1,padding=1
)
keep = (hmax[:,:,:-1,:-1] == heat).float()
return heat*keep


def fpn_forward(self,inputs):
assert len(inputs)==len(self.in_channels)
laterals = [
lateral_conv(inputs[i+self.start_level])
for i,lateral_conv in self.lateral_convs
]
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels-1,0,-1):
sh = torch.tensor(laterals[i].shape)
laterals[i-1] += F.interpolate(
laterals[i],size=(sh[2]*2,sh[3]*2),mode="nearest"
)
outs = [
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
]

if self.num_outs>len(outs):
if not self.add_extra_convs:
for i in range(self.num_outs-used_backbone_levels):
outs.append(F.max_pool2d(outs[-1],1,stride=2))
else:
if self.extra_convs_on_inputs:
orig = inputs[self.backbone_end_level -1]
outs.append(self.fpn_convs[used_backbone_levels](orig))
else:
outs.append(self.fpn_convs[used_backbone_levels](outs[-1]))
for i in range(used_backbone_levels + 1,self.num_outs):
if self.relu_before_extra_convs:
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
else:
outs.append(self.fpn_convs[i](outs[-1]))
return tuple(outs)


def forward_single(self, x, idx, eval=False, upsampled_size=None):
ins_kernel_feat = x
coord_feat = self.coord[idx]
seg_num_grid = self.seg_num_grids[idx]
cate_feat = F.interpolate(ins_kernel_feat,size=seg_num_grid,mode="bilinear")

kernel_feat = torch.cat([ins_kernel_feat,coord_feat],1)

kernel_feat = F.interpolate(kernel_feat, size=seg_num_grid, mode="bilinear")

kernel_feat = kernel_feat.contiguous()
for i, kernel_layer in enumerate(self.kernel_convs):
kernel_feat = kernel_layer(kernel_feat)
kernel_pred = self.solo_kernel(kernel_feat) # 256

# cate branch
cate_feat = cate_feat.contiguous()
for i, cate_layer in enumerate(self.cate_convs):
cate_feat = cate_layer(cate_feat)
cate_pred = self.solo_cate(cate_feat) # B*S*S*80

if eval:
cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
return cate_pred, kernel_pred


def split_feats(self, feats):
sh1 = torch.tensor(feats[0].shape)
sh2 = torch.tensor(feats[3].shape)
return (F.interpolate(feats[0], size=(int(sh1[2]*0.5),int(sh1[3]*0.5)), mode='bilinear'),
feats[1],
feats[2],
feats[3],
F.interpolate(feats[4], size=(sh2[2],sh1[3]), mode='bilinear'))


def forward(self, inputs):
assert len(inputs) == (self.end_level - self.start_level + 1)

feature_add_all_level = self.convs_all_levels[0](inputs[0])
x = self.convs_all_levels[1].conv0(inputs[1])
sh = torch.tensor(x.shape)
feature_add_all_level += F.interpolate(x,size=(sh[2]*2,sh[3]*2),mode="bilinear")

x = self.convs_all_levels[2].conv0(inputs[2])
sh = torch.tensor(x.shape)
x = F.interpolate(x,size=(sh[2]*2,sh[3]*2),mode="bilinear")
x = self.convs_all_levels[2].conv1(x)
sh = torch.tensor(x.shape)
feature_add_all_level += F.interpolate(x,size=(sh[2]*2,sh[3]*2),mode="bilinear")

coord_feat = self.coord
input_p = torch.cat([inputs[3],coord_feat],1)
x = self.convs_all_levels[3].conv0(input_p)
sh = torch.tensor(x.shape)
x = F.interpolate(x, size=(sh[2] * 2, sh[3] * 2), mode="bilinear")
x = self.convs_all_levels[3].conv1(x)
sh = torch.tensor(x.shape)
x = F.interpolate(x, size=(sh[2] * 2, sh[3] * 2), mode="bilinear")
x = self.convs_all_levels[3].conv2(x)
sh = torch.tensor(x.shape)
feature_add_all_level += F.interpolate(x, size=(sh[2] * 2, sh[3] * 2), mode="bilinear")

feature_pred = self.conv_pred(feature_add_all_level)
return feature_pred


def main_forward(self,img):
x = self.normal(img)
x = self.extract_feat(x)
outs = self.bbox_head(x)
mask_feat_pred = self.mask_feat_head(
x[self.mask_feat_head.start_level:self.mask_feat_head.end_level + 1]
)
cate_class = int(outs[0][0].shape[-1])
cate_pred_list = [
outs[0][i].view(-1,cate_class) for i in range(5)
]
kernel_shape = int(outs[1][0].shape[1])
kernel_pred_list = [
outs[1][i].squeeze(0).permute(1,2,0).view(-1,kernel_shape) for i in range(5)
]
cate_pred_list = torch.cat(cate_pred_list,dim=0)
kernel_pred_list = torch.cat(kernel_pred_list, dim=0)

return (cate_pred_list,kernel_pred_list,mask_feat_pred)


def parse_args():
parser = argparse.ArgumentParser(description='MMDet onnx model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('output',help='output onnx file')
args = parser.parse_args()
return args


def main():
args = parse_args()

cfg = mmcv.Config.fromfile(args.config)
cfg.model.pretrained = None
cfg.data.test.test_mode = True

# build the model and load checkpoint
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
x = [torch.load("coord1.pth"),torch.load("coord2.pth"),
torch.load("coord3.pth"),torch.load("coord4.pth"),
torch.load("coord5.pth")]
model.bbox_head.coord = x
model.bbox_head.forward_single = MethodType(forward_single,model.bbox_head)
model.bbox_head.split_feats = MethodType(split_feats, model.bbox_head)
model.mask_feat_head.coord=x[-1]
model.mask_feat_head.forward = MethodType(forward,model.mask_feat_head)
model.neck.forward = MethodType(fpn_forward,model.neck)

model.normal = Norm()
model.forward = MethodType(main_forward,model)

img = torch.randn(1,3,800,1344)

checkpoint = load_checkpoint(model,args.checkpoint,map_location='cpu')
onnx.export(model,img,args.output,verbose=True,opset_version=10)


if __name__ == '__main__':
main()

3、进行onnx模型到tensorrt模型的转换。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
import pycuda.driver as cuda
import pycuda.autoinit
import cv2
import os
import numpy as np
import tensorrt as trt
import time
import argparse
import torch
import torch.nn.functional as F
import numpy as np


seg_num_grids = [40,36,24,16,12]
self_strides = [8,8,16,32,32]
score_thr = 0.1
mask_thr = 0.5
max_per_img = 100
class_names = [] # 输入你模型要预测类的名字


class HostDeviceMem(object):
def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem

def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

def __repr__(self):
return self.__str__()


class Preprocessimage(object):
def __init__(self,inszie):
self.inszie = inszie

def process(self,image_path):
start = time.time()
image = cv2.imread(image_path) # bgr rgb
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
H,W,_ = image.shape

img_metas = dict()
image = cv2.resize(image,self.inszie) # resize
img_metas["img_shape"] = image.shape
image_raw = cv2.cvtColor(image,cv2.COLOR_RGB2BGR)

image = image.transpose([2,0,1]) # chw
image = np.expand_dims(image,axis=0) # nchw
image = np.array(image,dtype=np.float32,order="C")
print("preprocess time {:.3f} ms".format((time.time()-start)*1000))
return image,image_raw,img_metas


def get_engine(onnx_path,engine_path,TRT_LOGGER,mode="fp16"):
# 如果有engine直接用,否则构建新的engine
def build_engine():
EXPLICIT_BATCH = 1<<(int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with trt.Builder(TRT_LOGGER) as builder,\
builder.create_network(EXPLICIT_BATCH) as network,\
trt.OnnxParser(network,TRT_LOGGER) as parser:
builder.max_workspace_size = 1<<30
builder.max_batch_size = 1
if mode=="fp16":
builder.fp16_mode = True
if not os.path.exists(onnx_path):
print("onnx file {} not found".format(onnx_path))
exit(0)
print("loading onnx file {} .....".format(onnx_path))
with open(onnx_path,'rb') as model:
print("Begining parsing....")
parser.parse(model.read())
print("completed parsing")
print("Building an engine from file {}".format(onnx_path))
network.get_input(0).shape = [1,3,800,1344]
engine = builder.build_cuda_engine(network)

print("completed build engine")
with open(engine_path,"wb") as f:
f.write(engine.serialize())
return engine
if os.path.exists(engine_path):
print("loading engine file {} ...".format(engine_path))
with open(engine_path,"rb") as f,\
trt.Runtime(TRT_LOGGER) as runtime:
return runtime.deserialize_cuda_engine(f.read())
else:
return build_engine()


def allocate_buffers(engine):
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()

for binding in engine:
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
dtype = trt.nptype(engine.get_binding_dtype(binding))
host_mem = cuda.pagelocked_empty(size,dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)

bindings.append(int(device_mem))

if engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem,device_mem))
else:
outputs.append(HostDeviceMem(host_mem,device_mem))

return inputs,outputs,bindings,stream


def do_inference(context,bindings,inputs,outputs,stream,batch_size=1):
[cuda.memcpy_htod_async(inp.device,inp.host,stream) for inp in inputs]

context.execute_async_v2(bindings=bindings,stream_handle=stream.handle)

[cuda.memcpy_dtoh_async(out.host,out.device,stream) for out in outputs]

stream.synchronize()

return [out.host for out in outputs]


def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None):
"""Matrix NMS for multi-class masks.

Args:
seg_masks (Tensor): shape (n, h, w)
cate_labels (Tensor): shape (n), mask labels in descending order
cate_scores (Tensor): shape (n), mask scores in descending order
kernel (str): 'linear' or 'gauss'
sigma (float): std in gaussian method
sum_masks (Tensor): The sum of seg_masks

Returns:
Tensor: cate_scores_update, tensors of shape (n)
"""
n_samples = len(cate_labels)
if n_samples == 0:
return []
if sum_masks is None:
sum_masks = seg_masks.sum((1, 2)).float()
seg_masks = seg_masks.reshape(n_samples, -1).float()
# inter.
inter_matrix = torch.mm(seg_masks, seg_masks.transpose(1, 0))
# union.
sum_masks_x = sum_masks.expand(n_samples, n_samples)
# iou.
iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix)).triu(diagonal=1)
# label_specific matrix.
cate_labels_x = cate_labels.expand(n_samples, n_samples)
label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float().triu(diagonal=1)

# IoU compensation
compensate_iou, _ = (iou_matrix * label_matrix).max(0)
compensate_iou = compensate_iou.expand(n_samples, n_samples).transpose(1, 0)

# IoU decay
decay_iou = iou_matrix * label_matrix

# matrix nms
if kernel == 'gaussian':
decay_matrix = torch.exp(-1 * sigma * (decay_iou ** 2))
compensate_matrix = torch.exp(-1 * sigma * (compensate_iou ** 2))
decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0)
elif kernel == 'linear':
decay_matrix = (1-decay_iou)/(1-compensate_iou)
decay_coefficient, _ = decay_matrix.min(0)
else:
raise NotImplementedError

# update the score.
cate_scores_update = cate_scores * decay_coefficient
return cate_scores_update


def get_seg_single(cate_preds,
seg_preds,
kernel_preds,
img_metas):

img_shape = img_metas['img_shape']

# overall info.
h, w, _ = img_shape

featmap_size = seg_preds.size()[-2:]
upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4) #seg # 1344,800

# process.
inds = (cate_preds > score_thr)
cate_scores = cate_preds[inds]
if len(cate_scores) == 0:
return None

# cate_labels & kernel_preds
inds = inds.nonzero()
cate_labels = inds[:, 1]
kernel_preds = kernel_preds[inds[:, 0]] # 选择cate大于阈值对应的kernel

# trans vector.
size_trans = cate_labels.new_tensor(seg_num_grids).pow(2).cumsum(0) # tensor([1600, 2896, 3472, 3728, 3872])
strides = kernel_preds.new_ones(size_trans[-1]) # [1,1,1,1,....,1] # 3872 所有的s*s累加

n_stage = len(seg_num_grids) # 5
strides[:size_trans[0]] *= self_strides[0] # [8,8,8,8......,8] 前1600乘8
for ind_ in range(1, n_stage): #2,3,4,5
strides[size_trans[ind_-1]:size_trans[ind_]] *= self_strides[ind_] # self.strides[8, 8, 16, 32, 32]
strides = strides[inds[:, 0]] # 选择前坐标

# mask encoding.
I, N = kernel_preds.shape #( 选出的kernel,256)
kernel_preds = kernel_preds.view(I, N, 1, 1) # (out_channels,in_channe/groups,H,W)
seg_preds = F.conv2d(seg_preds, kernel_preds, stride=1).squeeze(0).sigmoid() #(选出的kernel,h,w)
# mask.
seg_masks = seg_preds > mask_thr
sum_masks = seg_masks.sum((1, 2)).float()

# filter.
keep = sum_masks > strides # 大于相对应的stride
if keep.sum() == 0:
return None

seg_masks = seg_masks[keep, ...]
seg_preds = seg_preds[keep, ...]
sum_masks = sum_masks[keep]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]

# mask scoring.
seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_scores

# sort and keep top nms_pre
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > max_per_img:
sort_inds = sort_inds[:max_per_img]
seg_masks = seg_masks[sort_inds, :, :]
seg_preds = seg_preds[sort_inds, :, :]
sum_masks = sum_masks[sort_inds]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]

# Matrix NMS
cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
kernel='gaussian',sigma=2., sum_masks=sum_masks)

if seg_preds.shape[0]==1:
seg_preds = cv2.resize(seg_preds.permute(1,2,0).numpy(),
(upsampled_size_out[1],upsampled_size_out[0]))[:,:,None].transpose(2,0,1)
else:
seg_preds = cv2.resize(seg_preds.permute(1,2,0).numpy(),
(upsampled_size_out[1],upsampled_size_out[0])).transpose(2,0,1)
seg_masks = seg_masks > mask_thr
return seg_masks, cate_labels, cate_scores


def vis_seg(image_raw,result,score_thresh,output):
img_show = image_raw
seg_show1 = img_show.copy()
seg_show = img_show.copy()
if result==None:
cv2.imwrite(output,seg_show1)
else:

seg_label = result[0]
seg_label = seg_label.astype(np.uint8)
cate_label = result[1]
cate_label = cate_label.numpy()
score = result[2].numpy()

vis_inds = score > score_thresh
seg_label = seg_label[vis_inds]
num_mask = seg_label.shape[0]
cate_label = cate_label[vis_inds]
cate_score = score[vis_inds]

mask_density = []
for idx in range(num_mask):
cur_mask = seg_label[idx, :, :]

mask_density.append(cur_mask.sum())
orders = np.argsort(mask_density)
seg_label = seg_label[orders]
cate_label = cate_label[orders]
cate_score = cate_score[orders]


for idx in range(num_mask):
idx = -(idx + 1)
cur_mask = seg_label[idx, :, :]

if cur_mask.sum() == 0:
continue
color_mask = (np.random.randint(0,255),np.random.randint(0,255),np.random.randint(0,255))
contours,_ = cv2.findContours(cur_mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_NONE)
cv2.drawContours(seg_show,contours,-1,color_mask,-1)
cur_cate = cate_label[idx]
label_text = class_names[cur_cate]
x1,y1,w,h = cv2.boundingRect(cur_mask)
x2 = x1+w
y2 = y1+h
vis_pos = (max(int(x1)-10,0),int(y1))
cv2.rectangle(seg_show,(x1,y1),(x2,y2),(0,0,0),thickness=2)
cv2.putText(seg_show,label_text,vis_pos,cv2.FONT_HERSHEY_COMPLEX,1,(0,0,0))
seg_show1 = cv2.addWeighted(seg_show,0.7,img_show,0.5,0)
cv2.imwrite(output,seg_show1)


def main():

args = argparse.ArgumentParser(description="trt pose predict")
args.add_argument("--onnx_path",type=str,default="dense121.onnx")
args.add_argument("--engine_path",type=str,default="dense121fp16.trt")
args.add_argument("--image_path",type=str)
args.add_argument("--mode",type=str,default="fp16")
args.add_argument("--output",type=str,default="result.png")
args.add_argument("--classes", type=int, default=80)
args.add_argument("--score_thr", type=float, default=0.3)
opt = args.parse_args()

insize = (1344,800)

output_shape = [(1, 256, 200, 336),(3872,opt.classes),(3872,256)]
TRT_LOGGER = trt.Logger()
preprocesser = Preprocessimage(insize)

image, image_raw,img_metas = preprocesser.process(opt.image_path)

with get_engine(opt.onnx_path,opt.engine_path,TRT_LOGGER,opt.mode) as engine, \
engine.create_execution_context() as context:
inputs,outputs,bindings,stream = allocate_buffers(engine)

inputs[0].host = image
start = time.time()
trt_outputs = do_inference(context,bindings,inputs,outputs,stream)
end = time.time()
print("inference time {:.3f} ms".format((end-start)*1000))
start = time.time()
trt_outputs = [output.reshape(shape) for output ,shape in zip(trt_outputs,output_shape)]
trt_outputs = [torch.tensor(output) for output in trt_outputs]

cate_pred = trt_outputs[1]
kernel_pred = trt_outputs[2]
seg_pred = trt_outputs[0]

with torch.no_grad():
result = get_seg_single(cate_pred,kernel_pred,seg_pred,img_metas)
vis_seg(image_raw,result,opt.score_thr,opt.output)
print("post time {:.3f} ms".format((end - start) * 1000))


if __name__=="__main__":
main()