2024年10月

大家好,我是每天分享AI应用的萤火君!

经常接触机器学习的同学可能都接触过Gradio这个框架,Gradio是一个基于Python的专门为机器学习项目创建的快速开发框架,可以让开发者快速发布自己的模型给用户测试,目前Huggingface上的机器学习项目都是基于Gradio对外提供服务的。

不过Gradio的目标是机器学习模型的快速演示,真正为用户提供服务时,我们还有很多需要关注的方面,比如用户的鉴权授权、消息通知、静态页面、SEO优化等等,这些使用Gradio有点捉襟见肘,我们还需要使用更加成熟的Web开发框架,比如Django这种。

但是我们初期可能已经用Gradio做了很多的功能,不想重写这些东西,这时候就产生了集成Gradio到其它框架的需求。这篇文章就来分享如何将Gradio集成到成熟的Web框架Django,以方便后来者。

创建Django项目

这里假设我们已经有了一个Gradio的项目,将在这个项目中继续创建一个Django项目。

创建 Django 项目

首先通过 pip 安装 Django

pip install django

然后在程序的根目录初始化Django项目
的一些基础文件:

django-admin startproject myproject
cd myproject

这里的 myproject 需要替换成你的 Django 项目名。

然后我们还要继续创建 Django 应用
,应用可以理解为模块,比如项目下有管理模块、用户模块、支付模块和具体的业务单元模块。每个应用都有自己的模型、视图、模板和 URL 路由。

python manage.py startapp myapp

请将myapp改为你的应用名称。

执行完这些命令之后,项目中将会增加一些Django的框架脚本。

创建 Django 页面

有了Django的基础脚本,然后就可以开发Web页面了。

1个页面涉及三个方面:视图、路由和HTML模板,还是以 myapp 为例:

在 myapp/views.py 中创建一个视图:

from django.shortcuts import render

def index(request):
    return render(request, 'index.html')

在 myapp/urls.py 中设置 URL 路由到这个视图:

from django.urls import path
from .views import index

urlpatterns = [
    path('', index, name='index'),
]

在 myapp/templates/index.html 创建 HTML 模板:

<!DOCTYPE html>
<html>
<head>
    <title>Gradio in Django</title>
</head>
<body>
    <h1>Welcome to My App</h1>
</body>
</html>

然后我们就可以启动程序,在浏览器访问这个页面了:

uvicorn myproject.wsgi:application --reload

启动程序使用的是 uvicorn工具,myproject是项目的名称,wsgi对应到myproject文件夹下的 wsgi.py。

集成Gradio到Django

准备一个Gradio项目

为了演示,这里准备一个Gradio的程序。

假设文件路径为:gradio/app.py

import gradio as gr

def greet(name):
    return f"Hello {name}!"

# 定义 Gradio 接口
demo = gr.Interface(fn=greet, inputs="text", outputs="text")

整合 Gradio 和 Django

现在我们把 Gradio 集成到 Django 中,它们将在同一个进程中运行,对外使用一个端口号。Django 默认通过根目录 / 进行访问,Gradio则通过 /gradio 进行访问。

这里走过一些弯路,有问题的方法就不讲了,直接给出我的方案。

这里还要引入一个框架 FastAPI,我们将使用 FastAPI 来代理对 Gradio 和 Django 的访问,所以其实不是将Gradio集成到Django,这个方法本质上是将 Gradio 和 Django 整合到一起。

打开 myproject/wsgi.py,这是 Django 项目的主文件:

import os
from django.core.wsgi import get_wsgi_application
from fastapi import Request, Response
from starlette.middleware.wsgi import WSGIMiddleware
import gradio as gr
from gradio.app import demo

os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'myproject.settings')

# 创建 FastAPI 应用
app = FastAPI()

# 挂载 Gradio 到FastAPI,注意这个path要和下边中间件中的一致
app = gr.mount_gradio_app(app, demo, path="/gradio")

# 获取 Django 的 WSGI 应用
django_app = get_wsgi_application()

# 注册一个FastAPI中间件,实现
@app.middleware("http")
async def route_middleware(request: Request, call_next):
   
    # 如果路径是 /gradio,则调用call_next,FastAPI框架会交给已经注册的 Gradio程序 处理
    if request.url.path.startswith("/gradio"):
        return await call_next(request)
    
    # 否则交给Django处理
    response = Response()
    
    async def send(message):
        if message['type'] == 'http.response.start':
            response.status_code = message['status']
            response.headers.update({k.decode(): v.decode() for k, v in message['headers']})
        elif message['type'] == 'http.response.body':
            response.body += message.get('body', b'')  # 注意这里用 += 来累积响应体
            
    await WSGIMiddleware(django_app)(request.scope, request.receive, send)
    
    response.headers["content-length"] = str(len(response.body))
    return response

这段代码的逻辑也比较简单,先创建FastAPI应用,然后将Gradio程序挂载到FastAPI,这里使用的是Gradio自带的mount_gradio_app方法,然后创建了一个FastAPI的中间件,对不同的路由使用不同的处理。

重点就在这个FastAPI中间件,它可以保证通过 /gradio 访问到Gradio程序,通过 / 访问到 Django 程序。

如果我们使用下面的这种方式来代理 Django,实测将不能通过 /gradio 访问到Gradio程序,无论 Gradio 和 Django 谁先注册。如果你的环境可以,欢迎留下你的各个 package 的版本。

app.mount("/", WSGIMiddleware(django_app))

静态文件的访问

因为静态文件是每个Web程序几乎避不开的,比如图片、css、js等,所以这里特别提下。

