Source code for easytransfer.app_zoo.feature_extractor
# coding=utf-8
# Copyright (c) 2019 Alibaba PAI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from easytransfer import preprocessors, model_zoo
from easytransfer.app_zoo.base import ApplicationModel
[docs]class BertFeatureExtractor(ApplicationModel):
""" Bert Feature Extraction Model (Only for predicting)"""
def __init__(self, **kwargs):
super(BertFeatureExtractor, self).__init__(**kwargs)
self.finetune_model_name = self.config.finetune_model_name if \
hasattr(self.config, "finetune_model_name") else None
[docs] def build_logits(self, features, mode):
""" Building BERT feature extraction graph
Args:
features (`OrderedDict`): A dict mapping raw input to tensors
mode (`bool`): tell the model whether it is under training
Returns:
pooled_output (`Tensor`): The output after pooling. Shape of [None, 768]
all_hidden_outputs (`Tensor`): The last hidden outputs of all sequence.
Shape of [None, seq_len, hidden_size]
"""
bert_preprocessor = preprocessors.get_preprocessor(self.config.pretrain_model_name_or_path,
user_defined_config=self.config)
input_ids, input_mask, segment_ids = bert_preprocessor(features)[:3]
if self.finetune_model_name == "text_match_bert_two_tower":
with tf.variable_scope('text_match_bert_two_tower', reuse=tf.AUTO_REUSE):
bert_backbone = model_zoo.get_pretrained_model(self.config.pretrain_model_name_or_path)
sequence_output, pooled_output = bert_backbone(
[input_ids, input_mask, segment_ids], output_features=True, mode=mode)
else:
bert_backbone = model_zoo.get_pretrained_model(self.config.pretrain_model_name_or_path)
sequence_output, pooled_output = bert_backbone(
[input_ids, input_mask, segment_ids], output_features=True, mode=mode)
return sequence_output, pooled_output
[docs] def build_predictions(self, predict_output):
""" Building BERT feature extraction prediction dict.
Args:
predict_output (`tuple`): (sequence_output, pooled_output)
Returns:
ret_dict (`dict`): A dict with (`pool_output`, `first_token_output`,
`all_hidden_outputs`)
"""
all_hidden_outputs, pool_output = predict_output
first_token_output = all_hidden_outputs[:, 0, :]
ret_dict = {
"pool_output": pool_output,
"first_token_output": first_token_output,
"all_hidden_outputs": all_hidden_outputs
}
return ret_dict