How to use the batchflow.models.tf.resnet.ResNet.block function in batchflow

To help you get started, we’ve selected a few batchflow examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github analysiscenter / batchflow / batchflow / models / tf / vnet.py View on Github external
inputs : tf.Tensor
            input tensor
        name : str
            scope name

        Returns
        -------
        tf.Tensor
        """
        kwargs = cls.fill_params('body', **kwargs)
        layout, kernel_size = cls.pop(['layout', 'kernel_size'], kwargs)
        x, inputs = inputs, None
        with tf.variable_scope(name):
            if downsample:
                x = conv_block(x, layout='cna', kernel_size=2, strides=2, name='downsample', **kwargs)
            x = ResNet.block(x, layout=layout, kernel_size=kernel_size, downsample=0, name='conv', **kwargs)
        return x
github analysiscenter / batchflow / batchflow / models / tf / resattention.py View on Github external
Returns
        -------
        tf.Tensor
        """
        kwargs = cls.fill_params('body', **kwargs)
        layout, filters = cls.pop(['layout', 'filters'], kwargs)
        mask = cls.pop('mask', kwargs)
        trunk = cls.pop('trunk', kwargs)
        mask = {**kwargs, **mask}
        trunk = {**kwargs, **trunk}

        x, inputs = inputs, None
        with tf.variable_scope(name):
            for i, b in enumerate(layout):
                if b == 'r':
                    x = ResNet.block(x, filters=filters[i], name='resblock-%d' % i, **{**kwargs, 'downsample':True})
                else:
                    x = cls.attention(x, level=int(b), filters=filters[i], name='attention-%d' % i, **kwargs)
        return x
github analysiscenter / batchflow / batchflow / models / tf / resattention.py View on Github external
Parameters
        ----------
        inputs : tf.Tensor
            input tensor
        level : int
            nested mask level
        name : str
            scope name

        Returns
        -------
        tf.Tensor
        """
        with tf.variable_scope(name):
            x, inputs = inputs, None
            x = ResNet.block(x, name='initial', **kwargs)

            t = cls.trunk(x, **kwargs)
            m = cls.mask((x, t), level=level, **kwargs)

            x = conv_block(m, layout='nac nac', kernel_size=1, name='scale',
                           **{**kwargs, 'filters': kwargs['filters']*4})
            x = tf.sigmoid(x, name='attention_map')

            x = (1 + x) * t

            x = ResNet.block(x, name='last', **kwargs)
        return x
github analysiscenter / batchflow / batchflow / models / tf / gcn.py View on Github external
""" An ordinary ResNet block

        Parameters
        ----------
        inputs : tf.Tensor
            input tensor
        name : str
            scope name

        Returns
        -------
        tf.Tensor
        """
        kwargs = cls.fill_params('body/br', **kwargs)
        kwargs['filters'] = cls.num_channels(inputs, data_format=kwargs['data_format'])
        return ResNet.block(inputs, name=name, **kwargs)
github analysiscenter / batchflow / batchflow / models / tf / resattention.py View on Github external
nested mask level
        name : str
            scope name

        Returns
        -------
        tf.Tensor
        """
        kwargs = cls.fill_params('body/mask', **kwargs)
        upsample_args = cls.pop('upsample', kwargs)

        with tf.variable_scope(name):
            x, skip = inputs
            inputs = None
            x = conv_block(x, layout='p', name='pool', **kwargs)
            b = ResNet.block(x, name='resblock_1', **kwargs)
            c = ResNet.block(b, name='resblock_2', **kwargs)

            if level > 0:
                i = cls.mask((b, b), level=level-1, name='submask-%d' % level, **kwargs)
                c = ResNet.block(c + i, name='resblock_3', **kwargs)

            x = cls.upsample(c, resize_to=skip, name='interpolation', data_format=kwargs['data_format'],
                             **upsample_args)
        return x
github analysiscenter / batchflow / batchflow / models / tf / resattention.py View on Github external
name : str
            scope name

        Returns
        -------
        tf.Tensor
        """
        kwargs = cls.fill_params('body/mask', **kwargs)
        upsample_args = cls.pop('upsample', kwargs)

        with tf.variable_scope(name):
            x, skip = inputs
            inputs = None
            x = conv_block(x, layout='p', name='pool', **kwargs)
            b = ResNet.block(x, name='resblock_1', **kwargs)
            c = ResNet.block(b, name='resblock_2', **kwargs)

            if level > 0:
                i = cls.mask((b, b), level=level-1, name='submask-%d' % level, **kwargs)
                c = ResNet.block(c + i, name='resblock_3', **kwargs)

            x = cls.upsample(c, resize_to=skip, name='interpolation', data_format=kwargs['data_format'],
                             **upsample_args)
        return x
github analysiscenter / batchflow / batchflow / models / tf / resattention.py View on Github external
-------
        tf.Tensor
        """
        kwargs = cls.fill_params('body/mask', **kwargs)
        upsample_args = cls.pop('upsample', kwargs)

        with tf.variable_scope(name):
            x, skip = inputs
            inputs = None
            x = conv_block(x, layout='p', name='pool', **kwargs)
            b = ResNet.block(x, name='resblock_1', **kwargs)
            c = ResNet.block(b, name='resblock_2', **kwargs)

            if level > 0:
                i = cls.mask((b, b), level=level-1, name='submask-%d' % level, **kwargs)
                c = ResNet.block(c + i, name='resblock_3', **kwargs)

            x = cls.upsample(c, resize_to=skip, name='interpolation', data_format=kwargs['data_format'],
                             **upsample_args)
        return x
github analysiscenter / batchflow / batchflow / models / tf / resattention.py View on Github external
tf.Tensor
        """
        with tf.variable_scope(name):
            x, inputs = inputs, None
            x = ResNet.block(x, name='initial', **kwargs)

            t = cls.trunk(x, **kwargs)
            m = cls.mask((x, t), level=level, **kwargs)

            x = conv_block(m, layout='nac nac', kernel_size=1, name='scale',
                           **{**kwargs, 'filters': kwargs['filters']*4})
            x = tf.sigmoid(x, name='attention_map')

            x = (1 + x) * t

            x = ResNet.block(x, name='last', **kwargs)
        return x
github analysiscenter / batchflow / batchflow / models / tf / resattention.py View on Github external
tf.Tensor
        """
        with tf.variable_scope(name):
            x, inputs = inputs, None
            x = ResNet.block(x, name='initial', **kwargs)

            t = cls.trunk(x, **kwargs)
            m = cls.mask((x, t), level=level, **kwargs)

            x = conv_block(m, layout='nac nac', kernel_size=1, name='scale',
                           **{**kwargs, 'filters': kwargs['filters']*4})
            x = tf.sigmoid(x, name='attention_map')

            x = (1 + x) * t

            x = ResNet.block(x, name='last', **kwargs)
        return x