基于PyTorch的不同深度学习解决方案对比
立即解锁
发布时间: 2025-09-10 01:24:42 阅读量: 8 订阅数: 13 AIGC 


表格数据机器学习实战
### 基于PyTorch的不同深度学习解决方案对比
#### 1. PyTorch与fastai结合
在分析了fastai解决方案中的关键代码区域后,我们来看看fastai在深度学习栈中的位置。下图展示了本示例中处理表格数据的深度学习栈结构:
```mermaid
graph LR
classDef process fill:#E5F6FF,stroke:#73A6FF,stroke-width:2px;
A(PyTorch Ecosystem):::process --> B(PyTorch):::process
C(Tabular data library):::process --> B
D(High-level API):::process --> B
C --> E(fastai):::process
D --> E
```
fastai在栈中既可以看作是高级API,也可以看作是表格数据处理库。回顾完fastai解决方案的代码后,接下来我们将其与Keras解决方案进行对比。
#### 2. fastai与Keras解决方案对比
对于Airbnb纽约房源价格预测问题,我们已经看到了Keras和fastai两种深度学习解决方案。下面是对这两种方案优缺点的对比:
| 方案 | 优点 | 缺点 |
| ---- | ---- | ---- |
| Keras | 1. 模型细节透明<br>2. 社区庞大,容易找到常见问题的解决方案 | 1. 没有内置对表格数据的支持,需要自定义代码来定义模型的管道和层<br>2. 训练过程较慢 |
| fastai | 1. 框架明确支持表格数据模型,代码更紧凑<br>2. 框架自动定义管道<br>3. 包含方便的函数,便于检查数据集 | 1. 框架的假设可能导致难以调试的问题<br>2. 用户社区较小,在部署生产应用方面不如Keras社区活跃,解决问题可能更困难 |
在性能方面,Keras模型的准确率在70% - 74%之间,而fastai模型的准确率始终保持在81%左右。Keras的底层框架是TensorFlow,而fastai是基于PyTorch构建的。它们都是通用的高级深度学习API,可用于处理多种数据类型,不仅仅是表格数据。
#### 3. PyTorch与TabNet结合
之前介绍的工具都是通用的深度学习库,现在我们来尝试专门为表格数据设计的库TabNet。
##### 3.1 TabNet解决方案的关键代码
- **导入必要的库**
```python
! pip install pytorch-tabnet
import torch
from pytorch_tabnet.tab_model import TabNetClassifier
```
与fastai的导入语句不同,TabNet的导入语句明确包含了对PyTorch库`torch`的导入。
- **生成NumPy数组列表**
```python
list_of_lists_train = []
list_of_lists_test = []
list_of_lists_valid = []
for i in range(0,7):
list_of_lists_train.append(X_train_list[i].tolist())
list_of_lists_valid.append(X_valid_list[i].tolist())
list_of_lists_test.append(X_test_list[i].tolist())
X_train = np.array(list_of_lists_train).T
X_valid = np.array(list_of_lists_valid).T
X_test = np.array(list_of_lists_test).T
y_train = dtrain.target
y_valid = dvalid.target
y_test = test.target
```
TabNet解决方案要求模型的输入是NumPy数组列表的形式,因此需要进行上述转换。
- **定义TabNet模型**
```python
tb_cls = TabNetClassifier(optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=1e-3),
scheduler_params={"step_size":10,"gamma":0.9},
scheduler_fn=to
```
0
0
复制全文
相关推荐









