PyTorch团队重写「分割一切」模型,比原始实现快八倍( 二 )



PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
具体来说(参考上图更容易理解,出现的变量名都在代码中) , 该研究发现在 SAM 的图像编码器中,有充当坐标缩放器(coordinate scalers)的变量 q_coords 和 k_coords,这些变量都是在 CPU 上分配和处理的 。然而 , 一旦这些变量被用来在 rel_pos_resized 中建立索引,这些索引操作就会自动的将这些变量移动到 GPU 上,这种复制会导致 GPU 同步 。为了解决上述问题,该研究注意到可以使用 torch.where 重写这部分内容来解决问题,如上所示 。
内核跟踪
在应用了这些更改之后,本文注意到单个内核调用之间有着显著的时间间隔,尤其在小批量(这里为 1)时更为突出 。为了更深入的了解这一现象,本文开始对批大小为 8 的 SAM 推理进行性能分析:
PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
在查看每个内核所花费的时间时,本文观察到 SAM 的大部分 GPU 时间都花费在逐元素内核(elementwise kernels)和 softmax 操作上 。
现在可以看到矩阵乘法的相对开销小了很多 。
PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
将 GPU 同步和 bfloat16 优化结合在一起,SAM 性能提高了 3 倍 。
PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
Torch.compile(+graph breaks 和 CUDA graphs)
本文发现在深入研究 SAM 的过程中有很多小的操作,他们认为使用编译器来融合操作有很大的好处,因而 PyTorch 对 torch.compile 做了以下优化:
  • 将 nn.LayerNorm 或 nn.GELU 等操作序列融合成一个单一的 GPU 内核;
  • 融合紧跟在矩阵乘法内核之后的操作,以减少 GPU 内核调用的数量 。
通过这些优化,该研究减少了 GPU 全局内存往返次数(roundtrips),从而加快了推理速度 。我们现在可以在 SAM 的图像编码器上尝试 torch.compile 。为了最大限度地提高性能,本文使用了一些高级编译技术:
PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
内核跟踪
PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
结果显示,torch.compile 工作得很好 。
PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
可以观察到 softmax 占了很大一部分时间,然后是各种 GEMM 变体 。以下测量的是批大小为 8 及以上的变化 。
PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
SDPA: scaled_dot_product_attention
接下来,本文又对 SDPA(scaled_dot_product_attention)进行了实验 , 研究的重点是注意力机制 。一般来讲,原生注意力机制在时间和内存上随序列长度呈二次方扩展 。PyTorch 的 SDPA 操作基于 Flash Attention、FlashAttentionV2 和 xFormer 的内存高效注意力原理构建 , 可以显着加快 GPU 注意力 。与 torch.compile 相结合,这个操作允许在 MultiheadAttention 的变体中表达和融合一个共同的模式 。经过一小部分更改后 , 现在模型可以使用 scaled_dot_product_attention 。
PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
内核跟踪
现在可以看到内存高效的注意力内核占用了 GPU 上大量的计算时间:
PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
使用 PyTorch 的原生 scaled_dot_product_attention,可以显著增加批处理大小 。下图为批大小为 32 及以上的变化 。
PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
之后,该研究又实验了 Triton,NestedTensor 、批处理 Predict_torch ,  int8 量化,半结构化 (2:4) 稀疏性等操作 。
例如本文使用自定义 positional Triton 内核,观察到批大小为 32 的测量结果 。
PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
使用 Nested Tensor,批大小为 32 及以上的变化 。
PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
添加量化后,批大小为 32 及以上变化的测量结果 。
PyTorch团队重写「分割一切」模型,比原始实现快八倍

文章插图
文章的最后是半结构化稀疏性 。该研究表示,矩阵乘法仍然是需要面对的一个瓶颈 。解决的办法是使用稀疏化来近似矩阵乘法 。通过稀疏矩阵(即将值归零)可以使用更少的位来存储权重和激活张量 。该研究将张量中哪些权重设置为零的过程称为剪枝 。剪枝掉较小的权重可以潜在地减小模型大小,而不会显着损失准确率 。


推荐阅读