本篇文章概览:
- 什么是高斯混合模型
- 什么是EM算法
- 如何利用EM算法推导GMM
- 使用Python实现GMM
一. 高斯混合模型(GMM)
高斯混合模型, 英文为Gaussian Mixture Model, 简称GMM, 是一种聚类算法. 它和K-means算法很像, 只不过GMM得到的结果是对概率密度的估计, 是一种软聚类. 那么究竟什么是高斯混合模型呢? 其实顾名思义, 其就假设数据是由多个服从高斯分布的数据混合而成的. 这里究竟有几个高斯分布不能确定, 就像K-means算法里的k值一样, 是一种超参数, 更多时候需要领域知识来决定. 模型中的每一个高斯分布被称为component, 即组分. 每一个组分的概率密度线性叠加就组成了GMM的概率密度函数:
根据上面的式子,如果我们要从 GMM 的分布中随机地取一个点的话,实际上可以分为两步:首先随机地在这 K 个 Component 之中选一个,每个 Component 被选中的概率实际上就是它的系数 ,选中了 Component 之后,再单独地考虑从这个 Component 的分布中选取一个点就可以了──这里已经回到了普通的 Gaussian 分布,转化为了已知的问题。[1]
给定一批数据, 我们想用GMM来对这批数据进行聚类. 具体如何做呢? 很简单, 我们只要通过这批数据来得到GMM的概率密度函数即可. 本质上就是通过数据来计算等参数.其中, 通过数据来推算概率密度被称作density estimation, 而估算参数被称作parameter estimation.
如何估计这些参数? 这就回到我们熟悉的最大似然估计了. 根据概率密度函数, 很容易写出对应的log似然函数:
对于上式, 我们无法像普通的log似然函数那样通过求导来求出最大值. 具体地, 我们通过如下步骤来解决这个问题.
- 估计当前模型下第i个观测数据来自第k个分模型的概率, 称为分模型k对观测数据的响应度.
此时, 假设均已知(随机初始值). - 利用第一步的估计每个组分的参数. 直观理解, 可以将看作这个值其中有这部分是由 组分所生成的, 即组份在生成数据时所做的贡献.

- 不断迭代上面两步, 知道收敛为止. 上面这三步其实就是GMM的核心了, 至此我们应该可以轻松的实现GMM的代码了. 不过先不着急, 上面的步骤只是直观地展示了GMM求解的步骤, 那么这些步骤是怎么来的呢? 有没有严格的数学证明? 下面我们就来看看什么是EM算法.
二. EM算法
假定有训练集
包含m个独立样本,希望从中找到该组数据的模型p(x,z)的参数.
常规操作, 对数似然函数为:
z是隐随机变量,不方便直接找到参数估计。 策略:计算l(θ)下界,求该下界的最大值; 重复该过程,直到收敛到局部最大值。
令是的某一个分布,, 有:
其中最后一步利用了Jensen不等式. 当且仅当时, 等号成立, 即:
![]()
这样我们就推出了, 解决了如何选择的问题, 这就是E步, 有了, 就有了的下界. 在M步中, 我们极大化这个下界. 一般的EM算法的步骤如下:
这里值得注意的是, 当我们把似然函数看成是关于和的函数时, 其实我们上面的迭代步骤就是关于和的坐标上升.
接下来, 我们来利用EM算法推导GMM.
三. 利用EM算法推导GMM
随机变量是有个高斯分布混合而成,取各个高斯分布的概率为,第个高斯分布的均值为,方差为。若观测到随机变量的一系列样本试估计参数 。
E-step:
M-step:
将多项分布和高斯分布的参数带入:
对求偏导:
令上式=0:
同理对求偏导并令结果为0可得:
上面就解决了高斯分布中的参数. 下面看多项分布中的参数.
考察M-step的目标函数,对于,删除常数项:
得到:
由于多项分布的概率和为1,建立拉格朗日方程
对 求偏导:
令上式等于0:
带回式中可得:
这样, 我们通过EM算法一步步推导得到了第一节中的结论.
到这里, 我们就掌握了GMM和EM算法. 这里还需注意的是, EM算法是一种通用的算法, 常常用于解决含有因变量的参数估计问题. 它不仅可以用在这里的GMM, 在HMM和LDA(Latent Dirichlet Allocation)中, 我们还会看到EM的身影.
最后, 附上Python实现GMM的代码.
四. Python实现GMM
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from sklearn.utils import shuffle
class GMM():
def __init__(self, n_components=2, max_iter=100):
self.n_comp = 2
self.max_iter = max_iter
self.weights_ = []
self.means_ = []
self.covariances_ = []
def fit(self, X):
m, n = X.shape
means = [np.random.standard_normal(n) for i in range(self.n_comp)]
sigmas = [np.identity(n) for i in range(self.n_comp)]
pis = [1/self.n_comp for i in range(self.n_comp)]
# EM
for i in range(self.max_iter):
# E Step
predict_gausses = [multivariate_normal(mean, sigma) for mean, sigma in zip(means, sigmas)]
gauss_sum = 0
for pi, predict_gauss in zip(pis, predict_gausses):
gauss_sum += pi * predict_gauss.pdf(X)
gammas = [pi * predict_gauss.pdf(X) / gauss_sum for pi, predict_gauss in zip(pis, predict_gausses)]
# M Step
means = [np.dot(gamma, X) / np.sum(gamma) for gamma in gammas]
sigmas = [np.dot(gamma * (X - mean).T, X - mean) / np.sum(gamma) for gamma, mean in zip(gammas, means)]
pis = [np.sum(gamma) / m for gamma in gammas]
self.weights_ = pis
self.covariances_ = sigmas
self.means_ = means
return self
if __name__ == '__main__':
mean1, sigma1 = [0, 0], [[1, 0], [0, 1]]
mean2, sigma2 = [2, 4], [[3, 0], [0, 1]]
# mean1, sigma1 = [0, 0, 0], [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
# mean2, sigma2 = [2, 4, 1], [[3, 0, 0], [0, 1, 0], [0, 0, 2]]
np.random.seed(8827)
X1 = np.random.multivariate_normal(mean1, sigma1, 500)
X2 = np.random.multivariate_normal(mean2, sigma2, 300)
y = np.array([1]*500 + [0]*300)
X = np.vstack([X1, X2])
X, y = shuffle(X, y)
gmm = GMM(n_components=2).fit(X)
weight1, weight2 = gmm.weights_
predict_mean1, predict_mean2 = gmm.means_
predict_sigma1, predict_sigma2 = gmm.covariances_
predict_gauss1 = multivariate_normal(predict_mean1, predict_sigma1)
predict_gauss2 = multivariate_normal(predict_mean2, predict_sigma2)
predict_y1 = predict_gauss1.pdf(X)
predict_y2 = predict_gauss2.pdf(X)
predict1 = (predict_y1 > predict_y2).astype(int)
predict2 = (predict_y1 < predict_y2).astype(int)
acc1, acc2 = np.mean(predict1 == y), np.mean(predict2 == y)
print('accuracy: {}'.format(acc1 if acc1 > acc2 else acc2))
fig = plt.figure(figsize=(12, 6))
ax = fig.add_subplot(121)
ax.set_title('True')
ax.scatter(X[y==1, 0], X[y==1, 1], c='r', s=10)
ax.scatter(X[y==0, 0], X[y==0, 1], c='b', s=10)
ax = fig.add_subplot(122)
ax.set_title('Predict')
ax.scatter(X[predict1==1, 0], X[predict1==1, 1], c='r', s=10)
ax.scatter(X[predict1==0, 0], X[predict1==0, 1], c='b', s=10)
plt.show()
Output:
accuracy: 0.9825
![]()
参考文献: