Skip to content

Conversation

@erenup
Copy link
Contributor

@erenup erenup commented Dec 14, 2019

Hi, @julien-c @thomwolf this PR is based on #1386 and #1984.

  • This PR modified run_squad.py and models_roberta to support Roberta.

  • This PR also made use of multiple processing to accelerate converting examples to features. (Converting examples to feature needed 15minus before and 34 seconds now with 24 cpu cores' acceleration. The threads number is 1 by default which is the same as the original single processing's speed).

  • The result of Roberta large on squad1.1:
    {'exact': 87.26584673604542, 'f1': 93.77663586186483, 'total': 10570, 'HasAns_exact': 87.26584673604542, 'HasAns_f1': 93.77663586186483, 'HasAns_total': 10570, 'best_exact': 87.26584673604542, 'best_exact_thresh': 0.0, 'best_f1': 93.77663586186483, 'best_f1_thresh': 0.0}, which is sighltly lower than Add RoBERTa question answering & Update SQuAD runner to support RoBERTa #1386 in a single run.
    Parameters are python ./examples/run_squad.py --model_type roberta --model_name_or_path roberta-large --do_train --do_eval --do_lower_case \ --train_file data/squad1/train-v1.1.json --predict_file data/squad1/dev-v1.1.json --learning_rate 1.5e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --output_dir ./models_roberta/large_squad1 --per_gpu_eval_batch_size=3 --per_gpu_train_batch_size=3 --save_steps 10000 --warmup_steps=500 --weight_decay=0.01. Hope this gap can be improved by `add_prefix_space=true' . I will do this comparasion in the next days.

  • The result of Roberta base is '{'exact': 80.65279091769158, 'f1': 88.57296806525736, 'total': 10570, 'HasAns_exact': 80.65279091769158, 'HasAns_f1': 88.57296806525736, 'HasAns_total': 10570, 'best_exact': 80.65279091769158, 'best_exact_thresh': 0.0, 'best_f1': 88.57296806525736, 'best_f1_thresh': 0.0}'. Roberta-base was also tested since it's more easy to be reproduced.

  • The results of bert-base-uncased is `{'exact': 79.21475875118259, 'f1': 87.13734938098504, 'total': 10570, 'HasAns_exact': 79.21475875118259, 'HasAns_f1': 87.13734938098504, 'HasAns_total': 10570, 'best_exact': 79.21475875118259, 'best_exact_thresh': 0.0, 'best_f1': 87.13734938098504, 'best_f1_thresh': 0.0}'. This is tested for the multiple processing's influence on other models. This result is the same with bert-base-uncased result without multiple processing.

  • Hope that someone else can help to reproduce my results. thank you! I will continue to find if three is some ways to improve the roberta-large's results.

  • Squad1 model on google drive roberta-large-finetuned-squad:

@codecov-io
Copy link

codecov-io commented Dec 14, 2019

Codecov Report

Merging #2173 into master will decrease coverage by 1.35%.
The diff coverage is 9.09%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #2173      +/-   ##
==========================================
- Coverage   80.79%   79.43%   -1.36%     
==========================================
  Files         113      113              
  Lines       17013    17067      +54     
==========================================
- Hits        13745    13558     -187     
- Misses       3268     3509     +241
Impacted Files Coverage Δ
transformers/data/metrics/squad_metrics.py 0% <0%> (ø) ⬆️
transformers/modeling_roberta.py 53.2% <21.21%> (-18.57%) ⬇️
transformers/data/processors/squad.py 14.75% <5.5%> (+0.56%) ⬆️
transformers/modeling_tf_pytorch_utils.py 9.72% <0%> (-85.42%) ⬇️
transformers/tests/modeling_tf_common_test.py 80.51% <0%> (-16.42%) ⬇️
transformers/modeling_xlnet.py 72.21% <0%> (-2.33%) ⬇️
transformers/modeling_ctrl.py 94.27% <0%> (-2.21%) ⬇️
transformers/modeling_openai.py 80.13% <0%> (-1.33%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7bd11dd...805c21a. Read the comment docs.

Comment on lines +124 to +125
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence + 1 \
if 'roberta' in str(type(tokenizer)) else tokenizer.max_len - tokenizer.max_len_single_sentence
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! We'll eventually have to think of an abstraction so that this method stays tokenizer-agnostic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. That would be better.

Comment on lines 588 to 594
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForMultipleChoice.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
loss, start_scores, end_scores = outputs[:2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update this to a similar example to that of BertForQuestionAnswering

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the usage example. Could you please help with the failed check? I just add some comments but it failed. I also did python -m pytest transformers/tests/modeling_roberta_test.py and all tests are passed. Thank you very much.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's currently an error with the test due to a segmentation fault. I'm fixing it on #2207, don't worry about it here.

@dayihengliu
Copy link

dayihengliu commented Dec 17, 2019

Really nice job!
Here are my results of RoBERTa-large on SQuAD using this PR:
Results: {'exact': 84.52792049187232, 'f1': 88.0216698977779, 'total': 11873, 'HasAns_exact': 80.90418353576248, 'HasAns_f1': 87.9017015344667, 'HasAns_total': 5928, 'NoAns_exact': 88.1412952060555, 'NoAns_f1': 88.1412952060555, 'NoAns_total': 5945, 'best_exact': 84.52792049187232, 'best_exact_thresh': 0.0, 'best_f1': 88.02166989777776, 'best_f1_thresh': 0.0}
The hyper-parameters are as follows:
python ./examples/run_squad.py \ --model_type roberta \ --model_name_or_path roberta-large \ --do_train \ --do_eval \ --do_lower_case \ --train_file data/squad2/train-v2.0.json \ --predict_file data/squad2/dev-v2.0.json \ --learning_rate 2e-5 \ --num_train_epochs 2 \ --max_seq_length 384 \ --doc_stride 128 \ --output_dir ./models_roberta/large_squad2 \ --per_gpu_eval_batch_size=6 \ --per_gpu_train_batch_size=6 \ --save_steps 10000 --warmup_steps=500 --weight_decay=0.01 --overwrite_cache --overwrite_output_dir --threads 24 --version_2_with_negative

@julien-c julien-c changed the title run_squa with roberta run_squad with roberta Dec 20, 2019
@thomwolf
Copy link
Member

Really nice, thanks a lot @erenup

@thomwolf thomwolf merged commit 18601c3 into huggingface:master Dec 21, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants