深度学习技术栈 —— Pytorch之TensorDataset、DataLoader
前言
简单来说,TensorDataset
与DataLoader
这两个类的作用, 就是将数据读入并做整合,以便交给模型处理。就像石油加工厂一样,你不关心石油是如何采集与加工的,你关心的是自己去哪加油,油价是多少,对于一个模型而言,DataLoader就是这样的一个予取予求的数据服务商。
参考文章或视频链接 |
---|
[1] How to use TensorDataset, Dataloader (pytorch) |
一、TensorDataset、DataLoader的用法?
# coding:utf-8
# @Time: 2024/1/23 上午9:57
# @Author: 键盘国治理专家
# @File: __init__.py.py
# @Description:
import numpy as np
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
def test_TensorDataset():
input = np.random.rand(4, 2) # Input data
correct = np.random.rand(4, 1) # Correct answer data
input = torch.FloatTensor(input