TSM(Temporal Shift Module)源码解析

TSM(Temporal Shift Module)源码解析

论文名:TSM: Temporal Shift Module for Efficient Video Understanding

代码链接:https://github.com/mit-han-lab/temporal-shift-module

代码的主要结构如下:

python file function
mian.py 主要的训练函数
opts.py 代码的参数配置
ops/dataset.py 数据集的载入部分,核心是__getitem__函数。
ops/dataset_config.py 用于配置不同的数据集
ops/models.py 组装模型
ops/temporal_shift.py 核心的temporal shift操作

1.opts.py是参数配置。除了路径、超参外,还有几个参数要注意(不同于TSN的点):

parser.add_argument('--shift', default=False, action="store_true", help='use shift for models')
parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)')

– modality 表示输入的类型,RGB/FLOW

–num_segments 表示每个视频采样frames数,一般是8或者16。

–shift 表示是否加入tsm模块。

–shift_div 表示shift的特征的比例,一般是8。表示2*1/8比例的特征会移动,其中1/8的特征做shift left, 另1/8的特征做shift right。

2.dataset_config.py是数据级的配置。

每个数据集实现一个return_xxx(modality)

返回数据集支持的子类名,train_list路径,val_list路径,数据集的根路径等信息。

3.dataset.py 是数据集的载入部分。

dataset.py的主要功能就是对数据集进行读取,并且对其稀疏采样,返回稀疏采样后得到的数据集。

dataset.py中实现了TSNDataSet类,处理原始数据,此类继承于torch.utils.data.dataset类。

其中首先定义了一个简单类:

3.1 VideoRecord,用于封装一个视频内容,包括图片的路径,frames的数量,标签信息。
class VideoRecord(object):
    def __init__(self, row):
        self._data = row

    @property
    def path(self):
        return self._data[0]

    @property
    def num_frames(self):
        return int(self._data[1])

    @property
    def label(self):
        return int(self._data[2])
3.2 TSNDataSet:
class TSNDataSet(data.Dataset):
    def __init__(self, root_path, list_file,
                 num_segments=3, new_length=1, modality='RGB',
                 image_tmpl='img_{:05d}.jpg', transform=None,
                 random_shift=True, test_mode=False,
                 remove_missing=False, dense_sample=False, twice_sample=False):

        self.root_path = root_path
        self.list_file = list_file
        self.num_segments = num_segments
        self.new_length = new_length
        self.modality = modality
        self.image_tmpl = image_tmpl
        self.transform = transform
        self.random_shift = random_shift
        self.test_mode = test_mode
        self.remove_missing = remove_missing
        self.dense_sample = dense_sample  # using dense sample as I3D
        self.twice_sample = twice_sample  # twice sample for more validation
        if self.dense_sample:
            print('=> Using dense sample for the dataset...')
        if self.twice_sample:
            print('=> Using twice sample for the dataset...')

        if self.modality == 'RGBDiff':
            self.new_length += 1  # Diff needs one more image to calculate diff

        self._parse_list()

设置一些参数和参数默认值之后,调用了_parse_list()函数,tmq是一个长度为训练数据数量的列表。每个值都是VIDEORecord对象,包含一个列表和3个属性,列表长度为3,用空格键分割,分别为帧路径、该视频含有多少帧和帧标签。然后调用VideoRecord(

### 实现Temporal Shift Module (TSM) Temporal Shift Module (TSM) 是一种高效的视频理解技术,通过在卷积神经网络中引入时间维度的信息来提升模型性能[^1]。以下是有关 TSM 的实现方法及其核心概念: #### 1. **TSM的核心思想** TSM 利用通道级别的信息共享机制,在不增加额外参数的情况下捕获时间依赖关系。其主要特点是通过对输入特征图进行切片操作,将前一帧的部分通道移动到当前帧,后一帧的部分通道也移入当前帧,从而模拟时间上的信息传递[^2]。 #### 2. **PyTorch中的TSM实现** 以下是一个简单的 PyTorch 实现示例,展示了如何定义和应用 Temporal Shift 操作: ```python import torch import torch.nn as nn class TemporalShift(nn.Module): def __init__(self, net, n_segment=3, n_div=8): super(TemporalShift, self).__init__() self.net = net self.n_segment = n_segment self.fold_div = n_div def forward(self, x): nt, c, h, w = x.size() n_batch = nt // self.n_segment x = x.view(n_batch, self.n_segment, c, h, w) fold = c // self.fold_div out = torch.zeros_like(x) out[:, :-1, :fold] = x[:, 1:, :fold] # shift left out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # no shift return out.view(nt, c, h, w).contiguous() # 使用示例 base_model = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2)) tsm_layer = TemporalShift(base_model, n_segment=3, n_div=8) input_tensor = torch.randn((9, 3, 224, 224)) # 假设有3个片段 output = tsm_layer(input_tensor) print(output.shape) ``` 上述代码实现了基本的时间位移模块 `TemporalShift`,其中: - 参数 `n_segment` 表示每批数据中有多少个连续帧。 - 参数 `n_div` 控制每个通道被分割的比例,用于决定哪些通道参与左移或右移操作。 #### 3. **训练流程概述** 为了完成完整的 TSM 训练过程,可以参考以下步骤(基于 UCF101 数据集为例): - 准备好视频剪辑的数据加载器,确保每次迭代返回多个连续帧组成的批量。 - 将基础 CNN 模型(如 ResNet 或 MobileNet)替换为其对应的 TSM 版本。 - 定义损失函数并优化权重更新逻辑。 具体可参见飞桨 AI Studio 提供的相关教程链接,这些资源详细描述了从零开始搭建整个工作流的过程。 #### 4. **注意事项** 当实际部署时需要注意几个方面: - 输入张量形状应满足 `(batch_size * num_segments, channels, height, width)` 格式以便于后续处理。 - 如果 GPU 显存不足,则可以通过减少批次大小或者调整分辨率解决内存占用过高问题。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值