TFRecords是一种tensorflow的内定标准文件格式,其实质是二进制文件,遵循protocol buffer(PB)协议(百度百科给的定义:protocol buffer(以下简称PB)是google 的一种数据交换的格式,它独立于语言,独立于平台),其后缀一般为tfrecord。TFRecords文件方便复制和移动,能够很好的利用内存,无需单独标记文件,适用于大量数据的顺序读取,是tensorflow“从文件里读取数据”的一种官方推荐方法.
其源代码主要位于文件tensorflow/python/lib/io/tf_record.py
官方例程tensorflow/examples/how_tos/reading_data/convert_to_records.py
第一步,生成TFRecord Writer
writer = tf.python_io.TFRecordWriter(path, options=None)
path:TFRecord文件的存放路径;
option:TFRecordOptions对象,定义TFRecord文件保存的压缩格式;
有三种文件压缩格式可选,分别为TFRecordCompressionType.ZLIB、TFRecordCompressionType.GZIP以及TFRecordCompressionType.NONE,默认为最后一种,即不做任何压缩,定义方法如下:
option = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
第二步,tf.train.Feature生成协议信息
一个协议信息特征(这里翻译可能不准确)是将原始数据编码成特定的格式,一般是features中又包含feature,内层feature是一个字典值,它是将某个类型列表编码成特定的feature格式,而该字典键用于读取TFRecords文件时索引得到不同的数据,某个类型列表可能包含零个或多个值,列表类型一般有BytesList, FloatList, Int64List,通常用如下方法来生成某个列表类型再送给内层的tf.train.Feature编码:
1
2
3
|
tf.train.BytesList(value=[value]) # value转化为字符串(二进制)列表
tf.train.FloatList(value=[value]) # value转化为浮点型列表
tf.train.Int64List(value=[value]) # value转化为整型列表
|
其中,value是你要保存的数据
内层feature编码方式:
1
2
3
4
5
|
feature_internal = {
"width":tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
"weights":tf.train.Feature(float_list=tf.train.FloatList(value=[weights])),
"image_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw]))
}
|
外层features再将内层字典编码:
features_extern = tf.train.Features(feature_internal)
看起来,tf.train.Feature这个接口可以编码封装列表类型和字典类型,但注意用的接口是不一样的,内层用的是tf.train.Feature而外层用的是tf.train.Features,一个是对单一数据编码成单元feature,而另一个是将包含多个单元feature的字典数据再编码为集成的features。
第三步,使用tf.train.Example将features编码数据封装成特定的PB协议格式
example = tf.train.Example(features_extern)
第四步,将example数据系列化为字符串
example_str = example.SerializeToString()
第五步,将系列化为字符串的Example数据写入协议缓冲器
writer.write(example_str)
TFRecordWriter拥有类似python文件操作的接口,如writer.flush()立即将内存数据刷新到磁盘文件里,writer.close()关闭TFRecordWriter,在写完数据到协议缓冲区后通常需要调用writer.close()主动关闭TFRecords文件操作接口。
TFRecords文件的读取主要是使用tf.TFRecordReader和tf.python_io.tf_record_iterator其源代码位于tensorflow/python/ops/io_ops.py和tensorflow/python/lib/io/tf_record.py
第一步,使用tf.train.string_input_producer生成文件队列
filename_queues = tf.train.string_input_producer([tfrecord_path_none,tfrecord_path_zlib,tfrecord_path_gzip])
第二步,生成TFRecord Reader
reader = tf.TFRecordReader(name=None, options=None)
第三步,读取TFRecord文件
_,serialized_example = reader.read(filename)
filename是tf.train.string_input_producer得到的文件队列名,读取得到的是一个系列化的example。
第四步,使用tf.parse_single_example解析得到的系列化example
features = tf.parse_single_example(
serialized_example,
features={
"float_val":tf.FixedLenFeature([], tf.float),
"width":tf.FixedLenFeature([], tf.int64),
"height":tf.FixedLenFeature([], tf.int64),
"image_raw":tf.FixedLenFeature([], tf.string)
}
)
需要按照存储时的格式还原features,必须写明features内的字典的键索引得到特定的数据!
第五步,处理得到的数据
features是一个字典,要使用特定数据,需要用字典的key来索引得到相应的数据,如要得到的width的值,则可以以features['width']得到,对于得到的数据还需要做一些处理的,比如features['image_raw']需要decode才能显示整个图片。
在tensorflow2.0中实现了以上操作 贴一下代码,还有一些问题没有解决
后续遇到再补充
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
|
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
filename = "pic.jpg"
image = tf.io.read_file(filename)
image_jpeg = tf.io.decode_image(image,channels=3,name="decode_jpeg_picture",expand_animations=True)
image_jpeg = tf.reshape(image_jpeg,shape=(4032,3024,3))
img_shape = image_jpeg.shape
width = img_shape[0]
height = img_shape[1]
sess = tf.Session()
image = sess.run(image)
sess.close()
path_none = "img_compress_none.tfrecord"
path_zlib = "img_compress_zlib.tfrecord"
path_gzip = "img_compress_gzip.tfrecord"
options_zlib = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
options_gzip = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
# 定义不同压缩选项的TFRecordWriter
writer_none = tf.python_io.TFRecordWriter(path_none, options=None)
writer_zlib = tf.python_io.TFRecordWriter(path_zlib, options=options_zlib)
writer_gzip = tf.python_io.TFRecordWriter(path_gzip, options=options_gzip)
# 编码数据,将数据生成特定类型列表后再编码为feature格式,再生成字典形式
feature_internal_none = {
"float_val":tf.train.Feature(float_list=tf.train.FloatList(value=[9.99])),
"width":tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
"height":tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
"image_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
}
feature_internal_zlib = {
"float_val":tf.train.Feature(float_list=tf.train.FloatList(value=[8.88])),
"width":tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
"height":tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
"image_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
}
feature_internal_gzip = {
"float_val":tf.train.Feature(float_list=tf.train.FloatList(value=[6.66])),
"width":tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
"height":tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
"image_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
}
# 编码内层字典形式数据
features_extern_none = tf.train.Features(feature=feature_internal_none)
features_extern_zlib = tf.train.Features(feature=feature_internal_zlib)
features_extern_gzip = tf.train.Features(feature=feature_internal_gzip)
# 将外层features生成特定格式的example
example_none = tf.train.Example(features=features_extern_none)
example_zlib = tf.train.Example(features=features_extern_zlib)
example_gzip = tf.train.Example(features=features_extern_gzip)
# example系列化字符串
example_str_none = example_none.SerializeToString()
example_str_zlib = example_zlib.SerializeToString()
example_str_gzip = example_gzip.SerializeToString()
# 将系列化字符串写入协议缓冲区
writer_none.write(example_str_none)
writer_zlib.write(example_str_zlib)
writer_gzip.write(example_str_gzip)
# 关闭TFRecords文件操作接口
writer_none.close()
writer_zlib.close()
writer_gzip.close()
print("finish to write data to tfrecord file!")
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
#!/usr/bin/env python
# coding: utf-8
# In[5]:
import tensorflow as tf
# In[10]:
from IPython import display
# In[11]:
raw_image_dataset = tf.data.TFRecordDataset('img_compress_none.tfrecord')
# In[12]:
image_feature_description = {
"float_val":tf.io.FixedLenFeature([], tf.float32),
"width":tf.io.FixedLenFeature([], tf.int64),
"height":tf.io.FixedLenFeature([], tf.int64),
"image_raw":tf.io.FixedLenFeature([], tf.string)
}
# In[13]:
def _parse_image_function(example_proto):
# Parse the input tf.Example proto using the dictionary above.
return tf.io.parse_single_example(example_proto, image_feature_description)
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
# In[14]:
for image_features in parsed_image_dataset:
image_raw = image_features['image_raw'].numpy()
display.display(display.Image(data=image_raw))
|
tensorflow1.x与tensorflow2.x兼容的两种方式:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.compat.v1.Session()