2024年1月

errgroup
想必稍有经验的golang程序员都应该听说过,实际项目中用过的也应该不在少数。它和
sync.WaitGroup
类似,都可以发起执行并等待一组协程直到所有协程运行结束。除此之外errgroup还可以在协程出错时取消当前的context,以及它还能控制可运行的协程的数量。

但在日常的代码review时我注意到了几个比较常见的问题,这些问题有的无伤大雅最多只会造成一些性能损失,有的则会导致资源泄露甚至是死锁崩溃。

这里对这些比较典型的误用做下记录。

多余的context嵌套

先说个不是很常见但我还是遇到过两三次的不太妥当的用法。

我们知道errgroup在协程返回错误的时候会取消掉创建时传入的context,这是为了能让同组的其他协程知道有错误发生应该尽快退出执行。

所以errgroup使用的context应该是派生于当前上下文的新的context,这样才不会让可能的取消操作影响到errgroup之外的范围。

因此第一个常见误用出现了:

func DoWork(ctx context.Context) {
    errCtx, cancel := context.WithCancel(ctx)
    defer cancel()
    group, errCtx := errgroup.WithContext(ctx)
    ...
}

误用在哪呢?答案是context会自动帮我们派生出新的context,除了需要设置超时一般不需要再次额外封装,看源代码:

// https://github.com/golang/sync/blob/master/errgroup/errgroup.go

// WithContext returns a new Group and an associated Context derived from ctx.
//
// The derived Context is canceled the first time a function passed to Go
// returns a non-nil error or the first time Wait returns, whichever occurs
// first.
func WithContext(ctx context.Context) (*Group, context.Context) {
	ctx, cancel := withCancelCause(ctx)
	return &Group{cancel: cancel}, ctx
}

// https://github.com/golang/sync/blob/master/errgroup/go120.go

func withCancelCause(parent context.Context) (context.Context, func(error)) {
	return context.WithCancelCause(parent)
}

多于的嵌套会浪费内存,以及会对性能带来负面影响,尤其是需要从context里取出某些value的时候,因为取value是对一层层嵌套的context递归查找的,嵌套层数越多查找就有可能越慢。

不过前面也说到了,有一种情况是允许的,那就是对整个errgroup所有的协程设置超时:

func DoWork(ctx context.Context) {
    errCtx, cancel := context.WithTimeout(ctx, 10 * time.Second)
    defer cancel()
    group, errCtx := errgroup.WithContext(ctx)
    ...
}

目前想设置超时只能这样做,所以这种算是特例。

Wait返回的时机

第二种误用比第一种要常见些。主要是对errgroup的行为理解上有误解。

这种误解经常表现为:如果协程返回错误或者ctx的超时被触发,
Wait
方法就会立即返回。

这并不是事实。

先来看看
Wait
的文档怎么说的:

Wait blocks until all function calls from the Go method have returned, then returns the first non-nil error (if any) from them.

Wait
需要等到所有goroutine返回后它才会返回。哪怕超时了,context取消了也一样,需要先等所有协程退出。再来看代码:

// https://github.com/golang/sync/blob/master/errgroup/errgroup.go

func (g *Group) Wait() error {
	g.wg.Wait()
	if g.cancel != nil {
		g.cancel(g.err)
	}
	return g.err
}

可以看到确实需要先等所有协程返回。如果你观察比较敏锐的话,其实能发现errgroup会对协程做包装,会不会包装的代码里有什么办法提前中止协程的执行呢?还是来看代码:

// https://github.com/golang/sync/blob/master/errgroup/errgroup.go

func (g *Group) Go(f func() error) {
	// 检查当前协程是否可运行的代码,先忽略

	g.wg.Add(1)
	go func() {
		defer g.done()  // 重点在这

		if err := f(); err != nil {
			g.errOnce.Do(func() {
				g.err = err
				if g.cancel != nil {
					g.cancel(g.err)
				}
			})
		}
	}()
}

注意那个defer,这意味着done只有在包装的函数运行结束(在你自己的函数f运行完并设置了error以及取消了ctx之后)时才会执行。

如果你自己的函数里不检查超时和上下文是否被取消,那leak和卡死问题就要找上门来了,比如下面这样的:

func main() {
    errCtx, cancel := context.WithTimeout(context.Background(), 1 * time.Second)
    defer cancel()
    group, errCtx := errgroup.WithContext(errCtx)
    group.Go(func () error {
        time.Sleep(10 * time.Second)
        fmt.Println("running")
        return nil
    })
    group.Go(func () error {
        return errors.New("error")
    })
    fmt.Println(group.Wait())
}

猜猜运行结果和执行时间。答案是
running\nerror\n
,运行需要10秒以上。

这种误用也很好识别,只要传给
Go
方法的函数里没有好好处理
errCtx
,那多半是有问题的。

不过要说句公道话,
Go
的参数形式不符合一般使用context的惯例,
Wait
的行为和其他能自主取消线程执行的语言也不一样造成了误用,语言和接口设计得背一半锅不能全赖用它的程序员。

SetLimit和死锁

这种就更常见了,尤其发生在把errgroup当成普通协程池用的时候。

先来我最爱的猜谜游戏,下面的代码运行结果是什么?

func main() {
    group, _ := errgroup.WithContext(context.Background())
    group.SetLimit(2) // 想法:只允许2个协程同时运行,但多个任务提交到“协程池”
    group.Go(func () error {
        fmt.Println("running 1")
        // 运行子任务
        group.Go(func () error {
            fmt.Println("sub running 1")
            return nil
        })
        group.Go(func () error {
            fmt.Println("sub running 2")
            return nil
        })
        return nil
    })
    group.Go(func () error {
        fmt.Println("running 2")
        // 运行子任务
        group.Go(func () error {
            fmt.Println("sub running 3")
            return nil
        })
        group.Go(func () error {
            fmt.Println("sub running 4")
            return nil
        })
        return nil
    })
    fmt.Println(group.Wait())
}

答案是会死锁panic。而且是100%触发。

我会详细的解释这是为什么,但在之前我要说一个重要的知识点:

SetLimit
设置的不是同时在运行的协程数量,而是设置errgroup内最多同时能持有多少个协程,errgroup持有的协程可以在运行也可以在等待运行。

如果每个running的sub running只有一个,那么有小概率不会死锁,所以我特地每组创建了两个,原因没那么复杂,看来后面的解释之后可以自行推理。

下面来解释,首先看
SetLimit
的代码,一切是从这开始的:

// https://github.com/golang/sync/blob/master/errgroup/errgroup.go

// SetLimit limits the number of active goroutines in this group to at most n.
// A negative value indicates no limit.
//
// Any subsequent call to the Go method will block until it can add an active
// goroutine without exceeding the configured limit.
//
// The limit must not be modified while any goroutines in the group are active.
func (g *Group) SetLimit(n int) {
	if n < 0 {
		g.sem = nil
		return
	}
	if len(g.sem) != 0 {
		panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
	}
	g.sem = make(chan token, n)
}

g.sem

chan strcut{}
。做的事很简单,如果参数大于0就按参数初始化一个长度为n的chan给
g.sem
,小于0就清空
g.sem
。如果你经验比较丰富的话,已经可以看出来这是一个简单的
ticket pool
模式了,这个模式在grpc里也有应用。

ticket pool
模式的原理是设置一个固定大小为n的空chan,然后协程要运行的时候向这个chan写入数据,协程运行结束的时候从chan里把写入的数据读出(可能会读到别人写进去的,但只要遵循这个写入读出的顺序就没问题)。如果chan的写入阻塞了,就说明已经有n个协程在运行了,新的协程需要等到有协程执行完并读出数据后才能继续执行;正常情况下读出操作不会被阻塞。这个是限制goroutine数量的最常见的手段之一。根据写入操作实在协程内部还是发起协程的调用者那里进行,这个模式还能分别控制“最大同时运行的goroutine数量”或“goroutine总数量”。其中
goroutine的总数量 = 在运行的goroutine数量 + 其他等待运行goroutine的数量

而errgroup属于后者。还记得
Go
的代码里我注释掉的那部分吧,现在可以看了:

// https://github.com/golang/sync/blob/master/errgroup/errgroup.go

func (g *Group) Go(f func() error) {
	if g.sem != nil {
		g.sem <- token{} // token是struct{}
	}

	g.wg.Add(1)
	go func() {
		defer g.done()

		if err := f(); err != nil {
			// 设置错误值
		}
	}()
}

func (g *Group) done() {
	if g.sem != nil {
		<-g.sem // 从ticket pool里读出
	}
	g.wg.Done()
}

进入
Go
的时候并没有启动协程,而是先检查
sem
,如果有设置limit,就需要按操作ticket pool的流程先写入数据。写入成功才会创建协程,协程运行结束后把数据读出。这样限制了errgroup最大可以持有的协程数量,因为超过数量限制会阻塞住不创建新的协程。


Go
完成sem的写入并执行go语句之前,errgroup并没有“持有”go语句创建的这个协程。协程运行结束并把sem的数据读出后,group将不会继续“持有”这个协程。

问题就出在写入那里。假设调度器是这样运行我们的猜谜代码的:

  1. 先启动running 1的协程,sem空位有2个,正常运行,running 1运行结束后它写入的数据才会被读出
  2. 接着启动running 2,sem还剩一个空位,没问题,running 2运行结束后它写入的数据才会被读出
  3. running 2先被执行,于是准备创建sub running 3的协程
  4. 这时sem没空位了,创建sub running 3的
    Go
    阻塞
  5. 调度器发现running 2被阻塞了,于是让running 1执行(假设而已,多核处理器上很可能是同时运行的)
  6. running 1输出后准备创建sub running 1的协程
  7. sem还是满的,
    Go
    又阻塞了
  8. 调度器发现running 1和running 2都阻塞了,于是只能让main goroutine执行(这里忽略runtime自己的协程,因为不影响死锁检测结果)
  9. main阻塞在
    Wait
    上,所有其他协程执行完才能继续执行
  10. 没有能继续运行下去的协程,全都阻塞了(注意是阻塞不是sleep),死锁检测发现这种情况,panic

我知道实际执行顺序肯定不一样,但死锁的原因一样的:因为之前的协程没有让出ticket pool,后面的子任务需要向pool写入,而前面占有pool的协程需要等子任务执行完才会让出pool。
这是一个典型的循环依赖导致的死锁,诱因是同一个errgroup的嵌套使用

是什么导致了你踩坑呢?最大的可能是文档里那个“active”。这个词太模糊了,你可以发现它即能代指running又能代指runnable,还能两个同时代指。这里因为下面还有一段话,所以可以根据上下文估摸着猜出active想代指的是所有被创建出来的协程不管它们在不在运行。但如果你只看了第一段话就先入为主放心大胆用的话,坑就来了。这样的词缺少足够的上下文时连母语者都会觉得有二义性,更何况我们这些作为第二语言甚至第三语言的人。

而errgroup选择限制goroutine总数量也是有原因的:只限制同时运行的goroutine的数量就没法限制协程的总数量,协程虽然很轻量,但还是要占用内存以及花费cpu资源来调度的,不受控制很可能会产生灾难性后果,比如一个不当心在循环里创建了数百万个协程导致严重的内存占用和调度压力,控制了总数量这类问题就可以避免。

幸运的是,这个误用也很好识别,
但凡有嵌套使用同一个errgroup的时候,就要警报大作了

更幸运的是,如果你没有嵌套调用,那么这个
SetLimit
不管设置成哪个数字,都能正常限制顶层的goroutine的数量(或者不做限制),它不能限制的是从顶层协程里嵌套调用派生出的子协程,只要不嵌套调用同一个group,什么问题的不会有。

前面两种误用都是该避免的,然而嵌套的errgroup虽然不多见但确实有用处,所以我也会提供写简单的解决方案以供参考。

第一种是设置一个足够的limit数值,聪明人应该发现了,如果把limit设置成希望group里同时存在的协程的总数量(顶层+所有嵌套派生的),问题就能避免。这没错,但我不推荐,两点原因:

  1. 设置成总数后起不到限制同时运行的协程的数量,在go里控制同时运行的协程数量是个很麻烦的事,limit通常只能起到“上限”的作用,但如果上限设置大了就容易出现问题。比如你的系统只能同时运行3个协程,你还有别的任务占用了一个协程在运行,为了避免死锁你设置了limit为4,这时候资源抢占和协程调度延迟都会明显上升,出现这类情况你的系统就离崩溃只有一步之遥了。
  2. 算这个数量很麻烦,上面的例子你可以很简单算出是4,如果我再套一层或者加上几个可以跳过
    Go
    调用的条件分支呢?而且limit设置多了是起不到限制goroutine数量的作用的,设少了会死锁。
  3. limit多半是个写死的常量或者干脆是魔数,那么下次协程的逻辑改了这个数字多半得跟着改,如果你算错了或者忘记改了,那么你就惨了,死锁就像个地雷一样埋下了。

综上,你应该用第二种方法:永远不要嵌套使用同一个errgroup,真有嵌套需求也应该使用新的errgroup实例,这样可以避免死锁,也最符合当前需求的语义:

func main() {
    group, errCtx := errgroup.WithContext(context.Background())
    group.SetLimit(1) // 想法:只允许2个协程同时运行,但多个任务提交到“协程池”
    group.Go(func () error {
        fmt.Println("running 1")
        // 运行子任务
        // 新建一个errgroup,上下文使用外层group的
        subGroup, _ := errgroup.WithContext(errCtx)
        subGroup.SetLimit(1)
        subGroup.Go(func () error {
            fmt.Println("sub running 1")
            return nil
        })
        subGroup.Go(func () error {
            fmt.Println("sub running 2")
            return nil
        })
        fmt.Println(subGroup.Wait())
        return nil
    })
    group.Go(func () error {
        fmt.Println("running 2")
        // 运行子任务
        subGroup, _ := errgroup.WithContext(errCtx)
        subGroup.SetLimit(1)
        subGroup.Go(func () error {
            fmt.Println("sub running 3")
            return nil
        })
        subGroup.Go(func () error {
            fmt.Println("sub running 4")
            return nil
        })
        fmt.Println(subGroup.Wait())
        return nil
    })
    fmt.Println(group.Wait())
}

是的,现在所有limit设置成1也不会死锁。因为没有嵌套调用,因此也没有资源间的循环依赖了。

当然还有终极方案:别把errgroup当成协程池,如果你有复杂功能依赖于协程池找个功能全面的真正的协程池比如ants之类的用。

对了。你问
SetLimit
传0进去会发生什么,那当然是直接死锁了。这也符合语义,因为你的group里不能有任何协程,这时候再调
Go
当然是不对的,死锁panic也是应该的。所以传0进去导致死锁这不算坑,也算不上误用。

总结

总结下上面三个误用:

  1. 传递有多余嵌套的context给errgroup
  2. 在加入errgroup的协程里没有正确处理context取消和超时
  3. 嵌套使用同一个errgroup

已有的静态分析工具不是很能识别这类问题,要么自己写个能识别的,要么只能靠review把关了。

比较大众的观点认为go简单易用,但实际上并不总是如此,有句话叫“Simple is not Easy”,go的使用者需要时刻为“大道至简”付出相应的代价。

简介

WebSocket 是基于TCP/IP协议,独立于HTTP协议的通信协议。WebSocket 连接允许客户端和服务器之间的全双工通信,以便任何一方都可以通过已建立的连接将数据推送到另一方。

我们常用的HTTP是客户端通过「请求-响应」的方式与服务器建立通信的,必须是客户端主动触发的行为,服务端只是做好接口被动等待请求。而在某些场景下的动作,是需要服务端主动触发的,比如向客户端发送消息、实时通讯、远程控制等。客户端是不知道这些动作几时触发的,假如用HTTP的方式,那么设备端需要不断轮询服务端,这样的方式对服务器压力太大,同时产生很多无效请求,且具有延迟性。于是才采用可以建立双向通讯的长连接协议。通过握手建立连接后,服务端可以实时发送数据与指令到设备端,服务器压力小。

Spring WebSocket是Spring框架的一部分,提供了在Web应用程序中实现实时双向通信的能力。本教程将引导你通过一个简单的例子,演示如何使用Spring WebSocket建立一个实时通信应用。

准备工作

确保你的项目中已经引入了Spring框架的WebSocket模块。你可以通过Maven添加以下依赖:

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-websocket</artifactId>
</dependency>

创建WebSocket配置类(实现WebSocketConfigurer接口)

首先,创建一个配置类,用于配置WebSocket的相关设置。

package com.ci.erp.human.config;

import com.ci.erp.human.handler.WebSocketHandler;
import com.ci.erp.human.interceptor.WebSocketHandleInterceptor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;

/**
 *
 * Websocket配置类
 *
 * @author lucky_fd
 * @since 2024-01-17
 */
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        // 注册websocket处理器和拦截器
        registry.addHandler(webSocketHandler(), "/websocket/server")
                .addInterceptors(webSocketHandleInterceptor()).setAllowedOrigins("*");
        registry.addHandler(webSocketHandler(), "/sockjs/server").setAllowedOrigins("*")
                .addInterceptors(webSocketHandleInterceptor()).withSockJS();
    }

    @Bean
    public WebSocketHandler webSocketHandler() {
        return new WebSocketHandler();
    }

    @Bean
    public WebSocketHandleInterceptor webSocketHandleInterceptor() {
        return new WebSocketHandleInterceptor();
    }
}

上面的配置类使用@EnableWebSocket注解启用WebSocket,并通过registerWebSocketHandlers方法注册WebSocket处理器。

  • registerWebSocketHandlers:这个方法是向spring容器注册一个handler处理器及对应映射地址,可以理解成MVC的Handler(控制器方法),websocket客户端通过请求的url查找处理器进行处理

  • addInterceptors:拦截器,当建立websocket连接的时候,我们可以通过继承spring的HttpSessionHandshakeInterceptor来做一些事情。

  • setAllowedOrigins:跨域设置,
    *
    表示所有域名都可以,不限制, 域包括ip:port, 指定
    *
    可以是任意的域名,不加的话默认localhost+本服务端口

  • withSockJS: 这个是应对浏览器不支持websocket协议的时候降级为轮询的处理。

创建WebSocket消息处理器(实现TextWebSocketHandler 接口)

接下来,创建一个消息处理器,处理客户端发送的消息。

package com.ci.erp.human.handler;

import cn.hutool.core.util.ObjectUtil;
import com.ci.erp.common.core.utils.JsonUtils;
import com.ci.erp.human.domain.thirdVo.YYHeartbeat;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

/**
 *
 * websocket处理类
 * 实现WebSocketHandler接口
 *
 * - websocket建立连接后执行afterConnectionEstablished回调接口
 * - websocket关闭连接后执行afterConnectionClosed回调接口
 * - websocket接收客户端消息执行handleTextMessage接口
 * - websocket传输异常时执行handleTransportError接口
 *
 * @author lucky_fd
 * @since 2024-01-17
 */

public class WebSocketHandler extends TextWebSocketHandler {

    /**
     * 存储websocket客户端连接
     * */
    private static final Map<String, WebSocketSession> connections = new HashMap<>();

    /**
     * 建立连接后触发
     * */
    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        System.out.println("成功建立websocket连接");
        // 建立连接后将连接以键值对方式存储,便于后期向客户端发送消息
        // 以客户端连接的唯一标识为key,可以通过客户端发送唯一标识
        connections.put(session.getRemoteAddress().getHostName(), session);
        System.out.println("当前客户端连接数:" + connections.size());
    }

    /**
     * 接收消息
     * */
    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        System.out.println("收到消息: " + message.getPayload());
		
		// 收到客户端请求消息后进行相应业务处理,返回结果
        this.sendMessage(session.getRemoteAddress().getHostName(),new TextMessage("收到消息: " + message.getPayload()));
    }

    /**
     * 传输异常处理
     * */
    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        super.handleTransportError(session, exception);
    }

    /**
     * 关闭连接时触发
     * */
    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        System.out.println("触发关闭websocket连接");
        // 移除连接
        connections.remove(session.getRemoteAddress().getHostName());
    }

    @Override
    public boolean supportsPartialMessages() {
        return super.supportsPartialMessages();
    }

    /**
     * 向连接的客户端发送消息
     *
     * @author lucky_fd
     * @param clientId 客户端标识
     * @param message 消息体
     **/
    public void sendMessage(String clientId, TextMessage message) {
        for (String client : connections.keySet()) {
            if (client.equals(clientId)) {
                try {
                    WebSocketSession session = connections.get(client);
                    // 判断连接是否正常
                    if (session.isOpen()) {
                        session.sendMessage(message);
                    }
                } catch (IOException e) {
                    System.out.println(e.getMessage());
                }
                break;
            }
        }
    }
}

