Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _init():
with open(get_resource(HANLP_CHAR_TABLE), encoding='utf-8') as src:
for line in src:
cells = line.rstrip('\n')
if len(cells) != 3:
continue
a, _, b = cells
CharTable.convert[a] = b
def read_conll(filepath):
sent = []
filepath = get_resource(filepath)
with open(filepath, encoding='utf-8') as src:
for line in src:
if line.startswith('#'):
continue
cells = line.strip().split()
if cells:
cells[0] = int(cells[0])
cells[6] = int(cells[6])
for i, x in enumerate(cells):
if x == '_':
cells[i] = None
sent.append(cells)
else:
yield sent
sent = []
if sent:
def export_model_for_serving(self, export_dir=None, version=1, overwrite=False, show_hint=False):
assert self.model, 'You have to fit or load a model before exporting it'
if not export_dir:
assert 'load_path' in self.meta, 'When not specifying save_dir, load_path has to present'
export_dir = get_resource(self.meta['load_path'])
model_path = os.path.join(export_dir, str(version))
if os.path.isdir(model_path) and not overwrite:
logger.info(f'{model_path} exists, skip since overwrite = {overwrite}')
return export_dir
logger.info(f'Exporting to {export_dir} ...')
tf.saved_model.save(self.model, model_path)
logger.info(f'Successfully exported model to {export_dir}')
if show_hint:
logger.info(f'You can serve it through \n'
f'tensorflow_model_server --model_name={os.path.splitext(os.path.basename(self.meta["load_path"]))[0]} '
f'--model_base_path={export_dir} --rest_api_port=8888')
return export_dir
def load_config(self, save_dir, filename='config.json'):
save_dir = get_resource(save_dir)
self.config.load_json(os.path.join(save_dir, filename))
def load_vocabs(self, save_dir, filename='vocabs.json'):
save_dir = get_resource(save_dir)
vocabs = SerializableDict()
vocabs.load_json(os.path.join(save_dir, filename))
for key, value in vocabs.items():
vocab = Vocab()
vocab.copy_from(value)
setattr(self.transform, key, vocab)
def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128, logger: logging.Logger = None,
callbacks: List[tf.keras.callbacks.Callback] = None, warm_up=True, verbose=True, **kwargs):
input_path = get_resource(input_path)
file_prefix, ext = os.path.splitext(input_path)
name = os.path.basename(file_prefix)
if not name:
name = 'evaluate'
if save_dir and not logger:
logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO if verbose else logging.WARN,
mode='w')
tst_data = self.transform.file_to_dataset(input_path, batch_size=batch_size)
samples = size_of_dataset(tst_data)
num_batches = math.ceil(samples / batch_size)
if warm_up:
self.model.predict_on_batch(tst_data.take(1))
if output:
assert save_dir, 'Must pass save_dir in order to output'
if isinstance(output, bool):
output = os.path.join(save_dir, name) + '.predict' + ext
def load_transform(self, save_dir) -> Transform:
"""
Try to load transform only. This method might fail due to the fact it avoids building the model.
If it do fail, then you have to use `load` which might be too heavy but that's the best we can do.
:param save_dir: The path to load.
"""
save_dir = get_resource(save_dir)
self.load_config(save_dir)
self.load_vocabs(save_dir)
self.transform.build_config()
self.transform.lock_vocabs()
return self.transform
with stdout_redirected(to=os.devnull):
model_url = fetch_tfhub_albert_model(transformer,
os.path.join(hanlp_home(), 'thirdparty', 'tfhub.dev', 'google',
transformer))
albert = True
spm_model_file = glob.glob(os.path.join(model_url, 'assets', '*.model'))
assert len(spm_model_file) == 1, 'No vocab found or unambiguous vocabs found'
spm_model_file = spm_model_file[0]
elif transformer in bert_models_google:
from bert.tokenization.bert_tokenization import FullTokenizer
model_url = bert_models_google[transformer]
albert = False
else:
raise ValueError(
f'Unknown model {transformer}, available ones: {list(bert_models_google.keys()) + list(zh_albert_models_google.keys()) + list(albert_models_tfhub.keys())}')
bert_dir = get_resource(model_url)
if spm_model_file:
vocab = glob.glob(os.path.join(bert_dir, 'assets', '*.vocab'))
else:
vocab = glob.glob(os.path.join(bert_dir, '*vocab*.txt'))
assert len(vocab) == 1, 'No vocab found or unambiguous vocabs found'
vocab = vocab[0]
lower_case = any(key in transformer for key in ['uncased', 'multilingual', 'chinese', 'albert'])
if spm_model_file:
# noinspection PyTypeChecker
tokenizer = FullTokenizer(vocab_file=vocab, spm_model_file=spm_model_file, do_lower_case=lower_case)
else:
tokenizer = FullTokenizer(vocab_file=vocab, do_lower_case=lower_case)
if tokenizer_only:
return tokenizer
if spm_model_file:
bert_params = albert_params(bert_dir)
def file_to_inputs(self, filepath: str, gold=True):
filepath = get_resource(filepath)
with open(filepath, encoding='utf-8') as src:
for line in src:
sentence = line.strip()
if not sentence:
continue
yield sentence