当前位置: 首页>后端>正文

神经网络两步模型剪枝框架Pruning PyTorch代码实现

神经网络两步模型剪枝框架Pruning PyTorch代码实现,第1张

随着这些年深度学习的发展,我们身边出现了越来越多的大模型,各种GPT模型,我们设计的模型也越来越大,那么相对应的模型的速度也随之降低,我们的模型需要大量的内存,大量的算力。

然而在我们的神经网络中,有大量的参数是不起作用的,神经网络中真正起作用的参数只占少量。那么我们是不是可以将模型中那些不重要的参数、卷积核、网络层去掉呢?那么这个方法就叫做模型剪枝。

论文地址:http://network.ee.tsinghua.edu.cn/niulab/wp-content/uploads/2019/12/%E6%96%BD%E6%96%87%E7%90%A6INFOCOM.pdf
Git 代码地址:https://github.com/hou-yz/pytorch-pruning-2step/blob/master/main.py

神经网络两步模型剪枝框架Pruning PyTorch代码实现,第2张

模型剪枝的分类较多,有的是减去卷积核里的一个一个参数,有的是减去卷积核中的一块区域,有的将卷积核直接去掉,甚至是将一层网络直接去掉。

今天我们介绍的这个两步剪枝框架,就是进行卷积核的裁剪,也可以被叫做是通道裁剪。也就是将不重要的卷积核直接去掉。

一、两步剪枝框架

神经网络两步模型剪枝框架Pruning PyTorch代码实现,第3张

这个框架主要是分为两步:

  • 第一步:对整个模型进行裁剪
  • 第二步:在上一步的基础上,对每一层再次进行裁剪
(1)第一步:对整个模型裁剪

首先,我们介绍一下这个剪枝的结束条件:

  1. 剪枝后的模型不低于原始模型精度的95%.
  2. 模型剪枝后的参数量不能低于原模型的20%

然后在每一轮简直过程中,会减去 num_filters_to_prune_per_iteration = math.ceil(number_of_filters / 16)个卷积核。

框架中在进行前向传播的时候,给每一个输出值都注册了一个钩子函数,用来将卷积核进行排名。

框架在获取所有的卷积核之后,在每一次反向传播的时候都会对卷积核计算一个value作为评价这个卷积核重要不重要的标准。这个标准在代码中说的是一阶泰勒值。原文的注释:compute the total 1st order taylor for each filters in a given layer

在经过反向传播之后,我们遍获得了所有卷积核的value,接下来,我们将整个网络的卷积核从小到大进行一个排名。将排在前面的num_filters_to_prune_per_iteration个卷积核去掉。一直循环,直到不满足条件。

(1)第二步:对模型中的每一层进行裁剪

其剪枝结束条件和上一步是一样的,剪枝策略也一样。代码中会将每一层剪枝过后的模型保存在checkpoint文件夹中。

这里我根据论文作者给出的代码进行参考,修改了一部分,并做了一些注释。可直接替换掉GitHub中的同名文件,只是如果需要剪枝自己的模型,可能需要重新写一下模型的结构文件。

  • main.py
'''
Train & Pruning with PyTorch by hou-yz.
首先运行: 安装缺少的库, 对应版本在requirement.txt里面
- 第一: 训练模型
- 第二: 一步剪枝
- 第三: 二步剪枝
- 第四: 测试所有模型
'''
import torch
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
import os
import math
from heapq import nsmallest
from operator import itemgetter
import json
import numpy as np
import argparse
from models import *
from model_refactor import *

# (1)选择需要剪枝的模型
net = AlexNet_() # 原alexnet
# net = RegNet() # 原regnet_x_400mf
# 如果使用下面模型,需要调整数据集90行到115行,因为下面的网路的输入是32*32,上面的网络是224*224
# net = VGG('VGG16')
# net = ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
# net = MobileNet()
# net = MobileNetV2()
# net = DPN92()
# net = ShuffleNetG2()
# net = SENet18()

# (2)配置参数
def get_args():
    parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
    parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
    parser.add_argument('--epoch', default=20, type=int, help='epoch')
    parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
    # 将defaut置为True就可以训练模型
    parser.add_argument("--train", default=True, dest="train", action="store_true") 
    # 将defaut置为True就可以进行第一步剪枝
    parser.add_argument("--prune", default=False, dest="prune", action="store_true")
    # 将defaut置为True就可以进行第二步剪枝
    parser.add_argument("--prune_layer", default=False, dest="prune_layer", action="store_true") 
    # 将defaut置为True就可以进行模型测试,分别包括原始模型,第一步剪枝后的模型,第三步剪枝后的模型,结果将会保存在根目录的三个json文件中
    parser.add_argument("--test_pruned", default=False, dest="test_pruned", action="store_true") 
    args = parser.parse_args()
    return args

