vai_p_tensorflow API - 3.5 简体中文

Vitis AI 用户指南 (UG1414)

Document ID
UG1414
Release Date
2023-09-28
Version
3.5 简体中文

tf_nndct1.IterativePruningRunner

__init__(self,
         model_name: str, 
         sess: SessionInterface, 
         input_specs: Mapping[str, tf.TensorSpec], 
         output_node_names: List[str], 
         excludes: List[str]=[])

实参:

model_name
 模型名称。
sess
TensorFlow 会话的实例包含计算图和初始化的变量。
input_specs
 此映射的键是基线模型的输入节点名称。
output_node_names
目标输出节点名称。
excludes
跳过剪枝的节点名称。

返回:IterativePruningRunner 的实例

ana(self,
    eval_fn: Callable[[tf.compat.v1.GraphDef], float], 
    gpu_ids: List[str]=['/GPU:0'], 
    checkpoint_interval: int = 10) -> None:

实参

eval_fn
此函数用于评估剪枝进程的中间结果。需返回浮点。
gpu_ids
该字符串列表表示要运行评估的器件。
checkpoint_interval
此方法会实现高速缓存机制,并保存每次执行 checkpoint_interval 评估的结果。

返回:无

prune(self,
      sparsity: float=None, 
      threshold: float=None, 
      max_attemp: int=10) -> Tuple[Mapping[str, TensorProto], Mapping[str, np.ndarray]]:

有两种剪枝模式:基于 FLOP 的方法和基于精度的方法,对应实参稀疏度和阈值。这两个实参 不得 同时设为 None。实参稀疏度的优先级高于阈值。

实参:

sparsity
此比率表示前向传递中模型的浮点计算量的缩减情况。
threshold
范围为 [0, 1]。表示剪枝后的计算图与原始计算图之间最大可接受的相对差值。
max_attemp
剪枝运行程序会以迭代方式查找最优剪枝策略,执行 max_attemp 个步骤后无论如何都会返回结果。

返回:

shape_tensors
该字符串表示 NodeDef 映射。键为 graph_def 中 node_def 的名称,需更新后方可获取精简计算图。值为目标 node_def 内容掩码。
masks
表示对应于变量的字符串到阵列映射。
get_slim_graph_def(self, 
                   shape_tensors: Mapping[str, TensorProto]=None, 
                   masks: Mapping[str, np.ndarray]=None) -> tf.compat.v1.GraphDef:

实参:

shape_tensors
prune 方法返回的字符串到 NodeDef 映射。
masks
表示对应于变量的字符串到阵列映射。该对象是从 prune 方法获取的。

返回:冻结的精简 graph_def。