2023年3月

并发编程的难题和挑战

在并发编程的技术领域中,对于我们而言的难题主要有两个:

  1. 多线程之间如何进行通信和线程之间如何同步,通信是指线程之间以何种机制来交换信息。

多线程的线程通信机制

在命令式编程中,线程之间的通信机制有两种:
共享内存

消息传递

  • 共享内存的方式,
    多线程之间共享公共的状态
    (变量),那么线程之间通过写/读内存中的公共状态(变量)来隐式进行通信。在此模式下,同步实现是隐式进行的,由于消息的发送必须在消息的接收之前。
  • 消息传递的方式,
    多线程之间没有公共的状态
    (变量),那么线程之间必须通过明确的传递状态(变量)来显式进行通信。在此模式下,同步实现是显式进行的,必须显式指定某个方法或某段代码需要在线程之间互斥执行。

Java中的同步模式是什么?

同步机制
是指程序用于控制不同线程之间操作发生相对顺序的机制。

Java生态中的并发编程模型采用的是共享内存模型,因此在Java线程之间的通信总是隐式进行, 整个通信过程对开发者是黑盒的,如果编写多线程程序的开发者不深入理解这种隐式模式下的线程之间通信机制,就会会出现内存可见性和一致性的问题,我们统称为线程不安全问题。

存在内存可见问题

Java应用程序中, 所有实例域、静态域和数组元素存储在堆内存中, 堆内存在线程之间共享。会存在这内存可见性问题。

不存在内存可见问题

局部变量(Local variables) , 方法定义参数(java语言规范称之为formal method parameters) 和异常处理器参数(exception handler parameters) 不会在线程之间共享,它们不会有内存可见性问题,也不受内存模型的影响。

所以,我们在开发多线程场景下的程序的时候主要需要关注的就是内存可见问题变量,包含:实例域、静态域和数组元素。

而为了降低并发编程的难度和门槛,这些线程之间的数据同步和通信控制就交由一个特定的数据模型进行控制和管理,我们称之为Java内存模型(JMM)。

Java内存模型(JMM)

JMM决定在程序运行中,一个线程对共享变量的写入何时对另一个线程可见。

JMM定义了线程和主内存之间的抽象关系

线程之间的共享变量存储在
主内存
中,每个线程都有一个私有的
本地内存
, 本地内存中存储了该线程以读/写共享变量的副本。

本地内存是JMM的一个抽象概念, 并不真实存在。它涵盖了缓存, 写缓冲区, 寄存器以及其他的硬件和编译器优化。

Java 内存模型的抽象示意图如下:

由上图可见,线程A与线程B之间如要数据通信,需要有以下两个步骤:

  1. 线程A把本地内存A中更新过的共享变量刷新到
    主内存
    中去。
  2. 线程B到主内存中去读取线程A之前已更新过的
    共享变量

下面通过示意图来说明这两个步骤:

如上图所示,本地内存A和B有主内存中共享变量x的副本。假设初始时,这三个内存中的x值都为0。

  1. 线程A在执行时,把更新后的x值,临时存放在自己的本地内存A中。
  2. 线程A和线程B需要通信时,线程A首先会把自己本地内存中修改后的x值刷新到主内存中,此时主内存中的x值变了。
  3. 线程B到主内存中去读取线程A更新后的x值,此时线程B的本地内存的x值也变了。

总结一下就是,这两个步骤数据角度而言是线程A在向线程B发送消息,而且这个通信过程必须要经过主内存。JMM通过控制主内存与每个线程的本地内存之间的交互, 来为程序提供内存可见性保证。

线程不安全因素之一(指令重排序问题)

基于上述所说的场景之下,JVM为了在执行程序时为了提高性能,编译器和处理器常常会对指令做重排序。在此我们将按照
重排序
的执行时间前后分为重排序分三种类型,如下图所示。

  • 第一步属于
    编译器
    重排序:
    编译器优化的重排序
    ,编译器在不改变单线程程序语义的前提下,可以重新安排语句的执行顺序。

  • 第二步属于
    处理器
    重排序:
    指令级并行的重排序
    ,现代处理器采用了指令级并行技术(Instruction-Level Parallelism, ILP) 来将多条指令重叠执行。如果不存在数据依赖性, 处理器可以改变语句对应机器指令的执行顺序。

  • 第三步属于
    处理器
    重排序:
    内存系统的重排序。由于处理器使用缓存和读/写缓冲区,这使得加载和存储操作看上去可能是在乱序执行,此处特别是针对与本地内存和共享主存之间的更新操作的一致性和可见性

这些重排序都可能会导致多线程程序出现内存可见性问题。

JMM解决重排序的线程不安全问题

解决编译器级别重排序

  • JMM的编译器重排序规则
    会禁止特定类型的编译器重排序,此处注意:
    不是所有的编译器重排序都要禁止

解决处理器级别重排序

  • JMM的处理器重排序规则会要求java编译器在生成指令序列时, 插入特定类型的内存屏障(memory barriers, 也可以称之为memory fence)指令
    , 通过
    内存屏障
    指令来禁止
    特定类型的处理器重排序

    此处注意:不是所有的处理器重排序都要禁止)

总结一下,针对于JMM属于语言级的内存模型, 它确保在不同的编译器和不同的处理器平台之上,通过禁止特定类型的编译器重排序和处理器重排序,从而实现了内存的可见性以及一致性。

处理器重排序与内存屏障指令

上面说了其实是通过插入了内存屏障指令,从而控制住了对应的处理器级别的指令重排。

线程不安全因素之一(写缓存处理模式)

  • 现代的处理器使用写缓冲区来临时保存向内存写入的数据,写缓冲区可以保证指令流水线持续运行,它可以避免由于处理器停顿下来等待向内存写入数据而产生的延迟。

  • 通过以批处理的方式刷新写缓冲区,以及合并写缓冲区中对同一内存地址的多次写,可以减少对内存总线的占用。虽然写缓冲区有这么多好处,但每个处理器上的写缓冲区,仅仅对它所在的处理器可见。

这个特性会对内存操作的执行顺序产生重要的影响,处理器对内存的读/写操作的执行顺序,不一定与内存实际发生的读/写操作顺序一致。

  1. 处理器A

    处理器B
    可以同时把共享变量写入自己的写缓冲区(A1,B1)
  2. 从内存中读取另一个共享变量(A2,B2)
  3. 最后才把自己写缓存区中保存的脏数据刷新到内存中(A3,B3)。

从内存操作实际发生的顺序来看,直到处理器A执行A3来刷新自己的写缓存区,写操作A1才算真正执行了。虽然处理器A执行内存操作的顺序为:A1->A2,但内存操作实际发生的顺序却是:A2->A1。此时,处理器A的内存操作顺序被重排序了(处理器B的情况和处理器A一样)。

由于现代的处理器都会使用写缓冲区,因此现代的处理器都会允许对写-读操作重排序。常见的处理器都允许Store-Load重排序,常见的处理器都不允许对存在数据依赖的操作做重排序。

内存屏障指令

为了保证内存可见性, java编译器在生成指令序列的适当位置会插入内存屏障指令来禁止特定类型的处理器重排序。JMM把内存屏障指令分为下列四类:

