SegmentationShader.metal 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. //
  2. // SegmentationShader.metal
  3. // ttpic
  4. //
  5. // Created by stonefeng on 2017/7/6.
  6. // Copyright © 2017年 Tencent. All rights reserved.
  7. //
  8. #include <metal_stdlib>
  9. using namespace metal;
  10. kernel void refineFilter1(texture2d<float, access::sample> inTexture [[ texture(0) ]],
  11. texture2d<float, access::write> outTexture [[ texture(1) ]],
  12. texture2d<float, access::sample> maskTexture [[ texture(2) ]],
  13. uint2 gid [[ thread_position_in_grid ]])
  14. {
  15. float eps = 0.01f;
  16. float step_x = 1.0f / (float)outTexture.get_width();
  17. float step_y = 1.0f / (float)outTexture.get_height();
  18. constexpr sampler quadSampler(coord::normalized, filter::linear, address::clamp_to_edge);
  19. float fMult9 = 1.0f / 9.0f;
  20. float4 srcValue[9];
  21. float2 fIdx0 = float2((float)gid.x * step_x, (float)gid.y * step_y);
  22. float2 fIdx = fIdx0;
  23. srcValue[4] = float4(inTexture.sample(quadSampler, fIdx).rgb, maskTexture.sample(quadSampler, fIdx).r);
  24. fIdx = float2(fIdx0.x - step_x, fIdx0.y - step_y);
  25. srcValue[0] = float4(inTexture.sample(quadSampler, fIdx).rgb, maskTexture.sample(quadSampler, fIdx).r);
  26. fIdx = float2(fIdx0.x, fIdx0.y - step_y);
  27. srcValue[1] = float4(inTexture.sample(quadSampler, fIdx).rgb, maskTexture.sample(quadSampler, fIdx).r);
  28. fIdx = float2(fIdx0.x + step_x, fIdx0.y - step_y);
  29. srcValue[2] = float4(inTexture.sample(quadSampler, fIdx).rgb, maskTexture.sample(quadSampler, fIdx).r);
  30. fIdx = float2(fIdx0.x - step_x, fIdx0.y);
  31. srcValue[3] = float4(inTexture.sample(quadSampler, fIdx).rgb, maskTexture.sample(quadSampler, fIdx).r);
  32. fIdx = float2(fIdx0.x + step_x, fIdx0.y);
  33. srcValue[5] = float4(inTexture.sample(quadSampler, fIdx).rgb, maskTexture.sample(quadSampler, fIdx).r);
  34. fIdx = float2(fIdx0.x - step_x, fIdx0.y + step_y);
  35. srcValue[6] = float4(inTexture.sample(quadSampler, fIdx).rgb, maskTexture.sample(quadSampler, fIdx).r);
  36. fIdx = float2(fIdx0.x, fIdx0.y + step_y);
  37. srcValue[7] = float4(inTexture.sample(quadSampler, fIdx).rgb, maskTexture.sample(quadSampler, fIdx).r);
  38. fIdx = float2(fIdx0.x + step_x, fIdx0.y + step_y);
  39. srcValue[8] = float4(inTexture.sample(quadSampler, fIdx).rgb, maskTexture.sample(quadSampler, fIdx).r);
  40. float4 mean_I = float4(0.0);
  41. float3 mean_Ip = float3(0.0);
  42. float var_I_rr = 0.0;
  43. float var_I_rg = 0.0;
  44. float var_I_rb = 0.0;
  45. float var_I_gg = 0.0;
  46. float var_I_gb = 0.0;
  47. float var_I_bb = 0.0;
  48. for (int i = 0; i < 9; i++){
  49. mean_I += srcValue[i];
  50. mean_Ip += srcValue[i].rgb * srcValue[i].a;
  51. var_I_rr += srcValue[i].r * srcValue[i].r;
  52. var_I_rg += srcValue[i].r * srcValue[i].g;
  53. var_I_rb += srcValue[i].r * srcValue[i].b;
  54. var_I_gg += srcValue[i].g * srcValue[i].g;
  55. var_I_gb += srcValue[i].g * srcValue[i].b;
  56. var_I_bb += srcValue[i].b * srcValue[i].b;
  57. }
  58. mean_I *= fMult9;
  59. mean_Ip *= fMult9;
  60. var_I_rr = var_I_rr * fMult9 - mean_I.r * mean_I.r + eps;
  61. var_I_rg = var_I_rg * fMult9 - mean_I.r * mean_I.g;
  62. var_I_rb = var_I_rb * fMult9 - mean_I.r * mean_I.b;
  63. var_I_gg = var_I_gg * fMult9 - mean_I.g * mean_I.g + eps;
  64. var_I_gb = var_I_gb * fMult9 - mean_I.g * mean_I.b;
  65. var_I_bb = var_I_bb * fMult9 - mean_I.b * mean_I.b + eps;
  66. float3 cov_Ip = mean_Ip - mean_I.rgb * mean_I.a;
  67. float invrr = var_I_gg * var_I_bb - var_I_gb * var_I_gb;
  68. float invrg = var_I_gb * var_I_rb - var_I_rg * var_I_bb;
  69. float invrb = var_I_rg * var_I_gb - var_I_gg * var_I_rb;
  70. float invgg = var_I_rr * var_I_bb - var_I_rb * var_I_rb;
  71. float invgb = var_I_rb * var_I_rg - var_I_rr * var_I_gb;
  72. float invbb = var_I_rr * var_I_gg - var_I_rg * var_I_rg;
  73. float covDet = invrr * var_I_rr + invrg * var_I_rg + invrb * var_I_rb;
  74. float4 resultColor = float4(0.0);
  75. resultColor.r = (invrr * cov_Ip.r + invrg * cov_Ip.g + invrb * cov_Ip.b) / covDet;
  76. resultColor.g = (invrg * cov_Ip.r + invgg * cov_Ip.g + invgb * cov_Ip.b) / covDet;
  77. resultColor.b = (invrb * cov_Ip.r + invgb * cov_Ip.g + invbb * cov_Ip.b) / covDet;
  78. resultColor.a = (mean_I.a - resultColor.r * mean_I.r - resultColor.g * mean_I.g - resultColor.b * mean_I.b) * 0.5;
  79. outTexture.write(resultColor * 0.5 + float4(0.5), gid);
  80. }
  81. kernel void refineFilter2(texture2d<float> inTexture [[ texture(0) ]],
  82. texture2d<float, access::write> outTexture [[ texture(1) ]],
  83. uint2 gid [[ thread_position_in_grid ]])
  84. {
  85. float step_x = 1.0f / (float)outTexture.get_width();
  86. float step_y = 1.0f / (float)outTexture.get_height();
  87. constexpr sampler quadSampler(coord::normalized, filter::linear, address::clamp_to_edge);
  88. float4 srcValue = float4(0.0);
  89. float2 fIdx0 = float2((float)gid.x * step_x, (float)gid.y * step_y);
  90. float2 fIdx = fIdx0;
  91. srcValue += inTexture.sample(quadSampler, fIdx);
  92. fIdx = float2(fIdx0.x - step_x, fIdx0.y - step_y);
  93. srcValue += inTexture.sample(quadSampler, fIdx);
  94. fIdx = float2(fIdx0.x, fIdx0.y - step_y);
  95. srcValue += inTexture.sample(quadSampler, fIdx);
  96. fIdx = float2(fIdx0.x + step_x, fIdx0.y - step_y);
  97. srcValue += inTexture.sample(quadSampler, fIdx);
  98. fIdx = float2(fIdx0.x - step_x, fIdx0.y);
  99. srcValue += inTexture.sample(quadSampler, fIdx);
  100. fIdx = float2(fIdx0.x + step_x, fIdx0.y);
  101. srcValue += inTexture.sample(quadSampler, fIdx);
  102. fIdx = float2(fIdx0.x - step_x, fIdx0.y + step_y);
  103. srcValue += inTexture.sample(quadSampler, fIdx);
  104. fIdx = float2(fIdx0.x, fIdx0.y + step_y);
  105. srcValue += inTexture.sample(quadSampler, fIdx);
  106. fIdx = float2(fIdx0.x + step_x, fIdx0.y + step_y);
  107. srcValue += inTexture.sample(quadSampler, fIdx);
  108. outTexture.write(srcValue / 9.0, gid);
  109. }
  110. kernel void refineFilter3(texture2d<float> inTexture [[ texture(0) ]],
  111. texture2d<float, access::write> outTexture [[ texture(1) ]],
  112. texture2d<float> maskTexture [[ texture(2) ]],
  113. uint2 gid [[ thread_position_in_grid ]])
  114. {
  115. constexpr sampler quadSampler(coord::normalized, filter::linear, address::clamp_to_edge);
  116. float step_x = 1.0f / (float)outTexture.get_width();
  117. float step_y = 1.0f / (float)outTexture.get_height();
  118. float2 fIdx0 = float2((float)gid.x * step_x, (float)gid.y * step_y);
  119. float4 r = inTexture.sample(quadSampler, fIdx0);
  120. float4 s = (maskTexture.sample(quadSampler, fIdx0) - float4(0.5)) * 2.0;
  121. float src = s.r * r.r + s.g * r.g + s.b * r.b + s.a * 2.0;
  122. src = (src-0.5) * 2.0 + 0.5;
  123. if (src < 0.05) src = 0.0;
  124. if (src > 0.95) src = 1.0;
  125. outTexture.write(float4(src,src,src,1.0), gid);
  126. }
  127. kernel void buffer2Texture2(texture2d<float, access::write> outTexture [[ texture(0) ]],
  128. constant float* uData [[buffer(0)]],
  129. uint2 gid [[thread_position_in_grid ]])
  130. {
  131. int width = outTexture.get_width();
  132. int height = outTexture.get_height();
  133. float posx = (gid.x+0.5) * 20.0 / width - 0.5;
  134. float posy = (gid.y+0.5) * 26.0 / height - 0.5;
  135. int dx = floor(posx);
  136. int dy = floor(posy);
  137. int dx2 = dx + 1;
  138. int dy2 = dy + 1;
  139. if (dx < 0) dx = 0;
  140. if (dy < 0) dy = 0;
  141. if (dx2 == 20) dx2 = dx;
  142. if (dy2 == 26) dy2 = dy;
  143. float ratioX = posx - dx;
  144. float ratioY = posy - dy;
  145. float u1 = uData[dx + dy * 20];
  146. float u2 = uData[dx2 + dy * 20];
  147. float u3 = uData[dx2 + dy2 * 20];
  148. float u4 = uData[dx + dy2 * 20];
  149. u1 = max(0.0, min(1.0, (u1 - 0.3) * 2.0 + 0.5));
  150. u2 = max(0.0, min(1.0, (u2 - 0.3) * 2.0 + 0.5));
  151. u3 = max(0.0, min(1.0, (u3 - 0.3) * 2.0 + 0.5));
  152. u4 = max(0.0, min(1.0, (u4 - 0.3) * 2.0 + 0.5));
  153. float value = u1 * (1.0 - ratioX) * (1.0 - ratioY) +
  154. u2 * ratioX * (1.0 - ratioY) +
  155. u3 * ratioX * ratioY +
  156. u4 * (1.0 - ratioX) * ratioY;
  157. outTexture.write(float4(value, value, value, 1.0), gid);
  158. }
  159. kernel void kernel_Float32toBGRA2(texture2d<float, access::write> outTexture [[ texture(0) ]],
  160. constant float* uData [[buffer(0)]],
  161. constant int* dimensions [[buffer(1)]],
  162. uint2 gid [[thread_position_in_grid ]])
  163. {
  164. float width = (float)outTexture.get_width();
  165. float height = (float)outTexture.get_height();
  166. int uniform_w = dimensions[0];
  167. int uniform_h = dimensions[1];
  168. float posx = (float)gid.x * (float)uniform_w / width;
  169. float posy = (float)gid.y * (float)uniform_h / height;
  170. int dx = floor(posx);
  171. int dy = floor(posy);
  172. int dx2 = dx + 1;
  173. int dy2 = dy + 1;
  174. if (dx < 0) dx = 0;
  175. if (dy < 0) dy = 0;
  176. if (dx2 == uniform_w) dx2 = dx;
  177. if (dy2 == uniform_h) dy2 = dy;
  178. float ratioX = posx - dx;
  179. float ratioY = posy - dy;
  180. float u1 = uData[dx + dy * uniform_w];
  181. float u2 = uData[dx2 + dy * uniform_w];
  182. float u3 = uData[dx2 + dy2 * uniform_w];
  183. float u4 = uData[dx + dy2 * uniform_w];
  184. float r0 = (u1 * (1.0 - ratioX) * (1.0 - ratioY) +
  185. u2 * ratioX * (1.0 - ratioY) +
  186. u3 * ratioX * ratioY +
  187. u4 * (1.0 - ratioX) * ratioY);
  188. int offset = uniform_w * uniform_h;
  189. u1 = uData[offset + dx + dy * uniform_w];
  190. u2 = uData[offset + dx2 + dy * uniform_w];
  191. u3 = uData[offset + dx2 + dy2 * uniform_w];
  192. u4 = uData[offset + dx + dy2 * uniform_w];
  193. float r1 = (u1 * (1.0 - ratioX) * (1.0 - ratioY) +
  194. u2 * ratioX * (1.0 - ratioY) +
  195. u3 * ratioX * ratioY +
  196. u4 * (1.0 - ratioX) * ratioY);
  197. float diff = exp(r1-r0);
  198. diff = diff/(diff + 1.0f);
  199. // float r = diff;
  200. // float r = diff > 0.5?1.0f:0.0f;
  201. float r = saturate((diff - 0.5f) * 1.5f + 0.5f);
  202. if (r < 0.05) r = 0;
  203. if (r > 0.95) r = 1.0f;
  204. outTexture.write(float4(r, r, r, 1.0), gid);
  205. }
  206. kernel void kernel_Float32toBGRA3(texture2d<float, access::write> outTexture [[ texture(0) ]],
  207. constant float* uData [[buffer(0)]],
  208. constant float* weight [[buffer(1)]],
  209. constant int* dimensions [[buffer(2)]],
  210. uint2 gid [[thread_position_in_grid ]])
  211. {
  212. float width = (float)outTexture.get_width();
  213. float height = (float)outTexture.get_height();
  214. int uniform_w = dimensions[0];
  215. int uniform_h = dimensions[1];
  216. float posx = (float)gid.x * (float)uniform_w / width;
  217. float posy = (float)gid.y * (float)uniform_h / height;
  218. int dx = floor(posx);
  219. int dy = floor(posy);
  220. int dx2 = dx + 1;
  221. int dy2 = dy + 1;
  222. if (dx < 0) dx = 0;
  223. if (dy < 0) dy = 0;
  224. if (dx2 == uniform_w) dx2 = dx;
  225. if (dy2 == uniform_h) dy2 = dy;
  226. float ratioX = posx - dx;
  227. float ratioY = posy - dy;
  228. int idx1 = dx + dy * uniform_w;
  229. int idx2 = dx2 + dy * uniform_w;
  230. int idx3 = dx2 + dy2 * uniform_w;
  231. int idx4 = dx + dy2 * uniform_w;
  232. if (weight[idx1] < 0.5 && weight[idx2] < 0.5 && weight[idx3] < 0.5 && weight[idx3] < 0.5) {
  233. outTexture.write(float4(0, 0, 0, 1.0), gid);
  234. }
  235. else {
  236. float u1 = uData[idx1];
  237. float u2 = uData[idx2];
  238. float u3 = uData[idx3];
  239. float u4 = uData[idx4];
  240. float r0 = (u1 * (1.0 - ratioX) * (1.0 - ratioY) +
  241. u2 * ratioX * (1.0 - ratioY) +
  242. u3 * ratioX * ratioY +
  243. u4 * (1.0 - ratioX) * ratioY);
  244. int offset = uniform_w * uniform_h;
  245. u1 = uData[idx1+offset];
  246. u2 = uData[idx2+offset];
  247. u3 = uData[idx3+offset];
  248. u4 = uData[idx4+offset];
  249. float r1 = (u1 * (1.0 - ratioX) * (1.0 - ratioY) +
  250. u2 * ratioX * (1.0 - ratioY) +
  251. u3 * ratioX * ratioY +
  252. u4 * (1.0 - ratioX) * ratioY);
  253. float diff = exp(r1-r0);
  254. diff = diff/(diff + 1.0f);
  255. float r = diff;
  256. // float r = saturate((diff - 0.5f) * 1.5f + 0.5f);
  257. // if (r < 0.05) r = 0;
  258. // if (r > 0.95) r = 1.0f;
  259. outTexture.write(float4(r, r, r, 1.0), gid);
  260. }
  261. }
  262. kernel void kernel_smallmap(texture2d<float, access::read> outTexture [[ texture(0) ]],
  263. constant float* uData [[buffer(0)]],
  264. constant int* dimensions [[buffer(1)]],
  265. device float* uData1 [[buffer(2)]],
  266. uint2 gid [[thread_position_in_grid ]])
  267. {
  268. int uniform_w = dimensions[0];
  269. int uniform_h = dimensions[1];
  270. int offset = uniform_w * uniform_h;
  271. int index = gid.y * uniform_w + gid.x;
  272. float u0 = uData[index];
  273. float u1 = uData[index + offset];
  274. float u = exp(u1- u0);
  275. u = u/(u + 1.0f);
  276. // u = (u-0.5) * 2.0 + 0.5;
  277. uData1[index] = u > 0.05?1.0f:0.0f;
  278. }
  279. kernel void kernel_refineMask(texture2d<float, access::read> outTexture [[ texture(0) ]],
  280. constant float* srcData [[buffer(0)]],
  281. device float* dstData [[buffer(1)]],
  282. uint2 gid [[thread_position_in_grid ]])
  283. {
  284. int width = outTexture.get_width();
  285. int height = outTexture.get_height();
  286. int offset[2];
  287. offset[0] = 0;
  288. offset[1] = width * height;
  289. int gx = gid.x;
  290. int gy = gid.y;
  291. int x[3];
  292. int y[3];
  293. x[0] = max(gx - 1, 0);
  294. x[1] = gid.x;
  295. x[2] = min(gx + 1, width - 1);
  296. y[0] = max(gy - 1, 0);
  297. y[1] = gid.y;
  298. y[2] = min(gy + 1, height - 1);
  299. #define s2(a, b) temp = a; a = min(a, b); b = max(temp, b);
  300. #define mn3(a, b, c) s2(a, b); s2(a, c);
  301. #define mx3(a, b, c) s2(b, c); s2(a, c);
  302. #define mnmx3(a, b, c) mx3(a, b, c); s2(a, b); // 3 exchanges
  303. #define mnmx4(a, b, c, d) s2(a, b); s2(c, d); s2(a, c); s2(b, d); // 4 exchanges
  304. #define mnmx5(a, b, c, d, e) s2(a, b); s2(c, d); mn3(a, c, e); mx3(b, d, e); // 6 exchanges
  305. #define mnmx6(a, b, c, d, e, f) s2(a, d); s2(b, e); s2(c, f); mn3(a, b, c); mx3(d, e, f); // 7 exchanges
  306. for (int i = 0; i < 2; i++) {
  307. float temp, v[6];
  308. int offseti = offset[i];
  309. uint index = width * y[2] + x[0] + offseti;
  310. v[0] = srcData[index];
  311. index = width * y[0] + x[2] + offseti;
  312. v[1] = srcData[index];
  313. index = width * y[0] + x[0] + offseti;
  314. v[2] = srcData[index];
  315. index = width * y[2] + x[2] + offseti;
  316. v[3] = srcData[index];
  317. index = width * y[1] + x[0] + offseti;
  318. v[4] = srcData[index];
  319. index = width * y[1] + x[2] + offseti;
  320. v[5] = srcData[index];
  321. mnmx6(v[0], v[1], v[2], v[3], v[4], v[5]);
  322. index = width * y[2] + x[1] + offseti;
  323. v[5] = srcData[index];
  324. mnmx5(v[1], v[2], v[3], v[4], v[5]);
  325. index = width * y[0] + x[1] + offseti;
  326. v[5] = srcData[index];
  327. mnmx4(v[2], v[3], v[4], v[5]);
  328. index = width * y[1] + x[1] + offseti;
  329. v[5] = srcData[index];
  330. mnmx3(v[3], v[4], v[5]);
  331. dstData[index + offseti] = v[4];
  332. }
  333. }
  334. kernel void erodeFilter(texture2d<half, access::read> inTexture [[ texture(0) ]],
  335. texture2d<half, access::write> outTexture [[ texture(1) ]],
  336. uint2 gid [[ thread_position_in_grid ]])
  337. {
  338. half4 inColor = inTexture.read(gid);
  339. uint2 position = uint2(gid.x - 1, gid.y - 1);
  340. inColor = min(inTexture.read(position), inColor);
  341. position = uint2(gid.x, gid.y - 1);
  342. inColor = min(inTexture.read(position), inColor);
  343. position = uint2(gid.x + 1, gid.y - 1);
  344. inColor = min(inTexture.read(position), inColor);
  345. position = uint2(gid.x - 1, gid.y);
  346. inColor = min(inTexture.read(position), inColor);
  347. position = uint2(gid.x + 1, gid.y);
  348. inColor = min(inTexture.read(position), inColor);
  349. position = uint2(gid.x - 1, gid.y + 1);
  350. inColor = min(inTexture.read(position), inColor);
  351. position = uint2(gid.x, gid.y + 1);
  352. inColor = min(inTexture.read(position), inColor);
  353. position = uint2(gid.x + 1, gid.y + 1);
  354. inColor = min(inTexture.read(position), inColor);
  355. outTexture.write(inColor, gid);
  356. }
  357. kernel void dilateFilter(texture2d<half, access::read> inTexture [[ texture(0) ]],
  358. texture2d<half, access::write> outTexture [[ texture(1) ]],
  359. // constant fixUniform &uData [[buffer(0)]],
  360. uint2 gid [[ thread_position_in_grid ]])
  361. {
  362. int deltaX = 1;//uData.deltaX;
  363. int deltaY = 1;//uData.deltaY;
  364. half4 inColor = inTexture.read(gid);
  365. uint2 xGid = uint2(gid.x + deltaX, gid.y + deltaY);
  366. inColor = max(inColor, inTexture.read(xGid));
  367. xGid = uint2(xGid.x + deltaX, xGid.y + deltaY);
  368. inColor = max(inColor, inTexture.read(xGid));
  369. xGid = uint2(gid.x - deltaX, gid.y - deltaY);
  370. inColor = max(inColor, inTexture.read(xGid));
  371. xGid = uint2(xGid.x - deltaX, xGid.y - deltaY);
  372. inColor = max(inColor, inTexture.read(xGid));
  373. outTexture.write(inColor, gid);
  374. }
  375. kernel void fixNormalFilter(texture2d<half, access::read> inTexture [[ texture(0) ]],
  376. texture2d<half, access::write> outTexture [[ texture(1) ]],
  377. uint2 gid [[ thread_position_in_grid ]])
  378. {
  379. half4 inColor = inTexture.read(gid);
  380. inColor = 2.0 * (inColor - half4(0.3)) + half4(0.5);
  381. inColor = max(half4(0.0), min(half4(1.0), inColor));
  382. outTexture.write(half4(inColor.rgb, 1.0), gid);
  383. }
  384. kernel void kernel_diff(texture2d<float, access::write> outTexture [[ texture(0) ]],
  385. texture2d<float, access::read> inTexture0 [[ texture(1) ]],
  386. texture2d<float, access::read> inTexture1 [[ texture(2) ]],
  387. uint2 gid [[thread_position_in_grid ]])
  388. {
  389. float r0 = inTexture0.read(gid).r;
  390. float r1 = inTexture1.read(gid).r;
  391. float diff = exp(r1-r0);
  392. diff = diff/(diff + 1.0f);
  393. float r = diff;
  394. // float r = (diff > 0.5f)?1.0f:0.0f;
  395. // float r = (diff - 0.5f) * 1.5f + 0.5f;
  396. // if (r < 0.5) r = 0;
  397. // if (r > 0.5) r = 1.0f;
  398. outTexture.write(float4(r, r, r, 1.0), gid);
  399. }
  400. kernel void kernel_diff2(texture2d<float, access::read> preTexture [[ texture(0) ]],
  401. texture2d<float, access::read> curTexture [[ texture(1) ]],
  402. texture2d<float, access::read> preTexture0 [[ texture(2) ]],
  403. texture2d<float, access::read> preTexture1 [[ texture(3) ]],
  404. texture2d<float, access::read> cnnTexture0 [[ texture(4) ]],
  405. texture2d<float, access::read> cnnTexture1 [[ texture(5) ]],
  406. texture2d<float, access::write> dstTexture0 [[ texture(6) ]],
  407. texture2d<float, access::write> dstTexture1 [[ texture(7) ]],
  408. uint2 gid [[thread_position_in_grid ]])
  409. {
  410. float4 curColor4 = curTexture.read(gid);
  411. float4 preColor4 = preTexture.read(gid);
  412. float diff = fabs(curColor4.r - preColor4.r) + fabs(curColor4.g - preColor4.g) + fabs(curColor4.b - preColor4.b);
  413. diff = min(1.0f, diff * 1.7f);
  414. float pre0 = preTexture0.read(gid).r;
  415. float cnn0 = cnnTexture0.read(gid).r;
  416. float r0 = mix(pre0, cnn0, diff);
  417. r0 = mix(cnn0, r0, 0.5f);
  418. dstTexture0.write(float4(r0, r0, r0, 1.0f), gid);
  419. float pre1 = preTexture1.read(gid).r;
  420. float cnn1 = cnnTexture1.read(gid).r;
  421. float r1 = mix(pre1, cnn1, diff);
  422. r1 = mix(cnn1, r1, 0.5f);
  423. dstTexture1.write(float4(r1, r1, r1, 1.0f), gid);
  424. }
  425. kernel void kernel_resize(texture2d<float, access::write> outTexture [[ texture(0) ]],
  426. texture2d<float, access::read> inTexture [[ texture(1) ]],
  427. uint2 gid [[thread_position_in_grid ]])
  428. {
  429. uint out_w = outTexture.get_width();
  430. uint out_h = outTexture.get_height();
  431. uint in_w = inTexture.get_width();
  432. uint in_h = inTexture.get_height();
  433. float posx = (float)(gid.x * in_w) / (float)out_w;
  434. float posy = (float)(gid.y * in_h) / (float)out_h;
  435. int dx = floor(posx);
  436. int dy = floor(posy);
  437. float u1 = inTexture.read(uint2(dx, dy)).r;
  438. float u2 = inTexture.read(uint2(dx+1, dy)).r;
  439. float u3 = inTexture.read(uint2(dx+1, dy+1)).r;
  440. float u4 = inTexture.read(uint2(dx, dy+1)).r;
  441. float ratioX = posx - dx;
  442. float ratioY = posy - dy;
  443. float r = (u1 * (1.0 - ratioX) * (1.0 - ratioY) +
  444. u2 * ratioX * (1.0 - ratioY) +
  445. u3 * ratioX * ratioY +
  446. u4 * (1.0 - ratioX) * ratioY);
  447. outTexture.write(float4(r, r, r, 1.0), gid);
  448. }
  449. struct VertexIO
  450. {
  451. float4 m_Position [[position]];
  452. float2 m_TexCoord [[user(texturecoord)]];
  453. };
  454. fragment half4 texturedQuadFragmentMaskBgFg2(VertexIO inFrag [[ stage_in ]],
  455. texture2d<half> tex2D [[ texture(0) ]],
  456. texture2d<half> mask2D [[ texture(1) ]])
  457. {
  458. constexpr sampler quadSampler(coord::normalized, filter::linear, address::clamp_to_edge);
  459. half4 src = tex2D.sample(quadSampler, inFrag.m_TexCoord);
  460. half4 mask = mask2D.sample(quadSampler, inFrag.m_TexCoord);
  461. half4 bgColor = half4(0.078431373,0.15686275,0.31372549,1.0);
  462. half alpha = saturate(mask.r * 1.4f - 0.15f);
  463. if (alpha < 0.3) alpha = 0.0f;
  464. if (alpha > 0.5) alpha = 1.0f;
  465. return mix(bgColor, src, alpha);
  466. }
  467. kernel void kernel_box(texture2d<float, access::read> inTexture [[ texture(0) ]],
  468. texture2d<float, access::write> outTexture [[ texture(1) ]],
  469. constant int* dimensions [[ buffer(1) ]],
  470. uint2 gid [[ thread_position_in_grid ]])
  471. {
  472. int kernel_w = dimensions[0];
  473. int kernel_h = dimensions[1];
  474. int iter_w = 2*kernel_w+1;
  475. int iter_h = 2*kernel_h+1;
  476. float4 inColor = float4(0.0f);
  477. for (int i = 0; i < iter_w; i++) {
  478. for (int j = 0; j < iter_h; j++) {
  479. uint2 position = uint2(gid.x - kernel_w + i, gid.y - kernel_h + j);
  480. inColor += inTexture.read(position);
  481. }
  482. }
  483. inColor = inColor/(float)(iter_w*iter_h);
  484. outTexture.write(inColor, gid);
  485. }
  486. kernel void kernel_box_horizon(texture2d<float, access::read> inTexture [[ texture(0) ]],
  487. texture2d<float, access::write> outTexture [[ texture(1) ]],
  488. constant int* dimensions [[ buffer(1) ]],
  489. uint2 gid [[ thread_position_in_grid ]])
  490. {
  491. int kernel_w = dimensions[0];
  492. int iter_w = 2*kernel_w+1;
  493. float4 inColor = float4(0.0f);
  494. for (int i = 0; i < iter_w; i++) {
  495. uint2 position = uint2(gid.x - kernel_w + i, gid.y);
  496. inColor += inTexture.read(position);
  497. }
  498. inColor = inColor/(float)iter_w;
  499. outTexture.write(inColor, gid);
  500. }
  501. kernel void kernel_box_vertical(texture2d<float, access::read> inTexture [[ texture(0) ]],
  502. texture2d<float, access::write> outTexture [[ texture(1) ]],
  503. constant int* dimensions [[ buffer(1) ]],
  504. uint2 gid [[ thread_position_in_grid ]])
  505. {
  506. int kernel_h = dimensions[1];
  507. int iter_h = 2*kernel_h+1;
  508. float4 inColor = float4(0.0f);
  509. for (int j = 0; j < iter_h; j++) {
  510. uint2 position = uint2(gid.x, gid.y - kernel_h + j);
  511. inColor += inTexture.read(position);
  512. }
  513. inColor = inColor/(float)(iter_h);
  514. outTexture.write(inColor, gid);
  515. }
  516. kernel void kernel_dilate_horizon(texture2d<float, access::read> inTexture [[ texture(0) ]],
  517. texture2d<float, access::write> outTexture [[ texture(1) ]],
  518. constant int* dimensions [[ buffer(1) ]],
  519. uint2 gid [[ thread_position_in_grid ]])
  520. {
  521. int kernel_w = dimensions[0];
  522. int iter_w = 2*kernel_w+1;
  523. float4 inColor = float4(0.0f);
  524. for (int i = 0; i < iter_w; i++) {
  525. uint2 position = uint2(gid.x - kernel_w + i, gid.y);
  526. float4 color = inTexture.read(position);
  527. inColor = max(color, inColor);
  528. }
  529. outTexture.write(inColor, gid);
  530. }
  531. kernel void kernel_dilate_vertical(texture2d<float, access::read> inTexture [[ texture(0) ]],
  532. texture2d<float, access::write> outTexture [[ texture(1) ]],
  533. constant int* dimensions [[ buffer(1) ]],
  534. uint2 gid [[ thread_position_in_grid ]])
  535. {
  536. int kernel_h = dimensions[1];
  537. int iter_h = 2*kernel_h+1;
  538. float4 inColor = float4(0.0f);
  539. for (int j = 0; j < iter_h; j++) {
  540. uint2 position = uint2(gid.x, gid.y - kernel_h + j);
  541. float4 color = inTexture.read(position);
  542. inColor = max(color, inColor);
  543. }
  544. outTexture.write(inColor, gid);
  545. }