How to use the braindecode.datautil.signal_target.SignalAndTarget function in braindecode

To help you get started, we’ve selected a few braindecode examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github TNTLFreiburg / braindecode / test / acceptance_tests / from_notebooks / test_cropped_decoding.py View on Github external
eog=False,
                                      exclude='bads')

    # Extract trials, only using EEG channels
    epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1,
                         proj=False, picks=eeg_channel_inds,
                         baseline=None, preload=True)
    import numpy as np
    from braindecode.datautil.signal_target import SignalAndTarget
    # Convert data from volt to millivolt
    # Pytorch expects float32 for input and int64 for labels.
    X = (epoched.get_data() * 1e6).astype(np.float32)
    y = (epoched.events[:, 2] - 2).astype(np.int64)  # 2,3 -> 0,1

    train_set = SignalAndTarget(X[:60], y=y[:60])
    test_set = SignalAndTarget(X[60:], y=y[60:])
    from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
    from torch import nn
    from braindecode.torch_ext.util import set_random_seeds
    from braindecode.models.util import to_dense_prediction_model

    # Set if you want to use GPU
    # You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
    cuda = False
    set_random_seeds(seed=20170629, cuda=cuda)

    # This will determine how many crops are processed in parallel
    input_time_length = 450
    n_classes = 2
    in_chans = train_set.X.shape[1]
    # final_conv_length determines the size of the receptive field of the ConvNet
    model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,
github TNTLFreiburg / braindecode / braindecode / datautil / splitters.py View on Github external
Returns
    -------
    reduced_set: :class:`.SignalAndTarget`
        Dataset with only examples selected.
    """
    # probably not necessary
    indices = np.array(indices)
    if hasattr(dataset.X, "ndim"):
        # numpy array
        new_X = np.array(dataset.X)[indices]
    else:
        # list
        new_X = [dataset.X[i] for i in indices]
    new_y = np.asarray(dataset.y)[indices]
    return SignalAndTarget(new_X, new_y)
github TNTLFreiburg / braindecode / braindecode / datautil / splitters.py View on Github external
def concatenate_two_sets(set_a, set_b):
    """
    Concatenate two sets together.
    
    Parameters
    ----------
    set_a, set_b: :class:`.SignalAndTarget`

    Returns
    -------
    concatenated_set: :class:`.SignalAndTarget`
    """
    new_X = concatenate_np_array_or_add_lists(set_a.X, set_b.X)
    new_y = concatenate_np_array_or_add_lists(set_a.y, set_b.y)
    return SignalAndTarget(new_X, new_y)
github TNTLFreiburg / braindecode / braindecode / datautil / trial_segment.py View on Github external
"lowest class"
                    )
            else:
                if np.max(np.sum(this_y, axis=1)) > 1:
                    log.warning(
                        "Have multiple active classes and will convert to "
                        "lowest class"
                    )
                this_new_y = np.argmax(this_y, axis=1)
                this_new_y[np.sum(this_y, axis=1) == 0] = -1
            new_y.append(this_new_y)
        y = new_y
    if one_label_per_trial:
        y = np.array(y, dtype=np.int64)

    return SignalAndTarget(X, y)
github TNTLFreiburg / braindecode / braindecode / models / base.py View on Github external
label 0 or 1, e.g. 0.5.
        individual_crops: bool

        Returns
        -------
            outs_per_trial: 2darray or list of 2darrays
                Network outputs for each trial, optionally for each crop within trial.
        """
        if individual_crops:
            assert self.cropped, "Cropped labels only for cropped decoding"
        X = _ensure_float32(X)
        all_preds = []
        with th.no_grad():
            dummy_y = np.ones(len(X), dtype=np.int64)
            for b_X, _ in self.iterator.get_batches(
                SignalAndTarget(X, dummy_y), False
            ):
                b_X_var = np_to_var(b_X)
                if self.cuda:
                    b_X_var = b_X_var.cuda()
                all_preds.append(var_to_np(self.network(b_X_var)))
        if self.cropped:
            outs_per_trial = compute_preds_per_trial_from_crops(
                all_preds, self.iterator.input_time_length, X
            )
            if not individual_crops:
                outs_per_trial = np.array(
                    [np.mean(o, axis=1) for o in outs_per_trial]
                )
        else:
            outs_per_trial = np.concatenate(all_preds)
        return outs_per_trial
github TNTLFreiburg / braindecode / braindecode / models / base.py View on Github external
if optimizer.__class__.__name__ == "AdamW":
                schedule_weight_decay = True
            optimizer = ScheduledOptimizer(
                scheduler,
                self.optimizer,
                schedule_weight_decay=schedule_weight_decay,
            )
        loss_function = self.loss
        if self.cropped:
            loss_function = lambda outputs, targets: self.loss(
                th.mean(outputs, dim=2), targets
            )
        if validation_data is not None:
            valid_X = _ensure_float32(validation_data[0])
            valid_y = validation_data[1]
            valid_set = SignalAndTarget(valid_X, valid_y)
        else:
            valid_set = None
        test_set = None
        self.monitors = [LossMonitor()]
        if self.cropped:
            self.monitors.append(CroppedTrialMisclassMonitor(input_time_length))
        else:
            self.monitors.append(MisclassMonitor())
        if self.extra_monitors is not None:
            self.monitors.extend(self.extra_monitors)
        self.monitors.append(RuntimeMonitor())
        exp = Experiment(
            self.network,
            train_set,
            valid_set,
            test_set,