Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
temp_inds = range(self.ntemp) if temp_inds is None else temp_inds
ntemp = len(temp_inds)
spins = range(self.ebands.nsppol) if spins is None else spins
kpt_inds = range(self.ebands.nkpt) if kpt_inds is None else kpt_inds
nkpt = len(kpt_inds)
xs, emin, emax = self.get_emesh_eminmax(estep)
nene = len(xs)
num_plots, ncols, nrows = nkpt, 1, 1
if num_plots > 1:
ncols = 2
nrows = (num_plots // ncols) + (num_plots % ncols)
# Build plot grid.
ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=True, sharey=True, squeeze=False)
ax_list = np.array(ax_list).ravel()
cmap = plt.get_cmap(colormap)
for isp, spin in enumerate(spins):
spin_sign = +1 if spin == 0 else -1
for ik, (ikpt, ax) in enumerate(zip(kpt_inds, ax_list)):
ax.grid(True)
atw = self.get_atw(xs, spin, ikpt, band_inds, temp_inds)
for it, itemp in enumerate(temp_inds):
ys = spin_sign * atw[it] + (it * apad)
ax.plot(xs, ys, lw=2, alpha=0.8, color=cmap(float(it) / ntemp),
label="T = %.1f K" % self.tmesh[itemp] if (ik, isp) == (0, 0) else None)
if spin == 0:
kpt = self.ebands.kpoints[ikpt]
"""
Plot (k, e) color maps for different temperatures.
Args:
fontsize (int): fontsize for titles and legend
Return: |matplotlib-Figure|
"""
temp_inds = range(self.ntemp) if temp_inds is None else temp_inds
# Build plot grid.
num_plots, ncols, nrows = len(temp_inds), 1, 1
if num_plots > 1:
ncols = 2
nrows = (num_plots // ncols) + (num_plots % ncols)
ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=True, sharey=True, squeeze=False)
ax_list = ax_list.ravel()
# Don't show the last ax if numeb is odd.
if num_plots % ncols != 0: ax_list[-1].axis("off")
for itemp, ax in zip(temp_inds, ax_list):
self.plot_ekmap_itemp(itemp=itemp, spins=spins, estep=estep, ax=ax, ylims=ylims,
with_colorbar=with_colorbar, show=False, **kwargs)
ax.set_title("T = %.1f K" % self.tmesh[itemp], fontsize=fontsize)
return fig
Args:
e0: Option used to define the zero of energy in the band structure plot. Possible values:
- ``fermie``: shift all eigenvalues to have zero energy at the Fermi energy (`self.fermie`).
- Number e.g e0=0.5: shift all eigenvalues to have zero energy at 0.5 eV
- None: Don't shift energies, equivalent to e0=0
fontsize: Fontsize for title.
Return: |matplotlib-Figure|
"""
num_plots, ncols, nrows = self.ntemp, 1, 1
if num_plots > 1:
ncols = 2
nrows = (num_plots // ncols) + (num_plots % ncols)
ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=True, sharey=True, squeeze=False)
ax_list = np.array(ax_list).ravel()
# don't show the last ax if num_plots is odd.
if num_plots % ncols != 0: ax_list[-1].axis("off")
#e0 = 0
for itemp, ax in enumerate(ax_list):
fig = self.plot_itemp(itemp, ax=ax, e0=e0, ylims=ylims, fontsize=fontsize, show=False)
if itemp != 0:
set_visible(ax, False, "ylabel", "legend")
return fig
def plot_a2(self, phdos, atol=1e-12, **kwargs):
"""
Grid with 3 plots showing: a2F(w), F(w), a2F(w). Requires phonon DOS.
Args:
phdos: |PhononDos|
atol: F(w) is replaced by atol in a2F(w) / F(w) ratio where :math:`|F(w)|` < atol
Returns: |matplotlib-Figure|
"""
phdos = PhononDos.as_phdos(phdos)
ax_list, fig, plt = get_axarray_fig_plt(None, nrows=3, ncols=1,
sharex=True, sharey=False, squeeze=True)
ax_list = ax_list.ravel()
# Spline phdos onto a2f mesh and compute a2F(w) / F(w)
f = phdos.spline(self.mesh)
f = self.values / np.where(np.abs(f) > atol, f, atol)
ax = ax_list[0]
ax.plot(self.mesh, f, color="k", linestyle="-")
ax.grid(True)
ax.set_ylabel(r"$\alpha^2(\omega)$ [1/eV]")
# Plot F(w). TODO: This should not be called plot_dos_idos!
ax = ax_list[1]
phdos.plot_dos_idos(ax=ax, what="d", color="k", linestyle="-")
ax.grid(True)
ax.set_ylabel(r"$F(\omega)$ [states/eV]")
Args:
what_list: ``phfreqs`` for phonons, `lambda`` for the eph coupling strength,
``gamma`` for phonon linewidths.
ax_list: List of |matplotlib-Axes| (same length as what_list)
or None if a new figure should be created.
ylims: Set the data limits for the y-axis. Accept tuple e.g. ``(left, right)``
or scalar e.g. ``left``. If left (right) is None, default values are used
label: String used to label the plot in the legend.
fontsize: Legend and title fontsize.
Returns: |matplotlib-Figure|
"""
what_list = list_strings(what_list)
nrows, ncols = len(what_list), 1
ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=nrows, ncols=ncols,
sharex=True, sharey=False, squeeze=False)
ax_list = np.array(ax_list).ravel()
units = "eV"
for i, (ax, what) in enumerate(zip(ax_list, what_list)):
# Decorate the axis (e.g add ticks and labels).
self.phbands.decorate_ax(ax, units="")
if what == "phbands":
# Plot phonon bands
self.phbands.plot(ax=ax, units=units, show=False)
else:
# Add eph coupling.
if what == "lambda":
yvals = self.reader.read_phlambda_qpath()
ylabel = r"$\lambda(q,\nu)$"
or scalar e.g. `left`. If left (right) is None, default values are used
ylims: Same meaning as `ylims` but for the y-axis
fontsize: fontsize for titles and legend.
Return: |matplotlib-Figure|
"""
# Build plot grid.
if qview == "avg":
ncols, nrows = 2, 1
elif qview == "all":
qpoints = self._get_qpoints()
ncols, nrows = 2, len(qpoints)
else:
raise ValueError("Invalid value of qview: %s" % str(qview))
ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=True, sharey=True, squeeze=False)
if qview == "avg":
# Plot averaged values
self.plot_mdftype_cplx(mdf_type, "Re", ax=ax_mat[0, 0], xlims=xlims, ylims=ylims,
fontsize=fontsize, with_legend=True, show=False)
self.plot_mdftype_cplx(mdf_type, "Im", ax=ax_mat[0, 1], xlims=xlims, ylims=ylims,
fontsize=fontsize, with_legend=False, show=False)
elif qview == "all":
# Plot MDF(q)
nqpt = len(qpoints)
for iq, qpt in enumerate(qpoints):
islast = (iq == nqpt - 1)
self.plot_mdftype_cplx(mdf_type, "Re", qpoint=qpt, ax=ax_mat[iq, 0], xlims=xlims, ylims=ylims,
fontsize=fontsize, with_legend=(iq == 0), with_xlabel=islast, with_ylabel=islast, show=False)
self.plot_mdftype_cplx(mdf_type, "Im", qpoint=qpt, ax=ax_mat[iq, 1], xlims=xlims, ylims=ylims,
def plot_all(self, **kwargs):
"""
Plot diagonal and off-diagonal elements of the dielectric tensor as a function of frequency.
Both real and imag part are show. Accepts all arguments of `plot` method with the exception of:
`component` and `reim`.
Returns: |matplotlib-Figure|
"""
axmat, fig, plt = get_axarray_fig_plt(None, nrows=2, ncols=2,
sharex=True, sharey=False, squeeze=False)
fontsize = kwargs.pop("fontsize", 8)
for irow in range(2):
component = {0: "diag", 1: "offdiag"}[irow]
for icol in range(2):
reim = {0: "re", 1: "im"}[icol]
self.plot(component=component, reim=reim, ax=axmat[irow, icol], fontsize=fontsize, show=False, **kwargs)
return fig
site = self.structure[site_index]
nn_list = self.structure.get_neighbors_old(site, radius, include_index=True)
if not nn_list:
cprint("Zero neighbors found for radius %s Ang. Returning None." % radius, "yellow")
return None
# Sorte sites by distance.
nn_list = list(sorted(nn_list, key=lambda t: t[1]))
if max_nn is not None and len(nn_list) > max_nn:
cprint("For radius %s, found %s neighbors but only max_nn %s sites are show." %
(radius, len(nn_list), max_nn), "yellow")
nn_list = nn_list[:max_nn]
# Get grid of axes.
nrows, ncols = len(nn_list), 1
ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=True, sharey=True, squeeze=True)
ax_list = ax_list.ravel()
interpolator = self.get_interpolator()
for i, (nn, ax) in enumerate(zip(nn_list, ax_list)):
nn_site, nn_dist, nn_sc_index = nn
title = "%s, %s, dist=%.3f A" % (nn_site.species_string, str(nn_site.frac_coords), nn_dist)
r = interpolator.eval_line(site.frac_coords, nn_site.frac_coords, num=num, kpoint=None)
for ispden in range(self.nspden):
ax.plot(r.dist, r.values[ispden],
label=latexlabel_ispden(ispden, self.nspden) if i == 0 else None)
ax.set_title(title, fontsize=fontsize)
Produce two subplots:
1. Re/Imag part and intersection with omega - eKs
2. A(w)
Args:
itemp: List of temperature indices. "all" to plot'em all.
ax_list: List of |matplotlib-Axes|. If None, new figure is produced.
xlims: Set the data limits for the x-axis. Accept tuple e.g. ``(left, right)``
or scalar e.g. ``left``. If left (right) is None, default values are used.
fontsize: legend and label fontsize.
kwargs: Keyword arguments passed to ax.plot
Returns: |matplotlib-Figure|
"""
ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=2, ncols=1, sharex=True, sharey=False)
xs, xlabel = self._get_wmesh_xlabel("e0")
ax0, ax1 = ax_list
ax0.grid(True)
ax0.plot(xs, self.vals_wr[itemp].real, label=r"$\Re(\Sigma)$")
ax0.plot(xs, self.vals_wr[itemp].imag, ls="--", label=r"$\Im(\Sigma)$")
ax0.plot(xs, self.wmesh - self.qp.e0, color="k", lw=1, ls="--", label=r"$\omega - \epsilon^0$")
#ax0.axvline(x=0.0, color='k', linestyle='--', lw=1)
ax0.set_ylabel(r"$\Sigma(\omega)\,$(eV)")
ax0.legend(loc="best", fontsize=fontsize, shadow=True)
set_axlims(ax0, xlims, "x")
ymin = min(self.vals_wr[itemp].real.min(), self.vals_wr[itemp].imag.min())
ymin = ymin - abs(ymin) * 0.2
ymax = max(self.vals_wr[itemp].real.max(), self.vals_wr[itemp].imag.max())
ymax = ymax + abs(ymax) * 0.2
if not self.has_atom[iatom]: continue
for l in range(min(self.lmax_atom[iatom] + 1, mylsize)):
totdos_al[iatom, l, spin] += weight * gs * wal_sbk[iatom, l, spin, band, k]
paw1dos_al[iatom, l, spin] += weight * gs * paw1_wal_sbk[iatom, l, spin, band, k]
pawt1dos_al[iatom, l, spin] += weight * gs * pawt1_wal_sbk[iatom, l, spin, band, k]
else:
raise ValueError("Method %s is not supported" % method)
# TOT = PW + AE - PS
pwdos_al = totdos_al - paw1dos_al + pawt1dos_al
# Build plot grid.
nrows, ncols = np.count_nonzero(self.has_atom), self.lsize
ax_mat = None
ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols,
sharex=True, sharey=True, squeeze=False)
ax_mat = np.reshape(ax_mat, (nrows, ncols))
irow = -1
for iatom in range(self.natom):
if not self.has_atom[iatom]: continue
irow += 1
#for l in range(min(self.lmax_atom[iatom] + 1, mylsize)):
for l in range(min(self.lsize, mylsize)):
ax = ax_mat[irow, l]
if l >= self.lmax_atom[iatom]+1:
# don't show this plots and cycle
ax.axis("off")
continue
ax.grid(True)
if l != 0: