2024年4月

前言

对于后端开发同学来说,访问数据库,是代码中必不可少的一个环节。

系统中收集到用户的核心数据,为了安全性,我们一般会存储到数据库,比如:mysql,oracle等。

后端开发的日常工作,需要不断的建库和建表,来满足业务需求。

通常情况下,建库的频率比建表要低很多,所以,我们这篇文章主要讨论建表相关的内容。

如果我们在建表的时候不注意细节,等后面系统上线之后,表的维护成本变得非常高,而且很容易踩坑。

今天就跟大家一起聊聊,数据库建表的18个小技巧。

文章中介绍的很多细节,我在工作中踩过坑,并且实践过的,非常有借鉴意义,希望对你会有所帮助。

1.名字

建表的时候,给


字段

索引
起个好名字,真的太重要了。

1.1 见名知意

名字就像


字段

索引
的一张脸,可以给人留下第一印象。

好的名字,言简意赅,见名知意,让人心情愉悦,能够提高沟通和维护成本。

坏的名字,模拟两可,不知所云。而且显得杂乱无章,看得让人抓狂。

反例:

用户名称字段定义成:yong_hu_ming、用户_name、name、user_name_123456789

你看了可能会一脸懵逼,这是什么骚操作?

正例:

用户名称字段定义成:user_name

温馨提醒一下,名字也不宜过长,尽量控制在
30
个字符以内。

1.2 大小写

名字尽量都用
小写字母
,因为从视觉上,小写字母更容易让人读懂。

反例:

字段名:PRODUCT_NAME、PRODUCT_name

全部大写,看起来有点不太直观。而一部分大写,一部分小写,让人看着更不爽。

正例:

字段名:product_name

名字还是使用全小写字母,看着更舒服。

1.3 分隔符

很多时候,名字为了让人好理解,有可能会包含多个单词。

那么,多个单词间的
分隔符
该用什么呢?

反例:

字段名:productname、productName、product name、product@name

单词间没有分隔,或者单词间用驼峰标识,或者单词间用空格分隔,或者单词间用@分隔,这几种方式都不太建议。

正例:

字段名:product_name

强烈建议大家在单词间用
_
分隔。

1.4 表名

对于表名,在言简意赅,见名知意的基础之上,建议带上
业务前缀

如果是订单相关的业务表,可以在表名前面加个前缀:
order_

例如:order_pay、order_pay_detail等。

如果是商品相关的业务表,可以在表名前面加个前缀:
product_

例如:product_spu,product_sku等。

这样做的好处是为了方便归类,把相同业务的表,可以非常快速的聚集到一起。

另外,还有有个好处是,如果哪天有非订单的业务,比如:金融业务,也需要建一个名字叫做pay的表,可以取名:finance_pay,就能非常轻松的区分。

这样就不会出现
同名表
的情况。

1.5 字段名称

字段名称
是开发人员发挥空间最大,但也最容易发生混乱的地方。

比如有些表,使用flag表示状态,另外的表用status表示状态。

可以统一一下,使用status表示状态。

如果一个表使用了另一个表的主键,可以在另一张表的名后面,加
_id

_sys_no
,例如:

在product_sku表中有个字段,是product_spu表的主键,这时候可以取名:product_spu_id或product_spu_sys_no。

还有创建时间,可以统一成:create_time,修改时间统一成:update_time。

删除状态固定为:delete_status。

其实还有很多公共字段,在不同的表之间,可以使用全局统一的命名规则,定义成相同的名称,以便于大家好理解。

1.6 索引名

在数据库中,索引有很多种,包括:主键、普通索引、唯一索引、联合索引等。

每张表的主键只有一个,一般使用:
id
或者
sys_no
命名。

普通索引和联合索引,其实是一类。在建立该类索引时,可以加
ix_
前缀,比如:ix_product_status。

唯一索引,可以加
ux_
前缀,比如:ux_product_code。

2.字段类型

在设计表时,我们在选择
字段类型
时,可发挥空间很大。

时间格式的数据有:date、datetime和timestamp等等可以选择。

字符类型的数据有:varchar、char、text等可以选择。

数字类型的数据有:int、bigint、smallint、tinyint等可以选择。

说实话,选择很多,有时候是一件好事,也可能是一件坏事。

如何选择一个
合适
的字段类型,变成了我们不得不面对的问题。

如果字段类型选大了,比如:原本只有1-10之间的10个数字,结果选了
bigint
,它占
8
个字节。

其实,1-10之间的10个数字,每个数字
1
个字节就能保存,选择
tinyint
更为合适。

这样会白白浪费7个字节的空间。

如果字段类型择小了,比如:一个18位的id字段,选择了
int
类型,最终数据会保存失败。

所以选择一个合适的字段类型,还是非常重要的一件事情。

以下原则可以参考一下:

  1. 尽可能选择占用存储空间小的字段类型,在满足正常业务需求的情况下,从小到大,往上选。
  2. 如果字符串长度固定,或者差别不大,可以选择char类型。如果字符串长度差别较大,可以选择varchar类型。
  3. 是否字段,可以选择bit类型。
  4. 枚举字段,可以选择tinyint类型。
  5. 主键字段,可以选择bigint类型。
  6. 金额字段,可以选择decimal类型。
  7. 时间字段,可以选择timestamp或datetime类型。

3.字段长度

前面我们已经定义好了
字段名称
,选择了合适的
字段类型
,接下来,需要重点关注的是
字段长度
了。

比如:varchar(20),biginit(20)等。

那么问题来了,
varchar
代表的是
字节
长度,还是
字符
长度呢?

答:在mysql中除了
varchar

char
是代表
字符
长度之外,其余的类型都是代表
字节
长度。

biginit(n) 这个
n
表示什么意思呢?

假如我们定义的字段类型和长度是:bigint(4),bigint实际长度是
8
个字节。

现在有个数据a=1,a显示4个字节,所以在不满4个字节时前面填充0(前提是该字段设置了zerofill属性),比如:0001。

当满了4个字节时,比如现在数据是a=123456,它会按照实际的长度显示,比如:123456。

但需要注意的是,有些mysql客户端即使满了4个字节,也可能只显示4个字节的内容,比如会显示成:1234。

所以bigint(4),这里的4表示显示的长度为4个字节,实际长度还是占8个字节。

4.字段个数

我们在建表的时候,一定要对
字段个数
做一些限制。

我之前见过有人创建的表,有几十个,甚至上百个字段,表中保存的数据非常大,查询效率很低。

如果真有这种情况,可以将一张
大表
拆成多张
小表
,这几张表的主键相同。

建议每表的字段个数,不要超过
20
个。

5. 主键

在创建表时,一定要创建
主键

因为主键自带了主键索引,相比于其他索引,主键索引的查询效率最高,因为它不需要回表。

此外,主键还是天然的
唯一索引
,可以根据它来判重。


单个
数据库中,主键可以通过
AUTO_INCREMENT
,设置成
自动增长
的。

但在
分布式
数据库中,特别是做了分库分表的业务库中,主键最好由外部算法(比如:雪花算法)生成,它能够保证生成的id是全局唯一的。

除此之外,主键建议保存跟业务无关的值,减少业务耦合性,方便今后的扩展。

不过我也见过,有些一对一的表关系,比如:用户表和用户扩展表,在保存数据时是一对一的关系。

这样,用户扩展表的主键,可以直接保存用户表的主键。

6.存储引擎


mysql8
以前的版本,默认的存储引擎是
myslam
,而
mysql8
以后的版本,默认的存储引擎变成了
innodb

之前我们还在创建表时,还一直纠结要选哪种存储引擎?

myslam
的索引和数据分开存储,而有利于查询,但它不支持事务和外键等功能。


innodb
虽说查询性能,稍微弱一点,但它支持事务和外键等,功能更强大一些。

以前的建议是:读多写少的表,用myslam存储引擎。而写多读多的表,用innodb。

但虽说mysql对innodb存储引擎性能的不断优化,现在myslam和innodb查询性能相差已经越来越小。

所以,建议我们在使用
mysql8
以后的版本时,直接使用默认的
innodb
存储引擎即可,无需额外修改存储引擎。

7. NOT NULL

在创建字段时,需要选择该字段是否允许为
NULL

我们在定义字段时,应该尽可能明确该字段
NOT NULL

为什么呢?

我们主要以innodb存储引擎为例,myslam存储引擎没啥好说的。

主要有以下原因:

  1. 在innodb中,需要额外的空间存储null值,需要占用更多的空间。
  2. null值可能会导致索引失效。
  3. null值只能用
    is null
    或者
    is not null
    判断,用
    =号
    判断永远返回false。

因此,建议我们在定义字段时,能定义成NOT NULL,就定义成NOT NULL。

但如果某个字段直接定义成NOT NULL,万一有些地方忘了给该字段写值,就会
insert
不了数据。

这也算合理的情况。

但有一种情况是,系统有新功能上线,新增了字段。上线时一般会先执行sql脚本,再部署代码。

由于老代码中,不会给新字段赋值,则insert数据时,也会报错。

由此,非常有必要给NOT NULL的字段设置默认值,特别是后面新增的字段。

例如:

alter table product_sku add column  brand_id int(10) not null default 0;

8.外键

在mysql中,是存在
外键
的。

外键存在的主要作用是:保证数据的
一致性

完整性

例如:

create table class (
  id int(10) primary key auto_increment,
  cname varchar(15)
);

有个班级表class。

然后有个student表:

create table student(
  id int(10) primary key auto_increment,
  name varchar(15) not null,
  gender varchar(10) not null,
  cid int,
  foreign key(cid) references class(id)
);

其中student表中的cid字段,保存的class表的id,这时通过
foreign key
增加了一个外键。

这时,如果你直接通过student表的id删除数据,会报异常:

a foreign key constraint fails

必须要先删除class表对于的cid那条数据,再删除student表的数据才行,这样能够保证数据的一致性和完整性。

顺便说一句:只有存储引擎是innodb时,才能使用外键。

如果只有两张表的关联还好,但如果有十几张表都建了外键关联,每删除一次主表,都需要同步删除十几张子表,很显然性能会非常差。

因此,互联网系统中,一般建议不使用外键。因为这类系统更多的是为了性能考虑,宁可牺牲一点数据一致性和完整性。

除了
外键
之外,
存储过程

触发器
也不太建议使用,他们都会影响性能。

9. 索引

在建表时,除了指定
主键索引
之外,还需要创建一些
普通索引

例如:

create table product_sku(
  id int(10) primary key auto_increment,
  spu_id int(10) not null,
  brand_id int(10) not null,
  name varchar(15) not null
);

在创建商品表时,使用spu_id(商品组表)和brand_id(品牌表)的id。

像这类保存其他表id的情况,可以增加普通索引:

create table product_sku (
  id int(10) primary key auto_increment,
  spu_id int(10) not null,
  brand_id int(10) not null,
  name varchar(15) not null,
	 KEY `ix_spu_id` (`spu_id`) USING BTREE,
	 KEY `ix_brand_id` (`brand_id`) USING BTREE
);

后面查表的时候,效率更高。

但索引字段也不能建的太多,可能会影响保存数据的效率,因为索引需要额外的存储空间。

建议单表的索引个数不要超过:
5
个。

如果在建表时,发现索引个数超过5个了,可以删除部分
普通索引
,改成
联合索引

顺便说一句:在创建联合索引的时候,需要使用注意
最左匹配原则
,不然,建的联合索引效率可能不高。

对于数据重复率非常高的字段,比如:状态,不建议单独创建普通索引。因为即使加了索引,如果mysql发现
全表扫描
效率更高,可能会导致索引失效。

如果你对索引失效问题比较感兴趣,可以看看我的另一篇文章《
聊聊索引失效的10种场景,太坑了
》,里面有非常详细的介绍。

10.时间字段

时间字段
的类型,我们可以选择的范围还是比较多的,目前mysql支持:date、datetime、timestamp、varchar等。

varchar
类型可能是为了跟接口保持一致,接口中的时间类型是String。

但如果哪天我们要通过时间范围查询数据,效率会非常低,因为这种情况没法走索引。

date
类型主要是为了保存
日期
,比如:2020-08-20,不适合保存
日期和时间
,比如:2020-08-20 12:12:20。


datetime

timestamp
类型更适合我们保存
日期和时间

但它们有略微区别。

  • timestamp
    :用4个字节来保存数据,它的取值范围为
    1970-01-01 00:00:01
    UTC ~
    2038-01-19 03:14:07
    。此外,它还跟时区有关。

  • datetime
    :用8个字节来保存数据,它的取值范围为
    1000-01-01 00:00:00
    ~
    9999-12-31 23:59:59
    。它跟时区无关。

优先推荐使用
datetime
类型保存日期和时间,可以保存的时间范围更大一些。

温馨提醒一下,在给时间字段设置默认值是,建议不要设置成:
0000-00-00 00:00:00
,不然查询表时可能会因为转换不了,而直接报错。

11.金额字段

mysql中有多个字段可以表示浮点数:float、double、decimal等。


float

double
可能会丢失精度,因此推荐大家使用
decimal
类型保存金额。

一般我们是这样定义浮点数的:decimal(m,n)。

其中
n
是指
小数
的长度,而
m
是指
整数加小数
的总长度。

假如我们定义的金额类型是这样的:decimal(10,2),则表示整数长度是8位,并且保留2位小数。

12. json字段

我们在设计表结构时,经常会遇到某个字段保存的数据值不固定的需求。

举个例子,比如:做异步excel导出功能时,需要在异步任务表中加一个字段,保存用户通过前端页面选择的查询条件,每个用户的查询条件可能都不一样。

这种业务场景,使用传统的数据库字段,不太好实现。

这时候就可以使用MySQL的json字段类型了,可以保存json格式的结构化数据。

保存和查询数据都是非常方便的。

MySQL还支持按字段名称或者字段值,查询json中的数据。

最近就业形式比较困难,为了感谢各位小伙伴对苏三一直以来的支持,我特地创建了一些工作内推群, 看看能不能帮助到大家。
你可以在群里发布招聘信息,也可以内推工作,也可以在群里投递简历找工作,也可以在群里交流面试或者工作的话题。

进群方式,添加苏三的私人微信:su_san_java,备注:博客园+所在城市,即可加入。

13.唯一索引

唯一索引
在我们实际工作中,使用频率相当高。

你可以给单个字段,加唯一索引,比如:组织机构code。

也可以给多个字段,加一个联合的唯一索引,比如:分类编号、单位、规格等。

单个的唯一索引还好,但如果是联合的唯一索引,字段值出现null时,则唯一性约束可能会失效。

关于唯一索引失效的问题,感兴趣的小伙伴可以看看我的另一篇文章《
明明加了唯一索引,为什么还是产生重复数据?
》。

创建唯一索引时,相关字段一定不能包含null值,否则唯一性会失效。

14.字符集

mysql中支持的
字符集
有很多,常用的有:latin1、utf-8、utf8mb4、GBK等。

这4种字符集情况如下:

latin1
容易出现乱码问题,在实际项目中使用比较少。


GBK
支持中文,但不支持国际通用字符,在实际项目中使用也不多。

从目前来看,mysql的字符集使用最多的还是:
utf-8

utf8mb4

其中
utf-8
占用3个字节,比
utf8mb4
的4个字节,占用更小的存储空间。

但utf-8有个问题:即无法存储emoji表情,因为emoji表情一般需要4个字节。

由此,使用utf-8字符集,保存emoji表情时,数据库会直接报错。

所以,建议在建表时字符集设置成:
utf8mb4
,会省去很多不必要的麻烦。

15. 排序规则

不知道,你关注过没,在mysql中创建表时,有个
COLLATE
参数可以设置。

例如:

CREATE TABLE `order` (
  `id` bigint NOT NULL AUTO_INCREMENT,
  `code` varchar(20) COLLATE utf8mb4_bin NOT NULL,
  `name` varchar(30) COLLATE utf8mb4_bin NOT NULL,
  PRIMARY KEY (`id`),
  UNIQUE KEY `un_code` (`code`),
  KEY `un_code_name` (`code`,`name`) USING BTREE,
  KEY `idx_name` (`name`)
) ENGINE=InnoDB AUTO_INCREMENT=5 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin

它是用来设置
排序规则
的。

字符排序规则跟字符集有关,比如:字符集如果是
utf8mb4
,则字符排序规则也是以:
utf8mb4_
开头的,常用的有:
utf8mb4_general_ci

utf8mb4_bin
等。

其中utf8mb4_general_ci排序规则,对字母的大小写不敏感。说得更直白一点,就是不区分大小写。

而utf8mb4_bin排序规则,对字符大小写敏感,也就是区分大小写。

说实话,这一点还是非常重要的。

假如order表中现在有一条记录,name的值是大写的YOYO,但我们用小写的yoyo去查,例如:

select * from order where name='yoyo';

如果字符排序规则是utf8mb4_general_ci,则可以查出大写的YOYO的那条数据。

如果字符排序规则是utf8mb4_bin,则查不出来。

由此,字符排序规则一定要根据实际的业务场景选择,否则容易出现问题。

16.大字段

我们在创建表时,对一些特殊字段,要额外关注,比如:
大字段
,即占用较多存储空间的字段。

比如:用户的评论,这就属于一个大字段,但这个字段可长可短。

但一般会对评论的总长度做限制,比如:最多允许输入500个字符。

如果直接定义成
text
类型,可能会浪费存储空间,所以建议将这类字段定义成
varchar
类型的存储效率更高。

当然,我还见过更大的字段,即该字段直接保存合同数据。

一个合同可能会占
几Mb

在mysql中保存这种数据,从系统设计的角度来说,本身就不太合理。

像合同这种非常大的数据,可以保存到
mongodb
中,然后在mysql的业务表中,保存mongodb表的id。

17.冗余字段

我们在设计表的时候,为了性能考虑,提升查询速度,有时可以冗余一些字段。

举个例子,比如:订单表中一般会有userId字段,用来记录用户的唯一标识。

但很多订单的查询页面,或者订单的明细页面,除了需要显示订单信息之外,还需要显示用户ID和用户名称。

如果订单表和用户表的数据量不多,我们可以直接用userId,将这两张表join起来,查询出用户名称。

但如果订单表和用户表的数据量都非常多,这样join是比较消耗查询性能的。

这时候我们可以通过冗余字段的方案,来解决性能问题。

我们可以在订单表中,可以再加一个userName字段,在系统创建订单时,将userId和userName同时写值。

当然订单表中历史数据的userName是空的,可以刷一下历史数据。

这样调整之后,后面只需要查询订单表,即可查询出我们所需要的数据。

不过冗余字段的方案,有利也有弊。

对查询性能有利。

但需要额外的存储空间,还可能会有数据不一致的情况,比如用户名称修改了。

我们在实际业务场景中,需要综合评估,冗余字段方案不适用于所有业务场景。

18.注释

我们在做表设计的时候,一定要把表和相关字段的注释加好。

例如下面这样的:

CREATE TABLE `sys_dept` (
  `id` bigint NOT NULL AUTO_INCREMENT COMMENT 'ID',
  `name` varchar(30) NOT NULL COMMENT '名称',
  `pid` bigint NOT NULL COMMENT '上级部门',
  `valid_status` tinyint(1) NOT NULL DEFAULT 1 COMMENT '有效状态 1:有效 0:无效',
  `create_user_id` bigint NOT NULL COMMENT '创建人ID',
  `create_user_name` varchar(30) NOT NULL COMMENT '创建人名称',
  `create_time` datetime(3) DEFAULT NULL COMMENT '创建日期',
  `update_user_id` bigint DEFAULT NULL COMMENT '修改人ID',
  `update_user_name` varchar(30)  DEFAULT NULL COMMENT '修改人名称',
  `update_time` datetime(3) DEFAULT NULL COMMENT '修改时间',
  `is_del` tinyint(1) DEFAULT '0' COMMENT '是否删除 1:已删除 0:未删除',
  PRIMARY KEY (`id`) USING BTREE,
  KEY `index_pid` (`pid`) USING BTREE
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='部门';

表和字段的注释,都列举的非常详细。

特别是有些状态类型的字段,比如:valid_status字段,该字段表示有效状态, 1:有效 0:无效。

让人可以一目了然,表和字段是干什么用的,字段的值可能有哪些。

最怕的情况是,你在表中创建了很多status字段,每个字段都有1、2、3、4、5、6、7、8、9等多个值。

没有写什么注释。

谁都不知道1代表什么含义,2代表什么含义,3代表什么含义。

可能刚开始你还记得。

但系统上线使用一年半载之后,可能连你自己也忘记了这些status字段,每个值的具体含义了,埋下了一个巨坑。

由此,我们在做表设计时,一定要写好相关的注释,并且经常需要更新这些注释。

最后说一句(求关注,别白嫖我)

如果这篇文章对您有所帮助,或者有所启发的话,帮忙扫描下发二维码关注一下,您的支持是我坚持写作最大的动力。
求一键三连:点赞、转发、在看。
关注苏三的公众号:【苏三说技术】,在公众号中回复:面试、代码神器、开发手册、时间管理有超赞的粉丝福利,另外回复:加群,可以跟很多BAT大厂的前辈交流和学习。

1. 背景

根据本qiang~最新的趋势观察,基于MoE架构的开源大模型越来越多,比如马斯克的Grok-1(314B), Qwen1.5-MoE-A2.7B等,因此想探究一下MoE里面的部分细节。

此文是本qiang~
针对大语言模型的MoE的整理,包括原理、流程及部分源码

2. MoE原理

MoE的流行源于”欧洲的OpenAI” Mistral AI发布的论文及模型《Mixtral of Experts》,评测集上的效果吊打众多开源模型,如Llama 2 70B和GPT3.5。

《Mixtral of Experts》基础模型使用的是Mistral AI自研的Mistral 7B,该模型的特点包括:滑窗注意力(Sliding Window Aattention), 滚动缓冲区缓存(Rolling Buffer Cache)以及预填充-分块(Pre-fill and Chunking),具体细节可以查阅文末的论文地址。

本文以《Mixtral of Experts》为引子,探究MoE的相关细节,MoE的原理如下图所示:

图2.1 MoE的原理

(1) Transformers架构中的每一层中的FFN网络均替换为了8个FFN(专家),且由一个网关路由(gate router)进行控制

(2) 针对每一个token,每一层的网关路由仅选择其中的2个FFN(专家)来处理当前状态并进行加权输出

(3) 结果就是,每一个token访问了47B参数,但是在推理阶段仅仅使用了13B的激活参数(即,只使用2个专家,冻结其他6个专家)。

(4) 与Dropout机制对比,Dropout让部分神经元失活,而MoE是让部分专家失活。

3. 源码

本qiang~研读并尝试执行了Mistral官网的github推理代码,该代码框架非常适合新手,无他,只因其几乎只是在torch上层做的封装,很少引擎其他第三方库,不像transformers,功能强大,但不适合新手研读代码…

为了普适性,下面的代码截取了transformers框架中的代码。

首先看下通用Transformers中FFN中的代码模块,代码位置在transformers.models.mistral.modeling_mistral, 主要流程是:

(1) 先经过gate_proj和up_proj的2个[hidden_size, intermediate_size]的线性转换

(2) 使用激活函数对gate_proj进行激活

(3) 二者的内积再经过down_proj线性转换。

1 classMistralMLP(nn.Module):2     def __init__(self, config):3         super().__init__()4         self.config =config5         self.hidden_size =config.hidden_size6         self.intermediate_size =config.intermediate_size7         self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)8         self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)9         self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)10         self.act_fn =ACT2FN[config.hidden_act]11 
12     defforward(self, x):13         return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

再来看下MoE中的专家模块,代码位置在transformers.models.mixtral.modeling_mixtral,主要流程是:

(1) 首先经过网关路由self.gate

(2) 然后选择其中2个专家,并归一化

(3) 之后遍历每个专家网络,并按照expert_mask进行筛选

(4) 如果expert_mask有值,则选择指定部分的隐藏层进行FFN操作,且输出结果进行加权

(5) 最后原地增加先前初始化的最终结果变量final_hidden_states

classMixtralSparseMoeBlock(nn.Module):def __init__(self, config):
super().
__init__()
self.hidden_dim
=config.hidden_size
self.ffn_dim
=config.intermediate_size
self.num_experts
=config.num_local_experts
self.top_k
=config.num_experts_per_tok#gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

self.experts
= nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ inrange(self.num_experts)])def forward(self, hidden_states: torch.Tensor) ->torch.Tensor:""" """batch_size, sequence_length, hidden_dim=hidden_states.shape
hidden_states
= hidden_states.view(-1, hidden_dim)#router_logits: (batch * sequence_length, n_experts) router_logits =self.gate(hidden_states)

routing_weights
= F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts
= torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights
/= routing_weights.sum(dim=-1, keepdim=True)#we cast back to the input dtype routing_weights =routing_weights.to(hidden_states.dtype)

final_hidden_states
=torch.zeros(
(batch_size
* sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
#One hot encode the selected experts to create an expert mask #this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)#Loop over all available experts in the model and perform the computation on each expert for expert_idx inrange(self.num_experts):
expert_layer
=self.experts[expert_idx]
idx, top_x
=torch.where(expert_mask[expert_idx])if top_x.shape[0] ==0:continue #in torch it is faster to index using lists than torch tensors top_x_list =top_x.tolist()
idx_list
=idx.tolist()#Index the correct hidden states and compute the expert hidden state for #the current expert. We need to make sure to multiply the output hidden #states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states
= expert_layer(current_state) *routing_weights[top_x_list, idx_list, None]#However `index_add_` only support torch tensors for indexing so we'll use #the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states
=final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)return final_hidden_states, router_logits

其中MixtralBlockSparseTop2MLP代码如下,可以看到和传统MistralMLP内容完全一致。

classMixtralBlockSparseTop2MLP(nn.Module):def __init__(self, config: MixtralConfig):
super().
__init__()
self.ffn_dim
=config.intermediate_size
self.hidden_dim
=config.hidden_size

self.w1
= nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2
= nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3
= nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)

self.act_fn
=ACT2FN[config.hidden_act]defforward(self, hidden_states):
current_hidden_states
= self.act_fn(self.w1(hidden_states)) *self.w3(hidden_states)
current_hidden_states
=self.w2(current_hidden_states)return current_hidden_states

4. MoE微调

由于MoE只是将每一层的FFN改变为了每一层的gate网关路由+8个FFN专家,且gate网关路由和8个专家内部均为线性运算,所以可以无缝地结合LoRA、QLoRA进行指令微调。

可以参考开源项目:https://github.com/yangjianxin1/Firefly

5. 答疑解惑

(1) 问:MoE 8*7B的模型是56B参数?

答:MoE 8*7B的参数量是47B,而不是56B,原因是每一层除了8个专家网络外,其他层均是复用的。

(2) 问:MoE的基础模型是Mistral 7B?

答:不是,MoE的模型架构与Mistral 7B相同,但其中的FFN替换为了8个FFN,且MoE是基于多语言数据集预训练而来的。

(3) MoE的稀疏性(sparse)体现在哪里?

答:在训练和推理时,同时只有两个专家网络会被激活,进行前向计算,其它专家网络处于失活状态。

6. 总结

一句话足矣~

本文主要针对大语言模型的MoE,包括原理及部分源码。

此外,建议大家可以针对源码进行运行,关于源码,欢迎大家一块交流。

7. 参考

(1) Mistral 7B:
https://arxiv.org/pdf/2310.06825v1.pdf

(2) MoE:
https://arxiv.org/pdf/2401.04088v1.pdf

(3) MoE开源指令微调框架Firefly:
https://github.com/yangjianxin1/Firefly

本文分享自华为云社区《
GaussDB(DWS)对象设计之序列SEQUENCE原理与使用方法介绍
》,作者:VV一笑。

1. 前言

  • 适用版本:8.2.1及以上版本

序列SEQUENCE用来生成唯一整数的数据库对象,本文对序列SEQUENCE的使用场景、使用方法及相关函数进行了介绍,并针对序列SEQUENCE在使用中容易遇到的问题和对应的解决方法进行了梳理总结。

2. SEQUENCE——自增整数序列

序列Sequence是用来产生唯一整数的数据库对象。序列的值是按照一定规则自增的整数。因为自增所以不重复,因此说Sequence具有唯一标识性。因此,在数据库中Sequence常常被作为主键使用。

3. 创建序列

通过序列使某字段成为唯一标识符的方法有两种:

  • 是声明字段的类型为序列整型,由数据库在后台自动创建一个对应的Sequence。
  • 使用CREATE SEQUENCE自定义一个新的Sequence,然后将nextval(‘sequence_name’)函数读取的序列值,指定为某一字段的默认值,这样该字段就可以作为唯一标识符。

方法一: 声明字段类型为序列整型来定义标识符字段。例如:

postgres=# CREATE TABLE T1
(
id serial,
name text
);

方法二: 创建序列,并通过nextval(‘sequence_name’)函数指定为某一字段的默认值。这种方式更灵活,可以为序列定义cache,一次预申请多个序列值,减少与GTM的交互次数,来提高性能。
1.创建序列

postgres=# CREATE SEQUENCE seq1 cache 100;

2.指定为某一字段的默认值,使该字段具有唯一标识属性。

postgres=# CREATE TABLE T2 
(
id
int not null default nextval('seq1'),
name text
);

【注意】

除了为序列指定了cache,方法二所实现的功能基本与方法一类似。但是一旦定义cache,序列将会产生空洞(序列值为不连贯的数值,如:1.4.5),并且不能保序。另外为某序列指定从属列后,该列删除,对应的sequence也会被删除。 虽然数据库并不限制序列只能为一列产生默认值,但最好不要多列共用同一个序列。
当前版本只支持在定义表的时候指定自增列,或者指定某列的默认值为nextval(‘seqname’), 不支持在已有表中增加自增列或者增加默认值为nextval(‘seqname’)的列。

3.1 CREATE SEQUENCE语句的使用方法

CREATE SEQUENCE用于向当前数据库里增加一个新的序列。序列的Owner为创建此序列的用户。

注意事项

  • Sequence是一个存放等差数列的特殊表,该表受DBMS控制。这个表没有实际意义,通常用于为行或者表生成唯一的标识符。
  • 如果给出一个模式名,则该序列就在给定的模式中创建,否则会在当前模式中创建。序列名必须和同一个模式中的其他序列、表、索引、视图或外表的名字不同。
  • 创建序列后,在表中使用序列的nextval()函数和generate_series(1,N)函数对表插入数据,请保证nextval的可调用次数大于等于N+1次,否则会因为generate_series()函数会调用N+1次而导致报错。
  • 不支持在template1数据库中创建SEQUENCE。

语法格式

CREATE SEQUENCE name [ INCREMENT [ BY ] increment ]
[ MINVALUE minvalue
| NO MINVALUE | NOMINVALUE ] [ MAXVALUE maxvalue | NO MAXVALUE |NOMAXVALUE]
[ START [ WITH ] start ] [ CACHE cache ] [ [ NO ] CYCLE
|NOCYCLE ]
[ OWNED BY { table_name.column_name
| NONE } ];

参数说明

  • name

将要创建的序列名称。

取值范围: 仅可以使用小写字母(a~z)、 大写字母(A~Z),数字和特殊字符"#","_","$"的组合。

  • increment

指定序列的步长。一个正数将生成一个递增的序列,一个负数将生成一个递减的序列。

缺省值为1。

  • MINVALUE minvalue | NO MINVALUE| NOMINVALUE

执行序列的最小值。如果没有声明minvalue或者声明了NO MINVALUE,则递增序列的缺省值为1,递减序列的缺省值为-263-1。NOMINVALUE等价于NO MINVALUE

  • MAXVALUE maxvalue | NO MAXVALUE| NOMAXVALUE

执行序列的最大值。如果没有声明maxvalue或者声明了NO MAXVALUE,则递增序列的缺省值为263-1,递减序列的缺省值为-1。NOMAXVALUE等价于NO MAXVALUE

start

指定序列的起始值。缺省值:对于递增序列为minvalue,递减序列为maxvalue。

cache

为了快速访问,而在内存中预先存储序列号的个数。一个缓存周期内,CN不再向GTM索取序列号,而是使用本地预先申请的序列号。

缺省值为1,表示一次只能生成一个值,也就是没有缓存。

【注意】

◾不建议同时定义cache和maxvalue或minvalue。因为定义cache后不能保证序列的连续性,可能会产生空洞,造成序列号段浪费。

◾建议cache值不要设置过大,否则会出现缓存序列号时(每个cache周期的第一个nextval)耗时过长的情况;同时建议cache值小于100000000。实际使用时应根据业务设置合理的cache值,既能保证快速访问,又不会浪费序列号。

CYCLE

用于使序列达到maxvalue或者minvalue后可循环并继续下去。

如果声明了NO CYCLE,则在序列达到其最大值后任何对nextval的调用都会返回一个错误。

NOCYCLE的作用等价于NO CYCLE。

缺省值为NO CYCLE。

若定义序列为CYCLE,则不能保证序列的唯一性。

OWNED BY-

将序列和一个表的指定字段进行关联。这样,在删除那个字段或其所在表的时候会自动删除已关联的序列。关联的表和序列的所有者必须是同一个用户,并且在同一个模式中。需要注意的是,通过指定OWNED BY,仅仅是建立了表的对应列和sequence之间关联关系,并不会在插入数据时在该列上产生自增序列。

缺省值为OWNED BY NONE,表示不存在这样的关联。

【注意】

◾通过OWNED BY创建的Sequence不建议用于其他表,如果希望多个表共享Sequence,该Sequence不应该从属于特定表。

示例

创建一个从101开始的递增序列,名为serial:

CREATE SEQUENCE serial
START
101CACHE20;

从序列中选出下一个数字:

SELECT nextval('serial');
nextval
--------- 101

从序列中选出下一个自增数字:

SELECT nextval('serial');
nextval
--------- 102

创建与表关联的序列:

CREATE TABLE customer_address
(
ca_address_sk integer not
null,
ca_address_id
char(16) not null,
ca_street_number
char(10) ,
ca_street_name varchar(
60) ,
ca_street_type
char(15) ,
ca_suite_number
char(10) ,
ca_city varchar(
60) ,
ca_county varchar(
30) ,
ca_state
char(2) ,
ca_zip
char(10) ,
ca_country varchar(
20) ,
ca_gmt_offset
decimal(5,2) ,
ca_location_type
char(20)
) ;

CREATE SEQUENCE serial1
START
101CACHE20OWNED BY customer_address.ca_address_sk;

使用serial创建主键自增序列表serial_table:

CREATE TABLE serial_table(a int, b serial);
INSERT INTO serial_table (a) VALUES (
1),(2),(3);
SELECT
*FROM serial_table ORDER BY b;
a
|b---+--- 1 | 1 2 | 2 3 | 3(3 rows)

4. 修改序列

ALTER SEQUENCE命令更改现有序列的属性,包括修改修改拥有者、归属列和最大值。

指定序列与列的归属关系:将序列和一个表的指定字段进行关联。在删除那个字段或其所在表的时候会自动删除已关联的序列。

postgres=# ALTER SEQUENCE seq1 OWNED BY T2.id;

将序列serial的最大值修改为300:

ALTER SEQUENCE seq1 MAXVALUE 300;

4.1 ALTER SEQUENCE语句的使用方法

ALTER SEQUENCE用于修改一个现有的序列的参数。

注意事项

  • 使用ALTER SEQUENCE的用户必须是该序列的所有者。
  • 当前版本仅支持修改拥有者、归属列和最大值。若要修改其他参数,可以删除重建,并用Setval函数恢复当前值。
  • ALTER SEQUENCE MAXVALUE不支持在事务、函数和存储过程中使用。
  • 修改序列的最大值后,会清空该序列在所有会话的cache。
  • ALTER SEQUENCE会阻塞nextval、setval、currval和lastval的调用。

语法格式

修改序列最大值或归属列

ALTER SEQUENCE [ IF EXISTS ] name 
[ MAXVALUE maxvalue
| NO MAXVALUE |NOMAXVALUE ]
[ OWNED BY { table_name.column_name
| NONE } ] ;

修改序列的拥有者

ALTER SEQUENCE [ IF EXISTS ] name OWNER TO new_owner;

参数说明

  • name

将要修改的序列名称。

  • IF EXISTS

当序列不存在时使用该选项不会出现错误消息,仅有一个通知。

  • MAXVALUE maxvalue | NO MAXVALUE

序列所能达到的最大值。如果声明了NO MAXVALUE,则递增序列的缺省值为263-1,递减序列的缺省值为-1。NOMAXVALUE等价于NO MAXVALUE。

  • OWNED BY

将序列和一个表的指定字段进行关联。这样,在删除那个字段或其所在表的时候会自动删除已关联的序列。

如果序列已经和表有关联后,使用这个选项后新的关联关系会覆盖旧的关联。

关联的表和序列的所有者必须是同一个用户,并且在同一个模式中。

使用OWNED BY NONE将删除任何已经存在的关联。

new_owner

序列新所有者的用户名。用户要修改序列的所有者,必须是新角色的直接或者间接成员,并且那个角色必须有序列所在模式上的CREATE权限。

示例

将序列serial的最大值修改为200:

ALTER SEQUENCE serial MAXVALUE 200;

创建一个表,定义默认值:

CREATE TABLE T1(C1 bigint default nextval('serial'));

将序列serial的归属列变为T1.C1:

ALTER SEQUENCE serial OWNED BY T1.C1;

5. 删除序列

使用DROP SEQUENCE命令删除一个序列。 例如,将删除名为seq1的序列:

DROP SEQUENCE seq1;

5.1 DROP SEQUENCE语句的使用方法

DROP SEQUENCE用于从当前数据库里删除序列。

注意事项

只有序列的所有者或者系统管理员才能删除。

语法格式

DROP SEQUENCE [ IF EXISTS ] {[schema.]sequence_name} [ , ... ] [ CASCADE | RESTRICT ];

参数说明

  • IF EXISTS

如果指定的序列不存在,则发出一个notice而不是抛出一个错误。

  • name

序列名称。

  • CASCADE

级联删除依赖序列的对象。

  • RESTRICT

如果存在任何依赖的对象,则拒绝删除序列。此项是缺省值。

6. SEQUENCE相关函数

序列函数为用户从序列对象中获取后续的序列值提供了简单的多用户安全的方法。DWS目前支持以下SEQUENCE函数:

6.1 nextval(regclass)

nextval(regclass)用于递增序列并返回新值。
返回类型:bigint
nextval函数有两种调用方式(其中第二种调用方式兼容Oracle的语法,目前不支持Sequence命名中有特殊字符"."的情况),调用方式如下:

示例1:

postgres=# SELECT nextval('seqDemo'); 
nextval
--------- 2(1 row)

示例2:

postgres=# SELECT seqDemo.nextval; 
nextval
--------- 2(1 row)

注意事项

为了避免从同一个序列获取值的并发事务被阻塞, nextval操作不会回滚;也就是说,一旦一个值已经被抓取, 那么就认为它已经被用过了,并且不会再被返回。 即使该操作处于事务中,当事务之后中断,或者如果调用查询结束不使用该值,也是如此。这种情况将在指定值的顺序中留下未使用的"空洞"。 因此,GaussDB(DWS)序列对象不能用于获得"无间隙"序列。

如果nextval被下推到DN上时,各个DN会自动连接GTM,请求next values值,例如(insert into t1 select xxx,t1某一列需要调用nextval函数),由于GTM上有最大连接数为8192的限制,而这类下推语句会导致消耗过多的GTM连接数,因此对于这类语句的并发数目限制为7000(其它语句需要占用部分连接)/集群DN数目。

6.2 currval(regclass)

currval(regclass)用于返回当前会话里最近一次nextval返回的指定的sequence的数值。如果当前会话还没有调用过指定的sequence的nextval,那么调用currval将会报错。需要注意的是,这个函数在默认情况下是不支持的,需要通过设置enable_beta_features为true之后,才能使用这个函数。同时在设置enable_beta_features为true之后,nextval()函数将不支持下推。
返回类型:bigint
currval函数有两种调用方式(其中第二种调用方式兼容Oracle的语法,目前不支持Sequence命名中有特殊字符"."的情况),调用方式如下:

示例1:

postgres=# SELECT currval('seq1'); 
currval
--------- 2(1 row)

示例2:

postgres=# SELECT seq1.currval seq1; 
currval
--------- 2(1 row)

6.3 lastval()

lastval()用于返回当前会话里最近一次nextval返回的数值。这个函数等效于currval,只是它不用序列名为参数,它抓取当前会话里面最近一次nextval使用的序列。如果当前会话还没有调用过nextval,那么调用lastval将会报错。

需要注意的是,lastval()函数在默认情况下是不支持的,需要通过设置enable_beta_features或者lastval_supported为true之后,才能使用这个函数。同时这种情况下,nextval()函数将不支持下推。

返回类型:bigint

示例:

postgres=# SELECT lastval(); 
lastval
--------- 2(1 row)

6.4 setval(regclass, bigint)

setval(regclass, bigint)用于设置序列的当前数值。

返回类型:bigint

示例:

postgres=# SELECT setval('seqDemo',1);
setval
-------- 1(1 row)

6.5 setval(regclass, bigint, boolean)

setval(regclass, bigint, boolean)用于设置序列的当前数值以及is_called标志。

返回类型:bigint

示例:

postgres=# SELECT setval('seqDemo',1,true);
setval
-------- 1(1 row)

注意事项

Setval后当前会话及GTM上会立刻生效,但如果其他会话有缓存的序列值,只能等到缓存值用尽才能感知Setval的作用。所以为了避免序列值冲突,setval要谨慎使用。因为序列是非事务的,setval造成的改变不会由于事务的回滚而撤销。

7. 注意事项

新序列值的产生是靠GTM维护的,默认情况下,每申请一个序列值都要向GTM发送一次申请,GTM在当前值的基础上加上步长值作为产生的新值返回给调用者。GTM作为全局唯一的节点,势必成为性能的瓶颈,所以对于需要大量频繁产生序列号的操作,如使用Bulkload工具进行数据导入场景,是非常不推荐产生默认序列值的。比如,在下面所示的场景中, INSERT FROM SELECT语句的性能会非常慢。

CREATE SEQUENCE newSeq1;
CREATE TABLE newT1
(
id
int not null default nextval('newSeq1'),
name text
);
INSERT INTO newT1(name) SELECT name
from T1;

可以提高性能的写法是(假设T1表导入newT1表中的数据为10000行):

INSERT INTO newT1(id, name) SELECT id,name fromT1;
SELECT SETVAL(
'newSeq1',10000);

序列操作函数nextval(),setval() 等均不支持回滚。另外setval设置的新值,会对当前会话的nextval立即生效,但对其他会话,如果定义了cache,不会立即生效,在用尽所有缓存的值后,其变动才被其他会话感知。所以为了避免产生重复值,要谨慎使用setval,设置的新值不能是已经产生的值或者在缓存中的值。

如果必须要在bulkload场景下产生默认序列值,则一定要为newSeq1定义足够大的cache,并且不要定义Maxvalue或者Minvalue。数据库会试图将nextval(‘sequence_name’)的调用下推到Data Node,以提高性能。 目前GTM对并发的连接请求是有限制的,当Data Node很多时,将产生大量并发连接, 这时一定要控制bulkload的并发数目,避免耗尽GTM的连接资源。如果目标表为复制表(DISTRIBUTE BY REPLICATION)时下推将不能进行。当数据量较大时,这对数据库将是个灾难。除了性能问题之外,空间也可能会剧烈膨胀,在导入结束后,需要用vacuum full来恢复。最好的方式还是如上建议的,不要在bulkload的场景中产生默认序列值。

另外,序列创建后,在每个节点上都维护了一张单行表,存储序列的定义及当前值,但此当前值并非GTM上的当前值,只是保存本节点与GTM交互后的状态。如果其他节点也向GTM申请了新值,或者调用了Setval修改了序列的状态,不会刷新本节点的单行表,但因每次申请序列值是向GTM申请,所以对序列正确性没有影响。

8. 使用案例

DWS如何重置自增列的开始序号?

使用函数setval(regclass, bigint)对自增列值进行重置。

示例:

将seqDemo列的开始序号重置为1:

postgres
=# SELECT setval('seqDemo',1);
setval
-------- 1(1 row)

DWS如何确定sequence和哪个表有关联?

先在pg_class查找目标sequence的oid,然后在pg_depend根据oid查依赖该sequence的对象

示例:

先创建自增序列seq1和依赖seq1的表T2:

postgres
=# CREATE SEQUENCE seq1 cache 100;

postgres
=# CREATE TABLE T2
postgres
-# (
postgres(# id
int not null default nextval('seq1'),
postgres(# name text
postgres(# );
根据seq1从表pg_class、pg_depend联合查询到依赖表T2的oid:
postgres=# select * from pg_depend where objid = (select oid from pg_class where relname = 'seq1')
classid
| objid | objsubid | refclassid | refobjid | refobjsubid |deptype---------+------------+----------+------------+----------+-------------+--------- 1259 | 2147485853 | 0 | 2615 | 2200 | 0 |n
(
1 row)

如何查询序列的last_value?

由于SEQUENCE在自增过程中并不是严格逐个增加,因此序列号中会存在空端数据,所以last_value本身并没有实际意义,可以采用函数lastval()进行查询。

示例:

postgres=# SELECT lastval(); 
lastval
--------- 2(1 row)

注意事项

如果当前会话还没有调用过nextval,那么调用lastval将会报错。此外,lastval()函数在默认情况下是不支持的,需要通过设置enable_beta_features或者lastval_supported为true之后,才能使用这个函数。同时这种情况下,nextval()函数将不支持下推。

如何查询SEQUENC的当前最新值?

通过currval函数可以查询SEQUENC的当前最新值。

示例:

currval函数有两种调用方式(其中第二种调用方式兼容Oracle的语法,目前不支持Sequence命名中有特殊字符"."的情况),调用方式如下:
示例1:
postgres
=# SELECT currval('seq1');
currval
--------- 2(1 row)
示例2:
postgres=# SELECT seq1.currval seq1; 
currval
--------- 2(1 row)

如何解决SEQUENC取值超出范围的问题?

1.可以在创建SEQUENC时设置CYCLE字段,从而使得序列达到maxvalue或者minvalue后可循环并继续下去。但需要注意,若定义序列为CYCLE,则不能保证序列的唯一性。

2.通过调用setval(regclass, bigint)函数对序列取值进行重置。

9. 总结

本文介绍了SEQUENCE的使用场景和相关的函数的使用方法,并对使用SEQUENCE过程中遇到的常见问题及解决方法进行了汇总。

点击关注,第一时间了解华为云新鲜技术~

前言

本文将探讨如何利用WPF框架实现树形表格控件,该控件不仅能够有效地展示复杂的层级数据,还能够提供丰富的个性化定制选项。我们将介绍如何使用WPF提供的控件、模板、布局、数据绑定等技术来构建这样一个树形表格。

一、运行效果

1.1默认样式

1.2 自定义样式

二、代码实现

2.1 创建自定义控件(TreeListView)

新建一个继承自TreeView的控件,并定义一个类型为ViewBase的View依赖属性,用于在代码中指定列。

public classTreeListView : TreeView
{
publicViewBase View
{
get { return(ViewBase)GetValue(ViewProperty); }set{ SetValue(ViewProperty, value); }
}
public static readonly DependencyProperty ViewProperty =DependencyProperty.Register("View", typeof(ViewBase), typeof(TreeListView));
}

2.2 在TreeListView控件模板中处理列头

为了在TreeListView中显示列头,需要在合适的位置添加GridViewHeaderRowPresenter控件,并在Columns属性上绑定我们之前定义的View.Columns属性。下面我们首先来分析TreeView控件模板的代码。

<ControlTemplateTargetType="{x:Type TreeView}">
    <Borderx:Name="Bd"BorderBrush="{TemplateBinding BorderBrush}"BorderThickness="{TemplateBinding BorderThickness}"SnapsToDevicePixels="true">
        <ScrollViewerx:Name="_tv_scrollviewer_"Background="{TemplateBinding Background}"CanContentScroll="false"Focusable="false"HorizontalScrollBarVisibility="{TemplateBinding ScrollViewer.HorizontalScrollBarVisibility}"Padding="{TemplateBinding Padding}"SnapsToDevicePixels="{TemplateBinding SnapsToDevicePixels}"VerticalScrollBarVisibility="{TemplateBinding ScrollViewer.VerticalScrollBarVisibility}">
            <ItemsPresenter/>
        </ScrollViewer>
    </Border>
    <ControlTemplate.Triggers>
        <TriggerProperty="IsEnabled"Value="false">
            <SetterProperty="Background"TargetName="Bd"Value="{DynamicResource {x:Static SystemColors.ControlBrushKey}}"/>
        </Trigger>
        <TriggerProperty="VirtualizingPanel.IsVirtualizing"Value="true">
            <SetterProperty="CanContentScroll"TargetName="_tv_scrollviewer_"Value="true"/>
        </Trigger>
    </ControlTemplate.Triggers>
</ControlTemplate>

通过以上代码我们可以看出,只要将GridViewHeaderRowPresenter控件添加到ScrollViewer控件上面即可实现列头功能,但这样会有一个问题,那就是内容宽度超出控件宽度后,鼠标拖动横向滚动条时列头不会跟随下方的数据列表一起滚动。为解决这个问题我们需要将GridViewHeaderRowPresenter放置到ScrollViewer控件模板中,以下为完整代码。

<Stylex:Key="{x:Static GridView.GridViewScrollViewerStyleKey}"TargetType="{x:Type ScrollViewer}">
    <SetterProperty="Focusable"Value="false" />
    <SetterProperty="Template">
        <Setter.Value>
            <ControlTemplateTargetType="{x:Type ScrollViewer}">
                <GridBackground="{TemplateBinding Background}"SnapsToDevicePixels="true">
                    <Grid.ColumnDefinitions>
                        <ColumnDefinitionWidth="*" />
                        <ColumnDefinitionWidth="Auto" />
                    </Grid.ColumnDefinitions>
                    <Grid.RowDefinitions>
                        <RowDefinitionHeight="*" />
                        <RowDefinitionHeight="Auto" />
                    </Grid.RowDefinitions>

                    <DockPanelMargin="{TemplateBinding Padding}">
                        <ScrollViewerDockPanel.Dock="Top"Focusable="false"HorizontalScrollBarVisibility="Hidden"VerticalScrollBarVisibility="Hidden">
                            <GridViewHeaderRowPresenterMargin="2,0,2,0"AllowsColumnReorder="{Binding TemplatedParent.View.AllowsColumnReorder, RelativeSource={RelativeSource TemplatedParent}}"ColumnHeaderContainerStyle="{Binding TemplatedParent.View.ColumnHeaderContainerStyle, RelativeSource={RelativeSource TemplatedParent}}"ColumnHeaderContextMenu="{Binding TemplatedParent.View.ColumnHeaderContextMenu, RelativeSource={RelativeSource TemplatedParent}}"ColumnHeaderStringFormat="{Binding TemplatedParent.View.ColumnHeaderStringFormat, RelativeSource={RelativeSource TemplatedParent}}"ColumnHeaderTemplate="{Binding TemplatedParent.View.ColumnHeaderTemplate, RelativeSource={RelativeSource TemplatedParent}}"ColumnHeaderTemplateSelector="{Binding TemplatedParent.View.ColumnHeaderTemplateSelector, RelativeSource={RelativeSource TemplatedParent}}"ColumnHeaderToolTip="{Binding TemplatedParent.View.ColumnHeaderToolTip, RelativeSource={RelativeSource TemplatedParent}}"Columns="{Binding TemplatedParent.View.Columns, RelativeSource={RelativeSource TemplatedParent}}"SnapsToDevicePixels="{TemplateBinding SnapsToDevicePixels}" />
                        </ScrollViewer>
                        <ScrollContentPresenterx:Name="PART_ScrollContentPresenter"CanContentScroll="{TemplateBinding CanContentScroll}"Content="{TemplateBinding Content}"ContentTemplate="{TemplateBinding ContentTemplate}"KeyboardNavigation.DirectionalNavigation="Local"SnapsToDevicePixels="{TemplateBinding SnapsToDevicePixels}" />
                    </DockPanel>

                    <ScrollBarx:Name="PART_HorizontalScrollBar"Grid.Row="1"Cursor="Arrow"Maximum="{TemplateBinding ScrollableWidth}"Minimum="0.0"Orientation="Horizontal"ViewportSize="{TemplateBinding ViewportWidth}"Visibility="{TemplateBinding ComputedHorizontalScrollBarVisibility}"Value="{Binding HorizontalOffset, Mode=OneWay, RelativeSource={RelativeSource TemplatedParent}}" />
                    <ScrollBarx:Name="PART_VerticalScrollBar"Grid.Column="1"Cursor="Arrow"Maximum="{TemplateBinding ScrollableHeight}"Minimum="0.0"Orientation="Vertical"ViewportSize="{TemplateBinding ViewportHeight}"Visibility="{TemplateBinding ComputedVerticalScrollBarVisibility}"Value="{Binding VerticalOffset, Mode=OneWay, RelativeSource={RelativeSource TemplatedParent}}" />
                    <DockPanelGrid.Row="1"Grid.Column="1"Background="{Binding Background, ElementName=PART_VerticalScrollBar}"LastChildFill="false">
                        <RectangleWidth="1"DockPanel.Dock="Left"Fill="White"Visibility="{TemplateBinding ComputedVerticalScrollBarVisibility}" />
                        <RectangleHeight="1"DockPanel.Dock="Top"Fill="White"Visibility="{TemplateBinding ComputedHorizontalScrollBarVisibility}" />
                    </DockPanel>
                </Grid>
            </ControlTemplate>
        </Setter.Value>
    </Setter>
</Style>

<StyleTargetType="{x:Type local:TreeListView}">
    <SetterProperty="ScrollViewer.HorizontalScrollBarVisibility"Value="Auto" />
    <SetterProperty="ScrollViewer.VerticalScrollBarVisibility"Value="Auto" />
    <SetterProperty="ScrollViewer.CanContentScroll"Value="true" />
    <SetterProperty="Template">
        <Setter.Value>
            <ControlTemplateTargetType="{x:Type local:TreeListView}">
                <BorderBorderBrush="{TemplateBinding BorderBrush}"BorderThickness="{TemplateBinding BorderThickness}">
                    <ScrollViewerPadding="{TemplateBinding Padding}"Style="{StaticResource {x:Static GridView.GridViewScrollViewerStyleKey}}">
                        <ItemsPresenter/>
                    </ScrollViewer>
                </Border>
            </ControlTemplate>
        </Setter.Value>
    </Setter>
</Style>

2.3 在TreeListViewItem模板中处理子项的展开和收缩

新建一个继承自TreeViewItem的类,命名为TreeListViewItem(如有个性化需求,可以在该类中处理),编辑控件模板,在模板中添加以下代码。

<StyleTargetType="{x:Type local:TreeListViewItem}">
    <SetterProperty="BorderThickness"Value="1" />
    <SetterProperty="Template">
        <Setter.Value>
            <ControlTemplateTargetType="{x:Type local:TreeListViewItem}">
                <StackPanel>
                    <BorderName="Bd"Padding="{TemplateBinding Padding}"Background="{TemplateBinding Background}"BorderBrush="{TemplateBinding BorderBrush}"BorderThickness="{TemplateBinding BorderThickness}">
                        <GridViewRowPresenterx:Name="PART_Header"Columns="{Binding RelativeSource={RelativeSource AncestorType=local:TreeListView}, Path=View.Columns}"Content="{TemplateBinding Header}" />
                    </Border>
                    <ItemsPresenterx:Name="ItemsHost" />
                </StackPanel>
                <ControlTemplate.Triggers>
                    <TriggerProperty="IsExpanded"Value="false">
                        <SetterTargetName="ItemsHost"Property="Visibility"Value="Collapsed" />
                    </Trigger>
                    <MultiTrigger>
                        <MultiTrigger.Conditions>
                            <ConditionProperty="HasHeader"Value="false" />
                            <ConditionProperty="Width"Value="Auto" />
                        </MultiTrigger.Conditions>
                        <SetterTargetName="PART_Header"Property="MinWidth"Value="75" />
                    </MultiTrigger>
                    <MultiTrigger>
                        <MultiTrigger.Conditions>
                            <ConditionProperty="HasHeader"Value="false" />
                            <ConditionProperty="Height"Value="Auto" />
                        </MultiTrigger.Conditions>
                        <SetterTargetName="PART_Header"Property="MinHeight"Value="19" />
                    </MultiTrigger>

                    <MultiTrigger>
                        <MultiTrigger.Conditions>
                            <ConditionProperty="extensions:TreeViewItemExtensions.IsMouseDirectlyOverItem"Value="True" />
                        </MultiTrigger.Conditions>
                        <SetterTargetName="Bd"Property="Background"Value="{StaticResource Item.MouseOver.Background}" />
                        <SetterTargetName="Bd"Property="BorderBrush"Value="{StaticResource Item.MouseOver.Border}" />
                    </MultiTrigger>
                    <MultiTrigger>
                        <MultiTrigger.Conditions>
                            <ConditionProperty="Selector.IsSelectionActive"Value="False" />
                            <ConditionProperty="IsSelected"Value="True" />
                        </MultiTrigger.Conditions>
                        <SetterTargetName="Bd"Property="Background"Value="{StaticResource Item.SelectedInactive.Background}" />
                        <SetterTargetName="Bd"Property="BorderBrush"Value="{StaticResource Item.SelectedInactive.Border}" />
                    </MultiTrigger>
                    <MultiTrigger>
                        <MultiTrigger.Conditions>
                            <ConditionProperty="Selector.IsSelectionActive"Value="True" />
                            <ConditionProperty="IsSelected"Value="True" />
                        </MultiTrigger.Conditions>
                        <SetterTargetName="Bd"Property="Background"Value="{StaticResource Item.SelectedActive.Background}" />
                        <SetterTargetName="Bd"Property="BorderBrush"Value="{StaticResource Item.SelectedActive.Border}" />
                    </MultiTrigger>
                    <TriggerProperty="IsEnabled"Value="False">
                        <SetterProperty="Foreground"Value="{DynamicResource {x:Static SystemColors.GrayTextBrushKey}}" />
                    </Trigger>
                </ControlTemplate.Triggers>
            </ControlTemplate>
        </Setter.Value>
    </Setter>
</Style>

此处的核心在于模板中添加了GridViewRowPresenter控件,并在Columns属性上绑定了我们之前定义的View.Columns属性,这样就可以在每一行上面显示列数据。还有一个关键点是ItemsPresenter,它用于显示子项数据,此处命名为ItemsHost,它由属性触发器中的代码来控件展开和收起。以下是属性触发器代码。

<TriggerProperty="IsExpanded"Value="false">
    <SetterTargetName="ItemsHost"Property="Visibility"Value="Collapsed" />
</Trigger>

2.4 在单元格模板中控件子项的展开与收起

为了达到展开和收起的效果,需要在首列的单元格中控制TreeListViewItem的IsExpanded属性。以下为完整代码。

<DataTemplatex:Key="ExpandCellTemplate">
    <DockPanel>
        <ToggleButtonx:Name="Expander"Margin="{Binding Path=Level, Converter={StaticResource LevelIndentConverter}, RelativeSource={RelativeSource AncestorType={x:Type TreeListViewItem}}}"ClickMode="Press"IsChecked="{Binding Path=IsExpanded, RelativeSource={RelativeSource AncestorType={x:Type TreeListViewItem}}}"Style="{StaticResource ExpandCollapseToggleStyle}" />
        <TextBlockText="{Binding Property1}" />
    </DockPanel>
    <DataTemplate.Triggers>
        <DataTriggerBinding="{Binding Path=HasItems, RelativeSource={RelativeSource AncestorType={x:Type TreeListViewItem}}}"Value="False">
            <SetterTargetName="Expander"Property="Visibility"Value="Hidden" />
        </DataTrigger>
    </DataTemplate.Triggers>
</DataTemplate>

其关键代码为

IsChecked="{Binding Path=IsExpanded, RelativeSource={RelativeSource AncestorType={x:Type TreeListViewItem}}}"

2.5 控件使用

<TreeListViewItemsSource="{Binding Collection}">
    <TreeListView.ItemTemplate>
        <HierarchicalDataTemplateItemsSource="{Binding Collection, IsAsync=True}" />
    </TreeListView.ItemTemplate>
    <TreeListView.View>
        <GridView>
            <GridViewColumnCellTemplate="{StaticResource ExpandCellTemplate}"Header="Property1" />
            <GridViewColumnDisplayMemberBinding="{Binding Property2}"Header="Property2" />
            <GridViewColumnDisplayMemberBinding="{Binding Property3}"Header="Property3" />
            <GridViewColumnDisplayMemberBinding="{Binding Property4}"Header="Property4" />
            <GridViewColumnDisplayMemberBinding="{Binding Property5}"Header="Property5" />
            <GridViewColumnDisplayMemberBinding="{Binding Property6}"Header="Property6" />
            <GridViewColumnDisplayMemberBinding="{Binding Property7}"Header="Property7" />
            <GridViewColumnDisplayMemberBinding="{Binding Property8}"Header="Property8" />
            <GridViewColumnDisplayMemberBinding="{Binding Property9}"Header="Property9" />
            <GridViewColumnDisplayMemberBinding="{Binding Property10}"Header="Property10" />
            <GridViewColumnDisplayMemberBinding="{Binding Property11}"Header="Property11" />
            <GridViewColumnDisplayMemberBinding="{Binding Property12}"Header="Property12" />
        </GridView>
    </TreeListView.View>
</TreeListView>

前面学习了一些Source Generators的基础只是,接下来就来实践一下,用这个来生成我们所需要的代码。
本文将通过读取swagger.json的内容,解析并生成对应的请求响应类的代码。

创建项目

首先还是先创建两个项目,一个控制台程序,一个类库。
image.png

添加swagger文件

在控制台程序中添加Files目录,并把swagger文件放进去。别忘了还需要添加AdditionalFiles。

<ItemGroup>
  <AdditionalFiles Include="Files\swagger.json" />
</ItemGroup>

image.png

实现ClassFromSwaggerGenerator

安装依赖

由于我们需要解析swagger,所以需要安装一下JSON相关的包。这里我们安装了Newtonsoft.Json。
需要注意的是,依赖第三方包的时候需要在项目文件添加下面内容:

<PropertyGroup>
  <GetTargetPathDependsOn>$(GetTargetPathDependsOn);GetDependencyTargetPaths</GetTargetPathDependsOn>
</PropertyGroup>
<Target Name="GetDependencyTargetPaths" AfterTargets="ResolvePackageDependenciesForBuild">
  <ItemGroup>
    <TargetPathWithTargetPlatformMoniker Include="@(ResolvedCompileFileDefinitions)" IncludeRuntimeDependency="false" />
  </ItemGroup>
</Target>

否则编译时会出现FileNotFound的异常。

构建管道

这里我们通过AdditionalTextsProvider筛选以及过滤我们的swagger文件。

var pipeline = context.AdditionalTextsProvider.Select(static (text, cancellationToken) =>
  {
      if (!text.Path.EndsWith("swagger.json", StringComparison.OrdinalIgnoreCase))
      {
          return default;
      }

      return JObject.Parse(text.GetText(cancellationToken)!.ToString());
  })
    .Where((pair) => pair is not null);

实现生成代码逻辑

接下来我们就解析Swagger中的内容,并且动态拼接代码内容。主要代码部分如下:

context.RegisterSourceOutput(pipeline, static (context, swagger) =>
 {

     List<(string name, string sourceString)> sources = new List<(string name, string sourceString)>();


     #region 生成实体
     var schemas = (JObject)swagger["components"]!["schemas"]!;
     foreach (JProperty item in schemas.Properties())
     {
         if (item != null)
         {
             sources.Add((HandleClassName(item.Name), $@"#nullable enable
using System;
using System.Collections.Generic;

namespace SwaggerEntities;
public {ClassOrEnum((JObject)item.Value)} {HandleClassName(item.Name)} 
{{
    {BuildProperty((JObject)item.Value)}
}}
"));
         }
     }
     foreach (var (name, sourceString) in sources)
     {
         var sourceText = SourceText.From(sourceString, Encoding.UTF8);

         context.AddSource($"{name}.g.cs", sourceText);
     }
     #endregion
     });

完整的代码如下:

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Text;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;

namespace GenerateClassFromSwagger.Analysis
{
    [Generator]
    public class ClassFromSwaggerGenerator : IIncrementalGenerator
    {
        public void Initialize(IncrementalGeneratorInitializationContext context)
        {
            var pipeline = context.AdditionalTextsProvider.Select(static (text, cancellationToken) =>
            {
                if (!text.Path.EndsWith("swagger.json", StringComparison.OrdinalIgnoreCase))
                {
                    return default;
                }

                return JObject.Parse(text.GetText(cancellationToken)!.ToString());
            })
            .Where((pair) => pair is not null);

            context.RegisterSourceOutput(pipeline, static (context, swagger) =>
            {

                List<(string name, string sourceString)> sources = new List<(string name, string sourceString)>();


                #region 生成实体
                var schemas = (JObject)swagger["components"]!["schemas"]!;
                foreach (JProperty item in schemas.Properties())
                {
                    if (item != null)
                    {
                        sources.Add((HandleClassName(item.Name), $@"#nullable enable
using System;
using System.Collections.Generic;

namespace SwaggerEntities;
public {ClassOrEnum((JObject)item.Value)} {HandleClassName(item.Name)} 
{{
    {BuildProperty((JObject)item.Value)}
}}
                "));
                    }
                }
                foreach (var (name, sourceString) in sources)
                {
                    var sourceText = SourceText.From(sourceString, Encoding.UTF8);

                    context.AddSource($"{name}.g.cs", sourceText);
                }
                #endregion
            });
        }

        static string HandleClassName(string name)
        {
            return name.Split('.').Last().Replace("<", "").Replace(">", "").Replace(",", "");
        }
        static string ClassOrEnum(JObject value)
        {
            return value.ContainsKey("enum") ? "enum" : "partial class";
        }


        static string BuildProperty(JObject value)
        {
            var sb = new StringBuilder();
            if (value.ContainsKey("properties"))
            {
                var propertys = (JObject)value["properties"]!;
                foreach (JProperty item in propertys!.Properties())
                {
                    sb.AppendLine($@"
    public {BuildProertyType((JObject)item.Value)} {ToUpperFirst(item.Name)}  {{ get; set; }}
");
                }
            }
            if (value.ContainsKey("enum"))
            {
                foreach (var item in JsonConvert.DeserializeObject<List<int>>(value["enum"]!.ToString())!)
                {
                    sb.Append($@"
    _{item},
");
                }
                sb.Remove(sb.Length - 1, 1);
            }
            return sb.ToString();
        }

        static string BuildProertyType(JObject value)
        {
            ;
            var type = GetType(value);
            var nullable = value.ContainsKey("nullable") ? value["nullable"]!.Value<bool?>() switch
            {
                true => "?",
                false => "",
                _ => ""
            } : "";
            return type + nullable;
        }

        static string GetType(JObject value)
        {
            return value.ContainsKey("type") ? value["type"]!.Value<string>() switch
            {
                "string" => "string",
                "boolean" => "bool",
                "number" => value["format"]!.Value<string>() == "float" ? "float" : "double",
                "integer" => value["format"]!.Value<string>() == "int32" ? "int" : "long",
                "array" => ((JObject)value["items"]!).ContainsKey("items") ?
                $"List<{HandleClassName(value["items"]!["$ref"]!.Value<string>()!)}>"
                : $"List<{GetType((JObject)value["items"]!)}>",
                "object" => value.ContainsKey("additionalProperties") ? $"Dictionary<string, {GetType((JObject)value["additionalProperties"]!)}>" : "object",
                _ => "object"
            } : value.ContainsKey("$ref") ? HandleClassName(value["$ref"]!.Value<string>()!) : "object";
        }

        static unsafe string ToUpperFirst(string str)
        {
            if (str == null) return null;
            string ret = string.Copy(str);
            fixed (char* ptr = ret)
                *ptr = char.ToUpper(*ptr);
            return ret;
        }
    }
}

详细的处理过程大家可以仔细看看代码,这里就不一一解释了。

启动编译

接下来编译控制台程序。编译成功后可以看到生成了很多cs的代码。若是看不见,可以重启VS。
image.png
点开一个文件,可以看到内容,并且在上方提示自动生成,无法编辑。
image.png
到这我们就完成了通过swagger来生成我们的请求和响应类的功能。

结语

本文章应用SourceGenerator,在编译时读取swagger.json的内容并解析,成功生成了我们API的请求和响应类的代码。
我们可以发现,代码生成没有问题,无法移动或者编辑生成的代码。
下一篇文章我们就来学习下如何输出SourceGenerator生成的代码文件到我们的文件目录。

本文代码仓库地址https://github.com/fanslead/Learn-SourceGenerator