Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def infer(save_model_path, test_id_path, test_word_path, test_label_path,
word_dict_path=None, label_dict_path=None, save_pred_path=None,
batch_size=64, dropout=0.5, embedding_dim=100,
rnn_hidden_dim=200, maxlen=300):
# load dict
test_ids = load_test_id(test_id_path)
word_ids_dict, ids_word_dict = load_dict(word_dict_path), load_reverse_dict(word_dict_path)
label_ids_dict, ids_label_dict = load_dict(label_dict_path), load_reverse_dict(label_dict_path)
# read data to index
word_ids = vectorize_data(test_word_path, word_ids_dict)
# pad sequence
word_seq = pad_sequence(word_ids, maxlen)
# load model by file
model = load_model(word_ids_dict, label_ids_dict, embedding_dim,
rnn_hidden_dim, dropout, save_model_path)
probs = model.predict(word_seq, batch_size=batch_size, verbose=0).argmax(-1)
assert len(probs) == len(word_seq)
print('probs.shape:', probs.shape)
test_words, test_labels = load_test_id(test_word_path), load_test_id(test_label_path)
save_preds(probs, test_ids, word_seq, ids_word_dict,
label_ids_dict, ids_label_dict, save_pred_path, test_words, test_labels)
embedding_dim=100,
rnn_hidden_dim=200,
maxlen=300,
cutoff_frequency=0):
"""
Train the bilstm_crf model for grammar correction.
"""
# build the word dictionary
build_dict(train_word_path,
word_dict_path,
cutoff_frequency,
insert_extra_words=[UNK_TOKEN, PAD_TOKEN])
# build the label dictionary
build_dict(train_label_path, label_dict_path)
# load dict
word_ids_dict = load_dict(word_dict_path)
label_ids_dict = load_dict(label_dict_path)
# read data to index
word_ids = vectorize_data(train_word_path, word_ids_dict)
label_ids = vectorize_data(train_label_path, label_ids_dict)
max_len = np.max([len(i) for i in word_ids])
print('max_len:', max_len)
# pad sequence
word_seq = pad_sequence(word_ids, maxlen=maxlen)
label_seq = pad_sequence(label_ids, maxlen=maxlen)
# reshape label for crf model use
label_seq = np.reshape(label_seq, (label_seq.shape[0], label_seq.shape[1], 1))
print(word_seq.shape)
print(label_seq.shape)
logger.info("Data loaded.")
# model
logger.info("Training BILSTM_CRF model...")
def infer(save_model_path, test_id_path, test_word_path, test_label_path,
word_dict_path=None, label_dict_path=None, save_pred_path=None,
batch_size=64, dropout=0.5, embedding_dim=100,
rnn_hidden_dim=200, maxlen=300):
# load dict
test_ids = load_test_id(test_id_path)
word_ids_dict, ids_word_dict = load_dict(word_dict_path), load_reverse_dict(word_dict_path)
label_ids_dict, ids_label_dict = load_dict(label_dict_path), load_reverse_dict(label_dict_path)
# read data to index
word_ids = vectorize_data(test_word_path, word_ids_dict)
# pad sequence
word_seq = pad_sequence(word_ids, maxlen)
# load model by file
model = load_model(word_ids_dict, label_ids_dict, embedding_dim,
rnn_hidden_dim, dropout, save_model_path)
probs = model.predict(word_seq, batch_size=batch_size, verbose=0).argmax(-1)
assert len(probs) == len(word_seq)
print('probs.shape:', probs.shape)
test_words, test_labels = load_test_id(test_word_path), load_test_id(test_label_path)
save_preds(probs, test_ids, word_seq, ids_word_dict,
label_ids_dict, ids_label_dict, save_pred_path, test_words, test_labels)
rnn_hidden_dim=200,
maxlen=300,
cutoff_frequency=0):
"""
Train the bilstm_crf model for grammar correction.
"""
# build the word dictionary
build_dict(train_word_path,
word_dict_path,
cutoff_frequency,
insert_extra_words=[UNK_TOKEN, PAD_TOKEN])
# build the label dictionary
build_dict(train_label_path, label_dict_path)
# load dict
word_ids_dict = load_dict(word_dict_path)
label_ids_dict = load_dict(label_dict_path)
# read data to index
word_ids = vectorize_data(train_word_path, word_ids_dict)
label_ids = vectorize_data(train_label_path, label_ids_dict)
max_len = np.max([len(i) for i in word_ids])
print('max_len:', max_len)
# pad sequence
word_seq = pad_sequence(word_ids, maxlen=maxlen)
label_seq = pad_sequence(label_ids, maxlen=maxlen)
# reshape label for crf model use
label_seq = np.reshape(label_seq, (label_seq.shape[0], label_seq.shape[1], 1))
print(word_seq.shape)
print(label_seq.shape)
logger.info("Data loaded.")
# model
logger.info("Training BILSTM_CRF model...")
model = create_model(word_ids_dict, label_ids_dict,