大模型推理——MLA实现方案

news/2025/2/9 5:44:04 标签: python, 深度学习, pytorch, 语言模型

1.整体流程

先上一张图来整体理解下MLA的计算过程

2.实现代码

python">import math
import torch
import torch.nn as nn


# rms归一化
class RMSNorm(nn.Module):
    """

    """
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        hidden_states = hidden_states.float()
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.float()


def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed


# 旋转位置编码
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=1024):
        super(RotaryEmbedding, self).__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_seq_len).float().unsqueeze(1)
        freqs = t @ inv_freq.unsqueeze(0)
        freqs = torch.cat((freqs, freqs), dim=-1)

        self.register_buffer("cos_cached", freqs.cos())
        self.register_buffer("sin_cached", freqs.sin())

    def forward(self, q, k):
        cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)
        sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)
        return apply_rotate_pos_emb(q, k, cos, sin)


class MLA(nn.Module):
    def __init__(self,
                 dim,
                 n_heads,
                 q_lora_rank,
                 kv_lora_rank,
                 qk_nope_head_dim,
                 qk_rope_head_dim,
                 v_head_dim,
                 max_seq_len,
                 max_batch_size,
                 mode):
        super().__init__()
        self.dim = dim  # 隐藏层维度
        self.n_heads = n_heads  # 总头数
        self.q_lora_rank = q_lora_rank  # q低秩压缩到的维度
        self.kv_lora_rank = kv_lora_rank  # k/v低秩压缩到的维度
        self.qk_nope_head_dim = qk_nope_head_dim    # q/k不带旋转位置编码的维度
        self.qk_rope_head_dim = qk_rope_head_dim    # q/k带旋转位置编码的维度
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim  # q/k的总维度,不带旋转位置编码的维度加上带旋转位置编码的维度
        self.v_head_dim = v_head_dim  # value的维度,等于不带旋转位置编码的k维度
        self.mode = mode
        self.max_seq_len = max_seq_len
        self.max_batch_size = max_batch_size

        self.wq_a = nn.Linear(self.dim, self.q_lora_rank)  # q的降维矩阵
        self.q_norm = RMSNorm(self.q_lora_rank)
        self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)  # q的升维矩阵
        # 4096*128+128*4864 = 524,288 + 622592 = 1146880    4096*4864 = 19,922,944

        self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)  # k/v的降维矩阵
        # nn.Linear(self.dim, self.kv_lora_rank)
        # nn.Linear(self.dim, self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))  # k/v的升维矩阵

        self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)

        self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)  # 旋转位置编码
        # 没有矩阵融合
        if self.mode == 'naive':
            self.register_buffer('k_cache',
                                 torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.qk_head_dim),
                                 persistent=False)
            self.register_buffer('v_cache',
                                 torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.v_head_dim),
                                 persistent=False)
        # 有矩阵融合
        else:
            self.register_buffer('kv_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank),
                                 persistent=False)
            self.register_buffer('pe_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim),
                                 persistent=False)

    def forward(self, x, mask=None):

        bs, seq_len, _ = x.shape

        q = self.wq_a(x)  # [bs, seq_len, q_lora_rank]
        q = self.q_norm(q)  # [bs, seq_len, q_lora_rank]
        q = self.wq_b(q)  # [bs, seq_len, n_heads * qk_head_dim]
        q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim)  # [bs, seq_len, n_heads, qk_head_dim]
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
                                   dim=-1)  # q_nope shape:[bs, seq_len, n_heads, qk_nope_head_dim] q_pe shape:[bs, seq_len, n_heads, qk_rope_head_dim]

        kv = self.wkv_a(x)  # [bs, seq_len, kv_lora_rank + qk_rope_head_dim]
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim],
                               dim=-1)  # kv shape:[bs, seq_len, kv_lora_rank] k_pe shape:[bs, seq_len, qk_rope_head_dim]

        k_pe = k_pe.unsqueeze(2)  # k_pe shape:[bs, seq_len, 1, qk_rope_head_dim]   一层共享一个key
        q_pe, k_pe = self.rotary_emb(q_pe, k_pe)
        if self.mode == 'naive':

            q = torch.cat([q_nope, q_pe], dim=-1)  # * [bs, seq_len, n_heads, qk_head_dim]

            kv = self.kv_norm(kv)  # [bs, seq_len, kv_lora_rank)]
            kv = self.wkv_b(kv)  # [bs, seq_len, n_heads * (qk_nope_head_dim + v_head_dim)]
            kv = kv.view(bs, seq_len, self.n_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1)
            # k shape:[bs, seq_len, n_heads, qk_head_dim]
            self.k_cache[:bs, :seq_len, :, :] = k
            self.v_cache[:bs, :seq_len, :, :] = v
            # scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bs, :seq_len]) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
            scores = torch.matmul(q.transpose(1, 2),
                                  self.k_cache[:bs, :seq_len, :, :].transpose(1, 2).transpose(2, 3) / math.sqrt(
                                      self.qk_nope_head_dim + self.qk_rope_head_dim))
            scores = scores.transpose(1, 2)

        else:
            k_pe = k_pe.squeeze(2)
            wkv_b = self.wkv_b.weight  # [n_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
            wkv_b = wkv_b.view(self.n_heads, -1,
                               self.kv_lora_rank)  # [n_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]
            q_nope = torch.einsum("bshd,hdc->bshc", q_nope,
                                  wkv_b[:, :self.qk_nope_head_dim])  # q_nope shape:[bs, seq_len, n_heads, kv_lora_rank]
            # q*k(T) = x*wq*(c*wkv_b[:, :self.qk_nope_head_dim])(T) = x*wq*wkv_b[:, :self.qk_nope_head_dim](T)*c(T)    c为压缩后的k/v
            # wq*wkv_b[:, :self.qk_nope_head_dim](T)作为q的投影矩阵  c可以替代原先的k,这样就可以直接使用压缩后的k/v计算注意力了,kv_cache时也只需存储压缩后的k/v
            kv = self.kv_norm(kv)
            self.kv_cache[:bs, :seq_len, :] = kv  # kv shape:[bs, seq_len, kv_lora_rank]
            self.pe_cache[:bs, :seq_len, :] = k_pe  # k_pe shape:[bs, seq_len, qk_rope_head_dim]
            scores_nope = torch.einsum("bshc,btc->bsht", q_nope,
                                       self.kv_cache[:bs, :seq_len, :])  # bshc btc -> bshc bct -> bsht
            scores_pe = torch.einsum("bshr,btr->bsht", q_pe,
                                     self.pe_cache[:bs, :seq_len, :])  # bshr btr -> bshr bt1r -> bshr bthr -> bsht
            scores = (scores_nope + scores_pe) / math.sqrt(
                self.qk_nope_head_dim + self.qk_rope_head_dim)  # [bs, seq_len, n_heads, seq_len]

        if mask is not None:
            # mask shape:[bs, seq_len, seq_len]
            scores += mask.unsqueeze(2)

        scores = scores.softmax(dim=-1)

        if self.mode == 'naive':
            x = torch.einsum("bsht,bthd->bshd", scores,
                             self.v_cache[:bs, :seq_len])  # bsht,bthd -> bhst, bhtd -> bhsd -> bshd
        else:

            # scores * v = scores * c * wkv_b[:, -self.v_head_dim:]
            x = torch.einsum("bsht,btc->bshc", scores,
                             self.kv_cache[:bs, :seq_len])  # x shape:[bs, seq_len, n_heads, kv_lora_rank]
            x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])  # bshc, hdc -> bshc,dch -> bsdh -> bshd

        x = x.contiguous().view(bs, seq_len, -1)
        x = self.wo(x) 

        return x


if __name__ == '__main__':
    torch.manual_seed(0)
    torch.set_printoptions(precision=3, sci_mode=False)

    x = torch.randn(1, 4, 16)

    dim = 16
    n_heads = 2
    q_lora_rank = 10
    kv_lora_rank = 6
    qk_nope_head_dim = 8
    qk_rope_head_dim = 4
    v_head_dim = 8
    max_seq_len = 10
    max_batch_size = 4
    mode = 'none'

    mla = MLA(dim=dim,
              n_heads=n_heads,
              q_lora_rank=q_lora_rank,
              kv_lora_rank=kv_lora_rank,
              qk_nope_head_dim=qk_nope_head_dim,
              qk_rope_head_dim=qk_rope_head_dim,
              v_head_dim=v_head_dim,
              max_seq_len=max_seq_len,
              max_batch_size=max_batch_size,
              mode=mode)

    print(mla(x))
    print(mla.kv_cache)

参考资料:

https://zhuanlan.zhihu.com/p/16730036197

https://github.com/wyf3/llm_related/tree/main/deepseek_learn


http://www.niftyadmin.cn/n/5845605.html

相关文章

Qt实现简易视频播放器

使用Qt6实现简易音乐播放器,效果如下: github: Gabriel-gxb/VideoPlayer: qt6实现简易视频播放器 一、整体架构 该代码整体架构围绕着MainWindow类构建一个媒体播放器相关的应用程序。 主要组件 (一)界面组件&…

