How to use the simpletransformers.classification.ClassificationModel function in simpletransformers

To help you get started, we’ve selected a few simpletransformers examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github ThilinaRajapakse / simpletransformers / examples / text_classification / binary_classification.py View on Github external
from simpletransformers.classification import ClassificationModel
import pandas as pd


# Train and Evaluation data needs to be in a Pandas Dataframe of two columns. The first column is the text with type str, and the second column is the label with type int.
train_data = [['Example sentence belonging to class 1', 1], ['Example sentence belonging to class 0', 0]]
train_df = pd.DataFrame(train_data)

eval_data = [['Example eval sentence belonging to class 1', 1], ['Example eval sentence belonging to class 0', 0]]
eval_df = pd.DataFrame(eval_data)

# Create a ClassificationModel
model = ClassificationModel('roberta', 'roberta-base')

# Train the model
model.train_model(train_df)

# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(eval_df)
github ThilinaRajapakse / simpletransformers / simpletransformers / experimental / classification / multi_label_classification_model.py View on Github external
XLMForMultiLabelSequenceClassification,
                                                    DistilBertForMultiLabelSequenceClassification,
                                                    AlbertForMultiLabelSequenceClassification
                                                    )
from transformers import (
    WEIGHTS_NAME,
    BertConfig, BertTokenizer,
    XLNetConfig, XLNetTokenizer,
    XLMConfig, XLMTokenizer,
    RobertaConfig, RobertaTokenizer,
    DistilBertConfig, DistilBertTokenizer,
    AlbertConfig, AlbertTokenizer
)


class MultiLabelClassificationModel(ClassificationModel):
    def __init__(self, model_type, model_name, num_labels=None, pos_weight=None, args=None, use_cuda=True):
        """
        Initializes a MultiLabelClassification model.

        Args:
            model_type: The type of model (bert, roberta)
            model_name: Default Transformer model name or path to a directory containing Transformer model file (pytorch_nodel.bin).
            num_labels (optional): The number of labels or classes in the dataset.
            pos_weight (optional): A list of length num_labels containing the weights to assign to each label for loss calculation.
            args (optional): Default args will be used if this parameter is not provided. If provided, it should be a dict containing the args that should be changed in the default args.
            use_cuda (optional): Use GPU if available. Setting to False will force model to use CPU only.
        """
        MODEL_CLASSES = {
            'bert':       (BertConfig, BertForMultiLabelSequenceClassification, BertTokenizer),
            'roberta':    (RobertaConfig, RobertaForMultiLabelSequenceClassification, RobertaTokenizer),
            'xlnet':      (XLNetConfig, XLNetForMultiLabelSequenceClassification, XLNetTokenizer),
github ThilinaRajapakse / simpletransformers / simpletransformers / classification / multi_label_classification_model.py View on Github external
RobertaForMultiLabelSequenceClassification, 
                                                    XLNetForMultiLabelSequenceClassification,
                                                    XLMForMultiLabelSequenceClassification,
                                                    DistilBertForMultiLabelSequenceClassification
                                                    )
from transformers import (
    WEIGHTS_NAME,
    BertConfig, BertTokenizer,
    XLNetConfig, XLNetTokenizer,
    XLMConfig, XLMTokenizer,
    RobertaConfig, RobertaTokenizer,
    DistilBertConfig, DistilBertTokenizer
)


class MultiLabelClassificationModel(ClassificationModel):
    def __init__(self, model_type, model_name, num_labels=2, args=None, use_cuda=True):
        """
        Initializes a MultiLabelClassification model.

        Args:
            model_type: The type of model (bert, roberta)
            model_name: Default Transformer model name or path to a directory containing Transformer model file (pytorch_nodel.bin).
            num_labels (optional): The number of labels or classes in the dataset.
            args (optional): Default args will be used if this parameter is not provided. If provided, it should be a dict containing the args that should be changed in the default args.
            use_cuda (optional): Use GPU if available. Setting to False will force model to use CPU only.
        """
        MODEL_CLASSES = {
            'bert': (BertConfig, BertForMultiLabelSequenceClassification, BertTokenizer),
            'roberta': (RobertaConfig, RobertaForMultiLabelSequenceClassification, RobertaTokenizer),
            'xlnet': (XLNetConfig, XLNetForMultiLabelSequenceClassification, XLNetTokenizer),
            'xlm': (XLMConfig, XLMForMultiLabelSequenceClassification, XLMTokenizer),
github ThilinaRajapakse / simpletransformers / examples / text_classification / multiclass_classification.py View on Github external
from simpletransformers.classification import ClassificationModel
import pandas as pd


# Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns. If the Dataframe has a header, it should contain a 'text' and a 'labels' column. If no header is present, the Dataframe should contain at least two columns, with the first column is the text with type str, and the second column in the label with type int.
train_data = [['Example sentence belonging to class 1', 1], ['Example sentence belonging to class 0', 0], ['Example eval senntence belonging to class 2', 2]]
train_df = pd.DataFrame(train_data)

eval_data = [['Example eval sentence belonging to class 1', 1], ['Example eval sentence belonging to class 0', 0], ['Example eval senntence belonging to class 2', 2]]
eval_df = pd.DataFrame(eval_data)

# Create a ClassificationModel
model = ClassificationModel('bert', 'bert-base-cased', num_labels=3, args={'reprocess_input_data': True, 'overwrite_output_dir': True})

# Train the model
model.train_model(train_df)

# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(eval_df)

predictions, raw_outputs = model.predict(["Some arbitary sentence"])

simpletransformers

An easy-to-use wrapper library for the Transformers library.

Apache-2.0
Latest version published 5 months ago

Package Health Score

70 / 100
Full package analysis

Similar packages