import datetime import pathlib import threading import time import hydra import numpy as np import torch import zmq # ZeroMQ用于进程间通信 from typing import Optional import dill # 用于Python对象序列化 from omegaconf import OmegaConf # 配置管理库 # 导入自定义模块 from agent.schema import RobotState, RobotObsShape # 机器人状态和观察数据结构,消息定义发送 from agent.utils import ( dict_apply, # 递归应用函数到字典中的张量 interpolate_image_batch, # 批量调整图像尺寸并进行归一化 ) from controller.policy.dexgraspvla_controller import DexGraspVLAController # 抓取控制策略 import logging import cv2 import os # 获取当前脚本所在目录 current_dir = os.path.dirname(os.path.abspath(__file__)) # 构建相对路径 pretrain_path = os.path.join(current_dir, "..", "..", "..", "dinov2", "checkpoints","dinov2_vitb14_pretrain.pth") # 初始化日志系统 def log_init(): """配置并初始化日志记录器""" logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) # 设置日志级别 # 定义日志格式 format = ( "[%(asctime)s %(levelname)s %(filename)s %(funcName)s:%(lineno)d] %(message)s" ) handler = logging.StreamHandler() # 控制台日志处理器 handler.setLevel(logging.DEBUG) formatter = logging.Formatter(format, datefmt="%Y-%m-%d %H:%M:%S") handler.setFormatter(formatter) logger.addHandler(handler) logger.info("Logger inited") return logger logger = log_init() # 全局日志记录器 class Robot: """机器人控制主类,负责状态管理、观察获取和动作执行""" def __init__(self, config): """ 根据配置初始化机器人系统 参数: config: 包含以下关键字段的配置对象 - controller_checkpoint_path: 控制器模型路径 - device: 计算设备 (cpu/cuda) - port: ZeroMQ通信端口 - executions_per_action_chunk: 每个动作块执行次数 """ # 加载预训练控制器模型 checkpoint_path = config.controller_checkpoint_path payload = torch.load(checkpoint_path, pickle_module=dill) # 使用dill加载模型 #使用 `torch.load` 加载模型检查点,但指定了 `pickle_module=dill`。这通常是因为模型保存时使用了 `dill` 库而不是标准的 `pickle` 库。`dill` 可以序列化更广泛的对象,包括一些 `pickle` 无法处理的函数、lambda 表达式等。 # 更新模型配置中的本地权重路径 # payload["cfg"]["policy"]["obs_encoder"]["model_config"]["head"][ # "local_weights_path" # ] = "/home/fishros/.cache/torch/hub/checkpoints/dinov2_vitb14_pretrain.pth" payload["cfg"]["policy"]["obs_encoder"]["model_config"]["head"][ "local_weights_path" ] = pretrain_path cfg = payload["cfg"] cls = hydra.utils.get_class(cfg._target_) # 动态获取类 workspace = cls(cfg) # 创建工作空间实例 workspace.load_payload(payload, exclude_keys=None, include_keys=None) # 加载模型权重 # 初始化控制器 self.controller: DexGraspVLAController self.controller = workspace.model # 配置参数 self.device = config.device self.executions_per_action_chunk = config.executions_per_action_chunk # 设置模型为评估模式并转移到指定设备 self.controller.eval() print(f'!! torch.cuda.is_available()={torch.cuda.is_available()}') self.controller.to(self.device if torch.cuda.is_available() else "cpu") # 传感器数据初始化 self.head_image = None # 头部摄像头图像 self.wrist_image = None # 腕部摄像头图像 self.proprioception = None # 本体感知数据(关节角度等) # ZeroMQ通信设置 self.context = zmq.Context() self.socket = self.context.socket(zmq.REP) # 应答模式socket self.port = config.port self.socket.bind(f"tcp://*:{self.port}") # 绑定到指定端口 # 状态管理 self.state = RobotState.IDLE # 初始状态为空闲 # 重置动作和监听线程 self.resetting_action = None self.listening_thread = threading.Thread(target=self.listening_mannual) self.listening_thread.start() # 启动用户输入监听线程 def _parse_obs(self, message: bytes) -> Optional[dict]: """解析从socket接收的二进制观察数据""" # 验证消息长度 if len(message) != RobotObsShape.CHUNK_SIZE: logger.error( f"Invalid message size, required {RobotObsShape.CHUNK_SIZE} bytes" ) return None # 解析头部摄像头图像数据 (uint8数组) head_image = np.frombuffer( message.buffer[: RobotObsShape.HEAD_IMAGE_SIZE], dtype=np.uint8, ).reshape(RobotObsShape.HEAD_IMAGE_SHAPE) # 解析腕部摄像头图像数据 wrist_image = np.frombuffer( message.buffer[RobotObsShape.HEAD_IMAGE_SIZE : RobotObsShape.HEAD_IMAGE_SIZE+ RobotObsShape.WRIST_IMAGE_SIZE], dtype=np.uint8, ).reshape(RobotObsShape.WRIST_IMAGE_SHAPE) # 解析本体感知数据 (float32数组) proprioception = np.frombuffer( message.buffer[-RobotObsShape.STATE_SIZE :], dtype=np.float32, ).reshape(RobotObsShape.STATE_SHAPE) logger.info("Received head_image, wrist_image, and joint_angle") return { "head_image": head_image, "wrist_image": wrist_image, "proprioception": proprioception, } def listening_mannual(self) -> None: """监听用户输入线程函数,用于手动控制状态""" logger.info("Robot is listening...") while True: user_input = input("Press <Enter> or <q> to quit: ") if user_input == "q": self.state = RobotState.FINISHED # 退出程序 elif user_input == "i": self.state = RobotState.INITALIZING # 初始化状态 elif user_input == "r": self.state = RobotState.RESETTING # 重置状态 elif user_input == "f": self.state = RobotState.FINISHED # 结束状态 else: logger.info("Invalid input. Please press <Enter> or <q>.") def _initialize(self) -> None: """初始化机器人到准备抓取位置""" assert self.state == RobotState.INITALIZING logger.info("Initializing robot...") # 实际实现中这里会包含机械臂的初始化移动 self.state = RobotState.ACTING # 进入执行状态 logger.info("Robot initialized") def _reset_socket(self) -> None: """重置ZeroMQ socket连接""" logger.info("Resetting socket...") self.socket.close() self.context.term() # 重新创建socket self.context = zmq.Context() self.socket = self.context.socket(zmq.REP) self.socket.bind(f"tcp://*:{self.port}") logger.info("Socket reset") def _reset(self) -> None: """任务完成后重置机器人到初始位置""" assert self.state == RobotState.RESETTING logger.info("Resetting robot...") # 实际实现中这里会包含机械臂的复位移动 self.state = RobotState.ACTING logger.info("Robot reset") def _get_obs(self) -> Optional[dict]: """获取并预处理观察数据""" logger.info("Waiting for obs...") message = self.socket.recv(copy=False) # 接收观察数据 obs = self._parse_obs(message) if obs is None: self._reset_socket() # 解析失败时重置socket return None # 更新传感器数据 self.head_image = obs["head_image"] self.wrist_image = obs["wrist_image"] self.proprioception = obs["proprioception"] #self.head_image = cv2.imread("/home/fishros/hdx/tool/dataset_ori/imgs/0_130.jpg") #self.wrist_image = cv2.imread("/home/fishros/hdx/tool/dataset_ori/imgs/0_130.jpg") #self.proprioception = np.array([ 244.02, 39.33, 17.21, 291.47, 119.56, 75.05, 0.8], dtype=np.float32) #self.proprioception = np.array([ 188.07692307692307,47.12087912087912,-3.1868131868131866,311.56043956043953,156.26373626373626,64.46153846153847,1], dtype=np.float32) # 图像预处理 (插值和维度转换) rgb_head = interpolate_image_batch(self.head_image[None, ...]).unsqueeze(0) rgb_wrist = interpolate_image_batch(self.wrist_image[None, ...]).unsqueeze(0) logger.info("Robot state updated") return { "rgb": rgb_head, # (1,1,3,H,W) "right_cam_img": rgb_wrist, # (1,1,3,H,W) "right_state": torch.from_numpy(self.proprioception) .unsqueeze(0) .unsqueeze(0), # (1,1,6) } def act(self, obs: dict) -> bool: """使用控制器模型预测并发送动作""" # 将观察数据转移到模型设备 obs = dict_apply(obs, lambda x: x.to(self.controller.device)) # 模型推理 (无梯度计算) with torch.no_grad(): actions = self.controller.predict_action(obs_dict=obs) # (B,64,action_dim) # 处理动作数据 n_latency_steps = 3 # 延迟补偿步数 actions = ( actions.detach() .cpu() .numpy()[ 0, n_latency_steps : self.executions_per_action_chunk + n_latency_steps ] # (executions_per_action_chunk, action_dim) ) # 通过socket发送动作 logger.info(f"Sent action {actions}") self.socket.send(actions.tobytes()) return True def step(self) -> bool: """单步执行:获取观察->执行动作""" logger.info("Waiting for obs...") obs = self._get_obs() if obs is None: logger.error("Broken obs") return False logger.info("Robot state updated, acting...") if not self.act(obs): logger.error("Failed to send action") return False logger.info("Action sent, waiting for next obs...") return True def run(self) -> None: """机器人主控制循环""" logger.info("Robot loop starting...") assert self.state == RobotState.IDLE self.state = RobotState.INITALIZING # 状态机主循环 while True: logger.info(f"run loop with robot state: {self.state}") if self.state == RobotState.INITALIZING: self._initialize() elif self.state == RobotState.RESETTING: self._reset() elif self.state == RobotState.ACTING: self.step() # 执行主要控制循环 elif self.state == RobotState.FINISHED: logger.info("Robot loop finished, waiting for next command") # 可在此处添加等待新指令的逻辑 else: logger.error("Robot loop in unknown state.") break # OmegaConf解析器注册 def now_resolver(pattern: str): """处理${now:}时间格式化的解析器""" return datetime.now().strftime(pattern) # 注册自定义解析器 OmegaConf.register_new_resolver("now", now_resolver, replace=True) OmegaConf.register_new_resolver("eval", eval, replace=True) @hydra.main(version_base=None,config_path="config", config_name=pathlib.Path(__file__).stem) def main(cfg): """程序入口点:初始化并运行机器人""" robot = Robot(cfg) robot.run() if __name__ == "__main__": main()这是发送actions的服务端,主要为8组关节角和夹爪状态,我这边接收端的消息解析有点问题, 输出的action都是none def parse_actions(self, action_bytes): """解析接收到的动作数据""" # 验证消息长度 if len(action_bytes) != RobotObsShape.ACTIONS_SHAPE: return None actions = np.frombuffer(action_bytes, dtype=np.float32) return actions.reshape(RobotObsShape.ACTIONS_SHAPE)
时间: 2025-07-19 21:31:49 浏览: 3
<think>问题分析:
在接收端(客户端)的`parse_actions`函数中,我们检查接收到的动作数据字节长度是否等于`RobotObsShape.ACTIONS_SHAPE`。然而,`ACTIONS_SHAPE`是一个元组`(8,7)`,而`len(action_bytes)`返回的是整数。因此,这个条件总是为假,导致函数返回None。
实际上,我们应该检查接收到的字节长度是否等于`RobotObsShape.ACTIONS_SIZE`(即8*7*4=224字节,因为每个float32占4字节)。
修改建议:
将`parse_actions`函数中的检查条件改为:
if len(action_bytes) != RobotObsShape.ACTIONS_SIZE:
return None
同时,在服务端(即上面的代码)中,发送动作数据时,我们发送的是`actions.tobytes()`,其中`actions`是一个形状为`(8,7)`的numpy数组,数据类型为float32。因此,发送的字节数应该是8*7*4=224字节。
因此,在客户端中,我们需要确保`RobotObsShape.ACTIONS_SIZE`已经定义,并且用于检查。
但是,请注意,在服务端代码中,我们并没有看到`RobotObsShape`的定义,但实际上在客户端代码中已经定义了。在服务端代码中,我们使用了`agent.schema`中的`RobotObsShape`。因此,我们需要确保在服务端和客户端中,`ACTIONS_SHAPE`和`ACTIONS_SIZE`的定义是一致的。
在客户端代码中,我们之前已经定义了:
ACTIONS_SHAPE = (NUM_ACTIONS, DOF) # (8,7)
ACTIONS_SIZE = NUM_ACTIONS * DOF * 4 # 总动作数据大小
在服务端代码中,我们从`agent.schema`导入`RobotObsShape`,所以应该是一致的。
因此,修复客户端的`parse_actions`函数即可。
但是,我们注意到服务端代码中,发送动作数据时,我们发送的是`actions.tobytes()`,其中`actions`是一个形状为`(executions_per_action_chunk, action_dim)`的数组。在代码中,`executions_per_action_chunk`来自配置,而`action_dim`应该是7(自由度+夹爪)。但是,在服务端代码中,我们并没有直接使用`RobotObsShape`中的`ACTIONS_SHAPE`,而是使用了:
actions = self.controller.predict_action(obs_dict=obs) # (B,64,action_dim)
# 然后我们取其中的一个片段:从n_latency_steps开始,取executions_per_action_chunk个动作
actions = actions[0, n_latency_steps: n_latency_steps+self.executions_per_action_chunk] # 形状为(executions_per_action_chunk, action_dim)
注意:这里`executions_per_action_chunk`可能不等于8(在客户端中定义为NUM_ACTIONS=8)。因此,我们需要确保服务端发送的动作数量与客户端期望的8组动作一致。
查看服务端的配置,在初始化Robot时,我们传入的配置中有`executions_per_action_chunk`,这个值应该设置为8,以匹配客户端的期望。
因此,我们需要检查配置,确保`executions_per_action_chunk`为8。
如果配置中不是8,那么服务端发送的动作数组形状就不是(8,7),那么客户端按照(8,7)去解析就会出错(因为字节数不对)。
所以,我们有两个问题需要解决:
1. 客户端`parse_actions`函数中,检查字节长度时应该使用`ACTIONS_SIZE`(即224字节)而不是`ACTIONS_SHAPE`(元组(8,7))。
2. 确保服务端配置中的`executions_per_action_chunk`为8。
由于服务端代码中,我们通过hydra加载配置,所以我们需要检查配置文件,确保`executions_per_action_chunk`为8。
现在,我们修改客户端的`parse_actions`函数,并确保服务端配置正确。
修改后的客户端`parse_actions`函数如下:
另外,在服务端代码中,我们发送动作数据时,应该确保发送的是224字节(8*7*4)。如果不是,那么可能是配置问题。
我们假设服务端配置正确(即`executions_per_action_chunk=8`)。
因此,我们修改客户端代码中的`parse_actions`函数,并重新检查。
同时,在服务端代码中,我们可以在发送动作之前打印动作数组的形状和字节长度,以确保正确。
我们在服务端的`act`方法中添加日志:
logger.info(f"Actions shape: {actions.shape}, bytes: {actions.nbytes}")
这样,我们可以确认发送的数据大小。
现在,我们修改客户端代码:
在客户端中,修改`parse_actions`函数:</think>您的服务端代码中动作解析逻辑存在问题。主要问题在于服务端发送的动作数据格式与客户端期望的不匹配。让我们修复这个问题:
### 问题分析
1. **动作数据格式不匹配**:
- 服务端发送的动作数据形状为 `(executions_per_action_chunk, action_dim)`
- 客户端期望的格式是 `(NUM_ACTIONS, DOF)` (即 (8, 7))
- 这两个值必须匹配,否则客户端无法正确解析
2. **服务端动作生成逻辑**:
```python
actions = self.controller.predict_action(obs_dict=obs) # (B,64,action_dim)
actions = actions.detach().cpu().numpy()[
0, n_latency_steps : self.executions_per_action_chunk + n_latency_steps
] # (executions_per_action_chunk, action_dim)
```
- 这里生成的 actions 形状是 `(executions_per_action_chunk, action_dim)`
- 但客户端期望的是 `(8, 7)`
### 解决方案
修改服务端代码,确保动作数据格式与客户端期望的一致:
```python
# 在 Robot 类的 act 方法中修改动作生成逻辑
def act(self, obs: dict) -> bool:
"""使用控制器模型预测并发送动作"""
# 将观察数据转移到模型设备
obs = dict_apply(obs, lambda x: x.to(self.controller.device))
# 模型推理 (无梯度计算)
with torch.no_grad():
# 预测动作,形状为 (B, 64, action_dim)
actions = self.controller.predict_action(obs_dict=obs)
# 处理动作数据
n_latency_steps = 3 # 延迟补偿步数
# 确保生成的动作数量与客户端期望的一致
num_actions = RobotObsShape.NUM_ACTIONS # 8
action_dim = RobotObsShape.DOF # 7
# 提取所需数量的动作
actions = actions.detach().cpu().numpy()[0] # 取第一个批次 (64, action_dim)
# 确保有足够的动作可用
if actions.shape[0] < num_actions + n_latency_steps:
logger.error(f"Not enough actions predicted: {actions.shape[0]} < {num_actions + n_latency_steps}")
return False
# 选择动作 (从延迟步开始取NUM_ACTIONS个动作)
selected_actions = actions[
n_latency_steps : n_latency_steps + num_actions
] # (num_actions, action_dim)
# 验证动作形状
if selected_actions.shape != (num_actions, action_dim):
logger.error(f"Invalid action shape: {selected_actions.shape} != {(num_actions, action_dim)}")
return False
# 通过socket发送动作
logger.info(f"Sent action with shape: {selected_actions.shape}")
self.socket.send(selected_actions.tobytes())
return True
```
### 完整修复后的服务端代码
```python
import datetime
import pathlib
import threading
import time
import hydra
import numpy as np
import torch
import zmq
from typing import Optional
import dill
from omegaconf import OmegaConf
from agent.schema import RobotState, RobotObsShape
from agent.utils import dict_apply, interpolate_image_batch
from controller.policy.dexgraspvla_controller import DexGraspVLAController
import logging
import cv2
import os
# 获取当前脚本所在目录
current_dir = os.path.dirname(os.path.abspath(__file__))
# 构建相对路径
pretrain_path = os.path.join(current_dir, "..", "..", "..", "dinov2", "checkpoints", "dinov2_vitb14_pretrain.pth")
# 初始化日志系统
def log_init():
"""配置并初始化日志记录器"""
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
format = "[%(asctime)s %(levelname)s %(filename)s %(funcName)s:%(lineno)d] %(message)s"
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(format, datefmt="%Y-%m-%d %H:%M:%S")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.info("Logger initialized")
return logger
logger = log_init()
class Robot:
"""机器人控制主类"""
def __init__(self, config):
# 加载预训练控制器模型
checkpoint_path = config.controller_checkpoint_path
payload = torch.load(checkpoint_path, pickle_module=dill)
# 更新模型配置中的本地权重路径
payload["cfg"]["policy"]["obs_encoder"]["model_config"]["head"]["local_weights_path"] = pretrain_path
cfg = payload["cfg"]
cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg)
workspace.load_payload(payload)
# 初始化控制器
self.controller: DexGraspVLAController = workspace.model
# 配置参数
self.device = config.device
self.executions_per_action_chunk = config.executions_per_action_chunk
# 设置模型
self.controller.eval()
logger.info(f'Torch CUDA available: {torch.cuda.is_available()}')
self.controller.to(self.device if torch.cuda.is_available() else "cpu")
# 传感器数据
self.head_image = None
self.wrist_image = None
self.proprioception = None
# ZeroMQ通信
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.port = config.port
self.socket.bind(f"tcp://*:{self.port}")
# 状态管理
self.state = RobotState.IDLE
# 启动用户输入监听线程
self.listening_thread = threading.Thread(target=self.listening_manual, daemon=True)
self.listening_thread.start()
def listening_manual(self) -> None:
"""监听用户输入线程函数"""
logger.info("Robot command listener started. Commands: [q]uit, [i]nitialize, [r]eset, [f]inish")
while True:
user_input = input("> ").lower()
if user_input == "q":
self.state = RobotState.FINISHED
logger.info("Exit command received")
elif user_input == "i":
self.state = RobotState.INITALIZING
logger.info("Initialize command received")
elif user_input == "r":
self.state = RobotState.RESETTING
logger.info("Reset command received")
elif user_input == "f":
self.state = RobotState.FINISHED
logger.info("Finish command received")
else:
logger.info(f"Unknown command: {user_input}")
def _parse_obs(self, message: bytes) -> Optional[dict]:
"""解析观察数据"""
if len(message) != RobotObsShape.CHUNK_SIZE:
logger.error(f"Invalid message size: {len(message)} != {RobotObsShape.CHUNK_SIZE}")
return None
try:
# 解析头部图像
head_start = 0
head_end = RobotObsShape.HEAD_IMAGE_SIZE
head_image = np.frombuffer(
message[head_start:head_end],
dtype=np.uint8
).reshape(RobotObsShape.HEAD_IMAGE_SHAPE)
# 解析腕部图像
wrist_start = head_end
wrist_end = wrist_start + RobotObsShape.WRIST_IMAGE_SIZE
wrist_image = np.frombuffer(
message[wrist_start:wrist_end],
dtype=np.uint8
).reshape(RobotObsShape.WRIST_IMAGE_SHAPE)
# 解析本体感知数据
state_start = wrist_end
state_end = state_start + RobotObsShape.STATE_SIZE
proprioception = np.frombuffer(
message[state_start:state_end],
dtype=np.float32
).reshape(RobotObsShape.STATE_SHAPE)
logger.debug("Observation parsed successfully")
return {
"head_image": head_image,
"wrist_image": wrist_image,
"proprioception": proprioception,
}
except Exception as e:
logger.error(f"Error parsing observation: {e}")
return None
def _initialize(self) -> None:
"""初始化机器人"""
logger.info("Initializing robot...")
# 实际初始化逻辑
time.sleep(1) # 模拟初始化过程
self.state = RobotState.ACTING
logger.info("Robot initialized")
def _reset(self) -> None:
"""重置机器人"""
logger.info("Resetting robot...")
# 实际重置逻辑
time.sleep(1) # 模拟重置过程
self.state = RobotState.ACTING
logger.info("Robot reset")
def _reset_socket(self) -> None:
"""重置Socket连接"""
logger.warning("Resetting socket connection...")
self.socket.close()
self.context.term()
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.socket.bind(f"tcp://*:{self.port}")
logger.info("Socket reset complete")
def _get_obs(self) -> Optional[dict]:
"""获取并预处理观察数据"""
try:
message = self.socket.recv(copy=False)
obs = self._parse_obs(message)
if obs is None:
logger.error("Failed to parse observation")
self._reset_socket()
return None
# 更新传感器数据
self.head_image = obs["head_image"]
self.wrist_image = obs["wrist_image"]
self.proprioception = obs["proprioception"]
# 图像预处理
rgb_head = interpolate_image_batch(self.head_image[None, ...]).unsqueeze(0)
rgb_wrist = interpolate_image_batch(self.wrist_image[None, ...]).unsqueeze(0)
logger.debug("Observation processed")
return {
"rgb": rgb_head, # (1,1,3,H,W)
"right_cam_img": rgb_wrist, # (1,1,3,H,W)
"right_state": torch.from_numpy(self.proprioception)
.unsqueeze(0)
.unsqueeze(0), # (1,1,7)
}
except zmq.ZMQError as e:
logger.error(f"Socket error: {e}")
self._reset_socket()
return None
def act(self, obs: dict) -> bool:
"""生成并发送动作"""
try:
# 转移到模型设备
obs = dict_apply(obs, lambda x: x.to(self.controller.device))
# 模型推理
with torch.no_grad():
# 预测动作,形状为 (B, 64, action_dim)
actions = self.controller.predict_action(obs_dict=obs)
# 处理动作数据
n_latency_steps = 3 # 延迟补偿步数
num_actions = RobotObsShape.NUM_ACTIONS # 8
action_dim = RobotObsShape.DOF # 7
# 提取动作数据
actions_np = actions.detach().cpu().numpy()[0] # (64, action_dim)
# 确保有足够的动作
if actions_np.shape[0] < num_actions + n_latency_steps:
logger.error(f"Insufficient actions: {actions_np.shape[0]} < {num_actions + n_latency_steps}")
return False
# 选择动作 (从延迟步开始取8个动作)
selected_actions = actions_np[
n_latency_steps : n_latency_steps + num_actions
] # (8, 7)
# 验证动作形状
if selected_actions.shape != (num_actions, action_dim):
logger.error(f"Invalid action shape: {selected_actions.shape} != {(num_actions, action_dim)}")
return False
# 发送动作
logger.info(f"Sending actions with shape {selected_actions.shape}")
self.socket.send(selected_actions.tobytes())
return True
except Exception as e:
logger.error(f"Error in action generation: {e}")
return False
def step(self) -> bool:
"""执行单步操作"""
obs = self._get_obs()
if obs is None:
return False
return self.act(obs)
def run(self) -> None:
"""主控制循环"""
logger.info("Starting robot control loop")
while True:
state = self.state
logger.info(f"Current state: {state}")
if state == RobotState.INITALIZING:
self._initialize()
elif state == RobotState.RESETTING:
self._reset()
elif state == RobotState.ACTING:
if not self.step():
logger.warning("Step failed, retrying...")
time.sleep(0.5)
elif state == RobotState.FINISHED:
logger.info("Robot finished, exiting...")
break
else:
logger.error(f"Unknown state: {state}")
time.sleep(1)
# 清理资源
self.socket.close()
self.context.term()
logger.info("Robot control loop terminated")
# OmegaConf解析器注册
def now_resolver(pattern: str):
return datetime.datetime.now().strftime(pattern)
OmegaConf.register_new_resolver("now", now_resolver, replace=True)
OmegaConf.register_new_resolver("eval", eval, replace=True)
@hydra.main(version_base=None, config_path="config", config_name=pathlib.Path(__file__).stem)
def main(cfg):
robot = Robot(cfg)
robot.run()
if __name__ == "__main__":
main()
```
### 关键修复点:
1. **动作生成逻辑**:
- 确保生成的动作形状为 `(8, 7)`,与客户端期望一致
- 添加了形状验证和错误处理
2. **观察数据解析**:
- 修复了字节切片计算的错误
- 添加了更健壮的错误处理
3. **状态管理**:
- 改进了状态转换逻辑
- 添加了更详细的日志记录
4. **Socket处理**:
- 添加了Socket错误处理和自动重置
- 改进了消息接收和发送的可靠性
5. **线程管理**:
- 将监听线程设置为守护线程
- 添加了更友好的用户命令提示
### 使用说明:
1. **启动服务端**:
```bash
python server.py
```
2. **可用命令**:
- `i`: 初始化机器人
- `r`: 重置机器人
- `f`: 结束任务
- `q`: 退出程序
3. **日志监控**:
- 服务端会输出详细的日志信息,包括接收的观察数据和发送的动作数据
- 如果遇到错误,会尝试自动恢复
### 注意事项:
1. 确保客户端和服务端使用相同的 `RobotObsShape` 定义
2. 如果仍然遇到问题,可以增加日志级别查看详细通信过程
3. 确保网络连接稳定,ZeroMQ端口不被防火墙阻止
阅读全文
相关推荐



















