2024年3月

技术背景

在前面的几篇
博客
中,我们介绍了MindSpore框架下使用CUDA来定义本地算子的基本方法,以及配合反向传播函数的使用,这里主要探讨一下MindSpore框架对于CUDA本地算子的输入输出的规范化形式。

测试思路

MindSpore使用的CUDA算子规范化接口形式为:

extern "C" int CustomOps(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
                         void *stream, void *extra)

也就是说,我们在一个
.cu
文件中按照这种形式写好函数接口,其中主要是规范化输入输出的形式,然后再将各项输入传给写好的CUDA Kernel函数进行计算并获得返回值。
我们可以使用一个Kernel打印函数的测试案例来说明MindSpore对于输入输出的处理:

#include <iostream>
#define THREADS 1024

__global__ void OpsKernel(const int shape0, const int *input){
    auto i = blockIdx.x * THREADS + threadIdx.x;
    if (i < shape0){
        printf("%d\n", input[i]);
    }
}

在这个函数体内,会把指定大小范围内的input的内容打印出来。

常数输入

首先我们来看一下最简单的常数输入,可以用一个最简单的整数来测试,对应的CUDA算子代码为:

// nvcc --shared -Xcompiler -fPIC -o test_shape.so test_shape.cu

#include <iostream>
#define THREADS 1024

__global__ void OpsKernel(const int shape0, const int *input){
    auto i = blockIdx.x * THREADS + threadIdx.x;
    if (i < shape0){
        printf("%d\n", input[i]);
    }
}

extern "C" int CustomOps(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
                         void *stream, void *extra){
    int *input = static_cast<int*>(params[0]);
    OpsKernel<<<1, THREADS>>>(shapes[0][0], input);
    return 0;
}

调用CUDA算子的Python代码为:

import os
import numpy as np
import mindspore as ms
from mindspore import ops, Tensor, context

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

CURRENT_PATH = os.path.abspath(__file__)
CustomOps = ops.Custom(CURRENT_PATH.replace(".py", ".so:CustomOps"),
                       out_shape=lambda x:x,
                       out_dtype=ms.int32,
                       func_type="aot")
T0 = Tensor([7], ms.int32)
print (T0)
CustomOps(T0)

运行的指令为:

$ nvcc --shared -Xcompiler -fPIC -o test_shape.so test_shape.cu && python3 test_shape.py 
[7]
7

需要注意的是,这里只能给MindSpore内置的几种Tensor变量,如果是直接调用
CustomOps(7)
会报一个段错误。

高维张量输入

这里一维的张量输入我们就不做讨论了,因为跟前面用到的常数输入本质上是一样的形式。这里我们用一个二维的张量来做一个测试,CUDA代码保持不动,只修改Python代码中的输入:

import os
import numpy as np
import mindspore as ms
from mindspore import ops, Tensor, context

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

CURRENT_PATH = os.path.abspath(__file__)
CustomOps = ops.Custom(CURRENT_PATH.replace(".py", ".so:CustomOps"),
                       out_shape=lambda x:x,
                       out_dtype=ms.int32,
                       func_type="aot")
T0 = Tensor(np.arange(12).reshape((4, 3)), ms.int32)
print (T0)
CustomOps(T0)

运行结果为:

$ nvcc --shared -Xcompiler -fPIC -o test_shape.so test_shape.cu && python3 test_shape.py 
[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]
0
1
2
3

需要注意的是,我们在CUDA的打印函数中设置的打印输出大小是输入张量的第一个维度的大小,我们给的是一个(4,3)大小的张量,因此会顺序打印4个数出来。这里我们也能够发现MindSpore在进行输入的规范化的时候,会自动压平输入的张量变成一个维度。因此这里的调用代码等价于先对输入张量做一个reshape,然后再把第一个维度对应大小的张量元素打印出来。如果要打印所有的元素也很简单,可以修改一下CUDA代码:

// nvcc --shared -Xcompiler -fPIC -o test_shape.so test_shape.cu

#include <iostream>
#define THREADS 1024

__global__ void OpsKernel(const int shape0, const int *input){
    auto i = blockIdx.x * THREADS + threadIdx.x;
    if (i < shape0){
        printf("%d\n", input[i]);
    }
}

extern "C" int CustomOps(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
                         void *stream, void *extra){
    int *input = static_cast<int*>(params[0]);
    int elements = 1;
    for (int i=0; i<ndims[0]; i++){
        elements *= shapes[0][i];
    }
    OpsKernel<<<1, THREADS>>>(elements, input);
    return 0;
}

通过定义一个elements变量用于存储对应张量的元素数量,然后再逐一打印出来即可,执行结果为:

$ nvcc --shared -Xcompiler -fPIC -o test_shape.so test_shape.cu && python3 test_shape.py 
[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]
0
1
2
3
4
5
6
7
8
9
10
11

输出规范化

当我们使用
ops.Custom
算子时,如果指定了out_dtype和out_shape,那么算子会自动帮我们分配好相应的device memory空间。那么我们在CUDA计算的时候可以直接修改对应的内存空间:

// nvcc --shared -Xcompiler -fPIC -o test_shape.so test_shape.cu

#include <iostream>
#define THREADS 1024

__global__ void OpsKernel(const int shape0, const int *input, float *output){
    auto i = blockIdx.x * THREADS + threadIdx.x;
    if (i < shape0){
        output[i] = input[i] * 0.5;
    }
}

extern "C" int CustomOps(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
                         void *stream, void *extra){
    int *input = static_cast<int*>(params[0]);
    float *output = static_cast<float*>(params[1]);
    int elements = 1;
    for (int i=0; i<ndims[0]; i++){
        elements *= shapes[0][i];
    }
    OpsKernel<<<1, THREADS>>>(elements, input, output);
    return 0;
}

这里我们对算子的功能做了一点调整,我们输出的结果是整个张量的元素值乘以0.5,同时也把一个整形变量转化成了一个浮点型变量。其运行Python代码也要做一点调整:

import os
import numpy as np
import mindspore as ms
from mindspore import ops, Tensor, context

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

CURRENT_PATH = os.path.abspath(__file__)
CustomOps = ops.Custom(CURRENT_PATH.replace(".py", ".so:CustomOps"),
                       out_shape=lambda x:x,
                       out_dtype=ms.float32,
                       func_type="aot")
