Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def evaluate(self, explain=True):
"""
This method is used to score our trained model.
"""
# Load UBM model
model_name = "ubm_{}.h5".format(self.NUM_GAUSSIANS)
ubm = sidekit.Mixture()
ubm.read(os.path.join(self.BASE_DIR, "ubm", model_name))
# Load TV matrix
filename = "tv_matrix_{}".format(self.NUM_GAUSSIANS)
outputPath = os.path.join(self.BASE_DIR, "ivector", filename)
fa = sidekit.FactorAnalyser(outputPath+".h5")
# Extract i-vectors from enrollment data
logging.info("Extracting i-vectors from enrollment data")
filename = 'enroll_stat_{}.h5'.format(self.NUM_GAUSSIANS)
enroll_stat = sidekit.StatServer.read(os.path.join(self.BASE_DIR, 'stat', filename))
enroll_iv = fa.extract_ivectors_single( ubm=ubm,
stat_server=enroll_stat,
uncertainty=False
)
# Extract i-vectors from test data
logging.info("Extracting i-vectors from test data")
filename = 'test_stat_{}.h5'.format(self.NUM_GAUSSIANS)
test_stat = sidekit.StatServer.read(os.path.join(self.BASE_DIR, 'stat', filename))
test_iv = fa.extract_ivectors_single(ubm=ubm,
stat_server=test_stat,
This method is used to train the Total Variability (TV) matrix
and save it into 'ivector' directory !!
"""
# Create status servers
self.__create_stats()
# Load UBM model
model_name = "ubm_{}.h5".format(self.NUM_GAUSSIANS)
ubm = sidekit.Mixture()
ubm.read(os.path.join(self.BASE_DIR, "ubm", model_name))
# Train TV matrix using FactorAnalyser
filename = "tv_matrix_{}".format(self.NUM_GAUSSIANS)
outputPath = os.path.join(self.BASE_DIR, "ivector", filename)
tv_filename = 'tv_stat_{}.h5'.format(self.NUM_GAUSSIANS)
fa = sidekit.FactorAnalyser()
fa.total_variability_single(os.path.join(self.BASE_DIR, "stat", tv_filename),
ubm,
tv_rank=self.TV_RANK,
nb_iter=self.TV_ITERATIONS,
min_div=True,
tv_init=None,
batch_size=self.BATCH_SIZE,
save_init=False,
output_file_name=outputPath
)
# tv = fa.F # TV matrix
# tv_mean = fa.mean # Mean vector
# tv_sigma = fa.Sigma # Residual covariance matrix
# Clear files produced at each iteration
filename_regex = "tv_matrix_{}_it-*.h5".format(self.NUM_GAUSSIANS)