项目作者: hwaking

项目描述 :
this project used CNN, RNN and Attention in text classification
高级语言: Python
项目地址: git://github.com/hwaking/deeplearning-text-classification.git
创建时间: 2019-01-08T13:09:13Z
项目社区:https://github.com/hwaking/deeplearning-text-classification

开源协议:MIT License

下载


基于 dennybritz’s 项目 cnn-text-classification-tf, 添加RNN+Attention 实现,同时对代码进行了简化和修改

Requirements

  • Python 3
  • Tensorflow > 0.12
  • Numpy

Training

  • 首先在config.py中设置模型参数,具体参数含义如下:
  1. config parameters:
  2. # 常规参数
  3. -- learning_rate 学习率
  4. -- training_steps 迭代次数
  5. -- batch_size 批数据量
  6. -- display_step 多少次打印一次结果
  7. -- evaluate_every 多少次评估一次模型
  8. -- checkpoint_every 多少次保存一次模型
  9. -- num_checkpoints 保存模型个数
  10. -- early_stop_steps 提前停止
  11. # 网络参数
  12. -- embedding_size 词向量embedding长度
  13. -- num_hidden 隐藏神经元个数
  14. -- num_classes 类别数目
  15. -- dropout_keep_prob dropout比例
  16. -- l2_reg_lambda l2正则化强度
  17. # CNN 网络参数
  18. -- filter_sizes 卷积核规格
  19. # RNN 网络参数
  20. -- network 网络类型lstm/gru
  21. -- bi_drection 是否选择双向网络
  22. -- timesteps = 56 序列长度
  23. -- attention_size attention神经元个数
  24. # 硬件设置
  25. -- allow_soft_placement = True
  26. -- log_device_placement = False
  27. # 数据路径
  28. -- dev_sample_percentage = 0.1 验证集比例
  29. -- positive_data_file = "./data/rt-polaritydata/rt-polarity.pos"
  30. -- negative_data_file = "./data/rt-polaritydata/rt-polarity.neg"

Train:

  1. # text cnn training command
  2. python train_cnn.py
  3. # text rnn training command
  4. python train_rnn.py

Evaluating

‘’’
待加入
‘’’

References