2024年1月

Ef Core花里胡哨系列(10) 动态起来的 DbContext

我们知道,
DbContext
有两种托管方式,一种是
AddDbContext

AddDbContextFactory
,但是呢他们各有优劣,例如工厂模式下性能更好呀等等。那么,我们能否自己托管
DbContext
呢?

Github Demo:
动态起来的 DbContext

场景:
结合我们之前的文章 [Ef Core花里胡哨系列(5) 动态修改追踪的实体、动态查询] 假设一个应用内有很多的子应用,且都需要更新追踪的动态实体,那么很多表在重置
OnModelCreating
的时候将会非常的慢。主要体现在
modelBuilder.Model.AddEntityType(type)
,每个实体都需要花费一小段时间,几百个实体就会按分钟计算了,而且还会数据库操作产生一定的影响。

我们先实现一个基础的
DbContext
用来添加一些通用的实体以及处理动态实体的逻辑,每次需要重置DbContext的时候,都会获取最新的动态实体进行更新:

public class DbContextBase : DbContext
{
    public DbSet<User> Users { get; set; } = null!;
    public DbSet<Department> Departments { get; set; } = null!;

    protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
    {
        optionsBuilder.UseSqlite("Data Source=sample.db");
        optionsBuilder.ReplaceService<IModelCacheKeyFactory, MyModelCacheFactory>();

        base.OnConfiguring(optionsBuilder);
    }

    protected override void OnModelCreating(ModelBuilder modelBuilder)
    {
        var name = GetType().Name.Split("_");
        if (name.Length > 1)
        {
            foreach (var item in FormTypeBuilder.GetAppTypes(name[0]).Where(item => modelBuilder.Model.FindEntityType(item.Value) is null))
            {
                modelBuilder.Model.AddEntityType(item.Value);
            }
        }

        base.OnModelCreating(modelBuilder);
    }
}

然后实现一个动态
DbContext
的生成器,用于针对不同的
AppId
生成不同的
DbContext

public class DbContextGenerator
{
    private readonly ConcurrentDictionary<string, Type> _contextTypes = new()
    {
    };

    public Type GetOrCreate(string appId)
    {
        if (!_contextTypes.TryGetValue(appId, out var value))
        {
            value = GeneratorDbContext(appId);
            _contextTypes.TryAdd(appId, value);
        }

        return value;
    }

    public Type GeneratorDbContext(string appId)
    {
        var assemblyName = new AssemblyName("__RuntimeDynamicDbContexts");
        var assemblyBuilder = AssemblyBuilder.DefineDynamicAssembly(assemblyName, AssemblyBuilderAccess.Run);
        var moduleBuilder = assemblyBuilder.DefineDynamicModule("__RuntimeDynamicModule");
        var typeBuilder = moduleBuilder.DefineType($"{appId.ToLower()}_DbContext", TypeAttributes.Public | TypeAttributes.Class, typeof(DbContextBase));
        var constructorBuilder = typeBuilder.DefineConstructor(MethodAttributes.Public, CallingConventions.Standard, new Type[] { });
        var ilGenerator = constructorBuilder.GetILGenerator();
        ilGenerator.Emit(OpCodes.Ldarg_0);
        ilGenerator.Emit(OpCodes.Call, typeof(DbContextBase).GetConstructor(Type.EmptyTypes));
        ilGenerator.Emit(OpCodes.Ret);
        typeBuilder.CreateType();
        var dbContextType = assemblyBuilder.GetType($"{appId.ToLower()}_DbContext");
        return dbContextType;
    }
}

然后我们需要实现一个
DbContext
的容器用于管理我们生成的
DbContext
,以及负责初始化:

public class DbContextContainer : IDisposable
{
    private readonly DbContextGenerator _generator;
    private readonly Dictionary<string, DbContext> _contexts = new();

    public DbContextContainer(DbContextGenerator generator)
    {
        _generator = generator;
    }

    public DbContext Get(string appId)
    {
        if (!_contexts.TryGetValue(appId, out var context))
        {
            context = (DbContext)Activator.CreateInstance(_generator.GetOrCreate(appId))!;
            _contexts[appId] = context;
        }

        return context;
    }

    public void Dispose()
    {
        _contexts.Clear();
    }
}

DbContextContainer
的生命周期即
DbContext
的生命周期,因为
DbContext
的缓存是共享的,所以我们也不用担心一些性能问题。

使用时也非常简单,我们只需要在
DbContextContainer
取出我们对应
AppId

DbContext
进行操作就可以了:

public class DynamicController : ApiControllerBase
{
    private readonly DbContextContainer _container;

    public DynamicController(DbContextContainer container)
    {
        _container = container;
    }

    [HttpGet]
    public async Task<IActionResult> GetCompanies()
    {
        var res = await _container.Get("test1").DynamicSet(typeof(Company)).ToDynamicListAsync();

        return Ok(res);
    }

    [HttpGet]
    public async Task<IActionResult> AddCompany()
    {
        var db = _container.Get("test1");
        FormTypeBuilder.AddDynamicEntity("test1", "Companies", typeof(Company));
        db.UpdateVersion();

        return Ok();
    }
}

DataHub 更青睐于PythonAPI对血缘与元数据操作

image

虽然开源源码都有Java示例和Python示例:但是这个API示例数量简直是1:100的差距!!不知为何,项目使用Java编写,示例推送偏爱Python的官方;;;搞不懂也许就是开源官方团队写脚本的是Python一哥吧!

显然DataHub 更青睐于Python API对血缘与元数据操作

Java示例:屈指可数

image

Python示例 就是海量丰富了

image

目前Java示例就两个好用:

DatasetAdd.java 和 DataJobLineageAdd.java

(一)DatasetAdd.java 是设置元数据到Datahub


 private static void extractedTable() {
    String token = "";
    try (RestEmitter emitter =
        RestEmitter.create(b -> b.server("http://10.130.1.49:8080").token(token))) {
      MetadataChangeProposal dataJobIOPatch =
              new DataJobInputOutputPatchBuilder()
                      .urn(
                              UrnUtils.getUrn(
                                      "urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_456)")) //这个是使用的JOB输入表级:中转处理任务
                      .addInputDatasetEdge(
                              DatasetUrn.createFromString(
                                      "urn:li:dataset:(urn:li:dataPlatform:mysql,JDK-Name,PROD)")) //这个是使用的JOB输入表级:入口节点
                      .addOutputDatasetEdge(
                              DatasetUrn.createFromString(
                                      "urn:li:dataset:(urn:li:dataPlatform:hive,JDK-Name,PROD)")) //这个是使用的JOB输入表级:出口节点
                      .addInputDatajobEdge(
                              DataJobUrn.createFromString(
                                      "urn:li:dataJob:(urn:li:dataFlow:(airflow,dag_abc,PROD),task_123)")) // 这里定义字段列级别的血缘关系:中转处理任务
                      .addInputDatasetField(
                              UrnUtils.getUrn(
                                      "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:hive,JDK-Name,PROD),userName)")) // 列字段的入口节点
                      .addOutputDatasetField(
                              UrnUtils.getUrn(
                                      "urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:mysql,JDK-Name,PROD),userName)")) // 列字段的出口节点
                      .build();
      Future<MetadataWriteResponse> response = emitter.emit(dataJobIOPatch);
      System.out.println(response.get().getResponseContent());
    } catch (Exception e) {
      e.printStackTrace();
      System.out.println("Failed to emit metadata to DataHub"+ e.getMessage());
      throw new RuntimeException(e);
    }
  }

(二)DataJobLineageAdd.java 是设置元数据带JOB任务的血缘到Datahub

 public static void main(String[] args)
      throws IOException, ExecutionException, InterruptedException {
    // Create a DatasetUrn object from a string
    DatasetUrn datasetUrn = UrnUtils.toDatasetUrn("hive", "JDK-Mysql", "PROD");
    // Create a CorpuserUrn object from a string
    CorpuserUrn userUrn = new CorpuserUrn("ingestion");
    // Create an AuditStamp object with the current time and the userUrn
    AuditStamp lastModified = new AuditStamp().setTime(1640692800000L).setActor(userUrn);

    // Create a SchemaMetadata object with the necessary parameters
    SchemaMetadata schemaMetadata =
        new SchemaMetadata()
            .setSchemaName("customer")
            .setPlatform(new DataPlatformUrn("hive"))
            .setVersion(0L)
            .setHash("")
            .setPlatformSchema(
                SchemaMetadata.PlatformSchema.create(
                    new OtherSchema().setRawSchema("__RawSchemaJDK__")))
            .setLastModified(lastModified);

    // Create a SchemaFieldArray object
    SchemaFieldArray fields = new SchemaFieldArray();

    // Create a SchemaField object with the necessary parameters
    SchemaField field1 =
        new SchemaField()
            .setFieldPath("mysqlId")
            .setType(
                new SchemaFieldDataType()
                    .setType(SchemaFieldDataType.Type.create(new StringType())))
            .setNativeDataType("VARCHAR(50)")
            .setDescription(
                "Java用户mysqlId名称VARCHAR")
            .setLastModified(lastModified);
    fields.add(field1);

    SchemaField field2 =
        new SchemaField()
            .setFieldPath("PassWord")
            .setType(
                new SchemaFieldDataType()
                    .setType(SchemaFieldDataType.Type.create(new StringType())))
            .setNativeDataType("VARCHAR(100)")
            .setDescription("Java用户密码VARCHAR")
            .setLastModified(lastModified);
    fields.add(field2);

    SchemaField field3 =
        new SchemaField()
            .setFieldPath("CreateTime")
            .setType(
                new SchemaFieldDataType().setType(SchemaFieldDataType.Type.create(new DateType())))
            .setNativeDataType("Date")
            .setDescription("Java用户创建时间Date")
            .setLastModified(lastModified);
    fields.add(field3);

    // Set the fields of the SchemaMetadata object to the SchemaFieldArray
    schemaMetadata.setFields(fields);

    // Create a MetadataChangeProposalWrapper object with the necessary parameters
    MetadataChangeProposalWrapper mcpw =
        MetadataChangeProposalWrapper.builder()
            .entityType("dataset")
            .entityUrn(datasetUrn)
            .upsert()
            .aspect(schemaMetadata)
            .build();

    // Create a token
    String token = "";
    // Create a RestEmitter object with the necessary parameters
    RestEmitter emitter = RestEmitter.create(b -> b.server("http://10.130.1.49:8080").token(token));
    // Emit the MetadataChangeProposalWrapper object
    Future<MetadataWriteResponse> response = emitter.emit(mcpw, null);
    // Print the response content
    System.out.println(response.get().getResponseContent());
    emitter.close();
  }

我们大多数时候不是需要带JOb的血缘关系

例如: 直接是表与表之间有关系

image

python脚本这里不赘述:太多示例了。重点是Java这边怎么实现这个东西

参考DataJobLineageAdd示例:他这里核心分析

(1.1) 就是把血缘关系提交到Datahub

代码====>

Future<MetadataWriteResponse> response = emitter.emit(dataJobIOPatch);
System.out.println(response.get().getResponseContent());

分析====>

emitter.emit(?) 这个方法就是提交血缘关系;
里面填充好的就是血缘关系数据吧:示例是dataJobIOPatch 就是携带JOB的血缘关系数据;

因为他初始化变量的时候就是DataJobInputOutputPatchBuilder构建的,见名知意就是JOb相关的

 MetadataChangeProposal dataJobIOPatch =
              new DataJobInputOutputPatchBuilder()......

所以我们是否是MetadataChangeProposal的实现替换为别的方式:找找源码

类比思想:看看同样的builder实现的地方有别的实现没有

image

挑出了看着很像的实现:猜一下肯定是和JOB没关系了,而且是直接操作元数据的关系的
DatasetPropertiesPatchBuilder
EditableSchemaMetadataPatchBuilder
UpstreamLineagePatchBuilder

SO 简单改造一下 取名为:DataSetLineageAdd

@Slf4j
class DataSetLineageAdd {

  private DataSetLineageAdd() {}

  /**
   * Adds lineage to an existing DataJob without affecting any lineage
   *
   * @param args
   * @throws IOException
   * @throws ExecutionException
   * @throws InterruptedException
   */
  public static void main(String[] args)
      throws IOException, ExecutionException, InterruptedException {
    extractedTable();
  }

  private static void extractedRow() {
   // 没有java版本。。。。
  }
  private static void extractedTable() {
    String token = "";
    try (RestEmitter emitter =
        RestEmitter.create(b -> b.server("http://10.130.1.49:8080").token(token))) {
      MetadataChangeProposal mcp =
              new UpstreamLineagePatchBuilder().
                      urn(UrnUtils.getUrn("urn:li:dataset:(urn:li:dataPlatform:mysql,ctmop.assets_info,PROD)"))
                      .addUpstream(DatasetUrn.createFromString(
                                      "urn:li:dataset:(urn:li:dataPlatform:mysql,ctmop.operation_fee_info,PROD)"), DatasetLineageType.TRANSFORMED)
                      .build();
      Future<MetadataWriteResponse> response = emitter.emit(mcp);
      System.out.println(response.get().getResponseContent());
    } catch (Exception e) {
      e.printStackTrace();
      System.out.println("Failed to emit metadata to DataHub"+ e.getMessage());
      throw new RuntimeException(e);
    }
  }
}

表级血缘用JAVA代码就实现了;这是一个简单的Demo;更深入的拓展需要自行挖掘!!!

image

有人说表级血缘太low了,能不能做到JAVA的字段级血缘关系呢。。。。当然没问题

看我示例用的这个:UpstreamLineagePatchBuilder 他意思没有指定表级还是字段级;API 方法 addUpstream 和 urn都是泛用型,理论上都OK

分析:
表级的元数据: urn:li:dataset:(urn:li:dataPlatform:mysql,ctmop.assets_info,PROD) 这个样子
列级的元数据: urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:mysql,JDK-Name,PROD),userName) 这个样子

发现规律了:表级外面包一层urn:li:schemaField:XXXX,字段名 那不就是列字段了,。。。。。浅谈捯饬结束!!!

有问题还望大家指正:!!!

1、准备材料

开发板(
STM32F407G-DISC1

ST-LINK/V2驱动
STM32CubeMX软件(
Version 6.10.0

keil µVision5 IDE(
MDK-Arm

逻辑分析仪
nanoDLA

2、实验目标

使用STM32CubeMX软件配置STM32F407开发板
使用基本定时器TIM6实现每500ms控制绿灯状态变化一次,基本定时器TIM7实现每1s控制红灯状态变化一次

3、定时器概述

STM32F407拥有2个基础定时器、10个通用定时器和2个高级定时器,14个定时器全部挂载在APB1和APB2时钟总线上,APB2时钟总线时钟频率最高可达84MHz,APB1时钟总线时钟频率最高可达42MHz,除TIM2和TIM5为32位外,其余定时器全部为16位,其结构框图如下图所示
(注释1)

不同的定时器具有不同的特性,有些定时器的计数器长度为16位,有些则为32位;有些定时器可以递增、递减或递增/递减计数,但有些定时器只能递增计数;有些定时器可以产生DMA请求,有些则不可以;另外定时器捕获/比较通道数量也不一样;具体特性区别请看下表
(注释2)

4、实验流程

4.0、前提知识

基本定时器由TIM6和TIM7组成,计数器为16位,内部结构较为简单,只有定时器的基本功能,可以做定时或驱动DAC,本实验暂不讨论DAC,只用定时功能,如下图所示为基本定时器框架
(注释3)
,基本定时器的时钟来源为APB1 Timer clocks,当通过控制器启动基本定时器TIM6/7时,时钟信号经过PSC预分频器将时钟分频,然后以分频后的时钟频率增加计数器的值,当计数器达到自动重载寄存器设置的值之后,产生溢出

4.1、CubeMX相关配置

请先阅读“
STM32CubeMX 工程建立
”实验3.4.1小节配置RCC和SYS

4.1.1、时钟树配置

基本定时器涉及到定时时间的问题,而TIM6/7的时钟来源自APB1 Timer clocks,因此需要先设置时钟树,知道APB1 Timer clocks的频率,才可以计算基本定时器的溢出时间

如下图所示,时钟树上所有总线频率均设置为了STM32F4能达到的最高频率,此时APB1 Timer clocks=84MHz

4.1.2、外设参数配置

在Pinout & Configuration页面左侧功能分类栏目中点开Timers栏目,单击栏目下的TIM6和TIM7

在页面中间TIM6/7 Mode and Configuration 中勾选Activated激活基本定时器,
One Pulse Mode为单次定时模式
,勾选该模式则定时器只触发一次,默认定时器为连续触发,触发完一次后自动重载ARR中设置的值重新计数

在页面中间Configuration栏中可以设置基本定时器的参数,包括预分频器系数、计数模式、ARR寄存器的值和预装载值自动重载,通过这些参数的设置可以决定基本定时器的溢出时间,
APB1 Timer clocks=84MHz,PSC=8399,ARR=4999,此时可计算溢出时间为(PSC+1)(ARR+1)/APB1 Timer clocks=0.5秒=500毫秒,则每500ms定时器产生一次溢出,ARR设置为9999则定时器1s溢出一次

参数auto re-load preload可以选择使能或不使能,如果不使能该参数,则在使用__HAL_TIM_SET_AUTORELOAD()函数动态修改基本定时器ARR参数值时,修改的值会立马生效;而如果使能该参数,则修改的值会在当前计数溢出之后下次得到修改

Trigger Output (TRGO) Parameters一般是用来设置用作其他外设的触发源的,比如将Trigger Event Selection选择为Update Event,然后在其他外设比如ADC中配置外部触发源时选择该定时器的触发事件(如果可以的话),这样在定时器产生Update Event时就可以启动外设,实现用定时器来控制外设启动的功能

上述配置如下图所示

4.1.3、外设中断配置

基本定时器的触发有三种模式①轮询方式②中断方式③DMA方式,这里只介绍前两种方式

①对于轮询方式,当前设置已经足够,只需要在生成的程序中使用HAL_TIM_Base_Start(&htim6)启动基本定时器,然后不断轮询计数值或UEV事件标志来判断是否发生了计数溢出

②中断方式是基本定时器最常用的方式,在Pinout & Configuration页面左侧功能分类栏目中点开NVIC栏目,然后选择合适的中断优先级并勾选基本定时器6和7的中断使能

4.2、生成代码

请先阅读“
STM32CubeMX 工程建立
”实验3.4.3小节配置Project Manager

单击页面右上角GENERATE CODE生成工程

4.2.1、外设初始化调用流程

在工程代码主函数main()中调用MX_TIM6_Init()函数对基本定时器TIM6参数进行了配置

在该MX_TIM6_Init()函数中调用了HAL_TIM_Base_Init()对定时器进行了初始化

然后在HAL_TIM_Base_Init()函数中调用了HAL_TIM_Base_MspInit()函数对TIM6时钟和中断设置/使能

TIM7初始化流程类似,具体定时器TIM6初始化流程如下图所示

4.2.2、外设中断调用流程

激活了基本定时器并启动TIM6/7全局中断之后,会在stm32f4xx_it.c中新增TIM6/7的中断服务函数TIM6_DAC_IRQHandler()和TIM7_IRQHandler()

该函数均调用HAL库的定时器中断统一处理函数HAL_TIM_IRQHandler(),该函数通过一系列的判断最终得出基本定时器目的为周期回调(注释4),因此最终调用周期回调函数HAL_TIM_PeriodElapsedCallback(),该函数为虚函数

TIM7中断调用流程类似,具体定时器TIM6中断调用流程如下图所示

4.2.3、添加其他必要代码

重新在tim.c中实现周期回调函数HAL_TIM_PeriodElapsedCallback(),当定时器TIM6溢出则翻转GREEN_LED引脚状态,当定时器TIM7溢出则翻转RED_LED引脚状态,具体代码如下图所示

源代码如下

/*基本定时器周期回调函数*/
void HAL_TIM_PeriodElapsedCallback(TIM_HandleTypeDef *htim)
{
    if(htim == &htim6)
    {
        HAL_GPIO_TogglePin(GREEN_LED_GPIO_Port, GREEN_LED_Pin) ;
    }
 
    if(htim == &htim7)
    {
        HAL_GPIO_TogglePin(RED_LED_GPIO_Port, RED_LED_Pin) ;
    }
}

在主函数中以中断方式启动基本定时器TIM6/7,具体代码如下图所示

5、常用函数

/*以轮询工作方式启动定时器*/
HAL_StatusTypeDef HAL_TIM_Base_Start(TIM_HandleTypeDef *htim)
 
/*停止轮询工作方式的定时器*/
HAL_StatusTypeDef HAL_TIM_Base_Stop(TIM_HandleTypeDef *htim)
 
/*以中断工作方式启动定时器*/
HAL_StatusTypeDef HAL_TIM_Base_Start_IT(TIM_HandleTypeDef *htim)
 
/*停止中断工作方式的定时器*/
HAL_StatusTypeDef HAL_TIM_Base_Stop_IT(TIM_HandleTypeDef *htim)
 
/*定时器周期回调子函数*/
void HAL_TIM_PeriodElapsedCallback(TIM_HandleTypeDef *htim)

6、烧录验证

6.1、具体步骤

“激活基本定时器TIM6/7 -> 配置合适参数实现500ms/1s定时时间 -> 勾选TIM6/7全局中断 -> 在生成的代码中重新实现周期回调函数HAL_TIM_PeriodElapsedCallback(TIM_HandleTypeDef *htim) -> 在该回调函数中判断定时器溢出来源 -> 根据来源编程对应响应 -> 在主函数主循环前使用HAL_TIM_Base_Start_IT(&htim6/7)以中断方式启动定时器”

6.2、实验现象

烧录程序,观察现象为绿灯每隔500ms状态改变一次,红灯每隔1s状态改变一次

使用逻辑分析仪监测PD12/14引脚状态,可以看出TIM6每500ms翻转一次PD12引脚状态,TIM7每1000ms翻转一次PD14引脚状态

7、注释详解

注释1
:图片来源STM32F407VGT6 Datasheet DS8626
注释2
:图片来源STM32 CubeMX 学习:003-定时器(其原表有错误)
注释3
:图片来源STM32F4xx中文参考手册
注释4
:具体过程请参看 HAL_TIM_IRQHandler(TIM_HandleTypeDef *htim) 函数详解

更多内容请浏览
OSnotes的CSDN博客

本文深入探讨了 PyTorch 中 Autograd 的核心原理和功能。从基本概念、Tensor 与 Autograd 的交互,到计算图的构建和管理,再到反向传播和梯度计算的细节,最后涵盖了 Autograd 的高级特性。

关注TechLead,分享AI全维度知识。作者拥有10+年互联网服务架构、AI产品研发经验、团队管理经验,同济本复旦硕,复旦机器人智能实验室成员,阿里云认证的资深架构师,项目管理专业人士,上亿营收AI产品研发负责人

file

一、Pytorch与自动微分Autograd

自动微分(Automatic Differentiation,简称 Autograd)是深度学习和科学计算领域的核心技术之一。它不仅在神经网络的训练过程中发挥着至关重要的作用,还在各种工程和科学问题的数值解法中扮演着关键角色。

1.1 自动微分的基本原理

在数学中,微分是一种计算函数局部变化率的方法,广泛应用于物理、工程、经济学等领域。自动微分则是通过计算机程序来自动计算函数导数或梯度的技术。

自动微分的关键在于将复杂的函数分解为一系列简单函数的组合,然后应用链式法则(Chain Rule)进行求导。这个过程不同于数值微分(使用有限差分近似)和符号微分(进行符号上的推导),它可以精确地计算导数,同时避免了符号微分的表达式膨胀问题和数值微分的精度损失。

import torch

# 示例:简单的自动微分
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2 + 3 * x + 1
y.backward()

# 打印梯度
print(x.grad)  # 输出应为 2*x + 3 在 x=2 时的值,即 7

1.2 自动微分在深度学习中的应用

在深度学习中,训练神经网络的核心是优化损失函数,即调整网络参数以最小化损失。这一过程需要计算损失函数相对于网络参数的梯度,自动微分在这里发挥着关键作用。

以一个简单的线性回归模型为例,模型的目标是找到一组参数,使得模型的预测尽可能接近实际数据。在这个过程中,自动微分帮助我们有效地计算损失函数关于参数的梯度,进而通过梯度下降法更新参数。

# 示例:线性回归中的梯度计算
x_data = torch.tensor([1.0, 2.0, 3.0])
y_data = torch.tensor([2.0, 4.0, 6.0])

# 模型参数
weight = torch.tensor([1.0], requires_grad=True)

# 前向传播
def forward(x):
    return x * weight

# 损失函数
def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2

# 计算梯度
l = loss(x_data, y_data)
l.backward()

print(weight.grad)  # 打印梯度

1.3 自动微分的重要性和影响

自动微分技术的引入极大地简化了梯度的计算过程,使得研究人员可以专注于模型的设计和训练,而不必手动计算复杂的导数。这在深度学习的快速发展中起到了推波助澜的作用,尤其是在训练大型神经网络时。

此外,自动微分也在非深度学习的领域显示出其强大的潜力,例如在物理模拟、金融工程和生物信息学等领域的应用。

二、PyTorch Autograd 的核心机制

file
PyTorch Autograd 是一个强大的工具,它允许研究人员和工程师以极少的手动干预高效地计算导数。理解其核心机制不仅有助于更好地利用这一工具,还能帮助开发者避免常见错误,提升模型的性能和效率。

2.1 Tensor 和 Autograd 的相互作用

file
在 PyTorch 中,Tensor 是构建神经网络的基石,而 Autograd 则是实现神经网络训练的关键。了解 Tensor 和 Autograd 如何协同工作,对于深入理解和有效使用 PyTorch 至关重要。

Tensor:PyTorch 的核心

Tensor 在 PyTorch 中类似于 NumPy 的数组,但它们有一个额外的超能力——能在 Autograd 系统中自动计算梯度。

  • Tensor 的属性:
    每个 Tensor 都有一个
    requires_grad
    属性。当设置为
    True
    时,PyTorch 会跟踪在该 Tensor 上的所有操作,并自动计算梯度。

Autograd:自动微分的引擎

Autograd 是 PyTorch 的自动微分引擎,负责跟踪那些对于计算梯度重要的操作。

  • 计算图:
    在背后,Autograd 通过构建一个计算图来跟踪操作。这个图是一个有向无环图(DAG),它记录了创建最终输出 Tensor 所涉及的所有操作。

Tensor 和 Autograd 的协同工作

当一个 Tensor 被操作并生成新的 Tensor 时,PyTorch 会自动构建一个表示这个操作的计算图节点。

  • 示例:简单操作的跟踪

    import torch
    
    # 创建一个 Tensor,设置 requires_grad=True 来跟踪与它相关的操作
    x = torch.tensor([2.0], requires_grad=True)
    
    # 执行一个操作
    y = x * x
    
    # 查看 y 的 grad_fn 属性
    print(y.grad_fn)  # 这显示了 y 是通过哪种操作得到的
    

    这里的
    y
    是通过一个乘法操作得到的。PyTorch 会自动跟踪这个操作,并将其作为计算图的一部分。

  • 反向传播和梯度计算

    当我们对输出的 Tensor 调用
    .backward()
    方法时,PyTorch 会自动计算梯度并将其存储在各个 Tensor 的
    .grad
    属性中。

    # 反向传播,计算梯度
    y.backward()
    
    # 查看 x 的梯度
    print(x.grad)  # 应输出 4.0,因为 dy/dx = 2 * x,在 x=2 时值为 4
    

2.2 计算图的构建和管理

file
在深度学习中,理解计算图的构建和管理是理解自动微分和神经网络训练过程的关键。PyTorch 使用动态计算图,这是其核心特性之一,提供了极大的灵活性和直观性。

计算图的基本概念

计算图是一种图形化的表示方法,用于描述数据(Tensor)之间的操作(如加法、乘法)关系。在 PyTorch 中,每当对 Tensor 进行操作时,都会创建一个表示该操作的节点,并将操作的输入和输出 Tensor 连接起来。

  • 节点(Node)
    :代表了数据的操作,如加法、乘法。
  • 边(Edge)
    :代表了数据流,即 Tensor。

动态计算图的特性

PyTorch 的计算图是动态的,即图的构建是在运行时发生的。这意味着图会随着代码的执行而实时构建,每次迭代都可能产生一个新的图。

  • 示例:动态图的创建

    import torch
    
    x = torch.tensor(1.0, requires_grad=True)
    y = torch.tensor(2.0, requires_grad=True)
    
    # 一个简单的运算
    z = x * y
    
    # 此时,一个计算图已经形成,其中 z 是由 x 和 y 通过乘法操作得到的
    

反向传播与计算图

在深度学习的训练过程中,反向传播是通过计算图进行的。当调用
.backward()
方法时,PyTorch 会从该点开始,沿着图逆向传播,计算每个节点的梯度。

  • 示例:反向传播过程

    # 继续上面的例子
    z.backward()
    
    # 查看梯度
    print(x.grad)  # dz/dx,在 x=1, y=2 时应为 2
    print(y.grad)  # dz/dy,在 x=1, y=2 时应为 1
    

计算图的管理

在实际应用中,对计算图的管理是优化内存和计算效率的重要方面。

  • 图的清空
    :默认情况下,在调用
    .backward()
    后,PyTorch 会自动清空计算图。这意味着每个
    .backward()
    调用都是一个独立的计算过程。对于涉及多次迭代的任务,这有助于节省内存。

  • 禁止梯度跟踪
    :在某些情况下,例如在模型评估或推理阶段,不需要计算梯度。使用
    torch.no_grad()
    可以暂时禁用梯度计算,从而提高计算效率和减少内存使用。

    with torch.no_grad():
        # 在这个块内,所有计算都不会跟踪梯度
        y = x * 2
        # 这里 y 的 grad_fn 为 None
    

2.3 反向传播和梯度计算的细节

反向传播是深度学习中用于训练神经网络的核心算法。在 PyTorch 中,这一过程依赖于 Autograd 系统来自动计算梯度。理解反向传播和梯度计算的细节是至关重要的,它不仅帮助我们更好地理解神经网络是如何学习的,还能指导我们进行更有效的模型设计和调试。

反向传播的基础

反向传播算法的目的是计算损失函数相对于网络参数的梯度。在 PyTorch 中,这通常通过在损失函数上调用
.backward()
方法实现。

  • 链式法则:
    反向传播基于链式法则,用于计算复合函数的导数。在计算图中,从输出到输入反向遍历,乘以沿路径的导数。

反向传播的 PyTorch 实现

以下是一个简单的 PyTorch 示例,说明了反向传播的基本过程:

import torch

# 创建 Tensor
x = torch.tensor(1.0, requires_grad=True)
w = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)

# 构建一个简单的线性函数
y = w * x + b

# 计算损失
loss = y - 5

# 反向传播
loss.backward()

# 检查梯度
print(x.grad)  # dy/dx
print(w.grad)  # dy/dw
print(b.grad)  # dy/db

在这个例子中,
loss.backward()
调用触发了整个计算图的反向传播过程,计算了
loss
相对于
x

w

b
的梯度。

梯度积累

在 PyTorch 中,默认情况下梯度是累积的。这意味着在每次调用
.backward()
时,梯度都会加到之前的值上,而不是被替换。

  • 梯度清零:
    在大多数训练循环中,我们需要在每个迭代步骤之前清零梯度,以防止梯度累积影响当前步骤的梯度计算。
# 清零梯度
x.grad.zero_()
w.grad.zero_()
b.grad.zero_()

# 再次进行前向和反向传播
y = w * x + b
loss = y - 5
loss.backward()

# 检查梯度
print(x.grad)  # dy/dx
print(w.grad)  # dy/dw
print(b.grad)  # dy/db

高阶梯度

PyTorch 还支持高阶梯度计算,即对梯度本身再次进行微分。这在某些高级优化算法和二阶导数的应用中非常有用。

# 启用高阶梯度计算
z = y * y
z.backward(create_graph=True)

# 计算二阶导数
x_grad = x.grad
x_grad2 = torch.autograd.grad(outputs=x_grad, inputs=x)[0]
print(x_grad2)  # d^2y/dx^2

三、Autograd 特性全解

PyTorch 的 Autograd 系统提供了一系列强大的特性,使得它成为深度学习和自动微分中的重要工具。这些特性不仅提高了编程的灵活性和效率,还使得复杂的优化和计算变得可行。

动态计算图(Dynamic Graph)

PyTorch 中的 Autograd 系统基于动态计算图。这意味着计算图在每次执行时都是动态构建的,与静态图相比,这提供了更大的灵活性。

  • 示例:动态图的适应性

    import torch
    
    x = torch.tensor(1.0, requires_grad=True)
    if x > 0:
        y = x * 2
    else:
        y = x / 2
    y.backward()
    

    这段代码展示了 PyTorch 的动态图特性。根据
    x
    的值,计算路径可以改变,这在静态图框架中是难以实现的。

自定义自动微分函数

PyTorch 允许用户通过继承
torch.autograd.Function
来创建自定义的自动微分函数,这为复杂或特殊的前向和后向传播提供了可能。

  • 示例:自定义自动微分函数

    class MyReLU(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input):
            ctx.save_for_backward(input)
            return input.clamp(min=0)
    
        @staticmethod
        def backward(ctx, grad_output):
            input, = ctx.saved_tensors
            grad_input = grad_output.clone()
            grad_input[input < 0] = 0
            return grad_input
    
    x = torch.tensor([-1.0, 1.0, 2.0], requires_grad=True)
    y = MyReLU.apply(x)
    y.backward(torch.tensor([1.0, 1.0, 1.0]))
    print(x.grad)  # 输出梯度
    

    这个例子展示了如何定义一个自定义的 ReLU 函数及其梯度计算。

requires_grad

no_grad

在 PyTorch 中,
requires_grad
属性用于指定是否需要计算某个 Tensor 的梯度。
torch.no_grad()
上下文管理器则用于临时禁用所有计算图的构建。

  • 示例:使用
    requires_grad

    no_grad

    x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
    
    with torch.no_grad():
        y = x * 2  # 在这里不会追踪 y 的梯度计算
    
    z = x * 3
    z.backward(torch.tensor([1.0, 1.0, 1.0]))
    print(x.grad)  # 只有 z 的梯度被计算
    

    在这个例子中,
    y
    的计算不会影响梯度,因为它在
    torch.no_grad()
    块中。

性能优化和内存管理

PyTorch 的 Autograd 系统还包括了针对性能优化和内存管理的特性,比如梯度检查点(用于减少内存使用)和延迟执行(用于优化性能)。

  • 示例:梯度检查点

    使用
    torch.utils.checkpoint
    来减少大型网络中的内存占用。

    import torch.utils.checkpoint as checkpoint
    
    def run_fn(x):
        return x * 2
    
    x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
    y = checkpoint.checkpoint(run_fn, x)
    y.backward(torch.tensor([1.0, 1.0, 1.0]))
    

    这个例子展示了如何使用梯度检查点来优化内存使用。

关注TechLead,分享AI全维度知识。作者拥有10+年互联网服务架构、AI产品研发经验、团队管理经验,同济本复旦硕,复旦机器人智能实验室成员,阿里云认证的资深架构师,项目管理专业人士,上亿营收AI产品研发负责人
如有帮助,请多关注
TeahLead KrisChang,10+年的互联网和人工智能从业经验,10年+技术和业务团队管理经验,同济软件工程本科,复旦工程管理硕士,阿里云认证云服务资深架构师,上亿营收AI产品业务负责人。


最近,Mistral 发布了一个激动人心的大语言模型: Mixtral 8x7b,该模型把开放模型的性能带到了一个新高度,并在许多基准测试上表现优于 GPT-3.5。我们很高兴能够在 Hugging Face 生态系统中全面集成 Mixtral 以对其提供全方位的支持