2024年7月

技术背景

在前面一篇
博客
中,我们介绍了使用Cython加速谐振势计算的方法。有了Cython对于计算过程更加灵活的配置(本质上是时间占用和空间占用的一种均衡),及其接近于C的性能,并且还最大程度上的保留了Python的编程语法特点,因此Cython确实是值得Python编程爱好者学习的一种加速手段。这里我们要介绍的是Cython与C语言相结合的一种方案,可以直接在pyx文件中加载C语言代码。

测试场景

我们测一个非常简单的场景————归约求和:

\[S=\sum_{i,j}A_{i,j}
\]

当然了,像这种基本运算,在Numpy中已经优化的非常极致了。所以,这里我们并不是要展现Cython在性能上的优势,而是Cython对于C语言和Python语言两者的兼容性。首先我们用C语言实现一个归约求和的简单函数:

// array_sum.c
double reduce_sum(int arr_len, double* arr){
    double s=0.0;
    int i;
    for (i=0; i<arr_len; i++){
        s = s + *arr;
        arr++;
    }
    return s;
}

这里我们使用了一个指针数组,然后用for循环进行遍历计算。在Cython中,我们可以使用extern来直接加载C语言中的这个函数:

# test_pointer.pyx
import numpy as np
cimport numpy as np

cdef extern from "array_sum.c":
    double reduce_sum(int arr_len, double* arr)

cpdef rsum(int arr_len, np.ndarray[np.float64_t, ndim=2, mode="c"] arr):
    cdef:
        double* arr_ptr = <double *>arr.data
        double res = 0.0
    res = reduce_sum(arr_len, arr_ptr)
    return res

这里加载了C语言中的
reduce_sum
函数,然后以Cython中定义的
rsum
函数作为一个接口,将传入的numpy数组的内存地址作为指针传给C语言中写好的函数。然后需要对这个pyx文件进行编译构建:

$ cythonize -i test_pointer.pyx

编译完成后会在当前路径下生成
*.c
文件和
*.so
文件:

$ ll | grep test_pointer  
-rw-r--r-- 1 root root  374450 Jul 25 14:52 test_pointer.c
-rwxr-xr-x 1 root root  234848 Jul 25 14:52 test_pointer.cpython-37m-x86_64-linux-gnu.so*
-rw-r--r-- 1 root root     347 Jul 25 15:02 test_pointer.pyx

调用Cython函数

我们可以开启一个Ipython,或者直接在Python脚本文件中调用Cython函数:

In [1]: import numpy as np

In [2]: from test_pointer import rsum

In [3]: num=10000

In [4]: x=np.random.random((num,num))

In [5]: x.shape
Out[5]: (10000, 10000)

In [6]: %timeit s=np.sum(x)
38.3 ms ± 254 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [7]: %timeit rs=rsum(num*num,x)
51.7 ms ± 302 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [15]: np.sum(x)
Out[15]: 50003980.32921535

In [17]: rsum(num*num, x)
Out[17]: 50003980.32921728

经过测试,确实可以在Python中调用这个C语言实现的函数。当然,前面也提到过,Numpy对于这个简单的求和计算已经优化的非常好了,所以这里没有体现出性能上的优势,这里更多的是演示一个方法。

总结概要

这篇文章介绍了Python-Cython-C三种语言的简单耦合,以Cython为中间接口,实现Python数据传到C语言的后端执行相关计算。这就相当于可以在Python中调用C语言中的指针功能来进行跨维度的数组运算,至于性能依然存在优化空间,这里仅仅做一个简单的功能演示。

版权声明

本文首发链接为:
https://www.cnblogs.com/dechinphy/p/cython-c.html

作者ID:DechinPhy

更多原著文章:
https://www.cnblogs.com/dechinphy/

请博主喝咖啡:
https://www.cnblogs.com/dechinphy/gallery/image/379634.html

Rust 中 *、&、mut、&mut、ref、ref mut 的用法和区别


Rust
中,
*

ref

mut

&

ref mut
是用于处理引用、解引用和可变性的关键字和操作符,它们在不同的上下文中有不同的用法。

一、
*
解引用

*
属于操作符

1. 作用

用于解引用指针或引用,以访问其指向的值。
通过解引用,可以从指针或引用中获取实际的值。

2. 用法

2.1. 解引用不可变引用

fn main() {
	let x = 5;
	let y = &x; // y 是对 x 的不可变引用
	println!("y: {}", *y); // 通过解引用 y 获取 x 的值,输出: y: 5
}

2.2. 解引用可变引用

fn main() {
	let mut x = 10;
    let y = &mut x; // y 是对 x 的可变引用
    *y += 5; // 通过解引用 y 修改 x 的值
    println!("x: {}", x); // 输出: x: 15
}

2.3. 解引用指针

fn main() {
	let x = 42;
    let y = &x as *const i32; // 创建不可变裸指针
    unsafe {
        println!("y: {}", *y); // 解引用不可变裸指针
    }
	
	let x = Box::new(10); // Box 智能指针
    println!("x: {}", *x); // 解引用 Box,获取其值,输出: x: 10
}

二、
&
借用引用

&
也是操作符

1. 作用

创建一个值的不可变引用,允许读而不获取所有权,该值在引用期间是只读的。

2. 用法

2.1. 不可变引用

fn main() {
    let x = 10;
    let y = &x; // y 是对 x 的不可变引用
    println!("y: {}", y); // 输出: y: 10
}

2.2. 函数中的借用

fn print_value(x: &i32) {
    println!("Value: {}", x);
}

fn main() {
    let a = 10;
    print_value(&a); // 传递 a 的不可变引用
}

2.3. match 中使用

fn main() {
	let reference = &4;
    match reference {
        &val => println!("Got a value via destructuring: {:?}", val),
    }
}

2.4. 结构体中使用

struct Point<'a> {
    x: &'a i32,
    y: &'a i32,
}
fn main() {
    let x = 10;
    let y = 20;
    let point = Point { x: &x, y: &y }; // 使用引用初始化结构体字段
    println!("Point: ({}, {})", point.x, point.y); // 输出: Point: (10, 20)
}

2.5. 集合中使用

fn main() {
    let vec = vec![1, 2, 3];
    for val in &vec {
        println!("Value: {}", val); // 输出: 1, 2, 3
    }
}

2.6. 切片中使用

fn main() {
    let s = String::from("hello");
    let slice = &s[0..2]; // 创建字符串切片
    println!("Slice: {}", slice); // 输出: Slice: he
}

三、
mut
可变

mut
是一个关键字

1. 作用

声明一个变量或引用为可变的,可以修改其值。

2. 用法

2.1. 可变变量

fn main() {
    let mut x = 5; // x 是可变的
    x += 1;
    println!("x: {}", x); // 输出: x: 6
}

2.2. 函数中可变参数

fn increment(mut num: i32) -> i32 {
    num += 1;
    num
}

2.3. 可变引用

fn main() {
    let mut x = 5;
    let y = &mut x;
    *y += 1;
    println!("{}", x); // 输出 6
}

2.4. 可变结构体

struct Point {
    x: i32,
    y: i32,
}
fn main() {
    let mut p = Point { x: 0, y: 0 };
    p.x = 5;
    p.y = 10;
    println!("Point: ({}, {})", p.x, p.y); // 输出 Point: (5, 10)
}

2.5. 可变元组

let mut tuple = (5, 10);
tuple.0 = 15;

2.6. match 中使用

match Some(10) {
    Some(mut value) => {
        value += 1;
        println!("{}", value); // 输出 11
    }
    None => {},
}

2.7. 集合中使用

let mut vec = vec![1, 2, 3];
for num in &mut vec {
    *num += 1;
}
println!("{:?}", vec);

四、
&mut
可变借用引用

&mut
既不属于操作符也不属于关键字

1. 作用

创建一个值的可变引用,允许修改值而不获取所有权。

2. 用法

2.1. 可变引用

fn main() {
    let mut x = 10;
    {
        let y = &mut x; // y 是对 x 的可变引用
        *y += 5; // 修改 x 的值
    } // y 的生命周期结束,此时 x 的可变借用结束
    println!("x: {}", x); // 输出: x: 15
}

2.2. 函数中的可变引用

fn add_one(x: &mut i32) {
    *x += 1;
}

fn main() {
    let mut a = 10;
    add_one(&mut a); // 传递 a 的可变引用
    println!("a: {}", a); // 输出: a: 11
}

2.3. 结构体中的可变引用

struct Point<'a> {
    x: &'a mut i32,
    y: &'a mut i32,
}
fn main() {
    let mut x = 10;
    let mut y = 20;
    let point = Point { x: &mut x, y: &mut y }; // 使用可变引用初始化结构体字段
    *point.x += 1;
    *point.y += 1;
    println!("Point: ({}, {})", point.x, point.y); // 输出: Point: (11, 21)
}

2.4. 集合中的可变引用

fn main() {
    let mut vec = vec![1, 2, 3];
    for val in &mut vec {
        *val += 1; // 修改集合中的元素
    }
    println!("{:?}", vec); // 输出: [2, 3, 4]
}

2.5. match 中使用

fn main() {
    let mut pair = (10, 20);
    match pair {
        (ref mut x, ref mut y) => {
            *x += 1;
            *y += 1;
            println!("x: {}, y: {}", x, y); // 输出: x: 11, y: 21
        },
    }
}

2.6. 结构体中使用

struct Counter {
    value: i32,
}
impl Counter {
    fn increment(&mut self) {
        self.value += 1;
    }
}
fn main() {
    let mut counter = Counter { value: 0 };
    counter.increment(); // 使用可变引用调用方法
    println!("Counter value: {}", counter.value); // 输出: Counter value: 1
}

五、
ref
模式匹配中创建引用

ref
属于关键字

1. 作用

在模式匹配中借用值的不可变引用,而不是获取所有权。

2. 用法

2.1. 元组中使用

fn main() {
    let tuple = (1, 2);
    let (ref x, ref y) = tuple; // x 和 y 是对 tuple 中元素的不可变引用
    println!("x: {}, y: {}", x, y); // 输出: x: 1, y: 2
}

2.2. match 中使用

fn main() {
    let pair = (10, 20);
    match pair {
        (ref x, ref y) => {
            println!("x: {}, y: {}", x, y); // x 和 y 是 pair 元素的不可变引用
        }
    }
}

2.3. if let / while let 中使用

// if let
fn main() {
    let some_value = Some(42);
    if let Some(ref x) = some_value {
        println!("Found a value: {}", x); // x 是 some_value 的不可变引用
    }
}
// while let
fn main() {
    let mut stack = vec![1, 2, 3];
    while let Some(ref x) = stack.pop() {
        println!("Popped: {}", x); // x 是 stack 中最后一个元素的不可变引用
    }
}

2.4. 函数中使用

fn print_ref((ref x, ref y): &(i32, i32)) {
    println!("x: {}, y: {}", x, y); // x 和 y 是元组元素的不可变引用
}
fn main() {
    let pair = (10, 20);
    print_ref(&pair); // 传递 pair 的引用
}

2.5. for 循环中使用

fn main() {
    let vec = vec![1, 2, 3];
    for ref x in &vec {
        println!("x: {}", x); // x 是 vec 中元素的不可变引用
    }
}

六、
ref mut
模式匹配中创建可变引用

ref mut
属于关键字

1. 作用

在模式匹配中借用值的可变引用,允许修改该值。

2. 用法

2.1. match 中使用

fn main() {
    let mut pair = (10, 20);
    match pair {
        (ref mut x, ref mut y) => {
            *x += 1;
            *y += 1;
            println!("x: {}, y: {}", x, y); // 输出: x: 11, y: 21
        }
    }
    // pair 的值已经被修改
}

2.2. if let / while let 中使用

