tf_nndct1.IterativePruningRunner
__init__(self,
model_name: str,
sess: SessionInterface,
input_specs: Mapping[str, tf.TensorSpec],
output_node_names: List[str],
excludes: List[str]=[])
实参:
- model_name
- 模型名称。
- sess
- TensorFlow 会话的实例包含计算图和初始化的变量。
- input_specs
- 此映射的键是基线模型的输入节点名称。
- output_node_names
- 目标输出节点名称。
- excludes
- 跳过剪枝的节点名称。
返回:IterativePruningRunner 的实例
ana(self,
eval_fn: Callable[[tf.compat.v1.GraphDef], float],
gpu_ids: List[str]=['/GPU:0'],
checkpoint_interval: int = 10) -> None:
实参
- eval_fn
- 此函数用于评估剪枝进程的中间结果。需返回浮点。
- gpu_ids
- 该字符串列表表示要运行评估的器件。
- checkpoint_interval
- 此方法会实现高速缓存机制,并保存每次执行 checkpoint_interval 评估的结果。
返回:无
prune(self,
sparsity: float=None,
threshold: float=None,
max_attemp: int=10) -> Tuple[Mapping[str, TensorProto], Mapping[str, np.ndarray]]:
有两种剪枝模式:基于 FLOP 的方法和基于精度的方法,对应实参稀疏度和阈值。这两个实参 不得 同时设为 None。实参稀疏度的优先级高于阈值。
实参:
- sparsity
- 此比率表示前向传递中模型的浮点计算量的缩减情况。
- threshold
- 范围为 [0, 1]。表示剪枝后的计算图与原始计算图之间最大可接受的相对差值。
- max_attemp
- 剪枝运行程序会以迭代方式查找最优剪枝策略,执行 max_attemp 个步骤后无论如何都会返回结果。
返回:
- shape_tensors
- 该字符串表示 NodeDef 映射。键为 graph_def 中 node_def 的名称,需更新后方可获取精简计算图。值为目标 node_def 内容掩码。
- masks
- 表示对应于变量的字符串到阵列映射。
get_slim_graph_def(self,
shape_tensors: Mapping[str, TensorProto]=None,
masks: Mapping[str, np.ndarray]=None) -> tf.compat.v1.GraphDef:
实参:
- shape_tensors
- 从 prune 方法返回的字符串到 NodeDef 映射。
- masks
- 表示对应于变量的字符串到阵列映射。该对象是从 prune 方法获取的。
返回:冻结的精简 graph_def。