import torch import vega import torch.nn as nn class SimpleCnn(nn.Module): """Simple CNN network.""" def __init__(self): """Initialize.""" super(SimpleCnn, self).__init__() self.num_class = 11 # nas best hyperparameter self.blocks = 3 self.channel1 = 32 self.channel2 = 64 self.b1 = nn.Sequential( nn.Conv2d(1, self.channel1, padding=(0, 4), kernel_size=(1, 9), stride=(1, 1)), nn.BatchNorm2d(self.channel1), nn.ReLU(inplace=True) ) self.b2 = nn.Sequential( nn.Conv2d(self.channel1, self.channel2, padding=(0, 1), kernel_size=(2, 3), stride=(1, 2)), nn.BatchNorm2d(self.channel2), nn.ReLU(inplace=True) ) self.blocks = self._blocks(self.channel2, self.channel2) self.fc = nn.Linear(64*self.channel2, self.num_class) def _blocks(self, in_channels, out_channels): blocks = nn.ModuleList([None] * self.blocks) for i in range(self.blocks): blocks[i] = nn.Sequential( nn.Conv2d(in_channels, out_channels, padding=(0,1), kernel_size=(1,3)), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) in_channels = out_channels return blocks def forward(self, x): """Forward.""" x = self.b1(x) x = self.b2(x) for block in self.blocks: x = block(x) x = x.view(x.size(0), -1) x = self.fc(x) return x if __name__ == "__main__": model = SimpleCnn() model.load_state_dict(torch.load('nas_results/model_12.pth')) model.eval()将这段代码从基于pytorch改为基于mindspore
时间: 2023-11-18 19:05:21 浏览: 133
将这段代码从基于PyTorch改为基于MindSpore,需要进行以下修改:
1. 导入MindSpore相关库:
```
import mindspore
import mindspore.nn as nn
from mindspore import Tensor
```
2. 将`nn.Module`改为`nn.Cell`,并且去掉`super`中的`self`参数:
```
class SimpleCnn(nn.Cell):
"""Simple CNN network."""
def __init__(self):
"""Initialize."""
super().__init__()
...
```
3. 将`nn.Sequential`改为`nn.SequentialCell`,并且去掉`inplace`参数:
```
self.b1 = nn.SequentialCell([
nn.Conv2d(1, self.channel1, padding=(0, 4), kernel_size=(1, 9), stride=(1, 1)),
nn.BatchNorm2d(self.channel1),
nn.ReLU(),
])
```
4. 将`nn.ModuleList`改为`nn.CellList`:
```
def _blocks(self, in_channels, out_channels):
blocks = nn.CellList([None] * self.blocks)
for i in range(self.blocks):
blocks[i] = nn.SequentialCell([
nn.Conv2d(in_channels, out_channels, padding=(0,1), kernel_size=(1,3)),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
])
in_channels = out_channels
return blocks
```
5. 将`x.size()`改为`x.shape()`:
```
x = x.view(x.shape()[0], -1)
```
6. 将PyTorch中的Tensor改为MindSpore中的Tensor:
```
model.load_state_dict(Tensor.load_npz('nas_results/model_12.ckpt'))
```
修改后的完整代码如下:
```
import mindspore
import mindspore.nn as nn
from mindspore import Tensor
class SimpleCnn(nn.Cell):
"""Simple CNN network."""
def __init__(self):
"""Initialize."""
super().__init__()
self.num_class = 11
# nas best hyperparameter
self.blocks = 3
self.channel1 = 32
self.channel2 = 64
self.b1 = nn.SequentialCell([
nn.Conv2d(1, self.channel1, padding=(0, 4), kernel_size=(1, 9), stride=(1, 1)),
nn.BatchNorm2d(self.channel1),
nn.ReLU(),
])
self.b2 = nn.SequentialCell([
nn.Conv2d(self.channel1, self.channel2, padding=(0, 1), kernel_size=(2, 3), stride=(1, 2)),
nn.BatchNorm2d(self.channel2),
nn.ReLU(),
])
self.blocks = self._blocks(self.channel2, self.channel2)
self.fc = nn.Dense(64*self.channel2, self.num_class)
def _blocks(self, in_channels, out_channels):
blocks = nn.CellList([None] * self.blocks)
for i in range(self.blocks):
blocks[i] = nn.SequentialCell([
nn.Conv2d(in_channels, out_channels, padding=(0,1), kernel_size=(1,3)),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
])
in_channels = out_channels
return blocks
def construct(self, x):
"""Forward."""
x = self.b1(x)
x = self.b2(x)
for block in self.blocks:
x = block(x)
x = x.view(x.shape()[0], -1)
x = self.fc(x)
return x
if __name__ == "__main__":
model = SimpleCnn()
model.load_state_dict(Tensor.load_npz('nas_results/model_12.ckpt'))
model.eval()
```
阅读全文