内存屏障类型 指令示例 备注
LoadLoad Barries Load1\LoadLoad\Load2 确保Load1数据的装载,之前于Load2及所有后续装载指令的装载
StoreStore Barries Store1\StoreStore\Store2 确保Store1数据对其他处理器可见(刷新到内存),之前于Store2及所有后续存储指令的存储。
LoadStore Barriers Load1\ LoadStore\Store2 确保Load1数据装载, 之前于Store2及所有后续的存储指令刷新到内存
StoreLoad Barriers Store1\StoreLoad\Load2 确保Storel数据对其他处理器变得可见(指刷新到内存),之前于Load2及所有后续装载指令的装载。StoreLoad Barriers会使该屏障之前的所有内存访问指令(存储和装载指令)完成之后,才执行该屏障之后的内存访问指令。

**StoreLoad Barriers是一个“全能型”的屏障, 它同时具有其他三个屏障的效果。现代的多处理器大都支持该屏障(其他类型的屏障不一定被所有处理器支持)。执行该屏障开销会很昂贵,因为当前处理器通常要把写缓冲区中的数据全部刷新到内存中(buffer fully flush) **。

未完善待续!

操作系统 :Windows10_x64 、CentOS 7.6.1810_x64

wireshark版本:3.6.12

Python 版本  :  3.9.12

一、背景描述

工作中有时候会遇到需要从pcap抓包文件里面提取音频的情况,比如下面这些场景:

  • 从pcap文件里面导出wav文件

  • 从pcap文件里面导出mp3文件

...

本文以pcma音频编码为例,介绍下从pcap文件提取音频的流程。

二、具体实现

这里提供两种实现方式从pcap文件提取音频的流程,分别为Windows 10环境和CentOS 7环境。

1、Windows 10环境使用wireshark提取rtp音频

wireshark版本:3.6.12

1)打开pcap文件,选择 voip 通话;

2)选中需要提取的通话,然后点击播放;

3)在弹出的窗口,点击播放按钮可以实时听取音频流;

4)选中需要导出的声道,执行导出音频操作;

5)在弹出的窗口选择导出的音频格式,目前的版本支持wav格式、au格式;

这里以wav格式为例展示下导出效果:

2、CentOS 7环境使用python提取rtp音频

使用python导出rtp音频大概分为以下两个步骤:

1)使用libpcap从pcap文件中提取raw格式的音频;

libpcap的使用可以参考这篇文章:
python3使用libpcap库进行抓包及数据处理

2)使用ffmpeg将raw格式转换成需要的格式(比如wav)

可安装ffmpeg后直接使用,也可自行编译,centos下编译ffmpeg可以参考这篇文章:
CentOS7环境下编译FFmpeg

示例代码如下:

完整代码从如下途径获取:

关注微信公众号(聊聊博文,文末可扫码)后回复 2023032601 获取。

运行效果如下:


三、资源获取

本文涉及示例代码和文件,可从百度网盘获取:

https://pan.baidu.com/s/1NVo9TK5bJwo1CUk5gE9qmA

关注微信公众号(聊聊博文,文末可扫码)后回复
2023032601
获取。

目前为止,阿里云官方并没有
dart
版本的oss sdk,所以才开发了这个插件
flutter_oss_aliyun
提供对oss sdk的支持。

flutter_oss_aliyun

一个访问阿里云oss并且支持STS临时访问凭证访问OSS的flutter库,基本上涵盖阿里云oss sdk的所有功能。⭐

flutter pub
:
https://pub.dev/packages/flutter_oss_aliyun

github
:
https://github.com/huhx/flutter_oss_aliyun

oss sts document
:
https://help.aliyun.com/document_detail/100624.html

目录

  • 了解需求
  • 方案 1:数据库轮询
  • 方案 2:JDK 的延迟队列
  • 方案 3:时间轮算法
  • 方案 4:redis 缓存
  • 方案 5:使用消息队列

了解需求

在开发中,往往会遇到一些关于延时任务的需求。

例如

  • 生成订单 30 分钟未支付,则自动取消
  • 生成订单 60 秒后,给用户发短信

对上述的任务,我们给一个专业的名字来形容,那就是延时任务。那么这里就会产生一个问题,这个延时任务和定时任务的区别究竟在哪里呢?一共有如下几点区别

定时任务有明确的触发时间,延时任务没有

定时任务有执行周期,而延时任务在某事件触发后一段时间内执行,没有执行周期

定时任务一般执行的是批处理操作是多个任务,而延时任务一般是单个任务

下面,我们以判断订单是否超时为例,进行方案分析

本文已经收录到Github仓库,该仓库包含
计算机基础、Java基础、多线程、JVM、数据库、Redis、Spring、Mybatis、SpringMVC、SpringBoot、分布式、微服务、设计模式、架构、校招社招分享
等核心知识点,欢迎star~

Github地址

方案 1:数据库轮询

思路

该方案通常是在小型项目中使用,即通过一个线程定时的去扫描数据库,通过订单时间来判断是否有超时的订单,然后进行 update 或 delete 等操作

实现

可以用 quartz 来实现的,简单介绍一下

maven 项目引入一个依赖如下所示

<dependency>
    <groupId>org.quartz-scheduler</groupId>
    <artifactId>quartz</artifactId>
    <version>2.2.2</version>
</dependency>

调用 Demo 类 MyJob 如下所示

package com.rjzheng.delay1;

import org.quartz.*;
import org.quartz.impl.StdSchedulerFactory;

public class MyJob implements Job {

    public void execute(JobExecutionContext context) throws JobExecutionException {
        System.out.println("要去数据库扫描啦。。。");
    }

    public static void main(String[] args) throws Exception {
        // 创建任务
        JobDetail jobDetail = JobBuilder.newJob(MyJob.class)
                .withIdentity("job1", "group1").build();
        // 创建触发器 每3秒钟执行一次
        Trigger trigger = TriggerBuilder
                .newTrigger()
                .withIdentity("trigger1", "group3")
                .withSchedule(
                        SimpleScheduleBuilder
                                .simpleSchedule()
                                .withIntervalInSeconds(3).
                                repeatForever())
                .build();
        Scheduler scheduler = new StdSchedulerFactory().getScheduler();
        // 将任务及其触发器放入调度器
        scheduler.scheduleJob(jobDetail, trigger);
        // 调度器开始调度任务
        scheduler.start();
    }

}

运行代码,可发现每隔 3 秒,输出如下

要去数据库扫描啦。。。

优点

简单易行,支持集群操作

面试网站

缺点

  • 对服务器内存消耗大
  • 存在延迟,比如你每隔 3 分钟扫描一次,那最坏的延迟时间就是 3 分钟
  • 假设你的订单有几千万条,每隔几分钟这样扫描一次,数据库损耗极大

方案 2:JDK 的延迟队列

思路

该方案是利用 JDK 自带的 DelayQueue 来实现,这是一个无界阻塞队列,该队列只有在延迟期满的时候才能从中获取元素,放入 DelayQueue 中的对象,是必须实现 Delayed 接口的。

DelayedQueue 实现工作流程如下图所示

其中 Poll():获取并移除队列的超时元素,没有则返回空

take():获取并移除队列的超时元素,如果没有则 wait 当前线程,直到有元素满足超时条件,返回结果。

实现

定义一个类 OrderDelay 实现 Delayed,代码如下

package com.rjzheng.delay2;

import java.util.concurrent.Delayed;
import java.util.concurrent.TimeUnit;

public class OrderDelay implements Delayed {

    private String orderId;

    private long timeout;

    OrderDelay(String orderId, long timeout) {
        this.orderId = orderId;
        this.timeout = timeout + System.nanoTime();
    }

    public int compareTo(Delayed other) {
        if (other == this) {
            return 0;
        }
        OrderDelay t = (OrderDelay) other;
        long d = (getDelay(TimeUnit.NANOSECONDS) - t.getDelay(TimeUnit.NANOSECONDS));
        return (d == 0) ? 0 : ((d < 0) ? -1 : 1);
    }

    // 返回距离你自定义的超时时间还有多少
    public long getDelay(TimeUnit unit) {
        return unit.convert(timeout - System.nanoTime(), TimeUnit.NANOSECONDS);
    }

    void print() {
        System.out.println(orderId + "编号的订单要删除啦。。。。");
    }

}

运行的测试 Demo 为,我们设定延迟时间为 3 秒

package com.rjzheng.delay2;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.DelayQueue;
import java.util.concurrent.TimeUnit;

public class DelayQueueDemo {

    public static void main(String[] args) {
        List<String> list = new ArrayList<String>();
        list.add("00000001");
        list.add("00000002");
        list.add("00000003");
        list.add("00000004");
        list.add("00000005");

        DelayQueue<OrderDelay> queue = newDelayQueue < OrderDelay > ();
        long start = System.currentTimeMillis();
        for (int i = 0; i < 5; i++) {
            //延迟三秒取出
            queue.put(new OrderDelay(list.get(i), TimeUnit.NANOSECONDS.convert(3, TimeUnit.SECONDS)));
            try {
                queue.take().print();
                System.out.println("After " + (System.currentTimeMillis() - start) + " MilliSeconds");
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

}

输出如下

00000001编号的订单要删除啦。。。。
After 3003 MilliSeconds
00000002编号的订单要删除啦。。。。
After 6006 MilliSeconds
00000003编号的订单要删除啦。。。。
After 9006 MilliSeconds
00000004编号的订单要删除啦。。。。
After 12008 MilliSeconds
00000005编号的订单要删除啦。。。。
After 15009 MilliSeconds

可以看到都是延迟 3 秒,订单被删除

优点

效率高,任务触发时间延迟低。

缺点

  • 服务器重启后,数据全部消失,怕宕机
  • 集群扩展相当麻烦
  • 因为内存条件限制的原因,比如下单未付款的订单数太多,那么很容易就出现 OOM 异常
  • 代码复杂度较高

方案 3:时间轮算法

思路

先上一张时间轮的图(这图到处都是啦)

时间轮算法可以类比于时钟,如上图箭头(指针)按某一个方向按固定频率轮动,每一次跳动称为一个 tick。这样可以看出定时轮由个 3 个重要的属性参数,ticksPerWheel(一轮的 tick 数),tickDuration(一个 tick 的持续时间)以及 timeUnit(时间单位),例如当 ticksPerWheel=60,tickDuration=1,timeUnit=秒,这就和现实中的始终的秒针走动完全类似了。

如果当前指针指在 1 上面,我有一个任务需要 4 秒以后执行,那么这个执行的线程回调或者消息将会被放在 5 上。那如果需要在 20 秒之后执行怎么办,由于这个环形结构槽数只到 8,如果要 20 秒,指针需要多转 2 圈。位置是在 2 圈之后的 5 上面(20 % 8 + 1)

实现

我们用 Netty 的 HashedWheelTimer 来实现

给 Pom 加上下面的依赖

<dependency>
    <groupId>io.netty</groupId>
    <artifactId>netty-all</artifactId>
    <version>4.1.24.Final</version>
</dependency>

测试代码 HashedWheelTimerTest 如下所示

package com.rjzheng.delay3;

import io.netty.util.HashedWheelTimer;
import io.netty.util.Timeout;
import io.netty.util.Timer;
import io.netty.util.TimerTask;

import java.util.concurrent.TimeUnit;

public class HashedWheelTimerTest {

    static class MyTimerTask implements TimerTask {

        boolean flag;

        public MyTimerTask(boolean flag) {
            this.flag = flag;
        }

        public void run(Timeout timeout) throws Exception {
            System.out.println("要去数据库删除订单了。。。。");
            this.flag = false;
        }
    }

    public static void main(String[] argv) {
        MyTimerTask timerTask = new MyTimerTask(true);
        Timer timer = new HashedWheelTimer();
        timer.newTimeout(timerTask, 5, TimeUnit.SECONDS);
        int i = 1;
        while (timerTask.flag) {
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(i + "秒过去了");
            i++;
        }
    }

}

输出如下

1秒过去了
2秒过去了
3秒过去了
4秒过去了
5秒过去了
要去数据库删除订单了。。。。
6秒过去了

优点

效率高,任务触发时间延迟时间比 delayQueue 低,代码复杂度比 delayQueue 低。

缺点

  • 服务器重启后,数据全部消失,怕宕机
  • 集群扩展相当麻烦
  • 因为内存条件限制的原因,比如下单未付款的订单数太多,那么很容易就出现 OOM 异常

方案 4:redis 缓存

思路一

利用 redis 的 zset,zset 是一个有序集合,每一个元素(member)都关联了一个 score,通过 score 排序来取集合中的值

添加元素:ZADD key score member [[score member][score member] …]

按顺序查询元素:ZRANGE key start stop [WITHSCORES]

查询元素 score:ZSCORE key member

移除元素:ZREM key member [member …]

测试如下

添加单个元素
redis> ZADD page_rank 10 google.com
(integer) 1

添加多个元素
redis> ZADD page_rank 9 baidu.com 8 bing.com
(integer) 2

redis> ZRANGE page_rank 0 -1 WITHSCORES
1) "bing.com"
2) "8"
3) "baidu.com"
4) "9"
5) "google.com"
6) "10"

查询元素的score值
redis> ZSCORE page_rank bing.com
"8"

移除单个元素
redis> ZREM page_rank google.com
(integer) 1

redis> ZRANGE page_rank 0 -1 WITHSCORES
1) "bing.com"
2) "8"
3) "baidu.com"
4) "9"

那么如何实现呢?我们将订单超时时间戳与订单号分别设置为 score 和 member,系统扫描第一个元素判断是否超时,具体如下图所示

实现一

package com.rjzheng.delay4;

import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.Tuple;

import java.util.Calendar;
import java.util.Set;

public class AppTest {

    private static final String ADDR = "127.0.0.1";

    private static final int PORT = 6379;

    private static JedisPool jedisPool = new JedisPool(ADDR, PORT);

    public static Jedis getJedis() {
        return jedisPool.getResource();
    }

    //生产者,生成5个订单放进去
    public void productionDelayMessage() {
        for (int i = 0; i < 5; i++) {
            //延迟3秒
            Calendar cal1 = Calendar.getInstance();
            cal1.add(Calendar.SECOND, 3);
            int second3later = (int) (cal1.getTimeInMillis() / 1000);
            AppTest.getJedis().zadd("OrderId", second3later, "OID0000001" + i);
            System.out.println(System.currentTimeMillis() + "ms:redis生成了一个订单任务:订单ID为" + "OID0000001" + i);
        }
    }

    //消费者,取订单

    public void consumerDelayMessage() {
        Jedis jedis = AppTest.getJedis();
        while (true) {
            Set<Tuple> items = jedis.zrangeWithScores("OrderId", 0, 1);
            if (items == null || items.isEmpty()) {
                System.out.println("当前没有等待的任务");
                try {
                    Thread.sleep(500);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                continue;
            }
            int score = (int) ((Tuple) items.toArray()[0]).getScore();
            Calendar cal = Calendar.getInstance();
            int nowSecond = (int) (cal.getTimeInMillis() / 1000);
            if (nowSecond >= score) {
                String orderId = ((Tuple) items.toArray()[0]).getElement();
                jedis.zrem("OrderId", orderId);
                System.out.println(System.currentTimeMillis() + "ms:redis消费了一个任务:消费的订单OrderId为" + orderId);
            }
        }
    }

    public static void main(String[] args) {
        AppTest appTest = new AppTest();
        appTest.productionDelayMessage();
        appTest.consumerDelayMessage();
    }

}

此时对应输出如下

可以看到,几乎都是 3 秒之后,消费订单。

然而,这一版存在一个致命的硬伤,在高并发条件下,多消费者会取到同一个订单号,我们上测试代码 ThreadTest

package com.rjzheng.delay4;

import java.util.concurrent.CountDownLatch;

public class ThreadTest {

    private static final int threadNum = 10;
    private static CountDownLatch cdl = newCountDownLatch(threadNum);

    static class DelayMessage implements Runnable {
        public void run() {
            try {
                cdl.await();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            AppTest appTest = new AppTest();
            appTest.consumerDelayMessage();
        }
    }

    public static void main(String[] args) {
        AppTest appTest = new AppTest();
        appTest.productionDelayMessage();
        for (int i = 0; i < threadNum; i++) {
            new Thread(new DelayMessage()).start();
            cdl.countDown();
        }
    }

}

输出如下所示

显然,出现了多个线程消费同一个资源的情况。

解决方案

(1)用分布式锁,但是用分布式锁,性能下降了,该方案不细说。

(2)对 ZREM 的返回值进行判断,只有大于 0 的时候,才消费数据,于是将 consumerDelayMessage()方法里的

if(nowSecond >= score){
    String orderId = ((Tuple)items.toArray()[0]).getElement();
    jedis.zrem("OrderId", orderId);
    System.out.println(System.currentTimeMillis()+"ms:redis消费了一个任务:消费的订单OrderId为"+orderId);
}

修改为

if (nowSecond >= score) {
    String orderId = ((Tuple) items.toArray()[0]).getElement();
    Long num = jedis.zrem("OrderId", orderId);
    if (num != null && num > 0) {
        System.out.println(System.currentTimeMillis() + "ms:redis消费了一个任务:消费的订单OrderId为" + orderId);
    }
}

在这种修改后,重新运行 ThreadTest 类,发现输出正常了

思路二

该方案使用 redis 的 Keyspace Notifications,中文翻译就是键空间机制,就是利用该机制可以在 key 失效之后,提供一个回调,实际上是 redis 会给客户端发送一个消息。是需要 redis 版本 2.8 以上。

实现二

在 redis.conf 中,加入一条配置

notify-keyspace-events Ex

运行代码如下

package com.rjzheng.delay5;

import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPubSub;

public class RedisTest {

    private static final String ADDR = "127.0.0.1";
    private static final int PORT = 6379;
    private static JedisPool jedis = new JedisPool(ADDR, PORT);
    private static RedisSub sub = new RedisSub();

    public static void init() {
        new Thread(new Runnable() {
            public void run() {
                jedis.getResource().subscribe(sub, "__keyevent@0__:expired");
            }
        }).start();
    }

    public static void main(String[] args) throws InterruptedException {
        init();
        for (int i = 0; i < 10; i++) {
            String orderId = "OID000000" + i;
            jedis.getResource().setex(orderId, 3, orderId);
            System.out.println(System.currentTimeMillis() + "ms:" + orderId + "订单生成");
        }
    }

    static class RedisSub extends JedisPubSub {
        @Override
        public void onMessage(String channel, String message) {
            System.out.println(System.currentTimeMillis() + "ms:" + message + "订单取消");

        }
    }
}

输出如下

可以明显看到 3 秒过后,订单取消了

ps:redis 的 pub/sub 机制存在一个硬伤,官网内容如下

原:Because Redis Pub/Sub is fire and forget currently there is no way to use this feature if your application demands reliable notification of events, that is, if your Pub/Sub client disconnects, and reconnects later, all the events delivered during the time the client was disconnected are lost.

翻: Redis 的发布/订阅目前是即发即弃(fire and forget)模式的,因此无法实现事件的可靠通知。也就是说,如果发布/订阅的客户端断链之后又重连,则在客户端断链期间的所有事件都丢失了。因此,方案二不是太推荐。当然,如果你对可靠性要求不高,可以使用。

优点

(1) 由于使用 Redis 作为消息通道,消息都存储在 Redis 中。如果发送程序或者任务处理程序挂了,重启之后,还有重新处理数据的可能性。

(2) 做集群扩展相当方便

(3) 时间准确度高

缺点

需要额外进行 redis 维护

方案 5:使用消息队列

思路

我们可以采用 rabbitMQ 的延时队列。RabbitMQ 具有以下两个特性,可以实现延迟队列

RabbitMQ 可以针对 Queue 和 Message 设置 x-message-tt,来控制消息的生存时间,如果超时,则消息变为 dead letter

lRabbitMQ 的 Queue 可以配置 x-dead-letter-exchange 和 x-dead-letter-routing-key(可选)两个参数,用来控制队列内出现了 deadletter,则按照这两个参数重新路由。结合以上两个特性,就可以模拟出延迟消息的功能,具体的,我改天再写一篇文章,这里再讲下去,篇幅太长。

优点

高效,可以利用 rabbitmq 的分布式特性轻易的进行横向扩展,消息支持持久化增加了可靠性。

缺点

本身的易用度要依赖于 rabbitMq 的运维.因为要引用 rabbitMq,所以复杂度和成本变高。

--end--

最后给大家分享一个Github仓库,上面有大彬整理的
300多本经典的计算机书籍PDF
,包括
C语言、C++、Java、Python、前端、数据库、操作系统、计算机网络、数据结构和算法、机器学习、编程人生
等,可以star一下,下次找书直接在上面搜索,仓库持续更新中~

Github地址

https://github.com/Tyson0314/java-books

1.机器学习算法(六)基于天气数据集的XGBoost分类预测

1.1 XGBoost的介绍与应用

XGBoost是2016年由华盛顿大学陈天奇老师带领开发的一个可扩展机器学习系统。严格意义上讲XGBoost并不是一种模型,而是一个可供用户轻松解决分类、回归或排序问题的软件包。它内部实现了梯度提升树(GBDT)模型,并对模型中的算法进行了诸多优化,在取得高精度的同时又保持了极快的速度,在一段时间内成为了国内外数据挖掘、机器学习领域中的大规模杀伤性武器。

更重要的是,XGBoost在系统优化和机器学习原理方面都进行了深入的考虑。毫不夸张的讲,XGBoost提供的可扩展性,可移植性与准确性推动了机器学习计算限制的上限,该系统在单台机器上运行速度比当时流行解决方案快十倍以上,甚至在分布式系统中可以处理十亿级的数据。

XGBoost在机器学习与数据挖掘领域有着极为广泛的应用。据统计在2015年Kaggle平台上29个获奖方案中,17只队伍使用了XGBoost;在2015年KDD-Cup中,前十名的队伍均使用了XGBoost,且集成其他模型比不上调节XGBoost的参数所带来的提升。这些实实在在的例子都表明,XGBoost在各种问题上都可以取得非常好的效果。

同时,XGBoost还被成功应用在工业界与学术界的各种问题中。例如商店销售额预测、高能物理事件分类、web文本分类;用户行为预测、运动检测、广告点击率预测、恶意软件分类、灾害风险预测、在线课程退学率预测。虽然领域相关的数据分析和特性工程在这些解决方案中也发挥了重要作用,但学习者与实践者对XGBoost的一致选择表明了这一软件包的影响力与重要性。

1.2 原理介绍

XGBoost底层实现了GBDT算法,并对GBDT算法做了一系列优化:

  1. 对目标函数进行了泰勒展示的二阶展开,可以更加高效拟合误差。
  2. 提出了一种估计分裂点的算法加速CART树的构建过程,同时可以处理稀疏数据。
  3. 提出了一种树的并行策略加速迭代。
  4. 为模型的分布式算法进行了底层优化。

XGBoost是基于CART树的集成模型,它的思想是串联多个决策树模型共同进行决策。

那么如何串联呢?XGBoost采用迭代预测误差的方法串联。举个通俗的例子,我们现在需要预测一辆车价值3000元。我们构建决策树1训练后预测为2600元,我们发现有400元的误差,那么决策树2的训练目标为400元,但决策树2的预测结果为350元,还存在50元的误差就交给第三棵树……以此类推,每一颗树用来估计之前所有树的误差,最后所有树预测结果的求和就是最终预测结果!

XGBoost的基模型是CART回归树,它有两个特点:(1)CART树,是一颗二叉树。(2)回归树,最后拟合结果是连续值。

XGBoost模型可以表示为以下形式,我们约定$f_t(x)$表示前$t$颗树的和,$h_t(x)$表示第$t$颗决策树,模型定义如下:

$f_{t}(x)=\sum_{t=1}^{T} h_{t}(x)$

由于模型递归生成,第$t$步的模型由第$t-1$步的模型形成,可以写成:

$f_{t}(x)=f_{t-1}(x)+h_{t}(x)$

每次需要加上的树$h_t(x)$是之前树求和的误差:

$r_{t, i}=y_{i}-f_{m-1}\left(x_{i}\right)$

我们每一步只要拟合一颗输出为$r_{t,i}$的CART树加到$f_{t-1}(x)$就可以了。

1.3 相关流程

  • 了解 XGBoost 的参数与相关知识
  • 掌握 XGBoost 的Python调用并将其运用到天气数据集预测

Part1 基于天气数据集的XGBoost分类实践

  • Step1: 库函数导入
  • Step2: 数据读取/载入
  • Step3: 数据信息简单查看
  • Step4: 可视化描述
  • Step5: 对离散变量进行编码
  • Step6: 利用 XGBoost 进行训练与预测
  • Step7: 利用 XGBoost 进行特征选择
  • Step8: 通过调整参数获得更好的效果

3.基于天气数据集的XGBoost分类实战

3.1 EDA探索性分析

在实践的最开始,我们首先需要导入一些基础的函数库包括:numpy (Python进行科学计算的基础软件包),pandas(pandas是一种快速,强大,灵活且易于使用的开源数据分析和处理工具),matplotlib和seaborn绘图。

#导入需要用到的数据集
!wget https://tianchi-media.oss-cn-beijing.aliyuncs.com/DSW/7XGBoost/train.csv
--2023-03-22 17:33:53--  https://tianchi-media.oss-cn-beijing.aliyuncs.com/DSW/7XGBoost/train.csv
正在解析主机 tianchi-media.oss-cn-beijing.aliyuncs.com (tianchi-media.oss-cn-beijing.aliyuncs.com)... 49.7.22.39
正在连接 tianchi-media.oss-cn-beijing.aliyuncs.com (tianchi-media.oss-cn-beijing.aliyuncs.com)|49.7.22.39|:443... 已连接。
已发出 HTTP 请求,正在等待回应... 200 OK
长度: 11476379 (11M) [text/csv]
正在保存至: “train.csv.2”

train.csv.2         100%[===================>]  10.94M  8.82MB/s    in 1.2s    

2023-03-22 17:33:55 (8.82 MB/s) - 已保存 “train.csv.2” [11476379/11476379])

Step1:函数库导入

##  基础函数库
import numpy as np 
import pandas as pd

## 绘图函数库
import matplotlib.pyplot as plt
import seaborn as sns

本次我们选择天气数据集进行方法的尝试训练,现在有一些由气象站提供的每日降雨数据,我们需要根据历史降雨数据来预测明天会下雨的概率。样例涉及到的测试集数据test.csv与train.csv的格式完全相同,但其RainTomorrow未给出,为预测变量。

数据的各个特征描述如下:

特征名称 意义 取值范围
Date 日期 字符串
Location 气象站的地址 字符串
MinTemp 最低温度 实数
MaxTemp 最高温度 实数
Rainfall 降雨量 实数
Evaporation 蒸发量 实数
Sunshine 光照时间 实数
WindGustDir 最强的风的方向 字符串
WindGustSpeed 最强的风的速度 实数
WindDir9am 早上9点的风向 字符串
WindDir3pm 下午3点的风向 字符串
WindSpeed9am 早上9点的风速 实数
WindSpeed3pm 下午3点的风速 实数
Humidity9am 早上9点的湿度 实数
Humidity3pm 下午3点的湿度 实数
Pressure9am 早上9点的大气压 实数
Pressure3pm 早上3点的大气压 实数
Cloud9am 早上9点的云指数 实数
Cloud3pm 早上3点的云指数 实数
Temp9am 早上9点的温度 实数
Temp3pm 早上3点的温度 实数
RainToday 今天是否下雨 No,Yes
RainTomorrow 明天是否下雨 No,Yes

Step2:数据读取/载入

## 我们利用Pandas自带的read_csv函数读取并转化为DataFrame格式

data = pd.read_csv('train.csv')

Step3:数据信息简单查看

## 利用.info()查看数据的整体信息
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 106644 entries, 0 to 106643
Data columns (total 23 columns):
 #   Column         Non-Null Count   Dtype  
---  ------         --------------   -----  
 0   Date           106644 non-null  object 
 1   Location       106644 non-null  object 
 2   MinTemp        106183 non-null  float64
 3   MaxTemp        106413 non-null  float64
 4   Rainfall       105610 non-null  float64
 5   Evaporation    60974 non-null   float64
 6   Sunshine       55718 non-null   float64
 7   WindGustDir    99660 non-null   object 
 8   WindGustSpeed  99702 non-null   float64
 9   WindDir9am     99166 non-null   object 
 10  WindDir3pm     103788 non-null  object 
 11  WindSpeed9am   105643 non-null  float64
 12  WindSpeed3pm   104653 non-null  float64
 13  Humidity9am    105327 non-null  float64
 14  Humidity3pm    103932 non-null  float64
 15  Pressure9am    96107 non-null   float64
 16  Pressure3pm    96123 non-null   float64
 17  Cloud9am       66303 non-null   float64
 18  Cloud3pm       63691 non-null   float64
 19  Temp9am        105983 non-null  float64
 20  Temp3pm        104599 non-null  float64
 21  RainToday      105610 non-null  object 
 22  RainTomorrow   106644 non-null  object 
dtypes: float64(16), object(7)
memory usage: 18.7+ MB
## 进行简单的数据查看,我们可以利用 .head() 头部.tail()尾部
data.head()
Date Location MinTemp MaxTemp Rainfall Evaporation Sunshine WindGustDir WindGustSpeed WindDir9am ... Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am Temp3pm RainToday RainTomorrow
0 2012/1/19 MountGinini 12.1 23.1 0.0 NaN NaN W 30.0 N ... 60.0 54.0 NaN NaN NaN NaN 17.0 22.0 No No
1 2015/4/13 Nhil 10.2 24.7 0.0 NaN NaN E 39.0 E ... 63.0 33.0 1021.9 1017.9 NaN NaN 12.5 23.7 No Yes
2 2010/8/5 Nuriootpa -0.4 11.0 3.6 0.4 1.6 W 28.0 N ... 97.0 78.0 1025.9 1025.3 7.0 8.0 3.9 9.0 Yes No
3 2013/3/18 Adelaide 13.2 22.6 0.0 15.4 11.0 SE 44.0 E ... 47.0 34.0 1025.0 1022.2 NaN NaN 15.2 21.7 No No
4 2011/2/16 Sale 14.1 28.6 0.0 6.6 6.7 E 28.0 NE ... 92.0 42.0 1018.0 1014.1 4.0 7.0 19.1 28.2 No No

5 rows × 23 columns

这里我们发现数据集中存在NaN,一般的我们认为NaN在数据集中代表了缺失值,可能是数据采集或处理时产生的一种错误。这里我们采用-1将缺失值进行填补,还有其他例如“中位数填补、平均数填补”的缺失值处理方法有兴趣的同学也可以尝试。

data = data.fillna(-1)
data.tail()
Date Location MinTemp MaxTemp Rainfall Evaporation Sunshine WindGustDir WindGustSpeed WindDir9am ... Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am Temp3pm RainToday RainTomorrow
106639 2011/5/23 Launceston 10.1 16.1 15.8 -1.0 -1.0 SE 31.0 NNW ... 99.0 86.0 999.2 995.2 -1.0 -1.0 13.0 15.6 Yes Yes
106640 2014/12/9 GoldCoast 19.3 31.7 36.0 -1.0 -1.0 SE 80.0 NNW ... 75.0 76.0 1013.8 1010.0 -1.0 -1.0 26.0 25.8 Yes Yes
106641 2014/10/7 Wollongong 17.5 22.2 1.2 -1.0 -1.0 WNW 65.0 WNW ... 61.0 56.0 1008.2 1008.2 -1.0 -1.0 17.8 21.4 Yes No
106642 2012/1/16 Newcastle 17.6 27.0 3.0 -1.0 -1.0 -1 -1.0 NE ... 68.0 88.0 -1.0 -1.0 6.0 5.0 22.6 26.4 Yes No
106643 2014/10/21 AliceSprings 16.3 37.9 0.0 14.2 12.2 ESE 41.0 NNE ... 8.0 6.0 1017.9 1014.0 0.0 1.0 32.2 35.7 No No

5 rows × 23 columns

## 利用value_counts函数查看训练集标签的数量
pd.Series(data['RainTomorrow']).value_counts()
No     82786
Yes    23858
Name: RainTomorrow, dtype: int64

我们发现数据集中的负样本数量远大于正样本数量,这种常见的问题叫做“数据不平衡”问题,在某些情况下需要进行一些特殊处理。

## 对于特征进行一些统计描述
data.describe()
MinTemp MaxTemp Rainfall Evaporation Sunshine WindGustSpeed WindSpeed9am WindSpeed3pm Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am Temp3pm
count 106644.000000 106644.000000 106644.000000 106644.000000 106644.000000 106644.000000 106644.000000 106644.000000 106644.000000 106644.000000 106644.000000 106644.000000 106644.000000 106644.000000 106644.000000 106644.000000
mean 12.129147 23.183398 2.313912 2.704798 3.509008 37.305137 13.852200 18.265378 67.940353 50.104657 917.003689 914.995385 2.381231 2.285670 16.877842 21.257600
std 6.444358 7.208596 8.379145 4.519172 5.105696 16.585310 8.949659 9.118835 20.481579 22.136917 304.042528 303.120731 3.483751 3.419658 6.629811 7.549532
min -8.500000 -4.800000 -1.000000 -1.000000 -1.000000 -1.000000 -1.000000 -1.000000 -1.000000 -1.000000 -1.000000 -1.000000 -1.000000 -1.000000 -7.200000 -5.400000
25% 7.500000 17.900000 0.000000 -1.000000 -1.000000 30.000000 7.000000 11.000000 56.000000 35.000000 1011.000000 1008.500000 -1.000000 -1.000000 12.200000 16.300000
50% 12.000000 22.600000 0.000000 1.600000 0.200000 37.000000 13.000000 17.000000 70.000000 51.000000 1016.700000 1014.200000 1.000000 1.000000 16.700000 20.900000
75% 16.800000 28.300000 0.600000 5.400000 8.700000 46.000000 19.000000 24.000000 83.000000 65.000000 1021.800000 1019.400000 6.000000 6.000000 21.500000 26.300000
max 31.900000 48.100000 268.600000 145.000000 14.500000 135.000000 130.000000 87.000000 100.000000 100.000000 1041.000000 1039.600000 9.000000 9.000000 39.400000 46.200000

Step4:可视化描述

为了方便,我们先纪录数字特征与非数字特征:

numerical_features = [x for x in data.columns if data[x].dtype == np.float]
category_features = [x for x in data.columns if data[x].dtype != np.float and x != 'RainTomorrow']
## 选取三个特征与标签组合的散点可视化
sns.pairplot(data=data[['Rainfall',
'Evaporation',
'Sunshine'] + ['RainTomorrow']], diag_kind='hist', hue= 'RainTomorrow')
plt.show()

从上图可以发现,在2D情况下不同的特征组合对于第二天下雨与不下雨的散点分布,以及大概的区分能力。相对的Sunshine与其他特征的组合更具有区分能力

for col in data[numerical_features].columns:
    if col != 'RainTomorrow':
        sns.boxplot(x='RainTomorrow', y=col, saturation=0.5, palette='pastel', data=data)
        plt.title(col)
        plt.show()




利用箱型图我们也可以得到不同类别在不同特征上的分布差异情况。我们可以发现Sunshine,Humidity3pm,Cloud9am,Cloud3pm的区分能力较强

tlog = {}
for i in category_features:
    tlog[i] = data[data['RainTomorrow'] == 'Yes'][i].value_counts()
flog = {}
for i in category_features:
    flog[i] = data[data['RainTomorrow'] == 'No'][i].value_counts()

plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
plt.title('RainTomorrow')
sns.barplot(x = pd.DataFrame(tlog['Location']).sort_index()['Location'], y = pd.DataFrame(tlog['Location']).sort_index().index, color = "red")
plt.subplot(1,2,2)
plt.title('Not RainTomorrow')
sns.barplot(x = pd.DataFrame(flog['Location']).sort_index()['Location'], y = pd.DataFrame(flog['Location']).sort_index().index, color = "blue")
plt.show()

从上图可以发现不同地区降雨情况差别很大,有些地方明显更容易降雨

plt.figure(figsize=(10,2))
plt.subplot(1,2,1)
plt.title('RainTomorrow')
sns.barplot(x = pd.DataFrame(tlog['RainToday'][:2]).sort_index()['RainToday'], y = pd.DataFrame(tlog['RainToday'][:2]).sort_index().index, color = "red")
plt.subplot(1,2,2)
plt.title('Not RainTomorrow')
sns.barplot(x = pd.DataFrame(flog['RainToday'][:2]).sort_index()['RainToday'], y = pd.DataFrame(flog['RainToday'][:2]).sort_index().index, color = "blue")
plt.show()

上图我们可以发现,今天下雨明天不一定下雨,但今天不下雨,第二天大概率也不下雨。

3.2 特征向量编码

Step5:对离散变量进行编码

由于XGBoost无法处理字符串类型的数据,我们需要一些方法讲字符串数据转化为数据。一种最简单的方法是把所有的相同类别的特征编码成同一个值,例如女=0,男=1,狗狗=2,所以最后编码的特征值是在$[0, 特征数量-1]$之间的整数。除此之外,还有独热编码、求和编码、留一法编码等等方法可以获得更好的效果。

## 把所有的相同类别的特征编码为同一个值
def get_mapfunction(x):
    mapp = dict(zip(x.unique().tolist(),
         range(len(x.unique().tolist()))))
    def mapfunction(y):
        if y in mapp:
            return mapp[y]
        else:
            return -1
    return mapfunction
for i in category_features:
    data[i] = data[i].apply(get_mapfunction(data[i]))

## 编码后的字符串特征变成了数字

data['Location'].unique()
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48])

3.3 模型训练预测

Step6:利用 XGBoost 进行训练与预测

## 为了正确评估模型性能,将数据划分为训练集和测试集,并在训练集上训练模型,在测试集上验证模型性能。
from sklearn.model_selection import train_test_split

## 选择其类别为0和1的样本 (不包括类别为2的样本)
data_target_part = data['RainTomorrow']
data_features_part = data[[x for x in data.columns if x != 'RainTomorrow']]

## 测试集大小为20%, 80%/20%分
x_train, x_test, y_train, y_test = train_test_split(data_features_part, data_target_part, test_size = 0.2, random_state = 2020)


#查看标签数据
print(y_train[0:2],y_test[0:2])

# 替换Yes为1,No为0
y_train = y_train.replace({'Yes': 1, 'No': 0})
y_test  = y_test.replace({'Yes': 1, 'No': 0})

# 打印修改后的结果
print(y_train[0:2],y_test[0:2])
98173    No
33154    No
Name: RainTomorrow, dtype: object 10273    Yes
90769     No
Name: RainTomorrow, dtype: object
98173    0
33154    0
Name: RainTomorrow, dtype: int64 10273    1
90769    0
Name: RainTomorrow, dtype: int64
The label for xgboost must consist of integer labels of the form 0, 1, 2, ..., [num_class - 1]. This means that the labels must be sequential integers starting from 0 up to the total number of classes minus 1. For example, if there are 3 classes, the labels should be 0, 1, and 2. If the labels are not in this format, xgboost may not be able to train the model properly.
## 导入XGBoost模型
from xgboost.sklearn import XGBClassifier
## 定义 XGBoost模型 
clf = XGBClassifier(use_label_encoder=False)
# 在训练集上训练XGBoost模型
clf.fit(x_train, y_train)

#https://cloud.tencent.com/developer/ask/sof/913362/answer/1303557
[17:34:10] WARNING: ../src/learner.cc:1061: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.





XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
              colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,
              importance_type='gain', interaction_constraints='',
              learning_rate=0.300000012, max_delta_step=0, max_depth=6,
              min_child_weight=1, missing=nan, monotone_constraints='()',
              n_estimators=100, n_jobs=24, num_parallel_tree=1, random_state=0,
              reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,
              tree_method='exact', use_label_encoder=False,
              validate_parameters=1, verbosity=None)
## 在训练集和测试集上分布利用训练好的模型进行预测
train_predict = clf.predict(x_train)
test_predict = clf.predict(x_test)
from sklearn import metrics

## 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_train,train_predict))
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_test,test_predict))

## 查看混淆矩阵 (预测值和真实值的各类情况统计矩阵)
confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test)
print('The confusion matrix result:\n',confusion_matrix_result)

# 利用热力图对于结果进行可视化
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()
The accuracy of the Logistic Regression is: 0.8982476703979371
The accuracy of the Logistic Regression is: 0.8575179333302076
The confusion matrix result:
 [[15656  2142]
 [  897  2634]]

我们可以发现共有15759 + 2306个样本预测正确,2470 + 794个样本预测错误。

3.3.1 特征选择

Step7: 利用 XGBoost 进行特征选择

XGBoost的特征选择属于特征选择中的嵌入式方法,在XGboost中可以用属性feature_importances_去查看特征的重要度。

? sns.barplot
sns.barplot(y=data_features_part.columns, x=clf.feature_importances_)

从图中我们可以发现下午3点的湿度与今天是否下雨是决定第二天是否下雨最重要的因素

初次之外,我们还可以使用XGBoost中的下列重要属性来评估特征的重要性。

  • weight:是以特征用到的次数来评价
  • gain:当利用特征做划分的时候的评价基尼指数
  • cover:利用一个覆盖样本的指标二阶导数(具体原理不清楚有待探究)平均值来划分。
  • total_gain:总基尼指数
  • total_cover:总覆盖
from sklearn.metrics import accuracy_score
from xgboost import plot_importance

def estimate(model,data):

    #sns.barplot(data.columns,model.feature_importances_)
    ax1=plot_importance(model,importance_type="gain")
    ax1.set_title('gain')
    ax2=plot_importance(model, importance_type="weight")
    ax2.set_title('weight')
    ax3 = plot_importance(model, importance_type="cover")
    ax3.set_title('cover')
    plt.show()
def classes(data,label,test):
    model=XGBClassifier()
    model.fit(data,label)
    ans=model.predict(test)
    estimate(model, data)
    return ans
 
ans=classes(x_train,y_train,x_test)
pre=accuracy_score(y_test, ans)
print('acc=',accuracy_score(y_test,ans))

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/xgboost/sklearn.py:888: UserWarning: The use of label encoder in XGBClassifier is deprecated and will be removed in a future release. To remove this warning, do the following: 1) Pass option use_label_encoder=False when constructing XGBClassifier object; and 2) Encode your labels (y) as integers starting with 0, i.e. 0, 1, 2, ..., [num_class - 1].
  warnings.warn(label_encoder_deprecation_msg, UserWarning)


[17:34:28] WARNING: ../src/learner.cc:1061: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.



acc= 0.8575179333302076

这些图同样可以帮助我们更好的了解其他重要特征。

Step8: 通过调整参数获得更好的效果

XGBoost中包括但不限于下列对模型影响较大的参数:

1. learning_rate: 有时也叫作eta,系统默认值为0.3。每一步迭代的步长,很重要。太大了运行准确率不高,太小了运行速度慢。
2. subsample:系统默认为1。这个参数控制对于每棵树,随机采样的比例。减小这个参数的值,算法会更加保守,避免过拟合, 取值范围零到一。
3. colsample_bytree:系统默认值为1。我们一般设置成0.8左右。用来控制每棵随机采样的列数的占比(每一列是一个特征)。
4. max_depth: 系统默认值为6,我们常用3-10之间的数字。这个值为树的最大深度。这个值是用来控制过拟合的。max_depth越大,模型学习的更加具体。

3.3.2 核心参数调优

1.
eta
[默认0.3]
通过为每一颗树增加权重,提高模型的鲁棒性。
典型值为0.01-0.2。

2.
min_child_weight
[默认1]
决定最小叶子节点样本权重和。
这个参数可以避免过拟合。当它的值较大时,可以避免模型学习到局部的特殊样本。
但是如果这个值过高,则会导致模型拟合不充分。

3.
max_depth
[默认6]
这个值也是用来避免过拟合的。max_depth越大,模型会学到更具体更局部的样本。
典型值:3-10

4.
max_leaf_nodes
树上最大的节点或叶子的数量。
可以替代max_depth的作用。
这个参数的定义会导致忽略max_depth参数。

5.
gamma
[默认0]
在节点分裂时,只有分裂后损失函数的值下降了,才会分裂这个节点。Gamma指定了节点分裂所需的最小损失函数下降值。
这个参数的值越大,算法越保守。这个参数的值和损失函数息息相关。

6.
max_delta_step
[默认0]
这参数限制每棵树权重改变的最大步长。如果这个参数的值为0,那就意味着没有约束。如果它被赋予了某个正值,那么它会让这个算法更加保守。
但是当各类别的样本十分不平衡时,它对分类问题是很有帮助的。

7.
subsample
[默认1]
这个参数控制对于每棵树,随机采样的比例。
减小这个参数的值,算法会更加保守,避免过拟合。但是,如果这个值设置得过小,它可能会导致欠拟合。
典型值:0.5-1

8.
colsample_bytree
[默认1]
用来控制每棵随机采样的列数的占比(每一列是一个特征)。
典型值:0.5-1

9.
colsample_bylevel
[默认1]
用来控制树的每一级的每一次分裂,对列数的采样的占比。
subsample参数和colsample_bytree参数可以起到相同的作用,一般用不到。

10.
lambda
[默认1]
权重的L2正则化项。(和Ridge regression类似)。
这个参数是用来控制XGBoost的正则化部分的。虽然大部分数据科学家很少用到这个参数,但是这个参数在减少过拟合上还是可以挖掘出更多用处的。

11.
alpha
[默认1]
权重的L1正则化项。(和Lasso regression类似)。
可以应用在很高维度的情况下,使得算法的速度更快。

12.
scale_pos_weight
[默认1]
在各类别样本十分不平衡时,把这个参数设定为一个正值,可以使算法更快收敛。

3.3.3 网格调参法

调节模型参数的方法有贪心算法、网格调参、贝叶斯调参等。这里我们采用网格调参,它的基本思想是穷举搜索:在所有候选的参数选择中,通过循环遍历,尝试每一种可能性,表现最好的参数就是最终的结果

## 从sklearn库中导入网格调参函数
from sklearn.model_selection import GridSearchCV

## 定义参数取值范围
learning_rate = [0.1, 0.3,]
subsample = [0.8]
colsample_bytree = [0.6, 0.8]
max_depth = [3,5]

parameters = { 'learning_rate': learning_rate,
              'subsample': subsample,
              'colsample_bytree':colsample_bytree,
              'max_depth': max_depth}
model = XGBClassifier(n_estimators = 20)

## 进行网格搜索
clf = GridSearchCV(model, parameters, cv=3, scoring='accuracy',verbose=1,n_jobs=-1)

clf = clf.fit(x_train, y_train)
## 网格搜索后的最好参数为

clf.best_params_
## 在训练集和测试集上分布利用最好的模型参数进行预测

## 定义带参数的 XGBoost模型 
clf = XGBClassifier(colsample_bytree = 0.6, learning_rate = 0.3, max_depth= 8, subsample = 0.9)
# 在训练集上训练XGBoost模型
clf.fit(x_train, y_train)

train_predict = clf.predict(x_train)
test_predict = clf.predict(x_test)

## 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_train,train_predict))
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_test,test_predict))

## 查看混淆矩阵 (预测值和真实值的各类情况统计矩阵)
confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test)
print('The confusion matrix result:\n',confusion_matrix_result)

# 利用热力图对于结果进行可视化
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/xgboost/sklearn.py:888: UserWarning: The use of label encoder in XGBClassifier is deprecated and will be removed in a future release. To remove this warning, do the following: 1) Pass option use_label_encoder=False when constructing XGBClassifier object; and 2) Encode your labels (y) as integers starting with 0, i.e. 0, 1, 2, ..., [num_class - 1].
  warnings.warn(label_encoder_deprecation_msg, UserWarning)


[17:55:25] WARNING: ../src/learner.cc:1061: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.
The accuracy of the Logistic Regression is: 0.9382992439781984
The accuracy of the Logistic Regression is: 0.856674011908669
The confusion matrix result:
 [[15611  2115]
 [  942  2661]]

原本有2470 + 790个错误,现在有 2112 + 939个错误,带来了明显的正确率提升。

更多调参技巧请参考:
https://blog.csdn.net/weixin_62684026/article/details/126859262

4. 总结

XGBoost的主要优点:

  1. 简单易用。相对其他机器学习库,用户可以轻松使用XGBoost并获得相当不错的效果。
  2. 高效可扩展。在处理大规模数据集时速度快效果好,对内存等硬件资源要求不高。
  3. 鲁棒性强。相对于深度学习模型不需要精细调参便能取得接近的效果。
  4. XGBoost内部实现提升树模型,可以自动处理缺失值。

XGBoost的主要缺点:

  1. 相对于深度学习模型无法对时空位置建模,不能很好地捕获图像、语音、文本等高维数据。
  2. 在拥有海量训练数据,并能找到合适的深度学习模型时,深度学习的精度可以遥遥领先XGBoost。

本项目链接:
https://www.heywhale.com/home/column/64141d6b1c8c8b518ba97dcc

参考链接:
https://tianchi.aliyun.com/course/278/3423


本人最近打算整合ML、DRL、NLP等相关领域的体系化项目课程,方便入门同学快速掌握相关知识。声明:部分项目为网络经典项目方便大家快速学习,后续会不断增添实战环节(比赛、论文、现实应用等)。

  • 对于机器学习这块规划为:基础入门机器学习算法--->简单项目实战--->数据建模比赛----->相关现实中应用场景问题解决。一条路线帮助大家学习,快速实战。
  • 对于深度强化学习这块规划为:基础单智能算法教学(gym环境为主)---->主流多智能算法教学(gym环境为主)---->单智能多智能题实战(论文复现偏业务如:无人机优化调度、电力资源调度等项目应用)
  • 自然语言处理相关规划:除了单点算法技术外,主要围绕知识图谱构建进行:信息抽取相关技术(含智能标注)--->知识融合---->知识推理---->图谱应用

上述对于你掌握后的期许:

  1. 对于ML,希望你后续可以乱杀数学建模相关比赛(参加就获奖保底,top还是难的需要钻研)
  2. 可以实际解决现实中一些优化调度问题,而非停留在gym环境下的一些游戏demo玩玩。(更深层次可能需要自己钻研了,难度还是很大的)
  3. 掌握可知识图谱全流程构建其中各个重要环节算法,包含图数据库相关知识。

这三块领域耦合情况比较大,后续会通过比如:搜索推荐系统整个项目进行耦合,各项算法都会耦合在其中。举例:知识图谱就会用到(图算法、NLP、ML相关算法),搜索推荐系统(除了该领域召回粗排精排重排混排等算法外,还有强化学习、知识图谱等耦合在其中)。饼画的有点大,后面慢慢实现。