admin 管理员组

文章数量: 1184232

MobileNetV3基于NNI剪枝操作

NNI剪枝入门可参考:nni模型剪枝_benben044的博客-CSDN博客_nni 模型剪枝

1、背景

本文的剪枝操作针对CenterNet算法的BackBone,即MobileNetV3算法。

该Backbone最后的输出格式如下:

假如out = model(x),则x[-1]['hm']可获得heatmap的shape。

2、直接添加nni操作

直接添加的示例代码如下:

import torch
from torch import nn
from nni.compression.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedupclass hswish(nn.Module):def __init__(self):super(hswish, self).__init__()self.relu6 = nn.ReLU6(inplace=True)def forward(self, x):out = x * self.relu6(x + 3) / 6return outclass hsigmoid(nn.Module):def __init__(self):super(hsigmoid, self).__init__()self.relu6 = nn.ReLU6(inplace=True)def forward(self, x):out = self.relu6(x + 3) / 6return out# 注意力机制
class SE(nn.Module):def __init__(self, in_channels, reduce=4):super(SE, self).__init__()self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(in_channels, in_channels // reduce, 1, bias=False),nn.BatchNorm2d(in_channels // reduce),nn.ReLU6(inplace=True),nn.Conv2d(in_channels // reduce, in_channels, 1, bias=False),nn.BatchNorm2d(in_channels),hsigmoid())def forward(self, x):out = self.se(x)out = x * outreturn outclass Block(nn.Module):def __init__(self, kernel_size, in_channels, expand_size, out_channels, stride, se=False, nolinear='RE'):super(Block, self).__init__()self.se = nn.Sequential()if se:self.se = SE(expand_size)if nolinear == 'RE':self.nolinear = nn.ReLU6(inplace=True)elif nolinear == 'HS':self.nolinear = hswish()self.block = nn.Sequential(nn.Conv2d(in_channels, expand_size, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(expand_size),self.nolinear,nn.Conv2d(expand_size, expand_size, kernel_size, stride=stride, padding=kernel_size // 2, groups=expand_size, bias=False),nn.BatchNorm2d(expand_size),self.se,self.nolinear,nn.Conv2d(expand_size, out_channels, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(out_channels))self.shortcut = nn.Sequential()if stride == 1 and in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels))self.stride = stridedef forward(self, x):out = self.block(x)if self.stride == 1:out += self.shortcut(x)return outclass MobileNetV3(nn.Module):def __init__(self, class_num):super(MobileNetV3, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(16),hswish())self.neck = nn.Sequential(Block(3, 16, 16, 16, 2, se=True),Block(3, 16, 72, 24, 2),Block(3, 24, 88, 24, 1),Block(5, 24, 96, 40, 2, se=True, nolinear='HS'),Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),Block(5, 40, 120, 48, 1, se=True, nolinear='HS'),Block(5, 48, 144, 48, 1, se=True, nolinear='HS'),Block(5, 48, 288, 96, 2, se=True, nolinear='HS'),Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),)self.conv2 = nn.Sequential(nn.Conv2d(96, 576, 1, bias=False),nn.BatchNorm2d(576),hswish())self.avgpool = nn.AdaptiveAvgPool2d(1)self.conv3 = nn.Sequential(nn.Conv2d(576, 1280, 2, bias=False),nn.BatchNorm2d(1280),hswish())self.hm = nn.Conv2d(20, class_num, kernel_size=1)self.wh = nn.Conv2d(20, 2, kernel_size=1)self.reg = nn.Conv2d(20, 2, kernel_size=1)def forward(self, x):x = self.conv1(x)x = self.neck(x)x = self.conv2(x)x = self.conv3(x)y = x.view(x.shape[0], -1, 128, 128)z = {}z['hm'] = self.hm(y)z['wh'] = self.wh(y)z['reg'] = self.reg(y)return [z]if __name__ == '__main__':model = MobileNetV3(10)print('-----------raw model------------')print(model)config_list = [{'sparsity_per_layer': 0.8,'op_types': ['Conv2d']}]pruner = L1NormPruner(model, config_list)_, masks = pruner.compress()for name, mask in masks.items():print(name, ' sparsity: ', '{:.2f}'.format(mask['weight'].sum() / mask['weight'].numel()))pruner._unwrap_model()ModelSpeedup(model, torch.rand(2, 3, 516, 516), masks).speedup_model()print('------------after speedup------------')print(model)

如果参考nni入门直接添加nni压缩的代码,则会报如下错误:
RuntimeError: Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions。

 File "D:\programs\python37\lib\site-packages\nni\common\graph_utils.py", line 78, in _traceself.trace = torch.jit.trace(model, dummy_input, **kw_args)File "D:\programs\python37\lib\site-packages\torch\jit\_trace.py", line 742, in trace_module_class,File "D:\programs\python37\lib\site-packages\torch\jit\_trace.py", line 940, in trace_module_force_outplace,
RuntimeError: Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions

 原因,返回的数据不符合torch.jit.trace的要求,而示例model返回的是一个dict,它不是tensors | lists | tuples of tensors | dictionary of tensors中的一种。

所以需要对MobileNetv3进行改造,以满足torch.jit.trace的返回要求。

3、MobileNetV3针对NNI的改造

改造方法:

(1)将输出从dict修改为tuple形式

(2)hm、wh、reg的定义从__init__()函数移到forward中。因为hm中conv的in_channel是会变化的,未剪枝前是A,剪枝后是B,所以在__init__()中定义没法动态修改in_channel值,只能放到forward中进行处理。

以下代码只适用于CPU模式下,不适用GPU上运行。

改造后的示例代码如下:

import torch
from torch import nn
from nni.compression.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedupclass hswish(nn.Module):def __init__(self):super(hswish, self).__init__()self.relu6 = nn.ReLU6(inplace=True)def forward(self, x):out = x * self.relu6(x + 3) / 6return outclass hsigmoid(nn.Module):def __init__(self):super(hsigmoid, self).__init__()self.relu6 = nn.ReLU6(inplace=True)def forward(self, x):out = self.relu6(x + 3) / 6return out# 注意力机制
class SE(nn.Module):def __init__(self, in_channels, reduce=4):super(SE, self).__init__()self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(in_channels, in_channels // reduce, 1, bias=False),nn.BatchNorm2d(in_channels // reduce),nn.ReLU6(inplace=True),nn.Conv2d(in_channels // reduce, in_channels, 1, bias=False),nn.BatchNorm2d(in_channels),hsigmoid())def forward(self, x):out = self.se(x)out = x * outreturn outclass Block(nn.Module):def __init__(self, kernel_size, in_channels, expand_size, out_channels, stride, se=False, nolinear='RE'):super(Block, self).__init__()self.se = nn.Sequential()if se:self.se = SE(expand_size)if nolinear == 'RE':self.nolinear = nn.ReLU6(inplace=True)elif nolinear == 'HS':self.nolinear = hswish()self.block = nn.Sequential(nn.Conv2d(in_channels, expand_size, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(expand_size),self.nolinear,nn.Conv2d(expand_size, expand_size, kernel_size, stride=stride, padding=kernel_size // 2, groups=expand_size, bias=False),nn.BatchNorm2d(expand_size),self.se,self.nolinear,nn.Conv2d(expand_size, out_channels, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(out_channels))self.shortcut = nn.Sequential()if stride == 1 and in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels))self.stride = stridedef forward(self, x):out = self.block(x)if self.stride == 1:out += self.shortcut(x)return outclass MobileNetV3(nn.Module):def __init__(self, class_num, sparsity_ratio):super(MobileNetV3, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(16),hswish())self.neck = nn.Sequential(Block(3, 16, 16, 16, 2, se=True),Block(3, 16, 72, 24, 2),Block(3, 24, 88, 24, 1),Block(5, 24, 96, 40, 2, se=True, nolinear='HS'),Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),Block(5, 40, 120, 48, 1, se=True, nolinear='HS'),Block(5, 48, 144, 48, 1, se=True, nolinear='HS'),Block(5, 48, 288, 96, 2, se=True, nolinear='HS'),Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),)self.conv2 = nn.Sequential(nn.Conv2d(96, 576, 1, bias=False),nn.BatchNorm2d(576),hswish())self.avgpool = nn.AdaptiveAvgPool2d(1)self.conv3 = nn.Sequential(nn.Conv2d(576, 1280, 2, bias=False),nn.BatchNorm2d(1280),hswish())self.class_num = class_numdef forward(self, x):x = self.conv1(x)x = self.neck(x)x = self.conv2(x)x = self.conv3(x)y = x.view(x.shape[0], -1, 128, 128)in_channel = y.shape[1]hm = nn.Conv2d(in_channel, self.class_num, kernel_size=1)wh = nn.Conv2d(in_channel, self.class_num, kernel_size=1)reg = nn.Conv2d(in_channel, self.class_num, kernel_size=1)return (hm(y), wh(y), reg(y))if __name__ == '__main__':model = MobileNetV3(10, 0.2)print('-----------raw model------------')print(model)config_list = [{'sparsity_per_layer': 0.2,'op_types': ['Conv2d']}]pruner = L1NormPruner(model, config_list)_, masks = pruner.compress()for name, mask in masks.items():print(name, ' sparsity: ', '{:.2f}'.format(mask['weight'].sum() / mask['weight'].numel()))pruner._unwrap_model()ModelSpeedup(model, torch.rand(2, 3, 516, 516), masks).speedup_model()print('------------after speedup------------')print(model)input = torch.randn(2, 3, 516, 516)   # batch_size =1 会报错out = model(input)print(out[0].shape)

4、cuda模式下适配CenterNet的MobileNetv3无法剪枝

上面第3段提到的方法只针对cpu,但是在gpu下是运行不成功的。

如果适配CenterNet的MobileNetV3不进行剪枝的话,如果在forward中定义hm、wh、reg的卷积方法,只需要改动3个地方,核心改动点如下:

 但是一旦再加上NNI的代码,则会报错,报错信息为:“RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor)”。

报错原因参考:Pytorch出现RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) - 水果+麦片 - 博客园

 也就是说网络层的定义必须放在__init__()方法中,否则该问题无法避免。

所以,mobileNetv3针对CenterNet的剪枝+再训练就只能在CPU环境下进行。

5、cpu下训练效果

 可以得到loss都无法收敛,所以nni剪枝暂告失败。

本文标签: MobileNetV3基于NNI剪枝操作