fn main() {
    let mut some_value = Some(42);
    if let Some(ref mut x) = some_value {
        *x += 1;
        println!("Found a value: {}", x); // 输出: Found a value: 43
    }
}
fn main() {
    let mut stack = vec![1, 2, 3];
    while let Some(ref mut x) = stack.pop() {
        *x += 1;
        println!("Popped: {}", x); // 输出: Popped: 4, Popped: 3, Popped: 2
    }
}

2.3. 函数中使用

fn increment_tuple((ref mut x, ref mut y): &mut (i32, i32)) {
    *x += 1;
    *y += 1;
}

fn main() {
    let mut pair = (10, 20);
    increment_tuple(&mut pair); // 传递 pair 的可变引用
    println!("pair: {:?}", pair); // 输出: pair: (11, 21)
}

2.4. 解构赋值

fn main() {
    let mut pair = (10, 20);
    let (ref mut x, ref mut y) = pair;
    *x += 1;
    *y += 1;
    println!("x: {}, y: {}", x, y); // 输出: x: 11, y: 21
    println!("{:?}", pair); // (11, 21)
}

七、总结

  • *
    :解引用操作符,用于访问指针或引用指向的值的类型。
  • &
    :借用操作符,用于创建不可变引用的类型,允许只读访问。
  • mut
    :关键字,用于声明可变变量或参数的类型,允许其值被修改。
  • &mut
    :借用操作符,用于创建可变引用的类型,允许读写访问。
  • ref
    :模式匹配中的关键字,用于创建不可变引用的类型,避免所有权转移。
  • ref mut
    :模式匹配中的关键字,用于创建可变引用的类型,允许修改引用的值。

工具:

apktool
ADT

命令:

反编译

java -jar apktool.jar d test.apk

重打包

java -jar apktool.jar b test

签名使用ADT

smail语言粗略理解(其实对于修改来说, 大概熟悉就就ok)

类定义

.class public Lcom/example/MyClass;
.super Ljava/lang/Object;
.class 指定类名和修饰符。
.super 指定父类。

字段定义

.field public myField:I
.field 定义字段。
I 表示整数类型(int)。

方法定义

.method public myMethod()V
    .locals 1
    .prologue
    .line 10
    return-void
.end method
.method 定义方法,V 表示返回类型为 void。
.locals 声明局部变量数量。
.prologue 和 .line 用于调试和代码注释。
return-void 表示方法结束并返回。

Smali 使用汇编语言风格的指令,以下是一些常见指令:

加载和存储指令

const/4 v0, 0x1  ; 将整数 1 加载到寄存器 v0

算术指令

add-int v0, v1, v2  ; v0 = v1 + v2

方法调用

invoke-virtual {v0}, Lcom/example/MyClass;->myMethod()

实战

三星通话app修复

这个类找不到

增加类

重打包签名后安装测试!

ok, 报其他类错误了, 依葫芦画瓢一一修复即可大功告成!

题目链接:
https://leetcode.cn/problems/find-bottom-left-tree-value/description/

题目叙述:

给定一个二叉树的 根节点 root,请找出该二叉树的 最底层 最左边 节点的值。

假设二叉树中至少有一个节点。

示例 1:

输入: root = [2,1,3]
输出: 1

示例 2:

输入: [1,2,3,4,null,5,6,null,null,7]
输出: 7

提示:

二叉树的节点个数的范围是 [1,10^4]
-2^31 <= Node.val <= 2^31 - 1

思路:

这题我们有递归和迭代两种写法,我们在这里重点介绍递归的解法,如果用层序遍历的迭代法的话,我们这道题就十分简单了,不过我在后面也会介绍层序遍历的写法。

递归法

递归法我们一定要清楚的是三点:

  1. 我们递归函数要传入的参数和递归函数的返回值
  2. 递归结束的条件(也就是递归的边界)
  3. 单层递归的逻辑

其实本题当中递归里面也蕴含着回溯的逻辑,其实所有的递归算法都离不开回溯,只是我们没有意识到回溯的过程,或者说回溯的过程被隐藏掉了。

下面的代码中我会重点强调回溯的逻辑

步骤1.确定我们的参数和返回值

这题的参数,既然是要求最后一层的最左边的节点,那么我们必然要使用一个参数
depth
来表示深度,然后我们也需要一个参数
maxdepth
来表示当前是否是达到了最大的深度,不过这个
maxdepth
变量不需要

传入函数中,我们可以定义为全局变量,如果depth>maxdepth,就证明当前还未到达最大深度,也就不是我们要处理的最左边的节点了。 同时,我们还需要一个参数
result
来接收我们需要求得这个节点的节点

值,这个变量我们也定义为全局变量。

确定递归的中止条件

我们要处理的是什么节点?是不是叶子节点,我们处理叶子节点的逻辑判断是什么?是不是只需要当前这个节点它的左右孩子都为空的时候,我们就到达了我们需要处理的时候了,这个时候就是返回的时候了。

那我们要处理这个节点,要做些什么事情呢?——我们要判断当前深度是否是最大深度,如果不是,我们就得更新这个最大深度,同时我们要更新result变量的值,然后再返回,这样就处理好了递归的边界条件,

对吧?

这段逻辑的代码如下:

       //处理到叶子节点就返回
        if(cur->left==NULL&&cur->right==NULL){
            if(depth>maxdepth){
                maxdepth=depth;
                result=cur->val;
            }
            return;
        }

单层递归的逻辑

我们现在找到了最深层次的叶子节点,那么我们如何保证它一定是最左边的节点呢?那还不简单嘛!只需要我们处理递归的时候,优先处理左子树,不就能保证我们先处理的是左孩子了嘛!对吧,

这段逻辑的代码如下:

            if(cur->left!=NULL){
            //先让depth++,让他处理下一层的节点
            depth++;
            traversal(cur->left,depth);
            //再让depth--,这就是回溯的过程,退到上一层的节点,再处理右边的子树
            depth--;
        }
            if(cur->right!=NULL){
            //这里也是一样的道理
            depth++;
            traversal(cur->right,depth);
            //这里也是回溯的过程
            depth--;
        }

其实,处理好了这几个边界条件,我们的代码就出来了

整体代码:

class Solution {
public:
    int result=0;
    int maxdepth=INT_MIN;
    void traversal(TreeNode*cur,int depth){
        //处理到叶子节点就返回
        if(cur->left==NULL&&cur->right==NULL){
            if(depth>maxdepth){
                maxdepth=depth;
                result=cur->val;
            }
            return;
        }
            if(cur->left!=NULL){
            //先让depth++,让他处理下一层的节点
            depth++;
            traversal(cur->left,depth);
            //再让depth--,这就是回溯的过程,退到上一层的节点,再处理右边的子树
            depth--;
        }
            if(cur->right!=NULL){
            //这里也是一样的道理
            depth++;
            traversal(cur->right,depth);
            //这里也是回溯的过程
            depth--;
        }
    }
    int findBottomLeftValue(TreeNode* root) {
        traversal(root,0);
        return result;
    }
};

层序遍历(迭代法)

其实,这题使用层序遍历才是最方便,最简单的做法。我们只需要处理每一层的第一个元素,然后处理到最后一层,它自然就是最后一层的左边第一个元素了,这题只需要在层序遍历的模板上面改动一点点

就可以实现了!

如果不会层序遍历的话,推荐去看看我的层序遍历的文章,里面详细讲解了层序遍历实现的过程!

层序遍历:
https://www.cnblogs.com/Tomorrowland/p/18318744

class Solution {
public:
    int findBottomLeftValue(TreeNode* root) {
        queue<TreeNode*> que;
        if (root != NULL) que.push(root);
        int result = 0;
        while (!que.empty()) {
            int size = que.size();
            for (int i = 0; i < size; i++) {
                TreeNode* node = que.front();
                que.pop();
                if (i == 0) result = node->val; // 记录最后一行第一个元素
                if (node->left) que.push(node->left);
                if (node->right) que.push(node->right);
            }
        }
        return result;
    }
};


0 abstract

Preference-based Reinforcement Learning (PbRL) circumvents the need for reward engineering by harnessing human preferences as the reward signal. However, current PbRL methods excessively depend on high-quality feedback from domain experts, which results in a lack of robustness. In this paper, we present RIME, a robust PbRL algorithm for effective reward learning from noisy preferences. Our method utilizes a sample selection-based discriminator to dynamically filter out noise and ensure robust training. To counteract the cumulative error stemming from incorrect selection, we suggest a warm start for the reward model, which additionally bridges the performance gap during the transition from pre-training to online training in PbRL. Our experiments on robotic manipulation and locomotion tasks demonstrate that RIME significantly enhances the robustness of the state-of-the-art PbRL method. Code is available at
https://github.com/CJReinforce/RIME_ICML2024
.

  • background 和 gap:基于偏好的强化学习 (PbRL) 通过利用人类偏好作为奖励信号,来规避奖励工程的需求。然而,目前的 PbRL 方法过度依赖专家的高质量反馈,导致缺乏鲁棒性。
  • method:在本文中,我们介绍了 RIME,这是一种鲁棒的 PbRL 算法,用于从嘈杂的偏好中有效地进行奖励学习。
    • 1 利用一个基于样本选择的鉴别器(discriminator),动态过滤噪声,确保鲁棒训练。
    • 2 为了抵消因错误选择而产生的累积误差(?),提出 reward model 的热启动(warm start),这进一步弥合了 PbRL 中的 pretrain → 正式训练 的性能差距。
  • 实验:在机器人操作(Meta-world)和运动任务(DMControl)上的实验表明,RIME 显著增强了最先进的 PbRL 方法(指 pebble)的稳健性。

1 intro

  • background:PbRL 省去 reward engineering,PbRL 好。
  • gap 1:PbRL 假设 preference 都是专家打的、没有错误,但人类是容易犯错的。
  • gap 2:从 noisy 的标签中学习,也称为鲁棒训练。
    • Song et al. ( 2022) 将鲁棒训练方法分为四个关键类别:鲁棒架构 (Cheng et al., 2020)、鲁棒正则化 (Xia et al., 2020)、鲁棒损失设计 (Lyu & Tsang, 2019) 和样本选择 (Li et al., 2020;Song 等人,2021 年)。
    • 然而,把它们整合到 PbRL 中很难,貌似因为 1 需要大量样本,而 PbRL 的 feedback 数量(我们常跑的几个 benchmark)最多几万;2 RL 训练期间有 distribution shift,破坏了 i.i.d(独立同分布)输入数据的假设,这是支持稳健训练方法的核心原则。
  • 我们提出了 RIME(
    R
    obust preference-based re
    I
    nforcement learning via war
    M
    -start d
    E
    noising discriminator),据他们生成是第一个研究 PbRL noisy label 的工作(?)
  • 主要方法:
    • 1 使用一个 discriminator,用一个阈值找到认为正确的样本
      \(\mathcal D_t\)
      ,再用一个阈值找到 看起来很错误的样本
      \(\mathcal D_f\)
      ,将其翻转,最后我们使用的样本是
      \(\mathcal D_t \cup\mathcal D_f\)
    • 具体的,这里的阈值是交叉熵 loss,有一个理论,感觉很 intuitive,是好工作ww
    • 2 用预训练的 intrinsic reward,初始化训一下 reward model。
    • 具体的,要在预训练时就归一化 intrinsic reward 到 (-1,1),这是因为 reward model 一般采用 tanh 做激活函数,而 tanh 的输出是 (-1,1)。
  • PbRL。
  • learning from noisy labels:
    • 把 intro 的介绍又说了一遍。
    • 提到,在 PbRL 背景下,Xue 等人(2023 年)提出了一种编码器-解码器架构,来模拟不同的人类偏好,但是相比 RIME 的工作,大概需要 100 倍的 preference 数量。
  • Policy-to-Value Reincarnating RL(PVRL):
    • Reincarnate:vt,使投胎、转世、赋予新形体。
    • PVRL,指将次优的 teacher policy 转移到一个 value-based 的 student RL agent(Agarwal 等人,2022 年)。
    • 启发:Uchendu et al. ( 2023) 发现,PVRL 中随机初始化的 Q 网络,会导致 teacher policy 很快被遗忘。
    • gap:在广泛采用的 PbRL pipeline 中,PVRL 挑战也出现在从 pretrain 到 online training 的过渡过程中,但在以前的研究中被忽视了。在 noisy feedback 下,忘记预训练策略的问题变得更加重要,详见第 4.2 节。
    • (这里的预训练指的是 pebble 等工作的 比如说 最大熵预训练策略。
    • 引出 reward model 的热启动。

3 preliminaries

  • PbRL。
  • Unsupervised Pre-training in PbRL:讲了
    pebble
    的预训练。
  • Noisy Preferences in PbRL:讲了
    BPref
    的模仿人类 scripted teacher,使用 error teacher。

4 method: RIME

4.1 RIME 的 denoising discriminator

  • 省流:用各个 (σ0, σ1, p) 的 CELoss 大小,来判断它是正确 / 错误样本,并翻转所有错误样本的 p。
  • 为什么用交叉熵 loss 来判断 是 正确 / 错误样本?
    • 现有研究表明,深度神经网络首先学习可泛化的模式,然后再过度拟合数据中的噪声(Arpit et al., 2017; Li 等人, 2020 年)。
    • 因此,将与较小损失相关的 sample 优先为正确样本,是提高稳健性的有充分依据的方法。(其实没太理解)
  • 回顾
    交叉熵与 KL 散度的关系
  • 如何确定交叉熵 loss 的阈值?
    • 定理 4.1,假设干净数据的 x 交叉熵 loss 以 ρ 为界,即
      \(\mathcal L^\text{CE}(x)\le\rho\)
      ;则有,损坏样本 x 的预测偏好
      \(P_\psi(x)\)
      ,和
      \(\tilde y(x)=1-y\)
      之间的 KL 散度,下限为
      \(D_{\mathrm{KL}}(\tilde{y}(x)\parallel P_{\psi}(x))\geq-\ln\rho+\frac{\rho}{2}+O(\rho^{2})\)
    • 然后,我们制定 KL 散度阈值的下限
      \(\tau_\text{base}=\ln \rho+\alpha\rho\)
      ,以过滤掉不可信样本。其中,
      \(\rho\)
      表示上次更新期间观察到的 可信样本的最大交叉熵 loss,
      \(\alpha\in(0,0.5]\)
      是可调的超参数。
    • 但是还要考虑 distribution shift 问题。为了在 distribution shift 的情况下,增加对干净样本的 tolerance,我们引入一个辅助项
      \(\tau_\text{unc}=\beta_t\cdot s_\mathrm{KL}\)
      ,来表征过滤的不确定性,其中
      \(\beta_t=\max(\beta_\min,\beta_\max-kt)\)
      是随时间变化的参数(β max = 3, β min = 1),
      \(s_\mathrm{KL}\)
      是 KL 散度的标准差(看起来是
      \(D_{\mathrm{KL}}(\tilde{y}(x)\parallel P_{\psi}(x))\)
      的 KL 散度)。这里的 intuition 是,训到 OOD 数据可能导致 CELoss 的波动(其实也没太听懂)
  • 识别可信样本的数据集:
    \(D_t=\{(\sigma^0,\sigma^1,\tilde{y}) | D_{\mathrm{KL}}(\tilde{y}\parallel P_\psi(\sigma^0,\sigma^1))<\tau_{\mathrm{lower}}\}\)
    ,其中
    \(\tau_{\mathrm{lower}}=\tau_{\mathrm{base}}+\tau_{\mathrm{unc}}=-\ln\rho+\alpha\rho+\beta_{t}\cdot s_{\mathrm{KL}}\)
  • 识别不可信样本的数据集:
    \(D_f=\{(\sigma^0,\sigma^1,\tilde{y}) | D_{\mathrm{KL}}(\tilde{y}\parallel P_\psi(\sigma^0,\sigma^1))>\tau_{\mathrm{upper}}\}\)

    \(\tau_{\mathrm{upper}}\)
    貌似是预先定义的值,定义成
    \(3\ln(10)\)
    了。 然后翻转 Df,将翻转后的 Df 与 Dt 并起来,拿去训 reward model。

4.2 reward model 的 warm start

  • 省流:用 intrinsic reward 训一下 reward model。
  • 观察:
    • 观察到在从预训练到在线训练的过渡过程中,性能显著下降(见图 2)。在 noisy feedback的 setting 下,这种差距是可以明显观察到的,并且对鲁棒性是致命的。
    • 在预训练后,PEBBLE 会重置 Q 网络,仅保留预训练的 policy。由于 Q 网络学的是最小化 noisy feedback 的 reward model 下的 TD-error,因此这种 biased Q 函数会导致 policy 学的不好,从而抹去预训练期间的收益。
  • reward model 的 warm start:
    • 具体来说,我们在预训练阶段,先拿 intrinsic reward 训一下 reward model。
    • 由于 reward model 的输出层通常使用 tanh 激活函数(Lee et al., 2021b),因此我们首先将内在奖励归一化到 (-1,1),使用当前已获得的 intrinsic reward 的 mean
      \(\hat r\)
      和 variance
      \(\sigma_r\)
      来做:
      \(r_{\mathrm{norm}}^{\mathrm{int}}(\mathbf{s}_t)=\mathrm{clip}(\frac{r^{\mathrm{int}}(\mathbf{s}_t)-\hat r}{3\sigma_r},-1+\delta,1-\delta)\)
    • 预训练 reward model 的数据,貌似就是
      \((s_t,a_t,r_{\mathrm{norm}}^{\mathrm{int}},s_{t+1})\)
      ,而不是用 segment 的形式。(这里提到一个最近邻,我没太看懂w)

4.3 整体算法流程

在附录 A 放了伪代码。在附录 A 放伪代码,真是好文明。

关键点:

  • 预训练与 reward model 的 warm start:
    • 第 5 行,收集的 intrinsic reward 是归一化过的。
    • 第 10 行,训 reward model 用的是
      \(r_{\mathrm{norm}}^{\mathrm{int}}\)

      \(\hat r\)
      的 MSE,而非 segment。
  • 鉴别错误 preference 的 denoising discriminator:
    • 第 13 行,初始化 ρ 为正无穷。
    • 第 19 行,算 辨别可信样本的阈值 τ lower。
    • 第 24 行,用 可信样本 ∪ 错误样本翻转 的数据集,来算新 ρ,其中 ρ 是 KL 散度的下界。

5 experiments

  • setting:跟 pebble 一样,三个 DMControl + 三个 Meta-world。
  • baselines:
    pebble

    surf

    rune
    、MRN(MRN 我还没看)。
  • error rate(即随机挑选 (σ0,σ1,p) 并翻转 p 的概率)是 0.1 到 0.3。
  • 大量 ablation:
    • 在 Appendix D.3 尝试了更多种 noisy teacher,放在正文的表比的是 各种 noisy teacher 的 average。
    • 与其他稳健的训练方法的比较:自适应去噪训练 (ADT)(Wang 等人,2021 年),即丢弃一定比例的 CELoss 大的样本,貌似效果不错;使用 MAE 和 t-CE 作为替代 CELoss(?)的损失函数;使用标签平滑 (LS)来处理所有 preference label(?)。
    • 居然有真 human,见 Appendix D.4。总反馈量和每个会话的反馈量分别为 100 和 10。任务是 hopper 后空翻(真假的,这么好学(?)难道 hopper 后空翻是一个 只要控制变量拉到极限 就能一直后空翻 的任务嘛)。但是怎么截图变成了 OpenAI gym 而非 DMControl。
    • 增加 feedback 总数,可以有效提升性能。
    • 各个模块是否有效?当反馈数量相当有限时(即,在Walker-walk上),热启动对于鲁棒性至关重要,可以节省 query 数量。