[Tf]TFRecords文件的生成和读取方法

TFRecords文件的生成和读取方法

TFRecords是一种tensorflow的内定标准文件格式,其实质是二进制文件,遵循protocol buffer(PB)协议(百度百科给的定义:protocol buffer(以下简称PB)是google 的一种数据交换的格式,它独立于语言,独立于平台),其后缀一般为tfrecord。TFRecords文件方便复制和移动,能够很好的利用内存,无需单独标记文件,适用于大量数据的顺序读取,是tensorflow“从文件里读取数据”的一种官方推荐方法.

TFRecords文件的生成

其源代码主要位于文件tensorflow/python/lib/io/tf_record.py 官方例程tensorflow/examples/how_tos/reading_data/convert_to_records.py

第一步,生成TFRecord Writer writer = tf.python_io.TFRecordWriter(path, options=None)

path:TFRecord文件的存放路径;

option:TFRecordOptions对象,定义TFRecord文件保存的压缩格式;

有三种文件压缩格式可选,分别为TFRecordCompressionType.ZLIB、TFRecordCompressionType.GZIP以及TFRecordCompressionType.NONE,默认为最后一种,即不做任何压缩,定义方法如下:

option = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)

第二步,tf.train.Feature生成协议信息 一个协议信息特征(这里翻译可能不准确)是将原始数据编码成特定的格式,一般是features中又包含feature,内层feature是一个字典值,它是将某个类型列表编码成特定的feature格式,而该字典键用于读取TFRecords文件时索引得到不同的数据,某个类型列表可能包含零个或多个值,列表类型一般有BytesList, FloatList, Int64List,通常用如下方法来生成某个列表类型再送给内层的tf.train.Feature编码:

1
2
3
tf.train.BytesList(value=[value]) # value转化为字符串(二进制)列表
tf.train.FloatList(value=[value]) # value转化为浮点型列表
tf.train.Int64List(value=[value]) # value转化为整型列表

其中,value是你要保存的数据

内层feature编码方式:

1
2
3
4
5
feature_internal = {
"width":tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
"weights":tf.train.Feature(float_list=tf.train.FloatList(value=[weights])),
"image_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw]))
}

外层features再将内层字典编码: features_extern = tf.train.Features(feature_internal)

看起来,tf.train.Feature这个接口可以编码封装列表类型和字典类型,但注意用的接口是不一样的,内层用的是tf.train.Feature而外层用的是tf.train.Features,一个是对单一数据编码成单元feature,而另一个是将包含多个单元feature的字典数据再编码为集成的features。

第三步,使用tf.train.Example将features编码数据封装成特定的PB协议格式

example = tf.train.Example(features_extern)

第四步,将example数据系列化为字符串

example_str = example.SerializeToString()

第五步,将系列化为字符串的Example数据写入协议缓冲器

writer.write(example_str)

TFRecordWriter拥有类似python文件操作的接口,如writer.flush()立即将内存数据刷新到磁盘文件里,writer.close()关闭TFRecordWriter,在写完数据到协议缓冲区后通常需要调用writer.close()主动关闭TFRecords文件操作接口。


TFRecords文件的读取

TFRecords文件的读取主要是使用tf.TFRecordReader和tf.python_io.tf_record_iterator其源代码位于tensorflow/python/ops/io_ops.py和tensorflow/python/lib/io/tf_record.py

第一步,使用tf.train.string_input_producer生成文件队列

filename_queues = tf.train.string_input_producer([tfrecord_path_none,tfrecord_path_zlib,tfrecord_path_gzip])

第二步,生成TFRecord Reader reader = tf.TFRecordReader(name=None, options=None)

第三步,读取TFRecord文件 _,serialized_example = reader.read(filename)

filename是tf.train.string_input_producer得到的文件队列名,读取得到的是一个系列化的example。

第四步,使用tf.parse_single_example解析得到的系列化example

features = tf.parse_single_example( 
serialized_example,
features={
"float_val":tf.FixedLenFeature([], tf.float),
"width":tf.FixedLenFeature([], tf.int64),
"height":tf.FixedLenFeature([], tf.int64),
"image_raw":tf.FixedLenFeature([], tf.string)
}
)

