How to use the hanlp.common.transform.Transform function in hanlp

To help you get started, we’ve selected a few hanlp examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github hankcs / HanLP / hanlp / transform / tsv.py View on Github external
tag_vocab.lock()
    return word_vocab, char_vocab, tag_vocab


class TsvTaggingFormat(Transform, ABC):
    def file_to_inputs(self, filepath: str, gold=True):
        assert gold, 'TsvTaggingFormat does not support reading non-gold files'
        yield from generator_words_tags(filepath, gold=gold, lower=self.config.get('lower', False),
                                        max_seq_length=self.max_seq_length)

    @property
    def max_seq_length(self):
        return self.config.get('max_seq_length', None)


class TSVTaggingTransform(TsvTaggingFormat, Transform):
    def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, use_char=False, **kwargs) -> None:
        super().__init__(**merge_locals_kwargs(locals(), kwargs))
        self.word_vocab: Optional[Vocab] = None
        self.tag_vocab: Optional[Vocab] = None
        self.char_vocab: Optional[Vocab] = None

    def fit(self, trn_path: str, **kwargs) -> int:
        self.word_vocab = Vocab()
        self.tag_vocab = Vocab(pad_token=None, unk_token=None)
        num_samples = 0
        for words, tags in self.file_to_inputs(trn_path, True):
            self.word_vocab.update(words)
            self.tag_vocab.update(tags)
            num_samples += 1
        if self.char_vocab:
            self.char_vocab = Vocab()
github hankcs / HanLP / hanlp / components / parsers / conll.py View on Github external
cells = line.strip().split()
            if cells:
                cells[0] = int(cells[0])
                cells[6] = int(cells[6])
                for i, x in enumerate(cells):
                    if x == '_':
                        cells[i] = None
                sent.append(cells)
            else:
                yield sent
                sent = []
    if sent:
        yield sent


class CoNLLTransform(Transform):

    def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, lower=True, n_buckets=32,
                 n_tokens_per_batch=5000, min_freq=2,
                 **kwargs) -> None:
        super().__init__(**merge_locals_kwargs(locals(), kwargs))
        self.form_vocab: Vocab = None
        self.cpos_vocab: Vocab = None
        self.rel_vocab: Vocab = None
        self.puncts: tf.Tensor = None

    def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
        form, cpos = x
        return self.form_vocab.token_to_idx_table.lookup(form), self.cpos_vocab.token_to_idx_table.lookup(cpos)

    def y_to_idx(self, y):
        head, rel = y
github hankcs / HanLP / hanlp / components / ner.py View on Github external
from hanlp.common.component import KerasComponent
from hanlp.components.taggers.ngram_conv.ngram_conv_tagger import NgramConvTagger
from hanlp.components.taggers.rnn_tagger import RNNTagger
from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
from hanlp.metrics.chunking.sequence_labeling import get_entities, iobes_to_span
from hanlp.utils.util import merge_locals_kwargs


class IOBES_NamedEntityRecognizer(KerasComponent, ABC):

    def predict_batch(self, batch, inputs=None):
        for words, tags in zip(inputs, super().predict_batch(batch, inputs)):
            yield from iobes_to_span(words, tags)


class IOBES_Transform(Transform):

    def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
        for words, tags in zip(inputs, super().Y_to_outputs(Y, gold, inputs=inputs, X=X)):
            yield from iobes_to_span(words, tags)


class RNNNamedEntityRecognizer(RNNTagger, IOBES_NamedEntityRecognizer):

    def fit(self, trn_data: str, dev_data: str = None, save_dir: str = None, embeddings=100, embedding_trainable=False,
            rnn_input_dropout=0.2, rnn_units=100, rnn_output_dropout=0.2, epochs=20, logger=None,
            loss: Union[tf.keras.losses.Loss, str] = None,
            optimizer: Union[str, tf.keras.optimizers.Optimizer] = 'adam', metrics='f1', batch_size=32,
            dev_batch_size=32, lr_decay_per_epoch=None,
            run_eagerly=False,
            verbose=True, **kwargs):
        # assert kwargs.get('run_eagerly', True), 'This component can only run eagerly'
github hankcs / HanLP / hanlp / transform / text.py View on Github external
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-04 11:46
from typing import Union, Tuple, Iterable, Any

import tensorflow as tf

from hanlp.common.structure import SerializableDict
from hanlp.common.transform import Transform
from hanlp.common.vocab import Vocab
from hanlp.metrics.chunking.sequence_labeling import get_entities
from hanlp.utils.file_read_backwards import FileReadBackwards
from hanlp.utils.io_util import read_tsv


