2024年11月

1.
背景

花了整整两天时间,本
qiang~
开发了一个关于
AI
新闻资讯的自动聚合及报告生成工具。

本篇记录一下整体的框架和实现原理,并且本着它山之石可以攻玉,本
qiang~
开放了所有的源码,源码可见如下第
5
章节,感谢各位看官的大力支持。如有问题,可私信或留言沟通。

成品可以参考链接:《
AI
资讯每日速递
(2024.11.05)

2.
为什么要做这件事?

深处
AI
时代,想要追赶前沿的一手技术与资讯,有一个工具能够实时获取每天的重点内容,包括咨询和技术相关内容,并且能够按照公司及内容的优先级进行筛选,然后午后捧着一杯奶茶,点开自动生成的报告,岂不美哉美哉?

3.
相关技术

  1. Crawl4ai
    :

    一块集成
    LLM
    的开源爬虫工具
  2. Swarm
    : OpenAI

    发布的
    Multi-Agent
    编排框架,可以参考本人先前的辛苦整理:《
    LLM
    应用实战
    : OpenAI
    多代理框架
    -Swarm

  3. Python-docx
    : word

    的操作工具
  4. Textdistance
    :
    用于报告模块中资讯排序结果与原始资讯结果的对齐
  5. Gpt-4o-mini
    :

    采用的大模型是
    gpt-4o-mini
    ,每日免费调用
    200
    次,不够用
    ...

4.
整体框架

整体框架分为三个模块:

4.1
下载模块

下载模块的数据源包括各大
AI
新闻网站及知名博客,然后通过开源爬虫工具
crawl4ai
进行爬取,爬取的维度包括标题、内容、图片等。

4.2
解析模块

解析模块是针对爬取的结果进行解析,采用
OpenAi Swarm
框架,包含
4

Agent
,其中
Analysis Agent
是主体
Agent
,遍历下载的每一个资讯,将每条资讯分别同步给其他
Agent
完成具体的解析任务。其中
Translator Agent
主要功能是翻译,将英文翻译为中文;
Classifier Agent
主要功能是针对资讯进行分类,如涉及技术还是产品之类的;
Modifier Agent
主要功能是将资讯的标题和内容进行改写,标题可以改写更醒目一些,内容主要是提取摘要信息。

Analysis Agent
负责串联其他
3

Agent
,每个
Agent
结束后均会返回到
Analysis Agent
,以便让
Analysis Agent
决定下一步的操作。

4.3
报告模块

报告模块包含
Sorter Agent
,主要功能是将解析后的资讯按照公司、内容等维度进行排序,然后筛选出其中相对排名较高的资讯。

经过排序
Agent
后,最终将结果保存为
word

5.
全部源码

5.1
下载模块

采用
crawl4ai
工具进行网站爬取,示例的网站是
https://www.aibase.com
,网站存在中文及英文,但增加翻译
Agent
是为了兼容其他网站。

1.
文件处理
file_util.py

importjsonimporthashlibdef get_datas(file_path, json_flag=True, all_flag=False, mode='r'):"""读取文本文件"""results=[]

with open(file_path, mode, encoding
='utf-8') as f:for line inf.readlines():ifjson_flag:
results.append(json.loads(line))
else:
results.append(line.strip())
ifall_flag:ifjson_flag:return json.loads(''.join(results))else:return '\n'.join(results)returnresultsdef save_datas(file_path, datas, json_flag=True, all_flag=False, with_indent=False, mode='w'):"""保存文本文件"""with open(file_path, mode, encoding='utf-8') as f:ifall_flag:ifjson_flag:
f.write(json.dumps(datas, ensure_ascii
=False, indent= 4 if with_indent elseNone))else:
f.write(
''.join(datas))else:for data indatas:ifjson_flag:
f.write(json.dumps(data, ensure_ascii
=False) + '\n')else:
f.write(data
+ '\n')

2.
网站爬取
web_crawler.py


from crawl4ai importAsyncWebCrawlerfrom crawl4ai.extraction_strategy importJsonCssExtractionStrategyimportjsonfrom typing importDict, Any, Union, Listfrom bs4 importBeautifulSoupfrom file_util import *
importosimportdatetimeimportreimportrequestsclassAbstractAICrawler():def __init__(self) ->None:pass
    defcrawl():raiseNotImplementedError()classAINewsCrawler(AbstractAICrawler):def __init__(self, domain) ->None:
super().
__init__()
self.domain
=domain
self.file_path
= f'data/{self.domain}.json'self.history=self.init()definit(self):if notos.path.exists(self.file_path):return{}return {ele['id']: ele for ele inget_datas(self.file_path)}defsave(self, datas: Union[List, Dict]):ifisinstance(datas, dict):
datas
=[datas]
self.history.update({ele[
'id']: ele for ele indatas})
save_datas(self.file_path, datas
=list(self.history.values()))

async
def crawl(self, url:str, schema: Dict[str, Any]=None):
extraction_strategy
= JsonCssExtractionStrategy(schema, verbose=True) if schema elseNone
async with AsyncWebCrawler(verbose
=True) as crawler:
result
=await crawler.arun(
url
=url,
extraction_strategy
=extraction_strategy,
bypass_cache
=True,
)
assert result.success, "Failed to crawl the page" ifschema:returnjson.loads(result.extracted_content)returnresult.cleaned_htmlclassAIBasesCrawler(AINewsCrawler):def __init__(self) ->None:
self.domain
= 'aibase'super().__init__(self.domain)
self.url
= 'https://www.aibase.com'asyncdef crawl_home(self, url='https://www.aibase.com/news'):
schema
={'name': 'ai base home page crawler','baseSelector': '.flex','fields': [
{
'name': 'link','selector': 'a[rel="noopener noreferrer"]','type': 'nested_list','fields': [
{
'name': 'href', 'type': 'attribute', 'attribute':'href'}
]
}
]
}
links
=await super().crawl(url, schema)
links
= [link['href'] for ele in links for link in ele['link']]
links
= list(set([f'{self.url}{ele}' for ele in links if ele.startswith('/news')]))
links
= sorted(links, key=lambda x: x, reverse=True)returnlinks

async
defcrawl_newsletter_cn(self, url):
html
=await super().crawl(url)
body
= BeautifulSoup(html, 'html.parser')
title
= body.select_one('h1').get_text().replace('\u200b', '').strip()
date
= [ele.get_text().strip() for ele in body.find_all('span') if re.match(r'(\d{4}年\d{1,2}月\d{1,2}号)', ele.get_text().strip())][0]
date
= datetime.datetime.strptime(date, '%Y年%m月%d号 %H:%M').strftime("%Y-%m-%d")
content
= '\n'.join([ele.get_text().strip().replace('\n', '').replace(' ', '') for ele in body.find_all('p')])
content
= content[:content.index('划重点:')].strip() if '划重点:' in content elsecontentreturn{'title': title,'link': url,'content': content,'date': date
}

async
def crawl_home_cn(self, url='https://www.aibase.com/zh/news'):
schema
={'name': 'ai base home page crawler','baseSelector': '.flex','fields': [
{
'name': 'link','selector': 'a[rel="noopener noreferrer"]','type': 'nested_list','fields': [
{
'name': 'href', 'type': 'attribute', 'attribute':'href'}
]
}
]
}
links
=await super().crawl(url, schema)
links
= [link['href'] for ele in links for link in ele['link']]
links
= list(set([f'{self.url}{ele}' for ele in links if ele.startswith('/zh/news')]))
links
= sorted(links, key=lambda x: x, reverse=True)returnlinks

async
defcrawl_newsletter(self, url):
html
=await super().crawl(url)
body
= BeautifulSoup(html, 'html.parser')
title
= body.select_one('h1').get_text().replace('\u200b', '').strip()
date
= ';'.join([ele.get_text().strip() for ele in body.find_all('span')])
date
= re.findall(r'(\b\w{3}\s+\d{1,2},\s+\d{4}\b)', date)[0]
date
= datetime.datetime.strptime(date, '%b %d, %Y').strftime("%Y-%m-%d")
content
= '\n'.join([ele.get_text().strip().replace('\n', '') for ele in body.find_all('p')])
content
= content[:content.index('Key Points:')].strip() if 'Key Points:' in content elsecontent
pic_urls
= [ele.get('src').strip() for ele in body.select('img') if ele.get('title')]
pic_url
= pic_urls[0] if pic_urls else ''pic_url= pic_url.replace('\\"', '')
pic_path
= '' ifpic_url:
pic_path
= f'data/images/{md5(url)}.jpg'response=requests.get(pic_url)if response.status_code == 200:
with open(pic_path,
'wb') as f:
f.write(response.content)
return{'title': title,'link': url,'content': content,'date': date,'pic': pic_path,'id': md5(url)
}

async
defcrawl(self):
links
=await self.crawl_home()
results
=[]for link inlinks:
_id
=md5(link)if _id inself.history:continueresults.append({'id': _id,'link': link,'contents': await self.crawl_newsletter(link),'time': datetime.datetime.now().strftime('%Y-%m-%d')
})
self.save(results)
returnawait self.get_last_day_data()

async
defget_last_day_data(self):
last_day
= (datetime.date.today() - datetime.timedelta(days=1)).strftime('%Y-%m-%d')
datas
=self.init()for v indatas.values():
v[
'contents']['id'] = v['id']return [v['contents'] for v in datas.values() if v['contents']['date'] == last_day]

View Code

5.2
解析模块

1. 解析提示语prompt.py


ANALYST = """你是一个AI领域的分析师,主要工作步骤如下:
1. 首先执行transform_to_translate_agent方法,切换到translate agent,执行翻译任务;
2. 然后再执行transform_to_classifier_agent,调用classifier agent,针对内容进行分类;
3. 接着再执行transform_to_modifier_agent,调用modifier agent,针对内容进行改写;
4. 前三步执行完毕后,意味着整个分析工作已经完成,最后调用finish方法,退出该整个工作流程。
需要注意的是:每个步骤必须执行完成后,才能执行后续的步骤,且同时只能有1个步骤在执行;如果modifier agent已经执行完毕,一定要调用finish退出整体工作流程。
"""TRANSLATE= """你现在是一个AI领域的翻译专家,请将如下英文的标题和内容分别翻译为中文。步骤及要求如下:
1. 首先调用translate方法进行翻译,要求如下:
a. 需要注意的标题和内容中如果包含公司名称、产品名称、技术名称等专业词汇,针对这些专业词汇需要保留英文形式,其他非专业词汇需要翻译为中文,注意标题也必须翻译;
b. 输出格式为 "标题: xxxxx\n内容: xxxxx",且需要保留换行符;
c. 注意该translate方法没有输入参数,返回的结果只是需要翻译的原始文本,需要你执行翻译操作,然后返回翻译结果;
d. 该translate方法执行完成后,需要你执行具体的翻译,等待翻译完成后,才能开展下一个步骤,不能直接将原文作为参数传给下一个步骤;

2. 抽取完成后,执行extract_translate_result方法,要求如下:
a. 该extract_translate_result方法存在1个输入参数,即执行1后得到的翻译结果

3. 待步骤2执行完成后,执行transform_to_analysis_agent方法,切换至analysis agent,执行其他工作。

4. 步骤1,2,3必须按照顺序执行,且同时只能有1个步骤在执行

5. 如果历史记录中已经执行了任何步骤,注意严格禁止再次重复执行,而要直接执行未执行的步骤,
"""CLASSIFIER= """你是一个AI领域的分类器,请判断输入是否与AI的技术相关。步骤及要求如下:
1. 首先调用classify方法进行分类,要求如下:
a. 输入的内容包括标题和内容两部分,重点基于内容进行判断这条信息是否与AI技术相关;
b. 如果是相关技术细节、技术原理、代码说明、架构说明,则输出"是",如果是与公司的最新资讯相关,如发行新产品、成立新部门、公司合作等非技术相关的,则输出"否"
c. 输出的结果只能是"是"、"否"两个选项中的一个,不要输出其他内容,包括解释信息等。
d. 注意该classify方法没有输入参数,返回的结果只是需要分类的原始文本,需要你执行分类任务,然后返回分类结果;


2. 获取到分类结果后,执行extract_classify_result方法,要求如下:
a. 该extract_classify_result方法存在1个输入参数,即执行1后得到的分类结果

3. 待步骤2执行完成后,执行transform_to_analysis_agent方法,切换至analysis agent,执行其他工作

4. 步骤1,2,3必须按照顺序执行,且同时只能有1个步骤在执行

5. 如果历史记录中已经执行了任何步骤,注意严格禁止再次重复执行,而要直接执行未执行的步骤,
"""MODIFIER= """你是一个AI新闻的改写器,请基于输入中的标题和内容进行改写。步骤及要求如下:
1. 首先调用modify方法进行改写,要求如下:
a. 输入的内容包括"标题"和"内容"两部分,需要分别针对"标题"和"内容"进行改写;
b. "标题"的改写目标是需要醒目且具有吸引力,能够吸引读者进一步阅读,要求字数不能超过30字;
c. "内容"需要摘要总结,需要准确提取主要内容,要求字数不超过200字;
d. 输出格式为 "标题: xxxx\n内容: xxxxx",且需要保留换行符,"标题"和"内容"需要以输入的中文为准;
e. 注意该modify方法没有输入参数,返回的结果是需要改写的原始文本,需要你执行改写任务,然后返回改写结果;


2. 获取到改写结果后,执行extract_modify_result方法,要求如下:
a. 该extract_modify_result方法存在1个输入参数,即执行1后得到的改写结果

3. 待步骤2执行完成后,执行transform_to_analysis_agent方法,切换至analysis agent,执行其他工作

4. 步骤1,2,3必须按照顺序执行,且同时只能有1个步骤在执行

5. 如果历史记录中已经执行了任何步骤,注意严格禁止再次重复执行,而要直接执行未执行的步骤
"""

View Code

2. 解析Agent整体流程agent.py


agent copy 2from swarm importSwarm, Agentfrom web_crawler importAIBasesCrawlerimportasynciofrom prompt import *
from file_util import *
from tqdm importtqdmimportdatetime


client
=Swarm()defdownload():returnasyncio.run(AIBasesCrawler().crawl())deftransform_to_analysis_agent():returnanalysis_agentdeftransform_to_translate_agent():returntranslate_agentdeftransform_to_classifier_agent():returnclassifier_agentdeftransform_to_modifier_agent():returnmodifier_agentdeftranslate(context_variables):return f'现在请按要求翻译如下内容:\n标题: {context_variables["title"]}\n内容: {context_variables["content"]}' defextract_translate_result(result: str, context_variables: dict):"""翻译的结果进行抽取

Args:
result (str): 翻译结果
Returns:
str: 翻译结果提取结束标志
"""context_variables['title_zh'] = result[result.index('标题:')+len('标题:'):result.index('内容:')]
context_variables[
'content_zh'] = result[result.index('内容:')+len('内容:'):]return '翻译结果提取任务已经完成,请继续下一步操作。' defclassify(context_variables):return f'现在请按要求针对以下内容进行分类,\n输入:\n标题: {context_variables["title_zh"]}\n内容: {context_variables["content_zh"]},\n输出:' defextract_classify_result(result: str, context_variables: dict):"""分类的结果进行抽取

Args:
result (str): 翻译结果
Returns:
str: 分类结果提取结束标志
"""context_variables['classify'] =resultreturn '分类结果提取任务已经完成,请继续下一步操作。' defmodify(context_variables):return f'现在请按要求针对以下内容进行改写,\n输入:\n标题: {context_variables["title_zh"]}\n内容: {context_variables["content_zh"]},\n输出:' defextract_modify_result(result: str, context_variables: dict):"""改写的结果进行抽取

Args:
result (str): 改写结果
Returns:
str: 改写结果提取结束标志
"""context_variables['title_modify'] = result[result.index('标题:')+len('标题:'):result.index('内容:')]
context_variables[
'content_modify'] = result[result.index('内容:')+len('内容:'):]return '改写结果提取任务已经完成,请继续下一步操作。' deffinish():return '分析任务已经完成,请直接退出整个工作流程,直接输出"退出"。'analysis_agent= Agent(name='analysis_agent', instructions=ANALYST, functions=[transform_to_translate_agent, transform_to_classifier_agent, transform_to_modifier_agent, finish])
translate_agent
= Agent(name='translate_agent', instructions=TRANSLATE, functions=[translate, extract_translate_result, transform_to_analysis_agent])
classifier_agent
= Agent(name='classifier_agent', instructions=CLASSIFIER, functions=[classify, extract_classify_result, transform_to_analysis_agent])
modifier_agent
= Agent(name='modifier_agent', instructions=MODIFIER, functions=[modify, extract_modify_result, transform_to_analysis_agent])

output_file_pre
= (datetime.date.today() - datetime.timedelta(days=1)).strftime('%Y.%m.%d')
output_path
= f'data/{output_file_pre}_final_results.json'results=get_datas(output_path)
process_ids
= [data['id'] for data inresults]for data intqdm(download()):if data['id'] in process_ids: continuecontext_variables= {'title': data['title'], 'content': data['content']}try:
result
= client.run(analysis_agent, messages=[{"role": "user", "content": "现在,请开始分析!"}], context_variables=context_variables, debug=True)
context_variables
=result.context_variables
data[
'title_zh'] = context_variables['title_zh']
data[
'content_zh'] = context_variables['content_zh']
data[
'classify'] = context_variables['classify']
data[
'title_modify'] = context_variables['title_modify']
data[
'content_modify'] = context_variables['content_modify']
save_datas(output_path, [data], mode
='a')exceptException as e:print(e)continue

View Code

5.3
报告模块

1. 排序提示语prompt.py


SORTER = """你是一个AI新闻的排序助手,请给予输入的新闻标题进行排序。要求如下:
1. 排序的规则是基于标题中所提及公司、组织机构的名气和重要性进行排序,名气和重要性是基于你所学的知识进行排序,名气和重要性越高,排名越靠前;
2. 排序的结果只返回名气最高的top10即可,输出的格式为"1xxxxx\n2xxxxx\n3xxxxx...\n10xxxxx",注意一定要以"\n"进行换行;
3. 输出的每个标题,需要和输入中对应的标题保持完全一致,禁止更改;
"""

View Code

2. 排序流程agent.py


from swarm importSwarm, Agentfrom prompt import *
from file_util import *
from collections importdefaultdictimportreimporttextdistancefrom word_util importsave_2_wordimportdatetimeimportrandom


client
=Swarm()
output_file_pre
= (datetime.date.today() - datetime.timedelta(days=1)).strftime('%Y.%m.%d')
output_path
= f'data/{output_file_pre}_final_results.json'sort_agent= Agent(name='sort_agent', instructions=SORTER)

datas
=get_datas(output_path)for ele indatas:
ele[
'title_modify'] = ele['title_modify'].strip()
ele[
'content_modify'] = ele['content_modify'].strip()defget_most_similar(t1, texts):
most_similarity
= 0.0most_similar_text= '' for ele intexts:
similarity
=textdistance.levenshtein.similarity(t1, ele)if similarity >most_similarity:
most_similarity
=similarity
most_similar_text
=elereturnmost_similar_text

type_2_title
=defaultdict(list)
{type_2_title[ele[
'classify']].append(ele['title_modify']) for ele indatas}
title_2_data
= {ele['title_modify']: ele for ele indatas}
final_results
=defaultdict(list)for k, v intype_2_title.items():
content
= "\n".join([ele for ele inv])
message
= f'现在请根据你所学习的知识,按照要求对以下输入进行排序,并且按照输出格式进行输出,\n输入:\n{content},\n输出:'result= client.run(sort_agent, messages=[{"role": "user", "content": message}], debug=True)
sort_results
= [ele['content'] for ele in result.messages[::-1] if 'content' in ele and ele['content'] and ele['content']]
sort_results
= sort_results[0].split('\n') if sort_results else random.sample(v, 10)
sort_results
= [re.sub(r'^\d+[\.,、\s]*', '', ele).strip() for ele insort_results]
final_results[k].extend([title_2_data[get_most_similar(ele, list(title_2_data.keys()))]
for ele insort_results])

sort_output
= f'data/{output_file_pre}_sort_results.json'save_datas(sort_output, [final_results])#生成word save_2_word(final_results, output_file_pre)

View Code

3. 报告生成word_util.py


from docx importDocumentfrom docx.shared importInches, Pt, RGBColorfrom docx.enum.text importWD_PARAGRAPH_ALIGNMENTimportosdefsave_2_word(info_dict, file_pre):
doc
=Document()

categories
= ['', '']
category_color
= 'FF5733' for category incategories:
news
=info_dict[category]
category_paragraph
=doc.add_paragraph()
category
= '技术' if category == '' else '资讯'category_run=category_paragraph.add_run(category)
category_run.bold
=True
category_run.font.size
= Pt(25)
category_run.font.color.rgb
=RGBColor.from_string(category_color)
category_paragraph.alignment
=WD_PARAGRAPH_ALIGNMENT.CENTERfor i, item inenumerate(news):
title
= item['title_modify']
doc.add_heading(f
'{i+1}. {title}', level=1)

pic
= item['pic'] if 'pic' in item else '' if pic andos.path.exists(pic):
pic_paragraph
=doc.add_paragraph()
pic_paragraph.alignment
=WD_PARAGRAPH_ALIGNMENT.CENTER
doc.add_picture(pic, width
=Inches(5))

content
= item['content_modify']
doc.add_paragraph(content)

doc.save(f
'data/AI资讯每日速递({file_pre}).docx')

View Code

6.
优化思考

1.
爬取模块目前是串行下载,且未增加反爬机制
,后续可以增加多线程,且增加代理池机制。

2.
免费的
gpt-4o-mini
每日调用次数仅有
200

次,执行本任务远远不够
,因此后期尝试切换为私有部署的

Qwen2.5

其实已经尝试了
Qwen2.5
,以
vllm
部署,但与
Swarm
框架中的
OpenAi
接口存在少许不兼容,例如不支持特定的参数,只能运行一轮。不过可以进一步优化
Swarm
框架来进行适配。

本次实验本
qiang~
花费了
30
大洋,买了一个
gpt-4o-mini
,生成最终结果,直接耗费了其中的
8
个大洋,烧钱
....

3.
信息推送机制不支持,如一键同步
到公众号、

CSDN
、知乎,这块如果有精力可以基于网站的开发接口,实现一键自动发布文章。

7.
总结

一句话足矣
~

开发了一块
AI
资讯的自动聚合及报告生成工具,包括具体的框架、实现原理以及完整源码,满满诚意,提供给各位看官。欢迎转发、订阅
~

有问题可以私信或留言沟通!

8.
参考

(1)
Swarm:
https://github.com/openai/swarm

(2)
Crawl4ai:
https://github.com/unclecode/crawl4ai

(3)
资讯网站
:
https://www.aibase.com/news

本博客所有文章除特别声明外,均采用
CC BY-NC-SA 4.0
许可协议。转载请注明来自
唯你

使用场景

JAVA 与 Rust 互操作让 Rust 可以背靠 Java 大生态来做更多事情,而 Java 也可以享受 Rust 语言特性的内存安全,所有权机制,无畏并发。

互操作的典型场景包括:

  • 性能优化:利用 Rust 处理计算密集型任务,提高 Java 应用的整体性能。
  • 系统级编程:结合 Rust 的底层控制能力与 Java 的高级抽象,实现更高效的系统交互。
  • 跨平台开发:使用 Rust 编写核心逻辑,通过 JNI 在不同平台上与 Java 交互,实现高效跨平台开发。
  • 安全关键应用:在金融、医疗等领域,利用 Rust 处理敏感数据和核心功能,保证高度安全性。
  • 实时系统:在游戏引擎、音频处理等延迟敏感的应用中,使用 Rust 处理时间关键部分。

背景知识

JNI

image.png

  • 全称 Java Native Interface,它允许 Java 代码与其他语言(如 C 或 C++)编写的应用程序进行互操作。
  • JNI Specification
    :这是 JNI 的官方规范,详细描述了 JNI 的使用方法、接口和功能。

Java 虚拟机(JVM)

image 1.png

JNI 是 Java 虚拟机的一部分,JVM 在启动时为每个线程创建一个 JNI 环境。JNI 环境包括指向 JVM 内部数据结构的指针,这些数据结构用于存储 Java 对象、方法和字段的信息。

JNIEnv (JNI 环境)

image 2.png

  • JNIEnv
    是一个指向结构体的指针,代表当前线程的 JNI 环境。它包含所有 JNI 相关函数的指针,让你能在本地代码中使用这些函数。每个线程都有自己独立的
    JNIEnv
    ,所以不能在不同线程间传递这个指针。
  • 可以将
    JNIEnv
    视为一个"翻译器"。当 Rust 代码需要与 Java 交互时,它通过这个"翻译器"发送请求,当调用 Java 方法或获取 Java 对象的属性。每个线程都拥有自己独立的"翻译器",这确保了各线程与 Java 交互时的独立性。

另外

  • 当 Java 代码调用本地方法时,JVM 会加载相应的本地库并创建一个
    JNIEnv
    指针。
  • 本地代码可以使用这个指针访问 JNI 提供的函数,进行 Java 对象的操作。
  • 每个线程有独立
    JNIEnv
    ,保证线程安全。新线程需调用
    AttachCurrentThread
    获取对应
    JNIEnv
  • JNI 提供了数据类型转换机制,实现 Java 与 C/C++之间的数据传递。

在 Rust 生态中使用 jni 0.21.1 库可以实现与 Java 代码的交互。

JNI 0.21.1 简介

该项目为 Rust 提供了完整的 JNI 绑定,允许:

  • 使用 Rust 代码与 Java 库进行交互,调用 Java 方法和访问 Java 对象。

  • 从 Rust 代码中使用 Java 类和接口。

  • 实现跨语言的高效数据交换。

  • 利用 Rust 的性能优势和 Java 的成熟生态系统

跨平台 UI 框架 Flutter 源码中的 MethodChannel 实现了 Dart 与 Android 层的通信,其底层 C++也是通过 JNI 调用插件中的 onMethodCall 来实现的。这与上述 jni 0.21.1 采用了相同的思路,但存在以下不同点:

  • 语言特性和类型安全:
    Rust 的
    jni
    库提供了一种更安全的方式来处理 Java 对象和方法调用。它利用 Rust 的所有权系统来减少潜在的内存错误,使得在 Rust 中使用 JNI 时更易于管理资源和避免常见错误。
  • 多平台支持:jni 0.21.1 提供了更广泛的跨平台支持。

如何运行示例

示例源码请
阅读原文
,见原文底部
源码获取

在 Windows 11 环境下运行示例时,笔者遇到了两个问题:

  1. Windows 自带的 PowerShell 无法直接执行 Makefile
  2. 由于 Rust 配置了特定目标平台,出现了不明原因的编译错误

以下是解决这些问题的方法:

在 MinGW-w64 中执行 Makefile

  1. 确保已在 MinGW-w64 环境中安装
    mingw32-make
    工具(通常随 MinGW-w64 一起安装)
  2. 打开 MinGW-w64 命令行
  3. 导航至 Makefile 所在目录
  4. 执行以下命令
//当前示例中是makefile,直接执行mingw32-make -f makefile即可
mingw32-make -f YourMakefileName

确认当前 Rust 环境

举例来说,笔者在 C:\Users\xxx.cargo 目录下配置了 config.toml 文件:

[build]
target = "aarch64-linux-android"

这导致在 Windows 上使用 mingw32-make 来编译针对 Android 平台的 Rust .so 文件,造成了混乱并引发了莫名其妙的编译错误。
解决方法是删除不必要的 config.toml 文件,确保当前运行环境与目标平台(如 Windows)一致。

输出结果:

Hello, josh!
[B@2f92e0f4
factCallback: res = 720
counterCallback: count = 1
counterCallback: count = 2
counterCallback: count = 3
counterCallback: count = 4
counterCallback: count = 5
Invoking asyncComputation (thread id = 1)
asyncCallback: thread id = 23, progress = 0%
asyncCallback: thread id = 23, progress = 10%
asyncCallback: thread id = 23, progress = 20%
asyncCallback: thread id = 23, progress = 30%
asyncCallback: thread id = 23, progress = 40%
asyncCallback: thread id = 23, progress = 50%
asyncCallback: thread id = 23, progress = 60%
asyncCallback: thread id = 23, progress = 70%
asyncCallback: thread id = 23, progress = 80%
asyncCallback: thread id = 23, progress = 90%
asyncCallback: thread id = 23, progress = 100%

示例简析

让我们深入分析示例中 asyncComputation 的流程。其核心目的是在 Rust 端执行一个异步计算,同时 Rust 端会调用 Java 端来报告计算进度。

image.png

整体流程图说明:

  1. Java 的 main()方法调用 asyncComputation()。
  2. asyncComputation()通过 JNI 调用 Rust 的 Java_HelloWorld_asyncComputation()函数。
  3. Rust 函数创建一个新线程来执行异步计算。
  4. 在新线程中,Rust 执行计算并周期性地调用 Java 的 asyncCallback()方法报告进度。
  5. 当 Rust 完成计算后,控制权返回到 Java 的主线程。

这个过程展示了 Java 调用 Rust(步骤 1)和 Rust 回调 Java(步骤 4)的双向交互。

源码如下

image 3.png

1 将一个 HelloWorld 类实例传递给 Rust 端,这对应下方 Rust 侧实现中 #3 处的 callback 对象

image 4.png

Rust 实现说明

  1. JNIEnv
    参数:详见上述相关概念中的解释。
  2. JClass
    :代表调用此本地方法的 Java 类引用,主要用于访问类级别的静态方法和字段。
  3. callback
    :Java 中新创建的 HelloWorld 对象实例。
  4. 获取 JVM 对象:因为
    env
    对象不支持线程间传递和共享(仅实现了 Send),而 JVM 支持。通过 JVM 对象可在线程间传递,并最终获得
    env
  5. 创建全局引用:获取 HelloWorld() 实例对象的全局引用,防止被垃圾回收。
  6. 线程安全:每个线程都有自己的
    JNIEnv
    ,确保线程安全。在新线程中需调用
    AttachCurrentThread
    获取对应的
    JNIEnv
  7. 反向调用 Java:通过
    env
    反向调用 Java 代码。调用对象是新创建的 HelloWorld 实例,回调方法是其中的 asyncCallback。在 JNI 中,
    "(I)V"
    是方法签名,描述了 Java 方法的参数和返回类型。
    "(I)V"
    表示接受一个整数参数并返回
    void
    的方法。在这里,asyncCallback 方法接收一个整数(
    progress
    )作为参数,无返回值。

总结

  1. Java 与 Rust 互操作让两种语言优势互补,提高性能和安全性,适用于多种场景如性能优化、系统级编程和跨平台开发。
  2. JNI(Java Native Interface)是实现 Java 与 Rust 互操作的关键技术,允许 Java 代码与其他语言编写的应用程序进行交互。
  3. 通过示例分析,我们了解了 Java 调用 Rust 函数和 Rust 回调 Java 方法的双向交互过程,展示了两种语言之间的无缝协作。

参考链接

https://github.com/jni-rs/jni-rs

JNI APIs and Developer Guides

Leveraging Rust in our high-performance Java database

上一篇:《构建人工智能模型基础:TFDS和Keras的完美搭配》

序言:
在人工智能模型的训练过程中,如何高效管理和处理大量数据是一个重要的课题。TensorFlow 的 TFRecord 格式为大规模数据存储和处理提供了一种灵活且高效的解决方案。在本节知识中,我们将介绍如何利用 TFRecord 结合 TensorFlow 的 Dataset API 进行数据的提取、转换和加载(ETL),从而更好地支持人工智能模型的训练和优化。通过 TFRecord,您可以将原始数据存储为一种轻量且易于处理的二进制格式,从而在大规模数据集的加载和解析上获得显著的性能提升。我们将详细探索如何构建 ETL 数据管道,通过并行处理、批处理和预取等技术,让数据加载与模型训练过程更加流畅、快速。这种数据处理方式不仅在单台机器上表现出色,还能在多核 CPU、GPU 或 TPU 上扩展,实现更大规模的人工智能模型训练。无论您是处理图像、文本,还是其他类型的大规模数据,理解并掌握 TFRecord 及其优化技术将为构建高效的数据管道奠定基础,使您能够更快速、智能地训练人工智能模型。

理解 TFRecord

当您使用 TFDS 时,数据会被下载并缓存到磁盘,因此您无需每次使用时都重新下载。TFDS 使用 TFRecord 格式进行缓存。如果您仔细观察下载过程,就会发现这一点——例如,图 4-1 展示了 cnn_dailymail 数据集如何被下载、打乱并写入 TFRecord 文件。

                        图 4-1. 将 cnn_dailymail 数据集下载为 TFRecord 文件

在 TensorFlow 中,TFRecord 是存储和检索大量数据的首选格式。这是一种非常简单的文件结构,按顺序读取以提高性能。在磁盘上,文件的结构相对直接,每条记录由一个表示记录长度的整数、其对应的循环冗余校验(CRC)、一个数据的字节数组及该字节数组的 CRC 组成。这些记录被连接成一个文件,如果数据集很大,则会进行分片。

                      例如,图 4-2 显示了 cnn_dailymail 的训练集在下载后被分成了 16 个文件。

为了更直观地了解一个简单的示例,可以下载 MNIST 数据集并打印其信息:

data, info = tfds.load("mnist", with_info=True)

print(info)

在 info 中,您会看到其特征是这样存储的:

features=FeaturesDict({

'image': Image(shape=(28, 28, 1), dtype=tf.uint8),

'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),

}),

与 CNN/DailyMail 的示例类似,文件会被下载到 /root/tensorflow_datasets/mnist/
/files 目录。

您可以像这样将原始记录作为 TFRecordDataset 加载:

filename="/root/tensorflow_datasets/mnist/3.0.0/mnist-test.tfrecord-00000-of-00001"

raw_dataset = tf.data.TFRecordDataset(filename)

for raw_record in raw_dataset.take(1):

print(repr(raw_record))

请注意,文件名的位置可能会根据您的操作系统而有所不同。

<tf.Tensor: shape=(), dtype=string, numpy=b"\n\x85\x03\n\xf2\x02\n\x05image\x12\xe8\x02\n\xe5\x02\n\xe2\x02\x89PNG\r \n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x1c\x00\x00\x00\x1c\x08\x00\x00\x00\x00Wf \x80H\x00\x00\x01)IDAT(\x91\xc5\xd2\xbdK\xc3P\x14\x05\xf0S(v\x13)\x04,.\x82\xc5A q\xac\xedb\x1d\xdc\n.\x12\x87n\x0e\x82\x93\x7f@Q\xb2\x08\xba\tbQ0.\xe2\xe2\xd4\x b1\xa2h\x9c\x82\xba\x8a(\nq\xf0\x83Fh\x95\n6\x88\xe7R\x87\x88\xf9\xa8Y\xf5\x0e\x 8f\xc7\xfd\xdd\x0b\x87\xc7\x03\xfe\xbeb\x9d\xadT\x927Q\xe3\xe9\x07:\xab\xbf\xf4\ xf3\xcf\xf6\x8a\xd9\x14\xd29\xea\xb0\x1eKH\xde\xab\xea%\xaba\x1b=\xa4P/\xf5\x02\ xd7\\x07\x00\xc4=,L\xc0,>\x01@2\xf6\x12\xde\x9c\xde[t/\xb3\x0e\x87\xa2\xe2\ xc2\xe0A<\xca\xb26\xd5(\x1b\xa9\xd3\xe8\x0e\xf5\x86\x17\xceE\xdarV\xae\xb7_\xf3 I\xf7(\x06m\xaaE\xbb\xb6\xac\r
\x9b$e<\xb8\xd7\xa2\x0e\x00\xd0l\x92\xb2\xd5\x15\ xcc\xae'\x00\xf4m\x08O'+\xc2y\x9f\x8d\xc9\x15\x80\xfe\x99[q\x962@CN|i\xf7\xa9!=\ \xab\x19\x00\xc8\xd6\xb8\xeb\xa1\xf0\xd8l\xca\xfb]\xee\xfb]
\x9fV\xe1\x07\xb7\xc 9\x8b55\xe7M\xef\xb0\x04\xc0\xfd&\x89\x01<\xbe\xf9\x03*\x8a\xf5\x81\x7f\xaa/2y\x 87ks\xec\x1e\xc1\x00\x00\x00\x00IEND\xaeB`\x82\n\x0e\n\x05label\x12\x05\x1a\x03\ n\x01\x02">

它是一个包含记录详细信息的长字符串,里面还包括校验和等内容。但是如果我们已经知道特征,我们就可以创建一个特征描述,然后用它来解析数据。代码如下:

创建特征描述

feature_description = {

'image': tf.io.FixedLenFeature([], dtype=tf.string),

'label': tf.io.FixedLenFeature([], dtype=tf.int64),

}

def _parse_function(example_proto):

使用上面的字典解析输入的
tf.Example
proto

return tf.io.parse_single_example(example_proto, feature_description)

parsed_dataset = raw_dataset.map(_parse_function)

for parsed_record in parsed_dataset.take(1):

print((parsed_record))

这样输出的内容就友好多了!首先,您可以看到图像是一个 Tensor,并且它包含一个 PNG。PNG 是一种压缩图像格式,头部由 IHDR 定义,图像数据位于 IDAT 和 IEND 之间。如果仔细观察字节流,您可以看到它们。同时也有标签,它以整数形式存储,包含值 2:

{

'image': <tf.Tensor: shape=(), dtype=string,

numpy=b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x1c\x00\x00\x00\x1c\x08\x00\x00\x00\x00Wf\x80H\x00\x00\x01)IDAT(\x91\xc5\xd2\xbdK\xc3P\x14\x05\xf0S(v\x13)\x04,.\x82\xc5Aq\xac\xedb\x1d\xdc\n.\x12\x87n\x0e\x82\x93\x7f@Q\xb2\x08\xba\tbQ0.\xe2\xe2\xd4\xb1\xa2h\x9c\x82\xba\x8a(\nq\xf0\x83Fh\x95\n6\x88\xe7R\x87\x88\xf9\xa8Y\xf5\x0e\x8f\xc7\xfd\xdd\x0b\x87\xc7\x03\xfe\xbeb\x9d\xadT\x927Q\xe3\xe9\x07:\xab\xbf\xf4\xf3\xcf\xf6\x8a\xd9\x14\xd29\xea\xb0\x1eKH\xde\xab\xea%\xaba\x1b=\xa4P/\xf5\x02\xd7\\x07\x00\xc4=,L\xc0,>\x01@2\xf6\x12\xde\x9c\xde[t/\xb3\x0e\x87\xa2\xe2\xc2\xe0A<\xca\xb26\xd5(\x1b\xa9\xd3\xe8\x0e\xf5\x86\x17\xceE\xdarV\xae\xb7_\xf3AR\r!I\xf7(\x06m\xaaE\xbb\xb6\xac\r*\x9b$e<\xb8\xd7\xa2\x0e\x00\xd0l\x92\xb2\xd5\x15\xcc\xae'\x00\xf4m\x08O'+\xc2y\x9f\x8d\xc9\x15\x80\xfe\x99[q\x962@CN|i\xf7\xa9!=\xd7

\xab\x19\x00\xc8\xd6\xb8\xeb\xa1\xf0\xd8l\xca\xfb]\xee\xfb]
\x9fV\xe1\x07\xb7\xc9\x8b55\xe7M\xef\xb0\x04\xc0\xfd&\x89\x01<\xbe\xf9\x03
\x8a\xf5\x81\x7f\xaa/2y\x87ks\xec\x1e\xc1\x00\x00\x00\x00IEND\xaeB`\x82">,

'label': <tf.Tensor: shape=(), dtype=int64, numpy=2>

}

到这里,您可以读取原始的 TFRecord 并使用类似 Pillow 的 PNG 解码库将其解码为 PNG。

在 TensorFlow 中管理数据的 ETL 过程

无论规模大小,ETL 都是 TensorFlow 用于训练的核心模式。我们在本书中探索了小规模的单台计算机模型构建,但相同的技术可以用于大规模训练,跨多台机器并处理海量数据集。

ETL 过程的提取阶段是将原始数据从存储位置加载,并准备成可以转换的形式。转换阶段是对数据进行操作,使其适合或优化用于训练。例如,将数据进行批处理、图像增强、映射到特征列等逻辑,都可以算作转换阶段的一部分。加载阶段则是将数据加载到神经网络中进行训练。

来看一下完整的代码,用来训练“马匹或人类”分类器。这里我添加了注释,展示提取、转换和加载阶段的所在位置:

import tensorflow as tf

import tensorflow_datasets as tfds

import tensorflow_addons as tfa

模型定义开始

model = tf.keras.models.Sequential([

tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(300, 300, 3)),

tf.keras.layers.MaxPooling2D(2, 2),

tf.keras.layers.Conv2D(32, (3,3), activation='relu'),

tf.keras.layers.MaxPooling2D(2,2),

tf.keras.layers.Conv2D(64, (3,3), activation='relu'),

tf.keras.layers.MaxPooling2D(2,2),

tf.keras.layers.Conv2D(64, (3,3), activation='relu'),

tf.keras.layers.MaxPooling2D(2,2),

tf.keras.layers.Conv2D(64, (3,3), activation='relu'),

tf.keras.layers.MaxPooling2D(2,2),

tf.keras.layers.Flatten(),

tf.keras.layers.Dense(512, activation='relu'),

tf.keras.layers.Dense(1, activation='sigmoid')

])

model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])

模型定义结束

提取阶段开始

data = tfds.load('horses_or_humans', split='train', as_supervised=True)

val_data = tfds.load('horses_or_humans', split='test', as_supervised=True)

提取阶段结束

转换阶段开始

def augmentimages(image, label):

image = tf.cast(image, tf.float32)

image = (image/255)

image = tf.image.random_flip_left_right(image)

image = tfa.image.rotate(image, 40, interpolation='NEAREST')

return image, label

train = data.map(augmentimages)

train_batches = train.shuffle(100).batch(32)

validation_batches = val_data.batch(32)

转换阶段结束

python

转换阶段结束

加载阶段开始

history = model.fit(train_batches, epochs=10, validation_data=validation_batches, validation_steps=1)

加载阶段结束

通过这样的流程,您的数据管道可以更少地受到数据和底层模式变化的影响。当您使用 TFDS 提取数据时,无论数据是小到可以放入内存,还是大到无法放入简单的机器中,都可以使用相同的底层结构。用于转换的 tf.data API 也是一致的,因此无论底层数据源是什么,都可以使用类似的 API。当然,一旦数据被转换,加载数据的过程也是一致的,无论您是在单个 CPU、GPU、多个 GPU 集群,甚至是 TPU 群组上进行训练。

然而,加载数据的方式可能对训练速度产生巨大影响。接下来,我们来看看如何优化加载阶段。

优化加载阶段

在训练模型时,我们可以深入了解提取-转换-加载(ETL)过程。我们可以认为数据的提取和转换可以在任何处理器上进行,包括 CPU。事实上,这些阶段的代码执行诸如下载数据、解压缩数据、逐条记录处理等任务,这些都不是 GPU 或 TPU 的强项,所以这部分代码通常会在 CPU 上运行。但是在训练阶段,GPU 或 TPU 能显著提升性能,因此如果有条件,最好在这一阶段使用 GPU 或 TPU。

因此,在有 GPU 或 TPU 的情况下,理想的做法是将工作负载分配到 CPU 和 GPU/TPU 上:提取和转换在 CPU 上完成,而加载则在 GPU/TPU 上完成。

假设您正在处理一个大型数据集。由于数据量大,必须以批次方式准备数据(即,执行提取和转换),这样就会出现类似图 4-3 所示的情况。当第一个批次正在准备时,GPU/TPU 处于空闲状态。当这个批次准备好时,它会被发送到 GPU/TPU 进行训练,但此时 CPU 则空闲,直到训练完成,CPU 才能开始准备第二个批次。在这里会有大量的空闲时间,因此我们可以看到优化的空间。

            图 4-3. 在 CPU/GPU 上训练

逻辑上的解决方案是并行处理,让数据准备和训练同时进行。这种过程称为流水线处理,见图 4-4。

            图 4-4. 流水线处理

在这种情况下,当 CPU 准备第一个批次时,GPU/TPU 仍然没有任务,因此处于空闲状态。当第一个批次准备好后,GPU/TPU 可以开始训练——同时,CPU 开始准备第二个批次。当然,训练第 n-1 批次和准备第 n 批次所需的时间并不总是相同的。如果训练时间更快,GPU/TPU 会有一段空闲时间;如果训练时间更慢,CPU 会有一段空闲时间。选择合适的批次大小可以帮助优化这里的性能——由于 GPU/TPU 的时间往往更昂贵,您可能会尽量减少它们的空闲时间。

您可能已经注意到,当我们从 Keras 中的简单数据集(如 Fashion MNIST)转向使用 TFDS 版本时,必须在训练之前对它们进行批处理。这就是原因:流水线模型的存在使得无论数据集有多大,您都可以使用一致的 ETL 模式来处理它。

并行 ETL 以提高训练性能

TensorFlow 为您提供了所有并行化提取和转换过程所需的 API。让我们通过 Dogs vs. Cats 数据集和底层 TFRecord 结构来看看它们的样子。

首先,使用 tfds.load 获取数据集:

train_data = tfds.load('cats_vs_dogs', split='train', with_info=True)

如果您想使用底层的 TFRecords,您需要访问下载的原始文件。由于数据集较大,它被分成多个文件(在 4.0.0 版本中分为 8 个)。

您可以创建这些文件的列表,并使用 tf.data.Dataset.list_files 来加载它们:

file_pattern = f'/root/tensorflow_datasets/cats_vs_dogs/4.0.0/cats_vs_dogs-train.tfrecord*'

files = tf.data.Dataset.list_files(file_pattern)

获取文件后,可以使用 files.interleave 将它们加载到数据集中,如下所示:

train_dataset = files.interleave(

tf.data.TFRecordDataset,

cycle_length=4,

num_parallel_calls=tf.data.experimental.AUTOTUNE

)

这里有几个新概念,我们来花点时间解释一下。

cycle_length 参数指定同时处理的输入元素数量。稍后您会看到解码记录的映射函数,它会在从磁盘加载时解码记录。因为 cycle_length 设置为 4,所以这个过程会一次处理四条记录。如果不指定该值,它会根据可用的 CPU 核心数量自动确定。

num_parallel_calls 参数用于指定要执行的并行调用数量。在这里设置为 tf.data.experimental.AUTOTUNE 可以使代码更具可移植性,因为值会根据可用的 CPU 动态调整。结合 cycle_length 参数,您就设置了并行度的最大值。例如,如果在自动调整后 num_parallel_calls 设置为 6 而 cycle_length 是 4,那么会有六个独立线程,每个线程一次加载四条记录。

现在提取过程已经并行化了,我们来看看如何并行化数据的转换。首先,创建加载原始 TFRecord 并将其转换为可用内容的映射函数——例如,将 JPEG 图像解码成图像缓冲区:

def read_tfrecord(serialized_example):

feature_description = {

"image": tf.io.FixedLenFeature((), tf.string, ""),

"label": tf.io.FixedLenFeature((), tf.int64, -1),

}

example = tf.io.parse_single_example(serialized_example, feature_description)

image = tf.io.decode_jpeg(example['image'], channels=3)

image = tf.cast(image, tf.float32)

image = image / 255

image = tf.image.resize(image, (300, 300))

return image, example['label']

如您所见,这是一个典型的映射函数,没有做任何特定的工作来使它并行化。并行化将在调用映射函数时完成。以下是实现方法:

import multiprocessing

cores = multiprocessing.cpu_count()

print(cores)

train_dataset = train_dataset.map(read_tfrecord, num_parallel_calls=cores)

train_dataset = train_dataset.cache()

首先,如果您不想自动调优,可以使用 multiprocessing 库获取 CPU 的数量。然后,在调用映射函数时,您可以将此 CPU 数量作为并行调用的数量传入。就是这么简单。

cache 方法会将数据集缓存到内存中。如果您的 RAM 足够多,这会显著加快速度。不过,如果在 Colab 中使用 Dogs vs. Cats 数据集尝试此操作,可能会导致虚拟机崩溃,因为数据集无法完全装入内存。在这种情况下,Colab 的基础设施会为您提供一个新的、更高 RAM 的机器。

加载和训练过程同样可以并行化。在对数据进行打乱和批处理时,还可以根据可用 CPU 核心数量进行预取。代码如下:

train_dataset = train_dataset.shuffle(1024).batch(32)

train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

当训练集完全并行化后,您可以像以前一样训练模型:

model.fit(train_dataset, epochs=10, verbose=1)

我在 Google Colab 中试验了这一点,发现这些用于并行化 ETL 过程的额外代码将每个 epoch 的训练时间从 75 秒减少到约 40 秒。如此简单的更改几乎将我的训练时间减半!

总结

到此为止我们完成了介绍谷歌的 TensorFlow Datasets,这是一个可以让您访问各种数据集的库,从小规模学习数据集到用于研究的全规模数据集。您看到它们使用了通用 API 和格式,以减少您获取数据时所需编写的代码量。我们还讨论了 ETL 过程,它是 TFDS 设计的核心,特别是我们探索了并行化提取、转换和加载数据以提高训练性能。在接下来的知识中,我们将细分学习当今最热的人工智能主题,自然语言处理技术。

一、引言

站长接触 AOT 已有 3 个月之久,此前在《
好消息:NET 9 X86 AOT的突破 - 支持老旧Win7与XP环境
》一文中就有所提及。在这段时间里,站长使用 Avalonia 开发的项目也成功完成了 AOT 发布测试。然而,这一过程并非一帆风顺。站长在项目功能完成大半部分才开始进行 AOT 测试,期间遭遇了不少问题,可谓是 “踩坑无数”。为了方便日后回顾,也为了给广大读者提供参考,在此将这段经历进行总结。

.NET AOT是将.NET代码提前编译为本机代码的技术。其优势众多,启动速度快,减少运行时资源占用,还提高安全性。AOT发布后无需再安装.NET运行时等依赖。.NET 8、9 AOT发布后,可在XP、Win7非SP1操作系统下运行。这使得应用部署更便捷,能适应更多老旧系统环境,为开发者拓展了应用场景,在性能提升的同时,也增加了系统兼容性,让.NET应用的开发和部署更具灵活性和广泛性,给用户带来更好的体验。

二、经验之谈

(一)测试策略的重要性

从项目创建伊始,就应养成良好的习惯,即只要添加了新功能或使用了较新的语法,就及时进行 AOT 发布测试。否则,问题积累到后期,解决起来会异常艰难,站长就因前期忽视了这一点,付出了惨痛的代价。无奈的解决方法是重新创建项目,然后逐个还原功能并进行 AOT 测试。经过了一周的加班AOT测试,每个 AOT 发布过程大致如下:

  1. 内网 AOT 发布一次需 2、3 分钟,这段时间只能看看需求文档、技术文章、需求文档、技术文章。。。
  2. 发布完成,运行无效果,体现在双击未出现界面,进程列表没有它,说明程序崩溃了,查看系统应用事件日志,日志中通常会包含异常警告信息。
  3. 依据日志信息检查代码,修改相关 API。
  4. 再次进行 AOT 发布,重复上述 1 - 3 步骤。

经过一周的努力,项目 AOT 后功能测试终于正常,至此收工。

(二)AOT 需要注意的点及解决方法

1. 添加rd.xml

在主工程创建一个XML文件,例如
Roots.xml
,内容大致如下:

<linker>
	<assembly fullname="CodeWF.Toolbox.Desktop" preserve="All" />
</linker>

需要支持AOT的工程,在该XML中添加一个
assembly
节点,
fullname
是程序集名称,
CodeWF.Toolbox.Desktop
是站长小工具的主工程名,
点击
查看源码。

在主工程添加
ItemGroup
节点关联该XML文件:

<ItemGroup>
    <TrimmerRootDescriptor Include="Roots.xml" />
</ItemGroup>

2. Prism支持

站长使用了Prism框架及DryIOC容器,若要支持 AOT,需要添加以下 NuGet 包:

<PackageReference Include="Prism.Avalonia" Version="8.1.97.11073" />
<PackageReference Include="Prism.DryIoc.Avalonia" Version="8.1.97.11073" />

rd.xml
需要添加

<assembly fullname="Prism" preserve="All" />
<assembly fullname="DryIoc" preserve="All" />
<assembly fullname="Prism.Avalonia" preserve="All" />
<assembly fullname="Prism.DryIoc.Avalonia" preserve="All" />

3. App.config读写

在.NET Core中使用
System.Configuration.ConfigurationManager
包操作App.config文件,
rd.xml
需添加如下内容:

<assembly fullname="System.Configuration.ConfigurationManager" preserve="All" />

使用
Assembly.GetEntryAssembly().location
失败,目前使用
ConfigurationManager.OpenExeConfiguration(ConfigurationUserLevel.None)
获取的应用程序程序配置,指定路径的方式后续再研究。

4. HttpClient使用

rd.xml
添加如下内容:

<assembly fullname="System.Net.Http" preserve="All" />

5. Dapper支持

Dapper的AOT支持需要安装
Dapper.AOT
包,
rd.xml
添加如下内容:

<assembly fullname="Dapper" preserve="All" />
<assembly fullname="Dapper.AOT" preserve="All" />

数据库操作的方法需要添加
DapperAOT
特性,举例如下:

[DapperAot]
public static bool EnsureTableIsCreated()
{
    try
    {
        using var connection = new SqliteConnection(DBConst.DBConnectionString);
        connection.Open();

        const string sql = $@"
            CREATE TABLE IF NOT EXISTS {nameof(JsonPrettifyEntity)}(
                {nameof(JsonPrettifyEntity.IsSortKey)} Bool,
                {nameof(JsonPrettifyEntity.IndentSize)} INTEGER
        )";

        using var command = new SqliteCommand(sql, connection);
        return command.ExecuteNonQuery() > 0;
    }
    catch (Exception ex)
    {
        return false;
    }
}

6. System.Text.Json

参考
JsonExtensions.cs

序列化

public static bool ToJson<T>(this T obj, out string? json, out string? errorMsg)
{
    if (obj == null)
    {
        json = default;
        errorMsg = "Please provide object";
        return false;
    }

    var options = new JsonSerializerOptions()
    {
        WriteIndented = true,
        Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
        TypeInfoResolver = new DefaultJsonTypeInfoResolver()
    };
    try
    {
        json = JsonSerializer.Serialize(obj, options);
        errorMsg = default;
        return true;
    }
    catch (Exception ex)
    {
        json = default;
        errorMsg = ex.Message;
        return false;
    }
}

反序列化

public static bool FromJson<T>(this string? json, out T? obj, out string? errorMsg)
{
    if (string.IsNullOrWhiteSpace(json))
    {
        obj = default;
        errorMsg = "Please provide json string";
        return false;
    }

    try
    {
        var options = new JsonSerializerOptions()
        {
            Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
            TypeInfoResolver = new DefaultJsonTypeInfoResolver()
        };
        obj = JsonSerializer.Deserialize<T>(json!, options);
        errorMsg = default;
        return true;
    }
    catch (Exception ex)
    {
        obj = default;
        errorMsg = ex.Message;
        return false;
    }
}

7. 反射问题

参考项目
CodeWF.NetWeaver

  1. 创建指定类型的
    List<T>

    Dictionary<T>
    实例:
public static object CreateInstance(Type type)
{
    var itemTypes = type.GetGenericArguments();
    if (typeof(IList).IsAssignableFrom(type))
    {
        var lstType = typeof(List<>);
        var genericType = lstType.MakeGenericType(itemTypes.First());
        return Activator.CreateInstance(genericType)!;
    }
    else
    {
        var dictType = typeof(Dictionary<,>);
        var genericType = dictType.MakeGenericType(itemTypes.First(), itemTypes[1]);
        return Activator.CreateInstance(genericType)!;
    }
}
  1. 反射调用
    List<T>

    Dictionary<T>

    Add
    方法添加元素失败,下面是伪代码:
// List<T>
var addMethod = type.GetMethod("Add");
addMethod.Invoke(obj, new[]{ child })
    
// Dictionary<Key, Value>
var addMethod = type.GetMethod("Add");
addMethod.Invoke(obj, new[]{ key, value })

解决办法,转换为实现的接口调用:

// List<T>
(obj as IList).Add(child);

// Dictionary<Key, Value>
(obj as IDictionary)[key] = value;
  1. 获取数组、
    List<T>

    Dictionary<key, value>
    的元素个数

同上面Add方法反射获取Length或Count属性皆返回0,
value.Property("Length", 0)
,封装的Property非AOT运行正确:

public static T Property<T>(this object obj, string propertyName, T defaultValue = default)
{
    if (obj == null) throw new ArgumentNullException(nameof(obj));
    if (string.IsNullOrEmpty(propertyName)) throw new ArgumentNullException(nameof(propertyName));

    var propertyInfo = obj.GetType().GetProperty(propertyName);
    if (propertyInfo == null)
    {
        return defaultValue;
    }

    var value = propertyInfo.GetValue(obj);

    try
    {
        return (T)Convert.ChangeType(value, typeof(T));
    }
    catch (InvalidCastException)
    {
        return defaultValue;
    }
}

AOT成功:直接通过转换为基类型或实现的接口调用属性即可:

// 数组
var length = ((Array)value).Length;

// List<T>
 if (value is IList list)
{
    var count = list.Count;
}

// Dictionary<key, value>
if (value is IDictionary dictionary)
{
    var count = dictionary.Count;
}

8. Windows 7支持

如遇AOT后无法在
Windows 7
运行,请添加
YY-Thunks
包:

<PackageReference Include="YY-Thunks" Version="1.1.4-Beta3" />

并指定目标框架为
net9.0-windows

9. Winform\兼容XP

如果第8条后还运行不了,请参考上一篇文章《
.NET 9 AOT的突破 - 支持老旧Win7与XP环境 - 码界工坊 (dotnet9.com)
》添加VC-LTL包,这里不赘述。

10. 其他

还有许多其他需要注意的地方,后续想起来逐渐完善本文。

三、总结

AOT 发布测试虽然过程中可能会遇到诸多问题,但通过及时的测试和正确的配置调整,最终能够实现项目的顺利发布。希望以上总结的经验能对大家在 AOT 使用过程中有所帮助,让大家在开发过程中少走弯路,提高项目的开发效率和质量。同时,也期待大家在实践中不断探索和总结,共同推动技术的进步和发展。

AOT可参考项目:

前言

Vue3.5响应式重构主要分为两部分:
双向链表

版本计数
。在上一篇文章中我们讲了
双向链表
,这篇文章我们接着来讲
版本计数

欧阳年底也要毕业了,加入欧阳的面试交流群(分享内推信息)、高质量vue源码交流群

版本计数

看这篇文章之前最好先看一下欧阳之前写的
双向链表
文章,不然有些部分可能看着比较迷茫。

在上篇
双向链表
文章中我们知道了新的响应式模型中主要分为三个部分:
Sub订阅者

Dep依赖

Link节点

  • Sub订阅者
    :主要有watchEffect、watch、render函数、computed等。

  • Dep依赖
    :主要有ref、reactive、computed等响应式变量。

  • Link节点
    :连接
    Sub订阅者

    Dep依赖
    之间的桥梁,
    Sub订阅者
    想访问
    Dep依赖
    只能通过
    Link节点
    ,同样
    Dep依赖
    想访问
    Sub订阅者
    也只能通过
    Link节点

细心的小伙伴可能发现了computed计算属性不仅是
Sub订阅者
还是
Dep依赖

原因是computed可以像
watchEffect
那样监听里面的响应式变量,当响应式变量改变后会触发computed的回调。

还可以将computed的返回值当做ref那样的普通响应式变量去使用,
所以我们才说computed不仅是Sub订阅者还是Dep依赖。

版本计数
中由4个version实现,分别是:全局变量
globalVersion

dep.version

link.version

computed.globalVersion

  • globalVersion
    是一个全局变量,初始值为
    0
    ,仅有响应式变量改变后才会触发
    globalVersion++

  • dep.version
    是在
    dep
    依赖上面的一个属性,初始值是0。当dep依赖是ref这种普通响应式变量,仅有响应式变量改变后才会触发
    dep.version++
    。当computed计算属性作为dep依赖时,只有等computed最终计算出来的值改变后才会触发
    dep.version++

  • link.version
    是Link节点上面的一个属性,初始值是0。每次响应式更新完了后都会保持和
    dep.version
    的值相同。在响应式更新前就是通过
    link.version

    dep.version
    的值是否相同判断是否需要更新。

  • computed.globalVersion
    :计算属性上面的版本,如果
    computed.globalVersion === globalVersion
    说明没有响应式变量改变,计算属性的回调就不需要重新执行。

而版本计数最大的受益者就是computed计算属性,这篇文章接下来我们将以computed举例讲解。

看个例子

我们来看个简单的demo,代码如下:

<template>
  <p>{{ doubleCount }}</p>
  <button @click="flag = !flag">切换flag</button>
  <button @click="count1++">count1++</button>
  <button @click="count2++">count2++</button>
</template>

<script setup>
import { computed, ref } from "vue";
const count1 = ref(1);
const count2 = ref(10);
const flag = ref(true);

const doubleCount = computed(() => {
  console.log("computed");
  if (flag.value) {
    return count1.value * 2;
  } else {
    return count2.value * 2;
  }
});
</script>

在computed中根据
flag.value
的值去决定到底返回
count1.value * 2
还是
count2.value * 2

那么问题来了,当
flag
的值为
true
时,点击
count2++
按钮,
console.log("computed")
会执行打印吗?也就是
doubleCount
的值会重新计算吗?

答案是:
不会
。虽然
count2
也是computed中使用到的响应式变量,但是他不参与返回值的计算,所以改变他不会导致computed重新计算。

有的同学想问为什么能够做到这么精细的控制呢?这就要归功于
版本计数
了,我们接下来会细讲。

依赖触发

还是前面那个demo,初始化时
flag
的值是true,所以在computed中会对
count1
变量进行读操作,然后触发get拦截。
count1
这个ref响应式变量就是由
RefImpl
类new出来的一个对象,代码如下:

class RefImpl {
  dep: Dep = new Dep();
  get value() {
    this.dep.track()
  }
  set value() {
    this.dep.trigger();
  }
}

在get拦截中会执行
this.dep.track()
,其中
dep
是由
Dep
类new出来的对象,代码如下

class Dep {
  version = 0;
  track() {
    let link = new Link(activeSub, this);
    // ...省略
  }
  trigger() {
    this.version++;
    globalVersion++;
    this.notify();
  }
}


track
方法中使用
Link
类new出来一个link对象,
Link
类代码如下:

class Link {
  version: number

  /**
   * Pointers for doubly-linked lists
   */
  nextDep?: Link
  prevDep?: Link
  nextSub?: Link
  prevSub?: Link
  prevActiveLink?: Link

  constructor(
    public sub: Subscriber,
    public dep: Dep,
  ) {
    this.version = dep.version
    this.nextDep =
      this.prevDep =
      this.nextSub =
      this.prevSub =
      this.prevActiveLink =
        undefined
  }
}

这里我们只关注Link中的
version
属性,其他的属性在上一篇双向链表文章中已经讲过了。


constructor
中使用
dep.version

link.version
赋值,保证
dep.version

link.version
的值是相等的,也就是等于0。因为
dep.version
的初始值是0,接着就会讲。

当我们点击
count1++
按钮时会让响应式变量
count1
的值自增。因为
count1
是一个ref响应式变量,所以会触发其set拦截。代码如下:

class RefImpl {
  dep: Dep = new Dep();
  get value() {
    this.dep.track()
  }
  set value() {
    this.dep.trigger();
  }
}

在set拦截中执行的是
this.dep.trigger()

trigger
函数代码如下:

class Dep {
  version = 0;
  track() {
    let link = new Link(activeSub, this);
    // ...省略
  }
  trigger() {
    this.version++;
    globalVersion++;
    this.notify();
  }
}

前面讲过了
globalVersion
是一个全局变量,初始值为0。

dep上面的
version
属性初始值也是0。


trigger
中分别执行了
this.version++

globalVersion++
,这里的this就是指向的dep。执行完后
dep.version

globalVersion
的值就是1了。而此时
link.version
的值依然还是0,这个时候
dep.version

link.version
的值就已经不相等了。

接着就是执行
notify
方法按照新的响应式模型进行通知订阅者进行更新,我们这个例子此时新的响应式模型如下图:
reactive

如果修改的响应式变量会触发多个订阅者,比如
count1
变量被多个
watchEffect
使用,修改
count1
变量的值就需要触发多个订阅者的更新。
notify
方法中正是将多个更新操作放到一个批次中处理,从而提高性能。由于篇幅有限我们就不去细讲
notify
方法的内容,你只需要知道执行
notify
方法就会触发订阅者的更新。

(这两段是
notify
方法内的逻辑)按照正常的逻辑如果
count1
变量的值改变,就可以通过
Link2
节点找到
Sub1
订阅者,然后执行订阅者的
notify
方法从而进行更新。

如果我们的
Sub1
订阅者是render函数,是这个正常的逻辑。但是此时我们的
Sub1
订阅者是计算属性
doubleCount
,这里会有一个优化,如果订阅者是一个计算属性,触发其更新时不会直接执行计算属性的回调函数,而是直接去通知计算属性的订阅者去更新,在更新前才会去执行计算属性的回调函数(这个接下来的文章会讲)。代码如下:

if (link.sub.notify()) {
  // if notify() returns `true`, this is a computed. Also call notify
  // on its dep - it's called here instead of inside computed's notify
  // in order to reduce call stack depth.
  link.sub.dep.notify()
}

link.sub.notify()
的执行结果是true就代表当前的订阅者是计算属性,然后就会触发计算属性“作为依赖”时对应的订阅者。我们这里的计算属性
doubleCount
是在template中使用,所以计算属性
doubleCount
的订阅者就是render函数。

所以这里就是调用
link.sub.notify()
不会触发计算属性
doubleCount
中的回调函数重新执行,而是去触发计算属性
doubleCount
的订阅者,也就是render函数。在执行render函数之前会再去通过
脏检查
(依靠版本计数实现)去判断是否需要重新执行计算属性的回调,如果需要执行计算属性的回调那么就去执行render函数重新渲染。

脏检查

所有的
Sub订阅者
内部都是基于
ReactiveEffect
类去实现的,调用订阅者的
notify
方法通知更新实际底层就是在调用
ReactiveEffect
类中的
runIfDirty
方法。代码如下:

class ReactiveEffect<T = any> implements Subscriber, ReactiveEffectOptions {
  /**
   * @internal
   */
  runIfDirty(): void {
    if (isDirty(this)) {
      this.run();
    }
  }
}


runIfDirty
方法中首先会调用
isDirty
方法判断当前是否需要更新,如果返回true,那么就执行
run
方法去执行Sub订阅者的回调函数进行更新。如果是
computed

watch

watchEffect
等订阅者调用run方法就会执行其回调函数,如果是render函数这种订阅者调用run方法就会再次执行render函数。

调用
isDirty
方法时传入的是this,值得注意的是this是指向
ReactiveEffect
实例。而
ReactiveEffect
又是继承自
Subscriber
订阅者,所以这里的this是指向的是订阅者。

前面我们讲过了,修改响应式变量
count1
的值时会通知
作为订阅者

doubleCount
计算属性。当通知
作为订阅者
的计算属性更新时不会去像watchEffect这样的订阅者一样去执行其回调,而是去通知计算属性
作为Dep依赖
时订阅他的订阅者进行更新。在这里计算属性
doubleCount
是在template中使用,所以他的订阅者是render函数。

所以修改count1变量执行runIfDirty时此时触发的订阅者是作为Sub订阅者的render函数,也就是说此时的this是render函数!!

我们来看看
isDirty
是如何进行脏检查,代码如下:

function isDirty(sub: Subscriber): boolean {
  for (let link = sub.deps; link; link = link.nextDep) {
    if (
      link.dep.version !== link.version ||
      (link.dep.computed &&
        (refreshComputed(link.dep.computed) ||
          link.dep.version !== link.version))
    ) {
      return true;
    }
  }
  return false;
}

这里就涉及到我们上一节讲过的双向链表了,回顾一下前面讲过的响应式模型图,如下图:
reactive
此时的sub订阅者是render函数,也就是图中的
Sub2

sub.deps
是指向指向
Sub2
订阅者X轴(横向)上面的Link节点组成的队列的头部,
link.nextDep
就是指向X轴上面下一个Link节点,通过Link节点就可以访问到对应的Dep依赖。

在这里render函数对应的订阅者
Sub2
在X轴上面只有一个节点
Link3

这里的for循环就是去便利Sub订阅者在X轴上面的所有Link节点,然后在for循环内部去通过Link节点访问到对应的Dep依赖去做版本计数的判断。

这里的for循环内部的if语句判断主要分为两部分:

 if (
  link.dep.version !== link.version ||
  (link.dep.computed &&
    (refreshComputed(link.dep.computed) ||
      link.dep.version !== link.version))
) {
  return true;
}

这两部分中只要有一个是true,那么就说明当前Sub订阅者需要更新,也就是执行其回调。

我们来看看第一个判断:

link.dep.version !== link.version

还记得我们前面讲过吗,初始化时会保持
dep.version

link.version
的值相同。每次响应式变量改变时走到set拦截中,在拦截中会去执行
dep.version++
,执行完了后此时
dep.version

link.version
的值就已经不相同了,在这里就能知道此时响应式变量改变过了,需要通知Sub订阅者更新执行其回调。

常规情况下Dep依赖是一个ref变量、Sub订阅者是wachEffect这种确实第一个判断就可以满足了。

但是我们这里的
link.dep
是计算属性
doubleCount
,计算属性是由
ComputedRefImpl
类new出来的对象,简化后代码如下:

class ComputedRefImpl<T = any> implements Subscriber {
  _value: any = undefined;
  readonly dep: Dep = new Dep(this);
  globalVersion: number = globalVersion - 1;
  get value(): T {
    // ...省略
  }
  set value(newValue) {
    // ...省略
  }
}

ComputedRefImpl
继承了
Subscriber
类,所以说他是一个订阅者。同时还有get和set拦截,以及初始化一个计算属性时也会去new一个对应的Dep依赖。

还有一点值得注意的是计算属性上面的
computed.globalVersion
属性初始值为
globalVersion - 1
,默认是不等于
globalVersion
的,这是为了第一次执行计算属性时能够去触发执行计算属性的回调,这个在后面的
refreshComputed
函数中会讲。

我们是直接修改的
count1
变量,在
count1
变量的set拦截中触发了
dep.version++
,但是并没有修改计算属性对应的
dep.version
。所以当计算属性作为依赖时单纯的使用
link.dep.version !== link.version
就不能满足需求了,需要使用到第二个判断:

(link.dep.computed &&
    (refreshComputed(link.dep.computed) ||
      link.dep.version !== link.version))

在第二个判断中首先判断当前当前的Dep依赖是不是计算属性,如果是就调用
refreshComputed
函数去执行计算属性的回调。然后判断计算属性的结果是否改变,如果改变了在
refreshComputed
函数中就会去执行
link.dep.version++
,所以执行完
refreshComputed
函数后
link.dep.version

link.version
的值就不相同了,表示计算属性的值更新了,当然就需要执行依赖计算属性的render函数啦。

refreshComputed函数

我们来看看
refreshComputed
函数的代码,简化后的代码如下:

function refreshComputed(computed: ComputedRefImpl): undefined {
  if (computed.globalVersion === globalVersion) {
    return;
  }
  computed.globalVersion = globalVersion;

  const dep = computed.dep;
  try {
    prepareDeps(computed);
    const value = computed.fn(computed._value);
    if (dep.version === 0 || hasChanged(value, computed._value)) {
      computed._value = value;
      dep.version++;
    }
  } catch (err) {
    dep.version++;
    throw err;
  } finally {
    cleanupDeps(computed);
  }
}

首先会去判断
computed.globalVersion === globalVersion
是否相等,如果相等就说明根本就没有响应式变量改变,那么当然就无需去重新执行计算属性回调。

还记得我们前面讲过每当响应式变量改变后触发set拦截是都会执行
globalVersion++
吗?所以这里就可以通过
computed.globalVersion === globalVersion
判断是否有响应式变量改变,如果没有说明计算属性的值肯定就没有改变。

接着就是执行
computed.globalVersion = globalVersion

computed.globalVersion
的值同步为
globalVersion
,为了下次判断是否需要重新执行计算属性做准备。

在try中会先去执行
prepareDeps
函数,这个先放放接下来讲,先来看看try中其他的代码。

首先调用
const value = computed.fn(computed._value)
去重新执行计算属性的回调函数拿到计算属性新的返回值
value

接着就是执行
if (dep.version === 0 || hasChanged(value, computed._value))

我们前面讲过了dep上面的version默认值为0,这里的
dep.version === 0
说明是第一次渲染计算属性。接着就是使用
hasChanged(value, computed._value)
判断计算属性新的值和旧的值相比较是否有修改。

上面这两个条件满足一个就执行if里面的内容,将新得到的计算属性的值更新上去,并且执行
dep.version++
。因为前面讲过了在外面会使用
link.dep.version !== link.version
判断dep的版本是否和link上面的版本是否相同,如果不相等就执行render函数。

这里由于计算属性的值确实改变了,所以会执行
dep.version++
,dep的版本和link上面的版本此时就不同了,所以就会被标记为dirty,从而执行render函数。

如果执行计算属性的回调函数出错了,同样也执行一次
dep.version++

最后就是剩余执行计算属性回调函数之前调用的
prepareDeps
和finally调用的
cleanupDeps
函数没讲了。

更新响应式模型

回顾一下demo的代码:

<template>
  <p>{{ doubleCount }}</p>
  <button @click="flag = !flag">切换flag</button>
  <button @click="count1++">count1++</button>
  <button @click="count2++">count2++</button>
</template>

<script setup>
import { computed, ref } from "vue";
const count1 = ref(1);
const count2 = ref(10);
const flag = ref(true);

const doubleCount = computed(() => {
  console.log("computed");
  if (flag.value) {
    return count1.value * 2;
  } else {
    return count2.value * 2;
  }
});
</script>


flag
的值为true时,对应的响应式模型前面我们已经讲过了,如下图:
reactive

如果我们将
flag
的值设置为false呢?此时的计算属性
doubleCount
就不再依赖于响应式变量
count1
,而是依赖于响应式变量
count2
。小伙伴们猜猜此时的响应式模型应该是什么样的呢?
reactive2

现在多了一个
count2
变量对应的
Link4
,原本
Link1

Link2
之间的连接也因为计算属性不再依赖于
count1
变量后,他们俩之间的连接也没有了,转而变成了
Link1

Link4
之间建立连接。

前面没有讲的
prepareDeps

cleanupDeps
函数就是去掉
Link1

Link2
之间的连接。

prepareDeps
函数代码如下:

function prepareDeps(sub: Subscriber) {
  // Prepare deps for tracking, starting from the head
  for (let link = sub.deps; link; link = link.nextDep) {
    // set all previous deps' (if any) version to -1 so that we can track
    // which ones are unused after the run
    link.version = -1
    // store previous active sub if link was being used in another context
    link.prevActiveLink = link.dep.activeLink
    link.dep.activeLink = link
  }
}

这里使用for循环遍历计算属性Sub1在X轴上面的Link节点,也就是Link1和Link2,并且将这些Link节点的
version
属性设置为-1。


flag
的值设置为false后,重新执行计算属性
doubleCount
中的回调函数时,就会对回调函数中的所有响应式变量进行读操作。从而再次触发响应式变量的get拦截,然后执行
track
方法进行依赖收集。注意此时新收集了一个响应式变量
count2
。收集完成后响应式模型图如下图:
reactive3

从上图中可以看到虽然计算属性虽然不再依赖
count1
变量,但是
count1
变量变量对应的
Link2
节点还在队列的连接上。

我们在
prepareDeps
方法中将计算属性依赖的所有Link节点的version属性都设置为-1,在
track
方法收集依赖时会执行这样一行代码,如下:

class Dep {
  track() {
    if (link === undefined || link.sub !== activeSub) {
      // ...省略
    } else if (link.version === -1) {
      link.version = this.version;
      // ...省略
    }
  }
}

如果
link.version === -1
,那么就将
link.version
的值同步为
dep.version
的值。

只有计算属性最新依赖的响应式变量才会触发
track
方法进行依赖收集,从而将对应的
link.version

-1
更新为
dep.version

而变量
count1
现在已经不会触发
track
方法了,所以变量
count1
对应的
link.version
的值还是
-1

最后就是执行
cleanupDeps
函数将
link.version
的值还是-1的响应式变量(也就是不再使用的
count1
变量)对应的Link节点,从双向链表中给干掉。代码如下:

function cleanupDeps(sub: Subscriber) {
  // Cleanup unsued deps
  let head;
  let tail = sub.depsTail;
  let link = tail;
  while (link) {
    const prev = link.prevDep;
    if (link.version === -1) {
      if (link === tail) tail = prev;
      // unused - remove it from the dep's subscribing effect list
      removeSub(link);
      // also remove it from this effect's dep list
      removeDep(link);
    } else {
      // The new head is the last node seen which wasn't removed
      // from the doubly-linked list
      head = link;
    }

    // restore previous active link if any
    link.dep.activeLink = link.prevActiveLink;
    link.prevActiveLink = undefined;
    link = prev;
  }
  // set the new head & tail
  sub.deps = head;
  sub.depsTail = tail;
}

遍历Sub1计算属性横向队列(X轴)上面的Link节点,当
link.version === -1
时,说明这个Link节点对应的Dep依赖已经不被计算属性所依赖了,所以执行
removeSub

removeDep
将其从双向链表中移除。

执行完
cleanupDeps
函数后此时的响应式模型就是我们前面所提到的样子,如下图:
reactive2

总结

版本计数主要有四个版本:全局变量
globalVersion

dep.version

link.version

computed.globalVersion

dep.version

link.version
如果不相等就说明当前响应式变量的值改变了,就需要让Sub订阅者进行更新。

如果是计算属性作为Dep依赖时就不能通过
dep.version

link.version
去判断了,而是执行
refreshComputed
函数进行判断。在
refreshComputed
函数中首先会判断
globalVersion

computed.globalVersion
是否相等,如果相等就说明并没有响应式变量更新。如果不相等那么就会执行计算属性的回调函数,拿到最新的值后去比较计算属性的值是否改变。并且还会执行
prepareDeps

cleanupDeps
函数将那些计算属性不再依赖的响应式变量对应的Link节点从双向链表中移除。

最后说一句,版本计数最大的赢家应该是computed计算属性,虽然引入版本计数后代码更难理解了。但是整体流程更加优雅,以及现在只需要通过判断几个version是否相等就能知道订阅者是否需要更新,性能当然也更好了。

关注公众号:【前端欧阳】,给自己一个进阶vue的机会

另外欧阳写了一本开源电子书
vue3编译原理揭秘
,看完这本书可以让你对vue编译的认知有质的提升。这本书初、中级前端能看懂,完全免费,只求一个star。