from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.models import resnet34
import medmnist
from medmnist import INFO, Evaluator
print(f"MedMNIST v{medmnist.__version__} @ {medmnist.HOMEPAGE}")
data_flag = 'dermamnist'
# data_flag = 'breastmnist'
download = True
NUM_EPOCHS = 3
BATCH_SIZE = 128
lr = 0.001
info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])
# n_classes = 7
device = 'cuda:3'
DataClass = getattr(medmnist, info['python_class'])
# preprocessing
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.5], std=[.5])
])
# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download, size=224, mmap_mode='r', root='/root/workspace/dataset/DermaMNIST_new')
test_dataset = DataClass(split='test', transform=data_transform, download=download, size=224, mmap_mode='r', root='/root/workspace/dataset/DermaMNIST_new')
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
print(train_dataset)
print("===================")
print(test_dataset)
# define a simple CNN model
model = resnet34(pretrained=True)
in_channel = model.fc.in_features
model.fc = nn.Linear(in_channel, n_classes)
model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
# train
for epoch in range(NUM_EPOCHS):
train_correct = 0
train_total = 0
test_correct = 0
test_total = 0
model.train()
for inputs, targets in tqdm(train_loader):
# forward + backward + optimize
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs.to(device))
targets = targets.squeeze().long()
loss = criterion(outputs, targets.to(device))
loss.backward()
optimizer.step()
# evaluation
split = 'test'
model.eval()
y_true = torch.tensor([]).to(device) # 将 y_true 初始化为 GPU 张量
y_score = torch.tensor([]) # 将 y_score 初始化为 GPU 张量
data_loader = test_loader
with torch.no_grad():
for inputs, targets in data_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs.to(device))
outputs = outputs.softmax(dim=-1)
y_score = torch.cat((y_score, outputs.cpu()), 0)
y_score = y_score.detach().numpy()
evaluator = Evaluator(data_flag, split, size=224)
metrics = evaluator.evaluate(y_score)
print('%s auc: %.3f acc: %.3f' % (split, *metrics))