bug概述:【踩坑记录📝】Removed shared tensor while saving.
简单来说,这个bug的危害是trainer.save()无法正确存储权重。这篇博文的作者也给出了两种处理方法,但要么要改transformers版本,要么要包裹Trainer类,太麻烦了。
我使用transformers版本4.38.0,deepspeed zero 0/1/2复现了这篇博文所述bug,研究得知bug的具体原因后,给出以下解决思路:
第一步:获得正确的state_dict
trainer调用_save方法时,默认入参state_dict=None,然后_save方法通过self.model.state_dict()获得state_dict,导致报错。原因是此时self.model是通过accelerator加载的,state_dict需要用self.accelerator.get_state_dict(self.model)的方式获得。
解决思路是提前在调用_save之前提前通过self.accelerator.get_state_dict(self.model)把state_dict取出,然后直接带state_dict入参来调用_save方法,实现权重存储。
我在这里借鉴了QWen-VL仓库finetune.py的写法,参考代码如下:
def safe_save_model_for_hf_trainer