首页 > 代码库 > 【EM】C++代码实现

【EM】C++代码实现

看了原理和比人的代码后,终于自己写了一个EM的实现。

我从网上找了一些身高性别的数据,用EM算法通过身高信息来识别性别。

实现的效果还行,正确率有84% (初始数据 男生170 女生160 方差都是10)

                                 79%  (初始数据 男生165 女生150 方差都是10)

正确率与初始值有关。

/*试图用EM算法来根据输入的身高来区分性别*/#include<iostream>#include<fstream>#include<algorithm>#include<vector>using namespace std;#define PI 3.14159#define max(x,y) (x > y ? x : y)typedef struct FLOAT2{    float f1;    float f2;}FLOAT2;typedef struct Gaussian{    float mean;    float var;}Gaussian;typedef struct EMData{    char sex;    float fHeight;}EMData;//获取身高性别数据int getdata(vector<EMData> &Data){    ifstream fin;    fin.open("data.txt");    if(!fin)    {        cout<<"error: can‘t open the file."<<endl;        return -1;    }    while(!fin.eof())    {        char c[10];        float height;        fin >> c >> height;        EMData data;        data.sex = c[0];        data.fHeight = height;        Data.push_back(data);    }    return 0;}//根据身高数据区分性别, 返回正确率float predict(vector<EMData> Data){    //设符合正态分布    Gaussian sex[2];    float a[2]; //男女生所占百分比    float t = 1;    float tlimit = 0.000001; //收敛条件    //赋初值 下标0表示男生 1表示女生    sex[0].mean = 180.0;    sex[0].var = 10.0;    sex[1].mean = 150.0;    sex[1].var = 10.0;    a[0] = 0.5;    a[1] = 0.5;    while(t > tlimit)    {        Gaussian sex_old[2];        float a_old[2];        sex_old[0] = sex[0];        sex_old[1] = sex[1];        a_old[0] = a[0];        a_old[1] = a[1];        //计算每个样本分别被两个模型抽中的概率        vector<FLOAT2> px;            vector<EMData>::iterator it;        for(it = Data.begin(); it < Data.end(); it++)        {            FLOAT2 p;            p.f1 = 1/(sqrt(2 * PI * sex[0].var)) * exp(-(it->fHeight - sex[0].mean) * (it->fHeight - sex[0].mean) / (2 * sex[0].var));            p.f2 = 1/(sqrt(2 * PI * sex[1].var)) * exp(-(it->fHeight - sex[1].mean) * (it->fHeight - sex[1].mean) / (2 * sex[1].var));            px.push_back(p);        }        //E步        //计算每个样本属于男生或女生的概率        vector<FLOAT2>::iterator it2;        for(it2 = px.begin(); it2 < px.end(); it2++)        {            float sum = 0.0;            (*it2).f1 *= a[0];            sum += (*it2).f1;            (*it2).f2 *= a[1];            sum += (*it2).f2;            (*it2).f1 = (*it2).f1/sum;            (*it2).f2 = (*it2).f2/sum;        }        //M步        float sum_male = 0, sum_female = 0;        float sum_mean_male = 0, sum_mean_female = 0;        for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++)        {            sum_male += (*it2).f1;            sum_female += (*it2).f2;            sum_mean_male += (*it2).f1 * (it->fHeight);            sum_mean_female += (*it2).f2 * (it->fHeight);        }        //更新a        a[0] = sum_male/(sum_male + sum_female);        a[1] = sum_female/(sum_male + sum_female);        //更新均值        sex[0].mean = sum_mean_male/ sum_male;        sex[1].mean = sum_mean_female/ sum_female;        //更新方差        float sum_var_male = 0, sum_var_female = 0;        for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++)        {            sum_var_male += (*it2).f1 * ((it->fHeight) - sex[0].mean) * ((it->fHeight) - sex[0].mean);            sum_var_female += (*it2).f2 * ((it->fHeight) - sex[1].mean) * ((it->fHeight) - sex[1].mean);        }        sex[0].var = sum_var_male / sum_male;        sex[1].var = sum_var_female / sum_female;        //计算变化率        t = max((a[0] - a_old[0])/a_old[0], (a[1] - a_old[1])/a_old[1]);        t = max(t, (sex[0].mean - sex_old[0].mean)/sex_old[0].mean);        t = max(t, (sex[1].mean - sex_old[1].mean)/sex_old[1].mean);        t = max(t, (sex[0].var - sex_old[0].var)/sex_old[0].var);        t = max(t, (sex[1].var - sex_old[1].var)/sex_old[1].var);    }    //计算正确率    int correct_num = 0;    float correct_rate = 0;    vector<EMData>::iterator it;    for(it = Data.begin(); it < Data.end(); it++)    {        float p[2];        char csex;        for(int i = 0; i < 2; i++)        {            p[i] = 1/(sqrt(2 * PI * sex[i].var)) * exp(-(it->fHeight - sex[i].mean) * (it->fHeight - sex[i].mean) / (2 * sex[i].var));        }        csex = (p[0] > p[1]) ? m : f;        if(csex == it->sex)            correct_num++;    }    correct_rate = (float)correct_num / Data.size();    return correct_rate;}int main(){    vector<EMData> Data;    getdata(Data);    float correct_rate = predict(Data);    cout << "correct rate = "<< correct_rate << endl;    return 0;}

 

数据:data.txt内容

male    164female    156male    168female    160female    162male    187female    162male    167female    160.5female    160female    158female    164female    165male    174female    166female    158male     162male    175male    170female    161female    169female    161female    160female    167male    176male    169male    178male    165female    155male    183male    171male    179female    154male    172female    172male    173male    172male    175male    160male    160male    160male    175male    163male    181male    172male    175male    175male    167male    172male    169male    172male    175male    172male    170male    158male    167male    164male    176male    182male    173male    176male    163male    166male    162male    169male    163male    163male    176male    169male    173male    163male    167male    176male    168male    167male    170female    155female    157female    165female    156female    155female    156female    160female    158female    162female    162female    155female    163female    160female    162female    165female    159female    147female    163female    157female    160female    162female    158female    155female    165female    161female    159female    163female    158female    155female    162female    157female    159female    152female    156female    165female    154female    156female    162

 

【EM】C++代码实现