Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _get_glue_postprocess_fn(builder_config):
if builder_config.name == "stsb":
return postprocessors.string_to_float
elif builder_config.name == "multirc":
return postprocessors.multirc
elif builder_config.name == "record":
return postprocessors.qa
else:
return functools.partial(
postprocessors.string_label_to_class_id,
label_classes=builder_config.label_classes,
)
TaskRegistry.add(
"super_glue_wsc_v102_simple_train",
TfdsTask,
tfds_name="super_glue/wsc.fixed:1.0.2",
text_preprocessor=functools.partial(
preprocessors.wsc_simple, correct_referent_only=True),
metric_fns=[],
sentencepiece_model_path=DEFAULT_SPM_PATH,
splits=["train"])
TaskRegistry.add(
"super_glue_wsc_v102_simple_eval",
TfdsTask,
tfds_name="super_glue/wsc.fixed:1.0.2",
text_preprocessor=functools.partial(
preprocessors.wsc_simple, correct_referent_only=False),
postprocess_fn=postprocessors.wsc_simple,
metric_fns=[metrics.accuracy],
sentencepiece_model_path=DEFAULT_SPM_PATH,
splits=["validation", "test"])
# =================================== WNLI =====================================
TaskRegistry.add(
"glue_wnli_v002_simple_eval",
TfdsTask,
tfds_name="glue/wnli:0.0.2",
text_preprocessor=preprocessors.wnli_simple,
postprocess_fn=postprocessors.wsc_simple,
metric_fns=[metrics.accuracy],
sentencepiece_model_path=DEFAULT_SPM_PATH,
splits=["validation", "test"])
# =================================== Squad ====================================
def _get_glue_postprocess_fn(builder_config):
if builder_config.name == "stsb":
return postprocessors.string_to_float
elif builder_config.name == "multirc":
return postprocessors.multirc
elif builder_config.name == "record":
return postprocessors.qa
else:
return functools.partial(
postprocessors.string_label_to_class_id,
label_classes=builder_config.label_classes,
)
"squad_v010_context_free",
TfdsTask,
tfds_name="squad/plain_text:0.1.0",
text_preprocessor=functools.partial(
preprocessors.squad, include_context=False),
postprocess_fn=postprocessors.qa,
metric_fns=[metrics.qa],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# Squad span prediction task instead of text.
TaskRegistry.add(
"squad_v010_allanswers_span",
TfdsTask,
tfds_name="squad/plain_text:0.1.0",
text_preprocessor=preprocessors.squad_span_space_tokenized,
postprocess_fn=postprocessors.span_qa,
metric_fns=[metrics.span_qa],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# Deprecated: Use `squad_v010_allanswers` instead.
TaskRegistry.add(
"squad_v010",
TfdsTask,
tfds_name="squad/plain_text:0.1.0",
text_preprocessor=preprocessors.squad,
metric_fns=[metrics.qa],
sentencepiece_model_path=DEFAULT_SPM_PATH)
# ================================= TriviaQA ===================================
TaskRegistry.add(
"trivia_qa_v010",
TfdsTask,