基本要求:
熟悉 C++ 编程。
确保下载 TensorFlow源文件 , 并可编译使用。
我们将支持文件格式的任务分成两部分:
文件格式: 我们使用 Reader Op来从文件中读取一个 record (可以使任意字符串)。
记录格式: 我们使用解码器或者解析运算将一个字符串记录转换为TensorFlow可以使用的张量。
例如, 读取一个 CSV 文件,我们使用 一个文本读写器, 然后是从一行文本中解析CSV数据的运算。
主要内容
自定义数据读取
编写一个文件格式读写器
编写一个记录格式Op
编写一个文件格式读写器
Reader 是专门用来读取文件中的记录的。TensorFlow中内建了一些读写器Op的实例:
tf.TFRecordReader (代码位于kernels/tf_record_reader_op.cc)
tf.FixedLengthRecordReader (代码位于 kernels/fixed_length_record_reader_op.cc)
tf.TextLineReader (代码位于 kernels/text_line_reader_op.cc)
你可以看到这些读写器的界面是一样的,唯一的差异是在它们的构造函数中。最重要的方法是 Read。
它需要一个行列参数,通过这个行列参数,可以在需要的时候随时读取文件名 (例如: 当 Read
Op首次运行,或者前一个 Read` 从一个文件中读取最后一条记录时)。它将会生成两个标量张量:
一个字符串和一个字符串关键值。
新创建一个名为 SomeReader 的读写器,需要以下步骤:
在 C++ 中, 定义一个 tensorflow::ReaderBase的子类,命名为 "SomeReader".
在 C++ 中,注册一个新的读写器Op和Kernel,命名为 "SomeReader"。
在 Python 中, 定义一个 tf.ReaderBase 的子类,命名为 "SomeReader"。
你可以把所有的 C++ 代码放在 tensorflow/core/user_ops/some_reader_op.cc文件中.
读取文件的代码将被嵌入到C++ 的 ReaderBase 类的迭代中。 这个 ReaderBase
类 是在 tensorflow/core/kernels/reader_base.h 中定义的。
你需要执行以下的方法:
OnWorkStartedLocked:打开下一个文件
ReadLocked:读取一个记录或报告 EOF/error
OnWorkFinishedLocked:关闭当前文件
ResetLocked:清空记录,例如:一个错误记录
以上这些方法的名字后面都带有 "Locked", 表示 ReaderBase
在调用任何一个方法之前确保获得互斥锁,这样就不用担心线程安全(虽然只保护了该类中的元素而不是全局的)。
对于 OnWorkStartedLocked, 需要打开的文件名是 current_work()
函数的返回值。此时的 ReadLocked 的数字签名如下:
Status ReadLocked(string*
key, string* value, bool* produced, bool*
at_end)
|
如果 ReadLocked 从文件中成功读取了一条记录,它将更新为:
*key: 记录的标志位,通过该标志位可以重新定位到该记录。 可以包含从 current_work()
返回值获得的文件名,并追加一个记录号或其他信息。
*value: 包含记录的内容。
*produced: 设置为 true。
当你在文件(EOF)末尾,设置 *at_end 为 true ,在任何情况下,都将返回 Status::OK()。
当出现错误的时候,只需要使用 tensorflow/core/lib/core/errors.h
中的一个辅助功能就可以简单地返回,不需要做任何参数修改。
接下来你讲创建一个实际的读写器Op。 如果你已经熟悉了添加新的Op 那会很有帮助。 主要步骤如下:
注册Op。
定义并注册 OpKernel。
要注册Op,你需要用到一个调用指令定义在 tensorflow/core/framework/op.h中的REGISTER_OP。
读写器 Op 没有输入,只有 Ref(string) 类型的单输出。它们调用 SetIsStateful(),并有一个
container 字符串和 shared_name 属性. 你可以在一个 Doc 中定义配置或包含文档的额外属性。
例如:详见 tensorflow/core/ops/io_ops.cc等:
#include "tensorflow/core/framework/op.h"
REGISTER_OP("TextLineReader")
.Output("reader_handle: Ref(string)")
.Attr("skip_header_lines: int = 0")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.Doc(R"doc(
A Reader that outputs the lines of a file
delimited by '\n'.
)doc");
|
要定义一个 OpKernel, 读写器可以使用定义在tensorflow/core/framework/reader_op_kernel.h中的
ReaderOpKernel 的递减快捷方式,并运行一个叫 SetReaderFactory
的构造函数。 定义所需要的类之后,你需要通过 REGISTER_KERNEL_BUILDER(...)
注册这个类。
一个没有属性的例子:
#include "tensorflow/core/framework/reader_op_kernel.h"
class TFRecordReaderOp : public ReaderOpKernel
{
public:
explicit TFRecordReaderOp (OpKernelConstruction*
context)
: ReaderOpKernel(context) {
Env* env = context->env();
SetReaderFactory([this, env]() { return
new TFRecordReader(name(), env); });
}
};
REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
TFRecordReaderOp);
|
一个带有属性的例子:
#include "tensorflow/core/framework/reader_op_kernel.h"
class TextLineReaderOp : public ReaderOpKernel
{
public:
explicit TextLineReaderOp(OpKernelConstruction*
context)
: ReaderOpKernel(context) {
int skip_header_lines = -1;
OP_REQUIRES_OK(context,
context->GetAttr("skip_header_lines",
&skip_header_lines));
OP_REQUIRES(context, skip_header_lines >=
0,
errors::InvalidArgument("skip_header_lines
must be >= 0 not ",
skip_header_lines));
Env* env = context->env();
SetReaderFactory([this, skip_header_lines,
env]() {
return new TextLineReader(name(), skip_header_lines,
env);
});
}
};
REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
TextLineReaderOp);
|
最后一步是添加 Python 包装器,你需要将 tensorflow.python.ops.io_ops
导入到 tensorflow/python/user_ops/user_ops.py,并添加一个
io_ops.ReaderBase的衍生函数。
from tensorflow.python.framework
import ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import io_ops
class SomeReader(io_ops.ReaderBase):
def __init__(self, name=None):
rr = gen_user_ops.some_reader(name=name)
super(SomeReader, self).__init__(rr)
ops.NoGradient("SomeReader")
ops.RegisterShape("SomeReader")(common_shapes.scalar_shape)
|
你可以在 tensorflow/python/ops/io_ops.py中查看一些范例。
编写一个记录格式Op
一般来说,这是一个普通的Op, 需要一个标量字符串记录作为输入, 因此遵循 添加Op的说明。
你可以选择一个标量字符串作为输入, 并包含在错误消息中报告不正确的格式化数据。
用于解码记录的运算实例:
tf.parse_single_example (and tf.parse_example)
tf.decode_csv
tf.decode_raw
请注意,使用多个Op 来解码某个特定的记录格式也是有效的。 例如,你有一张以字符串格式保存在
tf.train.Example 协议缓冲区的图像文件。 根据该图像的格式, 你可能从 tf.parse_single_example
的Op 读取响应输出并调用 tf.decode_jpeg, tf.decode_png, 或者
tf.decode_raw。通过读取 tf.decode_raw 的响应输出并使用tf.slice
和 tf.reshape 来提取数据是通用的方法。 |