1132 lines
43 KiB
Python
1132 lines
43 KiB
Python
"""
|
|
This module contains a more flexible API for Scikit-plot users, exposing
|
|
simple functions to generate plots.
|
|
"""
|
|
from __future__ import absolute_import, division, print_function, \
|
|
unicode_literals
|
|
|
|
import warnings
|
|
import itertools
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import numpy as np
|
|
|
|
from sklearn.metrics import confusion_matrix
|
|
from sklearn.preprocessing import label_binarize
|
|
from sklearn.metrics import roc_curve
|
|
from sklearn.metrics import auc
|
|
from sklearn.metrics import precision_recall_curve
|
|
from sklearn.metrics import average_precision_score
|
|
from sklearn.utils.multiclass import unique_labels
|
|
from sklearn.model_selection import learning_curve
|
|
from sklearn.base import clone
|
|
from sklearn.metrics import silhouette_score
|
|
from sklearn.metrics import silhouette_samples
|
|
from sklearn.utils import deprecated
|
|
|
|
from scipy import interp
|
|
|
|
from scikitplot.helpers import binary_ks_curve, validate_labels
|
|
|
|
|
|
warnings.warn("This module was deprecated in version 0.3.0 and its functions "
|
|
"are spread throughout different modules. Please check the "
|
|
"documentation and update your function calls as soon as "
|
|
"possible. This module will be removed in 0.4.0",
|
|
DeprecationWarning)
|
|
|
|
|
|
@deprecated('This will be removed in v0.4.0. Please use '
|
|
'scikitplot.metrics.plot_confusion_matrix instead.')
|
|
def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None,
|
|
pred_labels=None, title=None, normalize=False,
|
|
hide_zeros=False, x_tick_rotation=0, ax=None,
|
|
figsize=None, cmap='Blues', title_fontsize="large",
|
|
text_fontsize="medium"):
|
|
"""Generates confusion matrix plot from predictions and true labels
|
|
|
|
Args:
|
|
y_true (array-like, shape (n_samples)):
|
|
Ground truth (correct) target values.
|
|
|
|
y_pred (array-like, shape (n_samples)):
|
|
Estimated targets as returned by a classifier.
|
|
|
|
labels (array-like, shape (n_classes), optional): List of labels to
|
|
index the matrix. This may be used to reorder or select a subset
|
|
of labels. If none is given, those that appear at least once in
|
|
``y_true`` or ``y_pred`` are used in sorted order. (new in v0.2.5)
|
|
|
|
true_labels (array-like, optional): The true labels to display.
|
|
If none is given, then all of the labels are used.
|
|
|
|
pred_labels (array-like, optional): The predicted labels to display.
|
|
If none is given, then all of the labels are used.
|
|
|
|
title (string, optional): Title of the generated plot. Defaults to
|
|
"Confusion Matrix" if `normalize` is True. Else, defaults to
|
|
"Normalized Confusion Matrix.
|
|
|
|
normalize (bool, optional): If True, normalizes the confusion matrix
|
|
before plotting. Defaults to False.
|
|
|
|
hide_zeros (bool, optional): If True, does not plot cells containing a
|
|
value of zero. Defaults to False.
|
|
|
|
x_tick_rotation (int, optional): Rotates x-axis tick labels by the
|
|
specified angle. This is useful in cases where there are numerous
|
|
categories and the labels overlap each other.
|
|
|
|
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
|
|
plot the curve. If None, the plot is drawn on a new set of axes.
|
|
|
|
figsize (2-tuple, optional): Tuple denoting figure size of the plot
|
|
e.g. (6, 6). Defaults to ``None``.
|
|
|
|
cmap (string or :class:`matplotlib.colors.Colormap` instance, optional):
|
|
Colormap used for plotting the projection. View Matplotlib Colormap
|
|
documentation for available options.
|
|
https://matplotlib.org/users/colormaps.html
|
|
|
|
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"large".
|
|
|
|
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"medium".
|
|
|
|
Returns:
|
|
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
|
|
drawn.
|
|
|
|
Example:
|
|
>>> import scikitplot.plotters as skplt
|
|
>>> rf = RandomForestClassifier()
|
|
>>> rf = rf.fit(X_train, y_train)
|
|
>>> y_pred = rf.predict(X_test)
|
|
>>> skplt.plot_confusion_matrix(y_test, y_pred, normalize=True)
|
|
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
|
|
>>> plt.show()
|
|
|
|
.. image:: _static/examples/plot_confusion_matrix.png
|
|
:align: center
|
|
:alt: Confusion matrix
|
|
"""
|
|
if ax is None:
|
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
|
|
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
|
if labels is None:
|
|
classes = unique_labels(y_true, y_pred)
|
|
else:
|
|
classes = np.asarray(labels)
|
|
|
|
if normalize:
|
|
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
|
cm = np.around(cm, decimals=2)
|
|
cm[np.isnan(cm)] = 0.0
|
|
|
|
if true_labels is None:
|
|
true_classes = classes
|
|
else:
|
|
validate_labels(classes, true_labels, "true_labels")
|
|
|
|
true_label_indexes = np.in1d(classes, true_labels)
|
|
|
|
true_classes = classes[true_label_indexes]
|
|
cm = cm[true_label_indexes]
|
|
|
|
if pred_labels is None:
|
|
pred_classes = classes
|
|
else:
|
|
validate_labels(classes, pred_labels, "pred_labels")
|
|
|
|
pred_label_indexes = np.in1d(classes, pred_labels)
|
|
|
|
pred_classes = classes[pred_label_indexes]
|
|
cm = cm[:, pred_label_indexes]
|
|
|
|
if title:
|
|
ax.set_title(title, fontsize=title_fontsize)
|
|
elif normalize:
|
|
ax.set_title('Normalized Confusion Matrix', fontsize=title_fontsize)
|
|
else:
|
|
ax.set_title('Confusion Matrix', fontsize=title_fontsize)
|
|
|
|
image = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.get_cmap(cmap))
|
|
plt.colorbar(mappable=image)
|
|
x_tick_marks = np.arange(len(pred_classes))
|
|
y_tick_marks = np.arange(len(true_classes))
|
|
ax.set_xticks(x_tick_marks)
|
|
ax.set_xticklabels(pred_classes, fontsize=text_fontsize,
|
|
rotation=x_tick_rotation)
|
|
ax.set_yticks(y_tick_marks)
|
|
ax.set_yticklabels(true_classes, fontsize=text_fontsize)
|
|
|
|
thresh = cm.max() / 2.
|
|
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
|
|
if not (hide_zeros and cm[i, j] == 0):
|
|
ax.text(j, i, cm[i, j],
|
|
horizontalalignment="center",
|
|
verticalalignment="center",
|
|
fontsize=text_fontsize,
|
|
color="white" if cm[i, j] > thresh else "black")
|
|
|
|
ax.set_ylabel('True label', fontsize=text_fontsize)
|
|
ax.set_xlabel('Predicted label', fontsize=text_fontsize)
|
|
ax.grid('off')
|
|
|
|
return ax
|
|
|
|
|
|
@deprecated('This will be removed in v0.4.0. Please use '
|
|
'scikitplot.metrics.plot_roc_curve instead.')
|
|
def plot_roc_curve(y_true, y_probas, title='ROC Curves',
|
|
curves=('micro', 'macro', 'each_class'),
|
|
ax=None, figsize=None, cmap='nipy_spectral',
|
|
title_fontsize="large", text_fontsize="medium"):
|
|
"""Generates the ROC curves from labels and predicted scores/probabilities
|
|
|
|
Args:
|
|
y_true (array-like, shape (n_samples)):
|
|
Ground truth (correct) target values.
|
|
|
|
y_probas (array-like, shape (n_samples, n_classes)):
|
|
Prediction probabilities for each class returned by a classifier.
|
|
|
|
title (string, optional): Title of the generated plot. Defaults to
|
|
"ROC Curves".
|
|
|
|
curves (array-like): A listing of which curves should be plotted on the
|
|
resulting plot. Defaults to `("micro", "macro", "each_class")`
|
|
i.e. "micro" for micro-averaged curve, "macro" for macro-averaged
|
|
curve
|
|
|
|
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
|
|
plot the curve. If None, the plot is drawn on a new set of axes.
|
|
|
|
figsize (2-tuple, optional): Tuple denoting figure size of the plot
|
|
e.g. (6, 6). Defaults to ``None``.
|
|
|
|
cmap (string or :class:`matplotlib.colors.Colormap` instance, optional):
|
|
Colormap used for plotting the projection. View Matplotlib Colormap
|
|
documentation for available options.
|
|
https://matplotlib.org/users/colormaps.html
|
|
|
|
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"large".
|
|
|
|
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"medium".
|
|
|
|
Returns:
|
|
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
|
|
drawn.
|
|
|
|
Example:
|
|
>>> import scikitplot.plotters as skplt
|
|
>>> nb = GaussianNB()
|
|
>>> nb = nb.fit(X_train, y_train)
|
|
>>> y_probas = nb.predict_proba(X_test)
|
|
>>> skplt.plot_roc_curve(y_test, y_probas)
|
|
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
|
|
>>> plt.show()
|
|
|
|
.. image:: _static/examples/plot_roc_curve.png
|
|
:align: center
|
|
:alt: ROC Curves
|
|
"""
|
|
y_true = np.array(y_true)
|
|
y_probas = np.array(y_probas)
|
|
|
|
if 'micro' not in curves and 'macro' not in curves and \
|
|
'each_class' not in curves:
|
|
raise ValueError('Invalid argument for curves as it '
|
|
'only takes "micro", "macro", or "each_class"')
|
|
|
|
classes = np.unique(y_true)
|
|
probas = y_probas
|
|
|
|
fpr = dict()
|
|
tpr = dict()
|
|
roc_auc = dict()
|
|
for i in range(len(classes)):
|
|
fpr[i], tpr[i], _ = roc_curve(y_true, probas[:, i],
|
|
pos_label=classes[i])
|
|
roc_auc[i] = auc(fpr[i], tpr[i])
|
|
|
|
# Compute micro-average ROC curve and ROC area
|
|
micro_key = 'micro'
|
|
i = 0
|
|
while micro_key in fpr:
|
|
i += 1
|
|
micro_key += str(i)
|
|
|
|
y_true = label_binarize(y_true, classes=classes)
|
|
if len(classes) == 2:
|
|
y_true = np.hstack((1 - y_true, y_true))
|
|
|
|
fpr[micro_key], tpr[micro_key], _ = roc_curve(y_true.ravel(),
|
|
probas.ravel())
|
|
roc_auc[micro_key] = auc(fpr[micro_key], tpr[micro_key])
|
|
|
|
# Compute macro-average ROC curve and ROC area
|
|
|
|
# First aggregate all false positive rates
|
|
all_fpr = np.unique(np.concatenate([fpr[x] for x in range(len(classes))]))
|
|
|
|
# Then interpolate all ROC curves at this points
|
|
mean_tpr = np.zeros_like(all_fpr)
|
|
for i in range(len(classes)):
|
|
mean_tpr += interp(all_fpr, fpr[i], tpr[i])
|
|
|
|
# Finally average it and compute AUC
|
|
mean_tpr /= len(classes)
|
|
|
|
macro_key = 'macro'
|
|
i = 0
|
|
while macro_key in fpr:
|
|
i += 1
|
|
macro_key += str(i)
|
|
fpr[macro_key] = all_fpr
|
|
tpr[macro_key] = mean_tpr
|
|
roc_auc[macro_key] = auc(fpr[macro_key], tpr[macro_key])
|
|
|
|
if ax is None:
|
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
|
|
ax.set_title(title, fontsize=title_fontsize)
|
|
|
|
if 'each_class' in curves:
|
|
for i in range(len(classes)):
|
|
color = plt.cm.get_cmap(cmap)(float(i) / len(classes))
|
|
ax.plot(fpr[i], tpr[i], lw=2, color=color,
|
|
label='ROC curve of class {0} (area = {1:0.2f})'
|
|
''.format(classes[i], roc_auc[i]))
|
|
|
|
if 'micro' in curves:
|
|
ax.plot(fpr[micro_key], tpr[micro_key],
|
|
label='micro-average ROC curve '
|
|
'(area = {0:0.2f})'.format(roc_auc[micro_key]),
|
|
color='deeppink', linestyle=':', linewidth=4)
|
|
|
|
if 'macro' in curves:
|
|
ax.plot(fpr[macro_key], tpr[macro_key],
|
|
label='macro-average ROC curve '
|
|
'(area = {0:0.2f})'.format(roc_auc[macro_key]),
|
|
color='navy', linestyle=':', linewidth=4)
|
|
|
|
ax.plot([0, 1], [0, 1], 'k--', lw=2)
|
|
ax.set_xlim([0.0, 1.0])
|
|
ax.set_ylim([0.0, 1.05])
|
|
ax.set_xlabel('False Positive Rate', fontsize=text_fontsize)
|
|
ax.set_ylabel('True Positive Rate', fontsize=text_fontsize)
|
|
ax.tick_params(labelsize=text_fontsize)
|
|
ax.legend(loc='lower right', fontsize=text_fontsize)
|
|
return ax
|
|
|
|
|
|
@deprecated('This will be removed in v0.4.0. Please use '
|
|
'scikitplot.metrics.plot_ks_statistic instead.')
|
|
def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot',
|
|
ax=None, figsize=None, title_fontsize="large",
|
|
text_fontsize="medium"):
|
|
"""Generates the KS Statistic plot from labels and scores/probabilities
|
|
|
|
Args:
|
|
y_true (array-like, shape (n_samples)):
|
|
Ground truth (correct) target values.
|
|
|
|
y_probas (array-like, shape (n_samples, n_classes)):
|
|
Prediction probabilities for each class returned by a classifier.
|
|
|
|
title (string, optional): Title of the generated plot. Defaults to
|
|
"KS Statistic Plot".
|
|
|
|
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
|
|
plot the learning curve. If None, the plot is drawn on a new set of
|
|
axes.
|
|
|
|
figsize (2-tuple, optional): Tuple denoting figure size of the plot
|
|
e.g. (6, 6). Defaults to ``None``.
|
|
|
|
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"large".
|
|
|
|
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"medium".
|
|
|
|
Returns:
|
|
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
|
|
drawn.
|
|
|
|
Example:
|
|
>>> import scikitplot.plotters as skplt
|
|
>>> lr = LogisticRegression()
|
|
>>> lr = lr.fit(X_train, y_train)
|
|
>>> y_probas = lr.predict_proba(X_test)
|
|
>>> skplt.plot_ks_statistic(y_test, y_probas)
|
|
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
|
|
>>> plt.show()
|
|
|
|
.. image:: _static/examples/plot_ks_statistic.png
|
|
:align: center
|
|
:alt: KS Statistic
|
|
"""
|
|
y_true = np.array(y_true)
|
|
y_probas = np.array(y_probas)
|
|
|
|
classes = np.unique(y_true)
|
|
if len(classes) != 2:
|
|
raise ValueError('Cannot calculate KS statistic for data with '
|
|
'{} category/ies'.format(len(classes)))
|
|
probas = y_probas
|
|
|
|
# Compute KS Statistic curves
|
|
thresholds, pct1, pct2, ks_statistic, \
|
|
max_distance_at, classes = binary_ks_curve(y_true,
|
|
probas[:, 1].ravel())
|
|
|
|
if ax is None:
|
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
|
|
ax.set_title(title, fontsize=title_fontsize)
|
|
|
|
ax.plot(thresholds, pct1, lw=3, label='Class {}'.format(classes[0]))
|
|
ax.plot(thresholds, pct2, lw=3, label='Class {}'.format(classes[1]))
|
|
idx = np.where(thresholds == max_distance_at)[0][0]
|
|
ax.axvline(max_distance_at, *sorted([pct1[idx], pct2[idx]]),
|
|
label='KS Statistic: {:.3f} at {:.3f}'.format(ks_statistic,
|
|
max_distance_at),
|
|
linestyle=':', lw=3, color='black')
|
|
|
|
ax.set_xlim([0.0, 1.0])
|
|
ax.set_ylim([0.0, 1.0])
|
|
|
|
ax.set_xlabel('Threshold', fontsize=text_fontsize)
|
|
ax.set_ylabel('Percentage below threshold', fontsize=text_fontsize)
|
|
ax.tick_params(labelsize=text_fontsize)
|
|
ax.legend(loc='lower right', fontsize=text_fontsize)
|
|
|
|
return ax
|
|
|
|
|
|
@deprecated('This will be removed in v0.4.0. Please use '
|
|
'scikitplot.metrics.plot_precision_recall_curve instead.')
|
|
def plot_precision_recall_curve(y_true, y_probas,
|
|
title='Precision-Recall Curve',
|
|
curves=('micro', 'each_class'), ax=None,
|
|
figsize=None, cmap='nipy_spectral',
|
|
title_fontsize="large",
|
|
text_fontsize="medium"):
|
|
"""Generates the Precision Recall Curve from labels and probabilities
|
|
|
|
Args:
|
|
y_true (array-like, shape (n_samples)):
|
|
Ground truth (correct) target values.
|
|
|
|
y_probas (array-like, shape (n_samples, n_classes)):
|
|
Prediction probabilities for each class returned by a classifier.
|
|
|
|
curves (array-like): A listing of which curves should be plotted on the
|
|
resulting plot. Defaults to `("micro", "each_class")`
|
|
i.e. "micro" for micro-averaged curve
|
|
|
|
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
|
|
plot the curve. If None, the plot is drawn on a new set of axes.
|
|
|
|
figsize (2-tuple, optional): Tuple denoting figure size of the plot
|
|
e.g. (6, 6). Defaults to ``None``.
|
|
|
|
cmap (string or :class:`matplotlib.colors.Colormap` instance, optional):
|
|
Colormap used for plotting the projection. View Matplotlib Colormap
|
|
documentation for available options.
|
|
https://matplotlib.org/users/colormaps.html
|
|
|
|
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"large".
|
|
|
|
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"medium".
|
|
|
|
Returns:
|
|
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
|
|
drawn.
|
|
|
|
Example:
|
|
>>> import scikitplot.plotters as skplt
|
|
>>> nb = GaussianNB()
|
|
>>> nb = nb.fit(X_train, y_train)
|
|
>>> y_probas = nb.predict_proba(X_test)
|
|
>>> skplt.plot_precision_recall_curve(y_test, y_probas)
|
|
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
|
|
>>> plt.show()
|
|
|
|
.. image:: _static/examples/plot_precision_recall_curve.png
|
|
:align: center
|
|
:alt: Precision Recall Curve
|
|
"""
|
|
y_true = np.array(y_true)
|
|
y_probas = np.array(y_probas)
|
|
|
|
classes = np.unique(y_true)
|
|
probas = y_probas
|
|
|
|
if 'micro' not in curves and 'each_class' not in curves:
|
|
raise ValueError('Invalid argument for curves as it '
|
|
'only takes "micro" or "each_class"')
|
|
|
|
# Compute Precision-Recall curve and area for each class
|
|
precision = dict()
|
|
recall = dict()
|
|
average_precision = dict()
|
|
for i in range(len(classes)):
|
|
precision[i], recall[i], _ = precision_recall_curve(
|
|
y_true, probas[:, i], pos_label=classes[i])
|
|
|
|
y_true = label_binarize(y_true, classes=classes)
|
|
if len(classes) == 2:
|
|
y_true = np.hstack((1 - y_true, y_true))
|
|
|
|
for i in range(len(classes)):
|
|
average_precision[i] = average_precision_score(y_true[:, i],
|
|
probas[:, i])
|
|
|
|
# Compute micro-average ROC curve and ROC area
|
|
micro_key = 'micro'
|
|
i = 0
|
|
while micro_key in precision:
|
|
i += 1
|
|
micro_key += str(i)
|
|
|
|
precision[micro_key], recall[micro_key], _ = precision_recall_curve(
|
|
y_true.ravel(), probas.ravel())
|
|
average_precision[micro_key] = average_precision_score(y_true, probas,
|
|
average='micro')
|
|
|
|
if ax is None:
|
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
|
|
ax.set_title(title, fontsize=title_fontsize)
|
|
|
|
if 'each_class' in curves:
|
|
for i in range(len(classes)):
|
|
color = plt.cm.get_cmap(cmap)(float(i) / len(classes))
|
|
ax.plot(recall[i], precision[i], lw=2,
|
|
label='Precision-recall curve of class {0} '
|
|
'(area = {1:0.3f})'.format(classes[i],
|
|
average_precision[i]),
|
|
color=color)
|
|
|
|
if 'micro' in curves:
|
|
ax.plot(recall[micro_key], precision[micro_key],
|
|
label='micro-average Precision-recall curve '
|
|
'(area = {0:0.3f})'.format(average_precision[micro_key]),
|
|
color='navy', linestyle=':', linewidth=4)
|
|
|
|
ax.set_xlim([0.0, 1.0])
|
|
ax.set_ylim([0.0, 1.05])
|
|
ax.set_xlabel('Recall')
|
|
ax.set_ylabel('Precision')
|
|
ax.tick_params(labelsize=text_fontsize)
|
|
ax.legend(loc='best', fontsize=text_fontsize)
|
|
return ax
|
|
|
|
|
|
@deprecated('This will be removed in v0.4.0. Please use '
|
|
'scikitplot.estimators.plot_feature_importances instead.')
|
|
def plot_feature_importances(clf, title='Feature Importance',
|
|
feature_names=None, max_num_features=20,
|
|
order='descending', x_tick_rotation=0, ax=None,
|
|
figsize=None, title_fontsize="large",
|
|
text_fontsize="medium"):
|
|
"""Generates a plot of a classifier's feature importances.
|
|
|
|
Args:
|
|
clf: Classifier instance that implements ``fit`` and ``predict_proba``
|
|
methods. The classifier must also have a ``feature_importances_``
|
|
attribute.
|
|
|
|
title (string, optional): Title of the generated plot. Defaults to
|
|
"Feature importances".
|
|
|
|
feature_names (None, :obj:`list` of string, optional): Determines the
|
|
feature names used to plot the feature importances. If None,
|
|
feature names will be numbered.
|
|
|
|
max_num_features (int): Determines the maximum number of features to
|
|
plot. Defaults to 20.
|
|
|
|
order ('ascending', 'descending', or None, optional): Determines the
|
|
order in which the feature importances are plotted. Defaults to
|
|
'descending'.
|
|
|
|
x_tick_rotation (int, optional): Rotates x-axis tick labels by the
|
|
specified angle. This is useful in cases where there are numerous
|
|
categories and the labels overlap each other.
|
|
|
|
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
|
|
plot the curve. If None, the plot is drawn on a new set of axes.
|
|
|
|
figsize (2-tuple, optional): Tuple denoting figure size of the plot
|
|
e.g. (6, 6). Defaults to ``None``.
|
|
|
|
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"large".
|
|
|
|
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"medium".
|
|
|
|
Returns:
|
|
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
|
|
drawn.
|
|
|
|
Example:
|
|
>>> import scikitplot.plotters as skplt
|
|
>>> rf = RandomForestClassifier()
|
|
>>> rf.fit(X, y)
|
|
>>> skplt.plot_feature_importances(
|
|
... rf, feature_names=['petal length', 'petal width',
|
|
... 'sepal length', 'sepal width'])
|
|
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
|
|
>>> plt.show()
|
|
|
|
.. image:: _static/examples/plot_feature_importances.png
|
|
:align: center
|
|
:alt: Feature Importances
|
|
"""
|
|
if not hasattr(clf, 'feature_importances_'):
|
|
raise TypeError('"feature_importances_" attribute not in classifier. '
|
|
'Cannot plot feature importances.')
|
|
|
|
importances = clf.feature_importances_
|
|
|
|
if hasattr(clf, 'estimators_')\
|
|
and isinstance(clf.estimators_, list)\
|
|
and hasattr(clf.estimators_[0], 'feature_importances_'):
|
|
std = np.std([tree.feature_importances_ for tree in clf.estimators_],
|
|
axis=0)
|
|
|
|
else:
|
|
std = None
|
|
|
|
if order == 'descending':
|
|
indices = np.argsort(importances)[::-1]
|
|
|
|
elif order == 'ascending':
|
|
indices = np.argsort(importances)
|
|
|
|
elif order is None:
|
|
indices = np.array(range(len(importances)))
|
|
|
|
else:
|
|
raise ValueError('Invalid argument {} for "order"'.format(order))
|
|
|
|
if ax is None:
|
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
|
|
if feature_names is None:
|
|
feature_names = indices
|
|
else:
|
|
feature_names = np.array(feature_names)[indices]
|
|
|
|
max_num_features = min(max_num_features, len(importances))
|
|
|
|
ax.set_title(title, fontsize=title_fontsize)
|
|
|
|
if std is not None:
|
|
ax.bar(range(max_num_features),
|
|
importances[indices][:max_num_features], color='r',
|
|
yerr=std[indices][:max_num_features], align='center')
|
|
else:
|
|
ax.bar(range(max_num_features),
|
|
importances[indices][:max_num_features],
|
|
color='r', align='center')
|
|
|
|
ax.set_xticks(range(max_num_features))
|
|
ax.set_xticklabels(feature_names[:max_num_features],
|
|
rotation=x_tick_rotation)
|
|
ax.set_xlim([-1, max_num_features])
|
|
ax.tick_params(labelsize=text_fontsize)
|
|
return ax
|
|
|
|
|
|
@deprecated('This will be removed in v0.4.0. Please use '
|
|
'scikitplot.estimators.plot_learning_curve instead.')
|
|
def plot_learning_curve(clf, X, y, title='Learning Curve', cv=None,
|
|
train_sizes=None, n_jobs=1, scoring=None,
|
|
ax=None, figsize=None, title_fontsize="large",
|
|
text_fontsize="medium"):
|
|
"""Generates a plot of the train and test learning curves for a classifier.
|
|
|
|
Args:
|
|
clf: Classifier instance that implements ``fit`` and ``predict``
|
|
methods.
|
|
|
|
X (array-like, shape (n_samples, n_features)):
|
|
Training vector, where n_samples is the number of samples and
|
|
n_features is the number of features.
|
|
|
|
y (array-like, shape (n_samples) or (n_samples, n_features)):
|
|
Target relative to X for classification or regression;
|
|
None for unsupervised learning.
|
|
|
|
title (string, optional): Title of the generated plot. Defaults to
|
|
"Learning Curve"
|
|
|
|
cv (int, cross-validation generator, iterable, optional): Determines
|
|
the cross-validation strategy to be used for splitting.
|
|
|
|
Possible inputs for cv are:
|
|
- None, to use the default 3-fold cross-validation,
|
|
- integer, to specify the number of folds.
|
|
- An object to be used as a cross-validation generator.
|
|
- An iterable yielding train/test splits.
|
|
|
|
For integer/None inputs, if ``y`` is binary or multiclass,
|
|
:class:`StratifiedKFold` used. If the estimator is not a classifier
|
|
or if ``y`` is neither binary nor multiclass, :class:`KFold` is
|
|
used.
|
|
|
|
train_sizes (iterable, optional): Determines the training sizes used to
|
|
plot the learning curve. If None, ``np.linspace(.1, 1.0, 5)`` is
|
|
used.
|
|
|
|
n_jobs (int, optional): Number of jobs to run in parallel. Defaults to
|
|
1.
|
|
|
|
scoring (string, callable or None, optional): default: None
|
|
A string (see scikit-learn model evaluation documentation) or a
|
|
scorerbcallable object / function with signature
|
|
scorer(estimator, X, y).
|
|
|
|
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
|
|
plot the curve. If None, the plot is drawn on a new set of axes.
|
|
|
|
figsize (2-tuple, optional): Tuple denoting figure size of the plot
|
|
e.g. (6, 6). Defaults to ``None``.
|
|
|
|
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"large".
|
|
|
|
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"medium".
|
|
|
|
Returns:
|
|
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
|
|
drawn.
|
|
|
|
Example:
|
|
>>> import scikitplot.plotters as skplt
|
|
>>> rf = RandomForestClassifier()
|
|
>>> skplt.plot_learning_curve(rf, X, y)
|
|
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
|
|
>>> plt.show()
|
|
|
|
.. image:: _static/examples/plot_learning_curve.png
|
|
:align: center
|
|
:alt: Learning Curve
|
|
"""
|
|
if ax is None:
|
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
|
|
if train_sizes is None:
|
|
train_sizes = np.linspace(.1, 1.0, 5)
|
|
|
|
ax.set_title(title, fontsize=title_fontsize)
|
|
ax.set_xlabel("Training examples", fontsize=text_fontsize)
|
|
ax.set_ylabel("Score", fontsize=text_fontsize)
|
|
train_sizes, train_scores, test_scores = learning_curve(
|
|
clf, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes,
|
|
scoring=scoring)
|
|
train_scores_mean = np.mean(train_scores, axis=1)
|
|
train_scores_std = np.std(train_scores, axis=1)
|
|
test_scores_mean = np.mean(test_scores, axis=1)
|
|
test_scores_std = np.std(test_scores, axis=1)
|
|
ax.grid()
|
|
ax.fill_between(train_sizes, train_scores_mean - train_scores_std,
|
|
train_scores_mean + train_scores_std, alpha=0.1, color="r")
|
|
ax.fill_between(train_sizes, test_scores_mean - test_scores_std,
|
|
test_scores_mean + test_scores_std, alpha=0.1, color="g")
|
|
ax.plot(train_sizes, train_scores_mean, 'o-', color="r",
|
|
label="Training score")
|
|
ax.plot(train_sizes, test_scores_mean, 'o-', color="g",
|
|
label="Cross-validation score")
|
|
ax.tick_params(labelsize=text_fontsize)
|
|
ax.legend(loc="best", fontsize=text_fontsize)
|
|
|
|
return ax
|
|
|
|
|
|
@deprecated('This will be removed in v0.4.0. Please use '
|
|
'scikitplot.metrics.plot_silhouette instead.')
|
|
def plot_silhouette(clf, X, title='Silhouette Analysis', metric='euclidean',
|
|
copy=True, ax=None, figsize=None, cmap='nipy_spectral',
|
|
title_fontsize="large", text_fontsize="medium"):
|
|
"""Plots silhouette analysis of clusters using fit_predict.
|
|
|
|
Args:
|
|
clf: Clusterer instance that implements ``fit`` and ``fit_predict``
|
|
methods.
|
|
|
|
X (array-like, shape (n_samples, n_features)):
|
|
Data to cluster, where n_samples is the number of samples and
|
|
n_features is the number of features.
|
|
|
|
title (string, optional): Title of the generated plot. Defaults to
|
|
"Silhouette Analysis"
|
|
|
|
metric (string or callable, optional): The metric to use when
|
|
calculating distance between instances in a feature array.
|
|
If metric is a string, it must be one of the options allowed by
|
|
sklearn.metrics.pairwise.pairwise_distances. If X is
|
|
the distance array itself, use "precomputed" as the metric.
|
|
|
|
copy (boolean, optional): Determines whether ``fit`` is used on
|
|
**clf** or on a copy of **clf**.
|
|
|
|
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
|
|
plot the curve. If None, the plot is drawn on a new set of axes.
|
|
|
|
figsize (2-tuple, optional): Tuple denoting figure size of the plot
|
|
e.g. (6, 6). Defaults to ``None``.
|
|
|
|
cmap (string or :class:`matplotlib.colors.Colormap` instance, optional):
|
|
Colormap used for plotting the projection. View Matplotlib Colormap
|
|
documentation for available options.
|
|
https://matplotlib.org/users/colormaps.html
|
|
|
|
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"large".
|
|
|
|
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"medium".
|
|
|
|
Returns:
|
|
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
|
|
drawn.
|
|
|
|
Example:
|
|
>>> import scikitplot.plotters as skplt
|
|
>>> kmeans = KMeans(n_clusters=4, random_state=1)
|
|
>>> skplt.plot_silhouette(kmeans, X)
|
|
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
|
|
>>> plt.show()
|
|
|
|
.. image:: _static/examples/plot_silhouette.png
|
|
:align: center
|
|
:alt: Silhouette Plot
|
|
"""
|
|
if copy:
|
|
clf = clone(clf)
|
|
|
|
cluster_labels = clf.fit_predict(X)
|
|
|
|
n_clusters = len(set(cluster_labels))
|
|
|
|
silhouette_avg = silhouette_score(X, cluster_labels, metric=metric)
|
|
|
|
sample_silhouette_values = silhouette_samples(X, cluster_labels,
|
|
metric=metric)
|
|
|
|
if ax is None:
|
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
|
|
ax.set_title(title, fontsize=title_fontsize)
|
|
ax.set_xlim([-0.1, 1])
|
|
|
|
ax.set_ylim([0, len(X) + (n_clusters + 1) * 10 + 10])
|
|
|
|
ax.set_xlabel('Silhouette coefficient values', fontsize=text_fontsize)
|
|
ax.set_ylabel('Cluster label', fontsize=text_fontsize)
|
|
|
|
y_lower = 10
|
|
|
|
for i in range(n_clusters):
|
|
ith_cluster_silhouette_values = sample_silhouette_values[
|
|
cluster_labels == i]
|
|
|
|
ith_cluster_silhouette_values.sort()
|
|
|
|
size_cluster_i = ith_cluster_silhouette_values.shape[0]
|
|
y_upper = y_lower + size_cluster_i
|
|
|
|
color = plt.cm.get_cmap(cmap)(float(i) / n_clusters)
|
|
|
|
ax.fill_betweenx(np.arange(y_lower, y_upper),
|
|
0, ith_cluster_silhouette_values,
|
|
facecolor=color, edgecolor=color, alpha=0.7)
|
|
|
|
ax.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i),
|
|
fontsize=text_fontsize)
|
|
|
|
y_lower = y_upper + 10
|
|
|
|
ax.axvline(x=silhouette_avg, color="red", linestyle="--",
|
|
label='Silhouette score: {0:0.3f}'.format(silhouette_avg))
|
|
|
|
ax.set_yticks([]) # Clear the y-axis labels / ticks
|
|
ax.set_xticks(np.arange(-0.1, 1.0, 0.2))
|
|
|
|
ax.tick_params(labelsize=text_fontsize)
|
|
ax.legend(loc='best', fontsize=text_fontsize)
|
|
|
|
return ax
|
|
|
|
|
|
@deprecated('This will be removed in v0.4.0. Please use '
|
|
'scikitplot.cluster.plot_elbow_curve instead.')
|
|
def plot_elbow_curve(clf, X, title='Elbow Plot', cluster_ranges=None,
|
|
ax=None, figsize=None, title_fontsize="large",
|
|
text_fontsize="medium"):
|
|
"""Plots elbow curve of different values of K for KMeans clustering.
|
|
|
|
Args:
|
|
clf: Clusterer instance that implements ``fit`` and ``fit_predict``
|
|
methods and a ``score`` parameter.
|
|
|
|
X (array-like, shape (n_samples, n_features)):
|
|
Data to cluster, where n_samples is the number of samples and
|
|
n_features is the number of features.
|
|
|
|
title (string, optional): Title of the generated plot. Defaults to
|
|
"Elbow Plot"
|
|
|
|
cluster_ranges (None or :obj:`list` of int, optional): List of
|
|
n_clusters for which to plot the explained variances. Defaults to
|
|
``range(1, 12, 2)``.
|
|
|
|
copy (boolean, optional): Determines whether ``fit`` is used on
|
|
**clf** or on a copy of **clf**.
|
|
|
|
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
|
|
plot the curve. If None, the plot is drawn on a new set of axes.
|
|
|
|
figsize (2-tuple, optional): Tuple denoting figure size of the plot
|
|
e.g. (6, 6). Defaults to ``None``.
|
|
|
|
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"large".
|
|
|
|
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"medium".
|
|
|
|
Returns:
|
|
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
|
|
drawn.
|
|
|
|
Example:
|
|
>>> import scikitplot.plotters as skplt
|
|
>>> kmeans = KMeans(random_state=1)
|
|
>>> skplt.plot_elbow_curve(kmeans, cluster_ranges=range(1, 11))
|
|
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
|
|
>>> plt.show()
|
|
|
|
.. image:: _static/examples/plot_elbow_curve.png
|
|
:align: center
|
|
:alt: Elbow Curve
|
|
"""
|
|
if cluster_ranges is None:
|
|
cluster_ranges = range(1, 12, 2)
|
|
else:
|
|
cluster_ranges = sorted(cluster_ranges)
|
|
|
|
if not hasattr(clf, 'n_clusters'):
|
|
raise TypeError('"n_clusters" attribute not in classifier. '
|
|
'Cannot plot elbow method.')
|
|
|
|
clfs = []
|
|
for i in cluster_ranges:
|
|
current_clf = clone(clf)
|
|
setattr(current_clf, "n_clusters", i)
|
|
clfs.append(current_clf.fit(X).score(X))
|
|
|
|
if ax is None:
|
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
|
|
ax.set_title(title, fontsize=title_fontsize)
|
|
ax.plot(cluster_ranges, np.absolute(clfs), 'b*-')
|
|
ax.grid(True)
|
|
ax.set_xlabel('Number of clusters', fontsize=text_fontsize)
|
|
ax.set_ylabel('Sum of Squared Errors', fontsize=text_fontsize)
|
|
ax.tick_params(labelsize=text_fontsize)
|
|
|
|
return ax
|
|
|
|
|
|
@deprecated('This will be removed in v0.4.0. Please use '
|
|
'scikitplot.decomposition.plot_pca_component_variance instead.')
|
|
def plot_pca_component_variance(clf, title='PCA Component Explained Variances',
|
|
target_explained_variance=0.75, ax=None,
|
|
figsize=None, title_fontsize="large",
|
|
text_fontsize="medium"):
|
|
"""Plots PCA components' explained variance ratios. (new in v0.2.2)
|
|
|
|
Args:
|
|
clf: PCA instance that has the ``explained_variance_ratio_`` attribute.
|
|
|
|
title (string, optional): Title of the generated plot. Defaults to
|
|
"PCA Component Explained Variances"
|
|
|
|
target_explained_variance (float, optional): Looks for the minimum
|
|
number of principal components that satisfies this value and
|
|
emphasizes it on the plot. Defaults to 0.75
|
|
|
|
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
|
|
plot the curve. If None, the plot is drawn on a new set of axes.
|
|
|
|
figsize (2-tuple, optional): Tuple denoting figure size of the plot
|
|
e.g. (6, 6). Defaults to ``None``.
|
|
|
|
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"large".
|
|
|
|
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"medium".
|
|
|
|
Returns:
|
|
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
|
|
drawn.
|
|
|
|
Example:
|
|
>>> import scikitplot.plotters as skplt
|
|
>>> pca = PCA(random_state=1)
|
|
>>> pca.fit(X)
|
|
>>> skplt.plot_pca_component_variance(pca)
|
|
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
|
|
>>> plt.show()
|
|
|
|
.. image:: _static/examples/plot_pca_component_variance.png
|
|
:align: center
|
|
:alt: PCA Component variances
|
|
"""
|
|
if not hasattr(clf, 'explained_variance_ratio_'):
|
|
raise TypeError('"clf" does not have explained_variance_ratio_ '
|
|
'attribute. Has the PCA been fitted?')
|
|
|
|
if ax is None:
|
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
|
|
ax.set_title(title, fontsize=title_fontsize)
|
|
|
|
cumulative_sum_ratios = np.cumsum(clf.explained_variance_ratio_)
|
|
|
|
# Magic code for figuring out closest value to target_explained_variance
|
|
idx = np.searchsorted(cumulative_sum_ratios, target_explained_variance)
|
|
|
|
ax.plot(range(len(clf.explained_variance_ratio_) + 1),
|
|
np.concatenate(([0], np.cumsum(clf.explained_variance_ratio_))),
|
|
'*-')
|
|
ax.grid(True)
|
|
ax.set_xlabel('First n principal components', fontsize=text_fontsize)
|
|
ax.set_ylabel('Explained variance ratio of first n components',
|
|
fontsize=text_fontsize)
|
|
ax.set_ylim([-0.02, 1.02])
|
|
if idx < len(cumulative_sum_ratios):
|
|
ax.plot(idx+1, cumulative_sum_ratios[idx], 'ro',
|
|
label='{0:0.3f} Explained variance ratio for '
|
|
'first {1} components'.format(cumulative_sum_ratios[idx],
|
|
idx+1),
|
|
markersize=4, markeredgewidth=4)
|
|
ax.axhline(cumulative_sum_ratios[idx],
|
|
linestyle=':', lw=3, color='black')
|
|
ax.tick_params(labelsize=text_fontsize)
|
|
ax.legend(loc="best", fontsize=text_fontsize)
|
|
|
|
return ax
|
|
|
|
|
|
@deprecated('This will be removed in v0.4.0. Please use '
|
|
'scikitplot.decomposition.plot_pca_component_variance instead.')
|
|
def plot_pca_2d_projection(clf, X, y, title='PCA 2-D Projection', ax=None,
|
|
figsize=None, cmap='Spectral',
|
|
title_fontsize="large", text_fontsize="medium"):
|
|
"""Plots the 2-dimensional projection of PCA on a given dataset.
|
|
|
|
Args:
|
|
clf: Fitted PCA instance that can ``transform`` given data set into 2
|
|
dimensions.
|
|
|
|
X (array-like, shape (n_samples, n_features)):
|
|
Feature set to project, where n_samples is the number of samples
|
|
and n_features is the number of features.
|
|
|
|
y (array-like, shape (n_samples) or (n_samples, n_features)):
|
|
Target relative to X for labeling.
|
|
|
|
title (string, optional): Title of the generated plot. Defaults to
|
|
"PCA 2-D Projection"
|
|
|
|
ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to
|
|
plot the curve. If None, the plot is drawn on a new set of axes.
|
|
|
|
figsize (2-tuple, optional): Tuple denoting figure size of the plot
|
|
e.g. (6, 6). Defaults to ``None``.
|
|
|
|
cmap (string or :class:`matplotlib.colors.Colormap` instance, optional):
|
|
Colormap used for plotting the projection. View Matplotlib Colormap
|
|
documentation for available options.
|
|
https://matplotlib.org/users/colormaps.html
|
|
|
|
title_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"large".
|
|
|
|
text_fontsize (string or int, optional): Matplotlib-style fontsizes.
|
|
Use e.g. "small", "medium", "large" or integer-values. Defaults to
|
|
"medium".
|
|
|
|
Returns:
|
|
ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was
|
|
drawn.
|
|
|
|
Example:
|
|
>>> import scikitplot.plotters as skplt
|
|
>>> pca = PCA(random_state=1)
|
|
>>> pca.fit(X)
|
|
>>> skplt.plot_pca_2d_projection(pca, X, y)
|
|
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
|
|
>>> plt.show()
|
|
|
|
.. image:: _static/examples/plot_pca_2d_projection.png
|
|
:align: center
|
|
:alt: PCA 2D Projection
|
|
"""
|
|
transformed_X = clf.transform(X)
|
|
if ax is None:
|
|
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
|
|
|
ax.set_title(title, fontsize=title_fontsize)
|
|
classes = np.unique(np.array(y))
|
|
|
|
colors = plt.cm.get_cmap(cmap)(np.linspace(0, 1, len(classes)))
|
|
|
|
for label, color in zip(classes, colors):
|
|
ax.scatter(transformed_X[y == label, 0], transformed_X[y == label, 1],
|
|
alpha=0.8, lw=2, label=label, color=color)
|
|
ax.legend(loc='best', shadow=False, scatterpoints=1,
|
|
fontsize=text_fontsize)
|
|
ax.set_xlabel('First Principal Component', fontsize=text_fontsize)
|
|
ax.set_ylabel('Second Principal Component', fontsize=text_fontsize)
|
|
ax.tick_params(labelsize=text_fontsize)
|
|
|
|
return ax
|