调用 sparse_model()
以获取稀疏模型。此方法会查找所有 nn.Conv2d
/ nn.ConvTranspose2d
和 nn.BatchNorm2d
模块,并将这些模块替换为 DynamicConv2d
/ DynamicConvTranspose2d
和 DynamicBatchNorm2d
。此方法会将满足稀疏度条件的 nn.Conv2d
/ nn.
linear
层替换为 SparseConv2d
/ SparseLinear
。
此方法支持 nn.Conv2d
/ nn. linear
权重和激活同时执行剪枝。激活的稀疏度可为 0 或 0.5。当激活的稀疏度为 0 时,权重的稀疏度可设为 0、0.5 或 0.75。当激活的稀疏度为 0.5 时,权重的稀疏度只能设为 0.75。block_size 是输入通道的连续元素数。通道根据权重/激活展开。它设为 4、8 或 16。因此,如果卷积的输入通道权重大于 16,就会被替换为稀疏卷积。
sparse_model = sparse_pruner.sparse_model(w_sparsity=0.5,a_sparsity=0,block_size=4)
重新训练稀疏模型与创建基线模型并无不同。知识蒸馏可以实现更好的精度。