How to use the cgt.shape function in cgt

To help you get started, we’ve selected a few cgt examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github joschu / cgt / test / test_multi_output.py View on Github external
def shp_apply(self, inputs):
        return (cgt.shape(inputs[0]), cgt.shape(inputs[0]))
    def get_py_impl(self):
github joschu / cgt / cgt / distributions.py View on Github external
def sample(self, p, shape=None):
        p = core.as_node(p)
        shape = shape or cgt.shape(p)
        return cgt.rand(*shape) <= p
github joschu / cgt / cgt / ez.py View on Github external
def shp_apply(self, parents):
        if self.shapefun:
            return self.shapefun(parents)
        else:
            return cgt.shape(self)
    def typ_apply(self, _parents):
github joschu / cgt / cgt / compilation.py View on Github external
for node in nodes_sorted:

        base = node # by default, 
        if node.is_argument():
            pass
        elif node.op.writes_to_input >= 0:
            base = node2memowner[node.parents[node.op.writes_to_input]]
        elif node in after2before:
            base = after2before[node]
        elif enable_inplace_opt and node.op.return_type == "byref": # TODO think about if we need any other conditions
            nodeshape = node.op.shp_apply(node.parents)
            for parent in node.parents:
                parentowner = node2memowner[parent]
                if (len(node2child[parent])==1
                        and nodeshape==cgt.shape(parent) # XXX not a very robust way to check
                        and node.dtype == parent.dtype
                        and _is_data_mutable(parentowner)
                        and parent not in outputs
                        ):
                    base = parentowner
                    break
        # TODO: add optimization for in-place incrementing
        node2memowner[node] = base

    return node2memowner
github joschu / cgt / cgt / api.py View on Github external
def _subtensor2(x, slis, y):

    dims2drop = []
    for (ax,sli) in enumerate(slis):
        if _is_int_scalar(sli):
            if y is None:
                dims2drop.append(ax)
            else:
                yshape = cgt.shape(y)
                yshape.insert(ax, 1)
                y = y.reshape(yshape)
            sli = slice(sli, sli + 1, 1)

        assert isinstance(sli.step, int) or sli.step is None
        step = 1 if sli.step is None else sli.step

        if step < 0:
            start = size(x, ax)-1 if sli.start is None else sli.start
            stop = -1 if sli.stop is None else sli.stop
        else:
            start = 0 if sli.start is None else sli.start
            stop = size(x, ax) if sli.stop is None else sli.stop

        assert isinstance(step, (int, core.Node)), "step argument of a slice should be an integer or a symbolic variable"
github joschu / cgt / cgt / nn_ops / max_pool_2d.py View on Github external
def shp_apply(self, inputs):
        return cgt.shape(inputs[0])
    def typ_apply(self, inputs):
github joschu / cgt / cgt / ez.py View on Github external
            shapefun = lambda *args : tuple(cgt.shape(x) for x in inputs)  )
        return cgt.core.unpack(core.Result(pbop, inputs + [output, goutput]))
github joschu / cgt / cgt / nn_ops / max_pool_2d.py View on Github external
def shp_apply(self, inputs):
        # pooled_height_ = static_cast(ceil(static_cast(height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
        # pooled_width_ = static_cast(ceil(static_cast(width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
        info = self.info
        batch_size, channels, height, width = cgt.shape(inputs[0])
        pooled_height =  cgt.ceil_divide(height + 2*info.pad_h - info.kernel_h, info.stride_h)
        pooled_width = cgt.ceil_divide(width + 2*info.pad_w - info.kernel_w, info.stride_w)
        outshape = [batch_size ,  channels, pooled_height, pooled_width]
        return (outshape, outshape)
    def typ_apply(self, inputs):