import math
import os
import os.path as osp
import random
import sys
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
import torch.utils.data as tordata
from .network import TripletLoss, GaitNet
from .utils import TripletSampler
class Model:
def __init__(self,
hidden_dim,
lr,
hard_or_full_trip,
margin,
num_workers,
batch_size,
restore_iter,
total_iter,
save_name,
train_pid_num,
frame_num,
model_name,
train_source,
test_source,
img_size):
self.save_name = save_name # 'CSTL_CASIA-B'
self.train_pid_num = train_pid_num # 73
self.train_source = train_source # 8003
self.test_source = test_source # 5483
self.hidden_dim = hidden_dim # 256
self.lr = lr # 0.0001
self.hard_or_full_trip = hard_or_full_trip # 'full'
self.margin = margin # 0.2
self.frame_num = frame_num # 30
self.num_workers = num_workers # 0
self.batch_size = batch_size # (2, 8)
self.model_name = model_name # 'CSTL'
self.P, self.M = batch_size
self.restore_iter = restore_iter # 0
self.total_iter = total_iter # 120000
self.img_size = img_size # 128
self.encoder = GaitNet(self.hidden_dim, self.train_pid_num, img_size // 2, 16).float() # 256, 73, 64, 16
self.encoder = nn.DataParallel(self.encoder)
self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
self.triplet_loss = nn.DataParallel(self.triplet_loss)
self.id_loss = nn.CrossEntropyLoss()
self.id_loss = nn.DataParallel(self.id_loss)
self.id_loss.cuda()
self.encoder.cuda()
self.triplet_loss.cuda()
self.optimizer = optim.Adam([
{'params': self.encoder.parameters()},
], lr=self.lr)
self.hard_loss_metric = []
self.full_loss_metric = []
self.id_loss_metric = []
self.full_loss_num = []
self.dist_list = []
self.mean_dist = 0.01
self.sample_type = 'all'
def collate_fn(self, batch):
batch_size = len(batch)
feature_num = len(batch[0][0])
seqs = [batch[i][0] for i in range(batch_size)]
frame_sets = [batch[i][1] for i in range(batch_size)]
view = [batch[i][2] for i in range(batch_size)]
seq_type = [batch[i][3] for i in range(batch_size)]
label = [batch[i][4] for i in range(batch_size)]
batch = [seqs, view, seq_type, label, None]
def select_frame(index):
sample = seqs[index]
frame_set = frame_sets[index]
global _
if self.sample_type == 'random':
if len(frame_set) >= self.frame_num:
x = random.randint(0, (len(frame_set) - self.frame_num))
frame_set = frame_set[x:x+self.frame_num]
frame_id_list = random.sample(frame_set, k=self.frame_num)
frame_id_list.sort()
_ = [feature.loc[frame_id_list].values for feature in sample]
elif len(frame_set) > self.frame_num/2:
s_frame_id_list = np.random.choice(frame_set, size=(self.frame_num - len(frame_set)), replace=False).tolist()
frame_id_list = s_frame_id_list + frame_set
frame_id_list.sort()
_ = [feature.loc[frame_id_list].values for feature in sample]
else:
pass
else:
frame_id_list = random.sample(frame_set, k=len(frame_set))
frame_id_list.sort()
_ = [feature.loc[frame_id_list].values for feature in sample]
return _
seqs = list(map(select_frame, range(len(seqs))))
if self.sample_type == 'random':
seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)]
else:
gpu_num = min(torch.cuda.device_count(), batch_size)
batch_per_gpu = math.ceil(batch_size / gpu_num)
batch_frames = [[
len(frame_sets[i])
for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
if i < batch_size
] for _ in range(gpu_num)]
if len(batch_frames[-1]) != batch_per_gpu:
for _ in range(batch_per_gpu - len(batch_frames[-1])):
batch_frames[-1].append(0)
max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)])
seqs = [[
np.concatenate([
seqs[i][j]
for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
if i < batch_size
], 0) for _ in range(gpu_num)]
for j in range(feature_num)]
seqs = [np.asarray([
np.pad(seqs[j][_],
((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)),
'constant',
constant_values=0)
for _ in range(gpu_num)])
for j in range(feature_num)]
batch[4] = np.asarray(batch_frames)
batch[0] = seqs
return batch
def fit(self):
if self.restore_iter != 0:
self.load(self.restore_iter)
self.encoder.train()
self.sample_type = 'random'
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.lr
triplet_sampler = TripletSampler(self.train_source, self.batch_size)
train_loader = tordata.DataLoader(
dataset=self.train_source,
batch_sampler=triplet_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_workers)
train_label_set = list(self.train_source.label_set)
train_label_set.sort()
_time1 = datetime.now()
for seq, view, seq_type, label, batch_frame in train_loader:
self.restore_iter += 1
self.optimizer.zero_grad()
for i in range(len(seq)):
seq[i] = self.np2var(seq[i]).float()
if batch_frame is not None:
batch_frame = self.np2var(batch_frame).int()
feature, part_prob = self.encoder(*seq)
target_label = [train_label_set.index(l) for l in label]
target_label = self.np2var(np.array(target_label)).long()
triplet_feature = feature.permute(1, 0, 2).contiguous()
part_prob = part_prob.permute(1, 0, 2).contiguous()
p, n, c = part_prob.size()
part_label = target_label.unsqueeze(0).repeat(part_prob.size(0), 1).view(p*n)
triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1)
part_prob = part_prob.view(p*n, c)
id_loss = self.id_loss(part_prob, part_label)
(full_loss_metric, hard_loss_metric, mean_dist, full_loss_num
) = self.triplet_loss(triplet_feature, triplet_label)
if self.hard_or_full_trip == 'hard':
loss = hard_loss_metric.mean()
elif self.hard_or_full_trip == 'full':
loss = full_loss_metric.mean()
loss = loss + id_loss.mean()
self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().num

onnx
- 粉丝: 1w+
最新资源
- 外文翻译--组合机床CAD系统(1)(1).doc
- 软件实习工作总结(1).docx
- 5G通信建设合同(1).docx
- 移动通信基站天馈设备安装规范简介课件2(1).ppt
- 高职类院校通信工程专业实践教学改革探究-教育文档.doc
- 计算机科学与技术的发展趋势探究(1).docx
- 对中职计算机专业工学结合教学模式的探讨(1).docx
- 本科毕业设计--基于组态软件的双容液位单回路过程控制系统设计(1).doc
- 毕业设计(论文)-基于单片机的电子称设计(1).doc
- 基于大数据处理的高速公路智能管控数据中心设计研究(1).docx
- 提高技工院校计算机教学质量的思考(1).docx
- 端盖零件加工工艺及工艺装备设计-机械制造与自动化专业论文(1).pdf
- 《SQLServer数据库应用基础教程》第二章数据库的基本操作(1).ppt
- 谈电力自动化的未来发展趋势(1).docx
- 智能化技术在电气工程自动化控制中的应用解析(1).docx
- 基于深度学习的分层教学实践(1).docx
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈


