Source code for easytransfer.engines.model

# coding=utf-8
# Copyright (c) 2019 Alibaba PAI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
import random
import numpy as np
import os
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
SEED = 123123
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.set_random_seed(SEED)
tf.reset_default_graph()

flags = tf.app.flags
flags.DEFINE_string("config", default=None, help='')
flags.DEFINE_string("tables", default=None, help='')
flags.DEFINE_string("outputs", default=None, help='')
flags.DEFINE_integer('task_index', 0, 'task_index')
flags.DEFINE_string('worker_hosts', 'localhost:5001', 'worker_hosts')
flags.DEFINE_string('job_name', 'worker', 'job_name')
flags.DEFINE_string("mode", default=None, help="Which mode")
flags.DEFINE_string("modelZooBasePath", default=os.path.join(os.getenv("HOME"), ".eztransfer_modelzoo"), help="eztransfer_modelzoo")
flags.DEFINE_bool("usePAI", default=False, help="Whether to use pai")
flags.DEFINE_integer("workerCount", default=1, help="num_workers")
flags.DEFINE_integer("workerGPU", default=1, help="num_gpus")
flags.DEFINE_integer("workerCPU", default=1, help="num_cpus")
flags.DEFINE_string('f', '', 'kernel')
FLAGS = flags.FLAGS

from easytransfer.utils.hooks import avgloss_logger_hook
from easytransfer.optimizers import get_train_op


