Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def action_compare(argc, argv):
args = parse_args(argv)
metrics = {}
projects = { \
args.path_1: None,
args.path_2: None,
}
ref = None
inp_shape = None
out_shape = None
is_prepared = None
prjs = []
for path in projects:
prj = Project(path)
err = prj.load()
if err is not None:
log.error("error while loading project %s: %s", path, err)
quit()
prjs.append(prj)
if not is_prepared:
is_prepared = True
else:
small_dataset = generate_reduced_dataset(args.dataset)
are_equal = are_preparation_equal(prjs, small_dataset)
log.info("deleting temporal file %s", small_dataset)
os.remove(small_dataset)
if out_shape is None:
out_shape = prj.model.output_shape
elif out_shape != prj.model.output_shape:
log.error("model %s output shape is %s, expected %s", path, prj.model.output_shape, out_shape)
def action_view(argc, argv):
args = parse_args(argv)
prj = Project(args.path)
err = prj.load()
if err is not None:
log.error("error while loading project: %s", err)
quit()
prj.view(args.img_only)
def action_to_fdeep(argc, argv):
args = parse_args(argv)
prj = Project(args.path)
err = prj.load()
if err is not None:
log.error("error while loading project: %s", err)
quit()
elif not prj.is_trained():
log.error("no trained model found for this project")
quit()
convert(prj.weights_path, prj.fdeep_path, args.no_tests, args.metadata)
def action_relevance(argc, argv):
global prj, deltas, tot, start, speed, nrows, ncols, attributes
args = parse_args(argv)
prj = Project(args.path)
err = prj.load()
if err is not None:
log.error("error while loading project: %s", err)
quit()
elif not prj.is_trained():
log.error("no trained Keras model found for this project")
quit()
prj.prepare(args.dataset, 0.0, 0.0)
# one single worker in blocking mode = serial
if args.workers == 0:
args.workers = 1
X, y = prj.dataset.subsample(args.ratio)
nrows, ncols = X.shape if prj.dataset.is_flat else (X[0].shape[0], len(X))
def action_train(argc, argv):
args = parse_args(argv)
prj = Project(args.path)
err = prj.load()
if err is not None:
log.error("error while loading project: %s", err)
quit()
if args.dataset is not None:
# a dataset was specified, split it and generate
# the subsets
prj.dataset.do_save = not args.no_save
prj.prepare(args.dataset, args.test, args.validation, not args.no_shuffle)
elif prj.dataset.exists():
# no dataset passed, attempt to use the previously
# generated subsets
prj.dataset.load()
else:
log.error("no test/train/validation subsets found in %s, please specify a --dataset argument", args.path)
def action_encode(argc, argv):
args = parse_args(argv)
if not os.path.exists(args.path):
log.error("%s does not exist.", args.path)
quit()
prj = Project(args.project)
err = prj.load()
if err is not None:
log.error("error while loading project: %s", err)
quit()
args.label = args.label.strip().lower()
log.info("using %s labeling", 'auto' if args.label == 'auto' else 'hardcoded')
inputs = []
if os.path.isdir(args.path):
in_files = []
if args.label == 'auto':
# the label is inferred from the dirname, so we expect
# args.path to contain multiple subfolders
for subfolder in glob.glob(os.path.join(args.path, "*")):
log.info("enumerating %s ...", subfolder)
def action_serve(argc, argv):
global prj, app, classes, num_outputs
args = parse_args(argv)
prj = Project(args.path)
err = prj.load()
if err is not None:
log.error("error while loading project: %s", err)
quit()
elif not prj.is_trained():
log.error("no trained Keras model found for this project")
quit()
if args.classes is None:
num_outputs = prj.model.output.shape[1]
if prj.classes is None:
classes = ["class_%d" % i for i in range(num_outputs)]
else:
classes = [prj.classes[i] for i in range(num_outputs)]
else:
classes = [s.strip() for s in args.classes.split(',') if s.strip() != ""]
def action_to_tf(argc, argv):
args = parse_args(argv)
prj = Project(args.path)
err = prj.load()
if err is not None:
log.error("error while loading project: %s", err)
quit()
elif not prj.is_trained():
log.error("no trained Keras model found for this projec")
quit()
frozen_graph = freeze_session(K.get_session(),
output_names=[out.op.name for out in prj.model.outputs])
log.info("saving protobuf to %s ...", os.path.join(prj.path, 'model.pb'))
tf.train.write_graph(frozen_graph, prj.path, "model.pb", as_text=False)
if not is_prepared:
is_prepared = True
else:
small_dataset = generate_reduced_dataset(args.dataset)
are_equal = are_preparation_equal(prjs, small_dataset)
log.info("deleting temporal file %s", small_dataset)
os.remove(small_dataset)
if out_shape is None:
out_shape = prj.model.output_shape
elif out_shape != prj.model.output_shape:
log.error("model %s output shape is %s, expected %s", path, prj.model.output_shape, out_shape)
quit()
projects[path] = prj
for prj, path in zip(prjs, projects):
prj = Project(path)
err = prj.load()
if err is not None:
log.error("error while loading project %s: %s", path, err)
quit()
if ref is None:
prj.prepare(args.dataset, 0, 0, False)
ref = prj
is_prepared = True
else:
if are_equal:
log.info("Projects use same prepare.py file ...")
prj.dataset.X, prj.dataset.Y, prj.dataset.n_labels = ref.dataset.X.copy(), ref.dataset.Y.copy(), ref.dataset.n_labels
else:
log.info("Projects use different prepare.py files, reloading dataset ...")
import multiprocessing
n_jobs = multiprocessing.cpu_count()
elif args.workers != 0:
n_jobs = args.workers
log.info("using %d workers" % n_jobs)
if args.nclusters and not args.cluster:
log.warning("number of clusters specified but clustering won't be perfomed")
if not (args.pca or args.correlations or args.stats or args.cluster):
log.error("No exploration action was specified")
print("")
parse_args(["-h"])
quit()
prj = Project(args.path)
err = prj.load()
if err is not None:
log.error("error while loading project: %s", err)
quit()
prj.prepare(args.dataset, 0.0, 0.0)
if not prj.dataset.is_flat:
log.error("data exploration can only be applied to flat inputs")
quit()
X, y = prj.dataset.subsample(args.ratio)
nrows, ncols = X.shape
attributes = get_attributes(args.attributes, ncols)
if args.correlations:
log.info("computing correlations of each feature with target")