Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
'type': 'confirm',
'name': common.STRIP_DEBUG_OPS,
'message': 'Do you want to strip debug ops? \n'
'This will improve model execution performance.',
'default': True,
'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
(common.TF_SAVED_MODEL,
common.TF_HUB_MODEL))
}
]
params = PyInquirer.prompt(questions, format_params, style=prompt_style)
output_options = [
{
'type': 'input',
'name': common.OUTPUT_PATH,
'message': 'Which directory do you want to save '
'the converted model in?',
'filter': lambda path: os.path.expanduser(path.strip()),
'validate': lambda path: len(path) > 0
}, {
'type': 'confirm',
'message': 'The output already directory exists, '
'do you want to overwrite it?',
'name': 'overwrite_output_path',
'default': False,
'when': lambda ans: output_path_exists(ans[common.OUTPUT_PATH])
}
]
while (not common.OUTPUT_PATH in params or
output_path_exists(params[common.OUTPUT_PATH]) and
'message': 'Do you want to compress the model? '
'(this will decrease the model precision.)',
'choices': [{
'name': 'No compression (Higher accuracy)',
'value': None
}, {
'name': '2x compression (Accuracy/size trade-off)',
'value': 2
}, {
'name': '4x compression (Smaller size)',
'value': 1
}]
},
{
'type': 'input',
'name': common.WEIGHT_SHARD_SIZE_BYTES,
'message': 'Please enter shard size (in bytes) of the weight files?',
'default': str(4 * 1024 * 1024),
'when': lambda answers: value_in_list(answers, common.OUTPUT_FORMAT,
(common.TFJS_LAYERS_MODEL))
},
{
'type': 'confirm',
'name': common.SPLIT_WEIGHTS_BY_LAYER,
'message': 'Do you want to split weights by layers?',
'default': False,
'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
(common.TFJS_LAYERS_MODEL))
},
{
'type': 'confirm',
'name': common.SKIP_OP_CHECK,
'"tf_saved_model".')
parser.add_argument(
'--%s' % common.QUANTIZATION_BYTES,
type=int,
choices=set(quantization.QUANTIZATION_BYTES_TO_DTYPES.keys()),
help='How many bytes to optionally quantize/compress the weights to. 1- '
'and 2-byte quantizaton is supported. The default (unquantized) size is '
'4 bytes.')
parser.add_argument(
'--%s' % common.SPLIT_WEIGHTS_BY_LAYER,
action='store_true',
help='Applicable to keras input_format only: Whether the weights from '
'different layers are to be stored in separate weight groups, '
'corresponding to separate binary weight files. Default: False.')
parser.add_argument(
'--%s' % common.VERSION,
'-v',
dest='show_version',
action='store_true',
help='Show versions of tensorflowjs and its dependencies')
parser.add_argument(
'--%s' % common.SKIP_OP_CHECK,
action='store_true',
help='Skip op validation for TensorFlow model conversion.')
parser.add_argument(
'--%s' % common.STRIP_DEBUG_OPS,
type=bool,
default=True,
help='Strip debug ops (Print, Assert, CheckNumerics) from graph.')
parser.add_argument(
'--%s' % common.WEIGHT_SHARD_SIZE_BYTES,
type=int,
common.TFJS_GRAPH_MODEL))
},
{
'type': 'list',
'name': common.SIGNATURE_NAME,
'message': 'What is signature name of the model?',
'choices': available_signature_names,
'when': lambda answers: (is_saved_model(answers[common.INPUT_FORMAT])
and
(common.OUTPUT_FORMAT not in format_params
or format_params[common.OUTPUT_FORMAT] ==
common.TFJS_GRAPH_MODEL))
},
{
'type': 'list',
'name': common.QUANTIZATION_BYTES,
'message': 'Do you want to compress the model? '
'(this will decrease the model precision.)',
'choices': [{
'name': 'No compression (Higher accuracy)',
'value': None
}, {
'name': '2x compression (Accuracy/size trade-off)',
'value': 2
}, {
'name': '4x compression (Smaller size)',
'value': 1
}]
},
{
'type': 'input',
'name': common.WEIGHT_SHARD_SIZE_BYTES,
'type': 'confirm',
'name': common.STRIP_DEBUG_OPS,
'message': 'Do you want to strip debug ops? \n'
'This will improve model execution performance.',
'default': True,
'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
(common.TF_SAVED_MODEL,
common.TF_HUB_MODEL))
}
]
params = PyInquirer.prompt(questions, format_params, style=prompt_style)
output_options = [
{
'type': 'input',
'name': common.OUTPUT_PATH,
'message': 'Which directory do you want to save '
'the converted model in?',
'filter': lambda path: update_output_path(path, params),
'validate': lambda path: len(path) > 0
},
{
'type': 'confirm',
'message': 'The output already directory exists, '
'do you want to overwrite it?',
'name': 'overwrite_output_path',
'default': False,
'when': lambda ans: output_path_exists(ans[common.OUTPUT_PATH])
}
]
while (common.OUTPUT_PATH not in params or
'message': 'Please enter shard size (in bytes) of the weight files?',
'default': str(4 * 1024 * 1024),
'when': lambda answers: value_in_list(answers, common.OUTPUT_FORMAT,
(common.TFJS_LAYERS_MODEL))
},
{
'type': 'confirm',
'name': common.SPLIT_WEIGHTS_BY_LAYER,
'message': 'Do you want to split weights by layers?',
'default': False,
'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
(common.TFJS_LAYERS_MODEL))
},
{
'type': 'confirm',
'name': common.SKIP_OP_CHECK,
'message': 'Do you want to skip op validation? \n'
'This will allow conversion of unsupported ops, \n'
'you can implement them as custom ops in tfjs-converter.',
'default': False,
'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
(common.TF_SAVED_MODEL,
common.TF_HUB_MODEL))
},
{
'type': 'confirm',
'name': common.STRIP_DEBUG_OPS,
'message': 'Do you want to strip debug ops? \n'
'This will improve model execution performance.',
'default': True,
'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT,
(common.TF_SAVED_MODEL,
# at version 1.1.0.
if input_format == 'tensorflowjs':
raise ValueError(
'--input_format=tensorflowjs has been deprecated. '
'Use --input_format=tfjs_layers_model instead.')
input_format_is_keras = (
input_format in [common.KERAS_MODEL, common.KERAS_SAVED_MODEL])
input_format_is_tf = (
input_format in [common.TF_SAVED_MODEL, common.TF_HUB_MODEL])
if output_format is None:
# If no explicit output_format is provided, infer it from input format.
if input_format_is_keras:
output_format = common.TFJS_LAYERS_MODEL
elif input_format_is_tf:
output_format = common.TFJS_GRAPH_MODEL
elif input_format == common.TFJS_LAYERS_MODEL:
output_format = common.KERAS_MODEL
elif output_format == 'tensorflowjs':
# https://github.com/tensorflow/tfjs/issues/1292: Remove the logic for the
# explicit error message of the deprecated model type name 'tensorflowjs'
# at version 1.1.0.
if input_format_is_keras:
raise ValueError(
'--output_format=tensorflowjs has been deprecated under '
'--input_format=%s. Use --output_format=tfjs_layers_model '
'instead.' % input_format)
if input_format_is_tf:
raise ValueError(
'--output_format=tensorflowjs has been deprecated under '
'--input_format=%s. Use --output_format=tfjs_graph_model '
'instead.' % input_format)
format_params = PyInquirer.prompt(formats, input_params, style=prompt_style)
message = input_path_message(format_params)
questions = [
{
'type': 'input',
'name': common.INPUT_PATH,
'message': message,
'filter': expand_input_path,
'validate': lambda value: validate_input_path(
value, format_params[common.INPUT_FORMAT]),
'when': lambda answers: (not detected_input_format)
},
{
'type': 'list',
'name': common.SAVED_MODEL_TAGS,
'choices': available_tags,
'message': 'What is tags for the saved model?',
'when': lambda answers: (is_saved_model(answers[common.INPUT_FORMAT])
and
(not common.OUTPUT_FORMAT in format_params
or format_params[common.OUTPUT_FORMAT] ==
common.TFJS_GRAPH_MODEL))
},
{
'type': 'list',
'name': common.SIGNATURE_NAME,
'message': 'What is signature name of the model?',
'choices': available_signature_names,
'when': lambda answers: (is_saved_model(answers[common.INPUT_FORMAT])
and
(not common.OUTPUT_FORMAT in format_params