# -*- coding: utf-8 -*-
"""Functions for making pretty plots and whatnot."""
import os
from typing import Iterable
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # noqa
import nibabel as nib
import numpy as np
from .freesurfer import FSIGNORE, _decode_list
def _grid_communities(communities):
"""
Generate boundaries of `communities`.
Parameters
----------
communities : array_like
Community assignment vector
Returns
-------
bounds : list
Boundaries of communities
"""
communities = np.asarray(communities)
if 0 in communities:
communities = communities + 1
comm = communities[np.argsort(communities)]
bounds = []
for i in np.unique(comm):
ind = np.where(comm == i)
if len(ind) > 0:
bounds.append(np.min(ind))
bounds.append(len(communities))
return bounds
[docs]def sort_communities(consensus, communities):
"""
Sort `communities` in `consensus` according to strength.
Parameters
----------
consensus : array_like
Correlation matrix
communities : array_like
Community assignments for `consensus`
Returns
-------
inds : np.ndarray
Index array for sorting `consensus`
"""
communities = np.asarray(communities)
if 0 in communities:
communities = communities + 1
bounds = _grid_communities(communities)
inds = np.argsort(communities)
for n, f in enumerate(bounds[:-1]):
i = inds[f:bounds[n + 1]]
cco = i[consensus[np.ix_(i, i)].mean(axis=1).argsort()[::-1]]
inds[f:bounds[n + 1]] = cco
return inds
[docs]def plot_mod_heatmap(data, communities, *, inds=None, edgecolor='black',
ax=None, figsize=(6.4, 4.8), xlabels=None, ylabels=None,
xlabelrotation=90, ylabelrotation=0, cbar=True,
square=True, xticklabels=None, yticklabels=None,
mask_diagonal=True, **kwargs):
"""
Plot `data` as heatmap with borders drawn around `communities`.
Parameters
----------
data : (N, N) array_like
Correlation matrix
communities : (N,) array_like
Community assignments for `data`
inds : (N,) array_like, optional
Index array for sorting `data` within `communities`. If None, these
will be generated from `data`. Default: None
edgecolor : str, optional
Color for lines demarcating community boundaries. Default: 'black'
ax : matplotlib.axes.Axes, optional
Axis on which to plot the heatmap. If none provided, a new figure and
axis will be created. Default: None
figsize : tuple, optional
Size of figure to create if `ax` is not provided. Default: (20, 20)
{x,y}labels : list, optional
List of labels on {x,y}-axis for each community in `communities`. The
number of labels should match the number of unique communities.
Default: None
{x,y}labelrotation : float, optional
Angle of the rotation of the labels. Available only if `{x,y}labels`
provided. Default : xlabelrotation: 90, ylabelrotation: 0
square : bool, optional
Setting the matrix with equal aspect. Default: True
{x,y}ticklabels : list, optional
Incompatible with `{x,y}labels`. List of labels for each entry (not
community) in `data`. Default: None
cbar : bool, optional
Whether to plot colorbar. Default: True
mask_diagonal : bool, optional
Whether to mask the diagonal in the plotted heatmap. Default: True
kwargs : key-value mapping
Keyword arguments for `plt.pcolormesh()`
Returns
-------
ax : matplotlib.axes.Axes
Axis object containing plot
"""
for t, label in zip([xticklabels, yticklabels], [xlabels, ylabels]):
if t is not None and label is not None:
raise ValueError('Cannot set both {x,y}labels and {x,y}ticklabels')
# get indices for sorting consensus
if inds is None:
inds = sort_communities(data, communities)
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
# plot data re-ordered based on community and node strength
if mask_diagonal:
plot_data = np.ma.masked_where(np.eye(len(data)),
data[np.ix_(inds, inds)])
else:
plot_data = data[np.ix_(inds, inds)]
coll = ax.pcolormesh(plot_data, edgecolor='none', **kwargs)
ax.set(xlim=(0, plot_data.shape[1]), ylim=(0, plot_data.shape[0]))
# set equal aspect
if square:
ax.set_aspect('equal')
for side in ['top', 'right', 'left', 'bottom']:
ax.spines[side].set_visible(False)
# invert the y-axis so it looks "as expected"
ax.invert_yaxis()
# plot the colorbar
if cbar:
cb = ax.figure.colorbar(coll)
if kwargs.get('rasterized', False):
cb.solids.set_rasterized(True)
# draw borders around communities
bounds = _grid_communities(communities)
bounds[0] += 0.2
bounds[-1] -= 0.2
for n, edge in enumerate(np.diff(bounds)):
ax.add_patch(patches.Rectangle((bounds[n], bounds[n]),
edge, edge, fill=False, linewidth=2,
edgecolor=edgecolor))
if xlabels is not None or ylabels is not None:
# find the tick locations
initloc = _grid_communities(communities)
tickloc = []
for loc in range(len(initloc) - 1):
tickloc.append(np.mean((initloc[loc], initloc[loc + 1])))
if xlabels is not None:
# make sure number of labels match the number of ticks
if len(tickloc) != len(xlabels):
raise ValueError('Number of labels do not match the number of '
'unique communities.')
else:
ax.set_xticks(tickloc)
ax.set_xticklabels(labels=xlabels, rotation=xlabelrotation)
ax.tick_params(left=False, bottom=False)
if ylabels is not None:
# make sure number of labels match the number of ticks
if len(tickloc) != len(ylabels):
raise ValueError('Number of labels do not match the number of '
'unique communities.')
else:
ax.set_yticks(tickloc)
ax.set_yticklabels(labels=ylabels, rotation=ylabelrotation)
ax.tick_params(left=False, bottom=False)
if xticklabels is not None:
labels_ind = [xticklabels[i] for i in inds]
ax.set_xticks(np.arange(len(labels_ind)) + 0.5)
ax.set_xticklabels(labels_ind, rotation=90)
if yticklabels is not None:
labels_ind = [yticklabels[i] for i in inds]
ax.set_yticks(np.arange(len(labels_ind)) + 0.5)
ax.set_yticklabels(labels_ind)
return ax
[docs]def plot_conte69(data, lhlabel, rhlabel, surf='midthickness',
vmin=None, vmax=None, colormap='viridis',
colorbar=True, num_labels=4, orientation='horizontal',
colorbartitle=None, backgroundcolor=(1, 1, 1),
foregroundcolor=(0, 0, 0), **kwargs):
"""
Plot surface `data` on Conte69 Atlas.
Parameters
----------
data : (N,) array_like
Surface data for N parcels
lhlabel : str
Path to .gii file (generic GIFTI file) containing labels to N/2 parcels
on the left hemisphere
rhlabel : str
Path to .gii file (generic GIFTI file) containing labels to N/2 parcels
on the right hemisphere
surf : {'midthickness', 'inflated', 'vinflated'}, optional
Type of brain surface. Default: 'midthickness'
vmin : float, optional
Minimum value to scale the colormap. If None, the min of the data will
be used. Default: None
vmax : float, optional
Maximum value to scale the colormap. If None, the max of the data will
be used. Default: None
colormap : str, optional
Any colormap from matplotlib. Default: 'viridis'
colorbar : bool, optional
Wheter to display a colorbar. Default: True
num_labels : int, optional
The number of labels to display on the colorbar.
Available only if colorbar=True. Default: 4
orientation : str, optional
Defines the orientation of colorbar. Can be 'horizontal' or 'vertical'.
Available only if colorbar=True. Default: 'horizontal'
colorbartitle : str, optional
The title of colorbar. Available only if colorbar=True. Default: None
backgroundcolor : tuple of float values with RGB code in [0, 1], optional
Defines the background color. Default: (1, 1, 1)
foregroundcolor : tuple of float values with RGB code in [0, 1], optional
Defines the foreground color (e.g., colorbartitle color).
Default: (0, 0, 0)
kwargs : key-value mapping
Keyword arguments for `mayavi.mlab.triangular_mesh()`
Returns
-------
scene : mayavi.Scene
Scene object containing plot
"""
return plot_fslr(data, lhlabel, rhlabel, surf_atlas='conte69',
surf_type=surf, vmin=vmin, vmax=vmax, colormap=colormap,
colorbar=colorbar, num_labels=num_labels,
orientation=orientation, colorbartitle=colorbartitle,
backgroundcolor=backgroundcolor,
foregroundcolor=foregroundcolor, **kwargs)
[docs]def plot_fslr(data, lhlabel, rhlabel, surf_atlas='conte69',
surf_type='midthickness', vmin=None, vmax=None,
colormap='viridis', colorbar=True, num_labels=4,
orientation='horizontal', colorbartitle=None,
backgroundcolor=(1, 1, 1), foregroundcolor=(0, 0, 0),
**kwargs):
"""
Plot surface `data` on a given fsLR32k atlas.
Parameters
----------
data : (N,) array_like
Surface data for N parcels
lhlabel : str
Path to .gii file (generic GIFTI file) containing labels to N/2 parcels
on the left hemisphere
rhlabel : str
Path to .gii file (generic GIFTI file) containing labels to N/2 parcels
on the right hemisphere
surf_atlas: {'conte69', 'yerkes19'}, optional
Surface atlas on which to plot 'data'. Default: 'conte69'
surf_type : {'midthickness', 'inflated', 'vinflated'}, optional
Type of brain surface. Default: 'midthickness'
vmin : float, optional
Minimum value to scale the colormap. If None, the min of the data will
be used. Default: None
vmax : float, optional
Maximum value to scale the colormap. If None, the max of the data will
be used. Default: None
colormap : str, optional
Any colormap from matplotlib. Default: 'viridis'
colorbar : bool, optional
Wheter to display a colorbar. Default: True
num_labels : int, optional
The number of labels to display on the colorbar.
Available only if colorbar=True. Default: 4
orientation : str, optional
Defines the orientation of colorbar. Can be 'horizontal' or 'vertical'.
Available only if colorbar=True. Default: 'horizontal'
colorbartitle : str, optional
The title of colorbar. Available only if colorbar=True. Default: None
backgroundcolor : tuple of float values with RGB code in [0, 1], optional
Defines the background color. Default: (1, 1, 1)
foregroundcolor : tuple of float values with RGB code in [0, 1], optional
Defines the foreground color (e.g., colorbartitle color).
Default: (0, 0, 0)
kwargs : key-value mapping
Keyword arguments for `mayavi.mlab.triangular_mesh()`
Returns
-------
scene : mayavi.Scene
Scene object containing plot
"""
from .datasets import fetch_conte69, fetch_yerkes19
try:
from mayavi import mlab
except ImportError:
raise ImportError('Cannot use plot_fslr() if mayavi is not '
'installed. Please install mayavi and try again.') from None
opts = dict()
opts.update(**kwargs)
try:
if surf_atlas == 'conte69':
surface = fetch_conte69()[surf_type]
elif surf_atlas == 'yerkes19':
surface = fetch_yerkes19()[surf_type]
except KeyError:
raise ValueError('Provided surf "{}" is not valid. Must be one of '
'[\'midthickness\', \'inflated\', \'vinflated\']'
.format(surf_type)) from None
lhsurface, rhsurface = [nib.load(s) for s in surface]
lhlabels = nib.load(lhlabel).darrays[0].data
rhlabels = nib.load(rhlabel).darrays[0].data
lhvert, lhface = [d.data for d in lhsurface.darrays]
rhvert, rhface = [d.data for d in rhsurface.darrays]
# add NaNs for medial wall
data = np.append(np.nan, data)
# get lh and rh data
lhdata = np.squeeze(data[lhlabels.astype(int)])
rhdata = np.squeeze(data[rhlabels.astype(int)])
# plot
lhplot = mlab.figure()
rhplot = mlab.figure()
lhmesh = mlab.triangular_mesh(lhvert[:, 0], lhvert[:, 1], lhvert[:, 2],
lhface, figure=lhplot, colormap=colormap,
mask=np.isnan(lhdata), scalars=lhdata,
vmin=vmin, vmax=vmax, **opts)
lhmesh.module_manager.scalar_lut_manager.lut.nan_color = [0.863, 0.863,
0.863, 1]
lhmesh.update_pipeline()
if colorbar is True:
mlab.colorbar(title=colorbartitle, nb_labels=num_labels,
orientation=orientation)
rhmesh = mlab.triangular_mesh(rhvert[:, 0], rhvert[:, 1], rhvert[:, 2],
rhface, figure=rhplot, colormap=colormap,
mask=np.isnan(rhdata), scalars=rhdata,
vmin=vmin, vmax=vmax, **opts)
rhmesh.module_manager.scalar_lut_manager.lut.nan_color = [0.863, 0.863,
0.863, 1]
rhmesh.update_pipeline()
if colorbar is True:
mlab.colorbar(title=colorbartitle, nb_labels=num_labels,
orientation=orientation)
mlab.view(azimuth=180, elevation=90, distance=450, figure=lhplot)
mlab.view(azimuth=180, elevation=-90, distance=450, figure=rhplot)
mlab.figure(bgcolor=backgroundcolor, fgcolor=foregroundcolor,
figure=lhplot)
mlab.figure(bgcolor=backgroundcolor, fgcolor=foregroundcolor,
figure=rhplot)
return lhplot, rhplot
def _get_fs_subjid(subject_id, subjects_dir=None):
"""
Get fsaverage version `subject_id`, fetching if required.
Parameters
----------
subject_id : str
FreeSurfer subject ID
subjects_dir : str, optional
Path to FreeSurfer subject directory. If not set, will inherit from
the environmental variable $SUBJECTS_DIR. Default: None
Returns
-------
subject_id : str
FreeSurfer subject ID
subjects_dir : str
Path to subject directory with `subject_id`
"""
from netneurotools.utils import check_fs_subjid
# check for FreeSurfer install w/fsaverage; otherwise, fetch required
try:
subject_id, subjects_dir = check_fs_subjid(subject_id, subjects_dir)
except FileNotFoundError:
if 'fsaverage' not in subject_id:
raise ValueError('Provided subject {} does not exist in provided '
'subjects_dir {}'
.format(subject_id, subjects_dir)) from None
from netneurotools.datasets import fetch_fsaverage
from netneurotools.datasets.utils import _get_data_dir
fetch_fsaverage(subject_id)
subjects_dir = os.path.join(_get_data_dir(), 'tpl-fsaverage')
subject_id, subjects_dir = check_fs_subjid(subject_id, subjects_dir)
return subject_id, subjects_dir
[docs]def plot_fsaverage(data, *, lhannot, rhannot, order='lr', mask=None,
noplot=None, subject_id='fsaverage', subjects_dir=None,
vmin=None, vmax=None, **kwargs):
"""
Plot `data` to fsaverage brain using `annot` as parcellation.
Parameters
----------
data : (N,) array_like
Data for `N` parcels as defined in `annot`
lhannot : str
Filepath to .annot file containing labels to parcels on the left
hemisphere. If a full path is not provided the file is assumed to
exist inside the `subjects_dir`/`subject`/label directory.
rhannot : str
Filepath to .annot file containing labels to parcels on the right
hemisphere. If a full path is not provided the file is assumed to
exist inside the `subjects_dir`/`subject`/label directory.
order : str, optional
Order of the hemispheres in the data vector (either 'LR' or 'RL').
Default: 'LR'
mask : (N,) array_like, optional
Binary array where entries indicate whether values in `data` should be
masked from plotting (True = mask; False = show). Default: None
noplot : list, optional
List of names in `lhannot` and `rhannot` to not plot. It is assumed
these are NOT present in `data`. By default 'unknown' and
'corpuscallosum' will never be plotted if they are present in the
provided annotation files. Default: None
subject_id : str, optional
Subject ID to use; must be present in `subjects_dir`. Default:
'fsaverage'
subjects_dir : str, optional
Path to FreeSurfer subject directory. If not set, will inherit from
the environmental variable $SUBJECTS_DIR. Default: None
vmin : float, optional
Minimum value for colorbar. If not provided, a robust estimation will
be used from values in `data`. Default: None
vmax : float, optional
Maximum value for colorbar. If not provided, a robust estimation will
be used from values in `data`. Default: None
kwargs : key-value pairs
Provided directly to :func:`~.plot_fsvertex` without modification.
Returns
-------
brain : surfer.Brain
Plotted PySurfer brain
Examples
--------
>>> import numpy as np
>>> from netneurotools.datasets import fetch_cammoun2012, \
fetch_schaefer2018
>>> from netneurotools.plotting import plot_fsaverage
Plotting with the Cammoun 2012 parcellation we specify `order='RL'` because
many of the Lausanne connectomes have data for the right hemisphere before
the left hemisphere.
>>> values = np.random.rand(219)
>>> scale = 'scale125'
>>> cammoun = fetch_cammoun2012('fsaverage', verbose=False)[scale]
>>> plot_fsaverage(values, order='RL',
... lhannot=cammoun.lh, rhannot=cammoun.rh) # doctest: +SKIP
Plotting with the Schaefer 2018 parcellation we can use the default
parameter for `order`:
>>> values = np.random.rand(400)
>>> scale = '400Parcels7Networks'
>>> schaefer = fetch_schaefer2018('fsaverage', verbose=False)[scale]
>>> plot_fsaverage(values,
... lhannot=schaefer.lh,
... rhannot=schaefer.rh) # doctest: +SKIP
"""
subject_id, subjects_dir = _get_fs_subjid(subject_id, subjects_dir)
# cast data to float (required for NaNs)
data = np.asarray(data, dtype='float')
order = order.lower()
if order not in ('lr', 'rl'):
raise ValueError('order must be either \'lr\' or \'rl\'')
if mask is not None and len(mask) != len(data):
raise ValueError('Provided mask must be the same length as data.')
if vmin is None:
vmin = np.nanpercentile(data, 2.5)
if vmax is None:
vmax = np.nanpercentile(data, 97.5)
# parcels that should not be included in parcellation
drop = FSIGNORE.copy()
if noplot is not None:
if isinstance(noplot, str):
noplot = [noplot]
drop += list(noplot)
drop = _decode_list(drop)
vtx_data = []
for annot, hemi in zip((lhannot, rhannot), ('lh', 'rh')):
# loads annotation data for hemisphere, including vertex `labels`!
if not annot.startswith(os.path.abspath(os.sep)):
annot = os.path.join(subjects_dir, subject_id, 'label', annot)
labels, ctab, names = nib.freesurfer.read_annot(annot)
names = _decode_list(names)
# get appropriate data, accounting for hemispheric asymmetry
currdrop = np.intersect1d(drop, names)
if hemi == 'lh':
if order == 'lr':
split_id = len(names) - len(currdrop)
ldata, rdata = np.split(data, [split_id])
if mask is not None:
lmask, rmask = np.split(mask, [split_id])
elif order == 'rl':
split_id = len(data) - len(names) + len(currdrop)
rdata, ldata = np.split(data, [split_id])
if mask is not None:
rmask, lmask = np.split(mask, [split_id])
hemidata = ldata if hemi == 'lh' else rdata
# our `data` don't include the "ignored" parcels but our `labels` do,
# so we need to account for that. find the label ids that correspond to
# those and set them to NaN in the `data vector`
inds = sorted([names.index(n) for n in currdrop])
for i in inds:
hemidata = np.insert(hemidata, i, np.nan)
vtx = hemidata[labels]
# let's also mask data, if necessary
if mask is not None:
maskdata = lmask if hemi == 'lh' else rmask
maskdata = np.insert(maskdata, inds - np.arange(len(inds)), np.nan)
vtx[maskdata[labels] > 0] = np.nan
vtx_data.append(vtx)
brain = plot_fsvertex(np.hstack(vtx_data), order='lr', mask=None,
subject_id=subject_id, subjects_dir=subjects_dir,
vmin=vmin, vmax=vmax, **kwargs)
return brain
[docs]def plot_fsvertex(data, *, order='lr', surf='pial', views='lat',
vmin=None, vmax=None, center=None, mask=None,
colormap='viridis', colorbar=True, alpha=0.8,
label_fmt='%.2f', num_labels=3, size_per_view=500,
subject_id='fsaverage', subjects_dir=None, data_kws=None,
**kwargs):
"""
Plot vertex-wise `data` to fsaverage brain.
Parameters
----------
data : (N,) array_like
Data for `N` parcels as defined in `annot`
order : {'lr', 'rl'}, optional
Order of the hemispheres in the data vector. Default: 'lr'
surf : str, optional
Surface on which to plot data. Default: 'pial'
views : str or list, optional
Which views to plot of brain. Default: 'lat'
vmin : float, optional
Minimum value for colorbar. If not provided, a robust estimation will
be used from values in `data`. Default: None
vmax : float, optional
Maximum value for colorbar. If not provided, a robust estimation will
be used from values in `data`. Default: None
center : float, optional
Center of colormap, if desired. Default: None
mask : (N,) array_like, optional
Binary array where entries indicate whether values in `data` should be
masked from plotting (True = mask; False = show). Default: None
colormap : str, optional
Which colormap to use for plotting `data`. Default: 'viridis'
colorbar : bool, optional
Whether to display the colorbar in the plot. Default: True
alpha : [0, 1] float, optional
Transparency of plotted `data`. Default: 0.8
label_fmt : str, optional
Format of colorbar labels. Default: '%.2f'
number_of_labels : int, optional
Number of labels to display on colorbar. Default: 3
size_per_view : int, optional
Size, in pixels, of each frame in the plotted display. Default: 1000
subjects_dir : str, optional
Path to FreeSurfer subject directory. If not set, will inherit from
the environmental variable $SUBJECTS_DIR. Default: None
subject : str, optional
Subject ID to use; must be present in `subjects_dir`. Default:
'fsaverage'
data_kws : dict, optional
Keyword arguments for Brain.add_data()
Returns
-------
brain : surfer.Brain
Plotted PySurfer brain
"""
# hold off on imports until
try:
from surfer import Brain
except ImportError:
raise ImportError('Cannot use plot_fsaverage() if pysurfer is not '
'installed. Please install pysurfer and try again.') from None
subject_id, subjects_dir = _get_fs_subjid(subject_id, subjects_dir)
# cast data to float (required for NaNs)
data = np.asarray(data, dtype='float')
# handle data_kws if None
if data_kws is None:
data_kws = {}
if mask is not None and len(mask) != len(data):
raise ValueError('Provided mask must be the same length as data.')
order = order.lower()
if order not in ['lr', 'rl']:
raise ValueError('Specified order must be either \'lr\' or \'rl\'')
if vmin is None:
vmin = np.nanpercentile(data, 2.5)
if vmax is None:
vmax = np.nanpercentile(data, 97.5)
# set up brain views
if not isinstance(views, (np.ndarray, list)):
views = [views]
# size of window will depend on # of views provided
size = (size_per_view * 2, size_per_view * len(views))
brain_kws = dict(background='white', size=size)
brain_kws.update(**kwargs)
brain = Brain(subject_id=subject_id, hemi='split', surf=surf,
subjects_dir=subjects_dir, views=views, **brain_kws)
hemis = ('lh', 'rh') if order == 'lr' else ('rh', 'lh')
for n, (hemi, vtx_data) in enumerate(zip(hemis, np.split(data, 2))):
# let's mask data, if necessary
if mask is not None:
maskdata = np.asarray(np.split(mask, 2)[n], dtype=bool)
vtx_data[maskdata] = np.nan
# we don't want NaN values plotted so set a threshold if they exist
thresh, nanmask = None, np.isnan(vtx_data)
if np.any(nanmask) > 0:
thresh = np.nanmin(vtx_data) - 1
vtx_data[nanmask] = thresh
thresh += 0.5
# finally, add data to this hemisphere!
brain.add_data(vtx_data, vmin, vmax, hemi=hemi, mid=center,
thresh=thresh, alpha=1.0, remove_existing=False,
colormap=colormap, colorbar=colorbar, verbose=False,
**data_kws)
if alpha != 1.0:
surf = brain.data_dict[hemi]['surfaces']
for _, s in enumerate(surf):
s.actor.property.opacity = alpha
s.render()
# if we have a colorbar, update parameters accordingly
if colorbar:
# update label format, as desired
surf = brain.data_dict[hemi]['surfaces']
cmap = brain.data_dict[hemi]['colorbars']
# this updates the format of the colorbar labels
if label_fmt is not None:
for n, cm in enumerate(cmap):
cm.scalar_bar.label_format = label_fmt
surf[n].render()
# this updates how many labels are shown on the colorbar
if num_labels is not None:
for n, cm in enumerate(cmap):
cm.scalar_bar.number_of_labels = num_labels
surf[n].render()
return brain
[docs]def plot_point_brain(data, coords, views=None, views_orientation='vertical',
views_size=(4, 2.4), cbar=False, robust=True, size=50,
**kwargs):
"""
Plot `data` as a cloud of points in 3D space based on specified `coords`.
Parameters
----------
data : (N,) array_like
Data for an `N` node parcellation; determines color of points
coords : (N, 3) array_like
x, y, z coordinates for `N` node parcellation
views : list, optional
List specifying which views to use. Can be any of {'sagittal', 'sag',
'coronal', 'cor', 'axial', 'ax'}. If not specified will use 'sagittal'
and 'axial'. Default: None
views_orientation: str, optional
Orientation of the views. Can be either 'vertical' or 'horizontal'.
Default: 'vertical'.
views_size : tuple, optional
Figure size of each view. Default: (4, 2.4)
cbar : bool, optional
Whether to also show colorbar. Default: False
robust : bool, optional
Whether to use robust calculation of `vmin` and `vmax` for color scale.
size : int, optional
Size of points on plot. Default: 50
**kwargs
Key-value pairs passed to `matplotlib.axes.Axis.scatter`
Returns
-------
fig : :class:`matplotlib.figure.Figure`
"""
_views = dict(sagittal=(0, 180), sag=(0, 180),
axial=(90, 180), ax=(90, 180),
coronal=(0, 90), cor=(0, 90))
x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
if views is None:
views = [_views[f] for f in ['sagittal', 'axial']]
else:
if not isinstance(views, Iterable) or isinstance(views, str):
views = [views]
views = [_views[f] for f in views]
if views_orientation == 'vertical':
ncols, nrows = 1, len(views)
elif views_orientation == 'horizontal':
ncols, nrows = len(views), 1
figsize = (ncols * views_size[0], nrows * views_size[1])
# create figure and axes (3d projections)
fig, axes = plt.subplots(ncols=ncols, nrows=nrows,
figsize=figsize,
subplot_kw=dict(projection='3d'))
opts = dict(linewidth=0.5, edgecolor='gray', cmap='viridis')
if robust:
vmin, vmax = np.percentile(data, [2.5, 97.5])
opts.update(dict(vmin=vmin, vmax=vmax))
opts.update(kwargs)
# iterate through saggital/axial views and plot, rotating as needed
for n, view in enumerate(views):
# if only one view then axes is not a list!
ax = axes[n] if len(views) > 1 else axes
# make the actual scatterplot and update the view / aspect ratios
col = ax.scatter(x, y, z, c=data, s=size, **opts)
ax.view_init(*view)
ax.axis('off')
scaling = np.array([ax.get_xlim(),
ax.get_ylim(),
ax.get_zlim()])
ax.set_box_aspect(tuple(scaling[:, 1] - scaling[:, 0]))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0, wspace=0)
# add colorbar to axes
if cbar:
cbar = fig.colorbar(col, ax=axes.flatten(),
drawedges=False, shrink=0.7)
cbar.outline.set_linewidth(0)
return fig