Source code for easytransfer.app_zoo.text_comprehension

# coding=utf-8
# Copyright (c) 2019 Alibaba PAI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import tensorflow as tf
from easytransfer import preprocessors, model_zoo
from easytransfer.app_zoo.base import ApplicationModel
from easytransfer.evaluators import comprehension_eval_metrics
from easytransfer.losses import comprehension_loss
from easytransfer.model_zoo.modeling_hae import BertHAEPretrainedModel


[docs]class BaseTextComprehension(ApplicationModel): def __init__(self, **kwargs): """ Basic Text Comprehension Model """ super(BaseTextComprehension, self).__init__(**kwargs)
[docs] @staticmethod def default_model_params(): """ The default value of the Text Comprehension Model """ raise NotImplementedError
[docs] def build_logits(self, features, mode=None): """ Building graph of the Text Comprehension Model """ raise NotImplementedError
[docs] def build_loss(self, logits, labels): """ Building loss for training the Text Comprehension Model """ return comprehension_loss(logits, labels)
[docs] def build_eval_metrics(self, logits, labels): """ Building evaluation metrics while evaluating Args: logits (`Tensor`): shape of [None, num_labels] labels (`Tensor`): shape of [None] Returns: ret_dict (`dict`): A dict with (`py_accuracy`, `py_micro_f1`, `py_macro_f1`) tf.metrics op """ return comprehension_eval_metrics(logits, labels)
[docs] def build_predictions(self, predict_output): """ Building prediction dict of the Text Comprehension Model Args: predict_output (`tuple`): (logits, _) Returns: ret_dict (`dict`): A dict with (`predictions`, `probabilities`, `logits`) """ (start_logits, end_logits), _ = predict_output ret_dict = { "start_logits": start_logits, "end_logits": end_logits } return ret_dict
[docs]class BERTTextComprehension(BaseTextComprehension): """ BERT Text Comprehension Model .. highlight:: python .. code-block:: python default_param_dict = { "pretrain_model_name_or_path": "pai-bert-base-zh", "multi_label": False, "num_labels": 2, "max_query_length": 64, "doc_stride": 128, } """ def __init__(self, **kwargs): super(BERTTextComprehension, self).__init__(**kwargs)
[docs] @staticmethod def get_input_tensor_schema(sequence_length=64): return "input_ids:int:{},input_mask:int:{},segment_ids:int:{}," \ "start_position:int:1,end_position:int:1".format( sequence_length, sequence_length, sequence_length)
[docs] @staticmethod def get_received_tensor_schema(): return "input_ids:int:64,input_mask:int:64,segment_ids:int:64"
[docs] @staticmethod def default_model_params(): """ Get default model required parameters Returns: default_param_dict (`dict`): key/value pair of default model required parameters """ default_param_dict = { "pretrain_model_name_or_path": "pai-bert-base-zh", "num_labels": 2, "max_query_length": 64, "doc_stride": 128, } return default_param_dict
[docs] def build_logits(self, features, mode=None): """ Building graph of BERT Text Comprehension Args: features (`OrderedDict`): A dict mapping raw input to tensors mode (`bool): tell the model whether it is under training Returns: logits (`tuple`): (start_logits, end_logits), The output after the last dense layer. Two tensor of Shape [None, num_labels] label_ids (`tuple`): (start_positions, end_positions). Two tensor of shape [None] """ preprocessor = preprocessors.get_preprocessor(self.config.pretrain_model_name_or_path, app_model_name="text_comprehension_bert", user_defined_config=self.config) input_ids, input_mask, segment_ids, start_positions, end_positions = preprocessor(features) bert_backbone = model_zoo.get_pretrained_model(self.config.pretrain_model_name_or_path) sequence_output, _ = bert_backbone([input_ids, input_mask, segment_ids], mode=mode) seq_length = self.config.sequence_length hidden_size = int(sequence_output.shape[2]) output_weights = tf.get_variable( "app/output_weights", [2, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable( "app/output_bias", [2], initializer=tf.zeros_initializer()) final_hidden_matrix = tf.reshape(sequence_output, [-1, hidden_size]) logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [-1, seq_length, 2]) logits = tf.transpose(logits, [2, 0, 1]) unstacked_logits = tf.unstack(logits, axis=0) (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1]) self.check_and_init_from_checkpoint(mode) return (start_logits, end_logits), (start_positions, end_positions)
[docs]class BERTTextHAEComprehension(BaseTextComprehension): """ BERT-HAE Text Classification Model .. highlight:: python .. code-block:: python default_param_dict = { "pretrain_model_name_or_path": "pai-bert-base-zh", "multi_label": False, "num_labels": 2, "dropout_rate": 0.1, "max_query_length": 64, } """ def __init__(self, **kwargs): super(BERTTextHAEComprehension, self).__init__(**kwargs)
[docs] @staticmethod def get_input_tensor_schema(sequence_length=64): return "input_ids:int:{},input_mask:int:{},segment_ids:int:{},history_answer_marker:int:{}," \ "start_position:int:1,end_position:int:1".format( sequence_length, sequence_length, sequence_length, sequence_length)
[docs] @staticmethod def get_received_tensor_schema(): return "input_ids:int:64,input_mask:int:64,segment_ids:int:64,history_answer_marker:int:64"
[docs] @staticmethod def default_model_params(): """ Get default model required parameters Returns: default_param_dict (`dict`): key/value pair of default model required parameters """ default_param_dict = { "pretrain_model_name_or_path": "pai-bert-base-zh", "max_query_length": 64, "doc_stride": 128, "max_considered_history_turns": 11 } return default_param_dict
[docs] def build_logits(self, features, mode=None): """ Building graph of BERT Text Comprehension Args: features (`OrderedDict`): A dict mapping raw input to tensors mode (`bool): tell the model whether it is under training Returns: logits (`tuple`): (start_logits, end_logits), The output after the last dense layer. Two tensor of Shape [None, num_labels] label_ids (`tuple`): (start_positions, end_positions). Two tensor of shape [None] """ preprocessor = preprocessors.get_preprocessor(self.config.pretrain_model_name_or_path, app_model_name="text_comprehension_bert_hae", user_defined_config=self.config) input_ids, input_mask, segment_ids, history_answer_marker, start_positions, end_positions =\ preprocessor(features) bert_backbone = BertHAEPretrainedModel.get(self.config.pretrain_model_name_or_path) tmp = [input_ids, input_mask, segment_ids, history_answer_marker] sequence_output, _ = bert_backbone(tmp, mode=mode) seq_length = self.config.sequence_length hidden_size = int(sequence_output.shape[2]) output_weights = tf.get_variable( "app/output_weights", [2, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable( "app/output_bias", [2], initializer=tf.zeros_initializer()) final_hidden_matrix = tf.reshape(sequence_output, [-1, hidden_size]) logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) logits = tf.reshape(logits, [-1, seq_length, 2]) logits = tf.transpose(logits, [2, 0, 1]) unstacked_logits = tf.unstack(logits, axis=0) (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1]) self.check_and_init_from_checkpoint(mode) return (start_logits, end_logits), (start_positions, end_positions)