项目作者: sberbank-ai

项目描述 :
BERT-NER (nert-bert) with google bert https://github.com/google-research.
高级语言: Jupyter Notebook
项目地址: git://github.com/sberbank-ai/ner-bert.git
创建时间: 2018-11-21T10:15:33Z
项目社区:https://github.com/sberbank-ai/ner-bert

开源协议:MIT License

下载


0. Papers

There are two solutions based on this architecture.

  1. BSNLP 2019 ACL workshop: solution and paper on multilingual shared task.
  2. The second place solution of Dialogue AGRR-2019 task and paper.

Description

This repository contains solution of NER task based on PyTorch reimplementation of Google’s TensorFlow repository for the BERT model that was released together with the paper BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.

This implementation can load any pre-trained TensorFlow checkpoint for BERT (in particular Google’s pre-trained models).

Old version is in “old” branch.

2. Usage

2.1 Create data

  1. from modules.data import bert_data
  2. data = bert_data.LearnData.create(
  3. train_df_path=train_df_path,
  4. valid_df_path=valid_df_path,
  5. idx2labels_path="/path/to/vocab",
  6. clear_cache=True
  7. )

2.2 Create model

  1. from modules.models.bert_models import BERTBiLSTMAttnCRF
  2. model = BERTBiLSTMAttnCRF.create(len(data.train_ds.idx2label))

2.3 Create Learner

  1. from modules.train.train import NerLearner
  2. num_epochs = 100
  3. learner = NerLearner(
  4. model, data, "/path/for/save/best/model", t_total=num_epochs * len(data.train_dl))

2.4 Predict

  1. from modules.data.bert_data import get_data_loader_for_predict
  2. learner.load_model()
  3. dl = get_data_loader_for_predict(data, df_path="/path/to/df/for/predict")
  4. preds = learner.predict(dl)

2.5 Evaluate

  1. from sklearn_crfsuite.metrics import flat_classification_report
  2. from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer
  3. from modules.analyze_utils.plot_metrics import get_bert_span_report
  4. from modules.analyze_utils.main_metrics import precision_recall_f1
  5. pred_tokens, pred_labels = bert_labels2tokens(dl, preds)
  6. true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])
  7. tokens_report = flat_classification_report(true_labels, pred_labels, digits=4)
  8. print(tokens_report)
  9. results = precision_recall_f1(true_labels, pred_labels)

3. Results

We didn’t search best parametres and obtained the following results.

Model Data set Dev F1 tok Dev F1 span Test F1 tok Test F1 span
OURS
M-BERTCRF-IO FactRuEval - - 0.8543 0.8409
M-BERTNCRF-IO FactRuEval - - 0.8637 0.8516
M-BERTBiLSTMCRF-IO FactRuEval - - 0.8835 0.8718
M-BERTBiLSTMNCRF-IO FactRuEval - - 0.8632 0.8510
M-BERTAttnCRF-IO FactRuEval - - 0.8503 0.8346
M-BERTBiLSTMAttnCRF-IO FactRuEval - - 0.8839 0.8716
M-BERTBiLSTMAttnNCRF-IO FactRuEval - - 0.8807 0.8680
M-BERTBiLSTMAttnCRF-fit_BERT-IO FactRuEval - - 0.8823 0.8709
M-BERTBiLSTMAttnNCRF-fit_BERT-IO FactRuEval - - 0.8583 0.8456
- - - - - -
BERTBiLSTMCRF-IO CoNLL-2003 0.9629 - 0.9221 -
B-BERTBiLSTMCRF-IO CoNLL-2003 0.9635 - 0.9229 -
B-BERTBiLSTMAttnCRF-IO CoNLL-2003 0.9614 - 0.9237 -
B-BERTBiLSTMAttnNCRF-IO CoNLL-2003 0.9631 - 0.9249 -
Current SOTA
DeepPavlov-RuBERT-NER FactRuEval - - - 0.8266
CSE CoNLL-2003 - - 0.931 -
BERT-LARGE CoNLL-2003 0.966 - 0.928 -
BERT-BASE CoNLL-2003 0.964 - 0.924 -