Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
itertools.starmap(
lambda x, y: f"{x:.4f}Β±{y:.4f}",
zip(
np.mean(results, axis=0).tolist(),
np.std(results, axis=0).tolist(),
),
)
)
)
return tab_data
if __name__ == "__main__":
parser = options.get_training_parser()
args, _ = parser.parse_known_args()
args = options.parse_args_and_arch(parser, args)
print(args)
assert len(args.device_id) == 1
variants = list(
gen_variants(dataset=args.dataset, model=args.model, seed=args.seed)
)
# Collect results
results_dict = defaultdict(list)
results = [main(args) for args in variant_args_generator(args, variants)]
for variant, result in zip(variants, results):
results_dict[variant[:-1]].append(result)
col_names = ["Variant"] + list(results_dict[variant[:-1]][-1].keys())
tab_data = tabulate_results(results_dict)
print(tabulate(tab_data, headers=col_names, tablefmt="github"))
itertools.starmap(
lambda x, y: f"{x:.4f}Β±{y:.4f}",
zip(
np.mean(results, axis=0).tolist(),
np.std(results, axis=0).tolist(),
),
)
)
)
return tab_data
if __name__ == "__main__":
parser = options.get_training_parser()
args, _ = parser.parse_known_args()
args = options.parse_args_and_arch(parser, args)
print(args)
assert len(args.device_id) == 1
variants = list(
gen_variants(dataset=args.dataset, model=args.model, seed=args.seed)
)
# Collect results
results_dict = defaultdict(list)
results = [main(args) for args in variant_args_generator(args, variants)]
for variant, result in zip(variants, results):
results_dict[variant[:-1]].append(result)
col_names = ["Variant"] + list(results_dict[variant[:-1]][-1].keys())
tab_data = tabulate_results(results_dict)
print(tabulate(tab_data, headers=col_names, tablefmt="github"))
return itertools.starmap(Variant, itertools.product(*items.values()))
def getpid(_):
# HACK to get different pids
time.sleep(1)
return mp.current_process().pid
if __name__ == "__main__":
# Magic for making multiprocessing work for PyTorch
mp.set_start_method("spawn")
parser = options.get_training_parser()
args, _ = parser.parse_known_args()
args = options.parse_args_and_arch(parser, args)
# Make sure datasets are downloaded first
datasets = args.dataset
for dataset in datasets:
args.dataset = dataset
_ = build_dataset(args)
args.dataset = datasets
print(args)
variants = list(
gen_variants(dataset=args.dataset, model=args.model, seed=args.seed)
)
device_ids = args.device_id
if args.cpu:
num_workers = 1