How to use the memcnn.models.affine.AffineBlockFunction2.apply function in memcnn

To help you get started, we’ve selected a few memcnn examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github silvandeleemput / memcnn / memcnn / models / affine.py View on Github external
def forward(self, x):
        args = [x, self.Fm, self.Gm] + [w for w in self.Fm.parameters()] + [w for w in self.Gm.parameters()]

        if self.implementation_fwd == 0:
            out = AffineBlockFunction.apply(*args)
        elif self.implementation_fwd == 1:
            out = AffineBlockFunction2.apply(*args)
        elif self.implementation_fwd == -1:
            x1, x2 = torch.chunk(x, 2, dim=1)
            x1, x2 = x1.contiguous(), x2.contiguous()
            fmr1, fmr2 = self.Fm.forward(x2)
            y1 = (x1 * fmr1) + fmr2
            gmr1, gmr2 = self.Gm.forward(y1)
            y2 = (x2 * gmr1) + gmr2
            out = torch.cat([y1, y2], dim=1)
        else:
            raise NotImplementedError("Selected implementation ({}) not implemented..."
                                      .format(self.implementation_fwd))

        return out