关键词:Bert、T5EncoderModel、last_hidden_state、pooler_output、Pytorch
目标:使用T5中文预训练模型(t5-chinese-base)实现分类任务
先观察 T5EncoderModel 的输出:
模型输出只有 last_hidden_state,该张量为每个token的隐藏层输出。
为方便后续实现分类任务,可以仿照BERT模型构造pooler_output。
首先看下BERT关于pooler_output的源码部分:
再来看下 BertPooler 部分的源码:
所以BERT中 pooled_output 即为 last_hidden_state[:0] 再经全连接层而得。于是可以对 T5EncoderModel 添加全连接层,构造出 pooled_output,并修改输出维度为分类类别数量。