How to use the t5.data.utils.TaskRegistry function in t5

To help you get started, we’ve selected a few t5 examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google-research / text-to-text-transfer-transformer / t5 / data / tasks.py View on Github external
("cola", [metrics.matthews_corrcoef]),
    ("sst2", [metrics.accuracy]),
    ("mrpc", [metrics.f1_score_with_invalid, metrics.accuracy]),
    ("stsb", [metrics.pearson_corrcoef, metrics.spearman_corrcoef]),
    ("qqp", [metrics.f1_score_with_invalid, metrics.accuracy]),
    ("mnli", [metrics.accuracy]),
    ("mnli_matched", [metrics.accuracy]),
    ("mnli_mismatched", [metrics.accuracy]),
    ("qnli", [metrics.accuracy]),
    ("rte", [metrics.accuracy]),
    ("wnli", [metrics.accuracy]),
    ("ax", []),  # Only test set available.
])

for b in tfds.text.glue.Glue.builder_configs.values():
  TaskRegistry.add(
      "glue_%s_v002" % b.name,
      TfdsTask,
      tfds_name="glue/%s:%s" % (b.name, "1.0.0" if b.name == "ax" else "0.0.2"),
      text_preprocessor=_get_glue_text_preprocessor(b),
      metric_fns=GLUE_METRICS[b.name],
      sentencepiece_model_path=DEFAULT_SPM_PATH,
      postprocess_fn=_get_glue_postprocess_fn(b),
      splits=["test"] if b.name == "ax" else None,
  )

# =============================== CNN DailyMail ================================
TaskRegistry.add(
    "cnn_dailymail_v002",
    TfdsTask,
    tfds_name="cnn_dailymail/plain_text:0.0.2",
    text_preprocessor=functools.partial(preprocessors.summarize,
github google-research / text-to-text-transfer-transformer / t5 / data / utils.py View on Github external
def get_mixture_or_task(task_or_mixture_name):
  """Return the Task or Mixture from the appropriate registry."""
  mixtures = MixtureRegistry.names()
  tasks = TaskRegistry.names()
  if task_or_mixture_name in mixtures:
    if task_or_mixture_name in tasks:
      logging.warning("%s is both a Task and a Mixture, returning Mixture",
                      task_or_mixture_name)
    return MixtureRegistry.get(task_or_mixture_name)
  if task_or_mixture_name in tasks:
    return TaskRegistry.get(task_or_mixture_name)
  else:
    raise ValueError("No Task or Mixture found with name: %s" %
                     task_or_mixture_name)
github google-research / text-to-text-transfer-transformer / t5 / data / utils.py View on Github external
def add(cls, name, task_cls=Task, **kwargs):
    super(TaskRegistry, cls).add(name, task_cls, name, **kwargs)
github google-research / text-to-text-transfer-transformer / t5 / data / tasks.py View on Github external
from t5.data import postprocessors
from t5.data import preprocessors
from t5.data.utils import DEFAULT_SPM_PATH
from t5.data.utils import set_global_cache_dirs
from t5.data.utils import TaskRegistry
from t5.data.utils import TfdsTask
from t5.evaluation import metrics
import tensorflow_datasets as tfds




# ==================================== C4 ======================================
_c4_config_suffixes = ["", ".noclean", ".realnewslike", ".webtextlike"]
for config_suffix in _c4_config_suffixes:
  TaskRegistry.add(
      "c4{name}_v020_unsupervised".format(
          name=config_suffix.replace(".", "_")),
      TfdsTask,
      tfds_name="c4/en{config}:1.0.0".format(config=config_suffix),
      text_preprocessor=functools.partial(
          preprocessors.rekey, key_map={"inputs": None, "targets": "text"}),
      token_preprocessor=preprocessors.unsupervised,
      sentencepiece_model_path=DEFAULT_SPM_PATH,
      metric_fns=[])

# ================================ Wikipedia ===================================
TaskRegistry.add(
    "wikipedia_20190301.en_v003_unsupervised",
    TfdsTask,
    # 0.0.4 is identical to 0.0.3 except empty records removed.
    tfds_name="wikipedia/20190301.en:0.0.4",