1、源码准备
git clone https://github.com/jfzhang95/pytorch-video-recognition
下载完成后,即可得到__C3D-pytorch__版本的源码
2、源码结构
文件名称 | 功能作用 |
---|---|
dataset.py | 数据准备 |
C3D_model.py | 模型网络结构构建 |
train.py | 网络模型训练 |
inference.py | 数据预处理 |
mypath.py | 数据模型路径自定义 |
3、源码分析
代码可主要分为数据,模型和训练三个部分,接下来我们将通过这个顺序来一一讲解。
3.1 数据
dataset.py的主要功能是对数据集的读取和处理
它首先定义了一个VideoDataset类,用来处理视频数据。该类通过重写内部的__init__和__getitem__函数来实现对数据的读取与自定义处理。
接下来详细介绍一下各个函数
3.1.1 数据初始化部分
1)__init__函数
__init__函数主要功能是对类VideoDataset进行初始化,设置参数并且对数据进行了一些操作。
def __init__(self, dataset='ucf101', split='train', clip_len=16, preprocess=False):
self.root_dir, self.output_dir = Path.db_dir(dataset)
folder = os.path.join(self.output_dir, split)
self.clip_len = clip_len
self.split = split
self.resize_height = 128
self.resize_width = 171
self.crop_size = 112
if not self.check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You need to download it from official website.')
if (not self.check_preprocess(