• 为了保证你在浏览本网站时有着更好的体验,建议使用类似Chrome、Firefox之类的浏览器~~
    • 如果你喜欢本站的内容何不Ctrl+D收藏一下呢,与大家一起分享各种编程知识~
    • 本网站研究机器学习、计算机视觉、模式识别~当然不局限于此,生命在于折腾,何不年轻时多折腾一下

Keras学习-0x06-Tfrecord相关

Keras admin 1个月前 (08-20) 130次浏览 0个评论 扫描二维码

前面的描述中讲完了序列模型和函数式模型的理论,对于 keras 而言后面所有模型代码的实现都是基于这两种方式来实现,所以这也是有讨论,接下来就是要从每一点去学习,自己计划的方式是从 input 到 output 依次来学习,所以最先开始讲的就是输入数据相关,总的顺序是这样,每一步都可以发散很多点来了解。先从 tfrecord 学习开始。

tfrecord 也是官方推荐的一种数据存储方式,也是基于 pb 协议存储的方式。为什么推荐使用 tfrecord?现在大部分场景下数据量都是很大的,所以你打算使用基于内存型的方式那么肯定不行的,一下子将数据塞入内存效率比较低,那你肯定想说我一点一点的写入到内存基于 batch 方式来训练模型,但是也要兼顾效率的问题。

使用 tfrecord 的方式可以借助多线程 io 并行读取数据用来训练,tf.data 对此也有很好的处理,使用 pipline 加速数据的加载处理。

TFRecord

整体上建立 TFRecord 文件的流程主要如下;

  • 在 TFRecord 数据文件中,任何数据都是以 bytes 列表或 float 列表或 int64 列表的形式存储(注意:是列表形式),因此,将每条数据转化为列表格式。
  • 创建的每条数据列表都必须由一个 Feature 类包装,并且,每个 feature 都存储在一个 key-value 键值对中,其中 key 对应每个 feature 的名称。这些 key 将在后面从 TFRecord 提取数据时使用。
  • 当所需的字典创建完之后,会传递给 Features 类。
  • 最后,将 features 对象作为输入传递给 example 类,然后这个 example 类对象会被追加到 TFRecord 中。
  • 对于所有数据,重复上述过程。

TFRecord 建立

下面给出一个实例来说明 tfrecord

import  tensorflow as tf
data_arr = [
    {
        'clothes_category': 10, # 整型
        'clothes_prices':100.6, #浮点型
        'clothes_name':'jack jones'.encode(), # 字符串型,python3 下转化为 byte
        'clothes_topic':[110,120,78]  # 列表型
    },
    {clothes_category': 11, # 整型 
'clothes_prices':101.6, #浮点型 
'clothes_name':'cat'.encode(), # 字符串型,python3 下转化为 byte 
'clothes_topic':[89,130,87,522] # 列表型
    }
]

上面两条数据列了四个字段,第一个是衣服的分类数据,就是常见的 category 特征,第二个是衣服的价格,是个浮点型数据,

第三个是字符串类型,一般情况下可以经过哈希处理,第四个是衣服的 topic 主题向量,这是自己瞎编的。注意一点第四个特征是变长的。

完整的代码程序如下所示

# -*- coding: utf-8 -*-
# @Time    : 2019-08-20 23:13
# @Author  : zhusimaji
# @File    : gen_tfrecord.py
# @Software: PyCharm


import  tensorflow as tf
data_arr = [
    {
        'clothes_category': 10, # 整型
        'clothes_prices':100.6, #浮点型
        'clothes_name':'jack jones'.encode(), # 字符串型,python3 下转化为 byte
        'clothes_topic':[110,120,78]  # 列表型
    },
    {'clothes_category': 11, # 整型
'clothes_prices':101.6, #浮点型
'clothes_name':'cat'.encode(), # 字符串型,python3 下转化为 byte
'clothes_topic':[89,130,87,522] # 列表型
    }
]

def get_example_object(data_record):
    # 将数据转化为 int64 float 或 bytes 类型的列表
    # 注意都是 list 形式
    int_list1 = tf.train.Int64List(value = [data_record['clothes_category']])
    float_list1 = tf.train.FloatList(value = [data_record['clothes_prices']])
    str_list1 = tf.train.BytesList(value = [data_record['clothes_name']])
    float_list2 = tf.train.FloatList(value = data_record['clothes_topic'])
    feature_key_value_pair = {
        'clothes_category':tf.train.Feature(int64_list = int_list1),
        'clothes_prices': tf.train.Feature(float_list=float_list1),
        'clothes_name': tf.train.Feature(bytes_list=str_list1),
        'clothes_topic': tf.train.Feature(float_list=float_list2),
    }
    # 创建一个 features
    features = tf.train.Features(feature=feature_key_value_pair)
    # 创建一个 example
    example = tf.train.Example(features=features)
    return example
with tf.python_io.TFRecordWriter('./resources/clothes.tfrecord') as tfwriter:
    #遍历所有数据
    for data_record in data_arr:
        example = get_example_object(data_record)
        # 写入 tfrecord 数据文件
        tfwriter.write(example.SerializeToString())

此时你发现 tfrecord 数据已经生成了。上面的模式其实跟 python 正常读写文件的方式是一样的,只是写数据的内容不一样。

上面是借助 python io 方法来实现 tfrecord 建立,你也可以使用 tf.data 来创建 tfrecord,tf.data 还是很强大的,后面关于的数据流这块应该都会使用到它。

下一节再去学习读 tfrecord,可以的话在把 tf.data 生成 tfrecord 方法也写出来。


Deeplearn, 版权所有丨如未注明 , 均为原创丨本网站采用BY-NC-SA协议进行授权 , 转载请注明Keras 学习-0x06-Tfrecord 相关
喜欢 (0)
admin
关于作者:
互联网行业码农一枚/业余铲屎官/数码影音爱好者/二次元

您必须 登录 才能发表评论!