How to use the t5.data.get_mixture_or_task function in t5

To help you get started, we’ve selected a few t5 examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google-research / text-to-text-transfer-transformer / t5 / models / mesh_transformer.py View on Github external
mixture_or_task_name: string, an identifier for a Mixture or Task in the
      appropriate registry. Must be specified via gin.
    sequence_length: dict mapping feature key to the int length for that feature
      the max sequence length.
    vocabulary: a SentencePieceVocabulary.
    dataset_split: string, which split of the dataset to load. In most cases
      this should be "train".
    use_cached: bool, whether to load the cached version of this dataset.

  Returns:
    A tf.data.Dataset of preprocessed, tokenized, and batched examples.
  """
  if not isinstance(vocabulary, t5.data.SentencePieceVocabulary):
    raise ValueError("vocabulary must be a SentencePieceVocabulary")

  mixture_or_task = t5.data.get_mixture_or_task(mixture_or_task_name)

  ds = mixture_or_task.get_dataset(
      sequence_length, split=dataset_split, use_cached=use_cached, shuffle=True)
  ds = transformer_dataset.pack_or_pad(
      ds, sequence_length, pack=True,
      feature_keys=tuple(mixture_or_task.output_features), ensure_eos=True)
  return ds