How to use the gradio.outputs.AbstractOutput function in gradio

To help you get started, we’ve selected a few gradio examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github gradio-app / gradio-UI / gradio / interface.py View on Github external
:param saliency: an optional function that takes the model and the processed input and returns a 2-d array

        """
        if isinstance(inputs, str):
            self.input_interface = gradio.inputs.registry[inputs.lower()](
                preprocessing_fns
            )
        elif isinstance(inputs, gradio.inputs.AbstractInput):
            self.input_interface = inputs
        else:
            raise ValueError("Input interface must be of type `str` or `AbstractInput`")
        if isinstance(outputs, str):
            self.output_interface = gradio.outputs.registry[outputs.lower()](
                postprocessing_fns
            )
        elif isinstance(outputs, gradio.outputs.AbstractOutput):
            self.output_interface = outputs
        else:
            raise ValueError(
                "Output interface must be of type `str` or `AbstractOutput`"
            )
        self.model_obj = model
        if model_type is None:
            model_type = self._infer_model_type(model)
            if verbose:
                print(
                    "Model type not explicitly identified, inferred to be: {}".format(
                        self.VALID_MODEL_TYPES[model_type]
                    )
                )
        elif not (model_type.lower() in self.VALID_MODEL_TYPES):
            ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES))
github gradio-app / gradio-UI / gradio / outputs.py View on Github external
"""
        """
        return prediction

    def rebuild_flagged(self, dir, msg):
        """
        Default rebuild method to decode a base64 image
        """
        im = preprocessing_utils.decode_base64_to_image(msg)
        timestamp = datetime.datetime.now()
        filename = f'output_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
        im.save(f'{dir}/{filename}', 'PNG')
        return filename


registry = {cls.__name__.lower(): cls for cls in AbstractOutput.__subclasses__()}
github gradio-app / gradio-UI / gradio / outputs.py View on Github external
})
                        prediction[prediction.argmax()] = 0
        elif isinstance(prediction, str):
            response[Label.LABEL_KEY] = prediction
        else:
            raise ValueError("Unable to post-process model prediction.")
        return json.dumps(response)

    def rebuild_flagged(self, dir, msg):
        """
        Default rebuild method for label
        """
        return json.loads(msg)


class Textbox(AbstractOutput):

    def get_name(self):
        return 'textbox'

    def postprocess(self, prediction):
        """
        """
        return prediction

    def rebuild_flagged(self, dir, msg):
        """
        Default rebuild method for label
        """
        return json.loads(msg)
github gradio-app / gradio-UI / gradio / outputs.py View on Github external
def get_name(self):
        return 'textbox'

    def postprocess(self, prediction):
        """
        """
        return prediction

    def rebuild_flagged(self, dir, msg):
        """
        Default rebuild method for label
        """
        return json.loads(msg)


class Image(AbstractOutput):

    def get_name(self):
        return 'image'

    def postprocess(self, prediction):
        """
        """
        return prediction

    def rebuild_flagged(self, dir, msg):
        """
        Default rebuild method to decode a base64 image
        """
        im = preprocessing_utils.decode_base64_to_image(msg)
        timestamp = datetime.datetime.now()
        filename = f'output_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
github gradio-app / gradio-UI / build / lib / gradio / interface.py View on Github external
:param saliency: an optional function that takes the model and the processed input and returns a 2-d array

        """
        if isinstance(inputs, str):
            self.input_interface = gradio.inputs.registry[inputs.lower()](
                preprocessing_fns
            )
        elif isinstance(inputs, gradio.inputs.AbstractInput):
            self.input_interface = inputs
        else:
            raise ValueError("Input interface must be of type `str` or `AbstractInput`")
        if isinstance(outputs, str):
            self.output_interface = gradio.outputs.registry[outputs.lower()](
                postprocessing_fns
            )
        elif isinstance(outputs, gradio.outputs.AbstractOutput):
            self.output_interface = outputs
        else:
            raise ValueError(
                "Output interface must be of type `str` or `AbstractOutput`"
            )
        self.model_obj = model
        if model_type is None:
            model_type = self._infer_model_type(model)
            if verbose:
                print(
                    "Model type not explicitly identified, inferred to be: {}".format(
                        self.VALID_MODEL_TYPES[model_type]
                    )
                )
        elif not (model_type.lower() in self.VALID_MODEL_TYPES):
            ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES))
github gradio-app / gradio-UI / gradio / outputs.py View on Github external
    @abstractmethod
    def postprocess(self, prediction):
        """
        All interfaces should define a default postprocessing method
        """
        pass

    @abstractmethod
    def rebuild_flagged(self, inp):
        """
        All interfaces should define a method that rebuilds the flagged output when it's passed back (i.e. rebuilds image from base64)
        """
        pass


class Label(AbstractOutput):
    LABEL_KEY = 'label'
    CONFIDENCES_KEY = 'confidences'
    CONFIDENCE_KEY = 'confidence'

    def __init__(self, postprocessing_fn=None, num_top_classes=3, show_confidences=True, label_names=None,
                 max_label_length=None, max_label_words=None, word_delimiter=" "):
        self.num_top_classes = num_top_classes
        self.show_confidences = show_confidences
        self.label_names = label_names
        self.max_label_length = max_label_length
        self.max_label_words = max_label_words
        self.word_delimiter = word_delimiter
        super().__init__(postprocessing_fn=postprocessing_fn)

    def get_name(self):
        return 'label'