2024年7月

本文介绍基于
Python
语言,以一个
大文件夹
作为标准,对另一个
大文件夹
所包含的
子文件夹

文件
加以
查漏补缺
,并将查漏补缺的结果输出的方法。

首先,来明确一下本文所需实现的具体需求。现有一个
大文件夹
,其中包含了大量
子文件夹
,如下图所示。

image

此外,我们还有另一个
大文件夹
,其中同样包含了大量的
子文件夹
,如下图所示;从上图与下图中的紫色框可以看出,这是两个不一样的
大文件夹
。但同时,我们还知道这两个
大文件夹
中的
子文件夹
数量,以及每一个子文件夹的名称,几乎是完全一致的——但是下图所示的大文件夹较之上图,缺少了一些
子文件夹

我们现在希望实现的是,以第一幅图所示的
大文件夹
为标准,对第二幅图所示的
大文件夹
中的
子文件夹
加以查漏补缺,找出第二个
大文件夹
中缺少的
子文件夹
的名称,以及缺少的
子文件夹
的数量。

了解了具体需求,我们就可以开始代码的撰写。这里需要注意,本文比较的是两个
大文件夹

子文件夹
的差异;如果大家希望比较两个
大文件夹

文件
的差异,整体思路也都是一样的,也可以用本文提供的代码。

本文所用到的具体代码如下所示。

# -*- coding: utf-8 -*-
"""
Created on Tue Feb 21 17:12:47 2023

@author: fkxxgis
"""

import os

template_folder = r"E:\02_Project\01_Chlorophyll\Fishnet\ResultFolder"
new_folder = r"E:\02_Project\01_Chlorophyll\Fishnet\ResultFolder_AI"

folder_list = os.listdir(template_folder)
new_list = os.listdir(new_folder)

num = 0
for folder in folder_list:
    if folder not in new_list:
        num += 1
        print(folder, "is not in new folder!")
print("\n", num, " folder(S) is(are) not in new folder!", sep = "")

可以看到,代码整体也是非常简单的。首先,
template_folder
是我们作为标准的
大文件夹
,也就是本文开头第一幅图所示的
文件夹
;而
new_folder
则是需要对其中
子文件夹
加以查漏补缺的
大文件夹
,也就是本文开头第二幅图所示的
文件夹

首先,介绍一下代码的整体思路。

首先,我们基于
os.listdir()
函数,遍历标准
大文件夹
中的每一个
子文件夹
,获取每一个
子文件夹
的名称,并将其存放在一个列表中;接下来,我们通过同样的方式,获取待查漏补缺的
大文件夹
中的
子文件夹
名称,同样存放在一个列表中。接下来,我们即可开始对比两个
大文件夹

子文件夹
的数量差异。首先,设置一个变量
num
,作为
子文件夹
数量差异的计算变量;随后,通过一个
for
循环,依次取出标准
大文件夹

子文件夹
的名称,并在待查漏补缺的
大文件夹
对应的
子文件夹
名称列表中加以搜索;如果找不到当前名称的
子文件夹
,说明在第二个
大文件夹
中就少了这一
子文件夹
,因此需要将其名称输出,并在变量
num
中增加
1
。完成上述循环后,我们即可获得第二个
大文件夹
,也就是待查漏补缺的
大文件夹
中,所缺少的
子文件夹
的名称以及其数量。

其次,代码详细的逐句介绍如下。

第一部分,我们需要导入所需的
Python
内置模块
os
,其用于与操作系统进行交互,在本文中就是进行读取文件列表等操作。

随后,我们指定了一个文件夹路径,存储在变量
template_folder
中;该文件夹是我们作为标准的
大文件夹
,即本文开头第一幅图所示的
文件夹
。接下来,我们继续指定另一个文件夹路径,存储在变量
new_folder
中。该文件夹就是需要对其中
子文件夹
加以查漏补缺的
大文件夹
,也就是本文开头第二幅图所示的
文件夹

随后,使用
os.listdir()
函数获取作为标准的
大文件夹
中,所有的文件和文件夹的列表,并将其存储在变量
folder_list
中;同样的方法,使用
os.listdir()
函数获取另一个文件夹中的所有文件和文件夹的列表,并将其存储在变量
new_list
中。

接下来,我们初始化一个变量
num
,用于计数在模板文件夹中存在,但在新文件夹中不存在的文件夹的数量。随后,即可开始循环,遍历模板文件夹中的每个文件夹,并使用条件判断语句检查这个文件夹是否存在于新文件夹中——如果文件夹不在新文件夹中,则执行以下操作:第一步,将变量
num
的值增加
1
,用于计数不存在于新文件夹中的文件夹的数量;第二步,打印当前文件夹的名称,以及附加的文本信息。

最后,我们打印最终的结果,显示不存在于新文件夹中的文件夹的数量。

运行上述代码,将会得到如下所示的结果。

代码非常简单,到这里就结束了;如果大家还有其他需求,可以自行再扩充代码。例如,如果希望将待查漏补缺的
大文件夹
中缺少的
子文件夹
复制过来,则可以参考文章
Python结合文件名称将多个文件复制到不同路径下
中所提到的代码思路加以实现。

至此,大功告成。

煤矿安全大模型————矿途智护者

使用煤矿历史事故案例,事故处理报告、安全规程规章制度、技术文档、煤矿从业人员入职考试题库等数据,微调internlm2模型实现针对煤矿事故和煤矿安全知识的智能问答。

本项目简介:

近年来,国家对煤矿安全生产的重视程度不断提升。为了确保煤矿作业的安全,提高从业人员的安全知识水平显得尤为重要。鉴于此,目前迫切需要一个高效、集成化的解决方案,该方案能够整合煤矿安全相关的各类知识,为煤矿企业负责人、安全管理人员、矿工提供一个精确、迅速的信息查询、学习与决策支持平台。
为实现这一目标,我们利用包括煤矿历史事故案例、事故处理报告、安全操作规程、规章制度、技术文档以及煤矿从业人员入职考试题库等在内的丰富数据资源,通过微调InternLM2模型,构建出一个专门针对煤矿事故和煤矿安全知识智能问答的煤矿安全大模型。

更多优质内容请关注公号:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。

本项目的特点如下:

  • 支持煤矿安全领域常规题型解答,如:单选题、多选题、判断题、填空题等 (针对煤矿主要负责人及安管人员、煤矿各种作业人员)
  • 支持针对安全规程规章制度、技术等文档内容回答(如《中华人民共和国矿山安全法》、《煤矿建设安全规程》)
  • 支持煤矿历史事故案例,事故处理报告查询,提供事故原因详细分析、事故预防措施以及应急响应知识
类别 底座 名称 版本 下载链接 微调方法
对话模型 InternLM2-Chat-1_8B CoalMineLLM_InternLM2-Chat-1_8B V1.0 OpenXLab QLora
对话模型 InternLM2-Chat-7B CoalMineLLM_InternLM2-Chat-7B V1.0 OpenXLab QLora
对话模型 InternLM2-Math-7B CoalMineLLM_InternLM2-Math-7B V1.0 OpenXLab QLora
对话模型 InternLM2-Chat-20B CoalMineLLM_InternLM2-Chat-20B V1.0 OpenXLab QLora
对话模型 InternLM2-Chat-1_8B CoalMineLLM_InternLM2-Chat-1_8B-full V1.0 OpenXLab 全参微调
对话模型 CoalMineLLM_InternLM2-Chat-7B CoalMineLLM_InternLM2-Chat-7B-4bit V1.0 OpenXLab W4A16量化

1.架构图

2. 快速使用

2.1 本地Demo部署

git clone https://github.com/yaosenJ/CoalQA.git
cd CoalQA
conda create -n CoalQA python=3.10.0 -y
conda activate CoalQA
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install -r requirements.txt
cd web_app
streamlit run streamlit_app.py --server.address=127.0.0.1 --server.port 6006

2.2 在线体验

CoalMineLLM_InternLM2-Chat-1_8版体验地址:
https://openxlab.org.cn/apps/detail/milowang/CoalQAv1

3.详细指南

3.1 环境配置

git clone https://github.com/yaosenJ/CoalQA.git
cd CoalQA
conda create -n CoalQA python=3.10.0 -y
conda activate CoalQA
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install -r requirements.txt

3.2 数据构造

3.2.1 从规程规章制度等文件中生成InternLM2模型
Xtuner
训练数据格式

1.1 输入如下命令,会在data目录下生成CoalQA_data.json

python  Generate_QAdata.py  ../data/CoalQA_data.json

CoalQA_data.json内容结构形如:

[
  {
    "input": "请提供新建矿井立井井筒冻结法施工的具体要求。",
    "output": "新建矿井立井井筒冻结法施工需要遵守以下要求:冻结深度必须穿过风化带延深至稳定的基岩10m以上,第一个冻结孔必须全孔取芯,钻孔时必须测定钻孔的方向和偏斜度,偏斜度超过规定时必须及时纠正,冻结管必须采用无缝钢管并焊接或螺纹连接,开始冻结后必须经常观察水文观测孔的水位变化,并在确定冻结壁已交圈后才能进行试挖。"
  },
  ...
  
]

1.2 接上面生成json,处理成Xtuner可训练数据格式,运行下面命令即可

python format_internlm2.py  
[
    {
        "conversation": [
            {
                "system": "你是一个煤矿安全领域的知识达人,你对相关煤矿安全规章规程制度、技术等文档非常熟悉。请你专业正确地解答用户想问的煤矿安全相关问题。",
                "input": "请提供新建矿井立井井筒冻结法施工的具体要求。",
                "output": "新建矿井立井井筒冻结法施工需要遵守以下要求:冻结深度必须穿过风化带延深至稳定的基岩10m以上,第一个冻结孔必须全孔取芯,钻孔时必须测定钻孔的方向和偏斜度,偏斜度超过规定时必须及时纠正,冻结管必须采用无缝钢管并焊接或螺纹连接,开始冻结后必须经常观察水文观测孔的水位变化,并在确定冻结壁已交圈后才能进行试挖。"
            }
        ]
    },
    {
        "conversation": [
            {
                "system": "你是一个煤矿安全领域的知识达人,你对相关煤矿安全规章规程制度、技术等文档非常熟悉。请你专业正确地解答用户想问的煤矿安全相关问题。",
                "input": "煤矿安全规程中,对于新建矿井的立井井筒冻结法施工有哪些具体要求?",
                "output": "新建矿井立井井筒冻结法施工需要遵守以下要求:冻结深度必须穿过风化带延深至稳定的基岩10m以上,第一个冻结孔必须全孔取芯,钻孔时必须测定钻孔的方向和偏斜度,偏斜度超过规定时必须及时纠正,冻结管必须采用无缝钢管并焊接或螺纹连接,开始冻结后必须经常观察水文观测孔的水位变化,并在确定冻结壁已交圈后才能进行试挖。"
            }
        ]
    },
...
]

3.2.2 从整理好的题库生成InternLM2模型
Xtuner
训练数据格式

输入如下命令即可

python Generate_Question_bank.py

这里展示多选题生成

[
    {
        "conversation": [
            {
                "system": "你是一个煤矿安全领域的知识达人,你会解答很多题。用户给出一个多选题,你能从几个选项中,选择出多个正确选项。",
                "input": "掘进巷道在下列哪些情况下不能爆破()。\nA、掘进工作面或炮眼有突水预兆时\nB、探水孔超前距不够时\nC、空顶距超过规定时\nD、掘进工作面支架不牢固时",
                "output": "ABCD"
            }
        ]
    },
    ...
]

若想生成其他题目类型训练数据,请在相应位置替换成需要的内容

csv_filename = '../data/多选题.csv'
#csv_filename = '../data/单选题.csv'
#csv_filename = '../data/判断题.csv'
#csv_filename = '../data/填空题.csv'
#csv_filename = '../data/简答题.csv'

json_filename = '../data/multiple_choice.json'
#son_filename = '../data/single_choice.json'
#json_filename = '../data/true_or_false.json'
#json_filename = '../data/fill_in.json'
#json_filename = '../data/shot_answer.json'

 "system": "你是一个煤矿安全领域的知识达人,你会解答很多题。用户给出一个多选题,你能从几个选项中,选择出多个正确选项。"
#"system": "你是一个煤矿安全领域的知识达人,你会解答很多题。用户给出一个单选题,你能从几个选项中,选择出一个正确选项。"
#"system": "你是一个煤矿安全领域的知识达人,你会解答很多题。用户给出一个判断题,然后你作出判断,是正确还是错误。"
#"system": "你是一个煤矿安全领域的知识达人,你会解答很多题。用户给出一个填空题,然后你在题目中括号中,填写合适的答案"
#"system": "你是一个煤矿安全领域的知识达人,你对相关煤矿安全规章规程制度、技术等文档非常熟悉。请你专业正确地解答用户想问的煤矿安全相关问题。"

3.2.3 煤矿事故多轮对话生成

"""
""{accident}""
通过上面提供的事故调查事故报告,请你帮我生成多轮对话文本,格式为[
    {
        "conversation": [
            {
                "system": "你是一名煤矿安全领域的知识达人,提供有关煤矿安全规程、事故预防措施和应急响应的知识。",
                "input": "你好!",
                "output": "您好,我是一名煤矿安全领域的知识达人,请问有什么可以帮助您?"
            },
            {
                "input": " ",
                "output": " "
            },...
        ]
    }
]
"""

{
        "conversation": [
            {
                "system": "你是一名煤矿安全领域的知识达人,提供有关煤矿安全规程、事故预防措施和应急响应的知识。",
                "input": "你好!",
                "output": "您好,我是一名煤矿安全领域的知识达人,请问有什么可以帮助您?"
            },
            {
                "input": "2023年3月22日子长县中达焦家沟煤业有限公司发生了什么事故?",
                "output": "2023年3月22日2时56分左右,子长县中达焦家沟煤业有限公司5309轨道顺槽发生其他事故,造成1人死亡,直接经济损失175万元。"
            },
            {
                "input": "事故的直接原因是什么?",
                "output": "事故的直接原因是矿井将不符合公称压力强度要求的管路作为瓦斯抽放管使用,在对瓦斯抽放管进行气密性试验的过程中瓦斯抽放管连接法兰处漏气,现场作业人员违反安全技术措施在瓦斯抽放管未泄压完毕的情况下用梅花扳手紧固漏气的法兰螺丝,维修过程中管路受力失衡爆裂,作业人员被喷出的冲击波击中,导致事故发生。"
            },
            {
                "input": "事故的间接原因有哪些?",
                "output": "事故的间接原因包括:\n1. 安全生产责任落实不力,管理制度执行不严格。\n2. 现场安全管理差,违章冒险作业。\n3. 技术管理有漏洞,技术人员业务素质低。\n4. 安全教育培训效果差,职工安全意识淡薄。\n5. 安全监管存在不足。"
            },
            {
                "input": "如何改进企业的安全生产责任落实?",
                "output": "改进企业的安全生产责任落实的方法包括:\n1. 明确设备材料管理制度,确保采购把关严格,合同明确产品具体型号和使用范围。\n2. 制定并落实管路使用和回收的管理制度,明确责任主体。\n3. 加强日常维护保养,确保设备材料的安全使用。\n4. 严格执行《安全技术措施审批制度》等安全管理制度,确保安全生产责任落实到位。"
            },
            {
                "input": "如何加强现场安全管理,避免违章冒险作业?",
                "output": "加强现场安全管理的方法包括:\n1. 全面分析和辨识作业风险,制定详细的安全技术措施。\n2. 确保现场作业人员严格遵守安全操作规程,禁止违章冒险作业。\n3. 增强现场监督管理,确保有安全员和负责人进行现场安全监督。\n4. 实现专业队伍化,确保瓦斯抽放管路敷设和测试等工作由专业队伍执行。"
            },
            {
                "input": "如何提高技术管理水平?",
                "output": "提高技术管理水平的方法包括:\n1. 加强安全技术规范的学习和掌握,确保技术人员了解相关标准和要求。\n2. 科学编制安全技术措施,选择更安全的工艺和方法。\n3. 严格执行安全技术措施会审制度,确保措施科学合理。\n4. 确保作业人员严格执行安全技术措施,避免违章操作。\n5. 进行全面的风险分析和研判,辨识并防范可能的安全风险。"
            },
            {
                "input": "如何改进安全教育培训效果?",
                "output": "改进安全教育培训效果的方法包括:\n1. 制定详细的培训计划,确保职工全面掌握安全技术措施。\n2. 提高培训的实际操作性,通过实操训练提高职工的安全意识和技能。\n3. 加强对作业环境的安全风险辨识培训,提升职工的自保互保意识。\n4. 定期进行培训考核,检验培训效果,确保培训内容深入人心。"
            },
            {
                "input": "如何增强安全监管力度?",
                "output": "增强安全监管力度的方法包括:\n1. 明确驻矿安检员的职责,确保其全面掌握煤矿安全生产状况。\n2. 加强安全监管分工,确保各项监管工作责任到人。\n3. 进行定期和不定期的安全检查,发现并整改安全隐患。\n4. 建立健全安全监管考核制度,确保安全监管工作落实到位。"
            }
        ]
    }

使用GLM-4模型,构建煤矿事故知识图谱。暂时不开源

  • 补充细节
  • 1.合并两个json文件的脚本:merge_json.py

  • 2.格式化json文本的脚本:format_json.py

  • 3.打乱json中数据顺序的脚本:shuffle.py

相关数据请见data目录:
安全知识的智能问答

4. 模型微调

4.1 Internlm2微调

  • 环境配置

创建环境

conda create -n internlm2 python=3.10
conda activate internlm2
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia

环境包的安装

cd ~
git clone -b v0.1.18 https://github.com/InternLM/XTuner
cd XTuner
pip install [-e .]
  • 下载本项目仓库
git clone https://github.com/yaosenJ/CoalQA.git
  • 下载模型

进入finetune目录

cd CoalQA/finetune

执行如下命令,下载internlm2-chat-7b模型参数文件:

python download_model.py

4.2 模型微调过程详细

本文档提供了使用 XTuner 工具进行模型微调过程的详细指南。该过程包括转换、合并、训练以及为不同规模的模型(1.8B 和 20B)设置网络演示。

  • 要求
    • XTuner
    • DeepSpeed
    • Huggingface Transformers
    • 具备 SSH 和 Git 的使用权限

4.2.1 环境安装

#如果你是在 InternStudio 平台,则从本地 clone 一个已有 pytorch 的环境:
#pytorch    2.0.1   py3.10_cuda11.7_cudnn8.5.0_0

studio-conda xtuner0.1.17
#如果你是在其他平台:
#conda create --name xtuner0.1.17 python=3.10 -y

#激活环境
conda activate xtuner0.1.17
#进入家目录 (~的意思是 “当前用户的home路径”)
cd ~
#创建版本文件夹并进入,以跟随本教程
mkdir -p /root/xtuner0117 && cd /root/xtuner0117

#拉取 0.1.17 的版本源码
git clone -b v0.1.17  https://github.com/InternLM/xtuner
#无法访问github的用户请从 gitee 拉取:
#git clone -b v0.1.15 https://gitee.com/Internlm/xtuner

#进入源码目录
cd /root/xtuner0117/xtuner

#从源码安装 XTuner
pip install -e '.[all]'

4.2.2 1.8B 模型训练

  • 数据准备
#在ft这个文件夹里再创建一个存放数据的data文件夹,存储数据
mkdir -p /root/ft/data && cd /root/ft/data
  • 准备模型
#创建目标文件夹,确保它存在。
#-p选项意味着如果上级目录不存在也会一并创建,且如果目标文件夹已存在则不会报错。
mkdir -p /root/ft/model

#复制内容到目标文件夹。-r选项表示递归复制整个文件夹。
cp -r /root/share/new_models/Shanghai_AI_Laboratory/internlm2-chat-1_8b/* /root/ft/model/

如果是需要自己下载,可以使用transformers库

from transformers import AutoModel

#指定模型名称
model_name = 'internlm/internlm2-chat-1_8b'

#加载模型
model = AutoModel.from_pretrained(model_name)

#指定保存模型的目录
model_save_path = '/root/ft/model'

#保存模型
model.save_pretrained(model_save_path)

将这段代码保存为
download_model.py
,然后在命令行中运行这个脚本:

python download_model.py

这个脚本会自动下载模型并将其保存到指定的
/root/ft/model
目录中。

  • 下载配置文件
#XTuner 提供多个开箱即用的配置文件,用户可以通过下列命令查看:
#列出所有内置配置文件
#xtuner list-cfg

#假如我们想找到 internlm2-1.8b 模型里支持的配置文件
xtuner list-cfg -p internlm2_1_8b

#创建一个存放 config 文件的文件夹
mkdir -p /root/ft/config

#使用 XTuner 中的 copy-cfg 功能将 config 文件复制到指定的位置
xtuner copy-cfg internlm2_1_8b_qlora_alpaca_e3 /root/ft/config
  • 修改配置参数
#修改模型地址(在第27行的位置)
- pretrained_model_name_or_path = 'internlm/internlm2-1_8b'
+ pretrained_model_name_or_path = '/root/ft/model'

#修改数据集地址为本地的json文件地址(在第31行的位置)
- alpaca_en_path = 'tatsu-lab/alpaca'
+ alpaca_en_path = '/root/ft/data/personal_assistant.json'

#修改max_length来降低显存的消耗(在第33行的位置)
- max_length = 2048
+ max_length = 1024

#减少训练的轮数(在第44行的位置)
- max_epochs = 3
+ max_epochs = 2

#增加保存权重文件的总数(在第54行的位置)
- save_total_limit = 2
+ save_total_limit = 3

#修改每多少轮进行一次评估(在第57行的位置)
- evaluation_freq = 500
+ evaluation_freq = 300

#修改具体评估的问题(在第59到61行的位置)

#把 OpenAI 格式的 map_fn 载入进来(在第15行的位置)
- from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
+ from xtuner.dataset.map_fns import openai_map_fn, template_map_fn_factory

#将原本是 alpaca 的地址改为是 json 文件的地址(在第102行的位置)
- dataset=dict(type=load_dataset, path=alpaca_en_path),
+ dataset=dict(type=load_dataset, path='json', data_files=dict(train=alpaca_en_path)),

#将 dataset_map_fn 改为通用的 OpenAI 数据集格式(在第105行的位置)
- dataset_map_fn=alpaca_map_fn,
+ dataset_map_fn=None,
  • 模型训练
#指定保存路径
xtuner train /root/ft/config/internlm2_1_8b_qlora_alpaca_e3_copy.py --work-dir /root/ft/train

#使用 deepspeed 来加速训练
xtuner train /root/ft/config/internlm2_1_8b_qlora_alpaca_e3_copy.py --work-dir /root/ft/train_deepspeed --deepspeed deepspeed_zero2
  • 转换到 Huggingface 格式
  1. 创建目录
    :为转换后的 Huggingface 模型创建一个存储目录:

    mkdir -p /root/ft/huggingface/i8000
    
  2. 模型转换
    :使用提供的配置和权重文件进行模型转换:

    xtuner convert pth_to_hf /root/ft/config/internlm2_1_8b_qlora_alpaca_e3_copy.py /root/ft/train_deepspeed/iter_18000.pth /root/ft/huggingface/i8000 --fp32
    
  3. 合并模型
    :合并模型并解决依赖关系:

    mkdir -p /root/ft/final_model_8000
    export MKL_SERVICE_FORCE_INTEL=1
    xtuner convert merge /root/ft/model /root/ft/huggingface/1i8000 /root/ft/final_model_18000
    
  4. 测试模型
    :通过启动对话来测试模型:

    xtuner chat /root/ft/final_model_18000 --prompt-template internlm2_chat
    
  • 模型续训
xtuner train /root/ft/config/internlm2_1_8b_qlora_alpaca_e3_copy.py --work-dir /root/ft/train_deepspeed --resume /root/ft/train_deepspeed/iter_8500.pth  --deepspeed deepspeed_zero1
  • 网络演示设置
  1. 准备环境

    mkdir -p /root/ft/web_demo && cd /root/ft/web_demo
    git clone https://github.com/InternLM/InternLM.git
    cd /root/ft/web_demo/InternLM
    
  2. 运行演示
    使用 Streamlit:

    streamlit run /root/ft/web_demo/InternLM/chat/web_demo.py --server.address 127.0.0.1 --server.port 6006
    
  3. 通过 SSH 隧道访问演示

    ssh -CNg -L 6006:127.0.0.1:6006 root@ssh.intern-ai.org.cn -p 开发机端口号
    

4.2.3. 20B 模型训练

与1.8B模型训练过程类似,20B模型训练涉及到为配置、数据和最终模型创建相应的目录。此外,这一过程还包括使用多个GPU进行模型训练,并将模型转换为Huggingface格式。

  • 数据准备

为大规模的20B模型训练准备数据。

#创建一个专用于存放20B模型数据的目录
mkdir -p /root/ft20b/data && cd /root/ft20b/data
  • 准备模型

准备模型包括创建目标文件夹并将预训练的20B模型复制到指定位置。

#创建一个目录用来存放20B模型文件
mkdir -p /root/ft20b/model

#将预训练的模型复制到新创建的目录中
cp -r /root/share/new_models/Shanghai_AI_Laboratory/internlm2-chat-20b/* /root/ft20b/model/
  • 下载配置文件

下载并准备20B模型的配置文件,以便进行训练。

#列出所有支持20B模型的配置文件
xtuner list-cfg -p internlm2_20b

#创建一个目录用于存放20B模型的配置文件
mkdir -p /root/ft20b/config

#复制所需的配置文件到新创建的目录中
xtuner copy-cfg internlm2_20b_qlora_alpaca_e3 /root/ft20b/config
  • 修改配置参数

根据训练需求调整配置文件,以优化20B模型的训练。

#修改模型路径和数据集路径等关键参数以适配20B模型
- pretrained_model_name_or_path = 'internlm/internlm2-20b'
+ pretrained_model_name_or_path = '/root/ft20b/model'

- alpaca_en_path = 'tatsu-lab/alpaca'
+ alpaca_en_path = '/root/ft20b/data/specific_dataset.json'

- max_length = 2048
+ max_length = 1024

- max_epochs = 3
+ max_epochs = 2

- save_total_limit = 2
+ save_total_limit = 3

- evaluation_freq = 500
+ evaluation_freq = 300
  • 模型训练

使用DeepSpeed和多GPU配置来加速20B模型的训练过程。

#指定保存路径并开始训练
xtuner train /root/ft20b/config/internlm2_20b_qlora_alpaca_e3_copy.py --work-dir /root/ft20b/train_deepspeed --deepspeed deepspeed_zero2
  • 转换到 Huggingface 格式

为转换后的Huggingface模型创建目录并执行转换。

#创建一个目录用于存放转换后的Huggingface模型
mkdir -p /root/ft20b/huggingface

#执行模型转换
xtuner convert pth_to_hf /root/ft20b/config/internlm2_20b_qlora_alpaca_e3_copy.py /root/ft20b/train_deepspeed/iter_2600.pth /root/ft20b/huggingface
  • 2.7 模型合并

合并转换后的模型并解决依赖关系。

#创建一个名为final_model的目录以存储合并后的模型文件
mkdir -p /root/ft20b/final_model

#合并模型
xtuner convert merge /root/ft20b/model /root/ft20b/huggingface /root/ft20b/final_model
  • 测试模型

通过启动对话来测试合并后的模型。

#启动与模型的对话测试
xtuner chat /root/ft20b/final_model --prompt-template

 internlm2_chat

这一部分提供了详细的指导,确保20B模型的训练过程得到妥善管理和执行。

4.2.4 微调20b配置样例

max_length = 4096
pack_to_max_length = True

#parallel
sequence_parallel_size = 1

#Scheduler & Optimizer
batch_size = 4  # per_device
accumulative_counts = 16
accumulative_counts *= sequence_parallel_size
dataloader_num_workers = 0
max_epochs = 50

=》

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:89:00.0 Off |                    0 |
| N/A   65C    P0             334W / 400W |  59119MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          On  | 00000000:B3:00.0 Off |                    0 |
| N/A   66C    P0             358W / 400W |  59119MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

4.2.5 其他注意事项

  • 单卡训完的,不可以在双卡上续训

原因是:

问题的根源
:尝试加载的模型检查点是在数据并行(DP)世界大小为1(即单个GPU或单个训练进程)的环境中保存的。但当前尝试恢复训练的环境具有数据并行世界大小为2(即两个GPU或两个训练进程)。

ZeRO的限制
:DeepSpeed的ZeRO优化器分区(ZeRO-Optimizer State Partitioning)依赖于特定的世界大小配置,并且目前不支持自动调整新的世界大小。换句话说,如果你在一个GPU上训练并保存了检查点,那么在加载这个检查点进行恢复训练时,你必须在相同数量的GPU上进行。

  • 性能最优配置
    包括设置最大序列长度、批量大小及其他 DeepSpeed 特定设置。

5. 检索增强生成RAG

模块目的:根据用户的问题,检索对应信息以增强回答的专业性, 使CoalQA的回答更加专业可靠。检索内容包括但不限于以下几点:

  • 煤矿安全相关的问答对

  • 煤矿安全相关的案例

5.1 技术路线:

这个RAG系统各部分的技术路线分别是:

  • Embedding 模型:BAAI/bge-small-en-v1.5

  • LLM基座:InternLM2-Chat-1.8B InternLM2-Chat-7B InternLM2-Chat-20B InternLM2-Math-7B

  • 使用BAAI/bge-reranker-large做检索后精排

  • 向量数据库:


    • FAISS:是Facebook开源的一个高效的向量相似性搜索库。它为大规模向量检索提供了多种索引结构和搜索算法,能够快速查找与给定向量最相似的向量。FAISS擅长处理超高维度的向量数据,在图像检索、推荐系统等场景有广泛应用。

5.2 环境搭建:

git clone https://github.com/yaosenJ/CoalQA.git
cd rag
pip install -r requirements.txt

5.3 本地数据集向量化

PDF 数据

将需要向量化的pdf数据放到对应的目录下

JSON 数据

构建 QA 对并生成 JSON 文件(单轮/多轮对话),放入到对应目录下

数据格式如下

#单轮对话
{
  "conversation": [
      {
          "system": "你是一个煤矿安全领域的知识达人,你会解答很多题。用户给出一个判断题,然后你作出判断,是正确还是错误。",
          "input": "国家对在改善安全生产条件、防止生产安全事故、参加抢险救护等方面取得显著成绩的单位给予奖励;个人给予鼓励。",
          "output": "错误"
      }
  ]
},
#多轮对话
{
  "conversation": [
      {
          "system": "你是一名煤矿安全领域的知识达人,提供有关煤矿安全规程、事故预防措施和应急响应的知识。",
          "input": "你好!",
          "output": "您好,我是一名煤矿安全领域的知识达人,请问有什么可以帮助您?"
      },
      {
          "input": "2022年9月11日晋能控股集团山西三元煤业股份有限公司下霍煤矿井下无轨胶轮车大巷车场发生了什么事故?",
          "output": "2022年9月11日,晋能控股集团山西三元煤业股份有限公司下霍煤矿井下无轨胶轮车大巷车场发生了一起运输事故,造成1人死亡。"
      },
      {
          "input": "请告诉我这起事故的直接原因是什么?",
          "output": "事故直接原因是井下进行设备换装作业时,机电机运队跟班副队长池晶晶在未采取并确认有效防倒措施的情况下,违章指挥并参与卸车作业,且其本人处在竖装设备倾倒危险区域内,被突然倾倒的端头支架前梁和侧护板砸压受伤致死。"
      },
      {
          "input": "那么,这起事故的间接原因有哪些?",
          "output": "事故间接原因包括现场监督管理严重缺失、规程措施审批把关不严、职工安全风险意识淡薄、上级主体企业责任落实不到位、地方安全监管部门监管不到位。"
      },
      {
          "input": "为了避免类似事故,煤矿企业应该采取哪些措施?",
          "output": "为了避免类似事故,煤矿企业应严格执行安全规程,加强现场安全管理,确保安全技术措施到位,强化安全风险研判,建立完善的安全技术管理体系,落实安全生产主体责任,加强安全教育培训,提高职工安全意识和风险辨识能力,同时加强安全监管,确保监管工作细致认真。"
      },
      {
          "input": "感谢您的详细解答!",
          "output": "不客气,很高兴能帮助到您。如果您还有其他问题,欢迎继续提问。"
      }
  ]
},

5.4 构建向量数据库

  • 1.配置文件修改

根据需要改写 config.config 文件:

#llm存放位置
model_dir = os.path.join(base_dir, 'model')   

#向量化模型路径以及模型名称
embedding_path = os.path.join(model_dir, 'embedding_model')         # embedding
embedding_model_name = 'BAAI/bge-small-zh-v1.5'

#精排模型路径以及模型名称
rerank_path = os.path.join(model_dir, 'rerank_model')  	        	  # embedding
rerank_model_name = 'BAAI/bge-reranker-large'

#召回documents数量
retrieval_num = 3

#精排后最终选择留下的documents数量
select_num = 3

prompt_template = """
    你是一个乐于助人的问答代理人。\n
    你的任务是分析并综合检索回来的信息,从而提供有意义且高效的答案。
	{content}
	问题:{query}
"""
  • 2.本地调用

运行构建本地知识库脚本

python data_generate.py

向量化主要步骤如下:

  • 加载pdf数据集并提取文本

  • 利用RecursiveCharacterTextSplitter按照一定块的大小以及块之间的重叠大小对文本进行分割。

  • 加载 BAAI/bge-small-en-v1.5 模型

  • 根据文档集构建FAISS索引(即高性能向量数据库)

5.5 相关文本召回与精排

利用faiss找出与用户输入的问题最相关的文档,然后将召回出来的文本与用户原始输入拼接输入给llm。检索代码如下:

def get_retrieval_content(self, querys) -> str:
        """
            Input: 用户提问, 是否需要rerank
            ouput: 检索后的内容        
        """
        #print(querys)
        output = []
        content = []
        for query in querys:
            
            documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)
            
            for doc in documents:
                content.append(doc.page_content)
            logger.info(f'Contexts length:{len(content)}')
            if self.rerank_flag:
                model = self.data_processing_obj.load_rerank_model()
                documents = self.data_processing_obj.rerank(model, query, content, self.select_num)

                for doc in documents:
                    output.append(doc)
                logger.info(f'Selected contexts length:{len(output)}')
                logger.info(f'Selected contexts: {output}')
            else:
                logger.info(f'Selected contexts: {content}')
        return output if self.rerank_flag else content

5.6 RAG具体流程小结

  • 根据数据集构建 vector DB

  • 对用户输入的问题进行 embedding

  • 基于 embedding 结果在向量数据库中进行检索

  • 对召回数据重排序

  • 依据用户问题和召回数据生成最后的结果

5.7 使用Neo4j和Langchain集成非结构化和图知识增强煤矿事故QA

6. 部署

6.1 本地部署

  • 直接使用pytorch原生加载streamlit应用
   cd CoalQA
   conda create -n CoalQA python=3.10.0 -y
   conda activate CoalQA
   conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia
   pip install -r requirements.txt
   cd web_app
   streamlit run streamlit_app.py --server.address=127.0.0.1 --server.port 6006

6.2 openxlab部署:

直接使用pytorch原生加载streamlit应用

streamlit run app.py --server.address=127.0.0.1 --server.port 6006
  • 登陆 OpenXLab,创建 Streamlit 应用

  • 选择配置,创建应用,如果需要更多的硬件资源,在这里进行申请

6.3 基于
LMDeploy
的量化部署:

  • LMDeploy简介

LMDeploy 由
MMDeploy

MMRazor
团队联合开发,是涵盖了 LLM 任务的全套轻量化、部署和服务解决方案。 这个强大的工具箱提供以下核心功能:

  • 高效推理:LMDeploy 通过引入持久批处理(又称连续批处理)、阻塞式 KV 缓存、动态拆分与融合、张量并行、高性能 CUDA 内核等关键功能,将请求吞吐量提高到 vLLM 的 1.8 倍。

  • 有效量化:LMDeploy 支持只加权量化和 k/v 量化,4 位推理性能是 FP16 的 2.4 倍。量化质量已通过 OpenCompass 评估确认。

  • 轻松分发服务器:利用请求分发服务,LMDeploy 可在多台机器和卡上轻松高效地部署多模型服务。

  • 交互式推理模式:通过缓存多轮对话过程中的关注度 k/v,引擎可记住对话历史,从而避免重复处理历史会话。

6.3.1 环境安装

pip安装:

pip install lmdeploy

自 v0.3.0 起,默认预编译包在
CUDA 12
上编译。不过,如果需要
CUDA 11+
,可以通过以下方式安装 lmdeploy:

export LMDEPLOY_VERSION=0.3.0
export PYTHON_VERSION=38
pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118

6.3.2 使用LMDeploy与模型对话

使用LMDeploy与模型进行对话,可以执行如下命令运行下载的1.8B模型

lmdeploy chat /group_share/internlm2_chat_1_8b_qlora_18000

6.3.3 LMDeploy模型量化(lite)

  • 设置最大KV Cache缓存大小

通过 --cache-max-entry-count参数,控制KV缓存占用剩余显存的最大比例为0.5

lmdeploy chat /group_share/internlm2_chat_1_8b_qlora_18000 --cache-max-entry-count 0.5
  • 使用W4A16量化

LMDeploy使用AWQ算法,实现模型4bit权重量化。推理引擎TurboMind提供了非常高效的4bit推理cuda kernel,性能是FP16的2.4倍以上。它支持以下NVIDIA显卡:

  • 图灵架构(sm75):20系列、T4
  • 安培架构(sm80,sm86):30系列、A10、A16、A30、A100
  • Ada Lovelace架构(sm90):40 系列

运行前,首先安装一个依赖库。

pip install einops==0.7.0

仅需执行一条命令,就可以完成模型量化工作。

lmdeploy lite auto_awq \
   /group_share/internlm2_chat_1_8b_qlora_18000  \
  --calib-dataset 'ptb' \
  --calib-samples 128 \
  --calib-seqlen 1024 \
  --w-bits 4 \
  --w-group-size 128 \
  --work-dir /group_share/internlm2_chat_1_8b_qlora_18000-4bit

6.3.4 LMDeploy服务(serve)

通过以下lmdeploy命令启动API服务器,推理模型:

lmdeploy serve api_server \
    /group_share/internlm2_chat_1_8b_qlora_18000-4bit \
    --model-format hf \
    --quant-policy 0 \
    --server-name 0.0.0.0 \
    --server-port 23333 \
    --tp 1

即可以得到FastAPI的接口

7.案例展示

项目代码:

安全知识的智能问答-安全大模型

更多优质内容请关注公号:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。

1.新建项目,下载Nuget安装包

创建项目需要注意几点,如果是基于 .net framework 的项目 需要选择 相应版本的 EF, 如果是跨平台则选择EF Core版本。

我这里选择的是 .net framework 版本。红框里面是 实现EF Code First 需要的包。

对应的版本:

EntityFramework 6.3.0

MySql.Data 6.8.8

MySql.Data.Entities 6.8.3

如果是连接SqlServer 很简单,直接下载 EntityFramework 6.3.0 这个一个包就行了。程序集会引入这两个组件。然后编写代码即可。

针对MySQL 需要再下载这两个包

下载完成后设置App.config或者 Web.config 文件
,这一步下载包的时候一般会自动添加,如果没有的话就手动加一下

  <entityFramework>
    <providers>
    <provider invariantName="System.Data.SqlClient" type="System.Data.Entity.SqlServer.SqlProviderServices, EntityFramework.SqlServer" />
    <provider invariantName="MySql.Data.MySqlClient" type="MySql.Data.MySqlClient.MySqlProviderServices, MySql.Data.Entity.EF6, Version=6.8.3.0, Culture=neutral, PublicKeyToken=c5687fc88969c44d"></provider>
    </providers>
  </entityFramework>

2.创建EFModel

usingSystem;usingSystem.Collections.Generic;usingSystem.ComponentModel.DataAnnotations;usingSystem.ComponentModel.DataAnnotations.Schema;usingSystem.Data.Entity;usingSystem.Linq;usingSystem.Text;usingSystem.Threading.Tasks;namespaceConsoleWebSocket.Models
{
[Table(
"BaseDevice")]public classBaseDevice
{
[Key]
public int Id { get; set; }public string Name { get; set; }public string Description { get; set; }
}
public classBaseDeviceDbContext : DbContext
{
publicBaseDeviceDbContext()
:
base("myConn")
{
Database.SetInitializer(
new DropCreateDatabaseIfModelChanges<BaseDeviceDbContext>());
}
public DbSet<BaseDevice> BaseDevice { get; set; }
}
}

4.操作数据库 测试

        /// <summary>
        ///code first/// </summary>
        public static voidTestCodeFirst()
{
using (var context = newBaseDeviceDbContext())
{
//查询数据 List<BaseDevice> models =context.BaseDevice.ToList();//添加数据 context.BaseDevice.Add(new BaseDevice { Id = 1, Name = "New Model", Description= "Description"});
context.BaseDevice.Add(
new BaseDevice { Id = 3, Name = "New Model", Description = "Description"});
context.SaveChanges();
//// 更新数据 var model = context.BaseDevice.FirstOrDefault(m => m.Id == 1);if (model != null)
{
model.Name
= "Updated Name";
context.SaveChanges();
}
//删除数据 context.BaseDevice.Remove(model);
context.SaveChanges();
}
}

随机数对程序设计来说很重要,今天就从几方面探讨下一些常见的随机数相关的问题。

本文只讨论整数相关的随机数,另外需要你对概率论有最基本的了解(至少知道古典概型是什么)。

本文索引

如何从rand7生成rand5

首先是和某个知名算法题相反的问题。

这里给定一个可以概率均等地生成0到6的随机数生成器,要求用这个生成器创造出能概率均等地生成0到4的随机数生成器。

有人可能会立刻给出这样的答案:

func rand5() int {
    return rand7() % 5
}

然而这个答案只满足了输出的范围在0到4,不满足概率均等,所以不正确。这种时候列表法的作用就显现出来了:

rand7的输出 rand7 % 5
0 0
1 1
2 2
3 3
4 4
5 0
6 1

发现问题了吗,0和1出现了两次,它们出现的概率是其他数字的两倍,因此概率分布并不均等。

通过列表法我们其实也发现了这个问题真正的解法:除掉5和6的话剩下的输出不仅符合要求概率也是均等的,所以代码会变成这样:

func rand5() int {
    n := rand7()
    for n >= 5 {
        n = rand7()
    }
    return n
}

上面的代码其实就是随机采样算法中非常重要的一种:拒绝采样。同样上面的rand7生成rand5也可以归类成一大类问题:给定一组满足规律或者特征是
g(x)
的样本,现在需要从这些样本中筛选出或者生成另一组满足特征是
f(x)
的样本。解决这类问题的算法很多,而拒绝采样是比较直观的:判断
g(x)
的样本是否符合要求,不符合的就排除取下一个样本,符合要求的就归类到满足
f(x)
的样本集合中。

按照这个角度来看,上面的问题可以划分为几个基本元素:

  • g(x):rand7
  • f(x): rand5
  • 需要拒绝的样本:大于等于5的整数

拒绝采样在大多数时间都能获得理想的结果,但还有采样率需要注意。采样率就是
g(x)
的样本中有多少可以被接受,采样率太低的时候意味着算法的效率也会非常低。所以我们简单算算rand5的采样率,是七分之五,大约70%。这个概率不大不小,勉强合适。

当采样率过低的时候要么得改变拒绝采样的标准或范围,要么就不能再使用拒绝采样了。

go标准库的做法

标准库里当然不会有rand5和rand7,但它提供了一个叫
Int63n
的函数,它解决的问题是如何从均匀分布在
[0, 2⁶⁴)
范围上的随机整数中生成均匀分布在范围
[0, n)
上的随机整数。换句话说,虽然范围不一样了,但还是同一个问题。

我们肯定不能像上面那样把大于等于n的样本全丢了,因为2⁶⁴包含至少1844京(1E16)个整数样本,这会带来低得无法接受的采样率。

但因为我们是用mod来选择范围内的随机数的,因此我们可以选择n的倍数,这个证明太简单了,列表法加归纳就行。或者还可以这么想,有一个整数常数C,
x % n

Cx % n
能产生的输出的种类和它们的数量都是完全相同的,所以如果样本能均匀分布在
[0, n)
的范围上,那么范围
[0, C·n]
只不过是
[0, n)
重复了C次,所以样本在每一段上都是均匀分布的,整个合起来的区间上也是均匀的。

有了常数C,这样使得我们可以尽可能地让更多的样本被采样,这样能降低重试的次数。

C其实也很好选择,取一个2⁶⁴内的n的最大的倍数就行,如果C本身能被2⁶⁴整除,那么C就是
2⁶⁴/n

所以标准库是这样写的:

func (r *Rand) Int63n(n int64) int64 {
	if n <= 0 {
		panic("invalid argument to Int63n")
	}
	if n&(n-1) == 0 { // n is power of two, can mask
		return r.Int63() & (n - 1)
	}
	max := int64((1 << 63) - 1 - (1<<63)%uint64(n))
	v := r.Int63()
	for v > max {
		v = r.Int63()
	}
	return v % n
}

代码还是很简单的,超过
C·n
的样本全部拒绝采样,剩下的样本就能保证在
mod n
的时候获得分布均匀的随机整数了。

采样率是多少?我们可以利用拒绝率来反推,这里拒绝率还挺好算的,就是
(1<<63)%uint64(n)
,算下来拒绝了少的时候是百亿分之一,多的时候是数千万分之一——都很低,基本上大多数时间最多两次重试就能获得想要的结果了。

但作为标准库,它的性能还不够,尤其是go的编译优化非常弱的现实下,更需要高效的算法弥补。问题出在哪?首先不是采样率,这个采样率是足够的,问题出在它需要两次64位除法,除法运算相比其他运算比如右移要慢很多,何况还是两次,别的语言中的标准库采用的算法只需要0到1次除法就够了。

好在go提供了
math/rand/v2
,采用了更高效的算法。

新算法依旧基于拒绝采样,但对采样的范围进行了变更,具体是这样的:

  1. 依然用概率均等的rand64生成一个随机整数x
  2. 现在把
    x*n
    ,这样生成的值的范围是
    [0, n·2⁶⁴)
  3. 因为是对已有范围的等比扩大,所以
    x*n

    [0, n·2⁶⁴)
    依旧是均匀分布的(不过要注意,范围扩展了,但样本的总量不变还是2⁶⁴个)
  4. [0, n·2⁶⁴)
    可以均等分成n个范围,分别是
    [0, 2⁶⁴)
    ,
    [2⁶⁴, 2*2⁶⁴)
    ,
    [2*2⁶⁴, 3*2⁶⁴)
    ...
    [(n-1)*2⁶⁴, n*2⁶⁴)
  5. 这样每一个均等分割的范围整除以2⁶⁴就可以得到一个整数k,k一定在
    [0, n)
  6. k可以当作符合要求的结果,而整除以2⁶⁴实际上可以转换成位运算,这样除法运算可以减少一次。

新算法有几个问题,第一个是
x*n
在大多数情况下会超过2⁶⁴,但这不用担心,因为go提供了高性能128位整数运算。

第二个是
x*n
虽然在
[0, n·2⁶⁴)
均匀分布,但我们怎么保证在均等分割的每个
[(k-1)*2⁶⁴, k*2⁶⁴)
上也是均等分布的呢?

答案是如果只有上面写的六个步骤,我们保证不了。原因是因为要想保证
x*n
均匀分布在每个
[(k-1)*2⁶⁴, k*2⁶⁴)
上,我们就要保证x本身要均匀分布在
[(k-1)*(2⁶⁴/n), k*(2⁶⁴/n))
上,换人话说,就是把2⁶⁴分割成n份,每份里的样本数量都要一致。因为我们的样本都是整数而不是实数,所以动动脚趾就能想到很多数是不能整除2⁶⁴的,因此会留下“余数”。但我们的新算法实际上假设了x均匀分布在2⁶⁴分割出来的均等的范围内。不能整除的情况下意味着即使按最均匀的办法分割,也会存在一部分范围比其他的范围多几个样本或者少几个样本,会多还是会少取决与你对
2⁶⁴/N
取整的方式。

但这问题不大,通常分段直接的数量差异对概率产生的误差非常小,比如我们把n取6,按尽可能均匀的分割,就存在4个分段比剩下的分段里的样本总数多1个,但每个分段的样本数量都有超过3E18个,多一个还是多两个带来的影响几乎可以忽略不计。

然而标准库最重要的是要保证结果的正确性,即使可能性是3E18分之一,它依旧不是0,函数的实现是不正确的,更何况根据n的选择,n越大分段的样本数量越少,分段之间数量差异带来的影响就会越来越大,总有一个n能让结果的误差大到无法忽略。

问题其实也好解决,因为我们知道始终会有一些分组的样本是多余的,我们只要保证分组里的样本数量一致就行,不需要关心具体剔除的样本是什么。假设我们采用向下取整的办法,那么会存在一些理论上应该在分段k上的样本跑到k+1的分组上,这些样本通常分布在分段的起始位置上,我们可以把这些样本拒绝采样,这样比较容易实现。这些样本乘以n之后会落在
[k*2⁶⁴, k*2⁶⁴+(2⁶⁴%n))
上。

剔除这些样本后,我们就能保证
x*n
在每个
[(k-1)*(2⁶⁴/n), k*(2⁶⁴/n))
上都是均匀分布的了。

思路理解了看代码也就不难了:

func (r *Rand) uint64n(n uint64) uint64 {
	if is32bit && uint64(uint32(n)) == n {
		return uint64(r.uint32n(uint32(n)))
	}
	if n&(n-1) == 0 { // n is power of two, can mask
		return r.Uint64() & (n - 1)
	}

    hi, lo := bits.Mul64(r.Uint64(), n)
	if lo < n {
		thresh := -n % n // 2⁶⁴ % n 的简化形式
		for lo < thresh {
			hi, lo = bits.Mul64(r.Uint64(), n)
		}
	}
	return hi
}

精髓在于利用
(x*n) >> 64
来避免了
x % n
带来的除法运算。而且新算法不用一开始就算余数,因此运气好的时候可以一次除法都不做。

还有一个小疑问,128位乘法够了吗?肯定够了,因为n最大也只能取到2⁶⁴,这意味这
x*n
的范围最大也只到
[0, 2⁶⁴·2⁶⁴)
,128位乘法刚好够用。

最后做下性能测试,标准库里已经提供了,在我的10代i5上旧算法一次调用需要18ns,新算法只需要5ns,两者使用的随机数发生器是一样的,因此可以说新算法快了3倍,提升还是很可观的。

从rand5生成rand7

上一节讨论了从更大的样本空间里筛选出特定特征的子集,这一节我们反过来:从范围更小的样本空间里派生出有某一特征的超集。

同时,这也是一道常见的中等难度的算法题。

首先要考虑的是如何把受限的样本空间尽量扩张。上一节我们用乘法来扩展了样本分布的范围,然而乘法尤其是乘以常数是没法增加样本数量的,因此这个做法只能pass。加法可以平移样本的范围,但也不能增加样本总量,而且我们需要样本空间是
[0, x)
平移之后起点都变了,因此也不行。

那剩下的可行的也最稳定的办法是
rand5() * rand5()
。它像乘法一样能扩张样本的范围,同时因为不是乘以常数因此还有机会产生新的样本。我们列个表看看:

rand5 rand5 rand5 * rand5
0 0 0
0 1 0
0 2 0
0 3 0
0 4 0
1 0 0
1 1 1
1 2 2
1 3 3
1 4 4
2 0 0
2 1 2
2 2 4
2 3 6
2 4 8
3 0 0
3 1 3
3 2 6
3 3 9
3 4 12
4 0 0
4 1 4
4 2 8
4 3 12
4 4 16

确实有新样本出现了,但不够连续,比如没有7和10。因此这条路是不通的。

这时候就要上原汁原味的拒绝采样算法了,我们使用
5 * rand5 + rand5

rand5 rand5 5 * rand5 + rand5
0 0 0
0 1 1
0 2 2
0 3 3
0 4 4
1 0 5
1 1 6
1 2 7
1 3 8
1 4 9
2 0 10
2 1 11
2 2 12
2 3 13
2 4 14
3 0 15
3 1 16
3 2 17
3 3 18
3 4 19
4 0 20
4 1 21
4 2 22
4 3 23
4 4 24

没错,正好产生了均等分布的0到24的整数。很神奇吧,其实想明白为什么不难。我们先看
5 * rand5
,这样或产生0、5、10、15、20这五个数字,我们要想有机会生成连续的整数,就一定需要把缺少的1到4,11到14这些数字补上。这时候正巧一个
+ rand5
就可以把这些缺的空洞全部填上。当然用进位来理解会更简单。

总结:
n * randn + randn
可以产生连续的范围在
[0, n*n)
的均匀分布的整数。注意这里没有结合律,因为randn每次的结果都是不一样的

这个样本空间是远超rand7的要求的,因此现在问题回到了第一节:如何从rand25生成rand7?现在大家都知道了:

func rand7() int {
    x := 5*rand5() + rand5()
    max := 25 - 25%7
    for x >= max {
        x = 5*rand5() + rand5()
    }
    return x % 7
}

你也可以改写成上一节说的更高效的做法。

充分利用每一个bit

rand.Uint64()
返回的随机数有足足64bits,而我们通常不需要这么大的随机数,举个例子,假如我们只需要0到15的随机数,这个数字只需要4bits就够了,如果用
rand.Uint64N(16)
来生成,我们会浪费掉60bits的数据。

为啥这么说?因为
rand.Uint64()
保证能概率均等的生成
[0, 2⁶⁴)
范围内的整数,这其实说明了两件事,第一是这个随机数的每一个bit也都是随机的,这个很明显;第二个是每个bits是0还是1的概率也是均等的,这个可以靠列表加归纳法得出。我们来看看第二点这么证明。

先假设我们有一个能均匀生成
[0, 8)
范围内数字的随机数发生器,现在我们看看所有可能的情况:

生成的数字 二进制表示 从左到右第一位 从左到右第二位 从左到右第三位
0 000 0 0 0
1 001 0 0 1
2 010 0 1 0
3 011 0 1 1
4 100 1 0 0
5 101 1 0 1
6 110 1 1 0
7 111 1 1 1

不难注意到,三个bit各有八种输出,0占四种,1占剩下四种,0和1的概率均等。这个结论能推广到任意的
[0, 2^n)
范围上。

同样,基于这个结论,我们还能得到这样一个结论,任意连续的n个bit,都能产生均匀分布在
[0, 2^n)
上的随机数,这个证明太简单了,所以我们集中注意力就行了。

现在回头看看为什么我说会浪费60bits,因为根据上面的结论,我们64位的随机整数完全可以按每四位划分一次,这样可以分成16组,而每组正好能产生
[0, 16)
范围内的随机数,且概率同样的均等的。也就是说一次
rand.Uint64()
理论上应该可以产生16个我们需要的随机数,但实际上我们只生成了一个。这里就有很大的提升空间了。

怎么做呢?我们需要一点位运算,把64位整数分组四位一组:

n := rand.Uint64()
mask := uint64(0xf) // 0b1111
for i := range 10 {
	a[i] = int(n & mask) // 只让最右边四位有效(书写顺序的右边,这里不讨论大小端了因为说明起来太麻烦)
	n >>= 4 // 把刚刚使用过的四位去掉
}

代码很简单,下面看看性能测试:

// 不要这样写代码
// 我这么做是为了避免内存分配和回收会对测试产生不必要的杂音
func getRand(a []int) {
	if len(a) < 10 {
		panic("length wrong")
	}
	for i := range 10 {
		a[i] = int(rand.Int32N(16))
	}
}

// 不要这样写代码
// 我这么做是为了避免内存分配和回收会对测试产生不必要的杂音
func getRandSplit(a []int) {
	if len(a) < 10 {
		panic("length wrong")
	}
	n := rand.Uint64()
	mask := uint64(0xf)
	for i := range 10 {
		// 这里不需要mod
		a[i] = int(n & mask)
		n >>= 4
	}
}

func BenchmarkGetRand(b *testing.B) {
	var a [10]int
	for range b.N {
		getRand(a[:])
	}
}

func BenchmarkGetRandSplit(b *testing.B) {
	var a [10]int
	for range b.N {
		getRandSplit(a[:])
	}
}

测试结果:

goos: windows
goarch: amd64
pkg: benchrand
cpu: Intel(R) Core(TM) i5-10200H CPU @ 2.40GHz
BenchmarkGetRand-8        	15623799	        79.31 ns/op	       0 B/op	       0 allocs/op
BenchmarkGetRandSplit-8   	100000000	        11.18 ns/op	       0 B/op	       0 allocs/op

充分利用每一个bit之后我们的性能提升了
整整6倍

到目前为止还不错,如果你不在乎生成的随机数的概率分布或者你只想生成
[0, 2^n)
范围的随机数且这个n可以整除64,那么可以直接跳到下一节继续看了。

接着往下看的人肯定是希望不管在什么范围内都能生成概率均匀的随机数且尽量多利用已生成的随机bits的。但事情往往不尽人意,比如,
[0, 13)

[0, 7)
就是两个例子。前者右边界不是2的幂,后者虽然是2的幂但3不能整除64。

我们先从简单的问题开始解决,
[0, 13)
。受先表示数字12至少得用4个bit,4能整除64,所以我们还可以每4个连续的bit分割成一组,但这时概率分布均匀的条件是满足不了的。无法保证的原因很简单,和我们第一节里说的“rand7生成rand5”的情况一样,每个均匀分割出来的一组连续的bits的组合里有我们不需要的样本存在。处理这个情况的方法在第一节里已经有表述了,那就是拒绝采样。确定拒绝的范围也使用第一节说到的办法,注意到每一组bits能生成
[0, 16)
的随机数,再考虑到13本身是素数,这里只需要简单地把≥13的样本全部剔除即可。

所以代码变成了下面这样:

func getRandSplit(a []int) {
	if len(a) < 10 {
		panic("length wrong")
	}
	mask := uint64(0xf)
	count := 0
	for {
		n := rand.Uint64()
		for i := 0; i < 16; i++ {
			sample := int(n & mask)
			n >>= 4
			// 不符合要求后直接跳到下一组去
			if sample >= 13 {
				continue
			}
			// 这里也不需要mod
			a[count] = sample
			count++
			if count >= 10 {
				return
			}
		}
	}
}

如果一组不满足采样要求,我们就跳过直接去下一组,因此有可能16组里无法获得足够的随机数,因此我们得重新获取一次64位的随机数,然后再次进入分割计算。这么做会对性能产生一点点负面影响,但依旧很快:

goos: linux
goarch: amd64
pkg: benchrand
cpu: Intel(R) Core(TM) i5-10200H CPU @ 2.40GHz
BenchmarkGetRand-8              16242730                72.22 ns/op            0 B/op          0 allocs/op
BenchmarkGetRandSplit-8         37794038                31.81 ns/op            0 B/op          0 allocs/op

这时候性能就只提升了一倍。

上面那种情况还是最简单的,但
[0, 7)
就不好办了。首先表示6需要至少3个bit,而3不能整除64,其次6也不是2的幂。这个怎么处理呢?

有两个办法,核心思想都是拒绝采样,但拒绝的范围有区别。

第一个想法是,既然3不能整除64,那我们选个能被3整除的,这里是63,也就是说超过
2⁶³-1
的样本全部丢弃,然后把符合要求的样本按每连续的3bits进行分割。这样我们先保证了3bits分割出来的每一组都能均等的生成
[0, 8)
范围内的随机整数。现在问题转化成了“rand8怎么生成rand7”,这题我们会做而且做了好多回了,最终代码会是这样:

func getRandSplit(a []int) {
	if len(a) < 10 {
		panic("length wrong")
	}
	// 注意,mask现在只要三位也就是0b111了
	mask := uint64(0x7)
	count := 0
	for {
		n := rand.Uint64()
		// 先拒绝大于2⁶³-1的样本
		if n > 1<<63-1 {
			continue
		}
		for i := 0; i < 21; i++ {
			sample := int(n & mask)
			n >>= 3
			// 一组bits如果组合出来的数大于等于7也拒绝采样
			if sample > 6 {
				continue
			}
			// 这里是不用mod的,因为产生的sample本身只会在0-6之间
			a[count] = sample
			count++
			if count >= 10 {
				return
			}
		}
	}
}

代码变得很长也很复杂,而且需要两步拒绝采样,相对的我们一次也能分割出21组,比4bits的时候多了5组,所以难说性能下降还是不变,因此我们看看测试:

goos: linux
goarch: amd64
pkg: benchrand
cpu: Intel(R) Core(TM) i5-10200H CPU @ 2.40GHz
BenchmarkGetRand-8              16500700                73.77 ns/op            0 B/op          0 allocs/op
BenchmarkGetRandSplit-8         31098928                39.54 ns/op            0 B/op          0 allocs/op

确实慢了一些,但总体上还是提速了85%。

第二种想法只需要一步拒绝采样,既然3不能整除64,那么就找到一个离3最近的可以整除64且大于3的整数。在这里我们可以直接注意到4符合条件,实际开发中如果要找到任意符合条件的数,可以依赖一下线性探测。现在我们按连续的4位把64位随机整数分割,这样分割出来的每一组可以生成均匀分布在
[0, 16)
上的整数。然后问题变成了“从rand16生成rand7”。代码这样写:

func getRandSplit2(a []int) {
	if len(a) < 10 {
		panic("length wrong")
	}
	mask := uint64(0xf)
	count := 0
	for {
		n := rand.Uint64()
		for i := 0; i < 16; i++ {
			sample := int(n & mask)
			n >>= 4
			if sample > 13 {
				continue
			}
			// mod不能漏了,因为我们会产生大于等于7的结果
			a[count] = sample % 7
			count++
			if count >= 10 {
				return
			}
		}
	}
}

代码简单了,也只需要一步拒绝采样就行,但问题在于每一组的生成范围变大导致我们不得不使用取模操作。看看性能怎么样:

goos: linux
goarch: amd64
pkg: benchrand
cpu: Intel(R) Core(TM) i5-10200H CPU @ 2.40GHz
BenchmarkGetRand-8              16451838                75.86 ns/op            0 B/op          0 allocs/op
BenchmarkGetRandSplit-8         30802065                39.15 ns/op            0 B/op          0 allocs/op
BenchmarkGetRandSplit2-8        38995390                30.75 ns/op            0 B/op          0 allocs/op

想法2比想法1快了近20%,看来两步拒绝采样成了硬伤,不仅仅是因为多获取几次64位随机数更慢,多出来的一个if还可能会影响分支预测,即便最后我们多了5组可以采样也无济于事了。

所以,当你需要的随机数范围比较有限的时候,充分利用每一个bit是非常理想的性能提升手段。

带有权重的随机数

讨论了半天概率均匀分布的情况,但业务中还有一种常见场景:一组样本进行采样,既要有随机性,又要样本之间在统计上尽量满足某些比例关系。

这个场景我想大家最先想到的应该是抽奖。是的没错,这是带权重随机数的常见应用。但还有一个场景,一个负载均衡器连接着一组权重不同的服务器硬件,现在想要尽量按权重来分配链接,这时候带权重随机数就有用武之地了。

假设我们有样本
(1, 2, 3, 4)
四个整数,然后要按
1:2:3:4
的比例来随机生成样本,该怎么做呢?

按比例,我们可以得到1,2,3,4生成的概率是0.1,0.2,0.3,0.4,这些概率加起来是一定等于1的,所以我们不妨来想象有一个数轴上的0到1的区间,然后我们把这些比例“塞进”数轴里:

0.0   0.1   0.2   0.3   0.4   0.5   0.6   0.7   0.8   0.9   1.0
|-----|-----|-----|-----|-----|-----|-----|-----|-----|-----|

|_______________________|
            4           
                        |_____|
                           1  
                              |_________________|
                                       3        
                                                |___________|
                                                      2

我故意打乱了顺序,实际上顺序并不影响结果。每个样本可以有一个数轴上的范围,范围的长度之间也是符合比重的,因此当存在一个可以均匀生成
[0, 1)
之间所有实数的随机数生成器时,这个生成器生成的数落在哪个范围里,我们就选择生成这个范围对应的样本,举个例子,如果生成的实数落在
[0.0, 0.4)
这个区间里,那么就生成样本“4”,如果落在
[0.8, 1.0)
这个区间,就生成“2”。这样带权重的随机数就生成了。这个看上面那个图还是挺好理解的,我就不费笔墨去证明了。

但我不是很想用这个方法。为啥呢,因为你看到了,区间的左边是闭合的,这意味着要做浮点数的等值比较,虽然很简单,但我很懒不想多写,而且浮点数没法精确表示所有情况的比例导致我们区间的两端都有精度损失,这就需要考虑误差率,虽然通常这点精度损失带来的误差可以忽略不记(尤其是在统计意义上)但只是考虑到这点我就浑身难受了。

所以我要找一种不需要浮点数的解决方案。想出来其实不难,假设我们有0到9的整数,正好10个,现在样本“1”按比例需要占用其中1个样本,样本“4”按比例要占用其中四个样本,现在我们获得了一个能均匀生成0到9的整数的随机数生成器,那只要生成的随机数正好是样本“4”占用的那几个随机数我们就生成“4”,生成的随机数是样本“1”占用的那就生成“1”。可以看到只要占够一定数量的不同的样本,那么我们一样能生成带权重的随机数。

下面有几个问题,一是样本总数怎么确定,这个简单,每个比例当成整数相加即可,比如
1:2:3:4
就是
1+2+3+4=10

2:2:3:5
就是
2+2+3+5=12
,依此类推。如果比例是实数呢?
2.3 : 3.4 : 4.7
怎么办?这就要用到比例的性质了,等比扩大后比例不变,所以我们每个实数都乘以10,然后去掉小数点后的0全部当成整数,所以
23+34+47=104
,理论上任意比例都能这么整,不过整数最终有大小限制的,你总不能生成个随机数还用
big.Int
吧,所以注意总和别超过整数范围限制。

二是样本的范围怎么算,虽然我们只需要不相同的满足1里总数的离散的点就行,但为了方便计算我们还是选择连续的整数比较好,所以范围限定为
[0, sum-1]
。这样我们能直接利用
rand.Uint64N()
来生成需要的随机数生成器。

最后我们只要让样本按比例随机占领一些连续的整数就行了。而且我们只需要记录右边界就够了,我们从范围是
[0, n]
的第一个样本开始比较,如果生成器给出的随机数小于等于某个右边界,那它一定落在边界代表的样本上(因为是从最小的边界开始比较的,所以随机数必然不可能落在前一个范围的样本上)。

其实就是把连续的实数换成了离散的整数点罢了,换汤不换药。

搞明白思路后代码写起来就是快:

type WeightRandom[T comparable] struct {
	entries       []entry[T]
	upperBoundary uint64
}

type entry[T comparable] struct {
	value T
	end   uint64
}

首先是数据结构,
WeightRandom
是我们的带权重随机数生成器,
upperBoundary
是样本数量的总和,
entries
则是各个样本和样本占领的连续整数的右边界。

接着来看构造
WeightRandom
对象的方法:

func NewWeightRandom[T comparable](rndMap map[T]uint64) *WeightRandom[T] {
	var lowerBoundary uint64
	entries := make([]entry[T], 0, len(rndMap))
	for k, v := range rndMap {
		if v == 0 {
			panic("weight cannot be zero")
		}
		if lowerBoundary+v < lowerBoundary {
			panic("overflow")
		}
		lowerBoundary += v
		entries = append(entries, entry[T]{
			value: k,
			end:   lowerBoundary,
		})
	}
	slices.SortFunc(entries, func(a, b entry[T]) int {
		return cmp.Compare(a.end, b.end)
	})

	if len(entries) == 0 {
		panic("no valid sample")
	}

	return &WeightRandom[T]{
		entries:       entries,
		upperBoundary: lowerBoundary,
	}
}

lowerBoundary
用来统计有多少样本,我们最终选择了左闭右开的区间,这样方便算。
rndMap
的key是样本,value则是比例。当样本的范围计算并保存结束之后,我们需要按照右边界从小到大排序这些样本,因为后面的查找范围到样本的对应关系需要右边界满足从小到大的顺序。

最后是查找函数:

func (w *WeightRandom[T]) RandomGet() T {
	x := rand.Uint64N(w.upperBoundary)
	for i := range w.entries {
		if x < w.entries[i].end {
			return w.entries[i].value
		}
	}
	panic("not possible")
}

查找时先生成一个范围在
[0, upperBoundary)
之间的随机数,然后我们从最小的边界开始逐一比较,一旦发现比自己大的边界,那么就说明需要生成边界对应的样本。底部那句panic如字面意思,理论上是执行不到的,但go不知道,我们要么返回个空值要么panic,考虑到能走到这里那说明我们的程序或者标准库的代码有重大bug,panic了去调试才是比较好的办法。

根据upperBoundary的大小,实际上我们还能复用上一节充分利用每一个bit的办法,不需要每次都生成新的随机数,等分割出来的分组都消耗完了再生成,这样可以大大加速这个函数。不过为了通用性和尽量简化代码,我就不这样写了。

最后附加一个用例:

func main() {
	w := NewWeightRandom(map[string]uint64{
		"a": 15,
		"b": 30,
		"c": 45,
		"d": 60,
	})
	m := make(map[string]int, 4)
	const limit = 100_0000_0000
	for range limit {
		m[w.RandomGet()]++
	}
	for k, v := range m {
		fmt.Printf("key: %s, count: %d, p: %g\n", k, v, float64(v)/limit)
	}
}

我们按权重生成“abcd”四个字母,比例是
15:30:45:60
,简化一下就是
1:2:3:4
,所以理论上概率应该接近10%,20%,30%和40%。不过统计上的概率总是有点误差的,只要大致趋势接近于这个比例就行了。我们运行100亿次来看看结果:

$ go run main.go

key: b, count: 2000011606, p: 0.2000011606
key: a, count: 1000058297, p: 0.1000058297
key: d, count: 3999943022, p: 0.3999943022
key: c, count: 2999987075, p: 0.2999987075

非常符合预期。作为一项优化措施,我们可以利用类似二分查找的办法来定位样本,因为右边界本身是有序的,这可以显著改善在有大量样本时的性能表现。不过如果你的样本不超过10个的话我觉得现在这样的线性查找就足够了。

怎么说呢,简单粗暴但有效解决了问题。只是和浮点数的方案比还是有缺点的:

  1. 因为整数有范围限制,这导致了样本总量会被限制在一个比浮点数方案更小的范围内
  2. 同样因为整数大小限制,比例的选择范围会比浮点数更小

因为限制更少,所以在通用的工具库里用浮点数方案的人更多,但业务场景和通用工具库是不一样的,很多时候选整数方案也没啥问题,最终你应该选择一个符合业务需求并且自己和同事都看得懂的方案。

至于性能怎么样,浮点数方案的查找过程和整数方案是一样的,性能需要做一次完整的测试才能看出孰高孰低,我不好在这凭空幻想。当然测试我就不做了,我偷个懒。

随机数种子

“种子”其实是指一些伪随机数生成算法需要的一些初始状态,这些算法会根据初始状态来生成一系列有随机性的数值序列。

所以相同的种子通常会带来相同的序列,这时候虽然每个序列都有随机性,但两个序列之间是没有随机性的——有了其中一个序列就可以精准预测另一个序列的排列。具体表现很可能会是你编写了一个游戏,其中敌人会随机采取一些行动,然而因为种子没设置好,导致每次见到这个敌人的时候它都会采取一模一样的行动步骤,这样的游戏是极其无聊的。

不仅如此,产生相同的数值序列后还会带来无数安全问题。

种子通常只用设置一次,并且在程序第一次需要随机数的地方设置——理想情况是这样的,然而总是有倒霉蛋忘记了这一点,于是随机算法经常只能使用默认的低质量且相同的种子。所以比较现代的语言,比如go1.22和python3都选择了在程序刚开始运行的时候帮你自动设置种子。

此外我们还得担心一下种子的质量。

“种子”的值需要每次获取都不一样,符合要求的常见种子来源有以下几种:

  1. 使用系统时间。这个确实每次获取都不一样,获取速度也很快,是很多开发者和库的默认选项。但是系统时间是可以修改的,而且世界上还有闰秒这个麻烦东西,你意外得到相同的系统时间的概率其实不低。
  2. 使用随机设备,这些是软件模拟出来的随机数生成器,以Linux为例,有
    random

    urandom
    两种,其中
    random
    以操作系统及外部环境的噪音为数据源产生随机值,要求不高时我们可以认为这是“真随机”,当然它的生成速率是比较低的;另一个是
    urandom
    ,它会用外部噪音生成一个初始状态,然后基于这个状态用伪随机数算法快速产生随机值,因此它的生成速率高但质量低。一般使用urandom生成种子是够用的。
  3. 使用产生真实随机数的硬件,原理和软件模拟的类似,和大多数软件实现转硬件实现后会性能提升不同,TRNG反而可能会有更低的性能,但相比软件实现它可以更精细地收集环境里的杂音从而生成实验证明的不可预测的高质量的随机数。常见的TRNG生成速率从数百k/s到数兆/s的都有,但像
    /dev/random
    通常速率可以有数兆到数十兆/s。除了性能上的不稳定,不是所有设备都包含TRNG,这导致了适用面受限,所以直接用它的人不多。不过很多依赖高质量随机数的场景就不得不考虑TRNG了。
  4. 利用地址空间布局随机化。现代操作系统在加载程序后都会给程序的内存地址加点随机的偏移量,所以程序每次运行的时候获取的变量地址基本都是不同的。这个是成本极低的获取随机值的方法,几乎不花费运行时代价,谷歌的abseil库里就有很多用这个方法获取随机数种子的代码。然而,使用这个方法的前提是你的系统要支持地址空间布局随机化,其次系统加的随机偏移量质量要尚可,这两个我们都控制不了,我们只能相信常用操作系统都做到这几点了。另外,高权限的程序和用户始终能把一些代码写进有固定地址的地方,虽然这种操作正变得越来越难,但还不是完全不可能,所以需要高质量种子的时候这个方案通常不会被考虑(另一个原因是有的系统可以配置关闭随机化布局甚至根本不支持随机化)。
  5. auxiliary vector。Linux特有的,可以通过
    /proc/<pid>/auxv
    或者glibc的函数
    getauxval
    来获取。这个数组包含一系列操作系统和当前进程的信息,全部是操作系统在程序加载时写入的,Windows有类似的东西。这些数据中有些是不变的比如硬件信息和平台信息,有些则是完全随机的,比如其中有程序的入口地址和vDSO的地址,这些因为ASLR的缘故都是完全随机的,另外auxv里还有专门的随机值字段,这些信息加一起比单纯依赖ASLR能带来更强的不可预测。

原则就是尽量让预测结果的难度增加,最好是能做到完全不可预测。

那作为开发者我用啥呢?一般来说系统时间是够用了,自己写着完或者做些简单工具可以用这个,不过要记住系统时间是可修改的不可靠的。如果是库或者对随机性依赖比较重的比如游戏,
/dev/urandom
是个比较理想的选择。追求极致性能并且对种子质量要求没那么高时,像谷歌那样利用ASLR带来的随机值也是可以的。

实在有选择困难症的话,我们来看看别人是怎么做的。

golang和python3选择了那些源作为种子

golang实际上是先利用auxv,如果系统不支持,就回退到从urandom之类的随机设备读取随机值,这个也出问题了就使用系统时间:

// runtime/rand.go

// OS-specific startup can set startupRand if the OS passes
// random data to the process at startup time.
// For example Linux passes 16 bytes in the auxv vector.
var startupRand []byte

func randinit() {
	lock(&globalRand.lock)
	if globalRand.init {
		fatal("randinit twice")
	}

	seed := &globalRand.seed
	// 查看是否有auxv信息被系统写入
	if startupRand != nil {
		for i, c := range startupRand {
			seed[i%len(seed)] ^= c
		}
		clear(startupRand)
		startupRand = nil
	} else {
		// 先从urandom读取
		if readRandom(seed[:]) != len(seed) {
			// readRandom should never fail, but if it does we'd rather
			// not make Go binaries completely unusable, so make up
			// some random data based on the current time.
			readRandomFailed = true
			readTimeRandom(seed[:])
		}
	}
	globalRand.state.Init(*seed)
	clear(seed[:])
	globalRand.init = true
	unlock(&globalRand.lock)
}

这个是全局函数的设置,go还能自己创建
rand.Source
,这个的种子只能显式传进去,这时候传什么go就没法管了,灵活的同时牺牲了一定的安全性。

Python3则是先读取urandom,失败后会结合系统时间加当前进程pid来生成种子,这样比单使用系统时间要强:

// https://github.com/python/cpython/blob/main/Modules/_randommodule.c#L293
static int
random_seed(RandomObject *self, PyObject *arg)
{
    int result = -1;  /* guilty until proved innocent */
    PyObject *n = NULL;
    uint32_t *key = NULL;
    size_t bits, keyused;
    int res;

    // 参数为空的时候
	if (arg == NULL || arg == Py_None) {
       if (random_seed_urandom(self) < 0) {
            PyErr_Clear();

            /* Reading system entropy failed, fall back on the worst entropy:
               use the current time and process identifier. */
            if (random_seed_time_pid(self) < 0) {
                return -1;
            }
        }
        return 0;
    }

    // 参数不为空的时候根据参数生成种子

Done:
    Py_XDECREF(n);
    PyMem_Free(key);
    return result;
}

然后这个函数会被Random对象的
__init__
方法调用,如果初始化一个Random对象但不传seed参数,那么就会进行默认设置。而random模块里所有的方法其实都是由一个全局的
Random()
对象提供的,因为没传seed进去,所以代码里会自动设置seed:

# https://github.com/python/cpython/blob/main/Lib/random.py#L924
_inst = Random()
seed = _inst.seed
random = _inst.random
uniform = _inst.uniform
triangular = _inst.triangular
randint = _inst.randint
choice = _inst.choice
randrange = _inst.randrange
sample = _inst.sample
shuffle = _inst.shuffle
choices = _inst.choices
normalvariate = _inst.normalvariate
lognormvariate = _inst.lognormvariate
expovariate = _inst.expovariate
vonmisesvariate = _inst.vonmisesvariate
gammavariate = _inst.gammavariate
gauss = _inst.gauss
betavariate = _inst.betavariate
binomialvariate = _inst.binomialvariate
paretovariate = _inst.paretovariate
weibullvariate = _inst.weibullvariate
getstate = _inst.getstate
setstate = _inst.setstate
getrandbits = _inst.getrandbits
randbytes = _inst.randbytes

这样python就防止了用户忘记设置seed的问题。

总结

关于随机数要讲的暂时就这么多了,除了一丁点数值算法之外都是些比较浅显易懂的东西。

概率和统计对程序设计的影响是很大的,所以我觉得与其花时间看最近比较火的微分方程,稍微抽出点时间看看概率统计对自己的帮助可能更大。

最后,其实标准库还有各种第三方库已经贴心准备了几乎全套功能了,看懂文档就能放心用,而且我也更推荐用这些库,开源的且久经检验的代码始终是比自己闭门造车来的强。

参考

https://stackoverflow.com/questions/18394733/generating-a-random-number-between-1-7-by-rand5

https://en.wikipedia.org/wiki/Rejection_sampling

https://www.linuxquestions.org/questions/linux-kernel-70/what-does-proc-pid-auxv-mean-exactly-4175421876/

论文主要处理Vision Transformer中的性能问题,采用推理速度不同的级联模型进行速度优化,搭配层级间的特征复用和自注意力关系复用来提升准确率。从实验结果来看,性能提升不错

来源:晓飞的算法工程笔记 公众号

论文: Not All Images are Worth 16x16 Words: Dynamic Transformers for Efficient Image Recognition

Introduction


Transformers是自然语言处理 (NLP) 中占主导地位的自注意的模型,最近很多研究将其成功适配到图像识别任务。这类模型不仅在ImageNet上取得了SOTA,而且性能还能随着数据集规模的增长而不断增长。这类模型一般都先将图像拆分为固定数量的图像块,然后转换为1D token作为输入,拆分更多的token有助于提高预测的准确性,但也会带来巨额的计算成本(与token数成二次增长)。为了权衡性能和准确率,现有的这类模型都采用14x14或16x16的token数量。

论文认为不同图片之间存在相当大的差异,使用相同数量的token处理所有图片并不是最优的。最理想的做法应为每个输入专门配置token数量,这也是模型计算效率的关键。以T2T-ViT-12为例,官方推荐的14x14 token数仅比4x4 token数增加了15.9%(76.7% 对 60.8%)的准确率,却增加了8.5倍的计算成本(1.78G 对 0.21G)。也就是说,对“简单”图片使用14x14 token数配置浪费了大量计算资源,使用4x4 token数配置就足够了。

受此启发,论文提出了一种动态Vision Transformer(DVT)框架,能够根据每个图片自动配置合适的token数,实现高效计算。训练时使用逐渐增多的token数训练级联Transformer,测试时从较少的token数开始依次推理,得到置信度足够的预测即终止推理过程。通过自动调整token数,“简单”样本和“困难”样本的计算消耗将会不一样,从而显着提高效率。

另外,论文还设计了基于特征和基于关系的两种复用机制,减少冗余的计算。前者允许下游模型在先前提取的深度特征上进行训练,而后者允许利用上游模型中的自注意力关系来学习更准确的注意力图。

DVT是一个通用框架,可集成到大多数图像识别的Transformer模型中。而且可以通过简单地调整提前终止标准,在线调整整体计算成本,适用于计算资源动态波动或需要以最小功耗来实现特定性能的情况。从ImageNet和CIFAR的实验结果来看,在精度相同的情况下,DVT能将T2T-ViT的计算成本降低1.6-3.6倍,而在NVIDIA 2080Ti上的真实推理速度也与理论结果一致。

Dynamic Vision Transformer


Overview

  • Inference

DVT的推理过程如图2所示。对于每张测试图片,先使用少量1D token序列对其进行粗略表示,可通过直接使用分割图像块或利用如tokens-to-token模块之类的技术来实现,然后通过Vision Transformer对这些token进行快速预测。由于Transformer的计算消耗与token数量成二次增长,所以这个过程很快。最后基于预设的终止标准对预测结果进行快速评估,确定是否足够可靠。

如果预测未能满足终止标准,原始输入图像将被拆分为更多token,再进行更准确、计算成本更高的推理。每个token embedding的维度保持不变,只增加token数量,从而实现更细粒度的表示。此时推理使用的Vision Transformer与上一级具有相同架构,但参数是不同的。根据设计,此阶段在某些“困难”测试图片上权衡计算量以获得更高的准确性。为了提高效率,新模型可以复用之前学习的特征和关系。在获得新的预测结果后,同样根据终止标准进行判断,不符合则继续上述过程,直到结果符合标准或已使用最终的Vision Transformer。

  • Training

训练时,需保证DVT中所有级联Vision Transformer输出正确的预测结果,其优化目标为:

其中,
\((x, y)\)
为训练集
\(D_{train}\)
中的一个样本及其对应的标签,采用标准的交叉熵损失函数
\(L_{CE}(·)\)
,而
\(p_i\)
表示第
\(i\)
个模型输出的softmax预测概率。

  • Transformer backbone

DVT是一个通用且灵活的框架,可以嵌入到大多数现有的Vision Transformer模型(如ViT、DeiT和T2T-ViT)之中,提高其性能。

Feature and Relationship Reuse

DVT的一个重要挑战是如何进行计算的复用。在使用的具有更多token的下游Vision Transformer时,直接忽略之前模型中的计算结果显然是低效的。虽然上游模型的token数量较少,但也提取了对预测有价值的信息。因此,论文提出了两种机制来复用学习到的深度特征和自注意力关系,仅增加少量的额外计算成本就能显着提高准确率。

  • Background

介绍前,先重温一下Vision Transformer的基本公式。Transformer encoder由交替堆叠的多头自注意力(MSA)和多层感知器 (MLP)块组成,每个块的之前和之后分别添加了层归一化(LN)和残差连接。定义
\(z_l\in R^{N\times D}\)
表示第
\(l\)
层的输出,其中
\(N\)
是样本的token数,
\(D\)
是token的维度。需要注意的是,
\(N=HW+1\)
,对应
\(H\times W\)
图像块和可学习的分类token。假设Transformer共
\(L\)
层,则整个模型的计算可表示为:

得到最终的结果
\(z_L\)
后,取其中的分类token通过LN层+全连接层进行最终预测。这里省略了position embedding的细节,论文没有对其进行修改。

  • Feature reuse

DVT中的所有Transformer都具有相同的目标,即提取关键特征进行准确识别。 因此,下游模型应该在上游模型计算的深度特征的基础上学习才是最高效的,而不是从头开始提取特征。为此,论文提出了图3的特征复用机制,利用上游Transformer最后输出的结果
\(z^{up}_L\)
来生成下游模型每层的辅助embedding输入
\(E_l\)

\(f_l:\mathbb{R}^{N\times D}\to \mathbb{R}^{N\times D^{'}}\)
由LN+MLP(
\(\mathbb{R}^{D}\to \mathbb{R}^{D^{'}}\)
)开头,对上游模型输出进行非线性转换。转换后将结果reshape到原始图像中的相应位置,然后上采样并展平来匹配下游模型的token数量。一般情况下,使用较小的
\(D^{'}\)
以便快速生成
\(f_l\)

之后将
\(E_l\)
拼接到下游模型对应层的中间特征作为预测的先验知识,也就是将公式3替换为:

\(E_l\)
与中间特征
\(z^{'}_l\)
拼接,LN 的维度和MLP的第一层从
\(D\)
增加到
\(D+D^{'}\)
。 由于
\(E_l\)
是基于上游输出
\(z^{up}_L\)
生成的,token数少于
\(z^{'}_l\)
,它实际上为
\(z^{'}_l\)
中的每个token总结了输入图像的上下文信息。 因此,将
\(E_l\)
命名为上下文embedding。此外,论文发现不复用分类token对性能有提升,因此在公式5中将其填充零。

公式4和5允许下游模型在每层灵活地利用
\(z^{up}_L\)
内的信息,从而最小化最终识别损失,这种特征重用方式也可以认为隐式地扩大了模型深度。

  • Relationship reuse

Vision Transformer的关键在于自注意力模块能够整合整个图像的信息,从而有效地模拟图像中的长距离关系。通常情况下,模型需要在每一层学习一组注意力图来描述token之间的关系。除了上面提到的特征复用,论文认为下游模型还可以复用上游模型产生的自注意力图来进行优化。

定义输入特征
\(z_l\)
,自注意力模块先通过线性变换得到query矩阵
\(Q_l\)
、key矩阵
\(K_l\)
和value矩阵
\(V_l\)

其中,
\(W^Q_l\)

\(W^K_l\)

\(W^V_l\)
为权重矩阵。然后通过一个带有softmax的缩放点乘矩阵运算得到注意力图,最后根据注意力图来计算所有token的值:

其中,
\(d\)

\(Q\)

\(K\)
的点积结果维度,
\(A_l\in \mathbb{R}^{N\times N}\)
为注意力图。为了清楚起见,这省略了多头注意力机制的细节,多头情况下
\(A_l\)
包含多个注意力图。

对于关系复用,先将上游模型所有层产生的注意力图(即
\(A^{up}_l, l\in \{1,\cdots , L\}\)
)拼接起来:

其中,
\(N^{up}\)

\(N^{Att}_{up}\)
分别为上游模型中的toekn数和注意力图数,通常
\(N^{Att}_{up} = N^H L\)

\(N^H\)
是多头注意力的head数,
\(L\)
是层数。

下游的模型同时利用自己的token和
\(A^{up}\)
来构成注意力图,也就是将公式7替换为:

其中
\(r_l(\cdot)\)
是一个转换网络,整合
\(A^{up}\)
提供的信息来细化下游注意力图
\(A_l\)

\(r_l(\cdot)\)
的架构如图5所示,先进行非线性MLP转换,然后上采样匹配下游模型的注意力图大小。

公式9虽然很简单,但很灵活。有两个可以魔改的地方:

  • 由于下游模型中的每个自注意力模块可以访问上游模型的所有浅层和深层的注意力头,可以尝试通过可学习的方式来对多层的注意力信息进行加权整合。
  • 新生成的注意力图和复用注意力图直接相加,可以尝试通过可学习的方式来对两者加权。

还需要注意的是,
\(r_l(\cdot)\)
不能直接使用常规上采样操作。如图5所示,假设需要将
\(HW\times HW\)
(
\(H =W = 2\)
)的注意力图映射上采样到
\(H^{'}W^{'}\times H^{'}W^{'}\)
(
\(H^{'} =W^{'} = 3\)
)的大小。由于每一行对应单个token与其他
\(H\times W\)
个token的关系,直接对注意力图上采样会引入混乱的数据。因此,需要先将行reshape为
\(H\times W\)
,然后再缩放到
\(H^{'}W^{'}\times H^{'}W^{'}\)
,最后再展平为
\(H^{'}W^{'}\)
向量。

  • Adaptive Infernece

如前面所述,DVT框架逐渐增加测试样本的token数量并执行提前终止,“简单”和“困难”图像可以使用不同的token数来处理,从而提高了整体效率。对于第
\(i\)
个模型产生的softmax预测
\(p_i\)
,将
\(p_i\)
的最大项
\(max_j p_{ij}\)
与阈值
\({\mu}_{i}\)
进行比较。如果
\(max_j p_{ij}\ge {\mu}_{i}\)
,则停止并采用
\(p_i\)
作为输出。否则,将使用更多token数更多的下游模型继续预测直到最后一个模型。

阈值
\(\{\mu_1, \mu_2, \cdots\}\)
需要在验证集上求解。假设一个计算资源有限的批量数据分类场景,DVT需要在给定的计算预算
\(B > 0\)
内识别一组样本
\(D_{val}\)
。定义
\(Acc(D_{val}, \{\mu_1, \mu_2, \cdots\})\)

\(FLOPs(D_{val}, \{\mu_1, \mu_2, \cdots\})\)
为数据集
\(D_{val}\)
上使用阈值
\(\{\mu_1, \mu_2, \cdots\}\)
时的准确度和计算成本,最优阈值可以通过求解以下优化问题得到:

由于公式10是不可微的,论文使用遗传算法解决了这个问题。

Experiment


ImageNet上的性能对比。

推理性能对比。

CIFAR上对比DVT在不同模型规模的性能。

在ImageNet上与SOTA vision transformer提升方法的性能对比。

基于DeiT的DVT性能对比。

复用机制的对比实验。

与类似的提前退出方法的性能对比。

复用机制提升的性能与计算量。

复用机制实现细节的对比实验。

难易样本的例子以及数量分布。

不同终止标准的性能对比。

与自适应深度方法进行性能对比,自适应方法是在模型的不同位置插入分类器。

Conclusion


论文主要处理Vision Transformer中的性能问题,采用推理速度不同的级联模型进行速度优化,搭配层级间的特征复用和自注意力关系复用来提升准确率。从实验结果来看,性能提升不错。



如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

work-life balance.