首页 > 代码库 > 菜鸟笔记——用SSE指令作点乘和累加计算
菜鸟笔记——用SSE指令作点乘和累加计算
这几天在做学校的一个学习小项目,需要用到SIMD指令计算提速。也是第一次碰这个,看了一些资料和代码,模仿着写了两个函数。
void sse_mul_float(float *A, float *B, int cnt):两段内存float数据点乘,结果覆盖第一组内存。
float sse_acc_float(float *A, int cnt):一组内存float值累加。
注:
1. 没有考虑中间的精确问题,结果会有误差。
2. 每个函数包括指令操作部分和C++语句计算部分。本文附的代码注释介绍指令部分思路。
**3. 关于内存对齐,我不是很懂,所以下面的代码中判断是否对齐的相关语句我写的也不是很正确,所有后面都补上了一点C++的明白操作。
因此,有些指令操作也许没用上。
头文件
#include "time.h" #include "stdafx.h" #include<iostream> #include <stdlib.h> #include <stdio.h> #include <tchar.h> #include <math.h> #include <time.h> #include <windows.h> #include <iomanip> #include <sys/timeb.h> using namespace std;
sse_mul_float asm部分
1 //MOV EAX,1 ;request CPU feature flags 2 //CPUID ;0Fh, 0A2h CPUID instruction 3 //TEST EDX,4000000h ;test bit 26 (SSE2) 4 //JNZ >L18 ;SSE2 available 5 6 int cnt1; 7 int cnt2; 8 int cnt3; 9 10 //we process the majority by using SSE instructions 11 if (((int)A % 16) || ((int)B % 16)) //如果内存不对齐 12 { 13 14 cnt1 = cnt / 16; //该loop一轮处理16个float*float 15 cnt2 = (cnt - (16 * cnt1)) / 4; //该loop一轮处理4个float*float 16 cnt3 = (cnt - (16 * cnt1) - (4 * cnt2)); //该loop一轮处理1个float*float 17 18 _asm 19 { 20 21 mov edi, A; //先将内存地址放入指针寄存器 22 mov esi, B; 23 mov ecx, cnt1; //循环寄存器置值 24 jecxz ZERO; //如果数据量不超过16个,则跳过L1 25 26 L1: 27 28 29 //xmm 寄存器有128bit 30 //movups XMM,XMM/m128 31 //传128bit数据,不必对齐内存16字节. 32 movups xmm0, [edi]; 33 movups xmm1, [edi + 16]; 34 movups xmm2, [edi + 32]; 35 movups xmm3, [edi + 48]; 36 //为什么只载入4*4个float? 到上面看看这一轮需要处理多少数据 37 38 movups xmm4, [esi]; 39 movups xmm5, [esi + 16]; 40 movups xmm6, [esi + 32]; 41 movups xmm7, [esi + 48]; 42 43 //mulps XMM,XMM/m128 44 //寄存器按双字对齐, 45 //共4个单精度浮点数与目的寄存器里的4个对应相乘, 46 //结果送入目的寄存器, 内存变量必须对齐内存16字节. 47 mulps xmm0, xmm4; 48 mulps xmm1, xmm5; 49 mulps xmm2, xmm6; 50 mulps xmm3, xmm7; 51 52 //(一个float占4字节,也就是32bit) 53 //到这里,xmm0-3寄存器里都有了4个float的乘积结果 54 //然后回载到相应内存 55 movups[edi], xmm0; 56 movups[edi + 16], xmm1; 57 movups[edi + 32], xmm2; 58 movups[edi + 48], xmm3; 59 60 //记得给指针移位 61 //64=16 * 4 62 //每一轮处理了16次float * float,每一个float占4字节 63 //所以移位应该加64 64 add edi, 64; 65 add esi, 64; 66 67 loop L1; 68 69 ZERO: 70 mov ecx, cnt2; 71 jecxz ZERO1; 72 73 L2: 74 75 movups xmm0, [edi]; //对于4个float,一个xmm寄存器正好够用 76 movups xmm1, [esi]; 77 mulps xmm0, xmm1; //对应相乘,结果在xmm0 78 movups[edi], xmm0; //由xmm0回载内存 79 add edi, 16; //指针移位 80 add esi, 16; 81 82 loop L2; 83 84 ZERO1: 85 86 mov ecx, cnt3; 87 jecxz ZERO2; 88 89 mov eax, 0; 90 91 L3: 92 93 movd eax, [edi]; //对于单个float * float,无需sse指令 94 imul eax, [esi]; 95 movd[edi], eax; 96 add esi, 4; 97 add edi, 4; 98 99 loop L3; 100 101 ZERO2: 102 103 EMMS; //清空 104 105 } 106 107 } 108 else 109 { 110 111 cnt1 = cnt / 28; //该loop一轮处理28个float*float 112 cnt2 = (cnt - (28 * cnt1)) / 4; //该loop一轮处理4个float*float 113 cnt3 = (cnt - (28 * cnt1) - (4 * cnt2)); //该loop一轮处理1个float*float 114 115 _asm 116 { 117 118 119 mov edi, A; 120 mov esi, B; 121 mov ecx, cnt1; 122 jecxz AZERO; 123 124 AL1: 125 126 //movaps XMM, XMM / m128 127 //把源存储器内容值送入目的寄存器, 当有m128时, 必须对齐内存16字节, 也就是内存地址低4位为0. 128 movaps xmm0, [edi]; 129 movaps xmm1, [edi + 16]; 130 movaps xmm2, [edi + 32]; 131 movaps xmm3, [edi + 48]; 132 movaps xmm4, [edi + 64]; 133 movaps xmm5, [edi + 80]; 134 movaps xmm6, [edi + 96]; 135 //7*4=28,处理28个float*float 136 137 mulps xmm0, [esi]; //对应点乘 138 mulps xmm1, [esi + 16]; 139 mulps xmm2, [esi + 32]; 140 mulps xmm3, [esi + 48]; 141 mulps xmm4, [esi + 64]; 142 mulps xmm5, [esi + 80]; 143 mulps xmm6, [esi + 96]; 144 145 movaps[edi], xmm0; //回载 146 movaps[edi + 16], xmm1; 147 movaps[edi + 32], xmm2; 148 movaps[edi + 48], xmm3; 149 movaps[edi + 64], xmm4; 150 movaps[edi + 80], xmm5; 151 movaps[edi + 96], xmm6; 152 153 add edi, 112; 154 add esi, 112; 155 156 loop AL1; 157 158 AZERO: 159 mov ecx, cnt2; 160 jecxz AZERO1; 161 162 AL2: 163 164 movaps xmm0, [edi]; 165 mulps xmm0, [esi]; 166 movaps[edi], xmm0; 167 add edi, 16; 168 add esi, 16; 169 170 loop AL2; 171 172 AZERO1: 173 174 mov ecx, cnt3; 175 jecxz AZERO2; 176 177 mov eax, 0; 178 179 AL3: 180 181 movd eax, [edi]; 182 imul eax, [esi]; 183 movd[edi], eax; 184 add esi, 4; 185 add edi, 4; 186 187 loop AL3; 188 189 AZERO2: 190 191 EMMS; 192 193 } 194 195 }
由于内存对齐的问题,导致末尾有部分数据不正常,特添加C++部分修复。
sse_mul_float c++部分
1 int start; 2 start = cnt - (cnt % 4); 3 for (int i = start; i < cnt; i++) 4 { 5 A[i] *= B[i]; 6 }
用于累加的这个函数,分两块。一块是用指令把大部分数据处理掉,而极少部分数据使用C++语句,这样能各取所长。
sse_acc_float asm部分
1 float temp = 0; 2 3 int cnt1; 4 int cnt2; 5 int cnt3; 6 int select = 0; 7 8 //we process the majority by using SSE instructions 9 if (((int)A % 16)) //unaligned 如果这次调用,内存数据不对齐 10 { 11 select = 1; 12 13 cnt1 = cnt / 24; 14 cnt2 = (cnt - (24 * cnt1)) / 8; 15 cnt3 = (cnt - (24 * cnt1) - (8 * cnt2)); 16 17 __asm 18 { 19 20 mov edi, A; 21 mov ecx, cnt1; 22 pxor xmm0, xmm0; 23 jecxz ZERO; 24 25 L1: 26 27 movups xmm1, [edi]; 28 movups xmm2, [edi + 16]; 29 movups xmm3, [edi + 32]; 30 movups xmm4, [edi + 48]; 31 movups xmm5, [edi + 64]; 32 movups xmm6, [edi + 80]; 33 34 //addps 对应相加 35 //结果返回目的寄存器 36 addps xmm1, xmm2; 37 addps xmm3, xmm4; 38 addps xmm5, xmm6; 39 40 addps xmm1, xmm5; 41 addps xmm0, xmm3; 42 43 addps xmm0, xmm1; 44 //至此,xmm0内4个float的和就是24个float的和 45 46 add edi, 96; 47 48 loop L1; 49 50 ZERO: 51 52 53 movd ebx, xmm0; //低4个字节(第一个float)传入ebx 54 psrldq xmm0, 4; //xmm0右移4字节 55 movd eax, xmm0; //右移后,低4个字节(第二个float)传入eax 56 57 movd xmm1, eax; //第一个float传入xmm1低32bit 58 movd xmm2, ebx; //第二个float传入xmm2低32bit 59 addps xmm1, xmm2; //两个寄存器内4个float对应相加 60 movd eax, xmm1; //只取我们要的低位float,传入eax 61 movd xmm3, eax; //第一和第二个float的和存在xmm3低32位 62 psrldq xmm0, 4; //又截掉一个float 63 64 65 movd ebx, xmm0; //第三个float进ebx 66 psrldq xmm0, 4; //截掉第三个float 67 movd eax, xmm0; //第四个float进eax 68 69 movd xmm1, eax; 70 movd xmm2, ebx; 71 addps xmm1, xmm2; //第三和第四个float的和存在xmm1低32位 72 movd eax, xmm1; 73 movd xmm4, eax; 74 addps xmm3, xmm4; //4个float的和存在xmm3低32位 75 76 77 movd eax, xmm3; 78 mov temp, eax; //这部分求和存在temp地址区 79 80 81 82 EMMS; 83 84 } 85 } 86 else // aligned 如果这次调用,内存数据对齐 87 { 88 select = 2; 89 90 cnt1 = cnt / 56; 91 cnt2 = (cnt - (56 * cnt1)) / 8; 92 cnt3 = (cnt - (56 * cnt1) - (8 * cnt2)); 93 94 __asm 95 { 96 97 mov edi, A; 98 mov ecx, cnt1; 99 pxor xmm0, xmm0; 100 jecxz ZZERO; 101 102 LL1: 103 104 movups xmm1, [edi]; 105 movups xmm2, [edi + 16]; 106 movups xmm3, [edi + 32]; 107 movups xmm4, [edi + 48]; 108 movups xmm5, [edi + 64]; 109 movups xmm6, [edi + 80]; 110 111 addps xmm1, xmm2; 112 addps xmm3, xmm4; 113 addps xmm5, xmm6; 114 addps xmm1, xmm5; 115 addps xmm0, xmm3; 116 addps xmm0, xmm1; 117 118 add edi, 96; 119 120 movups xmm1, [edi]; 121 movups xmm2, [edi + 16]; 122 movups xmm3, [edi + 32]; 123 movups xmm4, [edi + 48]; 124 movups xmm5, [edi + 64]; 125 movups xmm6, [edi + 80]; 126 127 addps xmm1, xmm2; 128 addps xmm3, xmm4; 129 addps xmm5, xmm6; 130 addps xmm1, xmm5; 131 addps xmm0, xmm3; 132 addps xmm0, xmm1; 133 134 add edi, 96; 135 136 movups xmm1, [edi]; 137 movups xmm2, [edi + 16]; 138 139 addps xmm1, xmm2; 140 addps xmm0, xmm1; 141 142 add edi, 32; 143 144 loop LL1; 145 146 ZZERO: 147 148 149 movd ebx, xmm0; 150 psrldq xmm0, 4; 151 movd eax, xmm0; 152 153 movd xmm1, eax; 154 movd xmm2, ebx; 155 addps xmm1, xmm2; 156 movd eax, xmm1; 157 movd xmm3, eax; 158 psrldq xmm0, 4; 159 160 161 movd ebx, xmm0; 162 psrldq xmm0, 4; 163 movd eax, xmm0; 164 165 movd xmm1, eax; 166 movd xmm2, ebx; 167 addps xmm1, xmm2; 168 movd eax, xmm1; 169 movd xmm4, eax; 170 addps xmm3, xmm4; 171 172 173 movd eax, xmm3; 174 mov temp, eax; 175 176 EMMS; 177 178 } 179 }
sse_acc_float c++部分
//上面的select记录本次调用sse_acc_float时,数据是否对齐内存 //下面分情况把剩余的和累加 int start; float c = 0.0f; if (select == 1) { start = cnt - (cnt % 24); for (int i = start; i < cnt; i++) { c += A[i]; } } else { start = cnt - (cnt % 56); for (int i = start; i < cnt; i++) { c += A[i]; } } //temp 是用指令计算 ,大部分数据的和 //c 是用C++语句计算, 所有数据模24或者56剩余部分数据的和 return(temp + c);
我是一名编程菜鸟,有什么技术上的问题,欢迎讨论和交流指正。谢谢!
获取全部源码:点此 dot_acc.cpp
菜鸟笔记——用SSE指令作点乘和累加计算
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。