模型压缩
DP 模型的 Embedding net 网络数目是原子类型数目N的N2倍,随着原子类型增多,Embedding net 数目会快速增加,导致用于反向传播求导的计算图的规模会增加,成为 DP 模型做推理的瓶颈之一。如下我们对于一个五元合金系统在 DP 模型的推理过程的时间统计所示,对于 Embedding net 计算以及梯度计算的时间占比超过 90%,这存在大量的优化空间。Embedding net 的输入为一个Sij的单值,输出为m个值(m为 Embedding net 最后一层神经元数目)。因此,可以将 Embedding net 通过m个单值函数代替。我们在这里实现论文DP Compress中使用的五阶多项式压缩方法,同时我们也提供了基于 Hermite 插值方法的三阶多项式压缩方法供用户自由选择。在我们的测试中,当网格大小 dx=0.001 时,三阶多项式与五阶多项式能够达到相同的精度,详细测试数据见性能测试。
![proportion_time](/assets/images/proportion_time_inference-25985d56f499de5e8f3bc35fd68e53fe.png)
对于一个训练后 DP 模型做模型压缩,完整的模型压缩指令如下:
PWMLFF compress dp_model.ckpt -d 0.01 -o 3 -s cmp_dp_model
- compress 是压缩命令
- dp_model.ckpt为待压缩模型文件名称,为必须要提供的参数
- -d 为S_ij 的网格划分大小,默认值为0.01
- -o 为模型压缩阶数,3为三阶模型压缩,5为五阶模型压缩,默认值为3
- -s 为压缩后的模型名称,默认名称为“cmp_dp_model”
模型压缩之后,在 lammps 中做分子动力学模拟使用方式与标准的DP 模型相同。
我们在 Bulk 铜和五元合金体系上对 DP 模型 做了模型压缩,并在测试集上分别做了测试。结果如下图中所示,对于铜体系,我们加入了对二阶插值方法的精度对比,相比于三阶和五阶方法,二阶方法的精度达不到要求。
![cu_compress_dp_valid_abs_error](/assets/images/cu_compress_dp_valid_abs_error-8cb76d7bc3d8a93a2ca812bf8b60bcb9.png) 图1: Bulk铜体系DP模型二阶、三阶与五阶多项式压缩对比 | ![alloy_compress_dp_valid_abs_error](/assets/images/alloy_compress_dp_valid_abs_error-fdf339f257dedd29f312e29e63cc0cfc.png) 图2: 五元合金体系DP模型三阶与五阶多项式压缩对比 |
我们统计了五元合金体系下 DP 模型三阶多项式压缩以及未压缩时,在整个测试集上的推理时间。经过多项式压缩后明显减少了反向求导(autograd)时间,这是因为多项式方法能够显著减少 Embedding net 在 pytorch 自动求导时的计算图大小。
![alloy_compress_forward_time]() 图1: 五元合金体系三阶多项式压缩(dx=0.01)与未压缩对比 |
我们扫描全部训练集,得到sij的最大值,由于sij是原子i和j的三维坐标距离rij函数,当rij = rcut时取最小值。根据sij取值范围按照dx值等分为L份,则共有l+1个插值点,分别记为x1,x2,⋯,xl+1。在实际的使用中,由于训练集的不完备,可能存在一些sij值超出训练集之外,这里我们在上述网格之外,继续增加了sij到10×sij的网格,网格大小设置为10×dx。
对于每个[xl,xl+1)区间,采用如下的三阶多项式替代 Embedding net:
gml(x)=amlx3+bmlx2+cmlx+dml
这里m为 Embedding net 最后一层神经元数量,即 Embedding net 输出值数目,多项式的自变量x值应为sij−xl。在每个网格点上,都需要满足如下两个限定条件。
在每个网格点上限制如下条件。
多项式值与 Embedding net 输出值一致:
yl=Gm(xl)
多项式一阶导数与 Embedding net 对Sij的一阶导一致:
yl′=Gm′(xl)
解得对应系数为
aml=Δt31[(yl+1′+yl′)Δt−2h]
bml=Δt21[−(yl+1′+2yl′)Δt+3h]
cml=yl′
dml=yl
我们也实现了DP Compress中的五阶多项式压缩方法。
对于五阶多项式,对Sij的划分方法与五阶方法相同,采用如下的多项式代替 Embedding net:
gml(x)=amlx5+bmlx4+cmlx3+dmlx2+emlx+fml
注意:此时多项式的自变量x值应为sij−xl。在每个网格点上,都需要满足如下三个限定条件。
多项式值与 Embedding net 输出值一致:
yl=Gm(xl)
多项式一阶导数与 Embedding net 对Sij的一阶导一致:
yl′=Gm′(xl)
多项式二阶导数与 Embedding net 对Sij的二阶导一致:
yl′′=Gm′′(xl)
由此可得六个系数值分别为:
aml=2Δt51[12h−6(yl+1′+yl′)Δt+(yl+1′′−yl′′)Δt2]
bml=2Δt41[−30h+(14yl+1′+16yl′)Δt+(−2yl+1′′+3yl′′)Δt2]
cml=2Δt31[20h−(8yl+1′+12yl′)Δt+(yl+1′′−3yl′′)Δt2]
dml=21yl′′
eml=yl′
fml=yl
其中 h=yl+1−yl,Δt=xl+1−xl