Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_command_line_command_flag():
pytest.importorskip("numpy")
output = subprocess.check_output(
["python", "-m", "threadpoolctl", "-c", "import numpy"])
cli_info = json.loads(output.decode("utf-8"))
this_process_info = threadpool_info()
for module in cli_info:
assert module in this_process_info
def test_threadpool_limits_public_api():
# Check consistency between threadpool_info and _ThreadpoolInfo
public_info = threadpool_info()
private_info = _threadpool_info()
for module1, module2 in zip(public_info, private_info):
assert module1 == module2.todict()
def test_command_line_import_flag():
result = subprocess.run([
"python", "-m", "threadpoolctl", "-i",
"numpy",
"scipy.linalg",
"invalid_package",
"numpy.invalid_sumodule",
], capture_output=True, check=True, encoding="utf-8")
cli_info = json.loads(result.stdout)
this_process_info = threadpool_info()
for module in cli_info:
assert module in this_process_info
warnings = [w.strip() for w in result.stderr.splitlines()]
assert "WARNING: could not import invalid_package" in warnings
assert "WARNING: could not import numpy.invalid_sumodule" in warnings
if scipy is None:
assert "WARNING: could not import scipy.linalg" in warnings
else:
assert "WARNING: could not import scipy.linalg" not in warnings
def test_ThreadpoolInfo_todicts():
# Check all keys expected for the public api are in the dicts returned by
# the .todict(s) methods
info = _threadpool_info()
assert threadpool_info() == [module.todict() for module in info.modules]
assert info.todicts() == [module.todict() for module in info]
assert info.todicts() == [module.todict() for module in info.modules]
for module in info:
module_dict = module.todict()
assert "user_api" in module_dict
assert "internal_api" in module_dict
assert "prefix" in module_dict
assert "filepath" in module_dict
assert "version" in module_dict
assert "num_threads" in module_dict
if module.internal_api in ("mkl", "blis", "openblas"):
assert "threading_layer" in module_dict
from pprint import pprint
from statistics import mean, stdev
from threadpoolctl import threadpool_info, threadpool_limits
parser = ArgumentParser(description='Measure threadpool_limits call overhead.')
parser.add_argument('--import', dest="packages", default=[], nargs='+',
help='Python packages to import to load threadpool enabled'
' libraries.')
parser.add_argument("--n-calls", type=int, default=100,
help="Number of iterations")
args = parser.parse_args()
for package_name in args.packages:
__import__(package_name)
pprint(threadpool_info())
timings = []
for _ in range(args.n_calls):
t = time.time()
with threadpool_limits(limits=1):
pass
timings.append(time.time() - t)
print("Overhead per call: {:.3f} +/-{:.3f} ms"
.format(mean(timings) * 1e3, stdev(timings) * 1e3))