一个项目文件,首先要包括
data_loader.py
数据集类文件
class SYSUData(data.Dataset):# class数据集名称(data是从torch.utils中导入的模块)
#该类下一般有三个函数,初始化函数__init__、生成函数getitem和长度获取函数len;
def init(self, data_dir,):
1.该函数中一般是用来获取图像和标签的地址;
地址的语法:
./当前目录
…/上级目录
Windows | ubuntu | |
---|---|---|
相对地址 | …/Datasets/SYSU-MM01/ | …/Datasets/SYSU-MM01/ |
绝对地址 | “G:\xsf\ast\function_representation_learning-master\” | ‘/home/jdk1.8.0_65/bin/java’ |
np文件操作相关语法:
变量名 = np.load(文件夹地址 + 文件名),例如:
data_dir = '../Datasets/SYSU-MM01/'
self.train_label = np.load(data_dir + 'train _label.npy')
2.常用语法
Self.名字=名字(实例化类时传入的变量)
def __getitem__(self, index):
img1 = self.transform(img1)
img2 = self.transform(img2)
return img1, img2, target1, target2
#python类中的__getitem__方法,可以将实例化的类,实现字典dict的键值形式。实例对象的key不管是否存在都会调用类中的__getitem__()方法。而且返回值就是__getitem__()方法中规定的return值。语法: 实例化名称【字典中的键值】
def len(self):
return len(self.train_color_label)
Train.py
1、 添加参数
一般常用参数有:数据集名称和路径、学习率、参数优化方式、backbone、模型路径、训练日志路径、输入图像的宽和高、batchsize、gpu选择等,在输入运行命令时可以改变
2、 OS文件操作相关:
os.path.makedirs 创建文件夹
os.path.isdir 判断是否存在该文件夹
3、 确保训练生成的日志的路径正确,以及日志打印是的格式。
4、 打印信息
```handlebars
```handlebars
print("==========\nArgs:{}\n==========".format(args))
print('==> Loading data..')
print('Dataset 名字 statistics:)
print(' subset | # ids | # images')
print(' ------------------------------')
print(' visible | 类别数 | 训练集标签长度)
print(' query | {
:5d} | {
:8d}'.format(len(np.unique(query_label)), nquery))#搜寻标签
print(' gallery | {
:5d} | {
:8d}'.format(len(np.unique(gall_label)), ngall))#图库标签
print(' ------------------------------')
print('Data Loading Time:\t {
:.3f}'.format(time.time() - end))
print('==> Building model..')
print('Epoch: [{
}][{
}/{
}] '
'Time: {
batch_time.val:.3f} ({
batch_time.avg:.3f}) '
'lr:{
:.3f} '
'Loss: {
train_loss.val:.4f}