首页 > 代码库 > matlab 与c/c++ 混合MEX编程的几个需要注意的地方

matlab 与c/c++ 混合MEX编程的几个需要注意的地方

      最近做一个机器学习的课题,主体是matlab写的,其中有部分训练的核心算法是用c++写的,因为有太多的循环和数值计算用c++比较快。这也是我第一次用c++写matlab的模块,感觉走了很多弯路,下面给大家分享一点经验。

      matlab中的c++编程称为mex编程:matlab executive matlab 可执行文件,至于其中的具体机制我不是很清楚,有的大神会比较清楚编译期间产生的各种文件。

 

1) mex编程中指针和索引:

      matlab中默认的数据类型是double,用class()函数可以看到变量的数据类型:

                                    

    matlab代码如下:

mex mex.cpp -g;a = [1.1,2.1,3;4,5,6;7,8,9]mex(a)

      命令mex 用来编译mex文件,上面代码中  mex  mex.cpp ‘-g’ 编译了mex.cpp这个c++文件,编译完成之后会生成一个“mex.mexw64”的文件,后缀名说明这是在win64下编译完成的mex文件,后面的‘-g‘是一个附加参数,在这里不用理解。编译后的mex文件可以当matlab函数使用。    

      在matlab代码文件同目录下的c文件mex.cpp代码如下:

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){    double *input;    input = mxGetPr(prhs[0]);    printf("第一个值%f\n",*input);    printf("第二个值%f\n",*(input+1));}

      mexFunction有四个参数,nlhs( number left hand s):左边参数个数,也就是matlab函数输出值得个数,mxArray *plhs[]是一个指针数组,数组中的每一个元素都是一个指针,指向输出的矩阵;nrhs 是右边参数个数,也就是输入参数的个数,mxArray *prhs[]数组中的每个指针指向输入矩阵。mxGetPr()函数返回一个double*型的指针,指向矩阵的第一个元素,在matlab代码中调用:mex(a),那也就是 prhs[0]是输入矩阵a的地址,而 input = mxGetPr(prhs[0]) ,input指向了a第一个元素1 。
那矩阵第一排第二列的值a(1,2)的地址是多少呢?是(input+1)吗?在这里我们运行上面的matlab代码,得到的结果如下:

     可以看出,输出的*(input+1)是4,也就是说,c++中的matlab矩阵是按列进行索引的。这里是一个需要注意的地方,因为很多地方要对matlab输入的矩阵进行遍历得到矩阵的元素值,如果索引出错,那就完全错了。其实这里的内在原因,是因为在matlab中矩阵是按列进行索引的,而c++中指针式按行往后加的。

     有很多函数可以方便我们对矩阵进行索引,uint32 mxGetM(mxArray *)输入一个矩阵的指针,返回该矩阵的行数,uint32 mxGetN(mxArray *)返回列数,对行数和列数适当的计算,可以方便的访问矩阵元素,例如,访问a(i,j):   *(input+N*(j-1)+(i-1))  ,N为矩阵行数,这里需要-1的原因是,matlab的行数列数从1开始计数,而c的数组则从0开始索引。

 

2)mex编程中的数据类型与指针移位的重要关系,mxGetPr() 与 mxGetData():

     前面说过,matlab里的默认数据类型是double,那么,如果把mex函数的输入矩阵的数据类型转换一下,会出现什么结果呢?

matlab 代码:

1 clc2 mex mex.cpp -g;3 a = [1.1,2.1,3;4,5,6;7,8,9];4 a=single(a)5 mex(a)

c++代码:

1 #include "mex.h"2 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])3 {4     double *input;5     input = mxGetPr(prhs[0]);6     printf("第一个值%f\n",*input);7     printf("第二个值%f\n",*(input+1));8 }

c++代码并没有变,matlab代码也仅仅进行了一个数据类型转换,我们看看输出结果:

      可以看到这里输出的已经不是我们期望的数值了。在我调试mex代码的时候这个问题苦恼了我很久,因为mex不方便调试,很多时候输出的结果不是想要的,而且我的输入矩阵都是上万维的,很难调试。这里输入矩阵a变成了single单精度类型,前面我们说过,mxGetPr()返回double类型的指针,当我们用double类型指针访问一个单精度(在c++)中我们称之为浮点型float的数据的时候,当然会发成内存越界,用取值符号*去取值的时候超过了数据的内存块,因此发生错误,如果我们修改c++代码:

1 #include "mex.h"2 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])3 {4     float* input;5     input = (float*)mxGetPr(prhs[0]);6     printf("第一个值%f\n",*input);7     printf("第二个值%f\n",*(input+1));8 }

     将input类型设置成float,并将mxGetPr()的返回类型强制转换为float*就可以了。在这里还有一个函数mxGetData()也可以返回输入矩阵的头地址,只不过mxGetData()返回的是char*类型的指针,而mxGetPr()返回的是double*类型的指针,可以根据自己的需要选取函数,或者转换指针类型。如果指针类型不对,极有可能造成内存访问错误,导致matlab死掉。

 

3) nlhs 与 nrhs的作用 

     mexFunction函数中,两个指针参数分别指向输入输出的矩阵,而nrhs和nlhs分别记录输入输出矩阵的个数,在一般的操作中,我们仅仅对输入矩阵进行取值,运算,对输出矩阵进行赋值,nrhs和nlhs不是很常用,但是也是极其重要的。例如,在上面的代码中,如果我在matlab代码中这样调用mex:mex(),不输入任何参数,matlab就会马上死掉。因为在mex文件的cpp代码中,你用指针访问了输入矩阵的值,而在参数中你没有给mex输入任何参数,使得矩阵指针为野指针,导致内存错误。如果编码中出现这种参数不对的情况,将导致matlab频繁死掉,我的工作中数据特别多,准备数据需要几十分钟,这样让我非常痛苦。解决的方法就是利用nlhs和nrhs这两个参数。在mexFunction中判断nlhs的值来判断输入参数的个数,用nrhs判断输入参数的个数。如果输入参数少于某个值或者不满足你的要求可以让mexFunction直接return,避免后续的程序导致内存错误。


matlab 与c/c++ 混合MEX编程的几个需要注意的地方