import torch
import timm
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
from scipy.spatial.distance import cosine
# 1. 加载预训练的 DeiT 模型
model = timm.create_model('deit_base_patch16_224', pretrained=True) # 选择不同模型 'deit_small_patch16_224', 'deit_tiny_patch16_224'
model.eval() # 设置为评估模式
# 移除分类头,仅保留特征提取部分
model.head = torch.nn.Identity()
# 2. 图像预处理函数
transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整图像大小
transforms.ToTensor(), # 转换为 Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])
# 3. 加载和预处理图像
def preprocess_image(image_path):
img = Image.open(image_path).convert('RGB') # 打开并转换为 RGB
return transform(img).unsqueeze(0) # 添加批量维度
# 4. 提取特征
def extract_features(image_path):
img_tensor = preprocess_image(image_path) # 预处理图像
with torch.no_grad(): # 禁用梯度计算
features = model(img_tensor) # 获取特征
return features.squeeze() # 返回去除批量维度后的特征
# 5. 计算余弦相似度
def cosine_similarity(features1, features2):
# 计算余弦相似度
return 1 - cosine(features1.numpy(), features2.numpy()) # 余弦相似度越接近1表示越相似
# 6. 计算欧氏距离
def euclidean_distance(features1, features2):
return torch.norm(features1 - features2).item()
# 7. 加载两张图片并提取特征
image1_path = '1.jpg'
image2_path = '2.jpg'
features1 = extract_features(image1_path)
features2 = extract_features(image2_path)
# 8. 计算相似度
similarity_cosine = cosine_similarity(features1, features2)
similarity_euclidean = euclidean_distance(features1, features2)
# 9. 打印结果
print(f"Cosine Similarity: {similarity_cosine:.4f}")
print(f"Euclidean Distance: {similarity_euclidean:.4f}")