from typing import Union, Dict, List, Tuple
import warnings
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
from pymantra.statics import (
NODE_TYPE_LIST, EDGE_TYPE_LIST, DIRECT_EDGE_TYPE_LIST
)
NODE_COLOURS = dict(zip(
NODE_TYPE_LIST, ["tab:orange", "tab:green", "tab:red", "tab:blue"]
))
NODE_SHAPES = dict(zip(
NODE_TYPE_LIST, ["^", "s", "v", "o"]
))
EDGE_COLOURS = dict(zip(
EDGE_TYPE_LIST,
["tab:grey", "tab:grey", "tab:grey", "tab:grey",
"tab:grey", "tab:grey", "tab:grey", "tab:grey"]
))
EDGE_STYLES = dict(zip(
EDGE_TYPE_LIST,
["solid", "solid", "dashed", "solid",
"solid", "solid", "solid", "solid"]
))
DIRECT_EDGE_COLOURS = dict(zip(
DIRECT_EDGE_TYPE_LIST,
["tab:grey", "tab:grey", "tab:grey", "tab:grey"]
))
DIRECT_EDGE_STYLES = dict(zip(
DIRECT_EDGE_TYPE_LIST,
["solid", "solid", "solid", "dashed"]
))
def _legend(
node_types: List[str], edge_types: List[str],
node_colours: Dict[str, str] = None, node_shapes: Dict[str, str] = None,
edge_colours: Dict[str, str] = None, edge_styles: Dict[str, str] = None,
reaction_graph: bool = False, **kwargs
) -> List[Line2D]:
if node_colours is None:
node_colours = NODE_COLOURS
if node_shapes is None:
node_shapes = NODE_SHAPES
if edge_colours is None:
if reaction_graph:
edge_colours = EDGE_COLOURS
else:
edge_colours = DIRECT_EDGE_COLOURS
if edge_styles is None:
if reaction_graph:
edge_styles = EDGE_STYLES
else:
edge_styles = DIRECT_EDGE_STYLES
handles = []
for node_type in node_types:
markersize = kwargs.pop("markersize", 12)
handles.append(
Line2D(
[0], [0], color="w", marker=node_shapes[node_type],
markerfacecolor=node_colours[node_type], label=node_type,
markersize=markersize, **kwargs
)
)
# placeholder between nodes and edges
handles.append(Line2D([], [], color="none"))
for edge_type in edge_types:
lw = kwargs.pop("lw", 4)
handles.append(
Line2D(
[0], [0], color=edge_colours.get(edge_type, 'tab:grey'),
label=edge_type, linestyle=edge_styles.get(edge_type, 'solid'),
lw=lw, **kwargs
)
)
return handles
def _plot_graph(
graph: Union[nx.Graph, nx.DiGraph],
layout: callable = nx.kamada_kawai_layout,
directed: bool = False, reaction_graph: bool = False,
ax: plt.axis = None, reverse_directions: bool = False,
formula_as_label: bool = False, show_labels: bool = True,
edge_width: float = 1, **label_args
) -> plt.axis:
if ax is None:
ax = plt.subplots(figsize=(16, 9))
positions = layout(graph)
# nodes
nodes_by_type = {}
for node, node_type in nx.get_node_attributes(graph, "node_type").items():
nodes_by_type.setdefault(node_type, set()).add(node)
for node_type, nodes in nodes_by_type.items():
nx.draw_networkx_nodes(
graph, positions, nodelist=nodes,
node_color=NODE_COLOURS[node_type],
node_shape=NODE_SHAPES[node_type],
edgecolors="tab:grey", ax=ax
)
# edges
edges_by_type = {}
for edge, edge_type in nx.get_edge_attributes(graph, "edge_type").items():
edges_by_type.setdefault(edge_type, set()).add(edge)
if reaction_graph:
for edge_type, edges in edges_by_type.items():
nx.draw_networkx_edges(
graph, positions, edgelist=edges,
edge_color=EDGE_COLOURS.get(edge_type, 'tab:grey'),
style=DIRECT_EDGE_STYLES.get(edge_type, 'solid'),
arrows=directed, ax=ax, width=edge_width
)
else:
for edge_type, edges in edges_by_type.items():
nx.draw_networkx_edges(
graph, positions, edgelist=edges,
edge_color=DIRECT_EDGE_COLOURS.get(edge_type, 'tab:grey'),
style=DIRECT_EDGE_STYLES.get(edge_type, 'solid'),
arrows=directed, ax=ax, width=edge_width
)
if show_labels:
# node labels
# => "Description" or "Formula" (reaction), "Name" (metabolite),
# "nodeLabel" (organism), "nodeLabel" (gene)
label_key = {
"reaction": "Formula" if formula_as_label else "Description",
"metabolite": "Name",
"organism": "nodeLabel",
"gene": "nodeLabel"
}
labels = {
node: node_data.get(
label_key.get(node_data["node_type"], "nodeLabel"), node)
for node, node_data in graph.nodes(data=True)
}
nx.draw_networkx_labels(
graph, positions, labels=labels, ax=ax, **label_args)
# legend
# TODO: option to manually adapt colour scheme
legend_handles = _legend(
list(nodes_by_type.keys()), list(edges_by_type.keys()),
reaction_graph=reaction_graph
)
ax.legend(handles=legend_handles)
# axis formatting
ax.grid('off')
ax.axis('off')
return ax
[docs]def plot_directed_graph(
graph: nx.DiGraph,
layout: callable = nx.kamada_kawai_layout,
reaction_graph: bool = True,
formula_as_reaction_label: bool = False,
ax: plt.axis = None,
**label_args
) -> plt.axis:
"""
Plotting a directed network obtained from the database
Parameters
----------
graph: nx.DiGraph
networkx DiGraph object, that has to contain node and edge type for
each element labelled as <node/edge>_type
layout: callable, default nx.kamada_kawai_layout
Function to compute node positions for a :obj:`nx.DiGraph`, that has to
be compatible with networkx draw_* functions . Default is
:meth:`nx.kamada_kawai_layout`. If you want to use any of the networkx
functions with particular parameter settings please use lambda
functions to set the parameters.
reaction_graph: bool, default True
Whether the graph enter contains reaction nodes or not. This should be
equivalent to whether the input graph was generated with
:meth:`NetworkGenerator.get_reaction_subgraph` or
:meth:`NetworkGenerator.get_subgraph`.
formula_as_reaction_label: bool, False
If True reaction formulas will be used as node labels, otherwise the
reaction description as provided in the source database.
ax: plt.axis, Optional
Axis to plot the network onto.
label_args:
Keyword arguments for `networkx.draw_networkx_labels`
Examples
--------
>>> from pymantra import datasets
>>> mantra_graph = datasets.example_graph()
>>> plot_directed_graph(mantra_graph)
Returns
-------
plt.axis
Axis containing the network plot including legend
"""
if ax is None:
_, ax = plt.subplots(figsize=(16, 9))
return _plot_graph(
graph, layout, directed=True, reaction_graph=reaction_graph, ax=ax,
formula_as_label=formula_as_reaction_label, **label_args
)
[docs]def plot_undirected_graph(
graph: nx.Graph,
layout: callable = nx.kamada_kawai_layout,
reaction_graph: bool = False,
formula_as_reaction_label: bool = False,
ax: plt.axis = None,
**label_args
) -> plt.axis:
"""
Plotting an undirected network obtained from the database
Parameters
----------
graph: nx.Graph
networkx Graph object, that has to contain node and edge type for each
element labelled as node-/edge-type
layout: callable, default nx.kamada_kawai_layout
Function to compute node positions for a :obj:`nx.DiGraph`, that has to
be compatible with networkx draw_* functions . Default is
:meth:`nx.kamada_kawai_layout`. If you want to use any of the networkx
functions with particular parameter settings pleas use lambda functions
to set the parameters.
reaction_graph: bool, default False
Whether the graph enter contains reaction nodes or not. This should be
equivalent to whether the input graph was generated with
:meth:`NetworkGenerator.get_reaction_subgraph` or
:meth:`NetworkGenerator.get_subgraph`.
Generally, It is recommended to use :py:meth:`~plot_directed_graph`
when plotting a reaction graph.
formula_as_reaction_label: bool, False
If True reaction formulas will be used as node labels, otherwise the
reaction description as provided in the source database.
ax: plt.axis, Optional
Axis to plot the network onto.
label_args:
Keyword arguments for `networkx.draw_networkx_labels`
Examples
--------
>>> from pymantra import datasets
>>> mantra_graph = datasets.example_graph()
>>> plot_undirected_graph(mantra_graph)
Returns
-------
plt.axis
Axis containing the network plot including legend
"""
if ax is None:
_, ax = plt.subplots(figsize=(16, 9))
return _plot_graph(
graph, layout, directed=False, reaction_graph=reaction_graph, ax=ax,
formula_as_label=formula_as_reaction_label, **label_args
)
def _remove_zero_features(df: pd.DataFrame):
zero_idxs = df.index[(df == 0).all(axis=1)]
zero_cols = df.columns[(df == 0).all(axis=0)]
return df.drop(index=zero_idxs, columns=zero_cols)
def _prep_correlation_dfs(
corrs: Dict[str, pd.DataFrame], ref_group: str, tgt_group: str,
pvals: pd.Series, set_zero: bool, thresh: float, strip_column_names: bool
):
ref_df = corrs[ref_group].copy()
tgt_df = corrs[tgt_group].copy()
if set_zero:
ref_df[pvals[ref_group] > thresh] = 0
tgt_df[pvals[tgt_group] > thresh] = 0
if strip_column_names:
col_map = {col: col.split(", ")[0] for col in ref_df.columns}
ref_df.rename(columns=col_map, inplace=True)
tgt_df.rename(columns=col_map, inplace=True)
return ref_df, tgt_df
def _plot_clustmap(data, cluster, reorder, return_df, ax, **kwargs):
if cluster:
if return_df:
return data, sns.clustermap(data, **kwargs)
return sns.clustermap(data, **kwargs)
else:
if reorder:
clust = sns.clustermap(data)
rows = clust.dendrogram_row.reordered_ind
cols = clust.dendrogram_col.reordered_ind
if return_df:
return data, \
sns.heatmap(data.iloc[rows, cols], ax=ax, **kwargs)
return sns.heatmap(data.iloc[rows, cols], ax=ax, **kwargs)
if return_df:
return data, sns.heatmap(data, ax=ax, **kwargs)
return sns.heatmap(data, ax=ax, **kwargs)
[docs]def plot_correlation_averages(
corrs: Dict[str, pd.DataFrame], pvals: Dict[str, pd.DataFrame],
ref_group: str, tgt_group: str, set_zero: bool = True, thresh: float = .05,
cluster: bool = False, ax: plt.axis = None, reorder: bool = False,
strip_column_names: bool = False, return_averages: bool = False,
remove_all_zeros: bool = False, **kwargs
) -> Union[Union[sns.matrix.ClusterGrid, plt.axis],
Tuple[pd.DataFrame, Union[sns.matrix.ClusterGrid, plt.axis]]]:
"""Plot the average multi-omics correlation over multiple sample groups
Parameters
----------
corrs: Dict[str, pd.DataFrame]
Group-wise correlations as returned by
`pymantra.compute_multiomics_associations`
pvals: Dict[str, pd.DataFrame]
Group-wise correlation p-values as returned by
`pymantra.compute_multiomics_associations`
ref_group: str
Name of the reference group.
tgt_group: str
Name of the target group. The correlation values of this group will be
subtracted from the correlations of `ref_group`
set_zero: bool, False
Whether to set correlations with a p-value > some threshold (see
`thresh`) to zero
thresh: float, .05
p-value cutoff above which all correlations will be set to zero. Only
relevant if `set_zero` is True
cluster: bool, False
Whether to cluster features
ax: plt.axis, optional
Axis to plot onto. Only used if `cluster` if False
reorder: bool, False
Whether to reorder columns and rows by clustering. This essentially
means plotting a clustermap but leaving away the clustering trees and
only reordering. Only relevant if `cluster` is False.
strip_column_names: bool, False
Whether to strip the column names (i.e. reaction labels) to only retain
the first reaction annotation
return_averages: bool, False
Whether to return the correlation difference data frame
remove_all_zeros: bool, False
Whether to remove features with no significant associations. Only
used if `pvals` is given.
kwargs
Keyword arguments to parse to `seaborn.heatmap` (`cluster` is False)
or `seaborn.clustermap`
Returns
-------
Union[Union[sns.matrix.ClusterGrid, plt.axis],
Tuple[pd.DataFrame, Union[sns.matrix.ClusterGrid, plt.axis]]]
A seaborn `ClusterGrid` object, if `cluster` is True, a matplotlib axis
otherwise
Examples
--------
>>> from pymantra.datasets import example_multiomics_enrichment_data
>>> from pymantra import (
... compute_reaction_estimates, compute_multiomics_associations)
>>> metabolite_data, microbiome_data, sample_groups, graph = \
... example_multiomics_enrichment_data()
>>> residuals = \
... compute_reaction_estimates(graph, metabolite_data, sample_groups)
>>> corrs, pvals = compute_multiomics_associations(
... residuals, microbiome_data, sample_groups, comparison=("0", "1"))
>>> diff_associations, clust_map = plot_correlation_averages(
... corrs, pvals, "0", "1", cluster=True)
"""
ref_df, tgt_df = _prep_correlation_dfs(
corrs, ref_group, tgt_group, pvals, set_zero, thresh,
strip_column_names
)
mean_data = pd.DataFrame(
np.mean([ref_df, tgt_df], axis=0), columns=ref_df.columns,
index=ref_df.index
)
if remove_all_zeros:
mean_data = _remove_zero_features(mean_data)
return _plot_clustmap(
mean_data, cluster, reorder, return_averages, ax,
cmap=kwargs.pop("cmap", "vlag"), vmin=kwargs.pop("vmin", -1),
vmax=kwargs.pop("vmax", 1),
xticklabels=kwargs.pop("xticklabels", True),
yticklabels=kwargs.pop("yticklabels", True), **kwargs
)
[docs]def plot_correlation_differences(
corrs: Dict[str, pd.DataFrame], pvals: Dict[str, pd.DataFrame],
ref_group: str, tgt_group: str, set_zero: bool = True, thresh: float = .05,
cluster: bool = False, ax: plt.axis = None, reorder: bool = False,
strip_column_names: bool = False, return_differences: bool = False,
remove_all_zeros: bool = False, **kwargs
) -> Union[Union[sns.matrix.ClusterGrid, plt.axis],
Tuple[pd.DataFrame, Union[sns.matrix.ClusterGrid, plt.axis]]]:
"""Plot the multi-omics correlation differences between two sample groups
Parameters
----------
corrs: Dict[str, pd.DataFrame]
Group-wise correlations as returned by
`pymantra.compute_multiomics_associations`
pvals: Dict[str, pd.DataFrame]
Group-wise correlation p-values as returned by
`pymantra.compute_multiomics_associations`
ref_group: str
Name of the reference group.
tgt_group: str
Name of the target group. The correlation values of this group will be
subtracted from the correlations of `ref_group`
set_zero: bool, False
Whether to set correlations with a p-value > some threshold (see
`thresh`) to zero
thresh: float, .05
p-value cutoff above which all correlations will be set to zero. Only
relevant if `set_zero` is True
cluster: bool, False
Whether to cluster features
ax: plt.axis, optional
Axis to plot onto. Only used if `cluster` if False
reorder: bool, False
Whether to reorder columns and rows by clustering. Only relevant if
`cluster` is False,
strip_column_names: bool, False
Whether to strip the column names (i.e. reaction labels) to only retain
the first reaction annotation
return_differences: bool, False
Whether to return the correlation difference data frame
remove_all_zeros: bool, False
Whether to remove features with no significant associations. Only
used if `pvals` is given.
kwargs
Keyword arguments to parse to `seaborn.heatmap` (`cluster` is False)
or `seaborn.clustermap`
Returns
-------
Union[Union[sns.matrix.ClusterGrid, plt.axis],
Tuple[pd.DataFrame, Union[sns.matrix.ClusterGrid, plt.axis]]]
A seaborn `ClusterGrid` object, if `cluster` is True, a matplotlib axis
otherwise
Examples
--------
>>> from pymantra.datasets import example_multiomics_enrichment_data
>>> from pymantra import (
... compute_reaction_estimates, compute_multiomics_associations)
>>> metabolite_data, microbiome_data, sample_groups, graph = \
... example_multiomics_enrichment_data()
>>> residuals = \
... compute_reaction_estimates(graph, metabolite_data, sample_groups)
>>> corrs, pvals = compute_multiomics_associations(
... residuals, microbiome_data, sample_groups, comparison=("0", "1"))
>>> diff_associations, clust_map = plot_correlation_differences(
... corrs, pvals, "0", "1", cluster=True, return_differences=True)
"""
ref_df, tgt_df = _prep_correlation_dfs(
corrs, ref_group, tgt_group, pvals, set_zero, thresh,
strip_column_names
)
diff_data = ref_df - tgt_df
if remove_all_zeros:
diff_data = _remove_zero_features(diff_data)
return _plot_clustmap(
diff_data, cluster, reorder, return_differences, ax,
cmap=kwargs.pop("cmap", "vlag"), vmin=kwargs.pop("vmin", -2),
vmax=kwargs.pop("vmax", 2),
xticklabels=kwargs.pop("xticklabels", True),
yticklabels=kwargs.pop("yticklabels", True), **kwargs
)
[docs]def plot_reaction_association(
residuals: pd.DataFrame, omics_data: pd.DataFrame,
corrs: pd.DataFrame, groups: pd.Series = None, top_n: int = 3,
axes: List[plt.axis] = None, pal: Dict[str, str] = None
):
"""Plot the highest correlating associations between reactions and
non-metabolomic omics-data
Parameters
----------
residuals: pd.DataFrame
Residual values as computed by `pymantra.compute_reaction_estimates`
omics_data: pd.DataFrame
Multi-omics data associated with `residuals`. Rows must be features
(e.g. microbial species or transcript abundances) and columns samples.
corrs: pd.DataFrame
Correlations between omics-data (rows) and reaction residuals (columns)
groups: pd.Series
Sample groups, index must be the same as the indices in `residuals_`
top_n: int, 3
Number of associations to plot. If `axes` is given, len(axes) will be
used instead.
axes: List[plt.axis], optional
(Flat) List of axes to plot onto. If `None`, `top_n` axes will be
created inside the function. If you want to have a multi-row layout,
we recommend using `plt.subplots` with the `ncols`/`nrows` parameters
and then passing the axes with `axes.flatten()`.
pal: Dict[str, str], optional
Colour palette to use. Only use if `groups` is given. The keys must be
the group names that appear in `groups` and the values must be strings
defining colors in a matplotlib-compatible format.
Returns
-------
List[plt.axis]
List of axes containing the association plots
Examples
--------
>>> from pymantra.datasets import example_multiomics_enrichment_data
>>> from pymantra import (
... compute_reaction_estimates, compute_multiomics_associations)
>>> metabolite_data, microbiome_data, sample_groups, graph = \
... example_multiomics_enrichment_data()
>>> residuals = \
... compute_reaction_estimates(graph, metabolite_data, sample_groups)
>>> corrs, pvals = compute_multiomics_associations(
... residuals, microbiome_data, sample_groups, comparison=("0", "1"))
>>> diff_associations, clust_map = plot_correlation_differences(
... corrs, pvals, "0", "1", cluster=True, return_differences=True)
>>> plot_reaction_association(
... residuals, microbiome_data, corr_associations, sample_groups)
"""
# shape checking
if residuals.shape[0] != omics_data.shape[1]:
raise ValueError(
"'residuals' and 'microbiome_data' need to have matching samples")
n, m = residuals.shape[1], omics_data.shape[0]
if n != corrs.shape[1] or m != corrs.shape[0]:
raise ValueError(
"'corrs' must have multi-omics data in rows and residuals in"
f"columns. Expected {m} rows, found {corrs.shape[0]}.\n"
f"Expected {n} columns, found {corrs.shape[1]}."
)
# (N * M) x 2 array giving the indices of correlations sorted by absolute
# correlation values/differences
top_corrs = np.dstack(
np.unravel_index(
np.argsort(-corrs.abs().to_numpy().ravel()), corrs.shape)
).squeeze()
if axes is None:
_, axes = plt.subplots(ncols=top_n, figsize=(16, 9))
if top_n == 1:
axes = [axes]
if groups is None:
# plotting without group coloring
for i, ax in enumerate(axes):
yidx, xidx = top_corrs[i, :]
ax.scatter(residuals.iloc[:, xidx], omics_data.iloc[yidx, :])
ax.set_xlabel(corrs.columns[xidx])
ax.set_ylabel(corrs.index[yidx])
else:
# plotting with group coloring
masks = {group: groups == group for group in groups.unique()}
if pal is None:
pal = dict(zip(
masks.keys(), plt.rcParams["axes.prop_cycle"].by_key()['color']
))
for i, ax in enumerate(axes):
for group, mask in masks.items():
yidx, xidx = top_corrs[i, :]
ax.scatter(
residuals.loc[mask, :].iloc[:, xidx],
omics_data.loc[:, mask].iloc[yidx, :],
c=pal.get(group, "tab:grey"), label=group
)
ax.set_xlabel(residuals.columns[xidx])
ax.set_ylabel(omics_data.index[yidx])
# legend is only called shown on the last plot
plt.legend()
return axes
def _residual_to_seaborn_long(
residuals_: pd.DataFrame, groups: pd.Series, id_vars: List[str] = None,
):
"""Turn a wide residual data frame into a long format"""
residuals = residuals_.copy()
residuals["Group"] = groups
if id_vars is None:
id_vars = ["Group"]
return residuals.melt(
id_vars=id_vars, value_vars=residuals.columns,
value_name="Residual", var_name="Reaction"
)
def _move_legend(
ax: plt.axis, legend_position: Tuple[str, Tuple[float, float]],
drop_legend: bool = False
):
"""Move the legend of a seaborn plot"""
if drop_legend:
ax.legend_.remove()
elif legend_position is None:
sns.move_legend(ax, "lower right", bbox_to_anchor=(1.05, 0))
else:
sns.move_legend(
ax, legend_position[0], bbox_to_anchor=legend_position[1])
def _rotate_xticks(
ax: plt.axis, labels: List[str], rotate_labels: bool
):
"""Format the xticklabels of a reaction plot and optionally rotate"""
if rotate_labels:
ax.set_xticks(ax.get_xticks(), labels, rotation=90, ha="center")
else:
ax.set_xticks(ax.get_xticks(), labels)
[docs]def residual_violinplot(
residuals_: pd.DataFrame, groups: pd.Series, pvalues: pd.Series = None,
ax: plt.axis = None, fontsize: int = 12, significance_only: bool = True,
legend_position: Tuple[str, Tuple[float, float]] = None,
drop_legend: bool = False, rotate_labels: bool = False,
plot_significant_features: bool = False, thresh: float = .05, **kwargs
) -> plt.axis:
"""Plot the residual values per group as a violinplot
Parameters
----------
residuals_: pd.DataFrame
Residual values as computed by `pymantra.compute_reaction_estimates`
groups: pd.Series
Sample groups, index must be the same as the indices in `residuals_`
pvalues: pd.Series, optional
p-values to be added to the plot
ax: plt.axis, optional
Axis to plot onto. If `None`, a new figure will be initialised inside
the function
fontsize: int, 7
Font size for p-value annotations
significance_only: bool, True
If True and `pvalues` is not None, the p-values will be shown as
asterisks or 'ns' indicating different levels of significance. If False
rounded p-values are displayed.
If `pvalues` is None, this has no effect.
legend_position: Tuple[str, Tuple[float, float]], optional
Where to position the legend. First element is a string compatible with
the typical matplotlib legend location names, the second is a 2-tuple
giving the x- and y-axis offset (via `bbox_to_anchor`)
drop_legend: bool, False
If set to True legend drawing will be suppressed
rotate_labels: bool, False
Whether to rotate the x-axis ticklabels by 90 degrees or not
plot_significant_features: bool, False
Whether to show only significant features. Significance is defined via
the `thresh` parameter. Can only be used if `pvalues` is given.
thresh: float, .05
p-value threshold used to subset if `plot_significant_features` is True
kwargs
Optional keyword arguments to `seaborn.violinplot`
Returns
-------
plt.axis
Matplotlib axis used for plotting
Examples
--------
>>> from pymantra.datasets import example_multiomics_enrichment_data
>>> from pymantra import (
... compute_reaction_estimates, compute_multiomics_associations)
>>> metabolite_data, microbiome_data, sample_groups, graph = \
... example_multiomics_enrichment_data()
>>> residuals = \
... compute_reaction_estimates(graph, metabolite_data, sample_groups)
>>> residual_violinplot(residuals, sample_groups)
"""
# refactoring residuals data into long-format with group annotation
sig_features = residuals_.columns
if plot_significant_features:
if pvalues is None:
warnings.warn(
"'plot_significant' can only be used when 'pvalues' is set")
else:
sig_features = pvalues.index[pvalues < thresh]
if sig_features.empty:
warnings.warn(
"No significant features detected, with threshold "
f"{thresh}. Defaulting to top-10 lowest p-values."
)
sig_features = pvalues.index[np.argsort(pvalues)[:10]]
res_long = _residual_to_seaborn_long(
residuals_.loc[:, sig_features], groups)
# actual plotting
if ax is None:
_, ax = plt.subplots(figsize=(16, 9))
sns.violinplot(
data=res_long, x="Reaction", y="Residual", hue="Group", ax=ax,
inner=kwargs.pop("inner", "quart"), split=kwargs.pop("split", True),
saturation=kwargs.pop("saturation", 1), **kwargs
)
# removing legend or moving outside plot to see all significance
# annotations
_move_legend(ax, legend_position, drop_legend)
# get the maximum value on the y-axis to locate the p-value annotation
_, ymax = ax.get_ylim()
ymax -= ymax / 10
# p-value annotation
if pvalues is not None:
if significance_only:
def get_annotation(reac):
pval = pvalues[reac]
if pval < 0.05:
if pval < 0.01:
if pval < 0.005:
if pval < 0.001:
return "****"
else:
return "***"
else:
return "**"
else:
return "*"
else:
return "ns"
else:
def get_annotation(reac):
return str(round(pvalues[reac], ndigits=3))
# annotate the p-values on top
for ticklabel in ax.get_xticklabels():
reaction = ticklabel.get_text()
xpos = ticklabel.get_position()[0]
ax.text(
xpos, ymax, get_annotation(reaction), ha="center", va="bottom",
fontsize=fontsize
)
# xtick rotation while cutting merged reactions into one node label for
# readability
_rotate_xticks(ax, sig_features, rotate_labels)
return ax