2024年11月

在电商或服务平台中,
缓存
的使用是提高系统性能和响应速度的关键。然而,
缓存穿透
是一个常见的性能瓶颈问题,尤其是在查询不存在的数据时,系统会直接访问数据库,这不仅影响性能,还可能造成数据库负担过重。为了有效解决这个问题,我们提出了一种结合
布隆过滤器

空值缓存

分布式锁
的缓存穿透防护方案。以下是该方案的工作流程。

工作流程

1. 用户请求优惠券模板信息

用户首先发起对优惠券模板信息的请求。该请求包括一个优惠券模板ID,系统需要根据该ID返回相应的优惠券信息。

2. 缓存查询:Redis缓存

系统首先会在
Redis缓存
中查询是否已经缓存了相关的优惠券信息。Redis 是一个高效的缓存系统,通常可以极大地提高查询速度。如果缓存中存在相应的模板信息,系统直接返回给用户,查询过程结束。

3. 缓存未命中:布隆过滤器的使用

如果 Redis 缓存中没有找到对应的优惠券模板信息,系统会进一步通过
布隆过滤器
检查该模板ID是否有效。布隆过滤器是一种空间效率极高的数据结构,用来快速判断某个元素是否在集合中。

  • 如果布隆过滤器中没有该模板ID
    ,说明该优惠券模板ID不合法或已经失效,系统直接返回给用户
    “失败:无效的优惠券模板ID”
  • 如果布隆过滤器中存在该模板ID
    ,表示该优惠券模板ID可能有效,系统会继续查询数据库。

4. 空值缓存:防止重复查询

在布隆过滤器判断模板ID有效的情况下,系统继续检查 Redis 缓存中是否存在空值缓存。空值缓存是指对于某些查询,数据库返回了“空”结果(例如优惠券模板ID不存在于数据库中),为了避免重复查询数据库,这类空结果会被缓存一段时间。

  • 如果 Redis 缓存中存在空值
    ,系统会直接返回
    “失败:无效的优惠券模板ID”
    ,避免重复的数据库查询。
  • 如果 Redis 缓存中没有空值
    ,系统继续进行数据库查询操作。

5. 分布式锁:保证数据一致性

为了防止多个请求同时查询数据库,造成数据库压力过大,或者多个线程同时执行相同查询操作,系统使用了
分布式锁
来确保在同一时间只有一个请求会访问数据库查询数据。

  • 如果分布式锁可用
    ,系统获取锁,并进行以下步骤:


    1. 查询数据库获取优惠券模板信息。
    2. 如果数据库返回数据,系统将数据缓存到 Redis 中,减少后续请求对数据库的访问。
    3. 如果数据库返回空数据,系统在 Redis 中缓存空结果,并设置短时间过期,防止短时间内重复查询。
    4. 最后释放分布式锁。
  • 如果分布式锁不可用
    ,表示其他请求正在进行相同的数据库查询操作,系统会等待锁释放或返回错误信息。

6. 返回结果:缓存数据或数据库数据

  • 如果 Redis 缓存中有数据
    ,系统直接返回缓存的数据给用户。
  • 如果缓存中没有数据且查询成功
    ,系统将数据库中的数据返回给用户,并缓存该数据以提高后续查询的效率。
  • 如果查询失败
    (例如模板ID无效或数据库无数据),系统返回错误信息。


流程图

image

代码实现

public CouponTemplateQueryRespDTO getCouponTemplate(CouponTemplateQueryReqDTO requestParam) {
    // 查询 Redis 缓存中是否存在优惠券模板信息
    String cacheKey = String.format(RedisConstants.COUPON_TEMPLATE_KEY, requestParam.getTemplateId());
    Map<Object, Object> cacheMap = stringRedisTemplate.opsForHash().entries(cacheKey);

    // 如果缓存存在直接返回,否则通过布隆过滤器、空值缓存以及分布式锁查询数据库
    if (MapUtil.isEmpty(cacheMap)) {
        // 判断布隆过滤器是否存在指定模板 ID,不存在则返回错误
        if (!bloomFilter.contains(requestParam.getTemplateId())) {
            throw new ClientException("Coupon template does not exist");
        }

        // 查询 Redis 缓存中是否存在空值信息,如果存在则直接返回
        String nullCacheKey = String.format(RedisConstants.COUPON_TEMPLATE_NULL_KEY, requestParam.getTemplateId());
        Boolean isNullCached = stringRedisTemplate.hasKey(nullCacheKey);
        if (isNullCached) {
            throw new ClientException("Coupon template does not exist");
        }

        // 获取分布式锁
        RLock lock = redissonClient.getLock(String.format(RedisConstants.LOCK_COUPON_TEMPLATE_KEY, requestParam.getTemplateId()));
        lock.lock();

        try {
            // 双重检查空值缓存
            isNullCached = stringRedisTemplate.hasKey(nullCacheKey);
            if (isNullCached) {
                throw new ClientException("Coupon template does not exist");
            }

            // 使用双重检查锁避免并发查询数据库
            cacheMap = stringRedisTemplate.opsForHash().entries(cacheKey);
            if (MapUtil.isEmpty(cacheMap)) {
                LambdaQueryWrapper<CouponTemplate> queryWrapper = Wrappers.lambdaQuery(CouponTemplate.class)
                        .eq(CouponTemplate::getShopId, Long.parseLong(requestParam.getShopId()))
                        .eq(CouponTemplate::getId, Long.parseLong(requestParam.getTemplateId()))
                        .eq(CouponTemplate::getStatus, TemplateStatusEnum.ACTIVE.getStatus());
                CouponTemplate couponTemplate = couponTemplateMapper.selectOne(queryWrapper);

                // 如果模板不存在或已过期,设置空值缓存并抛出异常
                if (couponTemplate == null) {
                    stringRedisTemplate.opsForValue().set(nullCacheKey, "", 30, TimeUnit.MINUTES);
                    throw new ClientException("Coupon template does not exist or has expired");
                }

                // 将数据库记录序列化并存入 Redis 缓存
                CouponTemplateQueryRespDTO responseDTO = BeanUtil.toBean(couponTemplate, CouponTemplateQueryRespDTO.class);
                Map<String, Object> responseMap = BeanUtil.beanToMap(responseDTO, false, true);
                Map<String, String> cacheData = responseMap.entrySet().stream()
                        .collect(Collectors.toMap(
                                Map.Entry::getKey,
                                entry -> entry.getValue() != null ? entry.getValue().toString() : ""
                        ));

                // 使用 Lua 脚本将数据存入 Redis 并设置过期时间
                String luaScript = "redis.call('HMSET', KEYS[1], unpack(ARGV, 1, #ARGV - 1)) " +
                        "redis.call('EXPIREAT', KEYS[1], ARGV[#ARGV])";

                List<String> keys = Collections.singletonList(cacheKey);
                List<String> args = new ArrayList<>(cacheData.size() * 2 + 1);
                cacheData.forEach((key, value) -> {
                    args.add(key);
                    args.add(value);
                });

                // 设置优惠券活动的过期时间
                args.add(String.valueOf(couponTemplate.getEndTime().getTime() / 1000));

                // 执行 Lua 脚本
                stringRedisTemplate.execute(
                        new DefaultRedisScript<>(luaScript, Long.class),
                        keys,
                        args.toArray()
                );
                cacheMap = cacheData.entrySet()
                        .stream()
                        .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
            }
        } finally {
            lock.unlock();
        }
    }

    // 返回从缓存中获取的数据
    return BeanUtil.mapToBean(cacheMap, CouponTemplateQueryRespDTO.class, false, CopyOptions.create());
}

1.
背景

花了整整两天时间,本
qiang~
开发了一个关于
AI
新闻资讯的自动聚合及报告生成工具。

本篇记录一下整体的框架和实现原理,并且本着它山之石可以攻玉,本
qiang~
开放了所有的源码,源码可见如下第
5
章节,感谢各位看官的大力支持。如有问题,可私信或留言沟通。

成品可以参考链接:《
AI
资讯每日速递
(2024.11.05)

2.
为什么要做这件事?

深处
AI
时代,想要追赶前沿的一手技术与资讯,有一个工具能够实时获取每天的重点内容,包括咨询和技术相关内容,并且能够按照公司及内容的优先级进行筛选,然后午后捧着一杯奶茶,点开自动生成的报告,岂不美哉美哉?