在上边的路由中间件中,除了 /gradio 会路由到Gradio程序,其它都会走Django进行处理,静态文件也不例外。

这里假设静态文件放在 static 目录下。

打开 myproject/settings.py,这是 Django 项目的基础设置文件,修改其中静态文件的部分:

STATIC_URL = '/static/'
if DEBUG:
    STATICFILES_DIRS = [
        os.path.join(BASE_DIR, "static"),
    ]
else:
    STATIC_ROOT = os.path.join(BASE_DIR, 'static')

打开 myproject/urls.py,修改其中的路由定义,增加 re_path 这一行。

urlpatterns = [
    re_path('^static/(?P<path>.*)', serve, {'document_root': settings.STATIC_ROOT}),
    path('', include('myapp.urls')),  # 包含 myapp 的 URL 配置
]

这样可以在调测和生产环境都能正常访问 static 目录下的静态文件,而不用再进行不同的设置。

总结

本文分享了一种整合 Gradio 和 Django 程序的方法,在这种方法下,Gradio 和 Django 可以使用同一个进程,使用相同的端口号对外服务,同时Gradio程序使用子目录 /gradio 进行访问,Django 程序使用根目录 / 进行访问。

因本人对 Django 和 Gradio 的了解有限,文中介绍的方法可能存在瑕疵,请谨慎使用。

关注萤火架构,加速技术提升!

这段时间比较热门的莫过于华为推出的自主研发的面向全场景分布式操作系统
HarmonyOS

https://developer.huawei.com/

最新一直潜心学习鸿蒙os开发,于是基于
HarmonyOS NEXT 5.0 API12 Release
开发了一款自定义多功能导航条组件。

HMNavBar组件支持
自定义返回键、标题/副标题、标题居中、背景色/背景图片/背景渐变色、标题颜色、搜索、右侧操作区
等功能。

如下图:组件结构非常简单。

组件参数配置

@Component
export struct HMNavBar {
//是否隐藏左侧的返回键 @Prop hideBackButton: boolean //标题(支持字符串|自定义组件) @BuilderParam title: ResourceStr | CustomBuilder =BuilderFunction//副标题 @BuilderParam subtitle: ResourceStr | CustomBuilder =BuilderFunction//返回键图标 @Prop backButtonIcon: Resource | undefined = $r('sys.symbol.chevron_left')//返回键标题 @Prop backButtonTitle: ResourceStr//背景色 @Prop bgColor: ResourceColor = $r('sys.color.background_primary')//渐变背景色 @Prop bgLinearGradient: LinearGradient//图片背景 @Prop bgImage: ResourceStr |PixelMap//标题颜色 @Prop fontColor: ResourceColor//标题是否居中 @Prop centerTitle: boolean //右侧按钮区域 @BuilderParam actions: Array<ActionMenuItem> | CustomBuilder =BuilderFunction//导航条高度 @Prop navbarHeight: number = 56 //... }

调用方式非常简单易上手。

  • 基础用法
HMNavBar({
backButtonIcon: $r(
'sys.symbol.arrow_left'),
title:
'鸿蒙自定义导航栏',
subtitle:
'HarmonyOS Next 5.0自定义导航栏',
})
  • 自定义返回图标/文字、标题、背景色、右键操作按钮

@Builder customImgTitle() {
Image(
'https://developer.huawei.com/allianceCmsResource/resource/HUAWEI_Developer_VUE/images/logo-developer-header.svg').height(24).objectFit(ImageFit.Contain)
}

HMNavBar({
backButtonIcon: $r(
'sys.symbol.chevron_left'),
backButtonTitle:
'返回',
title:
this.customImgTitle,
subtitle:
'鸿蒙5.0 api 12',
centerTitle:
true,
bgColor:
'#f3f6f9',
fontColor:
'#0a59f7',
actions: [
{
icon: $r(
'app.media.app_icon'),
action: ()
=> promptAction.showToast({ message: "show toast index 1"})
},
{
//icon: $r('sys.symbol.plus'), label: '更多>',
color:
'#bd43ff',
action: ()
=> promptAction.showToast({ message: "show toast index 2"})
}
]
})
//自定义渐变背景、背景图片、右侧操作区
HMNavBar({
hideBackButton:
true,
title:
'HarmonyOS',
subtitle:
'harmonyos next 5.0 api 12',
bgLinearGradient: {
angle:
135,
colors: [[
'#42d392 ',0.2], ['#647eff',1]]
},
//bgImage: 'pages/assets/nav_bg.png', //bgImage: 'https://developer.huawei.com/allianceCmsResource/resource/HUAWEI_Developer_VUE/images/1025-pc-banner.jpeg', fontColor: '#fff',
actions: [
{
icon:
'https://developer.huawei.com/allianceCmsResource/resource/HUAWEI_Developer_VUE/images/yuanfuwuicon.png',
action: ()
=> promptAction.showToast({ message: "show toast index 1"})
},
{
icon:
'https://developer.huawei.com/allianceCmsResource/resource/HUAWEI_Developer_VUE/images/0620logo4.png',
action: ()
=> promptAction.showToast({ message: "show toast index 2"})
},
{
icon: $r(
'sys.symbol.person_crop_circle_fill_1'),
action: ()
=> promptAction.showToast({ message: "show toast index 3"})
}
],
navbarHeight:
70})

如上图:还支持自定义
导航搜索
功能。

HMNavBar({
title:
this.customSearchTitle1,
actions:
this.customSearchAction
})

HMNavBar({
hideBackButton:
true,
title:
this.customSearchTitle2,
bgColor:
'#0051ff',
})

HMNavBar({
backButtonIcon: $r(
'sys.symbol.arrow_left'),
backButtonTitle:
'鸿蒙',
title:
this.customSearchTitle3,
bgColor:
'#c543fd',
fontColor:
'#fff',
actions: [
{
icon: $r(
'sys.symbol.mic_fill'),
action: ()
=>promptAction.showToast({ ... })
}
]
})

HMNavBar导航组件布局结构如下。

build() {
Flex() {
//左侧模块 Stack({ alignContent: Alignment.Start }) {
Flex(){
if(!this.hideBackButton) {this.backBuilder()
}
if(!this.centerTitle) {this.contentBuilder()
}
}
.height(
'100%')
}
.height(
'100%')
.layoutWeight(
1)//中间模块 if(this.centerTitle) {
Stack() {
this.contentBuilder()
}
.height(
'100%')
.layoutWeight(
1)
}
//右键操作模块 Stack({ alignContent: Alignment.End }) {this.actionBuilder()
}
.padding({ right:
16})
.height(
'100%')
.layoutWeight(
this.centerTitle ? 1 : 0)
}
.backgroundColor(
this.bgColor)
.linearGradient(
this.bgLinearGradient)
.backgroundImage(
this.bgImage, ImageRepeat.NoRepeat).backgroundImageSize(ImageSize.FILL)
.height(
this.navbarHeight)
.width(
'100%')
}

支持悬浮在背景图上面。

最后,附上几个最新研发的跨平台实例项目。

Tauri2.0+Vite5聊天室|vue3+tauri2+element-plus仿微信|tauri聊天应用

tauri2.0-admin桌面端后台系统|tauri2+vite5+element-plus管理后台EXE程序

Electron32-ViteOS桌面版os系统|vue3+electron+arco客户端OS管理模板

Vite5+Electron聊天室|electron31跨平台仿微信EXE客户端|vue3聊天程序

Ok,今天的分享就到这里,希望以上的分享对大家有所帮助!

---  好的方法很多,我们先掌握一种  ---

【背景】

对于网页信息的采集,静态页面我们通常都可以通过python的request.get()库就能获取到整个页面的信息。

但是对于动态生成的网页信息来说,我们通过request.get()是获取不到。

【方法】

可以通过python第三方库selenium来配合实现信息获取,采取方案:python + request + selenium + BeautifulSoup

我们拿纵横中文网的小说采集举例(注意:请查看网站的robots协议找到可以爬取的内容,所谓盗亦有道):

思路整理:

1.通过selenium 定位元素的方式找到小说章节信息

2.通过BeautifulSoup加工后提取章节标题和对应的各章节的链接信息

3.通过request +BeautifulSoup 按章节链接提取小说内容,并将内容存储下来

【上代码】

1.先在开发者工具中,调试定位所需元素对应的xpath命令编写方式

2.通过selenium 中find_elements()定位元素的方式找到所有小说章节,我们这里定义一个方法接受参数来使用

def Get_novel_chapters_info(url:str,xpath:str,skip_num=None,chapters_num=None):#skip_num 需要跳过的采集章节(默认不跳过),chapters_num需要采集的章节数(默认全部章节)
        #创建Chrome选项(禁用图形界面)
        chrome_options =Options()
chrome_options.add_argument(
"--headless")
driver
= webdriver.Chrome(options=chrome_options)
driver.get(url)
driver.maximize_window()
time.sleep(
3)#采集小说的章节元素 catalogues_list =[]try:
catalogues
=driver.find_elements(By.XPATH,xpath)if skip_num isNone:for catalogue incatalogues:
catalogues_list.append(catalogue.get_attribute(
'outerHTML'))
driver.quit()
if chapters_num isNone:returncatalogues_listelse:returncatalogues_list[:chapters_num]else:for catalogue incatalogues[skip_num:]:
catalogues_list.append(catalogue.get_attribute(
'outerHTML'))
driver.quit()
if chapters_num isNone:returncatalogues_listelse:returncatalogues_list[:chapters_num]exceptException:
driver.quit()

3.把采集到的信息通过beautifulsoup加工后,提取章节标题和链接内容

        #获取章节标题和对应的链接信息
        title_link ={}for each incatalogues_list:
bs
= BeautifulSoup(each,'html.parser')
chapter
= bs.find('a')
title
=chapter.text
link
= 'https:' + chapter.get('href')
title_link[title]
= link

4.通过request+BeautifulSoup 按章节链接提取小说内容,并保存到一个文件中

        #按章节保存小说内容
        novel_path = '小说存放的路径/小说名称.txt'with open(novel_path,'a') as f:for title,url intitle_link.items():
response
= requests.get(url,headers={'user-agent':'Mozilla/5.0'})
html
= response.content.decode('utf-8')
soup
= BeautifulSoup(html,'html.parser')
content
= soup.find('div',class_='content').text#先写章节标题,再写小说内容 f.write('---小西瓜免费小说---' + '\n'*2)
f.write(title
+ '\n')
f.write(content
+'\n'*3)
                

Java 的 IO(输入/输出)操作是处理数据流的关键部分,涉及到文件、网络等多种数据源。以下将深入探讨 Java IO 的不同类型、底层实现原理、使用场景以及性能优化策略。

1. Java IO 的分类

Java IO 包括两大主要包:
java.io

java.nio

1.1 java.io 包

  • 字节流:用于处理二进制数据,主要有 InputStream 和 OutputStream,如
    FileInputStream

    FileOutputStream
  • 字符流:用于处理字符数据,主要有 Reader 和 Writer,如
    FileReader

    FileWriter

