经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 大数据/云/AI » 人工智能基础 » 查看文章
【NLP】使用bert
来源:cnblogs  作者:水奈樾  时间:2019/2/14 9:09:49  对本文有异议

# 参考 https://blog.csdn.net/luoyexuge/article/details/84939755 小做改动

需要:

  github上下载bert的代码:https://github.com/google-research/bert

  下载google训练好的中文语料模型:https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip

使用:

  使用bert,其实是使用几个checkpoint(ckpt)文件。上面下载的zip是google训练好的bert,我们可以在那个zip内的ckpt文件基础上继续训练,获得更贴近具体任务的ckpt文件。

 如果是直接使用训练好的ckpt文件(就是bert模型),只需如下代码,定义model,获得model的值

  1. from bert import modeling
    # 使用数据加载BertModel,获取对应的字embedding
  2. model = modeling.BertModel(
  3. config=bert_config,
  4. is_training=is_training,
  5. input_ids=input_ids,
  6. input_mask=input_mask,
  7. token_type_ids=segment_ids,
  8. use_one_hot_embeddings=use_one_hot_embeddings
  9. )
  10. # 获取对应的embedding 输入数据[batch_size, seq_length, embedding_size]
  11. embedding = model.get_sequence_output()

 

这里的bert_config 是之前定义的bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file);输入是input_ids, input_mask, segment_ids三个向量;还有两个设置is_training(False), use_one_hot_embedding(False),这样的设置还有很多,这里只列举这两个。。

关于FLAGS,需要提到TensorFlow的flags,相当于配置运行变量,设置如下:

  1. import tensorflow as tf
  2. flags = tf.flags
  3. FLAGS = flags.FLAGS
  4. # 预训练的中文model路径和项目路径
  5. bert_path = '/home/xiangbo_wang/xiangbo/NER/chinese_L-12_H-768_A-12/'
  6. root_path = '/home/xiangbo_wang/xiangbo/NER/BERT-BiLSTM-CRF-NER'
  7.  
  8. # 设置bert_config_file
  9. flags.DEFINE_string(
  10. "bert_config_file", os.path.join(bert_path, 'bert_config.json'),
  11. "The config json file corresponding to the pre-trained BERT model."
  12. )

 关于输入的三个向量,具体内容可以参照之前的博客https://www.cnblogs.com/rucwxb/p/10277217.html

input_ids, segment_ids 分别是 token embedding, segment embedding

position embedding会自动生成

input_mask 是input中需要mask的位置,本来是随机取一部分,这里的做法是把全部输入位置都mask住。

获得输入的这三个向量的方式如下:

 

  1. # 获得三个向量的函数
  2. def inputs(vectors,maxlen=10):
  3. length=len(vectors)
  4. if length>=maxlen:
  5. return vectors[0:maxlen],[1]*maxlen,[0]*maxlen
  6. else:
  7. input=vectors+[0]*(maxlen-length)
  8. mask=[1]*length+[0]*(maxlen-length)
  9. segment=[0]*maxlen
  10. return input,mask,segment
  11. # 测试的句子
  12. text = request.args.get('text')
  13. vectors = [di.get("[CLS]")] + [di.get(i) if i in di else di.get("[UNK]") for i in list(text)] + [di.get("[SEP]")]
  14. # 转成1*maxlen的向量
  15. input, mask, segment = inputs(vectors)
  16. input_ids = np.reshape(np.array(input), [1, -1])
  17. input_mask = np.reshape(np.array(mask), [1, -1])
  18. segment_ids = np.reshape(np.array(segment), [1, -1])

 

最后是将变量输入模型获得最终的bert向量:

  1. # 定义输入向量形状
  2. input_ids_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="input_ids_p")
  3. input_mask_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="input_mask_p")
  4. segment_ids_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="segment_ids_p")
  5. model = modeling.BertModel(
  6. config=bert_config,
  7. is_training=is_training,
  8. input_ids=input_ids_p,
  9. input_mask=input_mask_p,
  10. token_type_ids=segment_ids_p,
  11. use_one_hot_embeddings=use_one_hot_embeddings
  12. )
  13. # 载入预训练模型
  14. restore_saver = tf.train.Saver()
  15. restore_saver.restore(sess, init_checkpoint)
  16. # 一个[batch_size, seq_length, embedding_size]大小的向量
  17. embedding = tf.squeeze(model.get_sequence_output())
  18. # 运行结果
  19. ret=sess.run(embedding,feed_dict={"input_ids_p:0":input_ids,"input_mask_p:0":input_mask,"segment_ids_p:0":segment_ids})

