Categories

Calendar

January 2022
M T W T F S S
« Jun    
 12
3456789
10111213141516
17181920212223
24252627282930
31  

漫谈 Clustering (1): k-means

cluster_logo本文是“漫谈 Clustering 系列”中的第 1 篇,参见本系列的其他文章。

好久没有写 blog 了,一来是 blog 下线一段时间,而租 DreamHost 的事情又一直没弄好;二来是没有太多时间,天天都跑去实验室。现在主要折腾 Machine Learning 相关的东西,因为很多东西都不懂,所以平时也找一些资料来看。按照我以前的更新速度的话,这么长时间不写 blog 肯定是要被闷坏的,所以我也觉得还是不定期地整理一下自己了解到的东西,放在 blog 上,一来梳理总是有助于加深理解的,二来也算共享一下知识了。那么,还是从 clustering 说起吧。

Clustering 中文翻译作“聚类”,简单地说就是把相似的东西分到一组,同 Classification (分类)不同,对于一个 classifier ,通常需要你告诉它“这个东西被分为某某类”这样一些例子,理想情况下,一个 classifier 会从它得到的训练集中进行“学习”,从而具备对未知数据进行分类的能力,这种提供训练数据的过程通常叫做 supervised learning (监督学习),而在聚类的时候,我们并不关心某一类是什么,我们需要实现的目标只是把相似的东西聚到一起,因此,一个聚类算法通常只需要知道如何计算相似 度就可以开始工作了,因此 clustering 通常并不需要使用训练数据进行学习,这在 Machine Learning 中被称作 unsupervised learning (无监督学习)。

举一个简单的例子:现在有一群小学生,你要把他们分成几组,让组内的成员之间尽量相似一些,而组之间则差别大一些。最后分出怎样的结果,就取决于你对于“相似”的定义了,比如,你决定男生和男生是相似的,女生和女生也是相似的,而男生和女生之间则差别很大”,这样,你实际上是用一个可能取两个值“男”和“女”的离散变量来代表了原来的一个小学生,我们通常把这样的变量叫做“特征”。实际上,在这种情况下,所有的小学生都被映射到了两个点的其中一个上,已经很自然地形成了两个组,不需要专门再做聚类了。另一种可能是使用“身高”这个特征。我在读小学候,每周五在操场开会训话的时候会按照大家住的地方的地域和距离远近来列队,这样结束之后就可以结队回家了。除了让事物映射到一个单独的特征之外,一种常见的做法是同时提取 N 种特征,将它们放在一起组成一个 N 维向量,从而得到一个从原始数据集合到 N 维向量空间的映射——你总是需要显式地或者隐式地完成这样一个过程,因为许多机器学习的算法都需要工作在一个向量空间中。

那么让我们再回到 clustering 的问题上,暂且抛开原始数据是什么形式,假设我们已经将其映射到了一个欧几里德空间上,为了方便展示,就使用二维空间吧,如下图所示:

cluster

从数据点的大致形状可以看出它们大致聚为三个 cluster ,其中两个紧凑一些,剩下那个松散一些。我们的目的是为这些数据分组,以便能区分出属于不同的簇的数据,如果按照分组给它们标上不同的颜色,就是这个样子:

cluster

那么计算机要如何来完成这个任务呢?当然,计算机还没有高级到能够“通过形状大致看出来”,不过,对于这样的 N 维欧氏空间中的点进行聚类,有一个非常简单的经典算法,也就是本文标题中提到的 k-means 。在介绍 k-means 的具体步骤之前,让我们先来看看它对于需要进行聚类的数据的一个基本假设吧:对于每一个 cluster ,我们可以选出一个中心点 (center) ,使得该 cluster 中的所有的点到该中心点的距离小于到其他 cluster 的中心的距离。虽然实际情况中得到的数据并不能保证总是满足这样的约束,但这通常已经是我们所能达到的最好的结果,而那些误差通常是固有存在的或者问题本身的不可分性造成的。例如下图所示的两个高斯分布,从两个分布中随机地抽取一些数据点出来,混杂到一起,现在要让你将这些混杂在一起的数据点按照它们被生成的那个分布分开来:

gaussian

