当前位置: 首页>后端>正文

EM算法解决混合模型R代码 em算法的原理

EM算法的定义及应用范围

EM算法,即最大期望算法(Expectation Maximization Algorithm,又译期望最大化算法),是一种迭代算法,用于含有隐变量(latent variable)的概率参数模型的最大似然估计或极大后验概率估计。

EM算法应用于高斯混合模型(GMM)、聚类、隐式马尔科夫算法(HMM)、基于概率的PLSA模型等等。

问题引出

问题一:

现在一个班里有50个男生,50个女生。我们假定男生的身高服从正态分布

,女生的身高则服从另一个正态分布 

。这时候我们可以用极大似然法(MLE),分别通过这50个男生和50个女生的样本来估计这两个正态分布的参数。

问题二:

但现在我们让情况复杂一点,就是这50个男生和50个女生混在一起了。我们拥有100个人的身高数据,但是我们不知道抽取的那100个人里面的每一个人到底是从男生的那个身高分布里面抽取的,还是女生的那个身高分布抽取的。 用数学的语言就是,抽取得到的每个样本都不知道是从哪个分布抽取的。 这个时候,对于每一个样本,就有两个东西需要猜测或者估计:

                                                      (1)这个人是男的还是女的?

                                                      (2)男生和女生对应的身高的高斯分布的参数是多少?

EM算法要解决的问题是:     (1)求出每一个样本属于哪个分布

                                                      (2)求出每一个分布对应的参数

EM算法步骤

   1.初始化参数:先初始化男生身高的正态分布的参数:如均值=1.7,方差=0.1

   2.计算每一个人更可能属于男生分布或者女生分布;

   3.通过分为男生的n个人来重新估计男生身高分布的参数(最大似然估计),女生分布也按照相同的方式估计出来,更新分布。

   4.这时候两个分布的概率也变了,然后重复步骤(1)至(3),直到参数不发生变化为止。

EM算法解决混合模型R代码 em算法的原理,EM算法解决混合模型R代码 em算法的原理_EM算法解决混合模型R代码,第1张

总结:其实EM算法就是先通过假设的参数把数据进行分类,然后通过分类的数据计算参数,接着对比计算的参数和假设的参数是否满足精度,不满足就返回去,满足就结束。

图像分割案例

原图及效果图:

EM算法解决混合模型R代码 em算法的原理,EM算法解决混合模型R代码 em算法的原理_数据_02,第2张

EM算法解决混合模型R代码 em算法的原理,EM算法解决混合模型R代码 em算法的原理_数据_03,第3张

opencv源码:

#include <opencv2/opencv.hpp>

#include <iostream>



using namespace cv;

using namespace cv::ml;

using namespace std;



int main(int argc, char** argv) {

	Mat src = imread("1.jpg");

	if (src.empty()) {

		printf("could not load iamge...\n");

		return -1;

	}

	char* inputWinTitle = "input image";

	namedWindow(inputWinTitle, CV_WINDOW_AUTOSIZE);

	imshow(inputWinTitle, src);



	// 初始化

	int numCluster = 3;

	const Scalar colors[] = {

		Scalar(255, 0, 0),

		Scalar(0, 255, 0),

		Scalar(0, 0, 255),

		Scalar(255, 255, 0)

	};



	int width = src.cols;

	int height = src.rows;

	int dims = src.channels();

	int nsamples = width*height;

	Mat points(nsamples, dims, CV_64FC1);

	Mat labels;

	Mat result = Mat::zeros(src.size(), CV_8UC3);



	// 图像RGB像素数据转换为样本数据 

	int index = 0;

	for (int row = 0; row < height; row++) {

		for (int col = 0; col < width; col++) {

			index = row*width + col;

			Vec3b rgb = src.at<Vec3b>(row, col);

			points.at<double>(index, 0) = static_cast<int>(rgb[0]);

			points.at<double>(index, 1) = static_cast<int>(rgb[1]);

			points.at<double>(index, 2) = static_cast<int>(rgb[2]);

		}

	}



	// EM Cluster Train

	Ptr<EM> em_model = EM::create();

	em_model->setClustersNumber(numCluster);

	em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);

	em_model->setTermCriteria(TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 100, 0.1));

	em_model->trainEM(points, noArray(), labels, noArray());



	// 对每个像素标记颜色与显示

	Mat sample(1, dims, CV_64FC1);//

	double time = getTickCount();

	int r = 0, g = 0, b = 0;

	for (int row = 0; row < height; row++) {

		for (int col = 0; col < width; col++) {

			index = row*width + col;



			b = src.at<Vec3b>(row, col)[0];

			g = src.at<Vec3b>(row, col)[1];

			r = src.at<Vec3b>(row, col)[2];

			sample.at<double>(0, 0) = static_cast<double>(b);

			sample.at<double>(0, 1) = static_cast<double>(g);

			sample.at<double>(0, 2) = static_cast<double>(r);

			int response = cvRound(em_model->predict2(sample, noArray())[1]);

			Scalar c = colors[response];

			result.at<Vec3b>(row, col)[0] = c[0];

			result.at<Vec3b>(row, col)[1] = c[1];

			result.at<Vec3b>(row, col)[2] = c[2];

		}

	}

	printf("execution time(ms) : %.2f\n", (getTickCount() - time) / getTickFrequency() * 1000);

	imshow("EM-Segmentation", result);



	waitKey(0);

	return 0;

}

 


https://www.xamrdz.com/backend/3bc1960892.html

相关文章: