How to use the t5.data.utils.TaskRegistry.add function in t5

To help you get started, we’ve selected a few t5 examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google-research / text-to-text-transfer-transformer / t5 / data / tasks.py View on Github external
metric_fns=[metrics.qa],
    sentencepiece_model_path=DEFAULT_SPM_PATH)

# Maximized evaluation metrics over all answers.
TaskRegistry.add(
    "squad_v010_context_free",
    TfdsTask,
    tfds_name="squad/plain_text:0.1.0",
    text_preprocessor=functools.partial(
        preprocessors.squad, include_context=False),
    postprocess_fn=postprocessors.qa,
    metric_fns=[metrics.qa],
    sentencepiece_model_path=DEFAULT_SPM_PATH)

# Squad span prediction task instead of text.
TaskRegistry.add(
    "squad_v010_allanswers_span",
    TfdsTask,
    tfds_name="squad/plain_text:0.1.0",
    text_preprocessor=preprocessors.squad_span_space_tokenized,
    postprocess_fn=postprocessors.span_qa,
    metric_fns=[metrics.span_qa],
    sentencepiece_model_path=DEFAULT_SPM_PATH)

# Deprecated: Use `squad_v010_allanswers` instead.
TaskRegistry.add(
    "squad_v010",
    TfdsTask,
    tfds_name="squad/plain_text:0.1.0",
    text_preprocessor=preprocessors.squad,
    metric_fns=[metrics.qa],
    sentencepiece_model_path=DEFAULT_SPM_PATH)
github google-research / text-to-text-transfer-transformer / t5 / data / tasks.py View on Github external
splits=["validation", "test"])

# =================================== WNLI =====================================
TaskRegistry.add(
    "glue_wnli_v002_simple_eval",
    TfdsTask,
    tfds_name="glue/wnli:0.0.2",
    text_preprocessor=preprocessors.wnli_simple,
    postprocess_fn=postprocessors.wsc_simple,
    metric_fns=[metrics.accuracy],
    sentencepiece_model_path=DEFAULT_SPM_PATH,
    splits=["validation", "test"])

# =================================== Squad ====================================
# Maximized evaluation metrics over all answers.
TaskRegistry.add(
    "squad_v010_allanswers",
    TfdsTask,
    tfds_name="squad/plain_text:0.1.0",
    text_preprocessor=preprocessors.squad,
    postprocess_fn=postprocessors.qa,
    metric_fns=[metrics.qa],
    sentencepiece_model_path=DEFAULT_SPM_PATH)

# Maximized evaluation metrics over all answers.
TaskRegistry.add(
    "squad_v010_context_free",
    TfdsTask,
    tfds_name="squad/plain_text:0.1.0",
    text_preprocessor=functools.partial(
        preprocessors.squad, include_context=False),
    postprocess_fn=postprocessors.qa,
github google-research / text-to-text-transfer-transformer / t5 / data / tasks.py View on Github external
tfds_name="definite_pronoun_resolution/plain_text:0.0.1",
    text_preprocessor=preprocessors.definite_pronoun_resolution_simple,
    metric_fns=[metrics.accuracy],
    sentencepiece_model_path=DEFAULT_SPM_PATH)

# =================================== WSC ======================================
TaskRegistry.add(
    "super_glue_wsc_v102_simple_train",
    TfdsTask,
    tfds_name="super_glue/wsc.fixed:1.0.2",
    text_preprocessor=functools.partial(
        preprocessors.wsc_simple, correct_referent_only=True),
    metric_fns=[],
    sentencepiece_model_path=DEFAULT_SPM_PATH,
    splits=["train"])
TaskRegistry.add(
    "super_glue_wsc_v102_simple_eval",
    TfdsTask,
    tfds_name="super_glue/wsc.fixed:1.0.2",
    text_preprocessor=functools.partial(
        preprocessors.wsc_simple, correct_referent_only=False),
    postprocess_fn=postprocessors.wsc_simple,
    metric_fns=[metrics.accuracy],
    sentencepiece_model_path=DEFAULT_SPM_PATH,
    splits=["validation", "test"])

# =================================== WNLI =====================================
TaskRegistry.add(
    "glue_wnli_v002_simple_eval",
    TfdsTask,
    tfds_name="glue/wnli:0.0.2",
    text_preprocessor=preprocessors.wnli_simple,
github google-research / text-to-text-transfer-transformer / t5 / data / tasks.py View on Github external
_get_glue_text_preprocessor(b)
    ]
  else:
    text_preprocessor = _get_glue_text_preprocessor(b)
  TaskRegistry.add(
      "super_glue_%s_v102" % b.name,
      TfdsTask,
      tfds_name="super_glue/%s:1.0.2" % b.name,
      text_preprocessor=text_preprocessor,
      metric_fns=SUPERGLUE_METRICS[b.name],
      sentencepiece_model_path=DEFAULT_SPM_PATH,
      postprocess_fn=_get_glue_postprocess_fn(b),
      splits=["test"] if b.name in ["axb", "axg"] else None)