class Config(object):
    def __init__(self, mode, config_json):
        self._config_json = copy.deepcopy(config_json)
        self.mode = mode

        self.worker_hosts = str(config_json["worker_hosts"])
        self.task_index = int(config_json["task_index"])
        self.job_name = str(config_json["job_name"])
        self.num_gpus = int(config_json["num_gpus"])
        self.num_workers = int(config_json["num_workers"])
        if not FLAGS.usePAI:
            FLAGS.modelZooBasePath = config_json.get("modelZooBasePath", os.path.join(os.getenv("HOME"), ".eztransfer_modelzoo"))
        tf.logging.info("***************** modelZooBasePath {} ***************".format(FLAGS.modelZooBasePath))
        if self.mode == 'train' or self.mode == "train_and_evaluate" \
                or self.mode == "train_and_evaluate_on_the_fly" or self.mode == "train_on_the_fly":

            self.enable_xla = bool(config_json["train_config"]['distribution_config'].get(
                'enable_xla', False))

            self.enable_auto_mixed_precision = bool(config_json["train_config"]['distribution_config'].get(
                'enable_auto_mixed_precision', False))

            self.distribution_strategy = str(
                config_json["train_config"]['distribution_config'].get("distribution_strategy", None))
            self.pull_evaluation_in_multiworkers_training = bool(config_json["train_config"]['distribution_config'].get(
                'pull_evaluation_in_multiworkers_training', False))

            self.num_accumulated_batches = int(config_json["train_config"]['distribution_config'].get(
                'num_accumulated_batches', 1))

            self.num_model_replica = int(config_json["train_config"]['distribution_config'].get(
                'num_model_replica', 1))

            # optimizer
            self.optimizer = str(config_json["train_config"]['optimizer_config'].get('optimizer', "adam"))
            self.learning_rate = float(config_json['train_config']['optimizer_config'].get('learning_rate', 0.001))

            self.weight_decay_ratio = float(
                config_json['train_config']['optimizer_config'].get('weight_decay_ratio', 0))
            self.lr_decay = config_json['train_config']['optimizer_config'].get('lr_decay', "polynomial")
            self.warmup_ratio = float(config_json['train_config']['optimizer_config'].get('warmup_ratio', 0.1))
            self.clip_norm_value = float(config_json['train_config']['optimizer_config'].get('clip_norm_value', 1.0))
            self.gradient_clip = bool(config_json['train_config']['optimizer_config'].get('gradient_clip', True))
            self.num_freezed_layers = int(config_json['train_config']['optimizer_config'].get('num_freezed_layers', 0))

            # misc
            self.num_epochs = float(config_json['train_config'].get('num_epochs', 1))
            try:
                self.model_dir = str(config_json['train_config'].get('model_dir', None))
            except:
                raise ValueError("input model dir")

            self.throttle_secs = int(config_json['train_config'].get('throttle_secs', 100))
            self.keep_checkpoint_max = int(config_json['train_config'].get('keep_checkpoint_max', 10))

            if 'save_steps' not in config_json['train_config']:
                self.save_steps = None
            else:
                self.save_steps = int(config_json['train_config']['save_steps']) \
                    if config_json['train_config']['save_steps'] else \
                    config_json['train_config']['save_steps']

            self.log_step_count_steps = int(config_json['train_config'].get('log_step_count_steps', 100))

            # model
            for key, val in config_json['model_config'].items():
                setattr(self, key, val)

            # data
            self.input_schema = str(config_json['preprocess_config'].get('input_schema', None))

            if self.mode == 'train_and_evaluate_on_the_fly' or self.mode == 'train_on_the_fly':
                self.sequence_length = int(config_json['preprocess_config']['sequence_length'])
                self.first_sequence = str(config_json['preprocess_config']['first_sequence'])
                self.second_sequence = str(config_json['preprocess_config']['second_sequence'])
                self.label_name = str(config_json['preprocess_config']['label_name'])
                self.label_enumerate_values = config_json['preprocess_config'].get('label_enumerate_values', None)
                self.append_feature_columns = config_json['preprocess_config'].get('append_feature_columns', None)

            if self.mode == 'train_and_evaluate' or self.mode == 'train_and_evaluate_on_the_fly':
                self.eval_batch_size = int(config_json['evaluate_config']['eval_batch_size'])

                if 'num_eval_steps' not in config_json['evaluate_config']:
                    self.num_eval_steps = None
                else:
                    self.num_eval_steps = int(config_json['evaluate_config']['num_eval_steps']) \
                        if config_json['evaluate_config']['num_eval_steps'] else \
                        config_json['evaluate_config']['num_eval_steps']

                self.eval_input_fp = str(config_json['evaluate_config']['eval_input_fp'])

            self.train_input_fp = str(config_json['train_config']['train_input_fp'])
            self.train_batch_size = int(config_json['train_config']['train_batch_size'])

        elif self.mode == "evaluate" or self.mode == "evaluate_on_the_fly":
            self.eval_ckpt_path = config_json['evaluate_config']['eval_checkpoint_path']

            self.input_schema = config_json['preprocess_config']['input_schema']
            if self.mode == "evaluate_on_the_fly":
                self.sequence_length = int(config_json['preprocess_config']['sequence_length'])
                self.first_sequence = str(config_json['preprocess_config']['first_sequence'])
                self.second_sequence = str(config_json['preprocess_config']['second_sequence'])
                self.label_name = str(config_json['preprocess_config']['label_name'])
                self.label_enumerate_values = config_json['preprocess_config'].get('label_enumerate_values', None)

            for key, val in config_json['model_config'].items():
                setattr(self, key, val)

            self.eval_batch_size = config_json['evaluate_config']['eval_batch_size']
            self.num_eval_steps = config_json['evaluate_config'].get('num_eval_steps', None)
            self.eval_input_fp = config_json['evaluate_config']['eval_input_fp']

        elif self.mode == 'predict' or self.mode == 'predict_on_the_fly':
            self.predict_checkpoint_path = config_json['predict_config']['predict_checkpoint_path']
            self.input_schema = config_json['preprocess_config']['input_schema']
            self.label_name = config_json['preprocess_config'].get('label_name', None)
            self.label_enumerate_values = config_json['preprocess_config'].get('label_enumerate_values', None)
            self.append_feature_columns = config_json['preprocess_config'].get('append_feature_columns', None)

            if self.mode == 'predict_on_the_fly':
                self.first_sequence = config_json['preprocess_config']['first_sequence']
                self.second_sequence = config_json['preprocess_config']['second_sequence']
                self.sequence_length = config_json['preprocess_config']['sequence_length']
                self.max_predictions_per_seq = config_json['preprocess_config'].get('max_predictions_per_seq', None)

            self.predict_batch_size = config_json['predict_config']['predict_batch_size']

            if config_json['preprocess_config']['output_schema'] == "bert_finetune":
                self.output_schema = "input_ids,input_mask,segment_ids,label_id"

            elif config_json['preprocess_config']['output_schema'] == "bert_pretrain":
                self.output_schema = "input_ids,input_mask,segment_ids,masked_lm_positions,masked_lm_ids,masked_lm_weights"

            elif config_json['preprocess_config']['output_schema'] == "bert_predict":
                self.output_schema = "input_ids,input_mask,segment_ids"
            else:
                self.output_schema = config_json['preprocess_config']['output_schema']

            self.model_config = config_json['model_config']
            for key, val in config_json['model_config'].items():
                setattr(self, key, val)

            self.predict_input_fp = config_json['predict_config']['predict_input_fp']
            self.predict_output_fp = config_json['predict_config']['predict_output_fp']

        elif self.mode == 'export':
            self.checkpoint_path = config_json['export_config']['checkpoint_path']
            for key, val in config_json['model_config'].items():
                setattr(self, key, val)
            self.export_dir_base = config_json['export_config']['export_dir_base']
            self.checkpoint_path = config_json['export_config']['checkpoint_path']
            self.receiver_tensors_schema = config_json['export_config']['receiver_tensors_schema']
            self.input_tensors_schema = config_json['export_config']['input_tensors_schema']

        elif self.mode == 'preprocess':
            self.input_schema = config_json['preprocess_config']['input_schema']
            self.first_sequence = config_json['preprocess_config']['first_sequence']
            self.second_sequence = config_json['preprocess_config'].get('second_sequence', None)
            self.label_name = config_json['preprocess_config'].get('label_name', None)
            self.label_enumerate_values = config_json['preprocess_config'].get('label_enumerate_values', None)
            self.sequence_length = config_json['preprocess_config']['sequence_length']
            self.max_predictions_per_seq = config_json['preprocess_config'].get('max_predictions_per_seq', None)

            if config_json['preprocess_config']['output_schema'] == "bert_finetune":
                self.output_schema = "input_ids,input_mask,segment_ids,label_id"

            elif config_json['preprocess_config']['output_schema'] == "bert_pretrain":
                self.output_schema = "input_ids,input_mask,segment_ids,masked_lm_positions,masked_lm_ids,masked_lm_weights"

            elif config_json['preprocess_config']['output_schema'] == "bert_predict":
                self.output_schema = "input_ids,input_mask,segment_ids"
            else:
                self.output_schema = config_json['preprocess_config']['output_schema']

            self.preprocess_input_fp = config_json['preprocess_config']['preprocess_input_fp']
            self.preprocess_output_fp = config_json['preprocess_config']['preprocess_output_fp']
            self.preprocess_batch_size = config_json['preprocess_config']['preprocess_batch_size']
            self.tokenizer_name_or_path = config_json['preprocess_config']['tokenizer_name_or_path']

    def __str__(self):
        return json.dumps(self.__dict__, sort_keys=False, indent=4)


class EzTransEstimator(object):
    def __init__(self, **kwargs):

        if self.config.mode == 'train' or self.config.mode == "train_and_evaluate" or \
                self.config.mode == "train_and_evaluate_on_the_fly" or self.config.mode == "train_on_the_fly":

            tf.logging.info("***********Running in {} mode***********".format(self.config.mode))

            if self.config.enable_xla is True:
                tf.logging.info("***********Enable Tao***********")
                os.environ['BRIDGE_ENABLE_TAO'] = 'True'
                os.environ["TAO_ENABLE_CHECK"] = "false"
                os.environ["TAO_COMPILATION_MODE_ASYNC"] = "false"
                os.environ["DISABLE_DEADNESS_ANALYSIS"] = "true"
            else:
                tf.logging.info("***********Disable Tao***********")

            if self.config.enable_auto_mixed_precision is True:
                tf.logging.info("***********Enable AUTO_MIXED_PRECISION***********")
                os.environ['TF_AUTO_MIXED_PRECISION'] = 'True'
                os.environ['lossScaling'] = 'auto'
            else:
                tf.logging.info("***********Disable AUTO_MIXED_PRECISION***********")

            NCCL_MAX_NRINGS = "4"
            NCCL_MIN_NRINGS = "4"
            TF_JIT_PROFILING = 'False'
            PAI_ENABLE_HLO_DUMPER = 'False'
            os.environ['PAI_ENABLE_HLO_DUMPER'] = PAI_ENABLE_HLO_DUMPER
            os.environ['TF_JIT_PROFILING'] = TF_JIT_PROFILING
            os.environ["NCCL_MAX_NRINGS"] = NCCL_MAX_NRINGS
            os.environ["NCCL_MIN_NRINGS"] = NCCL_MIN_NRINGS
            os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
            tf.logging.info("***********NCCL_MAX_NRINGS {}***********".format(NCCL_MAX_NRINGS))
            tf.logging.info("***********NCCL_MIN_NRINGS {}***********".format(NCCL_MIN_NRINGS))
            tf.logging.info("***********TF_JIT_PROFILING {}***********".format(TF_JIT_PROFILING))
            tf.logging.info("***********PAI_ENABLE_HLO_DUMPER {}***********".format(PAI_ENABLE_HLO_DUMPER))

            self.strategy = None
            if self.config.num_gpus >= 1 and self.config.num_workers >= 1 and \
                    (self.config.distribution_strategy == "ExascaleStrategy" or
                     self.config.distribution_strategy == "CollectiveAllReduceStrategy"):

                if FLAGS.usePAI:
                    import pai
                    worker_hosts = self.config.worker_hosts.split(',')
                    tf.logging.info("***********Job Name is {}***********".format(self.config.job_name))
                    tf.logging.info("***********Task Index is {}***********".format(self.config.task_index))
                    tf.logging.info("***********Worker Hosts is {}***********".format(self.config.worker_hosts))
                    pai.distribute.set_tf_config(self.config.job_name,
                                                 self.config.task_index,
                                                 worker_hosts,
                                                 has_evaluator=self.config.pull_evaluation_in_multiworkers_training)


                if self.config.distribution_strategy == "ExascaleStrategy":
                    tf.logging.info("*****************Using ExascaleStrategy*********************")
                    if FLAGS.usePAI:
                        self.strategy = pai.distribute.ExascaleStrategy(num_gpus=self.config.num_gpus,
                                                                        num_micro_batches=self.config.num_accumulated_batches,
                                                                        max_splits=1,
                                                                        enable_sparse_allreduce=False)
                    else:
                        raise ValueError("Please set usePAI is True")

                elif self.config.distribution_strategy == "CollectiveAllReduceStrategy":
                    tf.logging.info("*****************Using CollectiveAllReduceStrategy*********************")
                    if FLAGS.usePAI:
                        self.strategy = tf.contrib.distribute.CollectiveAllReduceStrategy(
                            num_gpus_per_worker=self.config.num_gpus,
                            cross_tower_ops_type='default',
                            all_dense=True,
                            iter_size=self.config.num_accumulated_batches)
                    else:
                        self.strategy = tf.contrib.distribute.CollectiveAllReduceStrategy(
                            num_gpus_per_worker=self.config.num_gpus)

                if self.config.pull_evaluation_in_multiworkers_training is True:
                    real_num_workers = self.config.num_workers - 1
                else:
                    real_num_workers = self.config.num_workers

                global_batch_size = self.config.train_batch_size * self.config.num_gpus * real_num_workers


            elif self.config.num_gpus > 1 and self.config.num_workers == 1 and \
                    self.config.distribution_strategy == "MirroredStrategy":
                tf.logging.info("*****************Using MirroredStrategy*********************")
                if FLAGS.usePAI:
                    from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
                    cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps('nccl')
                    self.strategy = tf.contrib.distribute.MirroredStrategy(
                        num_gpus=self.config.num_gpus,
                        cross_tower_ops=cross_tower_ops,
                        all_dense=True,
                        iter_size=self.config.num_accumulated_batches)
                else:
                    self.strategy = tf.contrib.distribute.MirroredStrategy(num_gpus=self.config.num_gpus)

                global_batch_size = self.config.train_batch_size * self.config.num_gpus * self.config.num_accumulated_batches

            elif self.config.num_gpus >= 1 and self.config.num_workers >= 1 and \
                    self.config.distribution_strategy == "WhaleStrategy":

                if FLAGS.usePAI:
                    import pai
                    worker_hosts = self.config.worker_hosts.split(',')
                    tf.logging.info("***********Job Name is {}***********".format(self.config.job_name))
                    tf.logging.info("***********Task Index is {}***********".format(self.config.task_index))
                    tf.logging.info("***********Worker Hosts is {}***********".format(self.config.worker_hosts))
                    pai.distribute.set_tf_config(self.config.job_name,
                                                 self.config.task_index,
                                                 worker_hosts,
                                                 has_evaluator=self.config.pull_evaluation_in_multiworkers_training)

                tf.logging.info("*****************Using WhaleStrategy*********************")
                os.environ["WHALE_COMMUNICATION_SPARSE_AS_DENSE"] = "True"
                os.environ["WHALE_COMMUNICATION_NUM_COMMUNICATORS"] = "2"
                os.environ["WHALE_COMMUNICATION_NUM_SPLITS"] = "8"
                global_batch_size = self.config.train_batch_size * self.config.num_accumulated_batches * self.config.num_model_replica

            elif self.config.num_gpus == 1 and self.config.num_workers == 1:
                global_batch_size = self.config.train_batch_size * self.config.num_accumulated_batches
                tf.logging.info("***********Single worker, Single gpu, Don't use distribution strategy***********")

            elif self.config.num_gpus == 0 and self.config.num_workers == 1:
                global_batch_size = self.config.train_batch_size * self.config.num_accumulated_batches
                tf.logging.info("***********Single worker, Running on CPU***********")

            else:
                raise ValueError(
                    "In train model, Please set correct num_workers, num_gpus and distribution_strategy, \n"
                    "num_workers>=1, num_gpus>=1, distribution_strategy=WhaleStrategy|ExascaleStrategy|CollectiveAllReduceStrategy \n"
                    "num_workers>1, num_gpus==1, distribution_strategy=MirroredStrategy \n"
                    "num_workers=1, num_gpus=1, distribution_strategy=None")

            # Validate optional keyword arguments.
            if "num_train_examples" not in kwargs:
                raise ValueError('Please pass num_train_examples')

            self.num_train_examples = kwargs['num_train_examples']

            # if save steps is None, save per epoch
            if self.config.save_steps is None:
                self.save_steps = int(self.num_train_examples / global_batch_size)
            else:
                self.save_steps = self.config.save_steps

            self.train_steps = int(self.num_train_examples *
                                   self.config.num_epochs / global_batch_size) + 1

            self.throttle_secs = self.config.throttle_secs
            self.model_dir = self.config.model_dir
            tf.logging.info("model_dir: {}".format(self.config.model_dir))
            tf.logging.info("num workers: {}".format(self.config.num_workers))
            tf.logging.info("num gpus: {}".format(self.config.num_gpus))
            tf.logging.info("learning rate: {}".format(self.config.learning_rate))
            tf.logging.info("train batch size: {}".format(self.config.train_batch_size))
            tf.logging.info("global batch size: {}".format(global_batch_size))
            tf.logging.info("num accumulated batches: {}".format(self.config.num_accumulated_batches))
            tf.logging.info("num model replica: {}".format(self.config.num_model_replica))
            tf.logging.info("num train examples per epoch: {}".format(self.num_train_examples))
            tf.logging.info("num epochs: {}".format(self.config.num_epochs))
            tf.logging.info("train steps: {}".format(self.train_steps))
            tf.logging.info("save steps: {}".format(self.save_steps))
            tf.logging.info("throttle secs: {}".format(self.throttle_secs))
            tf.logging.info("keep checkpoint max: {}".format(self.config.keep_checkpoint_max))
            tf.logging.info("warmup ratio: {}".format(self.config.warmup_ratio))
            tf.logging.info("gradient clip: {}".format(self.config.gradient_clip))
            tf.logging.info("log step count steps: {}".format(self.config.log_step_count_steps))

            if self.config.distribution_strategy != "WhaleStrategy":
                self.estimator = tf.estimator.Estimator(
                    model_fn=self._build_model_fn(),
                    model_dir=self.config.model_dir,
                    config=self._get_run_train_config(config=self.config))
            else:
                tf.logging.info("***********Using Whale Estimator***********")
                try:
                    from easytransfer.engines.whale_estimator import WhaleEstimator
                    import whale as wh
                    wh.init()
                    self.estimator = WhaleEstimator(
                        model_fn=self._build_model_fn(),
                        model_dir=self.config.model_dir,
                        num_model_replica=self.config.num_model_replica,
                        num_accumulated_batches=self.config.num_accumulated_batches)
                except:
                    raise NotImplementedError("WhaleStrategy doesn't work well")


            if self.config.mode == 'train_and_evaluate' or self.config.mode == 'train_and_evaluate_on_the_fly':
                self.num_eval_steps = self.config.num_eval_steps
                tf.logging.info("num eval steps: {}".format(self.num_eval_steps))

        elif self.config.mode == 'evaluate' or self.config.mode == 'evaluate_on_the_fly':
            self.num_eval_steps = self.config.num_eval_steps
            tf.logging.info("num eval steps: {}".format(self.num_eval_steps))
            tf.logging.info("***********Running in {} mode***********".format(self.config.mode))
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=self._get_run_predict_config())

        elif self.config.mode == 'predict' or self.config.mode == 'predict_on_the_fly':
            tf.logging.info("***********Running in {} mode***********".format(self.config.mode))
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=self._get_run_predict_config())

        elif self.config.mode == 'export':
            tf.logging.info("***********Running in {} mode***********".format(self.config.mode))
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=self._get_run_predict_config())

        elif self.config.mode == 'preprocess':
            tf.logging.info("***********Running in {} mode***********".format(self.config.mode))
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=tf.estimator.RunConfig())

            self.first_sequence = self.config.first_sequence
            self.second_sequence = self.config.second_sequence
            self.label_enumerate_values = self.config.label_enumerate_values
            self.label_name = self.config.label_name

    def _get_run_train_config(self, config):
        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=False,
            intra_op_parallelism_threads=1024,
            inter_op_parallelism_threads=1024,
            gpu_options=tf.GPUOptions(allow_growth=True,
                                      force_gpu_compatible=True,
                                      per_process_gpu_memory_fraction=1.0))

        run_config = tf.estimator.RunConfig(session_config=session_config,
                                            model_dir=config.model_dir,
                                            tf_random_seed=123123,
                                            train_distribute=self.strategy,
                                            log_step_count_steps=100,
                                            save_checkpoints_steps=self.save_steps,
                                            keep_checkpoint_max=config.keep_checkpoint_max
                                            )
        return run_config

    def _get_run_predict_config(self):

        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=False,
            intra_op_parallelism_threads=1024,
            inter_op_parallelism_threads=1024,
            gpu_options=tf.GPUOptions(allow_growth=True,
                                      force_gpu_compatible=True,
                                      per_process_gpu_memory_fraction=1.0))

        run_config = tf.estimator.RunConfig(session_config=session_config)
        return run_config

    def _build_model_fn(self):
        def model_fn(features, labels, mode, params):
            if mode == tf.estimator.ModeKeys.TRAIN:
                logits, labels = self.build_logits(features, mode=mode)
                total_loss = self.build_loss(logits, labels)
                train_op = get_train_op(learning_rate=self.config.learning_rate,
                                        weight_decay_ratio=self.config.weight_decay_ratio,
                                        loss=total_loss,
                                        lr_decay=self.config.lr_decay,
                                        warmup_ratio=self.config.warmup_ratio,
                                        optimizer_name=self.config.optimizer,
                                        tvars=self.tvars if hasattr(self, "tvars") else None,
                                        train_steps=self.train_steps,
                                        clip_norm=self.config.gradient_clip,
                                        clip_norm_value=self.config.clip_norm_value,
                                        num_freezed_layers=self.config.num_freezed_layers
                                        )

                if self.config.distribution_strategy == "WhaleStrategy":
                    return total_loss, train_op

                avgloss_hook = avgloss_logger_hook(self.train_steps,
                                                   total_loss,
                                                   self.model_dir,
                                                   self.config.log_step_count_steps)

                summary_hook = tf.train.SummarySaverHook(save_steps=100, summary_op=tf.summary.merge_all())
                return tf.estimator.EstimatorSpec(
                    mode=mode, loss=total_loss, train_op=train_op,
                    training_hooks=[summary_hook, avgloss_hook])

            elif mode == tf.estimator.ModeKeys.EVAL:
                logits, labels = self.build_logits(features, mode=mode)
                eval_loss = self.build_loss(logits, labels)
                tf.summary.scalar("eval_loss", eval_loss)
                metrics = self.build_eval_metrics(logits, labels)
                summary_hook = tf.train.SummarySaverHook(save_steps=100,
                                                         summary_op=tf.summary.merge_all())
                return tf.estimator.EstimatorSpec(mode, loss=eval_loss,
                                                  eval_metric_ops=metrics,
                                                  evaluation_hooks=[summary_hook])

            elif mode == tf.estimator.ModeKeys.PREDICT:
                if self.config.mode == 'predict' or self.config.mode == 'export':
                    output = self.build_logits(features, mode=mode)
                    predictions = self.build_predictions(output)

                elif self.config.mode == 'predict_on_the_fly':
                    output = self.build_logits(features, mode=mode)
                    predictions = self.build_predictions(output)

                elif self.config.mode == 'preprocess':
                    output = self.build_logits(features, mode=mode)
                    predictions = self.build_predictions(output)
                else:
                    predictions = features

                output = {'serving_default': tf.estimator.export.PredictOutput(predictions)}
                predictions.update(features)

                return tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions=predictions,
                    export_outputs=output)

        return model_fn

    def run_train_and_evaluate(self, train_reader, eval_reader):
        train_spec = tf.estimator.TrainSpec(input_fn=train_reader.get_input_fn(),
                                            max_steps=self.train_steps)

        eval_spec = tf.estimator.EvalSpec(input_fn=eval_reader.get_input_fn(),
                                          steps=self.num_eval_steps,
                                          throttle_secs=self.throttle_secs)

        tf.estimator.train_and_evaluate(self.estimator,
                                        train_spec=train_spec,
                                        eval_spec=eval_spec)


    def run_train(self, reader):
        self.estimator.train(input_fn=reader.get_input_fn(),
                             max_steps=self.train_steps)

    def run_evaluate(self, reader, checkpoint_path=None):
        return self.estimator.evaluate(input_fn=reader.get_input_fn(),
                                       steps=self.num_eval_steps,
                                       checkpoint_path=checkpoint_path)

    def run_predict(self, reader, writer=None, checkpoint_path=None, yield_single_examples=False):

        if writer is None:
            return self.estimator.predict(
                input_fn=reader.get_input_fn(),
                yield_single_examples=yield_single_examples,
                checkpoint_path=checkpoint_path)

        for batch_idx, outputs in enumerate(self.estimator.predict(input_fn=reader.get_input_fn(),
                                                                   yield_single_examples=yield_single_examples,
                                                                   checkpoint_path=checkpoint_path)):

            if batch_idx % 100 == 0:
                tf.logging.info("Processing %d batches" % (batch_idx))
            writer.process(outputs)

        writer.close()

    def run_preprocess(self, reader, writer):
        for batch_idx, outputs in enumerate(self.estimator.predict(input_fn=reader.get_input_fn(),
                                                                   yield_single_examples=False,
                                                                   checkpoint_path=None)):
            if batch_idx % 100 == 0:
                tf.logging.info("Processing %d batches" % (batch_idx))
            writer.process(outputs)

        writer.close()

    def export_model(self):

        export_dir_base = self.config.export_dir_base
        checkpoint_path = self.config.checkpoint_path

        def serving_input_receiver_fn():
            export_features, receiver_tensors = self.get_export_features()
            return tf.estimator.export.ServingInputReceiver(
                features=export_features, receiver_tensors=receiver_tensors, receiver_tensors_alternatives={})

        return self.estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn,
                                                checkpoint_path=checkpoint_path)

[docs]class base_model(EzTransEstimator): def __init__(self, **kwargs): user_defined_config = kwargs.get("user_defined_config", None) if user_defined_config is None: assert FLAGS.mode is not None with tf.gfile.Open(FLAGS.config, "r") as f: tf.logging.info("config file is {}".format(FLAGS.config)) config_json = json.load(f) # enhance config_json config_json["worker_hosts"] = FLAGS.worker_hosts config_json["task_index"] = FLAGS.task_index config_json["job_name"] = FLAGS.job_name config_json["num_gpus"] = FLAGS.workerGPU config_json["num_workers"] = FLAGS.workerCount if FLAGS.tables is not None: if FLAGS.mode.startswith("train_and_evaluate"): config_json['train_config']['train_input_fp'] = FLAGS.tables.split(",")[0] config_json['evaluate_config']['eval_input_fp'] = FLAGS.tables.split(",")[1] elif FLAGS.mode.startswith("train"): config_json['train_config']['train_input_fp'] = FLAGS.tables.split(",")[0] elif FLAGS.mode.startswith("evaluate"): config_json['evaluate_config']['eval_input_fp'] = FLAGS.tables.split(",")[0] elif FLAGS.mode.startswith("predict"): config_json['predict_config']['predict_input_fp'] = FLAGS.tables.split(",")[0] elif FLAGS.mode.startswith("preprocess"): config_json['preprocess_config']['preprocess_input_fp'] = FLAGS.tables.split(",")[0] else: raise RuntimeError if FLAGS.outputs is not None: if FLAGS.mode.startswith("predict"): config_json['predict_config']['predict_output_fp'] = FLAGS.outputs.split(",")[0] elif FLAGS.mode.startswith("preprocess"): config_json['preprocess_config']['preprocess_output_fp'] = FLAGS.outputs.split(",")[0] else: raise RuntimeError if "predict" in FLAGS.mode: model_ckpt = config_json['predict_config']['predict_checkpoint_path'].split("/")[-1] config_fp = config_json['predict_config']['predict_checkpoint_path'].replace(model_ckpt, "train_config.json") if tf.gfile.Exists(config_fp): with tf.gfile.Open(config_fp, "r") as f: saved_config = json.load(f) model_config = saved_config.get("model_config", None) config_json["model_config"] = model_config self.config = Config(mode=FLAGS.mode, config_json=config_json) if "train" in FLAGS.mode: assert self.config.model_dir is not None if not tf.gfile.Exists(self.config.model_dir): tf.gfile.MakeDirs(self.config.model_dir) if not tf.gfile.Exists(self.config.model_dir + "/train_config.json"): with tf.gfile.GFile(self.config.model_dir + "/train_config.json", mode='w') as f: json.dump(config_json, f) else: self.config = user_defined_config for key, val in self.config.__dict__.items(): setattr(self, key, val) num_train_examples = 0 num_predict_examples = 0 if "train" in self.config.mode: if "odps://" in self.config.train_input_fp: reader = tf.python_io.TableReader(self.config.train_input_fp, selected_cols="", excluded_cols="", slice_id=0, slice_count=1, num_threads=0, capacity=0) num_train_examples = reader.get_row_count() elif ".tfrecord" in self.config.train_input_fp: for record in tf.python_io.tf_record_iterator(self.config.train_input_fp): num_train_examples += 1 elif ".list_tfrecord" in self.config.train_input_fp: with tf.gfile.Open(self.config.train_input_fp, 'r') as f: for i, line in enumerate(f): if i == 0 and line.strip().isdigit(): num_train_examples = int(line.strip()) tf.logging.info("Reading {} training examples from list_tfrecord".format(str(num_train_examples))) break if i%10 ==0: tf.logging.info("Reading {} files".format(i)) fp = line.strip() for record in tf.python_io.tf_record_iterator(fp): num_train_examples += 1 elif ".list_csv" in self.config.train_input_fp: with tf.gfile.Open(self.config.train_input_fp, 'r') as f: for i, line in enumerate(f): if i == 0 and line.strip().isdigit(): num_train_examples = int(line.strip()) tf.logging.info("Reading {} training examples from list_csv".format(str(num_train_examples))) break if i%10 ==0: tf.logging.info("Reading {} files".format(i)) fp = line.strip() with tf.gfile.Open(fp, 'r') as f: for record in f: num_train_examples += 1 else: with tf.gfile.Open(self.config.train_input_fp, 'r') as f: for record in f: num_train_examples += 1 assert num_train_examples > 0 tf.logging.info("total number of training examples {}".format(num_train_examples)) elif "predict" in self.config.mode: if "odps" in self.config.predict_input_fp: reader = tf.python_io.TableReader(self.config.predict_input_fp, selected_cols="", excluded_cols="", slice_id=0, slice_count=1, num_threads=0, capacity=0) num_predict_examples = reader.get_row_count() elif ".tfrecord" in self.config.predict_input_fp: for record in tf.python_io.tf_record_iterator(self.config.predict_input_fp): num_predict_examples += 1 elif ".list_csv" in self.config.predict_input_fp: with tf.gfile.Open(self.config.predict_input_fp, 'r') as f: for i, line in enumerate(f): if i == 0 and line.strip().isdigit(): num_predict_examples = int(line.strip()) tf.logging.info("Use preset num training examples") break if i%10 ==0: tf.logging.info("Reading {} files".format(i)) fp = line.strip() with tf.gfile.Open(fp, 'r') as f: for record in f: num_predict_examples += 1 else: with tf.gfile.Open(self.config.predict_input_fp, 'r') as f: for record in f: num_predict_examples += 1 assert num_predict_examples > 0 tf.logging.info("total number of predicting examples {}".format(num_predict_examples)) super(base_model, self).__init__(num_train_examples=num_train_examples) def get_export_features(self): export_features = {} for feat in self.config.input_tensors_schema.split(","): feat_name = feat.split(":")[0] feat_type = feat.split(":")[1] seq_len = int(feat.split(":")[2]) feat = {} feat['name'] = feat_name feat['type'] = feat_type if feat_type == "int": dtype = tf.int32 elif feat_type == "float": dtype = tf.float32 if seq_len == 1: ph = tf.placeholder(dtype=dtype, shape=[None], name=feat_name) else: ph = tf.placeholder(dtype=dtype, shape=[None, None], name=feat_name) export_features[feat_name] = ph receiver_tensors = {} feat_names = [] for feat in self.config.receiver_tensors_schema.split(","): feat_names.append(feat.split(":")[0]) for feat_name in feat_names: receiver_tensors[feat_name] = export_features[feat_name] return export_features, receiver_tensors
[docs] def build_logits(self, features, mode): """ Given features, this method take care of building graph for train/eval/predict Args: features : either raw text features or numerical features such as input_ids, input_mask ... mode : tf.estimator.ModeKeys.TRAIN | tf.estimator.ModeKeys.EVAL | tf.estimator.ModeKeys.PREDICT Returns: logits, labels Examples:: def build_logits(self, features, mode=None): preprocessor = preprocessors.get_preprocessor(self.pretrain_model_name_or_path) model = model_zoo.get_pretrained_model(self.pretrain_model_name_or_path) dense = layers.Dense(self.num_labels, kernel_initializer=layers.get_initializer(0.02), name='dense') input_ids, input_mask, segment_ids, label_ids = preprocessor(features) outputs = model([input_ids, input_mask, segment_ids], mode=mode) pooled_output = outputs[1] logits = dense(pooled_output) return logits, label_ids """ raise NotImplementedError("must be implemented in descendants")
[docs] def build_loss(self, logits, labels): """Build loss function Args: logits : logits returned from build_logits labels : labels returned from build_logits Returns: loss Examples:: def build_loss(self, logits, labels): return softmax_cross_entropy(labels, depth=self.config.num_labels, logits=logits) """ raise NotImplementedError("must be implemented in descendants")
[docs] def build_eval_metrics(self, logits, labels): """Build evaluation metrics Args: logits : logits returned from build_logits labels : labels returned from build_logits Returns: metric_dict Examples:: def build_eval_metrics(self, logits, labels): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) info_dict = { "predictions": predictions, "labels": labels, } evaluator = PyEvaluator() labels = [i for i in range(self.num_labels)] metric_dict = evaluator.get_metric_ops(info_dict, labels) ret_metrics = evaluator.evaluate(labels) tf.summary.scalar("eval accuracy", ret_metrics['py_accuracy']) tf.summary.scalar("eval F1 micro score", ret_metrics['py_micro_f1']) tf.summary.scalar("eval F1 macro score", ret_metrics['py_macro_f1']) return metric_dict """ raise NotImplementedError("must be implemented in descendants")
[docs] def build_predictions(self, logits): """Build predictions Args: logits : logits returned from build_logits Returns: predictions Examples:: def build_predictions(self, output): logits, _ = output predictions = dict() predictions["predictions"] = tf.argmax(logits, axis=-1, output_type=tf.int32) return predictions """ raise NotImplementedError("must be implemented in descendants")