How to use the tslearn.metrics.cdist_dtw function in tslearn

To help you get started, we’ve selected a few tslearn examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github rtavenar / tslearn / tslearn / clustering.py View on Github external
def _assign(self, X, update_class_attributes=True):
        if self.metric_params is None:
            metric_params = {}
        else:
            metric_params = self.metric_params.copy()
        if "gamma_sdtw" in metric_params.keys():
            metric_params["gamma"] = metric_params["gamma_sdtw"]
            del metric_params["gamma_sdtw"]
        if "n_jobs" in metric_params.keys():
            del metric_params["n_jobs"]
        if self.metric == "euclidean":
            dists = cdist(X.reshape((X.shape[0], -1)),
                          self.cluster_centers_.reshape((self.n_clusters, -1)),
                          metric="euclidean")
        elif self.metric == "dtw":
            dists = cdist_dtw(X, self.cluster_centers_, n_jobs=self.n_jobs,
                              **metric_params)
        elif self.metric == "softdtw":
            dists = cdist_soft_dtw(X, self.cluster_centers_, **metric_params)
        else:
            raise ValueError("Incorrect metric: %s (should be one of 'dtw', "
                             "'softdtw', 'euclidean')" % self.metric)
        matched_labels = dists.argmin(axis=1)
        if update_class_attributes:
            self.labels_ = matched_labels
            _check_no_empty_cluster(self.labels_, self.n_clusters)
            if self.dtw_inertia and self.metric != "dtw":
                inertia_dists = cdist_dtw(X, self.cluster_centers_,
                                          n_jobs=self.n_jobs)
            else:
                inertia_dists = dists
            self.inertia_ = _compute_inertia(inertia_dists,
github rtavenar / tslearn / tslearn / clustering.py View on Github external
self.cluster_centers_.reshape((self.n_clusters, -1)),
                          metric="euclidean")
        elif self.metric == "dtw":
            dists = cdist_dtw(X, self.cluster_centers_, n_jobs=self.n_jobs,
                              **metric_params)
        elif self.metric == "softdtw":
            dists = cdist_soft_dtw(X, self.cluster_centers_, **metric_params)
        else:
            raise ValueError("Incorrect metric: %s (should be one of 'dtw', "
                             "'softdtw', 'euclidean')" % self.metric)
        matched_labels = dists.argmin(axis=1)
        if update_class_attributes:
            self.labels_ = matched_labels
            _check_no_empty_cluster(self.labels_, self.n_clusters)
            if self.dtw_inertia and self.metric != "dtw":
                inertia_dists = cdist_dtw(X, self.cluster_centers_,
                                          n_jobs=self.n_jobs)
            else:
                inertia_dists = dists
            self.inertia_ = _compute_inertia(inertia_dists,
                                             self.labels_,
                                             self._squared_inertia)
        return matched_labels
github rtavenar / tslearn / tslearn / neighbors.py View on Github external
"""
        if self.metric in VARIABLE_LENGTH_METRICS:
            self._ts_metric = self.metric
            self.metric = "precomputed"

            if self.metric_params is None:
                metric_params = {}
            else:
                metric_params = self.metric_params.copy()
                if "n_jobs" in metric_params.keys():
                    del metric_params["n_jobs"]
            check_is_fitted(self, '_ts_fit')
            X = check_array(X, allow_nd=True, force_all_finite=False)
            X = to_time_series_dataset(X)
            if self._ts_metric == "dtw":
                X_ = cdist_dtw(X, self._ts_fit, n_jobs=self.n_jobs,
                               **metric_params)
            elif self._ts_metric == "softdtw":
                X_ = cdist_soft_dtw(X, self._ts_fit, **metric_params)
            else:
                raise ValueError("Invalid metric recorded: %s" %
                                 self._ts_metric)
            pred = super(KNeighborsTimeSeriesClassifier,
                         self).predict_proba(X_)
            self.metric = self._ts_metric
            return pred
        else:
            check_is_fitted(self, '_X_fit')
            X = check_array(X, allow_nd=True)
            X = to_time_series_dataset(X)
            X_ = to_sklearn_dataset(X)
            X_ = check_dims(X_, self._X_fit, extend=False)
github rtavenar / tslearn / tslearn / neighbors.py View on Github external
metric_params = self.metric_params.copy()
            if "n_jobs" in metric_params.keys():
                del metric_params["n_jobs"]
        else:
            metric_params = {}
        self_neighbors = False
        if n_neighbors is None:
            n_neighbors = self.n_neighbors
        if X is None:
            X = self._X_fit
            self_neighbors = True
        if self.metric == "precomputed":
            full_dist_matrix = X
        else:
            parallelize = False
            if self.metric == "dtw" or self.metric == cdist_dtw:
                cdist_fun = cdist_dtw
                parallelize = True
            elif self.metric == "softdtw" or self.metric == cdist_soft_dtw:
                cdist_fun = cdist_soft_dtw
            elif self.metric in ["euclidean", "sqeuclidean", "cityblock"]:
                def cdist_fun(X, Xp):
                    return scipy_cdist(X.reshape((X.shape[0], -1)),
                                       Xp.reshape((Xp.shape[0], -1)),
                                       metric=self.metric)
            else:
                raise ValueError("Unrecognized time series metric string: %s "
                                 "(should be one of 'dtw', 'softdtw', "
                                 "'euclidean', 'sqeuclidean' "
                                 "or 'cityblock')" % self.metric)

            if X.ndim == 2:  # sklearn-format case