class MultiDataLoader(object): def __init__(self, dataloaders, n_batches): if n_batches <= 0: raise ValueError("n_batches should be > 0") self._dataloaders = dataloaders self._n_batches = np.maximum(1, n_batches) self._init_iterators()什么意思
时间: 2023-05-11 16:01:54 浏览: 173
这段代码定义了一个名为 MultiDataLoader 的类,它的构造函数接受两个参数:一个 dataloaders 列表和一个 n_batches 数字。如果 n_batches 小于等于 0,将会抛出 ValueError 异常。该类还有一个名为 _init_iterators 的私有方法。
相关问题
CONF_THRESH = 0.5 IOU_THRESHOLD = 0.4 POSE_NUM = 17 * 3 DET_NUM = 6 SEG_NUM = 32 OBB_NUM = 1 def get_img_path_batches(batch_size, img_dir): ret = [] batch = [] for root, dirs, files in os.walk(img_dir): for name in files: if len(batch) == batch_size: ret.append(batch) batch = [] batch.append(os.path.join(root, name)) if len(batch) > 0: ret.append(batch) return ret def plot_one_box(x, img, color=None, label=None, line_thickness=None): """ description: Plots one bounding box on image img, this function comes from YoLov8 project. param: x: a box likes [x1,y1,x2,y2] img: a opencv image object color: color to draw rectangle, such as (0,255,0) label: str line_thickness: int return: no return """ tl = ( line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 ) # line/font thickness color = color or [random.randint(0, 255) for _ in range(3)] c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) if label: tf = max(tl - 1, 1) # font thickness t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled cv2.putText( img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA, ) class YoLov8TRT(object): """ description: A YOLOv8 class that warps TensorRT ops, preprocess and postprocess ops. """ def __init__(self, engine_file_path): # Create a Context on this device, self.ctx = cuda.Device(0).make_context() stream = cuda.Stream() TRT_LOGGER = trt.Logger(trt.Logger.INFO) runtime = trt.Runtime(TRT_LOGGER) # Deserialize the engine from file with open(engine_file_path, "rb") as f: engine = runtime.deserialize_cuda_engine(f.read()) context = engine.create_execution_context() host_inputs = [] cuda_inputs = [] host_outputs = [] cuda_outputs = [] bindings = [] for binding in engine: print('bingding:', binding, engine.get_binding_shape(binding)) size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size dtype = trt.nptype(engine.get_binding_dtype(binding)) # Allocate host and device buffers host_mem = cuda.pagelocked_empty(size, dtype) cuda_mem = cuda.mem_alloc(host_mem.nbytes) # Append the device buffer to device bindings. bindings.append(int(cuda_mem)) # Append to the appropriate list. if engine.binding_is_input(binding): self.input_w = engine.get_binding_shape(binding)[-1] self.input_h = engine.get_binding_shape(binding)[-2] host_inputs.append(host_mem) cuda_inputs.append(cuda_mem) else: host_outputs.append(host_mem) cuda_outputs.append(cuda_mem) # Store self.stream = stream self.context = context self.engine = engine self.host_inputs = host_inputs self.cuda_inputs = cuda_inputs self.host_outputs = host_outputs self.cuda_outputs = cuda_outputs self.bindings = bindings self.batch_size = engine.max_batch_size self.det_output_length = host_outputs[0].shape[0] def infer(self, raw_image_generator): threading.Thread.__init__(self) # Make self the active context, pushing it on top of the context stack. self.ctx.push() # Restore stream = self.stream context = self.context host_inputs = self.host_inputs cuda_inputs = self.cuda_inputs host_outputs = self.host_outputs cuda_outputs = self.cuda_outputs bindings = self.bindings # Do image preprocess batch_image_raw = [] batch_origin_h = [] batch_origin_w = [] batch_input_image = np.empty(shape=[self.batch_size, 3, self.input_h, self.input_w]) for i, image_raw in enumerate(raw_image_generator): input_image, image_raw, origin_h, origin_w = self.preprocess_image(image_raw) batch_image_raw.append(image_raw) batch_origin_h.append(origin_h) batch_origin_w.append(origin_w) np.copyto(batch_input_image[i], input_image) batch_input_image = np.ascontiguousarray(batch_input_image) # Copy input image to host buffer np.copyto(host_inputs[0], batch_input_image.ravel()) start = time.time() # Transfer input data to the GPU. cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream) # Run inference. context.execute_async(batch_size=self.batch_size, bindings=bindings, stream_handle=stream.handle) # Transfer predictions back from the GPU. cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream) # Synchronize the stream stream.synchronize() end = time.time() # Remove any context from the top of the context stack, deactivating it. self.ctx.pop() # Here we use the first row of output in that batch_size = 1 output = host_outputs[0] # Do postprocess for i in range(self.batch_size): result_boxes, result_scores, result_classid = self.post_process( output[i * self.det_output_length: (i + 1) * self.det_output_length], batch_origin_h[i], batch_origin_w[i] ) # Draw rectangles and labels on the original image for j in range(len(result_boxes)): box = result_boxes[j] plot_one_box( box, batch_image_raw[i], label="{}:{:.2f}".format( categories[int(result_classid[j])], result_scores[j] ), ) return batch_image_raw, end - start def destroy(self): # Remove any context from the top of the context stack, deactivating it. self.ctx.pop() def get_raw_image(self, image_path_batch): """ description: Read an image from image path """ for img_path in image_path_batch: yield cv2.imread(img_path) def get_raw_image_zeros(self, image_path_batch=None): """ description: Ready data for warmup """ for _ in range(self.batch_size): yield np.zeros([self.input_h, self.input_w, 3], dtype=np.uint8) def preprocess_image(self, raw_bgr_image): """ description: Convert BGR image to RGB, resize and pad it to target size, normalize to [0,1], transform to NCHW format. param: input_image_path: str, image path return: image: the processed image image_raw: the original image h: original height w: original width """
<think>我们正在使用一个基于TensorRT加速的YOLOv8推理代码,目前代码只支持检测一个类别(drone)。现在需要扩展为多类别检测。
从代码中可以看出,当前代码中定义了几个常量:
POSE_NUM = 17 * 3 # 关键点检测的数量(每个点3个值)
DET_NUM = 6 # 每个检测框的原始输出数量(x, y, w, h, conf, class_id)? 注意:实际上YOLOv8的输出格式可能不同
SEG_NUM = 32 # 分割掩码的参数数量
OBB_NUM = 1 # 定向边界框(OBB)的参数,这里可能是1个角度?
在`post_process`函数中,我们注意到:
num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM
这个值表示每个检测结果的参数总数。
然而,在标准的YOLOv8目标检测中(不带分割和姿态估计),每个检测框的输出应该是:4个坐标(x,y,w,h) + 1个置信度 + (类别数)个类别分数。
但是,在代码中,DET_NUM被设置为6,这可能是用于一个类别的情况(4坐标+1置信度+1类别分数)?但多类别情况下,类别分数应该是多个。
实际上,YOLOv8的TensorRT输出格式取决于模型。在导出模型时,如果是检测模型,输出维度为[1, 84, 8400](84=4+80,80是COCO类别数)。
但是,我们当前代码中,每个检测框的参数数量是固定的(由DET_NUM+SEG_NUM+...等构成)。这暗示我们使用的可能是YOLOv8的集成输出(同时包含检测、分割、姿态等)?
然而,用户的问题中只提到目标检测,并且当前代码中只处理了检测框(没有分割和姿态)。因此,我们需要明确:
1. 我们使用的模型是什么?是纯检测模型,还是带有其他任务(如分割、姿态)的模型?
2. 当前代码中,DET_NUM=6,这显然只适用于单类别(4坐标+1置信度+1类别分数)或者多类别但只取最大类别分数的情况?
观察代码中的`post_process`函数:
pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :]
然后,在`non_max_suppression`函数中,我们只使用了前6个值(x,y,w,h,conf,class_id)? 因为boxes = prediction[prediction[:, 4] >= conf_thres] 中,第4个是置信度。
实际上,在代码中,每个检测框的参数包括:
[x, y, w, h, conf, class_id, ... (其他参数,如分割、姿态等)]
因此,对于多类别检测,我们需要改变的是:
- 模型输出中,每个检测框的类别分数部分应该是多个分数(每个类别一个分数),而不是一个类别id。
但是,在代码中,我们注意到在`post_process`函数中,我们只取每个检测框的前DET_NUM(6)个参数。这6个参数中,第5个(索引4)是置信度,第6个(索引5)是类别id吗?
然而,在标准的多类别YOLOv8中,每个检测框的输出应该是:
[x, y, w, h, obj_conf, class_conf_0, class_conf_1, ... , class_conf_n]
然后,总的置信度 = obj_conf * max(class_conf) ?而类别id是argmax(class_conf)。
因此,当前代码中DET_NUM=6,只能处理5个类别(因为4坐标+1个obj_conf+5个类别分数=10,但这里DET_NUM=6)?这显然不对。
重新审视:在代码中,我们定义:
DET_NUM = 6
但实际在YOLOv8的TensorRT输出中,每个检测框的向量长度应该是:4(坐标)+ 1(obj_conf) + 类别数(class_conf)。因此,DET_NUM应该等于5+类别数。
那么,为什么这里DET_NUM=6?因为当前模型是单类别(drone)?所以类别数=1,那么总长度=4+1+1=6。
所以,要支持多类别,我们需要:
1. 确保模型是多类别模型(在训练时指定了多个类别)。
2. 修改代码中的DET_NUM,使其等于5+类别数(因为4坐标+1个obj_conf+类别数个类别分数)。
3. 修改后处理部分,正确解析类别分数。
但是,在代码中,我们注意到:
num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM
因此,如果我们修改了DET_NUM,那么整个向量长度就变了。但是,在代码中,我们只使用了每个检测框的前DET_NUM个值(在`non_max_suppression`函数中,我们只取prediction[:, :DET_NUM]?并不是,我们直接取了整个prediction,但在非极大值抑制中只用了前6个值(通过boxes = prediction[prediction[:,4]>=conf_thres]))。
实际上,在`non_max_suppression`函数中,我们传入的是整个预测数组(每个检测框有num_values_per_detection个值),但在函数内部,我们只使用了前6个(即x,y,w,h,conf,class_id)?但是,在单类别情况下,第5个(索引4)是置信度,第6个(索引5)是类别id(只有0,因为只有一个类别)。
对于多类别,我们需要改变的是:在`non_max_suppression`函数中,我们不应该直接使用第5个(索引5)作为类别id,而应该从第5个位置开始(索引4是obj_conf,然后后面是类别分数)计算每个类别的分数,然后取最大值作为类别置信度,并得到类别id。
因此,我们需要重构后处理部分。
步骤:
1. 修改DET_NUM:根据类别数,DET_NUM = 5 + num_classes(因为4坐标+1个obj_conf+num_classes个类别分数)。
但是,注意:在模型导出时,这个值已经固定了。所以,我们应该根据实际模型来设置这个值。
2. 在`post_process`函数中,我们得到每个检测框的向量(长度为num_values_per_detection),然后我们只取前DET_NUM个值(即检测部分),然后我们需要解析这个向量,得到每个类别的分数。
3. 修改`non_max_suppression`函数,使其能够处理多类别。
具体修改如下:
首先,我们需要知道类别数。可以在代码开头定义一个全局变量:
NUM_CLASSES = 80 # 例如,COCO有80个类别
然后,修改:
DET_NUM = 5 + NUM_CLASSES # 4坐标 + 1个obj_conf + NUM_CLASSES个类别分数
注意:在YOLOv8的原始输出中,每个检测框的格式是:
[x, y, w, h, obj_conf, class_conf_0, class_conf_1, ... , class_conf_{NUM_CLASSES-1}]
然后,在`non_max_suppression`函数中,我们需要做如下修改:
- 首先,从每个检测框的向量中取出前DET_NUM个值,然后分离出坐标(前4个)、obj_conf(第5个)和类别分数(后面的NUM_CLASSES个)。
- 计算每个类别的置信度:conf = obj_conf * class_conf_i
- 然后,我们保留那些最大类别置信度超过阈值的框,并记录其类别id和置信度。
但是,当前代码中,我们直接使用了prediction[:,4]作为置信度,这实际上只是obj_conf(没有乘以类别分数)。而且,prediction[:,5]作为类别id,这只有在单类别情况下正确(因为只有一个类别分数,所以类别id总是0)。
因此,我们需要重写`non_max_suppression`函数,使其能够处理多类别。
另外,注意在原始代码中,我们传入`non_max_suppression`函数的prediction参数是每个检测框的完整向量(包括检测、分割、姿态等),但我们只使用前DET_NUM个值(即检测部分)。所以,在函数内部,我们只处理检测部分。
修改后的`non_max_suppression`函数(多类别版本):
步骤:
1. 提取每个检测框的检测部分(前DET_NUM个值)
2. 分离坐标(0:4)、obj_conf(4:5)、类别分数(5:5+NUM_CLASSES)
3. 计算每个类别的置信度 = obj_conf * 类别分数(每个类别分别乘)
4. 找到每个框的最大置信度以及对应的类别id
5. 根据最大置信度进行初步筛选(大于conf_thres)
6. 然后,对每个类别分别进行NMS(或者使用多类别NMS)
但是,为了简单和效率,我们通常采用:
- 先对每个类别进行循环,然后分别做NMS,最后合并。
- 或者,使用单次多类别NMS(如cv2.dnn.NMSBoxes支持多类别)。
然而,由于我们使用的是自己的后处理,我们可以采用以下方法:
1. 先创建一个空列表,用于存放所有保留的框。
2. 遍历每个类别:
- 取出当前类别的置信度
- 筛选出大于阈值的框
- 对这些框进行NMS(使用同一类别内的框)
- 将保留的框加入列表
3. 或者,我们也可以使用另一种常见做法:先找出每个框最大分数对应的类别,然后按照这个类别分组,再对每个类别做NMS。
但是,这样效率较低。另一种高效做法是:先选出每个框的最大分数,然后按照这个最大分数进行排序,然后进行跨类别的NMS(即不同类别的框也参与抑制,但通常不同类别的框不互相抑制?)—— 实际上,不同类别的框即使重叠也不应该被抑制,所以应该按类别分别进行NMS。
然而,在目标检测中,通常NMS是每个类别独立进行的。因此,我们需要按类别处理。
但是,由于我们的检测框数量可能很大,且类别较多,循环处理可能较慢。我们可以使用矩阵操作来避免循环,但这里为了清晰,我们使用循环每个类别的方式。
修改`non_max_suppression`函数:
注意:这个函数现在需要额外的参数`num_classes`,并且需要知道DET_NUM(即检测框向量的长度)?实际上,我们已经在函数内部使用了全局的NUM_CLASSES和DET_NUM。
但是,为了减少全局变量,我们可以将类别数作为参数传入。
由于修改较大,我们重新设计`non_max_suppression`函数:
输入:
prediction: 一个数组,形状为[N, DET_NUM+SEG_NUM+...],但我们现在只关心前DET_NUM个值。
origin_h, origin_w: 原始图像尺寸
conf_thres: 置信度阈值
iou_thres: IOU阈值
num_classes: 类别数
输出:
一个数组,每行是[x1, y1, x2, y2, conf, class_id]
步骤:
1. 提取检测部分(前DET_NUM列) -> dets: [N, DET_NUM]
2. 分离:boxes = dets[:, :4] # [N,4]
obj_conf = dets[:, 4] # [N]
cls_scores = dets[:, 5:5+num_classes] # [N, num_classes]
3. 计算类别置信度:class_conf = obj_conf.reshape(-1,1) * cls_scores [N, num_classes]
4. 找到每个框的最大类别置信度:max_conf = np.max(class_conf, axis=1) [N]
class_id = np.argmax(class_conf, axis=1) [N]
5. 初步筛选:保留max_conf > conf_thres的框索引
mask = max_conf > conf_thres
boxes = boxes[mask]
class_id = class_id[mask]
max_conf = max_conf[mask]
6. 将坐标转换到原始图像(使用xywh2xyxy函数,但注意这个函数原本是针对单框向量设计的,现在boxes是多个框)
7. 对每个类别进行循环,进行NMS:
总保留框列表:keep_boxes = []
对每个类别id c:
从当前类别的框中,选出所有类别id等于c的框(注意:我们上面已经按最大类别筛选,所以每个框的类别就是class_id)
然后,对当前类别的框进行NMS(使用IOU阈值)
将保留的框加入keep_boxes
8. 如果keep_boxes不为空,则合并所有类别的框,然后返回。
但是,这样实现可能会比较慢。另一种方法是:先按类别分组,然后对每个组分别进行NMS,最后合并。
然而,在原始代码中,我们只进行了单类别的NMS(因为当时只有单类别)。现在需要支持多类别,所以必须按类别分组处理。
修改后的`non_max_suppression`函数(多类别版本)如下:
注意:由于原始代码中,这个函数是在类的方法中,所以我们需要修改为类的方法。
但是,由于代码较长,我们只给出关键部分。
然而,考虑到时间,我们也可以采用另一种方式:使用OpenCV的NMSBoxes,它支持多类别。但注意,OpenCV的NMSBoxes需要每个框的坐标和置信度,并且需要按类别提供。
我们可以这样做:
1. 将每个框的坐标转换为xyxy
2. 创建一个列表,存放所有框的坐标(xyxy)和对应的置信度、类别id
3. 调用cv2.dnn.NMSBoxes进行多类别NMS(注意:它要求传入所有框的坐标和每个框的得分,以及类别信息?)
但是,cv2.dnn.NMSBoxes只支持单类别NMS(即每个框只有一个分数),它不能直接处理多类别。所以,我们需要将每个框按照其最大类别分数作为该框的分数,然后使用NMSBoxes进行抑制(这样会抑制不同类别但重叠的框吗?不会,因为NMSBoxes在抑制时并不考虑类别,它只根据框的坐标和分数来抑制。所以,不同类别的框即使重叠也不会被抑制。因此,我们可以将所有框放在一起,然后调用一次NMSBoxes。
具体步骤:
1. 将每个框的坐标转换为xyxy(在原始图像坐标上)
2. 计算每个框的最大类别分数(已经在上面的步骤4中计算)
3. 将所有框(跨类别)的xyxy坐标、最大类别分数、类别id收集起来
4. 使用NMSBoxes进行非极大值抑制(指定score_threshold=conf_thres, nms_threshold=iou_thres)
这样,NMSBoxes会返回保留框的索引。
然后,我们根据这些索引,取出保留的框、分数和类别id。
因此,我们可以这样修改:
在`non_max_suppression`函数中:
... 步骤1-4(得到boxes, max_conf, class_id)...
# 步骤5:筛选
mask = max_conf > conf_thres
boxes = boxes[mask]
max_conf = max_conf[mask]
class_id = class_id[mask]
# 步骤6:将boxes从xywh转换到xyxy(注意:这里boxes是[N,4]格式,每行是[x,y,w,h])
# 注意:原始代码中有一个xywh2xyxy函数,但它处理的是单个框?实际上,我们修改为可以处理多个框。
# 修改xywh2xyxy函数,使其支持批量处理(现在输入是[N,4])
# 由于原始xywh2xyxy函数是针对一个框(一行)的,现在我们需要一个批量版本
# 定义一个新函数:xywh2xyxy_batch(batch_boxes)
# 或者,我们直接在函数内部转换:
# x1 = boxes[:,0] - boxes[:,2]/2
# y1 = boxes[:,1] - boxes[:,3]/2
# x2 = boxes[:,0] + boxes[:,2]/2
# y2 = boxes[:,1] + boxes[:,3]/2
# 但是注意:原始代码中xywh2xyxy函数考虑了预处理时的填充和缩放,所以我们需要用原始图像的尺寸进行转换。因此,我们仍然需要使用类中的xywh2xyxy函数,但需要修改为批量处理。
因此,我们修改类中的`xywh2xyxy`函数,使其支持批量处理。
原函数:
def xywh2xyxy(self, origin_h, origin_w, x):
# x: [center_x, center_y, w, h]
...
修改为:
def xywh2xyxy(self, origin_h, origin_w, x):
"""
x: 一个框或者多个框,形状为[N,4](每行是center_x, center_y, w, h)
返回:转换后的xyxy坐标,形状[N,4]
"""
# 复制原始方法,但改为批量计算
y = np.zeros_like(x)
r_w = self.input_w / origin_w
r_h = self.input_h / origin_h
if r_h > r_w:
# 批量处理
y[:, 0] = x[:, 0] # center_x
y[:, 2] = x[:, 2] # w
y[:, 1] = x[:, 1] - (self.input_h - r_w * origin_h) / 2
y[:, 3] = x[:, 3] - (self.input_h - r_w * origin_h) / 2
y /= r_w
else:
y[:, 0] = x[:, 0] - (self.input_w - r_h * origin_w) / 2
y[:, 2] = x[:, 2] - (self.input_w - r_h * origin_w) / 2
y[:, 1] = x[:, 1]
y[:, 3] = x[:, 3]
y /= r_h
# 将center_x, center_y, w, h 转换为 x1,y1,x2,y2
# 注意:y现在每行是[center_x, center_y, w, h](已经转换到原始图像尺寸)
# 然后,我们转换为xyxy
x1 = y[:, 0] - y[:, 2] / 2
y1 = y[:, 1] - y[:, 3] / 2
x2 = y[:, 0] + y[:, 2] / 2
y2 = y[:, 1] + y[:, 3] / 2
return np.stack([x1, y1, x2, y2], axis=1)
但是,这样修改会破坏原来的单框处理?所以,我们重新写一个批量处理的函数,或者修改原函数使其同时支持单个框和多个框。
为了简单,我们直接修改为批量处理。
注意:原函数中x的输入是[4]或[N,4],输出对应。
修改后的`xywh2xyxy`函数:
def xywh2xyxy(self, origin_h, origin_w, x):
if len(x.shape) == 1:
x = x[None, :] # 变成[1,4]
y = np.zeros_like(x)
r_w = self.input_w / origin_w
r_h = self.input_h / origin_h
if r_h > r_w:
y[:, 0] = x[:, 0]
y[:, 2] = x[:, 2]
y[:, 1] = x[:, 1] - (self.input_h - r_w * origin_h) / 2
y[:, 3] = x[:, 3] - (self.input_h - r_w * origin_h) / 2
y /= r_w
else:
y[:, 0] = x[:, 0] - (self.input_w - r_h * origin_w) / 2
y[:, 2] = x[:, 2] - (self.input_w - r_h * origin_w) / 2
y[:, 1] = x[:, 1]
y[:, 3] = x[:, 3]
y /= r_h
# 转换为xyxy
xyxy_boxes = np.zeros_like(y)
xyxy_boxes[:, 0] = y[:, 0] - y[:, 2] / 2 # x1 = center_x - w/2
xyxy_boxes[:, 1] = y[:, 1] - y[:, 3] / 2 # y1
xyxy_boxes[:, 2] = y[:, 0] + y[:, 2] / 2 # x2
xyxy_boxes[:, 3] = y[:, 1] + y[:, 3] / 2 # y2
return xyxy_boxes
这样,我们就可以批量处理了。
回到`non_max_suppression`函数:
# 转换坐标
xyxy_boxes = self.xywh2xyxy(origin_h, origin_w, boxes) # [N,4] 每行是x1,y1,x2,y2
# 现在,xyxy_boxes是原始图像坐标下的边界框
# 然后,我们使用OpenCV的NMSBoxes
indices = cv2.dnn.NMSBoxes(
bboxes=xyxy_boxes.tolist(),
scores=max_conf.tolist(),
score_threshold=conf_thres,
nms_threshold=iou_thres,
eta=1.0,
top_k=200
)
# indices是一个一维数组,表示保留的框的索引
if len(indices) > 0:
indices = indices.flatten()
result_boxes = xyxy_boxes[indices]
result_scores = max_conf[indices]
result_classid = class_id[indices]
else:
result_boxes = np.array([])
result_scores = np.array([])
result_classid = np.array([])
# 合并结果,返回一个数组,每行是[x1,y1,x2,y2, conf, class_id]
if len(result_boxes) > 0:
result = np.concatenate([
result_boxes,
result_scores[:, None],
result_classid[:, None]
], axis=1)
else:
result = np.zeros((0,6))
return result
这样,我们就完成了多类别的NMS。
但是,注意:在原始代码中,`non_max_suppression`函数返回的是三个独立的数组(result_boxes, result_scores, result_classid)。现在,为了保持接口一致,我们仍然返回三个独立的数组(因为调用方是这么使用的)。
因此,我们可以这样返回:
return result_boxes, result_scores, result_classid
但是,注意:在调用`non_max_suppression`的地方(在`post_process`函数中):
result_boxes, result_scores, result_classid = self.non_max_suppression(...)
所以,我们修改`non_max_suppression`函数,让它返回三个独立的数组(即使没有检测框,返回空数组)。
修改后的`non_max_suppression`函数(多类别版)如下:
def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, iou_thres=0.4):
# prediction: [N, num_values_per_detection] 其中N是检测框数量
# 我们只取前DET_NUM个值(检测部分)
dets = prediction[:, :DET_NUM] # [N, DET_NUM]
# 分离
boxes = dets[:, :4] # [N,4] (x,y,w,h)
obj_conf = dets[:, 4] # [N]
cls_scores = dets[:, 5:5+NUM_CLASSES] # [N, NUM_CLASSES]
# 计算类别置信度
class_conf = obj_conf[:, None] * cls_scores # [N, NUM_CLASSES]
# 找到每个框的最大类别置信度及其类别id
max_conf = np.max(class_conf, axis=1) # [N]
class_id = np.argmax(class_conf, axis=1) # [N]
# 初步筛选:保留大于置信度阈值的框
mask = max_conf > conf_thres
boxes = boxes[mask]
max_conf = max_conf[mask]
class_id = class_id[mask]
if len(boxes) == 0:
return np.array([]), np.array([]), np.array([])
# 转换坐标到原始图像(xywh to xyxy for batch)
xyxy_boxes = self.xywh2xyxy(origin_h, origin_w, boxes) # [N,4] (x1,y1,x2,y2)
# 使用OpenCV的NMSBoxes
# 注意:需要将坐标转换为整数?不需要,NMSBoxes支持浮点数
indices = cv2.dnn.NMSBoxes(
bboxes=xyxy_boxes.tolist(),
scores=max_conf.tolist(),
score_threshold=conf_thres, # 这个参数在OpenCV中实际上不起作用,因为我们已经筛选过了,所以设为0
nms_threshold=iou_thres,
eta=1.0,
top_k=200
)
if len(indices) > 0:
indices = indices.flatten()
result_boxes = xyxy_boxes[indices]
result_scores = max_conf[indices]
result_classid = class_id[indices]
else:
result_boxes = np.array([])
result_scores = np.array([])
result_classid = np.array([])
return result_boxes, result_scores, result_classid
注意:在调用NMSBoxes时,score_threshold参数可以设为0,因为我们已经在前面筛选过了。或者,我们可以不提前筛选,而是让NMSBoxes来筛选,但这样需要传入所有框(包括低置信度的),效率低。所以,我们提前筛选。
但是,在NMSBoxes中,score_threshold参数是用于筛选的,所以我们可以将score_threshold设为conf_thres,然后去掉我们自己的筛选步骤。这样,我们就不需要提前筛选了。
因此,我们可以不进行提前筛选,而是将所有框(包括低于阈值的)传入,然后让NMSBoxes根据score_threshold进行筛选。
修改:
# 去掉提前筛选的步骤,直接传入所有框(但注意:框的数量可能很多,所以效率问题?)
# 转换坐标
xyxy_boxes = self.xywh2xyxy(origin_h, origin_w, boxes) # 这里的boxes是原始的所有框(没有经过筛选)
# 然后调用NMSBoxes,设置score_threshold=conf_thres
indices = cv2.dnn.NMSBoxes(
bboxes=xyxy_boxes.tolist(),
scores=max_conf.tolist(), # 每个框的最大类别分数
score_threshold=conf_thres, # 这里会过滤掉低于阈值的框
nms_threshold=iou_thres
)
这样,我们就不需要提前筛选了。
但是,注意:max_conf和class_id是每个框的最大类别分数和类别id,即使这个最大分数低于阈值,我们也计算了。所以,我们可以先计算,然后传入NMSBoxes,它会根据score_threshold过滤。
因此,修改后的函数:
def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, iou_thres=0.4):
dets = prediction[:, :DET_NUM] # [N, DET_NUM]
boxes = dets[:, :4] # [N,4]
obj_conf = dets[:, 4] # [N]
cls_scores = dets[:, 5:5+NUM_CLASSES] # [N, NUM_CLASSES]
class_conf = obj_conf[:, None] * cls_scores # [N, NUM_CLASSES]
max_conf = np.max(class_conf, axis=1) # [N]
class_id = np.argmax(class_conf, axis=1) # [N]
# 转换坐标
xyxy_boxes = self.xywh2xyxy(origin_h, origin_w, boxes) # [N,4]
# 使用NMSBoxes
indices = cv2.dnn.NMSBoxes(
xyxy_boxes.tolist(),
max_conf.tolist(),
conf_thres,
iou_thres
)
if len(indices) > 0:
indices = indices.flatten()
result_boxes = xyxy_boxes[indices]
result_scores = max_conf[indices]
result_classid = class_id[indices]
else:
result_boxes = np.array([])
result_scores = np.array([])
result_classid = np.array([])
return result_boxes, result_scores, result_classid
注意:在调用NMSBoxes时,我们传入的bboxes是列表(每个元素是一个框的四个坐标),scores是每个框的分数(最大类别分数),score_threshold是conf_thres,nms_threshold是iou_thres。
这样,我们就完成了多类别的支持。
最后,我们需要修改几个全局变量:
NUM_CLASSES = 80 # 根据你的模型修改
DET_NUM = 5 + NUM_CLASSES # 4坐标+1个obj_conf+类别数
注意:SEG_NUM, POSE_NUM, OBB_NUM 可能需要根据模型调整。如果你的模型是纯检测模型(没有分割、姿态等),那么SEG_NUM和POSE_NUM应该为0。但是,在原始代码中,这些值被加在num_values_per_detection中。所以,如果你的模型输出中包含了这些额外的参数,那么num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM 应该等于模型输出中每个检测框的向量长度。
因此,请根据实际模型设置这些常量。
总结修改步骤:
1. 在代码开头,根据模型设置类别数:NUM_CLASSES = 80(根据你的模型修改)
2. 修改DET_NUM:DET_NUM = 5 + NUM_CLASSES
3. 修改xywh2xyxy函数,使其支持批量处理(如上述)
4. 修改non_max_suppression函数,使用上述的多类别处理方法(使用OpenCV的NMSBoxes)
5. 在post_process函数中,确保num_values_per_detection的计算正确(DET_NUM+SEG_NUM+...)
注意:如果你的模型没有分割和姿态等任务,那么SEG_NUM, POSE_NUM, OBB_NUM都应该设为0。
这样,代码就应该可以支持多类别检测了。
另外,在绘制框时,我们使用:
label="{}:{:.2f}".format(categories[int(result_classid[j])], result_scores[j])
其中,categories是类别名称列表(现在应该是多类别的列表)。所以,你需要将categories修改为包含所有类别名称的列表。
例如:
categories = ["class0", "class1", ...] # 你的类别名称
在代码的最后,我们初始化categories时,只定义了["drone"],现在需要修改为所有类别。
最后,请注意:模型本身必须是多类别模型。如果模型是单类别训练的,那么即使我们修改了代码,也无法检测多类别。
因此,确保你使用的TensorRT引擎文件是多类别模型导出的。
修改后的代码结构如下(仅显示修改部分):</think>要修改 YOLOv8 TensorRT 推理代码以支持多类别检测,需要调整以下几个关键部分。以下是详细的修改步骤和代码实现:
### 1. **修改类别定义和输出维度**
```python
# 修改类别数量(根据实际模型调整)
NUM_CLASSES = 80 # 例如 COCO 数据集有 80 个类别
DET_NUM = 5 + NUM_CLASSES # 4坐标 + 1置信度 + 类别数
POSE_NUM = 17 * 3
SEG_NUM = 32
OBB_NUM = 1
# 更新类别标签(示例使用 COCO 类别)
categories = ["person", "bicycle", "car", ... , "toothbrush"] # 完整 80 个类别
```
### 2. **重构后处理函数**
修改 `post_process` 和 `non_max_suppression` 函数以支持多类别:
```python
def post_process(self, output, origin_h, origin_w):
num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM
num = int(output[0])
pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :]
# 多类别处理
boxes = self.multi_class_nms(pred, origin_h, origin_w,
conf_thres=CONF_THRESH,
iou_thres=IOU_THRESHOLD)
result_boxes = boxes[:, :4] if len(boxes) else np.array([])
result_scores = boxes[:, 4] if len(boxes) else np.array([])
result_classid = boxes[:, 5] if len(boxes) else np.array([])
return result_boxes, result_scores, result_classid
def multi_class_nms(self, prediction, origin_h, origin_w, conf_thres=0.5, iou_thres=0.4):
"""
多类别非极大值抑制处理
"""
# 提取检测部分 [x, y, w, h, obj_conf, class_conf_0, ..., class_conf_n]
dets = prediction[:, :DET_NUM]
boxes = dets[:, :4] # [N, 4]
obj_conf = dets[:, 4] # [N]
cls_conf = dets[:, 5:5+NUM_CLASSES] # [N, NUM_CLASSES]
# 计算类别置信度 = obj_conf * class_conf
class_conf = obj_conf[:, None] * cls_conf # [N, NUM_CLASSES]
# 找到每个框的最大类别分数和对应类别ID
max_conf = np.max(class_conf, axis=1) # [N]
class_id = np.argmax(class_conf, axis=1) # [N]
# 转换坐标到原始图像尺寸
xyxy_boxes = self.xywh2xyxy(origin_h, origin_w, boxes) # [N, 4]
# 使用 OpenCV 的多类别 NMS
nms_boxes = []
for cls_idx in range(NUM_CLASSES):
cls_mask = (class_id == cls_idx) & (max_conf > conf_thres)
if not np.any(cls_mask):
continue
cls_boxes = xyxy_boxes[cls_mask]
cls_scores = max_conf[cls_mask]
# 执行 NMS
indices = cv2.dnn.NMSBoxes(
cls_boxes.tolist(),
cls_scores.tolist(),
conf_thres,
iou_thres
)
if len(indices) > 0:
indices = indices.flatten()
for idx in indices:
box = cls_boxes[idx]
score = cls_scores[idx]
nms_boxes.append([*box, score, cls_idx])
return np.array(nms_boxes) if nms_boxes else np.zeros((0, 6))
```
### 3. **修改坐标转换函数**
```python
def xywh2xyxy(self, origin_h, origin_w, x):
"""
支持批量处理的坐标转换
x: [N, 4] 格式的数组,每行是 [center_x, center_y, w, h]
返回: [N, 4] 格式的 [x1, y1, x2, y2]
"""
if len(x.shape) == 1:
x = x[None, :] # 单个框转为批量处理
y = np.zeros_like(x)
r_w = self.input_w / origin_w
r_h = self.input_h / origin_h
if r_h > r_w:
y[:, 0] = x[:, 0]
y[:, 2] = x[:, 2]
y[:, 1] = x[:, 1] - (self.input_h - r_w * origin_h) / 2
y[:, 3] = x[:, 3] - (self.input_h - r_w * origin_h) / 2
y /= r_w
else:
y[:, 0] = x[:, 0] - (self.input_w - r_h * origin_w) / 2
y[:, 2] = x[:, 2] - (self.input_w - r_h * origin_w) / 2
y[:, 1] = x[:, 1]
y[:, 3] = y[:, 1] # 保持 y 坐标不变
y /= r_h
# 转换为 xyxy 格式
xyxy_boxes = np.zeros_like(y)
xyxy_boxes[:, 0] = y[:, 0] - y[:, 2] / 2 # x1
xyxy_boxes[:, 1] = y[:, 1] - y[:, 3] / 2 # y1
xyxy_boxes[:, 2] = y[:, 0] + y[:, 2] / 2 # x2
xyxy_boxes[:, 3] = y[:, 1] + y[:, 3] / 2 # y2
return xyxy_boxes
```
### 4. **确保模型输出匹配**
在导出模型时,需要确认输出层包含多类别信息:
```bash
# 导出 ONNX 模型时指定类别数
yolo mode=export model=yolov8n.pt format=onnx opset=12
```
### 5. **可视化多类别结果**
修改绘图函数以显示不同类别:
```python
# 在 infer 方法中
for j in range(len(result_boxes)):
box = result_boxes[j]
cls_id = int(result_classid[j])
label = f"{categories[cls_id]}:{result_scores[j]:.2f}"
# 为不同类别生成不同颜色
color = [int(c) for c in COLORS[cls_id % len(COLORS)]]
plot_one_box(box, batch_image_raw[i], color=color, label=label)
# 定义颜色常量
COLORS = [
(255, 0, 0), (0, 255, 0), (0, 0, 255),
(255, 255, 0), (255, 0, 255), (0, 255, 255)
]
```
### 关键修改说明:
1. **输出维度调整**:
- `DET_NUM = 5 + NUM_CLASSES` 包含坐标、置信度和类别分数
- 解决了单类别检测中固定 `DET_NUM=6` 的限制
2. **多类别 NMS**:
- 按类别分组处理,避免不同类别间的错误抑制
- 使用 `cv2.dnn.NMSBoxes` 替代原单类别实现
3. **批量坐标转换**:
- 修改 `xywh2xyxy` 支持批量处理
- 正确处理不同长宽比的图像填充
4. **模型一致性**:
- 确保导出的 ONNX/TensorRT 模型包含多类别输出
- 验证输出维度匹配 `(4 + 1 + NUM_CLASSES)`
> **注意**:如果模型包含分割(SEG_NUM)或姿态估计(POSE_NUM)任务,需要相应调整 `num_values_per_detection` 的计算[^2]。
### 性能优化建议:
1. **类别过滤优化**:
```python
# 提前过滤低置信度框,减少 NMS 计算量
conf_mask = max_conf > conf_thres * 0.5 # 宽松阈值
boxes = boxes[conf_mask]
cls_conf = cls_conf[conf_mask]
```
2. **并行处理**:
```python
# 使用多线程处理不同类别
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
results = list(executor.map(process_class, range(NUM_CLASSES)))
```
3. **INT8 量化**:
```bash
# 导出时启用 INT8 量化
trtexec --onnx=yolov8n.onnx --saveEngine=yolov8n_int8.engine --int8
```
### 验证方法:
1. 检查输出张量形状:
```python
print(output.shape) # 应为 [batch, 4+1+NUM_CLASSES, num_boxes]
```
2. 测试不同类别的检测结果
3. 验证 mAP 指标是否与原始模型一致
> 修改后,模型可以同时检测多个类别,如 "car"、"person"、"drone" 等,并在图像中用不同颜色标注[^1][^5]。
以上代码出现(style_tune) C:\Users\28996\Desktop\AI\persona_contrastive_finetuning>python Contrastive_Training_LM.py 0%| | 0/3 [00:00<?, ?it/s]You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding. Traceback (most recent call last): File "C:\Users\28996\Desktop\AI\persona_contrastive_finetuning\Contrastive_Training_LM.py", line 227, in <module> trainer.train() File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\transformers\trainer.py", line 2171, in train return inner_training_loop( File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\transformers\trainer.py", line 2480, in _inner_training_loop batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\transformers\trainer.py", line 5156, in get_batch_samples batch_samples += [next(epoch_iterator)] File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\accelerate\data_loader.py", line 567, in __iter__ current_batch = next(dataloader_iter) File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\torch\utils\data\dataloader.py", line 701, in __next__ data = self._next_data() File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\torch\utils\data\dataloader.py", line 757, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\torch\utils\data\_utils\fetch.py", line 55, in fetch return self.collate_fn(data) File "C:\Users\28996\Desktop\AI\persona_contrastive_finetuning\Contrastive_Training_LM.py", line 38, in __call__ batch_anchor = self.tokenizer.pad( File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\transformers\tokenization_utils_base.py", line 3337, in pad raise ValueError( ValueError: type of 你如何看待气候变化? unknown: <class 'str'>. Should be one of a python, numpy, pytorch or tensorflow object. 0%| | 0/3 [00:00<?, ?it/s] 请分析解决
<think>我们遇到了一个新的错误:在调用tokenizer.pad方法时,传入的features中的某个元素是字符串类型,而不是整数列表。
根据错误信息,我们的数据集中可能有样本的字段是字符串而不是整数列表。我们需要确保在数据集中,每个样本的anchor_input_ids、positive_input_ids和negative_input_ids都是整数列表(List[int])。
可能的原因:
1. 在预处理数据集时,将token id列表保存为字符串(例如,用JSON保存时没有正确处理)。
2. 在加载数据集后,没有将字符串转换回整数列表。
解决方案:
在将数据传递给ContrastiveDataCollator之前,需要确保每个样本的三个字段都是整数列表。
我们可以修改数据集的格式,在加载后转换每个样本。具体来说,如果某个字段是字符串,我们需要将其解析为整数列表。
例如,假设我们的JSON文件中的样本是这样的:
{"anchor_input_ids": "[1, 2, 3]", ...} # 字符串形式
但我们期望的是:
{"anchor_input_ids": [1, 2, 3], ...} # 整数列表
因此,我们需要在加载数据集后,遍历每个样本,将字符串形式的token id列表转换为整数列表。
但是,也有可能我们的数据集已经是整数列表,那么就不需要转换。所以我们可以添加一个检查,如果遇到字符串,就尝试解析它。
然而,更可靠的方法是在数据预处理阶段确保保存为整数列表,或者在加载后转换。
由于我们使用`load_dataset`加载,我们可以使用`map`函数来转换整个数据集。
修改主程序部分,在加载数据集后,添加一个转换函数,将每个字段从字符串转换为整数列表(如果它们是字符串的话)。
注意:我们不知道具体是哪个字段出现了字符串,所以需要对三个字段都进行检查。
步骤:
1. 定义一个转换函数,遍历每个样本,将每个字段(如果它是字符串)转换为整数列表。
2. 使用数据集的map方法应用这个转换函数。
但是,我们也可以只对出现问题的样本进行处理。为了安全,我们处理所有样本。
修改后的代码(在主程序中加载数据集后):
```python
# 加载数据集
train_dataset = load_dataset('json', data_files='data/processed/train_style_triplets.json')['train']
val_dataset = load_dataset('json', data_files='data/processed/val_style_triplets.json')['train']
# 转换函数:将字符串形式的token id列表转换为整数列表
def convert_string_to_list(example):
# 定义要处理的字段
fields = ['anchor_input_ids', 'positive_input_ids', 'negative_input_ids']
for field in fields:
if field in example and isinstance(example[field], str):
# 去除字符串中的方括号,然后分割成数字字符串,再转换为整数
# 注意:这里假设字符串格式为 "[1, 2, 3]"
example[field] = [int(x) for x in example[field].strip('[]').split(', ') if x]
return example
# 应用转换函数
train_dataset = train_dataset.map(convert_string_to_list)
val_dataset = val_dataset.map(convert_string_to_list)
```
但是,注意:如果我们的数据集已经是整数列表,那么map操作不会改变它们,因为isinstance(example[field], str)会返回False。
然而,还有一种可能:在JSON中保存的是字符串,但load_dataset解析后可能是整数列表,也可能是字符串?这取决于JSON文件中的具体格式。
如果JSON文件中是:
"anchor_input_ids": [1,2,3] -> 解析为整数列表
"anchor_input_ids": "[1,2,3]" -> 解析为字符串
所以,我们需要确保JSON文件中的格式正确。如果已经是整数列表,那么就不需要转换。但为了处理可能存在的字符串,我们加上这个转换。
另外,我们也可以使用更健壮的方法,比如使用ast.literal_eval来安全地转换字符串表示的列表。
修改转换函数:
```python
import ast
def convert_string_to_list(example):
fields = ['anchor_input_ids', 'positive_input_ids', 'negative_input_ids']
for field in fields:
if field in example:
if isinstance(example[field], str):
try:
# 使用ast.literal_eval安全地转换字符串
example[field] = ast.literal_eval(example[field])
except:
logger.error(f"Error converting {field} for example: {example[field]}")
raise
# 确保转换后是整数列表
if isinstance(example[field], list) and all(isinstance(x, int) for x in example[field]):
pass # 正确
else:
logger.warning(f"Field {field} is not a list of integers: {example[field]}")
return example
```
但是,注意:我们可能不需要这么复杂,因为我们的数据集应该是我们预处理过的,所以我们可以先检查一下数据集的样本。
我们可以在转换前打印一个样本看看。
因此,修改后的主程序部分:
```python
# 加载数据集
train_dataset = load_dataset('json', data_files='data/processed/train_style_triplets.json')['train']
val_dataset = load_dataset('json', data_files='data/processed/val_style_triplets.json')['train']
# 打印一个样本以检查
print("训练集第一个样本(转换前):", train_dataset[0])
# 转换函数
def convert_string_to_list(example):
fields = ['anchor_input_ids', 'positive_input_ids', 'negative_input_ids']
for field in fields:
if field in example and isinstance(example[field], str):
try:
example[field] = ast.literal_eval(example[field])
except Exception as e:
print(f"Error converting {field}: {example[field]}, error: {e}")
raise
return example
import ast
train_dataset = train_dataset.map(convert_string_to_list)
val_dataset = val_dataset.map(convert_string_to_list)
print("训练集第一个样本(转换后):", train_dataset[0])
```
这样,我们可以确保输入到数据收集器的是整数列表。
如果问题仍然存在,我们可能需要检查数据收集器中的代码。但是错误信息明确指出了类型是字符串,所以问题应该出在数据集上。
另外,注意:如果我们的数据集很大,这个map操作可能会比较慢,但这是必要的。
现在,我们修改代码,加入这个转换步骤。
完整代码(只修改主程序的数据集加载部分):</think>### 问题分析
错误信息表明在数据预处理阶段,`tokenizer.pad()` 方法遇到了字符串类型的数据,而它期望的是整数列表(token IDs)。核心问题在于:
1. **数据类型不匹配**:数据集中的输入ID字段(如 `anchor_input_ids`)被错误地存储为字符串(如 `"[1, 2, 3]"`)而不是整数列表
2. **JSON解析问题**:当使用 `load_dataset` 加载JSON文件时,如果token IDs被存储为字符串,它们不会被自动转换为整数列表
3. **错误根源**:在数据预处理阶段,token IDs应该保存为整数列表,但实际被保存为JSON字符串
### 解决方案
需要修改两个地方:
1. **修复数据集加载**:将字符串格式的token IDs转换为整数列表
2. **优化数据收集器**:添加类型检查和转换逻辑
以下是修复后的完整代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
PreTrainedTokenizerBase
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from datasets import load_dataset, Dataset
from typing import Any, Dict, List, Optional, Tuple, Union
import logging
from dataclasses import dataclass
import json
import ast
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class ContrastiveDataCollator:
"""数据收集器,用于处理对比学习的三元组数据"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
# 类型检查和转换:确保输入ID是整数列表
def convert_to_int_list(input_ids):
"""将字符串格式的token IDs转换为整数列表"""
if isinstance(input_ids, str):
try:
# 尝试解析JSON格式的字符串
return json.loads(input_ids)
except json.JSONDecodeError:
try:
# 尝试解析Python列表格式的字符串
return ast.literal_eval(input_ids)
except:
logger.error(f"无法解析token IDs: {input_ids}")
return []
return input_ids
# 预处理每个特征
processed_features = []
for f in features:
processed_f = {
"anchor_input_ids": convert_to_int_list(f["anchor_input_ids"]),
"positive_input_ids": convert_to_int_list(f["positive_input_ids"]),
"negative_input_ids": convert_to_int_list(f["negative_input_ids"])
}
processed_features.append(processed_f)
# 分离出三元组的各个部分
anchor_features = [{"input_ids": f["anchor_input_ids"]} for f in processed_features]
positive_features = [{"input_ids": f["positive_input_ids"]} for f in processed_features]
negative_features = [{"input_ids": f["negative_input_ids"]} for f in processed_features]
# 对每个部分分别进行填充
batch_anchor = self.tokenizer.pad(
anchor_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch_positive = self.tokenizer.pad(
positive_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch_negative = self.tokenizer.pad(
negative_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
# 创建注意力掩码
def create_attention_mask(input_ids):
return (input_ids != self.tokenizer.pad_token_id).int()
return {
"anchor_input_ids": batch_anchor["input_ids"],
"anchor_attention_mask": create_attention_mask(batch_anchor["input_ids"]),
"positive_input_ids": batch_positive["input_ids"],
"positive_attention_mask": create_to_attention_mask(batch_positive["input_ids"]),
"negative_input_ids": batch_negative["input_ids"],
"negative_attention_mask": create_attention_mask(batch_negative["input_ids"]),
}
class ContrastiveTrainer(Trainer):
"""自定义训练器,实现对比学习损失"""
def __init__(self, *args, contrastive_config=None, **kwargs):
# 修复:使用新的参数名代替将被弃用的tokenizer参数
if "tokenizer" in kwargs:
kwargs["data_collator"] = kwargs.pop("tokenizer")
super().__init__(*args, **kwargs)
if contrastive_config is None:
contrastive_config = {}
# 设置默认值并验证配置
self.temperature = contrastive_config.get("temperature", 0.07)
self.margin = contrastive_config.get("margin", 0.3)
self.contrastive_weight = contrastive_config.get("weight", 0.8)
self.repr_layer = contrastive_config.get("repr_layer", -1)
# 验证必要参数
if not hasattr(self.model.config, "output_hidden_states") or not self.model.config.output_hidden_states:
raise ValueError("模型必须设置output_hidden_states=True以获取隐藏状态")
self.cross_entropy = nn.CrossEntropyLoss()
def compute_contrastive_loss(self, anchor_emb, pos_emb, neg_emb):
"""计算对比损失:结合InfoNCE和三元组损失"""
# 计算余弦相似度
pos_sim = F.cosine_similarity(anchor_emb, pos_emb)
neg_sim = F.cosine_similarity(anchor_emb, neg_emb)
# InfoNCE损失
numerator = torch.exp(pos_sim / self.temperature)
denominator = numerator + torch.exp(neg_sim / self.temperature)
info_nce_loss = -torch.log(numerator / (denominator + 1e-8)).mean()
# 三元组损失
triplet_loss = F.relu(neg_sim - pos_sim + self.margin).mean()
# 加权组合
return info_nce_loss + triplet_loss
def get_sequence_representation(self, outputs, attention_mask):
"""获取序列表示(取最后一个非填充token的隐藏状态)"""
# 获取指定层的隐藏状态
hidden_states = outputs.hidden_states[self.repr_layer]
# 获取每个序列的最后一个非填充token
seq_lengths = attention_mask.sum(dim=1) - 1
batch_indices = torch.arange(hidden_states.size(0))
return hidden_states[batch_indices, seq_lengths]
def compute_loss(self, model, inputs, return_outputs=False):
"""计算总损失:语言模型损失 + 对比损失"""
# 提取输入
anchor_ids = inputs["anchor_input_ids"]
anchor_mask = inputs["anchor_attention_mask"]
positive_ids = inputs["positive_input_ids"]
positive_mask = inputs["positive_attention_mask"]
negative_ids = inputs["negative_input_ids"]
negative_mask = inputs["negative_attention_mask"]
# 前向传播获取隐藏状态
def get_embeddings(input_ids, attention_mask):
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
return self.get_sequence_representation(outputs, attention_mask)
# 获取三元组的嵌入表示
anchor_emb = get_embeddings(anchor_ids, anchor_mask)
pos_emb = get_embeddings(positive_ids, positive_mask)
neg_emb = get_embeddings(negative_ids, negative_mask)
# 计算对比损失
cl_loss = self.compute_contrastive_loss(anchor_emb, pos_emb, neg_emb)
cl_loss = cl_loss * self.contrastive_weight
# 计算语言建模损失(仅针对positive回复)
lm_labels = positive_ids.clone()
lm_labels[lm_labels == self.tokenizer.pad_token_id] = -100
lm_outputs = model(
input_ids=positive_ids,
attention_mask=positive_mask,
labels=lm_labels
)
lm_loss = lm_outputs.loss
# 总损失 = LM损失 + 对比损失
total_loss = lm_loss + cl_loss
# 记录自定义指标
self.log({"cl_loss": cl_loss.item(), "lm_loss": lm_loss.item()})
return (total_loss, lm_outputs) if return_outputs else total_loss
# ================ 主程序 ================ #
if __name__ == "__main__":
# 加载模型和分词器
model = AutoModelForCausalLM.from_pretrained(
"model/Qwen/Qwen1.5-1.8B",
output_hidden_states=True, # 必须设置以获取隐藏状态
return_dict_in_generate=True # 修复隐藏状态警告
)
tokenizer = AutoTokenizer.from_pretrained("model/Qwen/Qwen1.5-1.8B")
tokenizer.pad_token = tokenizer.eos_token # 设置填充token
# 加载数据集并修复数据类型问题
def load_and_convert_dataset(file_path):
"""加载数据集并确保token IDs是整数列表"""
dataset = load_dataset('json', data_files=file_path)['train']
# 类型转换函数
def convert_token_ids(example):
for key in ['anchor_input_ids', 'positive_input_ids', 'negative_input_ids']:
if key in example and isinstance(example[key], str):
try:
# 尝试解析JSON格式的字符串
example[key] = json.loads(example[key])
except json.JSONDecodeError:
try:
# 尝试解析Python列表格式的字符串
example[key] = ast.literal_eval(example[key])
except:
logger.error(f"无法解析{key}: {example[key]}")
return example
return dataset.map(convert_token_ids)
# 加载训练集和验证集
train_dataset = load_and_convert_dataset('data/processed/train_style_triplets.json')
val_dataset = load_and_convert_dataset('data/processed/val_style_triplets.json')
# 验证数据集格式
print("训练集样本示例:", train_dataset[0])
print("验证集样本示例:", val_dataset[0])
# 训练参数配置
training_args = TrainingArguments(
output_dir="./model/lora_adapter",
per_device_train_batch_size=4,
num_train_epochs=3,
learning_rate=2e-4,
logging_steps=50,
save_steps=500,
fp16=True,
gradient_accumulation_steps=4,
report_to="none", # 禁用默认日志
remove_unused_columns=False, # 关键修复:保留自定义列
)
# 对比学习配置
contrastive_config = {
"temperature": 0.07, # 温度参数
"margin": 0.3, # 三元组损失边界
"weight": 0.8, # 对比损失权重
"repr_layer": -1 # 使用最后一层隐藏状态
}
# 初始化数据收集器
data_collator = ContrastiveDataCollator(
tokenizer=tokenizer,
max_length=512,
padding="max_length"
)
# 初始化训练器
trainer = ContrastiveTrainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=val_dataset,
contrastive_config=contrastive_config
)
# 开始训练
trainer.train()
# 评估模型
eval_results = trainer.evaluate()
print("评估结果:", eval_results)
```
### 关键修复点解释
1. **数据加载修复**:
```python
def load_and_convert_dataset(file_path):
dataset = load_dataset('json', data_files=file_path)['train']
def convert_token_ids(example):
for key in ['anchor_input_ids', ...]:
if isinstance(example[key], str):
try:
example[key] = json.loads(example[key])
except:
example[key] = ast.literal_eval(example[key])
return dataset.map(convert_token_ids)
```
- 自动检测字符串格式的token IDs
- 尝试两种解析方法:JSON格式和Python列表格式
- 确保最终token IDs是整数列表
2. **数据收集器增强**:
```python
def convert_to_int_list(input_ids):
if isinstance(input_ids, str):
try:
return json.loads(input_ids)
except:
return ast.literal_eval(input_ids)
return input_ids
```
- 在数据收集器内部添加类型检查和转换
- 双重解析机制确保兼容不同格式
- 错误处理防止解析失败导致程序崩溃
3. **数据验证**:
```python
print("训练集样本示例:", train_dataset[0])
print("验证集样本示例:", val_dataset[0])
```
- 训练前打印样本验证数据格式
- 确保token IDs已正确转换为整数列表
### 预防措施
1. **数据预处理脚本**:
在创建三元组数据集时,确保正确保存为整数列表:
```python
# 正确保存token IDs
with open('data.json', 'w') as f:
json.dump({
"anchor_input_ids": [1, 2, 3], # 整数列表
"positive_input_ids": [4, 5, 6],
"negative_input_ids": [7, 8, 9]
}, f)
# 错误方式(会导致问题):
json.dump({
"anchor_input_ids": str([1, 2, 3]) # 字符串
}, f)
```
2. **数据集检查工具**:
```python
def check_dataset(dataset):
for i in range(min(5, len(dataset))):
sample = dataset[i]
for key in ['anchor_input_ids', ...]:
if not isinstance(sample[key], list):
raise TypeError(f"样本{i}的{key}不是列表")
if not all(isinstance(x, int) for x in sample[key]):
raise TypeError(f"样本{i}的{key}包含非整数元素")
```
3. **使用Parquet格式**:
考虑使用Parquet格式替代JSON,它能更好地保留数据类型:
```python
# 保存
dataset.to_parquet('data.parquet')
# 加载
dataset = load_dataset('parquet', data_files='data.parquet')
```
阅读全文
相关推荐















