How to use the lightfm.datasets._common.get_data function in lightfm

To help you get started, we’ve selected a few lightfm examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github lyst / lightfm / lightfm / datasets / movielens.py View on Github external
test: sp.coo_matrix of shape [n_users, n_items]
         Contains testing set interactions.
    item_features: sp.csr_matrix of shape [n_items, n_item_features]
         Contains item features.
    item_feature_labels: np.array of strings of shape [n_item_features,]
         Labels of item features.
    item_labels: np.array of strings of shape [n_items,]
         Items' titles.
    """

    if not (indicator_features or genre_features):
        raise ValueError(
            "At least one of item_indicator_features " "or genre_features must be True"
        )

    zip_path = _common.get_data(
        data_home,
        (
            "https://github.com/maciejkula/"
            "lightfm_datasets/releases/"
            "download/v0.1.0/movielens.zip"
        ),
        "movielens100k",
        "movielens.zip",
        download_if_missing,
    )

    # Load raw data
    try:
        (train_raw, test_raw, item_metadata_raw, genres_raw) = _read_raw_data(zip_path)
    except zipfile.BadZipFile:
        # Download was corrupted, get rid of the partially
github lyst / lightfm / lightfm / datasets / stackexchange.py View on Github external
if not (0.0 < test_set_fraction < 1.0):
        raise ValueError("Test set fraction must be between 0 and 1")

    urls = {
        "crossvalidated": (
            "https://github.com/maciejkula/lightfm_datasets/releases/"
            "download/v0.1.0/stackexchange_crossvalidated.npz"
        ),
        "stackoverflow": (
            "https://github.com/maciejkula/lightfm_datasets/releases/"
            "download/v0.1.0/stackexchange_stackoverflow.npz"
        ),
    }

    path = _common.get_data(
        data_home,
        urls[dataset],
        os.path.join("stackexchange", dataset),
        "data.npz",
        download_if_missing,
    )

    data = np.load(path)

    interactions = sp.coo_matrix(
        (
            data["interactions_data"],
            (data["interactions_row"], data["interactions_col"]),
        ),
        shape=data["interactions_shape"].flatten(),
    )