tf_nndct1.IterativePruningRunner
__init__(self,
model_name: str,
sess: SessionInterface,
input_specs: Mapping[str, tf.TensorSpec],
output_node_names: List[str],
excludes: List[str]=[])
Arguments:
- model_name
- The name of the model.
- sess
- An instance of a TensorFlow session containing a graph and initialized variables.
- input_specs
- The keys of this mapping are input node names of the baseline model.
- output_node_names
- Target output node names.
- excludes
- The names of nodes that skip pruning.
Returns: Instance of IterativePruningRunner
ana(self,
eval_fn: Callable[[tf.compat.v1.GraphDef], float],
gpu_ids: List[str]=['/GPU:0'],
checkpoint_interval: int = 10) -> None:
Arguments
- eval_fn
- The function is to evaluate the intermediate results of the pruning process. Needs to return a float.
- gpu_ids
- A list of strings indicating devices to run evaluations.
- checkpoint_interval
- This method implements a cache mechanism and saves results for every checkpoint_interval evaluation.
Returns: None
prune(self,
sparsity: float=None,
threshold: float=None,
max_attemp: int=10) -> Tuple[Mapping[str, TensorProto], Mapping[str, np.ndarray]]:
There are two pruning modes: FLOPs-based and accuracy-based, corresponding to argument sparsity and threshold. These two arguments must not be None at the same time. Argument sparsity is more prioritized than the threshold.
Arguments:
- sparsity
- The ratio indicates the reduction in the amount of floating-point computation of the model in the forward pass.
- threshold
- Within range [0, 1]. Indicating the maximum acceptable relative difference in accuracy between the pruned graph and the original graph.
- max_attemp
- Pruning runner finds the optimal pruning strategy iteratively and returns after max_attemp steps anyway.
Returns:
- shape_tensors
- A string to NodeDef mapping. The keys are the names of node_defs in graph_def, which need to be updated to get a slim graph. The values are target node_def contents masks.
- masks
- A string-to-array mapping corresponding to variables.
get_slim_graph_def(self,
shape_tensors: Mapping[str, TensorProto]=None,
masks: Mapping[str, np.ndarray]=None) -> tf.compat.v1.GraphDef:
Arguments:
- shape_tensors
- A string to NodeDef mapping returned from the prune method.
- masks
- A string-to-array mapping corresponding to variables. This object is also obtained from the prune method.
Returns: A frozen slim graph_def.