More than code

More Than Code
The efficiency of your iteration of reading, practicing and thinking decides your understanding of the world.
  1. 首页
  2. 未分类
  3. 正文

cs336 lec10 inference

2025年11月9日 6点热度 0人点赞 0条评论

Inference workload

  • 首先需要知道的是,inference和train的区别是train只需要执行一次,而模型部署之后,inference执行的次数会非常多。所以我们需要让inference变的高效,才能让模型更加“经济”

  • 优化inference的一些应用点:

    • 日常使用:chatbots, code completion
    • 模型评估:llm as judge等
    • test-time compute,更多的thinking
    • RL,做sample generation和打分
  • Metrics:
    • TTFT,首token延迟
    • Latency/Throughput,生成token的速度/延迟

单独的inference也能养活一些公司,可以专门做开源模型的推理,比如Together, Fireworks, DeepInfra

  • 拓展阅读
    • vLLM: https://www.youtube.com/watch?v=8BaEwoTk8XI
    • TensorRT: https://nvidia.github.io/TensorRT-LLM/overview.html
    • TGI(HF): https://huggingface.co/docs/text-generation-inference/en/index

DeepMind讲Scale model的一本书:https://jax-ml.github.io/scaling-book/
里面有这样一幅图,画了Transformer核心的两个计算模块的计算开销

后续的分析来自于这样的指导原则:
* 分析出模型的计算强度(每传输一个byte,需要计算多少次)。
* 分析加速器的计算强度,比如用GPU的FLOPS除以带宽。
* 当模型的计算强度高于加速器的计算强度时,就是compute-bounded,否则则是memory-bounded

Arithmetic intensity of inference

在没有任何优化的情况下,Decode的推理开销是正比于T的3次方的。因为每次计算都有一个T方,然后需要推理T次。

因为causal mask的原因,AttentionBlock中的KV值只依赖过去的内容,可以在后续的Decode阶段被复用。从而可以大幅度降低计算的强度。所以后续的讨论中都会带上KV Cache
* 怎么有种RNN传递hidden state的感觉。只不过是transformer可以做并行计算

推理根据workload,一般分为两个阶段:
* Prefill,给定prompt,输入到Transformer中,把KV值都计算好。关键点在于可以并行,和训练阶段是一样的
* Decode,串行的生成每一个token

根据上面的计算公式,S是我们已经有的KV值的长度(也就是Prompt的长度),而T是生成的token的长度。
* 对于Prefill来说,T=S
* 对于Decode来说,T=1

MLP:


Attention:

上面数字的一个直观的理解是:
* 对于MLP,不同的输入可以共享参数矩阵,同时token都是独立计算。所以计算强度和batch size,以及输入token的数量是相关的。
* 对于Attention,不同的输入有不同的KV值,每次把KV从Cache拉出来都是针对这一个请求计算的,无法和batch内的其他请求共用。
* 对于prefill阶段,还可以用大的seq length,来在attention计算时候的一个大的矩阵

推理阶段计算强度总结:
* Prefill is compute-limited, generation is memory-limited
* MLP intensity is B (requires concurrent requests), attention intensity is 1 (impossible to improve)

然后来看一下整个Transformer的吞吐/延迟

可以通过这里的公式来简单的预估,这里内存只考虑了KV Cache的大小,并没有考虑一些中间结果的写入和写回。不过Decode阶段,应该KV Cache的传输是主导的。

带入Batch size,可以发现:
* batch size越大,吞吐越高。但是延迟也就越高
* batch size过大还需要考虑数据是否可以放入内存

Lossy methods

然后来看一些优化的手段,这一节主要讲的是一些Lossy的方法,也就是可能对效果有影响的方法。这里基本都是在模型层面做,而非系统层面(一般来说系统层面也不会做成lossy的)

Reduce KV Cache size

既然推理阶段的带宽瓶颈主要在KV Cache的传输上,那减少KV Cache的大小,就可以减少对带宽的压力,从而提升吞吐。


* 首先是GQA,Query head不变,多个query head共享一个KV head,分成若干组。
* 这里因为KV的数量少了,所以对应的投影矩阵的数量也需要变少