携手AWS,零成本在EKS上体验AutoMQ企业版

01 前言 AutoMQ是一款贯彻云优先理念来设计的 Kafka 替代产品。AutoMQ 创新地对 Apache Kafka 的存储层进行了基于云的重新设计,在 100% 兼容 Kafka 的基础上通过将持久性分离至 EBS 和 S3 带来了 10x 的成本降低以及 100x 的弹性能力提升,并且相比 Apa…

springcloud gateway 负载均衡

Spring Cloud Gateway的负载均衡是Spring Cloud生态系统中一个非常重要的功能,它使得微服务架构中的服务调用能够更加高效和均衡。以下是关于Spring Cloud Gateway负载均衡的详细解析: 一、Spring Cloud Gateway简介 Spring Cloud Gateway是一个基于Sp…

使用 Apifox、Postman 测试 Dubbo 服务,Apache Dubbo OpenAPI 即将发布

作者:何亮,Apache Dubbo Contributor Apache Dubbo OpenAPI 简介 设计背景 在微服务体系中,RPC 服务的文档管理、测试、调用协作一直都是影响研发效能的关键一环,这些难题通常是由于 RPC 的特性所决定的:RPC 服务的…

SpringSecurity:授权服务器与客户端应用(入门案例)

文章目录 一、需求概述二、开发授权服务器1、pom依赖2、yml配置3、启动服务端 三、开发客户端应用1、pom依赖2、yml配置3、SecurityConfig4、接口5、测试 一、需求概述 maven需要3.6.0以上版本 二、开发授权服务器 1、pom依赖 <dependency><groupId>org.springfr…

leetcode_深度遍历和广度遍历 100. 相同的树

100. 相同的树 给你两棵二叉树的根节点 p 和 q &#xff0c;编写一个函数来检验这两棵树是否相同。 如果两棵树在结构上相同&#xff0c;并且节点具有相同的值&#xff0c;则认为它们是相同的。 思路: (递归法) 返回True的情况: 两棵树都为空两棵树相同 返回False的情况: 两棵…

知识图谱智能应用系统:数据存储架构与流程解析

在当今数字化时代,知识图谱作为一种强大的知识表示和管理工具,正逐渐成为企业、科研机构以及各类智能应用的核心技术。知识图谱通过将数据转化为结构化的知识网络,不仅能够高效地存储和管理海量信息,还能通过复杂的查询和推理,为用户提供深度的知识洞察。然而,构建一个高…

RK3568上使用C++结合V4L2拉流,并RKMPP硬件编解码,并保存为MP4文件

在RK3568平台上使用C结合V4L2捕获视频流&#xff0c;并通过RKMPP进行硬件编码后保存为MP4文件&#xff0c;可以按照以下步骤实现&#xff1a; 1. 环境准备 硬件&#xff1a;RK3568开发板、摄像头模块。软件依赖&#xff1a; Linux内核支持V4L2。Rockchip MPP库&#xff08;RKM…