Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Second draft #38

Merged
merged 16 commits into from
Nov 12, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: Improve graphics
  • Loading branch information
cdalvaro committed Nov 11, 2020
commit cf69873928705f72e7cdf6897e72c24d7571ac72
2 changes: 1 addition & 1 deletion src/cdalvaro/graphics/color_palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ def set_color_palette():
sns.set_palette(color_palette(as_cmap=False))


def color_palette(as_cmap: bool = True, **kwargs):
def color_palette(as_cmap: bool = False, **kwargs):
# https://medium.com/@morganjonesartist/color-guide-to-seaborn-palettes-da849406d44f
return sns.color_palette('BrBG_r', as_cmap=as_cmap, **kwargs)
45 changes: 30 additions & 15 deletions src/cdalvaro/graphics/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns

from .color_palette import color_palette
Expand All @@ -10,7 +11,7 @@ def plot_clusters_catalogue_distribution(data: pd.DataFrame,
xlim: tuple = None,
ylim: tuple = None,
hue: str = 'diam'):
fig, ax = plt.subplots(figsize=(12, 6))
fig, ax = plt.subplots(figsize=(12, 6), tight_layout=True)
if title is not None:
ax.set_title(title)

Expand All @@ -23,11 +24,12 @@ def plot_clusters_catalogue_distribution(data: pd.DataFrame,
ax.set_ylim(ylim)

palette = None
hue_order = None
if hue is not None:
n_colors = len(pd.unique(data[hue]))
palette = color_palette(n_colors=n_colors)
hue_order = np.sort(pd.unique(data[hue]))
palette = color_palette(n_colors=len(hue_order))

g = sns.scatterplot(data=data, x="ra", y="dec", hue=hue, size=hue, palette=palette, ax=ax)
g = sns.scatterplot(data=data, x="ra", y="dec", hue=hue, hue_order=hue_order, size=hue, palette=palette, ax=ax)

plt.legend().set_title("Diameter (arcmin)")

Expand All @@ -40,7 +42,7 @@ def plot_cluster_proper_motion(data: pd.DataFrame,
ylim: tuple = None,
hue: str = 'cluster_g',
legend: bool = True):
fig, ax = plt.subplots(figsize=(6, 6))
fig, ax = plt.subplots(figsize=(6, 6), tight_layout=True)
if title is not None:
ax.set_title(title)

Expand All @@ -53,16 +55,25 @@ def plot_cluster_proper_motion(data: pd.DataFrame,
ax.set_ylim(ylim)

palette = None
hue_order = None
if hue is not None:
n_colors = len(pd.unique(data[hue]))
palette = color_palette(n_colors=n_colors)
hue_order = np.sort(pd.unique(data[hue]))
palette = color_palette(n_colors=len(hue_order))

g = sns.scatterplot(data=data, x="pmra", y="pmdec", hue=hue, palette=palette, s=12, ax=ax, legend=legend)
g = sns.scatterplot(data=data,
x="pmra",
y="pmdec",
hue=hue,
hue_order=hue_order,
palette=palette,
s=12,
ax=ax,
legend=legend)

return fig, ax, g


def plot_cluster_parallax_histogram(df_cluster,
def plot_cluster_parallax_histogram(data,
title: str = None,
xlim: tuple = None,
ylim: tuple = None,
Expand All @@ -83,13 +94,15 @@ def plot_cluster_parallax_histogram(df_cluster,
ax.set_ylim(ylim)

palette = None
hue_order = None
if hue is not None:
n_colors = len(pd.unique(df_cluster[hue]))
palette = color_palette(n_colors=n_colors)
hue_order = np.sort(pd.unique(data[hue]))
palette = color_palette(n_colors=len(hue_order))

g = sns.histplot(data=df_cluster,
g = sns.histplot(data=data,
x='parallax',
hue=hue,
hue_order=hue_order,
palette=palette,
legend=legend,
bins=bins,
Expand All @@ -106,7 +119,7 @@ def plot_cluster_isochrone_curve(data: pd.DataFrame,
ylim: tuple = None,
hue: str = 'cluster_g',
legend: bool = True):
fig, ax = plt.subplots(figsize=(6, 6))
fig, ax = plt.subplots(figsize=(6, 6), tight_layout=True)
if title is not None:
ax.set_title(title)

Expand All @@ -119,14 +132,16 @@ def plot_cluster_isochrone_curve(data: pd.DataFrame,
ax.set_ylim(ylim)

palette = None
hue_order = None
if hue is not None:
n_colors = len(pd.unique(data[hue]))
palette = color_palette(n_colors=n_colors)
hue_order = np.sort(pd.unique(data[hue]))
palette = color_palette(n_colors=len(hue_order))

g = sns.scatterplot(data=data,
x="bp_rp",
y="phot_g_mean_mag",
hue=hue,
hue_order=hue_order,
size='parallax',
sizes=(2, 20),
palette=palette,
Expand Down