-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
Add averaging option to AMI and NMI #11124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
80844ae
Add averaging option to AMI and NMI
aryamccarthy 4794486
Flake8 fixes
aryamccarthy 6279c25
Incorporate tests of means for AMI and NMI
aryamccarthy ed500d6
Add note about `average_method` in NMI
aryamccarthy 5ed8527
Update docs from AMI, NMI changes (#1)
aryamccarthy df60d46
Update documentation and remove nose tests (#2)
aryamccarthy b449cb9
Fix multiple spaces after operator
aryamccarthy 1b36da5
Rename all arguments
aryamccarthy 3d8bf2c
No more arbitrary values!
aryamccarthy 2854014
Improve handling of floating-point imprecision
aryamccarthy 059bae6
Clearly state when the change occurs
aryamccarthy e8b9579
Update AMI/NMI docs
aryamccarthy afe9776
Merge branch 'master' into pr/11124
amueller c65d2b3
Update v0.20.rst
amueller a5b3c0f
Catch FutureWarnings in AMI and NMI
aryamccarthy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,11 +11,13 @@ | |
# Thierry Guillemot <[email protected]> | ||
# Gregory Stupp <[email protected]> | ||
# Joel Nothman <[email protected]> | ||
# Arya McCarthy <[email protected]> | ||
# License: BSD 3 clause | ||
|
||
from __future__ import division | ||
|
||
from math import log | ||
import warnings | ||
|
||
import numpy as np | ||
from scipy import sparse as sp | ||
|
@@ -59,6 +61,21 @@ def check_clusterings(labels_true, labels_pred): | |
return labels_true, labels_pred | ||
|
||
|
||
def _generalized_average(U, V, average_method): | ||
"""Return a particular mean of two numbers.""" | ||
if average_method == "min": | ||
return min(U, V) | ||
elif average_method == "geometric": | ||
return np.sqrt(U * V) | ||
elif average_method == "arithmetic": | ||
return np.mean([U, V]) | ||
elif average_method == "max": | ||
return max(U, V) | ||
else: | ||
raise ValueError("'average_method' must be 'min', 'geometric', " | ||
"'arithmetic', or 'max'") | ||
|
||
|
||
def contingency_matrix(labels_true, labels_pred, eps=None, sparse=False): | ||
"""Build a contingency matrix describing the relationship between labels. | ||
|
||
|
@@ -245,7 +262,9 @@ def homogeneity_completeness_v_measure(labels_true, labels_pred): | |
|
||
V-Measure is furthermore symmetric: swapping ``labels_true`` and | ||
``label_pred`` will give the same score. This does not hold for | ||
homogeneity and completeness. | ||
homogeneity and completeness. V-Measure is identical to | ||
:func:`normalized_mutual_info_score` with the arithmetic averaging | ||
method. | ||
|
||
Read more in the :ref:`User Guide <homogeneity_completeness>`. | ||
|
||
|
@@ -444,7 +463,8 @@ def completeness_score(labels_true, labels_pred): | |
def v_measure_score(labels_true, labels_pred): | ||
"""V-measure cluster labeling given a ground truth. | ||
|
||
This score is identical to :func:`normalized_mutual_info_score`. | ||
This score is identical to :func:`normalized_mutual_info_score` with | ||
the ``'arithmetic'`` option for averaging. | ||
|
||
The V-measure is the harmonic mean between homogeneity and completeness:: | ||
|
||
|
@@ -459,6 +479,7 @@ def v_measure_score(labels_true, labels_pred): | |
measure the agreement of two independent label assignments strategies | ||
on the same dataset when the real ground truth is not known. | ||
|
||
|
||
Read more in the :ref:`User Guide <homogeneity_completeness>`. | ||
|
||
Parameters | ||
|
@@ -485,6 +506,7 @@ def v_measure_score(labels_true, labels_pred): | |
-------- | ||
homogeneity_score | ||
completeness_score | ||
normalized_mutual_info_score | ||
|
||
Examples | ||
-------- | ||
|
@@ -617,7 +639,8 @@ def mutual_info_score(labels_true, labels_pred, contingency=None): | |
return mi.sum() | ||
|
||
|
||
def adjusted_mutual_info_score(labels_true, labels_pred): | ||
def adjusted_mutual_info_score(labels_true, labels_pred, | ||
average_method='warn'): | ||
"""Adjusted Mutual Information between two clusterings. | ||
|
||
Adjusted Mutual Information (AMI) is an adjustment of the Mutual | ||
|
@@ -626,7 +649,7 @@ def adjusted_mutual_info_score(labels_true, labels_pred): | |
clusters, regardless of whether there is actually more information shared. | ||
For two clusterings :math:`U` and :math:`V`, the AMI is given as:: | ||
|
||
AMI(U, V) = [MI(U, V) - E(MI(U, V))] / [max(H(U), H(V)) - E(MI(U, V))] | ||
AMI(U, V) = [MI(U, V) - E(MI(U, V))] / [avg(H(U), H(V)) - E(MI(U, V))] | ||
|
||
This metric is independent of the absolute values of the labels: | ||
a permutation of the class or cluster label values won't change the | ||
|
@@ -650,9 +673,17 @@ def adjusted_mutual_info_score(labels_true, labels_pred): | |
labels_pred : array, shape = [n_samples] | ||
A clustering of the data into disjoint subsets. | ||
|
||
average_method : string, optional (default: 'warn') | ||
How to compute the normalizer in the denominator. Possible options | ||
are 'min', 'geometric', 'arithmetic', and 'max'. | ||
If 'warn', 'max' will be used. The default will change to | ||
'arithmetic' in version 0.22. | ||
|
||
.. versionadded:: 0.20 | ||
|
||
Returns | ||
------- | ||
ami: float(upperlimited by 1.0) | ||
ami: float (upperlimited by 1.0) | ||
The AMI returns a value of 1 when the two partitions are identical | ||
(ie perfectly matched). Random partitions (independent labellings) have | ||
an expected AMI around 0 on average hence can be negative. | ||
|
@@ -691,6 +722,12 @@ def adjusted_mutual_info_score(labels_true, labels_pred): | |
<https://en.wikipedia.org/wiki/Adjusted_Mutual_Information>`_ | ||
|
||
""" | ||
if average_method == 'warn': | ||
warnings.warn("The behavior of AMI will change in version 0.22. " | ||
"To match the behavior of 'v_measure_score', AMI will " | ||
"use average_method='arithmetic' by default.", | ||
FutureWarning) | ||
average_method = 'max' | ||
labels_true, labels_pred = check_clusterings(labels_true, labels_pred) | ||
n_samples = labels_true.shape[0] | ||
classes = np.unique(labels_true) | ||
|
@@ -709,17 +746,29 @@ def adjusted_mutual_info_score(labels_true, labels_pred): | |
emi = expected_mutual_information(contingency, n_samples) | ||
# Calculate entropy for each labeling | ||
h_true, h_pred = entropy(labels_true), entropy(labels_pred) | ||
ami = (mi - emi) / (max(h_true, h_pred) - emi) | ||
normalizer = _generalized_average(h_true, h_pred, average_method) | ||
denominator = normalizer - emi | ||
# Avoid 0.0 / 0.0 when expectation equals maximum, i.e a perfect match. | ||
# normalizer should always be >= emi, but because of floating-point | ||
# representation, sometimes emi is slightly larger. Correct this | ||
# by preserving the sign. | ||
if denominator < 0: | ||
denominator = min(denominator, -np.finfo('float64').eps) | ||
else: | ||
denominator = max(denominator, np.finfo('float64').eps) | ||
ami = (mi - emi) / denominator | ||
return ami | ||
|
||
|
||
def normalized_mutual_info_score(labels_true, labels_pred): | ||
def normalized_mutual_info_score(labels_true, labels_pred, | ||
average_method='warn'): | ||
"""Normalized Mutual Information between two clusterings. | ||
|
||
Normalized Mutual Information (NMI) is an normalization of the Mutual | ||
Information (MI) score to scale the results between 0 (no mutual | ||
information) and 1 (perfect correlation). In this function, mutual | ||
information is normalized by ``sqrt(H(labels_true) * H(labels_pred))``. | ||
information is normalized by some generalized mean of ``H(labels_true)`` | ||
and ``H(labels_pred))``, defined by the `average_method`. | ||
|
||
This measure is not adjusted for chance. Therefore | ||
:func:`adjusted_mustual_info_score` might be preferred. | ||
|
@@ -743,13 +792,22 @@ def normalized_mutual_info_score(labels_true, labels_pred): | |
labels_pred : array, shape = [n_samples] | ||
A clustering of the data into disjoint subsets. | ||
|
||
average_method : string, optional (default: 'warn') | ||
How to compute the normalizer in the denominator. Possible options | ||
are 'min', 'geometric', 'arithmetic', and 'max'. | ||
If 'warn', 'geometric' will be used. The default will change to | ||
'arithmetic' in version 0.22. | ||
|
||
.. versionadded:: 0.20 | ||
|
||
Returns | ||
------- | ||
nmi : float | ||
score between 0.0 and 1.0. 1.0 stands for perfectly complete labeling | ||
|
||
See also | ||
-------- | ||
v_measure_score: V-Measure (NMI with arithmetic mean option.) | ||
adjusted_rand_score: Adjusted Rand Index | ||
adjusted_mutual_info_score: Adjusted Mutual Information (adjusted | ||
against chance) | ||
|
@@ -773,6 +831,12 @@ def normalized_mutual_info_score(labels_true, labels_pred): | |
0.0 | ||
|
||
""" | ||
if average_method == 'warn': | ||
warnings.warn("The behavior of NMI will change in version 0.22. " | ||
"To match the behavior of 'v_measure_score', NMI will " | ||
"use average_method='arithmetic' by default.", | ||
FutureWarning) | ||
average_method = 'geometric' | ||
labels_true, labels_pred = check_clusterings(labels_true, labels_pred) | ||
classes = np.unique(labels_true) | ||
clusters = np.unique(labels_pred) | ||
|
@@ -789,7 +853,10 @@ def normalized_mutual_info_score(labels_true, labels_pred): | |
# Calculate the expected value for the mutual information | ||
# Calculate entropy for each labeling | ||
h_true, h_pred = entropy(labels_true), entropy(labels_pred) | ||
nmi = mi / max(np.sqrt(h_true * h_pred), 1e-10) | ||
normalizer = _generalized_average(h_true, h_pred, average_method) | ||
# Avoid 0.0 / 0.0 when either entropy is zero. | ||
normalizer = max(normalizer, np.finfo('float64').eps) | ||
nmi = mi / normalizer | ||
return nmi | ||
|
||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fact that mean is configurable and varies in the literature should be discussed here, perhaps with some notes on when one is more appropriate than another