TFrecord, from_generator

2023. 2. 18. 04:43Python/- Tensorflow

TFrecord

TFrecord file is binary data format for saving data, it serialize data and save file using google protocol buffer format.

TFrecord can help decrease data load time.

 

Let's know how to use TFrecord following example.

 

How to use TFrecord

I use voc data from tfds, you can follow example without extra downloading

 

all code

더보기
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import os, shutil, sys
import xml.etree.ElementTree as ET
import tqdm
from config import *
from PIL import Image
from utils import anchor_utils, io_utils

class Dataset():
    def __init__(self, split, dtype=DTYPE, batch_size=BATCH_SIZE, anchors=ANCHORS, strides=STRIDES,
                 labels=LABELS, image_size=IMAGE_SIZE, num_classes=NUM_CLASSES):
        self.split = split
        self.dtype = dtype
        self.anchors = np.array(anchors)
        self.num_anchors = len(anchors)
        self.batch_size = batch_size
        self.labels = labels
        self.strides = np.array(strides)
        self.image_size = image_size
        self.num_classes = num_classes
        self.data = []

    def load(self):
        assert self.split in ['train', 'valid', 'test'], "Check your dataset type and split."
        self._download_dataset()
        self.read_files()

    def generator(self):
        for image_file, labels in self.data:
            image = self.read_image(image_file)
            yield image, labels

    def read_image(self, image_file):
        return np.array(Image.open(image_file))

    def read_files(self):
        directories = self.get_load_directory(self.split)
        print('Reading local_files...  ', end='', flush=True)
        for directory in directories:
            directory += '/VOCdevkit/'
            directory += os.listdir(directory)[0]
            image_dir = directory + '/JPEGImages/'
            anno_dir = directory + '/Annotations/'

            image_files = os.listdir(image_dir)
            anno_files = os.listdir(anno_dir) if self.split != 'test' else []
            for i in range(len(image_files)):
                image_file = image_dir + image_files[i]

                if self.split == 'test':
                    labels = []
                else:
                    anno_file = anno_dir + anno_files[i]
                    labels, width, height = self.parse_annotation(anno_file)
                self.data.append([image_file, labels, width, height])
        np.random.shuffle(self.data)
        print('Done!')
        
    def _download_dataset(self):
        if not os.path.exists("./voc"):
            os.mkdir("./voc")
            tfds.load('voc/2012', data_dir='./voc')
            tfds.load('voc/2008', data_dir='./voc')
            if os.path.exists("./voc/voc"):
                shutil.rmtree("./voc/voc")
            for file in os.listdir("./voc/downloads/"):
                if file.endswith(".tar") or file.endswith(".INFO"):
                    os.remove("./voc/downloads/"+file)

    def get_load_directory(self, split):
        load_directory=[]
        extracted_dir = "./voc/downloads/extracted/"
        for dir in os.listdir(extracted_dir):
            if split=="train":
                if "tra" in dir:
                    load_directory.append(extracted_dir + dir)
            elif split=="valid":
                if "2007" in dir and "test" in dir:
                    load_directory.append(extracted_dir + dir)
            elif split=="test":
                if "2012" in dir and "test" in dir:
                    load_directory.append(extracted_dir + dir)
        return load_directory

    def parse_annotation(self, anno_path):
        tree = ET.parse(anno_path)
        labels = []
        
        for elem in tree.iter():
            if "width" in elem.tag:
                width = float(elem.text)
            elif "height" in elem.tag:
                height = float(elem.text)
            elif "object" in elem.tag:
                for attr in list(elem):
                    if "name" in attr.tag:
                        label = float(self.labels.index(attr.text))
                    elif "bndbox" in attr.tag:
                        for dim in list(attr):
                            if "xmin" in dim.tag:
                                xmin = float(dim.text)
                            elif "ymin" in dim.tag:
                                ymin = float(dim.text)
                            elif "xmax" in dim.tag:
                                xmax = float(dim.text)
                            elif "ymax" in dim.tag:
                                ymax = float(dim.text)
                        labels.append([xmin, ymin, xmax, ymax, label])
        return labels, width, height
       
    def make_tfrecord(self):
        filepath = f'./{self.dtype}/{self.split}.tfrecord'
        infopath = f'./{self.dtype}/{self.split}.txt'
        
        if os.path.exists(filepath):
            print(f'{filepath} is exist')
            return 
        
        self.load()

        print(f'Start make {filepath}......      ', end='', flush=True)
        with tf.io.TFRecordWriter(filepath) as writer:
            for image_file, labels, width, height in tqdm.tqdm(self.data):
                image = self.read_image(image_file)
                writer.write(_data_features(image, labels, width, height))
        print('Done!')
        if self.new_anchors:
            print('Anchors are changed. You need to restart file!')
            print("If you don't make new tfrecord, Anchors are't changed")
            sys.exit()

    def read_tfrecord(self):
        filepath = f'./{self.dtype}/{self.split}.tfrecord'
        dataset =  tf.data.TFRecordDataset(filepath, num_parallel_reads=-1) \
                        .map(self.parse_tfrecord_fn)
        return dataset

    def parse_tfrecord_fn(self, example):
        feature_description={
            'image': tf.io.FixedLenFeature([], tf.string),
            'labels': tf.io.VarLenFeature(tf.float32),
            'width': tf.io.FixedLenFeature([], tf.float32),
            'height': tf.io.FixedLenFeature([], tf.float32)
        }
        example = tf.io.parse_single_example(example, feature_description)
        example['image'] = tf.io.decode_jpeg(example['image'], channels=3)
        example['labels'] = tf.reshape(tf.sparse.to_dense(example['labels']), (-1, 5))

        return example['image'], example['labels'], example['width'], example['height']