需要按照存储时的格式还原features,必须写明features内的字典的键索引得到特定的数据!

第五步,处理得到的数据

features是一个字典,要使用特定数据,需要用字典的key来索引得到相应的数据,如要得到的width的值,则可以以features['width']得到,对于得到的数据还需要做一些处理的,比如features['image_raw']需要decode才能显示整个图片。


在tensorflow2.0中实现了以上操作 贴一下代码,还有一些问题没有解决 后续遇到再补充


tfrecord_image.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

filename = "pic.jpg"

image = tf.io.read_file(filename)

image_jpeg = tf.io.decode_image(image,channels=3,name="decode_jpeg_picture",expand_animations=True)

image_jpeg = tf.reshape(image_jpeg,shape=(4032,3024,3))

img_shape = image_jpeg.shape

width = img_shape[0]

height = img_shape[1]

sess = tf.Session()

image = sess.run(image)

sess.close()


path_none = "img_compress_none.tfrecord"
path_zlib = "img_compress_zlib.tfrecord"
path_gzip = "img_compress_gzip.tfrecord"


options_zlib = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
options_gzip = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)

# 定义不同压缩选项的TFRecordWriter
writer_none = tf.python_io.TFRecordWriter(path_none, options=None)
writer_zlib = tf.python_io.TFRecordWriter(path_zlib, options=options_zlib)
writer_gzip = tf.python_io.TFRecordWriter(path_gzip, options=options_gzip)
# 编码数据,将数据生成特定类型列表后再编码为feature格式,再生成字典形式
feature_internal_none = {
"float_val":tf.train.Feature(float_list=tf.train.FloatList(value=[9.99])),
"width":tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
"height":tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
"image_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
}
feature_internal_zlib = {
"float_val":tf.train.Feature(float_list=tf.train.FloatList(value=[8.88])),
"width":tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
"height":tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
"image_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
}
feature_internal_gzip = {
"float_val":tf.train.Feature(float_list=tf.train.FloatList(value=[6.66])),
"width":tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
"height":tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
"image_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
}
# 编码内层字典形式数据
features_extern_none = tf.train.Features(feature=feature_internal_none)
features_extern_zlib = tf.train.Features(feature=feature_internal_zlib)
features_extern_gzip = tf.train.Features(feature=feature_internal_gzip)
# 将外层features生成特定格式的example
example_none = tf.train.Example(features=features_extern_none)
example_zlib = tf.train.Example(features=features_extern_zlib)
example_gzip = tf.train.Example(features=features_extern_gzip)
# example系列化字符串
example_str_none = example_none.SerializeToString()
example_str_zlib = example_zlib.SerializeToString()
example_str_gzip = example_gzip.SerializeToString()
# 将系列化字符串写入协议缓冲区
writer_none.write(example_str_none)
writer_zlib.write(example_str_zlib)
writer_gzip.write(example_str_gzip)
# 关闭TFRecords文件操作接口
writer_none.close()
writer_zlib.close()
writer_gzip.close()
print("finish to write data to tfrecord file!")

tfread_img.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python
# coding: utf-8

# In[5]:


import tensorflow as tf


# In[10]:


from IPython import display


# In[11]:


raw_image_dataset = tf.data.TFRecordDataset('img_compress_none.tfrecord')


# In[12]:


image_feature_description = {
    "float_val":tf.io.FixedLenFeature([], tf.float32),
    "width":tf.io.FixedLenFeature([], tf.int64),
    "height":tf.io.FixedLenFeature([], tf.int64),
    "image_raw":tf.io.FixedLenFeature([], tf.string)
}


# In[13]:


def _parse_image_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset


# In[14]:


for image_features in parsed_image_dataset:
  image_raw = image_features['image_raw'].numpy()
  display.display(display.Image(data=image_raw))


tensorflow1.x与tensorflow2.x兼容的两种方式:

  1. import tensorflow.compat.v1 as tf tf.disable_v2_behavior()
  2. tf.compat.v1.Session()
updatedupdated2020-04-292020-04-29