首页 > 代码库 > 【EM】C++代码实现
【EM】C++代码实现
看了原理和比人的代码后,终于自己写了一个EM的实现。
我从网上找了一些身高性别的数据,用EM算法通过身高信息来识别性别。
实现的效果还行,正确率有84% (初始数据 男生170 女生160 方差都是10)
79% (初始数据 男生165 女生150 方差都是10)
正确率与初始值有关。
/*试图用EM算法来根据输入的身高来区分性别*/#include<iostream>#include<fstream>#include<algorithm>#include<vector>using namespace std;#define PI 3.14159#define max(x,y) (x > y ? x : y)typedef struct FLOAT2{ float f1; float f2;}FLOAT2;typedef struct Gaussian{ float mean; float var;}Gaussian;typedef struct EMData{ char sex; float fHeight;}EMData;//获取身高性别数据int getdata(vector<EMData> &Data){ ifstream fin; fin.open("data.txt"); if(!fin) { cout<<"error: can‘t open the file."<<endl; return -1; } while(!fin.eof()) { char c[10]; float height; fin >> c >> height; EMData data; data.sex = c[0]; data.fHeight = height; Data.push_back(data); } return 0;}//根据身高数据区分性别, 返回正确率float predict(vector<EMData> Data){ //设符合正态分布 Gaussian sex[2]; float a[2]; //男女生所占百分比 float t = 1; float tlimit = 0.000001; //收敛条件 //赋初值 下标0表示男生 1表示女生 sex[0].mean = 180.0; sex[0].var = 10.0; sex[1].mean = 150.0; sex[1].var = 10.0; a[0] = 0.5; a[1] = 0.5; while(t > tlimit) { Gaussian sex_old[2]; float a_old[2]; sex_old[0] = sex[0]; sex_old[1] = sex[1]; a_old[0] = a[0]; a_old[1] = a[1]; //计算每个样本分别被两个模型抽中的概率 vector<FLOAT2> px; vector<EMData>::iterator it; for(it = Data.begin(); it < Data.end(); it++) { FLOAT2 p; p.f1 = 1/(sqrt(2 * PI * sex[0].var)) * exp(-(it->fHeight - sex[0].mean) * (it->fHeight - sex[0].mean) / (2 * sex[0].var)); p.f2 = 1/(sqrt(2 * PI * sex[1].var)) * exp(-(it->fHeight - sex[1].mean) * (it->fHeight - sex[1].mean) / (2 * sex[1].var)); px.push_back(p); } //E步 //计算每个样本属于男生或女生的概率 vector<FLOAT2>::iterator it2; for(it2 = px.begin(); it2 < px.end(); it2++) { float sum = 0.0; (*it2).f1 *= a[0]; sum += (*it2).f1; (*it2).f2 *= a[1]; sum += (*it2).f2; (*it2).f1 = (*it2).f1/sum; (*it2).f2 = (*it2).f2/sum; } //M步 float sum_male = 0, sum_female = 0; float sum_mean_male = 0, sum_mean_female = 0; for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++) { sum_male += (*it2).f1; sum_female += (*it2).f2; sum_mean_male += (*it2).f1 * (it->fHeight); sum_mean_female += (*it2).f2 * (it->fHeight); } //更新a a[0] = sum_male/(sum_male + sum_female); a[1] = sum_female/(sum_male + sum_female); //更新均值 sex[0].mean = sum_mean_male/ sum_male; sex[1].mean = sum_mean_female/ sum_female; //更新方差 float sum_var_male = 0, sum_var_female = 0; for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++) { sum_var_male += (*it2).f1 * ((it->fHeight) - sex[0].mean) * ((it->fHeight) - sex[0].mean); sum_var_female += (*it2).f2 * ((it->fHeight) - sex[1].mean) * ((it->fHeight) - sex[1].mean); } sex[0].var = sum_var_male / sum_male; sex[1].var = sum_var_female / sum_female; //计算变化率 t = max((a[0] - a_old[0])/a_old[0], (a[1] - a_old[1])/a_old[1]); t = max(t, (sex[0].mean - sex_old[0].mean)/sex_old[0].mean); t = max(t, (sex[1].mean - sex_old[1].mean)/sex_old[1].mean); t = max(t, (sex[0].var - sex_old[0].var)/sex_old[0].var); t = max(t, (sex[1].var - sex_old[1].var)/sex_old[1].var); } //计算正确率 int correct_num = 0; float correct_rate = 0; vector<EMData>::iterator it; for(it = Data.begin(); it < Data.end(); it++) { float p[2]; char csex; for(int i = 0; i < 2; i++) { p[i] = 1/(sqrt(2 * PI * sex[i].var)) * exp(-(it->fHeight - sex[i].mean) * (it->fHeight - sex[i].mean) / (2 * sex[i].var)); } csex = (p[0] > p[1]) ? ‘m‘ : ‘f‘; if(csex == it->sex) correct_num++; } correct_rate = (float)correct_num / Data.size(); return correct_rate;}int main(){ vector<EMData> Data; getdata(Data); float correct_rate = predict(Data); cout << "correct rate = "<< correct_rate << endl; return 0;}
数据:data.txt内容
male 164female 156male 168female 160female 162male 187female 162male 167female 160.5female 160female 158female 164female 165male 174female 166female 158male 162male 175male 170female 161female 169female 161female 160female 167male 176male 169male 178male 165female 155male 183male 171male 179female 154male 172female 172male 173male 172male 175male 160male 160male 160male 175male 163male 181male 172male 175male 175male 167male 172male 169male 172male 175male 172male 170male 158male 167male 164male 176male 182male 173male 176male 163male 166male 162male 169male 163male 163male 176male 169male 173male 163male 167male 176male 168male 167male 170female 155female 157female 165female 156female 155female 156female 160female 158female 162female 162female 155female 163female 160female 162female 165female 159female 147female 163female 157female 160female 162female 158female 155female 165female 161female 159female 163female 158female 155female 162female 157female 159female 152female 156female 165female 154female 156female 162
【EM】C++代码实现
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。