det plot_learning_curve(train_loss,dev_loss,title=''): total _ steps = len ( train _ loss ) x _1= range ( total _ steps ) x _2= x _1[:: len ( train _ loss )// len ( dev _ loss )] plt . figure (1, figsize =(6,4)) plt . plot ( x _1, train _ loss , c =' tab : red ', label =' train ') plt . plot ( x _2, dev _ loss , c =' tab : cyan ', label =' dev ') plt . ylim (0.0,5.) plt . xlabel (' Training steps ') plt . ylabel (' MSE loss ') plt . title (' Learning curve of []'. format ( title )) plt . legend () plt . show ()
时间: 2024-04-10 12:32:33 浏览: 80
这是一个用于绘制学习曲线的函数。它的输入参数包括训练损失(train_loss)和验证损失(dev_loss),以及可选的标题(title)。
函数首先计算总步数(total_steps),然后创建两个x轴的范围。其中x_1的范围是从0到总步数,而x_2的范围是从x_1中按照训练损失和验证损失的比例进行采样得到的。
然后,创建一个图形对象,并设置其大小为(6, 4)。接下来,使用plt.plot函数分别绘制训练损失和验证损失的曲线,颜色分别为红色和青色,并添加标签。同时,设置y轴的范围为0.0到5.0,并添加合适的x轴和y轴标签,以及可选的标题。最后,调用plt.legend()函数显示图例,并调用plt.show()函数显示绘制的图形。
这个函数可以帮助你可视化训练过程中训练损失和验证损失的变化情况,以了解模型的学习进展。
相关问题
PS D:\DeepStudy\RT-DETR-main\rtdetr_pytorch> python tools/train.py -c configs/rtdetr/rtdetr_r18vd_6x_coco.yml --amp --seed=0 Disabling PyTorch because PyTorch >= 2.1 is required but found 2.0.1+cu118 None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used. Not init distributed mode. Start training Load PResNet18 state_dict Initial lr: [1e-05, 1e-05, 0.0001, 0.0001] loading annotations into memory... Done (t=0.01s) creating index... index created! loading annotations into memory... Done (t=0.00s) creating index... index created! number of params: 20184464 Disabling PyTorch because PyTorch >= 2.1 is required but found 2.0.1+cu118 None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used. Disabling PyTorch because PyTorch >= 2.1 is required but found 2.0.1+cu118 None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used. Disabling PyTorch because PyTorch >= 2.1 is required but found 2.0.1+cu118 None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used. Disabling PyTorch because PyTorch >= 2.1 is required but found 2.0.1+cu118 None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used. Epoch: [0] [ 0/130] eta: 0:50:29 lr: 0.000010 loss: 18.8544 (18.8544) loss_vfl: 0.5347 (0.5347) loss_bbox: 0.5407 (0.5407) loss_giou: 1.4987 (1.4987) loss_vfl_aux_0: 0.5732 (0.5732) loss_bbox_aux_0: 0.5383 (0.5383) loss_giou_aux_0: 1.5640 (1.5640) loss_vfl_aux_1: 0.5059 (0.5059) loss_bbox_aux_1: 0.6054 (0.6054) loss_giou_aux_1: 1.5544 (1.5544) loss_vfl_aux_2: 0.5083 (0.5083) loss_bbox_aux_2: 0.7757 (0.7757) loss_giou_aux_2: 1.6962 (1.6962) loss_vfl_dn_0: 0.9473 (0.9473) loss_bbox_dn_0: 0.4318 (0.4318) loss_giou_dn_0: 1.3516 (1.3516) loss_vfl_dn_1: 0.8394 (0.8394) loss_bbox_dn_1: 0.4318 (0.4318) loss_giou_dn_1: 1.3516 (1.3516) loss_vfl_dn_2: 0.8223 (0.8223) loss_bbox_dn_2: 0.4318 (0.4318) loss_giou_dn_2: 1.3516 (1.3516) time: 23.3010 data: 19.7480 max mem: 2363 Traceback (most recent call last): File "D:\DeepStudy\RT-DETR-main\rtdetr_pytorch\tools\train.py", line 50, in <module> main(args) File "D:\DeepStudy\RT-DETR-main\rtdetr_pytorch\tools\train.py", line 36, in main solver.fit() File "D:\DeepStudy\RT-DETR-main\rtdetr_pytorch\tools\..\src\solver\det_solver.py", line 37, in fit train_stats = train_one_epoch( ^^^^^^^^^^^^^^^^ File "D:\DeepStudy\RT-DETR-main\rtdetr_pytorch\tools\..\src\solver\det_engine.py", line 44, in train_one_epoch loss_dict = criterion(outputs, targets) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\DeepStudyHJ\Anaconda3data\envs\rtdetr\Lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\DeepStudy\RT-DETR-main\rtdetr_pytorch\tools\..\src\zoo\rtdetr\rtdetr_criterion.py", line 238, in forward indices = self.matcher(outputs_without_aux, targets) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\DeepStudyHJ\Anaconda3data\envs\rtdetr\Lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\DeepStudyHJ\Anaconda3data\envs\rtdetr\Lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "D:\DeepStudy\RT-DETR-main\rtdetr_pytorch\tools\..\src\zoo\rtdetr\matcher.py", line 99, in forward cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\DeepStudy\RT-DETR-main\rtdetr_pytorch\tools\..\src\zoo\rtdetr\box_ops.py", line 52, in generalized_box_iou assert (boxes1[:, 2:] >= boxes1[:, :2]).all() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AssertionError
### 问题分析
用户尝试运行 RT-DETR 的训练脚本时,遇到两个主要错误:
1. **PyTorch 版本不足**:提示 `Disabling PyTorch because PyTorch >= 2.1 is required but found 2.0.1`。
2. **AssertionError in box_ops.py**:在执行过程中抛出断言错误。
#### PyTorch 版本问题
RT-DETR 对 PyTorch 的版本要求较为严格,官方推荐使用 **PyTorch 2.0.1** 并与 CUDA 11.8 兼容[^4]。虽然部分文档指出需要 PyTorch ≥ 2.1,但实际测试表明 **PyTorch 2.0.1** 是兼容性较好的选择。若系统中安装的 PyTorch 版本高于 2.0.1,可能会导致模块调用失败或某些功能被禁用。
```bash
# 安装推荐版本 PyTorch 2.0.1(支持 CUDA 11.7/11.8)
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
```
#### AssertionError in box_ops.py
该错误通常发生在边界框坐标处理阶段,具体原因可能包括:
- 输入的边界框坐标格式不正确(例如超出图像尺寸、负值等)。
- 数据增强过程中的异常操作导致张量维度不匹配。
- 某些样本的标注信息存在异常(如无效的 xmin > xmax 或 ymin > ymax)。
可以在 `box_ops.py` 中找到触发断言的位置,并检查数据预处理流程是否符合模型预期。建议添加日志输出以定位异常样本:
```python
# 示例:在 box_ops.py 中增加调试信息
def box_xyxy_to_cxcywh(x):
x1, y1, x2, y2 = x.unbind(-1)
assert (x2 > x1).all(), f"x2 <= x1 detected: {x1}, {x2}"
assert (y2 > y1).all(), f"y2 <= y1 detected: {y1}, {y2}"
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
return torch.stack([cx, cy, w, h], dim=-1)
```
#### 环境配置建议
为避免依赖冲突,建议使用 Conda 创建独立虚拟环境进行开发和测试:
```bash
conda create -n rtdetr python=3.9
conda activate rtdetr
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
```
#### 代码兼容性调整
如果仍然遇到问题,可能是由于源码对某些函数或模块的调用方式与当前 PyTorch 不兼容。可以参考官方仓库更新日志,查找是否有针对新版本 PyTorch 的补丁或兼容性修复提交。
---
阅读全文
相关推荐














