How to use the wandb.util function in wandb

To help you get started, we’ve selected a few wandb examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github wandb / client / tests / test_retry.py View on Github external
def test_retry_with_noauth_401(capsys):
    def fail():
        res = requests.Response()
        res.status_code = 401
        raise retry.TransientException(exc=requests.HTTPError(response=res))
    fn = retry.Retry(fail, check_retry_fn=util.no_retry_auth)
    with pytest.raises(CommError) as excinfo:
        fn()
    assert excinfo.value.message == 'Invalid or missing api_key.  Run wandb login'
github wandb / client / wandb / core.py View on Github external
newline (bool, optional): Print a newline at the end of the string
            repeat (bool, optional): If set to False only prints the string once per process
    """
    if string:
        line = '\n'.join(['{}: {}'.format(LOG_STRING, s)
                          for s in string.split('\n')])
    else:
        line = ''
    if not repeat and line in PRINTED_MESSAGES:
        return
    # Repeated line tracking limited to 1k messages
    if len(PRINTED_MESSAGES) < 1000:
        PRINTED_MESSAGES.add(line)
    if os.getenv(env.SILENT):
        from wandb import util
        util.mkdir_exists_ok(os.path.dirname(util.get_log_file_path()))
        with open(util.get_log_file_path(), 'w') as log:
            click.echo(line, file=log, nl=newline)
    else:
        click.echo(line, file=sys.stderr, nl=newline)
github wandb / client / wandb / run_manager.py View on Github external
if isinstance(v, six.string_types):
                    if len(v) >= 20:
                        v = v[:20] + '...'
                    wandb.termlog(format_str.format(k, v))
                elif isinstance(v, numbers.Number):
                    wandb.termlog(format_str.format(k, v))

        self._run.history.load()
        history_keys = self._run.history.keys()
        # Only print sparklines if the terminal is utf-8
        if len(history_keys) and sys.stdout.encoding == "UTF_8":
            logger.info("rendering history")
            wandb.termlog('Run history:')
            max_len = max([len(k) for k in history_keys])
            for key in history_keys:
                vals = util.downsample(self._run.history.column(key), 40)
                if any((not isinstance(v, numbers.Number) for v in vals)):
                    continue
                line = sparkline.sparkify(vals)
                format_str = u'  {:>%s} {}' % max_len
                wandb.termlog(format_str.format(key, line))

        wandb_files = set([save_name for save_name in self._file_pusher.files() if util.is_wandb_file(save_name)])
        media_files = set([save_name for save_name in self._file_pusher.files() if save_name.startswith('media')])
        other_files = set(self._file_pusher.files()) - wandb_files - media_files
        logger.info("syncing files to cloud storage")
        if other_files:
            wandb.termlog('Syncing files in %s:' % os.path.relpath(self._run.dir))
            for save_name in sorted(other_files):
                wandb.termlog('  %s' % save_name)
            wandb.termlog('plus {} W&B file(s) and {} media file(s)'.format(len(wandb_files), len(media_files)))
        else:
github wandb / client / wandb / apis / internal.py View on Github external
'heartbeat_seconds': 30,
        }
        self.client = Client(
            transport=RequestsHTTPTransport(
                headers={'User-Agent': self.user_agent, 'X-WANDB-USERNAME': env.get_username(env=self._environ)},
                use_json=True,
                # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
                # https://bugs.python.org/issue22889
                timeout=self.HTTP_TIMEOUT,
                auth=("api", self.api_key or ""),
                url='%s/graphql' % self.settings('base_url')
            )
        )
        self.gql = retry.Retry(self.execute,
            retry_timedelta=retry_timedelta,
            check_retry_fn=util.no_retry_auth,
            retryable_exceptions=(RetryError, requests.RequestException))
        self._current_run_id = None
        self._file_stream_api = None
github wandb / client / wandb / data_types.py View on Github external
self._grouping = grouping
        self._caption = caption
        self._width = None
        self._height = None
        self._image = None

        if isinstance(data_or_path, six.string_types):
            super(Image, self).__init__(data_or_path, is_tmp=False)
        else:
            data = data_or_path

            PILImage = util.get_module(
                "PIL.Image", required='wandb.Image needs the PIL package. To get it, run "pip install pillow".')
            if util.is_matplotlib_typename(util.get_full_typename(data)):
                buf = six.BytesIO()
                util.ensure_matplotlib_figure(data).savefig(buf)
                self._image = PILImage.open(buf)
            elif isinstance(data, PILImage.Image):
                self._image = data
            elif util.is_pytorch_tensor_typename(util.get_full_typename(data)):
                vis_util = util.get_module(
                    "torchvision.utils", "torchvision is required to render images")
                if hasattr(data, "requires_grad") and data.requires_grad:
                    data = data.detach()
                data = vis_util.make_grid(data, normalize=True)
                self._image = PILImage.fromarray(data.mul(255).clamp(
                    0, 255).byte().permute(1, 2, 0).cpu().numpy())
            else:
                if hasattr(data, "numpy"): # TF data eager tensors
                    data = data.numpy()
                if data.ndim > 2:
                    data = data.squeeze()  # get rid of trivial dimensions as a convenience
github wandb / client / wandb / data_types.py View on Github external
def val_to_json(run, key, val, step='summary'):
    # Converts a wandb datatype to its JSON representation.
   
    converted = val
    typename = util.get_full_typename(val)

    if util.is_pandas_data_frame(val):
        assert step == 'summary', "We don't yet support DataFrames in History."
        return data_frame_to_json(val, run, key, step)
    elif util.is_matplotlib_typename(typename):
        # This handles plots with images in it because plotly doesn't support it
        # TODO: should we handle a list of plots?
        val = util.ensure_matplotlib_figure(val)
        if any(len(ax.images) > 0 for ax in val.axes):
            PILImage = util.get_module(
                "PIL.Image", required="Logging plots with images requires pil: pip install pillow")
            buf = six.BytesIO()
            val.savefig(buf)
            val = Image(PILImage.open(buf))
        else:
            converted = plot_to_json(val)
    elif util.is_plotly_typename(typename):
        converted = plot_to_json(val)
    elif isinstance(val, collections.Sequence) and all(isinstance(v, WBValue) for v in val):
        # This check will break down if Image/Audio/... have child classes.
        if len(val) and isinstance(val[0], BatchableMedia) and all(isinstance(v, type(val[0])) for v in val):
github wandb / client / wandb / apis / file_stream.py View on Github external
def _read_queue(self):
        # called from the push thread (_thread_body), this does an initial read
        # that'll block for up to rate_limit_seconds. Then it tries to read
        # as much out of the queue as it can. We do this because the http post
        # to the server happens within _thread_body, and can take longer than
        # our rate limit. So next time we get a chance to read the queue we want
        # read all the stuff that queue'd up since last time.
        #
        # If we have more than MAX_ITEMS_PER_PUSH in the queue then the push thread
        # will get behind and data will buffer up in the queue.
        return util.read_many_from_queue(
            self._queue, self.MAX_ITEMS_PER_PUSH, self.rate_limit_seconds())
github wandb / client / wandb / internal_cli.py View on Github external
# handle non-git directories
    if not root:
        root = os.path.abspath(os.getcwd())
        host = socket.gethostname()
        remote_url = 'file://%s%s' % (host, root)

    run.save(program=args['program'], api=api)
    env = dict(os.environ)
    run.set_environment(env)

    try:
        rm = wandb.run_manager.RunManager(api, run)
    except wandb.run_manager.Error:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        wandb.termerror('An Exception was raised during setup, see %s for full traceback.' %
                        util.get_log_file_path())
        wandb.termerror(exc_value)
        if 'permission' in str(exc_value):
            wandb.termerror(
                'Are you sure you provided the correct API key to "wandb login"?')
        lines = traceback.format_exception(
            exc_type, exc_value, exc_traceback)
        logging.error('\n'.join(lines))
    else:
        rm.run_user_process(args['program'], args['args'], env)
github wandb / client / wandb / data_types.py View on Github external
def val_to_json(run, key, val, step='summary'):
    # Converts a wandb datatype to its JSON representation.
   
    converted = val
    typename = util.get_full_typename(val)

    if util.is_pandas_data_frame(val):
        assert step == 'summary', "We don't yet support DataFrames in History."
        return data_frame_to_json(val, run, key, step)
    elif util.is_matplotlib_typename(typename):
        # This handles plots with images in it because plotly doesn't support it
        # TODO: should we handle a list of plots?
        val = util.ensure_matplotlib_figure(val)
        if any(len(ax.images) > 0 for ax in val.axes):
            PILImage = util.get_module(
                "PIL.Image", required="Logging plots with images requires pil: pip install pillow")
            buf = six.BytesIO()
            val.savefig(buf)
            val = Image(PILImage.open(buf))
        else:
            converted = plot_to_json(val)
    elif util.is_plotly_typename(typename):
        converted = plot_to_json(val)