2023. 2. 18. 04:43ㆍPython/- Tensorflow
TFrecord
TFrecord file is binary data format for saving data, it serialize data and save file using google protocol buffer format.
TFrecord can help decrease data load time.
Let's know how to use TFrecord following example.
How to use TFrecord
I use voc data from tfds, you can follow example without extra downloading
all code
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import os, shutil, sys
import xml.etree.ElementTree as ET
import tqdm
from config import *
from PIL import Image
from utils import anchor_utils, io_utils
class Dataset():
def __init__(self, split, dtype=DTYPE, batch_size=BATCH_SIZE, anchors=ANCHORS, strides=STRIDES,
labels=LABELS, image_size=IMAGE_SIZE, num_classes=NUM_CLASSES):
self.split = split
self.dtype = dtype
self.anchors = np.array(anchors)
self.num_anchors = len(anchors)
self.batch_size = batch_size
self.labels = labels
self.strides = np.array(strides)
self.image_size = image_size
self.num_classes = num_classes
self.data = []
def load(self):
assert self.split in ['train', 'valid', 'test'], "Check your dataset type and split."
self._download_dataset()
self.read_files()
def generator(self):
for image_file, labels in self.data:
image = self.read_image(image_file)
yield image, labels
def read_image(self, image_file):
return np.array(Image.open(image_file))
def read_files(self):
directories = self.get_load_directory(self.split)
print('Reading local_files... ', end='', flush=True)
for directory in directories:
directory += '/VOCdevkit/'
directory += os.listdir(directory)[0]
image_dir = directory + '/JPEGImages/'
anno_dir = directory + '/Annotations/'
image_files = os.listdir(image_dir)
anno_files = os.listdir(anno_dir) if self.split != 'test' else []
for i in range(len(image_files)):
image_file = image_dir + image_files[i]
if self.split == 'test':
labels = []
else:
anno_file = anno_dir + anno_files[i]
labels, width, height = self.parse_annotation(anno_file)
self.data.append([image_file, labels, width, height])
np.random.shuffle(self.data)
print('Done!')
def _download_dataset(self):
if not os.path.exists("./voc"):
os.mkdir("./voc")
tfds.load('voc/2012', data_dir='./voc')
tfds.load('voc/2008', data_dir='./voc')
if os.path.exists("./voc/voc"):
shutil.rmtree("./voc/voc")
for file in os.listdir("./voc/downloads/"):
if file.endswith(".tar") or file.endswith(".INFO"):
os.remove("./voc/downloads/"+file)
def get_load_directory(self, split):
load_directory=[]
extracted_dir = "./voc/downloads/extracted/"
for dir in os.listdir(extracted_dir):
if split=="train":
if "tra" in dir:
load_directory.append(extracted_dir + dir)
elif split=="valid":
if "2007" in dir and "test" in dir:
load_directory.append(extracted_dir + dir)
elif split=="test":
if "2012" in dir and "test" in dir:
load_directory.append(extracted_dir + dir)
return load_directory
def parse_annotation(self, anno_path):
tree = ET.parse(anno_path)
labels = []
for elem in tree.iter():
if "width" in elem.tag:
width = float(elem.text)
elif "height" in elem.tag:
height = float(elem.text)
elif "object" in elem.tag:
for attr in list(elem):
if "name" in attr.tag:
label = float(self.labels.index(attr.text))
elif "bndbox" in attr.tag:
for dim in list(attr):
if "xmin" in dim.tag:
xmin = float(dim.text)
elif "ymin" in dim.tag:
ymin = float(dim.text)
elif "xmax" in dim.tag:
xmax = float(dim.text)
elif "ymax" in dim.tag:
ymax = float(dim.text)
labels.append([xmin, ymin, xmax, ymax, label])
return labels, width, height
def make_tfrecord(self):
filepath = f'./{self.dtype}/{self.split}.tfrecord'
infopath = f'./{self.dtype}/{self.split}.txt'
if os.path.exists(filepath):
print(f'{filepath} is exist')
return
self.load()
print(f'Start make {filepath}...... ', end='', flush=True)
with tf.io.TFRecordWriter(filepath) as writer:
for image_file, labels, width, height in tqdm.tqdm(self.data):
image = self.read_image(image_file)
writer.write(_data_features(image, labels, width, height))
print('Done!')
if self.new_anchors:
print('Anchors are changed. You need to restart file!')
print("If you don't make new tfrecord, Anchors are't changed")
sys.exit()
def read_tfrecord(self):
filepath = f'./{self.dtype}/{self.split}.tfrecord'
dataset = tf.data.TFRecordDataset(filepath, num_parallel_reads=-1) \
.map(self.parse_tfrecord_fn)
return dataset
def parse_tfrecord_fn(self, example):
feature_description={
'image': tf.io.FixedLenFeature([], tf.string),
'labels': tf.io.VarLenFeature(tf.float32),
'width': tf.io.FixedLenFeature([], tf.float32),
'height': tf.io.FixedLenFeature([], tf.float32)
}
example = tf.io.parse_single_example(example, feature_description)
example['image'] = tf.io.decode_jpeg(example['image'], channels=3)
example['labels'] = tf.reshape(tf.sparse.to_dense(example['labels']), (-1, 5))
return example['image'], example['labels'], example['width'], example['height']
def _image_feature(value):
return _bytes_feature(tf.io.encode_jpeg(value).numpy())
def _array_feature(value):
if 'float' in value.dtype.name:
return _float_feature(np.reshape(value, (-1)))
elif 'int' in value.dtype.name:
return _int64_feature(np.reshape(value, (-1)))
raise Exception(f"Wrong array dtype: {value.dtype}")
def _string_feature(value):
return _bytes_feature(value.encode('utf-8'))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
if type(value) == float:
value=[value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
if type(value) == int:
value=[value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _data_features(image, labels, width, height):
image_feature = _image_feature(image)
labels_feature = _array_feature(np.array(labels))
width_feature = _float_feature(width)
height_feature = _float_feature(height)
objects_features = {
'image': image_feature,
'labels': labels_feature,
'width': width_feature,
'height': height_feature
}
example=tf.train.Example(features=tf.train.Features(feature=objects_features))
return example.SerializeToString()
Make TFrecord
TFrecord data type
TFrecord use these data type
1. tf.train.ByteList
- string
- byte
2. tf.train.FloatList
- float
- double
3. tf.train.Int64List
- bool
- enum
- int32
- uint32
- int64
- uint64
I will use some data type.
Byteslist
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
We can use bytelist for string, image, audio, etc.
But we need convert these datatype to binary data type.
def _image_feature(value):
return _bytes_feature(tf.io.encode_jpeg(value).numpy())
def _string_feature(value):
return _bytes_feature(value.encode('utf-8'))
These functions can help you converting data.
FloatList
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
We can use floatList for float type data. If you want to use float type variable, you must wrap data using list.
IntList
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
We can use intList for int type data. Also if you want to use int type variable, you must wrap data using list.
tf.train.Example
We convert original data to binary data. Additionaly We need to make each datas to object data using tf.train.example.
def _data_features(image, labels, width, height):
image_feature = _image_feature(image)
labels_feature = _array_feature(np.array(labels))
width_feature = _float_feature(width)
height_feature = _float_feature(height)
objects_features = {
'image': image_feature,
'labels': labels_feature,
'width': width_feature,
'height': height_feature
}
example=tf.train.Example(features=tf.train.Features(feature=objects_features))
return example.SerializeToString()
Make TFrecord
Now let's make TFrecord data.
with tf.io.TFRecordWriter(filepath) as writer:
for image_file, labels, width, height in tqdm.tqdm(self.data):
image = self.read_image(image_file)
writer.write(_data_features(image, labels, width, height))
You can customize code with any data.
Convert TFrecord to tf.data.Dataset
def read_tfrecord():
filepath = f'{filename}.tfrecord'
dataset = tf.data.TFRecordDataset(filepath, num_parallel_reads=-1) \
.map(self.parse_tfrecord_fn)
return dataset
def parse_tfrecord_fn(example):
feature_description={
'image': tf.io.FixedLenFeature([], tf.string),
'labels': tf.io.VarLenFeature(tf.float32),
'width': tf.io.FixedLenFeature([], tf.float32),
'height': tf.io.FixedLenFeature([], tf.float32)
}
example = tf.io.parse_single_example(example, feature_description)
example['image'] = tf.io.decode_jpeg(example['image'], channels=3)
example['labels'] = tf.reshape(tf.sparse.to_dense(example['labels']), (-1, 5))
return example['image'], example['labels'], example['width'], example['height']
You can read tfrecord using tf.data.Dataset. You must parse tfrecord before use data.
from_generator
You have another option using data. from_generator help to make dataset without TFrecord.
First you need to load raw data, and make generator.
def generator():
for image_file, labels in self.data:
image = read_image(image_file)
yield image, labels
If you make generator you can use grom_generator
data = tf.data.Dataset.from_generator(data_gen, (tf.uint8, tf.float32))
Result
Both ways use tf.data. But these ways don't have same performance.
I compare these ways voc dataset. TFrecord takes 10% process time less than from_generator.
'Python > - Tensorflow' 카테고리의 다른 글
| TFDS - custom dataset(images, xmls) (0) | 2022.10.17 |
|---|