0
点赞
收藏
分享

微信扫一扫

Spring @Transactional 注解

在Pytorch的2.2版本更新文档中,官方重点强调了通过实现FlashAtteneion-v2实现了对scaled_dot_product_attention约2X左右的加速。
在这里插入图片描述
今天抽空亲自试了下,看看加速效果是否如官方所说。测试前需要将Pytorch的版本更新到2.2及以上,下面是测试代码,一个是原始手写的Self-Attention的实现,一个是使用Pytorch官方的scaled_dot_product_attention接口:

import time
import torch
import torch.nn.functional as F


def main():
    repeat = 100
    device = torch.device("cuda:0")
    dtype = torch.float16

    query = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)
    key = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)
    value = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)
    scale_factor = 0.125

    ori_time_list = []
    for _ in range(repeat):
        torch.cuda.synchronize(device=device)
        time_start = time.perf_counter()
        # 原始Self-Attention实现
        res = torch.softmax(query @ key.transpose(-2, -1) * scale_factor, dim=-1) @ value
        torch.cuda.synchronize(device=device)
        time_end = time.perf_counter()
        ori_time_list.append(time_end - time_start)

    fa_time_list = []
    for _ in range(repeat):
        torch.cuda.synchronize(device=device)
        time_start = time.perf_counter()
        with torch.backends.cuda.sdp_kernel(enable_math=False):
            # 使用Pytorch官方提供的FA实现
            res_fa = F.scaled_dot_product_attention(query, key, value, scale=scale_factor)
        torch.cuda.synchronize(device=device)
        time_end = time.perf_counter()
        fa_time_list.append(time_end - time_start)

    diff = (res - res_fa).abs().max()
    ratio = [ori_time_list[i] / fa_time_list[i] for i in range(repeat)]
    avg_ratio = sum(ratio[1:]) / len(ratio[1:])
    print(f"max diff: {diff}")
    print(f"avg speed up ratio: {avg_ratio}")


if __name__ == '__main__':
    main()

执行以上代码,终端输出如下:

max diff: 0.00048828125
avg speed up ratio: 2.2846881043417118

这里使用的设备是RTX4070,跑了很多次发现确实加速2X左右,看来以后训练或者推理时可以考虑直接使用官方的scaled_dot_product_attention接口了。但是这里也发现了两个问题,一个是原始手写的Self-Attention的计算结果和直接调用scaled_dot_product_attention接口得到的结果差异有点大(注意,这里计算的Tensor都是FP16精度的),如果我切换到FP32精度差异会再小两个数量级。第二个问题是如果使用FP32的话实测没有明显加速,这个就很奇怪了,官方文档里并没有说专门针对FP16精度优化的(后面找了个A100的GPU试了下,发现FP32也是有加速的)。

举报

相关推荐

0 条评论