Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def main():
mu, logstd = 0., np.log(1.)
print('True mu=%.03f std=%.03f' % (mu, np.exp(logstd)))
for frac in [0.01, 0.05, 0.1, 0.5, 0.9, 0.95, 0.99]:
Y, C = generate_data(10000, frac_cens=frac, mu=mu, logstd=logstd)
#print(Y[C==0].mean(), Y[C==1].mean())
print('==== Censoring fraction %.2f ====' % torch.mean(C))
mle_mu, mle_std = fit(Y, C, MLE_surv, mu_init=0., logstd_init=0.)
print('MLE mu=%.03f std=%.03f' % (mle_mu, mle_std))
crps_mu, crps_std = fit(Y, C, CRPS_surv, mu_init=0., logstd_init=0.)
print('CRPS mu=%.03f std=%.03f' % (crps_mu, crps_std))
def cv_n_estimators(X, y, C, cv_list, n_folds=10, distrib = HomoskedasticNormal, quadrant = False, s = CRPS_surv):
kf = KFold(n_splits=n_folds)
kf.get_n_splits(X)
mse_list = []
lkh_list = []
for param in cv_list:
print("Cross validating with parameter %.2f" % (param))
mse = 0
lkh = 0
for train_index, val_index in kf.split(X):
X_train_cv, X_val_cv = X[train_index], X[val_index]
y_train_cv, y_val_cv = y[train_index], y[val_index]
C_train_cv, C_val_cv = C[train_index], C[val_index]
base_learner = lambda: DecisionTreeRegressor(criterion='friedman_mse', \
min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_depth=3)
print(optimal_het_q_crps)
print("*"*6 + " Homoskedastic Distributions with CRPS [Orthan Search] " + "*"*6)
hom_q_crps = cv_n_estimators(X_train, y_train, C_train, cv_list = n_estimators_list, \
n_folds=fold_num, distrib = HomoskedasticNormal, quadrant = True, s = CRPS_surv)
optimal_hom_q_crps = n_estimators_list[np.argmax(hom_q_crps)]
print("--- Cross Validation MSE ---")
print(hom_q_crps)
print("--- Optimal parameter for Heteroskedastic Distributions with CRPS [Orthan Search] ---")
print(optimal_hom_q_crps)
print("*"*6 + " Heteroskedastic Distributions with CRPS [Line Search] " + "*"*6)
het_l_crps = cv_n_estimators(X_train, y_train, C_train, cv_list = n_estimators_list, \
n_folds=fold_num, distrib = Normal, quadrant = False, s=CRPS_surv)
optimal_het_l_crps = n_estimators_list[np.argmax(het_l_crps)]
print("--- Cross Validation MSE ---")
print(het_l_crps)
print("--- Optimal parameter for Heteroskedastic Distributions with CRPS [Line Search] ---")
print(optimal_het_l_crps)
print("*"*6 + " Homoskedastic Distributions with CRPS [Line Search] " + "*"*6)
hom_l_crps = cv_n_estimators(X_train, y_train, C_train, cv_list = n_estimators_list, \
n_folds=fold_num, distrib = HomoskedasticNormal, quadrant = False, s = CRPS_surv)
optimal_hom_l_crps = n_estimators_list[np.argmax(hom_l_crps)]
print("--- Cross Validation MSE ---")
print(hom_l_crps)
print("--- Optimal parameter for Heteroskedastic Distributions with CRPS [Line Search] ---")
print(optimal_het_l_mle)
print("*"*6 + " Homoskedastic Distributions with MLE [Line Search] " + "*"*6)
hom_l_mle = cv_n_estimators(X_train, y_train, C_train, cv_list = n_estimators_list, \
n_folds=fold_num, distrib = HomoskedasticNormal, quadrant = False, s = MLE_surv)
optimal_hom_l_mle = n_estimators_list[np.argmax(hom_l_mle)]
print("--- Cross Validation MSE ---")
print(hom_l_mle)
print("--- Optimal parameter for Heteroskedastic Distributions with MLE [Line Search] ---")
print(optimal_hom_l_mle)
print("*"*6 + " Heteroskedastic Distributions with CRPS [Orthan Search] " + "*"*6)
het_q_crps = cv_n_estimators(X_train, y_train, C_train, cv_list = n_estimators_list, \
n_folds=fold_num, distrib = Normal, quadrant = True, s=CRPS_surv)
optimal_het_q_crps = n_estimators_list[np.argmax(het_q_crps)]
print("--- Cross Validation MSE ---")
print(het_q_crps)
print("--- Optimal parameter for Heteroskedastic Distributions with CRPS [Orthan Search] ---")
print(optimal_het_q_crps)
print("*"*6 + " Homoskedastic Distributions with CRPS [Orthan Search] " + "*"*6)
hom_q_crps = cv_n_estimators(X_train, y_train, C_train, cv_list = n_estimators_list, \
n_folds=fold_num, distrib = HomoskedasticNormal, quadrant = True, s = CRPS_surv)
optimal_hom_q_crps = n_estimators_list[np.argmax(hom_q_crps)]
print("--- Cross Validation MSE ---")
print(hom_q_crps)
print("--- Optimal parameter for Heteroskedastic Distributions with CRPS [Orthan Search] ---")
print(optimal_hom_q_crps)