活动介绍

class MultiDataLoader(object): def __init__(self, dataloaders, n_batches): if n_batches <= 0: raise ValueError("n_batches should be > 0") self._dataloaders = dataloaders self._n_batches = np.maximum(1, n_batches) self._init_iterators()什么意思

时间: 2023-05-11 16:01:54 浏览: 173
这段代码定义了一个名为 MultiDataLoader 的类,它的构造函数接受两个参数:一个 dataloaders 列表和一个 n_batches 数字。如果 n_batches 小于等于 0,将会抛出 ValueError 异常。该类还有一个名为 _init_iterators 的私有方法。
相关问题

CONF_THRESH = 0.5 IOU_THRESHOLD = 0.4 POSE_NUM = 17 * 3 DET_NUM = 6 SEG_NUM = 32 OBB_NUM = 1 def get_img_path_batches(batch_size, img_dir): ret = [] batch = [] for root, dirs, files in os.walk(img_dir): for name in files: if len(batch) == batch_size: ret.append(batch) batch = [] batch.append(os.path.join(root, name)) if len(batch) > 0: ret.append(batch) return ret def plot_one_box(x, img, color=None, label=None, line_thickness=None): """ description: Plots one bounding box on image img, this function comes from YoLov8 project. param: x: a box likes [x1,y1,x2,y2] img: a opencv image object color: color to draw rectangle, such as (0,255,0) label: str line_thickness: int return: no return """ tl = ( line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 ) # line/font thickness color = color or [random.randint(0, 255) for _ in range(3)] c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) if label: tf = max(tl - 1, 1) # font thickness t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled cv2.putText( img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA, ) class YoLov8TRT(object): """ description: A YOLOv8 class that warps TensorRT ops, preprocess and postprocess ops. """ def __init__(self, engine_file_path): # Create a Context on this device, self.ctx = cuda.Device(0).make_context() stream = cuda.Stream() TRT_LOGGER = trt.Logger(trt.Logger.INFO) runtime = trt.Runtime(TRT_LOGGER) # Deserialize the engine from file with open(engine_file_path, "rb") as f: engine = runtime.deserialize_cuda_engine(f.read()) context = engine.create_execution_context() host_inputs = [] cuda_inputs = [] host_outputs = [] cuda_outputs = [] bindings = [] for binding in engine: print('bingding:', binding, engine.get_binding_shape(binding)) size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size dtype = trt.nptype(engine.get_binding_dtype(binding)) # Allocate host and device buffers host_mem = cuda.pagelocked_empty(size, dtype) cuda_mem = cuda.mem_alloc(host_mem.nbytes) # Append the device buffer to device bindings. bindings.append(int(cuda_mem)) # Append to the appropriate list. if engine.binding_is_input(binding): self.input_w = engine.get_binding_shape(binding)[-1] self.input_h = engine.get_binding_shape(binding)[-2] host_inputs.append(host_mem) cuda_inputs.append(cuda_mem) else: host_outputs.append(host_mem) cuda_outputs.append(cuda_mem) # Store self.stream = stream self.context = context self.engine = engine self.host_inputs = host_inputs self.cuda_inputs = cuda_inputs self.host_outputs = host_outputs self.cuda_outputs = cuda_outputs self.bindings = bindings self.batch_size = engine.max_batch_size self.det_output_length = host_outputs[0].shape[0] def infer(self, raw_image_generator): threading.Thread.__init__(self) # Make self the active context, pushing it on top of the context stack. self.ctx.push() # Restore stream = self.stream context = self.context host_inputs = self.host_inputs cuda_inputs = self.cuda_inputs host_outputs = self.host_outputs cuda_outputs = self.cuda_outputs bindings = self.bindings # Do image preprocess batch_image_raw = [] batch_origin_h = [] batch_origin_w = [] batch_input_image = np.empty(shape=[self.batch_size, 3, self.input_h, self.input_w]) for i, image_raw in enumerate(raw_image_generator): input_image, image_raw, origin_h, origin_w = self.preprocess_image(image_raw) batch_image_raw.append(image_raw) batch_origin_h.append(origin_h) batch_origin_w.append(origin_w) np.copyto(batch_input_image[i], input_image) batch_input_image = np.ascontiguousarray(batch_input_image) # Copy input image to host buffer np.copyto(host_inputs[0], batch_input_image.ravel()) start = time.time() # Transfer input data to the GPU. cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream) # Run inference. context.execute_async(batch_size=self.batch_size, bindings=bindings, stream_handle=stream.handle) # Transfer predictions back from the GPU. cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream) # Synchronize the stream stream.synchronize() end = time.time() # Remove any context from the top of the context stack, deactivating it. self.ctx.pop() # Here we use the first row of output in that batch_size = 1 output = host_outputs[0] # Do postprocess for i in range(self.batch_size): result_boxes, result_scores, result_classid = self.post_process( output[i * self.det_output_length: (i + 1) * self.det_output_length], batch_origin_h[i], batch_origin_w[i] ) # Draw rectangles and labels on the original image for j in range(len(result_boxes)): box = result_boxes[j] plot_one_box( box, batch_image_raw[i], label="{}:{:.2f}".format( categories[int(result_classid[j])], result_scores[j] ), ) return batch_image_raw, end - start def destroy(self): # Remove any context from the top of the context stack, deactivating it. self.ctx.pop() def get_raw_image(self, image_path_batch): """ description: Read an image from image path """ for img_path in image_path_batch: yield cv2.imread(img_path) def get_raw_image_zeros(self, image_path_batch=None): """ description: Ready data for warmup """ for _ in range(self.batch_size): yield np.zeros([self.input_h, self.input_w, 3], dtype=np.uint8) def preprocess_image(self, raw_bgr_image): """ description: Convert BGR image to RGB, resize and pad it to target size, normalize to [0,1], transform to NCHW format. param: input_image_path: str, image path return: image: the processed image image_raw: the original image h: original height w: original width """

<think>我们正在使用一个基于TensorRT加速的YOLOv8推理代码,目前代码只支持检测一个类别(drone)。现在需要扩展为多类别检测。 从代码中可以看出,当前代码中定义了几个常量: POSE_NUM = 17 * 3 # 关键点检测的数量(每个点3个值) DET_NUM = 6 # 每个检测框的原始输出数量(x, y, w, h, conf, class_id)? 注意:实际上YOLOv8的输出格式可能不同 SEG_NUM = 32 # 分割掩码的参数数量 OBB_NUM = 1 # 定向边界框(OBB)的参数,这里可能是1个角度? 在`post_process`函数中,我们注意到: num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM 这个值表示每个检测结果的参数总数。 然而,在标准的YOLOv8目标检测中(不带分割和姿态估计),每个检测框的输出应该是:4个坐标(x,y,w,h) + 1个置信度 + (类别数)个类别分数。 但是,在代码中,DET_NUM被设置为6,这可能是用于一个类别的情况(4坐标+1置信度+1类别分数)?但多类别情况下,类别分数应该是多个。 实际上,YOLOv8的TensorRT输出格式取决于模型。在导出模型时,如果是检测模型,输出维度为[1, 84, 8400](84=4+80,80是COCO类别数)。 但是,我们当前代码中,每个检测框的参数数量是固定的(由DET_NUM+SEG_NUM+...等构成)。这暗示我们使用的可能是YOLOv8的集成输出(同时包含检测、分割、姿态等)? 然而,用户的问题中只提到目标检测,并且当前代码中只处理了检测框(没有分割和姿态)。因此,我们需要明确: 1. 我们使用的模型是什么?是纯检测模型,还是带有其他任务(如分割、姿态)的模型? 2. 当前代码中,DET_NUM=6,这显然只适用于单类别(4坐标+1置信度+1类别分数)或者多类别但只取最大类别分数的情况? 观察代码中的`post_process`函数: pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :] 然后,在`non_max_suppression`函数中,我们只使用了前6个值(x,y,w,h,conf,class_id)? 因为boxes = prediction[prediction[:, 4] >= conf_thres] 中,第4个是置信度。 实际上,在代码中,每个检测框的参数包括: [x, y, w, h, conf, class_id, ... (其他参数,如分割、姿态等)] 因此,对于多类别检测,我们需要改变的是: - 模型输出中,每个检测框的类别分数部分应该是多个分数(每个类别一个分数),而不是一个类别id。 但是,在代码中,我们注意到在`post_process`函数中,我们只取每个检测框的前DET_NUM(6)个参数。这6个参数中,第5个(索引4)是置信度,第6个(索引5)是类别id吗? 然而,在标准的多类别YOLOv8中,每个检测框的输出应该是: [x, y, w, h, obj_conf, class_conf_0, class_conf_1, ... , class_conf_n] 然后,总的置信度 = obj_conf * max(class_conf) ?而类别id是argmax(class_conf)。 因此,当前代码中DET_NUM=6,只能处理5个类别(因为4坐标+1个obj_conf+5个类别分数=10,但这里DET_NUM=6)?这显然不对。 重新审视:在代码中,我们定义: DET_NUM = 6 但实际在YOLOv8的TensorRT输出中,每个检测框的向量长度应该是:4(坐标)+ 1(obj_conf) + 类别数(class_conf)。因此,DET_NUM应该等于5+类别数。 那么,为什么这里DET_NUM=6?因为当前模型是单类别(drone)?所以类别数=1,那么总长度=4+1+1=6。 所以,要支持多类别,我们需要: 1. 确保模型是多类别模型(在训练时指定了多个类别)。 2. 修改代码中的DET_NUM,使其等于5+类别数(因为4坐标+1个obj_conf+类别数个类别分数)。 3. 修改后处理部分,正确解析类别分数。 但是,在代码中,我们注意到: num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM 因此,如果我们修改了DET_NUM,那么整个向量长度就变了。但是,在代码中,我们只使用了每个检测框的前DET_NUM个值(在`non_max_suppression`函数中,我们只取prediction[:, :DET_NUM]?并不是,我们直接取了整个prediction,但在非极大值抑制中只用了前6个值(通过boxes = prediction[prediction[:,4]>=conf_thres]))。 实际上,在`non_max_suppression`函数中,我们传入的是整个预测数组(每个检测框有num_values_per_detection个值),但在函数内部,我们只使用了前6个(即x,y,w,h,conf,class_id)?但是,在单类别情况下,第5个(索引4)是置信度,第6个(索引5)是类别id(只有0,因为只有一个类别)。 对于多类别,我们需要改变的是:在`non_max_suppression`函数中,我们不应该直接使用第5个(索引5)作为类别id,而应该从第5个位置开始(索引4是obj_conf,然后后面是类别分数)计算每个类别的分数,然后取最大值作为类别置信度,并得到类别id。 因此,我们需要重构后处理部分。 步骤: 1. 修改DET_NUM:根据类别数,DET_NUM = 5 + num_classes(因为4坐标+1个obj_conf+num_classes个类别分数)。 但是,注意:在模型导出时,这个值已经固定了。所以,我们应该根据实际模型来设置这个值。 2. 在`post_process`函数中,我们得到每个检测框的向量(长度为num_values_per_detection),然后我们只取前DET_NUM个值(即检测部分),然后我们需要解析这个向量,得到每个类别的分数。 3. 修改`non_max_suppression`函数,使其能够处理多类别。 具体修改如下: 首先,我们需要知道类别数。可以在代码开头定义一个全局变量: NUM_CLASSES = 80 # 例如,COCO有80个类别 然后,修改: DET_NUM = 5 + NUM_CLASSES # 4坐标 + 1个obj_conf + NUM_CLASSES个类别分数 注意:在YOLOv8的原始输出中,每个检测框的格式是: [x, y, w, h, obj_conf, class_conf_0, class_conf_1, ... , class_conf_{NUM_CLASSES-1}] 然后,在`non_max_suppression`函数中,我们需要做如下修改: - 首先,从每个检测框的向量中取出前DET_NUM个值,然后分离出坐标(前4个)、obj_conf(第5个)和类别分数(后面的NUM_CLASSES个)。 - 计算每个类别的置信度:conf = obj_conf * class_conf_i - 然后,我们保留那些最大类别置信度超过阈值的框,并记录其类别id和置信度。 但是,当前代码中,我们直接使用了prediction[:,4]作为置信度,这实际上只是obj_conf(没有乘以类别分数)。而且,prediction[:,5]作为类别id,这只有在单类别情况下正确(因为只有一个类别分数,所以类别id总是0)。 因此,我们需要重写`non_max_suppression`函数,使其能够处理多类别。 另外,注意在原始代码中,我们传入`non_max_suppression`函数的prediction参数是每个检测框的完整向量(包括检测、分割、姿态等),但我们只使用前DET_NUM个值(即检测部分)。所以,在函数内部,我们只处理检测部分。 修改后的`non_max_suppression`函数(多类别版本): 步骤: 1. 提取每个检测框的检测部分(前DET_NUM个值) 2. 分离坐标(0:4)、obj_conf(4:5)、类别分数(5:5+NUM_CLASSES) 3. 计算每个类别的置信度 = obj_conf * 类别分数(每个类别分别乘) 4. 找到每个框的最大置信度以及对应的类别id 5. 根据最大置信度进行初步筛选(大于conf_thres) 6. 然后,对每个类别分别进行NMS(或者使用多类别NMS) 但是,为了简单和效率,我们通常采用: - 先对每个类别进行循环,然后分别做NMS,最后合并。 - 或者,使用单次多类别NMS(如cv2.dnn.NMSBoxes支持多类别)。 然而,由于我们使用的是自己的后处理,我们可以采用以下方法: 1. 先创建一个空列表,用于存放所有保留的框。 2. 遍历每个类别: - 取出当前类别的置信度 - 筛选出大于阈值的框 - 对这些框进行NMS(使用同一类别内的框) - 将保留的框加入列表 3. 或者,我们也可以使用另一种常见做法:先找出每个框最大分数对应的类别,然后按照这个类别分组,再对每个类别做NMS。 但是,这样效率较低。另一种高效做法是:先选出每个框的最大分数,然后按照这个最大分数进行排序,然后进行跨类别的NMS(即不同类别的框也参与抑制,但通常不同类别的框不互相抑制?)—— 实际上,不同类别的框即使重叠也不应该被抑制,所以应该按类别分别进行NMS。 然而,在目标检测中,通常NMS是每个类别独立进行的。因此,我们需要按类别处理。 但是,由于我们的检测框数量可能很大,且类别较多,循环处理可能较慢。我们可以使用矩阵操作来避免循环,但这里为了清晰,我们使用循环每个类别的方式。 修改`non_max_suppression`函数: 注意:这个函数现在需要额外的参数`num_classes`,并且需要知道DET_NUM(即检测框向量的长度)?实际上,我们已经在函数内部使用了全局的NUM_CLASSES和DET_NUM。 但是,为了减少全局变量,我们可以将类别数作为参数传入。 由于修改较大,我们重新设计`non_max_suppression`函数: 输入: prediction: 一个数组,形状为[N, DET_NUM+SEG_NUM+...],但我们现在只关心前DET_NUM个值。 origin_h, origin_w: 原始图像尺寸 conf_thres: 置信度阈值 iou_thres: IOU阈值 num_classes: 类别数 输出: 一个数组,每行是[x1, y1, x2, y2, conf, class_id] 步骤: 1. 提取检测部分(前DET_NUM列) -> dets: [N, DET_NUM] 2. 分离:boxes = dets[:, :4] # [N,4] obj_conf = dets[:, 4] # [N] cls_scores = dets[:, 5:5+num_classes] # [N, num_classes] 3. 计算类别置信度:class_conf = obj_conf.reshape(-1,1) * cls_scores [N, num_classes] 4. 找到每个框的最大类别置信度:max_conf = np.max(class_conf, axis=1) [N] class_id = np.argmax(class_conf, axis=1) [N] 5. 初步筛选:保留max_conf > conf_thres的框索引 mask = max_conf > conf_thres boxes = boxes[mask] class_id = class_id[mask] max_conf = max_conf[mask] 6. 将坐标转换到原始图像(使用xywh2xyxy函数,但注意这个函数原本是针对单框向量设计的,现在boxes是多个框) 7. 对每个类别进行循环,进行NMS: 总保留框列表:keep_boxes = [] 对每个类别id c: 从当前类别的框中,选出所有类别id等于c的框(注意:我们上面已经按最大类别筛选,所以每个框的类别就是class_id) 然后,对当前类别的框进行NMS(使用IOU阈值) 将保留的框加入keep_boxes 8. 如果keep_boxes不为空,则合并所有类别的框,然后返回。 但是,这样实现可能会比较慢。另一种方法是:先按类别分组,然后对每个组分别进行NMS,最后合并。 然而,在原始代码中,我们只进行了单类别的NMS(因为当时只有单类别)。现在需要支持多类别,所以必须按类别分组处理。 修改后的`non_max_suppression`函数(多类别版本)如下: 注意:由于原始代码中,这个函数是在类的方法中,所以我们需要修改为类的方法。 但是,由于代码较长,我们只给出关键部分。 然而,考虑到时间,我们也可以采用另一种方式:使用OpenCV的NMSBoxes,它支持多类别。但注意,OpenCV的NMSBoxes需要每个框的坐标和置信度,并且需要按类别提供。 我们可以这样做: 1. 将每个框的坐标转换为xyxy 2. 创建一个列表,存放所有框的坐标(xyxy)和对应的置信度、类别id 3. 调用cv2.dnn.NMSBoxes进行多类别NMS(注意:它要求传入所有框的坐标和每个框的得分,以及类别信息?) 但是,cv2.dnn.NMSBoxes只支持单类别NMS(即每个框只有一个分数),它不能直接处理多类别。所以,我们需要将每个框按照其最大类别分数作为该框的分数,然后使用NMSBoxes进行抑制(这样会抑制不同类别但重叠的框吗?不会,因为NMSBoxes在抑制时并不考虑类别,它只根据框的坐标和分数来抑制。所以,不同类别的框即使重叠也不会被抑制。因此,我们可以将所有框放在一起,然后调用一次NMSBoxes。 具体步骤: 1. 将每个框的坐标转换为xyxy(在原始图像坐标上) 2. 计算每个框的最大类别分数(已经在上面的步骤4中计算) 3. 将所有框(跨类别)的xyxy坐标、最大类别分数、类别id收集起来 4. 使用NMSBoxes进行非极大值抑制(指定score_threshold=conf_thres, nms_threshold=iou_thres) 这样,NMSBoxes会返回保留框的索引。 然后,我们根据这些索引,取出保留的框、分数和类别id。 因此,我们可以这样修改: 在`non_max_suppression`函数中: ... 步骤1-4(得到boxes, max_conf, class_id)... # 步骤5:筛选 mask = max_conf > conf_thres boxes = boxes[mask] max_conf = max_conf[mask] class_id = class_id[mask] # 步骤6:将boxes从xywh转换到xyxy(注意:这里boxes是[N,4]格式,每行是[x,y,w,h]) # 注意:原始代码中有一个xywh2xyxy函数,但它处理的是单个框?实际上,我们修改为可以处理多个框。 # 修改xywh2xyxy函数,使其支持批量处理(现在输入是[N,4]) # 由于原始xywh2xyxy函数是针对一个框(一行)的,现在我们需要一个批量版本 # 定义一个新函数:xywh2xyxy_batch(batch_boxes) # 或者,我们直接在函数内部转换: # x1 = boxes[:,0] - boxes[:,2]/2 # y1 = boxes[:,1] - boxes[:,3]/2 # x2 = boxes[:,0] + boxes[:,2]/2 # y2 = boxes[:,1] + boxes[:,3]/2 # 但是注意:原始代码中xywh2xyxy函数考虑了预处理时的填充和缩放,所以我们需要用原始图像的尺寸进行转换。因此,我们仍然需要使用类中的xywh2xyxy函数,但需要修改为批量处理。 因此,我们修改类中的`xywh2xyxy`函数,使其支持批量处理。 原函数: def xywh2xyxy(self, origin_h, origin_w, x): # x: [center_x, center_y, w, h] ... 修改为: def xywh2xyxy(self, origin_h, origin_w, x): """ x: 一个框或者多个框,形状为[N,4](每行是center_x, center_y, w, h) 返回:转换后的xyxy坐标,形状[N,4] """ # 复制原始方法,但改为批量计算 y = np.zeros_like(x) r_w = self.input_w / origin_w r_h = self.input_h / origin_h if r_h > r_w: # 批量处理 y[:, 0] = x[:, 0] # center_x y[:, 2] = x[:, 2] # w y[:, 1] = x[:, 1] - (self.input_h - r_w * origin_h) / 2 y[:, 3] = x[:, 3] - (self.input_h - r_w * origin_h) / 2 y /= r_w else: y[:, 0] = x[:, 0] - (self.input_w - r_h * origin_w) / 2 y[:, 2] = x[:, 2] - (self.input_w - r_h * origin_w) / 2 y[:, 1] = x[:, 1] y[:, 3] = x[:, 3] y /= r_h # 将center_x, center_y, w, h 转换为 x1,y1,x2,y2 # 注意:y现在每行是[center_x, center_y, w, h](已经转换到原始图像尺寸) # 然后,我们转换为xyxy x1 = y[:, 0] - y[:, 2] / 2 y1 = y[:, 1] - y[:, 3] / 2 x2 = y[:, 0] + y[:, 2] / 2 y2 = y[:, 1] + y[:, 3] / 2 return np.stack([x1, y1, x2, y2], axis=1) 但是,这样修改会破坏原来的单框处理?所以,我们重新写一个批量处理的函数,或者修改原函数使其同时支持单个框和多个框。 为了简单,我们直接修改为批量处理。 注意:原函数中x的输入是[4]或[N,4],输出对应。 修改后的`xywh2xyxy`函数: def xywh2xyxy(self, origin_h, origin_w, x): if len(x.shape) == 1: x = x[None, :] # 变成[1,4] y = np.zeros_like(x) r_w = self.input_w / origin_w r_h = self.input_h / origin_h if r_h > r_w: y[:, 0] = x[:, 0] y[:, 2] = x[:, 2] y[:, 1] = x[:, 1] - (self.input_h - r_w * origin_h) / 2 y[:, 3] = x[:, 3] - (self.input_h - r_w * origin_h) / 2 y /= r_w else: y[:, 0] = x[:, 0] - (self.input_w - r_h * origin_w) / 2 y[:, 2] = x[:, 2] - (self.input_w - r_h * origin_w) / 2 y[:, 1] = x[:, 1] y[:, 3] = x[:, 3] y /= r_h # 转换为xyxy xyxy_boxes = np.zeros_like(y) xyxy_boxes[:, 0] = y[:, 0] - y[:, 2] / 2 # x1 = center_x - w/2 xyxy_boxes[:, 1] = y[:, 1] - y[:, 3] / 2 # y1 xyxy_boxes[:, 2] = y[:, 0] + y[:, 2] / 2 # x2 xyxy_boxes[:, 3] = y[:, 1] + y[:, 3] / 2 # y2 return xyxy_boxes 这样,我们就可以批量处理了。 回到`non_max_suppression`函数: # 转换坐标 xyxy_boxes = self.xywh2xyxy(origin_h, origin_w, boxes) # [N,4] 每行是x1,y1,x2,y2 # 现在,xyxy_boxes是原始图像坐标下的边界框 # 然后,我们使用OpenCV的NMSBoxes indices = cv2.dnn.NMSBoxes( bboxes=xyxy_boxes.tolist(), scores=max_conf.tolist(), score_threshold=conf_thres, nms_threshold=iou_thres, eta=1.0, top_k=200 ) # indices是一个一维数组,表示保留的框的索引 if len(indices) > 0: indices = indices.flatten() result_boxes = xyxy_boxes[indices] result_scores = max_conf[indices] result_classid = class_id[indices] else: result_boxes = np.array([]) result_scores = np.array([]) result_classid = np.array([]) # 合并结果,返回一个数组,每行是[x1,y1,x2,y2, conf, class_id] if len(result_boxes) > 0: result = np.concatenate([ result_boxes, result_scores[:, None], result_classid[:, None] ], axis=1) else: result = np.zeros((0,6)) return result 这样,我们就完成了多类别的NMS。 但是,注意:在原始代码中,`non_max_suppression`函数返回的是三个独立的数组(result_boxes, result_scores, result_classid)。现在,为了保持接口一致,我们仍然返回三个独立的数组(因为调用方是这么使用的)。 因此,我们可以这样返回: return result_boxes, result_scores, result_classid 但是,注意:在调用`non_max_suppression`的地方(在`post_process`函数中): result_boxes, result_scores, result_classid = self.non_max_suppression(...) 所以,我们修改`non_max_suppression`函数,让它返回三个独立的数组(即使没有检测框,返回空数组)。 修改后的`non_max_suppression`函数(多类别版)如下: def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, iou_thres=0.4): # prediction: [N, num_values_per_detection] 其中N是检测框数量 # 我们只取前DET_NUM个值(检测部分) dets = prediction[:, :DET_NUM] # [N, DET_NUM] # 分离 boxes = dets[:, :4] # [N,4] (x,y,w,h) obj_conf = dets[:, 4] # [N] cls_scores = dets[:, 5:5+NUM_CLASSES] # [N, NUM_CLASSES] # 计算类别置信度 class_conf = obj_conf[:, None] * cls_scores # [N, NUM_CLASSES] # 找到每个框的最大类别置信度及其类别id max_conf = np.max(class_conf, axis=1) # [N] class_id = np.argmax(class_conf, axis=1) # [N] # 初步筛选:保留大于置信度阈值的框 mask = max_conf > conf_thres boxes = boxes[mask] max_conf = max_conf[mask] class_id = class_id[mask] if len(boxes) == 0: return np.array([]), np.array([]), np.array([]) # 转换坐标到原始图像(xywh to xyxy for batch) xyxy_boxes = self.xywh2xyxy(origin_h, origin_w, boxes) # [N,4] (x1,y1,x2,y2) # 使用OpenCV的NMSBoxes # 注意:需要将坐标转换为整数?不需要,NMSBoxes支持浮点数 indices = cv2.dnn.NMSBoxes( bboxes=xyxy_boxes.tolist(), scores=max_conf.tolist(), score_threshold=conf_thres, # 这个参数在OpenCV中实际上不起作用,因为我们已经筛选过了,所以设为0 nms_threshold=iou_thres, eta=1.0, top_k=200 ) if len(indices) > 0: indices = indices.flatten() result_boxes = xyxy_boxes[indices] result_scores = max_conf[indices] result_classid = class_id[indices] else: result_boxes = np.array([]) result_scores = np.array([]) result_classid = np.array([]) return result_boxes, result_scores, result_classid 注意:在调用NMSBoxes时,score_threshold参数可以设为0,因为我们已经在前面筛选过了。或者,我们可以不提前筛选,而是让NMSBoxes来筛选,但这样需要传入所有框(包括低置信度的),效率低。所以,我们提前筛选。 但是,在NMSBoxes中,score_threshold参数是用于筛选的,所以我们可以将score_threshold设为conf_thres,然后去掉我们自己的筛选步骤。这样,我们就不需要提前筛选了。 因此,我们可以不进行提前筛选,而是将所有框(包括低于阈值的)传入,然后让NMSBoxes根据score_threshold进行筛选。 修改: # 去掉提前筛选的步骤,直接传入所有框(但注意:框的数量可能很多,所以效率问题?) # 转换坐标 xyxy_boxes = self.xywh2xyxy(origin_h, origin_w, boxes) # 这里的boxes是原始的所有框(没有经过筛选) # 然后调用NMSBoxes,设置score_threshold=conf_thres indices = cv2.dnn.NMSBoxes( bboxes=xyxy_boxes.tolist(), scores=max_conf.tolist(), # 每个框的最大类别分数 score_threshold=conf_thres, # 这里会过滤掉低于阈值的框 nms_threshold=iou_thres ) 这样,我们就不需要提前筛选了。 但是,注意:max_conf和class_id是每个框的最大类别分数和类别id,即使这个最大分数低于阈值,我们也计算了。所以,我们可以先计算,然后传入NMSBoxes,它会根据score_threshold过滤。 因此,修改后的函数: def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, iou_thres=0.4): dets = prediction[:, :DET_NUM] # [N, DET_NUM] boxes = dets[:, :4] # [N,4] obj_conf = dets[:, 4] # [N] cls_scores = dets[:, 5:5+NUM_CLASSES] # [N, NUM_CLASSES] class_conf = obj_conf[:, None] * cls_scores # [N, NUM_CLASSES] max_conf = np.max(class_conf, axis=1) # [N] class_id = np.argmax(class_conf, axis=1) # [N] # 转换坐标 xyxy_boxes = self.xywh2xyxy(origin_h, origin_w, boxes) # [N,4] # 使用NMSBoxes indices = cv2.dnn.NMSBoxes( xyxy_boxes.tolist(), max_conf.tolist(), conf_thres, iou_thres ) if len(indices) > 0: indices = indices.flatten() result_boxes = xyxy_boxes[indices] result_scores = max_conf[indices] result_classid = class_id[indices] else: result_boxes = np.array([]) result_scores = np.array([]) result_classid = np.array([]) return result_boxes, result_scores, result_classid 注意:在调用NMSBoxes时,我们传入的bboxes是列表(每个元素是一个框的四个坐标),scores是每个框的分数(最大类别分数),score_threshold是conf_thres,nms_threshold是iou_thres。 这样,我们就完成了多类别的支持。 最后,我们需要修改几个全局变量: NUM_CLASSES = 80 # 根据你的模型修改 DET_NUM = 5 + NUM_CLASSES # 4坐标+1个obj_conf+类别数 注意:SEG_NUM, POSE_NUM, OBB_NUM 可能需要根据模型调整。如果你的模型是纯检测模型(没有分割、姿态等),那么SEG_NUM和POSE_NUM应该为0。但是,在原始代码中,这些值被加在num_values_per_detection中。所以,如果你的模型输出中包含了这些额外的参数,那么num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM 应该等于模型输出中每个检测框的向量长度。 因此,请根据实际模型设置这些常量。 总结修改步骤: 1. 在代码开头,根据模型设置类别数:NUM_CLASSES = 80(根据你的模型修改) 2. 修改DET_NUM:DET_NUM = 5 + NUM_CLASSES 3. 修改xywh2xyxy函数,使其支持批量处理(如上述) 4. 修改non_max_suppression函数,使用上述的多类别处理方法(使用OpenCV的NMSBoxes) 5. 在post_process函数中,确保num_values_per_detection的计算正确(DET_NUM+SEG_NUM+...) 注意:如果你的模型没有分割和姿态等任务,那么SEG_NUM, POSE_NUM, OBB_NUM都应该设为0。 这样,代码就应该可以支持多类别检测了。 另外,在绘制框时,我们使用: label="{}:{:.2f}".format(categories[int(result_classid[j])], result_scores[j]) 其中,categories是类别名称列表(现在应该是多类别的列表)。所以,你需要将categories修改为包含所有类别名称的列表。 例如: categories = ["class0", "class1", ...] # 你的类别名称 在代码的最后,我们初始化categories时,只定义了["drone"],现在需要修改为所有类别。 最后,请注意:模型本身必须是多类别模型。如果模型是单类别训练的,那么即使我们修改了代码,也无法检测多类别。 因此,确保你使用的TensorRT引擎文件是多类别模型导出的。 修改后的代码结构如下(仅显示修改部分):</think>要修改 YOLOv8 TensorRT 推理代码以支持多类别检测,需要调整以下几个关键部分。以下是详细的修改步骤和代码实现: ### 1. **修改类别定义和输出维度** ```python # 修改类别数量(根据实际模型调整) NUM_CLASSES = 80 # 例如 COCO 数据集有 80 个类别 DET_NUM = 5 + NUM_CLASSES # 4坐标 + 1置信度 + 类别数 POSE_NUM = 17 * 3 SEG_NUM = 32 OBB_NUM = 1 # 更新类别标签(示例使用 COCO 类别) categories = ["person", "bicycle", "car", ... , "toothbrush"] # 完整 80 个类别 ``` ### 2. **重构后处理函数** 修改 `post_process` 和 `non_max_suppression` 函数以支持多类别: ```python def post_process(self, output, origin_h, origin_w): num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM num = int(output[0]) pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :] # 多类别处理 boxes = self.multi_class_nms(pred, origin_h, origin_w, conf_thres=CONF_THRESH, iou_thres=IOU_THRESHOLD) result_boxes = boxes[:, :4] if len(boxes) else np.array([]) result_scores = boxes[:, 4] if len(boxes) else np.array([]) result_classid = boxes[:, 5] if len(boxes) else np.array([]) return result_boxes, result_scores, result_classid def multi_class_nms(self, prediction, origin_h, origin_w, conf_thres=0.5, iou_thres=0.4): """ 多类别非极大值抑制处理 """ # 提取检测部分 [x, y, w, h, obj_conf, class_conf_0, ..., class_conf_n] dets = prediction[:, :DET_NUM] boxes = dets[:, :4] # [N, 4] obj_conf = dets[:, 4] # [N] cls_conf = dets[:, 5:5+NUM_CLASSES] # [N, NUM_CLASSES] # 计算类别置信度 = obj_conf * class_conf class_conf = obj_conf[:, None] * cls_conf # [N, NUM_CLASSES] # 找到每个框的最大类别分数和对应类别ID max_conf = np.max(class_conf, axis=1) # [N] class_id = np.argmax(class_conf, axis=1) # [N] # 转换坐标到原始图像尺寸 xyxy_boxes = self.xywh2xyxy(origin_h, origin_w, boxes) # [N, 4] # 使用 OpenCV 的多类别 NMS nms_boxes = [] for cls_idx in range(NUM_CLASSES): cls_mask = (class_id == cls_idx) & (max_conf > conf_thres) if not np.any(cls_mask): continue cls_boxes = xyxy_boxes[cls_mask] cls_scores = max_conf[cls_mask] # 执行 NMS indices = cv2.dnn.NMSBoxes( cls_boxes.tolist(), cls_scores.tolist(), conf_thres, iou_thres ) if len(indices) > 0: indices = indices.flatten() for idx in indices: box = cls_boxes[idx] score = cls_scores[idx] nms_boxes.append([*box, score, cls_idx]) return np.array(nms_boxes) if nms_boxes else np.zeros((0, 6)) ``` ### 3. **修改坐标转换函数** ```python def xywh2xyxy(self, origin_h, origin_w, x): """ 支持批量处理的坐标转换 x: [N, 4] 格式的数组,每行是 [center_x, center_y, w, h] 返回: [N, 4] 格式的 [x1, y1, x2, y2] """ if len(x.shape) == 1: x = x[None, :] # 单个框转为批量处理 y = np.zeros_like(x) r_w = self.input_w / origin_w r_h = self.input_h / origin_h if r_h > r_w: y[:, 0] = x[:, 0] y[:, 2] = x[:, 2] y[:, 1] = x[:, 1] - (self.input_h - r_w * origin_h) / 2 y[:, 3] = x[:, 3] - (self.input_h - r_w * origin_h) / 2 y /= r_w else: y[:, 0] = x[:, 0] - (self.input_w - r_h * origin_w) / 2 y[:, 2] = x[:, 2] - (self.input_w - r_h * origin_w) / 2 y[:, 1] = x[:, 1] y[:, 3] = y[:, 1] # 保持 y 坐标不变 y /= r_h # 转换为 xyxy 格式 xyxy_boxes = np.zeros_like(y) xyxy_boxes[:, 0] = y[:, 0] - y[:, 2] / 2 # x1 xyxy_boxes[:, 1] = y[:, 1] - y[:, 3] / 2 # y1 xyxy_boxes[:, 2] = y[:, 0] + y[:, 2] / 2 # x2 xyxy_boxes[:, 3] = y[:, 1] + y[:, 3] / 2 # y2 return xyxy_boxes ``` ### 4. **确保模型输出匹配** 在导出模型时,需要确认输出层包含多类别信息: ```bash # 导出 ONNX 模型时指定类别数 yolo mode=export model=yolov8n.pt format=onnx opset=12 ``` ### 5. **可视化多类别结果** 修改绘图函数以显示不同类别: ```python # 在 infer 方法中 for j in range(len(result_boxes)): box = result_boxes[j] cls_id = int(result_classid[j]) label = f"{categories[cls_id]}:{result_scores[j]:.2f}" # 为不同类别生成不同颜色 color = [int(c) for c in COLORS[cls_id % len(COLORS)]] plot_one_box(box, batch_image_raw[i], color=color, label=label) # 定义颜色常量 COLORS = [ (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255) ] ``` ### 关键修改说明: 1. **输出维度调整**: - `DET_NUM = 5 + NUM_CLASSES` 包含坐标、置信度和类别分数 - 解决了单类别检测中固定 `DET_NUM=6` 的限制 2. **多类别 NMS**: - 按类别分组处理,避免不同类别间的错误抑制 - 使用 `cv2.dnn.NMSBoxes` 替代原单类别实现 3. **批量坐标转换**: - 修改 `xywh2xyxy` 支持批量处理 - 正确处理不同长宽比的图像填充 4. **模型一致性**: - 确保导出的 ONNX/TensorRT 模型包含多类别输出 - 验证输出维度匹配 `(4 + 1 + NUM_CLASSES)` > **注意**:如果模型包含分割(SEG_NUM)或姿态估计(POSE_NUM)任务,需要相应调整 `num_values_per_detection` 的计算[^2]。 ### 性能优化建议: 1. **类别过滤优化**: ```python # 提前过滤低置信度框,减少 NMS 计算量 conf_mask = max_conf > conf_thres * 0.5 # 宽松阈值 boxes = boxes[conf_mask] cls_conf = cls_conf[conf_mask] ``` 2. **并行处理**: ```python # 使用多线程处理不同类别 from concurrent.futures import ThreadPoolExecutor with ThreadPoolExecutor() as executor: results = list(executor.map(process_class, range(NUM_CLASSES))) ``` 3. **INT8 量化**: ```bash # 导出时启用 INT8 量化 trtexec --onnx=yolov8n.onnx --saveEngine=yolov8n_int8.engine --int8 ``` ### 验证方法: 1. 检查输出张量形状: ```python print(output.shape) # 应为 [batch, 4+1+NUM_CLASSES, num_boxes] ``` 2. 测试不同类别的检测结果 3. 验证 mAP 指标是否与原始模型一致 > 修改后,模型可以同时检测多个类别,如 "car"、"person"、"drone" 等,并在图像中用不同颜色标注[^1][^5]。

以上代码出现(style_tune) C:\Users\28996\Desktop\AI\persona_contrastive_finetuning>python Contrastive_Training_LM.py 0%| | 0/3 [00:00<?, ?it/s]You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding. Traceback (most recent call last): File "C:\Users\28996\Desktop\AI\persona_contrastive_finetuning\Contrastive_Training_LM.py", line 227, in <module> trainer.train() File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\transformers\trainer.py", line 2171, in train return inner_training_loop( File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\transformers\trainer.py", line 2480, in _inner_training_loop batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\transformers\trainer.py", line 5156, in get_batch_samples batch_samples += [next(epoch_iterator)] File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\accelerate\data_loader.py", line 567, in __iter__ current_batch = next(dataloader_iter) File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\torch\utils\data\dataloader.py", line 701, in __next__ data = self._next_data() File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\torch\utils\data\dataloader.py", line 757, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\torch\utils\data\_utils\fetch.py", line 55, in fetch return self.collate_fn(data) File "C:\Users\28996\Desktop\AI\persona_contrastive_finetuning\Contrastive_Training_LM.py", line 38, in __call__ batch_anchor = self.tokenizer.pad( File "C:\Users\28996\miniconda3\envs\style_tune\lib\site-packages\transformers\tokenization_utils_base.py", line 3337, in pad raise ValueError( ValueError: type of 你如何看待气候变化? unknown: <class 'str'>. Should be one of a python, numpy, pytorch or tensorflow object. 0%| | 0/3 [00:00<?, ?it/s] 请分析解决

<think>我们遇到了一个新的错误:在调用tokenizer.pad方法时,传入的features中的某个元素是字符串类型,而不是整数列表。 根据错误信息,我们的数据集中可能有样本的字段是字符串而不是整数列表。我们需要确保在数据集中,每个样本的anchor_input_ids、positive_input_ids和negative_input_ids都是整数列表(List[int])。 可能的原因: 1. 在预处理数据集时,将token id列表保存为字符串(例如,用JSON保存时没有正确处理)。 2. 在加载数据集后,没有将字符串转换回整数列表。 解决方案: 在将数据传递给ContrastiveDataCollator之前,需要确保每个样本的三个字段都是整数列表。 我们可以修改数据集的格式,在加载后转换每个样本。具体来说,如果某个字段是字符串,我们需要将其解析为整数列表。 例如,假设我们的JSON文件中的样本是这样的: {"anchor_input_ids": "[1, 2, 3]", ...} # 字符串形式 但我们期望的是: {"anchor_input_ids": [1, 2, 3], ...} # 整数列表 因此,我们需要在加载数据集后,遍历每个样本,将字符串形式的token id列表转换为整数列表。 但是,也有可能我们的数据集已经是整数列表,那么就不需要转换。所以我们可以添加一个检查,如果遇到字符串,就尝试解析它。 然而,更可靠的方法是在数据预处理阶段确保保存为整数列表,或者在加载后转换。 由于我们使用`load_dataset`加载,我们可以使用`map`函数来转换整个数据集。 修改主程序部分,在加载数据集后,添加一个转换函数,将每个字段从字符串转换为整数列表(如果它们是字符串的话)。 注意:我们不知道具体是哪个字段出现了字符串,所以需要对三个字段都进行检查。 步骤: 1. 定义一个转换函数,遍历每个样本,将每个字段(如果它是字符串)转换为整数列表。 2. 使用数据集的map方法应用这个转换函数。 但是,我们也可以只对出现问题的样本进行处理。为了安全,我们处理所有样本。 修改后的代码(在主程序中加载数据集后): ```python # 加载数据集 train_dataset = load_dataset('json', data_files='data/processed/train_style_triplets.json')['train'] val_dataset = load_dataset('json', data_files='data/processed/val_style_triplets.json')['train'] # 转换函数:将字符串形式的token id列表转换为整数列表 def convert_string_to_list(example): # 定义要处理的字段 fields = ['anchor_input_ids', 'positive_input_ids', 'negative_input_ids'] for field in fields: if field in example and isinstance(example[field], str): # 去除字符串中的方括号,然后分割成数字字符串,再转换为整数 # 注意:这里假设字符串格式为 "[1, 2, 3]" example[field] = [int(x) for x in example[field].strip('[]').split(', ') if x] return example # 应用转换函数 train_dataset = train_dataset.map(convert_string_to_list) val_dataset = val_dataset.map(convert_string_to_list) ``` 但是,注意:如果我们的数据集已经是整数列表,那么map操作不会改变它们,因为isinstance(example[field], str)会返回False。 然而,还有一种可能:在JSON中保存的是字符串,但load_dataset解析后可能是整数列表,也可能是字符串?这取决于JSON文件中的具体格式。 如果JSON文件中是: "anchor_input_ids": [1,2,3] -> 解析为整数列表 "anchor_input_ids": "[1,2,3]" -> 解析为字符串 所以,我们需要确保JSON文件中的格式正确。如果已经是整数列表,那么就不需要转换。但为了处理可能存在的字符串,我们加上这个转换。 另外,我们也可以使用更健壮的方法,比如使用ast.literal_eval来安全地转换字符串表示的列表。 修改转换函数: ```python import ast def convert_string_to_list(example): fields = ['anchor_input_ids', 'positive_input_ids', 'negative_input_ids'] for field in fields: if field in example: if isinstance(example[field], str): try: # 使用ast.literal_eval安全地转换字符串 example[field] = ast.literal_eval(example[field]) except: logger.error(f"Error converting {field} for example: {example[field]}") raise # 确保转换后是整数列表 if isinstance(example[field], list) and all(isinstance(x, int) for x in example[field]): pass # 正确 else: logger.warning(f"Field {field} is not a list of integers: {example[field]}") return example ``` 但是,注意:我们可能不需要这么复杂,因为我们的数据集应该是我们预处理过的,所以我们可以先检查一下数据集的样本。 我们可以在转换前打印一个样本看看。 因此,修改后的主程序部分: ```python # 加载数据集 train_dataset = load_dataset('json', data_files='data/processed/train_style_triplets.json')['train'] val_dataset = load_dataset('json', data_files='data/processed/val_style_triplets.json')['train'] # 打印一个样本以检查 print("训练集第一个样本(转换前):", train_dataset[0]) # 转换函数 def convert_string_to_list(example): fields = ['anchor_input_ids', 'positive_input_ids', 'negative_input_ids'] for field in fields: if field in example and isinstance(example[field], str): try: example[field] = ast.literal_eval(example[field]) except Exception as e: print(f"Error converting {field}: {example[field]}, error: {e}") raise return example import ast train_dataset = train_dataset.map(convert_string_to_list) val_dataset = val_dataset.map(convert_string_to_list) print("训练集第一个样本(转换后):", train_dataset[0]) ``` 这样,我们可以确保输入到数据收集器的是整数列表。 如果问题仍然存在,我们可能需要检查数据收集器中的代码。但是错误信息明确指出了类型是字符串,所以问题应该出在数据集上。 另外,注意:如果我们的数据集很大,这个map操作可能会比较慢,但这是必要的。 现在,我们修改代码,加入这个转换步骤。 完整代码(只修改主程序的数据集加载部分):</think>### 问题分析 错误信息表明在数据预处理阶段,`tokenizer.pad()` 方法遇到了字符串类型的数据,而它期望的是整数列表(token IDs)。核心问题在于: 1. **数据类型不匹配**:数据集中的输入ID字段(如 `anchor_input_ids`)被错误地存储为字符串(如 `"[1, 2, 3]"`)而不是整数列表 2. **JSON解析问题**:当使用 `load_dataset` 加载JSON文件时,如果token IDs被存储为字符串,它们不会被自动转换为整数列表 3. **错误根源**:在数据预处理阶段,token IDs应该保存为整数列表,但实际被保存为JSON字符串 ### 解决方案 需要修改两个地方: 1. **修复数据集加载**:将字符串格式的token IDs转换为整数列表 2. **优化数据收集器**:添加类型检查和转换逻辑 以下是修复后的完整代码: ```python import torch import torch.nn as nn import torch.nn.functional as F from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, PreTrainedTokenizerBase ) from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy from datasets import load_dataset, Dataset from typing import Any, Dict, List, Optional, Tuple, Union import logging from dataclasses import dataclass import json import ast # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class ContrastiveDataCollator: """数据收集器,用于处理对比学习的三元组数据""" tokenizer: PreTrainedTokenizerBase padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None pad_to_multiple_of: Optional[int] = None return_tensors: str = "pt" def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: # 类型检查和转换:确保输入ID是整数列表 def convert_to_int_list(input_ids): """将字符串格式的token IDs转换为整数列表""" if isinstance(input_ids, str): try: # 尝试解析JSON格式的字符串 return json.loads(input_ids) except json.JSONDecodeError: try: # 尝试解析Python列表格式的字符串 return ast.literal_eval(input_ids) except: logger.error(f"无法解析token IDs: {input_ids}") return [] return input_ids # 预处理每个特征 processed_features = [] for f in features: processed_f = { "anchor_input_ids": convert_to_int_list(f["anchor_input_ids"]), "positive_input_ids": convert_to_int_list(f["positive_input_ids"]), "negative_input_ids": convert_to_int_list(f["negative_input_ids"]) } processed_features.append(processed_f) # 分离出三元组的各个部分 anchor_features = [{"input_ids": f["anchor_input_ids"]} for f in processed_features] positive_features = [{"input_ids": f["positive_input_ids"]} for f in processed_features] negative_features = [{"input_ids": f["negative_input_ids"]} for f in processed_features] # 对每个部分分别进行填充 batch_anchor = self.tokenizer.pad( anchor_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, ) batch_positive = self.tokenizer.pad( positive_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, ) batch_negative = self.tokenizer.pad( negative_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=self.return_tensors, ) # 创建注意力掩码 def create_attention_mask(input_ids): return (input_ids != self.tokenizer.pad_token_id).int() return { "anchor_input_ids": batch_anchor["input_ids"], "anchor_attention_mask": create_attention_mask(batch_anchor["input_ids"]), "positive_input_ids": batch_positive["input_ids"], "positive_attention_mask": create_to_attention_mask(batch_positive["input_ids"]), "negative_input_ids": batch_negative["input_ids"], "negative_attention_mask": create_attention_mask(batch_negative["input_ids"]), } class ContrastiveTrainer(Trainer): """自定义训练器,实现对比学习损失""" def __init__(self, *args, contrastive_config=None, **kwargs): # 修复:使用新的参数名代替将被弃用的tokenizer参数 if "tokenizer" in kwargs: kwargs["data_collator"] = kwargs.pop("tokenizer") super().__init__(*args, **kwargs) if contrastive_config is None: contrastive_config = {} # 设置默认值并验证配置 self.temperature = contrastive_config.get("temperature", 0.07) self.margin = contrastive_config.get("margin", 0.3) self.contrastive_weight = contrastive_config.get("weight", 0.8) self.repr_layer = contrastive_config.get("repr_layer", -1) # 验证必要参数 if not hasattr(self.model.config, "output_hidden_states") or not self.model.config.output_hidden_states: raise ValueError("模型必须设置output_hidden_states=True以获取隐藏状态") self.cross_entropy = nn.CrossEntropyLoss() def compute_contrastive_loss(self, anchor_emb, pos_emb, neg_emb): """计算对比损失:结合InfoNCE和三元组损失""" # 计算余弦相似度 pos_sim = F.cosine_similarity(anchor_emb, pos_emb) neg_sim = F.cosine_similarity(anchor_emb, neg_emb) # InfoNCE损失 numerator = torch.exp(pos_sim / self.temperature) denominator = numerator + torch.exp(neg_sim / self.temperature) info_nce_loss = -torch.log(numerator / (denominator + 1e-8)).mean() # 三元组损失 triplet_loss = F.relu(neg_sim - pos_sim + self.margin).mean() # 加权组合 return info_nce_loss + triplet_loss def get_sequence_representation(self, outputs, attention_mask): """获取序列表示(取最后一个非填充token的隐藏状态)""" # 获取指定层的隐藏状态 hidden_states = outputs.hidden_states[self.repr_layer] # 获取每个序列的最后一个非填充token seq_lengths = attention_mask.sum(dim=1) - 1 batch_indices = torch.arange(hidden_states.size(0)) return hidden_states[batch_indices, seq_lengths] def compute_loss(self, model, inputs, return_outputs=False): """计算总损失:语言模型损失 + 对比损失""" # 提取输入 anchor_ids = inputs["anchor_input_ids"] anchor_mask = inputs["anchor_attention_mask"] positive_ids = inputs["positive_input_ids"] positive_mask = inputs["positive_attention_mask"] negative_ids = inputs["negative_input_ids"] negative_mask = inputs["negative_attention_mask"] # 前向传播获取隐藏状态 def get_embeddings(input_ids, attention_mask): outputs = model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) return self.get_sequence_representation(outputs, attention_mask) # 获取三元组的嵌入表示 anchor_emb = get_embeddings(anchor_ids, anchor_mask) pos_emb = get_embeddings(positive_ids, positive_mask) neg_emb = get_embeddings(negative_ids, negative_mask) # 计算对比损失 cl_loss = self.compute_contrastive_loss(anchor_emb, pos_emb, neg_emb) cl_loss = cl_loss * self.contrastive_weight # 计算语言建模损失(仅针对positive回复) lm_labels = positive_ids.clone() lm_labels[lm_labels == self.tokenizer.pad_token_id] = -100 lm_outputs = model( input_ids=positive_ids, attention_mask=positive_mask, labels=lm_labels ) lm_loss = lm_outputs.loss # 总损失 = LM损失 + 对比损失 total_loss = lm_loss + cl_loss # 记录自定义指标 self.log({"cl_loss": cl_loss.item(), "lm_loss": lm_loss.item()}) return (total_loss, lm_outputs) if return_outputs else total_loss # ================ 主程序 ================ # if __name__ == "__main__": # 加载模型和分词器 model = AutoModelForCausalLM.from_pretrained( "model/Qwen/Qwen1.5-1.8B", output_hidden_states=True, # 必须设置以获取隐藏状态 return_dict_in_generate=True # 修复隐藏状态警告 ) tokenizer = AutoTokenizer.from_pretrained("model/Qwen/Qwen1.5-1.8B") tokenizer.pad_token = tokenizer.eos_token # 设置填充token # 加载数据集并修复数据类型问题 def load_and_convert_dataset(file_path): """加载数据集并确保token IDs是整数列表""" dataset = load_dataset('json', data_files=file_path)['train'] # 类型转换函数 def convert_token_ids(example): for key in ['anchor_input_ids', 'positive_input_ids', 'negative_input_ids']: if key in example and isinstance(example[key], str): try: # 尝试解析JSON格式的字符串 example[key] = json.loads(example[key]) except json.JSONDecodeError: try: # 尝试解析Python列表格式的字符串 example[key] = ast.literal_eval(example[key]) except: logger.error(f"无法解析{key}: {example[key]}") return example return dataset.map(convert_token_ids) # 加载训练集和验证集 train_dataset = load_and_convert_dataset('data/processed/train_style_triplets.json') val_dataset = load_and_convert_dataset('data/processed/val_style_triplets.json') # 验证数据集格式 print("训练集样本示例:", train_dataset[0]) print("验证集样本示例:", val_dataset[0]) # 训练参数配置 training_args = TrainingArguments( output_dir="./model/lora_adapter", per_device_train_batch_size=4, num_train_epochs=3, learning_rate=2e-4, logging_steps=50, save_steps=500, fp16=True, gradient_accumulation_steps=4, report_to="none", # 禁用默认日志 remove_unused_columns=False, # 关键修复:保留自定义列 ) # 对比学习配置 contrastive_config = { "temperature": 0.07, # 温度参数 "margin": 0.3, # 三元组损失边界 "weight": 0.8, # 对比损失权重 "repr_layer": -1 # 使用最后一层隐藏状态 } # 初始化数据收集器 data_collator = ContrastiveDataCollator( tokenizer=tokenizer, max_length=512, padding="max_length" ) # 初始化训练器 trainer = ContrastiveTrainer( model=model, args=training_args, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=val_dataset, contrastive_config=contrastive_config ) # 开始训练 trainer.train() # 评估模型 eval_results = trainer.evaluate() print("评估结果:", eval_results) ``` ### 关键修复点解释 1. **数据加载修复**: ```python def load_and_convert_dataset(file_path): dataset = load_dataset('json', data_files=file_path)['train'] def convert_token_ids(example): for key in ['anchor_input_ids', ...]: if isinstance(example[key], str): try: example[key] = json.loads(example[key]) except: example[key] = ast.literal_eval(example[key]) return dataset.map(convert_token_ids) ``` - 自动检测字符串格式的token IDs - 尝试两种解析方法:JSON格式和Python列表格式 - 确保最终token IDs是整数列表 2. **数据收集器增强**: ```python def convert_to_int_list(input_ids): if isinstance(input_ids, str): try: return json.loads(input_ids) except: return ast.literal_eval(input_ids) return input_ids ``` - 在数据收集器内部添加类型检查和转换 - 双重解析机制确保兼容不同格式 - 错误处理防止解析失败导致程序崩溃 3. **数据验证**: ```python print("训练集样本示例:", train_dataset[0]) print("验证集样本示例:", val_dataset[0]) ``` - 训练前打印样本验证数据格式 - 确保token IDs已正确转换为整数列表 ### 预防措施 1. **数据预处理脚本**: 在创建三元组数据集时,确保正确保存为整数列表: ```python # 正确保存token IDs with open('data.json', 'w') as f: json.dump({ "anchor_input_ids": [1, 2, 3], # 整数列表 "positive_input_ids": [4, 5, 6], "negative_input_ids": [7, 8, 9] }, f) # 错误方式(会导致问题): json.dump({ "anchor_input_ids": str([1, 2, 3]) # 字符串 }, f) ``` 2. **数据集检查工具**: ```python def check_dataset(dataset): for i in range(min(5, len(dataset))): sample = dataset[i] for key in ['anchor_input_ids', ...]: if not isinstance(sample[key], list): raise TypeError(f"样本{i}的{key}不是列表") if not all(isinstance(x, int) for x in sample[key]): raise TypeError(f"样本{i}的{key}包含非整数元素") ``` 3. **使用Parquet格式**: 考虑使用Parquet格式替代JSON,它能更好地保留数据类型: ```python # 保存 dataset.to_parquet('data.parquet') # 加载 dataset = load_dataset('parquet', data_files='data.parquet') ```
阅读全文

相关推荐

最新推荐

recommend-type

【网络工程】OSPF协议.docx

【网络工程】OSPF协议.docx
recommend-type

【培训课件】网络工程师—第3章ip地址及其规划.ppt

【培训课件】网络工程师—第3章ip地址及其规划.ppt
recommend-type

《基于人工神经网络的压力传感器的温度补偿办法》.ppt

《基于人工神经网络的压力传感器的温度补偿办法》.ppt
recommend-type

《Spring-in-china》Seasons-PPT课件【品质课件PPT】.pptx

《Spring-in-china》Seasons-PPT课件【品质课件PPT】.pptx
recommend-type

C++实现的DecompressLibrary库解压缩GZ文件

根据提供的文件信息,我们可以深入探讨C++语言中关于解压缩库(Decompress Library)的使用,特别是针对.gz文件格式的解压过程。这里的“lib”通常指的是库(Library),是软件开发中用于提供特定功能的代码集合。在本例中,我们关注的库是用于处理.gz文件压缩包的解压库。 首先,我们要明确一个概念:.gz文件是一种基于GNU zip压缩算法的压缩文件格式,广泛用于Unix、Linux等操作系统上,对文件进行压缩以节省存储空间或网络传输时间。要解压.gz文件,开发者需要使用到支持gzip格式的解压缩库。 在C++中,处理.gz文件通常依赖于第三方库,如zlib或者Boost.IoStreams。codeproject.com是一个提供编程资源和示例代码的网站,程序员可以在该网站上找到现成的C++解压lib代码,来实现.gz文件的解压功能。 解压库(Decompress Library)提供的主要功能是读取.gz文件,执行解压缩算法,并将解压缩后的数据写入到指定的输出位置。在使用这些库时,我们通常需要链接相应的库文件,这样编译器在编译程序时能够找到并使用这些库中定义好的函数和类。 下面是使用C++解压.gz文件时,可能涉及的关键知识点: 1. Zlib库 - zlib是一个用于数据压缩的软件库,提供了许多用于压缩和解压缩数据的函数。 - zlib库支持.gz文件格式,并且在多数Linux发行版中都预装了zlib库。 - 在C++中使用zlib库,需要包含zlib.h头文件,同时链接z库文件。 2. Boost.IoStreams - Boost是一个提供大量可复用C++库的组织,其中的Boost.IoStreams库提供了对.gz文件的压缩和解压缩支持。 - Boost库的使用需要下载Boost源码包,配置好编译环境,并在编译时链接相应的Boost库。 3. C++ I/O操作 - 解压.gz文件需要使用C++的I/O流操作,比如使用ifstream读取.gz文件,使用ofstream输出解压后的文件。 - 对于流操作,我们常用的是std::ifstream和std::ofstream类。 4. 错误处理 - 解压缩过程中可能会遇到各种问题,如文件损坏、磁盘空间不足等,因此进行适当的错误处理是必不可少的。 - 正确地捕获异常,并提供清晰的错误信息,对于调试和用户反馈都非常重要。 5. 代码示例 - 从codeproject找到的C++解压lib很可能包含一个或多个源代码文件,这些文件会包含解压.gz文件所需的函数或类。 - 示例代码可能会展示如何初始化库、如何打开.gz文件、如何读取并处理压缩数据,以及如何释放资源等。 6. 库文件的链接 - 编译使用解压库的程序时,需要指定链接到的库文件,这在不同的编译器和操作系统中可能略有不同。 - 通常,在编译命令中加入-l参数,比如使用g++的话可能是`g++ -o DecompressLibrary DecompressLibrary.cpp -lz`,其中`-lz`表示链接zlib库。 7. 平台兼容性 - 在不同平台上使用解压库可能需要考虑平台兼容性问题。 - Windows系统可能需要额外的配置和库文件,因为zlib或其他库可能不是默认预装的。 根据以上知识点,我们可以得出,在C++中解压.gz文件主要涉及到对zlib或类似库的使用,以及熟悉C++的I/O操作。正确使用这些库,能够有效地对压缩文件进行解压,并处理可能出现的错误情况。如果从codeproject获取到的C++解压lib确实是针对.gz文件格式的,那么它很可能已经封装好了大部分的操作细节,让开发者能够以更简单的方式实现解压功能。
recommend-type

