Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _update_centroids(self, X):
if self.metric_params is None:
metric_params = {}
else:
metric_params = self.metric_params.copy()
if "gamma_sdtw" in metric_params.keys():
metric_params["gamma"] = metric_params["gamma_sdtw"]
del metric_params["gamma_sdtw"]
for k in range(self.n_clusters):
if self.metric == "dtw":
self.cluster_centers_[k] = dtw_barycenter_averaging(
X=X[self.labels_ == k],
barycenter_size=None,
init_barycenter=self.cluster_centers_[k],
verbose=False)
elif self.metric == "softdtw":
self.cluster_centers_[k] = softdtw_barycenter(
X=X[self.labels_ == k],
max_iter=self.max_iter_barycenter,
init=self.cluster_centers_[k],
**metric_params)
else:
self.cluster_centers_[k] = euclidean_barycenter(
X=X[self.labels_ == k])
dtw_barycenter_averaging, softdtw_barycenter
from tslearn.datasets import CachedDatasets
numpy.random.seed(0)
X_train, y_train, X_test, y_test = CachedDatasets().load_dataset("Trace")
X = X_train[y_train == 2]
plt.figure()
plt.subplot(3, 1, 1)
for ts in X:
plt.plot(ts.ravel(), "k-", alpha=.2)
plt.plot(euclidean_barycenter(X).ravel(), "r-", linewidth=2)
plt.title("Euclidean barycenter")
plt.subplot(3, 1, 2)
dba_bar = dtw_barycenter_averaging(X, max_iter=100, verbose=False)
for ts in X:
plt.plot(ts.ravel(), "k-", alpha=.2)
plt.plot(dba_bar.ravel(), "r-", linewidth=2)
plt.title("DBA")
plt.subplot(3, 1, 3)
sdtw_bar = softdtw_barycenter(X, gamma=1., max_iter=100)
for ts in X:
plt.plot(ts.ravel(), "k-", alpha=.2)
plt.plot(sdtw_bar.ravel(), "r-", linewidth=2)
plt.title("Soft-DTW barycenter ($\gamma$=1.)")
plt.tight_layout()
plt.show()