How to use the wandb.util.get_module function in wandb

To help you get started, we’ve selected a few wandb examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github wandb / client / tests / utils.py View on Github external
import pytest
import os
import click
from click.testing import CliRunner
import git
import requests
import json
from wandb import util
from wandb.apis import InternalApi

import webbrowser
from wandb.git_repo import GitRepo
from distutils.version import LooseVersion

torch = util.get_module("torch")
if torch:
    if LooseVersion(torch.__version__) < LooseVersion("0.4"):
        pytorch_tensor = torch.Tensor
        OLD_PYTORCH = True
    else:
        # supports 0d tensors but is a module before 0.4
        pytorch_tensor = torch.tensor
        OLD_PYTORCH = False


@pytest.fixture
def runner(monkeypatch):
    whaaaaat = util.vendor_import("whaaaaat")
    monkeypatch.setattr('wandb.cli.api', InternalApi(
        default_settings={'project': 'test', 'git_tag': True}, load_settings=False))
    monkeypatch.setattr(click, 'launch', lambda x: 1)
github wandb / client / wandb / stats.py View on Github external
import collections
import os
from pynvml import *
import time
from numbers import Number
import threading
from wandb import util
from wandb import termlog
psutil = util.get_module("psutil")


class SystemStats(object):
    def __init__(self, run, api):
        try:
            nvmlInit()
            self.gpu_count = nvmlDeviceGetCount()
        except NVMLError as err:
            self.gpu_count = 0
        self.run = run
        self._api = api
        self.sampler = {}
        self.samples = 0
        self._shutdown = False
        if psutil:
            net = psutil.net_io_counters()
github wandb / client / wandb / data_types.py View on Github external
def __init__(self, data_or_path, sample_rate=None, caption=None):
        """Accepts a path to an audio file or a numpy array of audio data. 
        """
        self._duration = None
        self._sample_rate = sample_rate
        self._caption = caption

        if isinstance(data_or_path, six.string_types):
            super(Audio, self).__init__(data_or_path, is_tmp=False)
        else:
            if sample_rate == None:
                raise ValueError('Argument "sample_rate" is required when instantiating wandb.Audio with raw data.')

            soundfile = util.get_module(
                "soundfile", required='Raw audio requires the soundfile package. To get it, run "pip install soundfile"')

            tmp_path = os.path.join(MEDIA_TMP.name, util.generate_id() + '.wav')
            soundfile.write(tmp_path, data_or_path, sample_rate)
            self._duration = len(data_or_path) / float(sample_rate)

            super(Audio, self).__init__(tmp_path, is_tmp=True)
github wandb / client / wandb / agent_controller.py View on Github external
from __future__ import print_function
import logging
import yaml
import time
import json
import random
import string
import sys
import six

import wandb
from wandb import util
from wandb.apis import InternalApi


wandb_sweeps = util.get_module("wandb.sweeps.sweeps")
logger = logging.getLogger(__name__)


# Name:           run.Name,
# Config:         json.RawMessage(config),
# History:        json.RawMessage(history),
# State:          state,
# SummaryMetrics: json.RawMessage(summary),
class Run(object):
    def __init__(self, name, state, history, config, summaryMetrics, stopped):
        self.name = name
        self.state = state
        self.config = config
        self.history = history
        self.summaryMetrics = summaryMetrics
        self.stopped = stopped
github wandb / client / wandb / data_types.py View on Github external
def plot_to_json(obj):
    """Converts a matplotlib or plotly object to json so that we can pass
        it the the wandb server and display it nicely there"""

    if util.is_matplotlib_typename(util.get_full_typename(obj)):
        tools = util.get_module(
            "plotly.tools", required="plotly is required to log interactive plots, install with: pip install plotly or convert the plot to an image with `wandb.Image(plt)`")
        obj = tools.mpl_to_plotly(obj)

    if util.is_plotly_typename(util.get_full_typename(obj)):
        return {"_type": "plotly", "plot": numpy_arrays_to_lists(obj.to_plotly_json())}
    else:
        return obj
github wandb / client / wandb / tensorboard / __init__.py View on Github external
last = value.histo.bucket_limit[-2] + \
                    value.histo.bucket_limit[-2] - value.histo.bucket_limit[-3]
                np_histogram = (list(value.histo.bucket), [
                    first] + value.histo.bucket_limit[:-1] + [last])
                try:
                    #TODO: we should just re-bin if there are too many buckets
                    values[tag] = wandb.Histogram(
                        np_histogram=np_histogram)
                except ValueError:
                    wandb.termwarn("Not logging key \"{}\".  Histograms must have fewer than {} bins".format(
                        tag, wandb.Histogram.MAX_LENGTH), repeat=False)
            else:
                #TODO: is there a case where we can render this?
                wandb.termwarn("Not logging key \"{}\".  Found a histogram with only 2 bins.".format(tag), repeat=False)
        elif value.tag == "_hparams_/session_start_info":
            if wandb.util.get_module("tensorboard.plugins.hparams"):
                from tensorboard.plugins.hparams import plugin_data_pb2
                plugin_data = plugin_data_pb2.HParamsPluginData()
                plugin_data.ParseFromString(
                    value.metadata.plugin_data.content)
                for key, param in six.iteritems(plugin_data.session_start_info.hparams):
                    if not wandb.run.config.get(key):
                        wandb.run.config[key] = param.number_value or param.string_value or param.bool_value
            else:
                wandb.termerror(
                    "Received hparams tf.summary, but could not import the hparams plugin from tensorboard")

    return values
github wandb / client / wandb / data_types.py View on Github external
def encode(self):
        mpy = util.get_module("moviepy.editor", required='wandb.Video requires moviepy and imageio when passing raw data.  Install with "pip install moviepy imageio"')
        tensor = self._prepare_video(self.data)
        _, self._height, self._width, self._channels = tensor.shape

        # encode sequence of images into gif string
        clip = mpy.ImageSequenceClip(list(tensor), fps=self._fps)

        filename = os.path.join(MEDIA_TMP.name, util.generate_id() + '.'+ self._format)
        try:  # older version of moviepy does not support progress_bar argument.
            if self._format == "gif":
                clip.write_gif(filename, verbose=False, progress_bar=False)
            else:
                clip.write_videofile(filename, verbose=False, progress_bar=False)
        except TypeError:
            if self._format == "gif":
                clip.write_gif(filename, verbose=False)
            else:
github wandb / client / wandb / dataframes.py View on Github external
def image_segmentation_multiclass_dataframe(x, y_true, y_pred, labels, example_ids=None, class_colors=None):
    np = util.get_module('numpy', required='dataframes require numpy')
    pd = util.get_module('pandas', required='dataframes require pandas')

    x, y_true, y_pred= np.array(x), np.array(y_true), np.array(y_pred)

    if x.shape[0] != y_true.shape[0]:
        termwarn('Sample count mismatch: x(%d) != y_true(%d). skipping evaluation' % (x.shape[0], y_true.shape[0]))
        return
    if x.shape[0] != y_pred.shape[0]:
        termwarn('Sample count mismatch: x(%d) != y_pred(%d). skipping evaluation' % (x.shape[0], y_pred.shape[0]))
        return
    if class_colors is not None and len(class_colors) != y_true.shape[-1]:
        termwarn('Class color count mismatch: y_true(%d) != class_colors(%d). using generated colors' % (y_true.shape[-1], len(class_colors)))
        class_colors = None

    class_count = y_true.shape[-1]

    if class_colors is None:
github wandb / client / wandb / dataframes.py View on Github external
def image_segmentation_multiclass_dataframe(x, y_true, y_pred, labels, example_ids=None, class_colors=None):
    np = util.get_module('numpy', required='dataframes require numpy')
    pd = util.get_module('pandas', required='dataframes require pandas')

    x, y_true, y_pred= np.array(x), np.array(y_true), np.array(y_pred)

    if x.shape[0] != y_true.shape[0]:
        termwarn('Sample count mismatch: x(%d) != y_true(%d). skipping evaluation' % (x.shape[0], y_true.shape[0]))
        return
    if x.shape[0] != y_pred.shape[0]:
        termwarn('Sample count mismatch: x(%d) != y_pred(%d). skipping evaluation' % (x.shape[0], y_pred.shape[0]))
        return
    if class_colors is not None and len(class_colors) != y_true.shape[-1]:
        termwarn('Class color count mismatch: y_true(%d) != class_colors(%d). using generated colors' % (y_true.shape[-1], len(class_colors)))
        class_colors = None

    class_count = y_true.shape[-1]