How to use the gluonts.core.component.validated function in gluonts

To help you get started, we’ve selected a few gluonts examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github awslabs / gluon-ts / src / gluonts / distribution / dirichlet.py View on Github external
    @validated()
    def __init__(self, dim: int) -> None:
        assert dim > 1, "Dimension should be larger than one."
        self.args_dim = {"alpha": dim}
        self.distr_cls = Dirichlet
        self.dim = dim
        self.mask = None
github awslabs / gluon-ts / src / gluonts / shell / serve.py View on Github external
    @validated()
    def __init__(self, base_address: str) -> None:
        self.base_address = base_address
github awslabs / gluon-ts / src / gluonts / transform / _base.py View on Github external
    @validated()
    def __init__(self, trans: List[Transformation]) -> None:
        self.transformations = []
        for transformation in trans:
            # flatten chains
            if isinstance(transformation, Chain):
                self.transformations.extend(transformation.transformations)
            else:
                self.transformations.append(transformation)
github awslabs / gluon-ts / src / gluonts / model / deepstate / issm.py View on Github external
    @validated()
    def __init__(
        self,
        seasonal_issms: List[SeasonalityISSM],
        add_trend: bool = DEFAULT_ADD_TREND,
    ) -> None:
        super(CompositeISSM, self).__init__()
        self.seasonal_issms = seasonal_issms
        self.nonseasonal_issm = (
            LevelISSM() if add_trend is False else LevelTrendISSM()
        )
github awslabs / gluon-ts / src / gluonts / model / forecast.py View on Github external
    @validated()
    def __init__(
        self,
        distribution: Distribution,
        start_date,
        freq,
        item_id: Optional[str] = None,
        info: Optional[Dict] = None,
    ):
        self.distribution = distribution
        self.shape = (
            self.distribution.batch_shape + self.distribution.event_shape
        )
        self.prediction_length = self.shape[0]
        self.item_id = item_id
        self.info = info
github awslabs / gluon-ts / src / gluonts / model / deepstate / _network.py View on Github external
    @validated()
    def __init__(self, num_parallel_samples: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.num_parallel_samples = num_parallel_samples
github awslabs / gluon-ts / src / gluonts / block / feature.py View on Github external
    @validated()
    def __init__(
        self,
        T: int,
        use_static_cat: bool = False,
        use_static_real: bool = False,
        use_dynamic_cat: bool = False,
        use_dynamic_real: bool = False,
        embed_static: Optional[FeatureEmbedder] = None,
        embed_dynamic: Optional[FeatureEmbedder] = None,
        dtype: DType = np.float32,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        assert T > 0, "The value of `T` should be > 0"
github awslabs / gluon-ts / src / gluonts / model / prophet / _predictor.py View on Github external
    @validated()
    def __init__(
        self,
        freq: str,
        prediction_length: int,
        prophet_params: Optional[Dict] = None,
        init_model: Callable = lambda m: m,
    ) -> None:
        super().__init__(prediction_length, freq)

        if not PROPHET_IS_INSTALLED:
            raise ImportError(USAGE_MESSAGE)

        if prophet_params is None:
            prophet_params = {}

        assert "uncertainty_samples" not in prophet_params, (
github awslabs / gluon-ts / src / gluonts / model / transformer / trans_decoder.py View on Github external
    @validated()
    def __init__(self, decoder_length: int, config: Dict, **kwargs) -> None:

        super().__init__(**kwargs)

        self.decoder_length = decoder_length
        self.cache = {}

        with self.name_scope():
            self.enc_input_layer = InputLayer(model_size=config["model_dim"])

            self.dec_pre_self_att = TransformerProcessBlock(
                sequence=config["pre_seq"],
                dropout=config["dropout_rate"],
                prefix="pretransformerprocessblock_",
            )
            self.dec_self_att = MultiHeadSelfAttention(