EM算法的定义及应用范围
EM算法,即最大期望算法(Expectation Maximization Algorithm,又译期望最大化算法),是一种迭代算法,用于含有隐变量(latent variable)的概率参数模型的最大似然估计或极大后验概率估计。
EM算法应用于高斯混合模型(GMM)、聚类、隐式马尔科夫算法(HMM)、基于概率的PLSA模型等等。
问题引出
问题一:
现在一个班里有50个男生,50个女生。我们假定男生的身高服从正态分布
,女生的身高则服从另一个正态分布
。这时候我们可以用极大似然法(MLE),分别通过这50个男生和50个女生的样本来估计这两个正态分布的参数。
问题二:
但现在我们让情况复杂一点,就是这50个男生和50个女生混在一起了。我们拥有100个人的身高数据,但是我们不知道抽取的那100个人里面的每一个人到底是从男生的那个身高分布里面抽取的,还是女生的那个身高分布抽取的。 用数学的语言就是,抽取得到的每个样本都不知道是从哪个分布抽取的。 这个时候,对于每一个样本,就有两个东西需要猜测或者估计:
(1)这个人是男的还是女的?
(2)男生和女生对应的身高的高斯分布的参数是多少?
EM算法要解决的问题是: (1)求出每一个样本属于哪个分布
(2)求出每一个分布对应的参数
EM算法步骤
1.初始化参数:先初始化男生身高的正态分布的参数:如均值=1.7,方差=0.1
2.计算每一个人更可能属于男生分布或者女生分布;
3.通过分为男生的n个人来重新估计男生身高分布的参数(最大似然估计),女生分布也按照相同的方式估计出来,更新分布。
4.这时候两个分布的概率也变了,然后重复步骤(1)至(3),直到参数不发生变化为止。
总结:其实EM算法就是先通过假设的参数把数据进行分类,然后通过分类的数据计算参数,接着对比计算的参数和假设的参数是否满足精度,不满足就返回去,满足就结束。
图像分割案例
原图及效果图:
opencv源码:
#include <opencv2/opencv.hpp>
#include <iostream>
using namespace cv;
using namespace cv::ml;
using namespace std;
int main(int argc, char** argv) {
Mat src = imread("1.jpg");
if (src.empty()) {
printf("could not load iamge...\n");
return -1;
}
char* inputWinTitle = "input image";
namedWindow(inputWinTitle, CV_WINDOW_AUTOSIZE);
imshow(inputWinTitle, src);
// 初始化
int numCluster = 3;
const Scalar colors[] = {
Scalar(255, 0, 0),
Scalar(0, 255, 0),
Scalar(0, 0, 255),
Scalar(255, 255, 0)
};
int width = src.cols;
int height = src.rows;
int dims = src.channels();
int nsamples = width*height;
Mat points(nsamples, dims, CV_64FC1);
Mat labels;
Mat result = Mat::zeros(src.size(), CV_8UC3);
// 图像RGB像素数据转换为样本数据
int index = 0;
for (int row = 0; row < height; row++) {
for (int col = 0; col < width; col++) {
index = row*width + col;
Vec3b rgb = src.at<Vec3b>(row, col);
points.at<double>(index, 0) = static_cast<int>(rgb[0]);
points.at<double>(index, 1) = static_cast<int>(rgb[1]);
points.at<double>(index, 2) = static_cast<int>(rgb[2]);
}
}
// EM Cluster Train
Ptr<EM> em_model = EM::create();
em_model->setClustersNumber(numCluster);
em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);
em_model->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1));
em_model->trainEM(points, noArray(), labels, noArray());
// 对每个像素标记颜色与显示
Mat sample(1, dims, CV_64FC1);//
double time = getTickCount();
int r = 0, g = 0, b = 0;
for (int row = 0; row < height; row++) {
for (int col = 0; col < width; col++) {
index = row*width + col;
b = src.at<Vec3b>(row, col)[0];
g = src.at<Vec3b>(row, col)[1];
r = src.at<Vec3b>(row, col)[2];
sample.at<double>(0, 0) = static_cast<double>(b);
sample.at<double>(0, 1) = static_cast<double>(g);
sample.at<double>(0, 2) = static_cast<double>(r);
int response = cvRound(em_model->predict2(sample, noArray())[1]);
Scalar c = colors[response];
result.at<Vec3b>(row, col)[0] = c[0];
result.at<Vec3b>(row, col)[1] = c[1];
result.at<Vec3b>(row, col)[2] = c[2];
}
}
printf("execution time(ms) : %.2f\n", (getTickCount() - time) / getTickFrequency() * 1000);
imshow("EM-Segmentation", result);
waitKey(0);
return 0;
}