通过消息处理器,在开发中我们就可以实现向指定客户端或所有客户端发送消息,实现相应业务功能。

创建拦截器

拦截器会在握手时触发,可以用来进行权限验证

package com.ci.erp.human.interceptor;

import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;

import java.util.Map;

/**
 *
 * Websocket拦截器类
 *
 * @author lucky_fd
 * @since 2024-01-17
 */

public class WebSocketHandleInterceptor extends HttpSessionHandshakeInterceptor {

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        System.out.println("拦截器前置触发");
        return super.beforeHandshake(request, response, wsHandler, attributes);
    }

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception ex) {
        System.out.println("拦截器后置触发");
        super.afterHandshake(request, response, wsHandler, ex);
    }
}

创建前端页面客户端

最后,创建一个简单的HTML页面,用于接收用户输入并显示实时聊天信息。

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>Spring WebSocket Chat</title>
    <script src="https://code.jquery.com/jquery-3.6.4.min.js"></script>
    <script src="http://cdn.bootcss.com/sockjs-client/1.1.1/sockjs.js"></script>
</head>
<body>

请输入:<input type="text" id="message" placeholder="Type your message">
<button onclick="sendMessage()">Send</button>
<button onclick="websocketClose()">关闭连接</button>
<div id="chat"></div>

<script>
    var socket = null;
    if ('WebSocket' in window) {
    	// 后端服务port为22900
        socket = new WebSocket("ws://localhost:22900/websocket/server");
    } else if ('MozWebSocket' in window) {
        socket = new MozWebSocket("ws://localhost:22900/websocket/server");
    } else {
        socket = new SockJS("http://localhost:22900/sockjs/server");
    }

    // 接收消息触发
    socket.onmessage = function (event) {
        showMessage(event.data);
    };
    // 创建连接触发
    socket.onopen = function (event) {
        console.log(event.type);
    };
    // 连接异常触发
    socket.onerror = function (event) {
        console.log(event)
    };
    // 关闭连接触发
    socket.onclose = function (closeEvent) {
        console.log(closeEvent.reason);
    };

    //发送消息
    function sendMessage() {
        if (socket.readyState === socket.OPEN) {
            var message = document.getElementById('message').value;
            socket.send(message);
            console.log("发送成功!");
        } else {
            console.log("连接失败!");
        }

    }

    function showMessage(message) {
        document.getElementById('chat').innerHTML += '<p>' + message + '</p>';
    }

    function websocketClose() {
        socket.close();
        console.log("连接关闭");
    }

    window.close = function () {
        socket.onclose();
    };

</script>

</body>
</html>

这个页面使用了WebSocket对象来建立连接,并通过onmessage监听收到的消息。通过输入框发送消息,将会在页面上显示。

测试结果:

后端日志:

image

前端界面:

image

Java客户端

添加依赖

<dependency>
      <groupId>org.java-websocket</groupId>
      <artifactId>Java-WebSocket</artifactId>
      <version>1.4.0</version>
</dependency>

创建客户端类(继承WebsocketClient)

package com.river.websocket;
 
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.handshake.ServerHandshake;
 
import java.net.URI;
import java.net.URISyntaxException;
 
public class MyWebSocketClient extends WebSocketClient {
 
    MyWebSocketClient(String url) throws URISyntaxException {
        super(new URI(url));
    }
 	// 建立连接
    @Override
    public void onOpen(ServerHandshake shake) {
        System.out.println(shake.getHttpStatusMessage());
    }
 	// 接收消息
    @Override
    public void onMessage(String paramString) {
        System.out.println(paramString);
    }
 	// 关闭连接
    @Override
    public void onClose(int paramInt, String paramString, boolean paramBoolean) {
        System.out.println("关闭");
    }
 	// 连接异常
    @Override
    public void onError(Exception e) {
        System.out.println("发生错误");
    }
}

测试websocket

package com.river.websocket;
 
import org.java_websocket.enums.ReadyState;
 
import java.net.URISyntaxException;
 
/**
 * @author lucky_fd
 * @date 2024-1-17
 */
public class Client {
    public static void main(String[] args) throws URISyntaxException, InterruptedException {
        MyWebSocketClient client = new MyWebSocketClient("ws://localhost:22900/websocket/server");
        client.connect();
        while (client.getReadyState() != ReadyState.OPEN) {
            System.out.println("连接状态:" + client.getReadyState());
            Thread.sleep(100);
        }
        client.send("测试数据!");
        client.close();
    }
}

参考链接:

Python Coroutine 池化实现

池化介绍

在当今计算机科学和软件工程的领域中,池化技术如线程池、连接池和对象池等已经成为优化资源利用率和提高软件性能的重要工具。然而,在 Python 的协程领域,我们却很少见到类似于 ThreadPoolExecutor 的 CoroutinePoolExecutor。为什么会这样呢?

首先,Python Coroutine 的特性使得池化技术在协程中的应用相对较少。与像 Golang 这样支持有栈协程的语言不同,Python Coroutine 是无栈的,无法跨核执行,从而限制了协程池发挥多核优势的可能性。

其次,Python Coroutine 的轻量级和快速创建销毁的特性,使得频繁创建和销毁协程并不会带来显著的性能损耗。这也解释了为什么 Python 官方一直没有引入 CoroutinePoolExecutor。

然而,作为开发者,我们仍然可以在特定场景下考虑协程的池化。虽然 Python Coroutine 轻量,但在一些需要大量协程协同工作的应用中,池化技术能够提供更方便、统一的调度子协程的方式。尤其是在涉及到异步操作的同时需要控制并发数量时,协程池的优势就显而易见了。

关于 Python 官方是否会在未来引入类似于 TaskGroup 的 CoroutinePoolExecutor,这或许是一个悬而未决的问题。考虑到 Python 在异步编程方面的快速发展,我们不能排除未来可能性的存在。或许有一天,我们会看到 TaskGroup 引入一个 max_workers 的形参,以更好地支持对协程池的需求。

在实际开发中,我们也可以尝试编写自己的 CoroutinePoolExecutor,以满足特定业务场景的需求。通过合理的设计架构和对数据流的全局考虑,我们可以最大程度地发挥协程池的优势,提高系统的性能和响应速度。

在接下来的文章中,我们将探讨如何设计和实现一个简单的 CoroutinePoolExecutor,以及在实际项目中的应用场景。通过深入理解协程池的工作原理,我们或许能更好地利用这一技术,使我们的异步应用更为高效。

如何开始编写

如何开始编写 CoroutinePoolExecutor,首先我们要明确出其适用范畴、考虑到使用方式和其潜在的风险点:

  • 它并不适用于 Mult Thread + Mult Event Loop 的场景,因此它并非线程安全的。
  • 应当保持和 ThreadPoolExecutor 相同的调用方式。
  • 不同于 Mult Thread 中子线程不依赖于主线程的运行,而在 Mult Coroutine 中子协程必须依赖于主协程,因此主协程在子协程没有全部运行完毕之前不能直接 done 掉。这也解释了为什么 TaskGroup 官方实现中没有提供类似于 shutdown 之类的方法,而是只提供上下文管理的运行方式。

有了上述 3 点的考量,我们决定将 ThreadPoolExecutor 平替成 CoroutinePoolExecutor。这样的好处在于,作为学习者一方面可以了解 ThreadPoolExecutor 的内部实现机制,另一方面站在巨人肩膀上的编程借鉴往往会事半功倍,对于自我的提升也是较为明显的。

在考虑这些因素的同时,我们将继续深入研究协程池的设计和实现。通过对适用范围和使用方式的明确,我们能更好地把握 CoroutinePoolExecutor 的潜在优势,为异步应用的性能提升做出更有针对性的贡献。

具体代码实现

在这里我先贴出完整的代码实现,其中着重点已经用注释标明。

以下是 CoroutinePoolExecutor 的代码实现:

import os
import asyncio
import weakref
import logging
import itertools


async def _worker(executor_reference: "CoroutinePoolExecutor", work_queue: asyncio.Queue):
    try:
        while True:
            work_item = await work_queue.get()

            if work_item is not None:
                await work_item.run()
                del work_item

                executor = executor_reference()
                if executor is not None:
                    # Notify available coroutines
                    executor._idle_semaphore.release()
                del executor
                continue

            # Notifies the next coroutine task that it is time to exit
            await work_queue.put(None)
            break

    except Exception as exc:
        logging.critical('Exception in worker', exc_info=True)


class _WorkItem:
    def __init__(self, future, coro):
        self.future = future
        self.coro = coro

    async def run(self):
        try:
            result = await self.coro
        except Exception as exc:
            self.future.set_exception(exc)
        else:
            self.future.set_result(result)


class CoroutinePoolExecutor:
    """
    Coroutine pool implemented based on ThreadPoolExecutor
    Different from ThreadPoolExecutor, because the running of sub-coroutine depends on the main coroutine
    So you must use the shutdown method to wait for all subtasks and wait for them to complete execution
    """

    # Used to assign unique thread names when coroutine_name_prefix is not supplied.
    _counter = itertools.count().__next__

    def __init__(self, max_workers, coroutine_name_prefix=""):

        if max_workers is None:
            max_workers = min(32, (os.cpu_count() or 1) + 4)
        if max_workers <= 0:
            raise ValueError("max_workers must be greater than 0")

        self._max_workers = max_workers
        self._work_queue = asyncio.Queue()
        self._idle_semaphore = asyncio.Semaphore(0)
        self._coroutines = set()
        self._shutdown = False
        self._shutdown_lock = asyncio.Lock()
        self._coroutine_name_prefix = (coroutine_name_prefix or (
            f"{__class__.__name__}-{self._counter()}"
        ))

    async def submit(self, coro):
        async with self._shutdown_lock:
            # When the executor is closed, new coroutine tasks should be rejected, otherwise it will cause the problem that the newly added tasks cannot be executed.
            # This is because after shutdown, all sub-coroutines will end their work
            # one after another. Even if there are new coroutine tasks, they will not
            # be reactivated.
            if self._shutdown:
                raise RuntimeError('cannot schedule new coroutine task after shutdown')

            f = asyncio.Future()
            w = _WorkItem(
                f,
                coro
            )
            await self._work_queue.put(w)
            await self._adjust_coroutine_count()
            return f

    async def _adjust_coroutine_count(self):

        try:
            # 2 functions:
            # - When there is an idle coroutine and the semaphore is not 0, there is no need to create a new sub-coroutine.
            # - Prevent exceptions from modifying self._coroutines members when the for loop self._coroutines and await task in shutdown are modified
            # Since the Semaphore provided by asyncio does not have a timeout
            # parameter, you can choose to use it with wait_for.
            if await asyncio.wait_for(
                    self._idle_semaphore.acquire(),
                    0
            ):
                return
        except TimeoutError:
            pass

        num_coroutines = len(self._coroutines)
        if num_coroutines < self._max_workers:
            coroutine_name = f"{self._coroutine_name_prefix or self}_{num_coroutines}"
            t = asyncio.create_task(
                coro=_worker(
                    weakref.ref(self),
                    self._work_queue
                ),
                name=coroutine_name
            )

            self._coroutines.add(t)

    async def shutdown(self, wait=True, *, cancel_futures=False):
        async with self._shutdown_lock:
            self._shutdown = True

            if cancel_futures:
                while True:
                    try:
                        work_item = self._work_queue.get_nowait()
                    except asyncio.QueueEmpty:
                        break
                    if work_item is not None:
                        work_item.future.cancel()

            # None is an exit signal, given by the shutdown method, when the shutdown method is called
            # will notify the sub-coroutine to stop working and exit the loop
            await self._work_queue.put(None)

        if wait:
            for t in self._coroutines:
                await t

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.shutdown(wait=True)
        return False

以下是 CoroutinePoolExecutor 的使用方式:

import asyncio

from coroutinepoolexecutor import CoroutinePoolExecutor


async def task(i):
    await asyncio.sleep(1)
    print(f"task-{i}")


async def main():
    async with CoroutinePoolExecutor(2) as executor:
        for i in range(10):
            await executor.submit(task(i))

if __name__ == "__main__":
    asyncio.run(main())

我们知道,在线程池中,工作线程一旦创建会不断的领取新的任务并执行,除开 shutdown() 调用,否则对于静态的线程池来讲工作线程不会自己结束。

在上述协程池代码实现中,CoroutinePoolExecutor 类包含了主要的对外调用功能的接口、内部提供了存储 task 的 Queue、工作协程自动生成 name 的计数器、保障协程的信号量锁等等。

而 _worker 函数是工作协程的运行函数,其会在工作协程启动后,不断的从 CoroutinePoolExecutor 的 Queue 中得到 _WorkItem 并由 _WorkItem 具体执行 coro task。

剩下的 _WorkItem 是一个 future 对象与 coro task 的封装器,其功能是解耦 future 对象和 coro task、并在 coro task 运行时和运行后设置 future 的结果。

对于异步循环的思考

在此 CoroutinePoolExecutor 实现后,我其实又有了一个新的思考。Python 的 EventLoop 相较于 Node.js 的 EventLoop 来说其实更加的底层,它有感的暴露了出来。

具体体现在当 Python Event Loop 启动后,如果 main coroutine 停止运行,那么所有的 subtask coroutine 也会停止运行,尤其是对于一些需要清理资源的操作、如 aiohttp 的 close session、CoroutinePoolExecutor 的 shutdown 等都会在某些情况显得无措,说的更具体点就是不知道在什么时候调用。

对于这些问题,我们可以继承 BaseEventLoop 自己手动对 EventLoop 的功能进行扩展,如在事件循环关闭之前添加 hook function,甚至可以限制整个 EventLoop 的 max_workers 或者做成动态的可调节 coroutine 数量的 EventLoop 都行。

无论如何,只要心里有想法,就可以去将它实现 .. 学习本身就是一个不断挑战的过程。


引言


大家好,我是你们的老伙计秀才!今天带来的是[深入浅出Java多线程]系列的第二篇内容:Java多线程类和接口。大家觉得有用请点赞,喜欢请关注!秀才在此谢过大家了!!!

在现代计算机系统中,多线程技术是提升程序性能、优化资源利用和实现并发处理的重要手段。特别是在Java编程语言中,多线程机制被深度集成并广泛应用于高并发场景,如服务器响应多个客户端请求、大规模数据处理以及用户界面的实时更新等。理解并熟练掌握Java中的多线程创建与管理方式,不仅能帮助开发者充分利用硬件资源,还能有效避免竞态条件、死锁等并发问题,确保应用程序在多核处理器架构下运行得更为高效且稳定。

本文将深入探讨Java多线程编程的基本概念和技术细节。首先从最基础的Thread类入手,介绍如何通过继承Thread类或实现Runnable接口来定义并启动一个线程,强调start()方法对于激活线程执行的关键作用,并对比两种实现方式的优劣。同时,我们将揭开Thread类构造方法背后的秘密,详述各个参数的意义及初始化过程。

进一步地,文档将阐述Thread类中的一系列常用方法,包括获取当前线程引用的currentThread()方法、启动线程执行逻辑的start()方法、释放CPU时间片的yield()方法、控制线程暂停执行的sleep()方法,以及用于同步等待其他线程完成的join()方法。通过对这些方法的详细解读,读者能够更好地掌握Java线程间的协作和调度原理。

此外,为了满足异步任务执行和结果返回的需求,Java提供了Callable接口及其配套的Future和FutureTask类。Callable允许我们在新的线程中执行有返回值的任务,而Future作为异步计算的结果容器,可以用来查询任务是否完成、取消正在执行的任务以及获取计算结果。FutureTask则是对Future接口和Runnable接口功能的完美融合,它不仅封装了任务的执行逻辑,还提供了一种便捷的方式来管理和跟踪异步操作的状态。

综上所述,本文旨在引导逐步了解和掌握Java多线程编程的核心类与接口,并通过实际示例解析它们的工作机制和应用场景,为开发高性能、高并发的Java应用程序奠定坚实的基础。接下来的内容将逐一展开对上述关键知识点的详细讲解。


Java中创建与启动线程


Java中创建与启动线程(约800字)

在Java中,我们可以通过继承Thread类或实现Runnable接口来创建自定义的线程对象,并通过调用start()方法启动执行。这两种方式分别具有不同的应用场景和特点。


继承Thread类

通过直接继承Thread类并重写run()方法,可以便捷地创建一个具备特定任务逻辑的线程。以下是一个简单的示例:

public class MyCustomThread extends Thread {
    @Override
    public void run() {
        System.out.println("Inheriting from Thread class: " + Thread.currentThread().getName());
    }

    public static void main(String[] args) {
        MyCustomThread myThread = new MyCustomThread();
        myThread.start(); // 启动线程
    }
}

在这个例子中,
MyCustomThread
继承了Thread类并覆盖了run()方法,当调用start()方法时,JVM会为该线程分配资源并安排它在适当的时候执行run()方法中的代码。


注意

:每个线程只能调用一次start()方法。如果试图再次调用start(),将会抛出IllegalThreadStateException异常。这是因为一旦线程开始运行后,其生命周期已经进入执行阶段,不能重复初始化和启动。


实现Runnable接口

相较于继承Thread类,实现Runnable接口更为灵活,因为Java语言遵循单继承原则,而接口可以多重实现。这使得我们的类可以在继承其他类的同时实现多线程功能。以下是使用Runnable接口创建线程的示例:

public class RunnableTask implements Runnable {
    @Override
    public void run() {
        System.out.println("Implementing Runnable interface: " + Thread.currentThread().getName());
    }

    public static void main(String[] args) {
        RunnableTask task = new RunnableTask();
        Thread thread = new Thread(task, "MyRunnableThread");
        thread.start();
    }
}

在上述代码中,RunnableTask实现了Runnable接口并提供了run()方法的具体实现。然后,我们将RunnableTask实例传给Thread类的构造函数,创建了一个新的线程,并通过thread.start()来启动它。

此外,从Thread类的源码分析可知,Thread类是Runnable接口的一个实现类,其构造方法接收Runnable类型的参数target,并通过内部的init方法对其进行初始化。这样,无论我们是继承Thread还是实现Runnable,最终都是为了提供一个Runnable实例给Thread来执行具体的任务逻辑。

总结来说,Java提供了两种途径创建线程,各有优劣。继承Thread类的方式直观简洁,适用于轻量级的线程封装;而实现Runnable接口则更符合面向对象设计原则,避免了类层次结构的限制,提高了代码的可复用性和灵活性。在实际编程中,推荐优先考虑实现Runnable接口以保持代码结构清晰、易扩展。


Thread类构造方法详解

在Java中,Thread类的构造方法是创建线程对象并为其设置属性的核心途径。Thread类提供了多个构造函数以满足不同场景下的初始化需求,但它们最终都会调用到一个私有的
init
方法来完成线程对象的初始化。

// Thread类的部分源码片段:
private void init(ThreadGroup g, Runnable target, String name,
                        long stackSize, AccessControlContext acc,
                        boolean inheritThreadLocals)
 
{...}

public Thread(Runnable target) {
    init(null, target, "Thread-" + nextThreadNum(), 0);
}

上述代码揭示了Thread类的一个重要构造方法:接受一个Runnable类型的target参数,用于指定线程要执行的任务;同时为新创建的线程生成一个默认名称,并分配默认的栈大小。当通过这个构造器实例化Thread对象时,会调用内部的
init
方法进行详细的初始化操作:


  • g: ThreadGroup
    - 线程组,若不指定,默认值为null,表示线程将加入到当前应用程序的主要线程组中。

  • target: Runnable
    - 这个参数至关重要,它定义了线程执行体,即run()方法的具体内容。当我们实现Runnable接口或者继承Thread类重写run()方法时,实际就是给target赋值。

  • name: String
    - 指定线程的名字,如果没有提供名字,则系统会自动为其生成一个唯一的线程名。

  • stackSize: long
    - 栈的大小,通常情况下我们不会显式设置线程栈大小,这里使用默认值0,由JVM自行决定合适的栈空间大小。

  • acc: AccessControlContext
    - 安全控制上下文,用于控制线程执行权限,这是一个相对复杂且较少直接使用的概念,主要用于安全管理框架,如Java安全模型中的访问控制列表等。

  • inheritThreadLocals: boolean
    - 控制线程是否从父线程继承ThreadLocal变量。在多线程环境下,ThreadLocal可以为每个线程维护一个独立的变量副本,此处涉及到线程局部变量的传递问题。

此外,Thread类内部还包含了与ThreadLocal相关的两个私有属性
threadLocals

inheritableThreadLocals
,它们用于支持线程间的数据隔离以及特定情况下的线程本地变量继承。

