TSM(Temporal Shift Module)源码解析
论文名:TSM: Temporal Shift Module for Efficient Video Understanding
代码链接:https://github.com/mit-han-lab/temporal-shift-module
代码的主要结构如下:
python file | function |
---|---|
mian.py | 主要的训练函数 |
opts.py | 代码的参数配置 |
ops/dataset.py | 数据集的载入部分,核心是__getitem__函数。 |
ops/dataset_config.py | 用于配置不同的数据集 |
ops/models.py | 组装模型 |
ops/temporal_shift.py | 核心的temporal shift操作 |
1.opts.py是参数配置。除了路径、超参外,还有几个参数要注意(不同于TSN的点):
parser.add_argument('--shift', default=False, action="store_true", help='use shift for models')
parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)')
– modality 表示输入的类型,RGB/FLOW
–num_segments 表示每个视频采样frames数,一般是8或者16。
–shift 表示是否加入tsm模块。
–shift_div 表示shift的特征的比例,一般是8。表示2*1/8比例的特征会移动,其中1/8的特征做shift left, 另1/8的特征做shift right。
2.dataset_config.py是数据级的配置。
每个数据集实现一个return_xxx(modality)
返回数据集支持的子类名,train_list路径,val_list路径,数据集的根路径等信息。
3.dataset.py 是数据集的载入部分。
dataset.py的主要功能就是对数据集进行读取,并且对其稀疏采样,返回稀疏采样后得到的数据集。
dataset.py中实现了TSNDataSet类,处理原始数据,此类继承于torch.utils.data.dataset类。
其中首先定义了一个简单类:
3.1 VideoRecord,用于封装一个视频内容,包括图片的路径,frames的数量,标签信息。
class VideoRecord(object):
def __init__(self, row):
self._data = row
@property
def path(self):
return self._data[0]
@property
def num_frames(self):
return int(self._data[1])
@property
def label(self):
return int(self._data[2])
3.2 TSNDataSet:
class TSNDataSet(data.Dataset):
def __init__(self, root_path, list_file,
num_segments=3, new_length=1, modality='RGB',
image_tmpl='img_{:05d}.jpg', transform=None,
random_shift=True, test_mode=False,
remove_missing=False, dense_sample=False, twice_sample=False):
self.root_path = root_path
self.list_file = list_file
self.num_segments = num_segments
self.new_length = new_length
self.modality = modality
self.image_tmpl = image_tmpl
self.transform = transform
self.random_shift = random_shift
self.test_mode = test_mode
self.remove_missing = remove_missing
self.dense_sample = dense_sample # using dense sample as I3D
self.twice_sample = twice_sample # twice sample for more validation
if self.dense_sample:
print('=> Using dense sample for the dataset...')
if self.twice_sample:
print('=> Using twice sample for the dataset...')
if self.modality == 'RGBDiff':
self.new_length += 1 # Diff needs one more image to calculate diff
self._parse_list()
设置一些参数和参数默认值之后,调用了_parse_list()函数,tmq是一个长度为训练数据数量的列表。每个值都是VIDEORecord对象,包含一个列表和3个属性,列表长度为3,用空格键分割,分别为帧路径、该视频含有多少帧和帧标签。然后调用VideoRecord(