How to use the chemprop.models.mpn.MPN function in chemprop

To help you get started, we’ve selected a few chemprop examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github wengong-jin / chemprop / chemprop / models / jtnn.py View on Github external
def __init__(self, args: Namespace):
        super(JTNN, self).__init__()

        with open(args.vocab_path) as f:
            self.vocab = Vocab([line.strip("\r\n ") for line in f])

        self.hidden_size = args.hidden_size
        self.depth = args.depth
        self.args = args

        self.jtnn = MPN(args, atom_fdim=self.hidden_size, bond_fdim=self.hidden_size, graph_input=True)
        self.embedding = nn.Embedding(self.vocab.size(), self.hidden_size)
        self.mpn = MPN(args)
github wengong-jin / chemprop / chemprop / models / mpn.py View on Github external
def __init__(self,
                 args: Namespace,
                 atom_fdim: int = None,
                 bond_fdim: int = None,
                 graph_input: bool = False,
                 params: Dict[str, nn.Parameter] = None):
        super(MPN, self).__init__()
        self.args = args
        self.atom_fdim = atom_fdim or get_atom_fdim(args)
        self.bond_fdim = bond_fdim or get_bond_fdim(args) + (not args.atom_messages) * self.atom_fdim
        self.graph_input = graph_input
        self.encoder = MPNEncoder(self.args, self.atom_fdim, self.bond_fdim, params=params)
github wengong-jin / chemprop / chemprop / models / model.py View on Github external
def create_encoder(self, args: Namespace, params: Dict[str, nn.Parameter] = None):
        if args.jtnn:
            if params is not None:
                raise ValueError('Setting parameters not yeet supported for JTNN')
            self.encoder = JTNN(args)
        elif args.dataset_type == 'bert_pretraining':
            self.encoder = MPN(args, graph_input=True, params=params)
        else:
            self.encoder = MPN(args, params=params)
        
        if args.freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
        
        if args.gradual_unfreezing:
            self.create_unfreeze_queue(args)
github wengong-jin / chemprop / chemprop / models / moe.py View on Github external
def __init__(self, args):
        super(MOE, self).__init__()
        self.args = args
        self.num_sources = args.num_sources
        self.classifiers = nn.ModuleList([Classifier(args) for _ in range(args.num_sources)])
        self.encoder = MPN(args)
        self.mmd = MMD(args)
        self.Us = nn.ParameterList(
            [nn.Parameter(torch.zeros((args.hidden_size, args.m_rank)), requires_grad=True) for _ in
             range(args.num_sources)])
        # note zeros are replaced during initialization later
        if args.dataset_type == 'regression':
            self.mtl_criterion = nn.MSELoss(reduction='none')
            self.moe_criterion = nn.MSELoss(reduction='none')
        elif args.dataset_type == 'classification':  # this half untested
            self.mtl_criterion = nn.BCELoss(reduction='none')
            self.moe_criterion = nn.BCELoss(reduction='none')
        self.entropy_criterion = HLoss()
        self.lambda_moe = args.lambda_moe
        self.lambda_critic = args.lambda_critic
        self.lambda_entropy = args.lambda_entropy
github wengong-jin / chemprop / chemprop / models / model.py View on Github external
def create_encoder(self, args: Namespace, params: Dict[str, nn.Parameter] = None):
        if args.jtnn:
            if params is not None:
                raise ValueError('Setting parameters not yeet supported for JTNN')
            self.encoder = JTNN(args)
        elif args.dataset_type == 'bert_pretraining':
            self.encoder = MPN(args, graph_input=True, params=params)
        else:
            self.encoder = MPN(args, params=params)
        
        if args.freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
        
        if args.gradual_unfreezing:
            self.create_unfreeze_queue(args)
github wengong-jin / chemprop / chemprop / models / jtnn.py View on Github external
def __init__(self, args: Namespace):
        super(JTNN, self).__init__()

        with open(args.vocab_path) as f:
            self.vocab = Vocab([line.strip("\r\n ") for line in f])

        self.hidden_size = args.hidden_size
        self.depth = args.depth
        self.args = args

        self.jtnn = MPN(args, atom_fdim=self.hidden_size, bond_fdim=self.hidden_size, graph_input=True)
        self.embedding = nn.Embedding(self.vocab.size(), self.hidden_size)
        self.mpn = MPN(args)