How to use threadpoolctl - 10 common examples

To help you get started, we’ve selected a few threadpoolctl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github joblib / threadpoolctl / tests / test_threadpoolctl.py View on Github external
def test_threadpool_limits_manual_unregister():
    # Check that threadpool_limits can be used as an object which holds the
    # original state of the threadpools and that can be restored thanks to the
    # dedicated unregister method
    original_info = _threadpool_info()

    limits = threadpool_limits(limits=1)
    try:
        for module in _threadpool_info():
            if is_old_openblas(module):
                continue
            assert module.num_threads == 1
    finally:
        # Restore the original limits so that this test does not have any
        # side-effect.
        limits.unregister()

    assert _threadpool_info() == original_info
github joblib / threadpoolctl / tests / test_threadpoolctl.py View on Github external
def test_shipped_openblas():
    # checks that OpenBLAS effectively uses the number of threads requested by
    # the context manager
    original_info = _threadpool_info()

    openblas_modules = original_info.get_modules("internal_api", "openblas")

    with threadpool_limits(1):
        for module in openblas_modules:
            assert module.get_num_threads() == 1

    assert original_info == _threadpool_info()
github joblib / threadpoolctl / tests / test_threadpoolctl.py View on Github external
def _threadpool_info():
    # Like threadpool_info but return the object instead of the list of dicts
    return _ThreadpoolInfo(user_api=_ALL_USER_APIS)
github joblib / threadpoolctl / tests / test_threadpoolctl.py View on Github external
def test_command_line_command_flag():
    pytest.importorskip("numpy")
    output = subprocess.check_output(
        ["python", "-m", "threadpoolctl", "-c", "import numpy"])
    cli_info = json.loads(output.decode("utf-8"))

    this_process_info = threadpool_info()
    for module in cli_info:
        assert module in this_process_info
github joblib / threadpoolctl / tests / test_threadpoolctl.py View on Github external
def test_threadpool_limits_public_api():
    # Check consistency between threadpool_info and _ThreadpoolInfo
    public_info = threadpool_info()
    private_info = _threadpool_info()

    for module1, module2 in zip(public_info, private_info):
        assert module1 == module2.todict()
github joblib / threadpoolctl / tests / test_threadpoolctl.py View on Github external
def test_command_line_import_flag():
    result = subprocess.run([
        "python", "-m", "threadpoolctl", "-i",
        "numpy",
        "scipy.linalg",
        "invalid_package",
        "numpy.invalid_sumodule",
    ], capture_output=True, check=True, encoding="utf-8")
    cli_info = json.loads(result.stdout)

    this_process_info = threadpool_info()
    for module in cli_info:
        assert module in this_process_info

    warnings = [w.strip() for w in result.stderr.splitlines()]
    assert "WARNING: could not import invalid_package" in warnings
    assert "WARNING: could not import numpy.invalid_sumodule" in warnings
    if scipy is None:
        assert "WARNING: could not import scipy.linalg" in warnings
    else:
        assert "WARNING: could not import scipy.linalg" not in warnings
github joblib / threadpoolctl / tests / test_threadpoolctl.py View on Github external
def test_ThreadpoolInfo_todicts():
    # Check all keys expected for the public api are in the dicts returned by
    # the .todict(s) methods
    info = _threadpool_info()

    assert threadpool_info() == [module.todict() for module in info.modules]
    assert info.todicts() == [module.todict() for module in info]
    assert info.todicts() == [module.todict() for module in info.modules]

    for module in info:
        module_dict = module.todict()
        assert "user_api" in module_dict
        assert "internal_api" in module_dict
        assert "prefix" in module_dict
        assert "filepath" in module_dict
        assert "version" in module_dict
        assert "num_threads" in module_dict

        if module.internal_api in ("mkl", "blis", "openblas"):
            assert "threading_layer" in module_dict
github joblib / threadpoolctl / tests / test_threadpoolctl.py View on Github external
def test_threadpool_limits_bad_input():
    # Check that appropriate errors are raised for invalid arguments
    match = re.escape("user_api must be either in {} or None."
                      .format(_ALL_USER_APIS))
    with pytest.raises(ValueError, match=match):
        threadpool_limits(limits=1, user_api="wrong")

    with pytest.raises(TypeError,
                       match="limits must either be an int, a list or a dict"):
        threadpool_limits(limits=(1, 2, 3))
github joblib / threadpoolctl / tests / test_threadpoolctl.py View on Github external
def test_set_threadpool_limits_by_api(user_api, limit):
    # Check that the maximum number of threads can be set by user_api
    original_info = _threadpool_info()

    modules_matching_api = original_info.get_modules("user_api", user_api)
    if not modules_matching_api:
        user_apis = _ALL_USER_APIS if user_api is None else [user_api]
        pytest.skip("Requires a library which api is in {}".format(user_apis))

    with threadpool_limits(limits=limit, user_api=user_api):
        for module in modules_matching_api:
            if is_old_openblas(module):
                continue
            # threadpool_limits only sets an upper bound on the number of
            # threads.
            assert 0 < module.get_num_threads() <= limit

    assert _threadpool_info() == original_info
github joblib / threadpoolctl / tests / test_threadpoolctl.py View on Github external
@pytest.mark.parametrize("prefix", _ALL_PREFIXES)
@pytest.mark.parametrize("limit", [1, 3])
def test_threadpool_limits_by_prefix(prefix, limit):
    # Check that the maximum number of threads can be set by prefix
    original_info = _threadpool_info()

    modules_matching_prefix = original_info.get_modules("prefix", prefix)
    if not modules_matching_prefix:
        pytest.skip("Requires {} runtime".format(prefix))

    with threadpool_limits(limits={prefix: limit}):
        for module in modules_matching_prefix:
            if is_old_openblas(module):
                continue
            # threadpool_limits only sets an upper bound on the number of
            # threads.
            assert 0 < module.get_num_threads() <= limit