class TextTransform(Transform):

    def __init__(self,
                 forward=True,
                 seq_len=10,
                 tokenizer='char',
                 config: SerializableDict = None, map_x=True, map_y=True, **kwargs) -> None:
        super().__init__(config, map_x, map_y, seq_len=seq_len, tokenizer=tokenizer, forward=forward, **kwargs)
        self.vocab: Vocab = None

    def tokenize_func(self):
        if self.config.tokenizer == 'char':
            return list
        elif self.config.tokenizer == 'whitespace':
            return lambda x: x.split()
        else:
            return lambda x: x.split(self.config.tokenizer)
github hankcs / HanLP / hanlp / transform / tsv.py View on Github external
if lower:
                    word_vocab.add(word.lower())
                else:
                    word_vocab.add(word)
                char_vocab.update(list(word))
                tag_vocab.add(tag)
    if lock_word_vocab:
        word_vocab.lock()
    if lock_char_vocab:
        char_vocab.lock()
    if lock_tag_vocab:
        tag_vocab.lock()
    return word_vocab, char_vocab, tag_vocab


class TsvTaggingFormat(Transform, ABC):
    def file_to_inputs(self, filepath: str, gold=True):
        assert gold, 'TsvTaggingFormat does not support reading non-gold files'
        yield from generator_words_tags(filepath, gold=gold, lower=self.config.get('lower', False),
                                        max_seq_length=self.max_seq_length)

    @property
    def max_seq_length(self):
        return self.config.get('max_seq_length', None)


class TSVTaggingTransform(TsvTaggingFormat, Transform):
    def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, use_char=False, **kwargs) -> None:
        super().__init__(**merge_locals_kwargs(locals(), kwargs))
        self.word_vocab: Optional[Vocab] = None
        self.tag_vocab: Optional[Vocab] = None
        self.char_vocab: Optional[Vocab] = None
github hankcs / HanLP / hanlp / components / parsers / conll.py View on Github external
corpus = list(x for x in (samples() if callable(samples) else samples))
                n_tokens = 0
                batch = []
                for idx, sent in enumerate(corpus):
                    sent_len = self.len_of_sent(sent)
                    if n_tokens + sent_len > batch_size and batch:
                        yield from self.batched_inputs_to_batches(corpus, batch, shuffle)
                        n_tokens = 0
                        batch = []
                    n_tokens += sent_len
                    batch.append(idx)
                if batch:
                    yield from self.batched_inputs_to_batches(corpus, batch, shuffle)

        # next(generator())
        return Transform.samples_to_dataset(self, generator, False, False, 0, False, repeat, drop_remainder, prefetch,
                                            cache)
github hankcs / HanLP / hanlp / components / taggers / transformers / transformer_transform.py View on Github external
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-29 15:14
from typing import Union, Tuple, List, Iterable

import tensorflow as tf

from hanlp.common.structure import SerializableDict
from hanlp.common.transform import Transform
from hanlp.common.vocab import Vocab
from hanlp.components.taggers.transformers.utils import convert_examples_to_features, config_is
from hanlp.transform.tsv import TsvTaggingFormat


class TransformerTransform(TsvTaggingFormat, Transform):
    def __init__(self,
                 tokenizer=None,
                 config: SerializableDict = None,
                 map_x=False, map_y=False, **kwargs) -> None:
        super().__init__(config, map_x, map_y, **kwargs)
        self._tokenizer = tokenizer
        self.tag_vocab: Vocab = None
        self.special_token_ids = None
        self.pad = '[PAD]'
        self.unk = '[UNK]'

    @property
    def max_seq_length(self):
        # -2 for special tokens [CLS] and [SEP]
        return self.config.get('max_seq_length', 128) - 2
github hankcs / HanLP / hanlp / transform / txt.py View on Github external
types = tuple([tf.string] * (vec_dim - 1)), tf.string
    defaults = tuple([char_vocab.pad_token] + [
        ngram_vocab.pad_token if ngram_vocab else char_vocab.pad_token] * ngram_size), (
                   tag_vocab.pad_token if tag_vocab.pad_token else tag_vocab.first_token)
    dataset = tf.data.Dataset.from_generator(generator, output_shapes=shapes, output_types=types)
    if shuffle:
        if isinstance(shuffle, bool):
            shuffle = 1024
        dataset = dataset.shuffle(shuffle)
    if repeat:
        dataset = dataset.repeat(repeat)
    dataset = dataset.padded_batch(batch_size, shapes, defaults).prefetch(prefetch)
    return dataset


class TxtFormat(Transform, ABC):
    def file_to_inputs(self, filepath: str, gold=True):
        filepath = get_resource(filepath)
        with open(filepath, encoding='utf-8') as src:
            for line in src:
                sentence = line.strip()
                if not sentence:
                    continue
                yield sentence


class TxtBMESFormat(TxtFormat, ABC):
    def file_to_inputs(self, filepath: str, gold=True):
        max_seq_length = self.config.get('max_seq_length', False)
        if max_seq_length:
            if 'transformer' in self.config:
                max_seq_length -= 2  # allow for [CLS] and [SEP]