2024年1月

使用Docker部署Tomcat

1. 获取镜像

docker pull tomcat:8.5.38

docker images

2. 第一次启动tomcat

该步骤作用:是为了拷贝容器中Tomcat中的conf下配置文件和webapps下的管理页面,用于后面自定义Tomcat服务器配置或者部署应用。

第一次启动:

docker run -d --name tomcat8 -p 8081:8080 tomcat:8.5.38

在宿主机创建文件夹:

mkdir -p /home/tomcat8

从容器中拷贝配置文件和应用到宿主机中:

docker cp tomcat8:/usr/local/tomcat/conf/ /home/tomcat8/
docker cp tomcat8:/usr/local/tomcat/webapps/ /home/tomcat8/

3.带参数启动

此时,如果直接带参数启动时,会报有重名的tomcat8容器冲突,报以下类似错误:

Error response from daemon: Conflict. The container name "/tomcat8" is already in use by container "f087d304d5bffa1becc20b9c3668d634caf7bc001fc7ce89bdf5c5b43e3e869e". You have to remove (or rename) that container to be able to reuse that name.

这时候,需要先将容器tomcat8先停止,再删除:

docker stop tomcat8
docker remove tomcat8

最后再使用带参数的命令启动:

docker run --name tomcat8 -p 1808:8080  \
-v /home/tomcat8/conf:/usr/local/tomcat/conf \
-v /home/tomcat8/webapps:/usr/local/tomcat/webapps \
-v /home/tomcat8/logs:/usr/local/tomcat/logs \
-v /etc/localtime:/etc/localtime:ro \
-v /serverdata/conf/tsmsg/:/serverdata/conf/tsmsg \
-e TZ=Asia/Shanghai \
-d tomcat:8.5.38
-v /serverdata/conf/tsmsg/:/serverdata/conf/tsmsg 

这个为应用本身的配置文件映射,根据实际需求来设置配置文件地址。

-v /etc/localtime:/etc/localtime:ro 

这个命令的作用是将宿主机上的时区设置文件(
/etc/localtime
)挂载到Docker容器中相同的位置,使容器能够使用与宿主机相同的时区设置。

具体来说,这个命令的各部分意义如下:

  • -v
    :这是Docker命令用来指定挂载卷的标志(Volume)。
  • /etc/localtime
    :这是宿主机上时区文件的路径。该文件包含了当前时区的信息。
  • :/etc/localtime
    :这是容器内部的挂载点,即容器内与宿主机
    /etc/localtime
    文件对应的路径。
  • :ro
    :这表示以只读方式挂载(Read-Only)。容器可以读取该文件,但无法修改它,这样可以防止容器的操作影响宿主机的时区设置。

这样做的好处是保证容器内的应用程序可以正确处理与时区相关的操作,例如记录日志的时间戳、执行定时任务等。这对于需要时区一致性的应用程序来说非常重要。例如,如果你在
上海
运行宿主机,而你的容器也应当使用东京的时区,通过这种方式挂载
/etc/localtime
,你的容器将会自动采用
上海
的时区,而无需在容器内单独配置时区。

4.查看tomcat日志

  • 通过看宿主机
    /home/tomcat8/logs
    下的日志

  • 通过
    docker logs -f tomcat8
    查看日志

5.时区问题

如果你的应用用的是Oracle数据,可能会遇到这样的错:

Caused by: java.sql.SQLException: ORA-00604: error occurred at recursive SQL level 1
ORA-01882: timezone region not found

这个错误信息是由Oracle数据库返回的,表示有两个错误:

  1. ORA-00604
    : 这个错误表明在递归SQL级别(即Oracle内部操作)发生了错误。递归SQL是Oracle在处理用户SQL语句时内部自动执行的SQL语句,常见于触发器、登录逻辑等。
  2. ORA-01882
    : 这个错误表明找不到指定的时区地区。当数据库或JDBC驱动试图访问特定的时区信息,而该信息在数据库的时区文件中不存在或未被识别时,就会出现这个错误。

这两个错误一起出现可能意味着在执行初始化会话时,比如在用户登录时设置会话的时间区域,Oracle发现它不能识别或找到该时区。

这时候你会搜索到要同步时区,需先设置好宿主机的时区:

timedatectl set-timezone Asia/Shanghai

再在容器启动时添加

-v /etc/localtime:/etc/localtime:ro 

或许,会发现宿主机的时间不对,可用

date -s "20240131 11:31:00"

更改一下宿主机的时间

前面的文章(
飞桨paddlespeech语音唤醒推理C定点实现
)讲了INT16的定点实现。因为目前商用的语音唤醒方案推理几乎都是INT8的定点实现,于是我又做了INT8的定点实现。

实现前做了一番调研。量化主要包括权重值量化和激活值量化。权重值由于较小且均匀,还是用最大值非饱和量化。最大值法已不适合8比特激活值量化,用的话误差会很大,识别率等指标会大幅度的降低。激活值量化好多方案用的是NVIDIA提出的基于KL散度(Kullback-Leibler divergence)的方法。我也用了这个方法做了激活值的量化。这个方法用的是饱和量化。下图给出了最大值非饱和量化和饱和量化的区别。

从上图看出,最大值非饱和量化时,把绝对值的最大值|MAX|量化成127,|MAX|/127就是量化scale。激活值的分布范围一般都比较广, 这种情况下如果直接使用最大值非饱和量化, 就会把离散点噪声给放大从而影响模型的精度,最好是找到合适的阈值|T|,将|T|/127作为量化scale,把识别率等指标的降幅控制在一个较小的范围内,这就是饱和量化。KL散度法就是找到这个阈值|T|的一种方法,已广泛应用于8比特量化的激活值量化中。

KL散度又称为相对熵(relative entropy),是描述两个概率分布P和Q差异的一种方法。 KL散度值越小,代表两种分布越相似,量化误差越小;反之,KL散度值越大,代表两种分布差异越大,量化误差越大。 把KL散度用在激活值的量化上就是来衡量不同的INT8分布与原来的FP32分布之间的差异程度。KL散度的公式如下:

其中P,Q分别称为实际分布和量化分布, KL散度越小, 说明两个分布越接近。

使用KL散度方法前需要做如下准备工作:

1,从验证集选取一个子集。这个子集应该具有代表性,多样性。

2,把这个子集输入到模型进行前向推理, 并收集模型中各个Layer的激活值。

对于每层激活值,寻找阈值的步骤如下:

1,  用直方图将激活值分成N个bin(NVIDIA用的是2048), 每个bin内的值表示在此bin内激活值的个数,从而得到参考样本。

2,  不断地截断参考样本,长度从128开始到N, 截断区外的值加到截断样本的最后一个值之上,从而得到分布P。求得分布P的概率分布。

3,  创建分布Q,其元素的值为截断样本P的int8量化值, 将Q样本长度拓展到和原样本P具有相同长度。求得Q的概率分布 并计算P、Q的KL散度值。

4,  循环步骤2和3, 就能不断地构造P和Q并计算相对熵,最后找到最小(截断长度为M)的相对熵,阈值|T|就等于(M + 0.5)*一个bin的长度。|T|/127就是量化scale,根据这个量化scale得到激活值的量化值。

实现前读了腾讯ncnn的INT8定点实现,看有什么可借鉴的。 发现它不是一个纯定点的实现,即里面有部分是float的,当时觉得里面最关键的权重和激活值都是定点运算了,部分浮点运算可以接受, 我也先做一个非纯定点的实现,把参数个数较少的bias用浮点表示。 接下来就开始做INT8的定点实现了,还是基于不带BN的浮点实现(
飞桨paddlespeech语音唤醒推理C浮点实现
)。依旧像INT16定点实现时那样,一层一层的去调,评估指标还是欧氏距离。调试时还是用一个音频文件去调。方便调试出问题时找到原因以及稳妥起见,我将INT8的定点化分成3步来做。

1,depthwise以及pointwise等卷积函数的激活值数据以及参数等均是用float的(即函数参数相对浮点实现不变),在函数内部根据激活值和权重参数量化scale将激活值和权重量化为INT8,然后做定点运算。做完定点运算后再根据激活值和权重参数量化scale将输出的激活值反量化为float值。每层算完后结果都会去跟浮点实现做比较,用欧氏距离去评估。只有欧氏距离较小才算OK。

2,权重参数的量化事先做好。将上面第一步函数的参数中权重参数从float变为int8。在函数里根据激活值的量化scale只做激活值的量化。做完定点运算后再根据激活值和权重参数量化scale将输出的激活值反量化为float值。每层算完后结果都会去跟浮点实现做比较,用欧氏距离去评估。只有欧氏距离较小才算OK。

3,将上面第二步函数的参数中激活值参数也从float变为int8,这样激活值参数和权重参数就都是INT8。函数中权重和激活值就没有量化过程只有定点运算了。激活值得到后再根据当前层和下一层的激活值量化scale重量化为下一层需要的INT8值。需要注意的是在用欧氏距离评估每一层时要把激活值的INT8值转换为float值,因为评估时是与浮点实现作比较。

经过上面三步后一个不是纯的INT8的定点实现就完成了。以depthwise卷积函数为例来看看卷积层的处理:

从函数实现可以看出,偏置bias未做量化,是浮点参与运算的,权重和激活值做完定点乘累加后结果再转回浮点与bias做加法运算,最后做重量化把激活值结果变成INT8的值给下层使用。Input_scale/output_scale/weight_scale都是事先算好保存在数组里,当前层的output_scale就是下一层的Input_scale。

等模型调试完成后依旧是在INT16实现用的那个大的数据集(有两万五千多音频文件)上对INT8定点实现做全面的评估,看唤醒率和误唤醒率的变化。跟INT16实现比,唤醒率下降了0.9%,误唤醒率上升了0.6%。说明INT8定点化后性能没有出现明显的下降。

INT8定点实现是在PC上调试的,但我们最终是要用在audio DSP(ADSP,主频只有200M)上,我就在ADSP上搭了个KWS的DEMO,重点关注在模型上。试验下来发现运行一次模型推理(上面的INT8实现)需要近1.2秒,这是没办法部署的,需要优化。调查后发现很少的浮点运算却花了很长的时间。我们用的ADSP没有FPU(浮点运算单元),全是用软件来做浮点运算的,因此要把上面实现里的浮点运算全部改成定点的,主要包括bias以及各种scale的量化。考虑到模型中bias参数个数较少以及保证精度,我用INT32对bias以及scale做量化。看了这几种值的绝对值最大值后,简单起见,确定Q格式均为Q6.25。在卷积函数中,input_scale和weight_scale总是相乘后使用,因此可以看成一个值,相乘后再去做量化。最终一个纯定点的depthwise 卷积函数如下:

再去用那个大数据集(有两万五千多音频文件)上对INT8纯定点实现做全面的评估,看唤醒率和误唤醒率的变化。跟不是纯的INT8实现比,唤醒率和误唤醒率均没什么变化。再把这个纯定点的模型在ADSP上跑,做完一次推理用了不到400ms的时间。这样一个纯定点的INT8实现就完成了。然而这只是一个base,后面还需要继续优化,把运行时间降下来。事后想想如果模型运行在主频高的处理器上(如ARM),推理中有少部分浮点运算是可以的,如果运行在主频低的处理器上(如我上面说的ADSP,只有200M),且没有FPU,模型推理一定要是全定点的实现。

element的el-image组件支持大图预览模式,但需要和小图模式配合使用,项目中刚好有需求需要直接使用大图预览并且需要支持图片的动态加载,研究了一下el-image组件的源码发现el-image组件是通过引用image-viewer组件实现的大图预览的,刚好可以利用一下!

image

image-viewer属性

urlList
: 图片列表,数组类型
onSwitch
: 图片切换事件
onClose
: 关闭事件
initialIndex
: 图片预览初始图片index
zIndex
:设置图片预览的 z-index

源码

  props: {
    urlList: {
      type: Array,
      default: () => []
    },
    zIndex: {
      type: Number,
      default: 2000
    },
    onSwitch: {
      type: Function,
      default: () => {}
    },
    onClose: {
      type: Function,
      default: () => {}
    },
    initialIndex: {
      type: Number,
      default: 0
    },
    appendToBody: {
      type: Boolean,
      default: true
    },
    maskClosable: {
      type: Boolean,
      default: true
    }
  }

项目使用

页面需要单独引入
image-viewer
组件,打开大图预览时通过接口默认从后台获取2张图片,然后再利用图片切换事件
onSwitch
在浏览第二张图片时从后台获取第3张图片,依此类推以实现图片的动态加载。
小贴士
: 项目中也要求图片要通过 token 访问,于是后台接口读取图片返回图片流,前端使用
URL.createObjectURL
生成图片临时地址即可。

代码

<template>
  <basic-container>
    <el-image-viewer
      v-if="imgViewerVisible"
      :on-close="closeImgViewer"
      :onSwitch="switchImage"
      :url-list="imageList" />
  </basic-container>
</template>

<script>

export default {
  data () {
    return {
      imageList: [],
      imgViewerVisible: false
    }
  },
  components: {
    'el-image-viewer': () => import('element-ui/packages/image/src/image-viewer')
  },
  methods: {
   //打开大图预览
    view (row) {
      this.imageList = []
      this.image(row, 0)
    },
	//通过接口获取图片列表,默认获取两张
    image (row, index) {
      let query = {}
      query.page = index + 1
      preview(query).then((response) => {
	    //后台返回图片流,利用createObjectURL生成临时对象地址
        const blob = new Blob([response], {
          type: 'image/jpeg'
        })
        this.imgViewerVisible = true
        this.imageList.push(window.URL.createObjectURL(blob))
        if (index === 0 && row.pageCount > 1) {
          this.image(row, 1)
        }
        console.log(this.imageList.length)
      }).catch((response) => {
        console.error('预览出错', response)
      })
    },
	//关闭大图预览
    closeImgViewer () {
      this.imgViewerVisible = false
    },
	//图片切换事件,浏览第二张时获取第三张,依此类推
    switchImage (index) {
      if (index > this.index && this.imageList.length - 1 === index && this.imageList.length < this.currentRow.pageCount) {
        this.image(this.currentRow, this.imageList.length)
      }
      this.index = index
    }
  }
}
</script>

进阶优化

image-viewer
组件有个小问题,大图预览模式图片缩放时页面如果有滚动条也会跟着滚动体验不太好,可以在打开大图模式时禁用页面滚动,关闭大图模式再启用页面滚动

代码

    //打开大图预览
    view (row) {
	  this.disableMove()
      this.imageList = []
      this.image(row, 0)
    },
    //关闭大图预览
    closeImgViewer () {
      this.imgViewerVisible = false
	  this.enableMove()
    },
    disableMove () {
      const m = (e) => { e.preventDefault() }
      document.body.style.overflow = 'hidden'
      document.addEventListener('touchmove', m, false)
    },
    enableMove () {
      const m = (e) => { e.preventDefault() }
      document.body.style.overflow = 'auto'
      document.removeEventListener('touchmove', m, true)
    }

前言

OpenCV是一个基于Apache2.0许可(开源)发行的跨平台计算机视觉和机器学习软件库,它具有C++,Python,Java和MATLAB接口,并支持Windows,Linux,Android和Mac OS。OpenCvSharp是一个OpenCV的 .Net wrapper,应用最新的OpenCV库开发,使用习惯比EmguCV更接近原始的OpenCV,该库采用LGPL发行,对商业应用友好。

@

1. 项目环境

  • 编码环境:Visual Studio Code
  • 程序框架:.NET 6.0

目前在Linux上使用C#语言官方提供了
Visual Studio Code
平台,所以在此处我们演示使用
Visual Studio Code
进行演示。而代码的运行与配置使用
dotnet
指令实现。

关于
Visual Studio Code
以及
.NET
的安装方式可以参考一下官方教程:

在 Linux 上安装 .NET
:由于Linux系统环境类型较多,所以可以根据官方提供的教程并根据自己的系统安装即可;

Visual Studio Code on Linux
:大家可以根据自己的环境进行安装。

2. 创建控制台项目

此处使用
dotnet
指令创建新项目,在
Visual Studio Code
的终端中输入一下指令:

dotnet new console --framework net6.0 --use-program-main -o test_opencvsharp

如下图所示,在终端中输入以下指令后,会自动创建新的项目以及项目文件夹。
image

在创建好项目后,我们使用vscode打开,输入以下指令,如下图所示:

test_opencvsharp
code .

image

3. 添加 Nuget Package 程序包

OpenCvSharp4是一个可以跨平台使用的程序包,并且官方也提供了编译好的程序包,用户可以根据自己的平台进行安装。在Linux上,主要需要安装一下两个包,分别是OpenCvSharp4的官方程序包以及OpenCvSharp4的运行依赖包。

dotnet add package OpenCvSharp4
dotnet add package OpenCvSharp4_.runtime.ubuntu.20.04-x64

依次输入指令后输出如下图所示:
image
image

安装完上面两个安装包后,项目的配置的文件中会增加下面两个配置。

<Project Sdk="Microsoft.NET.Sdk">

  <PropertyGroup>
    <OutputType>Exe</OutputType>
    <TargetFramework>net6.0</TargetFramework>
    <ImplicitUsings>enable</ImplicitUsings>
    <Nullable>enable</Nullable>
  </PropertyGroup>

  <ItemGroup>
    <PackageReference Include="OpenCvSharp4" Version="4.9.0.20240103" />
    <PackageReference Include="OpenCvSharp4_.runtime.ubuntu.20.04-x64" Version="4.9.0.20240103" />
  </ItemGroup>

</Project>

接下来运行
dotnet run
,检验项目中是否包含所需要的配置文件:
OpenCvSharp.dll

runtimes/ubuntu.20.04-x64/native/
。打开项目运行生成的文件夹
bin/{build_config}/{dotnet_version}/
,在本项目中是
bin/Debug/net6.0/
文件夹,如下图所示:

image

可以看出,在程序运行后,安装的程序包中所有项目都已经加载到当前项目中,如果出现缺失,就需要找到程序包位置,将该文件复制到指定路径。

5. 安装依赖项目

在上面的测试中,并为使用到安装的
OpenCvSharp4
,因此运行并未出现其他错误,如果主机电脑之前没有安装使用过
OpenCV
,所以第一次使用需要配置依赖项目。
首先第一步检查一下缺少什么依赖项,在终端中输入以下指令:

ldd libOpenCvSharpExtern.so

image

如上图所示,经过
ldd
检测后,发现存在未安装的依赖,接下爱就是安装相应的依赖项,首先是解决
tesseract
缺少,在终端输入以下指令:

sudo apt install tesseract-ocr

安装完成后再进行依赖项检测,如下图所示:
image

可以看出,经过安装后,该依赖项已经可以检测到,接下来就是安装其他依赖项,依次输入以下指令即可:

sudo apt install libdc1394-dev
sudo apt install libavcodec-dev 
sudo apt install libavformat-dev
sudo apt install libswscale-dev
sudo apt install libopenexr-dev

最后,安装完成后,在进行检测,如下图所示,可以看出,目前已经成功检测到所有依赖项,程序就可以正常使用了。
image

4. 测试应用

最后我们编写项目代码进行测试,如下面代码所示:

using System;
using OpenCvSharp;
namespace test_opencvsharp 
{
    internal class Program
    {
        static void Main(string[] args)
        {
            Mat image = Cv2.ImRead("image.jpg");
            Mat image2=new Mat();
            if (image!=null)
            {
                Console.WriteLine("srcImg is OK!");
            }
            Console.WriteLine("图像的宽度是:{0}",image.Rows);
            Console.WriteLine("图像的高度是:{0}", image.Cols);
            Console.WriteLine("图像的通道数是:{0}", image.Channels());
            Cv2.ImShow("src", image);
            Cv2.CvtColor(image, image2, ColorConversionCodes.RGB2GRAY);//转为灰度图像
            Cv2.ImShow("src1", image2);
            Cv2.WaitKey(0);
            Cv2.DestroyAllWindows();//销毁所有窗口
        }
    }
}

项目代码运行后,最后呈现效果如下图所示:

image

5. 总结

在本次项目中,我们成功实现了在Linux上使用OpenCvSharp,并成功配置了OpenCvSharp依赖库,实现了在.NET 6.0环境下使用C#语言调用OpenCvSharp库,实现的图片数据的读取以及图像色彩转换,并进行了图像展示。

1. 背景

最近本qiang~老看到一些关于大语言模型的DPO、RLHF算法,但都有些云里雾里,因此静下心来收集资料、研读论文,并执行了下开源代码,以便加深印象。

此文是本qiang~
针对大语言模型的DPO算法的整理,包括原理、流程及部分源码

2. DPO vs RLHF

上图左边是RLHF算法,右边为DPO算法,两图的差异对比即可体现出DPO的改进之处。

1. RLHF算法包含奖励模型(reward model)和策略模型(policy model,也称为演员模型,actor model),基于偏好数据以及强化学习不断迭代优化策略模型的过程。

2. DPO算法不包含奖励模型和强化学习过程,直接通过偏好数据进行微调,将强化学习过程直接转换为SFT过程,因此整个训练过程简单、高效,主要的改进之处体现在于损失函数。

PS:

1. 偏好数据,可以表示为三元组(提示语prompt, 良好回答chosen, 一般回答rejected)。论文中的chosen表示为下标w(即win),rejected表示为下标l(即lose)

2. RLHF常使用PPO作为基础算法,整体流程包含了4个模型,且通常训练过程中需要针对训练的actor model进行采样,因此训练起来,稳定性、效率、效果不易控制。

1) actor model/policy model: 待训练的模型,通常是SFT训练后的模型作为初始化

2) reference model: 参考模型,也是经SFT训练后的模型进行初始化,且通常与actor model是同一个模型,且模型冻结,不参与训练,其作用是在强化学习过程中,保障actor model与reference model的分布差异不宜过大。

3) reward model: 奖励模型,用于提供每个状态或状态动作对的即时奖励信号。

4) Critic model: 作用是估计状态或状态动作对的长期价值,也称为状态值函数或动作值函数。

3. DPO算法仅包含RLHF中的两个模型,即演员模型(actor model)以及参考(reference model),且训练过程中不需要进行数据采样。

4. RLHF可以参考附件中的引文

3. DPO的损失函数

如何将RLHF的Reward model过程简化为上式,作者花了大量篇幅进行了推导,感兴趣的读者可以参考附件DPO的论文。

DPO算法的目的是最大化奖励模型(此处的奖励模型即为训练的策略),使得奖励模型对chosen和rejected数据的差值最大,进而学到人类偏好

上式的后半部分通过对数函数运算规则,可以进行如下转化。

转化后的公式和源代码中的计算函数中的公式是一致的。

其中左半部分是训练的policy模型选择chosen优先于rejected,右半部分是冻结的reference模型选择chosen优先于rejected,二者的差值可类似于KL散度,保障actor模型的分布与reference模型的分布不会有较大的差异。

4. 微调流程

上图展示了DPO微调的大致流程,其中Trained LM即为策略模型,Frozen LM即为参考模型,二者均是先进行SFT微调得到的模型进行初始化,其中Trained LM需要进行训练,Frozen LM不参与训练。

两个模型分别针对chosen和rejected进行预测获取对应的得分,再通过DPO的损失函数进行损失计算,进而不断的迭代优化。

5. 源码

源码参考代码:
https://github.com/eric-mitchell/direct-preference-optimization

5.1 DPO损失函数

1 defpreference_loss(policy_chosen_logps: torch.FloatTensor,2 policy_rejected_logps: torch.FloatTensor,3 reference_chosen_logps: torch.FloatTensor,4 reference_rejected_logps: torch.FloatTensor,5 beta: float,6                     label_smoothing: float = 0.0,7                     ipo: bool =False,8                     reference_free: bool = False) ->Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:9     #policy_chosen_logps: 训练模型对于chosen经过log后logits
10     #policy_rejected_logps: 训练模型对于rejected经过log后logits
11     #reference_chosen_logps: 训练模型对于chosen经过log后logits
12     #reference_rejected_logps: 训练模型对于rejected经过log后logits
13     #beta: policy和reference的差异性控制参数
14     
15     #actor模型选择chosen优先于rejected
16     pi_logratios = policy_chosen_logps -policy_rejected_logps17     #reference模型选择chosen优先于rejected
18     ref_logratios = reference_chosen_logps -reference_rejected_logps19 
20     ifreference_free:21         ref_logratios =022     
23     #差值可类似于KL散度,保障actor模型的分布与reference模型的分布不会有较大的差异
24     logits = pi_logratios - ref_logratios  #also known as h_{\pi_\theta}^{y_w,y_l}
25 
26     ifipo:27         losses = (logits - 1/(2 * beta)) ** 2  #Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
28     else:29         #Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
30         #label_smoothing为0,对应的DPO论文的算法
31         losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) *label_smoothing32     
33     #chosen和rejected的奖励
34     chosen_rewards = beta * (policy_chosen_logps -reference_chosen_logps).detach()35     rejected_rewards = beta * (policy_rejected_logps -reference_rejected_logps).detach()36 
37     return losses, chosen_rewards, rejected_rewards

5.2 批次训练过程

1 def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True):2     """Compute the SFT or DPO loss and other metrics for the given batch of inputs."""
3 
4     if loss_config.name in {'dpo', 'ipo'}:5         #policy模型针对chosen和rejected进行预测
6         policy_chosen_logps, policy_rejected_logps =self.concatenated_forward(self.policy, batch)7 with torch.no_grad():8             #reference模型针对chosen和rejected进行预测
9             reference_chosen_logps, reference_rejected_logps =self.concatenated_forward(self.reference_model, batch)10 
11         if loss_config.name == 'dpo':12             loss_kwargs = {'beta': loss_config.beta, 'reference_free': loss_config.reference_free, 'label_smoothing': loss_config.label_smoothing, 'ipo': False}13         elif loss_config.name == 'ipo':14             loss_kwargs = {'beta': loss_config.beta, 'ipo': True}15         else:16             raise ValueError(f'unknown loss {loss_config.name}')17         #损失计算
18         losses, chosen_rewards, rejected_rewards =preference_loss(19             policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, **loss_kwargs)20 
21         reward_accuracies = (chosen_rewards >rejected_rewards).float()22 
23     elif loss_config.name == 'sft':24         policy_chosen_logits = self.policy(batch['chosen_input_ids'], attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32)25         policy_chosen_logps = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'], average_log_prob=False)26 
27         losses = -policy_chosen_logps28 
29     return losses.mean()

5.3 LM的交叉熵计算

1 def _get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False) ->torch.FloatTensor:2     #经模型后的logits进行批量计算logps
3     
4     assert logits.shape[:-1] ==labels.shape5     
6     #基于先前的token预测下一个token
7     labels = labels[:, 1:].clone()8     logits = logits[:, :-1, :]9     loss_mask = (labels != -100)10 
11     #dummy token; we'll ignore the losses on these tokens later
12     labels[labels == -100] =013     
14     #交叉熵函数
15     per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)16 
17     ifaverage_log_prob:18         return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)19     else:20         return (per_token_logps * loss_mask).sum(-1)

5.4 其他注意

1. hugging face设置代理

源码会从hugging face中下载英文语料和模型,由于网络限制,因此设置代理映射,将HF_ENDPOINT设置为https://hf-mirror.com,即设置: os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

2. 如果仅想要熟悉DPO整体流程,可以下载较小的生成式模型,如BLOOM 560M,GPT2等

6. 总结

一句话足矣~

本文主要针对大语言模型的DPO算法的整理,包括原理、流程及部分源码。

此外,建议大家可以针对源码进行运行,源码的欢迎大家一块交流。

7. 参考

(1) RLHF:
https://blog.csdn.net/v_JULY_v/article/details/128579457

(2) DPO论文: https://arxiv.org/pdf/2305.18290v2.pdf

(3) DPO代码: https://github.com/eric-mitchell/direct-preference-optimization

(4) DPO理解1:
https://medium.com/@joaolages/direct-preference-optimization-dpo-622fc1f18707

(5) DPO理解2: https://zhuanlan.zhihu.com/p/669825918