Source code for easytransfer.datasets.csv_writer

# coding=utf-8
# Copyright (c) 2019 Alibaba PAI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import six
from easytransfer.engines.distribution import Process

[docs]class CSVWriter(Process): """ Writer csv format Args: output_glob : output file fp output_schema : output_schema """ def __init__(self, output_glob, output_schema, input_queue=None, **kwargs): job_name = 'DistTableWriter' super(CSVWriter, self).__init__(job_name, 1, input_queue) if six.PY3: self.writer = open(output_glob, "w", encoding='utf8') elif six.PY2: self.writer = open(output_glob, "w") self.output_schema = output_schema def close(self): tf.logging.info('Finished writing') self.writer.close() def process(self, features): def str_format(element): if isinstance(element, float) or isinstance(element, int) \ or isinstance(element, str): return str(element) if element == []: return '' if isinstance(element, list) and not isinstance(element[0], list): return ','.join([str(t) for t in element]) elif isinstance(element[0], list): return ';'.join([','.join([str(t) for t in item]) for item in element]) else: raise RuntimeError("type {} not support".format(type(element))) ziped_list = [] for idx, feat_name in enumerate(self.output_schema.split(",")): batch_feat_value = features[feat_name] curr_list = [] for feat in batch_feat_value: if len(batch_feat_value.shape) == 1: curr_list.append([feat]) else: curr_list.append(feat.tolist()) ziped_list.append(curr_list) for ele in zip(*ziped_list): str_list = [] for curr in ele: str_list.append(str_format(curr)) self.writer.write("\t".join(str_list) + "\n")