Source code for easytransfer.datasets.odps_table_writer

# coding=utf-8
# Copyright (c) 2019 Alibaba PAI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from easytransfer.engines.distribution import Process, Counter

[docs]class OdpsTableWriter(Process): """ Writer odps table Args: output_glob : output file fp output_schema : output_schema """ def __init__(self, output_glob, output_schema, slice_id, input_queue, job_name='DistOdpsTableWriter', **kwargs): super(OdpsTableWriter, self).__init__(job_name, 1, input_queue) self.table_writer = tf.python_io.TableWriter(output_glob, slice_id=slice_id) self.output_schema = output_schema if self.output_schema == "input_ids,input_mask,segment_ids,label_id"\ or self.output_schema == "input_ids,input_mask,segment_ids": self.output_indices = [0,1,2,3] elif self.output_schema == "input_ids,input_mask,segment_ids,masked_lm_positions,masked_lm_ids,masked_lm_weights": self.output_indices = [0, 1, 2, 3, 4, 5] else: self.output_indices = [i for i in range(len(output_schema.split(",")))] self.counter = Counter() def close(self): tf.logging.info('close table writer') self.table_writer.close() def process(self, features): if self.output_schema == "input_ids,input_mask,segment_ids,label_id": input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] labels = features["label_id"] def str_format(lst = []): return ','.join(list(map(lambda x: str(x), lst))) input_ids_str = list(map(lambda x: str_format(x), input_ids )) input_mask_str = list(map(lambda x: str_format(x), input_mask )) segment_ids_str = list(map(lambda x: str_format(x), segment_ids )) label_ids_str = list(map(lambda x: str_format(x), labels )) self.table_writer.write(list(zip(input_ids_str, input_mask_str, segment_ids_str, label_ids_str)), self.output_indices) elif self.output_schema == "input_ids,input_mask,segment_ids,masked_lm_positions,masked_lm_ids,masked_lm_weights": input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] masked_lm_positions = features["masked_lm_positions"] masked_lm_ids=features["masked_lm_ids"] masked_lm_weights=features["masked_lm_weights"] def str_format(lst=[]): return ','.join(list(map(lambda x: str(x), lst))) input_ids_str = list(map(lambda x: str_format(x), input_ids)) input_mask_str = list(map(lambda x: str_format(x), input_mask)) segment_ids_str = list(map(lambda x: str_format(x), segment_ids)) masked_lm_positions_str = list(map(lambda x: str_format(x), masked_lm_positions)) masked_lm_ids_str = list(map(lambda x: str_format(x), masked_lm_ids)) masked_lm_weights_str = list(map(lambda x: str_format(x), masked_lm_weights)) self.table_writer.write(list(zip(input_ids_str, input_mask_str, segment_ids_str, masked_lm_positions_str, masked_lm_ids_str, masked_lm_weights_str)), self.output_indices) else: def str_format(element): if isinstance(element, float) or isinstance(element, int) \ or isinstance(element, str): return str(element) if element == []: return '' if isinstance(element, list) and not isinstance(element[0], list): return ','.join([str(t) for t in element]) elif isinstance(element[0], list): return ';'.join([','.join([str(t) for t in item]) for item in element]) else: raise RuntimeError("type {} not support".format(type(element))) ziped_list = [] for idx, feat_name in enumerate(self.output_schema.split(",")): batch_feat_value = features[feat_name] curr_list = [] for feat in batch_feat_value: if len(batch_feat_value.shape) == 1: curr_list.append([feat]) else: curr_list.append(feat.tolist()) ziped_list.append(list(map(lambda x: str_format(x), curr_list))) self.table_writer.write(list(zip(*ziped_list)), self.output_indices) self.counter.count()