How to use the snorkel.models.Candidate function in snorkel

To help you get started, we’ve selected a few snorkel examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github snorkel-team / snorkel / snorkel / contrib / brat / tools.py View on Github external
spans.append(tc)
                            except Exception as e:
                                print "BRAT candidate conversion error", len(doc.sentences), j
                                print e

                entity_types[class_type].append(spans)

        for i, class_type in enumerate(stable_labels_by_type):

            if class_type in self.subclasses:
                class_name = self.subclasses[class_type]
            else:
                class_name = self.subclasses[self._get_normed_rela_name(class_type)]

            if clear:
                self.session.query(Candidate).filter(Candidate.split == i).delete()

            candidate_args = {'split': i}
            for args in entity_types[class_type]:
                for j, arg_name in enumerate(class_name.__argnames__):
                    candidate_args[arg_name + '_id'] = args[j].id

                candidate = class_name(**candidate_args)
                self.session.add(candidate)

        self.session.commit()
github snorkel-team / snorkel / snorkel / contrib / brat / brat.py View on Github external
collection_path = "{}/{}".format(self.data_root, annotation_dir)
        if os.path.exists(collection_path) and not overwrite:
            msg = "Error! Collection at '{}' already exists. ".format(annotation_dir)
            msg += "Please set overwrite=True to erase all existing annotations.\n"
            sys.stderr.write(msg)
            return

        # remove existing annotations
        if os.path.exists(collection_path):
            shutil.rmtree(collection_path, ignore_errors=True)
            print("Removed existing collection at '{}'".format(annotation_dir))

        # create subquery based on candidate split
        if split != None:
            cid_query = self.session.query(Candidate.id).filter(Candidate.split == split).subquery()

        # generate all documents for this candidate set
        doc_ids = get_doc_ids_by_query(self.session, self.candidate_class, cid_query)
        documents = self.session.query(Document).filter(Document.id.in_(doc_ids)).all()

        # create collection on disk
        os.makedirs(collection_path)

        for doc in documents:
            text = doc_to_text(doc)
            outfpath = "{}/{}".format(collection_path, doc.name)
            with codecs.open(outfpath + ".txt","w", self.encoding, errors=errors) as fp:
                fp.write(text)
            with codecs.open(outfpath + ".ann","w", self.encoding, errors=errors) as fp:
                fp.write("")
github snorkel-team / snorkel / snorkel / annotations.py View on Github external
def load_matrix(matrix_class, annotation_key_class, annotation_class, session,
    split=0, cids_query=None, key_group=0, key_names=None, zero_one=False,
    load_as_array=False, coerce_int=True):
    """
    Returns the annotations corresponding to a split of candidates with N members
    and an AnnotationKey group with M distinct keys as an N x M CSR sparse matrix.
    """
    cid_query = cids_query or session.query(Candidate.id)\
                                     .filter(Candidate.split == split)
    cid_query = cid_query.order_by(Candidate.id)

    keys_query = session.query(annotation_key_class.id)
    keys_query = keys_query.filter(annotation_key_class.group == key_group)
    if key_names is not None:
        keys_query = keys_query.filter(annotation_key_class.name.in_(frozenset(key_names)))
    keys_query = keys_query.order_by(annotation_key_class.id)

    # First, we query to construct the row index map
    cid_to_row = {}
    row_to_cid = {}
    for cid, in cid_query.all():
        if cid not in cid_to_row:
            j = len(cid_to_row)

            # Create both mappings
github snorkel-team / snorkel / snorkel / candidates.py View on Github external
def clear(self, session, split, **kwargs):
        session.query(Candidate).filter(Candidate.split == split).delete()
github snorkel-team / snorkel / tutorials / workshop / lib / scoring.py View on Github external
def score(session, lf, split, gold, unlabled_as_neg=False):

    cands = session.query(Candidate).filter(Candidate.split == split).order_by(Candidate.id).all()

    tp, fp, tn, fn = [], [], [], []
    for i,c in enumerate(cands):
        label = lf(c)
        label = -1 if label == 0 and unlabled_as_neg else label

        if label == -1 and gold[i, 0] == 1:
            fn += [c]
        elif label == -1 and gold[i, 0] == -1:
            tn += [c]
        elif label == 1 and gold[i, 0] == 1:
            tp += [c]
        elif label == 1 and gold[i, 0] == -1:
            fp += [c]

    print_scores(len(tp), len(fp), len(tn), len(fn), title='LF Score')
github snorkel-team / snorkel / snorkel / lf_helpers.py View on Github external
def test_LF(session, lf, split, annotator_name):
    """
    Gets the accuracy of a single LF on a split of the candidates, w.r.t. annotator labels,
    and also returns the error buckets of the candidates.
    """
    test_candidates = session.query(Candidate).filter(Candidate.split == split).all()
    test_labels     = load_gold_labels(session, annotator_name=annotator_name, split=split)
    scorer          = MentionScorer(test_candidates, test_labels)
    test_marginals  = np.array([0.5 * (lf(c) + 1) for c in test_candidates])
    return scorer.score(test_marginals, set_unlabeled_as_neg=False, set_at_thresh_as_neg=False)
github HazyResearch / metal / metal / contrib / backends / wrapper.py View on Github external
self.session = SnorkelDataset.session

        self.class_type = candidate_subclass(*candidate_def)
        self.cardinality = len(candidate_def[-1])
        self.split = split
        self.max_seq_len = max_seq_len

        # create markup sequences and labels
        markers = [
            m.format(i)
            for i in range(self.cardinality)
            for m in ["~~[[{}", "{}]]~~"]
        ]
        self.X = (
            self.session.query(Candidate)
            .filter(Candidate.split == split)
            .order_by(Candidate.id)
            .all()
        )
        self.X = [self._mark_entities(x, markers) for x in self.X]

        # initalize vocabulary
        self.word_dict = (
            self._build_vocab(self.X, markers) if not word_dict else word_dict
        )
        if pretrained_word_dict:
            # include pretrained embedding terms
            self._include_pretrained_vocab(
                pretrained_word_dict, self.session.query(Candidate).all()
            )

        # initalize labels (from either LFs or gold labels)
github snorkel-team / snorkel / snorkel / contrib / brat / tools.py View on Github external
def export_project(self, output_dir, positive_only_labels=True):
        """

        :param output_dir:
        :return:
        """
        candidates = self.session.query(Candidate).all()
        documents = self.session.query(Document).all()

        gold_labels = {label.candidate_id: label for label in self.session.query(GoldLabel).all()}
        gold_labels = {uid:label for uid, label in gold_labels.items()
                      if (positive_only_labels and label.value == 1) or not positive_only_labels}

        doc_index     = {doc.name:doc for doc in documents}
        cand_index    = _group_by_document(candidates)
        snorkel_types = {type(c): 1 for c in candidates}

        for name in doc_index:
            doc_anno = self._build_doc_annotations(cand_index[name], gold_labels) if name in cand_index else []
            fname = "{}{}".format(output_dir,name)
            #  write .ann files
            with codecs.open(fname + ".ann",'w',self.encoding) as fp:
                fp.write("\n".join(doc_anno))
github snorkel-team / snorkel / snorkel / learning / tensorflow / rnn / rnn_base.py View on Github external
def _marginals_batch(self, test_candidates):
        """Get likelihood of tagged sequences represented by test_candidates
            @test_candidates: list of lists representing test sentence
        """
        # Preprocess if not already preprocessed
        if isinstance(test_candidates[0], Candidate):
            X_test, ends = self._preprocess_data(test_candidates, extend=False)
            self._check_max_sentence_length(ends)
        else:
            X_test = test_candidates

        # Make tensor and run prediction op
        x, x_len = self._make_tensor(X_test)
        return self.session.run(self.marginals_op, {
            self.sentences:        x,
            self.sentence_lengths: x_len,
            self.keep_prob:        1.0,
        })