2024年7月

论文提出了经典的Vision Transormer模型Swin Transformer,能够构建层级特征提高任务准确率,而且其计算复杂度经过各种加速设计,能够与输入图片大小成线性关系。从实验结果来看,Swin Transormer在各视觉任务上都有很不错的准确率,而且性能也很高

来源:晓飞的算法工程笔记 公众号

论文: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

Introduction


长期以来,计算机视觉建模一直由卷积神经网络(CNN)主导。从AlexNet在ImageNet中的革命性表现开始,通过更大的规模、更广泛的连接以及更复杂的卷积形式逐级演变出越来越强大的CNN架构。另一方面,自然语言处理(NLP)网络架构的演变则采取了不同的路径,如今最流行的就是Transformer架构。Transformer专为序列建模和转导任务而设计,以使用注意力来建模数据中的长距离关系而著称。
Transformer在语言领域的巨大成功促使研究人员研究其在计算机视觉的适应性,目前也取得了很不错的结果,特别是用于图像分类的ViT以及用于视觉语言联合建模的CLIP。
本文作者尝试扩展Transformer的适用性,将其用作计算机视觉的通用主干,就像Transformer在NLP和CNN在视觉中所做的那样。将Transformer在语言领域的高性能表现转移到视觉领域所面临的主要挑战,主要源自两个领域之间的差异:

  • 尺寸。token作为NLP Transformer中的基本元素,其尺寸是固定的,对应段落中的一个单词。但视觉目标的尺寸可能有较大的差异,这也是如物体检测等任务备受关注的问题,通常需要捕获多尺度特征来解决。而在现有的基于Transformer的模型中,token都是固定尺寸的,对应一个单词或固定的图片区域,显然不适用于当前的视觉应用任务。
  • 数量级。与文本段落中的单词数量相比,图像中的像素数量要多很多。在许多如语义分割的视觉任务中,需要进行像素级的密集预测。而Transformer在高分辨率图像上的处理是难以进行的,因为自注意力的计算复杂度与图像大小成二次方关系。

为了解决这些问题,论文提出了Swin Transformer,能够构建层级特征图并且计算复杂度与图像大小成线性关系。
基于层级特征图,Swin Transformer模型可以很方便地结合先进的密集预测技术,如特征金字塔网络(FPN)或U-Net。如图1a所示,Swin Transformer从小尺寸的图像块开始,逐渐合并相邻图像块来构建层级特征。线性计算复杂度则是通过只在局部非重叠窗口(图1a红色区域)计算自注意力来实现的。由于窗口大小是固定的,所以复杂度与图像大小成线性关系。
Swin Transformer还有一个关键设计元素,就是在连续的同尺度self-attention层使用移位窗口分区(shifted window partition)。类似于对分组卷积的分组间通信优化,移位窗口能够促进前一层的窗口之间的特征融合,从而显著提高建模能力。常见的基于滑动窗口(sliding window)的自注意力,由于每个
query
对应的
key
集不同,所以都要单独计算注意力矩阵然后输出,实现上很低效。而移位窗口由于仅在窗口内进行自注意力计算,同窗口内的
query
对应的
key
集相同,
key
集可在窗口内共享,可直接单次矩阵计算同时完成全部注意力计算然后输出,在实现上十分高效。
Swin Transformer在图像分类、目标检测和语义分割的识别任务上取得了很不错的结果。在速度相似的情况下,准确率显著优于ViT/DeiT和ResNe(X)t模型。在COCO test-dev数据集上达到的58.7 box AP和51.1 mask AP,分别比SOTA高2.7和2.6。在ADE20K val数据集集上获得了 53.5 mIoU,比SOTA高3.2。在ImageNet-1K数据集上达到了87.3%的top-1准确率。

Method


Overall Architecture

Swin Transformer整体架构如图3所示,该图是Tiny版本Swin-T,分为以下几个部分:

  • Patch Partition:输入图像的处理跟ViT类似,通过patch splitting模块将输入的RGB图像分割成不重叠的图像块,直接将每个图像块内的RGB值concate起来作为一个token。在实现时,每个图像块的大小为
    \(4\times 4\)
    ,因此每个图像块的特征维度为
    \(4\times 4\times 3 = 48\)
  • Linear Embedding:随后,Linear Embedding层对这个原始特征进行处理,将其映射到指定维度大小
    \(C\)
  • Swin Transformer block:在得到图像块token后,连续使用多个包含改进自注意力的Transformer模块(Swin Transformer block)进行特征提取。
  • Patch Merging:为了构建层级特征,随着网络变深,通过Patch Merging层减少token的数量。第一个Patch Merging层将每个维度的
    \(2\times 2\)
    的相邻图像块特征concate起来,并在得到的
    \(4C\)
    维特征上使用Linear Embedding层进行维度映射。这样,token量就减少了
    \(2\times 2 = 4\)
    的倍数(相当于两倍下采样)并且映射到指定维度大小
    \(2C\)
    ,最后同样使用Swin Transformer blocks进行特征变换。

Linear Embedding与后续的Swin Transformer blocks一起称为Stage 1,token的数量为
\(\frac{H}{4}\times \frac{W}{4}\)
。第一个Patch Merging和Swin Transformer blocks称为Stage 2,分辨率保持在
\(\frac{H}{8}\times \frac{W}{8}\)
。该过程重复两次,分别为Stage 3和Stage 4,输出分辨率分别为
\(\frac{H}{16}\times \frac{W}{16}\)

\(\frac{H}{32}\times \frac{W}{32}\)
。各Stage共同构建的层级特征,其特征分辨率与典型卷积网络相同,例如VGG和ResNet。因此,Swin Transformer架构可以方便地替换现有方法中的骨干网络,用于各种视觉任务。

  • Swin Transformer block

Swin Transformer模块将Transformer模块中的多头自注意力(MSA)替换为基于windows或shifted window的多头自注意力,其他层保持不变。如图3b所示,对于连续的Swin Transformer模块,前一个使用基于window的MSA模块,后一个使用基于shifted window的MSA模块,然后都是接一个带GELU非线性激活的两层MLP,每个MSA模块和每个MLP都有LayerNorm(LN)层和一个残差连接。

Shifted Window based Self-Attention

标准的Transformer架构及其在图像分类的应用都进行全局自注意力计算,计算每个token和所有其他token之间的关系。全局自注意力计算的复杂度是token数量的二次方,这显然不适用于许多需要大量token进行密集预测或产生高分辨率图像的视觉问题。

  • Self-attention in non-overlapped windows

为了高效计算,论文提出仅在局部窗口内计算自注意力,各窗口以不重叠的方式均匀地划分图像。假设每个窗口包含
\(M\times M\)
个图像块,在包含
\(h\times w\)
个图像块的特征图上,全局模式和窗口模式的计算复杂度分别为:

复杂度前面的部分应该是
Q

K

V
和最终输出的生成计算,后面部分是
Q

K
的矩阵相乘和权值与
V
的相乘。全局模式的计算复杂度与图像块数量
\(hw\)
成二次方,而当
\(M\)
固定时(默认设置为7),窗口模式的计算复杂度则是线性的。所以当
\(hw\)
很大时,全局自注意力计算通常是难以进行的,而基于窗口的自注意力则是可调整的。

  • Shifted window partitioning in successive blocks

类似于分组卷积的问题,基于窗口的自注意力缺乏跨窗口的连接,限制了建模能力。为了在保持高效计算的情况下引入跨窗口连接,论文提出了移位窗口分区(shifted window partitioning)方法,在连续的Swin Transformer模块交替使用两种不同分区逻辑。

如图2所示,第一个模块使用从左上角像素开始的常规窗口分区策略,将
\(8\times 8\)
特征图均匀地划分为4个
\(4\times 4\)
(M = 4)大小的窗口。然后,下一个模块采用与前一层不同的窗口分区策略,将常规窗口移动
\((\lfloor \frac{M}{2}\rfloor, \lfloor \frac{M}{2}\rfloor)\)
个像素。
基于移位窗口分区方法,连续的Swin Transformer模块的计算变为:

其中
\(\hat{z}^l\)

\(z^l\)
表示
\(l\)
层的(S)WMSA模块和MLP模块的输出特征,W-MSA和SW-MSA 分别表示使用常规窗口分区和移位窗口分区的窗口多头自注意。
移位窗口分区方法增加了上一层中相邻的非重叠窗口之间的联系,这在图像分类、物体检测和语义分割中是十分有效的。

  • Efficient batch computation for shifted configuration

移位窗口分区会导致窗口数变多,从
\((\lfloor \frac{M}{2}\rfloor, \lfloor \frac{M}{2}\rfloor)\)
个窗口变为
\((\lfloor \frac{h}{M}+1\rfloor, \lfloor \frac{w}{M}+1\rfloor)\)
个窗口,而且部分窗口的大小会小于
\(M\times M\)
。在计算窗口自注意力时,一般会将多个窗口拼接成矩阵进行矩阵计算,要求每个窗口的大小一致。
一个简单的移位窗口分区的兼容做法是将较小的窗口填充到
\(M\times M\)
的大小,然后在计算注意力时屏蔽掉填充的值。在常规分区中的窗口数量较少时,例如
\(2\times 2\)
,使用这种简单的解决方案增加的计算量是相当大的(
\(2\times 2 \to 3\times 3\)
,增加2.25倍)。

为此,论文提出了一种更高效的批处理计算方法,通过向左上方向循环移位进行小窗口的合并计算,如图4所示。在移位之后,单个窗口可能由几个原本不相邻的子窗口组成,因此需要采用掩码机制将自注意力计算限制在每个子窗口内,掩码机制主要是屏蔽掉计算出来的注意力矩阵。在循环移位后,由于窗口数量与常规窗口分区的数量相同,因此计算量也相当。

  • Relative position bias

在计算self-attention时,论文参考当前一些研究的做法,在进行相似度计算时为每个head加入相对位置偏置(relative position bias)
\(B\in \mathbb{R}^{M^2\times M^2}\)
,注意区别于常规相对位置编码的做法:

其中
\(d\)

Q

K

V
特征的维度,
\(M^2\)
是窗口中的图像块数。由于每个轴方向的相对位置均在
\([−M + 1, M −1]\)
范围内,论文设置了一个较小尺寸的可学习偏置矩阵
\(\hat{B}\in \mathbb{R}^{(2M−1)\times(2M−1)}\)
(对应二维相对位置组合数量),然后根据窗口中各位置的相对位置转换得到唯一索引编码,从
\(\hat{B}\)
取对应的值构成
\(B\)
矩阵。这样做的目的有两个,降低参数量(
\((2M−1)\times(2M−1)\)
vs
\((M^2\times M^2)\)
),同时让相同位置的使用相同偏置。
从实验结果来看,与没有此偏置项或使用绝对位置偏置的对比,相对位置偏置有显著的性能提升。ViT使用了绝对位置偏置,论文也尝试进一步叠加绝对位置偏置,但测试会略微降低性能,因此在实现中未采用它。
当要fine-tuning不同窗口大小的模式时,预训练到的相对位置偏置也可通过bi-cubic interpolation进行转换。

Architecture Variants

论文构建了基础模型Swin-B,跟ViTB/DeiT-B的模型大小和计算复杂度差不多。此外,论文还涉及了Swin-T、Swin-S和Swin-L版本,分别是基础模型的模型大小和计算复杂度的0.25倍、0.5倍和2倍的版本。其中,Swin-T和Swin-S的复杂度分别对标ResNet-50(DeiT-S)和ResNet-101。默认情况下,窗口大小设置为 M = 7。对于所有实验,自注意力计算每个head的特征维度
\(d = 32\)
,每个MLP的扩展层
\(α = 4\)

这些模型变体的架构超参数是:

  • Swin-T:C = 96, layer numbers =
  • Swin-S:C = 96, layer numbers =
  • Swin-B:C = 128, layer numbers =
  • Swin-L:C = 192, layer numbers =

其中
\(C\)
是Stage 1的维度数。

Experiment


直接训练和预训练在Image-1K数据集上的性能对比。

目标检测上对比嵌套多种检测算法和其它主干网络。

语义分割上对比其它SOTA模型。

移位窗口策略性能以及不同的position embedding组合的对比。

不同策略之间的推理性能对比。

Swin Transformer搭配不同自注意力计算方法的性能对比。

Conclusion


论文提出了经典的Vision Transormer模型Swin Transformer,能够构建层级特征提高任务准确率,而且其计算复杂度经过各种加速设计,能够与输入图片大小成线性关系。从实验结果来看,Swin Transormer在各视觉任务上都有很不错的准确率,而且性能也很高。



如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

work-life balance.

事情要从一次不规范的代码开发开始说起

背景故事

时间

2024年某个风平浪静的周五晚上

地点

中国,北京,西二旗,某互联网大厂会议室

人物

小杰,小A,小B,老K

对话

老K:昨天提交的
代码被测试打回
来了!为什么
小B没开发完的内容
也一起提交上去了?

小B:啊?我不清楚啊,我在开发分支B开发完一部分就
提交到test分支进行联调
了啊

小A:额(''!),我把
test分支合并到发布分支提交给测试
了,因为我跟小杰最开始在各自的分支开发,但是中间联调的时候,为了
图方便直接在test分支上改
,改来改去就直接
在test分支上开发
了。。。

老K:什么?你怎么能直接在test分支开发?

老K:正确的开发流程规范应该是:现在
各自的开发分支
上开发,然和合并到
test分支上进行联调
,联调没有问题在
提交发布release分支进行测试和部署
,验证没问题在把各自的
开发分支合并到基线分支master

老K:现在要把代码回滚,谁做的事情谁负责。小杰,你去把test分支上的代码抽出来放到单独一个开发分支上

小杰:啊?


小杰接受这个任务,准备把test分支上
他跟小A多次提交
的内容跟转移到一个纯净的开发分支,小杰决定使用
git cherry pick
这个命令

git cherry-pick

介绍

git cherry-pick
是 Git 中的一个非常有用的命令,它允许你从一个分支中选择特定的提交(commit)应用到当前的分支。这个命令在需要引入某些特定功能或修复而不想进行完整的分支合并时特别有用。

使用示例

假设你有以下 Git 分支结构:

* 5a3d5f2 (feature) Add new feature
* c7e33a5 Fix bug B
* 1a2b3c4 Fix bug A
* 9d8e7f6 (main) Initial commit

现在你在
main
分支上,想要将
feature
分支中修复 bug A 的提交 (
1a2b3c4
)引入当前分支。你可以这样做:

  1. 切换到目标分支(假设是
    main
    分支):

    git checkout main
    
  2. 使用
    git cherry-pick
    命令:

    git cherry-pick 1a2b3c4
    

    执行上述命令后,提交
    1a2b3c4
    的更改会被应用到
    main
    分支上。

  3. 使用范围(range)来批量 cherry-pick
    . 假设你要 cherry-pick 从 commitA 到 commitB 之间的所有 commit(包含 commitA 但不包含 commitB),你可以使用以下命令:

    git cherry-pick commitA^..commitB
    
  4. 使用多个单独的 commit 来批量 cherry-pick
    . 假设你有一系列的 commit 哈希 commit1, commit2, commit3,你可以使用以下命令:

    git cherry-pick commit1 commit2 commit3
    
  5. 解决可能的冲突:
    在 cherry-pick 的过程中,如果遇到冲突,Git 会提示你。你需要手动解决这些冲突并继续 cherry-pick

    # 解决冲突后,添加解决后的文件
    git add <conflicted-file>
    # 继续 cherry-pick
    git cherry-pick --continue
    

注意事项

  1. 冲突处理
    :如果在
    cherry-pick
    的过程中,存在文件冲突,Git 会暂停操作,并提示冲突文件。你需要手动解决这些冲突,然后使用
    git add <file>
    添加解决后的文件,最后运行
    git cherry-pick --continue
    继续操作。如果你想中止
    cherry-pick
    ,可以使用
    git cherry-pick --abort
  2. 保持提交历史干净
    :频繁使用
    cherry-pick
    可能会导致提交历史变得复杂。在使用前,评估是否可以通过别的操作(如合并或重置)来实现相同的目标。
  3. 避免重复提交
    :如果你已经
    cherry-pick
    了一个提交,再次尝试
    cherry-pick
    同一个提交可能会引发问题。Git 会提示你已经包含了相同的更改。
  4. 顺序和依赖关系
    :如果一个提交依赖于之前的其他提交,
    cherry-pick
    这些提交时需要注意顺序,以避免破坏代码的完整性。


解决方案

第一次尝试

# Step 1: 创建新的分支 xsj_0701
git checkout master
git pull origin master
git checkout -b xsj_0701

# Step 2: 查看 stable_test 分支的 commit 历史
git log stable_test

# Step 3: 批量 cherry-pick commit
git checkout xsj_0701
git cherry-pick commitA^..commitB  # 使用范围
# 或者
git cherry-pick commit1 commit2 commit3  # 使用多个单独的 commit 哈希

# Step 4: 解决可能的冲突
# 解决冲突后
git add <conflicted-file>
git cherry-pick --continue

# Step 5: 推送新的分支
git push origin xsj_0701

小杰使用批量范围cherry-pick,这个范围大约包含了10个commit,正当小杰吭哧吭哧的解决几个冲突之后,cherry-pick突然停止,没有冲突,查看当前commit也只到
add cache
这个提交这里,如下图所示

image-20240702104831308

为什么cherry-pick会停止呢?

小杰经过观察发现,停止的位置是merge节点

当你尝试 cherry-pick 一个 merge commit 时,Git 需要更多信息来决定如何处理合并。默认情况下,Git 不会自动 cherry-pick merge commit,因为它无法确定你想要保留哪个分支的变更。

要解决这个问题,你可以使用以下方法:

方法 1: 跳过 merge commit

如果你不需要 cherry-pick 这个 merge commit,可以手动跳过它。你可以通过在失败后继续 cherry-pick 后续的 commit 来实现:

在发生停止后,手动跳过 merge commit 并继续 cherry-pick 后续的 commit:

git cherry-pick --skip
# 然后继续 cherry-pick 后续的 commit
git cherry-pick <remaining-commits>

方法 2: 使用 cherry-pick -m 选项

如果你确实需要 cherry-pick 这个 merge commit,可以使用
-m
选项。
-m
选项需要一个参数来指定父提交的索引,通常使用
1
表示第一父提交。

  • 继续 cherry-pick merge commit 并指定父提交索引:
git cherry-pick -m 1 5b30dd90
  • 继续后续的 cherry-pick:
git cherry-pick <remaining-commits>

第二次尝试

# Step 1: 创建新的分支 xsj_0701
git checkout master
git pull origin master
git checkout -b xsj_0701

# Step 2: 查看 stable_test 分支的 commit 历史
git log stable_test

# Step 3: 批量 cherry-pick commit
git checkout xsj_0701
git cherry-pick commitA^..commitB  # 使用范围
# 或者
git cherry-pick commit1 commit2 commit3  # 使用多个单独的 commit 哈希

# Step 4: 解决可能的冲突
# 解决冲突后
git add <conflicted-file>
git cherry-pick --continue

# Step 5: 跳过merge节点
git cherry-pick --skip

# Step 6: 推送新的分支
git push origin xsj_0701

经过一下午的奋战,小杰终于把test分支上的开发内容都迁移到纯净开发分支,然后屁颠屁颠去跟老K汇报了

一、简介

在实际的项目开发过程中,经常需要用到邮件通知功能。例如,通过邮箱注册,邮箱找回密码,邮箱推送报表等等,实际的应用场景非常的多。

早期的时候,为了能实现邮件的自动发送功能,通常会使用 JavaMail 相关的 api 来完成。后来 Spring 推出的 JavaMailSender 工具,进一步简化了邮件的自动发送过程,调用其 send 方法即可发送邮件。再之后, Spring Boot 针对邮件推送功能推出了
spring-boot-starter-mail
工具包,开发者可以通过它来快速实现邮件发送服务。

今天通过这篇文章,我们一起来学习如何在 Spring Boot 中快速实现一个自动发送邮件的功能。

二、环境准备

在介绍邮件推送实现之前,我们需要先准备一台邮件推送的服务器,以便实现相关功能。

这里以腾讯邮箱为例,将其作为邮件发送的中转平台。

2.1、开启 SMTP 服务

登陆腾讯邮箱,打开【设置】-》【收发信设置】,开启 SMTP 服务,最后点击【保存更改】。

2.2、生成客户端专用密码

点击【设置】-》【账户】,进入页面后点击【开启安全登陆】,点击【生成新密码】。

这个新密码会用于邮箱的自动发送,因此需要记录下来,最后点击【保存更改】。

2.3、相关扩展知识

  • 什么是 SMTP?

SMTP(simple mail transfer protocol),也被称为
简单邮件传输协议
,主要用于发送电子邮件的,通过它可以实现邮件的发送或者中转。遵循 SMTP 协议的服务器,通常称为发送邮件服务器。

  • 什么是 POP3?

POP3(Post Office Protocol),一种邮局通信协议。主要用于接受电子邮件的,POP3 允许用户从服务器上把邮件存储到自己的计算机上,同时删除保存在邮件服务器上的邮件。同理,遵循 POP3 协议的服务器,通常称为接收邮件服务器。

  • 什么是 IMAP?

IMAP(Internet Mail Access Protocol),一种交互式邮件存取协议。与 POP3 协议类似,主要用于接收电子邮件,稍有不同的是:IMAP 允许电子邮件客户端收取的邮件仍然保留在服务器上,同时在客户端上的操作都会反馈到服务器上,例如删除邮件,标记已读等,服务器上的邮件也会做相应的动作。所以无论从浏览器登录邮箱或者客户端软件登录邮箱,看到的邮件以及状态都是一致的。

总结下来就是:SMTP 负责发送邮件,POP3/IMAP 负责接收邮件。

常见邮箱发、收服务器如下!

三、邮件推送实现

用于发送邮件的服务器、账户和密码准备好了之后,就可以正式使用了。下面我们以 Spring Boot 的
2.1.0
版本为基础,实现过程如下。

2.1、添加依赖包


pom.xml
文件中,添加
spring-boot-starter-mail
依赖包。

<!--mail 支持-->
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-mail</artifactId>
</dependency>

2.2、添加相关配置


application.properties
中添加邮箱相关配置。

# 配置邮件发送主机地址
spring.mail.host=smtp.exmail.qq.com
# 配置邮件发送服务端口号
spring.mail.port=465
# 配置邮件发送服务协议
spring.mail.protocol=smtp
# 配置邮件发送者用户名或者账户
spring.mail.username=xxx@qq.com
# 配置邮件发送者密码或者授权码
spring.mail.password=xxxxxxx
# 配置邮件默认编码
spring.mail.default-encoding=UTF-8
# 配置smtp相关属性
spring.mail.properties.mail.smtp.auth=true
spring.mail.properties.mail.smtp.ssl.enable=true
spring.mail.properties.mail.smtp.ssl.required=true

2.3、简单发送一封邮件

通过单元测试来实现一封简单邮件的发送,示例如下:

@RunWith(SpringRunner.class)
@SpringBootTest
public class MailSimpleTest {

    @Autowired
    private JavaMailSender mailSender;

    @Test
    public void sendSimpleMail() throws Exception {
        SimpleMailMessage message = new SimpleMailMessage();
        // 配置发送者邮箱
        message.setFrom("xxxx@qq.com");
        // 配置接受者邮箱
        message.setTo("xxxxxx@qq.com");
        // 配置邮件主题
        message.setSubject("主题:简单邮件");
        // 配置邮件内容
        message.setText("测试邮件内容");
        // 发送邮件
        mailSender.send(message);
    }
}

运行单元测试之后,如果不出意外的话,接受者会收到这样的一封邮件。

至此,邮件发送成功!

2.4、发送 HTML 格式邮件

在实际的业务开发中,邮件的内容通常会要求丰富,比如会发送一些带有图片的内容,包括字体大小,各种超链接等,这个时候如何实现呢?

实际上,邮件内容支持 HTML 格式,因此可以借助页面模板引擎来实现绚丽多彩的内容。

下面我们以
freemarker
模板引擎为例,发送一封内容为 HTML 格式的邮件。

2.4.1、引入 freemarker 依赖包

首先,在
pom.xml
文件中,添加
freemarker
依赖包。

<!--freemarker 支持-->
<dependency>
    <groupId>org.freemarker</groupId>
    <artifactId>freemarker</artifactId>
    <version>2.3.23</version>
</dependency>
2.4.2、编写邮件页面模板

然后,在
resources/templates
目录下,创建一个
demo.ftl
文件,示例如下!

<html>
<head>
	<meta charset="utf-8">
	<title></title>
</head>
<body>
<div>您好:${userName}</div>
<div>这是html文本内容</div>
<img src="https://rescdn.qqmail.com/zh_CN/htmledition/images/logo/logo_0_0@2X1f1937.png" />
</body>
</html>
2.4.3、编写一个邮件推送服务

虽然采用 Spring Boot 提供的自动配置属性来实现邮件推送,可以极大的简化开发过程。而实际开发的时候,通常更推荐自定义一个邮件统一推送服务,这样更便于灵活的控制代码实现以及排查相关问题。

邮件统一发送服务,示范如下。

@Component
public class MailPushService {

    private final Logger LOGGER = LoggerFactory.getLogger(MailPushService.class);

    @Value("${mail.host}")
    private String host;

    @Value("${mail.port}")
    private String port;

    @Value("${mail.protocol}")
    private String protocol;

    @Value("${mail.username}")
    private String username;

    @Value("${mail.password}")
    private String password;

    @Value("${mail.fromEmail}")
    private String fromEmail;

    @Value("${mail.fromPersonal}")
    private String fromPersonal;

    @Autowired
    private JavaMailSender mailSender;


    /**
     * 发送邮件(简单模式)
     * @param toEmail
     * @param subject
     * @param content
     */
    public void sendMail(String toEmail, String subject,String content)  {
        try {
            final Properties props = new Properties();
            //服务器
            props.put("mail.smtp.host", host);
            //端口
            props.put("mail.smtp.port", port);
            //协议
            props.setProperty("mail.transport.protocol", protocol);
            //用户名
            props.put("mail.user", username);
            //密码
            props.put("mail.password", password);
            //使用smtp身份验证
            props.put("mail.smtp.auth", "true");

            //开启安全协议
            MailSSLSocketFactory sf = new MailSSLSocketFactory();
            sf.setTrustAllHosts(true);
            props.put("mail.smtp.ssl.enable", "true");
            props.put("mail.smtp.ssl.socketFactory", sf);
            Authenticator authenticator = new Authenticator() {
                @Override
                protected PasswordAuthentication getPasswordAuthentication() {
                    return new PasswordAuthentication(props.getProperty("mail.user"),
                            props.getProperty("mail.password"));
                }
            };

            Session session = Session.getDefaultInstance(props, authenticator);
            session.setDebug(true);
            MimeMessage mimeMessage = new MimeMessage(session);
            mimeMessage.setFrom(new InternetAddress(fromEmail, MimeUtility.encodeText(fromPersonal)));
            mimeMessage.setRecipient(MimeMessage.RecipientType.TO, new InternetAddress(toEmail));
            mimeMessage.setSubject(subject);
            mimeMessage.setContent(content, "text/html;charset=UTF-8");

            //保存信息
            mimeMessage.saveChanges();
            //发送消息
            Transport.send(mimeMessage);
            LOGGER.info("简单邮件已经发送。");
        } catch (Exception e) {
            LOGGER.error("发送简单邮件时发生异常!", e);
        }
    }
}

代码中相关自定义的全局参数配置如下:

mail.host=smtp.exmail.qq.com
mail.port=465
mail.protocol=smtp
mail.username=xxx@qq.com
mail.password=xxxxxx
mail.fromEmail=xxxxxx@qq.com
mail.fromPersonal=发送者昵称
2.4.4、测试服务的正确性

最后,编写一个单元测试来验证服务的正确性,示例如下:

@RunWith(SpringRunner.class)
@SpringBootTest
public class MailTest {

    @Autowired
    private MailPushService mailPushService;

    @Test
    public void testSendHtmlMail() throws Exception {
        String sendHtml = buildHtmlContent("张三");
        mailPushService.sendMail("xxxxx@qq.com","简单标题", sendHtml);
    }

    /**
     * 封装html页面
     * @return
     * @throws Exception
     */
    private static String buildHtmlContent(String userName) throws Exception {
        Configuration configuration = new Configuration(Configuration.VERSION_2_3_23);
        configuration.setDefaultEncoding(Charset.forName("UTF-8").name());
        configuration.setClassForTemplateLoading(MailTest.class, "/templates");
        // 获取页面模版
        Template template = configuration.getTemplate("demo.ftl");
        // 动态变量替换
        Map<String,Object> map = new HashMap<>();
        map.put("userName", userName);
        String htmlStr = FreeMarkerTemplateUtils.processTemplateIntoString(template,map);
        return htmlStr;
    }

}

运行单元测试之后,如果没有报错,接受者会收到这样的一封邮件。

2.5、发送带附件的邮件

某些业务场景,用户希望发送的邮件中能带上附件,比如上文中,在发送 HTML 格式的邮件时,同时也带上文件附件,这个时候如何实现呢?

2.5.1、编写带附件的邮件发送

此时可以在邮件推送服务中,新增一个支持带附件的方法,实现逻辑如下。

/**
 * 发送邮件(复杂模式)
 * @param toEmail    接受者邮箱
 * @param subject    主题
 * @param sendHtml   内容
 * @param attachment 附件
 */
public void sendMail(String toEmail, String subject, String sendHtml, File attachment) {
    try {
        //设置了附件名过长问题
        System.setProperty("mail.mime.splitlongparameters", "false");
        final Properties props = new Properties();
        //服务器
        props.put("mail.smtp.host", host);
        //端口
        props.put("mail.smtp.port", port);
        //协议
        props.setProperty("mail.transport.protocol", protocol);
        //用户名
        props.put("mail.user", username);
        //密码
        props.put("mail.password", password);
        //使用smtp身份验证
        props.put("mail.smtp.auth", "true");

        //开启安全协议
        MailSSLSocketFactory sf = new MailSSLSocketFactory();
        sf.setTrustAllHosts(true);
        props.put("mail.smtp.ssl.enable", "true");
        props.put("mail.smtp.ssl.socketFactory", sf);
        Authenticator authenticator = new Authenticator() {
            @Override
            protected PasswordAuthentication getPasswordAuthentication() {
                return new PasswordAuthentication(props.getProperty("mail.user"),
                        props.getProperty("mail.password"));
            }
        };

        Session session = Session.getDefaultInstance(props, authenticator);
        session.setDebug(true);
        MimeMessage mimeMessage = new MimeMessage(session);
        // 发送者邮箱
        mimeMessage.setFrom(new InternetAddress(fromEmail, MimeUtility.encodeText(fromPersonal)));
        // 接受者邮箱
        mimeMessage.setRecipient(MimeMessage.RecipientType.TO, new InternetAddress(toEmail));
        // 邮件主题
        mimeMessage.setSubject(subject);
        // 定义邮件内容
        Multipart multipart = new MimeMultipart();

        // 添加邮件正文
        BodyPart contentPart = new MimeBodyPart();
        contentPart.setContent(sendHtml, "text/html;charset=UTF-8");
        multipart.addBodyPart(contentPart);

        // 添加附件
        if (attachment != null) {
            BodyPart attachmentBodyPart = new MimeBodyPart();
            // MimeUtility.encodeWord可以避免文件名乱码
            FileDataSource fds=new FileDataSource(attachment);
            attachmentBodyPart.setDataHandler(new DataHandler(fds));
            attachmentBodyPart.setFileName(MimeUtility.encodeText(fds.getName()));
            multipart.addBodyPart(attachmentBodyPart);
        }

        // 将multipart对象放到message中
        mimeMessage.setContent(multipart);

        //保存信息
        mimeMessage.saveChanges();
        //发送消息
        Transport.send(mimeMessage);
        LOGGER.info("邮件已经发送。");
    } catch (Exception e) {
        LOGGER.error("发送邮件时发生异常!", e);
    }
}
2.5.2、测试服务的正确性

最后,编写一个单元测试来验证服务的正确性,示例如下:

@Test
public void doSendHtmlEmail() throws Exception {
    // 获取正文内容
    String sendHtml = buildHtmlContent("张三");

    // 获取附件
    File file = new File( "~/doc/Java开发手册.pdf");
    // 发送邮件
    mailPushService.sendMail("xxxxx@qq.com","带附件的邮件推送", sendHtml, file);
}

运行单元测试之后,如果没有报错,接受者会收到这样的一封邮件。

三、小结

最后总结一下,邮件自动推送功能在实际的业务系统中应用非常广,在发送过程中也可能会因为网络问题出现各种失败现象,因此推荐采用异步的方式来发送邮件,例如采用异步编程或者消息队列来实现,以便加快主流程的执行速度。

想要获取项目源代码的小伙伴,可以访问如下地址获取!

https://gitee.com/pzblogs/spring-boot-example-demo

四、参考

1.
https://blog.csdn.net/qq_26383975/article/details/121957917

1.
http://www.ityouknow.com/springboot/2017/05/06/spring-boot-mail.html

在学习Transformer这个模型前对seq2seq架构有个了解时很有必要的

先上图

输入和输出

首先理解模型时第一眼应该理解输入和输出最开始我就非常纠结

有一个Inputs,一个Outputs(shift right)和一个Output Probabilities,首先需要借助这三个输入/输出来初步了解该模型的运行方式。这里以一个英译汉的任务来举例,在训练时Input端的输入就是英语,Outputs(shift right)端输入的就是对应的汉语译文,Output Probabilities输出的就是模型预测的下一个词语(汉语),首先要确定一点,输出的词是一个一个出来的,而不是一起出来的,第n个词的预测需要依赖前n-1个词,如果之前没有接触过seq2seq,这里就会有个疑问,Outputs(shift right)已经将答案给模型了,那这个训练有什么意义呢?这里就涉及到Masked Multi-Head Attention,这个Masked让模型在预测第n个词语的时候,会将第n个词语及之后的词语给盖住,让模型接触不到后面的内容,这样保证模型不去“抄答案”。

那对于Outputs(shift right)为什么论文中要加一个shift right,是因为模型在输出第n个词时候需要前n-1个,那要输出第一个词怎么办呢,这里人为定义了一个<start>,解码器中输入的所有句子都是以<start>开头的。

Input Embedding

可以看到,所有直接的数据在输入编/解码器之前都会经过一次Embedding和一次Positional Encoding,对于Embedding,计算机无法直接理解中文或者英文,需要将其编码为向量方便操作,one-hot就是用来做这个的,但是对于大模型来说,动辄几万个词,如果使用one-hot编码,词向量将是几万维的,这是不可接受的,因此有了更加高效的方法,例如Word2Vec、GloVe、FastText等,就拿Word2Vec来说,他能够捕捉单词语义的相似性,例如
大树
这个词语的词向量为
v1

树木
这个词的词向量为
v2
,蓝色这个词语的词向量为
v3
,那么
v1

v2
的点积就比
v1

v3
的点积要大,两个词的词向量点积越大,那这两个词的语义越相近,这是由Word2Vec这个算法所决定的。

Positional Encoding

位置编码,在seq2seq模型中,我们主要运用的是RNN模型作为主干,加上注意力机制来改善效果,RNN模型的输入不论是中文还是英文,不论是编码器还是解码器,都是一个词一个词给进去的,在处理序列数据时,每个时间步的计算依赖于前一个时间步的输出。这种顺序依赖性使得RNN很难利用并行计算来加速训练和推理过程,这就造成了性能瓶颈,同时过长的输入会导致梯度爆炸或者梯度消失。

而Transformer中,完全抛弃了RNN的基本结构,使用自注意力机制来处理输入序列,可以对输入序列中的所有元素同时进行处理,从而大大提高了计算速度和效率。而由于数据是一起被放入模型中,一起被处理,其位置信息就丢失了,seq2seq中的词是一个一个输入进去的,输入的先后顺序就隐藏了位置信息,因此为了保存数据的位置信息,我们就需要Positional Encoding(位置编码)。

在论文中的解决方案是这样的,通过一定的方法(后面会介绍)为句子中的每个词生成一个位置信息,然后将这个位置信息直接加到对应的词向量上面去,过程如下

Positional Encoding是如何生成的呢,论文中的方法为
Sine and Cosine Positional Encodings
,其思想是,为输入序列中的每个位置生成一个固定的向量,这个向量的构造方式是通过不同频率的正弦和余弦函数来实现的。具体的公式如下:

\[PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)
\]

\[PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)
\]

其中,
pos
是位置,
i
是向量的维度索引,
\(d_{\text{model}}\)
是模型的embedding维度。这种方法确保了每个位置的编码向量是唯一的,并且不同位置之间的距离可以通过这些编码向量进行区分。这个公式看起来很头大,其实不必过于纠结,知道它是干什么用的就行,这个方法也不是完美方法,在后来的bert中就没用使用这个方法了,说明还是存在一些问题的。

编码器

左边的一块叫做编码器,右边的一块叫做解码器,这个取名很形象,借用之前的例子,编码器将英语编码成一种只有计算机才能理解的语言(不是计算机直接理解的计算机语言),再通过解码器解码成目标语言。

Multi-Head Attention

首先,我们需要搞清楚什么是自注意力机制,也就是 self-Attention。

所谓注意力机制,我的理解就是将重点放在重要的地方。举个简单的例子,当我们翻译句子时,例如原句是 "A black man",当我们翻译完了 "A black" 时,到 "man" 时,我们脑子里还是有 "A black man" 这个句子,但是此时我们的注意力肯定是放在 "man" 这个单词上,而会忽略 "A black",就是对这三个单词的注意力的重视程度(权重)不一样,"man" 这个单词的权重很高,而其余的很低。

这里可能有人会质疑,翻译不就是一个词一个词的翻译吗?其实翻译也是要看上下文的。你翻译 "A black" 时,就是 "一个黑色的",但是你看到 "man" 时,你不能直接翻译成 "一个黑色的人",而要翻译成 "一个黑人"。所以翻译任务不是简单的逐词翻译,而要联系上下文。在翻译很长的信息时,如果我们对所有的上下文都同样对待,将会陷入信息的海洋中迷失自己。因此,在翻译过程中为每个词添加权重时是很有必要的。

我们在翻译过程中,经验和大脑与生俱来的抽象能力自动帮我们实现了这一注意力的过程,但是如何把这个思想传递给模型呢?这里就是注意力机制要做的事情。

首先我之前提到过使用现代词嵌入技术是可以在向量中反映出词之间的相似程度的,我把每个词向量和同一句话中的其它词向量求内积,就可以得到每个词向量和其它词向量的相似度(可以理解为关联的强弱,因为这个词向量也是模型从海量文本中学习来的,比如大量文本中都出现了红苹果,几乎没有黄苹果,那么显然红和苹果的关联性就远大于黄和苹果的关联性。)暂不考虑其细节,我们似乎可以用这种点积的大小来反映各个词之间的关联程度,既然关联程度不一样,我们可以量化这种不同从而得到一个权重矩阵也就是注意力矩阵。

先看这张图

这里的Q、K、V就是经过Encoding之后的词向量,将其记作X如果按照之前的思路,我们现在应该
\(X \cdot X^T\)
,然而如果是这样的话显然就限定了数据的分布,而且如果词嵌入向量没有训练好的的话这里会十分影响模型性能,而且好像没有什么可以训练的参数,我们应该在这里加入一点东西让网络复杂一点,因此我们引入了三个独立的线性层,分别记作W1,W2和W3,我们将
\(X \cdot W_1\)
记作Q,
\(X \cdot W_2\)
记作K,
\(X \cdot W_3\)
记作V,这就是自注意力机制中Q、K、V的来历。

那在我们使用
\(Q \cdot K^T\)
就能得到输入句子中每个词与其他词之间的相似度的矩阵,为保证梯度稳定性,我们进行一个常规的归一化,通过数学推导,
\(Q \cdot K^T\)
的均值为0,方差为
\(d\)
(词嵌入向量的维度),归一化之后的式子如下

\[\frac{Q \cdot K^T}{\sqrt{d}}
\]

我们已经得到了处理后的内积,离我们想要的权重矩阵就只有一步之遥了,将其映射到0-1之间的概率分布即可,这里我们可以选用Sigmoid函数或者Softmax函数,论文中作者使用了Softmax函数,那我们就可以得到注意力矩阵了

\[Softmax(\frac{Q \cdot K^T}{\sqrt{d}})
\]

得到了注意力矩阵(权重矩阵),我们将其乘进V矩阵(原始矩阵)中得到加权后的矩阵,也就是加了注意力之后的矩阵

\[Softmax(\frac{Q \cdot K^T}{\sqrt{d}}) \cdot V
\]

这就是论文里的公式,用流程图展示如下

以上我们解释了Attention,那么Multi-Head又该如何解释呢?

在向量X(经过Encoding之后的词向量)进入自注意力机制模块前将他”断开“,不同的维度进入不同的自注意力机制模块进行相同的的运算,例如”你“这个词,假设它的词嵌入向量Y是512维的[a0, a1, a2, ···, a510, a511],我们只使用两个头,也就是h=2,那么就将Y截断,[a0, a1, ···, a254, a255]进入第一个自注意力机制模块进行计算,[a256, a257, ···, a510, a511]进入第二个模块经行同样的计算,在各自计算完成后拼接(concat)起来,再通过一个全连接层增加模型的复杂度。事实上,这样做是很有必要的,这样可以训练多个注意力矩阵提取不同维度的信息,增加了模型的复杂度,同时通过拆分维度把计算量分成一小块一小块的了,提高了并行性。

至此,我们走完了Multi-Head Attention这个模块。

Add & Norm

对于Add,借鉴了残差结构,就是将Multi-Head Attention输出的结果与向量X加一下,这样可以保证梯度稳定,这又是如何实现的呢?

编码器旁边有个Nx,说明肯定不止一层,这里如果加一下的话输出的结果就从f(x)变成了f(x)+x,这样在求导时就由f'(x)变成了f'(x)+1,可以一定程度上缓解梯度消失的问题。

Multi-Head Attention输出的结果与向量X加完的结果进行一个Layer Normalizaiton

\[\text{LayerNorm}(x) = \gamma \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta
\]

其中,
\(\gamma\)

\(\beta\)
是可训练的参数,其中为什么要用LN而不是BN主要是由于每个句子的含有词的长度不同,会导致数值不太稳定和一些其他考虑。注意,BN是对所有的batch的某个feature做归一化,LN是对某个batch的所有feature做归一化。

Feed Forward

基于位置的前馈网络(FFN)

过程:

\[FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
\]

两个全连接层中间夹着一个ReLU,为什么需要这个东西呢?首先分析一下输入到FFN的数据的形状应该是
(b, n, d)
的,那它跟CNN中的有什么区别呢?

首先对于CNN来说我们解释b、n、d的含义:

b
: 批处理大小,即一次处理的图像数量。例如,如果一次输入处理 32 张图片,则
b = 32

n
: 特征数量,这通常对应于卷积层和池化层之后的特征图数量。例如,如果最后一层的卷积层有 512 个输出通道,那么
n = 512

d
: 特征维度,指每个特征向量的长度。这可能是卷积层输出特征图的宽度和高度的乘积。例如,如果最后一层的特征图的大小为 7x7,那么
d = 7 * 7 = 49

那对于Transformer来说,b、n、d的含义又有所区别

b: Batch size,即一次输入网络的数据样本数。这个维度表示在一次前向传递中处理的序列数量。也就是指的句子数量。

n: Sequence length,序列的长度。这个维度表示每个输入序列中包含的词数量。也就是指句子里面词的数量。

d: Feature dimension,特征的维度。这个维度表示每个标记的嵌入向量的长度。也就是之前提到的d。

对于一个三维向量后面接上一个全连接层,CNN是如何处理的呢?

全连接层会将这个三维数据展平成二维数据
(b, n*d)
,然后再输入到全连接层中进行分类。例如,假设你有一个批量大小为 32 的输入,经过卷积和池化操作后得到 512 个 7x7 的特征图,那么输入到全连接层的数据形状将会是
(32, 512, 49)
,在展平后变成
(32, 512 * 49)
,即
(32, 25088)

那Transformer可不可以效仿这种做法呢?

显然是不行的因为每个句子的长度不可能都是相同的,那么n就是变化的,
n*d
的值就是变化的,由于这种变化的特性,我们无法确定一个
(n*d, out_dim)
的矩阵,做预测的时候,由于每个句子的n不同,也很难完成任务,因此不能完全效仿CNN的做法,只得另谋它路。那文中是如何实现这个全连接层的呢?

作者使用了一个
1*1
的卷积,通过改变通道数的方式来实现全连接层的效果。

先将输入数据的形状从
(b, n, d)
转换为
(b, d, n)
(这一步为了适应conv1函数的输入参数),然后对其使用一个
1*1
的卷积层,假设隐藏层大小为 512,将输出通道数设为 512,数据形状将变成
(b, 512, n)
。然后经过一次 ReLU 激活函数,再用同样的操作,将输出通道数设置为 d,将数据形状变回
(b, d, n)
,最后调整回原始形状
(b, n, d)

实现方法可能有点出入,但是思想就是利用1*1的卷积来实现全连接层的作用,在只需要大致理解模型时候不用过多关注这些细节,只需知道这里就是两个全连接层夹着一个ReLU即可

解码器

学完编码器各个组件后会发现编码器这边其实没有什么新东西了,只有唯一一个Masked Multi-Head Attention特别一点了。

Masked Multi-Head Attention

对于一个训练好的模型来说,假如我们要做英译汉,最理想的情况是这样的:

  1. 我们在
    Inputs
    端输入 "A red apple"。
  2. Outputs(shift right)
    端会自动输入一个
    <start>
    作为起始标记。
  3. 解码器依据输入在经过一系列的变化,
    Output Probabilities
    输出 "一个"。
  4. Outputs(shift right)
    端会自动输入
    <start> 一个
  5. 解码器依据输入在经过一系列的变化,
    Output Probabilities
    输出 "红色的"。
  6. Outputs(shift right)
    端会自动输入
    <start> 一个 红色的
  7. 解码器依据输入在经过一系列的变化,
    Output Probabilities
    输出 "苹果"。
  8. Outputs(shift right)
    端会自动输入
    <start> 一个 红色的 苹果
  9. 解码器依据输入在经过一系列的变化,
    Output Probabilities
    输出
    <end>
  10. 翻译完成。

可以看到,结果是一个一个输出的,第n个词的输出需要依赖前n-1个词的输入,训练过程也是一样

  1. 我们在
    Inputs
    端输入 "A red apple"。
  2. Outputs(shift right)
    端会自动输入一个
    <start>
    作为起始标记。
  3. 解码器依据输入在经过一系列的变化,但是实际情况下,如果训练的不够,
    Output Probabilities
    输出结果很可能不是 "一个",而是其他的,我们就用交叉熵损失函数来计算损失值(量化它的输出与标准答案“一个”的差异),根据这个来调整网络的参数。
  4. Outputs(shift right)
    端会自动输入
    <start> 一个
    ,注意,不是
    <start>
    加上
    Output Probabilities
    输出的不标准的答案,而是标准答案,这个方法叫
    Teacher forcing
    ,试想如果第一个输出就不对,用错的结果继续生成的也只能是错误的结果,最后随着训练的继续只能越错越多,十分不利于模型的收敛,因此我们的输入端是要求输入标准答案的。也正是因为有了这种机制,我们让模型去预测
    一个
    的同时,也能让模型去预测
    红色的
    ,因为训练过程中的输入不依赖上一步的输出,这也就为并行计算提供了可能。
  5. 一直重复3,4步骤直至句子结束

但是有个问题,假如我们现在需要
<start> 一个
,来看模型预测的结果与
红色的
的差距,我们该怎么从标准答案里把
一个
选出来呢,毕竟我们给模型的数据是整个句子
一个 红色的 苹果

我们先将整个句子经过Embedding之后传入Masked Multi-Head Attention块,再计算 $ Q \cdot K^{T} $后得到的矩阵做一个遮盖的处理

<start> 一个 红色的 苹果
<start> 0.36 -inf -inf -inf
一个 -0.28 0.13 -inf -inf
红色的 -0.9 0.42 1.17 -inf
苹果 -0.3 0.17 0.5 0.25

这样在生成注意力矩阵时,经过softmax时权重几乎会变为0,就不会考虑后面的内容了。

其余的内容与Multi-Head Attention一模一样。

接下来后面的内容在了解玩编码器后就很好理解了。

数据经过Masked Multi-Head Attention后经过一个Add & Norm,之后的结构可以看到是和编码器一模一样的,唯一的区别就是输入

它需要三个输入,分别是
Q、K、V
,其中
K、V
来自编码器最终的输出,
Q
来自刚刚处理完成的数据,经过编码器一模一样的操作之后得到最终输出。

输出

将最后得到的数据首先经过一次线性变换,然后Softmax得到输出的概率分布,然后通过词典,输出概率最大的对应的单词作为我们的预测输出。

IP重组

ip重组这部分 4.19内核与3.10内核有些差别,4.9.134以后内核中不使用低水位和工作队列了,同时使用了rhashtable 替代了 hash bucket的概念,在3.10内核中使用1024个hash bucket, 每个bucket中最多存放128个分片队列,在4.19内核中所有的分片队列都保存在可动态调整的rhashtable 中,同时不再使用低水位和工作队列对ip 分片进行回收

4.19内核中,在内存中会分配一个reassembly buffer用于IP分片的重组。同时,也定义了一系列的参数用于控制IP分片处理过程:
net.ipv4.ipfrag_high_thresh
: 用于IP分片重组的最大内存用量(默认为4194304 ,即4Mb)。
net.ipv4.ipfrag_time
: IP分片在内存中的保留时间(默认30,单位:秒)。
对应上述网络协议栈的内核参数,内核层定义了结构体netns_frags,包含分片重组功能需要的全局控制信息,其定义如下:

struct netns_frags {
struct percpu_counter   mem ____cacheline_aligned_in_smp;
        /* sysctls */
        int                     timeout;
        int                     high_thresh;
        int                     low_thresh;
int			max_dist;
struct inet_frags	*f;
        struct rhashtable       rhashtable ____cacheline_aligned_in_smp;
atomic_long_t		mem ____cacheline_aligned_in_smp;
};

其中rhashtable为分片队列(inet_frag_queue)所在的hash表,IP分片包在内核中根据IP报头的4个字段计算得到一个hash值(key值),每个hash值对应一个分片队列,在实现分片包重组功能时,IP层需要先缓存收到的所有分片包,等待同一个IP报文的所有分片包都到达后,把它们重组成一个大包再提交给L4(TCP/UDP... ...)协议。
当收到新的ip分片包时,将查找是否存在同一数据包的分片队列。首先检查当前内存中所有待重组分片包占用的内存(frag_mem_limit)是否高于高水位(net.ipv4.ipfrag_high_thresh),如果高于则丢弃分片包;否则接着对接收到的分片包与rhashtable表中缓存的分片队列进行匹配(即从rhashtable表查找分片队列)将属于同一数据包的分片包放在同一个分片队列中,如果一个数据包的所有分片包都接收完成,那么将进入数据包的重构流程;如果匹配失败,说明该分片属于一个新的数据包,那么进入分片队列新建流程。分片队列的接收查找函数inet_frag_find定义如下:

struct inet_frag_queue *inet_frag_find(struct netns_frags *nf, void *key)
{
    struct inet_frag_queue *fq = NULL, *prev;

     //①高水位判断
    if (!nf->high_thresh || frag_mem_limit(nf) > nf->high_thresh) 
        return NULL;

    rcu_read_lock();
    prev = rhashtable_lookup(&nf->rhashtable, key, nf->f->rhash_params); //② 查找rhashtable中的分片队列
    if (!prev)
        fq = inet_frag_create(nf, key, &prev); //③ 创建新分片队列

    if (prev && !IS_ERR(prev)) {
        fq = prev;
        if (!refcount_inc_not_zero(&fq->refcnt))
            fq = NULL;
    }   
    rcu_read_unlock();
    return fq; 
}

在分片队列的新建流程中,将从slab中分配一段空间,相应增加分片包占用的内存,同时设置定时器(超时时常为30秒)用来检查重组结果,如果定时器超时未重组成功,该分片包也将丢弃。分片包的新建函数inet_frag_alloc定义如下:

static struct inet_frag_queue *inet_frag_alloc(struct netns_frags *nf,
                                               struct inet_frags *f,
                                               void *arg)
{
        struct inet_frag_queue *q;
       
        q = kmem_cache_zalloc(f->frags_cachep, GFP_ATOMIC);
        if (!q)
                return NULL;
       ... ...
       add_frag_mem_limit(nf, f->qsize);          //①增加分片报文占用内存

       setup_timer(&q->timer,                     //②设置超时定时器
f->frag_expire, (unsigned long)q);        
        ... ...
        return q;
}

int ip_defrag(struct net *net, struct sk_buff *skb, u32 user)
{
	... ...

	qp = ip_find(net, ip_hdr(skb), user, vif); //①查找分片队列
	if (qp) {
... ...
		ret = ip_frag_queue(qp, skb); //②分片队列入队操作
    ... ...
		return ret;
	}

	kfree_skb(skb);
	return -ENOMEM;
}

如果一个数据包的所有分片包都已接收,则需将所有分片包整合获得原始数据包,并将整合后的数据包提交给高层协议。同时,处理与分片包相关的数据结构,譬如更新当前分片包占用的内存(frag_mem_limit),停止与分片包相关的定时器等。数据包的重构函数ip_frag_reasm定义如下:

static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb,
             struct sk_buff *prev_tail, struct net_device *dev)
{
     ... ...
     ipq_kill(qp);                   //①减少分片包引用计数
     ... ...
     sub_frag_mem_limit(qp->q.net,   //②减少分片包占用内存
head->truesize);
     ... ...
}

所以,一个分片包的接收通常经历了查找分片、缓存、重组、释放等阶段,下图是分片包的接收流程。
image

图1 4.19分片包接收流程

根据分析,内核中待重组的分片包占用内存量由高水位(net.ipv4.ipfrag_high_thresh)阈值和分片保留时间(net.ipv4.ipfrag_time)来控制,如果待重组分片包内存占用高于高水位(high_thresh),那么新收到的数据包分片将会直接丢弃, 如果分片包超过最大保留时间(ipfrag_time),那么已经收到的数据包也会被丢弃。

附3.10 ip重组

image