完整可运行代码如下:

  1. import tensorflow as tf
  2. from bert import modeling
  3. import collections
  4. import os
  5. import numpy as np
  6. import json
  7. flags = tf.flags
  8. FLAGS = flags.FLAGS
  9. bert_path = '/home/xiangbo_wang/xiangbo/NER/chinese_L-12_H-768_A-12/'
  10. flags.DEFINE_string(
  11. 'bert_config_file', os.path.join(bert_path, 'bert_config.json'),
  12. 'config json file corresponding to the pre-trained BERT model.'
  13. )
  14. flags.DEFINE_string(
  15. 'bert_vocab_file', os.path.join(bert_path,'vocab.txt'),
  16. 'the config vocab file',
  17. )
  18. flags.DEFINE_string(
  19. 'init_checkpoint', os.path.join(bert_path,'bert_model.ckpt'),
  20. 'from a pre-trained BERT get an initial checkpoint',
  21. )
  22. flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
  23. def convert2Uni(text):
  24. if isinstance(text, str):
  25. return text
  26. elif isinstance(text, bytes):
  27. return text.decode('utf-8','ignore')
  28. else:
  29. print(type(text))
  30. print('####################wrong################')
  31. def load_vocab(vocab_file):
  32. vocab = collections.OrderedDict()
  33. vocab.setdefault('blank', 2)
  34. index = 0
  35. with open(vocab_file) as reader:
  36. # with tf.gfile.GFile(vocab_file, 'r') as reader:
  37. while True:
  38. tmp = reader.readline()
  39. if not tmp:
  40. break
  41. token = convert2Uni(tmp)
  42. token = token.strip()
  43. vocab[token] = index
  44. index+=1
  45. return vocab
  46. def inputs(vectors, maxlen = 50):
  47. length = len(vectors)
  48. if length > maxlen:
  49. return vectors[0:maxlen], [1]*maxlen, [0]*maxlen
  50. else:
  51. input = vectors+[0]*(maxlen-length)
  52. mask = [1]*length + [0]*(maxlen-length)
  53. segment = [0]*maxlen
  54. return input, mask, segment
  55. def response_request(text):
  56. vectors = [dictionary.get('[CLS]')] + [dictionary.get(i) if i in dictionary else dictionary.get('[UNK]') for i in list(text)] + [dictionary.get('[SEP]')]
  57. input, mask, segment = inputs(vectors)
  58. input_ids = np.reshape(np.array(input), [1, -1])
  59. input_mask = np.reshape(np.array(mask), [1, -1])
  60. segment_ids = np.reshape(np.array(segment), [1, -1])
  61. embedding = tf.squeeze(model.get_sequence_output())
  62. rst = sess.run(embedding, feed_dict={'input_ids_p:0':input_ids, 'input_mask_p:0':input_mask, 'segment_ids_p:0':segment_ids})
  63. return json.dumps(rst.tolist(), ensure_ascii=False)
  64. dictionary = load_vocab(FLAGS.bert_vocab_file)
  65. init_checkpoint = FLAGS.init_checkpoint
  66. sess = tf.Session()
  67. bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
  68. input_ids_p = tf.placeholder(shape=[None, None], dtype = tf.int32, name='input_ids_p')
  69. input_mask_p = tf.placeholder(shape=[None, None], dtype = tf.int32, name='input_mask_p')
  70. segment_ids_p = tf.placeholder(shape=[None, None], dtype = tf.int32, name='segment_ids_p')
  71. model = modeling.BertModel(
  72. config = bert_config,
  73. is_training = FLAGS.use_tpu,
  74. input_ids = input_ids_p,
  75. input_mask = input_mask_p,
  76. token_type_ids = segment_ids_p,
  77. use_one_hot_embeddings = FLAGS.use_tpu,
  78. )
  79. print('####################################')
  80. restore_saver = tf.train.Saver()
  81. restore_saver.restore(sess, init_checkpoint)
  82. print(response_request('我叫水奈樾。'))
View Code

 

 

原文链接:http://www.cnblogs.com/rucwxb/p/10367609.html

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号