Module skm_pyutils.plot
Utilities for plotting with matplotlib.
Expand source code
"""Utilities for plotting with matplotlib."""
import colorsys
import os
from collections import OrderedDict
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import seaborn as sns
from skm_pyutils.path import make_path_if_not_exists
class GroupManager:
def __init__(self, group_list):
self.group_list = group_list
print(self.group_list)
self.info_dict = OrderedDict()
self.index = 0
self.color_list = ["Blues", "Oranges", "Greens", "Reds", "Purples", "Greys"]
set_vals = sorted(set(group_list), key=group_list.index)
print(set_vals)
import numpy as np
if len(set_vals) > len(self.color_list):
start_vals = np.arange(0.0, 2.5, 2.5 / (len(set_vals) - 0.99))
for set_v, start_v in zip(set_vals, start_vals):
self.info_dict[set_v] = ColorManager(
group_list.count(set_v), "sns_helix", start=start_v
)
else:
start_vals = self.color_list[: len(set_vals)]
for set_v, start_v in zip(set_vals, start_vals):
self.info_dict[set_v] = ColorManager(
group_list.count(set_v), "sns", sns_style=start_v
)
print(self.info_dict)
def get_next_color(self):
out = self.info_dict[self.group_list[self.index]].get_next_color()
self._increment()
return out
def _increment(self):
self.index += 1
if self.index == len(self.group_list):
self.index = 0
def test_plot(self):
import numpy as np
from scipy.stats import norm
fig, ax = plt.subplots()
x_axis = np.arange(-15, 5, 0.001)
std_devs = np.arange(0.8, 3, 2.20 / len(self.group_list))
for sd in std_devs:
ax.plot(x_axis, norm.pdf(x_axis, -sd * 2, sd), color=self.get_next_color())
sns.despine(top=True, bottom=True, right=True, left=True)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.savefig("test.png", dpi=400)
class ColorManager:
def __init__(self, num_colors, method="sns", **kwargs):
self.num_colors = num_colors
if method == "sns":
sns_style = kwargs.get("sns_style", None)
self.create_sns_palette(s_type=sns_style)
elif method == "rgb":
self.create_rgb()
elif method == "sns_helix":
start = kwargs.get("start", 0)
self.create_sns_helix(start)
else:
raise ValueError("{} not recognised method".format(method))
self.idx = 0
def create_rgb(self):
N = self.num_colors
HSV_tuples = [(x * 1.0 / N, 0.5, 0.5) for x in range(N)]
RGB_tuples = map(lambda x: colorsys.hsv_to_rgb(*x), HSV_tuples)
self.colors = list(RGB_tuples)
def create_sns_palette(self, s_type=None):
self.colors = sns.color_palette(s_type, self.num_colors)
def create_sns_helix(self, start):
self.colors = sns.cubehelix_palette(self.num_colors, start)
def get_next_color(self):
result = self.colors[self.idx]
self._increment()
return result
def get_color(self, index):
return self.colors[index]
def test_plot(self):
import numpy as np
from scipy.stats import norm
fig, ax = plt.subplots()
x_axis = np.arange(-15, 5, 0.001)
std_devs = np.arange(0.8, 3, 2.20 / self.num_colors)
for sd in std_devs:
ax.plot(x_axis, norm.pdf(x_axis, -sd * 2, sd), color=self.get_next_color())
sns.despine(top=True, bottom=True, right=True, left=True)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.savefig("test.png", dpi=400)
def _increment(self):
"""Private function to increase the internal idx counter."""
self.idx = self.idx + 1
if self.idx == self.num_colors:
self.idx = 0
class GridFig:
"""Handles gridded figures."""
def __init__(
self,
rows,
cols=None,
size_multiplier_x=5,
size_multiplier_y=5,
wspace=0.3,
hspace=0.3,
tight_layout=False,
traverse_rows=True,
):
"""
Set up the grid specifications.
If only rows is passed, the number of columns and rows
are automatically determined to be close to a square.
size_multiplier, wspace, and hspace are used for spacing
"""
if cols is None:
rows, cols = self.auto_determine_grid(rows)
self.fig = plt.figure(
figsize=(cols * size_multiplier_x, rows * size_multiplier_y),
tight_layout=tight_layout,
)
self.gs = gridspec.GridSpec(rows, cols, wspace=wspace, hspace=hspace)
self.idx = 0
self.rows = rows
self.cols = cols
self.along_rows = traverse_rows
def auto_determine_grid(self, rows):
closest_sqrt = 1
while closest_sqrt**2 < rows:
closest_sqrt += 1
return closest_sqrt, closest_sqrt
def get_ax(self, row_idx, col_idx, circular=False):
"""Add subplot with standard 1x1 gs -> returns ax."""
if circular:
return self.fig.add_subplot(self.gs[row_idx, col_idx], projection="polar")
return self.fig.add_subplot(self.gs[row_idx, col_idx])
def get_multi_ax(self, row_start, row_end, col_start, col_end):
"""Add subplot with custom gs sizes -> returns ax."""
ax = self.fig.add_subplot(self.gs[row_start:row_end, col_start:col_end])
plt.subplots_adjust(top=0.85)
return ax
def save_fig(self, out_dir, out_name):
"""Names and saves figure."""
out_loc = os.path.join(out_dir, out_name)
print(f"Saved figure to {out_loc}")
make_path_if_not_exists(out_loc)
self.fig.savefig(out_loc, dpi=400)
plt.close(self.fig)
def savefig(self, fname, **kwargs):
"""Passes all to matplotlib savefig call"""
print(f"Saved figure to {fname}")
make_path_if_not_exists(fname)
if "dpi" not in kwargs.keys():
kwargs["dpi"] = 400
self.fig.savefig(fname, **kwargs)
plt.close(self.fig)
def get_fig(self):
"""Return the figure object in this class."""
return self.fig
def get_next(self, circular=False):
"""
Get next index along rows or columns.
along rows:
1 2 3 4 5 6
7 8 9 10 11 12 ...
else:
1 3 5 7 9 11
2 4 6 8 10 12 ...
"""
if self.along_rows:
row_idx = self.idx // self.cols
col_idx = self.idx % self.cols
else:
row_idx = self.idx % self.rows
col_idx = self.idx // self.rows
ax = self.get_ax(row_idx, col_idx, circular=circular)
self._increment()
return ax
def get_next_snake(self):
"""
Get the next index in a snake like pattern.
1 2 5 6 9 10 ...
3 4 7 8 11 12 ..
"""
if self.rows != 2:
raise ValueError("Can't get snake unless there are two rows")
i = self.idx
row_idx = (i // 2) % 2
col_idx = (i // 2) + (i % 2) - row_idx
ax = self.get_ax(row_idx, col_idx)
self._increment()
return ax
def _increment(self):
"""Private function to increase the internal idx counter."""
self.idx = self.idx + 1
if self.idx == self.rows * self.cols:
self.idx = 0
class UnicodeGrabber(object):
"""This is a fully static class to get unicode chars for plotting."""
char_dict = {
"micro": "\u00B5",
"pow2": "\u00B2",
}
@staticmethod
def get_chars():
return list(UnicodeGrabber.char_dict.keys())
@staticmethod
def get(char, default=""):
return UnicodeGrabber.char_dict.get(char, default)
@staticmethod
def to_sub(num):
return chr(int(f"208{num}", 16))
Classes
class ColorManager (num_colors, method='sns', **kwargs)
-
Expand source code
class ColorManager: def __init__(self, num_colors, method="sns", **kwargs): self.num_colors = num_colors if method == "sns": sns_style = kwargs.get("sns_style", None) self.create_sns_palette(s_type=sns_style) elif method == "rgb": self.create_rgb() elif method == "sns_helix": start = kwargs.get("start", 0) self.create_sns_helix(start) else: raise ValueError("{} not recognised method".format(method)) self.idx = 0 def create_rgb(self): N = self.num_colors HSV_tuples = [(x * 1.0 / N, 0.5, 0.5) for x in range(N)] RGB_tuples = map(lambda x: colorsys.hsv_to_rgb(*x), HSV_tuples) self.colors = list(RGB_tuples) def create_sns_palette(self, s_type=None): self.colors = sns.color_palette(s_type, self.num_colors) def create_sns_helix(self, start): self.colors = sns.cubehelix_palette(self.num_colors, start) def get_next_color(self): result = self.colors[self.idx] self._increment() return result def get_color(self, index): return self.colors[index] def test_plot(self): import numpy as np from scipy.stats import norm fig, ax = plt.subplots() x_axis = np.arange(-15, 5, 0.001) std_devs = np.arange(0.8, 3, 2.20 / self.num_colors) for sd in std_devs: ax.plot(x_axis, norm.pdf(x_axis, -sd * 2, sd), color=self.get_next_color()) sns.despine(top=True, bottom=True, right=True, left=True) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) fig.savefig("test.png", dpi=400) def _increment(self): """Private function to increase the internal idx counter.""" self.idx = self.idx + 1 if self.idx == self.num_colors: self.idx = 0
Methods
def create_rgb(self)
-
Expand source code
def create_rgb(self): N = self.num_colors HSV_tuples = [(x * 1.0 / N, 0.5, 0.5) for x in range(N)] RGB_tuples = map(lambda x: colorsys.hsv_to_rgb(*x), HSV_tuples) self.colors = list(RGB_tuples)
def create_sns_helix(self, start)
-
Expand source code
def create_sns_helix(self, start): self.colors = sns.cubehelix_palette(self.num_colors, start)
def create_sns_palette(self, s_type=None)
-
Expand source code
def create_sns_palette(self, s_type=None): self.colors = sns.color_palette(s_type, self.num_colors)
def get_color(self, index)
-
Expand source code
def get_color(self, index): return self.colors[index]
def get_next_color(self)
-
Expand source code
def get_next_color(self): result = self.colors[self.idx] self._increment() return result
def test_plot(self)
-
Expand source code
def test_plot(self): import numpy as np from scipy.stats import norm fig, ax = plt.subplots() x_axis = np.arange(-15, 5, 0.001) std_devs = np.arange(0.8, 3, 2.20 / self.num_colors) for sd in std_devs: ax.plot(x_axis, norm.pdf(x_axis, -sd * 2, sd), color=self.get_next_color()) sns.despine(top=True, bottom=True, right=True, left=True) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) fig.savefig("test.png", dpi=400)
class GridFig (rows, cols=None, size_multiplier_x=5, size_multiplier_y=5, wspace=0.3, hspace=0.3, tight_layout=False, traverse_rows=True)
-
Handles gridded figures.
Set up the grid specifications.
If only rows is passed, the number of columns and rows are automatically determined to be close to a square.
size_multiplier, wspace, and hspace are used for spacing
Expand source code
class GridFig: """Handles gridded figures.""" def __init__( self, rows, cols=None, size_multiplier_x=5, size_multiplier_y=5, wspace=0.3, hspace=0.3, tight_layout=False, traverse_rows=True, ): """ Set up the grid specifications. If only rows is passed, the number of columns and rows are automatically determined to be close to a square. size_multiplier, wspace, and hspace are used for spacing """ if cols is None: rows, cols = self.auto_determine_grid(rows) self.fig = plt.figure( figsize=(cols * size_multiplier_x, rows * size_multiplier_y), tight_layout=tight_layout, ) self.gs = gridspec.GridSpec(rows, cols, wspace=wspace, hspace=hspace) self.idx = 0 self.rows = rows self.cols = cols self.along_rows = traverse_rows def auto_determine_grid(self, rows): closest_sqrt = 1 while closest_sqrt**2 < rows: closest_sqrt += 1 return closest_sqrt, closest_sqrt def get_ax(self, row_idx, col_idx, circular=False): """Add subplot with standard 1x1 gs -> returns ax.""" if circular: return self.fig.add_subplot(self.gs[row_idx, col_idx], projection="polar") return self.fig.add_subplot(self.gs[row_idx, col_idx]) def get_multi_ax(self, row_start, row_end, col_start, col_end): """Add subplot with custom gs sizes -> returns ax.""" ax = self.fig.add_subplot(self.gs[row_start:row_end, col_start:col_end]) plt.subplots_adjust(top=0.85) return ax def save_fig(self, out_dir, out_name): """Names and saves figure.""" out_loc = os.path.join(out_dir, out_name) print(f"Saved figure to {out_loc}") make_path_if_not_exists(out_loc) self.fig.savefig(out_loc, dpi=400) plt.close(self.fig) def savefig(self, fname, **kwargs): """Passes all to matplotlib savefig call""" print(f"Saved figure to {fname}") make_path_if_not_exists(fname) if "dpi" not in kwargs.keys(): kwargs["dpi"] = 400 self.fig.savefig(fname, **kwargs) plt.close(self.fig) def get_fig(self): """Return the figure object in this class.""" return self.fig def get_next(self, circular=False): """ Get next index along rows or columns. along rows: 1 2 3 4 5 6 7 8 9 10 11 12 ... else: 1 3 5 7 9 11 2 4 6 8 10 12 ... """ if self.along_rows: row_idx = self.idx // self.cols col_idx = self.idx % self.cols else: row_idx = self.idx % self.rows col_idx = self.idx // self.rows ax = self.get_ax(row_idx, col_idx, circular=circular) self._increment() return ax def get_next_snake(self): """ Get the next index in a snake like pattern. 1 2 5 6 9 10 ... 3 4 7 8 11 12 .. """ if self.rows != 2: raise ValueError("Can't get snake unless there are two rows") i = self.idx row_idx = (i // 2) % 2 col_idx = (i // 2) + (i % 2) - row_idx ax = self.get_ax(row_idx, col_idx) self._increment() return ax def _increment(self): """Private function to increase the internal idx counter.""" self.idx = self.idx + 1 if self.idx == self.rows * self.cols: self.idx = 0
Methods
def auto_determine_grid(self, rows)
-
Expand source code
def auto_determine_grid(self, rows): closest_sqrt = 1 while closest_sqrt**2 < rows: closest_sqrt += 1 return closest_sqrt, closest_sqrt
def get_ax(self, row_idx, col_idx, circular=False)
-
Add subplot with standard 1x1 gs -> returns ax.
Expand source code
def get_ax(self, row_idx, col_idx, circular=False): """Add subplot with standard 1x1 gs -> returns ax.""" if circular: return self.fig.add_subplot(self.gs[row_idx, col_idx], projection="polar") return self.fig.add_subplot(self.gs[row_idx, col_idx])
def get_fig(self)
-
Return the figure object in this class.
Expand source code
def get_fig(self): """Return the figure object in this class.""" return self.fig
def get_multi_ax(self, row_start, row_end, col_start, col_end)
-
Add subplot with custom gs sizes -> returns ax.
Expand source code
def get_multi_ax(self, row_start, row_end, col_start, col_end): """Add subplot with custom gs sizes -> returns ax.""" ax = self.fig.add_subplot(self.gs[row_start:row_end, col_start:col_end]) plt.subplots_adjust(top=0.85) return ax
def get_next(self, circular=False)
-
Get next index along rows or columns.
along rows: 1 2 3 4 5 6 7 8 9 10 11 12 …
else: 1 3 5 7 9 11 2 4 6 8 10 12 …
Expand source code
def get_next(self, circular=False): """ Get next index along rows or columns. along rows: 1 2 3 4 5 6 7 8 9 10 11 12 ... else: 1 3 5 7 9 11 2 4 6 8 10 12 ... """ if self.along_rows: row_idx = self.idx // self.cols col_idx = self.idx % self.cols else: row_idx = self.idx % self.rows col_idx = self.idx // self.rows ax = self.get_ax(row_idx, col_idx, circular=circular) self._increment() return ax
def get_next_snake(self)
-
Get the next index in a snake like pattern.
1 2 5 6 9 10 … 3 4 7 8 11 12 ..
Expand source code
def get_next_snake(self): """ Get the next index in a snake like pattern. 1 2 5 6 9 10 ... 3 4 7 8 11 12 .. """ if self.rows != 2: raise ValueError("Can't get snake unless there are two rows") i = self.idx row_idx = (i // 2) % 2 col_idx = (i // 2) + (i % 2) - row_idx ax = self.get_ax(row_idx, col_idx) self._increment() return ax
def save_fig(self, out_dir, out_name)
-
Names and saves figure.
Expand source code
def save_fig(self, out_dir, out_name): """Names and saves figure.""" out_loc = os.path.join(out_dir, out_name) print(f"Saved figure to {out_loc}") make_path_if_not_exists(out_loc) self.fig.savefig(out_loc, dpi=400) plt.close(self.fig)
def savefig(self, fname, **kwargs)
-
Passes all to matplotlib savefig call
Expand source code
def savefig(self, fname, **kwargs): """Passes all to matplotlib savefig call""" print(f"Saved figure to {fname}") make_path_if_not_exists(fname) if "dpi" not in kwargs.keys(): kwargs["dpi"] = 400 self.fig.savefig(fname, **kwargs) plt.close(self.fig)
class GroupManager (group_list)
-
Expand source code
class GroupManager: def __init__(self, group_list): self.group_list = group_list print(self.group_list) self.info_dict = OrderedDict() self.index = 0 self.color_list = ["Blues", "Oranges", "Greens", "Reds", "Purples", "Greys"] set_vals = sorted(set(group_list), key=group_list.index) print(set_vals) import numpy as np if len(set_vals) > len(self.color_list): start_vals = np.arange(0.0, 2.5, 2.5 / (len(set_vals) - 0.99)) for set_v, start_v in zip(set_vals, start_vals): self.info_dict[set_v] = ColorManager( group_list.count(set_v), "sns_helix", start=start_v ) else: start_vals = self.color_list[: len(set_vals)] for set_v, start_v in zip(set_vals, start_vals): self.info_dict[set_v] = ColorManager( group_list.count(set_v), "sns", sns_style=start_v ) print(self.info_dict) def get_next_color(self): out = self.info_dict[self.group_list[self.index]].get_next_color() self._increment() return out def _increment(self): self.index += 1 if self.index == len(self.group_list): self.index = 0 def test_plot(self): import numpy as np from scipy.stats import norm fig, ax = plt.subplots() x_axis = np.arange(-15, 5, 0.001) std_devs = np.arange(0.8, 3, 2.20 / len(self.group_list)) for sd in std_devs: ax.plot(x_axis, norm.pdf(x_axis, -sd * 2, sd), color=self.get_next_color()) sns.despine(top=True, bottom=True, right=True, left=True) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) fig.savefig("test.png", dpi=400)
Methods
def get_next_color(self)
-
Expand source code
def get_next_color(self): out = self.info_dict[self.group_list[self.index]].get_next_color() self._increment() return out
def test_plot(self)
-
Expand source code
def test_plot(self): import numpy as np from scipy.stats import norm fig, ax = plt.subplots() x_axis = np.arange(-15, 5, 0.001) std_devs = np.arange(0.8, 3, 2.20 / len(self.group_list)) for sd in std_devs: ax.plot(x_axis, norm.pdf(x_axis, -sd * 2, sd), color=self.get_next_color()) sns.despine(top=True, bottom=True, right=True, left=True) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) fig.savefig("test.png", dpi=400)
class UnicodeGrabber
-
This is a fully static class to get unicode chars for plotting.
Expand source code
class UnicodeGrabber(object): """This is a fully static class to get unicode chars for plotting.""" char_dict = { "micro": "\u00B5", "pow2": "\u00B2", } @staticmethod def get_chars(): return list(UnicodeGrabber.char_dict.keys()) @staticmethod def get(char, default=""): return UnicodeGrabber.char_dict.get(char, default) @staticmethod def to_sub(num): return chr(int(f"208{num}", 16))
Class variables
var char_dict
Static methods
def get(char, default='')
-
Expand source code
@staticmethod def get(char, default=""): return UnicodeGrabber.char_dict.get(char, default)
def get_chars()
-
Expand source code
@staticmethod def get_chars(): return list(UnicodeGrabber.char_dict.keys())
def to_sub(num)
-
Expand source code
@staticmethod def to_sub(num): return chr(int(f"208{num}", 16))