T0 = Tensor(np.arange(12).reshape((4, 3)), ms.int32)
print (T0)
output = CustomOps(T0)
print (output)

这里主要是修改了out_dtype为浮点型,这里如果写错了,会直接导致内存溢出。上述代码的运行结果为:

$ nvcc --shared -Xcompiler -fPIC -o test_shape.so test_shape.cu && python3 test_shape.py 
[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]
[[0.  0.5 1. ]
 [1.5 2.  2.5]
 [3.  3.5 4. ]
 [4.5 5.  5.5]]

可以看到这里输出的张量形状是跟输入保持一致的,即时这个输入张量在经过MindSpore的Custom算子接口时已经被压平成一个一维张量,但是因为我们设置了
out_shape=lambda x:x
,这表示输出的张量shape跟输入的张量shape一致,当然,直接用Python的列表来给
out_shape
赋值也是可以的。例如我们写一个输入输出不同shape的案例:

// nvcc --shared -Xcompiler -fPIC -o test_shape.so test_shape.cu

#include <iostream>
#define THREADS 1024

__global__ void OpsKernel(const int shape0, const int *input, int *output){
    auto i = blockIdx.x * THREADS + threadIdx.x;
    if (i < shape0){
        atomicAdd(&output[0], input[i]);
    }
}

extern "C" int CustomOps(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes,
                         void *stream, void *extra){
    int *input = static_cast<int*>(params[0]);
    int *output = static_cast<int*>(params[1]);
    int elements = 1;
    for (int i=0; i<ndims[0]; i++){
        elements *= shapes[0][i];
    }
    OpsKernel<<<1, THREADS>>>(elements, input, output);
    return 0;
}

这个Kernel函数的主要功能是通过一个atomicAdd函数,把输入张量的所有元素做一个求和,这样输出的张量的shape只有[1],对应的Python调用形式也要做一定的调整:

import os
import numpy as np
import mindspore as ms
from mindspore import ops, Tensor, context

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

CURRENT_PATH = os.path.abspath(__file__)
CustomOps = ops.Custom(CURRENT_PATH.replace(".py", ".so:CustomOps"),
                       out_shape=[1],
                       out_dtype=ms.int32,
                       func_type="aot")
T0 = Tensor(np.arange(12).reshape((4, 3)), ms.int32)
print (T0)
output = CustomOps(T0)
print (output)

由于
atomicAdd(addr, element)
原子操作要求输入输出的类型要一致,因此这里我们还是使用的int类型的output,输出结果如下所示:

$ nvcc --shared -Xcompiler -fPIC -o test_shape.so test_shape.cu && python3 test_shape.py 
[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]
[66]

总结概要

当我们使用GPU进行快速运算时,虽然可以用成熟的深度学习框架如MindSpore和PyTorch等进行实现,但其实从速度上来说,最快不过直接使用C/C++的CUDA来实现。也正是因为如此,在MindSpore框架中支持了对CUDA实现的算子的直接调用,只是在格式规范上有一定的要求。本文主要介绍MindSpore调用本地CUDA算子的一些规范化和技巧。

版权声明

本文首发链接为:
https://www.cnblogs.com/dechinphy/p/custom-ops-shape.html

作者ID:DechinPhy

更多原著文章:
https://www.cnblogs.com/dechinphy/

请博主喝咖啡:
https://www.cnblogs.com/dechinphy/gallery/image/379634.html

引言

在复杂的应用程序设计中,尤其是那些涉及多个状态变迁和业务流程控制的场景,有限状态机(
Finite State Machine, FSM
)是一种强大而有效的建模工具。Spring框架为此提供了Spring状态机(
Spring State Machine
)这一组件,它允许开发者以一种声明式且结构清晰的方式来管理和控制对象的状态流转。

提起Spring状态机,可能有些小伙伴还比较陌生。当你听到状态机时,一定会联想到状态设计模式。确实,状态机是状态模式的一种实际运用,在工作流引擎、订单系统等领域有大量的应用。在介绍状态机之前,我们先来回顾一下状态模式,以便更好地理解Spring状态机的概念和应用。

状态模式

状态模式是一种行为设计模式,用于管理对象的状态以及状态之间的转换。在状态模式中,对象在不同的状态下表现出不同的行为,而状态的转换是由外部条件触发的。状态模式将每个状态封装成一个独立的类,并将状态转换的逻辑分散在这些状态类中,从而使得状态的管理和转换变得简单和灵活。

状态模式通常由以下几个要素组成:

  1. 上下文(Context)
    :上下文是包含了状态的对象,它定义了当前的状态以及可以触发状态转换的接口。上下文对象在不同的状态下会调用相应状态对象的方法来执行具体的行为。

  2. 抽象状态(State)
    :抽象状态是一个接口或抽象类,定义了状态对象的通用行为接口。具体的状态类需要实现这个接口,并根据不同的状态来实现具体的行为。

  3. 具体状态(Concrete State)
    :具体状态是实现了抽象状态接口的具体类,它实现了在特定状态下对象的行为。每个具体状态类负责管理该状态下的行为和状态转换规则。

状态模式结构图

状态模式使得对象在不同状态下的行为更加清晰和可维护,同时也使得对象的状态转换逻辑更加灵活和可扩展。状态模式常见于需要对象根据外部条件改变行为的场景,例如订单状态(如待提交,待发货,已发货,已签收,已完结等状态)的管理、工作流引擎中的状态(例如提交,审核中,驳回,审核通过,审核失败等)管理。

我们以订单状态的流转为例:

  • 首先我们定义一个订单抽象状态的接口
public interface OrderState {  
  
    void handlerOrder();  
}
  • 在定义具体的订单状态,以及对应的订单状态的行为
public class OrderSubmitState implements OrderState{  
    @Override  
    public void handlerOrder() {  
        System.out.println("订单已提交");  
    }  
}

public class OrderOutboundState implements OrderState{  
  
    @Override  
    public void handlerOrder() {  
        System.out.println("订单已出库");  
    }  
}

public class OrderSignedState implements OrderState{  
    @Override  
    public void handlerOrder() {  
        System.out.println("订单已签收");  
    }  
}
  • 在定义一个状态的上下文,用于维护当前状态对象,以及提供状态流转的方法
