项目作者: CyberZHG

项目描述 :
Transformer-XL with checkpoint loader
高级语言: Python
项目地址: git://github.com/CyberZHG/keras-transformer-xl.git
创建时间: 2019-06-20T10:29:44Z
项目社区:https://github.com/CyberZHG/keras-transformer-xl

开源协议:MIT License

下载


Keras Transformer-XL

Version
License

[中文|English]

Unofficial implementation of Transformer-XL.

Install

  1. pip install keras-transformer-xl

Usage

Load Pretrained Weights

Several configuration files can be found at the info directory.

  1. import os
  2. from keras_transformer_xl import load_trained_model_from_checkpoint
  3. checkpoint_path = 'foo/bar/sota/enwiki8'
  4. model = load_trained_model_from_checkpoint(
  5. config_path=os.path.join(checkpoint_path, 'config.json'),
  6. checkpoint_path=os.path.join(checkpoint_path, 'model.ckpt')
  7. )
  8. model.summary()

About IO

The generated model has two inputs, and the second input is the lengths of memories.

You can use MemorySequence wrapper for training and prediction:

  1. from tensorflow import keras
  2. import numpy as np
  3. from keras_transformer_xl import MemorySequence, build_transformer_xl
  4. class DummySequence(keras.utils.Sequence):
  5. def __init__(self):
  6. pass
  7. def __len__(self):
  8. return 10
  9. def __getitem__(self, index):
  10. return np.ones((3, 5 * (index + 1))), np.ones((3, 5 * (index + 1), 3))
  11. model = build_transformer_xl(
  12. units=4,
  13. embed_dim=4,
  14. hidden_dim=4,
  15. num_token=3,
  16. num_block=3,
  17. num_head=2,
  18. batch_size=3,
  19. memory_len=20,
  20. target_len=10,
  21. )
  22. seq = MemorySequence(
  23. model=model,
  24. sequence=DummySequence(),
  25. target_len=10,
  26. )
  27. model.predict(model, seq, verbose=True)