# coding=utf-8
# Copyright (c) 2019 Alibaba PAI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from .reader import Reader
[docs]class OdpsTableReader(Reader):
""" Read odps table
Args:
input_glob : input file fp
batch_size : input batch size
is_training : True or False
"""
def __init__(self,
input_glob,
batch_size,
is_training,
thread_num=1,
input_queue=None,
output_queue=None,
slice_id=0,
slice_count=1,
job_name='DISTOdpsTableReader',
**kwargs):
super(OdpsTableReader, self).__init__(batch_size,
is_training,
thread_num,
input_queue,
output_queue,
job_name,
**kwargs)
self.input_glob = input_glob
self.table_reader = tf.python_io.TableReader(
input_glob,
selected_cols=','.join(self.input_tensor_names),
slice_id=slice_id,
slice_count=slice_count)
table_schema_list = [item[0] for item in self.table_reader.get_schema().tolist()]
for input_column_name in self.input_tensor_names:
if input_column_name not in table_schema_list:
raise ValueError("{} doesn't appear in odps table schema {}"
.format(input_column_name, ",".join(table_schema_list)))
if is_training:
self.num_train_examples = self.table_reader.get_row_count()
tf.logging.info("{}, total number of training examples {}".format(input_glob, self.num_train_examples))
else:
self.num_eval_examples = self.table_reader.get_row_count()
tf.logging.info(
"{}, total number of eval or predict examples {}".format(input_glob, self.num_eval_examples))
self.record_defaults = []
self.feature_types = []
self.slice_id = 0
self.slice_count = 1
self.shapes = []
for name, tensor in self.input_tensors.items():
default_value = tensor.default_value
shape = tensor.shape
if shape[0] > 1:
if default_value == 'base64':
default_value = 'base64'
else:
default_value = ''
self.record_defaults.append([default_value])
self.shapes.append(tensor.shape)
def get_input_fn(self):
def input_fn():
dataset = tf.data.TableRecordDataset(self.input_glob,
record_defaults=self.record_defaults,
selected_cols=','.join(self.input_tensor_names),
slice_id=self.slice_id,
slice_count=self.slice_count)
return self._get_data_pipeline(dataset, self._decode_odps_table)
return input_fn
def _decode_odps_table(self, *items):
num_tensors = len(self.input_tensor_names)
total_shape = 0
for shape in self.shapes:
total_shape += sum(shape)
ret = dict()
for idx, (name, feature) in enumerate(self.input_tensors.items()):
# finetune feature_text
if total_shape != num_tensors:
input_tensor = tf.squeeze(items[idx])
if sum(feature.shape) > 1:
default_value = self.record_defaults[idx]
if default_value[0] == '':
output = tf.string_to_number(
tf.string_split(tf.expand_dims(input_tensor, axis=0), delimiter=",").values,
feature.dtype)
output = tf.reshape(output, [feature.shape[0], ])
elif default_value[0] == 'base64':
decode_b64_data = tf.io.decode_base64(input_tensor)
output = tf.reshape(tf.io.decode_raw(decode_b64_data, out_type=tf.float32),
[feature.shape[0], ])
else:
output = tf.reshape(input_tensor, [1, ])
elif total_shape == num_tensors:
# preprocess raw_text
output = items[idx]
ret[name] = output
return ret
[docs] def process(self, input_data):
while True:
try:
batch_records = self.table_reader.read(self.batch_size)
for _, record in enumerate(batch_records):
output_dict = {}
for idx, name in enumerate(self.input_tensor_names):
output_dict[name] = record[idx]
self.put(output_dict)
except tf.errors.OutOfRangeError:
raise IndexError('read table data done')
except tf.python_io.OutOfRangeException:
raise IndexError('read table data done')