McQueen is a MCQ solving library that allows researchers and developers to train and test several existing MCQ solvers on custom textual mcq datasets. Related Paper: Exploring ways to incorporate additional knowledge to improve Natural Language Commonsense Question Answering
| bert-mcq-concat | contains a simple BERT based MCQ solver that score each choice string using BERT w.r.t. a premise string. File: pytorch_transformers/models/hf_bert_mcq_concat.py |
| bert-mcq-parallel-max | functionality for simple BERT based MCQ solver that scores each choice string using BERT w.r.t. an array premise strings and takes the maximum score as the confidence value for the choice. File: pytorch_transformers/models/hf_bert_mcq_parallel.py |
| bert-mcq-weighted-sum | functionality for simple BERT based MCQ solver that for each choice string run bert choice string, premise string pair and perform a weighted sum over the pooled cls token vectors to score the choice. File: pytorch_transformers/models/hf_bert_weighted_sum.py |
| bert-mac | coming soon. |
To run the huggingface models
- conda create --name env-name python=3.6
- source activate env-name
- pip install pytorch-transformers
- git checkout the repo
The training data should be in the format that is mentioned in the doc
There are three models as of now, namely bert-mcq-concat, bert-mcq-parallel-max, bert-mcq-weighted-average
Here is a sample run command to train bert-mcq-concat model with bert-large-whole-word-uncased :
- Running bert-mcq-concat
nohup python hf_trainer.py --training_data_path mcq_abductive_train.jsonl --validation_data_path mcq_abductive_dev.jsonl --mcq_model bert-mcq-concat --bert_model bert-large-uncased-whole-word-masking --output_dir ./serdir_bertlgww_concat_2e5_abd --num_train_epochs 4 --train_batch_size 64 --do_eval --do_train --max_seq_length 68 --do_lower_case --gradient_accumulation_steps 1 --learning_rate 2e-6 --weight_decay 0.009 --eval_freq 1000 --warmup_steps 250 &> bertlgww_concat_2e5_009_abd.out
- Running bert-mcq-parallel-max
nohup python hf_trainer.py --training_data_path mcq_abductive_train.jsonl --validation_data_path mcq_abductive_dev.jsonl --mcq_model bert-mcq-parellel-max --bert_model bert-large-uncased-whole-word-masking --output_dir ./serdir_bertlgww_concat_2e5_abd --num_train_epochs 4 --train_batch_size 64 --do_eval --do_train --max_seq_length 68 --do_lower_case --gradient_accumulation_steps 1 --learning_rate 2e-6 --weight_decay 0.009 --eval_freq 1000 --warmup_steps 250
- Running simple sum model
nohup python hf_trainer.py --training_data_path mcq_sc_sim_train.jsonl --validation_data_path mcq_sc_sim_dev.jsonl --mcq_model bert-mcq-simple-sum --bert_model bert-large-uncased-whole-word-masking --output_dir ./serdir_bertlgww_simple_sum_4e6_001_social --num_train_epochs 4 --train_batch_size 16 --do_eval --do_train --max_seq_length 68 --do_lower_case --gradient_accumulation_steps 1 --eval_freq 500 --learning_rate 4e-6 --warmup_steps 400 --weight_decay 0.001
- Running Weighted sum model with ( with tied weights )
nohup python hf_trainer.py --training_data_path mcq_sc_sim_train.jsonl --validation_data_path mcq_sc_sim_dev.jsonl --mcq_model bert-mcq-weighted-sum --tie_weights_weighted_sum --bert_model bert-large-uncased-whole-word-masking --output_dir ./serdir_bertlgww_simple_sum_4e6_001_social --num_train_epochs 4 --train_batch_size 16 --do_eval --do_train --max_seq_length 68 --do_lower_case --gradient_accumulation_steps 1 --eval_freq 500 --learning_rate 4e-6 --warmup_steps 400 --weight_decay 0.001
- Running weighted sum without tied weights
nohup python hf_trainer.py --training_data_path mcq_sc_sim_train.jsonl --validation_data_path mcq_sc_sim_dev.jsonl --mcq_model bert-mcq-weighted-sum --bert_model bert-large-uncased-whole-word-masking --output_dir ./serdir_bertlgww_simple_sum_4e6_001_social --num_train_epochs 4 --train_batch_size 16 --do_eval --do_train --max_seq_length 68 --do_lower_case --gradient_accumulation_steps 1 --eval_freq 500 --learning_rate 4e-6 --warmup_steps 400 --weight_decay 0.001
Please see the file hf_trainer.py to read the meaning and the default value of the parameters.
