项目作者: SaeedNajafi

项目描述 :
Implementation of the Optimal Completion Distillation for Sequence Labeling
高级语言: Python
项目地址: git://github.com/SaeedNajafi/pytorch-ocd.git
创建时间: 2019-03-07T19:47:34Z
项目社区:https://github.com/SaeedNajafi/pytorch-ocd

开源协议:MIT License

下载


CircleCI

Optimal Completion Distillation (OCD) Training

Implementation of the Optimal Completion Distillation for Sequence Labeling

source : https://arxiv.org/abs/1810.01398

Requirements

python3, pytorch 1.0.0

Install

  1. python3 -m venv env
  2. source env/bin/activate
  3. pip3 install .

How to use?

look at https://github.com/SaeedNajafi/pytorch-ocd/blob/master/ocd/__init__.py#L50
and

https://github.com/SaeedNajafi/pytorch-ocd/blob/master/tests/test_ocd.py#L132

  1. from ocd import OCD
  2. ocd_trainer = OCD(vocab_size=10, end_symbol_id=9)
  3. ... # model defines scores for each step and each possible output token.
  4. ocd_loss = ocd_trainer(model_scores, gold_output_sequence)
  5. ... # backprop with ocd_loss