プルーナーには、次の 2 つの引数が必要です。
- プルーニングするモデル
- モデルによる推論に必要な入力
import torch
from pytorch_nndct import SparsePruner
inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
pruner = SparsePruner(model, inputs)
プルーナーには、次の 2 つの引数が必要です。
import torch
from pytorch_nndct import SparsePruner
inputs = torch.randn([1, 3, 224, 224], dtype=torch.float32)
pruner = SparsePruner(model, inputs)