剪枝器需要两个实参:
- 要剪枝的模型
- 模型推断所需的输入
import torch
from pytorch_nndct import SparsePruner
inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
pruner = SparsePruner(model, inputs)
剪枝器需要两个实参:
import torch
from pytorch_nndct import SparsePruner
inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
pruner = SparsePruner(model, inputs)