The Pruner
class requires two arguments.
- The model to be pruned
- The inference inputs
Note: It is not necessary for the input to be real data. It can be randomly
generated dummy data as long as it has the same shape and type as the real data.
import torch
from pytorch_nndct import Pruner
inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
pruner = Pruner(model, inputs)
For models with multiple inputs, you can use a list or a tuple of inputs to initialize a pruner.