项目作者: charles9n

项目描述 :
a sklearn wrapper for Google's BERT model
高级语言: Jupyter Notebook
项目地址: git://github.com/charles9n/bert-sklearn.git
创建时间: 2019-02-18T03:48:10Z
项目社区:https://github.com/charles9n/bert-sklearn

开源协议:Apache License 2.0

下载


scikit-learn wrapper to finetune BERT

A scikit-learn wrapper to finetune Google’s BERT model for text and token sequence tasks based on the huggingface pytorch port.

  • Includes configurable MLP as final classifier/regressor for text and text pair tasks
  • Includes token sequence classifier for NER, PoS, and chunking tasks
  • Includes SciBERT and BioBERT pretrained models for scientific and biomedical domains.

Try in Google Colab!

installation

requires python >= 3.5 and pytorch >= 0.4.1

  1. git clone -b master https://github.com/charles9n/bert-sklearn
  2. cd bert-sklearn
  3. pip install .

basic operation

model.fit(X,y) i.e finetune BERT

  • X: list, pandas dataframe, or numpy array of text, text pairs, or token lists

  • y : list, pandas dataframe, or numpy array of labels/targets

  1. from bert_sklearn import BertClassifier
  2. from bert_sklearn import BertRegressor
  3. from bert_sklearn import load_model
  4. # define model
  5. model = BertClassifier() # text/text pair classification
  6. # model = BertRegressor() # text/text pair regression
  7. # model = BertTokenClassifier() # token sequence classification
  8. # finetune model
  9. model.fit(X_train, y_train)
  10. # make predictions
  11. y_pred = model.predict(X_test)
  12. # make probabilty predictions
  13. y_pred = model.predict_proba(X_test)
  14. # score model on test data
  15. model.score(X_test, y_test)
  16. # save model to disk
  17. savefile='/data/mymodel.bin'
  18. model.save(savefile)
  19. # load model from disk
  20. new_model = load_model(savefile)
  21. # do stuff with new model
  22. new_model.score(X_test, y_test)

See demo notebook.

model options

  1. # try different options...
  2. model.bert_model = 'bert-large-uncased'
  3. model.num_mlp_layers = 3
  4. model.max_seq_length = 196
  5. model.epochs = 4
  6. model.learning_rate = 4e-5
  7. model.gradient_accumulation_steps = 4
  8. # finetune
  9. model.fit(X_train, y_train)
  10. # do stuff...
  11. model.score(X_test, y_test)

See options

hyperparameter tuning

  1. from sklearn.model_selection import GridSearchCV
  2. params = {'epochs':[3, 4], 'learning_rate':[2e-5, 3e-5, 5e-5]}
  3. # wrap classifier in GridSearchCV
  4. clf = GridSearchCV(BertClassifier(validation_fraction=0),
  5. params,
  6. scoring='accuracy',
  7. verbose=True)
  8. # fit gridsearch
  9. clf.fit(X_train ,y_train)

See demo_tuning_hyperparameters notebook.

GLUE datasets

The train and dev data sets from the GLUE(Generalized Language Understanding Evaluation) benchmarks were used with bert-base-uncased model and compared againt the reported results in the Google paper and GLUE leaderboard.

MNLI(m/mm) QQP QNLI SST-2 CoLA STS-B MRPC RTE
BERT base(leaderboard) 84.6/83.4 89.2 90.1 93.5 52.1 87.1 84.8 66.4
bert-sklearn 83.7/83.9 90.2 88.6 92.32 58.1 89.7 86.8 64.6

Individual runs can be found can be found here.

CoNLL-2003 Named Entity Recognition(NER)

NER results for CoNLL-2003 shared task

dev f1 test f1
BERT paper 96.4 92.4
bert-sklearn 96.04 91.97

Span level stats on test:

  1. processed 46666 tokens with 5648 phrases; found: 5740 phrases; correct: 5173.
  2. accuracy: 98.15%; precision: 90.12%; recall: 91.59%; FB1: 90.85
  3. LOC: precision: 92.24%; recall: 92.69%; FB1: 92.46 1676
  4. MISC: precision: 78.07%; recall: 81.62%; FB1: 79.81 734
  5. ORG: precision: 87.64%; recall: 90.07%; FB1: 88.84 1707
  6. PER: precision: 96.00%; recall: 96.35%; FB1: 96.17 1623

See ner_english notebook for a demo using 'bert-base-cased' model.

NCBI Biomedical NER

NER results using bert-sklearn with SciBERT and BioBERT on the the NCBI disease Corpus name recognition task.

Previous SOTA for this task is 87.34 for f1 on the test set.

test f1 (bert-sklearn) test f1 (from papers)
BERT base cased 85.09 85.49
SciBERT basevocab cased 88.29 86.91
SciBERT scivocab cased 87.73 86.45
BioBERT pubmed_v1.0 87.86 87.38
BioBERT pubmed_pmc_v1.0 88.26 89.36
BioBERT pubmed_v1.1 87.26 NA

See ner_NCBI_disease_BioBERT_SciBERT notebook for a demo using SciBERT and BioBERT models.

See SciBERT paper and BioBERT paper for more info on the respective models.

Other examples

tests

Run tests with pytest :

  1. python -m pytest -sv tests/

references