Source code for easytransfer.losses.pretrain_loss

# coding=utf-8
# Copyright (c) 2019 Alibaba PAI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf


[docs]def masked_language_model_loss(lm_logits, masked_lm_ids, masked_lm_weights, vocab_size): log_probs = tf.nn.log_softmax(lm_logits, axis=-1) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) label_weights = tf.reshape(masked_lm_weights, [-1]) one_hot_labels = tf.one_hot(masked_lm_ids, depth=vocab_size, dtype=tf.float32) per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) numerator = tf.reduce_sum(label_weights * per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-5 masked_lm_loss = numerator / denominator tf.summary.scalar("masked_lm_loss", masked_lm_loss) return masked_lm_loss
[docs]def next_sentence_prediction_loss(nsp_logits, nx_sent_labels): log_probs = tf.nn.log_softmax(nsp_logits, axis=-1) next_sentence_labels = tf.reshape(nx_sent_labels, [-1]) one_hot_labels = tf.one_hot(next_sentence_labels, depth=2, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) next_sentence_prediction_loss = tf.reduce_mean(per_example_loss) tf.summary.scalar("next_sentence_prediction_loss", next_sentence_prediction_loss) return next_sentence_prediction_loss
[docs]def image_reconstruction_mse_loss(mpm_logits, target_raw_patch_features, masked_image_token_num, patch_feature_size): image_pred_probs = tf.nn.log_softmax(mpm_logits) image_pred_probs = tf.reshape(image_pred_probs, (-1, masked_image_token_num, patch_feature_size)) image_target_probs = tf.reshape(tf.nn.log_softmax(target_raw_patch_features), ( -1, masked_image_token_num, patch_feature_size)) image_loss = tf.keras.losses.mean_squared_error(image_target_probs, image_pred_probs) image_loss = tf.reduce_mean(image_loss) tf.summary.scalar("image_reconstruction_mse_loss", image_loss) return image_loss
[docs]def image_reconstruction_kld_loss(mpm_logits, target_raw_patch_features, masked_image_token_num, patch_feature_size): image_pred_probs = tf.nn.softmax(mpm_logits) image_pred_probs = tf.reshape(image_pred_probs, (-1, masked_image_token_num, patch_feature_size)) image_target_probs = tf.reshape(tf.nn.softmax(target_raw_patch_features), (-1, masked_image_token_num, patch_feature_size)) image_loss = tf.keras.losses.KLD(image_target_probs, image_pred_probs) image_loss = tf.reduce_mean(image_loss) tf.summary.scalar("image_reconstruction_kld_loss", image_loss) return image_loss