由于这两个分布本身有很大一部分重叠在一起了,例如,对于数据点 2.5 来说,它由两个分布产生的概率都是相等的,你所做的只能是一个猜测;稍微好一点的情况是 2 ,通常我们会将它归类为左边的那个分布,因为概率大一些,然而此时它由右边的分布生成的概率仍然是比较大的,我们仍然有不小的几率会猜错。而整个阴影部分是我们所能达到的最小的猜错的概率,这来自于问题本身的不可分性,无法避免。因此,我们将 k-means 所依赖的这个假设看作是合理的。

基于这样一个假设,我们再来导出 k-means 所要优化的目标函数:设我们一共有 N 个数据点需要分为 K 个 cluster ,k-means 要做的就是最小化

\displaystyle J = \sum_{n=1}^N\sum_{k=1}^K r_{nk} \|x_n-\mu_k\|^2

这个函数,其中 r_{nk} 在数据点 n 被归类到 cluster k 的时候为 1 ,否则为 0 。直接寻找 r_{nk} 和 \mu_k 来最小化 J 并不容易,不过我们可以采取迭代的办法:先固定 \mu_k ,选择最优的 r_{nk} ,很容易看出,只要将数据点归类到离他最近的那个中心就能保证 J 最小。下一步则固定 r_{nk},再求最优的 \mu_k。将 J 对 \mu_k 求导并令导数等于零,很容易得到 J 最小的时候 \mu_k 应该满足:

\displaystyle \mu_k=\frac{\sum_n r_{nk}x_n}{\sum_n r_{nk}}

亦即 \mu_k 的值应当是所有 cluster k 中的数据点的平均值。由于每一次迭代都是取到 J 的最小值,因此 J 只会不断地减小(或者不变),而不会增加,这保证了 k-means 最终会到达一个极小值。虽然 k-means 并不能保证总是能得到全局最优解,但是对于这样的问题,像 k-means 这种复杂度的算法,这样的结果已经是很不错的了。

下面我们来总结一下 k-means 算法的具体步骤:

  1. 选定 K 个中心 \mu_k 的初值。这个过程通常是针对具体的问题有一些启发式的选取方法,或者大多数情况下采用随机选取的办法。因为前面说过 k-means 并不能保证全局最优,而是否能收敛到全局最优解其实和初值的选取有很大的关系,所以有时候我们会多次选取初值跑 k-means ,并取其中最好的一次结果。
  2. 将每个数据点归类到离它最近的那个中心点所代表的 cluster 中。
  3. 用公式 \mu_k = \frac{1}{N_k}\sum_{j\in\text{cluster}_k}x_j 计算出每个 cluster 的新的中心点。
  4. 重复第二步,一直到迭代了最大的步数或者前后的 J 的值相差小于一个阈值为止。

按照这个步骤写一个 k-means 实现其实相当容易了,在 SciPy 或者 Matlab 中都已经包含了内置的 k-means 实现,不过为了看看 k-means 每次迭代的具体效果,我们不妨自己来实现一下,代码如下(需要安装 SciPy 和 matplotlib) :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/python
 
from __future__ import with_statement
import cPickle as pickle
from matplotlib import pyplot
from numpy import zeros, array, tile
from scipy.linalg import norm
import numpy.matlib as ml
import random
 
def kmeans(X, k, observer=None, threshold=1e-15, maxiter=300):
    N = len(X)
    labels = zeros(N, dtype=int)
    centers = array(random.sample(X, k))
    iter = 0
 
    def calc_J():
        sum = 0
        for i in xrange(N):
            sum += norm(X[i]-centers[labels[i]])
        return sum
 
    def distmat(X, Y):
        n = len(X)
        m = len(Y)
        xx = ml.sum(X*X, axis=1)
        yy = ml.sum(Y*Y, axis=1)
        xy = ml.dot(X, Y.T)
 
        return tile(xx, (m, 1)).T+tile(yy, (n, 1)) - 2*xy
 
    Jprev = calc_J()
    while True:
        # notify the observer
        if observer is not None:
            observer(iter, labels, centers)
 
        # calculate distance from x to each center
        # distance_matrix is only available in scipy newer than 0.7
        # dist = distance_matrix(X, centers)
        dist = distmat(X, centers)
        # assign x to nearst center
        labels = dist.argmin(axis=1)
        # re-calculate each center
        for j in range(k):
            idx_j = (labels == j).nonzero()
            centers[j] = X[idx_j].mean(axis=0)
 
        J = calc_J()
        iter += 1
 
        if Jprev-J < threshold:
            break
        Jprev = J
        if iter >= maxiter:
            break
 
    # final notification
    if observer is not None:
        observer(iter, labels, centers)
 
if __name__ == '__main__':
    # load previously generated points
    with open('cluster.pkl') as inf:
        samples = pickle.load(inf)
    N = 0
    for smp in samples:
        N += len(smp[0])
    X = zeros((N, 2))
    idxfrm = 0
    for i in range(len(samples)):
        idxto = idxfrm + len(samples[i][0])
        X[idxfrm:idxto, 0] = samples[i][0]
        X[idxfrm:idxto, 1] = samples[i][1]
        idxfrm = idxto
 
    def observer(iter, labels, centers):
        print "iter %d." % iter
        colors = array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
        pyplot.plot(hold=False)  # clear previous plot
        pyplot.hold(True)
 
        # draw points
        data_colors=[colors[lbl] for lbl in labels]
        pyplot.scatter(X[:, 0], X[:, 1], c=data_colors, alpha=0.5)
        # draw centers
        pyplot.scatter(centers[:, 0], centers[:, 1], s=200, c=colors)
 
        pyplot.savefig('kmeans/iter_%02d.png' % iter, format='png')
 
    kmeans(X, 3, observer=observer)

代码有些长,不过因为用 Python 来做这个事情确实不如 Matlab 方便,实际的 k-means 的代码只是 41 到 47 行。首先 3 个中心点被随机初始化,所有的数据点都还没有进行聚类,默认全部都标记为红色,如下图所示:

iter_00

然后进入第一次迭代:按照初始的中心点位置为每个数据点着上颜色,这是代码中第 41 到 43 行所做的工作,然后 45 到 47 行重新计算 3 个中心点,结果如下图所示:

iter_01

可以看到,由于初始的中心点是随机选的,这样得出来的结果并不是很好,接下来是下一次迭代的结果:

iter_02

可以看到大致形状已经出来了。再经过两次迭代之后,基本上就收敛了,最终结果如下:

iter_04

不过正如前面所说的那样 k-means 也并不是万能的,虽然许多时候都能收敛到一个比较好的结果,但是也有运气不好的时候会收敛到一个让人不满意的局部最优解,例如选用下面这几个初始中心点:

iter_00_bad

最终会收敛到这样的结果:

iter_03_bad

不得不承认这并不是很好的结果。不过其实大多数情况下 k-means 给出的结果都还是很令人满意的,算是一种简单高效应用广泛的 clustering 方法。

Update 2010.04.25: 很多人都问我要 cluster.pkl ,我干脆把它上传上来吧,其实是很容易自己生成的,点击这里下载。

