Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
tag_vocab.lock()
return word_vocab, char_vocab, tag_vocab
class TsvTaggingFormat(Transform, ABC):
def file_to_inputs(self, filepath: str, gold=True):
assert gold, 'TsvTaggingFormat does not support reading non-gold files'
yield from generator_words_tags(filepath, gold=gold, lower=self.config.get('lower', False),
max_seq_length=self.max_seq_length)
@property
def max_seq_length(self):
return self.config.get('max_seq_length', None)
class TSVTaggingTransform(TsvTaggingFormat, Transform):
def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, use_char=False, **kwargs) -> None:
super().__init__(**merge_locals_kwargs(locals(), kwargs))
self.word_vocab: Optional[Vocab] = None
self.tag_vocab: Optional[Vocab] = None
self.char_vocab: Optional[Vocab] = None
def fit(self, trn_path: str, **kwargs) -> int:
self.word_vocab = Vocab()
self.tag_vocab = Vocab(pad_token=None, unk_token=None)
num_samples = 0
for words, tags in self.file_to_inputs(trn_path, True):
self.word_vocab.update(words)
self.tag_vocab.update(tags)
num_samples += 1
if self.char_vocab:
self.char_vocab = Vocab()
cells = line.strip().split()
if cells:
cells[0] = int(cells[0])
cells[6] = int(cells[6])
for i, x in enumerate(cells):
if x == '_':
cells[i] = None
sent.append(cells)
else:
yield sent
sent = []
if sent:
yield sent
class CoNLLTransform(Transform):
def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, lower=True, n_buckets=32,
n_tokens_per_batch=5000, min_freq=2,
**kwargs) -> None:
super().__init__(**merge_locals_kwargs(locals(), kwargs))
self.form_vocab: Vocab = None
self.cpos_vocab: Vocab = None
self.rel_vocab: Vocab = None
self.puncts: tf.Tensor = None
def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
form, cpos = x
return self.form_vocab.token_to_idx_table.lookup(form), self.cpos_vocab.token_to_idx_table.lookup(cpos)
def y_to_idx(self, y):
head, rel = y
from hanlp.common.component import KerasComponent
from hanlp.components.taggers.ngram_conv.ngram_conv_tagger import NgramConvTagger
from hanlp.components.taggers.rnn_tagger import RNNTagger
from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
from hanlp.metrics.chunking.sequence_labeling import get_entities, iobes_to_span
from hanlp.utils.util import merge_locals_kwargs
class IOBES_NamedEntityRecognizer(KerasComponent, ABC):
def predict_batch(self, batch, inputs=None):
for words, tags in zip(inputs, super().predict_batch(batch, inputs)):
yield from iobes_to_span(words, tags)
class IOBES_Transform(Transform):
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
for words, tags in zip(inputs, super().Y_to_outputs(Y, gold, inputs=inputs, X=X)):
yield from iobes_to_span(words, tags)
class RNNNamedEntityRecognizer(RNNTagger, IOBES_NamedEntityRecognizer):
def fit(self, trn_data: str, dev_data: str = None, save_dir: str = None, embeddings=100, embedding_trainable=False,
rnn_input_dropout=0.2, rnn_units=100, rnn_output_dropout=0.2, epochs=20, logger=None,
loss: Union[tf.keras.losses.Loss, str] = None,
optimizer: Union[str, tf.keras.optimizers.Optimizer] = 'adam', metrics='f1', batch_size=32,
dev_batch_size=32, lr_decay_per_epoch=None,
run_eagerly=False,
verbose=True, **kwargs):
# assert kwargs.get('run_eagerly', True), 'This component can only run eagerly'
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-04 11:46
from typing import Union, Tuple, Iterable, Any
import tensorflow as tf
from hanlp.common.structure import SerializableDict
from hanlp.common.transform import Transform
from hanlp.common.vocab import Vocab
from hanlp.metrics.chunking.sequence_labeling import get_entities
from hanlp.utils.file_read_backwards import FileReadBackwards
from hanlp.utils.io_util import read_tsv
class TextTransform(Transform):
def __init__(self,
forward=True,
seq_len=10,
tokenizer='char',
config: SerializableDict = None, map_x=True, map_y=True, **kwargs) -> None:
super().__init__(config, map_x, map_y, seq_len=seq_len, tokenizer=tokenizer, forward=forward, **kwargs)
self.vocab: Vocab = None
def tokenize_func(self):
if self.config.tokenizer == 'char':
return list
elif self.config.tokenizer == 'whitespace':
return lambda x: x.split()
else:
return lambda x: x.split(self.config.tokenizer)
if lower:
word_vocab.add(word.lower())
else:
word_vocab.add(word)
char_vocab.update(list(word))
tag_vocab.add(tag)
if lock_word_vocab:
word_vocab.lock()
if lock_char_vocab:
char_vocab.lock()
if lock_tag_vocab:
tag_vocab.lock()
return word_vocab, char_vocab, tag_vocab
class TsvTaggingFormat(Transform, ABC):
def file_to_inputs(self, filepath: str, gold=True):
assert gold, 'TsvTaggingFormat does not support reading non-gold files'
yield from generator_words_tags(filepath, gold=gold, lower=self.config.get('lower', False),
max_seq_length=self.max_seq_length)
@property
def max_seq_length(self):
return self.config.get('max_seq_length', None)
class TSVTaggingTransform(TsvTaggingFormat, Transform):
def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, use_char=False, **kwargs) -> None:
super().__init__(**merge_locals_kwargs(locals(), kwargs))
self.word_vocab: Optional[Vocab] = None
self.tag_vocab: Optional[Vocab] = None
self.char_vocab: Optional[Vocab] = None
corpus = list(x for x in (samples() if callable(samples) else samples))
n_tokens = 0
batch = []
for idx, sent in enumerate(corpus):
sent_len = self.len_of_sent(sent)
if n_tokens + sent_len > batch_size and batch:
yield from self.batched_inputs_to_batches(corpus, batch, shuffle)
n_tokens = 0
batch = []
n_tokens += sent_len
batch.append(idx)
if batch:
yield from self.batched_inputs_to_batches(corpus, batch, shuffle)
# next(generator())
return Transform.samples_to_dataset(self, generator, False, False, 0, False, repeat, drop_remainder, prefetch,
cache)
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-29 15:14
from typing import Union, Tuple, List, Iterable
import tensorflow as tf
from hanlp.common.structure import SerializableDict
from hanlp.common.transform import Transform
from hanlp.common.vocab import Vocab
from hanlp.components.taggers.transformers.utils import convert_examples_to_features, config_is
from hanlp.transform.tsv import TsvTaggingFormat
class TransformerTransform(TsvTaggingFormat, Transform):
def __init__(self,
tokenizer=None,
config: SerializableDict = None,
map_x=False, map_y=False, **kwargs) -> None:
super().__init__(config, map_x, map_y, **kwargs)
self._tokenizer = tokenizer
self.tag_vocab: Vocab = None
self.special_token_ids = None
self.pad = '[PAD]'
self.unk = '[UNK]'
@property
def max_seq_length(self):
# -2 for special tokens [CLS] and [SEP]
return self.config.get('max_seq_length', 128) - 2
types = tuple([tf.string] * (vec_dim - 1)), tf.string
defaults = tuple([char_vocab.pad_token] + [
ngram_vocab.pad_token if ngram_vocab else char_vocab.pad_token] * ngram_size), (
tag_vocab.pad_token if tag_vocab.pad_token else tag_vocab.first_token)
dataset = tf.data.Dataset.from_generator(generator, output_shapes=shapes, output_types=types)
if shuffle:
if isinstance(shuffle, bool):
shuffle = 1024
dataset = dataset.shuffle(shuffle)
if repeat:
dataset = dataset.repeat(repeat)
dataset = dataset.padded_batch(batch_size, shapes, defaults).prefetch(prefetch)
return dataset
class TxtFormat(Transform, ABC):
def file_to_inputs(self, filepath: str, gold=True):
filepath = get_resource(filepath)
with open(filepath, encoding='utf-8') as src:
for line in src:
sentence = line.strip()
if not sentence:
continue
yield sentence
class TxtBMESFormat(TxtFormat, ABC):
def file_to_inputs(self, filepath: str, gold=True):
max_seq_length = self.config.get('max_seq_length', False)
if max_seq_length:
if 'transformer' in self.config:
max_seq_length -= 2 # allow for [CLS] and [SEP]