ez_transfer.losses

classification_regression_loss

easytransfer.losses.classification_regression_loss.softmax_cross_entropy(labels, depth, logits)[source]
easytransfer.losses.classification_regression_loss.mean_square_error(labels, logits)[source]
easytransfer.losses.classification_regression_loss.multi_label_sigmoid_cross_entropy(labels, depth, logits)[source]

comprehension_loss

easytransfer.losses.comprehension_loss.comprehension_loss(logits, labels)[source]

kd_loss

easytransfer.losses.kd_loss.build_kd_loss(teacher_logits, student_logits, task_balance=0.3, distill_tempreture=2.0, labels=None, loss_type='mse')[source]
easytransfer.losses.kd_loss.mse_loss(teacher_logits, student_logits)[source]
easytransfer.losses.kd_loss.xent_loss(teacher_logits, student_logits, labels, distill_tempreture, task_balance)[source]
easytransfer.losses.kd_loss.kld_loss(teacher_logits, student_logits, labels, distill_temperature, task_balance)[source]
easytransfer.losses.kd_loss.build_kd_probes_loss(teacher_logits, student_logits, task_balance=0.3, distill_tempreture=2.0, labels=None, loss_type='mse')[source]

labeling_loss

easytransfer.losses.labeling_loss.sequence_labeling_loss(logits, labels, num_labels)[source]

matching_loss

easytransfer.losses.matching_loss.matching_embedding_margin_loss(emb1, emb2)[source]

pretrain_loss

easytransfer.losses.pretrain_loss.masked_language_model_loss(lm_logits, masked_lm_ids, masked_lm_weights, vocab_size)[source]
easytransfer.losses.pretrain_loss.next_sentence_prediction_loss(nsp_logits, nx_sent_labels)[source]
easytransfer.losses.pretrain_loss.image_reconstruction_mse_loss(mpm_logits, target_raw_patch_features, masked_image_token_num, patch_feature_size)[source]
easytransfer.losses.pretrain_loss.image_reconstruction_kld_loss(mpm_logits, target_raw_patch_features, masked_image_token_num, patch_feature_size)[source]