然后是DeepSeek的MLA,核心思路是把KV Cache投影到一个压缩的向量中,减少缓存的大小。有点LoRA的感觉
* 有一个缺点是无法和RoPE兼容,所以需要单独处理
* 更详细的单独出一篇文章,这里两句话说不太清楚。
实验效果上看,效果比MHA更好,同时推理速度更快


CLA,核心思路是在多个层之间共享KV值。图中所示,当前层计算出来的KV下一层可以继续复用,减少计算量和需要缓存的KV的数量


Local Attention:
* 关注一个窗口内的token,而不是关注全局的token,减少kv cache的大小
* 全部使用local attention会影响效果,所以一般方法是穿插local attention和global attention
* 同时可以和CLA混合使用,如上面的右图,隔6层是一个full attention(global attention),中间则是复用kv的local attention

State-space models

具体什么是state-space我还没太看懂。不过这里的核心思想是,既然这种自回归的模型有效率问题,那可以尝试换一些其他架构的模型

这块我还没看太懂,就不多说了。比如linear attention就是这一类。

还有diffusion model

quantization

降低精度来提高效率,也是一个比较常见的手段了。这里有一篇相关的文章:https://apxml.com/posts/llm-quantization-techniques-explained

这块展开也是有很多,也不详细说了,能知道quantization的核心点就是:
* 低精度的类型有更高的计算能力和更低的内存使用
* 需要保证准确度
* 一般的方式是把模型参数重新放缩到对应的精度上,但是需要小心异常的数值影响整体放缩的效果。
* 对部分异常数值可以单独用高精度的方式处理

model pruning


核心思路是把评估模型的重要性,把一部分模型裁剪掉(比如剪掉某些层),然后再来修复他。
* 评估模型各个组件的重要性(比如裁剪head,layer)
* 裁剪不重要的组件
* 使用原始的模型去蒸馏小的模型(可能是去拟合每一层的输出?)

详细的方法还需要再看看,这里有一篇paper:Compact Language Models via Pruning and Knowledge Distillation

Lossless methods

lossy的方法基本上都是对模型做更改,然后这里还有非lossy的,偏系统一些

speculative_sampling

Fast Inference from Transformers via Speculative Decoding
Accelerating Large Language Model Decoding with Speculative Sampling

核心思想是:
* 先使用小模型去生成一些token
* 然后用大模型做check,验证这些token生成的是否正确。这一步相当于prefill,并行度更高

核心原理是拒绝采样,我们有两个分布P/Q,希望在P上采样,来逼近Q:
* 对于生成的一个token xi,计算Q(xi) / P(xi),然后做一个正态分布的sample,如果小这个计算值,则接受这个token。
* 这里的直观理解就是,如果小模型认为概率高,但是大模型不这么认为。那么就很难接受这个token。如果小模型认为概率低,但是大模型认为概率高,则同样会接受这个token

这块拓展起来能做的还有很多,比如让小模型和大模型协作的更好,避免接受率很低。
还可以让小模型来吃大模型输出的feature

Dynamic workloads

最后一块是推理阶段的一些额外的问题。

https://www.usenix.org/system/files/osdi22-yu.pdf

首先是continuous batching,这里说的是,在推理阶段,用户的请求到达时间是不固定的,同时结束位置也是不固定的。

这里细抠还得看看上面的论文,简单讲就是,对于非attention层,输入和sequence无关,就把并发的用户请求都拼接到一起变成一个大矩阵做运算。对于attention层,每一个用户请求单独计算

第二块是paged attention


核心问题是GPU上没有os这种虚拟内存,分配用户的KV cache等在GPU上的内存时,会遇到碎片的问题。
* 所以这里相当于是在GPU上实现了一些内存分配的算法,搬过来一些os的东西
* 同时可以做类似COW的操作,相同的prompt输出不同的内容的时候,前面一段的prompt对应的kv cache就可以被复用

标签: 暂无
最后更新:2025年11月9日

sheep

think again

点赞
< 上一篇

文章评论

取消回复

COPYRIGHT © 2021 heavensheep.xyz. ALL RIGHTS RESERVED.

THEME KRATOS MADE BY VTROIS