wenmo8 发布的文章

之前写过
两篇关于Roslyn源生成器生成源代码的用例
,今天使用Roslyn的代码修复器
CodeFixProvider
实现一个cs文件头部注释的功能,

代码修复器会同时涉及到
CodeFixProvider

DiagnosticAnalyzer
,

实现FileHeaderAnalyzer

首先我们知道修复器的先决条件是分析器,比如这里,如果要对代码添加头部注释,那么分析器必须要给出对应的分析提醒:

我们首先实现实现名为
FileHeaderAnalyzer
的分析器:

[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class FileHeaderAnalyzer : DiagnosticAnalyzer
{
    public const string DiagnosticId = "GEN050";
    private static readonly LocalizableString Title = "文件缺少头部信息";
    private static readonly LocalizableString MessageFormat = "文件缺少头部信息";
    private static readonly LocalizableString Description = "每个文件应包含头部信息.";
    private const string Category = "Document";

    private static readonly DiagnosticDescriptor Rule = new(
        DiagnosticId, Title, MessageFormat, Category, DiagnosticSeverity.Warning, isEnabledByDefault: true, description: Description);

    public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics => [Rule];

    public override void Initialize(AnalysisContext context)
    {
        if (context is null)
            return;

        context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
        context.EnableConcurrentExecution();
        context.RegisterSyntaxTreeAction(AnalyzeSyntaxTree);
    }

    private static void AnalyzeSyntaxTree(SyntaxTreeAnalysisContext context)
    {
        var root = context.Tree.GetRoot(context.CancellationToken);
        var firstToken = root.GetFirstToken();

        // 检查文件是否以注释开头
        var hasHeaderComment = firstToken.LeadingTrivia.Any(trivia => trivia.IsKind(SyntaxKind.SingleLineCommentTrivia) || trivia.IsKind(SyntaxKind.MultiLineCommentTrivia));

        if (!hasHeaderComment)
        {
            var diagnostic = Diagnostic.Create(Rule, Location.Create(context.Tree, TextSpan.FromBounds(0, 0)));
            context.ReportDiagnostic(diagnostic);
        }
    }
}

FileHeaderAnalyzer分析器的原理很简单,需要重载几个方法,重点是
Initialize
方法,这里的
RegisterSyntaxTreeAction
即核心代码,
SyntaxTreeAnalysisContext
对象取到当前源代码的
SyntaxNode
根节点,然后判断TA的第一个
SyntaxToken
是否为注释行(SyntaxKind.SingleLineCommentTrivia|SyntaxKind.MultiLineCommentTrivia)

如果不为注释行,那么就通知分析器!

实现了上面的代码我们看一下效果:

image

并且编译的时候分析器将会在错误面板中显示警告清单:

image

实现CodeFixProvider

分析器完成了,现在我们就来实现名为
AddFileHeaderCodeFixProvider
的修复器,

/// <summary>
/// 自动给文件添加头部注释
/// </summary>
[ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(AddFileHeaderCodeFixProvider))]
[Shared]
public class AddFileHeaderCodeFixProvider : CodeFixProvider
{
    private const string Title = "添加文件头部信息";
    //约定模板文件的名称
    private const string ConfigFileName = "Biwen.AutoClassGen.Comment";
    private const string VarPrefix = "$";//变量前缀
    //如果模板不存在的时候的默认注释文本
    private const string DefaultComment = """
        // Licensed to the {Product} under one or more agreements.
        // The {Product} licenses this file to you under the MIT license.
        // See the LICENSE file in the project root for more information.
        """;

    #region regex

    private const RegexOptions ROptions = RegexOptions.Compiled | RegexOptions.Singleline;
    private static readonly Regex VersionRegex = new(@"<Version>(.*?)</Version>", ROptions);
    private static readonly Regex CopyrightRegex = new(@"<Copyright>(.*?)</Copyright>", ROptions);
    private static readonly Regex CompanyRegex = new(@"<Company>(.*?)</Company>", ROptions);
    private static readonly Regex DescriptionRegex = new(@"<Description>(.*?)</Description>", ROptions);
    private static readonly Regex AuthorsRegex = new(@"<Authors>(.*?)</Authors>", ROptions);
    private static readonly Regex ProductRegex = new(@"<Product>(.*?)</Product>", ROptions);
    private static readonly Regex TargetFrameworkRegex = new(@"<TargetFramework>(.*?)</TargetFramework>", ROptions);
    private static readonly Regex TargetFrameworksRegex = new(@"<TargetFrameworks>(.*?)</TargetFrameworks>", ROptions);
    private static readonly Regex ImportRegex = new(@"<Import Project=""(.*?)""", ROptions);

    #endregion

    public sealed override ImmutableArray<string> FixableDiagnosticIds
    {
        //重写FixableDiagnosticIds,返回分析器的报告Id,表示当前修复器能修复的对应Id
        get { return [FileHeaderAnalyzer.DiagnosticId]; }
    }

    public sealed override FixAllProvider GetFixAllProvider()
    {
        return WellKnownFixAllProviders.BatchFixer;
    }

    public override Task RegisterCodeFixesAsync(CodeFixContext context)
    {
        var diagnostic = context.Diagnostics[0];
        var diagnosticSpan = diagnostic.Location.SourceSpan;

        context.RegisterCodeFix(
            CodeAction.Create(
                title: Title,
                createChangedDocument: c => FixDocumentAsync(context.Document, diagnosticSpan, c),
                equivalenceKey: Title),
            diagnostic);

        return Task.CompletedTask;
    }


