分支判断的语句一般来说是不太适合进行SSE优化的,因为他会破坏代码的并行性,但是也不是所有的都是这样的,在合适的场景中运用SSE还是能对分支预测进行一定的优化的,我们这里以某一个算法的部分代码为例进行讲解。
在某一个版本的USM锐化算法中有这样的一段代码:
int IM_UnsharpMask(unsigned char *Src, unsigned char *Dest, int Width, int Height, int Stride, int Radius, int Amount, int Threshold)
{
int Channel = Stride / Width;
if ((Src == NULL) || (Dest == NULL)) return IM_STATUS_NULLREFRENCE;
if ((Width <= 0) || (Height <= 0)) return IM_STATUS_INVALIDPARAMETER;
if ((Channel != 1) && (Channel != 3) && (Channel != 4)) return IM_STATUS_INVALIDPARAMETER;
int Status = IM_STATUS_OK;
Status = IM_ExpBlur(Src, Dest, Width, Height, Stride, Radius); // 这里标准过程是用IM_GaussBlur代替
if (Status != IM_STATUS_OK) return Status;
const float Inv255 = 1.0f / 255.0f;
int *Table = (int *)malloc(511 * 256 * sizeof(int));
if (Table == NULL) return IM_STATUS_OUTOFMEMORY;
for (int Y = 0; Y < 256; Y++)
{
float TempUp = Amount * sqrtf(1.0f - Y * Inv255) / 100.0f;
float TempDown = Amount * sqrtf(Y * Inv255) / 100.0f;
for (int X = -255; X <= 255; X++)
{
int Diff = X;
if (Diff >= Threshold)
{
Diff -= Threshold;
Table[((X + 255) << 8) + Y] = IM_ClampToByte(int(Diff * TempUp + 0.5f) + Y);
}
else if (Diff < -Threshold)
{
Diff += Threshold;
Table[((X + 255) << 8) + Y] = IM_ClampToByte(int(Diff * TempDown + 0.5f) + Y);
}
else
{
Table[((X + 255) << 8) + Y] = Y; // 不做变化
}
}
}
for (int Y = 0; Y < Height * Stride; Y++) // 分四路并行速度只有一点点提高
{
Dest[Y] = Table[((Src[Y] - Dest[Y] + 255) << 8) + Src[Y]];
}
free(Table);
return IM_STATUS_OK;
}
这个USM锐化的算法参考自:https://github.com/pluginguy/plugins/tree/master/USM2,源代码中的算法还提供了对高光、暗调和中间调进行不同调节的参数,我这里对他那个代码进行了适度的修改和简化,并且用查找表进行了优化。这个github的作者还提供了关于高斯模糊方面的资料,是个不错的参考点。
上述代码起始已经很高效了,复杂的浮点和开方计算都已经用查表的形式进行了简化,实测一副1080P的24位图像大处理时间大约在14.5ms左右,而其中的IM_ExpBlur耗时约有6.75ms,建立查找表花了0.75ms,后面的遍历图像进行查找表替换使用了7ms,注意前面的IM_ExpBlur的时间是已经进行了SSE编码后的优化时间。
查找表其实本身也是个耗时的工作,因为这个可能有着严重的cache miss,特别是查找表比较大的时候。但是查找表本身呢在目前SIMD框架下是无法使用SSE优化的(除非是16个字节的查找表,可以使用_mm_shuffle_epi8来优化),因此,如果查找表本身的建立算法并不特别复杂,是可以考虑使用SSE来对表中每个元素进行直接的实现的,鉴于此,我们来考虑上述代码的查找表的直接SSE实现。
为了表示清楚,我们把上述算法的非查找表方式实现的代码整理出来如下:
int IM_UnsharpMask(unsigned char *Src, unsigned char *Dest, int Width, int Height, int Stride, int Radius, int Amount, int Threshold) { int Channel = Stride / Width; if ((Src == NULL) || (Dest == NULL)) return IM_STATUS_NULLREFRENCE; if ((Width <= 0) || (Height <= 0)) return IM_STATUS_INVALIDPARAMETER; if ((Channel != 1) && (Channel != 3) && (Channel != 4)) return IM_STATUS_INVALIDPARAMETER; int Status = IM_STATUS_OK; Status = IM_ExpBlur(Src, Dest, Width, Height, Stride, Radius); // 这里标准过程是用IM_GaussBlur代替 if (Status != IM_STATUS_OK) return Status; float Adjust = Amount / 100.0f / sqrtf(255.0f); for (int Y = 0; Y < Height * Stride; Y++) { int Diff = Src[Y] - Dest[Y]; if (Diff >= Threshold) { Dest[Y] = IM_ClampToByte(int((Diff - Threshold) * Adjust * sqrtf(255.0f - Src[Y]) + 0.5f) + Src[Y]); } else if (Diff < -Threshold) { Dest[Y] = IM_ClampToByte(int((Diff + Threshold) * Adjust * sqrtf((float)Src[Y]) + 0.5f) + Src[Y]); } else { Dest[Y] = Src[Y]; // 不做变化 } } return IM_STATUS_OK; }
注意为减少计算我已经把一些重复的计算提取到Adjust变量中,其中的/sqrtf(255.0f)可以让循环内部的sqrtf的参数少一次乘法计算,并且在后面我们还可以看到他起到了另外一个特殊的作用。运行上述代码的同参数同照片耗时变为了55ms左右,可见查找表的优化也是很给力的。
我注意到这段代码已经有很久了,也一直想使用SSE优化他们,但苦于能力,一直未得良方,不过最近过年重新审视这段代码,发现只要手指按住键盘,总会有新大陆发现的。
第一方案:既然SSE不太好做分支判断,我就把所有分支的结果都计算出来,最后再根据分支条件做数据融合不就可以了吗,可以肯定SSE计算每个分支的速度肯定比C快,但是如果要每个分支都计算,这个增加的耗时和加速的时间比例如何呢,只有实践才知道,于是我硬着头皮把他们用SSE做个硬编码,代码如下所示:
// 实在没有好的办法,极端情况下把所有的分支的结果都算出来,然后在最后根据判断条件合成,比如下面的代码,写出来后比原始的查找表方式也还是要快一点的。 int IM_UnsharpMask(unsigned char *Src, unsigned char *Dest, int Width, int Height, int Stride, int Radius, int Amount, int Threshold) { int Channel = Stride / Width; if ((Src == NULL) || (Dest == NULL)) return IM_STATUS_NULLREFRENCE; if ((Width <= 0) || (Height <= 0)) return IM_STATUS_INVALIDPARAMETER; if ((Channel != 1) && (Channel != 3) && (Channel != 4)) return IM_STATUS_INVALIDPARAMETER; int Status = IM_STATUS_OK; Status = IM_ExpBlur(Src, Dest, Width, Height, Stride, Radius); if (Status != IM_STATUS_OK) return Status; const float Adjust = Amount / 100.0f / sqrt(255.0f); const int BlockSize = 8; int Block = (Height * Stride) / BlockSize; const __m128i Zero = _mm_setzero_si128(); const __m128i ThresholdV = _mm_set1_epi16(Threshold); const __m128i MinusThresholdV = _mm_set1_epi16(-Threshold); const __m128i One = _mm_set1_epi16(1); const __m128i MinusOne = _mm_set1_epi16(-1); const __m128 Const255 = _mm_set1_ps(255.0f); const __m128 AdjustV = _mm_set1_ps(Adjust); for (int Y = 0; Y < Block * BlockSize; Y += BlockSize) { __m128i SrcV = _mm_unpacklo_epi8(_mm_loadl_epi64((__m128i *)(Src + Y)), Zero); __m128i DstV = _mm_unpacklo_epi8(_mm_loadl_epi64((__m128i *)(Dest + Y)), Zero); __m128 SrcL = _mm_cvtepi32_ps(_mm_unpacklo_epi8(SrcV, Zero)); __m128 SrcH = _mm_cvtepi32_ps(_mm_unpackhi_epi8(SrcV, Zero)); __m128i Diff = _mm_sub_epi16(SrcV, DstV); __m128i DiffA = _mm_add_epi16(Diff, ThresholdV); __m128i DiffS = _mm_sub_epi16(Diff, ThresholdV); __m128 DiffL = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(Diff)); __m128 DiffH = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_srli_si128(Diff, 8))); __m128 UpL = _mm_mul_ps(AdjustV, _mm_sqrt_ps(_mm_sub_ps(Const255, SrcL))); __m128 UpH = _mm_mul_ps(AdjustV, _mm_sqrt_ps(_mm_sub_ps(Const255, SrcH))); __m128 DownL = _mm_mul_ps(AdjustV, _mm_sqrt_ps(SrcL)); __m128 DownH = _mm_mul_ps(AdjustV, _mm_sqrt_ps(SrcH)); __m128 DiffUpL = _mm_mul_ps(_mm_cvtepi32_ps(_mm_cvtepi16_epi32(DiffS)), UpL); __m128 DiffUpH = _mm_mul_ps(_mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_srli_si128(DiffS, 8))), UpH); __m128 DiffDownL = _mm_mul_ps(_mm_cvtepi32_ps(_mm_cvtepi16_epi32(DiffA)), DownL); __m128 DiffDownH = _mm_mul_ps(_mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_srli_si128(DiffA, 8))), DownH); __m128i DiffUp = _mm_adds_epi16(_mm_packs_epi32(_mm_cvtps_epi32(DiffUpL), _mm_cvtps_epi32(DiffUpH)), SrcV); __m128i DiffDown = _mm_adds_epi16(_mm_packs_epi32(_mm_cvtps_epi32(DiffDownL), _mm_cvtps_epi32(DiffDownH)), SrcV); __m128i DestV = _mm_blendv_si128(SrcV, DiffUp, _mm_cmpgt_epi16(Diff, ThresholdV)); DestV = _mm_blendv_si128(DestV, DiffDown, _mm_cmplt_epi16(Diff, MinusThresholdV));
_mm_storel_epi64((__m128i *)(Dest + Y), _mm_packus_epi16(DestV, Zero)); } for (int Y = Block * BlockSize; Y < Height * Stride; Y++) { int Diff = Src[Y] - Dest[Y]; if (Diff >= Threshold) { Dest[Y] = IM_ClampToByte(int((Diff - Threshold) * Adjust * sqrtf(255.0f - Src[Y]) + 0.5f) + Src[Y]); } else if (Diff < -Threshold) { Dest[Y] = IM_ClampToByte(int((Diff + Threshold) * Adjust * sqrtf(0.0f + Src[Y]) + 0.5f) + Src[Y]); } else { Dest[Y] = Src[Y]; } } return IM_STATUS_OK; }
上述代码基本就是普通C语言的翻译,这里讲几个需要注意的地方。
第一、_mm_cvtepi16_epi32这是个讲signed short转换为signed int的函数,只处理XMM寄存的低8位,如果需要将高8位也进行转换,就必须得配合_mm_srli_si128一起使用,如果需要转换的signed short能确认是大于等于0的,也可以使用_mm_unpacklo_epi16及_mm_unpackhi_epi16配合_mm_setzero_si128来实现,比如上面的SrcL和SrcH就是使用的这个技巧,但是如果有小于0的情况出现,一定只能用_mm_cvtepi16_epi32来实现,比如上面的DiffL和DiffH,我以前在这个上面吃过很多亏。
第二、在计算DiffUp和DiffDown这两个结果时,注意需要使用_mm_packs_epi32,而不是_mm_packus_epi32,因为计算结果是有负数存在的。
第三、结果的融合这里的技巧很好,我们知道SSE4提供了两个__m128i变量融合的函数,比如_mm_blendv_epi8,但是他要求最后的融合选项是个常数,而我们这里的融合选项是变化的,所以无法使用,我们使用了一个叫做_mm_blendv_si128的内联函数,这个函数用一个__m128i变量作为融合参数,对128个位进行融合,其代码如下:
static inline __m128i _mm_blendv_si128(__m128i x, __m128i y, __m128i mask) { return _mm_or_si128(_mm_andnot_si128(mask, x), _mm_and_si128(mask, y)); }
当mask的某一位为0时,选择x中的对应位的值,否则选择y中对应位的值。
这个函数正是我需要的,而且恰好前几天在浏览文章:A few missing SSE intrinsics发现了他,有的时候真的觉得处处留心皆学问啊。
这时我们来看下上面的融合的代码:__m128i DestV = _mm_blendv_si128(SrcV, DiffUp, _mm_cmpgt_epi16(Diff, ThresholdV));
后面的_mm_cmpgt_epi16的比较函数会返回一个__m128i变量,当Diff > Threshold时,对应的16位数据为0xFFFF,否则为0,这样我们使用_mm_blendv_si128融合时,满足条件的部分结果就为DiffUp了,其他部分还保持SrcV不变。
接着 DestV = _mm_blendv_si128(DestV, DiffDown, _mm_cmplt_epi16(Diff, MinusThresholdV)); 使用Diff < -Threshold作为判断条件,因为该条件和Diff > Threshold不可能同时成立,所以_mm_cmplt_epi16的返回结果中的为true的部分和_mm_cmpgt_epi16返回的true部分的值不可能重叠,因此,再次执行_mm_blendv_si128混合的值就是我们融合的正确结果。
那么我们最关心的速度来了,经过测试,上述算法对1080P彩色图能达到约14ms的执行速度,和查找表的C语言版本速度差不多,唯一的优势就是运算时少占用了一部分内存。但是同时也说明SSE的计算能力真的不是盖的,算一算,正正的SSE执行时间实际上只有14-6.75 =7.25ms,而不用查找表的C代码的用时为55-6.75=48.25ms,达到了进7倍的提速比,但这就是我们的终点了吗?
第二方案:我们在仔细观察下Diff > Threshold和Diff < -Threshold时计算的不同,第一个不同是Diff > Threshold时使用了Diff - Threshold,而Diff < -Threshold时使用了Diff + Threshold;第二个不同为Diff > Threshold时使用了255.0f - Src[Y]作为开平方的算式,而Diff < -Threshold时使用了 Src[Y]。关于第一个不同,我们可以看到仅仅是个符号位不同,如果在Threshold前面根据不同的条件加个符号位在进行乘法不就可以了,也就是说如果我们根据Diff和Threshold的关系构建一个-1和1的中间变量,则可以把他们写在一个式子里,那这样的符号为要如何构建呢?
自然而然我们又想到了上述方法的_mm_blendv_si128,简单的方式如下所示:
__m128i Sign = _mm_blendv_si128(Zero, MinusOne, _mm_cmpgt_epi16(Diff, ThresholdV));
Sign = _mm_blendv_si128(Sign, One, _mm_cmplt_epi16(Diff, MinusThresholdV));
Zero,MinusOne,One这个还需不需要解释,上面的代码还需不需要解释?
第二个不同,我们这样看,我们把它们放在一起 255.0f - Src[Y] | Src[Y],稍微改写一下255 - Src[Y] | 0 - Src[Y],后面的+和-可以用类似前面的同样的方法处理,我们还需处理255和0,如果我们能够根据判断条件构造出255 和 0这样的序列,那是不是就解决问题了,如何构造?
前面说过,_mm_cmpgt_epi16会返回0xFFFF和0,看成unsigned short类型则为65535和0, 如果我们把这个返回结果右移8位,是不是就变为了255和0呢,明白了吗?
最后我们注意一点,当-Threshold < Diff <Threshold时,我们的返回的是原图像的值,那在这种情况下是不是有问题呢,其实也不会,我们注意到此条件下Sign对应的符号位为0,而_mm_cmpgt_epi16返回的那部分数据也为0,也就是说此时对应的sqrt参数为0,那么作为乘法的一部分,整个前面的算式就为0,结果返回的恰好是原值。
我们还来在说下前面的符号问题,正或者负某个数,直接用符号位加乘法固然是可以实现的,但是有么有其他的方式更好的实现呢,翻一番SSE的手册,我们会发现有_mm_sign_epi8 、_mm_sign_epi16 、_mm_sign_epi32 这样的函数,他们是干什么的呢,我们以_mm_sign_epi16为例,看看他的文档说明:
extern __m128i _mm_sign_epi16 (__m128i a, __m128i b); Negate packed words in a if corresponding sign in b is less than zero. Interpreting a, b, and r as arrays of signed 16-bit integers: for (i = 0; i < 8; i++) { if (b[i] < 0) { r[i] = -a[i]; } else if (b[i] == 0) { r[i] = 0; } else { r[i] = a[i]; } }
什么意思,就是以参数b的符号位来决定a的值,当b为负数是,对a求反,当b为0时,a也为0,否则a值保持不变。这不就可以直接实现上述的符号位的问题了吗?
说了那么多,我贴出代码大家看一看:
int IM_UnsharpMask(unsigned char *Src, unsigned char *Dest, int Width, int Height, int Stride, int Radius, int Amount, int Threshold) { int Channel = Stride / Width; if ((Src == NULL) || (Dest == NULL)) return IM_STATUS_NULLREFRENCE; if ((Width <= 0) || (Height <= 0)) return IM_STATUS_INVALIDPARAMETER; if ((Channel != 1) && (Channel != 3) && (Channel != 4)) return IM_STATUS_INVALIDPARAMETER; int Status = IM_STATUS_OK; Status = IM_ExpBlur(Src, Dest, Width, Height, Stride, Radius); if (Status != IM_STATUS_OK) return Status; const float Adjust = Amount / 100.0f / sqrt(255.0f); const int BlockSize = 8; int Block = (Height * Stride) / BlockSize; const __m128i Zero = _mm_setzero_si128(); const __m128i ThresholdV = _mm_set1_epi16(Threshold); const __m128i MinusThresholdV = _mm_set1_epi16(-Threshold); const __m128i MinusOne = _mm_set1_epi16(-1); const __m128 AdjustV = _mm_set1_ps(Adjust); const __m128i One = _mm_set1_epi16(1); for (int Y = 0; Y < Block * BlockSize; Y += BlockSize) { __m128i SrcV = _mm_unpacklo_epi8(_mm_loadl_epi64((__m128i *)(Src + Y)), Zero); __m128i DstV = _mm_unpacklo_epi8(_mm_loadl_epi64((__m128i *)(Dest + Y)), Zero); __m128i Diff = _mm_sub_epi16(SrcV, DstV); // int Diff = Src[Y] - Dest[Y]; // 当Diff > ThresholdV时,Sign设置为负数,当Diff < -ThresholdV时,Sign设置为正数, // 介于-ThresholdV和ThresholdV之间时为0,这里One和MinusOne只是取得一个代表性的值 __m128i SignA = _mm_cmpgt_epi16(Diff, ThresholdV); __m128i SignB = _mm_cmplt_epi16(Diff, MinusThresholdV); __m128i Sign = _mm_blendv_si128(Zero, MinusOne, SignA); Sign = _mm_blendv_si128(Sign, One, SignB); // Diff 为不同值时,NewDiff需要带上不同符号,利用上面的Sign配合_mm_sign_epi16能很好的解决问题 __m128i NewDiff = _mm_add_epi16(Diff, _mm_sign_epi16(ThresholdV, Sign)); // _mm_cmpgt_epi16返回0xfffff和0两种值,我们这里需要的是0xff和0,因此需要进行下移位,注意此时在Diff < Threshold(Sign为0或者1时) // _mm_add_epi16的第一个参数都是0,而第二个参数对于Sign为0的情况则也返回0,这样0+0正好为0,Sqrt后也为0,对结果正好没有影响(巧合还是天意?) __m128i NewPower = _mm_add_epi16(_mm_srli_epi16(SignA, 8), _mm_sign_epi16(SrcV, Sign)); // 注意这里有负数存在,则必须用这种强制转换函数 __m128 NewDiffL = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(NewDiff)); __m128 NewDiffH = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(_mm_srli_si128(NewDiff, 8))); // 都是正数就可以这样转化 __m128 NewPowerL = _mm_cvtepi32_ps(_mm_unpacklo_epi16(NewPower, Zero)); __m128 NewPowerH = _mm_cvtepi32_ps(_mm_unpackhi_epi16(NewPower, Zero)); // 按公式计算结果 __m128 DstL = _mm_mul_ps(_mm_mul_ps(AdjustV, NewDiffL), _mm_sqrt_ps(NewPowerL)); __m128 DstH = _mm_mul_ps(_mm_mul_ps(AdjustV, NewDiffH), _mm_sqrt_ps(NewPowerH)); // 合成到16位的结果,注意这里不要用_mm_packus_epi32,因为后面还有一个加法要进行 __m128i Result = _mm_packs_epi32(_mm_cvtps_epi32(DstL), _mm_cvtps_epi32(DstH)); // 合成到8位的结果,注意这要用抗饱和的加法_mm_adds_epi16 _mm_storel_epi64((__m128i *)(Dest + Y), _mm_packus_epi16(_mm_adds_epi16(Result, SrcV), Zero)); } for (int Y = Block * BlockSize; Y < Height * Stride; Y++) { int Diff = Src[Y] - Dest[Y]; if (Diff >= Threshold) { Dest[Y] = IM_ClampToByte(int((Diff - Threshold) * Adjust * sqrtf(255.0f - Src[Y]) + 0.5f) + Src[Y]); } else if (Diff < -Threshold) { Dest[Y] = IM_ClampToByte(int((Diff + Threshold) * Adjust * sqrtf(0.0f + Src[Y]) + 0.5f) + Src[Y]); } else { Dest[Y] = IM_ClampToByte(int(Diff * Adjust * sqrtf(0.0f + 0.0f) + 0.5f) + Src[Y]); // 不做变化 } } return IM_STATUS_OK; }
最后回到我们关心的速度问题上去,经过上述优化后能达到的速度平均值在11.5ms左右,比查找表版本的还要快了3ms左右。
实际上上述求Sign的过程还有更为简单的优化过程的,想通了也很有道理,这个留个读者自行去研究,大概能加快0.4ms左右的速度。
关于分支预测的SSE优化,目前我掌握的技巧也就这么多,管件还是要看算法本身,有的时候要脱离原始算法,为了能用SSE而稍微改变下算法的外表。这就各位神仙各显神通了,当然有很多分支预测由于太复杂还是不能够用SIMD指令优化的。
最后说一句,关于Photoshop的标准USM锐化并不是使用的上述算法,其原理应该说比上面的还要简单,但也不是网络上流行的那个计算公式,我已经通过测试推到得到了和其一模一样的计算式,这里不提,不过呢,为什么非要一样呢,这里的这个算法也是不错的。
算法Demo下载地址:https://files.cnblogs.com/files/Imageshop/SSE_Optimization_Demo.rar