How to use the livelossplot.matplotlib_subplots.BaseSubplot function in livelossplot

To help you get started, we’ve selected a few livelossplot examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github stared / livelossplot / livelossplot / matplotlib_subplots.py View on Github external
self.model = model
        self.X = X 
        self.Y = Y

    def predict(self, model, X):
        # e.g. model(torch.fromnumpy(X)).detach().numpy()
        return model.predict(X)

    def draw(self, *args, **kwargs):
        plt.plot(self.X, self.Y, 'r.', label="Ground truth")
        plt.plot(self.X, self.predict(self.model, self.X), '-', label="Model")
        plt.title("Prediction")
        plt.legend(loc='lower right')


class Plot2d(BaseSubplot):
    def __init__(self, model, X, Y, valiation_data=(None, None), h=0.02, margin=0.25):
        super().__init__()

        self.model = model
        self.X = X 
        self.Y = Y
        self.X_test, self.Y_test = valiation_data

        # add size assertions

        self.cm_bg = plt.cm.RdBu
        self.cm_points = ListedColormap(['#FF0000', '#0000FF'])

        h = .02  # step size in the mesh
        x_min = X[:, 0].min() - margin
        x_max = X[:, 0].max() + margin
github stared / livelossplot / livelossplot / matplotlib_subplots.py View on Github external
serie_metric_name = serie_fmt.format(self.metric)
            serie_metric_logs = [(log.get('_i', i + 1), log[serie_metric_name])
                                for i, log in enumerate(logs[skip:])
                                if serie_metric_name in log]

            if len(serie_metric_logs) > 0:
                xs, ys = zip(*serie_metric_logs)
                plt.plot(xs, ys, label=serie_label)

        plt.title(self.title)
        plt.xlabel('epoch')
        plt.legend(loc='center right')


class Plot1D(BaseSubplot):
    def __init__(self, model, X, Y):
        super().__init__(self)
        self.model = model
        self.X = X 
        self.Y = Y

    def predict(self, model, X):
        # e.g. model(torch.fromnumpy(X)).detach().numpy()
        return model.predict(X)

    def draw(self, *args, **kwargs):
        plt.plot(self.X, self.Y, 'r.', label="Ground truth")
        plt.plot(self.X, self.predict(self.model, self.X), '-', label="Model")
        plt.title("Prediction")
        plt.legend(loc='lower right')
github stared / livelossplot / livelossplot / matplotlib_subplots.py View on Github external
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap


class BaseSubplot:
    def __init__(self):
        pass
    
    def draw(self):
        raise Exception("Not implemented")

    def __call__(self, *args, **kwargs):
        self.draw(*args, **kwargs)


class LossSubplot(BaseSubplot):
    def __init__(self,
                 metric,
                 title="",
                 series_fmt={'training': '{}', 'validation':'val_{}'},
                 skip_first=2,
                 max_epoch=None):
        super().__init__(self)
        self.metric = metric
        self.title = title
        self.series_fmt = series_fmt
        self.skip_first = skip_first
        self.max_epoch = max_epoch
    
    def _how_many_to_skip(self, log_length, skip_first):
        if log_length < skip_first:
            return 0