How to use the memcnn.models.additive.AdditiveBlockInverseFunction2.apply function in memcnn

To help you get started, we’ve selected a few memcnn examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github silvandeleemput / memcnn / memcnn / models / additive.py View on Github external
def inverse(self, y):
        args = [y, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()]

        if self.implementation_bwd == 0:
            x = AdditiveBlockInverseFunction.apply(*args)
        elif self.implementation_bwd == 1:
            x = AdditiveBlockInverseFunction2.apply(*args)
        elif self.implementation_bwd == -1:
            y1, y2 = torch.chunk(y, 2, dim=1)
            y1, y2 = y1.contiguous(), y2.contiguous()
            gmd = self.Gm.forward(y1)
            x2 = y2 - gmd
            fmd = self.Fm.forward(x2)
            x1 = y1 - fmd
            x = torch.cat([x1, x2], dim=1)
        else:
            raise NotImplementedError("Inverse for selected implementation ({}) not implemented..."
                                      .format(self.implementation_bwd))

        return x