示例代码

// 字节流示例
try (FileInputStream fis = new FileInputStream("input.txt");
     FileOutputStream fos = new FileOutputStream("output.txt")) {
    int byteData;
    while ((byteData = fis.read()) != -1) {
        fos.write(byteData);
    }
}

// 字符流示例
try (FileReader fr = new FileReader("input.txt");
     FileWriter fw = new FileWriter("output.txt")) {
    int charData;
    while ((charData = fr.read()) != -1) {
        fw.write(charData);
    }
}

1.2 java.nio包

  • 通道和缓冲区:NIO 引入了通道(Channel)和缓冲区(Buffer)的概念,支持非阻塞 IO 和选择器(Selector)。如
    FileChannel

    ByteBuffer

示例代码

try (FileChannel fileChannel = new FileInputStream("input.txt").getChannel()) {
    ByteBuffer buffer = ByteBuffer.allocate(1024);
    while (fileChannel.read(buffer) > 0) {
        buffer.flip(); // 切换读模式
        while (buffer.hasRemaining()) {
            System.out.print((char) buffer.get());
        }
        buffer.clear(); // 清空缓冲区
    }
}

2. Java IO 的设计考虑

2.1 面向流的抽象

Java IO 的核心在于“流”的概念。流允许程序以统一的方式处理数据,无论数据来自文件、网络还是其他源。流的抽象设计使得开发者能够轻松地进行数据读写操作。

  • 输入流与输出流

    InputStream

    OutputStream
    是所有字节流的超类,而
    Reader

    Writer
    则是字符流的超类。这样的设计确保了所有流都有统一的接口,使得代码可读性和可维护性增强。
  • 流的链式调用
    :通过使用装饰器模式,开发者可以将多个流组合在一起,例如将
    BufferedInputStream
    包装在
    FileInputStream
    外部,增加缓冲功能。

2.2 装饰器模式

Java IO 大量使用装饰器模式来增强流的功能。例如:

  • 缓冲流

    BufferedInputStream

    BufferedOutputStream
    可以提高读取和写入的效率,减少对底层系统调用的频繁访问。
  • 数据流

    DataInputStream

    DataOutputStream
    允许以原始 Java 数据类型读写数据,提供了一种简单的方式来处理二进制数据。

3. 底层原理

3.1 字节流与字符流的实现

  • 字节流的实现
    :Java 字节流通过
    FileDescriptor
    直接与操作系统的文件描述符交互。每当你调用
    read()

    write()
    方法时,Java 实际上是在调用系统级别的 IO 操作。这涉及用户态和内核态的切换,可能会导致性能下降。
  • 字符流的实现
    :字符流需要在底层进行字符编码和解码。
    InputStreamReader

    OutputStreamWriter
    是将字节转换为字符的桥梁。Java 使用不同的编码(如 UTF-8、UTF-16 等)来处理不同语言的字符,确保在全球范围内的兼容性。

3.2 NIO 的底层实现

  • 通道(Channel)
    :NIO 的
    Channel
    是双向的,允许同时读写。它直接与操作系统的 IO 操作交互,底层依赖于文件描述符。在高性能应用中,通道能够有效地传输数据。
  • 缓冲区(Buffer)
    :NIO 的
    Buffer
    是一个连续的内存区域,提供了读写操作的基本单元。缓冲区的实现底层使用 Java 的数组,但增加了指针管理(position、limit 和 capacity)以优化数据传输。
  • 选择器(Selector)
    :Selector 是 NIO 的核心组件之一,它允许单个线程监控多个通道的事件。底层依赖于操作系统提供的高效事件通知机制(如 Linux 的
    epoll
    和 BSD 的
    kqueue
    ),使得处理成千上万的并发连接成为可能。

4. 使用场景

4.1 文件处理

  • 大文件读取
    :在处理大文件时,NIO 的
    FileChannel

    ByteBuffer
    可以有效地减少内存使用和提高读写速度。例如,使用映射文件(Memory-Mapped Files)可以将文件直接映射到内存,从而实现高效的数据访问。
try (FileChannel fileChannel = FileChannel.open(Paths.get("largefile.txt"), StandardOpenOption.READ)) {
    MappedByteBuffer mappedBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
    // 直接在内存中处理数据
}

4.2 网络编程

  • 高并发服务器
    :在高并发场景下,使用 NIO 的非阻塞 IO 模型可以显著提高性能。例如,构建一个聊天服务器时,使用选择器能够处理大量的用户连接而不占用过多线程资源。
Selector selector = Selector.open();
ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
serverSocketChannel.configureBlocking(false);
serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);

4.3 数据流处理

  • 对象序列化与反序列化
    :在分布式系统中,使用
    ObjectInputStream

    ObjectOutputStream
    可以方便地进行对象的传输。这在 RMI 和其他需要对象共享的场景中非常常见。

5. 常见问题

5.1 IO 阻塞

传统的
java.io
操作是阻塞的,当 IO 操作未完成时,线程会被阻塞。这可能导致性能瓶颈,尤其在高并发情况下。

解决方案
:使用 NIO 的非阻塞 IO,结合选择器,可以让线程在等待 IO 操作时处理其他任务,从而提高吞吐量。

5.2 资源泄露

未正确关闭流会导致资源泄露,尤其在频繁的 IO 操作中,长时间未释放资源可能导致内存和文件句柄的耗尽。

解决方案
:使用
try-with-resources
语句自动管理流的生命周期,确保资源被及时释放。

try (BufferedReader br = new BufferedReader(new FileReader("file.txt"))) {
    // 读取文件
}

5.3 性能瓶颈

在小文件或频繁 IO 操作时,每次系统调用都可能导致性能开销。

解决方案
:使用缓冲流,减少对底层系统的直接调用。对于大量小文件的操作,可以将多个文件合并成一个大文件进行处理。

6. 性能优化

  • 使用缓冲流
    :通过使用
    BufferedInputStream

    BufferedOutputStream
    ,可以有效减少系统调用的次数。
  • 异步 IO
    :对于需要高性能的应用,考虑使用异步 IO(如 Java 7 的
    AsynchronousFileChannel

    AsynchronousSocketChannel
    ),可以进一步提高并发性能。
  • 优化对象序列化
    :在序列化过程中,避免使用
    ObjectInputStream

    ObjectOutputStream
    的默认实现,可以考虑使用更高效的序列化库(如 Kryo、Protobuf)来降低序列化和反序列化的开销。

# 借用Ultralytics Yolo快速训练一个物体检测器

[同步发表于 https://www.codebonobo.tech/post/14](https://www.codebonobo.tech/post/14 "https://www.codebonobo.tech/post/14")

大约在16/17年, 深度学习刚刚流行时, Object Detection 还是相当高端的技术, 各大高校还很流行水Fast RCNN / Faster RCNN之类的论文, 干着安全帽/行人/车辆检测之类的横项. 但到了2024年, 随着技术成熟, 物体检测几乎已经是个死方向了, 现在的学校应该忙着把别人好不容易训练的通用大模型退化成各领域的专用大模型吧...哈哈

回到主题, 目前物体检测模型的训练已经流程化了, 不再需要费脑去写训练代码, 使用Ultralytics Yolo就可以快速训练一个特定物品的检测器, 比如安全帽检测.

[https://github.com/ultralytics/ultralytics](https://github.com/ultralytics/ultralytics)

![](https://raw.githubusercontent.com/ultralytics/assets/main/yolov8/banner-yolov8.png)

# Step-1 准备数据集
你需要一些待检测物体比如安全帽, 把它从各个角度拍摄一下. 再找一些不相关的背景图片. 然后把安全帽给放大缩小旋转等等贴到背景图片上去, 生成一堆训练数据.

配置文件:

```python
extract_cfg:
output_dir: '/datasets/images'
fps: 0.25

screen_images_path: '/datasets/待检测图片'
max_scale: 1.0
min_scale: 0.1
manual_scale: [ {name: 'logo', min_scale: 0.05, max_scale: 0.3},
{name: 'logo', min_scale: 0.1, max_scale: 0.5},
{name: '箭头', min_scale: 0.1, max_scale: 0.5}
]
data_cfgs: [ {id: 0, name: 'logo', min_scale: 0.05, max_scale: 0.3, gen_num: 2},
{id: 1, name: '截屏', min_scale: 0.1, max_scale: 1.0, gen_num: 3, need_full_screen: true},
{id: 2, name: '红包', min_scale: 0.1, max_scale: 0.5, gen_num: 2},
{id: 3, name: '箭头', min_scale: 0.1, max_scale: 0.5, gen_num: 2, rotate_aug: true},
]
save_oss_dir: /datasets/gen_datasets/
gen_num_per_image: 2
max_bg_img_sample:
```

数据集生成:

```python
from pathlib import Path
import io
import random

import cv2
import numpy as np
from PIL import Image
import hydra
from omegaconf import DictConfig
import json
from tqdm import tqdm

# 加载图片
def load_images(background_path, overlay_path):
background = cv2.imread(background_path)
overlay = cv2.imread(overlay_path, cv2.IMREAD_UNCHANGED)
return background, overlay

# 随机缩放和位置
def random_scale_and_position(bg_shape, overlay_shape, max_scale=1.0, min_scale=0.1):
max_height, max_width = bg_shape[:2]
overlay_height, overlay_width = overlay_shape[:2]

base_scale = min(max_height / overlay_height, max_width / overlay_width)

# 随机缩放
scale_factor = random.uniform(
min_scale * base_scale, max_scale * base_scale)
new_height, new_width = int(
overlay_height * scale_factor), int(overlay_width * scale_factor)

# 随机位置
max_x = max_width - new_width - 1
max_y = max_height - new_height - 1
position_x = random.randint(0, max_x)
position_y = random.randint(0, max_y)

return scale_factor, (position_x, position_y)

def get_resized_overlay(overlay, scale):
overlay_resized = cv2.resize(overlay, (0, 0), fx=scale, fy=scale)
return overlay_resized

def rotate_image(img, angle):
if isinstance(img, np.ndarray):
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA))

# 确保图像具有alpha通道(透明度)
img = img.convert("RGBA")
# 旋转原始图像并粘贴到新的透明图像框架中
rotated_img = img.rotate(angle, resample=Image.BICUBIC, expand=True)
rotated_img = np.asarray(rotated_img)
return cv2.cvtColor(rotated_img, cv2.COLOR_RGBA2BGRA)

# 合成图片
def overlay_image(background, overlay_resized, position, scale):
h, w = overlay_resized.shape[:2]
x, y = position

# 透明度处理
alpha_s = overlay_resized[:, :, 3] / 255.0
alpha_l = 1.0 - alpha_s

for c in range(0, 3):
background[y:y + h, x:x + w, c] = (alpha_s * overlay_resized[:, :, c] +
alpha_l * background[y:y + h, x:x + w, c])

# 画出位置,调试使用
# print("position", x, y, w, h)
# cv2.rectangle(background, (x, y), (x + w, y + h), (0, 255, 0), 2)

background = cv2.cvtColor(background, cv2.COLOR_BGR2RGB)

return Image.fromarray(background)

class Box:
def __init__(self, x, y, width, height, category_id, image_width, image_height):
self.x = x
self.y = y
self.width = width
self.height = height
self.image_width = image_width
self.image_height = image_height
self.category_id = category_id

def to_yolo_format(self):
x_center = (self.x + self.width / 2) / self.image_width
y_center = (self.y + self.height / 2) / self.image_height
width = self.width / self.image_width
height = self.height / self.image_height
box_line = f"{self.category_id} {x_center} {y_center} {width} {height}"
return box_line

class SingleCategoryGen:
def __init__(self, cfg, data_cfg, output_dir):
self.output_dir = output_dir
self.screen_png_images = []
self.coco_images = []
self.coco_annotations = []
screen_images_path = Path(
cfg.screen_images_path.format(user_root=user_root))

self.manual_scale = {}

self.data_cfg = data_cfg
self.category_id = data_cfg.id
self.category_name = self.data_cfg.name
self.max_scale = self.data_cfg.max_scale
self.min_scale = self.data_cfg.min_scale
self.gen_num = self.data_cfg.gen_num
self.rotate_aug = self.data_cfg.get("rotate_aug", False)
self.need_full_screen = self.data_cfg.get("need_full_screen", False)

self.category_num = 0
self.category_names = {}

self.butcket = get_oss_bucket(cfg.bucket_name)
output_dir = Path(output_dir)
save_oss_dir = f"{cfg.save_oss_dir}/{output_dir.parent.name}/{output_dir.name}"
self.save_oss_dir = save_oss_dir
self.images_save_oss_dir = f"{save_oss_dir}/images"
self.label_save_oss_dir = f"{save_oss_dir}/labels"
self.annotations_save_oss_path = f"{save_oss_dir}/annotations.json"

self.load_screen_png_images_and_category(screen_images_path)

def load_screen_png_images_and_category(self, screen_images_dir):
screen_images_dir = Path(screen_images_dir)
category_id = self.category_id
screen_images_path = screen_images_dir / self.category_name
img_files = [p for p in screen_images_path.iterdir() if p.suffix in [
".png", ".jpg"]]
img_files.sort(key=lambda x: x.stem)
for i, img_file in enumerate(img_files):
self.screen_png_images.append(
dict(id=i, name=img_file.stem, supercategory=None, path=str(img_file)))

def add_new_images(self, bg_img_path: Path, gen_image_num=None, subset="train"):
gen_image_num = gen_image_num or self.gen_num
background_origin = cv2.imread(str(bg_img_path))
if background_origin is None:
print(f"open image {bg_img_path} failed")
return
max_box_num = 1

for gen_id in range(gen_image_num):
background = background_origin.copy()
category_id = self.category_id
overlay_img_path = self.sample_category_data()

overlay = cv2.imread(overlay_img_path, cv2.IMREAD_UNCHANGED)
if overlay.shape[2] == 3:
overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2BGRA)

if self.rotate_aug:
overlay = rotate_image(overlay, random.uniform(-180, 180))

# # 随机裁剪图片
# if random.random() < 0.5:
# origin_height = overlay.shape[0]
# min_height = origin_height // 4
# new_height = random.randint(min_height, origin_height)
# new_top = random.randint(0, origin_height - new_height)
# overlay = overlay[new_top:new_top+new_height, :, :]

box_num = random.randint(1, max_box_num)
# 获取随机缩放和位置
max_scale = self.max_scale
min_scale = self.min_scale

scale, position = random_scale_and_position(
background.shape, overlay.shape, max_scale, min_scale)

# 缩放overlay图片
overlay_resized = get_resized_overlay(overlay, scale)

# 合成后的图片
merged_img = overlay_image(background, overlay_resized, position, scale)

# 保存合成后的图片
filename = f"{bg_img_path.stem}_{category_id}_{gen_id:02d}.png"

merged_img.save(f'{output_dir}/{filename}')

# 生成COCO格式的标注数据
box = Box(*position, overlay_resized.shape[1], overlay_resized.shape[0], category_id, background.shape[1],
background.shape[0])
self.upload_image_to_oss(merged_img, filename, subset, [box])

def sample_category_data(self):
return random.choice(self.screen_png_images)["path"]

image_id = self.gen_image_id()

image_json = {
"id": image_id,
"width": image.width,
"height": image.height,
"file_name": image_name,
}
self.coco_images.append(image_json)

annotation_json = {
"id": image_id,
"image_id": image_id,
"category_id": 0,
"segmentation": None,
"area": bbox[2] * bbox[3],
"bbox": bbox,
"iscrowd": 0
}
self.coco_annotations.append(annotation_json)

def upload_image_to_oss(self, image, image_name, subset, box_list=None):
image_bytesio = io.BytesIO()
image.save(image_bytesio, format="PNG")
self.butcket.put_object(
f"{self.images_save_oss_dir}/{subset}/{image_name}", image_bytesio.getvalue())
if box_list:
label_str = "\n".join([box.to_yolo_format() for box in box_list])
label_name = image_name.split(".")[0] + ".txt"
self.butcket.put_object(
f"{self.label_save_oss_dir}/{subset}/{label_name}", label_str)

def upload_full_screen_image(self):
if not self.need_full_screen:
return
name = self.category_name
category_id = self.category_id
image_list = self.screen_png_images
subset_list = ["train" if i % 10 <= 7 else "val" if i %
10 <= 8 else "test" for i in range(len(image_list))]
for i in range(len(image_list)):
image_data = image_list[i]
subset = subset_list[i]
overlay_img_path = image_data["path"]
image = Image.open(overlay_img_path)
if random.random() < 0.5:
origin_height = image.height
min_height = origin_height // 4
new_height = random.randint(min_height, origin_height)
new_top = random.randint(0, origin_height - new_height)
image = image.crop(
(0, new_top, image.width, new_top + new_height))
filename = f"{name}_{category_id}_{i:05}.png"
box = Box(0, 0, image.width, image.height,
category_id, image.width, image.height)
self.upload_image_to_oss(image, filename, subset, [box])

class ScreenDatasetGen:
def __init__(self, cfg, output_dir):
self.output_dir = output_dir
self.screen_png_images = {}
self.coco_images = []
self.coco_annotations = []
screen_images_path = Path(
cfg.screen_images_path.format(user_root=user_root))
self.max_scale = cfg.max_scale
self.min_scale = cfg.min_scale
self.manual_scale = {}
for info in cfg.manual_scale:
self.manual_scale[info.name] = dict(
max_scale=info.max_scale, min_scale=info.min_scale)
self.category_num = 0
self.category_names = {}
self.category_id_loop = -1

self.butcket = get_oss_bucket(cfg.bucket_name)
output_dir = Path(output_dir)
save_oss_dir = f"{cfg.save_oss_dir}/{output_dir.parent.name}/{output_dir.name}"
self.save_oss_dir = save_oss_dir
self.images_save_oss_dir = f"{save_oss_dir}/images"
self.label_save_oss_dir = f"{save_oss_dir}/labels"
self.annotations_save_oss_path = f"{save_oss_dir}/annotations.json"

self.load_screen_png_images_and_category(screen_images_path)

def add_new_images(self, bg_img_path: Path, gen_image_num=1, subset="train"):
background_origin = cv2.imread(str(bg_img_path))
if background_origin is None:
print(f"open image {bg_img_path} failed")
return
max_box_num = 1

for gen_id in range(gen_image_num):
background = background_origin.copy()
category_id = self.get_category_id_loop()
overlay_img_path = self.sample_category_data(
category_id, subset=subset)

overlay = cv2.imread(overlay_img_path, cv2.IMREAD_UNCHANGED)
if overlay.shape[2] == 3:
overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2BGRA)

# # 随机裁剪图片
# if random.random() < 0.5:
# origin_height = overlay.shape[0]
# min_height = origin_height // 4
# new_height = random.randint(min_height, origin_height)
# new_top = random.randint(0, origin_height - new_height)
# overlay = overlay[new_top:new_top+new_height, :, :]

box_num = random.randint(1, max_box_num)
# 获取随机缩放和位置
category_name = self.category_names[category_id]
if category_name in self.manual_scale:
max_scale = self.manual_scale[category_name]["max_scale"]
min_scale = self.manual_scale[category_name]["min_scale"]
else:
max_scale = self.max_scale
min_scale = self.min_scale
scale, position = random_scale_and_position(
background.shape, overlay.shape, max_scale, min_scale)

# 缩放overlay图片
overlay_resized = get_resized_overlay(overlay, scale)

# 合成后的图片
merged_img = overlay_image(
background, overlay_resized, position, scale)

# 保存合成后的图片
filename = f"{bg_img_path.stem}_{category_id}_{gen_id:02d}.png"

# merged_img.save(f'{output_dir}/{filename}')

# 生成COCO格式的标注数据
box = Box(*position, overlay_resized.shape[1], overlay_resized.shape[0], category_id, background.shape[1],
background.shape[0])
self.upload_image_to_oss(merged_img, filename, subset, [box])
# self.add_image_annotion_to_coco(box, merged_img, filename)

def upload_full_screen_image(self, category_name=None):
if category_name is None:
return
if not isinstance(category_name, list):
category_name = [category_name]
for category_id in range(self.category_num):
name = self.category_names[category_id]
if name not in category_name:
continue
image_list = self.screen_png_images[category_id]
subset_list = ["train" if i % 10 <= 7 else "val" if i %
10 <= 8 else "test" for i in range(len(image_list))]
for i in range(len(image_list)):
image_data = image_list[i]
subset = subset_list[i]
overlay_img_path = image_data["path"]
image = Image.open(overlay_img_path)
if random.random() < 0.5:
origin_height = image.height
min_height = origin_height // 4
new_height = random.randint(min_height, origin_height)
new_top = random.randint(0, origin_height - new_height)
image = image.crop(
(0, new_top, image.width, new_top + new_height))
filename = f"{name}_{category_id}_{i:05}.png"
box = Box(0, 0, image.width, image.height,
category_id, image.width, image.height)
self.upload_image_to_oss(image, filename, subset, [box])

def load_screen_png_images_and_category(self, screen_images_dir):
screen_images_dir = Path(screen_images_dir)
screen_images_paths = [
f for f in screen_images_dir.iterdir() if f.is_dir()]
screen_images_paths.sort(key=lambda x: x.stem)
for category_id, screen_images_path in enumerate(screen_images_paths):
img_files = [p for p in screen_images_path.iterdir() if p.suffix in [
".png", ".jpg"]]
img_files.sort(key=lambda x: x.stem)
self.screen_png_images[category_id] = []
self.category_names[category_id] = screen_images_path.stem
print(f"{category_id}: {self.category_names[category_id]}")
for i, img_file in enumerate(img_files):
self.screen_png_images[category_id].append(
dict(id=i, name=img_file.stem, supercategory=None, path=str(img_file)))

self.category_num = len(screen_images_paths)
print(f"category_num: {self.category_num}")

def get_category_id_loop(self):
# self.category_id_loop = (self.category_id_loop + 1) % self.category_num
self.category_id_loop = random.randint(0, self.category_num - 1)
return self.category_id_loop

def sample_category_data(self, category_id, subset):
image_data = self.screen_png_images[category_id]
# valid_id = []
# if subset == "train":
# valid_id = [i for i in range(len(image_data)) if i % 10 <= 7]
# elif subset == "val":
# valid_id = [i for i in range(len(image_data)) if i % 10 == 8]
# elif subset == "test":
# valid_id = [i for i in range(len(image_data)) if i % 10 == 9]
# image_data = [image_data[i] for i in valid_id]
return random.choice(image_data)["path"]

def gen_image_id(self):
return len(self.coco_images) + 1

def add_image_annotion_to_coco(self, bbox, image: Image.Image, image_name):
image_id = self.gen_image_id()

image_json = {
"id": image_id,
"width": image.width,
"height": image.height,
"file_name": image_name,
}
self.coco_images.append(image_json)

annotation_json = {
"id": image_id,
"image_id": image_id,
"category_id": 0,
"segmentation": None,
"area": bbox[2] * bbox[3],
"bbox": bbox,
"iscrowd": 0
}
self.coco_annotations.append(annotation_json)

def upload_image_to_oss(self, image, image_name, subset, box_list=None):
image_bytesio = io.BytesIO()
image.save(image_bytesio, format="PNG")
self.butcket.put_object(
f"{self.images_save_oss_dir}/{subset}/{image_name}", image_bytesio.getvalue())
if box_list:
label_str = "\n".join([box.to_yolo_format() for box in box_list])
label_name = image_name.split(".")[0] + ".txt"
self.butcket.put_object(
f"{self.label_save_oss_dir}/{subset}/{label_name}", label_str)

def dump_coco_json(self):
categories = [{key: item[key] for key in ("id", "name", "supercategory")} for item in
self.screen_png_images.values()]
coco_json = {
"images": self.coco_images,
"annotations": self.coco_annotations,
"categories": categories
}
self.butcket.put_object(
self.annotations_save_oss_path, json.dumps(coco_json, indent=2))
# with open(f"{self.output_dir}/coco.json", "w") as fp:
# json.dump(coco_json, fp, indent=2)

@hydra.main(version_base=None, config_path=".", config_name="conf")
def main(cfg: DictConfig):
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
# get_image_and_annotation(output_dir)
# screen_dataset_gen = ScreenDatasetGen(cfg, output_dir)

category_generators = []
for data_cfg in cfg.data_cfgs:
category_generators.append(SingleCategoryGen(cfg, data_cfg, output_dir))

bg_img_files = [f for f in Path(cfg.extract_cfg.output_dir.format(user_root=user_root)).iterdir() if
f.suffix in [".png", ".jpg"]]

if cfg.get("max_bg_img_sample"):
bg_img_files = random.sample(bg_img_files, cfg.max_bg_img_sample)

img_index = 0
for bg_img_file in tqdm(bg_img_files):
subset = "train" if img_index % 10 <= 7 else "val" if img_index % 10 == 8 else "test"
img_index += 1
for category_generator in category_generators:
category_generator.add_new_images(bg_img_path=bg_img_file, subset=subset)

for category_generator in category_generators:
category_generator.upload_full_screen_image()

if __name__ == '__main__':
main()
```

运行后, 可以在outputs文件夹下生成符合要求的训练数据.

![](https://cdn.nlark.com/yuque/0/2024/png/114633/1730344557819-d66c0669-2275-413f-9194-39dca8bf2908.png)

image 就是背景+检测物体

labels 中的内容就是这样的文件:

```python
1 0.6701388888888888 0.289453125 0.5736111111111111 0.57421875
# 类型 box
```

# Step-2 训练模型
这个更简单, 在官网下载一个模型权重, 比如yolo8s.pt, 对付安全帽这种东西, 几M大的模型就够了.

训练配置文件:

```python
names:
0: logo
1: 截屏
2: 红包
path: /outputs
test: images/test
train: images/train
val: images/val

```

训练代码:

没错就这么一点

```python
from ultralytics import YOLO

model = YOLO('./yolo8s.pt')
model.train(data='dataset.yaml', epochs=100, imgsz=1280)
```

然后就可以自动化训练了, 结束后会自动保存模型与评估检测效果.

![](https://cdn.nlark.com/yuque/0/2024/png/114633/1730344671489-386881b2-8ccf-4fe7-aa9b-554c78439513.png)

# Step-3 检测
检测代码示意:

```python
class Special_Obj_Detect(object):

def __init__(self, cfg) -> None:
model_path = cfg.model_path
self.model = YOLO(model_path)
self.model.requires_grad_ = False
self.cls_names = {0: 'logo', 1: '截屏', 2: '红包'}

# 单帧图像检测
def detect_image(self, img_path):
results = self.model(img_path)
objects = []
objects_cnt = dict()
objects_area_pct = dict()
for result in results:
result = result.cpu()
boxes = list(result.boxes)
for box in boxes:
if box.conf < 0.8: continue
boxcls = box.cls[0].item()
objects.append(self.cls_names[boxcls])
objects_cnt[self.cls_names[boxcls]] = objects_cnt.get(self.cls_names[boxcls], 0) + 1
area_p = sum([ (xywh[2]*xywh[3]).item() for xywh in box.xywhn])
area_p = min(1, area_p)
objects_area_pct[self.cls_names[boxcls]] = area_p
objects = list(set(objects))
return objects, objects_cnt, objects_area_pct
```

收工.

#