【大模型开发】多模态大模型的训练和推理

下面给出一个多模态大模型图像描述(Image Captioning)任务上的可执行示例,演示如何使用 Hugging Face 的开源工具、公开数据集以及一个代表性的多模态模型(如 BLIP-2)来完成基础训练和推理。本示例可以在单卡 GPU 或 CPU 上运行(GPU 推荐),但请注意完整训练需要大量资源,示例中仅演示小规模训练流程。


目录

  1. 环境准备
  2. 数据集介绍与加载
  3. 模型与处理器加载
  4. 训练流程
  5. 推理流程
  6. 优化方向与未来建议

1. 环境准备

# 建议 Python 3.8+ 环境
pip install torch torchvision
pip install transformers datasets accelerate
  • 如果想在 GPU 上训练,请确保已安装对应 CUDA 版本的 PyTorch(例如 pip install torch --extra-index-url https://download.pytorch.org/whl/cu118)。
  • 本示例使用的库版本参考(可能随时间升级):
    • torch >= 1.13
    • transformers >= 4.26
    • datasets >= 2.6
    • accelerate >= 0.15

2. 数据集介绍与加载

在多模态训练中,常用的公开数据集有 MS COCO Captions、Flickr30k、Visual Genome 等。这里我们使用 MS COCO Captions(2017 版本)做一个小规模示例,只取训练集的一部分(如 1%)来节省时间。

  • MS COCO Captions:包含图像与 5 条英文描述。
  • 真实训练建议使用全部数据(约 118K+ 图像)。
from datasets import load_dataset

# 加载 COCO Captions 2017,train[:1%] 仅取 1% 数据演示
dataset = load_dataset("coco_captions", "2017", split="train[:1%]")

print(dataset)
# 示例输出: Dataset({
#     features: ['image', 'annotations'],
#     num_rows: 118...
# })

dataset 中的字段:

  • image:包含实际图像对象(可用 PIL 访问)。
  • annotations:列表形式,每个元素带有 "caption" 字段。

每条数据里通常有 5 条 caption,这里我们只使用第一条 caption 进行训练演示。


3. 模型与处理器加载

在 Hugging Face 上,你可以找到许多多模态/图文模型,如 BLIP, BLIP-2, Flamingo, OFA 等。下面以 BLIP-2 为例进行示范(它能够将视觉特征编码后输入语言大模型进行生成),也可切换到其他多模态模型。

import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration

# 可选择不同的权重,如 "Salesforce/blip2-opt-2.7b" 等
model_name = "Salesforce/blip2-opt-2.7b"  

# 加载多模态处理器(同时处理图像与文本)
processor = Blip2Processor.from_pretrained(model_name)

# 加载多模态模型
model = Blip2ForConditionalGeneration.from_pretrained(model_name)

如有 GPU,请执行:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

4. 训练流程

下面示例演示自定义训练循环,我们将图像与它的描述(caption)作为输入与标签,让模型学会“看图说话”。由于我们只取了 1% 的数据,并且只跑 1-2 个 epoch,主要目的是演示流程,真实场景需根据任务需求适当加大数据量与训练轮次。

4.1 数据打包与 Dataloader

from torch.utils.data import Dataset, DataLoader
import random

class CocoCaptionDataset(Dataset):
    def __init__(self, hf_dataset, processor):
        self.hf_dataset = hf_dataset
        self.processor = processor

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        sample = self.hf_dataset[idx]
        image = sample["image"]   # PIL 图片对象
        # annotations 是多个 caption,这里仅用第一个
        caption = sample["annotations"][0]["caption"] if sample["annotations"] else ""

        return image, caption

def collate_fn(batch):
    """
    将一批 (image, caption) 用 processor 打包成可输入模型的格式
    """
    images = [item[0] for item in batch]
    captions = [item[1] for item in batch]

    # BLIP-2 的处理器可以一次性处理多张图像和文本
    encoding = processor(
        images=images,
        text=captions,
        return_tensors="pt",
        padding=True
    )
    return encoding

# 构建 Dataset & Dataloader
train_dataset = CocoCaptionDataset(dataset, processor)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

4.2 训练循环

以下为简化版训练示例:

import torch
import torch.nn.functional as F
from tqdm import tqdm

# 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

model.train()
epochs = 1  # 演示用,实际可更多

for epoch in range(epochs):
    print(f"=== Epoch {epoch+1}/{epochs} ===")
    total_loss = 0.0

    for batch in tqdm(train_dataloader):
        # 将 batch 数据移动到 GPU(如果可用)
        for k, v in batch.items():
            batch[k] = v.to(device)

        # 前向计算
        outputs = model(**batch)
        loss = outputs.loss

        # 反向传播 & 更新
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)
    print(f"Train Loss: {avg_loss:.4f}")

# (可选)保存微调后权重
# model.save_pretrained("blip2_coco_finetuned")
# processor.save_pretrained("blip2_coco_finetuned")
  • loss:对图像-文本生成进行交叉熵损失。
  • 实际训练中,需要更多 epoch,并可切分验证集做校验(或使用 HuggingFace Trainer)。
  • lrbatch_sizeepochs 等都需根据硬件与实际任务情况调参。

5. 推理流程

下面我们使用训练/预训练好的 BLIP-2 模型,对一张本地图进行图像描述

import requests
from PIL import Image

def infer_image_caption(model, processor, image_path_or_url, device, max_new_tokens=30):
    """
    输入:模型、处理器,图像路径或URL
    输出:模型生成的描述文本
    """
    if image_path_or_url.startswith("http"):
        image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert("RGB")
    else:
        image = Image.open(image_path_or_url).convert("RGB")

    inputs = processor(images=image, return_tensors="pt").to(device)
    
    # BLIP-2 中的 generate 会根据视觉特征+默认prompt 生成描述
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)

    # decode
    generated_text = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_text

# 示例:推理
sample_image_url = "https://raw.githubusercontent.com/salesforce/BLIP/main/assets/example_image.jpg"
caption = infer_image_caption(model, processor, sample_image_url, device)
print("生成的图像描述:", caption)
  • 如果使用本地图片,直接指定 image_path_or_url="test.jpg" 即可。
  • 也可以给模型一个文本Prompt来做图片问答,如 processor(images=image, text="Question: How many cats are there? Answer:", return_tensors="pt"),再用 model.generate(...) 获得答案。

6. 优化方向与未来建议

  1. 完整数据训练

    • 示例只用了 1% 数据,完整训练需下载 MS COCO 全量数据(或其他多模态数据),确保多样性和大规模性来提升模型性能。
  2. 更多下游任务

    • 不仅是图像描述,也可以进行 视觉问答(VQA)图文检索图文生成 等多模态场景。对应需要调整训练目标与数据标注方式。
  3. 更大模型 / 更强算力

    • 如果需求场景需要更高的准确率,可使用更大预训练模型(如 blip2-flan-t5-xxl,或阿里开源的更大多模态 Qwen),并在多 GPU / 分布式环境中加速训练。
  4. 轻量化与部署

    • 模型量化(8-bit / 4-bit)或蒸馏,以减少推理时的显存占用;
    • 使用 Hugging Face Optimum / ONNX Runtime / TensorRT 等方案进行推理加速。
  5. 多语言、多模态扩展

    • 针对中文场景,可以换用中文多模态数据
    • 支持更多模态(视频、音频)时,需要相应的处理器和预训练数据。

总结

以上演示了一个可执行的多模态大模型基础训练推理流程,使用了 Hugging Face 的 BLIP-2 作为示例模型、COCO Captions 作为示例数据集。通过此示例,你可以:

  • 了解如何将图像与文本组成对儿输入模型并执行自定义训练循环
  • 使用 model.generate 进行图像描述问答推理;
  • 在此基础上扩展到更大规模、更多任务或更复杂的多模态应用(如问答、多语言场景、视频理解等)。

哈佛博后带小白玩转机器学习】 哔哩哔哩_bilibili

总课时超400+,时长75+小时

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值