2024年12月

做全文搜索,es比较好用,安装可能有点费时费力。mysql安装就不说了。主要是elastic8.4.0+kibana8.4.0+logstash-8.16.1,可视化操作及少量netcore查询代码。

安装elastic8.4.0+kibana8.4.0使用docker-desktop,logstash-8.16.1是线程解压执行文件。

  • 1.
    docker-compose.yml 如下: 首先使用docker network创建一个es-net内部通讯网络,这样kibana连接es可以通过容器名ELASTICSEARCH_HOSTS=http://elasticsearch:9200,此作为单机测试使用单机的es.
services:

elasticsearch:
container_name: elasticsearch
image: docker.elastic.co
/elasticsearch/elasticsearch:8.4.0environment:- discovery.type=single-node
ulimits:
memlock:
soft:
-1hard:-1cap_add:-IPC_LOCK
ports:
- "9200:9200"networks:- es-net

kibana:
container_name: kibana
image: docker.elastic.co
/kibana/kibana:8.4.0environment:- ELASTICSEARCH_HOSTS=http://elasticsearch:9200 ports:- "5601:5601"networks:- es-net

networks:
es
-net:
driver: bridge

作为es的8以上版本是有账号密码和crt证书的,需要做如下配置:

安装好es后默认给一个elastic账号,需要重置一下密码,进入es容器执行重置密码命令,会给你一个密码。

docker exec  -it -u root elasticsearch /bin/bash
bin
/elasticsearch-reset-password -u elastic

这里登录的其实是https带证书的,但是kibana使用的是http的,所以在容器内部,config/elasticsearch.yml中需要把下面的两个参数置为false ,生产环境不建议这么操作。

因为es带账号密码,所以kibana连接es也需要账号密码信息,但是默认的elastic是超级管理员,kibana默认是不支持的,需要自己新建账号。但是es默认是给了账号的,用他的就行。自己新建es账号给一个超级管理员角色依然没有重建所应权限,导致kibana起不来,用kibana_system就行。

进入es容器内部给kibana_system重置一个密码,用下面的命令在内部调用也行,我设置的elastic和kibana_system的密码一样,方便使用。

curl -u elastic:DiVnR2F6OGYmP+Ms+n2o -X POST "http://localhost:9200/_security/user/kibana_system/_password" -H 'Content-Type: application/json' -d'{"password": "DiVnR2F6OGYmP+Ms+n2o"}'

  • 2.
    然后在kibana容器中,加上账号密码信息即可,重启。还有最后一行加上i18n.locale: zh-CN  ,改变ui为中文。

然后通过开发工具就可以做es的调试了,这里注意下需要中文分词的可以去 https://github.com/infinilabs/analysis-ik/releases 下载对应版本8.4.0的中文分词器 ,改个名放到es容器内plugins中去。也可以自定义分词文件丢进去

  • 3. 下面就是logstash安装跟mysql的同步了,测试数据如下:

首先去logstash官网下载对应的包,我选的版本是8.16.1,目录如下是可以通过控制台执行的。

这里只需要配置好mysql-connector的驱动和链接信息即可。

jdbc.conf文件内容如下:

input {
stdin {}
jdbc {
type
=> "jdbc"# 数据库连接地址
jdbc_connection_string
=> "jdbc:mysql://192.168.200.2:3306/bbs?characterEncoding=UTF-8&autoReconnect=true"# 数据库连接账号密码;
jdbc_user
=> "admin"jdbc_password=> "这是密码"# MySQL依赖包路径;
jdbc_driver_library
=> "D:\software\logstash-8.16.1\mysql\mysql-connector-j-8.0.32.jar"# the name of the driverclass formysql
jdbc_driver_class
=> "com.mysql.jdbc.Driver"# 数据库重连尝试次数
connection_retry_attempts
=> "3"# 判断数据库连接是否可用,默认false不开启
jdbc_validate_connection
=> "true"# 数据库连接可用校验超时时间,默认3600S
jdbc_validation_timeout
=> "3600"# 开启分页查询(默认false不开启);
jdbc_paging_enabled
=> "true"# 单次分页查询条数(默认100000,若字段较多且更新频率较高,建议调低此值);
jdbc_page_size
=> "500"# statement为查询数据sql,如果sql较复杂,建议配通过statement_filepath配置sql文件的存放路径;
# sql_last_value为内置的变量,存放上次查询结果中最后一条数据tracking_column的值,此处即为ModifyTime;
# statement_filepath
=> "mysql/jdbc.sql"statement=> "SELECT ArticleID,UserID,ArticleTitle,ArticleContent,ImageAddress,StandPoint,PublishTime,`Status`,Likes, Shares,Comments,Reports, Sort,PublishingMode,SourceType,Reply,IsTop,TopEndTime,Hot,EditUserId,CreatedTime,EditTime,UserType,UserNickname,ForbiddenState,PublishDateTime,TopArea,SubscribeType,CollectionCount,Articletype,NewsID,CommentUserCount,TopStartTime,`View`,ViewDuration,Forwardings,ForwardingFId,Freshness,Shelf_Reason,AuditTime FROM bbs_articles"# 是否将字段名转换为小写,默认true(如果有数据序列化、反序列化需求,建议改为false);
lowercase_column_names
=> false# Value can be any of: fatal,error,warn,info,debug,默认info;
sql_log_level
=>warn
#
# 是否记录上次执行结果,true表示会将上次执行结果的tracking_column字段的值保存到last_run_metadata_path指定的文件中;
record_last_run
=> true# 需要记录查询结果某字段的值时,此字段为true,否则默认tracking_column为timestamp的值;
use_column_value
=> true# 需要记录的字段,用于增量同步,需是数据库字段
tracking_column
=> "PublishTime"# Value can be any of: numeric,timestamp,Default valueis "numeric"tracking_column_type=>timestamp
# record_last_run上次数据存放位置;
last_run_metadata_path
=> "mysql/last_id.txt"# 是否清除last_run_metadata_path的记录,需要增量同步时此字段必须为false;
clean_run
=> false#
# 同步频率(分 时 天 月 年),默认每分钟同步一次;
schedule
=> "* * * * *"}
}

filter {
json {
source
=> "message"remove_field=> ["message"]
}
# convert 字段类型转换,将字段TotalMoney数据类型改为float;
mutate {
convert
=>{
#
"TotalMoney" => "float"}
}
}
output {
elasticsearch {
# host
=> "127.0.0.1"# port=> "9200"# 配置ES集群地址
# hosts
=> ["192.168.1.1:9200", "192.168.1.2:9200", "192.168.1.3:9200"]
hosts
=> ["127.0.0.1:9200"]
user
=> "elastic"password=> "DiVnR2F6OGYmP+Ms+n2o"ssl=> false# 索引名字,必须小写
index
=> "bbs_act"# 数据唯一索引(建议使用数据库KeyID)
document_id
=> "%{ArticleID}"}
stdout {
codec
=>json_lines
}
}

配置文成后执行该命令,数据实时同步开始

bin\logstash.bat -f mysql\jdbc.conf

可以通过kibana的discover查看数据,也可以通过开发工具查询,elk日志就是这么玩。

  • 4. 下面就是代码,这里的实体没给全,注意实体需要给Text的Name属性,否则会解析不到数据的:
 public class ArticleEsContext : EsBase<ArticleDto>{public ArticleEsContext(EsConfig esConfig) : base(esConfig)
{
}
public override string IndexName => "bbs_act";public async Task<List<ArticleDto>>GetArticles(ArticleParameter parameter)
{
var client =_esConfig.GetClient(IndexName);//计算分页的起始位置 var from = (parameter.PageNumber - 1) *parameter.PageSize;var searchResponse = await client.SearchAsync<ArticleDto>(s =>s
.Index(IndexName)
.Query(q
=>q
.Bool(b
=>b
.Should(
sh
=> sh.Match(m =>m
.Field(f
=> f.ArticleTitle) //查询 ArticleTitle .Query(parameter.KeyWords)
.Fuzziness(Fuzziness.Auto)
//启用模糊查询 ),
sh
=> sh.Match(m =>m
.Field(f
=> f.ArticleContent) //查询 ArticleContent .Query(parameter.KeyWords)
.Fuzziness(Fuzziness.Auto)
//启用模糊查询 )
)
.MinimumShouldMatch(
1) //至少一个条件必须匹配 )
)
.From(
from) //设置分页的起始位置 .Size(parameter.PageSize) //设置每页大小 );if (!searchResponse.IsValid)
{
Console.WriteLine(searchResponse.DebugInformation);
return new List<ArticleDto>();
}
returnsearchResponse.Documents.ToList();
}
}
public classArticleDto
{
[Text(Name
= "ArticleID")]public int ArticleId { get; set; }
[Text(Name
= "ArticleTitle")]public string ArticleTitle { get; set; }
[Text(Name
= "ArticleContent")]public string ArticleContent { get; set; }
[Date(Name
= "CreatedTime")]public DateTime CreatedTime { get; set; }
}

代码调用结果如下:

一、数据库

linux下登录:

mysql -u root -p

查看数据库:

show databases;

可以在phpmyadmin面板点击SQL进行操作

1. 增加/创建

创建xxx数据库,并使用utf-8编码

create database xxx charset utf8;

2. 删除

删除xxx数据库

drop database xxx;

3. 选择进入数据库

进入xxx数据库

use xxx;

二、数据表

1. 增加/创建表

create table xxx;

定义表属性

varchar(40)字段可以存储的最大字符数为40个字符

(id int,
name varchar(40),
sex char(4),
birthday date,
job varchar(100)
);

这么使用

create table track(id int,
name varchar(40),
sex char(4),
birthday date,
job varchar(100)
);

形式如下:

2. 查看

查看数据表信息

show full columns from xxx;

结果:

查看数据表列表,* 代表所有列表

select * from xxx;

结果:

3. 删除

删除数据表

drop table xxx;
delete from xxx;

4. 修改

修改数据表名xxx为yyy

rename table xxx to yyy;

三、数据列和数据行

1. 增加/创建

增加一行

insert into xxx(id,name,sex,birthday,job)
values(1,'track','男','2000-00-00','IT');

结果:

增加一列

在xxx表中增加一列名为zenjia,可以存储最多8位数字,其中2位是小数点后的数字,-99999999.99 到 99999999.99

alter table xxx add zenjia decimal(8,2);

结果:

2. 修改

修改xxx表zenjia列所有值为5000

update xxx set zenjia=5000;

结果:

修改xxx表id=1的行,name值为name1

update xxx set name='name1' where id=1;

结果:

修改xxx表id=1的行,name=name2,zenjia列第一行的值为2000

update xxx set name='name2',zenjia=2000 where id=1;

结果:

3. 删除

删除列

删除zenjia列

alter table xxx drop zenjia;

删除行

删除job列值为it的行,不区分大小写

delete from xxx where job='it';

结果:

书接上回,前面章节已经实现Excel帮助类的第一步TableHeper的对象集合与DataTable相互转换功能,今天实现进入其第二步的核心功能ExcelHelper实现。

01
、接口设计

下面我们根据第一章中讲解的核心设计思路,先进行接口设计,确定ExcelHelper需要哪些接口即可满足我们的要求,然后再一个一个接口实现即可。

先简单回顾一下核心设计思路,主要涉及两类操作:读和写,两种转换:DataTable与Excel转换和对象集合与Excel转换。

下面先看看设计的所有接口:

//根据文件路径读取Excel到DataSet
//指定sheetName,sheetNumber则读取相应工作簿Sheet
//如果不指定则读取所有工作簿Sheet
public static DataSet Read(string path, bool isFirstRowAsColumnName = false, string? sheetName = null, int? sheetNumber = null);

//根据文件流读取Excel到DataSet
//指定sheetName,sheetNumber则读取相应工作簿Sheet
//如果不指定则读取所有工作簿Sheet
public static DataSet Read(Stream stream, string fileName, bool isFirstRowAsColumnName = false, string? sheetName = null, int? sheetNumber = null);

//根据文件流读取Excel到DataSet
//指定sheetName,sheetNumber则读取相应工作簿Sheet
//如果不指定则读取所有工作簿Sheet
public static DataSet Read(Stream stream, bool isXlsx, bool isFirstRowAsColumnName = false, string? sheetName = null, int? sheetNumber = null);

//根据文件流读取Excel到对象集合
//指定sheetName,sheetNumber则读取相应工作簿Sheet
//如果不指定则默认读取第一个工作簿Sheet
public static IEnumerable<T> Read<T>(string path, bool isFirstRowAsColumnName = false, string? sheetName = null, int? sheetNumber = null);

//根据文件流读取Excel到对象集合
//指定sheetName,sheetNumber则读取相应工作簿Sheet
//如果不指定则默认读取第一个工作簿Sheet
public static IEnumerable<T> Read<T>(Stream stream, string fileName, bool isFirstRowAsColumnName = false, string? sheetName = null, int? sheetNumber = null);

//根据文件流读取Excel到对象集合
//指定sheetName,sheetNumber则读取相应工作簿Sheet
//如果不指定则默认读取第一个工作簿Sheet
public static IEnumerable<T> Read<T>(Stream stream, bool isXlsx, bool isFirstRowAsColumnName = false, string? sheetName = null, int? sheetNumber = null);

//把表格数组写入Excel文件流
public static MemoryStream Write(DataTable[] dataTables, bool isXlsx, bool isColumnNameAsData);

//把表格数组写入Excel文件
public static void Write(DataTable[] dataTables, string path, bool isColumnNameAsData);

//把对象集合写入Excel文件流
public static MemoryStream Write<T>(IEnumerable<T> models, bool isXlsx, bool isColumnNameAsData, string? sheetName = null);

//把对象集合写入Excel文件
public static void Write<T>(IEnumerable<T> models, string path, bool isColumnNameAsData, string? sheetName = null);

02
、根据文件路径读取Excel到DataSet

该方法是通过Excel完全路径直接读取Excel文件,因此我们首先读取到文件流,然后再调用具体处理文件流实现方法。

因为Excel中工作簿Sheet正好对应DataSet中表格DataTable,因此在不指定读取某个工作簿Sheet的情况下,默认是读取Excel中所有工作簿Sheet。

指定工作簿方式也很简单,只要传参数指定工作簿名称sheetName或者工作簿编号sheetNumber即可,提供两个参数是考虑到可能名字不好记,但是第几个工作簿Sheet会比较好记,也因此工作簿编号sheetNumber是从1开始。两者会优先处理工作簿名称sheetName。

因为表格DataTable是有列名的,通过这个列名我们可以把它和对象属性关联上,最后实现相互映射转换,而工作簿Sheet则没有这个概念,因此我们要想最终实现对象和工作簿Sheet的相互转换,就需要人为指定这样的数据。

通常的做法是以工作簿Sheet中第一行数据作为表格DataTable列名,因此我们在接口中设计了这个参数用来指定是否需要把第一行数据作为表格列名。

具体代码实现如下:

//根据文件路径读取Excel到DataSet
//指定sheetName,sheetNumber则读取相应工作簿Sheet
//如果不指定则读取所有工作簿Sheet
public static DataSet Read(string path, bool isFirstRowAsColumnName = false, string? sheetName = null, int? sheetNumber = null)
{
    using var stream = new FileStream(path, FileMode.Open, FileAccess.Read);
    return Read(stream, IsXlsxFile(path), isFirstRowAsColumnName, sheetName, sheetNumber);
}

03
、根据文件流、文件名读取Excel到DataSet

在有些场景下,不需要我们直接读取Excel文件,而是直接给一个Excel文件流。比如说文件上传,前端上传文件后,后端接收到的就是一个文件流。

同时该方法还需要传一个文件名的参数,这是因为我们Excel有两种后缀格式即“.xls”和“.xlsx”,而两种格式处理方式又不相同,因此我们需要通过名字来说识别Excel文件流的具体格式,当然如果调用方法时已经明确知道文件流是什么格式,也可以直接调用下一个重载方法。

其他参数解释上节以及详细讲解了,实现代码如下:

//根据文件流读取Excel到DataSet
//指定sheetName,sheetNumber则读取相应工作簿Sheet
//如果不指定则读取所有工作簿Sheet
public static DataSet Read(Stream stream, string fileName, bool isFirstRowAsColumnName = false, string? sheetName = null, int? sheetNumber = null)
{
    return Read(stream, IsXlsxFile(fileName), isFirstRowAsColumnName, sheetName, sheetNumber);
}

04
、根据文件流、文件后缀读取Excel到DataSet

该方法是上面两个方法的最终实现,该方法首先会识别读取所有工作簿Sheet还是读取指定工作簿Sheet,然后调不同的方法。而两者差别也这是读一个还是读多个工作簿Sheet的差别,具体代码如下:

//根据文件流读取Excel到DataSet
public static DataSet Read(Stream stream, bool isXlsx, bool isFirstRowAsColumnName = false, string? sheetName = null, int? sheetNumber = null)
{
    if (sheetName == null && sheetNumber == null)
    {
        //读取所有工作簿Sheet至DataSet
        return CreateDataSetWithStreamOfSheets(stream, isXlsx, isFirstRowAsColumnName);
    }
    //读取指定工作簿Sheet至DataSet
    return CreateDataSetWithStreamOfSheet(stream, isXlsx, isFirstRowAsColumnName, sheetName, sheetNumber ?? 1);
}
//读取所有工作簿Sheet至DataSet
private static DataSet CreateDataSetWithStreamOfSheets(Stream stream, bool isXlsx, bool isFirstRowAsColumnName)
{
    //根据Excel文件后缀创建IWorkbook
    using var workbook = CreateWorkbook(isXlsx, stream);
    //根据Excel文件后缀创建公式求值器
    var evaluator = CreateFormulaEvaluator(isXlsx, workbook);
    var dataSet = new DataSet();
    for (var i = 0; i < workbook.NumberOfSheets; i++)
    {
        //获取工作簿Sheet
        var sheet = workbook.GetSheetAt(i);
        //通过工作簿Sheet创建表格
        var table = CreateDataTableBySheet(sheet, evaluator, isFirstRowAsColumnName);
        dataSet.Tables.Add(table);
    }
    return dataSet;
}
//读取指定工作簿Sheet至DataSet
private static DataSet CreateDataSetWithStreamOfSheet(Stream stream, bool isXlsx, bool isFirstRowAsColumnName, string? sheetName = null, int sheetNumber = 1)
{
    //把工作簿sheet编号转为索引
    var sheetIndex = sheetNumber - 1;
    var dataSet = new DataSet();
    if (string.IsNullOrWhiteSpace(sheetName) && sheetIndex < 0)
    {
        //工作簿sheet索引非法则返回
        return dataSet;
    }
    //根据Excel文件后缀创建IWorkbook
    using var workbook = CreateWorkbook(isXlsx, stream);
    if (string.IsNullOrWhiteSpace(sheetName) && sheetIndex >= workbook.NumberOfSheets)
    {
        //工作簿sheet索引非法则返回
        return dataSet;
    }
    //根据Excel文件后缀创建公式求值器
    var evaluator = CreateFormulaEvaluator(isXlsx, workbook);
    //优先通过工作簿名称获取工作簿sheet
    var sheet = !string.IsNullOrWhiteSpace(sheetName) ? workbook.GetSheet(sheetName) : workbook.GetSheetAt(sheetIndex);
    if (sheet != null)
    {
        //通过工作簿sheet创建表格
        var table = CreateDataTableBySheet(sheet, evaluator, isFirstRowAsColumnName);
        dataSet.Tables.Add(table);
    }
    return dataSet;
}

通过上图实现工作簿Sheet转换DataSet过程,可以发现大致分为三步:

第一步首先根据文件格式以及文件流获取IWorkbook;

第二步再通过文件格式以及IWorkbook获取公式求值器;

第三步再实现把工作簿Sheet转换为表格DataTable;

我们一起看看这三个代码实现:

//根据Excel文件后缀创建IWorkbook
private static IWorkbook CreateWorkbook(bool isXlsx, Stream? stream = null)
{
    if (stream == null)
    {
        return isXlsx ? new XSSFWorkbook() : new HSSFWorkbook();
    }
    return isXlsx ? new XSSFWorkbook(stream) : new HSSFWorkbook(stream);
}
//根据Excel文件后缀创建公式求值器
private static IFormulaEvaluator CreateFormulaEvaluator(bool isXlsx, IWorkbook workbook)
{
    return isXlsx ? new XSSFFormulaEvaluator(workbook) : new HSSFFormulaEvaluator(workbook);
}
//工作簿Sheet转换为表格DataTable
private static DataTable CreateDataTableBySheet(ISheet sheet, IFormulaEvaluator evaluator, bool isFirstRowAsColumnName)
{
    var dataTable = new DataTable(sheet.SheetName);
    //获取Sheet中最大的列数,并以此数为新的表格列数
    var maxColumnNumber = GetMaxColumnNumber(sheet);
    if (isFirstRowAsColumnName)
    {
        //如果第一行数据作为表头,则先获取第一行数据
        var firstRow = sheet.GetRow(sheet.FirstRowNum);
        for (var i = 0; i < maxColumnNumber; i++)
        {
            //尝试读取第一行每一个单元格数据,有值则作为列名,否则忽略
            string? columnName = null;
            var cell = firstRow?.GetCell(i);
            if (cell != null)
            {
                cell.SetCellType(CellType.String);
                if (cell.StringCellValue != null)
                {
                    columnName = cell.StringCellValue;
                }
            }
            dataTable.Columns.Add(columnName);
        }
    }
    else
    {
        for (var i = 0; i < maxColumnNumber; i++)
        {
            dataTable.Columns.Add();
        }
    }
    //循环处理有效行数据
    for (var i = isFirstRowAsColumnName ? sheet.FirstRowNum + 1 : sheet.FirstRowNum; i <= sheet.LastRowNum; i++)
    {
        var row = sheet.GetRow(i);
        var newRow = dataTable.NewRow();
        //通过工作簿sheet行数据填充表格新行数据
        FillDataRowBySheetRow(row, evaluator, newRow);
        //检查每单元格是否都有值
        var isNullRow = true;
        for (var j = 0; j < maxColumnNumber; j++)
        {
            isNullRow = isNullRow && newRow.IsNull(j);
        }
        if (!isNullRow)
        {
            dataTable.Rows.Add(newRow);
        }
    }
    return dataTable;
}

在实现工作簿Sheet转换为表格DataTable过程中,大致可以分为两步:

第一步求出工作簿Sheet中所有有效行中最宽的列编号,并以此为列数创建表格;

第二步把工作簿Sheet中所有有效行数据填充至表格中;

下面我们看看具体实现代码:

//获取工作簿Sheet中最大的列数
private static int GetMaxColumnNumber(ISheet sheet)
{
    var maxColumnNumber = 0;
    //在有效的行数据中
    for (var i = sheet.FirstRowNum; i <= sheet.LastRowNum; i++)
    {
        var row = sheet.GetRow(i);
        //找到最大的列编号
        if (row != null && row.LastCellNum > maxColumnNumber)
        {
            maxColumnNumber = row.LastCellNum;
        }
    }
    return maxColumnNumber;
}
//通过工作簿sheet行数据填充表格行数据
private static void FillDataRowBySheetRow(IRow row, IFormulaEvaluator evaluator, DataRow dataRow)
{
    if (row == null)
    {
        return;
    }
    for (var j = 0; j < dataRow.Table.Columns.Count; j++)
    {
        var cell = row.GetCell(j);
        if (cell != null)
        {
            switch (cell.CellType)
            {
                case CellType.Blank:
                    dataRow[j] = DBNull.Value;
                    break;
                case CellType.Boolean:
                    dataRow[j] = cell.BooleanCellValue;
                    break;
                case CellType.Numeric:
                    if (DateUtil.IsCellDateFormatted(cell))
                    {
                        dataRow[j] = cell.DateCellValue;
                    }
                    else
                    {
                        dataRow[j] = cell.NumericCellValue;
                    }
                    break;
                case CellType.String:
                    dataRow[j] = !string.IsNullOrWhiteSpace(cell.StringCellValue) ? cell.StringCellValue : DBNull.Value;
                    break;
                case CellType.Error:
                    dataRow[j] = cell.ErrorCellValue;
                    break;
                case CellType.Formula:
                    dataRow[j] = evaluator.EvaluateInCell(cell).ToString();
                    break;
                default:
                    throw new NotSupportedException("Unsupported cell type.");
            }
        }
    }
}


:测试方法代码以及示例源码都已经上传至代码库,有兴趣的可以看看。
https://gitee.com/hugogoos/Ideal

XPath解析

XPath(XML Path Language)是一种用于在XML和HTML文档中查找信息的语言,其通过路径表达式来定位节点,属性和文本内容,并支持复杂查询条件,XPath 是许多 Web 抓取工具如
Scrapy,Selenium
等的核心技术之一

XPath 解析的基本步骤

  1. 导入lxml.etree

    from lxml import etree
    
  2. 使用etree.parse(filename, parser=None)函数返回一个树形结构


    • etree.parse()
      用于解析本地XML或HTML文件,并将其转换为一个树形结构即
      ElementTree
      对象,可以通过该对象访问文档的各个节点
    • filename
      :要解析的文件路径
    • parser
      (可选):默认情况下,parser()会根据文件扩展名自动选择合适的解析器,如
      .xml
      文件使用XML解析器,.html使用HTML解析器
  3. 使用etree.HTML(html_string, parser=None)解析网络html字符串


    • html_string
      :要解析的HTML字符串
    • parser
      :(可选):默认情况下
      etree.HTML()
      使用
      etree.HTMLparser()
      进行解析
    • 返回值
      :etree.HTML()返回一个
      ELement
      对象,表示HTML文档的
      根元素
      ,可以通过该对象访问文档各个节点
  4. 使用.xpath(xpath_expression)在已经解析好的HTML文档中执行XPath查询

    result = html_tree.xpath(xpath_expression)
    

    • xpath_expression
      :XPath表达式,用于在文档中查找节点,XPath表达式可以是绝对路径或相对路径,也可以包含谓词,函数和轴操作,主要的XPath语法下面会展开讲解
    • html_tree
      :可以是
      ElementTree
      对象(由 etree.parse() 返回)或
      Element
      对象(由 etree.HTML() 返回)
from lxml import etree

# 使用etree.parser()解析文件路径
parser = etree.HTMLParser(encoding='utf-8')  # 以utf8进行编码
tree = etree.parse('../Learning02/三国演义.html', parser=parser)
print(tree)
#output-> <lxml.etree._ElementTree object at 0x000001A240107000>

# 使用etree.HTML()解析本地文件或网络动态HTML
# 读取文件 解析为字符串
file = open('../Learning02/三国演义.html', 'r', encoding='utf-8')
data = file.read()
root = etree.HTML(data)
print(root)

#整合
root = etree.HTML(open('../Learning02/三国演义.html', 'r', encoding='utf-8').read())
print(root)
#output-> <Element html at 0x1a23e462880>

XPath语法

XPath
语法可以用于在XML与HTML文档中查找信息的语言

路径表达式

XPath使用路径表达式来定位文档中的节点,路径也可以分为绝对路径与相对路径

绝对路径

  • /
    :表示从根节点开始选择,其用于定义一个绝对路径

从根节点html开始查找到head,再从head下找出title标签

root = etree.HTML(open('../Learning02/三国演义.html', 'r', encoding='utf-8').read())
all_titles = root.xpath('/html/head/title')
for title in all_titles:
    print(etree.tostring(title, encoding='utf-8').decode('utf-8'))
#output-> <title>《三国演义》全集在线阅读_史书典籍_诗词名句网</title>

相对路径

相比与绝对路径,相对路径使用率更好,更好用

  • //
    :表示从
    当前节点开始,
    选择文档中
    所有符合条件的节点,
    并且不考虑他们的位置
root = etree.HTML(open('../Learning02/三国演义.html', 'r', encoding='utf-8').read())
all_a = root.xpath('//a')
for a in all_a:
    print(a.text)
#None
#首页
#分类
#作者
#...

当前节点

  • ./
    :表示当前节点,通常用于指明当前节点本身,避免混淆
all_a = root.xpath('//a')
print(all_a[1].xpath('./text()')) #./表示当前的a标签
#output-> ['首页']

选择属性

  • @
    :用于选择元素的属性,而不是元素本身
# 使用 @ 选择 <a> 标签的 href 属性
all_hrefs = root.xpath('//a[@href]')
for hrefs in all_hrefs:
    print(etree.tostring(hrefs, encoding='unicode'))

XPath谓语

谓语是
xpath
中用于进一步筛选节点的表达式,通常放在方括号
[]
内,其可以基于节点的位置,属性值,文本内容或其他条件来
选择特定的节点,谓语可以嵌套使用,也可以与其他谓语组合使用

  • 基本语法

    //element[condition]
    

    • element
      :要选择的元素
    • condition
      :谓语中的条件,用于进一步筛选符合条件的元素

位置谓语

位置谓语用于根据节点在兄弟节点中的位置进行选择,可以使用
position()
或直接指定位置编号

  • 获取第一个
    ul
    标签中的第一个
    li
    标签

    #//ul获取的是所有ul,[0]选择第一个
    lis = root.xpath('//ul')[0].xpath('./li[1]')
    for li in lis:
        print(etree.tostring(li, encoding='unicode'))
    #output-> <li><a href="/">首页</a></li>
    
  • 使用
    last()
    获取最后第一个节点,和导数第二个节点

    # 倒一个
    last_li = root.xpath('//ul')[0].xpath('./li[last()]')
    print(etree.tostring(last_li[0], encoding='unicode'))
    # 倒二个
    last_second_li = root.xpath('//ul')[0].xpath('./li[last()-1]')
    print(etree.tostring(last_second_li[0], encoding='unicode'))
    #output-> <li><a href="/app/">安卓下载</a></li>
    #<li><a href="/book/">古籍</a></li>
    
  • 使用
    position()
    获取位置进行筛选

    # 获取前两个li标签
    last_li = root.xpath('//ul')[0].xpath('./li[position()<3]')
    for li in last_li:
        print(etree.tostring(li, encoding='unicode'))
    # 获取偶数位标签
    lis = root.xpath('//ul')[0].xpath('./li[position() mod 2=0]')
    for li in lis:
        print(etree.tostring(li, encoding='unicode'))
    
  • 属性谓语


    属性谓语用于
    选择具体特定属性的节点


    • 使用
      @attribute
      来获取属性名称,结合条件进行筛选

    # 选取所有具有 href 属性的 a 元素
    hrefs = root.xpath("//a[@href]")
    for href in hrefs:
        print(etree.tostring(href, encoding='unicode'))
    

    • 查找
      class
      属性值

    all_class = root.xpath('//@class')
    print(all_class)
    
  • 组合谓语


    将多个条件组合在一起,使用逻辑运算符
    and,or
    等来创建更复杂的谓语


    #选取href属性值为https://example.com且class属性值为link的a元素
    //a[@href='https://example.com' and @class='link']
    

    #选取href属性值为https://example.com或https://another.com的a 元素
    //a[@href='https://example.com' or @href='https://another.com']
    
  • 函数谓语


    Xpath提供了许多内置函数,来应对更复杂的筛选条件


    • contains((string1, string2)
      函数:


      • string1
        :要搜索的字符串
      • string2
        :要查找的字符串

      # 选取class包含"book"的img标签
      images = root.xpath('//img[contains(@src,"book")]')
      for image in images:
          print(etree.tostring(image, encoding='unicode'))
      
    • starts-with(string1, string2)
      函数:


      检查一个字符串是否以指定字符的前缀开始,是返回
      true
      ,否返回
      false


      • string1:
        要检查的字符串
      • string2:
        作为前缀的字符串

      # 选取所有href以https://开头的a标签
      all_a = root.xpath('//a[starts-with(@href,"https:")]')
      for a in all_a:
          print(etree.tostring(a, encoding='unicode'))
      
  • 文本内容谓语


    用于选择包含特定文本内容的节点,可以使用
    text()
    函数来提取节点的文本内容


    # 选择使用包含"三国"文本的p标签
    paragraphs = root.xpath('//p[contains(text(),"三国")]')
    for p in paragraphs:
        print(etree.tostring(p, encoding='unicode'))
    

通配符

xpath提供了多种通配符,用于在路径表达式中匹配未知的元素,属性,或任何节点.这些通配符非常有用,尤其是当不确定具体节点名称和结构的情况下

通配符 描述
* 匹配任何元素节点。 一般用于浏览器copy xpath会出现
@* 匹配任何属性节点。
node() 匹配任何类型的节点。

使用
*
匹配任何元素节点

  • *
    是最常用的通配符之一,其可以匹配任何元素,而不需要具体标签名.这在不确定元素名称或希望选择所有类型的元素时非常有用
# 选择所有 div 下的所有子元素
divs = root.xpath("//div/*")
for div in divs:
    print(etree.tostring(div, encoding='unicode'))

使用
@*
匹配任何属性节点

  • @*
    用于匹配任何属性节点,而不用指定具体属性名称,在你不确定属性名称或希望选择所有属性时非常有用
# 选择所有 a 元素的所有属性
all_a = root.xpath('//a/@*')
for a in all_a:
    print(a)

使用
node()
匹配任何类型的节点

  • node()
    是一个更通用的通配符,其能匹配任何类型节点,包括元素节点,文本节点,属性节点,注释节点等等,其在需要选择不仅仅是元素节点是十分有用
# 选择所有 ul 下的所有子节点(包括文本节点)
nodes = root.xpath('//ul/node()')
print(nodes)
#output-> ['\n ', <Element li at 0x2009621d800>, '\n,...] 

XPath,re正则,BeautifulSoup对比

在之前的学习中我们首先学习了re正则表达式,其次学习了更加便捷的bs4,哪为何还要学习XPath解析呢,接下来我们将它们的优点和适用场景进行对比学习

工具 优点 缺点 适用场景
XPath 强大的路径表达能力,支持层级结构和条件查询 学习曲线较陡,对不规范 HTML 容错性较差 结构化良好的 XML/HTML,复杂查询
re 灵活性高,适合处理纯文本中的模式匹配 不适合解析 HTML/XML,可读性差 从纯文本中提取特定模式的数据
BeautifulSoup 易于使用,容错性强,适合初学者 性能稍低,功能有限 不规范的 HTML,简单数据提取,网页抓取
  • 总结
    • 若需要处理结构良好的XML或HTML文档,并需要进行复杂查询
      ,那么XPath解析是最佳选择
    • 若需要从纯文本中提取特定模式的数据时
      ,如从日志中提取日期,IP地址的,re正则表达式是最佳选择
    • 需要解析不规范的 HTML 或者只需要进行简单的数据提取,
      BeautifulSoup 是最友好的选择

简介

在上一篇文章《
机器学习:神经网络构建(上)
》中讨论了线性层、激活函数以及损失函数层的构建方式,本节中将进一步讨论网络构建方式,并完整的搭建一个简单的分类器网络。

目录

  1. 网络Network
  2. 数据集管理器 DatasetManager
  3. 优化器 Optimizer
  4. 代码测试

网络Network

网络定义


在设计神经网络时,其基本结构是由一层层的神经元组成的,这些层可以是输入层、隐藏层和输出层。为了实现这一结构,通常会使用向量(vector)容器来存储这些层,因为层的数量是可变的,可能根据具体任务的需求而变化。

即使在网络已经进行了预训练并具有一定的参数的情况下,对于特定的任务,通常还是需要进行模型微调。这是因为不同的任务可能有不同的数据分布和要求,因此训练是构建高性能神经网络模型的重要步骤。

在训练过程中,有三个关键组件:

  1. 损失函数
    :神经网络的学习目标,通过最小化损失函数来优化模型参数。选择合适的损失函数对于确保模型能够学习到有效的特征表示至关重要。

  2. 优化器
    :优化器负责调整模型的参数以最小化损失函数。除了基本的参数更新功能外,优化器还可以提供更高级的功能,如学习率调整和参数冻结,这些功能有助于提高训练效率和模型性能。

  3. 数据集管理器
    :负责在训练过程中有效地管理和提供数据,包括数据的加载、预处理和批处理,以确保数据被充分利用。

对于网络的外部接口(公有方法),主要有以下几类:

  1. 网络设置
    :添加网络层、设置损失函数、优化器和数据集等操作,用于配置网络的结构和训练参数。
  2. 网络推理
    :前向传播和反向传播方法,用于在训练和测试过程中进行预测和参数更新。
  3. 网络训练
    :使用配置好的数据集和训练方法,执行指定次数的训练迭代,以优化网络参数。

以下是代码示例:

class Network {
private:
    vector<shared_ptr<Layer>> layers;

    shared_ptr<LossFunction> lossFunction;
    shared_ptr<Optimizer> optimizer;
    shared_ptr<DatasetManager> datasetManager;

public:
    void addLayer(shared_ptr<Layer> layer);

    void setLossFunction(shared_ptr<LossFunction> lossFunc);
    void setOptimizer(shared_ptr<Optimizer> opt);
    void setDatasetManager(shared_ptr<DatasetManager> manager);

    MatrixXd forward(const MatrixXd& input);
    void backward(const MatrixXd& outputGrad);

    double train(size_t epochs, size_t batchSize);
};

使用shared_ptr的好处:
存储方式vector<shared_ptr
>和vector 相比,如果直接存储 Layer 对象,需要手动管理内存,包括分配和释放内存,这不仅容易出错,还可能导致内存泄漏或悬挂指针的问题。而使用 std::shared_ptr 可以大大简化内存管理,提高代码的健壮性和可维护性。

网络训练


网络的训练函数通常包含两个输入参数,训练的集数和批尺寸:

  • 集数
    epochs

    :指训练集被完整的迭代的次数。在每一个epoch中,网络会使用训练集中的所有样本进行参数更新。

  • 批尺寸
    batchSize

    :指在一次迭代中用于更新模型参数的样本数量。在每次迭代中,模型会计算这些样本的总梯度,并据此调整模型的参数。

因此,网络的训练函数由两层循环结构组成,外层循环结构表示完整迭代的次数,直至完成所有迭代时停止。内层循环表示训练集中样本被网络调取的进度,直至训练集中的所有数据被调用时停止。

网络的训练过程是由多次的参数迭代(更新)完成的。而参数的的迭代是以批(Batch)为单位的。具体来说,一次迭代包含如下步骤:

  1. 获取数据
    :从数据集管理器中获取一批的数据(包含输入和输出)
  2. 前向传播
    :采用网络对数据进行推理,得到预测结果,依据预测结果评估损失。
  3. 反向传播
    :计算损失函数关于各层参数的梯度。
  4. 参数更新
    :依据损失、梯度等信息,更新各层梯度。
  5. 日志更新
    :计算并输出每个epoch的累积误差。

代码设计如下:

double Network::train(size_t epochs, size_t batchSize) {
    double totalLoss = 0.0;
    size_t sampleCount = datasetManager->getTrainSampleCount();

    for (size_t epoch = 0; epoch < epochs; ++epoch) {
        datasetManager->shuffleTrainSet();
        totalLoss = 0.0;
        for (size_t i = 0; i < sampleCount; i += batchSize) {
            // 获取一个小批量样本
            auto batch = datasetManager->getTrainBatch(batchSize, i / batchSize);
            MatrixXd batchInput = batch.first;
            MatrixXd batchLabel = batch.second;

            // 前向传播
            MatrixXd predicted = forward(batchInput);
            double loss = lossFunction->computeLoss(predicted, batchLabel);

            // 反向传播
            MatrixXd outputGrad = lossFunction->computeGradient(predicted, batchLabel);
            backward(outputGrad);

            // 参数更新
            optimizer->update(layers);

            // 累计损失
            totalLoss += loss;
        }
        totalLoss /= datasetManager->getTrainSampleCount();
        // 输出每个epoch的损失等信息
        std::cout << "Epoch " << epoch << ", totalLoss = " << totalLoss << "\n";
    }
    return totalLoss / (epochs * (sampleCount / batchSize)); // 返回平均损失(简化示例)
}

网络的其它公有方法


下面的代码给出了网络的其它公有方法的代码实现:

void Network::addLayer(std::shared_ptr<Layer> layer) {
    layers.push_back(layer);
}

void Network::setLossFunction(std::shared_ptr<LossFunction> lossFunc) {
    lossFunction = lossFunc;
}

void Network::setOptimizer(std::shared_ptr<Optimizer> opt) {
    optimizer = opt;
}

void Network::setDatasetManager(std::shared_ptr<DatasetManager> manager) {
    datasetManager = manager;
}

MatrixXd Network::forward(const MatrixXd& input) {
    MatrixXd currentInput = input;
    for (const auto& layer : layers) {
        currentInput = layer->forward(currentInput);
    }
    return currentInput;
}

void Network::backward(const MatrixXd& outputGrad) {
    MatrixXd currentGrad = outputGrad;
    for (auto it = layers.rbegin(); it != layers.rend(); ++it) {
        currentGrad = (*it)->backward(currentGrad);
    }
}

forward
方法除了作为训练时的步骤之一,还经常用于网络推理(预测),因此声明为公有方法

backward
方法只在训练时使用,在正常的使用用途中,不会被外部调用,因此,其可以声明为私有方法。

数据集管理器 DatasetManager


数据集管理器本质目的是提高网络对数据的利用率,其主要职能有:

  1. 保存数据:提供更为安全可靠的数据管理。
  2. 数据打乱:以避免顺序偏差,同时提升模型的泛化能力。
  3. 数据集划分:讲数据划分为训练集、验证集和测试集。
  4. 数据接口:使得外部可以轻松的获取批量数据。
    class DatasetManager {
    private:
        MatrixXd input;
        MatrixXd label;
        std::vector<int> trainIndices;
        std::vector<int> valIndices;
        std::vector<int> testIndices;

    public:
        // 设置数据集的方法
        void setDataset(const MatrixXd& inputData, const MatrixXd& labelData);

        // 划分数据集为训练集、验证集和测试集
        void splitDataset(double trainRatio = 0.8, double valRatio = 0.1, double testRatio = 0.1);

        // 获取训练集、验证集和测试集的小批量数据
        std::pair<MatrixXd, MatrixXd> getBatch(std::vector<int>& indices, size_t batchSize, size_t offset = 0);

        // 随机打乱训练集
        void shuffleTrainSet();

        // 获取批量数据
        std::pair<MatrixXd, MatrixXd> getTrainBatch(size_t batchSize, size_t offset = 0);
        std::pair<MatrixXd, MatrixXd> getValidationBatch(size_t batchSize, size_t offset = 0);
        std::pair<MatrixXd, MatrixXd> getTestBatch(size_t batchSize, size_t offset = 0);

        // 获取样本数量的方法
        size_t getSampleCount() const;
        size_t getTrainSampleCount() const;
        size_t getValidationSampleCount() const;
        size_t getTestSampleCount() const;
    };

数据集初始化


数据集初始化分为三步:数据集设置、数据集划分、数据集打乱。

// 设置数据集
void  ML::DatasetManager::setDataset(const MatrixXd& inputData, const MatrixXd& labelData) {
    input = inputData;
    label = labelData;

    trainIndices.resize(input.rows());
    std::iota(trainIndices.begin(), trainIndices.end(), 0);
    valIndices.clear();
    testIndices.clear();
}

// 打乱训练集
void ML::DatasetManager::shuffleTrainSet() {
    std::shuffle(trainIndices.begin(), trainIndices.end(), std::mt19937{ std::random_device{}() });
}

// 划分数据集为训练集、验证集和测试集
void ML::DatasetManager::splitDataset(double trainRatio, double valRatio, double testRatio) {
    size_t totalSamples = input.rows();
    size_t trainSize = static_cast<size_t>(totalSamples * trainRatio);
    size_t valSize = static_cast<size_t>(totalSamples * valRatio);
    size_t testSize = totalSamples - trainSize - valSize;

    shuffleTrainSet();

    valIndices.assign(trainIndices.begin() + trainSize, trainIndices.begin() + trainSize + valSize);
    testIndices.assign(trainIndices.begin() + trainSize + valSize, trainIndices.end());
    trainIndices.resize(trainSize);
}

对于打乱操作较频繁的场景,打乱索引是更为高效的操作;而对于不经常打乱的场景,直接在数据集上打乱更为高效。本例中仅给出打乱索引的代码示例。

数据获取


在获取数据时,首先明确所需数据集的类型(训练集或验证集)。然后,根据预设的批次大小(Batchsize),从索引列表中提取相应数量的索引,并将这些索引对应的数据存储到临时矩阵中。最后,导出数据,完成读取操作。

// 获取训练集、验证集和测试集的小批量数据
std::pair<MatrixXd, MatrixXd> ML::DatasetManager::getBatch(std::vector<int>& indices, size_t batchSize, size_t offset) {
    size_t start = offset * batchSize;
    size_t end = std::min(start + batchSize, indices.size());
    MatrixXd batchInput = MatrixXd::Zero(end - start, input.cols());
    MatrixXd batchLabel = MatrixXd::Zero(end - start, label.cols());

    for (size_t i = start; i < end; ++i) {
        batchInput.row(i - start) = input.row(indices[i]);
        batchLabel.row(i - start) = label.row(indices[i]);
    }

    return std::make_pair(batchInput, batchLabel);
}

// 获取训练集的批量数据
std::pair<MatrixXd, MatrixXd> ML::DatasetManager::getTrainBatch(size_t batchSize, size_t offset) {
    return getBatch(trainIndices, batchSize, offset);
}

// 获取验证集的批量数据
std::pair<MatrixXd, MatrixXd> ML::DatasetManager::getValidationBatch(size_t batchSize, size_t offset) {
    return getBatch(valIndices, batchSize, offset);
}

// 获取测试集的批量数据
std::pair<MatrixXd, MatrixXd> ML::DatasetManager::getTestBatch(size_t batchSize, size_t offset) {
    return getBatch(testIndices, batchSize, offset);
}

数据集尺寸的外部接口


为便于代码开发,需要为数据集管理器设计外部接口,以便于外部可以获取各个数据集的尺寸。

size_t ML::DatasetManager::getSampleCount() const {
    return input.rows();
}

size_t ML::DatasetManager::getTrainSampleCount() const {
    return trainIndices.size();
}

size_t ML::DatasetManager::getValidationSampleCount() const {
    return valIndices.size();
}

size_t ML::DatasetManager::getTestSampleCount() const {
    return testIndices.size();
}

优化器 Optimizer


随机梯度下降是一种优化算法,用于最小化损失函数以训练模型参数。与批量梯度下降(Batch Gradient Descent)不同,SGD在每次更新参数时只使用一个样本(或一个小批量的样本),而不是整个训练集。这使得SGD在计算上更高效,且能够更快地收敛,尤其是在处理大规模数据时。以下为随机梯度下降的代码示例:

class Optimizer {
public:
    virtual void update(std::vector<std::shared_ptr<Layer>>& layers) = 0;
    virtual ~Optimizer() {}
};

class SGDOptimizer : public Optimizer {
private:
    double learningRate;
public:
    SGDOptimizer(double learningRate) : learningRate(learningRate) {}
    void update(std::vector<std::shared_ptr<Layer>>& layers) override;
};

void SGDOptimizer::update(std::vector<std::shared_ptr<Layer>>& layers) {
    for (auto& layer : layers) {
        layer->update(learningRate);
    }
}

代码测试


如果你希望测试这些代码,首先可以从本篇文章,以及
上一篇文章
中复制代码,并参考下述图片构建你的解决方案。
description
如果你有遇到问题,欢迎联系作者!

示例1:线性回归


下述代码为线性回归的测试样例:

namespace LNR{
    // linear_regression
    void gen(MatrixXd& X, MatrixXd& y);
    void test();
}

void LNR::gen(MatrixXd& X, MatrixXd& y) {
    MatrixXd w(X.cols(), 1);

    X.setRandom();
    w.setRandom();

    X.rowwise() -= X.colwise().mean();
    X.array().rowwise() /= X.array().colwise().norm();

    y = X * w;
}

void LNR::test() {
    std::cout << std::fixed << std::setprecision(2);

    size_t input_dim = 10;
    size_t sample_num = 2000;

    MatrixXd X(sample_num, input_dim);
    MatrixXd y(sample_num, 1);

    gen(X, y);

    ML::DatasetManager dataset;
    dataset.setDataset(X, y);

    ML::Network net;

    net.addLayer(std::make_shared<ML::Linear>(input_dim, 1));

    net.setLossFunction(std::make_shared<ML::MSELoss>());
    net.setOptimizer(std::make_shared<ML::SGDOptimizer>(0.25));
    net.setDatasetManager(std::make_shared<ML::DatasetManager>(dataset));

    size_t epochs = 600;
    size_t batch_size = 50;
    net.train(epochs, batch_size);

    MatrixXd error(sample_num, 1);

    error = net.forward(X) - y;

    std::cout << "error=\n" << error << "\n";
}

详细解释

  1. gen
    函数:用以生成测试数据。
  2. 网络结构:本例的网络结构中只包含一个线性层,其中输入尺寸为特征维度,输出尺寸为1。
  3. 损失函数:采用MSE均方根误差作为损失函数。

输出展示

完成训练后,网络预测值与真实值的误差如下图;容易发现,网络具有较好的预测精度。

description

示例2:逻辑回归


下述代码为逻辑回归的测试样例:

namespace LC {
    // Linear classification
    void gen(MatrixXd& X, MatrixXd& y);
    void test();
}

void LC::gen(MatrixXd& X, MatrixXd& y) {
    MatrixXd w(X.cols(), 1);

    X.setRandom();
    w.setRandom();

    X.rowwise() -= X.colwise().mean();
    X.array().rowwise() /= X.array().colwise().norm();

    y = X * w;

    y = y.unaryExpr([](double x) { return x > 0.0 ? 1.0 : 0.0; });
}

void LC::test() {
    std::cout << std::fixed << std::setprecision(3);

    size_t input_dim = 10;
    size_t sample_num = 2000;

    MatrixXd X(sample_num, input_dim);
    MatrixXd y(sample_num, 1);

    gen(X, y);

    ML::DatasetManager dataset;
    dataset.setDataset(X, y);

    ML::Network net;

    net.addLayer(std::make_shared<ML::Linear>(input_dim, 1));
    net.addLayer(std::make_shared<ML::Sigmoid>());

    net.setLossFunction(std::make_shared<ML::LogisticLoss>());
    net.setOptimizer(std::make_shared<ML::SGDOptimizer>(0.05));
    net.setDatasetManager(std::make_shared<ML::DatasetManager>(dataset));

    size_t epochs = 200;
    size_t batch_size = 25;
    net.train(epochs, batch_size);

    MatrixXd predict(sample_num, 1);

    predict = net.forward(X);

    predict = predict.unaryExpr([](double x) { return x > 0.5 ? 1.0 : 0.0; });

    MatrixXd error(sample_num, 1);

    error = y - predict;

    error = error.unaryExpr([](double x) {return (x < 0.01 && x>-0.01) ? 1.0 : 0.0; });

    std::cout << "正确率=\n" << error.sum() / sample_num << "\n";
}

详细解释

  1. gen
    函数:用以生成测试数据。
  2. 网络结构:本例的网络结构中包含一个线性层及一个激活函数层,其中:线性层输入尺寸为特征维度,输出尺寸为1。
  3. 损失函数:采用对数误差作为损失函数。

输出展示
下图反映了网络预测过程中的损失变化,可以看到损失逐渐下降的趋势。
description

完成训练后,输出网络的预测结果的正确率。可以发现,网络具有较好的预测精度。
description