# 训练好的模型保存在checkpoint里面
# ckpt.train:原始模型权重文件
# ckpt.pure:第一步剪枝后的模型权重文件
# ckpt.pure_layer_x:第二步剪枝后的模型,第x层的权重文件

# 测试模型的json文件
# log_original.json:原始模型的测试结果,包括每一层的卷积核数量、准确率、每一层的计算时延、中间层特征变量带宽等
# log_prune.json:第一步剪枝后的模型测试结果,包括每一层的卷积核数量、准确率、每一层的计算时延、中间层特征变量带宽等
# log_log_prune_layer.json:第二步剪枝后模型的测试结果,包括每一层的卷积核数量、准确率、每一层的计算时延、中间层特征变量带宽等

acc_thre = 0.98 # 剪枝结束条件,剪枝后的模型准确率小于原始模型精度的98%

# (3)运行代码 python ./main.py 
# 也可以使用命令行命令 python ./main.py --train 开启训练模式
# 也可以使用命令行命令 python ./main.py --prune 开启第一步剪枝模式
# 也可以使用命令行命令 python ./main.py --prune_layer 开启第二步剪枝模式

###################################################################################################################################
###################################################下面的代码无需修改###############################################################
#############################################剪枝结束条件:小于原始模型精度的98%#####################################################
###################################################################################################################################

# 这里确认是否使用多线程
if os.name == 'nt':  # windows
    num_workers = 0
else:  # linux
    num_workers = 8
    os.environ["CUDA_VISIBLE_DEVICES"] = '1'

use_cuda = torch.cuda.is_available() # 训练设备
start_epoch = 1  # 训练从0开始还是从最后一轮的check point开始
total_filter_num_pre_prune = 0 # 每轮剪枝数量
batch_size = 32 # 批量大小

# 数据集准备
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

if not os.path.exists("./data"):
    os.mkdir("./data")

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # cifar10 的十个分类


# 训练函数
def train(optimizer=None, rankfilters=False, net = net):
    if optimizer is None:
        optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    net.train() # 开启训练模式
    train_loss = 0 # 训练损失
    correct = 0 # 正确数量
    total = 0 # 总数量
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        # 将数据拷贝到设备上
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad() # 清空梯度
        inputs, targets = Variable(inputs), Variable(targets)
        if rankfilters:
            outputs = pruner.forward(inputs) # 使用剪枝后的模型训练
            loss = criterion(outputs, targets) # 这里使用的就是交叉熵损失函数
            loss.backward() # 这里使用反向传播,会进入到pruner的给每个变量注册的钩子函数中去
        else:
            outputs = net(inputs) # 原始模型训练
            loss = criterion(outputs, targets) # 这里使用的就是交叉熵损失函数
            loss.backward()
            optimizer.step()

        train_loss += loss.data.item()  # item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

    print('Train Loss: %.3f | Acc: %.3f%% (%d/%d)'
          % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))


# 测试函数
def test(log_index=-1):
    net.eval() # 开启评估模式
    test_loss = 0
    correct = 0
    total = 0
    if log_index == -1 or use_cuda:
        # 原始模型的测试
        for batch_idx, (inputs, targets) in enumerate(testloader):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = Variable(inputs,requires_grad=True), Variable(targets)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()  # loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()

        print('Test  Loss: %.3f | Acc: %.3f%% (%d/%d)' % (
            test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
        acc = 100. * correct / total

    if log_index != -1:
        # 剪枝模型的测试
        (inputs, targets) = list(testloader)[0]
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        # get profile
        # Profiler用于分析CPU、GPU端Op执行时间
        # 这里就是查看网络中每一个部分的时延
        with torch.autograd.profiler.profile() as prof:
            net(Variable(inputs))
            # print(next(net.parameters()).is_cuda)
        pruner.forward_n_track(Variable(inputs), log_index) # pruner是FilterPruner的实例对象
        cfg = pruner.get_cfg() # 返回模型的配置数组,包含卷积的输出通道和最大池化层

        # get log for time/bandwidth
        delta_ts = []
        bandwidths = []
        for i in range(len(cfg)): # 读出每一个卷积层的输出通道和最大池化层
            delta_ts.append(
                sum(item.cpu_time for item in prof.function_events[:pruner.conv_n_pool_to_layer[i]]) /
                np.power(10, 6) / batch_size) # 获取计算时延
            if isinstance(cfg[i], int): # 卷积层的带宽
                bandwidths.append(
                    int(cfg[i] * (inputs.shape[2] * inputs.shape[3]) / np.power(4, cfg[:i + 1].count('M'))))
            else:
                bandwidths.append( # 最大池化层使用的是前一层的卷积层的输出
                    int(cfg[i - 1] * (inputs.shape[2] * inputs.shape[3]) / np.power(4, cfg[:i + 1].count('M'))))

        data = {
            'acc': acc if use_cuda else -1,
            'index': log_index,
            'delta_t_prof': delta_ts[log_index],
            'delta_ts': delta_ts,
            'bandwidth': bandwidths[log_index],
            'bandwidths': bandwidths,
            'layer_cfg': cfg[log_index],
            'config': cfg
        }
        return data

    return acc


# 模型保存函数
def save(acc, conv_index=-1, epoch=-1):
    print('Saving..')
    try:
        # save the cpu model 保存CPU版本的模型
        model = net.module if isinstance(net, torch.nn.DataParallel) else net # 判断是否使用多GPU运算
        state = {
            'net': model.cpu() if use_cuda else model, # 将模型搬运到cpu上
            'acc': acc,
            'conv_index': conv_index,
            'epoch': epoch,
        }
    except:
        pass
    # 将模型保存到checkpoint里面
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    if args.prune: # 如果是step1
        torch.save(state, './checkpoint/ckpt.prune')
    elif args.prune_layer and conv_index != -1: # 如果是step2
        torch.save(state, './checkpoint/ckpt.prune_layer_%d' % conv_index)
    elif epoch != -1: # 如果是epoch != -1的情况
        torch.save(state, './checkpoint/ckpt.train.epoch_' + str(epoch))
    else: # 其他情况下
        torch.save(state, './checkpoint/ckpt.train')

    # restore the cuda or cpu model
    if use_cuda:
        net.cuda()

# filter 裁剪器
class FilterPruner:
    def __init__(self, model):
        self.model = model
        self.reset() # 定义实例变量 filter_rank

    def reset(self):
        self.filter_ranks = {}

    # forward method that gives "compute_rank" a hook
    # 模型前向计算
    def forward(self, x):
        self.activations = [] # 用于存放卷积层的输出结果
        self.gradients = []
        self.grad_index = 0
        self.activation_to_layer = {} # 将对应的卷积层的索引加入字典

        conv_index = 0
        for layer, (name, module) in enumerate(self.model.features._modules.items()): # 枚举出每一层的名称和神经网络层
            x = module(x)
            if isinstance(module, torch.nn.modules.Conv2d): # 如果是卷积层
                # register_hook是给可以计算梯度的张量注册一个钩子函数,这个函数会在进行反向传播的时候执行
                x.register_hook(self.compute_rank)
                self.activations.append(x) # 将输出结果加入到activations变量中
                self.activation_to_layer[conv_index] = layer # 将activateions对应的层索引加入到activation_to_layer字典中
                conv_index += 1

        if self.model.avgpool != "":
            x = self.model.avgpool(x)

        return self.model.classifier(x.view(x.size(0), -1)) # 调用模型最后一层。

    # forward method that tracks computation info
    # 前向方法追踪前向信息
    # log_index是为了记录0-log_index层的时间延迟和卷积计算量
    def forward_n_track(self, x, log_index=-1):
        self.conv_n_pool_to_layer = {}

        index = 0
        delta_t_computations = 0
        all_conv_computations = 0  # 给定层的卷积计算数
        t0 = time.time() # 记录时间
        for layer, (name, module) in enumerate(self.model.features._modules.items()): # 枚举出每一层的名称和神经网络层
            x = module(x) # 前向计算 输出结果 N,C,H,W
            # 如果是relu层或者最大池化层
            if isinstance(module, torch.nn.modules.ReLU) or isinstance(module, torch.nn.modules.MaxPool2d):
                all_conv_computations += np.prod(x.data.shape[1:]) # prod是计算给定元素的乘积,C*H*W
                self.conv_n_pool_to_layer[index] = layer
                if log_index == index: # 如果log_index 和 index相等
                    delta_t = time.time() - t0 # 计算一次时间间隔
                    delta_t_computations = all_conv_computations # 记录一次卷积计算量
                    bandwidth = np.prod(x.data.shape[1:]) # 带宽等于当前层输出的卷积计算量, 即:C*H*W
                index += 1

        return delta_t, delta_t_computations, bandwidth, all_conv_computations

    # for all the conv layers
    # 获取最后一个卷积层的索引
    def get_conv_index_max(self):
        conv_index = 0
        for layer, (name, module) in enumerate(self.model.features._modules.items()):
            if isinstance(module, torch.nn.modules.Conv2d):
                conv_index += 1
        return conv_index

    # for all the conv layers and pool2d layers 循环所有的卷积层和池化层
    # 返回配置文件,cfg数组,如果是卷积层就加入输出通道,如果是最大池化层就加入"M" 
    def get_cfg(self):
        cfg = []
        for layer, (name, module) in enumerate(self.model.features._modules.items()): # 循环枚举所有的神经网络层
            if isinstance(module, torch.nn.modules.Conv2d): # 如果是卷积层就加入
                cfg.append(module.out_channels)
            elif isinstance(module, torch.nn.modules.MaxPool2d):
                cfg.append('M')
        return cfg

    # 计算排名,输入为梯度,这个函数会被递归的调用。
    # 使用一阶泰勒作为卷积核的值作为这个卷积核的作用能力
    def compute_rank(self, grad):
        conv_index = len(self.activations) - self.grad_index - 1 # 卷积层输出数量 - 0 - 1
        activation = self.activations[conv_index] # 最后一层输出的结果
        # print(activation.shape) # torch.Size([32, 512, 2, 2]) N C H W vgg16最后一层的通道数是512
        values = torch.sum((activation * grad), dim=0, keepdim=True).sum(dim=2, keepdim=True).sum(dim=3, keepdim=True)[
                 0, :, 0, 0].data  # compute the total 1st order taylor for each filters in a given layer 计算给定层中每个滤波器的总一阶泰勒,计算出来的是一个值

        # Normalize the rank by the filter dimensions 按筛选器维度规范化排名
        values = values / (activation.size(0) * activation.size(2) * activation.size(3))

        if conv_index not in self.filter_ranks:  # set self.filter_ranks[conv_index], 如果不在的话就添加
            self.filter_ranks[conv_index] = torch.FloatTensor(activation.size(1)).zero_() # 初始化每一个卷积核,初始化为0
            if use_cuda:
                self.filter_ranks[conv_index] = self.filter_ranks[conv_index].cuda() # 搬到cuda上去

        self.filter_ranks[conv_index] += values # 更新值, 所以self.filter_ranks是一个字典,len(self.filter_ranks[12]) = 512
        self.grad_index += 1 # 梯度向前更新

    def lowest_ranking_filters(self, num, conv_index):
        # print(len(self.filter_ranks.items())) # 13
        data = []
        if conv_index == -1:
            for i in sorted(self.filter_ranks.keys()):
                for j in range(self.filter_ranks[i].size(0)): # j -> N
                    data.append((self.activation_to_layer[i], j, self.filter_ranks[i][j])) # 第i层,第j条数据
        else:
            for j in range(self.filter_ranks[conv_index].size(0)):
                data.append((self.activation_to_layer[conv_index], j, self.filter_ranks[conv_index][j])) # 第i层,第j条数据

        # print(len(data)) # 4224 因为它
        # print(self.filter_ranks[0].shape) # torch.Size([64])
        # print(data[0]) # (0, 0, tensor(0.0844)) 第一个数是卷积层层数,第二个数是该层的第j个卷积核,第三个值是filter_ranks的数值,代表这个卷积核的重要程度,或者作用能力
        # exit()
        # 下面的代码等价于:sorted(x, key = itemgetter(2))[:num],寻找最小的num个数据,根据data下标为2的一列数据进行排序
        return nsmallest(num, data, itemgetter(2))  # find the minimum of data[_][2], aka, self.filter_ranks[i][j]

    # 归一化每一层的V值
    def normalize_ranks_per_layer(self):
        for i in self.filter_ranks:
            v = torch.abs(self.filter_ranks[i])
            v = v.cpu().numpy() / np.sqrt(torch.sum(v * v).cpu().numpy())
            self.filter_ranks[i] = torch.tensor(v).cpu() # 将归一化的值返回
    
    # 获取剪枝策略
    def get_pruning_plan(self, num_filters_to_prune, conv_index):
        filters_to_prune = self.lowest_ranking_filters(num_filters_to_prune, conv_index) # 获取num_filters_to_prune个作用程度不高的卷积核

        # print(filters_to_prune[0]) # (30, 54, tensor(1.5585e-05)) 这是一个数组
        # exit()
        # After each of the k filters are pruned,在k个卷积核被剪枝过后
        # the filter index of the next filters change since the model is smaller.下一个卷积核的索引会更改,因为模型变小了
        # 总体就是说,每次剪枝的卷积核不一样
        filters_to_prune_per_layer = {}
        for (l, f, _) in filters_to_prune: # l表示层标号,f表示卷积核标号
            if l not in filters_to_prune_per_layer: # 如果不在filters_to_prune_per_layer字典中,添加
                filters_to_prune_per_layer[l] = [] # 空数组
            filters_to_prune_per_layer[l].append(f) # 往数组中添加元素

        for l in filters_to_prune_per_layer: # 如果该层存在字典中
            filters_to_prune_per_layer[l] = sorted(filters_to_prune_per_layer[l]) # 对该层的卷积核标号进行排序
            for i in range(len(filters_to_prune_per_layer[l])): # 遍历所有要被裁减的卷积核
                filters_to_prune_per_layer[l][i] = filters_to_prune_per_layer[l][i] - i # 不知道为什么要减去i ????????????

        filters_to_prune = []
        for l in filters_to_prune_per_layer: # 遍历该字典
            for i in filters_to_prune_per_layer[l]:
                filters_to_prune.append((l, i))
        # print(filters_to_prune[0]) # (30, 12) 30对应原始网络的位置
        return filters_to_prune


    def get_candidates_to_prune(self, num_filters_to_prune, conv_index):
        self.reset() # 将卷积核排名置为空
        train(rankfilters=True) # 前向传播
        # print(len(self.filter_ranks.items())) # 13, 也就是卷积层的数量
        self.normalize_ranks_per_layer() # 这就是做filter_ranks的归一化

        return self.get_pruning_plan(num_filters_to_prune, conv_index) # 输出要被裁减的num_filters_to_prune个卷积核在原始网络中的坐标

    # 计算总filter量,等价于计算卷积核的个数
    def total_num_filters(self, conv_index):
        filters = 0
        i = 0
        for name, module in list(self.model.features._modules.items()): # 枚举每一层的网络名称和神经网络层
            if isinstance(module, torch.nn.modules.Conv2d): # 判断是否为卷积层,仅计算卷积层
                if conv_index == -1:
                    filters = filters + module.out_channels # filters就是卷积核的个数,也是输出通道数
                elif conv_index == i:
                    filters = filters + module.out_channels
                i = i + 1
        return filters

    # 剪枝操作
    def prune(self, conv_index=-1):
        # 在剪枝之前先测试原始神经网络的准确度
        acc_pre_prune = test()
        acc = acc_pre_prune

        # train(rankfilters=True)

        # 将所有的特征参数的requires_grad参数
        for param in self.model.features.parameters():
            param.requires_grad = True

        number_of_filters = pruner.total_num_filters(conv_index) # 计算卷积核的总量 4224
        num_filters_to_prune_per_iteration = math.ceil(number_of_filters / 16) # 每个迭代轮次,剪枝的数量 264

        # 剪枝结束条件:
        # (1)剪枝后的精度不小于,原始模型精度的98%
        # (2)剪枝后的模型卷积和数量不能少于总量的20%
        while acc > acc_pre_prune * acc_thre and pruner.total_num_filters(conv_index) / number_of_filters > 0.2:
            # print("Ranking filters.. ")

            # conv_index = -1
            # 获取会被减去的num_filters_to_prune_per_iteration个卷积核,返回值是一个二维,(num,2)
            prune_targets = pruner.get_candidates_to_prune(num_filters_to_prune_per_iteration, conv_index)
            num_layers_pruned = {}  # filters to be pruned in each layer 一个字典存放每一层被剪枝的卷积核数量
            for layer_index, filter_index in prune_targets: # 枚举被剪枝的层索引和卷积核索引
                if layer_index not in num_layers_pruned: # 如果其不在字典中,进行添加
                    num_layers_pruned[layer_index] = 0
                num_layers_pruned[layer_index] = num_layers_pruned[layer_index] + 1 # 计数器+1

            print("Layers that will be pruned", num_layers_pruned) # 输出会被剪枝的所有层和卷积核数量
            print("..............Pruning filters............. ")
            if use_cuda:
                self.model.cpu()

            for layer_index, filter_index in prune_targets: # 一个卷积核一个卷积核的裁剪
                prune_conv_layer(self.model, layer_index, filter_index) # 进行卷积核裁剪

            if use_cuda: # 将模型搬到cuda上
                self.model.cuda()
                # self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
                # cudnn.benchmark = True

            optimizer = torch.optim.SGD(self.model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

            # 每剪完一次就会进行微调和评估一次
            print("%d / %d Filters remain." % (pruner.total_num_filters(conv_index), number_of_filters))
            # test()
            print("Fine tuning to recover from pruning iteration.") # 剪枝完之后进行微调,训练两轮
            for epoch in range(2):
                train(optimizer, net=self.model)
            acc = test() # 测试准确度
            pass
            if acc <= acc_pre_prune * acc_thre:
                pass

        print("Finished. Going to fine tune the model a bit more")
        for epoch in range(5):
            train(optimizer)
        test()
        pass

if __name__ == '__main__':
    args = get_args() # 获取配置参数
    # prune = FilterPruner(net)
    # print(prune.total_num_filters(-1))
    # exit()

    # 构建模型
    if args.train:
        # 直接构建模型
        print('==> Building model..')
    else:
        # Load checkpoint.
        # 从checkpoint处加载模型
        print('==> Resuming from checkpoint..')
        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/ckpt.train')
        net = checkpoint['net']
        acc = checkpoint['acc']
        start_epoch = checkpoint['epoch'] + 1

    # 判断是否适用cuda,是否使用多GPU训练
    if use_cuda:
        net.cuda()
        # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
        # cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss() # 交叉熵
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) # 随机梯度下降损失函数

    # torch.nn.DataParallel是用于多卡/多GPU训练的,构建FilterPruner,传入模型
    pruner = FilterPruner(net.module if isinstance(net, torch.nn.DataParallel) else net) # 
    total_filter_num_pre_prune = pruner.total_num_filters(conv_index=-1)

    if args.prune:
        # 模型剪枝,step 1
        pruner.prune()
        acc = test() # 剪完之后测试
        save(acc) # 保存模型
        pass
    elif args.prune_layer:
        # 层剪枝,step 2
        # this is after --prune the whole model
        conv_index_max = pruner.get_conv_index_max() # 获取最后一个卷积层的索引
        for conv_index in range(conv_index_max):
            print('==> Resuming from checkpoint..')
            assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
            checkpoint = torch.load('./checkpoint/ckpt.prune') # 加载step1保存的模型权重文件和准确率
            net = checkpoint['net'] # 模型+权重
            acc = checkpoint['acc'] # 准确率
            if use_cuda: # 将模型搬到cuda上
                net.cuda()
                # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
                # cudnn.benchmark = True
            # create new pruner in each iteration
            pruner = FilterPruner(net.module if isinstance(net, torch.nn.DataParallel) else net) # 构建FilterPruner实例
            total_filter_num_pre_prune = pruner.total_num_filters(conv_index=-1)
            # prune given layer
            pruner.prune(conv_index) # 一层一层的剪枝,剪第conv_index层
            acc = test() # 测试函数
            save(acc, conv_index) # 保存模型
            pass
    elif args.train or args.resume:
        # 训练或者从检查点恢复
        # 从start_epoch到epoch
        for epoch in range(start_epoch, start_epoch + args.epoch):
            print('\nEpoch: %d' % epoch)
            train() # 训练
            acc = test() # 测试
            if epoch % 10 == 0:
                save(acc, -1, epoch) # 每隔十轮进行一次保存
                pass
        save(acc)
    elif args.test_pruned:
        # 测试剪枝后的模型
        use_cuda = 0
        cfg = pruner.get_cfg() # 获取卷积层输出通道和池化层
        conv_index_max = pruner.get_conv_index_max() # 获取最后一层的卷积层的索引
        original_data = [] # 原始数据
        prune_data = [] # step1数据
        prune_layer_data = [] # step2数据

        last_conv_index = 0  # log for checkpoint restoring, nearest conv layer, 检查点还原的日志,最近的 conv 层
        for index in range(len(cfg)):
            # original 原始模型, 从检查点开始恢复
            print('==> Resuming from checkpoint..')
            assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
            checkpoint = torch.load('./checkpoint/ckpt.train')
            net = checkpoint['net'] # 获取网络模型
            acc = checkpoint['acc'] # 获取模型准确度
            if use_cuda: # 将模型搬到cuda上
                net.cuda()
                # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
                # cudnn.benchmark = True
            # create new pruner in each iteration
            pruner = FilterPruner(net.module if isinstance(net, torch.nn.DataParallel) else net) # 构建FilterPruner实例
            total_filter_num_pre_prune = pruner.total_num_filters(conv_index=-1) # 获取卷积核总数
            data = test(index) # 测试原始模型的准确度
            if data['acc'] == -1:
                data['acc'] = acc
            original_data.append(data) # 原始模型的准确度加入到数组中

            # prune,恢复剪枝后的模型
            print('==> Resuming from checkpoint..')
            assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
            checkpoint = torch.load('./checkpoint/ckpt.prune')
            net = checkpoint['net']
            acc = checkpoint['acc']
            if use_cuda:
                net.cuda()
                # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
                # cudnn.benchmark = True
            # create new pruner in each iteration
            pruner = FilterPruner(net.module if isinstance(net, torch.nn.DataParallel) else net)
            total_filter_num_pre_prune = pruner.total_num_filters(conv_index=-1) # 剪枝后的卷积核总数
            data = test(index)
            if data['acc'] == -1:
                data['acc'] = acc
            prune_data.append(data)

            # prune_layer 从step2进行恢复
            print('==> Resuming from checkpoint..')
            assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
            checkpoint = torch.load('./checkpoint/ckpt.prune_layer_' + str(last_conv_index))
            # checkpoint = torch.load('./checkpoint/ckpt.prune')
            net = checkpoint['net']
            acc = checkpoint['acc']
            if use_cuda:
                net.cuda()
                # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
                # cudnn.benchmark = True
            # create new pruner in each iteration
            pruner = FilterPruner(net.module if isinstance(net, torch.nn.DataParallel) else net)
            total_filter_num_pre_prune = pruner.total_num_filters(conv_index=-1)
            data = test(index)
            if data['acc'] == -1:
                data['acc'] = acc
                prune_layer_data.append(data)

            if index + 1 < len(cfg):
                if not isinstance(cfg[index + 1], str):
                    last_conv_index += 1

        # 将原模型、step1、step2的信息写入json文件
        with open('./log_original.json', 'w') as fp:
            json.dump(original_data, fp, indent=2) 
        with open('./log_prune.json', 'w') as fp:
            json.dump(prune_data, fp, indent=2)
        with open('./log_prune_layer.json', 'w') as fp:
            json.dump(prune_layer_data, fp, indent=2)

  • model_refactor.py
import torch
from torch.autograd import Variable
from torchvision import models
import sys
import numpy as np
import os
import time


def replace_layers(model, i, indexes, layers):
    if i in indexes:
        return layers[indexes.index(i)]
    return model[i]


def prune_conv_layer(model, layer_index, filter_index):
    _, conv = list(model.features._modules.items())[layer_index] # 选出所有需要被裁减的卷积层
    batchnorm = None 
    next_conv = None 
    offset = 1 

    while layer_index + offset < len(list(model.features._modules.items())):  # get next conv
        res = list(model.features._modules.items())[layer_index + offset]
        if isinstance(res[1], torch.nn.modules.conv.Conv2d):
            _, next_conv = res
            break
        offset = offset + 1

    if conv.in_channels % conv.groups != 0 or (conv.out_channels-1) % conv.groups != 0:
        return
    
    if next_conv is not None:
        if conv.in_channels % conv.groups != 0 or (conv.out_channels-1) % conv.groups != 0 or \
            (next_conv.in_channels-1) % next_conv.groups != 0 or next_conv.out_channels % next_conv.groups != 0:
            return 

    res = list(model.features._modules.items())[layer_index + 1]
    if isinstance(res[1], torch.nn.modules.BatchNorm2d):
        _, batchnorm = res

    is_bias_present = False
    if conv.bias is not None:
        is_bias_present = True

    new_conv = \
        torch.nn.Conv2d(in_channels=conv.in_channels,
                        out_channels=conv.out_channels - 1,
                        kernel_size=conv.kernel_size,
                        stride=conv.stride,
                        padding=conv.padding,
                        dilation=conv.dilation,
                        groups=conv.groups,
                        bias=is_bias_present)

    old_weights = conv.weight.data.cpu().numpy()
    new_weights = new_conv.weight.data.cpu().numpy()

    new_weights[: filter_index, :, :, :] = old_weights[: filter_index, :, :, :]
    new_weights[filter_index:, :, :, :] = old_weights[filter_index + 1:, :, :, :]
    new_conv.weight.data = torch.from_numpy(new_weights).cuda()

    if is_bias_present:
        bias_numpy = conv.bias.data.cpu().numpy()

        bias = np.zeros(shape=(bias_numpy.shape[0] - 1), dtype=np.float32)
        bias[:filter_index] = bias_numpy[:filter_index]
        bias[filter_index:] = bias_numpy[filter_index + 1:]
        new_conv.bias.data = torch.from_numpy(bias).cuda()

    if next_conv is not None:
        is_bias_present = False
        if next_conv.bias is not None:
            is_bias_present = True
        next_new_conv = \
            torch.nn.Conv2d(in_channels=next_conv.in_channels - 1,
                            out_channels=next_conv.out_channels,
                            kernel_size=next_conv.kernel_size,
                            stride=next_conv.stride,
                            padding=next_conv.padding,
                            dilation=next_conv.dilation,
                            groups=next_conv.groups,
                            bias=is_bias_present)

        old_weights = next_conv.weight.data.cpu().numpy()
        new_weights = next_new_conv.weight.data.cpu().numpy()

        new_weights[:, : filter_index, :, :] = old_weights[:, : filter_index, :, :]
        new_weights[:, filter_index:, :, :] = old_weights[:, filter_index + 1:, :, :]
        next_new_conv.weight.data = torch.from_numpy(new_weights).cuda()
        if is_bias_present: next_new_conv.bias.data = next_conv.bias.data

    if batchnorm is not None:
        new_batchnorm = \
            torch.nn.BatchNorm2d(conv.out_channels - 1)

        try:
            old_weights = batchnorm.weight.data.cpu().numpy()
            new_weights = new_batchnorm.weight.data.cpu().numpy()
            new_weights[:filter_index] = old_weights[:filter_index]
            new_weights[filter_index:] = old_weights[filter_index + 1:]
            new_batchnorm.weight.data = torch.from_numpy(new_weights).cuda()

            bias_numpy = batchnorm.bias.data.cpu().numpy()
            bias = np.zeros(shape=(bias_numpy.shape[0] - 1), dtype=np.float32)
            bias[:filter_index] = bias_numpy[:filter_index]
            bias[filter_index:] = bias_numpy[filter_index + 1:]
            new_batchnorm.bias.data = torch.from_numpy(bias).cuda()
        except ValueError:
            pass


    if batchnorm is not None:
        features = torch.nn.Sequential(
            *(replace_layers(model.features, i, [layer_index + 1],
                         [new_batchnorm]) for i, _ in enumerate(model.features)))
        del model.features
        model.features = features


    if next_conv is not None:
        features = torch.nn.Sequential(
                *(replace_layers(model.features, i, [layer_index, layer_index + offset],
                                 [new_conv, next_new_conv]) for i, _ in enumerate(model.features)))

        del model.features
        del conv
        model.features = features

    else:
        # Prunning the last conv layer. This affects the first linear layer of the classifier.
        model.features = torch.nn.Sequential(
            *(replace_layers(model.features, i, [layer_index],
                             [new_conv]) for i, _ in enumerate(model.features)))
        layer_index = 0
        old_linear_layer = None
        one_layer_classifier = False
        for _, module in list(model.classifier._modules.items()):
            if isinstance(module, torch.nn.Linear):
                old_linear_layer = module
                break
            layer_index = layer_index + 1

        if isinstance(model.classifier, torch.nn.Linear):
            old_linear_layer = model.classifier
            one_layer_classifier = True
            layer_index = layer_index + 1

        if old_linear_layer is None:
            raise BaseException("No linear layer found in classifier")
        params_per_input_channel = round(old_linear_layer.in_features / conv.out_channels)

        new_linear_layer = \
            torch.nn.Linear(old_linear_layer.in_features - params_per_input_channel,
                            old_linear_layer.out_features)

        old_weights = old_linear_layer.weight.data.cpu().numpy()
        new_weights = new_linear_layer.weight.data.cpu().numpy()

        new_weights[:, : filter_index * params_per_input_channel] = \
            old_weights[:, : filter_index * params_per_input_channel]
        new_weights[:, filter_index * params_per_input_channel:] = \
            old_weights[:, (filter_index + 1) * params_per_input_channel:]

        new_linear_layer.bias.data = old_linear_layer.bias.data

        new_linear_layer.weight.data = torch.from_numpy(new_weights).cuda()

        if one_layer_classifier:
            classifier = new_linear_layer
        else:
            classifier = torch.nn.Sequential(
                *(replace_layers(model.classifier, i, [layer_index],
                                 [new_linear_layer]) for i, _ in enumerate(model.classifier)))

        del model.classifier
        del next_conv
        del conv
        model.classifier = classifier

    return # model


https://www.xamrdz.com/backend/3x71896423.html

相关文章: