2024年2月

前言

  1. 关于Transformer原理与论文的介绍:
    详细了解Transformer:Attention Is All You Need

对于论文给出的模型架构,使用 PyTorch 分别实现各个部分。

引入的相关库函数:

import copy
import torch
import math
from torch import nn
from torch.nn.functional import log_softmax

# module: 需要深拷贝的模块
# n: 拷贝的次数
# return: 深拷贝后的模块列表
def clones(module, n: int) -> list:
    return [copy.deepcopy(module) for _ in range(n)]

1. 编码器与解码器堆叠

Encoder 编码器

编码器由 N 个相同的编码层堆叠而成,每个编码层含两个子层:多头注意力层和前馈网络层。每个子层后跟着一层,用于残差连接与标准化。

Add & Norm 残差连接和标准化

对于上一层的结果:
\({\rm SubLayer}(x)\)
与输出上一层的变量:
\(x\)
做残差连接并进行标准化:
\({\rm LayerNorm}(x + {\rm Sublayer}(x))\)

# 层标准化
class LayerNorm(nn.Module):
    # 设置 features 形状的张量作为可学习的参数,初始化
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        # 初始化两个参数,α为权重,β为偏置
        self.a_2 = nn.Parameter(torch.ones(features))  
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        # 计算最后一个维度的均值、方差
        mean = x.mean(-1, keepdim=True)  
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

# 子层残差连接
class SublayerConnection(nn.Module):
    # size: 参数矩阵的shape, 
    # dropout_prob: dropout概率
    def __init__(self, size, dropout_prob):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
  • nn.Dropout()
    初始化参数
    p
    表示训练时,以概率 p 将输入张量的一些元素归零,对于没有归零的元素将乘以
    \(\frac{1}{1-p}\)
  • 输入为任意形状的张量,输出为与输入张量形状相同并经过处理的张量。[
    Source
    ]

Multi-Head Attention 多头注意力层

计算点乘注意力:$ \mathrm{Attention}(Q, K, V) = \mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V$

# q, k, v: 表示公式中的 Q, K, V
# mask: 当输入存在掩码时,将 mask 对应位置设置为负无穷
# dropout: dropout层
# return: 注意力层的输出,以及注意力权重
def attention(q, k, v, mask=None, dropout=None):
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    
    return torch.matmul(p_attn, v), p_attn 

# 多头注意力
class MultiHeadedAttention(nn.Module):
    # h: 多头注意力的头数
    # d_model: 嵌入词的维度
    def __init__(self, h, d_model, dropout_prob=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, q, k, v, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1) # 相同的mask应用于所有的注意力头h
        batch_size = q.size(0)

        # 1) 执行线性变换,将 d_model 维度的 x 分割成 h 个 d_k 维度
        q, k, v = [
            # 通过 view 改变张量形状,并使用 transpose 方法交换张量维度
            lin(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
            for lin, x in zip(self.linears, (q, k, v))
        ]

        # 2) 将 attention 用于每个 batch 的投影向量上
        x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)

        # 3) 通过线性层连接多头注意力计算完的向量
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
        return self.linears[-1](x)

关于
contiguous()

transpose()
不改变张量物理上的存储顺序,而是改变了查看时逻辑上的顺序,使得在内存上不连续(可以通过
is_contiguous()
查看张量是否是连续的)。

如果不是连续的,可以通过
contiguous()
方法返回内存上连续、数值上相同的张量。
view()
方法改变张量的形状需要张量是连续的。[
Source
]

Feed Forward 前馈网络层

由两个线性层组成,中间使用 ReLU 激活函数:
\(\mathrm{FFN}(x)=\max(0, xW_1 + b_1) W_2 + b_2\)

# 基于位置的前馈网络
class PositionwiseFeedForward(nn.Module):
    # d_model: 嵌入词的维度
    # d_ff: 前馈网络中间层的维度
    def __init__(self, d_model, d_ff, dropout_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x):
        return self.w_2(self.dropout(self.w_1(x).relu()))

编码层

每个编码层,含一个多头注意力层,一个前馈网络层,以及两个用于残差连接与标准化层分别跟在两个子层后面。N 个编码层组成编码器,每层的编码层的输出作为下一层的输入。

# 编码层
class EncoderLayer(nn.Module):
    # size: 参数矩阵的shape,
    # self_attn: 多头注意力层
    # feed_forward: 前馈网络层
    # dropout_prob: dropout概率
    def __init__(self, size, self_attn, feed_forward, dropout_prob):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout_prob), 2)
        self.size = size

    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda i: self.self_attn(i, i, i, mask))
        return self.sublayer[1](x, self.feed_forward)

# 编码器:由 N 个相同的层组成
class Encoder(nn.Module):
    def __init__(self, layer, n):
        super(Encoder, self).__init__()
        self.layers = clones(layer, n)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

EncoderLayer

forward()
内的
x = self.sublayer[0](x, lambda i: self.self_attn(i, i, i, mask))
,虽然此处输入的 q,k,v 均为 i 但在注意力层内,它们将分别与对应的 Q,K,V 矩阵(由线性层Linear实现)相乘,得到用于计算注意力的 q,k,v 。

Decoder 解码器

解码器由 N 层解码层组成。结构与编码层相似,由三个子层组成:带掩码的多头注意力层,多头注意力层和前馈网络层。每个子层后跟着一层,用于残差连接与标准化。

对于第二个子层,输入每一解码层的 K,V 为Encoder(第 N 层的编码层)的输出。为了区别输入Encoder和Decoder的嵌入词,分别用 src(Source,源) 和 tgt(Target,目标) 表示。

# 解码层:由多头注意力层、源-目标注意力层和前馈神经网络组成
class DecoderLayer(nn.Module):
    # size: 参数矩阵的shape,
    # self_attn: 多头注意力层
    # src_attn: 源-目标注意力层
    # feed_forward: 前馈网络层
    # dropout_prob: dropout概率
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout_prob):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout_prob), 3)

    # x: 解码曾输入
    # memory: 编码器的输出
    # src_mask: 源嵌入词掩码
    # tgt_mask: 目标嵌入词掩码
    # return: 解码层的输出
    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda i: self.self_attn(i, i, i, tgt_mask))
        x = self.sublayer[1](x, lambda i: self.src_attn(i, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)


# 解码器:由 N 个相同的层组成
class Decoder(nn.Module):
    def __init__(self, layer, n):
        super(Decoder, self).__init__()
        self.layers = clones(layer, n)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

2. Generator 生成器

生成器将解码器的输出映射到词汇表上,由一个线性层和一个 softmax 层组成,用于预测下一个token的概率。

# 生成器:线性层和 softmax 层
class Generator(nn.Module):
    # d_model: 解码器输出的(嵌入词)向量维度
    # vocab: 词汇表的维度大小
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return log_softmax(self.proj(x), dim=-1)  # 对最后一个维度进行 softmax

3. Embedding 嵌入层

使用
nn.Embedding
构建查找表(Look-Up Table, LUT)。[
Source
]

  • 初始化时,
    num_embedding
    表示嵌入字典大小;
    embedding_dim
    表示每个嵌入词向量的维度大小。

  • forward()
    中使用时,输入维度为
    \(d\)
    的张量,返回维度为
    \(d\times {\rm embedding\_dim}\)
    的张量。

文中,作者还将嵌入层返回的张量乘以
\(\sqrt{d_{model}}\)

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(num_embeddings=vocab, embedding_dim=d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

4. Positional Encoding 位置编码

为了使模型学习文本的顺序信息,需要引入位置编码:

\[\begin{cases}
PE_{(pos,2i)} = \sin(pos / 10000^{2i/d_{\text{model}}}) \\
PE_{(pos,2i+1)} = \cos(pos / 10000^{2i/d_{\text{model}}})
\end{cases}
\]

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout_prob, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout_prob)

        # 计算位置编码
        pe = torch.zeros(max_len, d_model)  # Shape: max_len x d_model
        position = torch.arange(0, max_len).unsqueeze(1)  # Shape: max_len x 1
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000) / d_model))
        res = position * div_term  # Shape: max_len x d_model/2
        pe[:, 0::2] = torch.sin(res)
        pe[:, 1::2] = torch.cos(res)
        pe = pe.unsqueeze(0)  # Shape: 1 x max_len x d_model
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)].requires_grad_(False)
        return self.dropout(x)

self.register_buffer()
用于将模型训练参数之外的变量注册加缓存,通过register_buffer()登记过的张量,会自动成为模型中的参数,随着模型移动(gpu/cpu)而移动,但是不会随着梯度进行更新。

在PyTorch中,对于梯度更新的需求,有着不同的张量定义方式[2]。

5. 整体架构

class EncoderDecoder(nn.Module):
    # encoder: 编码器
    # decoder: 解码器
    # src_embed: 源嵌入层
    # tgt_embed: 目标嵌入层
    # generator: 生成器
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    # src: 源语言句子
    # src_mask: 源语言句子掩码
    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)  # 编码器

    # memory: 编码器的输出
    # src_mask: 源语言句子掩码
    # tgt: 目标语言句子
    # tgt_mask: 目标语言句子掩码
    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

    def forward(self, src, tgt, src_mask, tgt_mask):
        memory = self.encode(src, src_mask)
        res_dec = self.decode(memory, src_mask, tgt, tgt_mask)
        return self.generator(res_dec)


# src_vocab: 源语言词典大小
# tgt_vocab: 目标语言词典大小
# n: 编码器和解码器的层数
# d_model: 嵌入词的维度
# d_ff: 前馈网络中间层的维度
# h: 多头注意力的头数
# dropout_prb: dropout概率
# return: Transformer 模型
def make_model(src_vocab, tgt_vocab, n=6, d_model=512, d_ff=2048, h=8, dropout_prb=0.1):
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout_prb)
    position = PositionalEncoding(d_model, dropout_prb)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout_prb), n),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout_prb), n),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab),
    )
    # 初始化参数
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

参考文献

  1. The Annotated Transformer
  2. 实测!PyTorch 中 nn.Parameter、register_buffer 和直接把 Tensor 当属性有啥区别?

Java线程池实现多任务并发执行

1️⃣ 创建一些任务来落地多任务并发执行

每一个数组里面的数据可以看成任务,或者是需要并发的业务接口,
数组与数组之间,可以看作为他们之间有血缘关系,简单来说就是:
taskJksj里面的10个任务执行完之后,才可以执行taskJxdx里面的4个任务,执行完
taskJxdx之后,才可以执行taskNbzz里面的2个任务

2️⃣ 创建线程池

要将taskJksj、taskJxdx、taskNbzz这几个数组中里面定义的任务通过线程池并发执行

3️⃣ ThreadPoolExecutor源码分析以及为什么不用newFixedThreadPool()和newCachedThreadPool()

1.首先为什么不用
newFixedThreadPool()和newCachedThreadPool()

点进去查看这两个方法的源码

2.为什么用ThreadPoolExecutor创建

通过上面两个例子就能看到,这俩方法很不靠谱,如果你不明白他的原理,看到项目上以前创建线程的代码就是这样的,你想都不想就copy过来,那后面绝对就是在给自己挖坑;
通过发现这俩方法,他们都是return new ThreadPoolExecutor(),所以真正的大佬其实是
ThreadPoolExecutor,他俩只是调用了
ThreadPoolExecutor而已。

  • 核心线程数:初始定义的线程数量,是绝对会开启的固定的线程数量
  • 最大线程数:当前线程池支持的最大线程数量,如果超过了这个数量那么肯定就报错了
  • 阻塞队列   :当前进来线程池的线程大于核心线程数且小于最大线程数,那么就把当前线程池的线程-核心线程数的线程放在阻塞队列里,让他等着
    假如核心线程数 2 个,最大线程数5个,阻塞队列长度  3,当前进来了4个线程,那么就将 4 - 2 = 2 个线程放在阻塞队列里面,让他先等待
  • 默认工厂  :当前进来线程池的线程大于最大线程数且小于(最大线程数+阻塞队列长度),那么就需要开放剩下的三个线程通道,让另外的3个线程通道进行工作,
    核心线程数 2 个,最大线程数5个,    阻塞队列长度3,当前进来了8个,可以看到进来了8个线程,已经满足了最大线程和阻塞队列长度之和了,简单理解就是
    现在进来的线程把目前这个线程池所有能利用的空间都占满了,只有 2个线程工作不够,需要把另外的3个(5 - 2)赶快放开让他们也工作,这个打开另外
    三个线程的这个工作就需要工厂来做,让工厂把这三个线程打开
  • 拒绝策略 :当进入线程池的线程过多,远远超过了最大线程数+阻塞队列,那么就需要拒绝这些即将要进入线程池的线程。
  • 等待时间:在等待时间段中,当线程池里面的线程都执行的差不多了,又回到了"
    进来线程池的线程大于核心线程数且小于最大线程数"时,就没有必要把5个线程通道全部打开,浪费资源,所以就把
    其    他的三个线程关掉,留2个核心的就行
  • 等待时间单位:单位,时分秒

4️⃣ 执行任务

为三个任务编写对应的执行多线程方法,写法都是一样的,重复copy即可,最后执行的效果就是

import java.util.concurrent.*;/*** @Author : YuanXin
* @create 2024/2/1 11:11
* @Description :
*/ public classMain {public static voidmain(String[] args) {

taskListImpl taskList
= newtaskListImpl();

String taskJksj
=taskList.poolExecutorJksj();

String taskJxdx
= null;if (taskJksj.equals("taskJksjSuccess")) {

taskJxdx
=taskList.poolExecutorJxdx();

}
if (taskJxdx.equals("taskJxdxSuccess")) {

taskList.poolExecutorNbzz();

}

}
}
classtaskListImpl {//创建一些任务 int[] taskJksj = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 0};int[] taskJxdx = new int[]{15, 16, 17, 18};int[] taskNbzz = new int[]{101, 102};publicString poolExecutorJksj() {

ThreadPoolExecutor pool
= newThreadPoolExecutor(3,10,3,
TimeUnit.SECONDS,
new LinkedBlockingDeque<>(5),
Executors.defaultThreadFactory(),
newThreadPoolExecutor.AbortPolicy()
);
//ExecutorService cachedThreadPool = Executors.newCachedThreadPool();//ExecutorService fixedThreadPool = Executors.newFixedThreadPool(5); try{for (int i = 0; i < taskJksj.length; i++) {int num =i;


pool.execute(()
->{

taskJksjPool(num);

});

}
}
catch(Exception e) {throw newRuntimeException(e);
}
finally{
pool.shutdown();
}
return "taskJksjSuccess";

}
public void taskJksjPool(intnum) {

System.out.println(Thread.currentThread().getName()
+ " " +taskJksj[num]);

}
publicString poolExecutorJxdx() {

ThreadPoolExecutor pool
= new ThreadPoolExecutor(3, 10, 3, TimeUnit.SECONDS, new LinkedBlockingDeque<>(5), Executors.defaultThreadFactory(), newThreadPoolExecutor.AbortPolicy());try{for (int i = 0; i < taskJxdx.length; i++) {int num =i;

pool.execute(()
->{

taskJxdxPool(num);

});

}
}
catch(Exception e) {throw newRuntimeException(e);
}
finally{
pool.shutdown();
}
return "taskJxdxSuccess";

}
public void taskJxdxPool(intnum) {

System.out.println(Thread.currentThread().getName()
+ " " +taskJxdx[num]);

}
publicString poolExecutorNbzz() {

ThreadPoolExecutor pool
= new ThreadPoolExecutor(3, 10, 3, TimeUnit.SECONDS, new LinkedBlockingDeque<>(5), Executors.defaultThreadFactory(), newThreadPoolExecutor.AbortPolicy());try{for (int i = 0; i < taskNbzz.length; i++) {int num =i;

pool.execute(()
->{

taskNbzzPool(num);

});

}
}
catch(Exception e) {throw newRuntimeException(e);
}
finally{
pool.shutdown();
}
return "taskNbzzSuccess";

}
public void taskNbzzPool(intnum) {

System.out.println(Thread.currentThread().getName()
+ " " +taskNbzz[num]);

}

}

命令行程序是平时写一些小工具时最常用的方式。

为了让命令行程序更加灵活,我们常常会设置一些参数,根据参数让程序执行不同的功能。
这样就不用频繁的修改代码来执行不同的功能。

随着命令行程序功能的丰富,也就是参数多了以后,解析和管理参数之间的关系会变得越来越繁重。
而本次介绍的
Fire
库,正好可以解决这个问题。
使用
Fire
库,我们只要关心具体功能的实现,最后
Fire
会帮助我们自动把所有功能组织成一个命令行程序。

Fire
库在github上的地址:
https://github.com/google/python-fire

1. 一般命令

一般的命令,也就是带有几个参数的一段程序,比如:

# -*- coding:utf-8 -*-

def import_file(fp):
    print("import file from: {}".format(fp))

def export_file(fp):
    print("EXPORT file to: {}".format(fp))

这是模拟文件导出功能的两个函数。

使用
Fire
转换成命令行程序非常简单,下面介绍几种常用的方式。

1.1. 默认方式

# -*- coding:utf-8 -*-

import fire

def import_file(fp):
    print("IMPORT file from: {}".format(fp))

def export_file(fp):
    print("EXPORT file to: {}".format(fp))

if __name__ == "__main__":
    # fire默认会将所有函数转换成子命令
    fire.Fire()

然后,就可以通过子命令的方式执行导入导出功能。

$ python main.py import_file --fp ./path/xxx.csv
IMPORT file from: ./path/xxx.csv

$ python main.py export_file --fp ./path/xxx.csv
EXPORT file to: ./path/xxx.csv

函数的名称
自动变为
子命令的名称

函数的参数
自动变成
子命令的参数

1.2. Fire<fn>

Fire
库的默认方式会把所有的函数都转换为子命令,
如果只想导出一个函数的话,可以用
Fire<fn>
的方式。

if __name__ == "__main__":
    # 只导出 import_file 函数作为子命令
    fire.Fire(import_file)

只导出一个函数的时候,执行命令的时候不需要输入子命令的名称。

$ python main.py --fp ./path/xxx.csv
IMPORT file from: ./path/xxx.csv

1.3. Fire<dict>

导出多个函数作为子命令时,默认是使用函数名作为子命令名称的,函数名称有时候会非常长,输入很麻烦。
这时,可以用
Fire<dict>
方式。

if __name__ == "__main__":
    # 子命令的名称分别是:import 和 export
    fire.Fire({
        "import": import_file,
        "export": export_file,
    })

执行时,使用简化的子命令名称。

$ python main.py import --fp ./path/xxx.csv
IMPORT file from: ./path/xxx.csv

$ python main.py export --fp ./path/xxx.csv
EXPORT file to: ./path/xxx.csv

这种方式非常灵活,不仅可以设置子命令名称,还可以控制需要导出哪些函数。

1.4. Fire<object>

除了导出函数,
Fire<object>
方式也可以导出对象的公有方法。

import fire

class FileHandler():

    def __init__(self):
        pass

    def import_file(self, fp):
        print("IMPORT file from: {}".format(fp))

    def export_file(self, fp):
        print("EXPORT file to: {}".format(fp))

if __name__ == "__main__":
    fh = FileHandler()
    fire.Fire(fh)

使用方式如下:

$ python main.py import_file --fp ./path/xxx.csv
IMPORT file from: ./path/xxx.csv

$ python main.py export_file --fp ./path/xxx.csv
EXPORT file to: ./path/xxx.csv

使用对象的方式没有直接使用函数那么简单,但有个好处是可以在初始化时传入一些状态。

import fire
import os

class FileHandler():

    def __init__(self, folder=""):
        self.folder = folder

    def import_file(self, fp):
        print("IMPORT file from: {}".format(os.path.join(self.folder, fp)))

    def export_file(self, fp):
        print("EXPORT file to: {}".format(os.path.join(self.folder, fp)))

if __name__ == "__main__":
    # 设置了默认文件夹,使用时直接传入文件名即可
    fh = FileHandler("./default_path")
    fire.Fire(fh)
$ python main.py import_file --fp xxx.csv
IMPORT file from: ./default_path/xxx.csv

$ python main.py export_file --fp xxx.csv
EXPORT file to: ./default_path/xxx.csv

1.5. Fire<class>

Fire<class>
的方式也可以直接作用在类上,不用初始化对象。

if __name__ == "__main__":
    fire.Fire(FileHandler)


Fire<object>
不同的是,
__init__
方法的参数也变成了命令的参数,也可以在命令行中调整了。

$ python main.py import_file --fp xxx.csv
IMPORT file from: xxx.csv

$ python main.py import_file --fp xxx.csv --folder ./my_folder
IMPORT file from: ./my_folder/xxx.csv

2. 组合命令

当功能越来越多时,可能就会需要组合一些功能一起运行,省得输入一个一个的子命令。

class FileHandler():

    def __init__(self, folder="./defalut_dir"):
        self.folder = folder

    def import_file(self, fp):
        print("IMPORT file from: {}".format(os.path.join(self.folder, fp)))

    def export_file(self, fp):
        print("EXPORT file to: {}".format(os.path.join(self.folder, fp)))

class DatabaseHandler():

    def __init__(self, src="aliyun-mysql", dst="tecent-mysql"):
        self.src = src
        self.dst = dst

    def import_db(self):
        print("IMPORT data from: {} to: {}".format(self.src, self.dst))

    def export_db(self):
        print("EXPORT data from: {} to: {}".format(self.src, self.dst))

# 组合 FileHandler 和 DatabaseHandler
class ComposeHandler():
    def __init__(self):
        self.fh = FileHandler()
        self.dh = DatabaseHandler()

    def import_all(self, fp):
        self.fh.import_file(fp)
        self.dh.import_db()


    def export_all(self, fp):
        self.fh.export_file(fp)
        self.dh.export_db()

if __name__ == "__main__":
    fire.Fire(ComposeHandler)

导出组合命令之后,不仅可以执行组合命令,也可以只执行子命令。

$ python main.py import_all --fp xxx.csv
IMPORT file from: ./defalut_dir/xxx.csv
IMPORT data from: aliyun-mysql to: tecent-mysql

$ python main.py export_all --fp xxx.csv
EXPORT file to: ./defalut_dir/xxx.csv
EXPORT data from: aliyun-mysql to: tecent-mysql

$ python main.py fh export_file --fp xxx.csv
EXPORT file to: ./defalut_dir/xxx.csv

$ python main.py dh export_db
EXPORT data from: aliyun-mysql to: tecent-mysql

3. 链式命令

链式命令和组合命令不一样的地方在于:
组合命令中,每个命令之间一般是相互独立的,
而链式命令中,上一个命令的执行结果会对下一个命令造成影响。
比如:

class Stat():
    def __init__(self):
        self.total = 0
        self.avg = 0
        self.n = 0

    # 模拟统计合计值
    def sum(self, n):
        self.n += n
        for i in range(n):
            self.total += i

        return self

    # 模拟求平均值
    def average(self):
        if self.n == 0:
            self.avg = 0
        else:
            self.avg = self.total / self.n

        return self

    # 显示分析结果
    def show(self):
        print("SUM: {}, and AVERAGE: {}".format(self.total, self.avg))

if __name__ == "__main__":
    fire.Fire(Stat)

执行链式命令时,可以先求和,再求平均值,最后显示结果:

$ python main.py sum 10 average show
SUM: 45, and AVERAGE: 4.5

因为是链式命令,所以可以多次执行:

$ python main.py sum 10 sum 10 average show
SUM: 90, and AVERAGE: 4.5

$ python main.py sum 10 sum 20 sum 30 average show
SUM: 670, and AVERAGE: 11.166666666666666

4. 复杂命令参数

上面的示例中,参数都是简单的数据类型,比如字符串,数字之类的。

最后,介绍下复杂的命令参数如何使用,所谓复杂的参数,就是元组,列表,字典等等。

def hello(data):
    tp = type(data).__name__
    if  tp == "tuple" or tp == "list":
        for item in data:
            print("hello: {}".format(item))

    if tp == "dict":
        for k, v in data.items():
            print("hello: key {}, val {}".format(k, v))

if __name__ == "__main__":
    fire.Fire(hello)

python
是弱类型语言,函数的参数可以是任何类型。
主要看看命令行中如何传入复杂的类型:

$ python main.py "(aa, bb, cc)"
hello: aa
hello: bb
hello: cc

$ python main.py "[aa, bb, cc]"
hello: aa
hello: bb
hello: cc

$ python main.py "{aa: 11, bb: 22}"
hello: key aa, val 11
hello: key bb, val 22

5. 总结

Python

Fire
库是一个构思非常巧妙的命令行接口库,各种语言的命令行接口我接触过不少,还没有在其他编程语言中看到过类似的库。

Fire
库最方便的地方在于,你不用再关心命令行的参数(参数的解析一直是命令行程序中最让人头疼的地方),只要专注于自己要实现的功能。
此外,如果你已经有一些
python脚本
的话,通过这个库把它们改造成命令行程序也非常简单。

wmproxy

wmproxy
已用
Rust
实现
http/https
代理,
socks5
代理, 反向代理, 负载均衡, 静态文件服务器,
websocket
代理,四层TCP/UDP转发,内网穿透等,会将实现过程分享出来,感兴趣的可以一起造个轮子

项目地址

国内: https://gitee.com/tickbh/wmproxy

github: https://github.com/tickbh/wmproxy

设计目标

负载均衡时通过匹配规则匹配正确的location进行处理相关的操作。

设计方案变更

初始设计方案

初始方案以最快的方式进行支持,仅支持前缀匹配,即如果配置

[[http.server.location]]
rule = "/wmproxy"

那么当我们访问
/wmproxy/xx
时将会被分配到该location,此方案相对简单,但是当我们碰到复杂的需求时将无法被满足。

设计方案需求

除了前缀匹配外,我们将会有其它各种需求的匹配:

  • 后缀匹配
    比如以wmproxy结尾的path,如
    /api/update/wmproxy
    需要匹配成
    *wmproxy
  • 中间匹配
    比如常用的api中间转化成数据
    /api/<user_id>/get
    ,那么匹配为
    /api/*/get
  • 正则匹配
    当前的配置的为正则规则,需进行匹配
  • 请求方法匹配
    比如仅当请求方法为
    POST
    才进行转发
  • 客户端IP
    比如仅当客户端内网或者外网时区分请求
  • Host地址
    比如当前如果请求为ip则不进行转发,需要匹配host才进行转发
  • 协议
    比如某个网站不支持
    http
    当我们匹配到
    http
    时需强制转化成
    https
    实际配置中当仅仅只有前缀匹配时已经显然无法满足我们的需求

设计方案迭代

当前我们就必须将数据进行更迭,但是在通常情况下我们又不想将配置变得复杂,此时就需要我们支持更多的类的自定义化,首先我们定义类:

/// location匹配,将根据该类的匹配信息进行是否匹配
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Matcher {
    path: Option<String>,
    #[serde_as(as = "Option<DisplayFromStr>")]
    client_ip: Option<IpSets>,
    #[serde_as(as = "Option<DisplayFromStr>")]
    remote_ip: Option<IpSets>,
    host: Option<String>,
    #[serde_as(as = "Option<DisplayFromStr>")]
    method: Option<MatchMethod>,
    #[serde_as(as = "Option<DisplayFromStr>")]
    scheme: Option<MatchScheme>,
}

此时我们将location中的rule的类型从String变成了Matcher,那么此时我们首先遇到的一个问题他可能为一个String值或者可能为一个Map值,我们先得对这种情况进行处理。
我们根据serde的提供的解析方案进行如下函数,当前我们重写了
visit_str

visit_map
表示我们将只支持这两种源格式转化成
Matcher

pub fn string_or_struct<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
    T: Deserialize<'de> + FromStr<Err = WebError>,
    D: Deserializer<'de>,
{
    struct StringOrStruct<T>(PhantomData<fn() -> T>);

    impl<'de, T> Visitor<'de> for StringOrStruct<T>
    where
        T: Deserialize<'de> + FromStr<Err = WebError>,
    {
        type Value = T;

        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
            formatter.write_str("string or map")
        }

        fn visit_str<E>(self, value: &str) -> Result<T, E>
        where
            E: de::Error,
        {
            Ok(FromStr::from_str(value).unwrap())
        }

        fn visit_map<M>(self, map: M) -> Result<T, M::Error>
        where
            M: MapAccess<'de>,
        {
            Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))
        }
    }
    deserializer.deserialize_any(StringOrStruct(PhantomData))
}

其次我们将在location中做处理

/// 负载均衡中的location匹配,将匹配合适的处理逻辑
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocationConfig {
    #[serde(deserialize_with = "string_or_struct")]
    pub rule: Matcher,
    //...
}

由于这种大类的匹配通常会在别处额外定义,我们通过以
@name

@
开头来表示索引的信息,来简化配置。通过初始化的时候来重新初始化
Matcher

处理匹配

我们初始化完Matcher之后,需要能正确的判断传入的数据是否当前能正确匹配。主要的复杂点在于path的匹配,主要为
正则匹配

前缀匹配

中间匹配

后缀匹配

对其进行细分,可确定分为两种

  1. 正则匹配

  2. *
    的路径匹配

    1. 前缀匹配可以看成
      /start*
      或者
      /start
    2. 中间匹配可以看成
      /start*end
    3. 后缀匹配可以看成
      *end

即当前我们只需处理两种匹配模式:

  • 正则匹配
    ,频繁调用时主要在于初始化正则时可能会消耗大量的算力。当前我们对我们的匹配规则的正则进行缓存
/// may memory leak
pub fn try_cache_regex(origin: &str) -> Option<Regex> {
    // 因为均是从配置中读取的数据, 在这里缓存正则表达示会在总量上受到配置的限制
    lazy_static! {
        static ref RE_CACHES: Mutex<HashMap<&'static str, Option<Regex>>> =
            Mutex::new(HashMap::new());
    };

    if origin.len() == 0 {
        return None;
    }

    if let Ok(mut guard) = RE_CACHES.lock() {
        if let Some(re) = guard.get(origin) {
            return re.clone();
        } else {
            if let Ok(re) = Regex::new(origin) {
                guard.insert(
                    Box::leak(origin.to_string().into_boxed_str()),
                    Some(re.clone()),
                );
                return Some(re);
            }
        }
    }
    return None;
}

此处我们用到了static变量,也就是将某部分数据进行了静态化处理,且此处我们将String转化成了
&'static str
可能存在一定的内存泄漏,大小值跟配置的数据有关,可以接受这空间换取时间。然后用正则的is_match进行匹配即可。

if let Some(re) = Helper::try_cache_regex(&p) {
    if !re.is_match(path) {
        return Ok(false);
    }
}

  • *
    的路径匹配

    主要将路径中的*进行前进字符串的匹配。
    在rust中的字符串切割主要由
    split
    或者
    strip_prefix
    或者
    strip_suffix
    来处理,相对其它语言中均存在的
    subString
    或者
    substr
    在rust中的则表示为引用,所以在rust中不存在substring函数
let src = "wmproxy is good";
let first = &src[..7];
let second = &src[3..8];
let end = &src[8..];
let vals = src.split(" ").collect::<Vec<&str>>();

以上各数据均引用src的资源,即在这过程中并没有创建内存对象。
那么匹配函数则先将
'*'
进行分割,数组的第一个则前缀匹配,最后一个则后缀匹配,若不存在
'*'
则数组数量为1,符合前缀匹配。

pub fn is_match(src: &str, pattern: &str) -> bool {
    let mut oper = src;
    let vals = pattern.split("*").collect::<Vec<&str>>();
    for i in 0..vals.len() {
        if i == 0 {
            if let Some(val) = oper.strip_prefix(vals[i]) {
                oper = val;
            } else {
                return false;
            }
        } else if i == vals.len() - 1 {
            if let Some(val) = oper.strip_suffix(vals[i]) {
                oper = val;
            } else {
                return false;
            }
        } else {
            if let Some(idx) = oper.find(vals[i]) {
                oper = &oper[idx + vals[i].len() .. ]
            } else {
                return false;
            }
        }
    }
    true
}

那么完整的匹配函数在
Matcher

/// 当本地限制方法时,优先匹配方法,在进行路径的匹配
pub fn is_match_rule(&self, path: &String, req: &RecvRequest) -> ProtResult<bool>  {
    if let Some(p) = &self.path {
        let mut is_match = false;
        if Helper::is_match(&path, p) {
            is_match = true;
        }
        if !is_match {
            if let Some(re) = Helper::try_cache_regex(&p) {
                if !re.is_match(path) {
                    return Ok(false);
                }
            } else {
                return Ok(false);
            }
        }
    }

    if let Some(m) = &self.method {
        if !m.0.contains(req.method()) {
            return Ok(false);
        }
    }

    if let Some(s) = &self.scheme {
        if !s.0.contains(req.scheme()) {
            return Ok(false);
        }
    }

    if let Some(h) = &self.host {
        match req.get_host() {
            Some(host) if &host == h => {},
            _ => return Ok(false),
        }
    }

    if let Some(c) = &self.client_ip {
        match req.headers().system_get("{client_ip}") {
            Some(ip) => {
                let ip = ip
                .parse::<IpAddr>()
                .map_err(|_| ProtError::Extension("client ip error"))?;
                if !c.contains(&ip) {
                    return Ok(false)
                }
            },
            None => return Ok(false),
        }
    }

    Ok(true)
}

小结

匹配规则在对于复杂匹配的时候尤为重要,我们可以轻松的将各个请求分配到合适的位置,此处我们着重介绍了正则匹配及带
*
的路径匹配。

点击
[关注]

[在看]

[点赞]
是对作者最大的支持

前言

最近有些小伙伴,希望我分享一些好用的工具类,帮他们提升开发效率。

今天这篇文章专门跟大家一起总结一下,Spring框架本身自带的一些好用的工具类,希望对你会有所帮助。

1 Assert

很多时候,我们需要在代码中做判断:如果不满足条件,则抛异常。

有没有统一的封装呢?

其实Spring给我们提供了
Assert
类,它表示断言。

1.1 断言参数是否为空

断言参数是否空,如果不满足条件,则直接抛异常。

String str = null;
Assert.isNull(str, "str必须为空");
Assert.isNull(str, () -> "str必须为空");
Assert.notNull(str, "str不能为空");

如果不满足条件就会抛出IllegalArgumentException异常。

1.2 断言集合是否为空

断言集合是否空,如果不满足条件,则直接抛异常。

List<String> list = null;
Map<String, String> map = null;
Assert.notEmpty(list, "list不能为空");
Assert.notEmpty(list, () -> "list不能为空");
Assert.notEmpty(map, "map不能为空");

如果不满足条件就会抛出IllegalArgumentException异常。

1.3 断言条件是否为空

断言是否满足某个条件,如果不满足条件,则直接抛异常。

List<String> list = null;
Assert.isTrue(CollectionUtils.isNotEmpty(list), "list不能为空");
Assert.isTrue(CollectionUtils.isNotEmpty(list), () -> "list不能为空");

当然Assert类还有一些其他的功能,这里就不多介绍了。

2 StringUtils

在我们日常开发过程中,对字符串的操作是非常频繁的,但JDK提供的对于字符串操作的方法,过于简单,无法满足我们开发中的需求。

其实Spring提供了工具类StringUtils,对JDK中一些字符串的操作进行了扩展。

2.1 判空

StringUtils类其实有个isEmpty()方法判断,不过已经被废弃了。

我们可以改成使用hasLength()方法判断,例如:

if (!StringUtils.hasLength("")) {
  System.out.println("字符串为空");
}

2.2 去掉空格

对于后端的很多接口,经常需要去掉前后空格,我们可以使用String类的trim(),但是如果要同时去掉中间的空格呢?

可以使用StringUtils类的trimAllWhitespace方法。

例如:

@Test
public void testEmpty() {
    System.out.println("1" + StringUtils.trimAllWhitespace(" 苏三说技术 测试 ") + "1");
}

这个方法执行接口:1苏三说技术测试1,会把中间的空格也去掉了。

2.3 判断开头或结尾字符串

要判断一个字符串,是不是以某个固定字符串开头或者结尾,是非常常见的需求。

我们可以使用StringUtils类的startsWithIgnoreCase和endsWithIgnoreCase,可以忽略大小写比较字符串。

例如:

@Test
public void testEmpty() {
    System.out.println(StringUtils.startsWithIgnoreCase("苏三说技术", "苏三"));
    System.out.println(StringUtils.endsWithIgnoreCase("苏三说技术", "技术"));
}

该方法的执行结果会返回两个true。

2.4 集合拼接字符串

有时候我们需要将某个字符串集合的所有元素,拼接成一个字符串,用逗号隔开。

这种场景可以使用StringUtils类的collectionToCommaDelimitedString方法。

例如:

@Test
public void testEmpty() {
    List<String> list = new ArrayList<>();
    list.add("a");
    list.add("b");
    list.add("c");
    System.out.println(StringUtils.collectionToCommaDelimitedString(list));
}

该方法的执行结果:a,b,c

这个工具类里面还有很多有用的方法:

3. CollectionUtils

在我们日常开发当中,经常会遇到集合,比如:list判空的情况。

其实Spring专门为我们提供了,给集合判空的工具类:
CollectionUtils
,它位于org.springframework.util包下。

对于一些简单的集合判断,集合中是否包含某个元素,集合转数组,用这个工具还是非常方便的。

3.1 集合判空

通过CollectionUtils工具类的isEmpty方法可以轻松判断集合是否为空。

例如:

List<Integer> list = new ArrayList<>();
list.add(2);
list.add(1);
list.add(3);

if (CollectionUtils.isEmpty(list)) {
    System.out.println("集合为空");
}

3.2 判断元素是否存在

通过CollectionUtils工具类的contains方法,可以判断元素在集合中是否存在。

例如:

List<Integer> list = new ArrayList<>();
list.add(2);
list.add(1);
list.add(3);

if (CollectionUtils.contains(list.iterator(), 3)) {
    System.out.println("元素存在");
}

在判断时需要先调用集合的iterator()方法。

4 ObjectUtils

Spring为我们专门提供了一个对象操作工具:
ObjectUtils
,也在org.springframework.util包下。

里面有很多非常有用的方法。

4.1 判空

之前已经介绍过字符串判空工具类StringUtils,和集合的判空工具类CollectionUtils。

而ObjectUtils工具的判空更强大,支持:对象、字符串、集合、数组、Optional、Map的判断。

例如:

 @Test
public void testEmpty() {
    String a = "123";
    Integer b = new Integer(1);
    List<String> c = new ArrayList<>();
    Integer[] d = new Integer[]{b};
    c.add(a);
    Map<String, String> e = new HashMap<>();
    e.put(a, a);
    Optional<String> f = Optional.of(a);
    if (!ObjectUtils.isEmpty(a)) {
        System.out.println("a不为空");
    }
    if (!ObjectUtils.isEmpty(b)) {
        System.out.println("b不为空");
    }
    if (!ObjectUtils.isEmpty(c)) {
        System.out.println("c不为空");
    }
    if (!ObjectUtils.isEmpty(d)) {
        System.out.println("d不为空");
    }
    if (!ObjectUtils.isEmpty(e)) {
        System.out.println("e不为空");
    }
    if (!ObjectUtils.isEmpty(f)) {
        System.out.println("f不为空");
    }
}

这6种对象的判空都支持,非常强大。

4.2 判断两个对象相等

之前我们用Objects.equals方法,判断两个对象是否相等,经常会出现空指针问题。

而ObjectUtils类提供了安全的判断两个对象相等的方法:nullSafeEquals。

例如:

@Test
public void testEquals() {
    String a = "123";
    String b = null;
    System.out.println(ObjectUtils.nullSafeEquals(a, b));
}

这个例子返回的是false,不会出现空指针的问题。

甚至可以判断两个数组是否相等。

例如:

@Test
public void testArrayEquals() {
    String[] a = new String[]{"123"};
    String[] b = new String[]{"123"};
    System.out.println(ObjectUtils.nullSafeEquals(a, b));
}

这个例子的执行结果返回的是true。

4.3 获取对象的hashCode

如果想要快速获取某个对象十六进制的hashCode,则可以调用getIdentityHexString方法。

例如:

@Test
public void testIdentityHex() {
    String a = "123";
    System.out.println(ObjectUtils.getIdentityHexString(a));
}

执行结果:2925bf5b

5 ClassUtils

Spring的org.springframework.util包下的
ClassUtils
类,它里面有很多让我们惊喜的功能。

它里面包含了类和对象相关的很多非常实用的方法。

5.1 获取对象的所有接口

如果你想获取某个对象的所有接口,可以使用ClassUtils的getAllInterfaces方法。例如:

Class<?>[] allInterfaces = ClassUtils.getAllInterfaces(new User());

5.2 获取某个类的包名

如果你想获取某个类的包名,可以使用ClassUtils的getPackageName方法。例如:

String packageName = ClassUtils.getPackageName(User.class);
System.out.println(packageName);

5.3 判断某个类是否内部类

如果你想判断某个类是否内部类,可以使用ClassUtils的isInnerClass方法。例如:

System.out.println(ClassUtils.isInnerClass(User.class));

5.4 判断对象是否代理对象

如果你想判断对象是否代理对象,可以使用ClassUtils的isCglibProxy方法。例如:

System.out.println(ClassUtils.isCglibProxy(new User()));

ClassUtils还有很多有用的方法,等待着你去发掘。感兴趣的朋友,可以看看下面内容:

最近就业形式比较困难,为了感谢各位小伙伴对苏三一直以来的支持,我特地创建了一些工作内推群, 看看能不能帮助到大家。

你可以在群里发布招聘信息,也可以内推工作,也可以在群里投递简历找工作,也可以在群里交流面试或者工作的话题。
添加苏三的私人微信:su_san_java,备注:博客园+所在城市,即可加入。

6 BeanUtils

Spring给我们提供了一个JavaBean的工具类,它在org.springframework.beans包下面,它的名字叫做:
BeanUtils

让我们一起看看这个工具可以带给我们哪些惊喜。

6.1 拷贝对象的属性

曾几何时,你有没有这样的需求:把某个对象中的所有属性,都拷贝到另外一个对象中。这时就能使用BeanUtils的copyProperties方法。例如:

User user1 = new User();
user1.setId(1L);
user1.setName("苏三说技术");
user1.setAddress("成都");

User user2 = new User();
BeanUtils.copyProperties(user1, user2);
System.out.println(user2);

6.2 实例化某个类

如果你想通过反射实例化一个类的对象,可以使用BeanUtils的instantiateClass方法。例如:

User user = BeanUtils.instantiateClass(User.class);
System.out.println(user);

6.3 获取指定类的指定方法

如果你想获取某个类的指定方法,可以使用BeanUtils的findDeclaredMethod方法。例如:

Method declaredMethod = BeanUtils.findDeclaredMethod(User.class, "getId");
System.out.println(declaredMethod.getName());

6.4 获取指定方法的参数

如果你想获取某个方法的参数,可以使用BeanUtils的findPropertyForMethod方法。例如:

Method declaredMethod = BeanUtils.findDeclaredMethod(User.class, "getId");
PropertyDescriptor propertyForMethod = BeanUtils.findPropertyForMethod(declaredMethod);
System.out.println(propertyForMethod.getName());

如果你对BeanUtils比较感兴趣,可以看看下面内容:

7 ReflectionUtils

有时候,我们需要在项目中使用反射功能,如果使用最原始的方法来开发,代码量会非常多,而且很麻烦,它需要处理一大堆异常以及访问权限等问题。

好消息是Spring给我们提供了一个
ReflectionUtils
工具,它在org.springframework.util包下面。

7.1 获取方法

如果你想获取某个类的某个方法,可以使用ReflectionUtils类的findMethod方法。例如:

Method method = ReflectionUtils.findMethod(User.class, "getId");

7.2 获取字段

如果你想获取某个类的某个字段,可以使用ReflectionUtils类的findField方法。例如:

Field field = ReflectionUtils.findField(User.class, "id");

7.3 执行方法

如果你想通过反射调用某个方法,传递参数,可以使用ReflectionUtils类的invokeMethod方法。例如:

ReflectionUtils.invokeMethod(method, springContextsUtil.getBean(beanName), param);

7.4 判断字段是否常量

如果你想判断某个字段是否常量,可以使用ReflectionUtils类的isPublicStaticFinal方法。例如:

Field field = ReflectionUtils.findField(User.class, "id");
System.out.println(ReflectionUtils.isPublicStaticFinal(field));

7.5 判断是否equals方法

如果你想判断某个方法是否equals方法,可以使用ReflectionUtils类的isEqualsMethod方法。例如:

Method method = ReflectionUtils.findMethod(User.class, "getId");
System.out.println(ReflectionUtils.isEqualsMethod(method));

当然这个类还有不少有趣的方法,感兴趣的朋友,可以看看下面内容:

最近就业形式比较困难,为了感谢各位小伙伴对苏三一直以来的支持,我特地创建了一些工作内推群, 看看能不能帮助到大家。

你可以在群里发布招聘信息,也可以内推工作,也可以在群里投递简历找工作,也可以在群里交流面试或者工作的话题。
添加苏三的私人微信:su_san_java,备注:博客园+所在城市,即可加入。

8 Base64Utils

有时候,为了安全考虑,需要将参数只用base64编码。

这时就能直接使用org.springframework.util包下的
Base64Utils
工具类。

它里面包含:
encode

decode
方法,用于对数据进行
编码

解码

例如:

String str = "abc";
String encode = new String(Base64Utils.encode(str.getBytes()));
System.out.println("编码后:" + encode);
try {
    String decode = new String(Base64Utils.decode(encode.getBytes()), "utf8");
    System.out.println("解码:" + decode);
} catch (UnsupportedEncodingException e) {
    e.printStackTrace();
}

执行结果:

编码后:YWJj
解码后:abc

9 SerializationUtils

有时候,我们需要把数据进行序列化和反序列化处理。

传统的做法是某个类实现Serializable接口,然后重新它的writeObject和readObject方法。

但如果使用org.springframework.util包下的
SerializationUtils
工具类,能更轻松实现序列化和反序列化功能。

例如:

Map<String, String> map = Maps.newHashMap();
map.put("a", "1");
map.put("b", "2");
map.put("c", "3");
byte[] serialize = SerializationUtils.serialize(map);
Object deserialize = SerializationUtils.deserialize(serialize);
System.out.println(deserialize);

10 HttpStatus

很多时候,我们会在代码中定义http的返回码,比如:接口正常返回200,异常返回500,接口找不到返回404,接口不可用返回502等。

private int SUCCESS_CODE = 200;
private int ERROR_CODE = 500;
private int NOT_FOUND_CODE = 404;

其实org.springframework.http包下的
HttpStatus
枚举,或者org.apache.http包下的HttpStatus接口,已经把常用的http返回码给我们定义好了,直接拿来用就可以了,真的不用再重复定义了。

11 HtmlUtils

有时候,用户输入的内容中包含了一些特殊的标签,比如<,如果不错处理程序可能会报错。

而且为了安全性,对用户输入的特色字符,也需要做转义,防止一些SQL注入,或者XSS攻击等。

其实Spring给我们提供了一个专门处理html的工具:HtmlUtils,我们可以直接用它来做转义,使用起来非常方便。

例如:

@Test
public void testHtml() {
    String specialStr = "<div id=\"testDiv\">test1;test2</div>";
    String str1 = HtmlUtils.htmlEscape(specialStr);
    System.out.println(str1);
}

执行结果:
&lt;div id=&quot;testDiv&quot;&gt;test1;test2&lt;/div&gt
;

最后说一句(求关注,别白嫖我)

如果这篇文章对您有所帮助,或者有所启发的话,帮忙扫描下发二维码关注一下,您的支持是我坚持写作最大的动力。
求一键三连:点赞、转发、在看。
关注公众号:【苏三说技术】,在公众号中回复:面试、代码神器、开发手册、时间管理有超赞的粉丝福利,另外回复:加群,可以跟很多BAT大厂的前辈交流和学习。