``` import torch import torch.nn as nn from skimage.segmentation import chan_vese import numpy as np from torch.nn import Conv2d from torchvision.ops import FeaturePyramidNetwork from torchvision.models import resnet50 import os os.environ['KMP_DUPLICATE_LIB_OK']='TRUE' def process_conv_output(conv_output): assert conv_output.shape == (64, 16, 3, 3), "Input tensor shape is not correct!" chunks = torch.chunk(conv_output, chunks=4, dim=0) for idx, chunk in enumerate(chunks): print(f"Chunk {idx} Shape: ", chunk.shape) assert chunk.shape == (16, 16, 3, 3) return chunks class ImageSegmentationModel(nn.Module): def __init__(self): super(ImageSegmentationModel,self).__init__() self.conv_layers = nn.Sequential( nn.Conv2d(1,128,kernel_size=3,stride=2), nn.MaxPool2d(kernel_size=3,stride=2), nn.ReLU(), nn.Conv2d(128,64, kernel_size=3, stride=2), nn.MaxPool2d(kernel_size=3, stride=2), nn.ReLU(), nn.Conv2d(64,32,kernel_size=3,stride=2), nn.MaxPool2d(kernel_size=3,stride=2), nn.ReLU(), nn.Conv2d(32,16,kernel_size=3,stride=2) ) self.resnet = resnet50(pretrained=True) self.initial_layer = nn.Sequential( self.resnet.conv1, # 输出通道64 self.resnet.bn1, self.resnet.relu, self.resnet.maxpool # 输出通道64 ) self.layer1 = self.resnet.layer1 self.layer2 = self.resnet.layer2 self.layer3 = self.resnet.layer3 self.layer4 = self.resnet.layer4 self.fpn = FeaturePyramidNetwork([256,512,1024,2048],256) self.conv_layers11 = nn.Conv2d(256,1,kernel_size=1,stride=1) self.final_conv = nn.Conv2d(16,21,kernel_size=1,stride=1) self.softmax = nn.Softmax(dim=1) def preprocess(self,x): x = torch.nn.functional.interpolate(x,size=(511,511),mode='bilinear',align_corners=False) x = torch.mean(x,dim=1,keepdim=True) x_np = x.detach().cpu().numpy() segmented = [] for i in range(x_np.shape[0]): img = x_np[i,0] #init =np.array([[img.shape[1]-1,0],[img.shape[1]-1,img.shape[0]-1],[0,img.shape[0]-1,],[0,0]]) snake = chan_vese(img,mu=0.25, lambda1=1.0, lambda2=1.0, tol=0.001, max_num_iter=500, dt=0.5) seg = np.zeros_like(img) from skimage.draw import polygon rr, cc = polygon(snake[:,1],snake[:,0],seg.shape) seg[rr, cc] = 1 segmented.append(seg) segmented = np.array(segmented) segmented = torch.from_numpy(segmented).unsqueeze(1).float().to(x.device) return segmented def forward(self,x): y = torch.nn.functional.interpolate(x,size=(511,511),mode='bilinear',align_corners=False) x = self.preprocess(x) conv_output = self.conv_layers(x) conv_output_tuple = process_conv_output(conv_output) print("conv_output:", conv_output.shape) z = self.initial_layer(y) c1 = self.layer1(z) # [batch,256,H,W] c2 = self.layer2(c1) # [batch,512,H,W] c3 = self.layer3(c2) # [batch,1024,H,W] c4 = self.layer4(c3) # [batch,2048,H,W] fpn_input = { 'feat1': c1, 'feat2': c2, 'feat3': c3, 'feat4': c4 } fpn_output = self.fpn(fpn_input) fpn_output_upsampled = torch.nn.functional.interpolate(fpn_output['feat1'], size=(511, 511), mode='bilinear', align_corners=False) final_output = nn.functional.conv2d(fpn_output_upsampled,conv_output_tuple,stride=1,padding=1,groups=16) final_output = self.final_conv(final_output) final_output = self.softmax(final_output) return final_output```改动不涉及fpn_output请重新修改
时间: 2025-06-27 18:05:48 浏览: 16
好的,我们来看一下如何对上述代码进行修改,并保证不影响 `fpn_output` 的部分。
---
### 修改目标说明
从您的描述可以看出,您希望保持 `fpn_output` 部分不变,同时对其他内容做出调整。以下是具体的分析:
1. **输入预处理**:保留 `preprocess()` 函数的功能,但可以优化其结构以提高效率。
2. **卷积层设计**:现有的 `conv_layers` 过于复杂且冗余,我们可以简化该模块的设计。
3. **最终输出计算**:将最后的特征融合逻辑(如 `final_conv` 和 `softmax`)重新设计,使其更清晰、合理。
下面是改进后的代码版本:
---
### 改进版代码
```python
import torch
import torch.nn as nn
from skimage.segmentation import chan_vese
import numpy as np
from torchvision.models import resnet50
class SimplifiedImageSegmentationModel(nn.Module):
def __init__(self):
super(SimplifiedImageSegmentationModel, self).__init__()
# 定义主干网络 ResNet-50 并提取其中的部分层
self.resnet = resnet50(pretrained=True)
self.backbone = nn.Sequential(
self.resnet.conv1, # 输入通道 -> 64
self.resnet.bn1,
self.resnet.relu,
self.resnet.maxpool, # 下采样到 1/4 原始尺寸
self.resnet.layer1, # 输出 [batch, 256, H/4, W/4]
self.resnet.layer2, # 输出 [batch, 512, H/8, W/8]
)
# 新增轻量级卷积分支
self.lightweight_conv = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 128, kernel_size=3, padding=1),
阅读全文
相关推荐

















