How to use the ludwig.utils.data_utils.split_dataset_tvt function in ludwig

To help you get started, we’ve selected a few ludwig examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github uber / ludwig / tests / integration_tests / test_visualization_api.py View on Github external
def obtain_df_splits(data_csv):
    """Split input data csv file in to train, validation and test dataframes.

    :param data_csv: Input data CSV file.
    :return test_df, train_df, val_df: Train, validation and test dataframe
            splits
    """
    data_df = read_csv(data_csv)
    # Obtain data split array mapping data rows to split type
    # 0-train, 1-validation, 2-test
    data_split = get_split(data_df)
    train_split, test_split, val_split = split_dataset_tvt(data_df, data_split)
    # Splits are python dictionaries not dataframes- they need to be converted.
    test_df = pd.DataFrame(test_split)
    train_df = pd.DataFrame(train_split)
    val_df = pd.DataFrame(val_split)
    return test_df, train_df, val_df
github uber / ludwig / ludwig / data / preprocessing.py View on Github external
else:
            dataset[output_feature['name']] = hdf5_data[
                output_feature['name']][()]
        if 'limit' in output_feature:
            dataset[output_feature['name']] = collapse_rare_labels(
                dataset[output_feature['name']],
                output_feature['limit']
            )

    if not split_data:
        hdf5_data.close()
        return dataset

    split = hdf5_data['split'][()]
    hdf5_data.close()
    training_set, test_set, validation_set = split_dataset_tvt(dataset, split)

    # shuffle up
    if shuffle_training:
        training_set = data_utils.shuffle_dict_unison_inplace(training_set)

    return training_set, test_set, validation_set
github uber / ludwig / ludwig / data / preprocessing.py View on Github external
)
        if not skip_save_processed_input:
            logger.info('Writing dataset')
            data_hdf5_fp = replace_file_extension(data_csv, 'hdf5')
            data_utils.save_hdf5(data_hdf5_fp, data, train_set_metadata)
            train_set_metadata[DATA_TRAIN_HDF5_FP] = data_hdf5_fp
            logger.info('Writing train set metadata with vocabulary')

            train_set_metadata_json_fp = replace_file_extension(
                data_csv,
                'json'
            )
            data_utils.save_json(
                train_set_metadata_json_fp, train_set_metadata)

        training_set, test_set, validation_set = split_dataset_tvt(
            data,
            data['split']
        )

    elif data_train_csv is not None:
        # use data_train (including _validation and _test if they are present)
        # and ignore data and train set metadata
        # needs preprocessing
        logger.info(
            'Using training raw csv, no hdf5 and json '
            'file with the same name have been found'
        )
        logger.info('Building dataset (it may take a while)')
        concatenated_df = concatenate_csv(
            data_train_csv,
            data_validation_csv,
github uber / ludwig / ludwig / data / preprocessing.py View on Github external
)
        logger.info('Building dataset (it may take a while)')
        concatenated_df = concatenate_csv(
            data_train_csv,
            data_validation_csv,
            data_test_csv
        )
        concatenated_df.csv = data_train_csv
        data, train_set_metadata = build_dataset_df(
            concatenated_df,
            features,
            preprocessing_params,
            train_set_metadata=train_set_metadata,
            random_seed=random_seed
        )
        training_set, test_set, validation_set = split_dataset_tvt(
            data,
            data['split']
        )
        if not skip_save_processed_input:
            logger.info('Writing dataset')
            data_train_hdf5_fp = replace_file_extension(data_train_csv, 'hdf5')
            data_utils.save_hdf5(
                data_train_hdf5_fp,
                training_set,
                train_set_metadata
            )
            train_set_metadata[DATA_TRAIN_HDF5_FP] = data_train_hdf5_fp
            if validation_set is not None:
                data_validation_hdf5_fp = replace_file_extension(
                    data_validation_csv,
                    'hdf5'