正文

SRAM(静态随机存取存储器)
HBM(显存)
FlashAttention算法核心思想:减少HBM(显存)的访问,将QKV切分为小块后放入SRAM中,计算完毕后_(矩阵乘法、mask、softmax、dropout)_,将计算结果从SRAM中写入到HBM中
核心方法:tiling, recomputation
1. tiling(平铺): 分块计算
因为Attention计算中涉及Softmax,所以不能简单的分块后直接计算。softmax操作是row-wise的,即每行都算一次softmax,所以需要用到
平铺算法来分块计算softmax。
【safe softmax】 原始softmax数值不稳定,为了数值稳定性,FlashAttention采用safe softmax。(也就是减去一个最大值再softmax)
2 recomputation(重新计算)
FlashAttention算法的目标:在计算中减少显存占用,从O(N²) 大小降低到线性,这样就可以把数据加载到SRAM中,提高IO速度。
解决方案:传统Attention在计算中需要用到Q,K,V去计算S,P两个矩阵,FlashAttention引入softmax中的统计量(m, l),结合output O和在SRAM中的Q,K,V块进行计算。