一、加载 TFLite 模型
const std::string model_path = "./best_saved_model/best_float16.tflite";
std::unique_ptr<tflite::FlatBufferModel> model = tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
if (!model) {
qCritical() << "加载TFLite模型失败: " << QString::fromStdString(model_path);
return -1;
}
二、构建解释器
tflite::ops::builtin::BuiltinOpResolver resolver;
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
if (!interpreter) {
qCritical() << "构建解释器失败。";
return -1;
}
三、为模型的输入/输出张量分配内存
interpreter->AllocateTensors();
四、获取所有输入张量的索引列表
std::vector<int>& input_tensors_indices = interpreter->inputs();
五、获取所有输入张量的总数:
interpreter->input_tensors_indices.size();(上面定义的vector)
六、加载图像并整理
std::string image_path = "1001.jpg";
cv::Mat image = cv::imread(image_path);
if (image.empty()) {
qCritical() << "加载图片失败: " << QString::fromStdString(image_path);
return -1;
}
cv::Mat image_rgb;
cv::cvtColor(image, image_rgb, cv::COLOR_BGR2RGB);
七、获取输入张量的维度
int current_input_tensor_idx=input_tensors_indices[i];//从vector数组中获取输入张量的索引号
TfLiteIntArray* input_dims = interpreter->tensor(current_input_tensor_idx)->dims;
八、获取输入张量的数据类型
TfLiteType input_type = interpreter->tensor(current_input_tensor_idx)->type;
九、获取输入张量的高和宽
// 假设输入张量的形状至少是 [batch, height, width, channels]
// 并且高度和宽度在索引 1 和 2
int input_height = input_dims->data[1];
int input_width = input_dims->data[2];
十、 为当前输入张量准备数据并复制
if (input_type == kTfLiteFloat32) {
cv::Mat float_input;
image_resized.convertTo(float_input, CV_32FC3, 1.0 / 255.0);
memcpy(interpreter->typed_input_tensor<float>(current_input_tensor_idx), float_input.data, input_height * input_width * 3 * sizeof(float));
} else if (input_type == kTfLiteUInt8) {
memcpy(interpreter->typed_input_tensor<uint8_t>(current_input_tensor_idx), image_resized.data, input_height * input_width * 3 * sizeof(uint8_t));
} else {
qCritical() << "不支持的输入张量类型,无法为第" << i << "个输入张量准备数据。";
return;
}
十一、 推理
if (interpreter->Invoke() != kTfLiteOk) {
qCritical() << "错误:模型推理失败。请检查模型、输入或TFLite版本兼容性。";
return -1;
}
qDebug() << "模型推理成功完成。";
--------------后处理(与单输入模型相同)-------------
十二、获取输出张量的索引:
// 假设只有一个输出张量,或者你只关心第一个输出
int output_tensor_idx = interpreter->outputs()[0];
十三、获取(创建)输出张量对象
TfLiteTensor* output_tensor = interpreter->tensor(output_tensor_idx);
if (!output_tensor) {
qCritical() << "错误:interpreter->tensor() 返回空指针。输出张量对象未成功获取。";
return -1;
}
十四、获取输出张量数据类型
TfLiteType output_type = output_tensor->type;
qDebug() << "输出张量实际数据类型 (TfLiteType): " << output_type << " (期望: kTfLiteFloat32)";
十五、获取输出张量指针
float* output_data_ptr = interpreter->typed_output_tensor<float>(output_tensor_idx);
if (!output_data_ptr) {
qCritical() << "错误:获取输出张量指针失败,指针为空。";
return -1;
}
qDebug() << "成功获取输出张量指针。";
十六、获取输出张量维度
TfLiteIntArray* output_dims = output_tensor->dims;
if (!output_dims) {
qCritical() << "错误:无法获取输出张量维度信息。";
return -1;
}
十七、复制输出数据到 std::vector
std::vector<float> predictions;
try {
predictions.assign(output_data_ptr, output_data_ptr + output_size);
qDebug() << "成功将输出数据复制到 std::vector。vector 大小: " << predictions.size();
} catch (const std::bad_alloc& e) {
qCritical() << "内存分配失败 (std::bad_alloc): " << e.what();
return -1;
} catch (const std::exception& e) {
qCritical() << "复制数据到 vector 时发生异常: " << e.what();
return -1;
} catch (...) {
qCritical() << "复制数据到 vector 时发生未知异常。";
return -1;
}
十八、检测结果输出
float conf_threshold = 0.5f;
std::vector<cv::Rect> boxes;
std::vector<float> confidences;
std::vector<int> class_ids;
int num_detections = output_dims->data[1];
int num_classes = output_dims->data[2] - 4;
int stride = output_dims->data[2];
qDebug() << "后处理参数 -> num_detections: " << num_detections
<< ", num_classes: " << num_classes
<< ", stride: " << stride;
if (num_classes != 7) {
qWarning() << "警告:模型报告的类别数与预期 (7) 不符。实际类别数: " << num_classes;
}
for (int i = 0; i < num_detections; ++i) {
const float* detection = &predictions[i * stride];
float x_center = detection[0];
float y_center = detection[1];
float width = detection[2];
float height = detection[3];
float max_score = 0.0f;
int class_id = -1;
for (int k = 0; k < num_classes; ++k) {
float score = detection[4 + k];
if (score > max_score) {
max_score = score;
class_id = k;
}
}
if (max_score > conf_threshold) {
int x1 = static_cast<int>((x_center - width / 2) * image.cols);
int y1 = static_cast<int>((y_center - height / 2) * image.rows);
int x2 = static_cast<int>((x_center + width / 2) * image.cols);
int y2 = static_cast<int>((y_center + height / 2) * image.rows);
boxes.push_back(cv::Rect(x1, y1, x2 - x1, y2 - y1));
confidences.push_back(max_score);
class_ids.push_back(class_id);
}
}
const std::vector<std::string> class_names = {"bus", "car", "human", "motobike", "rect", "square", "truck"};
float max_overall_confidence = 0.0f;
int best_detection_idx = -1;
for (int i = 0; i < confidences.size(); ++i) {
int current_class_id = class_ids[i];
if (current_class_id >= 0 && current_class_id < class_names.size()) {
if (confidences[i] > max_overall_confidence) {
max_overall_confidence = confidences[i];
best_detection_idx = i;
}
} else {
qWarning() << "发现无效的类别ID,将跳过此检测框: " << current_class_id;
}
}
if (best_detection_idx != -1) {
int idx = best_detection_idx;
cv::Rect box = boxes[idx];
float confidence = confidences[idx];
int class_id = class_ids[idx];
cv::rectangle(image, box, cv::Scalar(0, 255, 0), 2);
std::string label = class_names[class_id] + ": " + std::to_string(confidence * 100).substr(0, 5) + "%";
int baseline = 0;
cv::Size label_size = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 2, &baseline);
cv::rectangle(image, cv::Point(box.x, box.y - label_size.height - 10),
cv::Point(box.x + label_size.width, box.y), cv::Scalar(0, 255, 0), cv::FILLED);
cv::putText(image, label, cv::Point(box.x, box.y - 5), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0), 2);
qDebug() << "最高置信度检测结果 -> 类别: " << QString::fromStdString(class_names[class_id])
<< ", 置信度: " << confidence
<< ", 框: (" << box.x << ", " << box.y << ", " << box.width << ", " << box.height << ")";
} else {
qDebug() << "未找到任何高于阈值的有效类别检测结果。";
}
std::string output_image_path = "detection_result.jpg";
bool success = cv::imwrite(output_image_path, image);
if (success) {
qDebug() << "检测结果已成功保存到: " << QString::fromStdString(output_image_path);
} else {
qCritical() << "错误:保存检测结果图像失败。请检查路径或文件权限。";
}
秋风写淄博,业务联系与技术交流:Q375172665