How to use the deeppavlov.core.common.registry.register function in deeppavlov

To help you get started, we’ve selected a few deeppavlov examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github deepmipt / DeepPavlov / deeppavlov / models / kbqa / answer_generation_rus.py View on Github external
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import numpy as np
import pickle
from deeppavlov.core.models.serializable import Serializable

from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component
from pathlib import Path


@register('answer_generation_rus')
class AnswerGeneration(Component, Serializable):
    """
       Class for generation of answer using triplets with the entity
       in the question and relations predicted from the question by the
       relation prediction model.
       We search a triplet with the predicted relations
    """
    
    def __init__(self, load_path: str, *args, **kwargs) -> None:
        super().__init__(save_path = None, load_path = load_path)
        self.load()

    def load(self) -> None:
        load_path = Path(self.load_path).expanduser()
        with open(load_path, 'rb') as fl:
            self.q_to_name = pickle.load(fl)
github deepmipt / DeepPavlov / deeppavlov / models / morpho_tagger / common.py View on Github external
>>> self.prettify(sent, tags)
                1	John	_	PROPN	_	Number=Sing	_	_	_	_
                2	really	_	ADV	_	_	_	_	_	_
                3	likes	_	VERB	_	Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin	_	_	_	_
                4	pizza	_	NOUN	_	Number=Sing	_	_	_	_
                5	.	_	PUNCT	_	_	_	_	_	_
        """
        answer = []
        for i, (word, tag) in enumerate(zip(tokens, tags)):
            answer.append(self.format_string.format(i + 1, word, *make_pos_and_tag(tag)))
        if self.return_string:
            answer = self.begin + self.sep.join(answer) + self.end
        return answer


@register('lemmatized_output_prettifier')
class LemmatizedOutputPrettifier(Component):
    """Class which prettifies morphological tagger output to 4-column
    or 10-column (Universal Dependencies) format.

    Args:
        format_mode: output format,
            in `basic` mode output data contains 4 columns (id, word, pos, features),
            in `conllu` or `ud` mode it contains 10 columns:
            id, word, lemma, pos, xpos, feats, head, deprel, deps, misc
            (see http://universaldependencies.org/format.html for details)
            Only id, word, tag and pos values are a in current version,
            other columns are filled by `_` value.
        return_string: whether to return a list of strings or a single string
        begin: a string to append in the beginning
        end: a string to append in the end
        sep: separator between word analyses
github deepmipt / DeepPavlov / deeppavlov / models / classifiers / sklearn_classifiers.py View on Github external
from sklearn.linear_model.logistic import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier
from scipy.sparse import vstack
from scipy.sparse import csr_matrix


from deeppavlov.core.common.registry import register
from deeppavlov.core.common.log import get_logger
from deeppavlov.core.models.estimator import Estimator
from deeppavlov.core.common.errors import ConfigError

log = get_logger(__name__)


@register("logistic_regression")
class LogReg(Estimator):
    """
    The class implements the Logistic Regression Classifier from Sklearn library.

    Args:
        save_path (str): save path
        load_path (str): load path
        mode: train/infer trigger
        **kwargs: additional arguments

    Attributes:
        model: Logistic Regression Classifier class from sklearn
    """

    def __init__(self, penalty='l2', dual=False, tol=0.0001, C=1.0, fit_intercept=True, intercept_scaling=1,
                 class_weight=None, random_state=None, solver='liblinear', max_iter=100, multi_class='ovr',
github deepmipt / DeepPavlov / deeppavlov / dataset_readers / multiwoz_reader.py View on Github external
from tqdm import tqdm
from pathlib import Path
from typing import Dict, List, Union, Tuple
from logging import getLogger

from overrides import overrides

from deeppavlov.core.common.registry import register
from deeppavlov.core.data.dataset_reader import DatasetReader
from deeppavlov.core.data.utils import download_decompress, mark_done


log = getLogger()


@register('multiwoz_reader')
class MultiWOZDatasetReader(DatasetReader):
    """
    # TODO: add docs
    """

    url = 'http://files.deeppavlov.ai/datasets/multiwoz.tar.gz'

    DATA_FILES = ['data.json', 'dialogue_acts.json', 'valListFile.json',
                  'testListFile.json',
                  'attraction_db.json', 'bus_db.json',
                  'hospital_db.json', 'hotel_db.json', 'police_db.json',
                  'restaurant_db.json', 'taxi_db.json', 'train_db.json']
    PREPROCESSED = ['data_prep.json']

    @classmethod
    @overrides
github deepmipt / DeepPavlov / deeppavlov / models / trackers / hcn_et.py View on Github external
from enum import Enum

import numpy as np

from deeppavlov.core.common.registry import register
from deeppavlov.core.models.inferable import Inferable

ENTITIES = {
    '': None,
    '': None,
    '': None,
    '': None,
}


@register('hcn_et')
class EntityTracker(Inferable):
    def __init__(self, entities=copy.deepcopy(ENTITIES)):
        self.entities = entities
        self.num_features = 4  # tracking 4 entities
        self.rating = None

        # constants
        self.party_sizes = ['1', '2', '3', '4', '5', '6', '7', '8', 'one', 'two', 'three',
                            'four', 'five', 'six', 'seven', 'eight']
        self.locations = ['bangkok', 'beijing', 'bombay', 'hanoi', 'paris', 'rome', 'london',
                          'madrid', 'seoul', 'tokyo']
        self.cuisines = ['british', 'cantonese', 'french', 'indian', 'italian', 'japanese',
                         'korean', 'spanish', 'thai', 'vietnamese']
        self.rest_types = ['cheap', 'expensive', 'moderate']

        self.EntType = Enum('Entity Type',
github deepmipt / DeepPavlov / deeppavlov / models / tokenizers / spacy_tokenizer.py View on Github external
def _try_load_spacy_model(model_name: str, disable: Iterable[str] = ()):
    disable = set(disable)
    try:
        model = spacy.load(model_name, disable=disable)
    except OSError as e:
        try:
            model = __import__(model_name).load(disable=disable)
            if not isinstance(model, spacy.language.Language):
                raise RuntimeError(f'{model_name} is not a spacy model module')
        except Exception:
            raise e
    return model


@register('stream_spacy_tokenizer')
class StreamSpacyTokenizer(Component):
    """Tokenize or lemmatize a list of documents. Default spacy model is **en_core_web_sm**.
    Return a list of tokens or lemmas for a whole document.
    If is called onto ``List[str]``, performs detokenizing procedure.

    Args:
        disable: spacy pipeline elements to disable, serves a purpose of performing; if nothing
        stopwords: a list of stopwords that should be ignored during tokenizing/lemmatizing
         and ngrams creation
        batch_size: a batch size for spaCy buffering
        ngram_range: size of ngrams to create; only unigrams are returned by default
        lemmas: whether to perform lemmatizing or not
        lowercase: whether to perform lowercasing or not; is performed by default by :meth:`_tokenize`
         and :meth:`_lemmatize` methods
        alphas_only: whether to filter out non-alpha tokens; is performed by default by
         :meth:`_filter` method
github deepmipt / DeepPavlov / deeppavlov / models / evolution / evolution_param_generator.py View on Github external
# limitations under the License.

from copy import deepcopy
from logging import getLogger
from pathlib import Path
from typing import List, Any

import numpy as np

from deeppavlov.core.common.params_search import ParamsSearch
from deeppavlov.core.common.registry import register

log = getLogger(__name__)


@register('params_evolution')
class ParamsEvolution(ParamsSearch):
    """
    Class performs full evolutionary process (task scores -> max):
    1. initializes random population
    2. makes replacement to get next generation:
        a. selection according to obtained scores
        b. crossover (recombination) with given probability p_crossover
        c. mutation with given mutation rate p_mutation (probability to mutate)
            according to given mutation power sigma
            (current mutation power is randomly from -sigma to sigma)

    Args:
        population_size: number of individuums per generation
        p_crossover: probability to cross over for current replacement
        crossover_power: part of EVOLVING parents parameters to exchange for offsprings
        p_mutation: probability of mutation for current replacement
github deepmipt / DeepPavlov / deeppavlov / models / embedders / dict_embedder.py View on Github external
"""

import numpy as np
from pathlib import Path
from overrides import overrides

from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component
from deeppavlov.core.common.log import get_logger
from deeppavlov.core.models.serializable import Serializable
from typing import List

log = get_logger(__name__)


@register('dict_emb')
class DictEmbedder(Component, Serializable):
    def __init__(self, load_path, save_path=None, dim=100, **kwargs):
        super().__init__(save_path=save_path, load_path=load_path)
        self.tok2emb = {}
        self.dim = dim

        self.load()

    def save(self, *args, **kwargs):
        raise NotImplementedError

    def load(self):
        """
        Load dictionary of embeddings from file.
        """
github deepmipt / DeepPavlov / deeppavlov / dataset_readers / ubuntu_v2_mt_reader.py View on Github external
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import csv
from pathlib import Path
from typing import List, Tuple, Union, Dict

from deeppavlov.core.common.registry import register
from deeppavlov.core.data.dataset_reader import DatasetReader


@register('ubuntu_v2_mt_reader')
class UbuntuV2MTReader(DatasetReader):
    """The class to read the Ubuntu V2 dataset from csv files taking into account multi-turn dialogue ``context``.

    Please, see https://github.com/rkadlec/ubuntu-ranking-dataset-creator.

    Args:
        data_path: A path to a folder with dataset csv files.
        num_context_turns: A maximum number of dialogue ``context`` turns.
        padding: "post" or "pre" context sentences padding
    """
    
    def read(self, data_path: str,
             num_context_turns: int = 1,
             padding: str = "post",
             *args, **kwargs) -> Dict[str, List[Tuple[List[str], int]]]:
        """Read the Ubuntu V2 dataset from csv files taking into account multi-turn dialogue ``context``.
github deepmipt / DeepPavlov / deeppavlov / models / chitchat_bot / adapter.py View on Github external
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import getLogger

from overrides import overrides
from jsonschema import validate

from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component

log = getLogger(__name__)


@register("chitchat_bot_adapter")
class ChitChatBotAdapter(Component):
    """
    Expample of input_data:
    [
        {
            "dialog": [
                {
                    "sender_id": "text",
                    "sender_class": "text",
                    "text": "text",
                    "system": False,
                    "time": "text",
                }
            ],
            "start_time": "text",
            "users": [