- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
具体实现
(一)环境
语言环境:Python 3.10
编 译 器: PyCharm
框 架: Tensorflow 2.10.0
(二)具体步骤
from absl.logging import warning
import tensorflow as tf
from tensorflow.python.data import AUTOTUNE
from utils import GPU_ON
import matplotlib.pyplot as plt
# 目标:主要学习数据增强的方式方法
# 第一步:准备环境
# 查询tensorflow版本
print("Tensorflow Version:", tf.__version__)
# print(tf.config.experimental.list_physical_devices('GPU'))
# 设置使用GPU
gpus = tf.config.list_physical_devices("GPU")
print(gpus)
if gpus:
gpu0 = gpus[0] # 如果有多个GPU,仅使用第0个GPU
tf.config.experimental.set_memory_growth(gpu0, True) # 设置GPU显存按需使用
tf.config.set_visible_devices([gpu0], "GPU")>)
# ##########output#############################################
# Tensorflow Version: 2.10.0# [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
# [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
# ##########end output##########################################
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
import os, PIL, pathlib
# 隐藏警告
import warnings
warnings.filterwarnings('ignore')
# 第二步:导入数据
data_dir = "./datasets/365-7-data"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:", image_count)
# ########output##############################################
# 图片总数为: 3400# ########end output##########################################
# 第三步:数据预处理
batch_size = 8
img_height, img_width = 224, 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size,
)
# ############output##########################################
# Found 3400 files belonging to 2 classes.