    private static async Task<Document> FixDocumentAsync(Document document, TextSpan span, CancellationToken ct)
    {
        var root = await document.GetSyntaxRootAsync(ct).ConfigureAwait(false);

        //从项目配置中获取文件头部信息
        var projFilePath = document.Project.FilePath ?? "C:\\test.csproj";//单元测试时没有文件路径,因此使用默认路径

        var projectDirectory = Path.GetDirectoryName(projFilePath);
        var configFilePath = Path.Combine(projectDirectory, ConfigFileName);

        var comment = DefaultComment;

        string? copyright = "MIT";
        string? author = Environment.UserName;
        string? company = string.Empty;
        string? description = string.Empty;
        string? title = document.Project.Name;
        string? version = document.Project.Version.ToString();
        string? product = document.Project.AssemblyName;
        string? file = Path.GetFileName(document.FilePath);
        string? targetFramework = string.Empty;
#pragma warning disable CA1305 // 指定 IFormatProvider
        string? date = DateTime.Now.ToString("yyyy-MM-dd HH:mm:ss");
#pragma warning restore CA1305 // 指定 IFormatProvider


        if (File.Exists(configFilePath))
        {
            comment = File.ReadAllText(configFilePath, System.Text.Encoding.UTF8);
        }

        #region 查找程序集元数据

        // 加载项目文件:
        var text = File.ReadAllText(projFilePath, System.Text.Encoding.UTF8);
        // 载入Import的文件,例如 : <Import Project="..\Version.props" />
        // 使用正则表达式匹配Project:
        var importMatchs = ImportRegex.Matches(text);
        foreach (Match importMatch in importMatchs)
        {
            var importFile = Path.Combine(projectDirectory, importMatch.Groups[1].Value);
            if (File.Exists(importFile))
            {
                text += File.ReadAllText(importFile);
            }
        }

        //存在变量引用的情况,需要解析
        string RawVal(string old, string @default)
        {
            if (old == null)
                return @default;

            //当取得的版本号为变量引用:$(Version)的时候,需要再次解析
            if (version.StartsWith(VarPrefix, StringComparison.Ordinal))
            {
                var varName = old.Substring(2, old.Length - 3);
                var varMatch = new Regex($@"<{varName}>(.*?)</{varName}>", RegexOptions.Singleline).Match(text);
                if (varMatch.Success)
                {
                    return varMatch.Groups[1].Value;
                }
                //未找到变量引用,返回默
                return @default;
            }
            return old;
        }

        var versionMatch = VersionRegex.Match(text);
        var copyrightMath = CopyrightRegex.Match(text);
        var companyMatch = CompanyRegex.Match(text);
        var descriptionMatch = DescriptionRegex.Match(text);
        var authorsMatch = AuthorsRegex.Match(text);
        var productMatch = ProductRegex.Match(text);
        var targetFrameworkMatch = TargetFrameworkRegex.Match(text);
        var targetFrameworksMatch = TargetFrameworksRegex.Match(text);

        if (versionMatch.Success)
        {
            version = RawVal(versionMatch.Groups[1].Value, version);
        }
        if (copyrightMath.Success)
        {
            copyright = RawVal(copyrightMath.Groups[1].Value, copyright);
        }
        if (companyMatch.Success)
        {
            company = RawVal(companyMatch.Groups[1].Value, company);
        }
        if (descriptionMatch.Success)
        {
            description = RawVal(descriptionMatch.Groups[1].Value, description);
        }
        if (authorsMatch.Success)
        {
            author = RawVal(authorsMatch.Groups[1].Value, author);
        }
        if (productMatch.Success)
        {
            product = RawVal(productMatch.Groups[1].Value, product);
        }
        if (targetFrameworkMatch.Success)
        {
            targetFramework = RawVal(targetFrameworkMatch.Groups[1].Value, targetFramework);
        }
        if (targetFrameworksMatch.Success)
        {
            targetFramework = RawVal(targetFrameworksMatch.Groups[1].Value, targetFramework);
        }

        #endregion

        //使用正则表达式替换
        comment = Regex.Replace(comment, @"\{(?<key>[^}]+)\}", m =>
        {
            var key = m.Groups["key"].Value;
            return key switch
            {
                "Product" => product,
                "Title" => title,
                "Version" => version,
                "Date" => date,
                "Author" => author,
                "Company" => company,
                "Copyright" => copyright,
                "File" => file,
                "Description" => description,
                "TargetFramework" => targetFramework,
                _ => m.Value,
            };
        }, RegexOptions.Singleline);

        var headerComment = SyntaxFactory.Comment(comment + Environment.NewLine);
        var newRoot = root?.WithLeadingTrivia(headerComment);
        if (newRoot == null)
        {
            return document;
        }
        var newDocument = document.WithSyntaxRoot(newRoot);

        return newDocument;
    }
}

代码修复器最重要的重载方法
RegisterCodeFixesAsync
,对象
CodeFixContext
包含项目和源代码以及对应分析器的信息:

