SNIP: SINGLE-SHOT NETWORK PRUNING BASED ON CONNECTION SENSITIVITY
Contact me
- Blog -> https://cugtyt.github.io/blog/index
- Email -> cugtyt@qq.com
- GitHub -> Cugtyt@GitHub
本系列博客主页及相关见此处
ABSTRACT
现有的剪枝方法是在迭代中,通过启发式剪枝策略或者增加超参数完成,削弱了可用性。我们展示了一个新方法,可以在训练前的初始化阶段一次完成剪枝。我们引入了基于连接敏感的显著性指标来定位当前任务中的重要连接。减少了预训练和剪枝策略的复杂性,对于不同的网络结构也比较鲁棒。在剪枝后,稀疏的网络按照常规训练。我们的方法可以获得极为稀疏的网络,同时准确率得到了保持。我们的方法证明了保留的连接与给定的任务相关。
1 INTRODUCTION
由于剪枝是迭代优化的,现有的方法需要高昂的剪枝-再训练的循环,需要额外的超参数,实用性不好。
我们引入了显著性指标在训练前定位重要的连接,它是依赖于数据的。我们把不同尺度初始化下重要连接对损失函数的影响称为连接显著度。
我们的方法:
- Simplicity 在训练前做一次即可,没有额外的参数。
- Versatility 由于显著性指标选择的是结构性重要连接,所以可以用于不同的网络结构上。
- Interpretability 我们的方法在一个小批量上就可以确定那些重要的连接。通过变化剪枝需要的小批量,我们的方法可以验证保留的连接的确是重要的。
2 RELATED WORK
【略】
3 NEURAL NETWORK PRUNING
目标是找到一个稀疏的网络但是准确率要保留。
给定一个数据集$\mathcal{D}=\left\{\left(\mathbf{x}_{i}, \mathbf{y}_{i}\right)\right\}_{i=1}^{n}$
,和一个稀疏级别$\kappa$(非0的权重),可以得到如下优化问题:
$\ell(\cdot)$是损失函数,$\mathbf{W}$是参数,$m$是参数数量,$\vert\cdot \vert_{0}$
是$L_{0}$
正则。
传统的优化以上问题的方法是加入一个稀疏惩罚项,最近有人尝试使用随机版的映射梯度下降来解决上述受限优化问题。但是这些方法不如基于显著性的方法,因为需要大量的超参数调整。
另一方面,基于显著性的方法把上述问题看做选择性的移除荣誉参数。所以必须想到一个好的指标来定位冗余的连接。流行的指标包括权重的梯度,loss对于权重的Hessian:
\[s_{j}=\left\{\begin{array}{ll}{\left|w_{j}\right|}\qquad \text{for magnitude based} \\ {\frac{w_{j}^{2} H_{j j}}{2}} {\text { or } \frac{w_{j}^{2}}{2 H_{j j}^{-1}} \qquad \text{for Hessian based} }\end{array}\right. \qquad (2)\]这里对于连接$j$而言,$\mathcal{S}_{j}$
是显著性分数,$w_{j}$
是权重,$H_{j j}$
是Hessian矩阵的值,$\mathbf{H}=\partial^{2} L / \partial \mathbf{w}^{2} \in \mathbb{R}^{m \times m}$。Hessian矩阵通常不是对角化的也不是正定的,不适用于计算大型网络。
这些指标依赖于权重的大小范围,需要预训练,而且对结构选择敏感。例如不同的归一化层对权重的范围影响不同,而且剪枝和优化步骤交替很多次导致了高昂的剪枝-再训练循环。
我们设计了一个指标可以直接衡量与数据相关的连接重要性。
4 SINGLE-SHOT NETWORK PRUNING BASED ON CONNECTION SENSITIVITY
我们希望能衡量每个连接的重要性(与权重无关),因此引入了一个指示变量$\mathbf{c} \in{0,1}^{m}$来表示参数$\mathbf{W}$的连接性。给定稀疏级别$\kappa$,公式1可以改写为:
\[\min _{\mathbf{c}, \mathbf{w}} L(\mathbf{c} \odot \mathbf{w} ; \mathcal{D})=\min _{\mathbf{c}, \mathbf{w}} \frac{1}{n} \sum_{i=1}^{n} \ell\left(\mathbf{c} \odot \mathbf{w} ;\left(\mathbf{x}_{i}, \mathbf{y}_{i}\right)\right) \\ \text { s.t. } \quad \mathbf{w} \in \mathbb{R}^{m}, \mathbf{c} \in\{0,1\}^{m}, \quad\|\mathbf{c}\|_{0} \leq \kappa \qquad (3)\]$\odot$表示Hadamard乘积。与公式1相比,我们加倍了网络要学习的参数数量,导致优化问题更加困难。但是由于我们区分了$\mathbf{W}$中要还是不要($\mathbf{c}$控制),我们也许可以通过损失函数衡量连接的重要性。
例如,$C_{j}$表示连接$j$保留$c_{j}=1$,还是减去$c_{j}=0$。移除连接$j$的影响可以通过以下来衡量:
\[\Delta L_{j}(\mathbf{w} ; \mathcal{D})=L(\mathbf{1} \odot \mathbf{w} ; \mathcal{D})-L\left(\left(\mathbf{1}-\mathbf{e}_{j}\right) \odot \mathbf{w} ; \mathcal{D}\right) \qquad (4)\]$\mathbf{e}_{j}$是指示向量的元素。
注意对每个$j \in{1 \ldots m}$计算$\Delta L_{j}$很昂贵,因为它需要m+1次前向传播。实际上,由于$\mathbf{C}$是二元的,$L$对于$\mathbf{C}$不是可微的,可以容易得到$\Delta L_{j}$尝试着去衡量损失函数对于连接$j$的影响。因此,放松$\mathbf{C}$的二元限制,$\Delta L_{j}$可以近似为$L$对于$c_j$的导数,记作$g_{j}(\mathbf{w} ; \mathcal{D})$。因此,连接$j$的影响可以写作:
\[\Delta L_{j}(\mathbf{w} ; \mathcal{D}) \approx g_{j}(\mathbf{w} ; \mathcal{D})=\left.\frac{\partial L(\mathbf{c} \odot \mathbf{w} ; \mathcal{D})}{\partial c_{j}}\right|_{\mathbf{c}=1}=\lim _{\delta \rightarrow 0}\left.\frac{L(\mathbf{c} \odot \mathbf{w} ; \mathcal{D})-L\left(\left(\mathbf{c}-\delta \mathbf{e}_{j}\right) \odot \mathbf{w} ; \mathcal{D}\right)}{\delta}\right|_{\mathbf{c}=\mathbf{1}} \qquad (5)\]【TODO】
6 DISCUSSION AND FUTURE WORK
SNIP方法简单可解释,可以在训练之前一次减去不相关的连接,可以在很多网络中使用无需修改。虽然SNIP产生了极其稀疏的模型,但是我们发现连接显著性衡量本身也很值得关注,因为它能在完全未训练的网络中定位出重要的连接。可以进一步探索。