【数据融合技术】:甘肃土壤类型空间分析中的专业性应用

# 摘要 数据融合技术作为一种集成多源数据信息的方法,在土壤类型空间分析中发挥着关键作用。本文介绍了数据融合技术的基本概念及其理论基础,阐述了数据预处理、同步整合及冲突解决等关键技术,并详细描述了甘肃土壤类型数据准备的流程,包括数据采集、质量评估、空间化处理及融合实践准备。通过具体案例分析,展示了数据融合在土壤类型空间分布分析、土壤质量评估及土壤保护规划中的应用。同时,文
recommend-type

VM ware如何查看软件版本信息

<think>首先,用户的问题是:“我想查找VMware软件版本信息的查看方法请问如何在VMware中查看软件版本信息”。用户参考了之前的引用,但那些引用主要是关于查看VMware版本的一般方法,现在用户具体想知道在VMware软件中如何查看版本信息。回顾之前的对话:-用户第一次问的是VSCode的版本查看方法。-现在用户问的是VMware的版本查看方法。-系统级指令要求使用中文回答,正确格式化数学表达式(如果需要),但这里可能不需要数学表达式。-指令还要求生成相关问题,并在回答中引用段落时添加引用标识。用户提供的引用[1]到[5]是关于VMware版本的查看方法、下载等,但用户特别强调“参考
recommend-type

数据库课程设计报告:常用数据库综述

数据库是现代信息管理的基础,其技术广泛应用于各个领域。在高等教育中,数据库课程设计是一个重要环节,它不仅是学习理论知识的实践,也是培养学生综合运用数据库技术解决问题能力的平台。本知识点将围绕“经典数据库课程设计报告”展开,详细阐述数据库的基本概念、课程设计的目的和内容,以及在设计报告中常用的数据库技术。 ### 1. 数据库基本概念 #### 1.1 数据库定义 数据库(Database)是存储在计算机存储设备中的数据集合,这些数据集合是经过组织的、可共享的,并且可以被多个应用程序或用户共享访问。数据库管理系统(DBMS)提供了数据的定义、创建、维护和控制功能。 #### 1.2 数据库类型 数据库按照数据模型可以分为关系型数据库(如MySQL、Oracle)、层次型数据库、网状型数据库、面向对象型数据库等。其中,关系型数据库因其简单性和强大的操作能力而广泛使用。 #### 1.3 数据库特性 数据库具备安全性、完整性、一致性和可靠性等重要特性。安全性指的是防止数据被未授权访问和破坏。完整性指的是数据和数据库的结构必须符合既定规则。一致性保证了事务的执行使数据库从一个一致性状态转换到另一个一致性状态。可靠性则保证了系统发生故障时数据不会丢失。 ### 2. 课程设计目的 #### 2.1 理论与实践结合 数据库课程设计旨在将学生在课堂上学习的数据库理论知识与实际操作相结合,通过完成具体的数据库设计任务,加深对数据库知识的理解。 #### 2.2 培养实践能力 通过课程设计,学生能够提升分析问题、设计解决方案以及使用数据库技术实现这些方案的能力。这包括需求分析、概念设计、逻辑设计、物理设计、数据库实现、测试和维护等整个数据库开发周期。 ### 3. 课程设计内容 #### 3.1 需求分析 在设计报告的开始,需要对项目的目标和需求进行深入分析。这涉及到确定数据存储需求、数据处理需求、数据安全和隐私保护要求等。 #### 3.2 概念设计 概念设计阶段要制定出数据库的E-R模型(实体-关系模型),明确实体之间的关系。E-R模型的目的是确定数据库结构并形成数据库的全局视图。 #### 3.3 逻辑设计 基于概念设计,逻辑设计阶段将E-R模型转换成特定数据库系统的逻辑结构,通常是关系型数据库的表结构。在此阶段,设计者需要确定各个表的属性、数据类型、主键、外键以及索引等。 #### 3.4 物理设计 在物理设计阶段,针对特定的数据库系统,设计者需确定数据的存储方式、索引的具体实现方法、存储过程、触发器等数据库对象的创建。 #### 3.5 数据库实现 根据物理设计,实际创建数据库、表、视图、索引、触发器和存储过程等。同时,还需要编写用于数据录入、查询、更新和删除的SQL语句。 #### 3.6 测试与维护 设计完成之后,需要对数据库进行测试,确保其满足需求分析阶段确定的各项要求。测试过程包括单元测试、集成测试和系统测试。测试无误后,数据库还需要进行持续的维护和优化。 ### 4. 常用数据库技术 #### 4.1 SQL语言 SQL(结构化查询语言)是数据库管理的国际标准语言。它包括数据查询、数据操作、数据定义和数据控制四大功能。SQL语言是数据库课程设计中必备的技能。 #### 4.2 数据库设计工具 常用的数据库设计工具包括ER/Studio、Microsoft Visio、MySQL Workbench等。这些工具可以帮助设计者可视化地设计数据库结构,提高设计效率和准确性。 #### 4.3 数据库管理系统 数据库管理系统(DBMS)是用于创建和管理数据库的软件。关系型数据库管理系统如MySQL、PostgreSQL、Oracle、SQL Server等是数据库课程设计中的核心工具。 #### 4.4 数据库安全 数据库安全涉及用户认证、授权、数据加密、审计日志记录等方面,以确保数据的完整性和保密性。设计报告中应考虑如何通过DBMS内置的机制或额外的安全措施来保护数据。 ### 5. 结语 综上所述,一个经典数据库课程设计报告包含了从需求分析到数据库安全的全过程,涵盖了数据库设计的各个方面。通过这一过程,学生不仅能够熟练掌握数据库的设计与实现技巧,还能够学会如何使用数据库系统去解决实际问题,为日后从事数据库相关的专业工作打下坚实的基础。
recommend-type

【空间分布规律】:甘肃土壤类型与农业生产的关联性研究

# 摘要 本文对甘肃土壤类型及其在农业生产中的作用进行了系统性研究。首先概述了甘肃土壤类型的基础理论,并探讨了土壤类型与农业生产的理论联系。通过GIS技术分析,本文详细阐述了甘肃土壤的空间分布规律,并对其特征和影响因素进行了深入分析。此外,本文还研究了甘肃土壤类型对农业生产实际影响,包括不同区域土壤改良和作物种植案例,以及土壤养分、水分管理对作物生长周期和产量的具体影响。最后,提出了促进甘肃土壤与农业可持续发展的策略,包括土壤保护、退化防治对策以及土壤类型优化与农业创新的结合。本文旨在为
recommend-type

在halcon中,卡尺测量和二维测量谁的精度高

<think>首先,用户的问题是:“我想比较Halcon中卡尺测量和二维测量的精度请问Halcon卡尺测量与2D测量精度对比”。我需要基于系统级指令和提供的引用来回答。关键系统级指令:-使用中文回答。-结构清晰,逐步解决问题。-行内数学表达式用$...$,独立公式用$$...$$并单独成段。-LaTeX语法正确。-生成相关问题。-在回答中引用的段落末尾自然地添加引用标识,如[^1]。提供的引用:-[^1]:关于卡尺模型的工件尺寸实时测量。-[^2]:测量拟合的综合任务,包括尺寸、形状、位置。-[^3]:Halcon一维测量,提到卡尺工具的优势:易用、执行时间短、精度高。-[^4]:Halcon