wenmo8 发布的文章

题目:给你一个 32 位的有符号整数 x ,返回将 x 中的数字部分反转后的结果。如果反转后整数超过 32 位的有符号整数的范围 [−231,  231 − 1] ,就返回 0。

假设环境不允许存储 64 位整数(有符号或无符号)。

01
、long类型字符串转换法

虽然题目要求不允许使用64位整数,但是我们还是先使用最简单最直观的最符合思维逻辑的方式实现,然后以此打开思维寻找出更好的方法。

看到此题我第一反应就是直接把整数x转为字符串,然后直接调用字符串方法反转字符串,最后把反转后的字符串转换为long类型,同时比较是否再有效范围内即可。

当然这是大致解题思路,其中还有许多细节需要处理。

首先需要处理负号问题,如果是负数我们需要取其绝对值,然后再反转绝对值,而在取绝对值时需要注意int的最小值int.MinValue为-2147483648,而int.MaxValue最大值为2147483647,因此我们不能直接对int整数x直接取绝对值,而需要先把x转为long类型整数,不然会报错。

然后把绝对值反转成字符数组,同时判断正负号,如果是负数则需要在字符数组前加上负号’-’。

最后直接使用long.TryParse方法对字符数组进行转换,同时判断其有效范围,并输出结果,具体代码如下:

//字符串long
public static int ReverseLongString(int x)
{
    //是否为负数
    var isNegative = x < 0;
    //取绝对值,必须要先转为long类型
    //否则int.MinValue -2147483648会报错
    var abs = Math.Abs((long)x);
    //把值转为字符串并反转,得到字符集合
    var reversedChars = abs.ToString().Reverse();
    if (isNegative)
    {
        //如果是负数则在字节数组前加入负号'-'
        reversedChars = reversedChars.Prepend('-');
    }
    //转换为long类型,并且是有效的int值,则返回结果
    if (long.TryParse(reversedChars.ToArray(), out long result) && result >= int.MinValue && result <= int.MaxValue)
    {
        return (int)result;
    }
    return 0;
}

02
、int类型字符串转换法

既然题目中要求我们不能使用64位整数long类型,那么我们是否可以直接int类型进行转换呢?

根据上一个解法,同样的思路,我们先把int整数x转为字符串,然后直接使用int.TryParse进行转换即可。这样可以利用转换失败过滤掉所有溢出的值。

具体实现代码如下:

//字符串int
public static int ReverseIntString(int x)
{
    //把值转为字符串,并去掉负号'-',最后反转,得到字符集合
    var reversed = x.ToString().TrimStart('-').Reverse();
    //转换为int,成功则返回
    if (int.TryParse(reversed.ToArray(), out int result))
    {
        //根据原始符号,返回结果
        return x < 0 ? -result : result;
    }
    return 0;
}

03
、数学方法

上面两个方法本质还是通过字符串转换,就效率来说是比较低的,因此我们可以通过数学计算的方式来实现其转换。

如上图,我们以把12345反转为例,详解讲解反转过程。

(1)通过12345%10,获取到尾数字5,而尾数字又将作为新数值的首数字,新数值5;

(2)通过1234%10,获取到尾数字4,新数值为5*10+4=54;

(3)通过123%10,获取到尾数字3,新数值为54*10+3=543;

(4)通过12%10,获取到尾数字2,新数值为543*10+2=5432;

(5)通过1%10,获取到尾数字1,新数值为5432*10+1=54321;

其中还有一个重点是关于溢出的判断,前两个方法本质都是通过转换方法触发异常来拦截溢出,而在此方法中我们可以在实时计算的过程中直接判断出是否溢出。

因为32位有符号整数x的取值范围是-2147483648<=x<=2147483647,如果要保证反转过来不溢出,则在处理到第九位的时候整个值应该在(-214748364,214748364)之间,不然结果肯定会溢出,而有效的int值首位数字最大为2,即使反转过来也不可能大于7或小于-8,因此只需要判断第九位数字是否合法即可完成溢出判断。

具体实现代码如下:

//数学方法
public static int ReverseMath(int x)
{
    var result = 0;
    while (x != 0)
    {
        //判断溢出,因为输入的是32位的有符号整数 x
        //即输入的 -2147483648<=x<=2147483647
        //所以翻转后的最后一位是1或2并不会导致溢出
        //因此只需判断九位数 > int.MaxValue / 10 或者 < int.MinValue / 10
        if (result < int.MinValue / 10 || result > int.MaxValue / 10)
        {
            return 0;
        }
        //获取当前末尾的数字
        var digit = x % 10;
        //去掉末尾数字
        x /= 10;
        //反转并累积结果
        result = result * 10 + digit;
    }
    return result;
}

04
、基准测试

我们做个简单的基准测试,分别对三种方法进行100万次随机生成整数值在范围-2147483648至2147483647之间的值进行测试,得到如下结果。

通过上图不难发现数学方法在整体性能方面远远高于字符串处理方式。


:测试方法代码以及示例源码都已经上传至代码库,有兴趣的可以看看。
https://gitee.com/hugogoos/Planner

【引言】

本文将介绍一个名为“颜文字搜索器”的开发案例,该应用是基于鸿蒙NEXT平台构建的,旨在帮助用户快速查找和使用各种风格的表情符号。通过本案例的学习,读者可以了解如何在鸿蒙平台上进行数据处理、UI设计以及交互逻辑的实现。

【环境准备】

• 操作系统:Windows 10

• 开发工具:DevEco Studio NEXT Beta1 Build Version: 5.0.3.806

• 目标设备:华为Mate60 Pro

• 开发语言:ArkTS

• 框架:ArkUI

• API版本:API 12

【开发思路】

1. 数据模型设计

为了表示单个表情符号的信息,我们定义了一个 EmoticonBean 类,它包含了表情符号的风格(style)、类型(type)、表情符号本身(emoticon)及其含义(meaning)。此外,还添加了一个布尔属性 isShown 来追踪表情符号是否应该显示给用户,这有助于在搜索时动态更新列表。

2. UI 组件与布局

应用的主界面由一个 Index 组件构成,它负责整体布局的设计。界面上部是一个搜索框,允许用户输入关键词来过滤表情符号列表。下方则以表格形式展示了所有符合条件的表情符号,每一行包括四个部分:风格、类型、表情符号和含义。为了提升用户体验,当用户点击某个表情符号或其含义中的高亮文本时,会触发相应的点击事件,并输出日志信息。

3. 数据加载与处理

表情符号的数据来源于一个本地 JSON 文件 (emoticons.json),该文件在组件初次渲染之前被读取并解析为 EmoticonBean 对象数组。每次用户修改搜索框中的内容时,都会调用 splitAndHighlight 方法对每个表情符号的含义进行分割,并检查是否存在匹配的关键字。如果存在,则设置 isShown 属性为 true,否则为 false,以此控制表情符号是否在界面上显示。

4. 搜索与高亮

splitAndHighlight 函数用于将表情符号的含义按关键字分割成多个片段,并返回这些片段组成的数组。对于包含关键字的片段,会在界面上以不同的颜色高亮显示,从而直观地指出匹配的部分。此外,此函数还会根据是否有匹配项来决定表情符号是否可见,确保只有相关的表情符号才会展示给用户。

5. 用户交互

为了让用户有更好的操作体验,我们在界面上实现了触摸事件监听,当用户点击非输入区域时,自动关闭键盘。这样既保证了界面整洁,又简化了用户的操作流程。

【完整代码】

数据源:src/main/resources/rawfile/emoticons.json

https://download.csdn.net/download/zhongcongxu01/90126325

代码

// 引入必要的工具库 util 用于文本解码等操作
import { util } from '@kit.ArkTS'
// 引入 BusinessError 类用于处理业务逻辑错误
import { BusinessError } from '@kit.BasicServicesKit'
// 引入 inputMethod 模块用于管理输入法行为
import { inputMethod } from '@kit.IMEKit'
 
// 定义一个可以被观察的数据模型 EmoticonBean 表示单个表情符号的信息
@ObservedV2
class EmoticonBean {
  // 定义风格属性,并初始化为空字符串
  style: string = ""
  // 定义类型属性,并初始化为空字符串
  type: string = ""
  // 定义表情符号本身,并初始化为空字符串
  emoticon: string = ""
  // 定义含义属性,并初始化为空字符串
  meaning: string = ""
 
  // 构造函数,允许在创建对象时设置上述属性
  constructor(style: string, type: string, emoticon: string, meaning: string) {
    this.style = style
    this.type = type
    this.emoticon = emoticon
    this.meaning = meaning
  }
 
  // 定义是否显示的表情符号状态标记,默认为 true,使用 @Trace 装饰器使其可追踪变化
  @Trace isShown: boolean = true
}
 
// 使用 @Entry 和 @Component 装饰器定义 Index 组件作为应用入口
@Entry
@Component
struct Index {
  // 定义一个状态变量 textInput 用于存储搜索框中的文本内容,默认为空字符串
  @State private textInput: string = ''
  // 定义一个状态变量 emoticonList 用于存储表情符号列表,默认为空数组
  @State private emoticonList: EmoticonBean[] = []
 
  // 定义线条颜色属性
  private lineColor: string = "#e6e6e6"
  // 定义标题背景色属性
  private titleBackground: string = "#f8f8f8"
  // 定义文本颜色属性
  private textColor: string = "#333333"
  // 定义基础填充大小
  private basePadding: number = 4
  // 定义线条宽度
  private lineWidth: number = 2
  // 定义单元格高度
  private cellHeight: number = 50
  // 定义列权重比例
  private weightRatio: number[] = [1, 1, 5, 4]
  // 定义基础字体大小
  private baseFontSize: number = 14
 
  // 定义一个方法 splitAndHighlight 用于分割并高亮表情符号含义中的关键词
  private splitAndHighlight(item: EmoticonBean, keyword: string): string[] {
    let text = item.meaning // 获取表情符号的含义文本
    if (!keyword) { // 如果没有关键词,则直接返回整个文本,并显示该表情符号
      item.isShown = true
      return [text]
    }
    let segments: string[] = []; // 用于存储分割后的文本片段
    let lastMatchEnd: number = 0; // 记录上一次匹配结束的位置
    while (true) { // 循环查找关键词在文本中的位置
      const matchIndex = text.indexOf(keyword, lastMatchEnd); // 查找关键词出现的位置
      if (matchIndex === -1) { // 如果找不到关键词,将剩余文本加入到segments中并退出循环
        segments.push(text.slice(lastMatchEnd));
        break;
      } else { // 如果找到关键词,将非关键词部分和关键词部分分别加入到segments中
        segments.push(text.slice(lastMatchEnd, matchIndex)); // 非关键词部分
        segments.push(text.slice(matchIndex, matchIndex + keyword.length)); // 关键词部分
        lastMatchEnd = matchIndex + keyword.length;
      }
    }
    // 如果有关键词出现,则设置表情符号为显示状态
    item.isShown = (segments.indexOf(keyword) != -1)
    return segments;
  }
 
  // 当组件即将出现在屏幕上时调用此方法,用于加载表情符号数据
  aboutToAppear() {
    // 从资源管理器中读取本地文件 emoticons.json 的内容
    getContext().resourceManager.getRawFileContent("emoticons.json", (err: BusinessError, data) => {
      if (err) { // 如果读取失败,打印错误信息
        console.error('getRawFileContent error: ' + JSON.stringify(err))
        return
      }
      // 创建一个文本解码器来将二进制数据转换为字符串
      let textDecoder = util.TextDecoder.create('utf-8', { ignoreBOM: true })
      let jsonString = textDecoder.decodeToString(data, { stream: false })
      let jsonObjectArray: object[] = JSON.parse(jsonString) // 将 JSON 字符串解析为对象数组
      for (let i = 0; i < jsonObjectArray.length; i++) { // 遍历对象数组,填充 emoticonList
        let item = jsonObjectArray[i]
        this.emoticonList.push(new EmoticonBean(item['s'], item['t'], item['e'], item['m']))
      }
      try {
        // 打印 emoticonList 到控制台以供调试
        console.info(`this.emoticonList:${JSON.stringify(this.emoticonList, null, '\u00a0\u00a0')}`)
      } catch (err) {
        console.error('parse error: ' + JSON.stringify(err))
      }
    })
  }
 
  // 定义 build 方法构建组件的UI结构
  build() {
    Column({ space: 0 }) { // 创建一个列容器,内部元素之间没有间距
      // 搜索框组件,绑定到 textInput 状态变量
      Search({ value: $$this.textInput })
        .margin(this.basePadding) // 设置外边距
        .fontFeature("\"ss01\" on") // 设置字体特征
      // 创建一个列容器用于表头
      Column() {
        Row() { // 创建一行用于放置表头
          // 表头文字:风格
          Text('风格')
            .height('100%') // 设置高度为父容器的100%
            .layoutWeight(this.weightRatio[0]) // 根据权重分配宽度
            .textAlign(TextAlign.Center) // 文本居中对齐
            .fontSize(this.baseFontSize) // 设置字体大小
            .fontWeight(600) // 设置字体粗细
            .fontColor(this.textColor) // 设置文本颜色
          // 分割线
          Line().height('100%').width(this.lineWidth).backgroundColor(this.lineColor)
          // 表头文字:类型
          Text('类型')
            .height('100%')
            .layoutWeight(this.weightRatio[1])
            .textAlign(TextAlign.Center)
            .fontSize(this.baseFontSize)
            .fontWeight(600)
            .fontColor(this.textColor)
          // 分割线
          Line().height('100%').width(this.lineWidth).backgroundColor(this.lineColor)
          // 表头文字:表情
          Text('表情')
            .height('100%')
            .layoutWeight(this.weightRatio[2])
            .textAlign(TextAlign.Center)
            .fontSize(this.baseFontSize)
            .fontWeight(600)
            .fontColor(this.textColor)
          // 分割线
          Line().height('100%').width(this.lineWidth).backgroundColor(this.lineColor)
          // 表头文字:含义
          Text('含义')
            .height('100%')
            .layoutWeight(this.weightRatio[3])
            .textAlign(TextAlign.Center)
            .fontSize(this.baseFontSize)
            .fontWeight(600)
            .fontColor(this.textColor)
        }.height(this.cellHeight).borderWidth(this.lineWidth).borderColor(this.lineColor)
        .backgroundColor(this.titleBackground) // 设置背景颜色
      }.width(`100%`).padding({ left: this.basePadding, right: this.basePadding })
 
      // 创建一个滚动容器 Scroll 包含表情符号列表
      Scroll() {
        Column() {
          // ForEach 循环遍历 emoticonList 数组,创建每一行代表一个表情符号条目
          ForEach(this.emoticonList, (item: EmoticonBean) => {
            Row() {
              // 显示表情符号的风格
              Text(item.style)
                .height('100%')
                .layoutWeight(this.weightRatio[0])
                .textAlign(TextAlign.Center)
                .fontSize(this.baseFontSize)
                .fontColor(this.textColor)
              // 分割线
              Line().height('100%').width(this.lineWidth).backgroundColor(this.lineColor)
              // 显示表情符号的类型
              Text(item.type)
                .height('100%')
                .layoutWeight(this.weightRatio[1])
                .textAlign(TextAlign.Center)
                .fontSize(this.baseFontSize)
                .fontColor(this.textColor)
              // 分割线
              Line().height('100%').width(this.lineWidth).backgroundColor(this.lineColor)
              // 显示表情符号
              Text(item.emoticon)
                .height('100%')
                .layoutWeight(this.weightRatio[2])
                .textAlign(TextAlign.Center)
                .fontSize(this.baseFontSize)
                .fontColor(this.textColor)
                .copyOption(CopyOptions.LocalDevice) // 允许复制到剪贴板
              // 分割线
              Line().height('100%').width(this.lineWidth).backgroundColor(this.lineColor)
              // 显示表情符号的含义,支持关键字高亮
              Text() {
                ForEach(this.splitAndHighlight(item, this.textInput), (segment: string, index: number) => {
                  ContainerSpan() {
                    Span(segment)
                      .fontColor(segment === this.textInput ? Color.White : Color.Black) // 根据是否是关键词设置字体颜色
                      .onClick(() => { // 设置点击事件监听器
                        console.info(`Highlighted text clicked: ${segment}`); // 打印点击的文本信息
                        console.info(`Click index: ${index}`); // 打印点击的索引信息
                      });
                  }.textBackgroundStyle({
                    color: segment === this.textInput ? Color.Red : Color.Transparent // 根据是否是关键词设置背景颜色
                  });
                });
              }
              .height('100%')
              .layoutWeight(this.weightRatio[3])
              .textAlign(TextAlign.Center)
              .fontSize(this.baseFontSize)
              .fontColor(this.textColor)
              .padding({ left: this.basePadding, right: this.basePadding })
            }
            .height(this.cellHeight)
            .borderWidth({ left: this.lineWidth, right: this.lineWidth, bottom: this.lineWidth })
            .borderColor(this.lineColor)
            // 根据表情符号的状态(是否显示)来决定其可见性
            .visibility(item.isShown ? Visibility.Visible : Visibility.None)
          })
        }.width(`100%`).padding({ left: this.basePadding, right: this.basePadding })
      }.width('100%').layoutWeight(1).align(Alignment.Top)
      // 触摸事件处理,当用户点击空白区域时,关闭键盘输入
      .onTouch((event) => {
        if (event.type == TouchType.Down) { // 如果是按下事件
          inputMethod.getController().stopInputSession() // 停止当前的输入会话
        }
      })
    }.width('100%').height('100%').backgroundColor(Color.White); // 设置容器的宽高和背景颜色
  }
}

前一阵多步RAG的风吹入了工业界,kimi推出了探索版本,各应用都推出了深度搜索,You.COM更是早就有了Genius的多步模式。其实都是类似multi-hop RAG的实现。之前学术界在讨论multi-hop RAG的时候总是给一些基于历史知识类的问题,什么某年诺贝尔奖的获奖人在哪读的大学呀,给人一种错觉就是这类问题现实世界里真的有人这么提问么?其实还真有!

这里举几个单步RAG效果可能不好的case,在碰到的很多场景里,多步RAG其实主要针对模糊指代的问题,包括

  1. 偏动态信息的主体指代:例如事件,产品,政策,现象
  • 华为最新型号的手机市场怎么看:需要先获取华为最新的手机型号
  • 最新出台的房地产政策专家解读:需要先获取最新的房地产政策
  • 请根据当前的市场成交情况,分析当前市场情绪:需要获取当前市场成交数据
  1. 偏静态知识的主体指代:例如分类,话题,主题,相关主体
  • 近三年全国各大电影节的获奖名单:需要先获取有哪些电影节,再逐个获取每个电影节的获奖名单
  • 光伏上下游产业链近期有哪些利好政策:需要先获取光伏上下游产业链节点
  1. 抽象指代:需要先验知识,专家经验,对前置问题的回答
  • 请依据美林时钟,判断从2013年到2023年间都经历哪些经济周期阶段:需要先获获取美林时钟的定义
  • 探讨我国粮食价格的形成机制,从市场供求关系、政策调控等多个因素分析价格变动的原因:需要先搜索粮食价格形成机制
  1. 多条件指代:因为条件过于复杂,需要多步缩小筛选范围
  • 哪些国家在内部制造业不景气的时候,通过设立贸易壁垒来解决失业率过高的问题
  • 哪些中央银行,曾经因为长时间通货膨胀居高不下,在经济不景气的时候还不得不提高提率水平
  1. 时间指代
  • 贵州茅台最新季报的资产负债情况:需要先获取当前时间,定位季报
  • 近一周异动板块有哪些:需要获取当前时间,并定位时间窗口

前面说单步RAG
可能
解决不好,因为以上的场景当你幸运的召回了正确的数据时,包括但不限于query改写拆解引入相关信息,检索部分解决了时效性问题,模型自身压缩知识的辅助等等,其实是有可能解决的,当然需要碰运气哈哈哈~~

例如华为手机刚发布大热,那你不需要获取华为最新的手机号,直接使用搜索引擎搜索“华为最新型号的手机市场怎么看”,这时搜索引擎已经帮你处理了热点的时效性问题,大概率你就能获得正确的答案。

再比如"光伏上下游产业链近期有哪些利好政策",可能把query拆解为光伏上游产业链利好政策+光伏中游产业链利好政策+光伏下游产业链利好政策,你不需要知道上中下游具体是啥,也是能检索到部分有效信息的。

但是!我们需要的是可以稳定解决复杂,多条件,模糊指代问题的方案!

所以下面我们会给出多步RAG的几个核心步骤和对比,再讨论几篇论文大致的实现方案,论文细节大家感兴趣可以自己去看。更多RAG query改写,召回,粗排,精排的多个步骤我们在前面的章节已经说过,这里就不提了~

方案对比

懒得看具体方案的小伙伴直接看对比吧,这里总结对比下多步RAG的几个核心模块,和几种实现方案

模块 方案1 方案2 方案3 对比
规划模块/COT 每次只基于上一步的检索规划下一步的局部规划方案 全局预规划 先全局预规划再进行修订 局部规划方案容易歪楼,有时会缺少整体视角;全局规划是否需要修订其实部分取决于拆分步骤的方式,如果拆分过程不和query耦合其实多数场景不用修订也是可以的
子Query生成 和规划等同(规划本身就是子Q) 和规划一起全局生成 依赖前面的的检索结果生成(全部or上一步) 方案2和3结合的方式更常见,依赖检索结果的Query提供更加specific的检索视角,而只依赖主Q拆分的子Q提供更加general的检索视角
推理 每一步独立推理拼接最后润色 获得所有检索结果后一体生成 每一步基于上一步的推理和新获得的检索内容向后续写 连贯性最好幻觉较低,但对模型长文本有更高要求的肯定是一体化生成方案类似Kimi;You.COM的生成结果更类似于多步推理再拼接润色;而对于一些超长文本生成续写的方案使用更多

在尝试过You.COM的Genius模式,Kimi的探索模式,智谱的深度搜索后,发现除了以上的核心模块,多步RAG还有以下几个可以进一步提升效果的方向

  1. Reflection: 复杂问题往往很难一步到位,通过对推理结果进行反思,发现遗漏的方向,然后新增一步检索推理进行补充,或者对原始推理修订

image

  1. 针对多跳问题设计更适配的Query生成方案:复杂query往往涉及多条件,或者多主体,因此在Query拆解上需要有更多的视角。几种Query生成视角包括
  • 相似改写:最常见的单步RAG主要使用
  • 角度拆解:问题独立的多个视角(多步RAG的问题拆解是串行的这里是并行的),单步RAG主要使用,例如

image

  • 抽象、系统思维:类似step-back prompt,通过对问题抽象,获取领域概念或先验知识,多步RAG的第一步query拆解常用,例如
    image
  • 简化假设:通过放宽问题假设来补充信息召回,多步RAG中多条件的复杂问题常用,例如

image

  • 实体遍历: 通过对第一步检索得到的实体进行遍历生成第二步检索的query,通过明确主体获得更丰富的主体信息,多步RAG的第二步常见,例如
    image
  • 历史视角:分析影响类问题,事件解读类问题,多引入历史视角来补充观点

几种实现

局部 + 生成:IRCOT

image

这里IRCOT的实现最为简单,我们那这篇论文作为基准论文。IRCOT的整个流程是

  1. Retrieve:用户Query进来直接去检索
  2. COT:检索内容作为上文,使用以下prompt进行COT推理,只保留COT推理的第一个段落
  3. Retrieve:使用上一步推理的句子直接作为Query取进行搜索
  4. COT:使用当前检索到的全部上文,之前COT推理完成的段落,再继续进行推理并保留第一个句子。
  5. 不断重复Retrieve,COT直到模型给出"The answer is",或者超过最大迭代步骤

IRCOT的几个特点包括

  • 没有全局的规划:每一步都是先检索再向前推理一步
  • 没有Query生成:Query直接来自上一步推理的COT,没有额外生成
  • 直接推理生成:COT是直接基于检索,和前面的推理结果进行继续生成
  • 每一步都独立推理: 不会修改前面的生成结果

以下是基于检索内容生成COT推理的prompt格式

Wikipedia Title: <Page Title>
<Paragraph Text>
...
Wikipedia Title: <Page Title>
<Paragraph Text>
Q: <Question>
A: <CoT-Sent-1> ... <CoT-Sent-n>

全局 + 生成:Search in the chain

image

对比前面的IRCOT,Search in the chain会预先生成全局规划,并且规划的步骤中增加了子问题的拆解生成。每一步检索后,都根据检索重新生成新的全局规划。以下是Search in the Chain的实现步骤

  1. SearchChain:针对问题生成全局的SearchChain,包括问题拆解的多个子query,并且每个子问题会让模型判断能否直接回答
  • 能:直接给出[Answer]
  • 不能: 给出[Unsolved Query]
  1. Retrieve:把以上的全局SearchChain按结构拆分成多个步骤,第一步的query先去进行检索
  2. Revise or Generate:根据第一个节点的检索结果,如果当初模型判断这个节点子问题无法回答(Unsolved Query)则根据检索进行生成,如果当初模型判断可以回答但答案和检索不一致,则根据检索进行修订,两种逻辑的prompt不同,会根据searhChain选择不同的prompt拼接到模型的上文
  3. Revise SearchChain: 然后新的上文,让模型重新生成新的SearchChain,再接着遍历下一个节点进行步骤2~3,直到结束

SearchChain的几个特点包括

  • 全局预规划
  • Query预生成:Query就是全局规划中拆分的子问题
  • 允许基于前一步的检索结果调整全局规划,但其实之所以需要调整,就是因为query和规划是一起做的,所以当中间推理答案错误的时候,下一步的query生成也会存在问题,所以需要重新生成

以下是Search in the chain构建search chain的prompt

image

全局 + 修订:RAT

image

对比Search Chain,RAT也是生成全局规划,但是把query生成的步骤,放到了检索之后,同时把基于检索重新生成的方案,替换成了基于检索对初始回答进行修正。RAT的整体流程是

  1. COT:用户query进来先让模型进行COT推理,然后把COT按\n\n分成多个段落(step)
  2. ToQuery:针对第1个段落的推理生成检索Query
  3. Retrieve:基于检索query发起检索
  4. Revise: 使用检索结果,对第1个段落进行修订
  5. TOQuery:针对第1个修订段落+第2个原始段落生成检索Query
  6. Retrieve:基于检索query发起检索
  7. Revise: 使用检索结果,对前面的所有段落进行修订
  8. 遍历各个段落重复ToQuery,Retrieve,Revise的步骤直到结束

RAT的几个特点包括

  • 全局预规划
  • 串行Query生成,并使用全部历史信息:每一步Query生成,都是用前面生成的全部信息。更不容易丢失核心主体信息,但同时可能会让Query过于宽泛不够具体。
  • 修订而非生成:使用检索结果修改原始推理而非直接生成,能更多保留模型压缩的知识效果,但是存在内容检索不全,修订后的答案还是有错误存在的可能。
  • 每一步都修订前面的所有内容:成本更高,但可能会有更好的连贯性,但也有可能因为多次修订而引入幻觉

以下分别为第一步COT回答的Prompt,query生成的prompt和基于检索内容的COT进行

image

image

image

想看更全的大模型论文·微调预训练数据·开源框架·AIGC应用 >>
DecryPrompt

ChatGPT生成的文章摘要

这篇博客记录了作者在家中使用Pascal显卡运行大型模型时遇到的挑战和解决方案。随着本地大型模型性能的提升,作者选择使用vllm库进行推理。然而,作者遇到了多个技术难题,需要自行编译vllm和PyTorch,以支持Pascal架构的显卡。编译过程中,作者深入研究了显卡不支持的问题,特别是在量化矩阵乘法计算中发现性能瓶颈。最终,解决了性能问题,让性能提升了43倍。这次技术探索不仅解决了具体问题,还为作者提供了深入学习和扩展其他相关技术的机会,同时也展示了LLM在整个过程中提供的帮助。文章结尾,作者总结了经验并提出了进一步研究的方向。

背景

家里有张Pascal架构的显卡【划重点,后面要考】,最近发现本地大模型的性能在蹭蹭往上涨,于是开始研究下是否能在本地跑大模型。

之前我就了解vllm库,vllm的推理速度还是很快的,并且我之前还给vllm提交过一个PR,对vllm比较熟悉,所以我选择了使用vllm来进行推理。

选择结束之后就开始了漫长的抗争之路,期间着实遇到了很多问题,也学到了很多知识,故写此文以记录。

第一关:下载安装

当时无知的我以为安装是一件很简单的事情,以前使用vllm,直接
pip install vllm
,不仅会帮忙安装好vllm,pytorch,还会帮忙下载对应的cuda库,自己啥都不用操心。

这次的安装也如以前一样顺利,

安装完后就是选择模型了,选择模型的话,对于消费级显卡来说,显存占用是一个主要的考量因素,你得先跑起来。获取模型的显存占用的方式有两种:

  1. 计算模型需要占用的显存大小,比如一个7B的模型,它的参数量是7,000M个,一个float16的参数占2个字节,所以需要
    7,000M *2B=14GB
    的显存,除了参数外,还要考虑存储KV缓存,以及样本在中间传输时的值,量化元信息(如果涉及量化的话),所以需要留一些buffer。

  2. 另外一个获取显存占用的方式是直接用这个
    工具
    [1]
    ,输入模型在huggingface上的名称,然后选择精度,就可以看到模型占用的显存大小了。

    model memory usage

    需要注意的是,这里同样需要预留buffer,这上面的显存大小是纯模型本身的大小,量化的模型尤其要注意,需要考虑量化元数据带来的显存占用。

这样看下来,我这张12G显存的显卡,顶多只能跑一个7B-int8的模型,为了能跑稍微大一点的上下文,我最终选择了Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4的模型(经过项目的实测,Qwen模型现在在中文开源领域确实很不错)。

兴奋地下载完的模型后,
噩梦
在启动vllm server的时候开始了。

迎面而来的是第一个错误是:

RuntimeError: Error in model execution (input dumped to /tmp/err_execute_model_input_20241211-200011.pkl): CUDA error: no kernel image is available for execution on the device

这个问题去stackoverflow
[2]
了一下,大概率是vllm编译的时候没有支持对应的显卡架构,还记得重点么,没错,大概率就是不支持Pascal架构,我去官方文档
[3]
上看了一下,确实没有发现Pascal的显卡支持,支持矩阵长这样,没有Pascal架构呀:

vllm support matrix

没办法了,那就尝试自己编译vllm,看看能不能解决这个问题。

第二关:vllm编译

编译vllm

一开始编译的时候感觉还挺简单的,直接照着vllm的文档来,文档就只有一行命令
pip install -e .
,事情肯定没有这么简单,编译出错了:

CMake Error at CMakeLists.txt:252 (cuda_archs_loose_intersection):
        cuda_archs_loose_intersection Function invoked with incorrect arguments for
        function named: cuda_archs_loose_intersection

252行是这么写的:

cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS})

中间经过了大量时间的定位,最终找到了问题所在,主要就是vllm设置了一个支持的显卡架构(其实它使用了算力来表示架构,算力和架构有对应关系
[4]
):

set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")

只支持到7.0算力,而Pascal架构是6.1算力,所以最终CUDA_ARCHS就为空,所以就报错了。

那简单呀,我直接给CUDA_SUPPORTED_ARCHS加上6.1就行了,然后重新编译...

这次编译很顺利,编译完成之后,我就继续兴奋地启动vllm了,不幸的是,又一次报了
Cuda error: no kernel image is available for execution on the device
错误。

于是我继续Google,找到了这么一个github的issue
[5]
,issue说这种情况是显卡不受支持了,需要自己编译(后面我自己测试了一下pytorch,其实我的pytorch是可以使用的,至于这里为什么报错,后续再研究研究吧),于是我就屁颠屁颠地去开始编译pytorch了。

编译pytorch

pytorch的编译就复杂很多了,不像vllm的编译命令,pytorch分了很多步。

先是要安装一堆前置工具:
- CuDNN
- cmake, ninja
- requirements.txt
- mkl-static, mkl-include
- magma-cuda121
- triton

这些工具安装都还算顺利,要么照着说明安装,要么就是conda或pip安装,最后的triton就是一个make。

这里有几个坑:

  1. pytorch要求先export CMAKE_PREFIX_PATH,并且给了个命令,检查一下执行完后的命令,有可能conda的路径没有找对,需要自己手动指定一下。
  2. cmake一开始会找不到cudnn,需要将cudnn-version.h(直接用find找一下自己安装的cudnn-version.h在哪)文件拷贝或link到cuda的include目录下。

编译完成!

下面重新编译一次vllm,由于我们需要使用自己编译的pytorch,所以需要执行一下
python use_existing_torch.py
,vllm会帮我们把pytorch从依赖里删除掉,然后执行
pip install -r requirements-build.txt
,安装一下依赖,最后执行
pip install -e . --no-build-isolation
,这样安装的时候,vllm就不会再去安装这部分依赖了。

中间如果出现
version 'GLIBCXX_3.4.30' not found
的错误,我是把我安装的gcc的libstdc++.so.6软链到conda的lib目录就行了。

strings /usr/lib/x86_64-linux-gnu/libstdc++.so.6 | grep GLIBCXX_3.4.30

检查一下libstdc++.so.6是否包含GLIBCXX_3.4.30,如果包含,则软链到conda的lib目录下。

ln -s /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CONDA_PREFIX_1}/lib/libstdc++.so.6

编译完成!

再次满怀期待地启动vllm server,不出意外地又报错了,这次报错是没找到xformers,这个是因为vllm默认是不带注意力后端的,因为它也不知道你用什么注意力后端,所以需要自己安装一下。安装的时候发现它依赖了pytorch并且去下载了pytorch,那要不还是自己编译一把吧。

xformers页面介绍中支持Pascal架构,所以安装起来很丝滑,一行命令即可:

pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers

启动vllm server!

终于告一段落了,vllm server终于启动了,没有任何报错,我成功地看到了
Loading model weights took 5.2035 GB
。(这里可以印证我之前说的,量化的模型在考虑上量化元数据后,显存占用变大了很多,从计算得到的3.5GB,变成了5.2GB)

你以为故事到这就结束了?不不不,现在才是故事的开始。日志到了
Loading model weights took 5.2035 GB
就卡住了,我等了很久,发现它一直在卡在这。

第三关:定位性能问题的根因

初见端倪

出现这样的状况后,我是一点头绪都没有,只能像无头苍蝇一样,在vllm的Python代码里多打一些断点试试看了,在疯狂打了几十个断点之后,终于定位到卡哪了,vllm默认会先做一次profile run,来告诉你一些基本信息:

Memory profiling results: duration=11.82 seconds, total_gpu_memory=11.88GiB, initial_memory_usage=6.15GiB, peak_torch_memory=6.54GiB, memory_usage_post_profile=6.20GiB, non_torch_memory=1.05GiB, kv_cache_size=2.50GiB, gpu_memory_utilization=0.85.

因为这里需要进行模型推理,所以卡住了,这时候我才意识到,看一下nvidia-smi看看显卡是否在工作其实就能知道它确实是在跑模型代码(虽然我一开始也有点意识到,却一直没往这个方面上想,毕竟再慢也不至于这么慢)。事实证明,卡住的时候,显卡确实在工作,所以问题很明显了,就是因为我的显卡推理速度“太慢”导致的。于是我就把
max-model-len
设置成了100,看看是否能够跑出结果来。等待了很长的时候后,服务真的启动了。

速度这么慢我是万万没有想到的,只能先换台机器测一下看怎么样,用了一台A6000的机器,发现人家一瞬间就启动了,那很明显了,问题就是只有我这边很慢。

初步定位问题

有了方向之后,那要做的事情就比较简单了,因为我自己编译了pytorch、xformers以及vllm,所以我需要一个个地排查。

先在pytorch官网上找到了跑
benchmark
[6]
的文档,分别在A6000机器、我的机器上自己编译的pytorch以及直接用pip install的pytorch上跑了一下,发现pytorch基础的性能是不差的。

然后使用xformers的
benchmark
[7]
,同样测试了一下,发现xformers的性能也是ok的。

那问题多半就出在vllm了,由于我不确定到底问题出在什么地方,以及我大概率确定基础库是没啥问题的,所以我打算把整个模型推理的各个步骤都记录一下执行时间,来看看具体是什么地方出问题了,按照28原则,问题大概率出在20%的地方。

接下来就是想办法记录时间了,我自己没有特别好的思路,所以就请教了一下LLM,LLM给了我一个思路,可以使用pytorch的
register_forward_pre_hook

register_forward_hook
来记录时间。它给的代码很粗糙直接使用time库来记录时间,而且只能记录一层模型。所以我就“稍”作修改,改成了递归地访问每一层模型,并且用cuda的
Event
(当然这个也是从LLM那问出来的)来记录时间。

时间记录的代码写完了,接下来就是运行一下,看看问题出在哪了。下面是我运行后跑出来的结果,各位来找找看觉得哪里有问题?

model: 134811.72338464856 ms
  model.embed_tokens: 37.62428665161133 ms
  model.layers: 134773.90933799744 ms
    model.layers.0: 4777.431374847889 ms
      model.layers.0.input_layernorm: 1.620192050933838 ms
      model.layers.0.self_attn: 673.7694255411625 ms
        model.layers.0.self_attn.qkv_proj: 411.43023681640625 ms
        model.layers.0.self_attn.rotary_emb: 0.1632319986820221 ms
        model.layers.0.self_attn.attn: 4.729087829589844 ms
        model.layers.0.self_attn.o_proj: 257.4468688964844 ms
      model.layers.0.post_attention_layernorm: 0.1900160014629364 ms
      model.layers.0.mlp: 4101.85174125433 ms
        model.layers.0.mlp.gate_up_proj: 2740.14697265625 ms
        model.layers.0.mlp.act_fn: 0.8391680121421814 ms
        model.layers.0.mlp.down_proj: 1360.8656005859375 ms

不得不说134s才跑完profile真的是离谱,然后确实就是28原则,问题就出在了4个地方,分别是:

  • model.layers.0.self_attn.qkv_proj
  • model.layers.0.self_attn.o_proj
  • model.layers.0.mlp.gate_up_proj
  • model.layers.0.mlp.down_proj

这几个地方耗时都明显不正常,人家attention的计算才花了4ms,怎么这些操作要花几百甚至上千ms。

作为对比,我去查看了一下A6000机器上的结果:

model: 7459.573736906052 ms
  model.embed_tokens: 265.0838928222656 ms
  model.layers: 7192.4459400177 ms
    model.layers.0: 259.46213555336 ms
      model.layers.0.input_layernorm: 1.3496320247650146 ms
      model.layers.0.self_attn: 145.4847927093506 ms
        model.layers.0.self_attn.qkv_proj: 129.69778442382812 ms
        model.layers.0.self_attn.rotary_emb: 1.3486080169677734 ms
        model.layers.0.self_attn.attn: 3.180543899536133 ms
        model.layers.0.self_attn.o_proj: 11.257856369018555 ms
      model.layers.0.post_attention_layernorm: 2.0490241050720215 ms
      model.layers.0.mlp: 110.57868671417236 ms
        model.layers.0.mlp.gate_up_proj: 69.62483215332031 ms
        model.layers.0.mlp.act_fn: 4.104191780090332 ms
        model.layers.0.mlp.down_proj: 36.84966278076172 ms

结果很明显了,确实就是刚刚那几个地方的问题,其他地方的耗时基本上都差不多,有些甚至有领先(这个感觉应该属于误差)。

ok,知道问题了就去看看代码吧。

通过Python源码定位问题

经过一番研究,最终我把问题锁定到了量化计算上面,因为所有出问题的点都执行了量化的矩阵乘法计算。从网上搜了一张Qwen的架构图
[8]
,我把耗时长的点都用红框标出来了。

Qwen architecture

从中我们可以看到,这些地方都执行了没有量化的输入和量化后的weight之间的矩阵乘法计算。

vllm的代码里则对应了:

class ColumnParallelLinear(LinearBase):
    ...
    def forward(self, input_):
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
        assert self.quant_method is not None
        output_parallel = self.quant_method.apply(self, input_, bias) ## 就是这行进行了量化矩阵乘法
        if self.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

class RowParallelLinear(LinearBase):
    ...
    def forward(self, input_):
        ...
        assert self.quant_method is not None
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self,               ## 就是这行进行了量化矩阵乘法
                                                  input_parallel,
                                                  bias=bias_)
        ...

        return output, output_bias

由于我使用的是GPTQ量化模型,所以继续跟进需要去找的quant_method是GPTQ相关的。
跟进到
self.quant_method.apply
:

class GPTQLinearMethod(LinearMethodBase):
    ...
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
        reshaped_x = x.reshape(-1, x.shape[-1])

        output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
                               layer.scales, layer.g_idx,
                               layer.exllama_state == ExllamaState.READY,
                               self.quant_config.weight_bits)
        if bias is not None:
            output.add_(bias)
        return output.reshape(out_shape)

这里很明显问题就是gptq_gemm的计算(GEMM表示General Matrix Multiplication,通用矩阵乘法),继续:

def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
              b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
              b_g_idx: torch.Tensor, use_exllama: bool,
              bit: int) -> torch.Tensor:
    return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
                                  b_g_idx, use_exllama, bit)

哎,最终还是得去看cuda代码么!!!

小插曲

这里想说一下GPTQ的名字,大家一看到可能会觉得它和GPT有关系,其实不是的,它这算是蹭GPT的热度,GPTQ的全称是Post-Training Quantization for Generative pre-trained transformers,确实是硬蹭的。Post-Training Quantization,指的是训练后量化,所以它是一种在模型训练完之后,不再继续训练,单纯对权重和/或激活值进行量化的方法,而GPTQ是对PTQ的一种。

由于要去看cuda的源码,我对此没有很强的信心,我一没看过cuda源码,二不了解量化计算是什么样的,所以我就去紧急补课了一下,在网上找了个量化计算的视频
[9]
来看,这个视频讲得很详细,对量化感兴趣的同学可以去看一下。看完视频过后我还不过瘾,我想弄清楚GPTQ的量化数学原理(GPTQ有一套完善的数学推理),只看了它的前身OBS、OBC、OBQ,在看GPTQ本身的时候,想到,我已经了解得足够多了,再看下去有点浪费时间了,还是回归主线先把。

感兴趣的同学可以参考下面2个链接,OBC/OBQ的论文本身写得也挺友好的,也可以看看:

  1. https://readpaper.feishu.cn/docx/OPP2dTuXAoaO0oxWhQAcC05Wnpc
  2. https://zhuanlan.zhihu.com/p/646210009
  3. https://arxiv.org/abs/2208.11580

通过cuda源码定位问题

接下来就是跟踪cuda源码了,通过搜索gptq_gemm找到对应的cuda源码:

torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                        torch::Tensor b_gptq_qzeros,
                        torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
                        bool use_exllama, int64_t bit) {
  const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
  auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
  at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
  at::Tensor temp_dq = torch::empty(
      {b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options);

  vllm::gptq::gemm_half_q_half_cuda(
      at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(),
      (const uint32_t*)b_q_weight.data_ptr(),
      (const uint32_t*)b_gptq_qzeros.data_ptr(),
      (const half*)b_gptq_scales.data_ptr(),
      b_g_idx.device().is_meta() ? NULL : (const int*)b_g_idx.data_ptr(),
      (half*)c.data_ptr(), (half*)temp_dq.data_ptr(),
      c.size(0),              // m
      c.size(1),              // n
      a.size(1),              // k
      b_gptq_qzeros.size(0),  // group number
      use_exllama, bit);
  return c;
}

主要就是gemm_half_q_half_cuda这个函数,这个函数是GPTQ的量化矩阵乘法计算,a是输入,b_q_weight是量化后的权重,b_gptq_qzeros是公式里的Z,b_gptq_scales是公式里的S,然后use_exllama是是否使用exllama库。

由于use_exllama后续会影响到分支逻辑,所以先检查一下use_exllama是否为true。从这里的代码一直往上翻查,可以看到use_exllama是从config中读取的,qwen2.5的config中设置的是true。

继续跟进代码:


void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
                           const uint32_t* b_q_weight,
                           const uint32_t* b_gptq_qzeros,
                           const half* b_gptq_scales, const int* b_g_idx,
                           half* c, half* temp_dq, int size_m, int size_n,
                           int size_k, int groups, bool use_exllama, int bit) {
  bool use_reconstruct;
  if (use_exllama) {
    use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) ||
                       (bit != 8 && size_m > MAX_Q_GEMM_ROWS));
  } else {
    // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so
    // we disabled them for now.
    use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS);
  }
  if (use_reconstruct) {
    // Reconstruct FP16 matrix, then cuBLAS
    if (use_exllama) {
      reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
                          temp_dq, size_k, size_n, groups, bit);
    } else {
      reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
                       temp_dq, size_k, size_n, groups, bit);
    }

    const half alpha = __float2half(1.0f);
    const half beta = __float2half(0.0f);
    cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k,
                &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n);
} else if (use_exllama) {
    // Quantized matmul
    int max_chunks = size_m / BLOCK_M_SIZE_MAX;
    int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
    int last_chunk_size = size_m - last_chunk;

    if (max_chunks) {
      gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
                                 b_g_idx, c, last_chunk, size_n, size_k,
                                 BLOCK_M_SIZE_MAX, groups, bit);
    }

    if (last_chunk_size) {
      gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight,
                                 b_gptq_qzeros, b_gptq_scales, b_g_idx,
                                 c + last_chunk * size_n, last_chunk_size,
                                 size_n, size_k, last_chunk_size, groups, bit);
    }
  } else {
    gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
                         c, size_m, size_n, size_k, bit);
  }

就是这部分代码,虽然现在看来比较明确它主要是走了use_reconstruct=True的分支,但是当时着实看了我很久的时间,要搞清楚走了哪个分支,得先知道这里的size_m代表着什么,它其实表示着输入a的行数,也就是seq_len*batch_size,而vllm在profile的时候,会使用到max_token_len大的seq_len。

大部分应该都是大于MAX_Q_GEMM_ROWS(=50)的,所以大部分是走了use_reconstruct=True的分支。这里我并没有深入研究reconstruct_exllama和reconstruct_gptq之间的差异点在哪,之后可以研究一下。

通过Nvidia的工具包定位问题

虽然代码大概看完了,但是我还是不知道到底是什么函数出问题了呀,那就只能用老法子了,要么打印,要么用profile工具。所以我就问了问GPT,它给我推荐了Nsight Compute,这是Nvidia出的一个工具,可以用来分析cuda程序的性能。吭哧吭哧学习了一下怎么用,然后现实给了我一顿暴击,Nsight Compute不支持Pascal架构,它的2019的版本才支持,但是2019的版本和现在的cuda版本又不兼容,尴尬。。。

不过幸运的是,在学习使用Nsight Compute的时候,我发现了Nsight System,这个也是Nvidia出的一个工具,可以用来分析cuda程序,看CPU和GPU联动的时候,问题出在哪,虽然不会像Nsight Compute那样详细地分析GPU的各个执行过程,但它能简单地分析cuda内核函数的耗时,这个正好是我现在需要的。

上结果:

alt text

图中有两个关键信息:

  1. 大部分的耗时都在2个内核函数上,就是
    maxwell_hgemm_128x128

    maxwell_hgemm_128x64
  2. 在执行这俩函数前,都在执行reconstruct_exllama内核函数。

这样的话就比较容易定位了,就是看reconstruct_exllama后面执行了什么,那不就是
cublasHgemm
么。

和cublasHgemm较劲

经过一番搜索后,我了解了这个函数是啥,然后我就有点楞住了,啊?凭啥?这个是CuBLAS的函数,是Nvidia写的专门用来做向量和矩阵计算的,这怎么会有问题呢?这怎么能有问题呢?

为了验证它,我让GPT帮我写了个比较大的矩阵乘法并计算1000次来验证,结果确实是它的问题,执行起来很慢很慢,在A6000的机器上结果几乎是秒出,而我这边就会卡很久很久。

在这里我卡壳了好久,不知道这种情况下该咋办,感觉Pascal显卡就是该入土了,甚至想放弃了。后面想到,pytorch和xformers的性能不是没啥问题么,那肯定是有法子解决的。

于是我想了一个尝试的路子,我能不能换个库?我就去搜索了一下有没有CuBLAS的替代库。问了下GPT,还真就让我找到了,它就是CUTLASS,一个开源的CuBLAS库。

于是我就吭哧吭哧地又去编译了一下CUTLASS,3.0版本开始的CUTLASS就不支持PASCAL了,所以我只能用2.11版本。编译起来倒是异常丝滑,没有任何问题,和最新的cuda也能兼容。

编译完成后,我还是按照老思路,先找找看它的profile工具,确实有这个工具,于是我就进行了一次profile,就是CUTLASS的这次profile,帮我找到了问题的根因,官方的profile示例给的是用sgemm kernel:
./tools/profiler/cutlass_profiler --kernels=sgemm --m=4352 --n=4096 --k=4096
,我这边测试下来很快5s左右就执行完了,性能指标看着也还行:

Runtime: 15.7136  ms
Memory: 12.4296 GiB/s
Math: 9295.45 GFLOP/s

当时我并不知道sgemm kernel的s表示什么,但我猜到了和精度相关,我一开始还猜是small(其实它表示单精度single-precision),就是精度很低,我就想,之前不是都是hgemm(半精度)么,我也来试试看它的profile是不是有这个kernel,这里纯属手贱,并不是想到了什么。但是就是这么一个意外,帮我找到了本次问题的根因。测试的结果是极其慢:

Runtime: 739.977  ms
Memory: 0.131972 GiB/s
Math: 197.391 GFLOP/s

我当时就在想,这差距也太大了吧,就算是small,也不应该small得这么厉害,能差这么多呀。我就又测了一下dgemm(双精度),结果和hgemm基本类似。

然后我就去确认了一下,sgemm表示的是单精度的运算。到这,我基本上能知道怎么回事了,大概率是Pascal架构不支持半精度的运算,导致计算效率很低。为了验证我这个想法,顺便作为学习,我去翻了Nvidia的官网,找了各个时期的架构白皮书,看了一下里面主要的显卡性能介绍。为了方便比较我先是让LLM帮我从各个白皮书里提取了性能信息,然后让它帮我输出json,我再用pandas将json转成了html方便我直观地对比。

这里给熟悉游戏显卡的同学稍微科普一下Nvidia的架构历史,从Maxwell开始:

  • Maxwell 架构
    • 发布时间:2014年
    • 游戏卡命名:GTX 9xx 系列,如 GTX 970, GTX 980
    • 数据卡命名:Tesla Mxx 系列,如 Tesla M40, Tesla M60
  • Pascal 架构
    • 发布时间:2016年
    • 游戏卡命名:GTX 10xx 系列,如 GTX 1070, GTX 1080, GTX 1080 Ti
    • 数据卡命名:Tesla Pxx 系列,如 Tesla P100
  • Volta 架构
    • 发布时间:2017年
    • 游戏卡:N/A
    • 数据卡命名:Tesla Vxx系列,如 Tesla V100
  • Turing 架构
    • 发布时间:2018年
    • 游戏卡命名:RTX 20xx 系列,如 RTX 2070, RTX 2080, RTX 2080 Ti; GTX 16xx 系列如 GTX 1660, GTX 1660 Ti(不包含RT核的变体)
    • 数据卡命名:Tesla Txx 系列,如 Tesla T4
  • Ampere 架构
    • 发布时间:2020年
    • 游戏卡命名:RTX 30xx 系列,如 RTX 3070, RTX 3080, RTX 3090
    • 数据卡命名:A100, A30
  • Ada Lovelace 架构
    • 发布时间:2022年
    • 游戏卡命名:RTX 40xx 系列,如 RTX 4070, RTX 4080, RTX 4090
    • 数据卡命名:L4
  • Hopper 架构
    • 发布时间:2022年
    • 游戏卡命名:N/A
    • 数据卡命名:H100
  • Blackwell 架构
    • 发布时间:2024年
    • 游戏卡命名:N/A
    • 数据卡命名:B100

alt text

可以看到,Pascal架构的P100并没有fp16的支持, 而要有fp16支持的前提也是tensor core,Pascal架构是没有tensor core,只有cuda core的。然后也能发现,为什么说4090的推理性能能强过A100,因为它的各个算力指标都好于A100,A100强的是它显存大,显存带宽大,有SXM的支持,显卡之间的互联带宽高,所以在训练上有巨大的优势。

这下百分百确定问题所在了,没有fp16的支持,计算能力自然就很弱了。

第四关:优化性能

接下来就是改代码了,我的第一个想法是直接改成fp32的计算,这样计算速度就有保障了。但我还是决定去问一下LLM,看它有什么好的建议。它给我的建议是使用cublasGemmEx函数,这个函数也是CuBLAS的函数,它允许我们的输入输出矩阵都是fp16的,但是在计算的时候,转换成fp32来进行计算。

最后的改动就是这样:

    // cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k,
    //             &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n);
    cublasGemmEx(
      cublas_handle,                // Handle
      CUBLAS_OP_N,                  // transa
      CUBLAS_OP_N,                  // transb
      size_n,                       // m
      size_m,                       // n
      size_k,                       // k
      &alpha,                       // alpha
      temp_dq,                      // A
      CUDA_R_16F,                   // A type
      size_n,                       // lda
      a,                            // B
      CUDA_R_16F,                   // B type
      size_k,                       // ldb
      &beta,                        // beta
      c,                            // C
      CUDA_R_16F,                   // C type
      size_n,                       // ldc
      CUDA_R_32F,                   // computeType (FP32 for accumulation)
      CUBLAS_GEMM_DFALT_TENSOR_OP   // algo (default with potential Tensor Core usage)
    );

结果就如标题所说,这一行代码的更改,让性能提升了43倍,现在再来看一下我之前的pytorch的耗时日志:

model: 3098.3325251191854 ms
  model.embed_tokens: 33.70710372924805 ms
  model.layers: 3064.419405385852 ms
    model.layers.0: 131.46515500545502 ms
      model.layers.0.input_layernorm: 0.6445760130882263 ms
      model.layers.0.self_attn: 30.52022334933281 ms
        model.layers.0.self_attn.qkv_proj: 20.16111946105957 ms
        model.layers.0.self_attn.rotary_emb: 0.16473600268363953 ms
        model.layers.0.self_attn.attn: 3.8500161170959473 ms
        model.layers.0.self_attn.o_proj: 6.344351768493652 ms
      model.layers.0.post_attention_layernorm: 0.22275200486183167 ms
      model.layers.0.mlp: 100.07760363817215 ms
        model.layers.0.mlp.gate_up_proj: 65.92633819580078 ms
        model.layers.0.mlp.act_fn: 0.9378560185432434 ms
        model.layers.0.mlp.down_proj: 33.213409423828125 ms
    model.layers.1: 115.83395344018936 ms
      model.layers.1.self_attn: 17.98700802028179 ms
        model.layers.1.self_attn.rotary_emb: 0.16617600619792938 ms

可以看到,vllm的profile的耗时,从134s降到了3s,性能整整提升了43倍呀!!!

终于可以用我的Pascal显卡来推理了,爽!!

总结

第1点

对于一些程序员新人来说,希望这次的经历能给你一个参考,我们可以从一个问题点(一个好的问题从哪来确实也挺看运气的,我这次的问题刚好就是一个很深的问题,但是有时候我们可以刻意去创造一个问题,比如之前我看spark源码的时候,就是想搞清楚一个job的启动过程到底是怎么样的,这样也算是自己提出的一个好问题了)开始,然后一直深挖下去,这样你就熟悉了从表面一直到内核的整个过程,然后你就可以选择在任意感兴趣的地方开枝散叶,就能熟悉一整个框架乃至领域了。

对于我自己来说,我接下来能研究的就有:

  • 再去研究一下GPTQ的量化过程,把数据原理完全搞懂,有机会的话自己可以跑一遍模型量化
  • 看看GGUF的量化是怎么做的
  • 看看GEMM具体是怎么计算的,有哪些点可以做来进行优化
  • 去看看xformers的注意力计算是怎么做的
  • 去看看vllm的kv cache是怎么做的
  • 也可以去学学cuda编程
  • ...

第2点

LLM在整个过程中起到了很大的作用,包括不限于:

  1. 解释一些源码
  2. 帮忙写部分测试用的代码
  3. 帮忙澄清一些概念
  4. 帮忙解释一些bug
  5. ...

所以,赶紧用起来吧!

第3点

没事别瞎折腾别人不支持的东西,人家不支持是有原因的,除非你有折腾的觉悟和兴趣。

参考资料

  1. https://stackoverflow.com/questions/75682385/runtimeerror-cuda-error-no-kernel-image-is-available-for-execution-on-the-devi
  2. https://docs.vllm.ai/en/latest/usage/compatibility_matrix.html
  3. https://developer.nvidia.com/cuda-gpus
  4. https://github.com/pytorch/pytorch/issues/31285
  5. https://pytorch.org/tutorials/recipes/recipes/benchmark.html
  6. https://github.com/facebookresearch/xformers/blob/main/BENCHMARKS.md
  7. https://blog.csdn.net/fan_fan_feng/article/details/138978901
  8. https://www.bilibili.com/video/BV17m411f7Cm?spm_id_from=333.788.videopod.sections&vd_source=68452628e4137592ea9efa4793a102a6

背景

由于 Natasha 及周边项目发版任务多,文件结构也不简单,之前一直使用基于 Github 管道脚本和 XUnit 来发版。这个方案对于发版环境与条件依赖性较强,且不够灵活,因此萌生出做一个本地管理 Nuget 发版工具的想法,取名为 Jester.

下载地址:
https://github.com/NMSAzulX/Jester.Tools.Nuget/releases/tag/1.0.0.0
若出现问题可在本篇文章下留言,或在仓储地址中
提交 ISSUE
.

运行环境

  1. 采用独立打包。
  2. Win64。

安全说明

目前不打算公开这种工具类的源代码,因此列出相关依赖项以保证用户自主评估安全风险。

依赖项列表:

  1. Microsoft.CodeAnalysis.CSharp
    Roslyn 的 C# 语言构建库, 作者:Microsoft
  2. NuGet.Versioning
    Semantic Versioning 的操作库,作者: Microsoft
  3. Spectre.Console
    控制台界面,作者: Patrik Svensson, Phil Scott, Nils Andresen, Cedric Luthi, Frank Ray.

Spectre.Console 项目地址:
https://github.com/spectreconsole/spectre.console
个人觉得不太好用,还有 BUG.

网络使用:

  1. 使用 "pl" 命令查看打包信息列表时,向 NUGET 请求包信息。
  2. 使用 "pack" 命令打包时,向 NUGET 请求包信息以做校验本地包与服务器包上的版本信息。
  3. 使用 "push" 命令推送包时,向用户指定源服务器或 NUGET 官方源推送包。

文件改变:

  1. 使用 "pack" 命令时,会更新或创建 csproj 文件的版本信息。
  2. 使用 "pack" 命令时,会更新或创建 csproj 文件中的 None Include 节点信息。

工具使用

界面与输入端

命令支持

输入命令 {?} 可以查看改工具支持的命令。

Nuget Key 操作

#查看当前 nuget key 列表
nl

#添加一个 nuget key 发布方案
na publishTest oy2nabcdefghiojxv http://公司的nuget地址
#不指定 source 使用 nuget 官方的地址
na publishTest oy2nabcdefghiojxv 

#删除发布方案
nd publishTest

Solution 操作

该工具以 sln 解决方案为单位,支持添加、删除、更新 NUGET KEY 等操作。

以下为命令使用案例:

#查看当前存在的解决方案
sl 

#增加一个解决方案
sa c://mysolution

#删除列表中第一个方案
sd 0

#让该方案在发布时使用某个 NUGET KEY
#锁定第一个方案
sc 0

检查与打包

该工具以 CHANGELOG.MD 文件为打包发版依据。

#查看当前解决方案列表
sl

#锁定你要操作的解决方案
#锁定列表中第一个解决方案 
#若没有 CHANGELOG 文件工具将自动生成一个文件
#CHANGELOG 中有相应的版本说明和使用案例
sc 0

#按照文件中的案例格式编辑被打开的 CHANGELOG 文件

#使用 pl 查看当前 CHANGELOG 文件与当前解决方案的匹配信息【非必要】
pl

#确认无误,打包
pack

#若不发布隐式 using 文件,则
pack --no-using

#若不合并 targets 文件,则
pack --no-combine

#若有多个 targets 文件,则
#1. 去掉文件中的 <project></project> 节点标签,以便 Jester 合并内容到同一个 <project> 节点下。
#2. 不要使用 Jester.Usings.targets 来命名文件,该文件预留给了 Jester 输出隐式命名空间用。
#3. 不要使用 Jester.Combine.targets 来命名文件,该文件预留给了 Jester 合并所有 targets 文件的输出文件使用。

#后续会根据反馈对 pack 策略做出调整。

pl 命令会自动拉去 nuget 官方记录的版本信息,local 为本地工程文件中的版本信息,next 为 CHANGELOG 被打包的版本信息

发布

#如果之前已经锁定解决方案则不用重新锁定了
#若被锁定的解决方案没有关联发布策略
su publishTest

#查看打包好的包 在 your_sln_folder/nugets 文件夹中

#发布
push

退出

按 Q 退出