How to use the art.DATA_PATH function in art

To help you get started, we’ve selected a few art examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github IBM / adversarial-robustness-toolbox / tests / classifiers / test_tensorflow.py View on Github external
def test_pickle(self):
        classifier, sess = get_classifier_tf()
        full_path = os.path.join(DATA_PATH, 'my_classifier')
        folder = os.path.split(full_path)[0]

        if not os.path.exists(folder):
            os.makedirs(folder)

        pickle.dump(classifier, open(full_path, 'wb'))

        # Unpickle:
        with open(full_path, 'rb') as f:
            loaded = pickle.load(f)
            self.assertEqual(classifier._clip_values, loaded._clip_values)
            self.assertEqual(classifier._channel_index, loaded._channel_index)
            self.assertEqual(set(classifier.__dict__.keys()), set(loaded.__dict__.keys()))

        # Test predict
        predictions_1 = classifier.predict(self.x_test)
github IBM / adversarial-robustness-toolbox / tests / test_visualization.py View on Github external
def test_save_image(self):
        (x, _), (_, _), _, _ = load_mnist(raw=True)

        f_name = 'image1.png'
        save_image(x[0], f_name)
        path = os.path.join(DATA_PATH, f_name)
        self.assertTrue(os.path.isfile(path))
        os.remove(path)

        f_name = 'image2.jpg'
        save_image(x[1], f_name)
        path = os.path.join(DATA_PATH, f_name)
        self.assertTrue(os.path.isfile(path))
        os.remove(path)

        folder = 'images123456'
        f_name_with_dir = os.path.join(folder, 'image3.png')
        save_image(x[3], f_name_with_dir)
        path = os.path.join(DATA_PATH, f_name_with_dir)
        self.assertTrue(os.path.isfile(path))
        os.remove(path)
        os.rmdir(os.path.split(path)[0])  # Remove also test folder

        folder = os.path.join('images123456', 'inner')
        f_name_with_dir = os.path.join(folder, 'image4.png')
        save_image(x[3], f_name_with_dir)
        path_nested = os.path.join(DATA_PATH, f_name_with_dir)
        self.assertTrue(os.path.isfile(path_nested))
github IBM / adversarial-robustness-toolbox / art / visualization.py View on Github external
def save_image(image_array, f_name):
    """
    Saves image into a file inside `DATA_PATH` with the name `f_name`.

    :param image_array: Image to be saved
    :type image_array: `np.ndarray`
    :param f_name: File name containing extension e.g., my_img.jpg, my_img.png, my_images/my_img.png
    :type f_name: `str`
    :return: `None`
    """
    file_name = os.path.join(DATA_PATH, f_name)
    folder = os.path.split(file_name)[0]
    if not os.path.exists(folder):
        os.makedirs(folder)

    from PIL import Image
    image = Image.fromarray(image_array)
    image.save(file_name)
    logger.info('Image saved to %s.', file_name)
github IBM / adversarial-robustness-toolbox / art / visualization.py View on Github external
colors.append('C' + str(i))
        else:
            if len(colors) != len(np.unique(labels)):
                raise ValueError('The amount of provided colors should match the number of labels in the 3pd plot.')

        fig = plt.figure()
        axis = plt.axes(projection='3d')

        for i, coord in enumerate(points):
            try:
                color_point = labels[i]
                axis.scatter3D(coord[0], coord[1], coord[2], color=colors[color_point])
            except IndexError:
                raise ValueError('Labels outside the range. Should start from zero and be sequential there after')
        if save:
            file_name = os.path.realpath(os.path.join(DATA_PATH, f_name))
            folder = os.path.split(file_name)[0]

            if not os.path.exists(folder):
                os.makedirs(folder)
            fig.savefig(file_name, bbox_inches='tight')
            logger.info('3d-plot saved to %s.', file_name)

        return fig
    except ImportError:
        logger.warning("matplotlib not installed. For this reason, cluster visualization was not displayed.")