# ======================== Definite Pronoun Resolution =========================
TaskRegistry.add(
    "dpr_v001_simple",
    TfdsTask,
    tfds_name="definite_pronoun_resolution/plain_text:0.0.1",
    text_preprocessor=preprocessors.definite_pronoun_resolution_simple,
    metric_fns=[metrics.accuracy],
    sentencepiece_model_path=DEFAULT_SPM_PATH)

# =================================== WSC ======================================
TaskRegistry.add(
    "super_glue_wsc_v102_simple_train",
    TfdsTask,
    tfds_name="super_glue/wsc.fixed:1.0.2",
    text_preprocessor=functools.partial(
        preprocessors.wsc_simple, correct_referent_only=True),
    metric_fns=[],
    sentencepiece_model_path=DEFAULT_SPM_PATH,
github google-research / text-to-text-transfer-transformer / t5 / data / tasks.py View on Github external
continue
  if b.name == "axb":
    text_preprocessor = [
        functools.partial(
            preprocessors.rekey,
            key_map={
                "premise": "sentence1",
                "hypothesis": "sentence2",
                "label": "label",
                "idx": "idx",
            }),
        _get_glue_text_preprocessor(b)
    ]
  else:
    text_preprocessor = _get_glue_text_preprocessor(b)
  TaskRegistry.add(
      "super_glue_%s_v102" % b.name,
      TfdsTask,
      tfds_name="super_glue/%s:1.0.2" % b.name,
      text_preprocessor=text_preprocessor,
      metric_fns=SUPERGLUE_METRICS[b.name],
      sentencepiece_model_path=DEFAULT_SPM_PATH,
      postprocess_fn=_get_glue_postprocess_fn(b),
      splits=["test"] if b.name in ["axb", "axg"] else None)

# ======================== Definite Pronoun Resolution =========================
TaskRegistry.add(
    "dpr_v001_simple",
    TfdsTask,
    tfds_name="definite_pronoun_resolution/plain_text:0.0.1",
    text_preprocessor=preprocessors.definite_pronoun_resolution_simple,
    metric_fns=[metrics.accuracy],
github google-research / text-to-text-transfer-transformer / t5 / data / tasks.py View on Github external
metric_fns=[],
    sentencepiece_model_path=DEFAULT_SPM_PATH,
    splits=["train"])
TaskRegistry.add(
    "super_glue_wsc_v102_simple_eval",
    TfdsTask,
    tfds_name="super_glue/wsc.fixed:1.0.2",
    text_preprocessor=functools.partial(
        preprocessors.wsc_simple, correct_referent_only=False),
    postprocess_fn=postprocessors.wsc_simple,
    metric_fns=[metrics.accuracy],
    sentencepiece_model_path=DEFAULT_SPM_PATH,
    splits=["validation", "test"])

# =================================== WNLI =====================================
TaskRegistry.add(
    "glue_wnli_v002_simple_eval",
    TfdsTask,
    tfds_name="glue/wnli:0.0.2",
    text_preprocessor=preprocessors.wnli_simple,
    postprocess_fn=postprocessors.wsc_simple,
    metric_fns=[metrics.accuracy],
    sentencepiece_model_path=DEFAULT_SPM_PATH,
    splits=["validation", "test"])

# =================================== Squad ====================================
# Maximized evaluation metrics over all answers.
TaskRegistry.add(
    "squad_v010_allanswers",
    TfdsTask,
    tfds_name="squad/plain_text:0.1.0",
    text_preprocessor=preprocessors.squad,
github google-research / text-to-text-transfer-transformer / t5 / data / tasks.py View on Github external
# ==================================== C4 ======================================
_c4_config_suffixes = ["", ".noclean", ".realnewslike", ".webtextlike"]
for config_suffix in _c4_config_suffixes:
  TaskRegistry.add(
      "c4{name}_v020_unsupervised".format(
          name=config_suffix.replace(".", "_")),
      TfdsTask,
      tfds_name="c4/en{config}:1.0.0".format(config=config_suffix),
      text_preprocessor=functools.partial(
          preprocessors.rekey, key_map={"inputs": None, "targets": "text"}),
      token_preprocessor=preprocessors.unsupervised,
      sentencepiece_model_path=DEFAULT_SPM_PATH,
      metric_fns=[])

# ================================ Wikipedia ===================================
TaskRegistry.add(
    "wikipedia_20190301.en_v003_unsupervised",
    TfdsTask,
    # 0.0.4 is identical to 0.0.3 except empty records removed.
    tfds_name="wikipedia/20190301.en:0.0.4",
    text_preprocessor=functools.partial(
        preprocessors.rekey, key_map={"inputs": None, "targets": "text"}),
    token_preprocessor=preprocessors.unsupervised,
    sentencepiece_model_path=DEFAULT_SPM_PATH,
    metric_fns=[])


# =================================== GLUE =====================================
def _get_glue_text_preprocessor(builder_config):
  """Return the glue preprocessor.

  Args:
github google-research / text-to-text-transfer-transformer / t5 / data / tasks.py View on Github external
for prefix, b, tfds_version in b_configs:
  TaskRegistry.add(
      "wmt%s_%s%s_v003" % (prefix, b.language_pair[1], b.language_pair[0]),
      TfdsTask,
      tfds_name="wmt%s_translate/%s:%s" % (prefix, b.name, tfds_version),
      text_preprocessor=functools.partial(
          preprocessors.translate,
          source_language=b.language_pair[1],
          target_language=b.language_pair[0],
          ),
      metric_fns=[metrics.bleu],
      sentencepiece_model_path=DEFAULT_SPM_PATH)

# Special case for t2t ende.
b = tfds.translate.wmt_t2t.WmtT2tTranslate.builder_configs["de-en"]
TaskRegistry.add(
    "wmt_t2t_ende_v003",
    TfdsTask,
    tfds_name="wmt_t2t_translate/de-en:0.0.1",
    text_preprocessor=functools.partial(
        preprocessors.translate,
        source_language=b.language_pair[1],
        target_language=b.language_pair[0],
        ),
    metric_fns=[metrics.bleu],
    sentencepiece_model_path=DEFAULT_SPM_PATH)

# ================================= SuperGlue ==================================
SUPERGLUE_METRICS = collections.OrderedDict([
    ("boolq", [metrics.accuracy]),
    ("cb", [
        metrics.mean_multiclass_f1(num_classes=3),
github google-research / text-to-text-transfer-transformer / t5 / data / tasks.py View on Github external
postprocess_fn=postprocessors.qa,
    metric_fns=[metrics.qa],
    sentencepiece_model_path=DEFAULT_SPM_PATH)

# Squad span prediction task instead of text.
TaskRegistry.add(
    "squad_v010_allanswers_span",
    TfdsTask,
    tfds_name="squad/plain_text:0.1.0",
    text_preprocessor=preprocessors.squad_span_space_tokenized,
    postprocess_fn=postprocessors.span_qa,
    metric_fns=[metrics.span_qa],
    sentencepiece_model_path=DEFAULT_SPM_PATH)

# Deprecated: Use `squad_v010_allanswers` instead.
TaskRegistry.add(
    "squad_v010",
    TfdsTask,
    tfds_name="squad/plain_text:0.1.0",
    text_preprocessor=preprocessors.squad,
    metric_fns=[metrics.qa],
    sentencepiece_model_path=DEFAULT_SPM_PATH)

# ================================= TriviaQA ===================================
TaskRegistry.add(
    "trivia_qa_v010",
    TfdsTask,
    tfds_name="trivia_qa:0.1.0",
    text_preprocessor=preprocessors.trivia_qa,
    metric_fns=[],
    token_preprocessor=preprocessors.trivia_qa_truncate_inputs,
    sentencepiece_model_path=DEFAULT_SPM_PATH)
github google-research / text-to-text-transfer-transformer / t5 / data / tasks.py View on Github external
# Format: year, tfds builder config, tfds version
b_configs = [
    ("14", tfds.translate.wmt14.Wmt14Translate.builder_configs["de-en"], "0.0.3"
    ),
    ("14", tfds.translate.wmt14.Wmt14Translate.builder_configs["fr-en"], "0.0.3"
    ),
    ("16", tfds.translate.wmt16.Wmt16Translate.builder_configs["ro-en"], "0.0.3"
    ),
    ("15", tfds.translate.wmt15.Wmt15Translate.builder_configs["fr-en"], "0.0.4"
    ),
    ("19", tfds.translate.wmt19.Wmt19Translate.builder_configs["de-en"], "0.0.3"
    ),
]

for prefix, b, tfds_version in b_configs:
  TaskRegistry.add(
      "wmt%s_%s%s_v003" % (prefix, b.language_pair[1], b.language_pair[0]),
      TfdsTask,
      tfds_name="wmt%s_translate/%s:%s" % (prefix, b.name, tfds_version),
      text_preprocessor=functools.partial(
          preprocessors.translate,
          source_language=b.language_pair[1],
          target_language=b.language_pair[0],
          ),
      metric_fns=[metrics.bleu],
      sentencepiece_model_path=DEFAULT_SPM_PATH)

# Special case for t2t ende.
b = tfds.translate.wmt_t2t.WmtT2tTranslate.builder_configs["de-en"]
TaskRegistry.add(
    "wmt_t2t_ende_v003",
    TfdsTask,