How to use the t5.data.utils.TaskRegistry.get function in t5

To help you get started, we’ve selected a few t5 examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google-research / text-to-text-transfer-transformer / t5 / data / utils.py View on Github external
pair whose first element is the task name and whose second element
        is either a float (rate) or a function from Task to float.
      default_rate: a float or a function from Task to float. This specifies the
        default rate if rates are not provided in the `tasks` argument.
    """
    self._task_to_rate = {}
    self._tasks = []
    for t in tasks:
      if isinstance(t, str):
        task_name = t
        rate = default_rate
        if default_rate is None:
          raise ValueError("need a rate for each task")
      else:
        task_name, rate = t
      self._tasks.append(TaskRegistry.get(task_name))
      self._task_to_rate[task_name] = rate
    if len(set(tuple(t.output_features) for t in self._tasks)) != 1:
      raise ValueError(
          "All Tasks in a Mixture must have the same output features."
      )
    if len(set(t.sentencepiece_model_path for t in self._tasks)) != 1:
      raise ValueError(
          "All Tasks in a Mixture must have the same sentencepiece_model_path."
      )
github google-research / text-to-text-transfer-transformer / t5 / data / utils.py View on Github external
def get_mixture_or_task(task_or_mixture_name):
  """Return the Task or Mixture from the appropriate registry."""
  mixtures = MixtureRegistry.names()
  tasks = TaskRegistry.names()
  if task_or_mixture_name in mixtures:
    if task_or_mixture_name in tasks:
      logging.warning("%s is both a Task and a Mixture, returning Mixture",
                      task_or_mixture_name)
    return MixtureRegistry.get(task_or_mixture_name)
  if task_or_mixture_name in tasks:
    return TaskRegistry.get(task_or_mixture_name)
  else:
    raise ValueError("No Task or Mixture found with name: %s" %
                     task_or_mixture_name)