3.
相关技术

  1. Crawl4ai
    :

    一块集成
    LLM
    的开源爬虫工具
  2. Swarm
    : OpenAI

    发布的
    Multi-Agent
    编排框架,可以参考本人先前的辛苦整理:《
    LLM
    应用实战
    : OpenAI
    多代理框架
    -Swarm

  3. Python-docx
    : word

    的操作工具
  4. Textdistance
    :
    用于报告模块中资讯排序结果与原始资讯结果的对齐
  5. Gpt-4o-mini
    :

    采用的大模型是
    gpt-4o-mini
    ,每日免费调用
    200
    次,不够用
    ...

4.
整体框架

整体框架分为三个模块:

4.1
下载模块

下载模块的数据源包括各大
AI
新闻网站及知名博客,然后通过开源爬虫工具
crawl4ai
进行爬取,爬取的维度包括标题、内容、图片等。

4.2
解析模块

解析模块是针对爬取的结果进行解析,采用
OpenAi Swarm
框架,包含
4

Agent
,其中
Analysis Agent
是主体
Agent
,遍历下载的每一个资讯,将每条资讯分别同步给其他
Agent
完成具体的解析任务。其中
Translator Agent
主要功能是翻译,将英文翻译为中文;
Classifier Agent
主要功能是针对资讯进行分类,如涉及技术还是产品之类的;
Modifier Agent
主要功能是将资讯的标题和内容进行改写,标题可以改写更醒目一些,内容主要是提取摘要信息。

Analysis Agent
负责串联其他
3

Agent
,每个
Agent
结束后均会返回到
Analysis Agent
,以便让
Analysis Agent
决定下一步的操作。

4.3
报告模块

报告模块包含
Sorter Agent
,主要功能是将解析后的资讯按照公司、内容等维度进行排序,然后筛选出其中相对排名较高的资讯。

经过排序
Agent
后,最终将结果保存为
word

5.
全部源码

5.1
下载模块

采用
crawl4ai
工具进行网站爬取,示例的网站是
https://www.aibase.com
,网站存在中文及英文,但增加翻译
Agent
是为了兼容其他网站。

1.
文件处理
file_util.py

importjsonimporthashlibdef get_datas(file_path, json_flag=True, all_flag=False, mode='r'):"""读取文本文件"""results=[]

with open(file_path, mode, encoding
='utf-8') as f:for line inf.readlines():ifjson_flag:
results.append(json.loads(line))
else:
results.append(line.strip())
ifall_flag:ifjson_flag:return json.loads(''.join(results))else:return '\n'.join(results)returnresultsdef save_datas(file_path, datas, json_flag=True, all_flag=False, with_indent=False, mode='w'):"""保存文本文件"""with open(file_path, mode, encoding='utf-8') as f:ifall_flag:ifjson_flag:
f.write(json.dumps(datas, ensure_ascii
=False, indent= 4 if with_indent elseNone))else:
f.write(
''.join(datas))else:for data indatas:ifjson_flag:
f.write(json.dumps(data, ensure_ascii
=False) + '\n')else:
f.write(data
+ '\n')

2.
网站爬取
web_crawler.py


from crawl4ai importAsyncWebCrawlerfrom crawl4ai.extraction_strategy importJsonCssExtractionStrategyimportjsonfrom typing importDict, Any, Union, Listfrom bs4 importBeautifulSoupfrom file_util import *
importosimportdatetimeimportreimportrequestsclassAbstractAICrawler():def __init__(self) ->None:pass
    defcrawl():raiseNotImplementedError()classAINewsCrawler(AbstractAICrawler):def __init__(self, domain) ->None:
super().
__init__()
self.domain
=domain
self.file_path
= f'data/{self.domain}.json'self.history=self.init()definit(self):if notos.path.exists(self.file_path):return{}return {ele['id']: ele for ele inget_datas(self.file_path)}defsave(self, datas: Union[List, Dict]):ifisinstance(datas, dict):
datas
=[datas]
self.history.update({ele[
'id']: ele for ele indatas})
save_datas(self.file_path, datas
=list(self.history.values()))

async
def crawl(self, url:str, schema: Dict[str, Any]=None):
extraction_strategy
= JsonCssExtractionStrategy(schema, verbose=True) if schema elseNone
async with AsyncWebCrawler(verbose
=True) as crawler:
result
=await crawler.arun(
url
=url,
extraction_strategy
=extraction_strategy,
bypass_cache
=True,
)
assert result.success, "Failed to crawl the page" ifschema:returnjson.loads(result.extracted_content)returnresult.cleaned_htmlclassAIBasesCrawler(AINewsCrawler):def __init__(self) ->None:
self.domain
= 'aibase'super().__init__(self.domain)
self.url
= 'https://www.aibase.com'asyncdef crawl_home(self, url='https://www.aibase.com/news'):
schema
={'name': 'ai base home page crawler','baseSelector': '.flex','fields': [
{
'name': 'link','selector': 'a[rel="noopener noreferrer"]','type': 'nested_list','fields': [
{
'name': 'href', 'type': 'attribute', 'attribute':'href'}
]
}
]
}
links
=await super().crawl(url, schema)
links
= [link['href'] for ele in links for link in ele['link']]
links
= list(set([f'{self.url}{ele}' for ele in links if ele.startswith('/news')]))
links
= sorted(links, key=lambda x: x, reverse=True)returnlinks

async
defcrawl_newsletter_cn(self, url):
html
=await super().crawl(url)
body
= BeautifulSoup(html, 'html.parser')
title
= body.select_one('h1').get_text().replace('\u200b', '').strip()
date
= [ele.get_text().strip() for ele in body.find_all('span') if re.match(r'(\d{4}年\d{1,2}月\d{1,2}号)', ele.get_text().strip())][0]
date
= datetime.datetime.strptime(date, '%Y年%m月%d号 %H:%M').strftime("%Y-%m-%d")
content
= '\n'.join([ele.get_text().strip().replace('\n', '').replace(' ', '') for ele in body.find_all('p')])
content
= content[:content.index('划重点:')].strip() if '划重点:' in content elsecontentreturn{'title': title,'link': url,'content': content,'date': date
}

async
def crawl_home_cn(self, url='https://www.aibase.com/zh/news'):
schema
={'name': 'ai base home page crawler','baseSelector': '.flex','fields': [
{
'name': 'link','selector': 'a[rel="noopener noreferrer"]','type': 'nested_list','fields': [
{
'name': 'href', 'type': 'attribute', 'attribute':'href'}
]
}
]
}
links
=await super().crawl(url, schema)
links
= [link['href'] for ele in links for link in ele['link']]
links
= list(set([f'{self.url}{ele}' for ele in links if ele.startswith('/zh/news')]))
links
= sorted(links, key=lambda x: x, reverse=True)returnlinks

async
defcrawl_newsletter(self, url):
html
=await super().crawl(url)
body
= BeautifulSoup(html, 'html.parser')
title
= body.select_one('h1').get_text().replace('\u200b', '').strip()
date
= ';'.join([ele.get_text().strip() for ele in body.find_all('span')])
date
= re.findall(r'(\b\w{3}\s+\d{1,2},\s+\d{4}\b)', date)[0]
date
= datetime.datetime.strptime(date, '%b %d, %Y').strftime("%Y-%m-%d")
content
= '\n'.join([ele.get_text().strip().replace('\n', '') for ele in body.find_all('p')])
content
= content[:content.index('Key Points:')].strip() if 'Key Points:' in content elsecontent
pic_urls
= [ele.get('src').strip() for ele in body.select('img') if ele.get('title')]
pic_url
= pic_urls[0] if pic_urls else ''pic_url= pic_url.replace('\\"', '')
pic_path
= '' ifpic_url:
pic_path
= f'data/images/{md5(url)}.jpg'response=requests.get(pic_url)if response.status_code == 200:
with open(pic_path,
'wb') as f:
f.write(response.content)
return{'title': title,'link': url,'content': content,'date': date,'pic': pic_path,'id': md5(url)
}

async
defcrawl(self):
links
=await self.crawl_home()
results
=[]for link inlinks:
_id
=md5(link)if _id inself.history:continueresults.append({'id': _id,'link': link,'contents': await self.crawl_newsletter(link),'time': datetime.datetime.now().strftime('%Y-%m-%d')
})
self.save(results)
returnawait self.get_last_day_data()

async
defget_last_day_data(self):
last_day
= (datetime.date.today() - datetime.timedelta(days=1)).strftime('%Y-%m-%d')
datas
=self.init()for v indatas.values():
v[
'contents']['id'] = v['id']return [v['contents'] for v in datas.values() if v['contents']['date'] == last_day]

View Code

5.2
解析模块

1. 解析提示语prompt.py


ANALYST = """你是一个AI领域的分析师,主要工作步骤如下:
1. 首先执行transform_to_translate_agent方法,切换到translate agent,执行翻译任务;
2. 然后再执行transform_to_classifier_agent,调用classifier agent,针对内容进行分类;
3. 接着再执行transform_to_modifier_agent,调用modifier agent,针对内容进行改写;
4. 前三步执行完毕后,意味着整个分析工作已经完成,最后调用finish方法,退出该整个工作流程。
需要注意的是:每个步骤必须执行完成后,才能执行后续的步骤,且同时只能有1个步骤在执行;如果modifier agent已经执行完毕,一定要调用finish退出整体工作流程。
"""TRANSLATE= """你现在是一个AI领域的翻译专家,请将如下英文的标题和内容分别翻译为中文。步骤及要求如下:
1. 首先调用translate方法进行翻译,要求如下:
a. 需要注意的标题和内容中如果包含公司名称、产品名称、技术名称等专业词汇,针对这些专业词汇需要保留英文形式,其他非专业词汇需要翻译为中文,注意标题也必须翻译;
b. 输出格式为 "标题: xxxxx\n内容: xxxxx",且需要保留换行符;
c. 注意该translate方法没有输入参数,返回的结果只是需要翻译的原始文本,需要你执行翻译操作,然后返回翻译结果;
d. 该translate方法执行完成后,需要你执行具体的翻译,等待翻译完成后,才能开展下一个步骤,不能直接将原文作为参数传给下一个步骤;

2. 抽取完成后,执行extract_translate_result方法,要求如下:
a. 该extract_translate_result方法存在1个输入参数,即执行1后得到的翻译结果

3. 待步骤2执行完成后,执行transform_to_analysis_agent方法,切换至analysis agent,执行其他工作。

4. 步骤1,2,3必须按照顺序执行,且同时只能有1个步骤在执行

5. 如果历史记录中已经执行了任何步骤,注意严格禁止再次重复执行,而要直接执行未执行的步骤,
"""CLASSIFIER= """你是一个AI领域的分类器,请判断输入是否与AI的技术相关。步骤及要求如下:
1. 首先调用classify方法进行分类,要求如下:
a. 输入的内容包括标题和内容两部分,重点基于内容进行判断这条信息是否与AI技术相关;
b. 如果是相关技术细节、技术原理、代码说明、架构说明,则输出"是",如果是与公司的最新资讯相关,如发行新产品、成立新部门、公司合作等非技术相关的,则输出"否"
c. 输出的结果只能是"是"、"否"两个选项中的一个,不要输出其他内容,包括解释信息等。
d. 注意该classify方法没有输入参数,返回的结果只是需要分类的原始文本,需要你执行分类任务,然后返回分类结果;


2. 获取到分类结果后,执行extract_classify_result方法,要求如下:
a. 该extract_classify_result方法存在1个输入参数,即执行1后得到的分类结果

3. 待步骤2执行完成后,执行transform_to_analysis_agent方法,切换至analysis agent,执行其他工作

4. 步骤1,2,3必须按照顺序执行,且同时只能有1个步骤在执行

5. 如果历史记录中已经执行了任何步骤,注意严格禁止再次重复执行,而要直接执行未执行的步骤,
"""MODIFIER= """你是一个AI新闻的改写器,请基于输入中的标题和内容进行改写。步骤及要求如下:
1. 首先调用modify方法进行改写,要求如下:
a. 输入的内容包括"标题"和"内容"两部分,需要分别针对"标题"和"内容"进行改写;
b. "标题"的改写目标是需要醒目且具有吸引力,能够吸引读者进一步阅读,要求字数不能超过30字;
c. "内容"需要摘要总结,需要准确提取主要内容,要求字数不超过200字;
d. 输出格式为 "标题: xxxx\n内容: xxxxx",且需要保留换行符,"标题"和"内容"需要以输入的中文为准;
e. 注意该modify方法没有输入参数,返回的结果是需要改写的原始文本,需要你执行改写任务,然后返回改写结果;


2. 获取到改写结果后,执行extract_modify_result方法,要求如下:
a. 该extract_modify_result方法存在1个输入参数,即执行1后得到的改写结果

3. 待步骤2执行完成后,执行transform_to_analysis_agent方法,切换至analysis agent,执行其他工作

4. 步骤1,2,3必须按照顺序执行,且同时只能有1个步骤在执行

5. 如果历史记录中已经执行了任何步骤,注意严格禁止再次重复执行,而要直接执行未执行的步骤
"""

View Code

2. 解析Agent整体流程agent.py


agent copy 2from swarm importSwarm, Agentfrom web_crawler importAIBasesCrawlerimportasynciofrom prompt import *
from file_util import *
from tqdm importtqdmimportdatetime


client
=Swarm()defdownload():returnasyncio.run(AIBasesCrawler().crawl())deftransform_to_analysis_agent():returnanalysis_agentdeftransform_to_translate_agent():returntranslate_agentdeftransform_to_classifier_agent():returnclassifier_agentdeftransform_to_modifier_agent():returnmodifier_agentdeftranslate(context_variables):return f'现在请按要求翻译如下内容:\n标题: {context_variables["title"]}\n内容: {context_variables["content"]}' defextract_translate_result(result: str, context_variables: dict):"""翻译的结果进行抽取

Args:
result (str): 翻译结果
Returns:
str: 翻译结果提取结束标志
"""context_variables['title_zh'] = result[result.index('标题:')+len('标题:'):result.index('内容:')]
context_variables[
'content_zh'] = result[result.index('内容:')+len('内容:'):]return '翻译结果提取任务已经完成,请继续下一步操作。' defclassify(context_variables):return f'现在请按要求针对以下内容进行分类,\n输入:\n标题: {context_variables["title_zh"]}\n内容: {context_variables["content_zh"]},\n输出:' defextract_classify_result(result: str, context_variables: dict):"""分类的结果进行抽取

Args:
result (str): 翻译结果
Returns:
str: 分类结果提取结束标志
"""context_variables['classify'] =resultreturn '分类结果提取任务已经完成,请继续下一步操作。' defmodify(context_variables):return f'现在请按要求针对以下内容进行改写,\n输入:\n标题: {context_variables["title_zh"]}\n内容: {context_variables["content_zh"]},\n输出:' defextract_modify_result(result: str, context_variables: dict):"""改写的结果进行抽取

Args:
result (str): 改写结果
Returns:
str: 改写结果提取结束标志
"""context_variables['title_modify'] = result[result.index('标题:')+len('标题:'):result.index('内容:')]
context_variables[
'content_modify'] = result[result.index('内容:')+len('内容:'):]return '改写结果提取任务已经完成,请继续下一步操作。' deffinish():return '分析任务已经完成,请直接退出整个工作流程,直接输出"退出"。'analysis_agent= Agent(name='analysis_agent', instructions=ANALYST, functions=[transform_to_translate_agent, transform_to_classifier_agent, transform_to_modifier_agent, finish])
translate_agent
= Agent(name='translate_agent', instructions=TRANSLATE, functions=[translate, extract_translate_result, transform_to_analysis_agent])
classifier_agent
= Agent(name='classifier_agent', instructions=CLASSIFIER, functions=[classify, extract_classify_result, transform_to_analysis_agent])
modifier_agent
= Agent(name='modifier_agent', instructions=MODIFIER, functions=[modify, extract_modify_result, transform_to_analysis_agent])

output_file_pre
= (datetime.date.today() - datetime.timedelta(days=1)).strftime('%Y.%m.%d')
output_path
= f'data/{output_file_pre}_final_results.json'results=get_datas(output_path)
process_ids
= [data['id'] for data inresults]for data intqdm(download()):if data['id'] in process_ids: continuecontext_variables= {'title': data['title'], 'content': data['content']}try:
result
= client.run(analysis_agent, messages=[{"role": "user", "content": "现在,请开始分析!"}], context_variables=context_variables, debug=True)
context_variables
=result.context_variables
data[
'title_zh'] = context_variables['title_zh']
data[
'content_zh'] = context_variables['content_zh']
data[
'classify'] = context_variables['classify']
data[
'title_modify'] = context_variables['title_modify']
data[
'content_modify'] = context_variables['content_modify']
save_datas(output_path, [data], mode
='a')exceptException as e:print(e)continue

View Code

5.3
报告模块

1. 排序提示语prompt.py


SORTER = """你是一个AI新闻的排序助手,请给予输入的新闻标题进行排序。要求如下:
1. 排序的规则是基于标题中所提及公司、组织机构的名气和重要性进行排序,名气和重要性是基于你所学的知识进行排序,名气和重要性越高,排名越靠前;
2. 排序的结果只返回名气最高的top10即可,输出的格式为"1xxxxx\n2xxxxx\n3xxxxx...\n10xxxxx",注意一定要以"\n"进行换行;
3. 输出的每个标题,需要和输入中对应的标题保持完全一致,禁止更改;
"""

View Code

2. 排序流程agent.py


from swarm importSwarm, Agentfrom prompt import *
from file_util import *
from collections importdefaultdictimportreimporttextdistancefrom word_util importsave_2_wordimportdatetimeimportrandom


client
=Swarm()
output_file_pre
= (datetime.date.today() - datetime.timedelta(days=1)).strftime('%Y.%m.%d')
output_path
= f'data/{output_file_pre}_final_results.json'sort_agent= Agent(name='sort_agent', instructions=SORTER)

datas
=get_datas(output_path)for ele indatas:
ele[
'title_modify'] = ele['title_modify'].strip()
ele[
'content_modify'] = ele['content_modify'].strip()defget_most_similar(t1, texts):
most_similarity
= 0.0most_similar_text= '' for ele intexts:
similarity
=textdistance.levenshtein.similarity(t1, ele)if similarity >most_similarity:
most_similarity
=similarity
most_similar_text
=elereturnmost_similar_text

type_2_title
=defaultdict(list)
{type_2_title[ele[
'classify']].append(ele['title_modify']) for ele indatas}
title_2_data
= {ele['title_modify']: ele for ele indatas}
final_results
=defaultdict(list)for k, v intype_2_title.items():
content
= "\n".join([ele for ele inv])
message
= f'现在请根据你所学习的知识,按照要求对以下输入进行排序,并且按照输出格式进行输出,\n输入:\n{content},\n输出:'result= client.run(sort_agent, messages=[{"role": "user", "content": message}], debug=True)
sort_results
= [ele['content'] for ele in result.messages[::-1] if 'content' in ele and ele['content'] and ele['content']]
sort_results
= sort_results[0].split('\n') if sort_results else random.sample(v, 10)
sort_results
= [re.sub(r'^\d+[\.,、\s]*', '', ele).strip() for ele insort_results]
final_results[k].extend([title_2_data[get_most_similar(ele, list(title_2_data.keys()))]
for ele insort_results])

sort_output
= f'data/{output_file_pre}_sort_results.json'save_datas(sort_output, [final_results])#生成word save_2_word(final_results, output_file_pre)

View Code

3. 报告生成word_util.py


from docx importDocumentfrom docx.shared importInches, Pt, RGBColorfrom docx.enum.text importWD_PARAGRAPH_ALIGNMENTimportosdefsave_2_word(info_dict, file_pre):
doc
=Document()

categories
= ['', '']
category_color
= 'FF5733' for category incategories:
news
=info_dict[category]
category_paragraph
=doc.add_paragraph()
category
= '技术' if category == '' else '资讯'category_run=category_paragraph.add_run(category)
category_run.bold
=True
category_run.font.size
= Pt(25)
category_run.font.color.rgb
=RGBColor.from_string(category_color)
category_paragraph.alignment
=WD_PARAGRAPH_ALIGNMENT.CENTERfor i, item inenumerate(news):
title
= item['title_modify']
doc.add_heading(f
'{i+1}. {title}', level=1)

pic
= item['pic'] if 'pic' in item else '' if pic andos.path.exists(pic):
pic_paragraph
=doc.add_paragraph()
pic_paragraph.alignment
=WD_PARAGRAPH_ALIGNMENT.CENTER
doc.add_picture(pic, width
=Inches(5))

content
= item['content_modify']
doc.add_paragraph(content)

doc.save(f
'data/AI资讯每日速递({file_pre}).docx')

View Code

6.
优化思考

1.
爬取模块目前是串行下载,且未增加反爬机制
,后续可以增加多线程,且增加代理池机制。

2.
免费的
gpt-4o-mini
每日调用次数仅有
200

次,执行本任务远远不够
,因此后期尝试切换为私有部署的

Qwen2.5

其实已经尝试了
Qwen2.5
,以
vllm
部署,但与
Swarm
框架中的
OpenAi
接口存在少许不兼容,例如不支持特定的参数,只能运行一轮。不过可以进一步优化
Swarm
框架来进行适配。

本次实验本
qiang~
花费了
30
大洋,买了一个
gpt-4o-mini
,生成最终结果,直接耗费了其中的
8
个大洋,烧钱
....

3.
信息推送机制不支持,如一键同步
到公众号、

CSDN
、知乎,这块如果有精力可以基于网站的开发接口,实现一键自动发布文章。

7.
总结

一句话足矣
~

开发了一块
AI
资讯的自动聚合及报告生成工具,包括具体的框架、实现原理以及完整源码,满满诚意,提供给各位看官。欢迎转发、订阅
~

有问题可以私信或留言沟通!

8.
参考

(1)
Swarm:
https://github.com/openai/swarm

(2)
Crawl4ai:
https://github.com/unclecode/crawl4ai

(3)
资讯网站
:
https://www.aibase.com/news

本博客所有文章除特别声明外,均采用
CC BY-NC-SA 4.0
许可协议。转载请注明来自
唯你

使用场景

JAVA 与 Rust 互操作让 Rust 可以背靠 Java 大生态来做更多事情,而 Java 也可以享受 Rust 语言特性的内存安全,所有权机制,无畏并发。

互操作的典型场景包括:

  • 性能优化:利用 Rust 处理计算密集型任务,提高 Java 应用的整体性能。
  • 系统级编程:结合 Rust 的底层控制能力与 Java 的高级抽象,实现更高效的系统交互。
  • 跨平台开发:使用 Rust 编写核心逻辑,通过 JNI 在不同平台上与 Java 交互,实现高效跨平台开发。
  • 安全关键应用:在金融、医疗等领域,利用 Rust 处理敏感数据和核心功能,保证高度安全性。
  • 实时系统:在游戏引擎、音频处理等延迟敏感的应用中,使用 Rust 处理时间关键部分。

背景知识

JNI

image.png

  • 全称 Java Native Interface,它允许 Java 代码与其他语言(如 C 或 C++)编写的应用程序进行互操作。
  • JNI Specification
    :这是 JNI 的官方规范,详细描述了 JNI 的使用方法、接口和功能。

Java 虚拟机(JVM)

image 1.png

JNI 是 Java 虚拟机的一部分,JVM 在启动时为每个线程创建一个 JNI 环境。JNI 环境包括指向 JVM 内部数据结构的指针,这些数据结构用于存储 Java 对象、方法和字段的信息。

JNIEnv (JNI 环境)

image 2.png

  • JNIEnv
    是一个指向结构体的指针,代表当前线程的 JNI 环境。它包含所有 JNI 相关函数的指针,让你能在本地代码中使用这些函数。每个线程都有自己独立的
    JNIEnv
    ,所以不能在不同线程间传递这个指针。
  • 可以将
    JNIEnv
    视为一个"翻译器"。当 Rust 代码需要与 Java 交互时,它通过这个"翻译器"发送请求,当调用 Java 方法或获取 Java 对象的属性。每个线程都拥有自己独立的"翻译器",这确保了各线程与 Java 交互时的独立性。

另外

  • 当 Java 代码调用本地方法时,JVM 会加载相应的本地库并创建一个
    JNIEnv
    指针。
  • 本地代码可以使用这个指针访问 JNI 提供的函数,进行 Java 对象的操作。
  • 每个线程有独立
    JNIEnv
    ,保证线程安全。新线程需调用
    AttachCurrentThread
    获取对应
    JNIEnv
  • JNI 提供了数据类型转换机制,实现 Java 与 C/C++之间的数据传递。

在 Rust 生态中使用 jni 0.21.1 库可以实现与 Java 代码的交互。

JNI 0.21.1 简介

该项目为 Rust 提供了完整的 JNI 绑定,允许:

  • 使用 Rust 代码与 Java 库进行交互,调用 Java 方法和访问 Java 对象。

  • 从 Rust 代码中使用 Java 类和接口。

  • 实现跨语言的高效数据交换。

  • 利用 Rust 的性能优势和 Java 的成熟生态系统

跨平台 UI 框架 Flutter 源码中的 MethodChannel 实现了 Dart 与 Android 层的通信,其底层 C++也是通过 JNI 调用插件中的 onMethodCall 来实现的。这与上述 jni 0.21.1 采用了相同的思路,但存在以下不同点:

  • 语言特性和类型安全:
    Rust 的
    jni
    库提供了一种更安全的方式来处理 Java 对象和方法调用。它利用 Rust 的所有权系统来减少潜在的内存错误,使得在 Rust 中使用 JNI 时更易于管理资源和避免常见错误。
  • 多平台支持:jni 0.21.1 提供了更广泛的跨平台支持。

如何运行示例

示例源码请
阅读原文
,见原文底部
源码获取

在 Windows 11 环境下运行示例时,笔者遇到了两个问题:

  1. Windows 自带的 PowerShell 无法直接执行 Makefile
  2. 由于 Rust 配置了特定目标平台,出现了不明原因的编译错误

以下是解决这些问题的方法:

在 MinGW-w64 中执行 Makefile

  1. 确保已在 MinGW-w64 环境中安装
    mingw32-make
    工具(通常随 MinGW-w64 一起安装)
  2. 打开 MinGW-w64 命令行
  3. 导航至 Makefile 所在目录
  4. 执行以下命令
//当前示例中是makefile,直接执行mingw32-make -f makefile即可
mingw32-make -f YourMakefileName

确认当前 Rust 环境

举例来说,笔者在 C:\Users\xxx.cargo 目录下配置了 config.toml 文件:

[build]
target = "aarch64-linux-android"

这导致在 Windows 上使用 mingw32-make 来编译针对 Android 平台的 Rust .so 文件,造成了混乱并引发了莫名其妙的编译错误。
解决方法是删除不必要的 config.toml 文件,确保当前运行环境与目标平台(如 Windows)一致。

输出结果:

Hello, josh!
[B@2f92e0f4
factCallback: res = 720
counterCallback: count = 1
counterCallback: count = 2
counterCallback: count = 3
counterCallback: count = 4
counterCallback: count = 5
Invoking asyncComputation (thread id = 1)
asyncCallback: thread id = 23, progress = 0%
asyncCallback: thread id = 23, progress = 10%
asyncCallback: thread id = 23, progress = 20%
asyncCallback: thread id = 23, progress = 30%
asyncCallback: thread id = 23, progress = 40%
asyncCallback: thread id = 23, progress = 50%
asyncCallback: thread id = 23, progress = 60%
asyncCallback: thread id = 23, progress = 70%
asyncCallback: thread id = 23, progress = 80%
asyncCallback: thread id = 23, progress = 90%
asyncCallback: thread id = 23, progress = 100%

示例简析

让我们深入分析示例中 asyncComputation 的流程。其核心目的是在 Rust 端执行一个异步计算,同时 Rust 端会调用 Java 端来报告计算进度。

image.png

整体流程图说明:

  1. Java 的 main()方法调用 asyncComputation()。
  2. asyncComputation()通过 JNI 调用 Rust 的 Java_HelloWorld_asyncComputation()函数。
  3. Rust 函数创建一个新线程来执行异步计算。
  4. 在新线程中,Rust 执行计算并周期性地调用 Java 的 asyncCallback()方法报告进度。
  5. 当 Rust 完成计算后,控制权返回到 Java 的主线程。

这个过程展示了 Java 调用 Rust(步骤 1)和 Rust 回调 Java(步骤 4)的双向交互。

源码如下

image 3.png

1 将一个 HelloWorld 类实例传递给 Rust 端,这对应下方 Rust 侧实现中 #3 处的 callback 对象

image 4.png

Rust 实现说明

  1. JNIEnv
    参数:详见上述相关概念中的解释。
  2. JClass
    :代表调用此本地方法的 Java 类引用,主要用于访问类级别的静态方法和字段。
  3. callback
    :Java 中新创建的 HelloWorld 对象实例。
  4. 获取 JVM 对象:因为
    env
    对象不支持线程间传递和共享(仅实现了 Send),而 JVM 支持。通过 JVM 对象可在线程间传递,并最终获得
    env
  5. 创建全局引用:获取 HelloWorld() 实例对象的全局引用,防止被垃圾回收。
  6. 线程安全:每个线程都有自己的
    JNIEnv
    ,确保线程安全。在新线程中需调用
    AttachCurrentThread
    获取对应的
    JNIEnv
  7. 反向调用 Java:通过
    env
    反向调用 Java 代码。调用对象是新创建的 HelloWorld 实例,回调方法是其中的 asyncCallback。在 JNI 中,
    "(I)V"
    是方法签名,描述了 Java 方法的参数和返回类型。
    "(I)V"
    表示接受一个整数参数并返回
    void
    的方法。在这里,asyncCallback 方法接收一个整数(
    progress
    )作为参数,无返回值。

总结

  1. Java 与 Rust 互操作让两种语言优势互补,提高性能和安全性,适用于多种场景如性能优化、系统级编程和跨平台开发。
  2. JNI(Java Native Interface)是实现 Java 与 Rust 互操作的关键技术,允许 Java 代码与其他语言编写的应用程序进行交互。
  3. 通过示例分析,我们了解了 Java 调用 Rust 函数和 Rust 回调 Java 方法的双向交互过程,展示了两种语言之间的无缝协作。

参考链接

https://github.com/jni-rs/jni-rs

JNI APIs and Developer Guides

Leveraging Rust in our high-performance Java database

上一篇:《构建人工智能模型基础:TFDS和Keras的完美搭配》

序言:
在人工智能模型的训练过程中,如何高效管理和处理大量数据是一个重要的课题。TensorFlow 的 TFRecord 格式为大规模数据存储和处理提供了一种灵活且高效的解决方案。在本节知识中,我们将介绍如何利用 TFRecord 结合 TensorFlow 的 Dataset API 进行数据的提取、转换和加载(ETL),从而更好地支持人工智能模型的训练和优化。通过 TFRecord,您可以将原始数据存储为一种轻量且易于处理的二进制格式,从而在大规模数据集的加载和解析上获得显著的性能提升。我们将详细探索如何构建 ETL 数据管道,通过并行处理、批处理和预取等技术,让数据加载与模型训练过程更加流畅、快速。这种数据处理方式不仅在单台机器上表现出色,还能在多核 CPU、GPU 或 TPU 上扩展,实现更大规模的人工智能模型训练。无论您是处理图像、文本,还是其他类型的大规模数据,理解并掌握 TFRecord 及其优化技术将为构建高效的数据管道奠定基础,使您能够更快速、智能地训练人工智能模型。

理解 TFRecord

当您使用 TFDS 时,数据会被下载并缓存到磁盘,因此您无需每次使用时都重新下载。TFDS 使用 TFRecord 格式进行缓存。如果您仔细观察下载过程,就会发现这一点——例如,图 4-1 展示了 cnn_dailymail 数据集如何被下载、打乱并写入 TFRecord 文件。

                        图 4-1. 将 cnn_dailymail 数据集下载为 TFRecord 文件

在 TensorFlow 中,TFRecord 是存储和检索大量数据的首选格式。这是一种非常简单的文件结构,按顺序读取以提高性能。在磁盘上,文件的结构相对直接,每条记录由一个表示记录长度的整数、其对应的循环冗余校验(CRC)、一个数据的字节数组及该字节数组的 CRC 组成。这些记录被连接成一个文件,如果数据集很大,则会进行分片。

                      例如,图 4-2 显示了 cnn_dailymail 的训练集在下载后被分成了 16 个文件。

为了更直观地了解一个简单的示例,可以下载 MNIST 数据集并打印其信息:

data, info = tfds.load("mnist", with_info=True)

print(info)

在 info 中,您会看到其特征是这样存储的:

features=FeaturesDict({

'image': Image(shape=(28, 28, 1), dtype=tf.uint8),

'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),

}),

与 CNN/DailyMail 的示例类似,文件会被下载到 /root/tensorflow_datasets/mnist/
/files 目录。

您可以像这样将原始记录作为 TFRecordDataset 加载:

filename="/root/tensorflow_datasets/mnist/3.0.0/mnist-test.tfrecord-00000-of-00001"

raw_dataset = tf.data.TFRecordDataset(filename)

for raw_record in raw_dataset.take(1):

print(repr(raw_record))

请注意,文件名的位置可能会根据您的操作系统而有所不同。

<tf.Tensor: shape=(), dtype=string, numpy=b"\n\x85\x03\n\xf2\x02\n\x05image\x12\xe8\x02\n\xe5\x02\n\xe2\x02\x89PNG\r \n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x1c\x00\x00\x00\x1c\x08\x00\x00\x00\x00Wf \x80H\x00\x00\x01)IDAT(\x91\xc5\xd2\xbdK\xc3P\x14\x05\xf0S(v\x13)\x04,.\x82\xc5A q\xac\xedb\x1d\xdc\n.\x12\x87n\x0e\x82\x93\x7f@Q\xb2\x08\xba\tbQ0.\xe2\xe2\xd4\x b1\xa2h\x9c\x82\xba\x8a(\nq\xf0\x83Fh\x95\n6\x88\xe7R\x87\x88\xf9\xa8Y\xf5\x0e\x 8f\xc7\xfd\xdd\x0b\x87\xc7\x03\xfe\xbeb\x9d\xadT\x927Q\xe3\xe9\x07:\xab\xbf\xf4\ xf3\xcf\xf6\x8a\xd9\x14\xd29\xea\xb0\x1eKH\xde\xab\xea%\xaba\x1b=\xa4P/\xf5\x02\ xd7\\x07\x00\xc4=,L\xc0,>\x01@2\xf6\x12\xde\x9c\xde[t/\xb3\x0e\x87\xa2\xe2\ xc2\xe0A<\xca\xb26\xd5(\x1b\xa9\xd3\xe8\x0e\xf5\x86\x17\xceE\xdarV\xae\xb7_\xf3 I\xf7(\x06m\xaaE\xbb\xb6\xac\r
\x9b$e<\xb8\xd7\xa2\x0e\x00\xd0l\x92\xb2\xd5\x15\ xcc\xae'\x00\xf4m\x08O'+\xc2y\x9f\x8d\xc9\x15\x80\xfe\x99[q\x962@CN|i\xf7\xa9!=\ \xab\x19\x00\xc8\xd6\xb8\xeb\xa1\xf0\xd8l\xca\xfb]\xee\xfb]
\x9fV\xe1\x07\xb7\xc 9\x8b55\xe7M\xef\xb0\x04\xc0\xfd&\x89\x01<\xbe\xf9\x03*\x8a\xf5\x81\x7f\xaa/2y\x 87ks\xec\x1e\xc1\x00\x00\x00\x00IEND\xaeB`\x82\n\x0e\n\x05label\x12\x05\x1a\x03\ n\x01\x02">

它是一个包含记录详细信息的长字符串,里面还包括校验和等内容。但是如果我们已经知道特征,我们就可以创建一个特征描述,然后用它来解析数据。代码如下:

创建特征描述

feature_description = {

'image': tf.io.FixedLenFeature([], dtype=tf.string),

'label': tf.io.FixedLenFeature([], dtype=tf.int64),

}

def _parse_function(example_proto):

使用上面的字典解析输入的
tf.Example
proto

return tf.io.parse_single_example(example_proto, feature_description)

parsed_dataset = raw_dataset.map(_parse_function)

for parsed_record in parsed_dataset.take(1):

print((parsed_record))

这样输出的内容就友好多了!首先,您可以看到图像是一个 Tensor,并且它包含一个 PNG。PNG 是一种压缩图像格式,头部由 IHDR 定义,图像数据位于 IDAT 和 IEND 之间。如果仔细观察字节流,您可以看到它们。同时也有标签,它以整数形式存储,包含值 2:

{

'image': <tf.Tensor: shape=(), dtype=string,

numpy=b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x1c\x00\x00\x00\x1c\x08\x00\x00\x00\x00Wf\x80H\x00\x00\x01)IDAT(\x91\xc5\xd2\xbdK\xc3P\x14\x05\xf0S(v\x13)\x04,.\x82\xc5Aq\xac\xedb\x1d\xdc\n.\x12\x87n\x0e\x82\x93\x7f@Q\xb2\x08\xba\tbQ0.\xe2\xe2\xd4\xb1\xa2h\x9c\x82\xba\x8a(\nq\xf0\x83Fh\x95\n6\x88\xe7R\x87\x88\xf9\xa8Y\xf5\x0e\x8f\xc7\xfd\xdd\x0b\x87\xc7\x03\xfe\xbeb\x9d\xadT\x927Q\xe3\xe9\x07:\xab\xbf\xf4\xf3\xcf\xf6\x8a\xd9\x14\xd29\xea\xb0\x1eKH\xde\xab\xea%\xaba\x1b=\xa4P/\xf5\x02\xd7\\x07\x00\xc4=,L\xc0,>\x01@2\xf6\x12\xde\x9c\xde[t/\xb3\x0e\x87\xa2\xe2\xc2\xe0A<\xca\xb26\xd5(\x1b\xa9\xd3\xe8\x0e\xf5\x86\x17\xceE\xdarV\xae\xb7_\xf3AR\r!I\xf7(\x06m\xaaE\xbb\xb6\xac\r*\x9b$e<\xb8\xd7\xa2\x0e\x00\xd0l\x92\xb2\xd5\x15\xcc\xae'\x00\xf4m\x08O'+\xc2y\x9f\x8d\xc9\x15\x80\xfe\x99[q\x962@CN|i\xf7\xa9!=\xd7

\xab\x19\x00\xc8\xd6\xb8\xeb\xa1\xf0\xd8l\xca\xfb]\xee\xfb]
\x9fV\xe1\x07\xb7\xc9\x8b55\xe7M\xef\xb0\x04\xc0\xfd&\x89\x01<\xbe\xf9\x03
\x8a\xf5\x81\x7f\xaa/2y\x87ks\xec\x1e\xc1\x00\x00\x00\x00IEND\xaeB`\x82">,

'label': <tf.Tensor: shape=(), dtype=int64, numpy=2>

}

到这里,您可以读取原始的 TFRecord 并使用类似 Pillow 的 PNG 解码库将其解码为 PNG。

在 TensorFlow 中管理数据的 ETL 过程

无论规模大小,ETL 都是 TensorFlow 用于训练的核心模式。我们在本书中探索了小规模的单台计算机模型构建,但相同的技术可以用于大规模训练,跨多台机器并处理海量数据集。

ETL 过程的提取阶段是将原始数据从存储位置加载,并准备成可以转换的形式。转换阶段是对数据进行操作,使其适合或优化用于训练。例如,将数据进行批处理、图像增强、映射到特征列等逻辑,都可以算作转换阶段的一部分。加载阶段则是将数据加载到神经网络中进行训练。

来看一下完整的代码,用来训练“马匹或人类”分类器。这里我添加了注释,展示提取、转换和加载阶段的所在位置:

import tensorflow as tf

import tensorflow_datasets as tfds

import tensorflow_addons as tfa

模型定义开始

model = tf.keras.models.Sequential([

tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(300, 300, 3)),

tf.keras.layers.MaxPooling2D(2, 2),

tf.keras.layers.Conv2D(32, (3,3), activation='relu'),

tf.keras.layers.MaxPooling2D(2,2),

tf.keras.layers.Conv2D(64, (3,3), activation='relu'),

tf.keras.layers.MaxPooling2D(2,2),

tf.keras.layers.Conv2D(64, (3,3), activation='relu'),

tf.keras.layers.MaxPooling2D(2,2),

tf.keras.layers.Conv2D(64, (3,3), activation='relu'),

tf.keras.layers.MaxPooling2D(2,2),

tf.keras.layers.Flatten(),

tf.keras.layers.Dense(512, activation='relu'),

tf.keras.layers.Dense(1, activation='sigmoid')

])

model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])

模型定义结束

提取阶段开始

data = tfds.load('horses_or_humans', split='train', as_supervised=True)

val_data = tfds.load('horses_or_humans', split='test', as_supervised=True)

提取阶段结束

转换阶段开始

def augmentimages(image, label):

image = tf.cast(image, tf.float32)

image = (image/255)

image = tf.image.random_flip_left_right(image)

image = tfa.image.rotate(image, 40, interpolation='NEAREST')

return image, label

train = data.map(augmentimages)

train_batches = train.shuffle(100).batch(32)

validation_batches = val_data.batch(32)

转换阶段结束

python

转换阶段结束

