ez_transfer¶
ez_transfer.base_model¶
-
class
easytransfer.
base_model
(**kwargs)[source]¶ -
build_logits
(features, mode)[source]¶ Given features, this method take care of building graph for train/eval/predict
Parameters: - features -- either raw text features or numerical features such as input_ids, input_mask ...
- mode -- tf.estimator.ModeKeys.TRAIN | tf.estimator.ModeKeys.EVAL | tf.estimator.ModeKeys.PREDICT
Returns: logits, labels
Examples:
def build_logits(self, features, mode=None): preprocessor = preprocessors.get_preprocessor(self.pretrain_model_name_or_path) model = model_zoo.get_pretrained_model(self.pretrain_model_name_or_path) dense = layers.Dense(self.num_labels, kernel_initializer=layers.get_initializer(0.02), name='dense') input_ids, input_mask, segment_ids, label_ids = preprocessor(features) outputs = model([input_ids, input_mask, segment_ids], mode=mode) pooled_output = outputs[1] logits = dense(pooled_output) return logits, label_ids
-
build_loss
(logits, labels)[source]¶ Build loss function
Parameters: - logits -- logits returned from build_logits
- labels -- labels returned from build_logits
Returns: loss
Examples:
def build_loss(self, logits, labels): return softmax_cross_entropy(labels, depth=self.config.num_labels, logits=logits)
-
build_eval_metrics
(logits, labels)[source]¶ Build evaluation metrics
Parameters: - logits -- logits returned from build_logits
- labels -- labels returned from build_logits
Returns: metric_dict
Examples:
def build_eval_metrics(self, logits, labels): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) info_dict = { "predictions": predictions, "labels": labels, } evaluator = PyEvaluator() labels = [i for i in range(self.num_labels)] metric_dict = evaluator.get_metric_ops(info_dict, labels) ret_metrics = evaluator.evaluate(labels) tf.summary.scalar("eval accuracy", ret_metrics['py_accuracy']) tf.summary.scalar("eval F1 micro score", ret_metrics['py_micro_f1']) tf.summary.scalar("eval F1 macro score", ret_metrics['py_macro_f1']) return metric_dict
-
build_predictions
(logits)[source]¶ Build predictions
Parameters: logits -- logits returned from build_logits Returns: predictions Examples:
def build_predictions(self, output): logits, _ = output predictions = dict() predictions["predictions"] = tf.argmax(logits, axis=-1, output_type=tf.int32) return predictions
-