总之,通过Thread类的构造方法,我们可以灵活地定制线程的各种属性,包括任务目标、线程名以及其他可能影响线程行为的因素。这些构造方法的设计充分体现了Java对线程管理的灵活性和可配置性。


Thread类常用方法


在Java中,Thread类提供了多种方法来管理和控制线程的生命周期及行为。以下将详细解析Thread类的几个核心方法,并对比使用Runnable接口与继承Thread类创建线程的方式。


Thread类的常用方法


  • currentThread()

    : 这是一个静态方法,返回对当前正在执行的线程对象的引用。例如:

    Thread currentThread = Thread.currentThread();
    System.out.println(currentThread.getName());

    通过这个方法可以获取当前线程信息并进行相应的操作。


  • start()

    : 用于启动一个线程,使其从新建状态进入就绪状态,然后等待操作系统调度执行。调用start()方法后,虚拟机内部会调用该线程的run()方法。多次调用start()会导致异常,因此确保只调用一次。


  • yield()

    : 表示当前线程愿意放弃CPU时间片,使其他同等优先级的线程有机会运行。但这并不是强制性的,实际调度结果取决于JVM和操作系统的实现。

    Thread t1 = new Thread(() -> {
        for (int i = 0; i < 5; i++) {
            Thread.yield();
            System.out.println(Thread.currentThread().getName() + ": " + i);
        }
    });
    t1.start();


  • sleep(long millis)

    : 让当前线程暂停指定毫秒数的时间,交出CPU使用权给其他线程。此方法会抛出InterruptedException,需妥善处理。


  • join()

    : 使当前线程等待另一个线程结束。当在一个线程上调用
    t.join()
    时,当前线程将被阻塞直到线程
    t
    完成其任务。

    Thread threadA = new Thread(() -> {
        // 执行耗时任务
    });
    Thread threadB = new Thread(() -> {
        try {
            threadA.join(); // 等待threadA执行完毕
            System.out.println("Thread B continues after A");
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    });

    threadA.start();
    threadB.start();


Thread类与Runnable接口的比较

  • 继承Thread类的方式可以直接访问Thread类中的诸多方法,但受到Java单继承的限制,若需要扩展已有类则不适用。
  • 实现Runnable接口更符合面向对象原则,因为Runnable是接口,可以实现多继承,降低了线程对象和线程任务之间的耦合度。并且,采用Runnable时,可以通过灵活组合Thread类的各种构造方法来创建线程实例。

例如,考虑以下两个实现方式:

// 继承Thread类方式
public class MyThread extends Thread {
    @Override
    public void run() {
        // 线程执行逻辑
    }

    public static void main(String[] args) {
        MyThread myThread = new MyThread();
        myThread.start();
    }
}

// 实现Runnable接口方式
public class RunnableTask implements Runnable {
    @Override
    public void run() {
        // 线程执行逻辑
    }

    public static void main(String[] args) {
        RunnableTask task = new RunnableTask();
        Thread thread = new Thread(task, "MyRunnableThread");
        thread.start();
    }
}

总结来说,尽管两种方式都可以创建并启动线程,但在复杂应用中,由于其灵活性和设计原则上的优势,通常推荐优先采用实现Runnable接口的方式来定义线程任务。同时结合Thread类提供的各种方法,可以更好地控制线程的行为和状态。


异步模型与Future接口



异步模型与Future接口

在Java多线程编程中,为了支持执行有返回值的任务并获取其结果,JDK引入了Callable接口和Future接口。这两种接口为开发者提供了处理异步任务的强大工具。


Callable接口

Callable接口提供了一个call()方法,它具有返回类型并且可以抛出异常,这使得线程能够执行一个可能需要较长时间且有明确结果的计算任务。例如:

import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class CallableExample {
    public static void main(String[] args) throws ExecutionException, InterruptedException {
        // 创建可缓存线程池
        ExecutorService executor = Executors.newCachedThreadPool();

        // 自定义实现Callable接口的任务类
        Callable<Integer> task = new Callable<Integer>() {
            @Override
            public Integer call() throws Exception {
                Thread.sleep(1000); // 模拟耗时操作
                return 42// 返回计算结果
            }
        };

        // 提交任务到线程池并获取Future对象
        Future<Integer> futureResult = executor.submit(task);

        // 使用get方法阻塞等待结果,并打印出来
        System.out.println("Future result: " + futureResult.get());

        // 关闭线程池
        executor.shutdown();
    }
}

在这个例子中,我们创建了一个实现了Callable接口的匿名内部类,该类在call()方法中模拟了一个耗时计算过程,并返回一个整数值作为结果。通过ExecutorService提交这个任务后,得到一个Future对象,然后调用其get()方法来阻塞等待计算结果。


Future接口

Future接口代表了一个异步计算的结果,它提供了几个关键方法:

  • cancel(boolean mayInterruptIfRunning) :试图取消正在执行的任务,如果mayInterruptIfRunning为true,则会尝试中断线程。
  • isCancelled() :检查是否已经取消了此任务。
  • isDone() :判断任务是否已完成,无论正常结束还是被取消。
  • get() get(long timeout, TimeUnit unit) :阻塞等待直到计算完成或超时,然后获取计算结果。如果不希望无限期等待,可以选择带有超时参数的方法。

使用Future接口的一个重要优势在于,它可以让我们以同步或异步的方式控制任务的执行和结果的获取。同时,由于Future提供了取消任务的能力,因此相比Runnable更适合那些需要随时中止的任务场景。

此外,JDK还提供了FutureTask类,它是Future接口和Runnable接口的实现类,既可以作为一个Runnable对象交给Thread或者ExecutorService执行,又能持有并管理计算结果。通过FutureTask,我们可以更方便地进行异步计算以及状态跟踪。


FutureTask类与用途

FutureTask类在Java多线程编程中扮演着关键角色,它是对Runnable接口和Future接口的融合实现。作为一个可运行且具有未来结果的任务封装器,FutureTask可以将任务提交给线程执行,并通过Future接口提供对任务状态、取消操作以及获取计算结果的支持。

import java.util.concurrent.Callable;
import java.util.concurrent.FutureTask;

public class FutureTaskExample {
    public static void main(String[] args) throws InterruptedException, ExecutionException {
        // 创建一个FutureTask实例,使用Callable实现的任务
        Callable<Integer> callable = () -> {
            Thread.sleep(1000); // 模拟耗时计算
            return 123// 返回计算结果
        };

        FutureTask<Integer> futureTask = new FutureTask<>(callable);

        // 将FutureTask作为Runnable对象提交到线程池或直接创建新线程启动
        Thread thread = new Thread(futureTask);
        thread.start();

        // 当前线程等待FutureTask完成并获取结果
        Integer result = futureTask.get();
        System.out.println("FutureTask计算结果: " + result);
    }
}

在上述示例中,我们首先定义了一个实现了Callable接口的任务对象,然后通过FutureTask来包装这个任务。FutureTask既可以直接传递给Thread对象使其成为一个可运行的任务,也可以提交给ExecutorService进行异步执行。当调用其get()方法时,当前线程会阻塞直到任务完成,之后返回计算得到的结果。

FutureTask的主要优势在于它为异步任务提供了生命周期管理功能,包括:


  • 任务调度
    :FutureTask可以被多个线程安全地调度,确保任务仅被执行一次。

  • 任务状态跟踪
    :内部维护了任务的状态机,可以通过isDone()等方法检查任务是否已完成或已被取消。

  • 结果获取
    :get()方法允许在任务完成后获取计算结果,支持阻塞等待和超时机制。

  • 任务取消
    :调用cancel方法可以尝试中断正在执行的任务,或者防止尚未开始的任务执行。

总之,FutureTask是Java并发框架中的重要组件,它结合了Runnable和Future的优点,使得异步任务的管理和控制更为灵活便捷,极大地提高了程序设计的效率和代码的可读性。



FutureTask的状态变迁


在Java多线程编程中,FutureTask类作为实现RunnableFuture接口的实例,不仅封装了任务执行逻辑,还负责管理任务状态。FutureTask内部维护了一个volatile的int型变量state来表示其生命周期中的不同状态。


  1. NEW
    :初始状态,表示FutureTask尚未开始执行。

  2. COMPLETING
    :瞬态状态,表示任务正在完成,即call()方法正在运行或者结果已经设置,等待后续的完成处理过程。

  3. NORMAL
    :正常结束状态,任务已成功执行并设置了结果。

  4. EXCEPTIONAL
    :异常结束状态,任务在执行过程中抛出了未捕获的异常,结果被设置为该异常对象。

  5. CANCELLED
    :取消状态,通过调用cancel方法且成功取消了任务,此时任务不会继续执行。

  6. INTERRUPTING
    :中断中状态,也是瞬态状态,表明正在进行取消操作,并尝试中断底层的任务执行线程。

  7. INTERRUPTED
    :已中断状态,意味着任务在取消过程中已被成功中断。

这些状态之间的转换路径如下:

  • NEW -> COMPLETING -> NORMAL 或 EXCEPTIONAL
  • NEW -> CANCELLED
  • NEW -> INTERRUPTING -> INTERRUPTED

FutureTask的设计确保了任务只执行一次,即使在并发环境下也能正确地管理状态变迁和结果返回。例如,在高并发场景下,如果有多个线程同时尝试启动一个FutureTask,它会保证仅有一个线程实际执行任务,其余线程等待结果。

以下是一个简单的FutureTask状态变迁的示例代码片段,但请注意,由于FutureTask内部对状态变更做了严格控制和同步处理,我们无法直接模拟所有状态变迁的过程:

public class FutureTaskStateExample {
    public static void main(String[] args) throws ExecutionException, InterruptedException {
        FutureTask<Integer> futureTask = new FutureTask<>(new Callable<Integer>() {
            @Override
            public Integer call() throws Exception {
                Thread.sleep(1000); // 模拟耗时计算
                return 42// 正常返回结果
            }
        });

        Thread t = new Thread(futureTask);
        t.start();

        // 在任务执行期间尝试取消
        futureTask.cancel(true);

        // 判断任务是否已取消或已完成
        if (futureTask.isCancelled()) {
            System.out.println("任务已取消");
        } else if (futureTask.isDone()) {
            System.out.println("任务已完成,结果:" + futureTask.get());
        }

        // 根据实际情况,这里可能输出"任务已取消"或"任务已完成"
    }
}

这段代码创建了一个FutureTask实例并在新线程中执行。在任务执行过程中尝试取消,根据最终状态判断任务是已取消还是已完成。真实情况下,FutureTask会确保按照预定义的状态变迁规则进行切换。


总结


Java多线程编程提供了丰富的类与接口,便于开发者高效地创建、管理和控制线程。在实际应用中,我们可以通过以下几种方式来实现:


  • 继承Thread类或实现Runnable接口
    :前者通过重写run()方法定义线程任务;后者更符合面向对象原则且不受单继承限制,允许通过构造函数传递Runnable实例给Thread类以启动新线程。示例代码展示了如何通过这两种途径创建并运行线程。
// 继承Thread类
public class MyThread extends Thread {
    @Override
    public void run() {
        System.out.println("Thread running");
    }

    public static void main(String[] args) {
        MyThread myThread = new MyThread();
        myThread.start();
    }
}

// 实现Runnable接口
public class RunnableTask implements Runnable {
    @Override
    public void run() {
        System.out.println("Runnable Task running");
    }

    public static void main(String[] args) {
        RunnableTask task = new RunnableTask();
        Thread thread = new Thread(task);
        thread.start();
    }
}


  • 使用Future和Callable进行异步计算
    :当需要获取线程执行结果时,可以结合Callable和Future接口实现异步模型。FutureTask作为这两个接口的实现,兼顾了任务执行和结果返回的功能。例如:
import java.util.concurrent.*;

public class FutureTaskExample {
    public static void main(String[] args) throws ExecutionException, InterruptedException {
        Callable<Integer> callable = () -> { return calculate(); };
        FutureTask<Integer> futureTask = new FutureTask<>(callable);

        Thread executor = new Thread(futureTask);
        executor.start();

        Integer result = futureTask.get(); // 阻塞等待结果
        System.out.println("FutureTask returned: " + result);
    }

    private static Integer calculate() {
        try {
            Thread.sleep(1000);
            return 42;
        } catch (InterruptedException e) {
            return -1;
        }
    }
}


  • FutureTask状态变迁
    :FutureTask内部维护了多种状态,如NEW、COMPLETING、NORMAL等,用于准确反映任务从初始化到完成或取消的全过程,确保并发环境下的正确性。

综上所述,在Java多线程编程中,通过灵活运用Thread、Runnable、Callable以及Future/FutureTask等工具,开发者能够更好地设计和管理复杂的并发场景,并利用异步编程提高系统性能与响应速度。深入理解这些类与接口的工作机制及应用场景,是构建高效稳定多线程应用程序的关键所在。同时,学习线程组、线程优先级等相关概念,将有助于进一步提升对Java多线程编程的全面掌控能力。

问题

最近跑师兄21年的论文代码,代码里使用了Pytorch分布式训练,在单机8卡的情况下,运行代码,出现如下问题。
image
也就是说GPU(1..7)上的进程占用了GPU0,这导致GPU0占的显存太多,以至于我的batchsize不能和原论文保持一致。

解决方法

我一点一点进行debug。
首先,在数据加载部分,由于没有将
local_rank

world_size
传入
get_cifar_iter
函数,导致后续使用DALI创建pipeline时使用了默认的
local_rank=0
,因此会在GPU0上多出该GPU下的进程
image

其次,在使用
torch.load
加载模型权重时,没有设置
map_location
,于是会默认加载到GPU0上,下图我选择将模型权重加载到cpu。虽然,这会使训练速度变慢,但为了和论文的batchsize保持一致也不得不这样做了。
-.-
image

参考文献

  1. nn.parallel.DistributedDataParallel多卡训练,第一张卡会多出进程?