ag_news数据集下载
时间: 2025-03-26 20:39:18 浏览: 27
### 下载并使用 AG_NEWS 数据集
#### 使用 `torchtext` 库下载和加载数据集
为了方便地获取 AG_NEWS 数据集,推荐使用 PyTorch 提供的 `torchtext` 工具库。该工具简化了文本数据处理流程,并内置支持多个常用自然语言处理(NLP)数据集。
```python
from torchtext.datasets import AG_NEWS
import os
# 设置保存路径
root_dir = './data'
os.makedirs(root_dir, exist_ok=True)
# 加载训练集和测试集
train_iter = AG_NEWS(split='train', root=root_dir)
test_iter = AG_NEWS(split='test', root=root_dir)
```
上述代码会自动检测本地是否存在指定版本的数据;如果不存在,则尝试从网络上下载并将解压后的文件存储到给定目录中[^1]。
#### 验证数据结构
一旦完成下载过程,可以通过迭代器查看部分样本:
```python
for idx, (label, text) in enumerate(train_iter):
print(f'Label: {label}, Text: {text[:50]}...')
if idx >= 2:
break
```
这段脚本打印前几个条目的标签及其对应的开头部分内容,帮助确认数据已被正确读取[^3]。
#### 构建词汇表与分词器
对于后续的任务如模型训练来说,通常还需要构建一个映射单词至索引编号的字典以及定义用于分割句子的方法:
```python
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(lambda data: tokenizer(data[1]), train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
print("Vocabulary size:", len(vocab))
```
这里采用了简单的英文基础分词方式,并创建了一个包含未知标记 `<unk>` 的词汇表实例。
#### 创建批量加载器
最后一步是设置好批处理机制以便高效喂入神经网络算法:
```python
def collate_batch(batch):
label_list, text_list, offsets = [], [], [0]
for (_label, _text) in batch:
processed_text = torch.tensor([vocab[token] for token in tokenizer(_text)], dtype=torch.int64)
label_list.append(label - 1)
text_list.append(processed_text)
offsets.append(len(processed_text))
label_list = torch.tensor(label_list, dtype=torch.int64)
text_list = torch.cat(text_list)
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
return label_list.to(device), text_list.to(device), offsets.to(device)
batch_size = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_dataloader = DataLoader(AG_NEWS(split='train'), batch_size=batch_size,
shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(AG_NEWS(split='test'), batch_size=batch_size,
shuffle=False, collate_fn=collate_batch)
```
通过自定义函数 `collate_batch()` 来组合每批次内的所有输入项,从而形成适合传递给 GPU 计算设备的形式[^2]。
阅读全文
相关推荐