public class OrderContext {  
  
    private OrderState orderState;  
  
    public void setOrderState(OrderState orderState){  
        this.orderState = orderState;  
    }  
  
    public void handleOrder(){  
        orderState.handlerOrder();  
    }  
}
  • 编写具体业务,测试订单状态流转
public class OrderStateTest {  
  
    public static void main(String[] args) {  
        OrderSubmitState orderSubmitState = new OrderSubmitState();  
        OrderContext orderContext = new OrderContext();  
        orderContext.setOrderState(orderSubmitState);  
        orderContext.handleOrder();  
  
        OrderOutboundState orderOutboundState = new OrderOutboundState();  
        orderContext.setOrderState(orderOutboundState);  
        orderContext.handleOrder();  
  
        OrderSignedState orderSignedState = new OrderSignedState();  
        orderContext.setOrderState(orderSignedState);  
        orderContext.handleOrder();  
    }  
}

执行结果如下:

image.png
使用状态模式中的状态类不仅能消除if-else逻辑校验,在一定程度上也增强了代码的可读性和可维护性。类似策略模式,但是状态机模式跟策略模式还有很大的区别的。

  1. 状态模式
    :


    • 关注对象在不同状态下的行为和状态之间的转换。
    • 通过封装每个状态为单独的类来实现状态切换,使得每个状态对象都能处理自己的行为。
    • 状态之间的转换通常是通过条件判断或外部事件触发的。
  2. 策略模式
    :


    • 关注对象在不同策略下的行为差异。
    • 将不同的算法或策略封装成单独的类,使得它们可以互相替换,并且在运行时动态地选择不同的策略。
    • 不涉及状态转换,而是更多地关注于执行特定行为时选择合适的策略。

虽然两种模式都涉及对象行为的管理,但它们的关注点和应用场景略有不同。

关于消除if-else的方案请参考:
代码整洁之道(一)之优化if-else的8种方案

什么是状态机

状态机,顾名思义,是一种数学模型,它通过定义一系列有限的状态以及状态之间的转换规则来模拟现实世界或抽象系统的动态行为。每个状态代表系统可能存在的条件或阶段,而状态间的转换则是由特定的输入(即事件)触发的。例如,在电商应用中,订单状态可能会经历创建、支付、打包、发货、完成等多个状态,每个状态之间的转变都由对应的业务动作触发。

在状态机中,有以下几个基本概念:

  1. 状态(State)
    :系统处于的特定状态,可以是任何抽象的状态,如有限状态机中的“开”、“关”状态,或是更具体的状态如“运行”、“暂停”、“停止”等。

  2. 事件(Event)
    :导致状态转换发生的触发器或输入,例如用户的输入、外部事件等。事件触发状态之间的转换。

  3. 转移(Transition)
    :描述状态之间的变化或转换,即从一个状态到另一个状态的过程。转移通常由特定的事件触发,触发特定的转移规则。

  4. 动作(Action)
    :在状态转换发生时执行的动作或操作,可以是一些逻辑处理、计算、输出等。动作可以与状态转移相关联。

  5. 初始状态(Initial State)
    :系统的初始状态,即系统启动时所处的状态。

  6. 终止状态(Final State)
    :状态机执行完成后所达到的状态,表示整个状态机的结束。

状态机可以分为有限状态机(
Finite State Machine,FSM
)和无限状态机(
Infinite State Machine
)两种。有限状态机是指状态的数量是有限的,而无限状态机则可以有无限多个状态。在系统设计中,有限状态机比较常见。

Spring状态机原理

Spring状态机建立在有限状态机(FSM)的概念之上,提供了一种简洁且灵活的方式来定义、管理和执行状态机。它将状态定义为Java对象,并通过配置来定义状态之间的转换规则。状态转换通常由外部事件触发,我们可以根据业务逻辑定义不同的事件类型,并与状态转换关联。Spring状态机还提供了状态监听器,用于在状态变化时执行特定的逻辑。同时,状态机的状态可以持久化到数据库或其他存储介质中,以便在系统重启或故障恢复时保持状态的一致性。

Spring状态机核心主要包括以下三个关键元素:

  1. 状态(State)
    :定义了系统可能处于的各个状态,如订单状态中的待支付、已支付等。

  2. 转换(Transition)
    :描述了在何种条件下,当接收到特定事件时,系统可以从一个状态转移到另一个状态。例如,接收到“支付成功”事件时,订单状态从“待支付”转变为“已支付”。

  3. 事件(Event)
    :触发状态转换的动作或者消息,它是引起状态机从当前状态迁移到新状态的原因。

接下来,我们将上述状态模式中关于订单状态的示例转换为状态机实现。

Spring状态机的使用

对于状态机,Spring中封装了一个组件
spring-statemachine
,直接引入即可。

引入依赖

<dependency>
	<groupId>org.springframework.statemachine</groupId>
	<artifactId>spring-statemachine-starter</artifactId>
	<version>2.2.1.RELEASE</version>
</dependency>

定义状态机的状态以及事件类型

在状态机(
Finite State Machine, FSM
)的设计中,“定义状态”和“定义转换”是构建状态机模型的基础元素。

定义状态(States)
: 状态是状态机的核心组成单元,代表了系统或对象在某一时刻可能存在的条件或模式。在状态机中,每一个状态都是系统可能处于的一种明确的条件或阶段。例如,在一个简单的咖啡机状态机中,可能有的状态包括“待机”、“磨豆”、“冲泡”和“完成”。每个状态都是独一无二的,且在任何给定时间,系统只能处于其中一个状态。

定义转换(Transitions)
: 转换则是指状态之间的转变过程,它是状态机模型动态性的体现。当一个外部事件(如用户按下按钮、接收到信号、满足特定条件等)触发时,状态机会从当前状态转移到另一个状态。在定义转换时,需要指出触发转换的事件(
Event
)以及事件发生时系统的响应,即从哪个状态(
Source State
)转到哪个状态(
Target State
)。

/**
*订单状态
*/
public enum OrderStatusEnum {
    /**待提交*/
    DRAFT,
    /**待出库*/
    SUBMITTED,
    /**已出库*/
    DELIVERING,
    /**已签收*/
    SIGNED,
    /**已完成*/
    FINISHED,
    ;
}

