How to use the seaborn.color_palette function in seaborn

To help you get started, we’ve selected a few seaborn examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github YosefLab / scVI / tests / notebooks / utils / gimvi_tutorial.py View on Github external
def plot_umap(trainer):
    latent_seq, latent_fish = trainer.get_latent()
    latent2d = umap.UMAP().fit_transform(np.concatenate([latent_seq, latent_fish]))
    latent2d_seq = latent2d[: latent_seq.shape[0]]
    latent2d_fish = latent2d[latent_seq.shape[0] :]

    data_seq, data_fish = [p.gene_dataset for p in trainer.all_dataset]

    colors = sns.color_palette(n_colors=30)
    plt.figure(figsize=(25, 10))
    ax = plt.subplot(1, 3, 1)
    ax.scatter(*latent2d_seq.T, color="r", label="seq", alpha=0.5, s=0.5)
    ax.scatter(*latent2d_fish.T, color="b", label="osm", alpha=0.5, s=0.5)
    ax.legend()

    ax = plt.subplot(1, 3, 2)
    labels = data_seq.labels.ravel()
    for i, label in enumerate(data_seq.cell_types):
        ax.scatter(
            *latent2d_seq[labels == i].T,
            color=colors[i],
            label=label[:12],
            alpha=0.5,
            s=5
        )
github joeybose / FloRL / pytorch-soft-actor-critic / plots / comet_plot.py View on Github external
y_label='Rewards'):

    fig = plt.figure(figsize=(12, 8))
    ax = plt.subplot()
    for label in (ax.get_xticklabels()):
        label.set_fontname('Arial')
        label.set_fontsize(28)
    for label in (ax.get_yticklabels()):
        label.set_fontname('Arial')
        label.set_fontsize(28)
    plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    ax.xaxis.get_offset_text().set_fontsize(20)
    axis_font = {'fontname':'Arial', 'size':'32'}

    # get a list of colors here.
    colors = sns.color_palette('colorblind', n_colors=len(list_of_data))
    #colors = sns.color_palette('cubehelix', n_colors=len(list_of_data))
    rewards_smoothed = []

    for data, label, color in zip(list_of_data, labels, colors):
        episodes = np.arange(data.shape[0])
        smoothed_data = pd.DataFrame(data).rolling(smoothing_window, min_periods=smoothing_window).mean()

        rewards_smoothed.append(smoothed_data)
        data_mean = smoothed_data.mean(axis=1)
        data_std = smoothed_data.std(axis=1)
        ax.fill_between(episodes,  data_mean + data_std, data_mean - data_std, alpha=0.3,
                        edgecolor=color, facecolor=color)
        plt.plot(episodes, data_mean, color=color, linewidth=1.5,  label=label)

    ax.legend(loc='lower right', prop={'size' : 26})
    ax.set_xlabel(x_label,**axis_font)
github dhiana / pretty_tsne / pretty_tsne.py View on Github external
def scatter(x):
    # We choose a color palette with seaborn.
    palette = np.array(sns.color_palette("hls", 10))

    # We create a scatter plot.
    f = plt.figure(figsize=(8, 8))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(x[:, 0], x[:, 1], lw=0, s=40)
    plt.xlim(-25, 25)
    plt.ylim(-25, 25)
    ax.axis('off')
    ax.axis('tight')

    return f, ax, sc
github vahndi / SyNaPSE / modules / dataframe_figures.py View on Github external
second object-dtype column.

    Inputs
    ------    
    dataframe:
        a long-form dataframe with 2 numeric columns and 2 object columns
    '''
    fig = Figure()
    ax = fig.add_subplot(111)
    categorical = dataframe.select_dtypes([object])
    numeric = dataframe.select_dtypes(['number'])
    colour_col = categorical.columns[0]
    marker_col = categorical.columns[1]
    colour_labels = sorted(categorical[colour_col].unique())
    marker_labels = sorted(categorical[marker_col].unique())
    palette = sns.color_palette('cubehelix', len(colour_labels))
    for m in marker_labels:
        for c in colour_labels:
            sub_df = dataframe[(dataframe[marker_col] == m) &
                               (dataframe[colour_col] == c)]            
            sns.regplot(
                    data = sub_df, fit_reg = False, ax = ax,
                    x = numeric.columns[0], y = numeric.columns[1], 
                    color = palette[colour_labels.index(c)],
                    marker = markers[marker_labels.index(m)],
                    label = '%s - %s' % (m, c))
    ax.legend()
    return fig
github krischer / LASIF / lasif / components / visualizations.py View on Github external
:return: The potentially created axes object.
        """
        # Get the statistics.
        data = self.comm.windows.get_window_statistics(window_set_name, events)

        import matplotlib
        import matplotlib.pylab as plt
        import seaborn as sns

        if ax is None:
            plt.figure(figsize=(10, 6))
            ax = plt.gca()

        ax.invert_yaxis()

        pal = sns.color_palette("Set1", n_colors=4)

        total_count = []
        count_z = []
        count_n = []
        count_e = []
        event_names = []

        width = 0.2
        ind = np.arange(len(data))

        cm = matplotlib.cm.RdYlGn

        for _i, event in enumerate(sorted(data.keys())):
            d = data[event]
            event_names.append(event)
            total_count.append(d["total_station_count"])
github calico / basenji / bin / basenji_sat_plot.py View on Github external
def plot_sad(ax, sat_loss_ti, sat_gain_ti):
  """ Plot loss and gain SAD scores.

    Args:
        ax (Axis): matplotlib axis to plot to.
        sat_loss_ti (L_sm array): Minimum mutation delta across satmut length.
        sat_gain_ti (L_sm array): Maximum mutation delta across satmut length.
    """

  rdbu = sns.color_palette('RdBu_r', 10)

  ax.plot(-sat_loss_ti, c=rdbu[0], label='loss', linewidth=1)
  ax.plot(sat_gain_ti, c=rdbu[-1], label='gain', linewidth=1)
  ax.set_xlim(0, len(sat_loss_ti))
  ax.legend()
  # ax_sad.grid(True, linestyle=':')

  ax.xaxis.set_ticks([])
  for axis in ['top', 'bottom', 'left', 'right']:
    ax.spines[axis].set_linewidth(0.5)
github edraizen / molmimic / molmimic / visualize / mayavi_vieiwer.py View on Github external
plot_matrix(ax, volume, colors=colors)

    if truth is not None:
        return rot_z180, rot_x45


def plot_cube_at(pos = (0,0,0), ax = None, color=(0,1,0), alpha=0.4):
    """Plots a cube element at position pos
    From: EnzyNet
    """
    if ax != None:
        X, Y, Z = cuboid_data(pos)
        ax.plot_surface(X, Y, Z, color=color, rstride=1, cstride=1, alpha=alpha, shade=False)

set2 = sns.color_palette("Set2", 8)
def plot_matrix(ax, matrix, truth=False, colors=False):
    'Plots cubes from a volumic matrix'
    if len(matrix.shape) >= 3:
        if len(matrix.shape) == 4 and matrix.shape[3]==3:
            use_raw_color = True
        else:
            use_raw_color = False

        half_k = matrix.shape[2]/2.
        for i in xrange(matrix.shape[0]):
            for j in xrange(matrix.shape[1]):
                for k in xrange(matrix.shape[2]):
                    #if matrix[i,j,k] == 1:
                    #print "Plotting voxel at", i, j, k
                    if truth:
                        color = (0,0,1)
github rymc / n2d / n2d.py View on Github external
def plot(x, y, plot_id, names=None):
    viz_df = pd.DataFrame(data=x[:5000])
    viz_df['Label'] = y[:5000]
    if names is not None:
        viz_df['Label'] = viz_df['Label'].map(names)

    viz_df.to_csv(args.save_dir + '/' + args.dataset + '.csv')
    plt.subplots(figsize=(8, 5))
    sns.scatterplot(x=0, y=1, hue='Label', legend='full', hue_order=sorted(viz_df['Label'].unique()),
                    palette=sns.color_palette("hls", n_colors=args.n_clusters),
                    alpha=.5,
                    data=viz_df)
    l = plt.legend(bbox_to_anchor=(-.1, 1.00, 1.1, .5), loc="lower left", markerfirst=True,
                   mode="expand", borderaxespad=0, ncol=args.n_clusters + 1, handletextpad=0.01, )

    l.texts[0].set_text("")
    plt.ylabel("")
    plt.xlabel("")
    plt.tight_layout()
    plt.savefig(args.save_dir + '/' + args.dataset +
                '-' + plot_id + '.png', dpi=300)
    plt.clf()
github bienz2 / PyFancyPlots / pyfancyplot / plot.py View on Github external
def get_palette(num_colors = None):
    global palette_name
    global palette_n_colors 
    global palette_desat
    if not num_colors is None:
        palette_n_colors = num_colors
    return sns.color_palette(palette_name, palette_n_colors, palette_desat)
github samplchallenges / SAMPL6 / host_guest / Analysis / Scripts / analyze_sampling.py View on Github external
reference_free_energies.set_index('System name', inplace=True)

    # Import user map.
    with open('../SubmissionsDoNotUpload/SAMPL6_user_map.csv', 'r') as f:
        user_map = pd.read_csv(f)

    # Load submissions data. We do OA and TEMOA together.
    submissions = load_submissions(SamplingSubmission, SAMPLING_SUBMISSIONS_DIR_PATH, user_map)

    # Export YANK analysis and submissions to CSV/JSON tables.
    yank_analysis.export(os.path.join(SAMPLING_DATA_DIR_PATH, 'reference_free_energies'))
    export_submissions(submissions, reference_free_energies)

    # Create output directory for plots.
    os.makedirs(SAMPLING_PLOT_DIR_PATH, exist_ok=True)
    palette_mean = sns.color_palette('dark')

    # Plot submission data.
    for submission in submissions:
        # CB8-G3 calculations haven't converged yet.
        if submission.name == 'Expanded-ensemble/MBAR':
            continue

        mean_free_energies = submission.mean_free_energies()
        unique_system_names = submission.data['System name'].unique()

        # Create a figure with 3 axes (one for each system).
        n_systems = len(unique_system_names)
        if PLOT_ERRORS:
            # The second row will plot the errors.
            fig, axes = plt.subplots(nrows=2, ncols=n_systems, figsize=(6*n_systems, 12))
            trajectory_axes = axes[0]