Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import numpy as np | |
from sklearn.cluster import AffinityPropagation | |
from sklearn import metrics | |
from sklearn.datasets import make_blobs | |
import matplotlib.pyplot as plt | |
import matplotlib | |
matplotlib.use('agg') | |
import gradio as gr | |
def generate_data(num_centers, num_samples): | |
all_centers = [[1, 1], [-1, -1], [1, -1], [-1, 1]] | |
centers = all_centers[:num_centers] | |
X, labels_true = make_blobs(n_samples=num_samples, centers=centers, cluster_std=0.5, random_state=0) | |
return X, labels_true | |
def create_plot(num_clusters, num_samples): | |
X, labels_true = generate_data(num_clusters, num_samples) | |
af = AffinityPropagation(preference=-50, random_state=0).fit(X) | |
cluster_centers_indices = af.cluster_centers_indices_ | |
labels = af.labels_ | |
n_clusters_ = len(cluster_centers_indices) | |
metrics_str = f"Estimated number of clusters: {n_clusters_}\n" | |
metrics_str += f"Homogeneity: {metrics.homogeneity_score(labels_true, labels):0.3f}\n" | |
metrics_str += f"Completeness: {metrics.completeness_score(labels_true, labels):0.3f}\n" | |
metrics_str += f"V-measure: {metrics.v_measure_score(labels_true, labels):0.3f}\n" | |
metrics_str += f"Adjusted Rand Index: {metrics.adjusted_rand_score(labels_true, labels):0.3f}\n" | |
metrics_str += f"Adjusted Mutual Information: {metrics.adjusted_mutual_info_score(labels_true, labels):0.3f}\n" | |
metrics_str += f"Silhouette Coefficient: {metrics.silhouette_score(X, labels, metric='sqeuclidean'):0.3f}\n" | |
fig = plt.figure(1) | |
plt.clf() | |
colors = plt.cycler("color", plt.cm.viridis(np.linspace(0, 1, n_clusters_))) | |
for k, col in zip(range(n_clusters_), colors): | |
class_members = labels == k | |
cluster_center = X[cluster_centers_indices[k]] | |
plt.scatter( | |
X[class_members, 0], X[class_members, 1], color=col["color"], marker="." | |
) | |
plt.scatter( | |
cluster_center[0], cluster_center[1], s=14, color=col["color"], marker="o" | |
) | |
for x in X[class_members]: | |
plt.plot( | |
[cluster_center[0], x[0]], [cluster_center[1], x[1]], color=col["color"] | |
) | |
plt.title("Estimated number of clusters: %d" % n_clusters_) | |
return fig, metrics_str | |
title = "Affinity propagation clustering algorithm" | |
description = "This demo plots clusters of a synthetic 2D dataset that contains up to 4 clusters using the affinity propagation algorithm." | |
with gr.Blocks() as demo: | |
gr.Markdown(f"## {title}") | |
gr.Markdown(description) | |
num_clusters = gr.Slider(minimum=2, maximum=4, step=1, value=2, label="Number of clusters") | |
num_samples = gr.Slider(minimum=100, maximum=300, step=100, value=200, label="Number of samples") | |
with gr.Row(): | |
plot = gr.Plot() | |
text_box = gr.Textbox(label="Results") | |
num_clusters.change(fn=create_plot, inputs=[num_clusters, num_samples], outputs=[plot, text_box]) | |
num_samples.change(fn=create_plot, inputs=[num_clusters, num_samples], outputs=[plot, text_box]) | |
demo.launch(enable_queue=True) |