/**
* 订单状态流转事件
*/
public enum OrderStatusOperateEventEnum {
    /**确认,已提交*/
    CONFIRMED,
    /**发货*/
    DELIVERY,
    /**签收*/
    RECEIVED,
    /**完成*/
    CONFIRMED_FINISH,
    ;
}

定义状态机以及状态流转规则

状态机配置类是在使用
Spring State Machine
或其他状态机框架时的一个重要步骤,这个类主要用于定义状态机的核心结构,包括状态(
states
)、事件(
events
)、状态之间的转换规则(
transitions
),以及可能的状态迁移动作和决策逻辑。


Spring State Machine
中,创建状态机配置类通常是通过继承
StateMachineConfigurerAdapter
类来实现的。这个适配器类提供了几个模板方法,允许开发者重写它们来配置状态机的各种组成部分:

  1. 配置状态

    configureStates(StateMachineStateConfigurer)
    ): 在这个方法中,开发者定义状态机中所有的状态,包括初始状态(
    initial state
    )和结束状态(
    final/terminal states
    )。例如,定义状态A、B、C,并指定状态A作为初始状态。

  2. 配置转换

    configureTransitions(StateMachineTransitionConfigurer)
    ): 在这里,开发者描述状态之间的转换规则,也就是当某个事件(
    event
    )发生时,状态机应如何从一个状态转移到另一个状态。例如,当事件X发生时,状态机从状态A转移到状态B。

  3. 配置初始状态

    configureInitialState(ConfigurableStateMachineInitializer)
    ): 如果需要显式指定状态机启动时的初始状态,可以在该方法中设置。

@Configuration
@EnableStateMachine(name = "orderStateMachine")
public class OrderStatusMachineConfig extends StateMachineConfigurerAdapter<OrderStatusEnum, OrderStatusOperateEventEnum> {

    /**
     * 设置状态机的状态
     * StateMachineStateConfigurer 即 状态机状态配置
     * @param states 状态机状态
     * @throws Exception 异常
     */
    @Override
    public void configure(StateMachineStateConfigurer<OrderStatusEnum, OrderStatusOperateEventEnum> states) throws Exception {
        states.withStates()
                .initial(OrderStatusEnum.DRAFT)
                .end(OrderStatusEnum.FINISHED)
                .states(EnumSet.allOf(OrderStatusEnum.class));
    }

    /**
     * 设置状态机与订单状态操作事件绑定
     * StateMachineTransitionConfigurer
     * @param transitions
     * @throws Exception
     */
    @Override
    public void configure(StateMachineTransitionConfigurer<OrderStatusEnum, OrderStatusOperateEventEnum> transitions) throws Exception {
        transitions.withExternal().source(OrderStatusEnum.DRAFT).target(OrderStatusEnum.SUBMITTED)
                .event(OrderStatusOperateEventEnum.CONFIRMED)
                .and()
                .withExternal().source(OrderStatusEnum.SUBMITTED).target(OrderStatusEnum.DELIVERING)
                .event(OrderStatusOperateEventEnum.DELIVERY)
                .and()
                .withExternal().source(OrderStatusEnum.DELIVERING).target(OrderStatusEnum.SIGNED)
                .event(OrderStatusOperateEventEnum.RECEIVED)
                .and()
                .withExternal().source(OrderStatusEnum.SIGNED).target(OrderStatusEnum.FINISHED)
                .event(OrderStatusOperateEventEnum.CONFIRMED_FINISH);

    }
}

配置状态机持久化

状态机持久化是指将状态机在某一时刻的状态信息存储到数据库、缓存系统等中,使得即使在系统重启、网络故障或进程终止等情况下,状态机仍能从先前保存的状态继续执行,而不是从初始状态重新开始。

在业务场景中,例如订单处理、工作流引擎、游戏进度跟踪等,状态机通常用于表示某个实体在其生命周期内的状态变迁。如果没有持久化机制,一旦发生意外情况导致系统宕机或重启,未完成的状态变迁将会丢失,这对于业务连续性和一致性是非常不利的。

状态机持久化通常涉及以下几个方面:

  1. 状态记录
    :记录当前状态机实例处于哪个状态。
  2. 上下文数据
    :除了状态外,可能还需要持久化与状态关联的上下文数据,例如触发状态变迁的事件参数、额外的状态属性等。
  3. 历史轨迹
    :某些复杂场景下可能需要记录状态机的历史变迁轨迹,以便于审计、回溯分析或错误恢复。
  4. 并发控制
    :在多线程或多节点环境下,状态机的持久化还要考虑并发访问和同步的问题。

Spring Statemachine
提供了与
Redis

MongoDB
等数据存储结合的持久化方案,可以将状态机的状态信息序列化后存储到Redis中。当状态机需要恢复时,可以从存储中读取状态信息并重新构造状态机实例,使其能够从上次中断的地方继续执行流程。

@Configuration
public class OrderPersist {


    /**
     * 持久化配置
     * 在实际使用中,可以配合数据库或者Redis等进行持久化操作
     * @return
     */
    @Bean
    public DefaultStateMachinePersister<OrderStatusEnum, OrderStatusOperateEventEnum, OrderDO> stateMachinePersister(){
        Map<OrderDO, StateMachineContext<OrderStatusEnum, OrderStatusOperateEventEnum>> map = new HashMap();
        return new DefaultStateMachinePersister<>(new StateMachinePersist<OrderStatusEnum, OrderStatusOperateEventEnum, OrderDO>() {
            @Override
            public void write(StateMachineContext<OrderStatusEnum, OrderStatusOperateEventEnum> context, OrderDO order) throws Exception {
                //持久化操作
                map.put(order, context);
            }

            @Override
            public StateMachineContext<OrderStatusEnum, OrderStatusOperateEventEnum> read(OrderDO order) throws Exception {
                //从库中或者redis中读取order的状态信息
                return map.get(order);
            }
        });
    }
}    

定义状态机监听器

状态机监听器(
State Machine Listener
)是一种组件,它可以监听并响应状态机在运行过程中的各种事件,例如状态变迁、进入或退出状态、转换被拒绝等。


Spring Statemachine
中,监听器可以通过实现
StateMachineListener
接口来定义。该接口提供了一系列回调方法,如
transitionTriggered

stateEntered

stateExited
等,当状态机触发转换、进入新状态或离开旧状态时,这些方法会被调用。同时,我们也可以通过注解实现监听器。注解方式可以在类的方法上直接声明该方法应该在何种状态下被调用,简化监听器的编写和配置。例如
@OnTransition

@OnTransitionEnd

@OnTransitionStart

@Component
@WithStateMachine(name = "orderStateMachine")
public class OrderStatusListener {

    @OnTransition(source = "DRAFT", target = "SUBMITTED")
    public boolean payTransition(Message<OrderStatusOperateEventEnum> message) {
        OrderDO order = (OrderDO) message.getHeaders().get("order");
        order.setOrderStatusEnum(OrderStatusEnum.SUBMITTED);
        System.out.println(String.format("出库订单[%s]确认,状态机信息:%s", order.getOrderNo(), message.getHeaders()));
        return true;
    }

    @OnTransition(source = "SUBMITTED", target = "DELIVERING")
    public boolean deliverTransition(Message<OrderStatusOperateEventEnum> message) {
        OrderDO order = (OrderDO) message.getHeaders().get("order");
        order.setOrderStatusEnum(OrderStatusEnum.DELIVERING);
        System.out.println(String.format("出库订单[%s]发货出库,状态机信息:%s", order.getOrderNo(), message.getHeaders()));
        return true;
    }

    @OnTransition(source = "DELIVERING", target = "SIGNED")
    public boolean receiveTransition(Message<OrderStatusOperateEventEnum> message){
        OrderDO order = (OrderDO) message.getHeaders().get("order");
        order.setOrderStatusEnum(OrderStatusEnum.SIGNED);
        System.out.println(String.format("出库订单[%s]签收,状态机信息:%s", order.getOrderNo(), message.getHeaders()));
        return true;
    }

    @OnTransition(source = "SIGNED", target = "FINISHED")
    public boolean finishTransition(Message<OrderStatusOperateEventEnum> message){
        OrderDO order = (OrderDO) message.getHeaders().get("order");
        order.setOrderStatusEnum(OrderStatusEnum.FINISHED);
        System.out.println(String.format("出库订单[%s]完成,状态机信息:%s", order.getOrderNo(), message.getHeaders()));
        return true;
    }
}

而监听器需要监听到状态流转的事件才会发挥他的作用,才能监听到某个状态事件之后,完成状态的变更。

@Component
public class StateEventUtil {

    private StateMachine<OrderStatusEnum, OrderStatusOperateEventEnum> orderStateMachine;

    private StateMachinePersister<OrderStatusEnum, OrderStatusOperateEventEnum, OrderDO> stateMachinePersister;

    /**
     * 发送状态转换事件
     *  synchronized修饰保证这个方法是线程安全的
     * @param message
     * @return
     */
    public synchronized boolean sendEvent(Message<OrderStatusOperateEventEnum> message) {
        boolean result = false;
        try {
            //启动状态机
            orderStateMachine.start();
            OrderDO order = (OrderDO) message.getHeaders().get("order");
            //尝试恢复状态机状态
            stateMachinePersister.restore(orderStateMachine, order);
            result = orderStateMachine.sendEvent(message);
            //持久化状态机状态
            stateMachinePersister.persist(orderStateMachine, order);
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (Objects.nonNull(message)) {
                OrderDO order = (OrderDO) message.getHeaders().get("order");
                if (Objects.nonNull(order) && Objects.equals(order.getOrderStatusEnum(), OrderStatusEnum.FINISHED)) {
                    orderStateMachine.stop();
                }
            }
        }
        return result;
    }

    @Autowired
    public void setOrderStateMachine(StateMachine<OrderStatusEnum, OrderStatusOperateEventEnum> orderStateMachine) {
        this.orderStateMachine = orderStateMachine;
    }

    @Autowired
    public void setStateMachinePersister(StateMachinePersister<OrderStatusEnum, OrderStatusOperateEventEnum, OrderDO> stateMachinePersister) {
        this.stateMachinePersister = stateMachinePersister;
    }
}

到这里,我们的状态机就定义好了,下面我们就可以在业务代码中使用状态机完成的订单状态的流转。

业务代码使用

@Service
public class OrderServiceImpl implements IOrderService {

    private StateEventUtil stateEventUtil;

    private static final AtomicInteger ID_COUNTER = new AtomicInteger(0);

    private static final Map<Long, OrderDO> ORDER_MAP = new ConcurrentHashMap<>();

    /**
     * 创建新订单
     *
     * @param orderDO
     */
    @Override
    public Long createOrder(OrderDO orderDO) {
        long orderId = ID_COUNTER.incrementAndGet();
        orderDO.setOrderId(orderId);
        orderDO.setOrderNo("OC20240306" + orderId);
        orderDO.setOrderStatusEnum(OrderStatusEnum.DRAFT);
        ORDER_MAP.put(orderId, orderDO);
        System.out.println(String.format("订单[%s]创建成功:", orderDO.getOrderNo()));
        return orderId;
    }

    /**
     * 确认订单
     *
     * @param orderId
     */
    @Override
    public void confirmOrder(Long orderId) {
        OrderDO order = ORDER_MAP.get(orderId);
        System.out.println("确认订单,订单号:" + order.getOrderNo());
        Message message = MessageBuilder.withPayload(OrderStatusOperateEventEnum.CONFIRMED).
                setHeader("order", order).build();
        if (!stateEventUtil.sendEvent(message)) {
            System.out.println(" 确认订单失败, 状态异常,订单号:" + order.getOrderNo());
        }
    }

    /**
     * 订单发货
     *
     * @param orderId
     */
    @Override
    public void deliver(Long orderId) {
        OrderDO order = ORDER_MAP.get(orderId);
        System.out.println("订单出库,订单号:" + order.getOrderNo());
        Message message = MessageBuilder.withPayload(OrderStatusOperateEventEnum.DELIVERY).
                setHeader("order", order).build();
        if (!stateEventUtil.sendEvent(message)) {
            System.out.println(" 订单出库失败, 状态异常,订单号:" + order.getOrderNo());
        }
    }

