TensorFlow加载tflite模型推理基本步骤

一、加载 TFLite 模型

const std::string model_path = "./best_saved_model/best_float16.tflite";
std::unique_ptr<tflite::FlatBufferModel> model = tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
if (!model) {
    qCritical() << "加载TFLite模型失败: " << QString::fromStdString(model_path);
    return -1;
}

二、构建解释器

 tflite::ops::builtin::BuiltinOpResolver resolver;
    std::unique_ptr<tflite::Interpreter> interpreter;
    tflite::InterpreterBuilder(*model, resolver)(&interpreter);
    if (!interpreter) {
        qCritical() << "构建解释器失败。";
        return -1;
    }

三、为模型的输入/输出张量分配内存

 interpreter->AllocateTensors();

四、获取所有输入张量的索引列表

std::vector<int>& input_tensors_indices = interpreter->inputs();

五、获取所有输入张量的总数:

interpreter->input_tensors_indices.size();(上面定义的vector)

六、加载图像并整理

 std::string image_path = "1001.jpg";
    cv::Mat image = cv::imread(image_path);
    if (image.empty()) {
        qCritical() << "加载图片失败: " << QString::fromStdString(image_path);
        return -1;
    }
    cv::Mat image_rgb;
    cv::cvtColor(image, image_rgb, cv::COLOR_BGR2RGB);

七、获取输入张量的维度

 int current_input_tensor_idx=input_tensors_indices[i];//从vector数组中获取输入张量的索引号
    TfLiteIntArray* input_dims = interpreter->tensor(current_input_tensor_idx)->dims;

八、获取输入张量的数据类型

  TfLiteType input_type = interpreter->tensor(current_input_tensor_idx)->type;

九、获取输入张量的高和宽

// 假设输入张量的形状至少是 [batch, height, width, channels]
// 并且高度和宽度在索引 1 和 2
int input_height = input_dims->data[1];
int input_width = input_dims->data[2];

十、 为当前输入张量准备数据并复制

 if (input_type == kTfLiteFloat32) {
        cv::Mat float_input;
        image_resized.convertTo(float_input, CV_32FC3, 1.0 / 255.0);
        memcpy(interpreter->typed_input_tensor<float>(current_input_tensor_idx), float_input.data, input_height * input_width * 3 * sizeof(float));
    } else if (input_type == kTfLiteUInt8) {
        memcpy(interpreter->typed_input_tensor<uint8_t>(current_input_tensor_idx), image_resized.data, input_height * input_width * 3 * sizeof(uint8_t));
    } else {
        qCritical() << "不支持的输入张量类型,无法为第" << i << "个输入张量准备数据。";
        return;
    }

十一、 推理

  if (interpreter->Invoke() != kTfLiteOk) {
        qCritical() << "错误:模型推理失败。请检查模型、输入或TFLite版本兼容性。";
        return -1;
    }
    qDebug() << "模型推理成功完成。";

--------------后处理(与单输入模型相同)-------------
十二、获取输出张量的索引:

// 假设只有一个输出张量,或者你只关心第一个输出
int output_tensor_idx = interpreter->outputs()[0];

十三、获取(创建)输出张量对象

TfLiteTensor* output_tensor = interpreter->tensor(output_tensor_idx);
if (!output_tensor) {
    qCritical() << "错误:interpreter->tensor() 返回空指针。输出张量对象未成功获取。";
    return -1;
}

十四、获取输出张量数据类型

TfLiteType output_type = output_tensor->type;
    qDebug() << "输出张量实际数据类型 (TfLiteType): " << output_type << " (期望: kTfLiteFloat32)";

十五、获取输出张量指针

  float* output_data_ptr = interpreter->typed_output_tensor<float>(output_tensor_idx);
    if (!output_data_ptr) {
        qCritical() << "错误:获取输出张量指针失败,指针为空。";
        return -1;
    }
    qDebug() << "成功获取输出张量指针。";

十六、获取输出张量维度

 TfLiteIntArray* output_dims = output_tensor->dims;
    if (!output_dims) {
        qCritical() << "错误:无法获取输出张量维度信息。";
        return -1;
    }

十七、复制输出数据到 std::vector

 std::vector<float> predictions;
    try {
        predictions.assign(output_data_ptr, output_data_ptr + output_size);
        qDebug() << "成功将输出数据复制到 std::vector。vector 大小: " << predictions.size();
    } catch (const std::bad_alloc& e) {
        qCritical() << "内存分配失败 (std::bad_alloc): " << e.what();
        return -1;
    } catch (const std::exception& e) {
        qCritical() << "复制数据到 vector 时发生异常: " << e.what();
        return -1;
    } catch (...) {
        qCritical() << "复制数据到 vector 时发生未知异常。";
        return -1;
    }

十八、检测结果输出

   float conf_threshold = 0.5f;
    std::vector<cv::Rect> boxes;
    std::vector<float> confidences;
    std::vector<int> class_ids;

    int num_detections = output_dims->data[1];
    int num_classes = output_dims->data[2] - 4;
    int stride = output_dims->data[2];

    qDebug() << "后处理参数 -> num_detections: " << num_detections
             << ", num_classes: " << num_classes
             << ", stride: " << stride;

    if (num_classes != 7) {
         qWarning() << "警告:模型报告的类别数与预期 (7) 不符。实际类别数: " << num_classes;
    }

    for (int i = 0; i < num_detections; ++i) {
        const float* detection = &predictions[i * stride];

        float x_center = detection[0];
        float y_center = detection[1];
        float width = detection[2];
        float height = detection[3];

        float max_score = 0.0f;
        int class_id = -1;

        for (int k = 0; k < num_classes; ++k) {
            float score = detection[4 + k];
            if (score > max_score) {
                max_score = score;
                class_id = k;
            }
        }

        if (max_score > conf_threshold) {
            int x1 = static_cast<int>((x_center - width / 2) * image.cols);
            int y1 = static_cast<int>((y_center - height / 2) * image.rows);
            int x2 = static_cast<int>((x_center + width / 2) * image.cols);
            int y2 = static_cast<int>((y_center + height / 2) * image.rows);

            boxes.push_back(cv::Rect(x1, y1, x2 - x1, y2 - y1));
            confidences.push_back(max_score);
            class_ids.push_back(class_id);
        }
    }

    const std::vector<std::string> class_names = {"bus", "car", "human", "motobike", "rect", "square", "truck"};

    float max_overall_confidence = 0.0f;
    int best_detection_idx = -1;

    for (int i = 0; i < confidences.size(); ++i) {
        int current_class_id = class_ids[i];
        if (current_class_id >= 0 && current_class_id < class_names.size()) {
            if (confidences[i] > max_overall_confidence) {
                max_overall_confidence = confidences[i];
                best_detection_idx = i;
            }
        } else {
            qWarning() << "发现无效的类别ID,将跳过此检测框: " << current_class_id;
        }
    }

    if (best_detection_idx != -1) {
        int idx = best_detection_idx;
        cv::Rect box = boxes[idx];
        float confidence = confidences[idx];
        int class_id = class_ids[idx];

        cv::rectangle(image, box, cv::Scalar(0, 255, 0), 2);

        std::string label = class_names[class_id] + ": " + std::to_string(confidence * 100).substr(0, 5) + "%";
        int baseline = 0;
        cv::Size label_size = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 2, &baseline);
        cv::rectangle(image, cv::Point(box.x, box.y - label_size.height - 10),
                      cv::Point(box.x + label_size.width, box.y), cv::Scalar(0, 255, 0), cv::FILLED);
        cv::putText(image, label, cv::Point(box.x, box.y - 5), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0), 2);

        qDebug() << "最高置信度检测结果 -> 类别: " << QString::fromStdString(class_names[class_id])
                 << ", 置信度: " << confidence
                 << ", 框: (" << box.x << ", " << box.y << ", " << box.width << ", " << box.height << ")";

    } else {
        qDebug() << "未找到任何高于阈值的有效类别检测结果。";
    }

    std::string output_image_path = "detection_result.jpg";
    bool success = cv::imwrite(output_image_path, image);

    if (success) {
        qDebug() << "检测结果已成功保存到: " << QString::fromStdString(output_image_path);
    } else {
        qCritical() << "错误:保存检测结果图像失败。请检查路径或文件权限。";
    }

秋风写淄博,业务联系与技术交流:Q375172665

加载tflite模型进行图片识别可以通过以下步骤实现: 1. 准备模型文件 首先需要准备好tflite模型文件。可以从TensorFlow官网下载已经训练好的模型,或者自己训练一个模型并转换为tflite格式。 2. 加载模型 使用TensorFlow.js的`tf.lite.loadModel()`方法加载tflite模型文件。 ```javascript const model = await tf.lite.loadModel('model.tflite'); ``` 3. 加载图片 使用JavaScript的`Image`对象或者`HTMLCanvasElement`对象加载需要识别的图片。 ```javascript const image = new Image(); image.src = 'image.jpg'; await image.decode(); const canvas = document.createElement('canvas'); canvas.width = image.width; canvas.height = image.height; const context = canvas.getContext('2d'); context.drawImage(image, 0, 0, image.width, image.height); const imageData = context.getImageData(0, 0, image.width, image.height); ``` 4. 预处理图片数据 将图片数据转换为模型可以接受的格式。通常需要将像素值归一化到0到1之间,并且将图片数据转换为张量。 ```javascript const tensor = tf.browser.fromPixels(imageData) .resizeNearestNeighbor([224, 224]) .toFloat() .sub(255 / 2) .div(255 / 2) .expandDims(); ``` 5. 进行推理 调用模型的`predict()`方法进行推理,并且获取预测结果。 ```javascript const output = model.predict(tensor); const predictions = output.dataSync(); ``` 6. 处理预测结果 根据模型的输出,处理预测结果并进行展示。 ```javascript // 假设模型是一个分类模型,输出是一个长度为1000的数组,每个元素表示一个类别的概率 const topK = 10; // 取前10个概率最大的类别 const topIndices = tf.topk(output, topK).indices.dataSync(); const topProbabilities = tf.topk(output, topK).values.dataSync(); for (let i = 0; i < topIndices.length; i++) { console.log(`类别: ${topIndices[i]}, 概率: ${topProbabilities[i]}`); } ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值