How to use the medcat.utils.ml_utils.train_network function in medcat

To help you get started, we’ve selected a few medcat examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github CogStack / MedCAT / medcat / meta_cat.py View on Github external
# Mainly for testing, not really used in a normal workflow
            f1s = []
            ps = []
            rs = []
            cls_reports = []
            for i in range(cv):
                # Reset the model
                if fine_tune:
                    self.load_model(model=model_name)
                else:
                    if model_name == 'lstm':
                        from medcat.utils.models import LSTM
                        nclasses = len(self.category_values)
                        self.model = LSTM(self.embeddings, self.pad_id, nclasses=nclasses)

                (_f1, _p, _r, _cls_report) = train_network(self.model, data, max_seq_len=(self.cntx_left+self.cntx_right+1), lr=lr, test_size=test_size,
                        pad_id=self.pad_id, batch_size=batch_size, nepochs=nepochs, device=self.device,
                        class_weights=class_weights, ignore_cpos=ignore_cpos, save_dir=self.save_dir, score_average=score_average)
                f1s.append(_f1)
                ps.append(_p)
                rs.append(_r)
                cls_reports.append(_cls_report)
            f1 = np.average(f1s)
            p = np.average(ps)
            r = np.average(rs)

            # Average cls reports
            cls_report = {}
            _cls_report = cls_reports[0]
            for label in _cls_report.keys():
                cls_report[label] = {}
                if type(_cls_report[label]) == dict:
github CogStack / MedCAT / medcat / meta_cat.py View on Github external
if not fine_tune:
            if model_name == 'lstm':
                from medcat.utils.models import LSTM
                nclasses = len(self.category_values)
                bid = model_config.get("bid", True)
                num_layers = model_config.get("num_layers", 2)
                input_size = model_config.get("input_size", 300)
                hidden_size = model_config.get("hidden_size", 300)
                dropout = model_config.get("dropout", 0.5)

                self.model = LSTM(self.embeddings, self.pad_id, nclasses=nclasses, bid=bid, num_layers=num_layers,
                             input_size=input_size, hidden_size=hidden_size, dropout=dropout)

        if cv == 0:
            (f1, p, r, cls_report) = train_network(self.model, data, max_seq_len=(self.cntx_left+self.cntx_right+1), lr=lr, test_size=test_size,
                    pad_id=self.pad_id, batch_size=batch_size, nepochs=nepochs, device=self.device,
                    class_weights=class_weights, ignore_cpos=ignore_cpos, save_dir=self.save_dir,
                    auto_save_model=auto_save_model, score_average=score_average)
        elif cv > 0:
            # Mainly for testing, not really used in a normal workflow
            f1s = []
            ps = []
            rs = []
            cls_reports = []
            for i in range(cv):
                # Reset the model
                if fine_tune:
                    self.load_model(model=model_name)
                else:
                    if model_name == 'lstm':
                        from medcat.utils.models import LSTM