    /**
     * 签收订单
     *
     * @param orderId
     */
    @Override
    public void signOrder(Long orderId) {
        OrderDO order = ORDER_MAP.get(orderId);
        System.out.println("订单签收,订单号:" + order.getOrderNo());
        Message message = MessageBuilder.withPayload(OrderStatusOperateEventEnum.RECEIVED).
                setHeader("order", order).build();
        if (!stateEventUtil.sendEvent(message)) {
            System.out.println(" 订单签收失败, 状态异常,订单号:" + order.getOrderNo());
        }
    }

    /**
     * 确认完成
     *
     * @param orderId
     */
    @Override
    public void finishOrder(Long orderId) {
        OrderDO order = ORDER_MAP.get(orderId);
        System.out.println("订单完成,订单号:" + order.getOrderNo());
        Message message = MessageBuilder.withPayload(OrderStatusOperateEventEnum.CONFIRMED_FINISH).
                setHeader("order", order).build();
        if (!stateEventUtil.sendEvent(message)) {
            System.out.println(" 订单完成失败, 状态异常,订单号:" + order.getOrderNo());
        }
    }

    /**
     * 获取所有订单信息
     */
    @Override
    public List<OrderDO> listOrders() {
        return new ArrayList<>(ORDER_MAP.values());
    }

    @Autowired
    public void setStateEventUtil(StateEventUtil stateEventUtil) {
        this.stateEventUtil = stateEventUtil;
    }
}

我们在定义一个接口,模拟订单的状态流转:

@RestController
public class OrderController {

    private IOrderService orderService;

    @GetMapping("testOrderStatusMachine")
    public void testOrderStatusMachine(){
        Long orderId1 = orderService.createOrder(new OrderDO());
        Long orderId2 = orderService.createOrder(new OrderDO());

        orderService.confirmOrder(orderId1);
        new Thread("客户线程"){
            @Override
            public void run() {
                orderService.deliver(orderId1);
                orderService.signOrder(orderId1);
                orderService.finishOrder(orderId1);
            }
        }.start();

        orderService.confirmOrder(orderId2);
        orderService.deliver(orderId2);
        orderService.signOrder(orderId2);
        orderService.finishOrder(orderId2);

        System.out.println("全部订单状态:" + orderService.listOrders());


    }

    @Autowired
    public void setOrderService(IOrderService orderService) {
        this.orderService = orderService;
    }
}

我们调用接口:
image.png

我们在日志中可以看到订单状态在状态机的控制下,流转的很丝滑。。。

注意事项

  • 一致性保证
    :确保状态机的配置正确反映了业务逻辑,并保持其在并发环境下的状态一致性。

  • 异常处理
    :在状态转换过程中可能出现异常情况,需要适当地捕获和处理这些异常,防止状态机进入无效状态。

  • 监控与审计
    :在实际应用中,为了便于调试和追溯,可以考虑集成日志记录或事件监听器来记录状态机的每一次状态变迁。

  • 扩展性与维护性
    :随着业务的发展,状态机的设计应当具有足够的灵活性,以便于新增状态或调整转换规则。

一点思考

除了直接使用如Spring状态机这样的专门状态管理工具外,还可以使用其他的哪些方法实现状态机的功能呢?比如:

  1. 消息队列方式
    状态的变更通过发布和消费消息来驱动。每当发生状态变更所需的事件时,生产者将事件作为一个消息发布到特定的消息队列(Topic),而消费者则监听这些消息,根据消息内容和业务规则对订单状态进行更新。这种方式有利于解耦各个服务,实现异步处理,同时增强系统的伸缩性和容错能力。

  2. 定时任务驱动
    使用定时任务定期检查系统中的订单状态,根据预设的业务规则判断是否满足状态变迁条件。比如,每隔一段时间执行一次Job,查询数据库中处于特定状态的订单,并决定是否进行状态更新。这种方法适用于具有一定时效性的状态变迁,但实时性相对较低,对于瞬时响应要求高的场景不太适用。

有关SpringBoot下几种定时任务的实现方式请参考:
玩转SpringBoot:SpringBoot的几种定时任务实现方式

  1. 规则引擎方式
    利用规则引擎(如
    Drools

    LiteFlow
    等)实现状态机,业务团队可以直接在规则引擎中定义状态及状态之间的转换规则,当新的事实数据(如订单信息)输入到规则引擎时,引擎会自动匹配并执行相应的规则,触发状态改变。这种方式的优点在于业务规则高度集中,易于管理和修改,同时也具备较高的灵活性,能够快速应对业务规则的变化。

SpringBoot下使用LiteFlow规则引擎请参考:
轻松应对复杂业务逻辑:LiteFlow-编排式规则引擎框架的优势

总结

Spring状态机提供了一种强大的工具,使得在Java应用中实现复杂的业务流程变得更为简洁和规范。不仅可以提升代码的可读性和可维护性,还能有效降低不同模块之间的耦合度,提高系统的整体稳定性与健壮性。

本文已收录于我的个人博客:
码农Academy的博客,专注分享Java技术干货,包括Java基础、Spring Boot、Spring Cloud、Mysql、Redis、Elasticsearch、中间件、架构设计、面试题、程序员攻略等

零售商家为什么要建设线上商城

传统的实体门店服务范围有限,只能吸引周边500米内的消费者。因此,如何拓展服务范围,吸引更多消费者到店,成为了店家迫切需要解决的问题。

缺乏忠实顾客,客户基础不稳,往往是一次性购物,门店无法形成有效的顾客回流。在当前的市场环境下,构建并维护粉丝群体,成为了商家的核心竞争力。

运营成本不断增长,包括租金和人工成本的上涨,但是广告投放、宣传又成本高昂,且难以追踪效果,达不到预期目标。如何有效吸引新客和提升销售业绩,变得至关重要。

电商不断挤压生存空间,随着网购成为人们的一种生活习惯,由于其方便和价格优势,再加上退换货几乎不产生成本,电商对于实体店构成了巨大的竞争压力。

系统定位

面向新零售连锁商家的线上商城系统,定位包括以下几个方面:

  • 拓宽门店服务半径。通过开发商城小程序的功能,对线下门店来说极为有利,能让5公里范围内的潜在顾客,通过搜索小程序的方式发现商家,有效扩大了潜在顾客群体。
  • 支持多渠道引流、多种业务模式。商城小程序支持众多入口方式,如小程序二维码、微信搜索、朋友推荐、社交媒体分享、微信公众号链接等,为商家提供了丰富的引流手段。商家可依据自身需求,在小程序中开展多种业务模式,例如电商购物、O2O购物、卡券核销、预约服务等多种业务。
  • 提升用户使用体验,促进交易转化。商城小程序的设计无需下载安装,即用即走,带来流畅的交互体验,方便用户随时接触产品和服务。在线客服系统能够实时解答客户疑问,为用户提供全天候服务,显著提升消费体验。
  • 构建私域客户群。商家可利用线上多种场景吸引客户访问小程序,并引导客户关注公众号或加群,通过营销活动促进粉丝转化。通过会员制和积分系统等策略,有效积累和管理私域客户资源。

业务分析

新零售线上商城系统需要满足两种核心业务模式:电商购物模式、O2O购物模式。

我们以瑞幸咖啡为例,下图为瑞幸小程序首页,有到店取、幸运送、电商购的购物入口,其中到店取、幸运送为O2O购物模式,电商购为电商购物模式。

Untitled

电商购物流程

Untitled

O2O购物流程

Untitled

两种业务模式的差异

消费场所的差异:

  • 电商购物模式:完全在线上进行,从进店、选择商品、下单、支付到收货,消费者在线上即可完成购物全过程。
  • O2O购物模式:结合了线上和线下的消费场景。消费者可能在线上选购商品或服务,但可能在实体店进行自提或体验服务。

服务范围的差异

  • 电商购物模式:通常覆盖全国地区,不太受地理位置的限制。
  • O2O购物模式:服务范围受限于实体店的位置,更侧重于本地化服务。

物流配送的差异

  • 电商购物模式:依赖于第三方物流或自建物流进行商品配送,消费者通常在家中等待收取快递。
  • O2O购物模式:消费者可以到店自提商品,或者通过骑手配送商品。

售后服务的差异

  • 电商购物模式:售后服务主要通过线上进行沟通和处理,包括退货、换货、维修等。
  • O2O购物模式:售后服务可以在线上进行,也可以提供线下服务点,让消费者有更多选择。

写在最后

商家建设线上商城的主要原因包括拓宽服务半径,支持多渠道引流和多种业务模式,提升用户体验,构建私域客户群。

新零售线上商城系统需要满足电商购物模式和O2O购物模式,这两种模式在消费场所、服务范围、物流配送和售后服务等方面有所不同。

前言

关于动态代理的一些知识,以及cglib与jdk动态代理的区别,在
这一篇
已经介绍过,不熟悉的可以先看下。
本篇我们来学习一下cglib的FastClass机制,这是cglib与jdk动态代理的一个主要区别,也是一个面试考点。
我们知道jdk动态代理是使用InvocationHandler接口,在invoke方法内,可以使用Method方法对象进行反射调用,反射的一个最大问题是性能较低,cglib就是通过使用FastClass来优化反射调用,提升性能,接下来我们就看下它是如何实现的。

示例

我们先写一个hello world,让代码跑起来。如下:

public class HelloWorld {

	public void print() {
		System.out.println("hello world");
	}
}

public class HelloWorldInterceptor implements MethodInterceptor {
	public Object intercept(Object o, Method method, Object[] objects, MethodProxy methodProxy) throws Throwable {
		System.out.println("before hello world");
		methodProxy.invokeSuper(o, objects);
		System.out.println("after hello world");
		return null;
	}
}

非常简单,就是使用MethodInterceptor在HelloWorld类print方法前后打印一句话,模拟对一个方法前后织入自定义逻辑。
接着使用cglib Enhancer类,创建动态代理对象,设置MethodInterceptor,调用方法。
为了方便观察源码,我们将cglib生成的动态代理类保存下来。


//将生成的动态代理类保存下来
System.setProperty(DebuggingClassWriter.DEBUG_LOCATION_PROPERTY, "D:\\");

Enhancer enhancer = new Enhancer();
enhancer.setSuperclass(HelloWorld.class);
enhancer.setCallback(new HelloWorldInterceptor());

HelloWorld target = (HelloWorld) enhancer.create();
target.print();

输出

before hello world
hello world
after hello world

FastClass机制

我们知道cglib是通过继承实现的,动态代理类会继承被代理类,并重写它的方法,所以它不需要像jdk动态代理一样要求被代理对象有实现接口,因此比较灵活。
既然是通过继承实现的,那应该生成一个类就可以了,但是通过上面的路径观察,可以看到生成了3个文件,其中两个带有FastClass关键字。
这三个类分别是:动态代理类,动态代理类的FastClass,被代理对象的FastClass,从名称上也可以看出它们的关系。

其中动态代理类继承了被代理类,并重写了父类的所有方法,包括父类的父类的方法,包括Object类的equals方法和toString方法等。

public class HelloWorld$$EnhancerByCGLIB$$49f9f9c8 extends HelloWorld implements Factory {
}

这里我们只关注print方法,如下:

第一个直接调用父类方法,也就是被代理对象的方法;第二个会先判断有没有拦截器,如果没有也是直接调用父类方法,否则调用MethodInterceptor的intercept方法,对于我们这里就是HelloWorldInterceptor。
看下intercept的几个参数分别是什么,这几个参数的初始化在动态代理类的静态代码块中都可以找到。
第1个表示动态代理对象。
第2个是被代理对象方法的Method,就是HelloWorld.print。
第3个表示方法参数。
第4个是MethodProxy对象,通过名字我们可以知道它是方法的代理,每一个方法都会有一个对应的MethodProxy,它包含被代理对象、代理对象、以及对应的方法元信息。

这里我们重点关注MethodProxy,它的初始化如下:

CGLIB$print$0$Proxy = MethodProxy.create(var1, var0, "()V", "print", "CGLIB$print$0");       

第1个参数表示被代理对象的Class。
第2个参数表示动态代理对象的Class。
第3个参数是方法的返回值。
第4个参数表示被代理对象的方法名称。
第5个参数表示对应动态代理对象的方法名称。

MethodProxy对象创建好后,我们上面就是通过它进行调用的

methodProxy.invokeSuper(o, objects);

invokeSuper主要源码如下:

public Object invokeSuper(Object obj, Object[] args) throws Throwable {
    init();
    FastClassInfo fci = fastClassInfo;
    return fci.f2.invoke(fci.i2, obj, args);
}

private void init()
{
    if (fastClassInfo == null)
    {
        synchronized (initLock)
        {
            if (fastClassInfo == null)
            {
                CreateInfo ci = createInfo;

                FastClassInfo fci = new FastClassInfo();
                fci.f1 = helper(ci, ci.c1); //被代理对象的FastClass
                fci.f2 = helper(ci, ci.c2); //动态代理对象的FastClass
                fci.i1 = fci.f1.getIndex(sig1); //被代理对象方法的索引下标
                fci.i2 = fci.f2.getIndex(sig2); //动态代理对象方法的索引下标,这里是:CGLIB$print$0 
                fastClassInfo = fci;
                createInfo = null;
            }
        }
    }
}

init方法使用加锁+双检查的方式,只会初始化一次fastClassInfo变量,它用volatile关键字进行修饰,这里涉及到java字节码重排问题,具体可以参考我们之前的分析:
happend before原则

接着回到invokeSuper方法,fci.f2.invoke(fci.i2, obj, args); 实际就是调用动态代理对象的FastClass的invoke方法,并把要调用方法的索引下标i2传过去。
至于方法的索引下标是怎么找到的,可以看动态代理对象的FastClass的getIndex方法,其实就是通过方法的名称、参数个数、参数类型,完全匹配,点到源码文件可以看到有大量的switch分支判断。
这里我们可以看到print方法的索引下标就是18。

public int getIndex(String var1, Class[] var2) {
    switch (var1.hashCode()) {
        case -1295482945:
            if (var1.equals("equals")) {
                switch (var2.length) {
                    case 1:
                        if (var2[0].getName().equals("java.lang.Object")) {
                            return 0;
                        }
                }
            }
        break;
        case 770871766:
            if (var1.equals("CGLIB$print$0")) {
                switch (var2.length) {
                    case 0:
                        return 18;
                }
            }
        break;
    }
}
 public Object invoke(int var1, Object var2, Object[] var3) throws InvocationTargetException {
    HelloWorld..EnhancerByCGLIB..49f9f9c8 var10000 = (HelloWorld..EnhancerByCGLIB..49f9f9c8)var2;
    int var10001 = var1;

    //...
    switch (var10001) {                
        //...
        case 18:
            var10000.CGLIB$print$0();
            return null;
    }
 }    

可以看到最终调用到动态代理类的CGLIB$print$0方法,也就是:

    final void CGLIB$print$0() {
        super.print();
    }

最终调用的就是父类的方法。我们画张图总结一下,有兴趣的同学跟着图和代码逻辑应该可以快速理解。

总结

经过上面的分析,我们可以看到cglib在整个调用过程并没有用到反射,而是使用FastClass对每个方法进行索引,通过方法名称,参数长度,参数类型就可以找到具体的方法,因此性能较好。但也有缺点,首次调用需要生成3个类,会比较慢。在我们实际开发中,特别是一些框架开发,如果有类似的场景也可以借助FastClass对反射进行优化,如:

MyClass cs = new MyCase();
FastClass fastClass = FastClass.create(Case.class);
int index = fastClass.getIndex("test", new Class[]{Integer.class});
Object invoke = fastClass.invoke(index, cs, new Object[1]);

另外MethodProxy还有一个invoke方法,如果我们换一下调用这个方法会发生?留给大家自己尝试。

methodProxy.invokeSuper(o, objects);
//换成 methodProxy.invoke(o, objects);

更多分享,欢迎关注我的github:
https://github.com/jmilktea/jtea

一、问题

今天在将之前的STM32 LwIP1.4.1版本程序移植到2.1.2版本上时,发现ping不同,但是开发板有ICMP回复包,黄颜色警告checksum为0x0000。说明LwIP移植应该是没问题,数据处理这一块出错了。

在网上找了下相关的错误,ST论坛有个问题和我这个一样。

Hardware IPv4 checksum on an STM32F407 is not working

意思就是使用软件校验和能正常使用,但是使用硬件校验和时ICMP数据包的校验为0x0000。问题原因lwipopts.h文件中硬件校验和宏定义下是没有添加以下宏定义

  /*CHECKSUM_CHECK_ICMP==0: Check checksums by hardware for incoming ICMP packets.*/  
  #define CHECKSUM_GEN_ICMP               0

二、解决方法

按照上文的思路,我也看了一下我自己工程文件的lwipopts.h代码果然没有添加这个宏,在添加相关宏定义后就能够正常ping通了。

以下是之前lwip1.4.1 lwipopts.h使用的宏定义:

下面的是
STSW-STM32070
官方例程中的。

三、思考

先说结论:正点原子lwip例程使用的lwip源码是修改过的
以上虽然解决了问题,但是之前lwip1.4.1的例程为什么也能正常跑通呢?
通过查看之前的lwip1.4.1源码发现
不定义CHECKSUM_GEN_ICMP 0的话CHECKSUM_GEN_ICMP 默认为1,也就是ICMP使用软件校验和,但是如果STM32开启硬件校验和的话,STM32会丢弃该帧也就出错了。
这个是lwip1.4.1 opt.h中的宏定义,那为什么不会出错呢?

查找了ICMP check相关的代码发现icmp.c文件是被修改过了,下面的代码也就说定义了CHECKSUM_BY_HARDWARE,在处理ICMP包时就会使用硬件校验,而不需要宏定义定义 CHECK_GEN_ICMP。注释也说明了是ST修改。

单从这一块代码看,修改改后和之前没什么区别,可能是有什么原因CHECK_GEN_ICMP要置1,这个也不纠结了,但是回过头看ST的
STSW-STM32070
官方例程中的ICMP.c文件。

这一段代码也没有修改过,应该是之后例程优化改回来了,而正点原子的例程代码是参考了之前的修改过的ST官方源码。所以在我是用lwip2.1.2进行移植的时候使用硬件校验和就需要定义CHECK_GEN_ICMP宏。