How to use tfrecord - 10 common examples

To help you get started, we’ve selected a few tfrecord examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github jinnovation / rainy-image-dataset / tfrecord.py View on Github external
import tensorflow as tf


class ExitOnExceptionHandler(logging.StreamHandler):
    def __init__(self, critical_levels, *args, **kwargs):
        self.lvls = critical_levels
        super().__init__(*args, **kwargs)

    def emit(self, record):
        if record.levelno in self.lvls:
            raise SystemExit(-1)


logger = logging.getLogger(__name__)
click_log.basic_config(logger)
logger.handlers.append(ExitOnExceptionHandler([logging.CRITICAL]))

tf.enable_eager_execution()


def indices_all(ground_truth_dir):
    return [
        os.path.splitext(os.path.basename(f))[0]
        for f in glob.glob(os.path.join(ground_truth_dir, "*.jpg"))
    ]


def serialize_example(f_in, f_out, is_strict=True):
    def _bytes(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    def _int(value):
github vahidk / tfrecord / tfrecord / writer.py View on Github external
            "int": lambda f: example_pb2.Feature(
                int64_list=example_pb2.Int64List(value=f))
        }
github vahidk / tfrecord / tfrecord / writer.py View on Github external
            "byte": lambda f: example_pb2.Feature(
                bytes_list=example_pb2.BytesList(value=f)),
            "float": lambda f: example_pb2.Feature(
github vahidk / tfrecord / tfrecord / writer.py View on Github external
            "float": lambda f: example_pb2.Feature(
                float_list=example_pb2.FloatList(value=f)),
            "int": lambda f: example_pb2.Feature(
github vahidk / tfrecord / tfrecord / writer.py View on Github external
feature_map = {
            "byte": lambda f: example_pb2.Feature(
                bytes_list=example_pb2.BytesList(value=f)),
            "float": lambda f: example_pb2.Feature(
                float_list=example_pb2.FloatList(value=f)),
            "int": lambda f: example_pb2.Feature(
                int64_list=example_pb2.Int64List(value=f))
        }

        def serialize(value, dtype):
            if not isinstance(value, (list, tuple, np.ndarray)):
                value = [value]
            return feature_map[dtype](value)

        features = {key: serialize(value, dtype) for key, (value, dtype) in datum.items()}
        example_proto = example_pb2.Example(features=example_pb2.Features(feature=features))
        return example_proto.SerializeToString()
github vahidk / tfrecord / tfrecord / reader.py View on Github external
-------
    features: dict of {str, np.ndarray}
        Decoded bytes of the features into its respective data type (for
        an individual record).
    """

    typename_mapping = {
        "byte": "bytes_list",
        "float": "float_list",
        "int": "int64_list"
    }

    record_iterator = tfrecord_iterator(data_path, index_path, shard)

    for record in record_iterator:
        example = example_pb2.Example()
        example.ParseFromString(record)

        all_keys = list(example.features.feature.keys())
        if description is None:
            description = dict.fromkeys(all_keys, None)
        elif isinstance(description, list):
            description = dict.fromkeys(description, None)

        features = {}
        for key, typename in description.items():
            if key not in all_keys:
                raise KeyError(f"Key {key} doesn't exist (select from {all_keys})!")
            # NOTE: We assume that each key in the example has only one field
            # (either "bytes_list", "float_list", or "int64_list")!
            field = example.features.feature[key].ListFields()[0]
            inferred_typename, value = field[0].name, field[1].value
github vahidk / tfrecord / tfrecord / writer.py View on Github external
"float": lambda f: example_pb2.Feature(
                float_list=example_pb2.FloatList(value=f)),
            "int": lambda f: example_pb2.Feature(
github vahidk / tfrecord / tfrecord / writer.py View on Github external
"byte": lambda f: example_pb2.Feature(
                bytes_list=example_pb2.BytesList(value=f)),
            "float": lambda f: example_pb2.Feature(
github vahidk / tfrecord / tfrecord / writer.py View on Github external
"int": lambda f: example_pb2.Feature(
                int64_list=example_pb2.Int64List(value=f))
        }
github vahidk / tfrecord / tfrecord / torch / dataset.py View on Github external
def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            np.random.seed(worker_info.seed % np.iinfo(np.uint32).max)
        it = reader.multi_tfrecord_loader(
            self.data_pattern, self.index_pattern, self.splits, self.description)
        if self.shuffle_queue_size:
            it = iterator_utils.shuffle_iterator(it, self.shuffle_queue_size)
        if self.transform:
            it = map(self.transform, it)
        return it