github IBM / adversarial-robustness-toolbox / art / classifiers / keras.py View on Github external
def save(self, filename, path=None):
        """
        Save a model to file in the format specific to the backend framework. For Keras, .h5 format is used.

        :param filename: Name of the file where to store the model.
        :type filename: `str`
        :param path: Path of the folder where to store the model. If no path is specified, the model will be stored in
                     the default data location of the library `DATA_PATH`.
        :type path: `str`
        :return: None
        """
        import os

        if path is None:
            from art import DATA_PATH
            full_path = os.path.join(DATA_PATH, filename)
        else:
            full_path = os.path.join(path, filename)
        folder = os.path.split(full_path)[0]
        if not os.path.exists(folder):
            os.makedirs(folder)

        self._model.save(str(full_path))
        logger.info('Model saved in path: %s.', full_path)
github IBM / adversarial-robustness-toolbox / art / utils.py View on Github external
can also be extracted. This is a simplified version of the function with the same name in Keras.

    :param filename: Name of the file.
    :type filename: `str`
    :param url: Download URL.
    :type url: `str`
    :param path: Folder to store the download. If not specified, `~/.art/data` is used instead.
    :type: `str`
    :param extract: If true, tries to extract the archive.
    :type extract: `bool`
    :return: Path to the downloaded file.
    :rtype: `str`
    """
    if path is None:
        from art import DATA_PATH
        path_ = os.path.expanduser(DATA_PATH)
    else:
        path_ = os.path.expanduser(path)
    if not os.access(path_, os.W_OK):
        path_ = os.path.join('/tmp', '.art')
    if not os.path.exists(path_):
        os.makedirs(path_)

    if extract:
        extract_path = os.path.join(path_, filename)
        full_path = extract_path + '.tar.gz'
    else:
        full_path = os.path.join(path_, filename)

    # Determine if dataset needs downloading
    download = not os.path.exists(full_path)
github IBM / adversarial-robustness-toolbox / art / poison_detection / activation_defence.py View on Github external
def _unpickle_classifier(file_name):
        """
        Unpickles classifier using the filename provided. Function assumes that the pickle is in `art.DATA_PATH`.

        :param file_name:
        :return:
        """
        from art import DATA_PATH
        import pickle

        full_path = os.path.join(DATA_PATH, file_name)
        logger.info('Loading classifier from %s', full_path)
        with open(full_path, 'rb') as f_classifier:
            loaded_classifier = pickle.load(f_classifier)
            return loaded_classifier
github IBM / adversarial-robustness-toolbox / art / classifiers / keras.py View on Github external
def __setstate__(self, state):
        """
        Use to ensure `KerasClassifier` can be unpickled.

        :param state: State dictionary with instance parameters to restore.
        :type state: `dict`
        """
        self.__dict__.update(state)

        # Load and update all functionality related to Keras
        import os
        from art import DATA_PATH
        from keras.models import load_model

        full_path = os.path.join(DATA_PATH, state['model_name'])
        model = load_model(str(full_path))

        self._model = model
        self._initialize_params(model, state['_use_logits'], state['_input_layer'], state['_output_layer'],
                                state['_custom_activation'])
github IBM / adversarial-robustness-toolbox / art / utils.py View on Github external
content = cPickle.load(file_)
            else:
                content = cPickle.load(file_, encoding='bytes')
                content_decoded = {}
                for key, value in content.items():
                    content_decoded[key.decode('utf8')] = value
                content = content_decoded
        data = content['data']
        labels = content['labels']

        data = data.reshape(data.shape[0], 3, 32, 32)
        return data, labels

    from art import DATA_PATH

    path = get_file('cifar-10-batches-py', extract=True, path=DATA_PATH,
                    url='http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')

    num_train_samples = 50000

    x_train = np.zeros((num_train_samples, 3, 32, 32), dtype=np.uint8)
    y_train = np.zeros((num_train_samples,), dtype=np.uint8)

    for i in range(1, 6):
        fpath = os.path.join(path, 'data_batch_' + str(i))
        data, labels = load_batch(fpath)
        x_train[(i - 1) * 10000: i * 10000, :, :, :] = data
        y_train[(i - 1) * 10000: i * 10000] = labels

    fpath = os.path.join(path, 'test_batch')
    x_test, y_test = load_batch(fpath)
    y_train = np.reshape(y_train, (len(y_train), 1))