Source code for easytransfer.preprocessors.pretrain_preprocessor

# coding=utf-8
# Copyright (c) 2019 Alibaba PAI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import tensorflow as tf
import random
from collections import OrderedDict
import collections
from .tokenization import convert_to_unicode
from .preprocessor import Preprocessor, PreprocessorConfig, truncate_seq_pair

MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])

[docs]def create_chinese_subwords(segment): new_segment = [] for token in segment: if u'\u4e00' > token[0] or token[0] > u'\u9fa5': new_segment.append(token) else: if len(token) > 1: new_segment.append(token[0]) for ele in token[1:]: new_segment.append("##"+ele) else: new_segment.append(token) return new_segment
[docs]def create_int_feature(values): feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) return feature
[docs]def create_float_feature(values): feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) return feature
[docs]def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, do_whole_word_mask, rng): """Creates the predictions for the masked LM objective.""" cand_indexes = [] for (i, token) in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue # Whole Word Masking means that if we mask all of the wordpieces # corresponding to an original word. When a word has been split into # WordPieces, the first token does not have any marker and any subsequence # tokens are prefixed with ##. So whenever we see the ## token, we # append it to the previous set of word indexes. # # Note that Whole Word Masking does *not* change the training code # at all -- we still predict each WordPiece independently, softmaxed # over the entire vocabulary. if (do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")): cand_indexes[-1].append(i) else: cand_indexes.append([i]) rng.shuffle(cand_indexes) output_tokens = list(tokens) num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) masked_lms = [] covered_indexes = set() for index_set in cand_indexes: if len(masked_lms) >= num_to_predict: break # If adding a whole-word mask would exceed the maximum number of # predictions, then just skip this candidate. if len(masked_lms) + len(index_set) > num_to_predict: continue is_any_index_covered = False for index in index_set: if index in covered_indexes: is_any_index_covered = True break if is_any_index_covered: continue for index in index_set: covered_indexes.add(index) masked_token = None # 80% of the time, replace with [MASK] if rng.random() < 0.8: masked_token = "[MASK]" else: # 10% of the time, keep original if rng.random() < 0.5: masked_token = tokens[index] # 10% of the time, replace with random word else: masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] output_tokens[index] = masked_token masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) assert len(masked_lms) <= num_to_predict masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lm_positions = [] masked_lm_labels = [] for p in masked_lms: masked_lm_positions.append(p.index) masked_lm_labels.append(p.label) return (output_tokens, masked_lm_positions, masked_lm_labels)
[docs]class TrainingInstance(object): """A single training instance (sentence pair).""" def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, is_random_next): self.tokens = tokens self.segment_ids = segment_ids self.is_random_next = is_random_next self.masked_lm_positions = masked_lm_positions self.masked_lm_labels = masked_lm_labels
[docs]class PretrainPreprocessorConfig(PreprocessorConfig): def __init__(self, **kwargs): super(PretrainPreprocessorConfig, self).__init__(**kwargs) self.input_schema = kwargs.get("input_schema") self.sequence_length = kwargs.get("sequence_length") self.first_sequence = kwargs.get("first_sequence") self.second_sequence = kwargs.get("second_sequence") self.label_name = kwargs.get("label_name") self.label_enumerate_values = kwargs.get("label_enumerate_values") self.max_predictions_per_seq = kwargs.get("max_predictions_per_seq") self.masked_lm_prob = kwargs.get("masked_lm_prob", 0.15) self.do_whole_word_mask = kwargs.get("do_whole_word_mask", True)
[docs]class PretrainPreprocessor(Preprocessor): config_class = PretrainPreprocessorConfig def __init__(self, config, **kwargs): Preprocessor.__init__(self, config, **kwargs) self.config = config self.input_tensor_names = [] for schema in config.input_schema.split(","): name = schema.split(":")[0] self.input_tensor_names.append(name) self.vocab_words = list(self.config.tokenizer.vocab.keys()) self.rng = random.Random(12345) self.label_idx_map = OrderedDict() if self.config.label_enumerate_values is not None: for (i, label) in enumerate(self.config.label_enumerate_values.split(",")): self.label_idx_map[convert_to_unicode(label)] = i self.feature_type = kwargs.get('feature_type', "pretrain_lm")
[docs] def set_feature_schema(self): if self.mode.startswith("predict") or self.mode == "preprocess": self.output_schema = self.config.output_schema self.output_tensor_names = ["input_ids", "input_mask", "segment_ids", "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"] if self.feature_type == "pretrain_lm": self.Tout = [tf.string] * 6 self.seq_lens = [self.config.sequence_length] * 3 + [self.config.max_predictions_per_seq] * 3 self.feature_value_types = [tf.int64] * 5 + [tf.float32] elif self.feature_type == "pretrain_multimodel": self.Tout = [tf.string] * 7 self.seq_lens = [self.config.sequence_length] * 3 + [self.config.max_predictions_per_seq] * 3 + [4] self.feature_value_types = [tf.int64] * 5 + [tf.float32] + [tf.int64]
[docs] def convert_example_to_features(self, items): text_a = items[self.input_tensor_names.index(self.config.first_sequence)] tokens_a = self.config.tokenizer.tokenize(convert_to_unicode(text_a)) if self.config.second_sequence in self.input_tensor_names: text_b = items[self.input_tensor_names.index(self.config.second_sequence)] tokens_b = self.config.tokenizer.tokenize(convert_to_unicode(text_b)) truncate_seq_pair(tokens_a, tokens_b, self.config.sequence_length - 3) else: if len(tokens_a) > self.config.sequence_length - 2: tokens_a = tokens_a[0:(self.config.sequence_length - 2)] tokens_b = None tokens = [] segment_ids = [] tokens.append("[CLS]") segment_ids.append(0) for token in tokens_a: tokens.append(token) segment_ids.append(0) tokens.append("[SEP]") segment_ids.append(0) if tokens_b: for token in tokens_b: tokens.append(token) segment_ids.append(1) tokens.append("[SEP]") segment_ids.append(1) input_ids = self.config.tokenizer.convert_tokens_to_ids(tokens) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < self.config.sequence_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) assert len(input_ids) == self.config.sequence_length assert len(input_mask) == self.config.sequence_length assert len(segment_ids) == self.config.sequence_length tokens, masked_lm_positions, masked_lm_labels = \ create_masked_lm_predictions(tokens, self.config.masked_lm_prob, self.config.max_predictions_per_seq, self.vocab_words, self.config.do_whole_word_mask, self.rng) input_ids = self.config.tokenizer.convert_tokens_to_ids(tokens) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < self.config.sequence_length: input_ids.append(0) input_mask.append(0) assert len(input_ids) == self.config.sequence_length assert len(input_mask) == self.config.sequence_length assert len(segment_ids) == self.config.sequence_length masked_lm_positions = list(masked_lm_positions) masked_lm_ids = self.config.tokenizer.convert_tokens_to_ids(masked_lm_labels) masked_lm_weights = [1.0] * len(masked_lm_ids) if len(masked_lm_positions) >= self.config.max_predictions_per_seq: masked_lm_positions = masked_lm_positions[0:self.config.max_predictions_per_seq] masked_lm_ids = masked_lm_ids[0:self.config.max_predictions_per_seq] masked_lm_weights = masked_lm_weights[0:self.config.max_predictions_per_seq] while len(masked_lm_positions) < self.config.max_predictions_per_seq: masked_lm_positions.append(0) masked_lm_ids.append(0) masked_lm_weights.append(0.0) if self.feature_type == "pretrain_lm": return ' '.join([str(t) for t in input_ids]), \ ' '.join([str(t) for t in input_mask]), \ ' '.join([str(t) for t in segment_ids]), \ ' '.join([str(t) for t in masked_lm_positions]), \ ' '.join([str(t) for t in masked_lm_ids]), \ ' '.join([str(t) for t in masked_lm_weights]) elif self.feature_type == "pretrain_multimodel": return ' '.join([str(t) for t in input_ids]), \ ' '.join([str(t) for t in input_mask]), \ ' '.join([str(t) for t in segment_ids]), \ ' '.join([str(t) for t in masked_lm_positions]), \ ' '.join([str(t) for t in masked_lm_ids]), \ ' '.join([str(t) for t in masked_lm_weights]), \ ' '.join(sorted(random.sample([str(x) for x in range(10)], 4)))