tf.expand_dims()用法详解

本文详细介绍了如何在TensorFlow中使用tf.expand_dims函数在不同维度扩展numpy数组current,通过实例展示了扩展轴x的效果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

axis = x就是在第x维扩展,直接用例子来演示

current = np.array([
    [0, 7, 1, 2, 2],
    [1, 7, 3, 4, 3],
    [2, 7, 5, 6, 6],
    [3, 7, 7, 8, 7],
    [4, 7, 7, 8, 7],
    [5, 7, 7, 8, 7]
])

current = array(current)
current = tf.constant(current)

points_e0 = tf.expand_dims(current, axis=0)
points_e1 = tf.expand_dims(current, axis=1)
points_e2 = tf.expand_dims(current, axis=2)

with tf.Session() as sess :
     print(points_e0.eval())
     print(points_e1.eval())
     print(points_e2.eval())

输出结果

[[[0 7 1 2 2]
  [1 7 3 4 3]
  [2 7 5 6 6]
  [3 7 7 8 7]
  [4 7 7 8 7]
  [5 7 7 8 7]]]
[[[0 7 1 2 2]]
 [[1 7 3 4 3]]
 [[2 7 5 6 6]]
 [[3 7 7 8 7]]
 [[4 7 7 8 7]]
 [[5 7 7 8 7]]]
[[[0]
  [7]
  [1]
  [2]
  [2]]
 [[1]
  [7]
  [3]
  [4]
  [3]]
 [[2]
  [7]
  [5]
  [6]
  [6]]
 [[3]
  [7]
  [7]
  [8]
  [7]]
 [[4]
  [7]
  [7]
  [8]
  [7]]
 [[5]
  [7]
  [7]
  [8]
  [7]]]

import tensorflow as tf import numpy as np import cv2 import os import json from tqdm import tqdm class ObjectRecognitionDeployer: def __init__(self, model_path, class_labels): """ 初始化部署器 :param model_path: 模型文件路径 (Keras或TFLite) :param class_labels: 类别标签列表 """ self.class_labels = class_labels self.model_path = model_path self.interpreter = None self.input_details = None self.output_details = None # 根据模型类型加载 if model_path.endswith('.tflite'): self.load_tflite_model(model_path) else: self.model = tf.keras.models.load_model(model_path) self.input_shape = self.model.input_shape[1:3] def load_tflite_model(self, model_path): """加载并配置TFLite模型""" # 加载模型 self.interpreter = tf.lite.Interpreter(model_path=model_path) self.interpreter.allocate_tensors() # 获取输入输出详细信息 self.input_details = self.interpreter.get_input_details() self.output_details = self.interpreter.get_output_details() # 保存输入形状 self.input_shape = tuple(self.input_details[0]['shape'][1:3]) # 安全地打印模型元数据 self.print_model_metadata(model_path) def print_model_metadata(self, model_path): """安全地打印TFLite模型元数据""" try: from tflite_support import metadata displayer = metadata.MetadataDisplayer.with_model_file(model_path) print("--- 模型元数据 ---") print(displayer.get_metadata_json()) print("--- 关联文件 ---") print(displayer.get_packed_associated_file_list()) except (ImportError, ValueError) as e: print(f"警告: 无法获取模型元数据 - {str(e)}") print("使用输入/输出详细信息代替:") print(f"输入: {self.input_details}") print(f"输出: {self.output_details}") def preprocess_image(self, image, input_size, input_dtype=np.float32): """ 预处理图像 :param image: 输入图像 (numpy数组或文件路径) :param input_size: 模型输入尺寸 (height, width) :param input_dtype: 期望的输入数据类型 :return: 预处理后的图像张量 """ if isinstance(image, str): if not os.path.exists(image): raise FileNotFoundError(f"图像文件不存在: {image}") img = cv2.imread(image) if img is None: raise ValueError(f"无法读取图像: {image}") else: img = image # 调整尺寸和颜色空间 img = cv2.resize(img, (input_size[1], input_size[0])) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 根据数据类型进行归一化 if input_dtype == np.uint8: img = img.astype(np.uint8) # 量化模型使用uint8 else: # 浮点模型使用float32 img = img.astype(np.float32) / 255.0 # 添加批次维度 img = np.expand_dims(img, axis=0) return img def predict(self, image): """ 执行预测 :param image: 输入图像 (numpy数组或文件路径) :return: 预测结果 (类别名称, 置信度) """ if self.interpreter is not None: # TFLite模型推理 return self.predict_tflite(image) else: # Keras模型推理 return self.predict_keras(image) def predict_keras(self, image): """使用Keras模型预测""" # 预处理 img = self.preprocess_image(image, self.input_shape, np.float32) # 预测 predictions = self.model.predict(img, verbose=0)[0] class_idx = np.argmax(predictions) confidence = predictions[class_idx] class_name = self.class_labels[class_idx] return class_name, confidence def predict_tflite(self, image): """使用TFLite模型预测""" # 获取输入数据类型 input_dtype = self.input_details[0]['dtype'] # 预处理 img = self.preprocess_image(image, self.input_shape, input_dtype) # 设置输入张量 self.interpreter.set_tensor(self.input_details[0]['index'], img) # 执行推理 self.interpreter.invoke() # 获取输出 output_data = self.interpreter.get_tensor(self.output_details[0]['index']) predictions = output_data[0] # 解析结果 class_idx = np.argmax(predictions) confidence = predictions[class_idx] # 如果输出是量化数据,需要反量化 if self.output_details[0]['dtype'] == np.uint8: # 反量化输出 scale, zero_point = self.output_details[0]['quantization'] confidence = scale * (confidence - zero_point) class_name = self.class_labels[class_idx] return class_name, confidence def benchmark(self, image, runs=100): """ 模型性能基准测试 :param image: 测试图像 :param runs: 运行次数 :return: 平均推理时间(ms), 内存占用(MB) """ # 预热运行 self.predict(image) # 计时测试 start_time = tf.timestamp() for _ in range(runs): self.predict(image) end_time = tf.timestamp() avg_time_ms = (end_time - start_time).numpy() * 1000 / runs # 内存占用 if self.interpreter: # 计算输入张量内存占用 input_size = self.input_details[0]['shape'] dtype_size = np.dtype(self.input_details[0]['dtype']).itemsize mem_usage = np.prod(input_size) * dtype_size / (1024 * 1024) else: # 估算Keras模型内存 mem_usage = self.model.count_params() * 4 / (1024 * 1024) # 假设32位浮点数 return avg_time_ms, mem_usage def create_metadata(self, output_path): """ 创建并保存模型元数据文件 :param output_path: 元数据文件输出路径 """ metadata = { "model_type": "tflite" if self.model_path.endswith('.tflite') else "keras", "class_labels": self.class_labels, "input_size": self.input_shape, "input_dtype": str(self.input_details[0]['dtype']) if self.interpreter else "float32", "quantization": None } if self.interpreter and self.input_details[0]['dtype'] == np.uint8: metadata["quantization"] = { "input_scale": float(self.input_details[0]['quantization'][0]), "input_zero_point": int(self.input_details[0]['quantization'][1]), "output_scale": float(self.output_details[0]['quantization'][0]), "output_zero_point": int(self.output_details[0]['quantization'][1]) } with open(output_path, 'w') as f: json.dump(metadata, f, indent=4) return metadata def convert_to_tflite_with_metadata(self, output_path, quantize=False, representative_data_dir=None): """ 将Keras模型转换为TFLite格式并添加元数据 :param output_path: 输出TFLite文件路径 :param quantize: 是否进行量化 :param representative_data_dir: 代表性数据集目录 """ if not self.model_path.endswith(('.keras', '.h5')): raise ValueError("需要Keras模型格式进行转换") # 加载Keras模型 keras_model = tf.keras.models.load_model(self.model_path) # 创建转换器 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) if quantize: # 量化配置 converter.optimizations = [tf.lite.Optimize.DEFAULT] # 设置代表性数据集生成器 converter.representative_dataset = lambda: self.representative_dataset( representative_data_dir, input_size=self.input_shape ) # 设置输入输出类型 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.uint8 # 转换模型 tflite_model = converter.convert() # 保存模型 with open(output_path, 'wb') as f: f.write(tflite_model) print(f"TFLite模型已保存到: {output_path}") # 添加元数据 self.add_tflite_metadata(output_path) return output_path def representative_dataset(self, data_dir=None, input_size=(224, 224), num_samples=100): """ 生成代表性数据集用于量化 :param data_dir: 真实数据目录 :param input_size: 输入尺寸 (height, width) :param num_samples: 样本数量 """ # 优先使用真实数据 if data_dir and os.path.exists(data_dir): image_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] # 限制样本数量 image_files = image_files[:min(len(image_files), num_samples)] print(f"使用 {len(image_files)} 张真实图像进行量化校准") for img_path in tqdm(image_files, desc="量化校准"): try: # 读取并预处理图像 img = cv2.imread(img_path) if img is None: continue img = cv2.resize(img, (input_size[1], input_size[0])) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img.astype(np.float32) / 255.0 # 转换为float32并归一化 img = np.expand_dims(img, axis=0) yield [img] except Exception as e: print(f"处理图像 {img_path} 时出错: {str(e)}") else: # 使用随机数据作为备选 print(f"使用随机数据生成 {num_samples} 个样本进行量化校准") for _ in range(num_samples): # 生成随机图像,归一化到[0,1]范围,使用float32类型 data = np.random.rand(1, input_size[0], input_size[1], 3).astype(np.float32) yield [data] def add_tflite_metadata(self, model_path): """为TFLite模型添加元数据""" # 创建标签文件 labels_path = os.path.join(os.path.dirname(model_path), "labels.txt") with open(labels_path, 'w') as f: for label in self.class_labels: f.write(f"{label}\n") # 创建元数据 metadata_path = os.path.join(os.path.dirname(model_path), "metadata.json") self.create_metadata(metadata_path) print(f"元数据已创建: {metadata_path}") print(f"标签文件已创建: {labels_path}") # 使用示例 if __name__ == "__main__": # 类别标签 CLASS_LABELS = ['book', 'cup', 'glasses', 'phone', 'shoe'] # 初始化部署器 deployer = ObjectRecognitionDeployer( model_path='optimized_model.keras', class_labels=CLASS_LABELS ) # 转换为带元数据的TFLite格式 tflite_path = 'model_quantized.tflite' # 使用真实数据目录进行量化校准 REPRESENTATIVE_DATA_DIR = 'path/to/representative_dataset' # 替换为实际路径 deployer.convert_to_tflite_with_metadata( tflite_path, quantize=True, representative_data_dir=REPRESENTATIVE_DATA_DIR ) # 重新加载带元数据的模型 tflite_deployer = ObjectRecognitionDeployer( model_path=tflite_path, class_labels=CLASS_LABELS ) # 测试预测 test_image = 'test_image.jpg' class_name, confidence = tflite_deployer.predict(test_image) print(f"预测结果: {class_name}, 置信度: {confidence:.2f}") # 性能测试 avg_time, mem_usage = tflite_deployer.benchmark(test_image) print(f"平均推理时间: {avg_time:.2f} ms") print(f"内存占用: {mem_usage:.2f} MB") # 创建元数据文件 metadata = deployer.create_metadata('model_metadata.json') print("模型元数据:", json.dumps(metadata, indent=4)) import tensorflow as tf import numpy as np import cv2 import os import json from tqdm import tqdm class ObjectRecognitionDeployer: def __init__(self, model_path, class_labels): """ 初始化部署器 :param model_path: 模型文件路径 (Keras或TFLite) :param class_labels: 类别标签列表 """ self.class_labels = class_labels self.model_path = model_path self.interpreter = None self.input_details = None self.output_details = None # 根据模型类型加载 if model_path.endswith('.tflite'): self.load_tflite_model(model_path) else: self.model = tf.keras.models.load_model(model_path) self.input_shape = self.model.input_shape[1:3] def load_tflite_model(self, model_path): """加载并配置TFLite模型""" # 加载模型 self.interpreter = tf.lite.Interpreter(model_path=model_path) self.interpreter.allocate_tensors() # 获取输入输出详细信息 self.input_details = self.interpreter.get_input_details() self.output_details = self.interpreter.get_output_details() # 保存输入形状 self.input_shape = tuple(self.input_details[0]['shape'][1:3]) # 安全地打印模型元数据 self.print_model_metadata(model_path) def print_model_metadata(self, model_path): """安全地打印TFLite模型元数据""" try: from tflite_support import metadata displayer = metadata.MetadataDisplayer.with_model_file(model_path) print("--- 模型元数据 ---") print(displayer.get_metadata_json()) print("--- 关联文件 ---") print(displayer.get_packed_associated_file_list()) except (ImportError, ValueError) as e: print(f"警告: 无法获取模型元数据 - {str(e)}") print("使用输入/输出详细信息代替:") print(f"输入: {self.input_details}") print(f"输出: {self.output_details}") def preprocess_image(self, image, input_size, input_dtype=np.float32): """ 预处理图像 :param image: 输入图像 (numpy数组或文件路径) :param input_size: 模型输入尺寸 (height, width) :param input_dtype: 期望的输入数据类型 :return: 预处理后的图像张量 """ if isinstance(image, str): if not os.path.exists(image): raise FileNotFoundError(f"图像文件不存在: {image}") img = cv2.imread(image) if img is None: raise ValueError(f"无法读取图像: {image}") else: img = image # 调整尺寸和颜色空间 img = cv2.resize(img, (input_size[1], input_size[0])) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 根据数据类型进行归一化 if input_dtype == np.uint8: img = img.astype(np.uint8) # 量化模型使用uint8 else: # 浮点模型使用float32 img = img.astype(np.float32) / 255.0 # 添加批次维度 img = np.expand_dims(img, axis=0) return img def predict(self, image): """ 执行预测 :param image: 输入图像 (numpy数组或文件路径) :return: 预测结果 (类别名称, 置信度) """ if self.interpreter is not None: # TFLite模型推理 return self.predict_tflite(image) else: # Keras模型推理 return self.predict_keras(image) def predict_keras(self, image): """使用Keras模型预测""" # 预处理 img = self.preprocess_image(image, self.input_shape, np.float32) # 预测 predictions = self.model.predict(img, verbose=0)[0] class_idx = np.argmax(predictions) confidence = predictions[class_idx] class_name = self.class_labels[class_idx] return class_name, confidence def predict_tflite(self, image): """使用TFLite模型预测""" # 获取输入数据类型 input_dtype = self.input_details[0]['dtype'] # 预处理 img = self.preprocess_image(image, self.input_shape, input_dtype) # 设置输入张量 self.interpreter.set_tensor(self.input_details[0]['index'], img) # 执行推理 self.interpreter.invoke() # 获取输出 output_data = self.interpreter.get_tensor(self.output_details[0]['index']) predictions = output_data[0] # 解析结果 class_idx = np.argmax(predictions) confidence = predictions[class_idx] # 如果输出是量化数据,需要反量化 if self.output_details[0]['dtype'] == np.uint8: # 反量化输出 scale, zero_point = self.output_details[0]['quantization'] confidence = scale * (confidence - zero_point) class_name = self.class_labels[class_idx] return class_name, confidence def benchmark(self, image, runs=100): """ 模型性能基准测试 :param image: 测试图像 :param runs: 运行次数 :return: 平均推理时间(ms), 内存占用(MB) """ # 预热运行 self.predict(image) # 计时测试 start_time = tf.timestamp() for _ in range(runs): self.predict(image) end_time = tf.timestamp() avg_time_ms = (end_time - start_time).numpy() * 1000 / runs # 内存占用 if self.interpreter: # 计算输入张量内存占用 input_size = self.input_details[0]['shape'] dtype_size = np.dtype(self.input_details[0]['dtype']).itemsize mem_usage = np.prod(input_size) * dtype_size / (1024 * 1024) else: # 估算Keras模型内存 mem_usage = self.model.count_params() * 4 / (1024 * 1024) # 假设32位浮点数 return avg_time_ms, mem_usage def create_metadata(self, output_path): """ 创建并保存模型元数据文件 :param output_path: 元数据文件输出路径 """ metadata = { "model_type": "tflite" if self.model_path.endswith('.tflite') else "keras", "class_labels": self.class_labels, "input_size": self.input_shape, "input_dtype": str(self.input_details[0]['dtype']) if self.interpreter else "float32", "quantization": None } if self.interpreter and self.input_details[0]['dtype'] == np.uint8: metadata["quantization"] = { "input_scale": float(self.input_details[0]['quantization'][0]), "input_zero_point": int(self.input_details[0]['quantization'][1]), "output_scale": float(self.output_details[0]['quantization'][0]), "output_zero_point": int(self.output_details[0]['quantization'][1]) } with open(output_path, 'w') as f: json.dump(metadata, f, indent=4) return metadata def convert_to_tflite_with_metadata(self, output_path, quantize=False, representative_data_dir=None): """ 将Keras模型转换为TFLite格式并添加元数据 :param output_path: 输出TFLite文件路径 :param quantize: 是否进行量化 :param representative_data_dir: 代表性数据集目录 """ if not self.model_path.endswith(('.keras', '.h5')): raise ValueError("需要Keras模型格式进行转换") # 加载Keras模型 keras_model = tf.keras.models.load_model(self.model_path) # 创建转换器 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) if quantize: # 量化配置 converter.optimizations = [tf.lite.Optimize.DEFAULT] # 设置代表性数据集生成器 converter.representative_dataset = lambda: self.representative_dataset( representative_data_dir, input_size=self.input_shape ) # 设置输入输出类型 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.uint8 # 转换模型 tflite_model = converter.convert() # 保存模型 with open(output_path, 'wb') as f: f.write(tflite_model) print(f"TFLite模型已保存到: {output_path}") # 添加元数据 self.add_tflite_metadata(output_path) return output_path def representative_dataset(self, data_dir=None, input_size=(224, 224), num_samples=100): """ 生成代表性数据集用于量化 :param data_dir: 真实数据目录 :param input_size: 输入尺寸 (height, width) :param num_samples: 样本数量 """ # 优先使用真实数据 if data_dir and os.path.exists(data_dir): image_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] # 限制样本数量 image_files = image_files[:min(len(image_files), num_samples)] print(f"使用 {len(image_files)} 张真实图像进行量化校准") for img_path in tqdm(image_files, desc="量化校准"): try: # 读取并预处理图像 img = cv2.imread(img_path) if img is None: continue img = cv2.resize(img, (input_size[1], input_size[0])) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img.astype(np.float32) / 255.0 # 转换为float32并归一化 img = np.expand_dims(img, axis=0) yield [img] except Exception as e: print(f"处理图像 {img_path} 时出错: {str(e)}") else: # 使用随机数据作为备选 print(f"使用随机数据生成 {num_samples} 个样本进行量化校准") for _ in range(num_samples): # 生成随机图像,归一化到[0,1]范围,使用float32类型 data = np.random.rand(1, input_size[0], input_size[1], 3).astype(np.float32) yield [data] def add_tflite_metadata(self, model_path): """为TFLite模型添加元数据""" # 创建标签文件 labels_path = os.path.join(os.path.dirname(model_path), "labels.txt") with open(labels_path, 'w') as f: for label in self.class_labels: f.write(f"{label}\n") # 创建元数据 metadata_path = os.path.join(os.path.dirname(model_path), "metadata.json") self.create_metadata(metadata_path) print(f"元数据已创建: {metadata_path}") print(f"标签文件已创建: {labels_path}") # 使用示例 if __name__ == "__main__": # 类别标签 CLASS_LABELS = ['book', 'cup', 'glasses', 'phone', 'shoe'] # 初始化部署器 deployer = ObjectRecognitionDeployer( model_path='optimized_model.keras', class_labels=CLASS_LABELS ) # 转换为带元数据的TFLite格式 tflite_path = 'model_quantized.tflite' # 使用真实数据目录进行量化校准 REPRESENTATIVE_DATA_DIR = 'path/to/representative_dataset' # 替换为实际路径 deployer.convert_to_tflite_with_metadata( tflite_path, quantize=True, representative_data_dir=REPRESENTATIVE_DATA_DIR ) # 重新加载带元数据的模型 tflite_deployer = ObjectRecognitionDeployer( model_path=tflite_path, class_labels=CLASS_LABELS ) # 测试预测 test_image = 'test_image.jpg' class_name, confidence = tflite_deployer.predict(test_image) print(f"预测结果: {class_name}, 置信度: {confidence:.2f}") # 性能测试 avg_time, mem_usage = tflite_deployer.benchmark(test_image) print(f"平均推理时间: {avg_time:.2f} ms") print(f"内存占用: {mem_usage:.2f} MB") # 创建元数据文件 metadata = deployer.create_metadata('model_metadata.json') print("模型元数据:", json.dumps(metadata, indent=4)) 上述代码我已经成功执行,并且我的ObjectRecognitionDeployer类路径导入代码是from 计算机视觉.test2 import ObjectRecognitionDeployer
最新发布
06-23
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值