How to use the nnabla.set_default_context function in nnabla

To help you get started, we’ve selected a few nnabla examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github sony / nnabla-examples / GANs / pggan / validate.py View on Github external
def main():
    # Args
    args = get_args()

    # Context
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    logger.info(ctx)
    nn.set_default_context(ctx)
    nn.set_auto_forward(True)

    # Monitor
    monitor = Monitor(args.monitor_path)

    # Validation
    logger.info("Start validation")

    num_images = args.valid_samples
    num_batches = num_images // args.batch_size

    # DataIterator
    di = data_iterator(args.img_path, args.batch_size,
                       imsize=(args.imsize, args.imsize),
                       num_samples=args.valid_samples,
                       dataset_name=args.dataset_name)
github sony / nnabla-examples / semantic-segmentation / deeplabv3plus / train.py View on Github external
distributed = args.distributed
    compute_acc = args.compute_acc

    if distributed:
        # Communicator and Context
        from nnabla.ext_utils import get_extension_context
        extension_module = "cudnn"
        ctx = get_extension_context(
            extension_module, type_config=args.type_config)
        comm = C.MultiProcessDataParalellCommunicator(ctx)
        comm.init()
        n_devices = comm.size
        mpi_rank = comm.rank
        device_id = mpi_rank
        ctx.device_id = str(device_id)
        nn.set_default_context(ctx)
    else:
        # Get context.
        from nnabla.ext_utils import get_extension_context
        extension_module = args.context
        if args.context is None:
            extension_module = 'cpu'
        logger.info("Running in %s" % extension_module)
        ctx = get_extension_context(
            extension_module, device_id=args.device_id, type_config=args.type_config)
        nn.set_default_context(ctx)
        n_devices = 1
        device_id = 0

    # training data
    data = data_iterator_segmentation(
            args.train_samples, args.batch_size, args.train_dir, args.train_label_dir, target_width=args.image_width, target_height=args.image_height)
github sony / nnabla-examples / meta-learning / metric_based_meta_learning.py View on Github external
max_iteration = args.max_iteration
    lr_decay_interval = args.lr_decay_interval
    lr_decay = args.lr_decay
    iter_per_epoch = args.iter_per_epoch
    iter_per_valid = args.iter_per_valid
    n_episode_for_valid = args.n_episode_for_valid
    n_episode_for_test = args.n_episode_for_test
    work_dir = args.work_dir

    # Set context
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Monitor outputs
    from nnabla.monitor import Monitor, MonitorSeries
    monitor = Monitor(args.work_dir)
    monitor_loss = MonitorSeries(
        "Training loss", monitor, interval=iter_per_epoch)
    monitor_valid_err = MonitorSeries(
        "Validation error", monitor, interval=iter_per_valid)
    monitor_test_err = MonitorSeries("Test error", monitor)
    monitor_test_conf = MonitorSeries("Test error confidence", monitor)

    # Output files
    param_file = work_dir + "params.h5"
    tsne_file = work_dir + "tsne.png"

    # Load data
github sony / nnabla / examples / cpp / forward_check / mnist / classification.py View on Github external
* Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_lenet_prediction
    if args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))
github sony / nnabla-examples / GANs / sagan / generate.py View on Github external
def generate(args):
    # Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn
    threshold = args.truncation_threshold

    # Model
    nn.load_parameters(args.model_load_path)
    z = nn.Variable([batch_size, latent])
    y_fake = nn.Variable([batch_size])
    x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\
        .apply(persistent=True)
github sony / nnabla-examples / reduction / cifar10 / structured-sparsity / classification.py View on Github external
def train():
    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == "cifar10_resnet23_prediction":
        model_prediction = cifar10_resnet23_prediction

    # TRAIN
    maps = 64
    data_iterator = data_iterator_cifar10
    c = 3
    h = w = 32
    n_train = 50000
    n_valid = 10000

    # Create input variables.
    image = nn.Variable([args.batch_size, c, h, w])
    label = nn.Variable([args.batch_size, 1])
github sony / nnabla-examples / semantic-segmentation / deeplabv3plus / eval.py View on Github external
def main():
    args = get_args()
    rng = np.random.RandomState(1223)

    # Get context
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    miou = validate(args)
github sony / nnabla-examples / mnist-collection / classification_bnn.py View on Github external
* Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Set parameter gradients zero
      * Execute forwardprop on the training graph.
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    args = get_args(monitor_path='tmp.monitor.bnn')

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Initialize DataIterator for MNIST.
    data = data_iterator_mnist(args.batch_size, True)
    vdata = data_iterator_mnist(args.batch_size, False)

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_binary_connect_lenet_prediction
    if args.net == 'bincon':
        mnist_cnn_prediction = mnist_binary_connect_lenet_prediction
    elif args.net == 'binnet':
        mnist_cnn_prediction = mnist_binary_net_lenet_prediction
    elif args.net == 'bwn':
        mnist_cnn_prediction = mnist_binary_weight_lenet_prediction
    elif args.net == 'bincon_resnet':
        mnist_cnn_prediction = mnist_binary_connect_resnet_prediction
    elif args.net == 'binnet_resnet':
github sony / nnabla-examples / GANs / sagan / morph.py View on Github external
def morph(args):
    # Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn
    threshold = args.truncation_threshold

    # Model
    nn.load_parameters(args.model_load_path)
    z = nn.Variable([batch_size, latent])
    alpha = nn.Variable.from_numpy_array(np.zeros([1, 1]))
    beta = (nn.Variable.from_numpy_array(np.ones([1, 1])) - alpha)
    y_fake_a = nn.Variable([batch_size])