matlab 转换 tfrecord,训练数据集与TFRecord互相转换的两种方式

这篇博客介绍了如何使用TensorFlow将图像数据和标签信息转换为TFRecord格式,以便于数据管理和高效训练。首先定义了用于创建TFRecord文件的辅助函数,然后通过读取label.txt文件,结合图像路径和目标信息,利用tf.gfile.FastGFile或PIL库解析图像并获取其尺寸信息。最后,将这些信息写入TFRecord文件。在读取TFRecord文件时,使用tf.image.decode_jpeg或tf.decode_raw解码图像,并进行相应处理。整个过程涉及到图像的解码、尺寸获取以及数据结构的序列化和反序列化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

TensorFlow使用TFRecord格式来统一存储数据,该格式可以将图像数据、标签信息、图像路径以及宽高等不同类型的信息放在一起进行统一存储,从而方便有效的管理不同的属性。

将训练数据集转成TFRecord

这里采用的数据集为目前正在做的项目的数据集,共包含两个目标文件夹(分别包含100幅图像)及对应的label.txt,label文件中的每一条内容分别对应两个文件夹中的一幅图像的路径及目标物的位置信息,即左上顶点和右下顶点的坐标信息(),接下来我们将上面的数据制作成TFRecord文件,由于后续需要验证制作的TFRecord数据是否正确,而每张图像的尺寸并不一致,因此在生成的TFRecord文件中除了包含图像内容和标签信息,还包括了图像的宽、高及通道的信息,这样在解析图像的时候,才能把图像数据重新reshape成图像。

根据读取图像数据方式的不同,共有两种方式将自己的数据集转换成TFRecord格式,同样对应两种方式对TFRecord格式进行解析。具体代码如下:# Convert own_data  to TFRecord of TF-Example protos.import tensorflow as tffrom PIL import Imageimport numpy as npimport os# 生成整数型的属性def int64_feature(values):

return tf.train.Feature(int64_list=tf.train.Int64List(value=values))# 生成浮点型的属性def float_feature(values):

return tf.train.Feature(float_list=tf.train.FloatList(value=values))# 生成字符串型的属性def bytes_feature(values):

return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))# 标签信息的地址dataset_dir = "/Users/**/**/label.txt"# 图像存放的根目录地址root_dir = '"/Users/**/**/'# 输出TFRecord文件的地址output_filename = "/Users/**/**/output.tfrecord"file_lines = open(dataset_dir).readlines()# 创建一个writer来写TFRecord文件writer = tf.python_io.TFRecordWriter(output_filename)# 统计有效数据valid_record_count = 0# 从label.txt循环读入要写入的数据信息for idx, line in enumerate(file_lines):

line = line.strip('\n')

image_target_path = line.split(",")[0]

image_search_path = line.split(",")[1]

image_labels_str = line.split(",")[2:]

image_format = str(image_target_path.split('.')[-1]).lower()

image_target_path = os.path.join(root_dir, image_target_path)

image_search_path = os.path.join(root_dir, image_search_path)    # 使用tf.gfile.FastGFile读取图像的原始数据,method_1

image_target_data = tf.gfile.FastGFile(image_target_path, 'r').read()

image_search_data = tf.gfile.FastGFile(image_search_path, 'r').read()    # 使用tf.image.decode_jpeg对图像进行解码,并利用img.eval().shape获得图像的宽高和通道信息

T_height, T_width, channels = tf.image.decode_jpeg(image_target_data).eval().shape

S_height, S_width, channels = tf.image.decode_jpeg(image_search_data).eval().shape    # 使用PIL的Image.open读取图像,method_2

image_target = Image.open(image_target_path, 'r')

image_target_data = image_target.tobytes()

T_height, T_width = image_target.size

image_search = Image.open(image_search_path, 'r')

image_search_data = image_search.tobytes()

S_height, S_width = image_search.size

image_labels = [float(x) for x in image_labels_str]    if not len(image_labels) == 4:

print("invalid label: " + line)        continue

# 将一个样例转化为Example Protocol Buffer,并将所有信息写入数据结构

example = tf.train.Example(features=tf.train.Features(feature={        'image_target/encoded': bytes_feature(image_target_data),        'image_search/encoded': bytes_feature(image_search_data),        'image_target/format': bytes_feature(image_format),        'image_search/format': bytes_feature(image_format),        'image/class/label': float_feature(image_labels),        'image_target/height': int64_feature(T_height),        'image_target/width': int64_feature(T_width),        'image_search/height': int64_feature(S_height),        'image_search/width': int64_feature(S_width),        'image/channels': int64_feature(channels),        'image_target/path': bytes_feature(image_target_path),        'image_search/path': bytes_feature(image_search_path) }))

# 将一个Example写入TFRecord文件

writer.write(example.SerializeToString())

valid_record_count += 1writer.close()

print("\nvalid image count: " + str(valid_record_count))

读取TFRecord文件,具体代码如下:# 使用 tf.image.decode_jpeg对jpg格式图像进行解码,对应tf.gfile读取图像,method_1image_target = tf.image.decode_jpeg(features['image_target/encoded'])# 使用tf.decode_raw将字符串解析成图像对应的像素数组,对应Image.open读取图像,method_2image_target = tf.decode_raw(features['image_target/encoded'], tf.uint8)

label = features['image/class/label']

T_height = tf.cast(features['image_target/height'], tf.int32)

T_width = tf.cast(features['image_target/width'], tf.int32)

channels = tf.cast(features['image/channels'], tf.int32)

image_target_path = features['image_target/path']

sess = tf.Session()

coord = tf.train.Coordinator()

threads = tf.train.start_queue_runners(sess=sess,coord=coord)# 每次运行可以读取TFRecord文件中的一个样例for i in range(100):

image_t, label_info,t_height, t_width, channnel, path = sess.run([image_target,label,T_height, T_width,channels,image_target_path])

image_name = path.split("/")[-1].split(".")[0]

sample = sess.run(tf.reshape(image_t, [t_height, t_width, channnel]))

image= Image.fromarray(sample,'RGB')    # 以图像名称_label信息对图像命名,并进行存储

image.save(decode_path+ image_name+'_'+ str(label_info[0])+'.jpg')

作者:我是笨徒弟

链接:https://www.jianshu.com/p/9448f71e9641

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值