mobilenetv3代码复现pytorch
时间: 2025-03-03 18:35:14 浏览: 48
### 如何在 PyTorch 中实现 MobileNetV3 网络结构
为了构建 MobileNetV3,在 PyTorch 中可以定义一系列模块来表示网络的不同部分。下面是一个简化版的 `MobileNetV3` 实现,它基于已有的研究和开源项目[^1]。
#### 定义基础组件
首先定义一些必要的组件,比如卷积层、批量归一化以及激活函数 h-swish 和 ReLU:
```python
import torch.nn as nn
import torch
class Hswish(nn.Module):
def __init__(self, inplace=True):
super(Hswish, self).__init__()
self.inplace = inplace
def forward(self, x):
return x * F.relu6(x + 3., inplace=self.inplace) / 6.
def conv_3x3_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
Hswish()
)
```
#### 构建瓶颈 (Bottleneck)
接着创建一个类用于生成每个阶段内的瓶颈单元(bneck),这是 MobileNetV3 的核心组成部分之一:
```python
class Bottleneck(nn.Module):
def __init__(self, inp, oup, kernel, exp_size, se, nl, s):
super(Bottleneck, self).__init__()
assert s in [1, 2]
hidden_dim = exp_size
layers = []
# pw
if inp != hidden_dim:
layers.append(nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False))
layers.append(nn.BatchNorm2d(hidden_dim))
if nl == "RE":
layers.append(nn.ReLU(inplace=True))
elif nl == "HS":
layers.append(Hswish())
layers.extend([
# dw
nn.Conv2d(hidden_dim, hidden_dim, kernel, s, kernel//2, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
# SE
SELayer(hidden_dim) if se else nn.Identity(),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
])
if nl == "RE":
layers.append(nn.ReLU(inplace=True))
elif nl == "HS":
layers.append(Hswish())
self.conv = nn.Sequential(*layers)
self.use_res_connect = s == 1 and inp == oup
def forward(self, x):
out = self.conv(x)
if self.use_res_connect:
return x + out
else:
return out
```
这里还涉及到 Squeeze-and-Excitation(SE) 层的设计,这有助于提高模型性能[^4]:
```python
class SELayer(nn.Module):
def __init__(self, channel, reduction=4):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
Hsigmoid() # or use sigmoid activation function
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
```
最后组合上述所有部件完成整个 MobileNetV3 结构的搭建,并提供预训练权重加载功能以便快速上手使用[^2]。
```python
cfgs_large = [
# k, t, c, SE, NL, s
[3, 16, 16, False, 'RE', 1],
[3, 64, 24, False, 'RE', 2],
...
]
class MobileNetV3_Large(nn.Module):
def __init__(num_classes=1000):
super(MobileNetV3_Large, self).__init__()
self.features = make_layers(cfgs_large)
self.classifier = nn.Sequential(
nn.Linear(last_channel, num_classes),
)
def forward(self, x):
x = self.features(x)
x = x.mean([2, 3]) # global average pooling
x = self.classifier(x)
return x
def make_layers(cfgs):
layers = [conv_3x3_bn(3, cfgs[0][1], 2)]
for k, exp_size, c, use_se, nl, s in cfgs:
output_channels = c
layers.append(Bottleneck(input_channel, output_channels, k, exp_size, use_se, nl, s))
input_channel = output_channels
return nn.Sequential(*layers)
model = MobileNetV3_Large().to(device)
pretrained_dict = torch.load('path_to_pretrained_weights.pth')
model.load_state_dict(pretrained_dict)
```
以上就是利用 PyTorch 来实现 MobileNetV3 大型版本的一个简单例子。实际应用中可能还需要考虑更多细节优化等问题。
阅读全文
相关推荐


















