Source code for easytransfer.datasets.csv_reader

# coding=utf-8
# Copyright (c) 2019 Alibaba PAI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from .reader import Reader

[docs]class CSVReader(Reader): """ Read csv format Args: input_glob : input file fp batch_size : input batch size is_training : True or False thread_num: thread number """ def __init__(self, input_glob, batch_size, is_training, thread_num=1, input_queue=None, output_queue=None, job_name='DISTCSVReader', **kwargs): super(CSVReader, self).__init__(batch_size, is_training, thread_num, input_queue, output_queue, job_name, **kwargs) self.input_glob = input_glob if is_training: with tf.gfile.Open(input_glob, 'r') as f: for record in f: self.num_train_examples += 1 tf.logging.info("{}, total number of training examples {}".format(input_glob, self.num_train_examples)) else: with tf.gfile.Open(input_glob, 'r') as f: for record in f: self.num_eval_examples += 1 tf.logging.info("{}, total number of eval examples {}".format(input_glob, self.num_eval_examples)) self.csv_reader = tf.gfile.Open(input_glob) def get_input_fn(self): def input_fn(): dataset = tf.data.TextLineDataset(self.input_glob) return self._get_data_pipeline(dataset, self._decode_csv) return input_fn def _decode_csv(self, record): record_defaults = [] tensor_names = [] shapes = [] for name, feature in self.input_tensors.items(): default_value = feature.default_value shape = feature.shape if shape[0] > 1: if default_value == 'base64': default_value = 'base64' else: default_value = '' else: default_value = feature.default_value record_defaults.append([default_value]) tensor_names.append(name) shapes.append(feature.shape) num_tensors = len(tensor_names) items = tf.decode_csv(record, field_delim='\t', record_defaults=record_defaults, use_quote_delim=False) outputs = dict() total_shape = 0 for shape in shapes: total_shape += sum(shape) for idx, (name, feature) in enumerate(self.input_tensors.items()): # finetune feature_text if total_shape != num_tensors: input_tensor = items[idx] if sum(feature.shape) > 1: default_value = record_defaults[idx] if default_value[0] == '': output = tf.string_to_number( tf.string_split(tf.expand_dims(input_tensor, axis=0), delimiter=",").values, feature.dtype) output = tf.reshape(output, [feature.shape[0], ]) elif default_value[0] == 'base64': decode_b64_data = tf.io.decode_base64(tf.expand_dims(input_tensor, axis=0)) output = tf.reshape(tf.io.decode_raw(decode_b64_data, out_type=tf.float32), [feature.shape[0], ]) else: output = tf.reshape(input_tensor, [1, ]) elif total_shape == num_tensors: # preprocess raw_text output = items[idx] outputs[name] = output return outputs
[docs] def process(self, input_data): for line in self.csv_reader: line = line.strip() segments = line.split("\t") output_dict = {} for idx, name in enumerate(self.input_tensor_names): output_dict[name] = segments[idx] self.put(output_dict) raise IndexError("Read tabel done")
[docs] def close(self): self.csv_reader.close()
[docs]class BundleCSVReader(CSVReader): """ Read group of csv formats Args: input_glob : input file fp batch_size : input batch size worker_hosts: worker hosts task_index: task index is_training : True or False """ def __init__(self, input_glob, batch_size, worker_hosts, task_index, is_training=False, **kwargs): super(BundleCSVReader, self).__init__(input_glob, batch_size, is_training, **kwargs) self.input_fps = [] with tf.gfile.Open(input_glob, 'r') as f: for line in f: line = line.strip() if line == '' or line.isdigit(): continue self.input_fps.append(line) self.worker_hosts = worker_hosts self.task_index = task_index def get_input_fn(self): def input_fn(): if self.is_training: d = tf.data.Dataset.from_tensor_slices(tf.constant(self.input_fps)) d = d.shard(len(self.worker_hosts.split(',')), self.task_index) d = d.repeat() d = d.shuffle(buffer_size=len(self.input_fps)) cycle_length = min(4, len(self.input_fps)) d = d.apply( tf.data.experimental.parallel_interleave( tf.data.TextLineDataset, sloppy=True, cycle_length=cycle_length)) d = d.shuffle(buffer_size=self.shuffle_buffer_size) else: d = tf.data.Dataset.from_tensor_slices(tf.constant(self.input_fps)) d = d.shard(len(self.worker_hosts.split(',')), self.task_index) d = d.repeat(1) cycle_length = min(4, len(self.input_fps)) d = d.apply( tf.data.experimental.parallel_interleave( tf.data.TextLineDataset, sloppy=True, cycle_length=cycle_length)) # d = tf.data.TextLineDataset(self.input_fps) # Since we evaluate for a fixed number of steps we don't want to encounter # out-of-range exceptions. #d = d.repeat() d = self._map_batch_prefetch(d, self._decode_csv) return d return input_fn