2024年11月

【效果】

元服务链接格式(API>=12适用):
https://hoas.drcn.agconnect.link/ggMRM

生成二维码后效果:

【参考网址】

使用App Linking实现元服务跳转:https://developer.huawei.com/consumer/cn/doc/AppGallery-connect-Guides/agc-applinking-atomic-link-0000002046440041

草料二维码:https://cli.im/

【引言】

本文将详细介绍如何使用App Linking技术实现元服务之间的无缝跳转,并通过生成二维码的方式快速拉起元服务,从而简化用户操作流程,增强应用的互动性和推广效率。

【什么是元服务链接?】

元服务链接是一种专为开发者设计的受控URL服务,允许用户点击后直接进入特定的元服务内容页。这种即点即享的功能极大地简化了用户的操作流程,并且可以精准控制用户访问的时间范围。对于已上架的元服务,开发者能够为其生成并配置专属链接,同时设置有效期,以确保用户在有效期内能够访问到最新的内容或功能。

【使用场景】

• 扫码直达:用户可以通过扫描二维码直接进入特定的元服务页面。

• 社交分享:方便用户通过社交媒体分享特定的服务内容。

• 唤醒沉默用户:通过推送通知中的链接快速激活不活跃的用户。

• 营销推广:作为广告内容的一部分,引导用户进入体验服务,提高转化率。

【创建元服务链接】

要创建一个元服务链接,首先需要满足以下前提条件:

1. 在AGC(AppGallery Connect)平台上创建项目。

2. 开通App Linking服务。

3. 项目中存在已上架且支持HarmonyOS API 12及以上的元服务。

接下来,按照以下步骤创建链接:

1. 登录AppGallery Connect,选择“我的项目”。

2. 选择项目后,在左侧导航栏找到“增长 > App Linking”,选择“元服务链接(API>=12适用)”页签。

3. 点击“创建”,填写链接名称、设置链接的有效期等信息。

4. 可以选择添加自定义参数,以便更精确地定位到元服务中的指定页面。

5. 最后,保存或发布链接。

【自定义参数】

为了更灵活地控制跳转行为,开发者可以在创建元服务链接时设置自定义参数。这些参数通常用于指定页面路径或是导航目标。例如,可以通过pagePath参数指定具体的页面路径,或者使用navRouterName参数指向特定的导航目的地。如果涉及分包,则还需要提供subPackageName参数。

【应用内集成】

在应用内部,开发者可以使用UIAbilityContext.openLink接口来打开元服务链接。根据设置的不同,如果匹配到相应的元服务则会直接打开;否则,可能会抛出异常或者尝试通过浏览器打开链接。此外,还可以设置appLinkingOnly参数来控制是否仅限于通过App Linking打开元服务。

// 示例代码
import { common } from '@kit.AbilityKit';
import { BusinessError } from '@kit.BasicServicesKit';

let context: common.UIAbilityContext = getContext(this) as common.UIAbilityContext;
let link: string = "https://hoas.drcn.agconnect.link/9P7g";
context.openLink(link, { appLinkingOnly: true })
  .then(() => {
    console.info('openlink success.');
  })
  .catch((error: BusinessError) => {
    console.error(`openlink failed. error:${JSON.stringify(error)}`);
  });

【错误处理与调试】

当元服务链接过期或无效时,系统会给出相应的错误提示。开发者可以根据这些提示来进行错误处理。例如,当appLinkingOnly设为true时,如果遇到非法或失效链接,系统会抛出错误码"16000019"。在这种情况下,开发者应当准备好相应的错误处理逻辑,以保证良好的用户体验。

【二维码生成】

最后一步,开发者可以使用草料二维码工具将生成的元服务链接转换成二维码,方便用户通过扫描二维码的方式访问元服务。这不仅提升了用户体验,也增加了应用的互动性和传播性。

【结论】

通过以上步骤,开发者可以轻松地利用App Linking技术实现鸿蒙元服务之间的无缝跳转,并通过二维码方式快速拉起元服务。这项技术不仅有助于简化用户操作,还能增强应用的互动性和推广效果。希望本文能帮助开发者更好地理解和运用这一强大功能,为用户提供更加流畅便捷的服务体验。

前言:
大型语言模型(LLMs)的发展历程可以说是非常长,从早期的GPT模型一路走到了今天这些复杂的、公开权重的大型语言模型。最初,LLM的训练过程只关注预训练,但后来逐步扩展到了包括预训练和后训练在内的完整流程。后训练通常涵盖监督指导微调和对齐过程,而这些在ChatGPT的推广下变得广为人知。

自ChatGPT首次发布以来,训练方法学也在不断进化。在这几期的文章中,我将回顾近1年中在预训练和后训练方法学上的最新进展。

关于LLM开发与训练流程的概览,特别关注本文中讨论的新型预训练与后训练方法

每个月都有数百篇关于LLM的新论文提出各种新技术和新方法。然而,要真正了解哪些方法在实践中效果更好,一个非常有效的方式就是看看最近最先进模型的预训练和后训练流程。幸运的是,在近1年中,已经有四个重要的新型LLM发布,并且都附带了相对详细的技术报告。

在本文中,我将重点介绍以下模型中的Qwen 2预训练和后训练流程:

• 阿里巴巴的 Qwen 2

• 苹果的 智能基础语言模型

• 谷歌的 Gemma 2

• Meta AI 的 Llama 3.1

我会完整的介绍列表中的全部模型,但介绍顺序是基于它们各自的技术论文在arXiv.org上的发表日期,这也巧合地与它们的字母顺序一致。

1. 阿里的 Qwen 2

我们先来说说 Qwen 2,这是一个非常强大的 LLM 模型家族,与其他主流的大型语言模型具有竞争力。不过,不知为何,它的知名度不如 Meta AI、微软和谷歌那些公开权重的模型那么高。

1.1 Qwen 2 概览

在深入探讨 Qwen 2 技术报告中提到的预训练和后训练方法之前,我们先简单总结一下它的一些核心规格。

Qwen 2 系列模型共有 5 种版本,包括 4 个常规(密集型)的 LLM,分别为 5 亿、15 亿、70 亿和 720 亿参数。此外,还有一个专家混合模型(Mixture-of-Experts),参数量为 570 亿,但每次仅激活 140 亿参数。(由于这次不重点讨论模型架构细节,我就不深入讲解专家混合模型了,不过简单来说,它与 Mistral AI 的 Mixtral 模型类似,但激活的专家更多。如果想了解更高层次的概述,可以参考这一篇知识《模型融合、专家混合与更小型 LLM 的未来》中的 Mixtral 架构部分。)

Qwen 2 模型的一大亮点是它在 30 种语言中的出色多语言能力。此外,它的词汇量非常大,达到 151,642 个 token。(相比之下,Llama 2 的词汇量为 32k,而 Llama 3.1 则为 128k)。根据经验法则,词汇量增加一倍,输入 token 数量会减少一半,因此 LLM 可以在相同输入中容纳更多 token。这种大词汇量特别适用于多语言数据和编程场景,因为它能覆盖标准英语词汇之外的单词。

下面是与其他 LLM 在 MMLU 基准测试中的简要对比。(需要注意的是,MMLU 是一个多选基准测试,因此有其局限性,但仍是评估 LLM 性能的最受欢迎方法之一。)

MMLU基准测试得分,针对最新的开源权重模型(分数越高越好)。这个图中的得分是从每个模型的官方研究论文中收集的。

1.2 Qwen 2 预训练

Qwen 2 团队对参数规模为 15 亿、70 亿和 720 亿的模型进行了训练,使用了 7 万亿个训练 token,这是一个合理的规模。作为对比,Llama 2 模型使用了 2 万亿个 token,Llama 3.1 模型使用了 15 万亿个 token。

有趣的是,参数规模为 5 亿的模型使用了 12 万亿个 token 进行训练。然而,研究人员并没有用这个更大的 12 万亿 token 数据集来训练其他模型,因为在训练过程中并未观察到性能提升,同时额外的计算成本也难以合理化。

他们的一个重点是改进数据过滤流程,以去除低质量数据,同时增强数据混合,从而提升数据的多样性——这一点我们在分析其他模型时会再次提到。

有趣的是,他们还使用了 Qwen 模型(尽管没有明确说明细节,我猜是指前一代的 Qwen 模型)来生成额外的预训练数据。而且,预训练包含了“多任务指令数据……以增强模型的上下文学习能力和指令遵循能力。”

此外,他们的训练分为两个阶段:常规预训练和长上下文预训练。在预训练的最后阶段,他们使用了“高质量、长文本数据”将上下文长度从 4,096 token 增加到 32,768 token。

            Qwen 2 预训练技术总结。‘持续预训练’指的是两阶段预训练,研究人员先进行了常规预训练,然后接着进行长上下文的持续预训练。

(遗憾的是,这些技术报告的另一个特点是关于数据集的细节较少,因此如果总结看起来不够详细,是因为公开的信息有限。)

1.3 Qwen 2 后训练

Qwen 2 团队采用了流行的两阶段后训练方法,首先进行监督式指令微调(SFT),在 50 万个示例上进行了 2 个 epoch 的训练。这一阶段的目标是提高模型在预设场景下的响应准确性。

                                  典型的大语言模型开发流程

在完成 SFT 之后,他们使用直接偏好优化(DPO)来将大语言模型(LLM)与人类偏好对齐。(有趣的是,他们的术语将其称为基于人类反馈的强化学习,RLHF。)正如我几周前在《LLM预训练和奖励模型评估技巧》文章中所讨论的,由于相比其他方法(例如结合 PPO 的 RLHF)更加简单易用,SFT+DPO 方法似乎是当前最流行的偏好调优策略。

对齐阶段本身也分为两个步骤。第一步是在现有数据集上使用 DPO(离线阶段);第二步是利用奖励模型形成偏好对,并进入“在线”优化阶段。在这里,模型在训练中会生成多个响应,奖励模型会选择优化步骤中更符合偏好的响应,这种方法也常被称为“拒绝采样”。

在数据集构建方面,他们使用了现有语料库,并通过人工标注来确定 SFT 的目标响应,以及识别偏好和被拒绝的响应(这是 DPO 的关键)。研究人员还合成了人工注释数据。

此外,团队还使用 LLM 生成了专门针对“高质量文学数据”的指令-响应对,以创建用于训练的高质量问答对。

                                Qwen2后训练技术汇总

1.4 结论

Qwen 2 是一个相对能力较强的模型,与早期的 Qwen 系列类似。在 2023 年 12 月的 NeurIPS LLM 效率挑战赛中,我记得大部分获胜方案都涉及 Qwen 模型。

关于 Qwen 2 的训练流程,值得注意的一点是,合成数据被用于预训练和后训练阶段。同时,将重点放在数据集过滤(而不是尽可能多地收集数据)也是 LLM 训练中的一个显著趋势。在我看来,数据确实是越多越好,但前提是要满足一定的质量标准。

从零实现直接偏好优化(DPO)对齐 LLM

直接偏好优化(DPO)已经成为将 LLM 更好地与用户偏好对齐的首选方法之一。这篇文章中你会多次看到这个概念。如果你想学习它是如何工作的,Sebastian Raschka博士有一篇很好的文章,即:《从零实现直接偏好优化(DPO)用于 LLM 对齐》,你可以看看它。在介绍完本文列表中的模型扣会根据它用中文语言为大家重新编写一篇发布出来。

利用DPO技术实现人工智能大语言模型与人类对齐流程概览

​ Vue 中的路由用于实现单页应用(SPA)中的页面导航。它允许你在不刷新整个页面的情况下,根据不同的 URL 路径显示不同的组件,提供了类似于多页面应用的用户体验。例如,在一个电商应用中,可以通过路由实现从首页到商品详情页、购物车页和用户个人中心页等不同页面的切换

​ Vue Router(Vue 官方的路由库)通过监听浏览器的 URL 变化,根据预先定义的路由规则,动态地加载和显示相应的组件。它利用了浏览器的
History
API 或者
Hash
模式来实现 URL 的管理,前端的路由key是路径,value是组件或者一个function,访问指定路径显示对应内容

SPA 应用的理解:

​ SPA(Single - Page Application)即单页应用,是一种现代的网页应用架构模式。在 SPA 中,整个应用只有一个 HTML 页面,当用户与应用进行交互(如点击链接、提交表单等)时,不会像传统的多页面应用那样进行整页刷新,而是通过 JavaScript 动态地更新页面的部分内容来呈现不同的视图或功能。例如,像一些大型的 Web 应用,如谷歌文档(Google Docs)和 GitHub,它们的用户体验很流畅,在使用过程中基本感觉不到页面的刷新,这就是典型的 SPA 应用

路由的基本使用

  1. 安装vue-router

    vue2使用vue-router3版本,vue3使用vue-router4版本

    npm i vue-router@3  
    
  2. 使用vue-router

    import  VueRouter from "vue-router"
    // vue-router是一个插件库,使用了VueRouter之后,vue可以配置router选项
    Vue.use(VueRouter)
    
  3. 创建router

    新建一个router文件夹/index.js

    import VueRouter from "vue-router"
    
    // 自定义的组件
    import Apage from "@/components/Bpage.vue";
    import Bpage from "@/components/Apage.vue";
    
    
    // 创建并暴露一个路由
    export default new VueRouter({
        routes: [
            // 如果请求路径是/a,触发a组件
            {
                path: "/a",
                component: Apage
            },
            // 如果请求路径是/b,触发b组件
            {
                path: "/b",
                component: Bpage
            }
        ]
    
    })
    
    
    
  4. 配置router

    import router from "./router"
    new Vue({
      render: h => h(App),
      // 配置引入的router
      router:router
    }).$mount('#app')
    
  5. 路由状态

    路由配置成功,浏览器地址栏的url后会带一个#

    http://localhost:8080/#/
    
  6. 实现路由跳转、显示

    <template>
      <div id="app">
    
        <!--    to指定的不是文件位置,而是路由定义的path-->
    <!--  active-class属性可以指定该路由激活时router-link上应用的样式,值是一个css类名  -->
        <router-link to="/a"  active-class="active">Apage</router-link>
        <router-link to="/b"  active-class="active">Bpage</router-link>
    
        <!--  指定路由组件呈现的位置   -->
        <router-view></router-view>
      </div>
    </template>
    
    
    <style>
    /* 路由被激活时用的样式 */
    .active{
      font-size: 20px;
      color: red;
    }
    </style>
    
  7. 注意事项


    • 路由组件(通过路由匹配切换)通常存放在
      pages
      文件夹
    • 一般组件(使用组件标签应用)通常存放在
      components
      文件夹。
    • 通过切换,隐藏了的路由组件,默认是被销毁掉的,需要的时候再去挂载。
    • 每个组件都有自己的
      $route
      属性,里面存储着自己的路由信息。
    • 整个应用只有一个router,可以通过组件的
      $router
      属性获取到

嵌套路由

// 创建并暴露一个路由
export default new VueRouter({
    routes: [
        // 如果请求路径是/a,触发a组件
        {
            path: "/a",
            component: Apage,
            // 通过children 配置多级路由,可以配置多个路由
            children: [
                {
                    path: "c",  // 路径不需要写/
                    component: Cpage
                }


            ]
        },
        // 如果请求路径是/b,触发b组件
        {
            path: "/b",
            component: Bpage
        }
    ]

})

<!--  父路由 -- >
<template>
  <div>
    <p>a page</p>
    <!-- 此处跳转路径要写完整的路径,父路由path/子路由path,不能只写c-->
    <router-link to="/a/c">to c page</router-link>
    <!-- 子路由显示位置 -->
    <router-view></router-view>
  </div>
</template>

路由query参数

路由跳转的时候在url上携带参数

  1. 传递参数-字符串写法

        <!--   使用数据绑定语法, 路径后?后拼接参数,多个参数用&连接 -->
        <!--    字符串引号里要使用``包裹-->
        <router-link :to="`/a/c?id=${id}&size=${num}`">to c page</router-link>
    
  2. 传递参数-对象写法

        <router-link :to="{
          // 路径
          path:'/a/c',
          // 使用query选项指定参数
          query:{
            id:id,
            size:num
          }
        }">to c page</router-link>
    
  3. 接收参数

    在对应的路由页面里

    {{$route.query.id}}
    {{$route.query.size}}
    

命名路由

命名路由是给路由定义一个名称。这就像给一个人起名字一样,方便在代码的其他地方通过这个名字来引用特定的路由,而不是使用冗长的路径字符串。它使得路由的跳转和操作更加灵活和可读

  1. 给路由命名-通过name属性

    export default new VueRouter({
        routes: [
            // 如果请求路径是/a,触发a组件
            {
                path: "/a",
                component: Apage,
                // 通过children 配置多级路由,可以配置多个路由
                children: [
                    {
    
                        path: "c",  // 子路由跳转路由,不需要写/
                        component: Cpage,
                        name:"c_page" // 给路由命名
                    }
                ]
            }
        ]
    })
    
  2. 路由跳转

        <!--    通过name属性指定跳转到c_page-->
        <router-link :to="{name:'c_page'}">to c page</router-link>
    <!-- 传递参数在后面写query配置即可-->
    

路由params参数

params
主要用于在路由路径中定义动态片段,这些动态片段是路由路径的一部分,用于区分不同的资源,
query
主要用于传递一些额外的、不影响资源定位的信息,比如搜索关键词、排序方式、分页信息等

  1. 声明params参数

     children: [
                    {
    
                        path: "c/:id",  // 使用占位符声明接收的params参数
                        component: Cpage,
                        name: "c_page", // 给路由命名
                        
                        
                    }
    
    
                ]
    
  2. 传递参数-字符串写法

     <router-link :to="`/a/c/${id}`">to c page</router-link>
    
  3. 传递参数-对象写法

    对象写法传递params参数 路由跳转必须用name别名,不能用path

      <router-link :to="{
        name:'c_page',
        params:{
          id:id
        }
    
      }"></router-link>
    
  4. 读取参数

    {{$route.params.id}}
    

路由props配置

props是一个用于将路由参数以props的形式传递给组件的选项,让路由组件更方便的收到参数

  1. 对象写法-传递的是写死的数据

                // 通过children 配置多级路由,可以配置多个路由
                children: [
                    {
    
                        path: "c/:id",  // 使用占位符声明接收的params参数
                        component: Cpage,
                        name: "c_page", // 给路由命名
                        // 对象写法,key和value会以props形式传递给Cpage组件,需要在组件内接收
                        props:{
                            id:1
                        }
                    }
                ]
    

    export default {
      name: "Cpage",
      props:["id"]
    }
    
  2. 布尔值开关写法

     // 通过children 配置多级路由,可以配置多个路由
                children: [
                    {
    
                        path: "c/:id",  // 使用占位符声明接收的params参数
                        component: Cpage,
                        name: "c_page", // 给路由命名
                        // 开启了该选项,会把该路由组件收到的params参数以props的形式传给组件,需要在组件接收
                        props:true
                    }
                ]
    

    export default {
      name: "Cpage",
      props:["id"]
    }
    
  3. 函数写法-同时支持动态params、query

     children: [
            {
    
                path: "c",  // 使用占位符声明接收的params参数
                component: Cpage,
                name: "c_page", // 给路由命名
    
                // 函数返回值中的每一组key-value都会通过props传递给组件
                props($route){
                // query参数、params参数
    
                    return {id:$route.query.id,p:$route.params.p}
                }
    
            }
        ]
    

    export default {
      name: "Cpage",
      props:["id","p"]
    }
    

router-link的replace

控制路由跳转时操作浏览器历史记录的模式

浏览器的历史记录有两种写入方式:分别为push和replace,默认为push

push
是指将一个新的路由记录添加到浏览器的历史记录栈中。这意味着当用户进行导航操作时,会在历史记录中新增一条记录,就好像在栈顶放入了一个新的元素。例如,在一个网页应用中,用户从首页点击链接进入产品详情页,使用
push
操作后,浏览器的历史记录栈中就会新增一条记录,表示用户访问了产品详情页,用户可以通过浏览器的 “后退” 按钮回到之前的首页


push
不同,
replace
操作是替换当前的路由记录,而不是添加新的记录。它就像是修改了历史记录栈顶的元素,而不是在栈顶添加新元素。还是以网页应用为例,如果用户在登录页面登录成功后,使用
replace
操作跳转到用户主页,那么在历史记录中,登录页面的记录会被用户主页的记录替换,用户在用户主页时,点击 “后退” 按钮不会回到登录页面

   ```html
<!--     开启replace模式  -->
<router-link to="/a/c" replace>to c page</router-link>
   ```

编程式路由导航

编程式路由导航是指在 JavaScript 代码中通过调用相关的方法来实现路由跳转,不借助
<router-link>
实现路由跳转,让路由跳转更加灵活

  1. push形式路由跳转
export default {
  name: "Apage",
  methods:{
    puscCpage(){
      //  以push的形式跳转到c_page路由
      this.$router.push({
        name:"c_page", // name是路由命名name的值
        query:{
          id:1
        }

      })
    }
  }
}
<template>
  <div>
    <p>a page</p>

<!-- 点击按钮触发路由跳转-->
    <button @click="puscCpage"></button>

    <router-view></router-view>
  </div>
</template>
  1. replace形式路由跳转

    <script>
    export default {
      name: "Apage",
      methods:{
        replaceCpage(){
          //  以replace的形式跳转到c_page路由
          this.$router.replace({
            name:"c_page", // name是路由命名name的值
            query:{
              id:1
            }
    
          })
        }
      }
    }
    </script>
    
    <template>
      <div>
        <p>a page</p>
    
    
        <button @click="replaceCpage"></button>
    
        <router-view></router-view>
      </div>
    </template>
    
  2. 路由前进、后退

    // 触发前进
    this.$router.forward()
    // 触发后退
    this.$router.back()
    // 可以前进也可以后退,传正数 前进指定数的记录,传负数 后退指定数的记录
    this.$router.go(3)
    

缓存路由组件

让不展示的路由组件保持挂载,不被销毁,使用kepp-alive

<!--  include的值是组件的名字-->
<!-- 如果指定多个组件 :include=[组件名字] -->
<!--  如果不指定include,则是被keep-alive包裹的区域所有的组件-->
<keep-alive include="Cpage"> 
    <router-view></router-view>
</keep-alive>

路由的生命周期钩子

路由组件所独有的两个钩子,用于捕获路由组件的激活状态

export default {
  name: "Apage",
  // activated:路由组件被激活时触发
  activated(){
    console.log("activated")
  },
  //deactivated:路由组件失活时触发
  deactivated(){
    console.log("deactivated")
    
  }


}

路由守卫

路由守卫是 Vue Router(以及其他路由框架也有类似概念)提供的一种机制,用于在路由跳转过程中进行拦截和处理。它就像是在路由的各个关键节点设置的 “关卡”,可以检查用户是否有权限访问某个页面、在进入或离开页面时执行一些特定的操作,如数据获取、页面过渡动画控制等

路由元信息

路由元信息(
meta
)是在定义路由时可以添加的额外信息,它可以是任何你想要的数据类型,如对象、字符串、数字等。这些信息主要用于在路由守卫或者其他需要获取路由额外信息的场景中使用,是一种灵活的方式来为每个路由配置自定义的属性

const routes = [
  {
    path: '/login',
    component: HomeComponent,
    // 自定义的元信息
    meta: {
      title: '首页',
      requiresAuth: false
    }
  },
  {
    path: '/admin',
    component: AdminComponent,
    meta: {
      title: '管理页面',
      requiresAuth: true
    }
  }
];
// 读取元信息
this.$route.meta
全局路由守卫
// 创建并暴露一个路由
const router = new VueRouter({
  routes:[]
})

export  default router
  1. beforeEach前置守卫

    在初始化、每次路由跳转之前被调用,可以用来进行全局的权限检查、加载显示等操作

    router.beforeEach((to, from, next) => {
      // to是即将要进入的目标路由对象
      // from是当前正要离开的路由对象
      // next是一个函数,用于决定是否继续路由跳转
      // 检查用户是否登录,通过检查localStorage中的token来判断
      const token = localStorage.getItem('token');
      if (to.meta.requiresAuth &&!token) {
        // 如果目标路由需要认证(通过路由元信息meta中的requiresAuth字段判断)
        
        // 跳转登录页
        next('/login');
      } else {
        // 跳转要跳转的页面
        next();
      }
    });
    
    
  2. afterEach后置守卫

    初始化、每次路由跳转完成后被调用,主要用于执行一些不需要中断路由跳转的操作,如页面标题更新、统计页面访问次数等

    router.afterEach((to, from) => {
      // 更新页面标题,假设每个路由的meta信息中有title字段用于定义页面标题
      document.title = to.meta.title || '默认标题';
    });
    
路由独享守卫

可以在单个路由配置中定义,只对该路由生效。它和全局
beforeEach
守卫类似,但作用范围仅限于特定的路由

routes:[
  {
    path: '/admin',
    component: AdminComponent,
    beforeEnter: (to, from, next) => {
      // 检查用户是否是管理员,假设通过检查localStorage中的adminToken来判断
      const adminToken = localStorage.getItem('adminToken');
      if (adminToken) {
        next();
      } else {
        next('/login');
      }
    }
  }
]
组件内路由守卫
  1. beforeRouteEnter


    在组件被渲染之前调用,此时组件实例还没有被创建,所以不能访问组件的this,但是可以通过传递给
    next
    函数的回调来访问组件实例


    export default {
      name: 'Admin',
      // 通过路由规则,进入该组件时被调用
      beforeRouteEnter(to, from, next) {
        // 检查用户是否有权限访问该组件
        const authService = {
          // 权限检查函数,根据用户的角色或其他条件返回布尔值
          checkAccess: () => {
            const userRole = 'admin'
            // 根据角色判断是否有权限访问
            return userRole === 'admin'
          }
        }
    
        // 检查用户是否有权限访问该组件
        if (authService.checkAccess()) {
          // 如果有权限,则继续导航到目标路由
          next()
        } else {
          // 如果没有权限,则重定向到无权访问页面
          next('/unauthorized')
        }
      },
      data() {
        return {
          message: 'Welcome'
        }
      }
    }
    

  2. beforeRouteUpdate


    在当前路由改变,但是该组件被复用时调用。这对于动态路由参数的组件很有用,可以在这里响应路由参数的变化并更新组件数据

    例如,有一个用户详情组件,其路由路径为
    /user/:id
    ,当从
    /user/1
    跳转到
    /user/2
    时,由于组件是复用的(都是
    UserDetailComponent
    ),
    beforeRouteUpdate
    就会被调用


    export default {
    
      beforeRouteUpdate(to, from, next) {
        //  方法updateData用于更新数据
        this.updateData(to.params.id);
        next();
      }
    }
    
  3. beforeRouteLeave


    在离开当前组件对应的路由时调用,可以用于保存未完成的数据、询问用户是否确认离开等操作


    export default {
      // 通过路由规则,离开该组件时被调用
      beforeRouteLeave(to, from, next) {
        // 检查是否有未保存的数据
        const unsavedData = this.checkUnsavedData();
        if (unsavedData) {
          // 如果有未保存的数据,弹出确认框询问用户是否确认离开
          const confirmLeave = window.confirm('你有未保存的数据,确定要离开吗?');
          if (confirmLeave) {
            next()
          } else {
            // false代表中断当前路由操作
            next(false)
          }
        } else {
          next()
        }
      },
    }
    

history和hash模式

  1. Hash 模式
  • vue路由默认是hash模式

  • 在 URL 中,#符号及其后面的部分被称为哈希(hash)部分

  • Hash 模式的路由是基于 URL 的哈希值变化来实现的。当哈希值改变时,浏览器不会向服务器发送请求,而是会触发浏览器的hashchange事件,JavaScript 可以通过监听这个事件来实现页面的局部更新,从而达到切换页面(路由)的效果

  • 兼容性好:几乎所有的浏览器都支持hash模式,包括一些较老的浏览器。

  • 不会发送请求到服务器:这使得在单页应用(SPA)中可以实现无刷新的页面切换,因为服务器不会对哈希部分的变化做出响应,所有的路由切换逻辑都在前端完成。

  • URL 相对不美观:带有#符号的 URL 可能看起来不够简洁、美观,并且在某些场景下可能会让人感觉比较 “奇怪”,例如在分享链接或者生成静态资源链接时。

  1. History 模式
  • History 模式利用了 HTML5 新增的History API,特别是pushState和replaceState方法
  • 通过这些方法,JavaScript 可以改变浏览器的历史记录栈,并且在改变 URL 时不会引起页面的刷新(除非手动刷新),浏览器不会自动发送请求到服务器,而是由前端应用来处理这个路由变化。
  • URL 更美观:没有#符号,使得 URL 看起来更像传统的多页应用的链接,更符合用户的认知习惯,在分享链接、搜索引擎优化(SEO)等方面有一定优势。
  • 需要服务器配置支持:因为改变后的 URL 和普通的 URL 没有区别,所以当用户直接访问或者刷新一个通过History API修改后的 URL 时,浏览器会向服务器发送请求。这就需要服务器进行配置,将所有的前端路由请求都重定向到应用的入口文件(例如index.html),否则会出现 404 错误。
  • 兼容性稍差:虽然现代浏览器都支持History API,但一些较老的浏览器可能不支持,在使用时需要考虑兼容性问题
  1. 切换history模式

    前端配置

    import Vue from 'vue';
    import Router from 'vue-router';
    Vue.use(Router);
    const router = new Router({
      // 指定History
      mode: 'History',
      routes: [
    
      ]
    });
    export default router;
    

    服务端配置


    当使用History模式时,因为浏览器会把修改后的 URL 当作普通的请求发送给服务器,所以需要服务器进行相应的配置,具体参考使用对应的Server的解决方案

技术背景

PySAGES是一款可以使用GPU加速的增强采样插件,它可以直接对接到
OpenMM
上进行增强采样分子动力学模拟,这里我们测试一下相关的安装,并尝试跑一个简单的增强采样示例。

安装PySAGES

PySAGES本身可以使用pip进行安装:

python3 -m pip install git+https://github.com/SSAGESLabs/PySAGES.git
$ python3 -m pip install git+https://github.com/SSAGESLabs/PySAGES.git
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting git+https://github.com/SSAGESLabs/PySAGES.git
  Cloning https://github.com/SSAGESLabs/PySAGES.git to /tmp/pip-req-build-1fcvtmpb
  Running command git clone --filter=blob:none --quiet https://github.com/SSAGESLabs/PySAGES.git /tmp/pip-req-build-1fcvtmpb
  Resolved https://github.com/SSAGESLabs/PySAGES.git to commit 5f5bfc7ab97c8027bb60eedd65cdcd66b5556b57
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
Requirement already satisfied: cython in /home/dechin/.local/lib/python3.10/site-packages (from pysages==0.5.0) (3.0.11)
Collecting dill (from pysages==0.5.0)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/46/d1/e73b6ad76f0b1fb7f23c35c6d95dbc506a9c8804f43dda8cb5b0fa6331fd/dill-0.3.9-py3-none-any.whl (119 kB)
Requirement already satisfied: jax>=0.3.5 in /home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages (from pysages==0.5.0) (0.3.25)
Collecting plum-dispatch!=2.0.0,!=2.0.1,>=1.5.4 (from pysages==0.5.0)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/56/48/253352df240f5f1d4226f757e4107344bc7f49a4f84ba7d1affb5916d622/plum_dispatch-2.5.3-py3-none-any.whl (42 kB)
Collecting numba (from pysages==0.5.0)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/79/58/cb4ac5b8f7ec64200460aef1fed88258fb872ceef504ab1f989d2ff0f684/numba-0.60.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.7/3.7 MB 1.3 MB/s eta 0:00:00
Requirement already satisfied: numpy>=1.20 in /home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages (from jax>=0.3.5->pysages==0.5.0) (1.24.3)
Requirement already satisfied: opt-einsum in /home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages (from jax>=0.3.5->pysages==0.5.0) (3.3.0)
Requirement already satisfied: scipy>=1.5 in /home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages (from jax>=0.3.5->pysages==0.5.0) (1.10.0)
Requirement already satisfied: typing-extensions in /home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages (from jax>=0.3.5->pysages==0.5.0) (4.11.0)
Collecting beartype>=0.16.2 (from plum-dispatch!=2.0.0,!=2.0.1,>=1.5.4->pysages==0.5.0)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/64/69/f6db6e4cb2fe2f887dead40b76caa91af4844cb647dd2c7223bb010aa416/beartype-0.19.0-py3-none-any.whl (1.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 1.4 MB/s eta 0:00:00
Collecting rich>=10.0 (from plum-dispatch!=2.0.0,!=2.0.1,>=1.5.4->pysages==0.5.0)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl (242 kB)
Collecting llvmlite<0.44,>=0.43.0dev0 (from numba->pysages==0.5.0)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/21/2ffbab5714e72f2483207b4a1de79b2eecd9debbf666ff4e7067bcc5c134/llvmlite-0.43.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (43.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 43.9/43.9 MB 1.6 MB/s eta 0:00:00
Collecting markdown-it-py>=2.2.0 (from rich>=10.0->plum-dispatch!=2.0.0,!=2.0.1,>=1.5.4->pysages==0.5.0)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl (87 kB)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages (from rich>=10.0->plum-dispatch!=2.0.0,!=2.0.1,>=1.5.4->pysages==0.5.0) (2.15.1)
Collecting mdurl~=0.1 (from markdown-it-py>=2.2.0->rich>=10.0->plum-dispatch!=2.0.0,!=2.0.1,>=1.5.4->pysages==0.5.0)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl (10.0 kB)
Building wheels for collected packages: pysages
  Building wheel for pysages (pyproject.toml) ... done
  Created wheel for pysages: filename=pysages-0.5.0-py3-none-any.whl size=117796 sha256=d9d97db55522297ba4ec17a6680790e0c13e87d92ea35e92a00b4055b7bc47b7
  Stored in directory: /tmp/pip-ephem-wheel-cache-zec6nip9/wheels/85/08/28/c73436bba0d28b37f2bbcf4081cdaa187aa06eef51e869c8a1
Successfully built pysages
Installing collected packages: mdurl, llvmlite, dill, beartype, numba, markdown-it-py, rich, plum-dispatch, pysages
Successfully installed beartype-0.19.0 dill-0.3.9 llvmlite-0.43.0 markdown-it-py-3.0.0 mdurl-0.1.2 numba-0.60.0 plum-dispatch-2.5.3 pysages-0.5.0 rich-13.9.4

安装测试

看起来是安装成功了,跑一个简单的用例试一试。先准备一个简单的pdb文件:

input.pdb
CRYST1    0.000    0.000    0.000  90.00  90.00  90.00 P 1           1
ATOM      1  H1  ACE A   1      -1.838  -6.570  -0.492  0.00  0.00
ATOM      2  CH3 ACE A   1      -0.764  -6.587  -0.283  0.00  0.00
ATOM      3  H2  ACE A   1      -0.392  -7.533  -0.746  0.00  0.00
ATOM      4  H3  ACE A   1      -0.592  -6.446   0.740  0.00  0.00
ATOM      5  C   ACE A   1      -0.006  -5.404  -0.828  0.00  0.00
ATOM      6  O   ACE A   1      -0.544  -4.619  -1.673  0.00  0.00
ATOM      7  N   ALA A   2       1.278  -5.323  -0.423  0.00  0.00
ATOM      8  H   ALA A   2       1.622  -5.845   0.368  0.00  0.00
ATOM      9  CA  ALA A   2       2.284  -4.164  -0.399  0.00  0.00
ATOM     10  HA  ALA A   2       2.098  -3.653   0.505  0.00  0.00
ATOM     11  CB  ALA A   2       3.651  -4.787  -0.566  0.00  0.00
ATOM     12  HB1 ALA A   2       4.274  -4.031  -0.972  0.00  0.00
ATOM     13  HB2 ALA A   2       3.977  -5.106   0.419  0.00  0.00
ATOM     14  HB3 ALA A   2       3.697  -5.612  -1.274  0.00  0.00
ATOM     15  C   ALA A   2       1.995  -3.152  -1.576  0.00  0.00
ATOM     16  O   ALA A   2       1.544  -2.065  -1.221  0.00  0.00
ATOM     17  N   NME A   3       2.255  -3.614  -2.845  0.00  0.00
ATOM     18  H   NME A   3       2.788  -4.485  -2.929  0.00  0.00
ATOM     19  CH3 NME A   3       1.991  -2.802  -4.055  0.00  0.00
ATOM     20 HH31 NME A   3       2.561  -1.891  -3.988  0.00  0.00
ATOM     21 HH32 NME A   3       1.897  -3.419  -4.937  0.00  0.00
ATOM     22 HH33 NME A   3       0.985  -2.388  -3.930  0.00  0.00
END

然后在上一篇文章中介绍的
OpenMM基础案例
的基础上增加一个PySAGES的MetaDynamics案例:

from openmm.app import PDBFile, ForceField, Simulation, PDBReporter, StateDataReporter, HBonds
from openmm import LangevinMiddleIntegrator
from openmm.unit import nanometer, kelvin, picoseconds, picosecond, BOLTZMANN_CONSTANT_kB, AVOGADRO_CONSTANT_NA, kilojoules_per_mole

import pysages
from pysages.colvars import DihedralAngle
from numpy import pi
from pysages.methods import Metadynamics, MetaDLogger

kB = BOLTZMANN_CONSTANT_kB * AVOGADRO_CONSTANT_NA
kB = kB.value_in_unit(kilojoules_per_mole / kelvin)

def NVT(pdb_name='input.pdb', pdb_out='output.pdb', ff='amber14-all.xml', log_file='log.dat'):
    pdb = PDBFile(pdb_name)
    forcefield = ForceField(ff)
    system = forcefield.createSystem(pdb.topology, nonbondedCutoff=1*nanometer, constraints=HBonds)

    integrator = LangevinMiddleIntegrator(300*kelvin, 1/picosecond, 0.004*picoseconds)
    simulation = Simulation(pdb.topology, system, integrator)
    simulation.context.setPositions(pdb.positions)

    simulation.minimizeEnergy()
    simulation.reporters.append(PDBReporter(pdb_out, 1000))
    simulation.reporters.append(StateDataReporter(log_file, 1000, step=True, potentialEnergy=True, temperature=True, volume=True))
    return simulation

def MetaD(hills_file="hills.dat", time_steps=10000):
    cvs = [DihedralAngle([4, 6, 8, 14]), DihedralAngle([6, 8, 14, 16])]
    height = 1.2  # kJ/mol
    sigma = [0.35, 0.35]  # radians
    deltaT = 5000
    stride = 500
    ngauss = time_steps // stride + 1
    grid = pysages.Grid(lower=(-pi, -pi), upper=(pi, pi), shape=(50, 50), periodic=True)
    method = Metadynamics(cvs, height, sigma, stride, ngauss, deltaT=deltaT, kB=kB, grid=grid)
    callback = MetaDLogger(hills_file, stride)
    run_result = pysages.run(method, NVT, time_steps, callback)
    result = pysages.analyze(run_result)
    metapotential = result["metapotential"]
    return metapotential

if __name__ == '__main__':
    potential = MetaD()
    print (potential)

发生了一个报错:

Traceback (most recent call last):
  File "/home/dechin/projects/gitee/dechin/tests/test_openmm.py", line 43, in <module>
    potential = MetaD()
  File "/home/dechin/projects/gitee/dechin/tests/test_openmm.py", line 37, in MetaD
    run_result = pysages.run(method, NVT, time_steps, callback)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/plum/function.py", line 383, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 230, in run
    futures = [submit_work(ex, method, callback) for _ in range(config.copies)]
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 230, in <listcomp>
    futures = [submit_work(ex, method, callback) for _ in range(config.copies)]
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 218, in submit_work
    return executor.submit(
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/utils.py", line 33, in submit
    future.set_result(fn(*args, **kwargs))
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 324, in _run_replica
    return run(method, *args, **kwargs)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/plum/function.py", line 383, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 371, in _run
    sampling_context = SamplingContext(method, context_generator, callback, context_args)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/backends/core.py", line 101, in __init__
    backend = import_module("." + self._backend_name, package="pysages.backends")
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/backends/openmm.py", line 8, in <module>
    import openmm_dlext as dlext
ModuleNotFoundError: No module named 'openmm_dlext'

提示是需要安装一个
openmm_dlext
的插件。因为这个插件只有一个Github仓库,没有太多的文档,也没有介绍怎么安装的。我测试过下载源码下来,
cmake&&make install
去编译构建,但是又会有很多其他的报错提示要处理,最终我采取的方案是使用conda安装:

$ conda install conda-forge::openmm-dlext

安装完成后再次运行上面的案例,又有一个新的报错:

$ python3 test_openmm.py 
Traceback (most recent call last):
  File "/home/dechin/projects/gitee/dechin/tests/test_openmm.py", line 43, in <module>
    potential = MetaD()
  File "/home/dechin/projects/gitee/dechin/tests/test_openmm.py", line 37, in MetaD
    run_result = pysages.run(method, NVT, time_steps, callback)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/plum/function.py", line 383, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 230, in run
    futures = [submit_work(ex, method, callback) for _ in range(config.copies)]
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 230, in <listcomp>
    futures = [submit_work(ex, method, callback) for _ in range(config.copies)]
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 218, in submit_work
    return executor.submit(
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/utils.py", line 33, in submit
    future.set_result(fn(*args, **kwargs))
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 324, in _run_replica
    return run(method, *args, **kwargs)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/plum/function.py", line 383, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 371, in _run
    sampling_context = SamplingContext(method, context_generator, callback, context_args)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/backends/core.py", line 102, in __init__
    self.sampler = backend.bind(self, callback, **kwargs)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/backends/openmm.py", line 194, in bind
    force.add_to(context)  # OpenMM will handle the lifetime of the force
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/openmm_dlext/__init__.py", line 29, in add_to
    self.__alt__.add_to(_to_capsule(context), _to_capsule(system))
RuntimeError: Unsupported platform

关于这个报错,我在
openmm_dlext的Issue
里面找到了相应的解决方案,说是要手动配置一下CUDA Platform,于是修改一下代码:

from openmm.app import PDBFile, ForceField, Simulation, PDBReporter, StateDataReporter, HBonds
from openmm import LangevinMiddleIntegrator, Platform
from openmm.unit import nanometer, kelvin, picoseconds, picosecond, BOLTZMANN_CONSTANT_kB, AVOGADRO_CONSTANT_NA, kilojoules_per_mole

import pysages
from pysages.colvars import DihedralAngle
from numpy import pi
from pysages.methods import Metadynamics, MetaDLogger

openmm_platform = Platform.getPlatformByName('CUDA')
kB = BOLTZMANN_CONSTANT_kB * AVOGADRO_CONSTANT_NA
kB = kB.value_in_unit(kilojoules_per_mole / kelvin)

def NVT(pdb_name='input.pdb', pdb_out='output.pdb', ff='amber14-all.xml', log_file='log.dat', platform=openmm_platform):
    pdb = PDBFile(pdb_name)
    forcefield = ForceField(ff)
    system = forcefield.createSystem(pdb.topology, nonbondedCutoff=1*nanometer, constraints=HBonds)

    integrator = LangevinMiddleIntegrator(300*kelvin, 1/picosecond, 0.004*picoseconds)
    simulation = Simulation(pdb.topology, system, integrator, platform=platform)
    simulation.context.setPositions(pdb.positions)

    simulation.minimizeEnergy()
    simulation.reporters.append(PDBReporter(pdb_out, 1000))
    simulation.reporters.append(StateDataReporter(log_file, 1000, step=True, potentialEnergy=True, temperature=True, volume=True))
    return simulation

def MetaD(hills_file="hills.dat", time_steps=10000):
    cvs = [DihedralAngle([4, 6, 8, 14]), DihedralAngle([6, 8, 14, 16])]
    height = 1.2  # kJ/mol
    sigma = [0.35, 0.35]  # radians
    deltaT = 5000
    stride = 500
    ngauss = time_steps // stride + 1
    grid = pysages.Grid(lower=(-pi, -pi), upper=(pi, pi), shape=(50, 50), periodic=True)
    method = Metadynamics(cvs, height, sigma, stride, ngauss, deltaT=deltaT, kB=kB, grid=grid)
    callback = MetaDLogger(hills_file, stride)
    run_result = pysages.run(method, NVT, time_steps, callback)
    result = pysages.analyze(run_result)
    metapotential = result["metapotential"]
    return metapotential

if __name__ == '__main__':
    potential = MetaD()
    print (potential)

再次运行,又出现一个新的报错:

Traceback (most recent call last):
  File "/home/dechin/projects/gitee/dechin/tests/test_openmm.py", line 44, in <module>
    potential = MetaD()
  File "/home/dechin/projects/gitee/dechin/tests/test_openmm.py", line 38, in MetaD
    run_result = pysages.run(method, NVT, time_steps, callback)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/plum/function.py", line 383, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 230, in run
    futures = [submit_work(ex, method, callback) for _ in range(config.copies)]
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 230, in <listcomp>
    futures = [submit_work(ex, method, callback) for _ in range(config.copies)]
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 218, in submit_work
    return executor.submit(
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/utils.py", line 33, in submit
    future.set_result(fn(*args, **kwargs))
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 324, in _run_replica
    return run(method, *args, **kwargs)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/plum/function.py", line 383, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 371, in _run
    sampling_context = SamplingContext(method, context_generator, callback, context_args)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/backends/core.py", line 78, in __init__
    context = context_generator(**context_args)
  File "/home/dechin/projects/gitee/dechin/tests/test_openmm.py", line 20, in NVT
    simulation = Simulation(pdb.topology, system, integrator, platform=platform)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/openmm/app/simulation.py", line 104, in __init__
    self.context = mm.Context(self.system, self.integrator, platform)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/openmm/openmm.py", line 19511, in __init__
    _openmm.Context_swiginit(self, _openmm.new_Context(*args))
openmm.OpenMMException: Error loading CUDA module: CUDA_ERROR_UNSUPPORTED_PTX_VERSION (222)

好,这次是CUDA版本号不支持,类似的问题在一条
openmm的issue
中有讨论过,需要检查一下自己本地cudatoolkit的配置信息:

$ conda list | grep cudatoolkit
$

这里发现在这个虚拟环境里面没有配置cudatoolkit,需要再装一个:

$ conda install -c conda-forge cudatoolkit=11.6
Collecting package metadata (current_repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Collecting package metadata (repodata.json): done
Solving environment: done


==> WARNING: A newer version of conda exists. <==
  current version: 23.1.0
  latest version: 24.11.0

Please update conda by running

    $ conda update -n base -c defaults conda

Or to minimize the number of packages updated during conda update use

     conda install conda=24.11.0



## Package Plan ##

  environment location: /home/dechin/anaconda3/envs/jax

  added / updated specs:
    - cudatoolkit=11.6


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    cudatoolkit-11.6.2         |      hfc3e2af_13       598.8 MB  conda-forge
    openmm-8.1.1               |  py310h358ce72_1        10.8 MB  conda-forge
    openmm-dlext-0.2.1         |  py310h552f1b7_8         115 KB  conda-forge
    ------------------------------------------------------------
                                           Total:       609.8 MB

The following NEW packages will be INSTALLED:

  cudatoolkit        conda-forge/linux-64::cudatoolkit-11.6.2-hfc3e2af_13 

The following packages will be REMOVED:

  cuda-nvrtc-12.4.127-h99ab3db_1
  cuda-version-12.4-hbda6634_3
  libcufft-11.2.1.3-h99ab3db_1

The following packages will be UPDATED:

  openssl            anaconda/pkgs/main::openssl-3.0.15-h5~ --> conda-forge::openssl-3.4.0-hb9d3cd8_0 

The following packages will be SUPERSEDED by a higher-priority channel:

  ca-certificates    anaconda/pkgs/main::ca-certificates-2~ --> conda-forge::ca-certificates-2024.8.30-hbcca054_0 
  openmm             anaconda/cloud/conda-forge::openmm-8.~ --> conda-forge::openmm-8.1.1-py310h358ce72_1 

The following packages will be DOWNGRADED:

  openmm-dlext                        0.2.1-py310hcb41016_8 --> 0.2.1-py310h552f1b7_8 


Proceed ([y]/n)? y


Downloading and Extracting Packages
                                                                                                                                                                        
Preparing transaction: done                                                                                                                                             
Verifying transaction: done                                                                                                                                             
Executing transaction: | By downloading and using the CUDA Toolkit conda packages, you accept the terms and conditions of the CUDA End User License Agreement (EULA): https://docs.nvidia.com/cuda/eula/index.html

done

再次执行上面的程序,报错+1:

$ python3 test_openmm.py 
Traceback (most recent call last):
  File "/home/dechin/projects/gitee/dechin/tests/test_openmm.py", line 44, in <module>
    potential = MetaD()
  File "/home/dechin/projects/gitee/dechin/tests/test_openmm.py", line 38, in MetaD
    run_result = pysages.run(method, NVT, time_steps, callback)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/plum/function.py", line 383, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 230, in run
    futures = [submit_work(ex, method, callback) for _ in range(config.copies)]
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 230, in <listcomp>
    futures = [submit_work(ex, method, callback) for _ in range(config.copies)]
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 218, in submit_work
    return executor.submit(
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/utils.py", line 33, in submit
    future.set_result(fn(*args, **kwargs))
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 324, in _run_replica
    return run(method, *args, **kwargs)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/plum/function.py", line 383, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/methods/core.py", line 371, in _run
    sampling_context = SamplingContext(method, context_generator, callback, context_args)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/backends/core.py", line 102, in __init__
    self.sampler = backend.bind(self, callback, **kwargs)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/backends/openmm.py", line 197, in bind
    helpers, restore, bias = build_helpers(sampling_context.view, sampling_method)
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/backends/openmm.py", line 135, in build_helpers
    sync_forces, view = utils.cupy_helpers()
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/site-packages/pysages/backends/utils.py", line 21, in cupy_helpers
    cupy = importlib.import_module("cupy")
  File "/home/dechin/anaconda3/envs/jax/lib/python3.10/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1004, in _find_and_load_unlocked
ModuleNotFoundError: No module named 'cupy'

不过这个看起来好处理,就是少装了一个cupy的依赖,稳妥起见,我们还是选择使用conda来安装cupy:

$ conda install -c conda-forge cupy -y
Collecting package metadata (current_repodata.json): done
Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: - 
failed with repodata from current_repodata.json, will retry with next repodata source.
Collecting package metadata (repodata.json): done
Solving environment: done


==> WARNING: A newer version of conda exists. <==
  current version: 23.1.0
  latest version: 24.11.0

Please update conda by running

    $ conda update -n base -c defaults conda

Or to minimize the number of packages updated during conda update use

     conda install conda=24.11.0



## Package Plan ##

  environment location: /home/dechin/anaconda3/envs/jax

  added / updated specs:
    - cupy


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    cuda-version-11.6          |       hca96458_3          21 KB  conda-forge
    cupy-13.3.0                |  py310h189a05f_2         347 KB  conda-forge
    cupy-core-13.3.0           |  py310h5da974a_2        42.9 MB  conda-forge
    fastrlock-0.8.2            |  py310hc6cd4ac_2          37 KB  conda-forge
    ------------------------------------------------------------
                                           Total:        43.3 MB

The following NEW packages will be INSTALLED:

  cuda-version       conda-forge/noarch::cuda-version-11.6-hca96458_3 
  cupy               conda-forge/linux-64::cupy-13.3.0-py310h189a05f_2 
  cupy-core          conda-forge/linux-64::cupy-core-13.3.0-py310h5da974a_2 
  fastrlock          conda-forge/linux-64::fastrlock-0.8.2-py310hc6cd4ac_2 



Downloading and Extracting Packages
                                                                                                                                                                        
Preparing transaction: done                                                                                                                                             
Verifying transaction: done                                                                                                                                             
Executing transaction: done                 

然后再执行测试程序:

$ python3 test_openmm.py 
<CompiledFunction of <function analyze.<locals>.<lambda> at 0x7fc1a04aba30>>

同时会在执行路径下生成
log.dat
文件和
hills.dat
文件如下:

log.dat
#"Step","Potential Energy (kJ/mole)","Temperature (K)","Box Volume (nm^3)"
1000,-39.254608154296875,284.2437497410935,8.0
2000,-18.68511962890625,318.9596355900776,8.0
3000,-30.86761474609375,289.998240351891,8.0
4000,-21.921295166015625,338.6328004320157,8.0
5000,-31.451812744140625,245.66209355694957,8.0
6000,-32.077880859375,182.1505555238515,8.0
7000,-56.050750732421875,252.68200201473644,8.0
8000,-27.819427490234375,289.4957194622587,8.0
9000,-36.86553955078125,271.0313362334861,8.0
10000,-12.531005859375,254.41902934626566,8.0
hills.dat
500	-1.3592318296432495	1.5952203273773193	0.35	0.35	1.2
1000	-1.1732323169708252	0.8145138621330261	0.35	0.35	1.1973907824735082
1500	-1.433311104774475	1.7097197771072388	0.35	0.35	1.1671557288100634
2000	-1.114228367805481	2.042632579803467	0.35	0.35	1.179103510697602
2500	-1.1403875350952148	0.9936402440071106	0.35	0.35	1.1603072043059584
3000	-1.2672390937805176	0.40286365151405334	0.35	0.35	1.173835519022326
3500	-1.302258014678955	1.4455255270004272	0.35	0.35	1.1211946752964734
4000	-1.4070658683776855	1.013744592666626	0.35	0.35	1.121531633082655
4500	-2.6735711097717285	2.6055266857147217	0.35	0.35	1.1999969285502206
5000	-2.8140780925750732	2.9895386695861816	0.35	0.35	1.1809462925907876
5500	-2.6453146934509277	2.5226593017578125	0.35	0.35	1.1505253387717271
6000	-2.476658344268799	2.7877397537231445	0.35	0.35	1.1410532890823843
6500	-2.7321791648864746	-2.84220814704895	0.35	0.35	1.1797609616409976
7000	-1.596192479133606	1.0611979961395264	0.35	0.35	1.1126547940366505
7500	-1.3219820261001587	0.2645364999771118	0.35	0.35	1.1460450294138773
8000	-1.5232703685760498	1.4845924377441406	0.35	0.35	1.0865578670844307
8500	-1.3037762641906738	0.6937571167945862	0.35	0.35	1.076390175717164
9000	-1.3598891496658325	2.0672760009765625	0.35	0.35	1.1276490649082718
9500	-2.420367479324341	2.7348878383636475	0.35	0.35	1.1032026693365438

喜大普奔,PySAGES环境部署完毕!

案例测试

还是沿用上面的
input.pdb
案例,这里我们测试一个MetaDynamics的FES,增加了一个analyse的plot模块:

from openmm.app import PDBFile, ForceField, Simulation, PDBReporter, StateDataReporter, HBonds
from openmm import LangevinMiddleIntegrator, Platform
from openmm.unit import nanometer, kelvin, picoseconds, picosecond, BOLTZMANN_CONSTANT_kB, AVOGADRO_CONSTANT_NA, kilojoules_per_mole
from sys import stdout

import pysages
from pysages.colvars import DihedralAngle
from numpy import pi
from pysages.methods import Metadynamics, MetaDLogger
from pysages.approxfun import compute_mesh

import matplotlib.pyplot as plt

openmm_platform = Platform.getPlatformByName('CUDA')
kB = BOLTZMANN_CONSTANT_kB * AVOGADRO_CONSTANT_NA
kB = kB.value_in_unit(kilojoules_per_mole / kelvin)
T = 300*kelvin
dt = 0.004*picoseconds

def NVT(pdb_name='input.pdb', pdb_out='output.pdb', ff='amber14-all.xml', log_file='log.dat', platform=openmm_platform):
    pdb = PDBFile(pdb_name)
    forcefield = ForceField(ff)
    system = forcefield.createSystem(pdb.topology, nonbondedCutoff=1*nanometer, constraints=HBonds)

    integrator = LangevinMiddleIntegrator(T, 1/picosecond, dt)
    simulation = Simulation(pdb.topology, system, integrator, platform=platform)
    simulation.context.setPositions(pdb.positions)

    simulation.minimizeEnergy()
    simulation.reporters.append(PDBReporter(pdb_out, 1000))
    simulation.reporters.append(StateDataReporter(stdout, 1000, step=True, potentialEnergy=True, temperature=True, volume=True))
    simulation.reporters.append(StateDataReporter(log_file, 1000, step=True, potentialEnergy=True, temperature=True, volume=True))
    return simulation

def plot_grid(metapotential, method):
    plot_grid = pysages.Grid(lower=(-pi, -pi), upper=(pi, pi), shape=(64, 64), periodic=True)
    xi = (compute_mesh(plot_grid) + 1) / 2 * plot_grid.size + plot_grid.lower
    alpha = (
        1
        if method.deltaT is None
        else (T.value_in_unit(kelvin) + method.deltaT) / method.deltaT
    )
    kT = kB * T.value_in_unit(kelvin)
    A = metapotential(xi) * -alpha / kT
    A = A - A.min()
    A = A.reshape(plot_grid.shape)
    # plot and save free energy to a PNG file
    fig, ax = plt.subplots(dpi=120)

    im = ax.imshow(A, interpolation="bicubic", origin="lower", extent=[-pi, pi, -pi, pi])
    ax.contour(A, levels=12, linewidths=0.75, colors="k", extent=[-pi, pi, -pi, pi])
    ax.set_xlabel(r"$\phi$")
    ax.set_ylabel(r"$\psi$")

    cbar = plt.colorbar(im)
    cbar.ax.set_ylabel(r"$A~[k_{B}T]$", rotation=270, labelpad=20)

    fig.savefig("Figure.png", dpi=fig.dpi)

def MetaD(hills_file="hills.dat", time_steps=50000):
    cvs = [DihedralAngle([4, 6, 8, 14]), DihedralAngle([6, 8, 14, 16])]
    height = 2.0  # kJ/mol
    sigma = [0.2, 0.2]  # radians
    deltaT = 5000
    stride = 50
    ngauss = time_steps // stride + 1
    grid = pysages.Grid(lower=(-pi, -pi), upper=(pi, pi), shape=(50, 50), periodic=True)
    method = Metadynamics(cvs, height, sigma, stride, ngauss, deltaT=deltaT, kB=kB, grid=grid)
    callback = MetaDLogger(hills_file, stride)
    run_result = pysages.run(method, NVT, time_steps, callback)
    result = pysages.analyze(run_result)
    metapotential = result["metapotential"]
    plot_grid(metapotential, method)
    return result

if __name__ == '__main__':
    res = MetaD()

因为在OpenMM的Simulation的Report中我们增加了一个
stdout
,因此会同时在屏幕上输出结果,也会在相应的
log.dat
文件中保存结果,运行输出如下:

#"Step","Potential Energy (kJ/mole)","Temperature (K)","Box Volume (nm^3)"
1000,9.73681640625,377.6993364201515,8.0
2000,-23.971282958984375,298.44721588786604,8.0
3000,-15.113677978515625,213.76058649837066,8.0
4000,-9.906219482421875,444.7447921758242,8.0
5000,-35.83831787109375,299.89809403902336,8.0
6000,-33.826202392578125,313.4731859605099,8.0
7000,-19.394073486328125,337.10699269365875,8.0
8000,-45.882415771484375,250.66251735991736,8.0
9000,-14.17413330078125,358.22016011687015,8.0
10000,-29.421051025390625,246.90072858113598,8.0
11000,-19.7567138671875,301.9975514069083,8.0
12000,-32.948822021484375,367.195135668361,8.0
13000,-9.27825927734375,289.7929305791173,8.0
14000,-30.180389404296875,309.66557282887885,8.0
15000,-2.736083984375,302.5309003205113,8.0
16000,-32.576629638671875,291.86829747083937,8.0
17000,-14.334503173828125,231.850481046364,8.0
18000,-20.755645751953125,298.57497669296873,8.0
19000,-43.75299072265625,306.65343794873587,8.0
20000,33.35467529296875,258.82936068957366,8.0
21000,-1.04156494140625,339.65646518408494,8.0
22000,0.01190185546875,197.8572390770094,8.0
23000,5.273040771484375,289.2310517046787,8.0
24000,-14.901947021484375,383.9835521646287,8.0
25000,-0.839019775390625,268.7104144595147,8.0
26000,-23.747772216796875,222.84395037451839,8.0
27000,-27.284759521484375,285.9093985100245,8.0
28000,-23.12164306640625,248.21416090812198,8.0
29000,10.6822509765625,319.4106894537426,8.0
30000,-16.64678955078125,304.24748131130184,8.0
31000,-6.5423583984375,329.8362141299685,8.0
32000,-3.944793701171875,333.3584976075751,8.0
33000,-29.894744873046875,355.53462625307105,8.0
34000,-22.54876708984375,366.93298893561547,8.0
35000,-14.81097412109375,330.12835522481674,8.0
36000,-45.39825439453125,363.9047710837139,8.0
37000,-5.33160400390625,355.4129749973852,8.0
38000,-19.806365966796875,361.73243838698073,8.0
39000,-13.85650634765625,411.9526625662002,8.0
40000,1.711639404296875,225.61063301965956,8.0
41000,-34.70196533203125,389.73863301467037,8.0
42000,-33.63153076171875,307.1604571229406,8.0
43000,-37.86602783203125,277.5274210978626,8.0
44000,-3.5263671875,248.0723224072663,8.0
45000,21.574676513671875,294.5427172838524,8.0
46000,-31.097808837890625,302.0992611330069,8.0
47000,-23.125152587890625,307.7661366778404,8.0
48000,8.406402587890625,174.53798831979185,8.0
49000,-19.694549560546875,380.88297517197196,8.0
50000,12.754608154296875,380.40854584876394,8.0

这里输出的FES被保存成了一个图片,内容为:

这就是PySAGES的Well-Tempered MetaDynamics输出的FES结果。

工作流

PySAGES的工作流是这样的:

这里我们的backend使用的就是OpenMM了,大致的流程是,通过PySAGES来构建对应backend的Simulation对象,然后启动Simulation。每一次需要
update bias force
的时候,从backend传回来一个force,在PySAGES层面加入bias force然后传回backend。循环迭代,直至time step截止。

PySAGES自带了一些增强采样的方法和一些定义好的CV,当然,因为其基于Jax-Python开发,因此自定义一个新的CV在形式上也非常的简洁:

最关键的,一般这种外接的增强采样软件会很大程度上影响到整体分子模拟的性能,甚至很可能成为Bottleneck。而根据PySAGES官方给出的profile结果来看:

MetaDynamics部分的时间占比并没有成为Bottleneck,从时长比例上来说,这个性能表现是非常突出的。

总结概要

本文主要介绍了增强采样外接软件PySAGES的基本安装和使用方法,重点是安装过程中没有写清楚的一些环境依赖和可能出现的问题介绍,以及相应的解决方案。并简单的梳理了一下PySAGES软件的工作流机制,其能够做到Zero Copy,并使得Enhanced Sampling不再成为很多模拟的Bottleneck,这是一个相当出色的结果。

版权声明

本文首发链接为:
https://www.cnblogs.com/dechinphy/p/pysages.html

作者ID:DechinPhy

更多原著文章:
https://www.cnblogs.com/dechinphy/

请博主喝咖啡:
https://www.cnblogs.com/dechinphy/gallery/image/379634.html

参考链接

  1. https://pysages.readthedocs.io/en/latest/installation.html

简介

在上一篇文章《
机器学习:线性回归(上)
》中讨论了二维数据下的线性回归及求解方法,本节中我们将进一步的将其推广至高维情形。

章节安排

  1. 背景介绍
  2. 最小二乘法
  3. 梯度下降法
  4. 程序实现

一、背景介绍

1.1 超平面
\(L\)
的定义


定义在
\(D\)
维空间中的超平面
\(L\)
的方程为:

\[\begin{align*}
L:\text w^T\text x+b=0 \tag{1.1}
\end{align*}
\]

其中:
\(\text w^T=[w_0,w_1,\dots,w_D]\)
为不同维度的系数或权重,
\(\text x^T=[x_0,x_1,\dots ,x_D]\)
为数据样本的特征向量。

在该定义中,超平面
\(L\)
是由是由法向量
\(w\)
和偏置项
\(b\)
决定的。具体来说,超平面
\(L\)

\(D\)
维空间划分为两个半空间,一个半空间满足
\(\text w^T\text x+b>0\)
,另一个半空间满足
\(\text w^T\text x+b<0\)
,式
\((1.1)\)
称为矩阵表示法,也可以用标量表示法表示为:

\[\begin{align*}
L:\sum_{i=1}^{D}w_ix_i+b=w_1x_1+w_2x_2+\cdots+w_Dx_D+b=0
\tag{1.2}
\end{align*}
\]

在一些情况下,也会将偏置项
\(b\)
引入向量中,该方法分别对权重
\(w\)
和特征值
\(x\)
做增广:

\[\begin{align*}
x^T&=[1,x_1,x_2,\dots,x_D]\\
w^T&=[b,w_1,w_2,\dots,w_D]
\end{align*}
\]

在此基础上,超平面
\(L\)
的定义可以简化为:

\[\begin{align*}
L:\text w^T\text x=0 \tag{1.3}
\end{align*}
\]

有时也简称

\[\begin{align*}
L(\text x)=0 \tag{1.4}
\end{align*}
\]

示例

为方便读者理解,这里给出一个从二维的直线方程到超平面方程
\(L\)
的转换

\[\begin{align*}
y&=kx+b\\
kx-y+b&=0\\
\begin{bmatrix}
b&k&-1
\end{bmatrix}
\cdot
\begin{bmatrix}
1\\
x\\
y
\end{bmatrix}
&=0
\end{align*}
\]

1.2 高维线性回归


在高维线性回归任务中,采样数据的形式为
\(S=\{\text X,\text y\}\)
,其中
\(X\)
称为采样数据,为
\(N\times D\)
的矩阵,
\(y\)
称为标签数据,更具体的有:

\[\text X^T=[\text x_0,\text x_1, \dots, \text x_N], \text x_i=[x_{i1},x_{i2},\dots,x_{iD}], \text x_i \in \mathbb{R}^D
\]

\[\text y^T =[y_0,y_1,\dots,y_N]
\]

在高维数据的回归任务中,我们的目标是找到一个权重
\(\text w\)
,使得其能够对特征数据
\(\text X\)
给出预测
\(\hat{\text y}\)

\[\hat{\text y}=\text X \text w
\]

其中:
\(\text w^T=[w_1,\dots,w_D]\)
是大小为
\(D*1\)
的向量。
同时,我们可以定义
均方根误差(MSE)
如下:

\[\begin{align*}
\text{MSE}=\big \| \text y-\text X\text w\big\|_2^2
\end{align*}
\]

其中
\(\|\cdot\|_2\)
为二范数,或欧几里得距离。
线性回归的目标为,最小化损失,下面我们将从最小二乘法和梯度下降法两个角度实现线性回归。

二、最小二乘法


最小二乘法(Least Squares Method)是一种广泛使用的线性回归问题的求解方法,其核心思想是,均方根误差MSE关于权重
\(w\)
的偏导为0时所求得的
\(w\)
为最优解,故对MSE化简如下:

\[\begin{align*}
\text{MSE}&=\big \| \text y-\text X\text w\big\|_2^2\\
&=(\text y-\text X\text w)^T(\text y-\text X\text w)\\
&=\text y^T\text y-\text w^T\text X^T \text y-\text y\text X\text w+\text w^T \text X^T \text X \text w\\
\end{align*}
\]

由于
\(\text w^T\text X^T \text y\)

\(\text y\text X\text w\)
是标量,其数值相等,故有:

\[\begin{align*}
\text{MSE}&=\text y^T\text y-2\text w^T\text X^T \text y+\text w^T \text X^T \text X \text w
\end{align*}
\]


\(\text {MSE}\)
关于
\(\text w\)
的偏导得:

\[\begin{align*}
\frac{\partial\text{MSE}}{\partial\text w}&=-2\text X^T\text y+2 \text X^T \text X \text w
\end{align*}
\]

另偏导等于
\(0\)
得:

\[\begin{align*}
\text X^T\text y&= \text X^T \text X \text w \tag{2.1}
\end{align*}
\]

该方程称为
正规方程(Normal Equation)
,解该方程可得:

\[\text w =(\text X^T\text X)^{-1}\text X^T \text y
\]

2.1 最小二乘法缺点

以下是最小二乘法的主要缺点:

矩阵逆计算的复杂性
最小二乘法的解析解需要计算矩阵
\(\text X^T \text X\)
的逆矩阵:

\[\text w = (\text X^T \text X)^{-1} \text X^T \text y \tag{2.2}
\]

在高维情况下(即特征数量
\(d\)
较大),计算
\(\text X^T \text X\)
的逆矩阵的计算复杂度很高,甚至可能不可行。具体来说:

  • 计算
    \(\text X^T \text X\)
    的时间复杂度为
    \(O(n d^2)\)
    ,其中
    \(n\)
    是样本数量,
    \(d\)
    是特征数量。
  • 计算矩阵逆的时间复杂度为
    \(O(d^3)\)

因此,当
\(d\)
很大时,计算逆矩阵的代价非常高。

矩阵不可逆问题

在高维情况下,特征数量
\(d\)
可能大于样本数量
\(n\)
,此时矩阵
\(\text X^T \text X\)
可能是不可逆的(即奇异矩阵),这意味着无法直接计算其逆矩阵。此外,即使矩阵可逆,也可能因为浮点数精度问题导致计算结果不稳定。

对异常值敏感

最小二乘法对异常值非常敏感。因为最小二乘法最小化的是平方误差,所以异常值会对模型的拟合产生较大的影响。这可能导致模型的泛化能力下降。

不适用于稀疏数据

对于稀疏数据(即特征矩阵中有大量零元素),最小二乘法的计算效率较低。稀疏数据通常更适合使用稀疏矩阵的优化方法,如 Lasso 或 Ridge 回归。

过拟合问题

如果没有正则化,最小二乘法容易过拟合,尤其是在特征数量远大于样本数量的情况下。过拟合会导致模型在训练集上表现很好,但在测试集上表现很差。

总结

尽管最小二乘法在许多情况下是一个简单有效的线性回归求解方法,但它也存在一些明显的缺点,特别是在高维数据和复杂情况下。为了克服这些缺点,可以考虑使用其他优化方法,如梯度下降、岭回归(Ridge Regression)、Lasso 回归等,这些方法在计算效率、对异常值的鲁棒性和防止过拟合方面有更好的表现。

三、梯度下降法


梯度下降法是一种常用的优化算法。通过迭代更新模型的参数,使得均方误差逐步减小,最终达到最优解。

对于单个样本
\(\{\text x_i, y_i\}\)
,其损失函数定义为:

\[J(\text w)=(y-\text x_i \text w)^2
\]

求其关于权重的偏导得:

\[\begin{align*}
\frac{\partial}{\partial \text w}J(\text w)&=\frac{\partial}{\partial \text w}(y-\text x_i\text w)^2\\
&=2(y-\text x\text w)\text x\tag{3.1}
\end{align*}
\]

故有参数修正公式如下:

\[\begin{align*}
\text w:=\text w -\lambda\cdot \frac{\partial J}{\partial \text w} \tag{3.2}
\end{align*}
\]

四、程序实现

4.1 生成测试数据


程序流程:

  1. 定义特征维数
    feature_num
    及点个数
    point_num
  2. 定义权重向量
    w
    ,特征数据
    X
    ,标签数据
    y
  3. 生成随机数,填充
    w

    X
  4. 定义误差向量
    error
    ,并用随机数填充
  5. 计算
    y
#include <iostream>
#include <vector>
#include <Eigen/Dense>

// Multiple linear regression data generation
namespace MLR {
    void gen(Eigen::VectorXd& w, Eigen::MatrixXd& X, Eigen::VectorXd& y) {
        if (w.rows() != X.cols()) {
            throw std::invalid_argument("Dimension mismatch: The number of rows in w must equal the number of columns in X.");
        }
        if (X.rows() != y.rows()) {
            throw std::invalid_argument("Dimension mismatch: The number of rows in X must equal the number of rows in y.");
        }

        w.setRandom();
        X.setRandom();

        Eigen::VectorXd error(y.rows());
        error.setRandom();
        error *= 0.02;

        y = X * w + error;

        return;
    }
}


int main() {
    const size_t point_num = 10;
    const size_t feature_num = 7;

    Eigen::VectorXd w(feature_num);
    Eigen::MatrixXd X(point_num, feature_num);
    Eigen::VectorXd y(point_num);

    MLR::gen(w, X, y);

    std::cout << "y =\n" << y << "\n";

    return 0;
}

4.2 最小二乘法实现:


程序流程:

  1. 构建向量
    wp
    用以存储计算结果
  2. 采用公式
    \((2.2)\)
    计算权重
    wp
  3. 输出
    w-wp
    以观察计算误差

Eigen库中求逆、求转置都需要以矩阵为主体,例如:
M.inverse()

M.transpose()

取名
wp
是因为Weight prediction的首字母。

void LSM(Eigen::VectorXd& w, Eigen::MatrixXd& X, Eigen::VectorXd& y) {
    if (w.rows() != X.cols()) {
        throw std::invalid_argument("Dimension mismatch: The number of rows in w must equal the number of columns in X.");
    }
    if (X.rows() != y.rows()) {
        throw std::invalid_argument("Dimension mismatch: The number of rows in X must equal the number of rows in y.");
    }

    w = (X.transpose() * X).inverse() * X.transpose() * y;
}

int main() {
    // ...

    Eigen::VectorXd wp(feature_num);

    LSM(wp, X, y);

    std::cout << "w_error =\n" << w-wp << "\n";

    return 0;
}

下图为程序输出结果,由该图可以看出,最小二乘法的估计较为准确。
description

4.3 梯度下降法实现


程序流程:

  1. 构建向量
    wp
    ,并初始化为随机权重。
  2. 每一个数据样本
    x
    ,依据公式
    \((3.2)\)
    更新一次权重。(
    GD_step
    函数功能)
  3. 重复步骤2,100次。
  4. 输出
    w-wp
    以观察计算误差

注意事项:

在该算法中,我们将样本的个数改为100个,即:
feature_num = 100

学习率过高会导致发散,详细参考上一篇文章:《
机器学习:线性回归(上)

下式子作用是将矩阵
X
的第
idx
行读取为列向量
Eigen::VectorXd x = X.row(idx);
这与我们的使用直觉不符,实际上应为行向量。为避免出错,在后续计算中应使用
x.transpose()
而非直接使用
x

有一种方法可以规避该问题,即使用点积(内积)进行计算。在代码中给出了相关的示例(注释部分)

void GD_step(Eigen::VectorXd& w, Eigen::MatrixXd& X, Eigen::VectorXd& y, const double& lambda) {
    if (w.rows() != X.cols()) {
        throw std::invalid_argument("Dimension mismatch: The number of rows in w must equal the number of columns in X.");
    }
    if (X.rows() != y.rows()) {
        throw std::invalid_argument("Dimension mismatch: The number of rows in X must equal the number of rows in y.");
    }

    for (size_t idx = 0; idx < X.rows(); ++idx) {
        Eigen::VectorXd x = X.row(idx);

        // 使用点积
        // Eigen::VectorXd gradient = 2 * (y(idx) - x.dot(w)) * x;

        // 因为 y-x*w是标量,且输出结果为VectorXd,因此最后的transpose是可去的。
        // Eigen::VectorXd gradient = 2 * (y(idx) - x.transpose() * w) * x.transpose();

        Eigen::VectorXd gradient = 2 * (y(idx) - x.transpose() * w) * x;

        w += lambda * gradient;
    }
}

int main() {
    const size_t point_num = 100;
    
    // ...

    Eigen::VectorXd wp(feature_num);
    wp.setRandom(); // 生成初始值

    double lambda = 2e-3;

    for (int _ = 0; _ < 100; ++_) {
        GD_step(wp, X, y, lambda);
    }

    std::cout << "w_error =\n" << w - wp << "\n";

    return 0;
}

下图为程序输出结果,由该图可以看出,梯度下降法的估计较为准确。
description