比如:
CodeFixContext.Document
表示对应的源代码,
CodeFixContext.Document.Project
表示对应项目,
CodeFixContext.Document.Project.FilePath
就是代码中我需要的
*.csproj
的项目文件,

我们取到项目文件,那么我们就可以读取配置在项目文件中的信息,比如
Company
,
Authors
,
Description
,甚至上一篇我们提到的版本号等有用信息,当前我用的正则表达式,当然如果可以你也可以使用
XPath
,
然后取到的有用数据替换模板即可得到想要的注释代码片段了!

比如我的Comment模板文件
Biwen.AutoClassGen.Comment

// Licensed to the {Product} under one or more agreements.
// The {Product} licenses this file to you under the MIT license. 
// See the LICENSE file in the project root for more information.
// {Product} Author: {Author} Github: https://github.com/vipwan
// {Description}
// Modify Date: {Date} {File}

替换后将会生成如下的代码:

// Licensed to the Biwen.QuickApi under one or more agreements.
// The Biwen.QuickApi licenses this file to you under the MIT license. 
// See the LICENSE file in the project root for more information.
// Biwen.QuickApi Author: 万雅虎 Github: https://github.com/vipwan
// Biwen.QuickApi ,NET9+ MinimalApi CQRS
// Modify Date: 2024-09-07 15:22:42 Verb.cs

最后使用
SyntaxFactory.Comment(comment)
方法生成一个注释的
SyntaxTrivia
并附加到当前的根语法树上,最后返回这个新的
Document
即可!

大功告成,我们来看效果:
image

以上代码就完成了整个代码修复器步骤,最后你可以使用我发布的nuget包体验:

dotnet add package Biwen.AutoClassGen

源代码我发布到了GitHub,欢迎star!
https://github.com/vipwan/Biwen.AutoClassGen

https://github.com/vipwan/Biwen.AutoClassGen/blob/master/Biwen.AutoClassGen.Gen/CodeFixProviders/AddFileHeaderCodeFixProvider.cs

XGBoost模型 0基础小白也能懂(附代码)

原文链接

啥是XGBoost模型

XGBoost 是 eXtreme Gradient Boosting 的缩写称呼,它是一个非常强大的 Boosting 算法工具包,优秀的性能(效果与速度)让其在很长一段时间内霸屏数据科学比赛解决方案榜首,现在很多大厂的机器学习方案依旧会首选这个模型。

XGBoost 在并行计算效率、缺失值处理、控制过拟合、预测泛化能力上都变现非常优秀。本文我们给大家详细展开介绍 XGBoost,包含「算法原理」和「工程实现」两个方面。

关于 XGBoost 的原理,其作者陈天奇本人有一个非常详尽的
Slides
做了系统性的介绍。

Boosted Tree

Boosted Tree(提升树)是一种常用的机器学习方法,属于集成学习的一种。它通过将多个弱学习器(通常是决策树)组合起来,以提升整个模型的预测性能。Boosted Tree的核心思想是通过逐步训练多个决策树,每个树都试图修正前一个树的错误,最终得到一个更强大的模型。

模型:假设我们有
\(K\)
棵树
\(\hat{y_i}=\sum_{k=1}^Kf_k(x_i),f_k\in{F}\)

\(F\)
为包含所有回归树的函数空间。
目标函数:
\(Obj=\sum_{i=1}^nl(y_i,\hat{y_i})+\sum_{k=1}^K\Omega(f_k)\)
\(\sum_{i=1}^nl(y_i,\hat{y_i})\)
是成本函数
\(\sum_{k=1}^K\Omega(f_k)\)
是正则化项,代表树的复杂程度,树越复杂正则化项的值越高(正则化项如何定义我们会在后面详细说)。

当我们讨论决策树或相关的树模型时,通常是启发式的。启发式(heuristic)在机器学习中指的是使用经验法则或近似方法来解决问题,而不保证找到最优解。

Gradient Boosting(如何学习)

在做 GBDT 的时候,我们没有办法使用 SGD(Stochastic Gradient Descent,随机梯度下降),因为它们是树,而非数值向量——也就是说从原来我们熟悉的参数空间变成了函数空间。Gradient Boosting Decision Trees(GBDT)与深度学习或线性模型不同,它的核心不是直接通过参数更新来优化,而是通过构建新的决策树来逐步降低误差。

解决方案:初始化一个预测值,每次迭代添加一个新函数
\((f)\)

1)目标函数变换

根据解决方案可以对目标函数进行初步变形

其中constant是常数项,比如
\(\Omega(f_1),\Omega(f_2)\)
之类的,然后第三行就是考虑平方损失,
\(l(y_i,\hat{y_i})=\frac{1}{2}(y_i-\hat{y_i})^2\)
,代进去就行

所以我们的目的就是找到
\(f(t)\)
使得目标函数最低。然而,经过上面初次变形的目标函数仍然很复杂,目标函数会产生二次项。引入泰勒公式

这图也多少有点问题,是在还没考虑平方损失的地方引入泰勒公式,然后泰勒公式也有问题,后面两项应该是
\(f(x)\)
的一阶导数和二阶导数,所以才是
\(g_i,h_i\)

再把里面的常数项提取出,和
\(f_t\)
无关

2)重新定义树

前面已经用
\(f_t(x)\)
表示一棵树,在本小节,我们重新定义一下树:我们通过叶子结点中的分数向量和将实例映射到叶子结点的索引映射函数来定义树:(有点儿抽象,具体请看下图)

图里有问题,第一个叶子结点权重是+2

3)定义树的复杂程度

其中
\(T\)
才是叶子节点的个数,
\(\gamma\)
是控制树的复杂度的参数,树的叶子节点越多,复杂度越高。通过调节
\(\gamma\)
可以控制模型的复杂度。后面一堆是 L2 Norm正则化系数

4)重新审视目标函数

定义在叶子结点
\(j\)
中的实例的集合为:
\(I_j=\{i|q(x_i)=j\}\)
,这么定义也是为了能够构建出第三个式子,都写成
\(\sum_{j=1}^T\)

同时也会发现上式是
\(T\)
个独立二次函数的和

5)计算叶子结点的值

搞了一大坨,其实也就是先把值换成
\(G_j,H_j\)
,然后用一元二次方程求一个最优值就完了。

下图是前面公式讲解对应的一个实际例子。

这里再次总结一下,我们已经把目标函数变成了仅与
\(G,H,\gamma,\lambda,T\)
这五项已知参数有关的函数,把之前的变量
\(f_t\)
消灭掉了,也就不需要对每一个叶子进行打分了!

那么现在问题来,刚才我们提到,以上这些是假设树结构确定的情况下得到的结果。但是树的结构有好多种,我们应该如何确定呢?

6) 贪婪算法生成树

上一部分中我们假定树的结构是固定的。但是,树的结构其实是有无限种可能的,本小节我们使用贪婪算法生成树:

首先生成一个深度为0的树(只有一个根结点,也叫叶子结点)

对于每棵树的每个叶子结点,尝试去做分裂(生成两个新的叶子结点,原来的叶子结点不再是叶子结点)。在增加了分裂后的目标函数前后变化为(我们希望增加了树之后的目标函数小于之前的目标函数,所以用之前的目标函数减去之后的目标函数):

\(Gain=\frac{1}{2}(\frac{G_L^2}{H_L+\lambda}+\frac{G_R^2}{H_R+\lambda}-\frac{(G_L+G_R)^2}{H_L+H_R+\lambda})-\gamma\)

接下来要考虑的是如何寻找最佳分裂点。

例如,如果
\(x_j\)
是年龄,当分裂点是
\(a\)
的时候的增益
\(Gain\)
是多少?

其实这里对排序后的实例进行从左到右的线性扫描就足以决定特征的最佳分裂点。从左到右依次扫描:一旦数据按照特征值进行了排序,我们从第一个样本开始,依次计算每个可能的分裂点。对于每个分裂点,我们把样本分为“左侧”和“右侧”两个子集,分别计算划分前后目标函数的变化。下面还有别的一些办法

7)如何处理分类型变量

在很多情况下,我们不需要为分类变量设计特殊的处理方式,可以将其转换为one-hot 编码来处理。

\(z_j=
\begin{cases}
0& \text{if x is in category y}\\
1& \text{otherwise}
\end{cases}
\)

如果有太多的分类的话,矩阵会非常稀疏,算法会优先处理稀疏数据。

8) 修剪和正则化

回顾之前的增益,当训练损失减少的值小于正则化带来的复杂度时,增益有可能会是负数,此时就是模型的简单性和可预测性之间的权衡

XGBoost核心原理归纳解析

铺垫了那么多,总算到这里了。XGBoost 也是一个 Boosting 加法模型,每一步迭代只优化当前步中的子模型。


\(m\)
步我们有:
\(F_m(x_i)=F_{m-1}(x_i)+f_m(x_i)\)

\(f_m(x_i)\)
为当前步的子模型。
\(F_{m-1}(x_i)\)
为前
\(m-1\)
个完成训练且固定了的子模型。

泰勒展开

然后去掉常数,带入复杂度(和之前一样)

1)近似算法

基于性能的考量,XGBoost 还对贪心准则做了一个近似版本,简单说,处理方式是「将特征分位数作为划分候选点」。这样将划分候选点集合由全样本间的遍历缩减到了几个分位数之间的遍历。

展开来看,特征分位数的选取还有 global 和 local 两种可选策略:

精确贪心准则:这是默认的精确算法,遍历所有可能分裂点,找到能最大化增益的点。计算量最大,但分裂效果最优。
Global 近似分裂:使用全体样本的特征分位数进行一次性划分,分裂点在所有节点中复用,计算量大幅减少,适合较大的数据集。
Local 近似分裂:在每个节点分裂前根据当前节点的样本重新计算特征分位数,能够更加灵活适应不同节点的特征分布,适合样本分布差异较大的情况。

近似算法的性能与精确贪心算法几乎相同,但大大降低了计算成本。

2)加权分位数

在 XGBoost 中,加权分位数(Weighted Quantile Sketch)用于加速分裂点的寻找过程。加权分位数算法并不是直接根据样本的特征值来划分分位点,而是考虑了
样本的二阶导数(Hessian)
作为权重,从而更好地平衡分裂点的选择,特别是在近似算法中。

令偏导为0易得
\(f_m^*(x_i)=-\frac{g_i}{h_i}\)

3) 列采样与学习率

列采样指的是在构建每棵决策树时,XGBoost 不会使用全部特征,而是随机选择部分特征用于分裂。这种方法源自于随机森林的思想,目的是增加模型的多样性,从而防止过拟合。

学习率在梯度提升树(GBDT)中是一个非常重要的超参数,用于控制每棵树对模型的贡献。学习率可以防止模型更新过快,从而提升模型的稳定性和性能。也叫步长、shrinkage,具体的操作是在每个子模型前(即每个叶节点的回归值上)乘上该系数,不让单颗树太激进地拟合,留有一定空间,使迭代更稳定。XGBoost默认设定为 。

4) 特征缺失与稀疏性

简单说,它的做法是将缺失值和稀疏
\(0\)
值等同视作缺失值,将其「绑定」在一起,分裂节点的遍历会跳过缺失值的整体。这样大大提高了运算效率。

比如在下面的例子中有六种划分情况,XGBoost 会遍历以上6种情况(3个非缺失值的切分点×缺失值的两个方向),最大的分裂收益就是本特征上的分裂收益

XGBoost工程优化

1)并行列块设计

XGBoost 将每一列特征提前进行排序,以块(Block)的形式储存在缓存中,并以索引将特征值和梯度统计量对应起来,每次节点分裂时会重复调用排好序的块。而且不同特征会分布在独立的块中,因此可以进行分布式或多线程的计算。

2)缓存访问优化

特征值排序后通过索引来取梯度
\(g_i,h_i\)
会导致访问的内存空间不一致,进而降低缓存的命中率,影响算法效率。为解决这个问题,XGBoost为每个线程分配一个单独的连续缓存区,用来存放梯度信息。

3) 核外块计算

数据量非常大的情形下,无法同时全部载入内存。XGBoost 将数据分为多个 blocks 储存在硬盘中,使用一个独立的线程专门从磁盘中读取数据到内存中,实现计算和读取数据的同时进行。
为了进一步提高磁盘读取数据性能,XGBoost 还使用了两种方法:

① 压缩 block,用解压缩的开销换取磁盘读取的开销。
② 将 block 分散储存在多个磁盘中,提高磁盘吞吐量。

XGBoost vs GBDT

GBDT 是机器学习算法,XGBoost 在算法基础上还有一些工程实现方面的优化。

GBDT 使用的是损失函数一阶导数,相当于函数空间中的梯度下降;XGBoost 还使用了损失函数二阶导数,相当于函数空间中的牛顿法。

正则化
:XGBoost 显式地加入了正则项来控制模型的复杂度,能有效防止过拟合。

列采样
:XGBoost 采用了随机森林中的做法,每次节点分裂前进行列随机采样。

缺失值
:XGBoost 运用稀疏感知策略处理缺失值,GBDT无缺失值处理策略。

并行高效
:XGBoost 的列块设计能有效支持并行运算,效率更优。

代码实现

需要先下载xgboost

pip install xgboost

代码如下

# 导入所需的库
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes  # 替换为 load_diabetes
from sklearn.metrics import mean_squared_error

# 1. 加载糖尿病数据集
# 这个数据集包含442个样本,10个特征,用于预测一个连续目标变量
diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target  # X是特征数据,y是标签(目标变量)

# 2. 划分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 3. 将数据转换为 DMatrix 格式
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

# 4. 设置 XGBoost 模型的超参数
params = {
    'objective': 'reg:squarederror',  # 回归任务使用的目标函数,平方误差
    'max_depth': 3,                   # 决策树的最大深度,控制模型的复杂度
    'eta': 0.05,                       # 学习率,控制每棵树对整体模型的贡献
    'eval_metric': 'rmse' ,            # 评估指标,使用均方根误差(RMSE)
    'lambda': 2,                        # L2 正则化项,防止过拟合
    'alpha': 0.5   # L1 正则化项
}

# 5. 设定训练轮数
num_round = 200  # 训练的轮数,即构建多少棵树

# 6. 定义评估数据集
evals = [(dtrain, 'train'), (dtest, 'eval')]  # (数据集, 数据集名称)

# 7. 训练 XGBoost 模型,加入 early_stopping_rounds早停机制,防止过拟合
bst = xgb.train(params, dtrain, num_round, evals, early_stopping_rounds=10)

# 8. 使用训练好的模型对测试集进行预测
y_pred = bst.predict(dtest)

# 9. 评估模型性能
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}")

# 10. 保存训练好的模型
bst.save_model('xgboost_model.json')

# 11. 加载已保存的模型
loaded_bst = xgb.Booster()
loaded_bst.load_model('xgboost_model.json')

# 12. 使用加载的模型进行预测
y_pred_loaded = loaded_bst.predict(dtest)
mse_loaded = mean_squared_error(y_test, y_pred_loaded)
print(f"Mean Squared Error from loaded model: {mse_loaded}")

结果如下

[0]	train-rmse:76.08309	eval-rmse:71.75905
[1]	train-rmse:74.34324	eval-rmse:70.47408
[2]	train-rmse:72.66427	eval-rmse:69.24759
[3]	train-rmse:71.10664	eval-rmse:68.09809
[4]	train-rmse:69.63498	eval-rmse:67.14668
[5]	train-rmse:68.24045	eval-rmse:66.09854
[6]	train-rmse:66.93042	eval-rmse:64.91738
[7]	train-rmse:65.73304	eval-rmse:64.08775
[8]	train-rmse:64.58640	eval-rmse:63.26052
[9]	train-rmse:63.51304	eval-rmse:62.49745
[10]	train-rmse:62.44810	eval-rmse:61.64759
[11]	train-rmse:61.51387	eval-rmse:60.96222
[12]	train-rmse:60.61767	eval-rmse:60.32972
[13]	train-rmse:59.77722	eval-rmse:59.74329
[14]	train-rmse:59.01348	eval-rmse:59.13121
[15]	train-rmse:58.24704	eval-rmse:58.55106
[16]	train-rmse:57.57392	eval-rmse:58.15165
[17]	train-rmse:56.92761	eval-rmse:57.68188
[18]	train-rmse:56.33319	eval-rmse:57.37781
[19]	train-rmse:55.72582	eval-rmse:56.97001
[20]	train-rmse:55.14420	eval-rmse:56.45029
[21]	train-rmse:54.61096	eval-rmse:55.97904
[22]	train-rmse:54.12594	eval-rmse:55.57225
[23]	train-rmse:53.68383	eval-rmse:55.39305
[24]	train-rmse:53.24822	eval-rmse:55.01127
[25]	train-rmse:52.85214	eval-rmse:54.85699
[26]	train-rmse:52.43814	eval-rmse:54.49904
[27]	train-rmse:52.07004	eval-rmse:54.42905
[28]	train-rmse:51.68191	eval-rmse:54.25354
[29]	train-rmse:51.28268	eval-rmse:54.09452
[30]	train-rmse:50.94229	eval-rmse:54.06703
[31]	train-rmse:50.58475	eval-rmse:53.88010
[32]	train-rmse:50.24739	eval-rmse:53.74475
[33]	train-rmse:49.97042	eval-rmse:53.49905
[34]	train-rmse:49.65855	eval-rmse:53.41597
[35]	train-rmse:49.38190	eval-rmse:53.34692
[36]	train-rmse:49.07203	eval-rmse:53.32202
[37]	train-rmse:48.81472	eval-rmse:53.22084
[38]	train-rmse:48.57124	eval-rmse:53.24058
[39]	train-rmse:48.33730	eval-rmse:53.13983
[40]	train-rmse:47.97171	eval-rmse:53.05406
[41]	train-rmse:47.75619	eval-rmse:52.87405
[42]	train-rmse:47.43067	eval-rmse:52.80852
[43]	train-rmse:47.18844	eval-rmse:52.70296
[44]	train-rmse:46.96694	eval-rmse:52.61260
[45]	train-rmse:46.79053	eval-rmse:52.58588
[46]	train-rmse:46.58746	eval-rmse:52.51602
[47]	train-rmse:46.38476	eval-rmse:52.50433
[48]	train-rmse:46.15591	eval-rmse:52.44922
[49]	train-rmse:46.00542	eval-rmse:52.36981
[50]	train-rmse:45.84480	eval-rmse:52.27445
[51]	train-rmse:45.63700	eval-rmse:52.23794
[52]	train-rmse:45.49250	eval-rmse:52.25740
[53]	train-rmse:45.31208	eval-rmse:52.16836
[54]	train-rmse:45.15374	eval-rmse:52.22044
[55]	train-rmse:45.00284	eval-rmse:52.15072
[56]	train-rmse:44.87677	eval-rmse:52.04112
[57]	train-rmse:44.71921	eval-rmse:52.08482
[58]	train-rmse:44.55626	eval-rmse:52.02783
[59]	train-rmse:44.41483	eval-rmse:52.09304
[60]	train-rmse:44.27997	eval-rmse:52.03098
[61]	train-rmse:44.15710	eval-rmse:52.08378
[62]	train-rmse:44.00683	eval-rmse:52.02136
[63]	train-rmse:43.84878	eval-rmse:52.06178
[64]	train-rmse:43.74180	eval-rmse:52.06495
[65]	train-rmse:43.59775	eval-rmse:52.08875
[66]	train-rmse:43.44009	eval-rmse:52.20317
[67]	train-rmse:43.29717	eval-rmse:52.14245
[68]	train-rmse:43.10437	eval-rmse:52.15464
[69]	train-rmse:43.00768	eval-rmse:52.17011
[70]	train-rmse:42.87951	eval-rmse:52.11852
[71]	train-rmse:42.79951	eval-rmse:52.21249
[72]	train-rmse:42.66769	eval-rmse:52.22331
Mean Squared Error: 2727.2736118611274
Mean Squared Error from loaded model: 2727.2736118611274

train-rmse是训练集上的预测值与真实值之间的误差。eval-rmse是模型在测试集上的 RMSE

分析下早停机制下最后的数据,42.66769 表示在训练集上,模型的预测误差为 42.67。RMSE 越低,表示模型在训练集上拟合得越好。52.22 说明模型在测试集上的预测误差明显高于训练集,表明模型可能存在一定的过拟合问题,模型在训练集上表现良好,但在新数据(测试集)上的泛化能力不如在训练集上的表现。

manim
中绘制一个角度其实就是绘制两条直线,本篇介绍的不是绘制角度,而是绘制
角度标记

对于
锐角

钝角

角度标记
是一个弧,弧的度数与角的度数一样;

对于
直角

角度标记
是一个垂直的拐角。

manim
中关于
角度标记
的模型主要有3个:

  1. Angle
    :根据两条直线绘制角度标记
  2. RightAngle
    :根据两条
    互相垂直
    的线绘制直角标记
  3. Elbow
    :不受限于直线,任意方向和大小的直角标记

其中,
RightAngle
模块继承自
Angle

角度标记
的主要作用是在动画中标记出一些特殊角度,更好的展示数学定理的证明过程。

1. 主要参数

Angle
模块是通用的角度标记,它的主要参数有:

参数名称 类型 说明
line1 Line 构成角度的第一条线
line2 Line 构成角度的第二条线
radius float 角度标记的半径
quadrant Point2D 此参数控制角度标记显示在哪个位置
other_angle bool True
:顺时针从line1到line2
False
:逆时针从line1到line2
dot bool 是否在角度标记中显示一个点
dot_radius float 点的半径
dot_distance float 点到圆弧(角度标记)的相对距离
dot_color Color 点的颜色
elbow bool 是否显示成直角的形状

后面在使用示例中演示这些参数的使用。

RightAngle
模块继承自
Angle
,除了上面
Angle
的参数之外,还有一个自己特有的参数。

参数名称 类型 说明
length float 标记的大小

Elbow
模块与上面两个不一样,它不是根据两条线来生成角度标记。

参数名称 类型 说明
width float 标记的大小
angle float 标记朝向那个方向

Elbow
的形状和
RightAngle
是一样的。

2. 主要方法

Angle
模块的方法主要有3个:

名称 说明
from_three_points 根据三个点来生成角度标记
get_lines 获取生成角度的两条线
get_value 获取角度的值

一般我绘制一个角度标记时,都是根据两条相交的线来确定角度位置的。

通过
from_three_points
方法,可以根据任意3个点来生成一个角度标记。

A = np.array([2, -1, 0])
B = np.array([0, 0, 0])
C = np.array([1, 1, 0])

angle = Angle.from_three_points(A, B, C)

函数的参数是
A

B

C
三个点,

  • A:角度的起点
  • B:角度的顶点
  • C:角度的终点

生成的角度以
B
为顶点,从点A到点C逆时针旋转。

方法
get_lines
可获取构成角度的两条线,也就是上图中的
BA

BC
两条线。

lines = angle.get_lines()

最后,
get_value
方法,可以实时得到当前角度的值,值可以是度数,也可以是弧度。

print(f"角度:{angle.get_value(degrees=True)}")
print(f"弧度:{angle.get_value()}")

# 运行结果
角度:71.56505117707799
弧度:1.2490457723982544

3. 使用示例

3.1. 角度大小

因为角度标记
Angle
是一个弧形,所以角度的大小通过参数
radius
(半径)来调整。

line1 = Line(LEFT, RIGHT)
line2 = Line(DOWN, UP)

Angle(line1, line2)
Angle(line1, line2, radius=0.2)
Angle(line1, line2, radius=0.5)
Angle(line1, line2, radius=0.8)

3.2. 角度位置

角度标记的位置由两个参数来控制,
quadrant

other_angle

quadrant
参数一共有四种选项:
(1, 1)

(1, -1)

(-1, 1)

(-1, -1)

这个参数分两部分,分别表示角度标记在
Line1
上的
起点位置
和在
Line2
上的
终点位置

比如下面相交的两条直线,
quadrant
的第一个值和第二个值分别在
Line1

Line2
上的位置如图。

other_angle
默认为
False
,表示绘制角度时从
Line1

Line2

设置
other_angle
为True时,绘制角度的顺序相反,从
Line2

Line1

l1 = Line(
    LEFT + (1 / 3) * UP,
    RIGHT + (1 / 3) * DOWN,
)
l2 = Line(
    DOWN + (1 / 3) * RIGHT,
    UP + (1 / 3) * LEFT,
)

Angle(l1, l2)
Angle(l1, l2, quadrant=(1, -1))
Angle(l1, l2, quadrant=(-1, 1))
Angle(l1, l2, quadrant=(-1, -1))
Angle(l1, l2, other_angle=True)
Angle(l1, l2, quadrant=(1, -1), other_angle=True)
Angle(l1, l2, quadrant=(-1, 1), other_angle=True)
Angle(l1, l2, quadrant=(-1, -1), other_angle=True)

3.3. 角度中的点

Angle
中可以加一个点的标记,当一个画面中有很多角度的时候,这个标记可以帮助我们区分不同的角。

通过
dot_radius

dot_distance

dot_color
等参数,可以调整点的大小,位置和颜色。

line1 = Line(
    LEFT / 2,
    RIGHT / 2,
)
line2 = Line(
    DOWN / 2,
    UP / 2,
)

Angle(
    line1,
    line2,
    dot=True,
    dot_radius=0.02,
    dot_color=RED,
)
Angle(
    line1,
    line2,
    dot=True,
    dot_radius=0.08,
    dot_color=BLUE,
)
Angle(
    line1,
    line2,
    dot=True,
    dot_distance=0.2,
    dot_color=GREEN,
)
Angle(
    line1,
    line2,
    dot=True,
    dot_distance=0.8,
    dot_color=YELLOW,
)

3.4. 直角标记

最后,还有一个特殊的角度标记--直角标记。

manim
中提供了2个模块来标记直角,
RightAngle

Elbow

它们的显示效果差不多,区别在于,
RightAngle
需要根据两条线来生成,


Elbow
更加灵活一些,它可以在任意位置生成直角标记。

line1 = Line(
    LEFT / 2,
    RIGHT / 2,
)
line2 = Line(
    DOWN / 2,
    UP / 2,
)

RightAngle(
    line1,
    line2,
    length=0.2,
)
RightAngle(
    line1,
    line2,
    length=0.4,
)
RightAngle(
    line1,
    line2,
    quadrant=(1, -1),
)
RightAngle(
    line1,
    line2,
    quadrant=(-1, -1),
)
Elbow(width=0.5)
Elbow(width=1)
Elbow(width=1, angle=PI / 2)
Elbow(width=1, angle=5 * PI / 4)

4. 附件

文中完整的代码放在网盘中了(
angle.py
),

下载地址:
完整代码
(访问密码: 6872)

在电商行业竞争尤为激烈的当下,除了打价格战外,如何有效的控制成本,是每个从业者都在思考的问题

IDM-VTON
是一个
AI
虚拟换装工具,旨在帮助服装商家解决约拍模特导致的高昂成本问题,只需一张服装图片,就可以生成各种身穿该服装的模特,大大简化了传统的产品展示过程

IDM-VTON
最新中文版:

百度网盘:
https://pan.baidu.com/s/1fMTJHrLGr6-CWoFv1LmkDw?pwd=gw15

IDM-VTON
采用了先进的图像识别和视觉检测算法,在用户上传服装图片和模特姿势图后,能在短时间内生成多张商业用级的照片,在操作界面可以直观地看到服装的实际穿着效果

IDM-VTON
为服装商家提供了一个创新的解决方案,不仅节省了场地、拍摄、后期等费用,在降低成本的同时,还加速了新服装上市的效率,提高了商品的竞争力

技术优势

1.
数据学习:通过学习大量的真实穿着图片,扩散模型能够理解衣物在不同体型和肤色上的表现

2.
个性化适配:利用深度学习技术,模型能够根据用户的体型、肤色等特征,生成个性化的试穿效果

3.
环境适应:通过先进的图像处理技术,
IDM-VTON
能够适应不同的光线和背景,确保试穿效果的真实性

应用场景

·
电子商务:提高在线购物体验,让用户在购买前能够更直观地看到衣物的穿着效果

·
时尚设计:帮助设计师快速预览设计效果,加速设计流程

·
个性化推荐:根据用户身材和偏好数据,精准匹配合适的服装

·
社交媒体:博主可以在社交账号上尝试不同的穿衣风格,增强粉丝互动性和娱乐性

使用方法

1.
上传人物图片

2.
上传服装图片

3.
选择服装所在的部位(上半身,下半身,鞋类等)

4.
想要改变服装的大小(如:长袖变短袖),去掉“自动生成遮罩”的勾选,点下画笔后,在上传的人物图片处手动绘制遮罩的面积

5.
提示词(可选),帮助程序更精准的识别衣服,如
Dress
(连衣裙)

6.
点击“试穿运行”开始换装

在控制台可以查看当前的处理进度,程序执行完毕会输出信息“图片处理完成”

稍微等待一下,图片就生成好了,模特的试衣效果非常惊艳!(从
Output
界面右上角下载图片)

注意事项

①项目安装路径不要包含中文

②推荐使用
GTX1060
以上显卡运行此项目

③使用过程中若不慎关闭软件后台,请重新打开,并刷新网页

一位离职的前端同事,最近接了个
React Native
的活儿,遇到许多搞不定的问题,于是找到我帮忙“补课”(没有系统的学习
React Native
,也不具备原生
Android

iOS
开发基础知识)。

此前带过另一位前端同事入门
React Native
开发,有段时间甚至一天得花一两个小时,专门视频连线手把手传帮带,帮忙解决各种疑难杂症。

这可能是纯前端开发小伙伴,刚开始接触 APP 开发最头痛的一段。不管是
React Native
还是
Flutter
,虽然都号称跨平台开发,但如果没有相应平台(
Android

iOS
等)的开发基础,还是很难深入的。一般会卡在下面这些问题上:

  1. 开发环境的搭建和修改,可能涉及到
    JDK

    Maven

    Ruby

    Gems

    CocoaPods
    等;
  2. Android

    iOS
    项目的一些配置,如
    build.gradle

    AndroidManifest.xml

    Info.plist

    Podfile
    等;
  3. Android Studio

    XCode
    的使用;
  4. Android

    iOS
    原生代码的修改,包括
    Java
    /
    Kotlin

    Objective-C
    /
    Swift
  5. C/C++
    源码编译问题的处理;
  6. 命令行工具的使用,如
    Shell
    脚本编写、
    ADB
    的使用、
    react-native run-*
    命令的使用等;
  7. 调试工具的使用,如
    Flipper

    DevTools

    Reactotron
    等;
  8. 各种原生相关的三方库依赖问题处理;
  9. 由各种缓存所引发的问题处理;

另外,可能还会遇到诸如
切换/点击响应很慢、画面卡顿
等问题,
感觉上没有原生的看上去丝滑
。很多人会归咎于非原生,然而大部分时候并非如此。很可能是由于不明白相应的原理,导致写的代码执行效率太差。这是很大一部分纯前端小伙伴的通病。

时间精力允许的情况下,会在这里记录一些相应的问题解决方法。