70 comments to 漫谈 Clustering (1): k-means

  • 如果用模拟退火类似的算法综合k-means 是不是可以走出局部极值

  • @winsty
    恩,不知道模拟退火算法这个东西,刚才查了一下,发现看起来好像很牛的样子:

    • 初值无关
    • 几乎以概率 1 收敛于全局最优解
    • 具有并行性

    不过大致看了一下它的求解过程,看起来和求解 PageRank 那里用的方法差不多,就是有一个概率能跳出局部最优,而不是死陷在那里。真的要用在这里的话,就是直接去求 J 的极值了,和 K-means 基本上无关了。

    不过 K-means 的真正目的实际上是进行聚类,也就是标 label ,求得 J 的最小值只是其中一个附加产品,就算用退火能求出全局最小的 J ,却只是得到了一个 bound 而已,要从这个 J 的数值推导出所对应的 cluster 形态还是没有办法的事情啊。

  • @pluskid
    不是啊 在最小化J的同时也能够求出对应的参数rnk吧
    这里有个伪代码
    2.1.1步骤就是按照你这篇文章里提到的那个办法,本质没有区别
    只是用SA走出局部极值,避免发生你最后一个图片那样的问题

    Simulated Annealing (SA) Algorithm:
    1 初始化:系统初温T ,初始状态S0 ,马尔可夫链长L,终止条件AIM
    2 while (true)
    2.1 对于k=1..L, 执行2.1.1到2.1.4
    2.1.1 从当前解S,产生新解SN ,他们之间的差值为D.
    2.1.2 若 (D<0 或 满足概率 exp(-D/T)),则 S:=SN.
    2.1.3 若(当前解S<当前最优解 SB),则SB:=S.
    2.1.4 if (T趋于0 或 连续AIM次迭代物最优值), 则可近似认为SB为最优,转3.
    2.2 降温
    2.3 S=SB
    3 输出 SB

  • @winsty
    哦!我大致明白了,就是每次迭代的时候实际上是有一个概率是否接受新的解了。不过还是不能直接套到 K-means 里面去,因为 K-means 每次产生新的解的方式是固定的,而不是随机的,换句话说,初值确定之后,后面会得到什么样的结果就已经定了。要用在这里的话,还要设计一个产生新解的步骤,就是对应到你那个伪代码的 2.1.1 那一步。

  • @pluskid
    嗯 是的-.-
    不过这个也应该不太困难

  • @winsty
    恩,回头好好研究下,这个好像和 Markov random walk 有关系。

    ps: blog 上的时间好像和中国时间相差了十几个小时啊,得好好设置一下……

  • @pluskid
    期待后续连载……
    花了一上午看这些文章和相关的wiki链接
    爽死了

  • @winsty
    赞一下看链接的人,难得我用心良苦呢。恩,后面的会陆续出来,不过也急不得,每写一篇都得花不少功夫呢。我也是一边学一边写啊。

  • rhythm

    看到clustering,第一反应是服务器集群,结果发现完全不是……不过也是很有趣的话题。话说这篇文章里的代码、公式和图是不是用CodeColorer、Latex for WordPress以及gnuplot做的?

  • @rhythm
    代码是用 wp-syntax 高亮的,推荐一下这个插件。公式是 LaTeXRender 这个插件吧,好像不太好找,你需要的话我可以拷贝给你。图多是用 Python 的 matplotlib 或者直接用 Matlab 画的。

  • 写得很清楚,赞一个。特别是图,下次我写考试笔记的时候就直接用了,呵呵。

  • lyslys34

    请问使用matlab如何画图以及运行此实例呢?

  • @lyslys34
    这里的代码是 Python 的,Matlab 里自带了一个 kmeans 函数可以用的。

  • heshizhu

    学习…
    k-mean 算法一般都是做baseline比较的,易理解,易实现,效率效果都不错。大部分时间就是花在计算各个对象之间的距离(相异度/相似度)上!matlab可以直接有这个函数!!!对matlab还不熟,我们这里做实验是用java,你们都是用matlab吗?

  • @heshizhu
    是的,能用 Matlab 的话就直接用了,因为很方便。

  • 花瓣雨

    你好,我关注你的文章很久了,你的关于聚类的这几篇文章我都仔细地看过,其中一些层面,我还是有些不太明了,希望能和你交流一下,以期待共同进步,可以吗?多谢!
    邮箱:sunyxrizhao@yahoo.com.cn
    qq:251562907

  • @花瓣雨
    恩,有问题大家可以互相讨论的。

  • 花瓣雨

    @pluskid
    可以告诉我你的qq号或是邮箱吗呵呵或许有点冒失了,不好意思,因为我正在研究聚类集成的问题,想请教一下你在这个方面,有什么高见吗?还请多多指教,谢谢。

  • @花瓣雨
    你好,我的邮箱是 pluskid at gmail.com ,在 about 页面可以看到。我没有 QQ 号,不好意思。 :)

  • fion_ly

    你好,请问为什么没有这个系列Hierarchical Clustering那篇文章?

  • @fion_ly
    唔,不好意思,最近一直比较忙,所以还没有写出来。 >_<

  • luiqt

    请教,抓取网上的文章,分类存储。我现在了解的有两种方式:
    1)TF/IDF + 余弦定理: 每类有个特征那个词库,计算待分类文章与特征词库余弦夹角,取夹角最小的分类
    2)Fisher Method:统计每个分类的概率,去最大者。
    第一个用了关键词在某篇文档中出现的次数,而第二个只用了关键词在多少个文档中出现,而不关心一篇文章中的词频,这是为什么呢?
    这两种方法的区别的优劣势是怎样的呢?

  • @luiqt
    你好,TF 和 IDF 就是分别代表“关键词再某篇文档中出现的次数”和“关键词在多少个文档中出现”,所以你说的第一种方法实际两个信息都用到了。

  • 昨天我把tf/idf+余弦和费舍尔方法仔细思考对比了一下,决定采用tf/idf+余弦方法来实现我的文章分类。
    之前pluskid有篇文章:“训练数据对分类器性能的影响https://blog.pluskid.org/?p=223”,在样本集中于特定一类文章,或者各类文章样本分布不均的情况下,用tf/idf+余弦更加简单,性能更好。
    因为它只需要比较样本集和待分类文章的余弦夹角就可以了。而费舍尔方法需要计算比较待分类文章关键字在各类别中出现的综合概率。

    另外,tf/idf+余弦方法在对文章分类的时候,还可以把IDF省略,减少复杂度和计算量。原始数据只需要tf词频,非常简化了。

  • […] K-means å’Œ GMM 这些聚类的方法是古代流行的算法的话,那么这次要讲的 Spectral […]

  • blackball

    你好,我想请教一下那幅两个高斯分布的插图中间的阴影部分是如何画的?3Q

  • […] K-means å’Œ GMM 这些聚类的方法是古代流行的算法的话,那么这次要讲的 Spectral […]

  • […] 上一次我们谈到了用 k-means 进行聚类的方法,这次我们来说一下另一个很流行的算法:Gaussian Mixture Model (GMM)。事实上,GMM å’Œ k-means 很像,不过 GMM 是学习出一些概率密度函数来(所以 GMM 除了用在 clustering 上之外,还经常被用于 density estimation ),简单地说,k-means 的结果是每个数据点被 assign 到其中某一个 cluster 了,而 GMM 则给出这些数据点被 assign 到每个 cluster 的概率,又称作 soft assignment 。 […]

  • POZEN

    我想请问一下:要是分类的时候我不知道要分成多少类,应该怎么办呢?

    • 这个是个很难的问题,没有什么通用的特别有效的办法啦,一般需要根据领域特定的知识来分析问题。

  • POZEN

    读了你的很多文章,觉得写得非常好,清晰明了,容易理解。期待更多好文章。

  • guobo

    看了博文收获很大
    不晓得博主对m-tree算法是不是了解。

    想和您讨论一下~

  • guobo

    如果想把扫描文件中的文字,这里我们认为只是英文单词。
    想把这些字母聚类,按照某种特性存储,不晓得是否可以用这个算法?

    • kmeans 是通用的聚类算法,如果你的问题确实是需要聚类来解决的话,是可以尝试一下的。

  • […] 上一次我们了解了一个最基本的 clustering 办法 k-means ,这次要说的 k-medoids 算法,其实从名字上就可以看出来,和 k-means 肯定是非常相似的。事实也确实如此,k-medoids 可以算是 k-means 的一个变种。 […]

  • […] 上一次我们谈到了用 k-means 进行聚类的方法,这次我们来说一下另一个很流行的算法:Gaussian Mixture Model (GMM)。事实上,GMM å’Œ k-means 很像,不过 GMM 是学习出一些概率密度函数来(所以 GMM 除了用在 clustering 上之外,还经常被用于 density estimation ),简单地说,k-means 的结果是每个数据点被 assign 到其中某一个 cluster 了,而 GMM 则给出这些数据点被 assign 到每个 cluster 的概率,又称作 soft assignment 。 […]

  • […] K-means å’Œ GMM 这些聚类的方法是古代流行的算法的话,那么这次要讲的 Spectral […]

  • michael

    你好,请问J对Uk求导时是如何进行的,Uk的维度对求导没有影响吗?

    • 你好,J 是一个(一维)函数,对向量求导,可以看作是多元函数求导,得到一个梯度。

  • meng

    Test
    内容总结的相当不错,还望再接再厉。

    PS:如果可以的话,希望能把相关的两点之间距离的求法,以及Cluster重心的公式写上来。

  • […] 参考文章二 维基百科k-means链接 泰森多边形法维基百科链接(Voronoi […]

  • LoveU

    推荐看bishop的pattern recognition and ML chapter 9

  • lsxpu

    每看一篇,Mark 一下

  • 不是的,我的意思是聚类啊,里面的这些python代码,我不能运行。

  • llxlf2012

    请问对于聚类中心点的选择有没有什么方法可以帮助我们?因为我看到的办法都是建议随机选择,谢谢

  • Friday

    请问有没有对聚类的结果作评价的函数?就是我怎么 样才知道聚类聚的好不好?

    • 如果你有真实 label 的话,比较聚类结果和真实 label 就可以了,否则不太好弄。

  • Friday

    楼主~ 知道矩阵聚类么?这个聚类网上的介绍不多啊

  • Friday

    楼主&校友……矩阵聚类有没有听过?

    • geron

      你是指matrix clustering么~之前data mining老师课上略略带过~好像是用在weg data analysis&mining里面的?

  • […] 如果说 K-means 和 GMM 这些聚类的方法是古代流行的算法的话,那么这次要讲的 Spectral Clustering 就可以算是现代流行的算法了,中文通常称为“谱聚类”。由于使用的矩阵的细微差别,谱聚类实际上可以说是一“类”算法。 […]

  • Everest1573

    楼主你的博客太强大了,膜拜之~我正在看你写的Clustering的系列,根据我的理解,你的Clustering应该是针对一些已经存在的数据,然后将之根据一定的准则进行聚类吧?
    但是如果有新的数据加入进来,需要对新的数据进行聚类的话是不是这个问题就变成了“分类”了呢?聚类的算法对“分类”的问题还适用么?
    谢谢楼主哈,我比较菜鸟~

  • austinls

    好文章!最近正在学习CS229,配合楼主的博客学习刚刚好!大谢!

  • Albert

    本人实属菜鸟一枚~~点击下载“cluster.pkl”,怎么成了cluster_sql.mht了,咋弄啊,大神

  • systolic

    作为一个python的beginner,我对cluster.pkl中的data很感兴趣。
    根据博主的文章(强文啊),尤其是聚类图,那些点应该是(x,y)的值
    Google了一点pickel的资料,
    http://blog.sina.com.cn/s/blog_4ba2c6a201012afq.html

    花了点时间,勉强写了一个python脚本,将数据用1450个x y的二元组形式写出来了。
    和博主的图基本是对上的。

    #!/usr/bin/python
    
    from __future__ import with_statement
    import cPickle as pickle
    from numpy import zeros, array, tile
    
    if __name__ == '__main__':
    	# load previously generated points
    	inf = open("cluster.pkl") 
    	samples = pickle.load(inf)
    	
    	N = 0
    	xmin = 1000
    	xmax = -1000
    	ymin = 1000
    	ymax = -1000
    	fRes = open('points.txt', 'w')
    	for smp in samples :
    		N += len(smp[0])
    
    	X= zeros( (N,2))
    	idxfrm = 0
    	for i in range(len(samples)):
    		idxto = idxfrm + len(samples[i][0])
    		X[idxfrm:idxto, 0] = samples[i][0]
    		X[idxfrm:idxto, 1] = samples[i][1]
    		idxfrm = idxto
    	for i in xrange(N):
    		if X[i][0]>xmax:
    			xmax = X[i][0]
    		if X[i][0]ymax:
    			ymax = X[i][1]
    		if X[i][1]<ymin:
    			ymin = X[i][1]
    		fRes.write( "%d" % X[i][0])
    		fRes.write( '	' )
    		fRes.write( "%d" % X[i][1])
    		fRes.write('\n')
    	
    	fRes.write("xmax = %d \n" % xmax)
    	fRes.write("xmin = %d \n" % xmin)
    	fRes.write("ymax = %d \n" % ymax)
    	fRes.write("ymin = %d \n" % ymin)
    

Leave a Reply

 

 

 

You can use these HTML tags

<a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <s> <strike> <strong>