Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _threadpool_info():
# Like threadpool_info but return the object instead of the list of dicts
return _ThreadpoolInfo(user_api=_ALL_USER_APIS)
def test_threadpool_limits_bad_input():
# Check that appropriate errors are raised for invalid arguments
match = re.escape("user_api must be either in {} or None."
.format(_ALL_USER_APIS))
with pytest.raises(ValueError, match=match):
threadpool_limits(limits=1, user_api="wrong")
with pytest.raises(TypeError,
match="limits must either be an int, a list or a dict"):
threadpool_limits(limits=(1, 2, 3))
def test_set_threadpool_limits_by_api(user_api, limit):
# Check that the maximum number of threads can be set by user_api
original_info = _threadpool_info()
modules_matching_api = original_info.get_modules("user_api", user_api)
if not modules_matching_api:
user_apis = _ALL_USER_APIS if user_api is None else [user_api]
pytest.skip("Requires a library which api is in {}".format(user_apis))
with threadpool_limits(limits=limit, user_api=user_api):
for module in modules_matching_api:
if is_old_openblas(module):
continue
# threadpool_limits only sets an upper bound on the number of
# threads.
assert 0 < module.get_num_threads() <= limit
assert _threadpool_info() == original_info