How to use the t5.data.utils.Task function in t5

To help you get started, we’ve selected a few t5 examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google-research / text-to-text-transfer-transformer / t5 / data / utils.py View on Github external
**task_kwargs)

  @property
  def splits(self):
    """Override since we can't call `info.splits` until after init."""
    return self._splits or self._tfds_dataset.info.splits

  @property
  def tfds_dataset(self):
    return self._tfds_dataset

  def num_input_examples(self, split):
    return self.tfds_dataset.size(split)


class TextLineTask(Task):
  """A `Task` that reads text lines as input.

  Requires a text_processor to be passed that takes a tf.data.Dataset of
  strings and returns a tf.data.Dataset of feature dictionaries.
  e.g. preprocessors.preprocess_tsv()
  """

  def __init__(
      self,
      name,
      split_to_filepattern,
      text_preprocessor,
      sentencepiece_model_path,
      metric_fns,
      skip_header_lines=0,
      **task_kwargs):
github google-research / text-to-text-transfer-transformer / t5 / data / utils.py View on Github external
def get_subtasks(task_or_mixture):
  """Returns all the Tasks in a Mixture as a list or the Task itself."""
  if isinstance(task_or_mixture, Task):
    return [task_or_mixture]
  else:
    return task_or_mixture.tasks
github google-research / text-to-text-transfer-transformer / t5 / data / utils.py View on Github external
"%s-*-of-*%d" % (
            get_tfrecord_prefix(self.cache_dir, split),
            split_info["num_shards"]),
        shuffle=shuffle)
    ds = ds.interleave(
        tf.data.TFRecordDataset,
        cycle_length=16, block_length=16,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.map(lambda ex: tf.parse_single_example(ex, feature_desc),
                num_parallel_calls=tf.data.experimental.AUTOTUNE)
    if self.get_cached_stats(split)["examples"] <= _MAX_EXAMPLES_TO_MEM_CACHE:
      ds = ds.cache()
    return ds


class TfdsTask(Task):
  """A `Task` that uses TensorFlow Datasets to provide the input dataset."""

  def __init__(
      self,
      name,
      tfds_name,
      text_preprocessor,
      sentencepiece_model_path,
      metric_fns,
      tfds_data_dir=None,
      splits=None,
      **task_kwargs):
    """TfdsTask constructor.

    Args:
      name: string, a unique name for the Task. A ValueError will be raised if