import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from pathlib import Path
import visdom
from constants import settings
from constants.enum_keys import HK
from models.pose_estimation_model import PoseEstimationModel
from models.pafs_network import PAFsLoss
from aichallenger import AicNorm
from imgaug.augmentables.heatmaps import HeatmapsOnImage
class Trainer:
def __init__(self, batch_size, is_unittest=False):
# self.debug_mode: Set num_worker to 0. Otherwise pycharm debug won't work due to multithreading.
self.is_unittest = is_unittest
self.epochs = 5
self.val_step = 500
self.batch_size = batch_size
self.vis = visdom.Visdom()
self.img_key = HK.NORM_IMAGE
self.pcm_key = HK.PCM_ALL
self.paf_key = HK.PAF_ALL
if torch.cuda.is_available():
print("GPU available.")
else:
print("GPU not available. running with CPU.")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model_pose = PoseEstimationModel()
self.model_optimizer = optim.Adam(self.model_pose.parameters(), lr=1e-3)
self.loss = PAFsLoss()
train_dataset = AicNorm(Path.home() / "AI_challenger_keypoint", is_train=True,
resize_img_size=(512, 512), heat_size=(64, 64))
self.train_loader = DataLoader(train_dataset, self.batch_size, shuffle=True, num_workers=settings.num_workers,
pin_memory=True, drop_last=True)
# test_dataset = AicNorm(Path.home() / "AI_challenger_keypoint", is_train=False,
# resize_img_size=(512, 512), heat_size=(64, 64))
# self.val_loader = DataLoader(test_dataset, self.batch_size, shuffle=False, num_workers=workers, pin_memory=True,
# drop_last=True)
# self.val_iter = iter(self.val_loader)
def set_train(self):
"""Convert models to training mode
"""
self.model_pose.train()
def set_eval(self):
"""Convert models to testing/evaluation mode
"""
self.model_pose.eval()
def train(self):
self.epoch = 0
self.step = 0
self.model_pose.load_ckpt()
for self.epoch in range(self.epochs):
print("Epoch:{}".format(self.epoch))
self.run_epoch()
if self.is_unittest:
break
def run_epoch(self):
self.set_train()
for batch_idx, inputs in enumerate(self.train_loader):
inputs[self.img_key] = inputs[self.img_key].to(self.device, dtype=torch.float32)
inputs[self.pcm_key] = inputs[self.pcm_key].to(self.device, dtype=torch.float32)
inputs[self.paf_key] = inputs[self.paf_key].to(self.device, dtype=torch.float32)
loss, b1_out, b2_out = self.process_batch(inputs)
self.model_optimizer.zero_grad()
loss.backward()
self.model_optimizer.step()
# Clear visdom environment
if self.step == 0:
self.vis.close()
# Validate
if self.step % self.val_step == 0 and self.step != 0:
# self.val()
self.model_pose.save_ckpt()
if self.step % 50 == 0:
print("step {}; Loss {}".format(self.step, loss.item()))
# Show training materials
if self.step % self.val_step == 0:
pcm_CHW = b1_out[0].cpu().detach().numpy()
paf_CHW = b2_out[0].cpu().detach().numpy()
img_CHW = inputs[self.img_key][0].cpu().detach().numpy()[::-1, ...]
pred_pcm_amax = np.amax(pcm_CHW, axis=0) # HW
gt_pcm_amax = np.amax(inputs[self.pcm_key][0].cpu().detach().numpy(), axis=0)
pred_paf_amax = np.amax(paf_CHW, axis=0)
gt_paf_amax = np.amax(inputs[self.paf_key][0].cpu().detach().numpy(), axis=0)
self.vis.image(img_CHW, win="Input", opts={'title': "Input"})
self.vis.heatmap(np.flipud(pred_pcm_amax), win="Pred-PCM", opts={'title': "Pred-PCM"})
self.vis.heatmap(np.flipud(gt_pcm_amax), win="GT-PCM", opts={'title': "GT-PCM"})
self.vis.heatmap(np.flipud(pred_paf_amax), win="Pred-PAF", opts={'title': "Pred-PAF"})
self.vis.heatmap(np.flipud(gt_paf_amax), win="GT-PAF", opts={'title': "GT-PAF"})
self.vis.line(X=np.array([self.step]), Y=loss.cpu().detach().numpy()[np.newaxis], win='Loss', update='append')
if self.is_unittest:
break
self.step += 1
# def val(self, ):
# """Validate the model on a single minibatch
# """
# self.set_eval()
#
# try:
# inputs = next(self.val_iter)
# except StopIteration:
# self.val_iter = iter(self.val_loader)
# inputs = next(self.val_iter)
#
# inputs_gpu = {self.img_key: inputs[self.img_key].to(self.device, dtype=torch.float32),
# self.pcm_key: inputs[self.pcm_key].to(self.device, dtype=torch.float32),
# self.paf_key: inputs[self.paf_key].to(self.device, dtype=torch.float32)}
# with torch.no_grad():
# loss, b1_out, b2_out = self.process_batch(inputs_gpu)
# pred_pcm_amax = np.amax(b1_out[0].cpu().numpy(), axis=0) # HW
# gt_pcm_amax = np.amax(inputs[self.pcm_key][0].cpu().numpy(), axis=0)
# pred_paf_amax = np.amax(b2_out[0].cpu().numpy(), axis=0)
# gt_paf_amax = np.amax(inputs[self.paf_key][0].cpu().numpy(), axis=0)
# # Image augmentation disabled due to pred phase
# self.vis.image(inputs[self.img_key][0].cpu().numpy()[::-1, ...], win="Input", opts={'title': "Input"})
# self.vis.heatmap(np.flipud(pred_pcm_amax), win="Pred-PCM", opts={'title': "Pred-PCM"})
# self.vis.heatmap(np.flipud(gt_pcm_amax), win="GT-PCM", opts={'title': "GT-PCM"})
# self.vis.heatmap(np.flipud(pred_paf_amax), win="Pred-PAF", opts={'title': "Pred-PAF"})
# self.vis.heatmap(np.flipud(gt_paf_amax), win="GT-PAF", opts={'title': "GT-PAF"})
# self.vis.line(X=np.array([self.step]), Y=loss.cpu().numpy()[np.newaxis], win='Loss', update='append')
# self.set_train()
def process_batch(self, inputs):
"""Pass a minibatch through the network and generate images and losses
"""
res = self.model_pose(inputs[self.img_key])
gt_pcm = inputs[self.pcm_key] # ["heatmap"]: {"vis_or_not": NJHW, "visible": NJHW}
gt_pcm = gt_pcm.unsqueeze(1)
gt_paf = inputs[self.paf_key]
gt_paf = gt_paf.unsqueeze(1)
b1_stack = torch.stack(res[HK.B1_SUPERVISION], dim=1) # Shape (N, Stage, C, H, W)
b2_stack = torch.stack(res[HK.B2_SUPERVISION], dim=1)
loss = self.loss(b1_stack, b2_stack, gt_pcm, gt_paf)
return loss, res[HK.B1_OUT], res[HK.B2_OUT]
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
基于深度学习pytorch框架的交通警察指挥手势识别系统(含源码+训练好的模型+数据集+操作教程).zi 识别8种中国交通警察指挥手势的Pytorch深度学习项目 带训练好的模型以及数据集 下载模型参数文件checkpoint和生成的骨架generated 放置在: ctpgr-pytorch/checkpoints ctpgr-pytorch/generated 下载交警手势数据集(必选) 交警手势数据集下载: 放置在: (用户文件夹)/PoliceGestureLong (用户文件夹)/AI_challenger_keypoint # 用户文件夹 在 Windows下是'C:\Users\(用户名)',在Linux下是 '/home/(用户名)' 安装Pytorch和其它依赖: # Python 3.8.5 conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch conda install ujson pip install visdom opencv-python imgaug 数据集和
资源推荐
资源详情
资源评论

























收起资源包目录















































共 37 条
- 1

onnx
- 粉丝: 1w+
上传资源 快速赚钱
我的内容管理 展开
我的资源 快来上传第一个资源
我的收益
登录查看自己的收益我的积分 登录查看自己的积分
我的C币 登录后查看C币余额
我的收藏
我的下载
下载帮助


最新资源
- 电子闹钟设计自动化与信息工程系-学位论文(1).doc
- 互联网时代下我国职业分类发展趋势(1).docx
- 农业物联网的服务平台设计(1).docx
- 信息化技术在建筑施工技术课程中的应用初探(1).docx
- 城市轨道交通信号系统的安全策略与可靠性分析(1).docx
- 简论安全使用网络的论文-计算机网络论文(1).docx
- 运营商的互联网新思维(1).docx
- 河北省中小学计算机教室管理制度(2)(1).doc
- 大学毕业论文-—基于单片机智能循迹小车(1).doc
- 计算机二级Ms-office高级应用选择题基础知识点——王永辉(1).docx
- 汇存电子商务平台设计书样本样本(1).doc
- 高校档案信息化建设策略探析.doc
- 互联网+时代大学生创业路径研究(1).docx
- 基于ASP.NET的快捷网上购物模拟系统构建(1).docx
- 基于多元化理念下的中专计算机教学研究(1).docx
- 互联网不正当竞争行为法律问题研究(1).docx
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈



安全验证
文档复制为VIP权益,开通VIP直接复制

- 1
- 2
前往页