Model 复现系列(三)π0 -- Physical Intelligence Pi-zero(Pi0)

这篇博客就是复现 Physical Intelligence 的 π0 即 Pi0,虽然 PI 已经推出了这个模型的进化版本 Pi0.5,但复现 Pi0 仍然有很重要的意义(主要是很多公司还在基于 Pi0 进行微调,只能说 PI 速度太快了),因此这篇博客先复现 Pi0 后面会再去复现 Pi0.5。

pi0 在部署上存在不少坑,特别对于国内使用而言,因此下文中的第一章需要仔细阅读。


阅读须知

以下是使用本博客前必须知道的几个重点,你应该在完整运行前就需要知晓。

  1. 在 《1.1 软硬件约束》部分说明了使用 pi0 的最低软硬件配置,在执行前必须查看自己的软硬件是否满足条件;
  2. 在《2. 下载模型检查点》部分不是说必须都要下载,但至少下载一个以让后面的示例能够推动;
  3. 在《3. 下载数据集》部分不要求强制下载,如果你想要跑一下流程可以选择最小的那份下载,通常情况下你需要使用自己的数据集微调模型;
  4. 整个博客只有1~2章中的内容是要求全部完成的,后面的章节根据自己需要执行,不要求你每一章都完整执行;
  5. 一些示例中使用到的真机实验可能会因为我这里暂时没有协调到设备而跳过一部分内容,但我会尽可能确保操作是可行的,等我协调到设备后会补充上去;
  6. 因为整篇博客涉及到的内容很多,我会分次进行完善;

1. 前期准备

由于 pi0 相比较于之前 ACT 和 OpenVAL 模型而言更复杂,因此前期准备工作是必不可少的,特别是在 python 环境管理上使用了 uv,对很多新手而言是第一次接触。

1.1 [重要-必读] 软硬件条件

在一切开始之前,首先要检查自己的硬件是否满足要求,特别是 GPU 资源这部分,仓库中 ReadMe 文件中直接提到了对推理、LoRA、Full 三种操作的最小资源,如果你当前中没有准备好下面的最低硬件条件,那么建议先去搞定这块,否则后面的所有工作都无法开展:

ModeMemory RequiredExample GPU
Inference> 8 GBRTX 4090
Fine-Tuning (LoRA)> 22.5 GBRTX 4090
Fine-Tuning (Full)> 70 GBA100 (80GB) / H100

操作系统最好用 Ubuntu 22.04,这一点可能是坑,因为之前复现其他模型的时候出现过 Ubuntu 20.04 和 22.04 动态库不兼容的情况,我未来会在 Ubuntu 20.04 上测试原生和 Docker 两种方式运行,到时候会在博客中补上。

1.2 论文速览

Pi-Zero 的整体框架如下所示,从整体结构图上可以发现 Pi-Zero 最吸引人的地方是其 跨本体 性能,无论是机器人训练数据集还是最终执行端数据,Pi-Zero 都可以实现从本体 A 学到的知识迁移到本体 B 上。

在这里插入图片描述

上图中的模型结构部分如下图所示:

在这里插入图片描述

1.3 拉取源码

这里需要提前明一下,因为 PI 在他的博客和 Paper 中都没有提到他们的源码链接,所以有很多人在找的时候会找到其他人的仓库,其实直接在 Github 中搜索 Physical Intelligence 就可以找到他们的源码仓库。

在这里插入图片描述

我把源码和这篇博客涉及到的数据都放在了网盘中,不方便科学上网的可以直接从这里下,但要注意代码的时效性(2025年06月19日):

链接: https://pan.baidu.com/s/1LndXQWXF5TQpRr_psgcUkQ?pwd=x4ue 提取码: x4ue 

最推荐的仍然是从 Github 仓库中拉取代码:

$ git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git

1.4 环境部署

在环境部署方面源码提供了两种方式:一种是本地 uv ,另一种是使用 docker。

【Note】:后面运行时根据自己实际部署情况运行,这里为了与官方仓库同步就只演示用 uv 的方式运行代码,如果你用了 docker 或 conda 记得用对应的命令。

1.4.1 本地 uv 部署

由于国内访问可能在这一步会出现很多次失败,只要失败了重新执行失败的命令即可(是存在你配置了国内源后仍然很慢的可能的)。虽然说 uv 可以并行下载包能更快部署,但总体进度取决于你的 带宽和网络稳定性 。如果网络有波动,那么很容易这一批已经下载一部分的包都要重新下载。

这一步直接按照 uv 官网 的步骤安装即可,但要注意如果你安装了 conda 并且在 bash 中自动激活了,那么需要先 退出conda 再执行下面的命令:

(base) $ conda deactivate

【Note】:配置国内清华源,如果不配置国内源能卡到你怀疑人生。

$ export UV_DEFAULT_INDEX="https://mirrors.aliyun.com/pypi/simple"

最后再安装 uv :

$ wget -qO- https://astral.sh/uv/install.sh | sh

如果你想要在使用 uv 的时候有自动补全功能则需要执行下面的命令:

$ echo 'eval "$(uv generate-shell-completion bash)"' >> ~/.bashrc

安装 python uv:

$ pip install uv

在安装完后进入到 openpi 的根目录下执行下面的命令:

$ cd openpi
$ GIT_LFS_SKIP_SMUDGE=1 uv sync
$ GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .

然后创建一个临时的 python 脚本试一下 uv 能否正常运行:

print("UV work done")

执行 uv run 命令:

$ uv run test.py

如果运行后报出以下类似的错误:

error: Project virtual environment directory /gemini/code/pi0/openpi/.venv cannot be used because it is not a valid

那么就删除当前目录下的 .venv 文件然后再初始化一下:

【Note】:只要你还在当前的bash没有退出那么下面的操作并不会重新下载刚才的包。

$ rm -rf .venv
$ uv venv

看到下面的提示就说明正常:

在这里插入图片描述

再去执行刚才的 uv run test.py 命令,没有任何报错并输出 UV work done 就说明你的 uv 初始化成功了。如果还有其他错误可以在评论区留言。

1.4.2 [未完待续] 使用 docker

如果你实在受不了 uv 并且对 docker 比较熟悉且,那么就按如下操作在 docker 中使用,前期准备工作具体操作步骤参考文档

如果你想要将 docker 部署到 GPU 服务器上还需要安装下面的库:

$ sudo apt install fuse-overlayfs

如果之前没有安装过 docker 那么按照下面的步骤进行安装:

$ sudo apt-get update
$ sudo apt-get install ca-certificates curl
$ sudo install -m 0755 -d /etc/apt/keyrings
$ sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
$ sudo chmod a+r /etc/apt/keyrings/docker.asc

$ echo \
  "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
  $(. /etc/os-release && echo "${UBUNTU_CODENAME:-$VERSION_CODENAME}") stable" | \
  sudo tee /etc/apt/sources.list.d/docker.list > /dev/null

$ sudo apt-get update

$ sudo apt-get install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin

执行完上面的代码后用 hello-world 镜像测试一下:

$ sudo docker run hello-world

如果你的设备是 Ubuntu 22.04 并且已经部署好了 Docker 那么可以直接使用源码中的脚本快速完成 docker 部署工作:

$ docker compose -f scripts/docker/compose.yml up --build

1.4.3 conda 部署

uv 本质上仍然是对 python 的包进行了一层管理,那么同样可以用 conda 实现。python 版本需要查看 openpi/.python-version 里面只有一行内容 3.11,使用下面的命令创建一个新的 conda 环境:

(base) $ conda create -n pi0 python=3.11

然后切换到 pi0 环境下安装依赖包:

(base) $ conda activate pi0 

剩下的操作就和 1.4.1 完全一致。


2. 下载模型检查点

PI 在其论文中提供了很多很多对比实验,同时开源了对应的检查点,因为本身每个模型都十几 GB,考虑到有些同学不方便下载,这一章节提供了两种获取模型的方法,根据自己情况选择最适合的就行。

2.1 从网盘中拉取

我将检查点全部下载好放在了网盘中的 Models 文件夹中,该文件夹中有两个子文件夹 Base ModelFine-Tuned Model 分别存储了 ReadMe 中对应的基础模型和微调模型,每个文件都是 zip 包:

链接: https://pan.baidu.com/s/10s9o7cgxTmfaZE5-EYzn8Q?pwd=r4xm 提取码: r4xm 

在这里插入图片描述

下载好后需要将其解压到 ~/.cache/openpi/openpi-assets/checkpoints 路径下来告诉脚本已经下载完成,这里以 pi0_fast_droid 模型为例,处理好后你的文件结构应该如下:

$ tree ~/.cache/openpi/openpi-assets/checkpoints/pi0_fast_droid
.
|-- assets
|   `-- droid
|-- norm_stats.json
`-- params
    |-- _CHECKPOINT_METADATA
    |-- _METADATA
    |-- _sharding
    |-- d
    |-- manifest.ocdbt
    `-- ocdbt.process_0

5 directories, 5 files

在这里插入图片描述

2.2 使用脚本下载

同样也可以通过修改下面脚本中的 model_namemodel_link 变量来指定想要的模型。这两个变量其实就是 ReadMe 提供的链接,如 gs://openpi-assets/checkpoints/pi0_fast_droid,那么模型名就是 pi0_fast_droid

为了能更快地下载模型,首先在终端设置下面的环境变量:

$ export HF_ENDPOINT="https://hf-mirror.com" 

【Note】:模型会自动下载到 ~/.cache/openpi/openpi-assets/checkpoints 路径下,如果你想要修改默认下载路径可以在当前bash中执行以下命令:

$ export OPENPI_DATA_HOME="your model save path"

然后再执行下面的脚本,这个脚本要放在 openpi/src 目录下:

from openpi.training import config
from openpi.policies import policy_config
from openpi.shared import download

model_name = "pi0_fast_droid"
model_link = "gs://openpi-assets/checkpoints/pi0_fast_droid"

config = config.get_config(model_name)
checkpoint_dir = download.maybe_download(model_link)

policy = policy_config.create_trained_policy(config, checkpoint_dir)

使用命令开启下载:

$ uv run src/download.py

看到正常读条后就可以让其挂着下载了:

在这里插入图片描述

下载完成后应该可以看到下面的内容:

在这里插入图片描述

如果你在下载完几十 GB 完整模型后出现了如下报错:

OSError: We couldn't connect to 'https://huggingface.co' to load this file, couldn't find it in the cached files and it looks like physical-intelligence/fast is not the path to a directory containing a file named config.json.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.

在这里插入图片描述

需要确定是否设置了环境变量,如果没设置则设置一次后重新执行下载命令,这次下载只会下载缺失的部分,不会下载已经下好的内容:

$ export HF_ENDPOINT="https://hf-mirror.com" 

那么最终你会得到以下输出信息:

在这里插入图片描述

【Note】:下载完成后一定要去 .cache/openpi 文件夹中将模型备份一份,如果后面执行某些命令触发了一些错误,这个文件夹中的数据可能会消失!

上面脚本中可修改的两个变量名列表如下:

【Note】:官方 ReadMe 中的 pi0_basepi0_fast_base 已经改名为 pi0_liberopi0_fast_libero ,如果你使用旧的会报错。

model_namemodel_linkcommentsize
pi0_liberos3://openpi-assets/checkpoints/pi0_libero基础模型0.0 GB
pi0_fast_liberos3://openpi-assets/checkpoints/pi0_fast_libero自回归基础模型0.0 GB
pi0_fast_droids3://openpi-assets/checkpoints/pi0_fast_droid在DROID数据集上训练的可部署模型10.1 GB
pi0_droids3://openpi-assets/checkpoints/pi0_droid在DROID数据集上训练的可部署模型,但语言指令跟随能力可能不强12.3 GB
pi0_aloha_towels3://openpi-assets/checkpoints/pi0_aloha_towelAloha 机器人上开箱即用叠毛巾的模型0.0 GB
pi0_aloha_tupperwares3://openpi-assets/checkpoints/pi0_aloha_tupperwareAloha 机器人上开箱即用从容器中取出食物的模型0.0 GB
pi0_aloha_pen_uncaps3://openpi-assets/checkpoints/pi0_aloha_pen_uncapAloha 机器人上开箱即用打开笔帽的模型0.0 GB

【Note】:huggingface 的第三方仓库会发布一些基于其任务微调好的预训练模型,感兴趣的可以拉一份试试。


3. 数据集相关

pi0 在训练时涉及到了很多数据集,包括第三方的开源数据以及他们自己组编的高质量数据,这一章节将分别介绍如何下载部分数据集并进行简单测试。

【Note】:这里并不要求你将每个数据集都下载下来,而是根据自己实际情况有选择地去下载。如果你已经有了自己构建的数据集,直接跳过这一整章节都是可以的。

根据官方 ReadMe 文件的描述,他们使用了 Libero 数据集的格式作为输入,这个数据集是在一系列 仿真 环境下构建的 单臂 操作数据集,具体有哪些任务可以通过他们的 官网链接 点进去后拉到最下面看到一张大表,里面是 Libero-90 数据集中的一些任务在不同模型下成功率的示例:

在这里插入图片描述

从这个官网上你可以下载得到以下四个数据集:

数据集size描述
LIBERO-Spatial2.88 GB包含10类空间推理任务
LIBERO-Object4.26 GB包含10类对象操作任务
LIBERO-Goal2.88 GB包含10类任务推理任务
LIBERO-10033.98 GB由LIBERO-90 和 LIBERO-10构成

3.1 网盘下载

为了方便取用,我将数据集上传到我的网盘中,此外该网盘下还包含了 5 个样本数据,你可以先下载这几个样本看看 pi0 能够接受的数据格式是怎样的,然后根据这个格式适配自己的数据集。

链接: https://pan.baidu.com/s/1stXb2f1bEuSTBMmyiTINaw?pwd=u5e6 提取码: u5e6 

位置存放在 DataSets 文件夹下:

在这里插入图片描述

在该文件夹下的样本信息如下:

文件名SizeSource
pick_up_the_black_bowl_next_to_the_cookie_box_and_place_it_on_the_plate_demo.hdf5603.05 MBlibero_spatial
pick_up_the_black_bowl_between_the_plate_and_the_ramekin_and_place_it_on_the_plate_demo.hdf5485.21 MBlibero_spatial
pick_up_the_ketchup_and_place_it_in_the_basket_demo.hdf5804.7 MBlibero_object
pick_up_the_salad_dressing_and_place_it_in_the_basket_demo.hdf5664.2 MBlibero_object
put_the_bowl_on_the_stove_demo.hdf5509.1 MBlibero_goal
turn_on_the_stove_demo.hdf5447.5 MBlibero_goal

3.2 网页端下载

打开他们 官网链接 后向下拉看到 Download Datasets 字段,直接点击链接即可进行下载。
在这里插入图片描述

3.3 抽查数据集

按照我个人的习惯,在下载好数据集后都会抽一份看看内容,这里以 libero-goal 文件下的 turn_on_the_stove_demo.hdf5 样本为例,其结构如下图所示:

在这里插入图片描述

使用下面的脚本查看一下这个数据集中有哪些内容:

# !pip install pyh5 matplotlib

import h5py
import matplotlib.pylab as plt
import time

file = "turn_on_the_stove_demo.hdf5"
frames = []

with h5py.File(file, 'r') as f:
    print(f'File {file} contan {f.keys()}')
    demo_0_data = f['data']['demo_0']
    print(f'Length of agentview: {len(demo_0_data["obs"]["agentview_rgb"])}')
    for frame in demo_0_data["obs"]["agentview_rgb"]:
        frames.append(frame)
    
    
plt.ion()
fig, ax = plt.subplots()
img = ax.imshow(frames[0], vmin=0, vmax=255)
ax.axis('off')

for frame in frames:
    img.set_data(frame)
    fig.canvas.draw()
    fig.canvas.flush_events()
    time.sleep(0.02)  # 控制帧率

plt.ioff()
plt.show()

在这里插入图片描述

[未完] 3.4 数据集转换

根据 ReadMe 文件描述如果想要将自己的数据以最低的代价快速在模型上跑起来,可以使用下面的方式


4. 噪声输入推理

通常情况下在决定好使用哪个模型并完成本地部署后可以用 噪声输入 来测试一下自己的部署是否存在问题。这部分内容是源码中 openpi/examples/simple_client 的部分,涉及到的内容为以下两个链接:

在使用命令之前可以先用下面的命令查看你能运行哪些机器人设备,其实主要就看 --env 字段支持哪些型号:

$ uv run examples/simple_client/main.py --help

usage: main.py [-h] [OPTIONS]

Command line arguments.

_─ options ─────────────────────────────────────────────────────────────────────────────────────────────────────────────_
│ -h, --help              show this help message and exit                                                               │
│ --host STR              Host and port to connect to the server. (default: 0.0.0.0)                                    │
│ --port {None}|INT       Port to connect to the server. If None, the server will use the default port. (default: 8000) │
│ --api-key {None}|STR    API key to use for the server. (default: None)                                                │
│ --num-steps INT         Number of steps to run the policy for. (default: 20)                                          │
│ --timing-file {None}|PATH                                                                                             │
│                         Path to save the timings to a parquet file. (e.g., timing.parquet) (default: None)            │
│ --env {ALOHA,ALOHA_SIM,DROID,LIBERO}                                                                                  │
│                         Environment to run the policy in. (default: ALOHA_SIM)                                        │
_───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────_

在这里插入图片描述

命令执行后告诉我们可以给 --env 参数传入 ALOHA,ALOHA_SIM,DROID,LIBERO 这几个字段。

  • 打开一个 终端 1 作为 机器人端 并输入以下命令:
$ uv run examples/simple_client/main.py --env DROID

INFO:root:Waiting for server at ws://0.0.0.0:8000...
INFO:root:Still waiting for server...

这个 examples/simple_client/main.py 文件就是生成了一堆无意义的观察噪声传递给后端服务,执行后会等待服务器就绪:
在这里插入图片描述

  • 开一个 终端2 作为 GPU推理端 并输入以下命令,这一步可能会消耗很长时间,因为要加载模型到 GPU 中,这里会下载一个 paligemma_tokenizer.model ,下载好后便运行推理:
$ uv run scripts/serve_policy.py --env DROID

INFO:root:Loading model...
...
he/openpi/big_vision/paligemma_tokenizer.model
 29%|███████████████▏                                    | 1.19M/4.07M [00:30<02:23, 21.0kiB/s]

在这里插入图片描述

然后回到 终端 1 可以看到 GPU 服务器返回了计算结果,但要注意这里的计算结果只是几帧的数值:

在这里插入图片描述

如果你想要测试连续很多帧的计算可以修改 examples/simple_client/main.py 文件中的 main 函数内的 for 循环,增加循环迭代次数:

在这里插入图片描述

【Note】:无论上面你使用的是哪种部署方式,亦或你打算用后面哪些机器人进行验证,都建议你跑一次本章节的内容,目的在于验证环境是否正常。


[未完] 4. DRIOD 单臂机器人

这部分内容是依照源码中 openpi/examples/droid/ 的部分,涉及到的内容为以下两个链接:

在这里插入图片描述

【未完待续】


5. Aloha 仿真机器人

在没有 Aloha 真机机器人的情况下,可以现在仿真环境中先进行尝试,这部分是依照源码中 openpi/examples/aloha_sim ,涉及到的内容有以下链接:

首先创建一个 uv 的虚拟环境并安装一些依赖库:

uv venv --python 3.10 examples/aloha_sim/.venv
source examples/aloha_sim/.venv/bin/activate
uv pip sync examples/aloha_sim/requirements.txt
uv pip install -e packages/openpi-client

然后运行仿真客户端:

$ MUJOCO_GL=egl python examples/aloha_sim/main.py

在这里插入图片描述

如果在运行后出现 GL 相关的报错,安装下面的库:

$ sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev

除了官方 ReadMe 中提到的上面的库以外,还需要安装以下 python 库:

$ uv pip install PyOpenGL PyOpenGL.accelerate

新开一个终端运行仿真服务端,这里需要联网下载一些资源:

$ uv run scripts/serve_policy.py --env ALOHA_SIM

在这里插入图片描述


[未完] 6. Aloha 真机机器人

Aloha 真机是 pi0 模型重点推动的验证平台,这部分是依照源码中 openpi/examples/aloha_real ,涉及到的内容有以下链接:

真机 Aloha 使用的是松灵的产品,虽然在松灵官网上并没有具体的介绍页,但在其 Github 仓库中却有相关源码,产品照片如下所示:

【未完待续】


[未完] 7. 远程服务器推理 + 本地执行

这部分内容是依照源码中 openpi/docs/remote_inference.md 的部分,涉及到的内容为以下两个链接:

Pi0 提供了在 GPU 服务器上进行推理然后将计算结果传输给真机上,原理是使用 websocket 进行数据通讯。

【未完待续】


[编写中] 8. 使用自己数据微调 Base Model

这部分内容是依照源码中 examples/libero/convert_libero_data_to_lerobot.py 的部分,涉及到的内容为以下两个链接:

因为在这部分中每人使用的数据集在格式上存在不少差异,但受限于精力,我这里只对 rosbagliberoHDF5 三种形式进行 lerobot 转译与微调;

【Note】:考虑到有些人是第一次使用 uv 方式管理包,在适配自己数据集的时候难免出现缺包的情况,想要不污染自己默认环境则使用下面的命令补装 uv 包,前面一定要带上 uv

$ uv pip install <package_name>

8.1 rosbag 转译

因为 rosbag 是最常见的机器人数据包格式,并且我这里使用的设备也是直接产出 rosbag 格式的数据,因此这里先以该格式数据为例进行转译。数据样本我也放在了网盘中,感兴趣的可以自行下载,位置在 DataSets\Fold-T-shirt 中,一共有 10 份数据:

链接: https://pan.baidu.com/s/1T6h2dVI2GqGM2KhTCj1INg?pwd=qyh4 提取码: qyh4 

在这里插入图片描述
首先要安装一些依赖:

$ uv pip install pycryptodomex gnupg rospkg 

examples/aloha_real/convert_rosbag_to_lerobot.py 文件夹下创建一个 convert_rosbag_to_lerobot.py 的脚本,并将以下内容填充进去(我这里删除了很多不必要的内容,只保留了最小结构),这份代码会将数据加载到内存中:

【Note】:我在下面的代码中添加一个将图像写入到内存的操作,目的是随时能抽看一张图像是否正确。

import dataclasses
from pathlib import Path
from typing import Literal

import rosbag
import cv2
import numpy as np
from cv_bridge import CvBridge

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
import torch
import tqdm
import tyro


@dataclasses.dataclass(frozen=True)
class DatasetConfig:
    use_videos: bool = True
    tolerance_s: float = 0.0001
    image_writer_processes: int = 10
    image_writer_threads: int = 5
    video_backend: str | None = None


DEFAULT_DATASET_CONFIG = DatasetConfig()


def create_empty_dataset(
    repo_id: str,
    robot_type: str,
    mode = "video",
    *,
    dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
) -> LeRobotDataset:
    motors = [
        "right_waist",
        "right_shoulder",
        "right_elbow",
        "right_forearm_roll",
        "right_wrist_angle",
        "right_wrist_rotate",
        "right_gripper",
        "left_waist",
        "left_shoulder",
        "left_elbow",
        "left_forearm_roll",
        "left_wrist_angle",
        "left_wrist_rotate",
        "left_gripper",
    ]
    cameras = [
        "cam_high",
        "cam_left_wrist",
        "cam_right_wrist",
    ]

    features = {
        "observation.state": {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": [
                motors,
            ],
        },
        "action": {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": [
                motors,
            ],
        },
    }

    for cam in cameras:
        features[f"observation.images.{cam}"] = {
            "dtype": mode,
            "shape": (3, 480, 640),
            "names": [
                "channels",
                "height",
                "width",
            ],
        }

    return LeRobotDataset.create(
        repo_id=repo_id,
        fps=20,
        robot_type=robot_type,
        features=features,
        use_videos=dataset_config.use_videos,
        tolerance_s=dataset_config.tolerance_s,
        image_writer_processes=dataset_config.image_writer_processes,
        image_writer_threads=dataset_config.image_writer_threads,
        video_backend=dataset_config.video_backend,
    )


def load_raw_episode_data(
    ep_path: Path,
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
    bridge = CvBridge()
    bag = rosbag.Bag(ep_path)
    images = {
        "/camera_f/color/compressed": [],
        "/camera_l/color/compressed": [],
        "/camera_r/color/compressed": []
    }
    state = []
    print(f"total catched {bag.get_compression_info()} topices")
    robot_arm_state = [0.0] * 14
    for topic, msg, t in bag.read_messages():
        if topic in ["/camera_f/color/compressed", "/camera_l/color/compressed", "/camera_r/color/compressed"]:
            cv_image = bridge.compressed_imgmsg_to_cv2(msg)
            images[topic].append(cv_image)
            state.append(robot_arm_state.copy())
            cv2.imwrite("./check_image.png", cv_image)
        elif topic == "/puppet/joint_left":
            robot_arm_state[:7] = msg.position
        elif topic == "/puppet/joint_right":
            robot_arm_state[7:] = msg.position
    bag.close()

    print(f"State Frames = {len(state)}")
    print(f"Image Frames =")
    print(len(images["/camera_f/color/compressed"]))
    print(len(images["/camera_l/color/compressed"]))
    print(len(images["/camera_r/color/compressed"]))

    state = np.array(state, dtype=np.float32)

    print(f'rosbag file {ep_path} parase done.')
    action = state

    imgs_per_cam = {
        "cam_high": np.array(images["/camera_f/color/compressed"]),
        "cam_left_wrist": np.array(images["/camera_l/color/compressed"]),
        "cam_right_wrist": np.array(images["/camera_r/color/compressed"]),
    }
    return imgs_per_cam, state, action


def populate_dataset(
    dataset: LeRobotDataset,
    rosbag_files: list[Path],
    task: str,
    episodes: list[int] | None = None,
) -> LeRobotDataset:
    if episodes is None:
        episodes = range(len(rosbag_files))

    for ep_idx in tqdm.tqdm(episodes):
        ep_path = rosbag_files[ep_idx]
        imgs_per_cam, state, action = load_raw_episode_data(ep_path)
        # num_frames = state.shape[0]
        
        img_frame_counts = [img_array.shape[0] for img_array in imgs_per_cam.values()]
        target_frame_count = min(state.shape[0], *img_frame_counts)

        # 生成等距索引
        indices = np.linspace(0, state.shape[0] - 1, target_frame_count).astype(int)
        state = state[indices]
        action = action[indices]
                        
        num_frames = target_frame_count
        print(f"[INFO] Synced num_frames: {num_frames}")
        
        for i in range(num_frames):
            frame = {
                "observation.state": state[i],
                "action": action[i],
                "task":task
            }
            for camera, img_array in imgs_per_cam.items():
                frame[f"observation.images.{camera}"] = img_array[i]
            dataset.add_frame(frame)
        dataset.save_episode()
    
    return dataset


def port_aloha(
    raw_dir: Path,
    repo_id: str,
    raw_repo_id: str | None = None,
    task: str = "DEBUG",
    *,
    episodes: list[int] | None = None,
    mode: Literal["video", "image"] = "image",
    dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
):
    if not raw_dir.exists():
        if raw_repo_id is None:
            raise ValueError("raw_repo_id must be provided if raw_dir does not exist")

    rosbag_files = sorted(raw_dir.glob("episode_*.bag"))
    print(rosbag_files)

    dataset = create_empty_dataset(
        repo_id,
        robot_type="aloha",
        mode=mode,
        dataset_config=dataset_config,
    )
    print(dataset)

    dataset = populate_dataset(
        dataset,
        rosbag_files,
        task=task,
        episodes=episodes,
    )
    print(dataset)


if __name__ == "__main__":
    tyro.cli(port_aloha)

然后执行命令,其中 --raw_dir 参数是你原始数据路径;--repo-id 是你当前新建的任务名字可以任意,在执行后会将转换好的数据存放在 ~/.cache/huggingface/lerobot 目录下:

【Note】:执行下面命令的时候需要联网下载一些东西。

$ uv run examples/aloha_real/convert_rosbag_to_lerobot.py --raw_dir ~/Desktop/SmallDatas --repo-id vlalab

你可以在终端使用下面的命令以 2s 为单位刷新,实时监控内存使用情况:

$ watch -n 2 "free -h"

正常运行的前提下你会在终端看见一堆读条的操作:

在这里插入图片描述

8.N 微调模型

根据 ReadMe 文件的介绍,最好在正式训练之前对数据进行一次统计以检查,这一步比较耗时目的是统计整个模型中被激活和冻结的部分,参数 --config-name 表示你想要微调配置:

【Note】:这一步首次执行是需要联网的,建议设置好环境变量后再执行。

如果你不清楚这里的配置是什么意思,可以去查看src/openpi/training/config.py 文件:

修改这个文件就可以调整训练配置,包括训练 batch-sizeepochs 等参数。在确定好想要的训练配置之后就可以直接执行下面的命令进行训练或微调(取决于你命令中的配置名),这里以 PI 提供的开箱即用叠毛巾配置为例:

$ export HF_ENDPOINT="https://hf-mirror.com" 
$ uv run scripts/compute_norm_stats.py --config-name pi0_aloha_towel

运行后你会看见代码在拉取一些数据以计算训练数据的归一化信息:

在这里插入图片描述

然后就可以使用下面的命令进行模型训练,该命令会将训练进度记录到控制台,并将检查点保存到检查点目录。为了最大限度地利用 GPU 显存 XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 可以使 JAX 能够使用高达 90% 的 GPU 内存,参数 --exp-name 就是你上面一步中的 --repo-id; 参数 --overwrite 用于覆盖检查点:

【Note】此处就需要用到缓存的模型文件,如果你之前在上面下载过就直接解压到 ~/.cache/openpi/openpi-assets/checkpoints 中,切记你用到哪个模型就需要哪个文件:

$ XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_fast_libero --exp-name=vlalab --overwrite

运行后出现下面这个弹窗直接输入 `3`` 不可视化结果即可:

在这里插入图片描述
然后就是将数据加载进来开始训练:
在这里插入图片描述
【Note】:如果你的设备是 Nvidia 50 系显卡,那么可能会出现以下报错:

在这里插入图片描述
这是由于 CUDA 版本和 torch 版本不匹配导致的,如果你的 CUDA 版本是 12.8 就可以直接重装 torch:

$ uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128

重装 pytorch 版本还报错的话那就真的是你显存不够了,翻到博客的 1.1 节,即便是 LoRA 微调也需要 22 GB 显存,如果你只是想体验下这个微调流程并且确定有大于 22 GB 现存的硬件,那么可以在上面的执行命令中尝试使用 pi_libero ,完整命令如下但需要重新计算对应模型的方差:

$ uv run scripts/compute_norm_stats.py --config-name pi0_libero_low_mem_finetune	# 重新计算方差
$ XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_libero_low_mem_finetune --exp-name=vlalab --overwrite

[未完] 9. 使用 huggingface 第三方微调模型

【未完待续】

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值