import os
# os.environ['DEVICE_ID'] = '6'
import numpy as np
import mindspore as ms
from mindspore import nn
from mindspore import context
from mindspore import dataset
from mindspore.train.callback import LossMonitor
from mindspore.common.api import ms_function
from mindspore.ops import operations as P
from PIL import Image
#当前实验选择算力为Ascend,如果在本地体验,参数device_target设置为"CPU”
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
#要筛选的分辨率条件
targetWidth=426
targetHeight=640
targetChannal=3
#读取animal文件夹下所有文件的名字
rootDir='animal'
fileNameList=['cat','elephant','sheep']
label_map = {
'cat': 0,
'elephant': 1,
'sheep': 2
}
X,Y=[],[]
for fileName in fileNameList:
fileDir=rootDir+'/'+fileName
#print(fileDir)
imgNameList=os.listdir(fileDir)
#print(imgNameList)
for imgName in imgNameList:
imgDir=fileDir+'/'+imgName
img=Image.open(imgDir)
img=np.array(img)
if(len(img.shape)==3):
width,height,channal=img.shape
if width==targetWidth and height==targetHeight and channal==targetChannal:#符合筛选条件的样本留下放到X,其标签放到Y
X.append(img.flatten())
Y.append(label_map[fileName])#类别
#print(X,Y)
#划分训练集和测试集合
sampleNum=len(X)
train_idx = np.random.choice(sampleNum, int(sampleNum*0.8), replace=False)#取80%的样本作为训练集
test_idx = np.array(list(set(range(sampleNum)) - set(train_idx)))#剩下的样本作为测试集
X_train=[X[i].astype(np.float32) for i in range(len(X)) if i in train_idx]
Y_train=[Y[i] for i in range(len(Y)) if i in train_idx]
X_test=[X[i].astype(np.float32) for i in range(len(X)) if i in test_idx]
Y_test=[Y[i] for i in range(len(Y)) if i in test_idx]
XY_train = list(zip(X_train, Y_train))
ds_train = dataset.GeneratorDataset(XY_train, ['x', 'y'])
# ds_train.set_dataset_size(sampleNum)
ds_train = ds_train.shuffle(buffer_size=sampleNum).batch(32, drop_remainder=True)