PyTorch Fairseq项目中的大规模反向翻译技术详解
fairseq 项目地址: https://gitcode.com/gh_mirrors/fai/fairseq
什么是反向翻译
反向翻译(Back-Translation)是机器翻译领域的一项重要技术,它通过利用目标语言的单语数据来提升翻译模型的质量。其核心思想是:首先训练一个反向翻译模型(如德语→英语),然后用这个模型将目标语言的单语数据(如德语)翻译回源语言(英语),最后将这些合成的平行数据与真实的平行数据一起训练最终的翻译模型(英语→德语)。
Fairseq中的预训练模型
Fairseq提供了基于WMT'18英语-德语翻译任务的预训练模型,该模型采用了Transformer架构,并使用了大规模反向翻译技术进行训练。这个模型在WMT'18比赛中取得了优异成绩。
模型特点:
- 基于Transformer架构
- 使用大规模反向翻译技术增强
- 在WMT'18英语-德语翻译任务上表现优异
如何使用预训练模型
要使用预训练的英语-德语翻译模型,需要先安装必要的依赖:
pip install subword_nmt sacremoses
然后可以通过以下Python代码加载和使用模型:
import torch
# 加载WMT'18英语-德语翻译模型
en2de_ensemble = torch.hub.load(
'pytorch/fairseq', 'transformer.wmt18.en-de',
checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt',
tokenizer='moses', bpe='subword_nmt')
# 进行翻译
print(en2de_ensemble.translate('Hello world!'))
# 输出: 'Hallo Welt!'
训练自己的反向翻译模型
下面详细介绍如何在Fairseq中训练自己的英语-德语反向翻译模型。
第一步:准备平行数据并训练基线模型
- 下载并预处理WMT'18英语-德语数据:
cd examples/backtranslation/
bash prepare-wmt18en2de.sh
cd ../..
- 对数据进行二进制化处理:
TEXT=examples/backtranslation/wmt18_en_de
fairseq-preprocess \
--joined-dictionary \
--source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \
--workers 20
- (可选)仅使用平行数据训练基线模型:
CHECKPOINT_DIR=checkpoints_en_de_parallel
fairseq-train --fp16 \
data-bin/wmt18_en_de \
--source-lang en --target-lang de \
--arch transformer_wmt_en_de_big --share-all-embeddings \
--dropout 0.3 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--max-tokens 3584 --update-freq 16 \
--max-update 30000 \
--save-dir $CHECKPOINT_DIR
第二步:反向翻译单语数据
- 训练反向翻译模型(德语→英语):
CHECKPOINT_DIR=checkpoints_de_en_parallel
fairseq-train --fp16 \
data-bin/wmt18_en_de \
--source-lang de --target-lang en \
--arch transformer_wmt_en_de_big --share-all-embeddings \
--dropout 0.3 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--max-tokens 3584 --update-freq 16 \
--max-update 30000 \
--save-dir $CHECKPOINT_DIR
- 准备单语数据并进行反向翻译:
# 下载并准备单语数据
cd examples/backtranslation/
bash prepare-de-monolingual.sh
cd ../..
# 对单语数据进行二进制化处理
TEXT=examples/backtranslation/wmt18_de_mono
for SHARD in $(seq -f "%02g" 0 24); do \
fairseq-preprocess \
--only-source \
--source-lang de --target-lang en \
--joined-dictionary \
--srcdict data-bin/wmt18_en_de/dict.de.txt \
--testpref $TEXT/bpe.monolingual.dedup.${SHARD} \
--destdir data-bin/wmt18_de_mono/shard${SHARD} \
--workers 20; \
cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \
done
# 执行反向翻译
mkdir backtranslation_output
for SHARD in $(seq -f "%02g" 0 24); do \
fairseq-generate --fp16 \
data-bin/wmt18_de_mono/shard${SHARD} \
--path $CHECKPOINT_DIR/checkpoint_best.pt \
--skip-invalid-size-inputs-valid-test \
--max-tokens 4096 \
--sampling --beam 1 \
> backtranslation_output/sampling.shard${SHARD}.out; \
done
第三步:训练最终模型
- 合并平行数据和反向翻译数据:
python examples/backtranslation/extract_bt_data.py \
--minlen 1 --maxlen 250 --ratio 1.5 \
--output backtranslation_output/bt_data --srclang en --tgtlang de \
backtranslation_output/sampling.shard*.out
# 二进制化处理
TEXT=backtranslation_output
fairseq-preprocess \
--source-lang en --target-lang de \
--joined-dictionary \
--srcdict data-bin/wmt18_en_de/dict.en.txt \
--trainpref $TEXT/bt_data \
--destdir data-bin/wmt18_en_de_bt \
--workers 20
- 训练最终模型:
CHECKPOINT_DIR=checkpoints_en_de_parallel_plus_bt
fairseq-train --fp16 \
data-bin/wmt18_en_de_para_plus_bt \
--upsample-primary 16 \
--source-lang en --target-lang de \
--arch transformer_wmt_en_de_big --share-all-embeddings \
--dropout 0.3 --weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--max-tokens 3584 --update-freq 16 \
--max-update 100000 \
--save-dir $CHECKPOINT_DIR
技术要点与最佳实践
-
数据准备:反向翻译的质量很大程度上取决于单语数据的质量和数量。建议使用高质量、领域相关的单语数据。
-
模型架构:Fairseq使用的是Transformer架构,特别是"transformer_wmt_en_de_big"配置,这是经过优化的WMT英德翻译模型架构。
-
训练技巧:
- 使用混合精度训练(--fp16)可以显著减少显存占用并加快训练速度
- 适当调整--update-freq参数以适应不同的GPU数量
- 使用--upsample-primary选项可以增加平行数据的采样比例
-
评估指标:建议使用sacreBLEU进行评估,这是机器翻译领域标准的评估方法。
-
反向翻译策略:可以使用不同的解码策略(如采样、beam search等)生成多样化的反向翻译数据。
通过Fairseq实现的大规模反向翻译技术,开发者可以显著提升低资源语言对的翻译质量,特别是在平行数据有限的情况下。这种方法已被证明在多个翻译任务中都能带来显著的性能提升。
fairseq 项目地址: https://gitcode.com/gh_mirrors/fai/fairseq
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考