此 API 具有下列方法:
-
__init__(model, inputs)
- model
- 要进行剪枝的
torch.nn.Module
对象。 - inputs
- 单个 torch 或 torch 列表。张量用作为模型推断的输入。输入无需采用真实数据。可以采用随机生成的张量,与真实数据的形状和数据类型相同即可。
-
ofa_model(expand_ratio, channel_divisble=8, excludes=None, auto_add_excludes=True, save_search_space=False)
- expand_ratio
- 每个卷积层的剪枝率列表。OFA 模型中每个卷积层的输出通道均可使用任意的剪枝率。
此列表中的最大值和最小值分别表示该模型的最大和最小压缩率。其他值则表示要最优化的子网络。剪枝率默认设为 [0.5, 0.75, 0.1]。
- channel_divisible
- 可除以给定除数的通道数量。
- excludes
- 要从剪枝中排除的模块列表。
- auto_add_excludes
- 布尔值。如果该值为 True,那么此方法会自动识别第一个卷积和最后一个卷积,并将其置于排除列表中。如果为 False,则跳过创建排除列表。默认值为 True。
- save_search_space
- 布尔值。如果该值为 True,则将模型的搜索空间保存为“searchspace.config”文件。您可以检查每个层的搜索空间。默认值为 False。
-
sample_subnet(model, mode)
返回子网络及其给定模式的配置。该子网络可以使用来自 OFA 模型及其设置的部分权重执行前向/后向传递进程。
- model
- OFA 模型。
- mode
- 下列值之一:['random', 'max', 'min']。
-
reset_bn_running_stats_for_calibration(model)
复位 Batch Normalization 层的运行统计数据。
- model
- OFA 模型。
-
run_evolutionary_search(model, calibration_fn, calib_args, eval_fn, eval_args, evaluation_metric, min_or_max_metric, min_macs, max_macs, macs_step=10, parent_popu_size=16, iteration=10, mutate_size=8, mutate_prob=0.2, crossover_size=4)
运行进化搜索,查找最佳子网络,该子网络的 MAC 在给定范围内。
- model
- OFA 模型。
- calibration_fn
- BatchNormalization 校准函数。所有子网络共享 OFA 模型中的权重,但在训练 OFA 模型时不存储批量归一化统计数据(平均值和方差)。训练完成后,必须使用训练数据为用于评估的每个已采样的子网络重新校准批量归一化统计数据。
- calib_args
- calibration_fn 的实参。
- eval_fn
- 用于评估模型的函数。
- eval_args
- eval_fn 的实参。
- evaluation_metric
- 用于记录结果的 evaluation_metric 字符串。
- min_or_max_metric
- 下列值之一:['max', 'min']。要记录在进化搜索中的评估指标的最大值或最小值。例如,如果评估指标精度为 top1,则在进化搜索中记录每次迭代的最大值。但如果评估指标为平均平方误差 (mse) 或平均绝对误差 (mae),则记录最小值。
- min_macs
- 已搜索的子网络的最小 MAC 数。
- max_macs
- 已搜索的子网络的最大 MAC 数。
- macs_step
- 搜索 MAC 的步骤。将间隔 [min_macs, max macs] 除以 macs_step 即可对其进行分割。对于每个分割段,搜索已达成 MAC 数与精度的最佳平衡的子网络。
- parent_popu_size
- 对于所含 MAC 数在给定范围内的随机子网络,此项表示对其进行采样时的初始父填充数。该数值越大,搜索时间越长,并且获取最佳结果的可能性越高。
- iteration
- 搜索的迭代次数或者整个算法的周期数。
- mutate_size
- 突变的大小。子网络设置的每个值都被替换为候选值列表的另一个值(概率为 mutate_prob)。
- mutate_prob
- 突变的概率。
- crossover_size
- 交叉的大小。对两项子网络设置进行采样并对这两项子网络设置中的任意值进行随机交换。
-
save_subnet_config(setting_config, file_name)
利用 JSON 保存动态/静态子网络设置。
- setting_config
- 动态子网络设置的配置。
- file_name
- 用于保存子网络设置的文件路径。
-
load_subnet_config(file_name)
- file_name
- 用于加载子网络设置的文件路径。