Source code for easytransfer.datasets.tfrecord_writer

# coding=utf-8
# Copyright (c) 2019 Alibaba PAI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from easytransfer.engines.distribution import Process
import collections

[docs]class TFRecordWriter(Process): """ Writer tfrecords Args: output_glob : output file fp output_schema : output_schema """ def __init__(self, output_glob, output_schema, input_queue): job_name = 'DistTFRecordWriter' super(TFRecordWriter, self).__init__(job_name, 1, input_queue) self.writer = tf.python_io.TFRecordWriter(output_glob) self.output_schema = output_schema def close(self): tf.logging.info('Finished writing') self.writer.close() def create_int_feature(self, values): feature = tf.train.Feature( int64_list=tf.train.Int64List(value=list(values))) return feature def create_float_feature(self, values): feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) return feature def process(self, features): ziped_list = [] for idx, feat_name in enumerate(self.output_schema.split(",")): batch_feat_value = features[feat_name] curr_list = [] for feat in batch_feat_value: if len(batch_feat_value.shape) == 1: curr_list.append([feat]) else: curr_list.append(feat.tolist()) ziped_list.append(curr_list) feat_names = self.output_schema.split(",") for ele in zip(*ziped_list): features = collections.OrderedDict() for feat_name, value in zip(feat_names, ele): if isinstance(value[0], float): features[feat_name] = self.create_float_feature(value) elif isinstance(value[0], int): features[feat_name] = self.create_int_feature(value) elif isinstance(value[0], str): new_value = [int(x) for x in value] features[feat_name] = self.create_int_feature(new_value) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) self.writer.write(tf_example.SerializeToString())