def _image_feature(value):
    return _bytes_feature(tf.io.encode_jpeg(value).numpy())
def _array_feature(value):
    if 'float' in value.dtype.name:
        return _float_feature(np.reshape(value, (-1)))
    elif 'int' in value.dtype.name:
        return _int64_feature(np.reshape(value, (-1)))
    raise Exception(f"Wrong array dtype: {value.dtype}")
def _string_feature(value):
    return _bytes_feature(value.encode('utf-8'))
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
    if type(value) == float:
        value=[value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
    if type(value) == int:
        value=[value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _data_features(image, labels, width, height):
    image_feature = _image_feature(image)
    labels_feature = _array_feature(np.array(labels))
    width_feature = _float_feature(width)
    height_feature = _float_feature(height)
    
    objects_features = {
        'image': image_feature,
        'labels': labels_feature,
        'width': width_feature,
        'height': height_feature
    }      
    example=tf.train.Example(features=tf.train.Features(feature=objects_features))
    return example.SerializeToString()

 

Make TFrecord

TFrecord data type

TFrecord use these data type

 

1. tf.train.ByteList

    - string

    - byte

2. tf.train.FloatList

    - float

    - double

3. tf.train.Int64List

    - bool

    - enum

    - int32

    - uint32

    - int64

    - uint64

 

I will use some data type.

 

Byteslist 

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

We can use bytelist for string, image, audio, etc. 

But we need convert these datatype to binary data type.

def _image_feature(value):
    return _bytes_feature(tf.io.encode_jpeg(value).numpy())
def _string_feature(value):
    return _bytes_feature(value.encode('utf-8'))

These functions can help you converting data.

 

FloatList

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

We can use floatList for float type data. If you want to use float type variable, you must wrap data using list.

 

IntList

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

We can use intList for int type data. Also if you want to use int type variable, you must wrap data using list.

 

tf.train.Example

We convert original data to binary data. Additionaly We need to make each datas to object data using tf.train.example.

def _data_features(image, labels, width, height):
    image_feature = _image_feature(image)
    labels_feature = _array_feature(np.array(labels))
    width_feature = _float_feature(width)
    height_feature = _float_feature(height)
    
    objects_features = {
        'image': image_feature,
        'labels': labels_feature,
        'width': width_feature,
        'height': height_feature
    }      
    example=tf.train.Example(features=tf.train.Features(feature=objects_features))
    return example.SerializeToString()

 

Make TFrecord

Now let's make TFrecord data.

with tf.io.TFRecordWriter(filepath) as writer:
	for image_file, labels, width, height in tqdm.tqdm(self.data):
	image = self.read_image(image_file)
	writer.write(_data_features(image, labels, width, height))

You can customize code with any data.

 

Convert TFrecord to tf.data.Dataset

def read_tfrecord():
	filepath = f'{filename}.tfrecord'
	dataset =  tf.data.TFRecordDataset(filepath, num_parallel_reads=-1) \
                        .map(self.parse_tfrecord_fn)
	return dataset

def parse_tfrecord_fn(example):
	feature_description={
	'image': tf.io.FixedLenFeature([], tf.string),
	'labels': tf.io.VarLenFeature(tf.float32),
	'width': tf.io.FixedLenFeature([], tf.float32),
	'height': tf.io.FixedLenFeature([], tf.float32)
	}
	example = tf.io.parse_single_example(example, feature_description)
	example['image'] = tf.io.decode_jpeg(example['image'], channels=3)
	example['labels'] = tf.reshape(tf.sparse.to_dense(example['labels']), (-1, 5))

	return example['image'], example['labels'], example['width'], example['height']

You can read tfrecord using tf.data.Dataset. You must parse tfrecord before use data.

 

from_generator

You have another option using data. from_generator help to make dataset without TFrecord.

First you need to load raw data, and make generator.

def generator():
    for image_file, labels in self.data:
        image = read_image(image_file)
        yield image, labels

If you make generator you can use grom_generator

data = tf.data.Dataset.from_generator(data_gen, (tf.uint8, tf.float32))

 

Result

Both ways use tf.data. But these ways don't have same performance.

I compare these ways voc dataset. TFrecord takes 10% process time less than from_generator.

'Python > - Tensorflow' 카테고리의 다른 글

TFDS - custom dataset(images, xmls)  (0) 2022.10.17