加载阶段开始

history = model.fit(train_batches, epochs=10, validation_data=validation_batches, validation_steps=1)

加载阶段结束

通过这样的流程,您的数据管道可以更少地受到数据和底层模式变化的影响。当您使用 TFDS 提取数据时,无论数据是小到可以放入内存,还是大到无法放入简单的机器中,都可以使用相同的底层结构。用于转换的 tf.data API 也是一致的,因此无论底层数据源是什么,都可以使用类似的 API。当然,一旦数据被转换,加载数据的过程也是一致的,无论您是在单个 CPU、GPU、多个 GPU 集群,甚至是 TPU 群组上进行训练。

然而,加载数据的方式可能对训练速度产生巨大影响。接下来,我们来看看如何优化加载阶段。

优化加载阶段

在训练模型时,我们可以深入了解提取-转换-加载(ETL)过程。我们可以认为数据的提取和转换可以在任何处理器上进行,包括 CPU。事实上,这些阶段的代码执行诸如下载数据、解压缩数据、逐条记录处理等任务,这些都不是 GPU 或 TPU 的强项,所以这部分代码通常会在 CPU 上运行。但是在训练阶段,GPU 或 TPU 能显著提升性能,因此如果有条件,最好在这一阶段使用 GPU 或 TPU。

因此,在有 GPU 或 TPU 的情况下,理想的做法是将工作负载分配到 CPU 和 GPU/TPU 上:提取和转换在 CPU 上完成,而加载则在 GPU/TPU 上完成。

假设您正在处理一个大型数据集。由于数据量大,必须以批次方式准备数据(即,执行提取和转换),这样就会出现类似图 4-3 所示的情况。当第一个批次正在准备时,GPU/TPU 处于空闲状态。当这个批次准备好时,它会被发送到 GPU/TPU 进行训练,但此时 CPU 则空闲,直到训练完成,CPU 才能开始准备第二个批次。在这里会有大量的空闲时间,因此我们可以看到优化的空间。

            图 4-3. 在 CPU/GPU 上训练

逻辑上的解决方案是并行处理,让数据准备和训练同时进行。这种过程称为流水线处理,见图 4-4。

            图 4-4. 流水线处理

在这种情况下,当 CPU 准备第一个批次时,GPU/TPU 仍然没有任务,因此处于空闲状态。当第一个批次准备好后,GPU/TPU 可以开始训练——同时,CPU 开始准备第二个批次。当然,训练第 n-1 批次和准备第 n 批次所需的时间并不总是相同的。如果训练时间更快,GPU/TPU 会有一段空闲时间;如果训练时间更慢,CPU 会有一段空闲时间。选择合适的批次大小可以帮助优化这里的性能——由于 GPU/TPU 的时间往往更昂贵,您可能会尽量减少它们的空闲时间。

您可能已经注意到,当我们从 Keras 中的简单数据集(如 Fashion MNIST)转向使用 TFDS 版本时,必须在训练之前对它们进行批处理。这就是原因:流水线模型的存在使得无论数据集有多大,您都可以使用一致的 ETL 模式来处理它。

并行 ETL 以提高训练性能

TensorFlow 为您提供了所有并行化提取和转换过程所需的 API。让我们通过 Dogs vs. Cats 数据集和底层 TFRecord 结构来看看它们的样子。

首先,使用 tfds.load 获取数据集:

train_data = tfds.load('cats_vs_dogs', split='train', with_info=True)

如果您想使用底层的 TFRecords,您需要访问下载的原始文件。由于数据集较大,它被分成多个文件(在 4.0.0 版本中分为 8 个)。

您可以创建这些文件的列表,并使用 tf.data.Dataset.list_files 来加载它们:

file_pattern = f'/root/tensorflow_datasets/cats_vs_dogs/4.0.0/cats_vs_dogs-train.tfrecord*'

files = tf.data.Dataset.list_files(file_pattern)

获取文件后,可以使用 files.interleave 将它们加载到数据集中,如下所示:

train_dataset = files.interleave(

tf.data.TFRecordDataset,

cycle_length=4,

num_parallel_calls=tf.data.experimental.AUTOTUNE

)

这里有几个新概念,我们来花点时间解释一下。

cycle_length 参数指定同时处理的输入元素数量。稍后您会看到解码记录的映射函数,它会在从磁盘加载时解码记录。因为 cycle_length 设置为 4,所以这个过程会一次处理四条记录。如果不指定该值,它会根据可用的 CPU 核心数量自动确定。

num_parallel_calls 参数用于指定要执行的并行调用数量。在这里设置为 tf.data.experimental.AUTOTUNE 可以使代码更具可移植性,因为值会根据可用的 CPU 动态调整。结合 cycle_length 参数,您就设置了并行度的最大值。例如,如果在自动调整后 num_parallel_calls 设置为 6 而 cycle_length 是 4,那么会有六个独立线程,每个线程一次加载四条记录。

现在提取过程已经并行化了,我们来看看如何并行化数据的转换。首先,创建加载原始 TFRecord 并将其转换为可用内容的映射函数——例如,将 JPEG 图像解码成图像缓冲区:

def read_tfrecord(serialized_example):

feature_description = {

"image": tf.io.FixedLenFeature((), tf.string, ""),

"label": tf.io.FixedLenFeature((), tf.int64, -1),

}

example = tf.io.parse_single_example(serialized_example, feature_description)

image = tf.io.decode_jpeg(example['image'], channels=3)

image = tf.cast(image, tf.float32)

image = image / 255

image = tf.image.resize(image, (300, 300))

return image, example['label']

如您所见,这是一个典型的映射函数,没有做任何特定的工作来使它并行化。并行化将在调用映射函数时完成。以下是实现方法:

import multiprocessing

cores = multiprocessing.cpu_count()

print(cores)

train_dataset = train_dataset.map(read_tfrecord, num_parallel_calls=cores)

train_dataset = train_dataset.cache()

首先,如果您不想自动调优,可以使用 multiprocessing 库获取 CPU 的数量。然后,在调用映射函数时,您可以将此 CPU 数量作为并行调用的数量传入。就是这么简单。

cache 方法会将数据集缓存到内存中。如果您的 RAM 足够多,这会显著加快速度。不过,如果在 Colab 中使用 Dogs vs. Cats 数据集尝试此操作,可能会导致虚拟机崩溃,因为数据集无法完全装入内存。在这种情况下,Colab 的基础设施会为您提供一个新的、更高 RAM 的机器。

加载和训练过程同样可以并行化。在对数据进行打乱和批处理时,还可以根据可用 CPU 核心数量进行预取。代码如下:

train_dataset = train_dataset.shuffle(1024).batch(32)

train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

当训练集完全并行化后,您可以像以前一样训练模型:

model.fit(train_dataset, epochs=10, verbose=1)

我在 Google Colab 中试验了这一点,发现这些用于并行化 ETL 过程的额外代码将每个 epoch 的训练时间从 75 秒减少到约 40 秒。如此简单的更改几乎将我的训练时间减半!

总结

到此为止我们完成了介绍谷歌的 TensorFlow Datasets,这是一个可以让您访问各种数据集的库,从小规模学习数据集到用于研究的全规模数据集。您看到它们使用了通用 API 和格式,以减少您获取数据时所需编写的代码量。我们还讨论了 ETL 过程,它是 TFDS 设计的核心,特别是我们探索了并行化提取、转换和加载数据以提高训练性能。在接下来的知识中,我们将细分学习当今最热的人工智能主题,自然语言处理技术。

一、引言

站长接触 AOT 已有 3 个月之久,此前在《
好消息:NET 9 X86 AOT的突破 - 支持老旧Win7与XP环境
》一文中就有所提及。在这段时间里,站长使用 Avalonia 开发的项目也成功完成了 AOT 发布测试。然而,这一过程并非一帆风顺。站长在项目功能完成大半部分才开始进行 AOT 测试,期间遭遇了不少问题,可谓是 “踩坑无数”。为了方便日后回顾,也为了给广大读者提供参考,在此将这段经历进行总结。

.NET AOT是将.NET代码提前编译为本机代码的技术。其优势众多,启动速度快,减少运行时资源占用,还提高安全性。AOT发布后无需再安装.NET运行时等依赖。.NET 8、9 AOT发布后,可在XP、Win7非SP1操作系统下运行。这使得应用部署更便捷,能适应更多老旧系统环境,为开发者拓展了应用场景,在性能提升的同时,也增加了系统兼容性,让.NET应用的开发和部署更具灵活性和广泛性,给用户带来更好的体验。

二、经验之谈

(一)测试策略的重要性

从项目创建伊始,就应养成良好的习惯,即只要添加了新功能或使用了较新的语法,就及时进行 AOT 发布测试。否则,问题积累到后期,解决起来会异常艰难,站长就因前期忽视了这一点,付出了惨痛的代价。无奈的解决方法是重新创建项目,然后逐个还原功能并进行 AOT 测试。经过了一周的加班AOT测试,每个 AOT 发布过程大致如下:

  1. 内网 AOT 发布一次需 2、3 分钟,这段时间只能看看需求文档、技术文章、需求文档、技术文章。。。
  2. 发布完成,运行无效果,体现在双击未出现界面,进程列表没有它,说明程序崩溃了,查看系统应用事件日志,日志中通常会包含异常警告信息。
  3. 依据日志信息检查代码,修改相关 API。
  4. 再次进行 AOT 发布,重复上述 1 - 3 步骤。

经过一周的努力,项目 AOT 后功能测试终于正常,至此收工。

(二)AOT 需要注意的点及解决方法

1. 添加rd.xml

在主工程创建一个XML文件,例如
Roots.xml
,内容大致如下:

<linker>
	<assembly fullname="CodeWF.Toolbox.Desktop" preserve="All" />
</linker>

需要支持AOT的工程,在该XML中添加一个
assembly
节点,
fullname
是程序集名称,
CodeWF.Toolbox.Desktop
是站长小工具的主工程名,
点击
查看源码。

在主工程添加
ItemGroup
节点关联该XML文件:

<ItemGroup>
    <TrimmerRootDescriptor Include="Roots.xml" />
</ItemGroup>

2. Prism支持

站长使用了Prism框架及DryIOC容器,若要支持 AOT,需要添加以下 NuGet 包:

<PackageReference Include="Prism.Avalonia" Version="8.1.97.11073" />
<PackageReference Include="Prism.DryIoc.Avalonia" Version="8.1.97.11073" />

rd.xml
需要添加

<assembly fullname="Prism" preserve="All" />
<assembly fullname="DryIoc" preserve="All" />
<assembly fullname="Prism.Avalonia" preserve="All" />
<assembly fullname="Prism.DryIoc.Avalonia" preserve="All" />

3. App.config读写

在.NET Core中使用
System.Configuration.ConfigurationManager
包操作App.config文件,
rd.xml
需添加如下内容:

<assembly fullname="System.Configuration.ConfigurationManager" preserve="All" />

使用
Assembly.GetEntryAssembly().location
失败,目前使用
ConfigurationManager.OpenExeConfiguration(ConfigurationUserLevel.None)
获取的应用程序程序配置,指定路径的方式后续再研究。

4. HttpClient使用

rd.xml
添加如下内容:

<assembly fullname="System.Net.Http" preserve="All" />

5. Dapper支持

Dapper的AOT支持需要安装
Dapper.AOT
包,
rd.xml
添加如下内容:

<assembly fullname="Dapper" preserve="All" />
<assembly fullname="Dapper.AOT" preserve="All" />

数据库操作的方法需要添加
DapperAOT
特性,举例如下:

[DapperAot]
public static bool EnsureTableIsCreated()
{
    try
    {
        using var connection = new SqliteConnection(DBConst.DBConnectionString);
        connection.Open();

        const string sql = $@"
            CREATE TABLE IF NOT EXISTS {nameof(JsonPrettifyEntity)}(
                {nameof(JsonPrettifyEntity.IsSortKey)} Bool,
                {nameof(JsonPrettifyEntity.IndentSize)} INTEGER
        )";

        using var command = new SqliteCommand(sql, connection);
        return command.ExecuteNonQuery() > 0;
    }
    catch (Exception ex)
    {
        return false;
    }
}

6. System.Text.Json

参考
JsonExtensions.cs

序列化

public static bool ToJson<T>(this T obj, out string? json, out string? errorMsg)
{
    if (obj == null)
    {
        json = default;
        errorMsg = "Please provide object";
        return false;
    }

    var options = new JsonSerializerOptions()
    {
        WriteIndented = true,
        Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
        TypeInfoResolver = new DefaultJsonTypeInfoResolver()
    };
    try
    {
        json = JsonSerializer.Serialize(obj, options);
        errorMsg = default;
        return true;
    }
    catch (Exception ex)
    {
        json = default;
        errorMsg = ex.Message;
        return false;
    }
}

反序列化

public static bool FromJson<T>(this string? json, out T? obj, out string? errorMsg)
{
    if (string.IsNullOrWhiteSpace(json))
    {
        obj = default;
        errorMsg = "Please provide json string";
        return false;
    }

    try
    {
        var options = new JsonSerializerOptions()
        {
            Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
            TypeInfoResolver = new DefaultJsonTypeInfoResolver()
        };
        obj = JsonSerializer.Deserialize<T>(json!, options);
        errorMsg = default;
        return true;
    }
    catch (Exception ex)
    {
        obj = default;
        errorMsg = ex.Message;
        return false;
    }
}

7. 反射问题

参考项目
CodeWF.NetWeaver

  1. 创建指定类型的
    List<T>

    Dictionary<T>
    实例:
public static object CreateInstance(Type type)
{
    var itemTypes = type.GetGenericArguments();
    if (typeof(IList).IsAssignableFrom(type))
    {
        var lstType = typeof(List<>);
        var genericType = lstType.MakeGenericType(itemTypes.First());
        return Activator.CreateInstance(genericType)!;
    }
    else
    {
        var dictType = typeof(Dictionary<,>);
        var genericType = dictType.MakeGenericType(itemTypes.First(), itemTypes[1]);
        return Activator.CreateInstance(genericType)!;
    }
}
  1. 反射调用
    List<T>

    Dictionary<T>

    Add
    方法添加元素失败,下面是伪代码:
// List<T>
var addMethod = type.GetMethod("Add");
addMethod.Invoke(obj, new[]{ child })
    
// Dictionary<Key, Value>
var addMethod = type.GetMethod("Add");
addMethod.Invoke(obj, new[]{ key, value })

解决办法,转换为实现的接口调用:

// List<T>
(obj as IList).Add(child);

// Dictionary<Key, Value>
(obj as IDictionary)[key] = value;
  1. 获取数组、
    List<T>

    Dictionary<key, value>
    的元素个数

同上面Add方法反射获取Length或Count属性皆返回0,
value.Property("Length", 0)
,封装的Property非AOT运行正确:

public static T Property<T>(this object obj, string propertyName, T defaultValue = default)
{
    if (obj == null) throw new ArgumentNullException(nameof(obj));
    if (string.IsNullOrEmpty(propertyName)) throw new ArgumentNullException(nameof(propertyName));

    var propertyInfo = obj.GetType().GetProperty(propertyName);
    if (propertyInfo == null)
    {
        return defaultValue;
    }

    var value = propertyInfo.GetValue(obj);

    try
    {
        return (T)Convert.ChangeType(value, typeof(T));
    }
    catch (InvalidCastException)
    {
        return defaultValue;
    }
}

AOT成功:直接通过转换为基类型或实现的接口调用属性即可:

// 数组
var length = ((Array)value).Length;

// List<T>
 if (value is IList list)
{
    var count = list.Count;
}

// Dictionary<key, value>
if (value is IDictionary dictionary)
{
    var count = dictionary.Count;
}

8. Windows 7支持

如遇AOT后无法在
Windows 7
运行,请添加
YY-Thunks
包:

<PackageReference Include="YY-Thunks" Version="1.1.4-Beta3" />

并指定目标框架为
net9.0-windows

9. Winform\兼容XP

如果第8条后还运行不了,请参考上一篇文章《
.NET 9 AOT的突破 - 支持老旧Win7与XP环境 - 码界工坊 (dotnet9.com)
》添加VC-LTL包,这里不赘述。

10. 其他

还有许多其他需要注意的地方,后续想起来逐渐完善本文。

三、总结

AOT 发布测试虽然过程中可能会遇到诸多问题,但通过及时的测试和正确的配置调整,最终能够实现项目的顺利发布。希望以上总结的经验能对大家在 AOT 使用过程中有所帮助,让大家在开发过程中少走弯路,提高项目的开发效率和质量。同时,也期待大家在实践中不断探索和总结,共同推动技术的进步和发展。

AOT可参考项目: