From 8f61d4959a28f3950be9998b667a09f6a8213010 Mon Sep 17 00:00:00 2001 From: Kyle Swanson Date: Wed, 30 Jul 2025 16:26:26 -0700 Subject: [PATCH 1/2] feature/vif: port all speed related changes Co-authored-by: Nil Fons Miret --- libvmaf/src/feature/common/convolution.h | 5 + libvmaf/src/feature/common/convolution_avx.c | 2562 ++---------------- libvmaf/src/feature/float_adm.c | 4 +- libvmaf/src/feature/float_ansnr.c | 4 +- libvmaf/src/feature/float_moment.c | 4 +- libvmaf/src/feature/float_motion.c | 2 +- libvmaf/src/feature/float_ms_ssim.c | 4 +- libvmaf/src/feature/float_psnr.c | 4 +- libvmaf/src/feature/float_ssim.c | 4 +- libvmaf/src/feature/float_vif.c | 4 +- libvmaf/src/feature/picture_copy.c | 32 +- libvmaf/src/feature/picture_copy.h | 2 +- libvmaf/src/feature/vif.c | 157 +- libvmaf/src/feature/vif_tools.c | 735 +++-- libvmaf/src/feature/vif_tools.h | 63 +- 15 files changed, 790 insertions(+), 2796 deletions(-) diff --git a/libvmaf/src/feature/common/convolution.h b/libvmaf/src/feature/common/convolution.h index 0b030916f..de6b5ab61 100644 --- a/libvmaf/src/feature/common/convolution.h +++ b/libvmaf/src/feature/common/convolution.h @@ -19,6 +19,11 @@ #ifndef CONVOLUTION_H_ #define CONVOLUTION_H_ +/* +Filter widths above this one will not use the AVX path for convolutions. +*/ +#define MAX_FWIDTH_AVX_CONV 17 + /* * All functions listed here expect a SYMMETRICAL filter. * All array arguments must be 32-byte aligned. diff --git a/libvmaf/src/feature/common/convolution_avx.c b/libvmaf/src/feature/common/convolution_avx.c index c6bfb7096..2e0d8cf86 100644 --- a/libvmaf/src/feature/common/convolution_avx.c +++ b/libvmaf/src/feature/common/convolution_avx.c @@ -21,2414 +21,228 @@ #include "convolution.h" #include "convolution_internal.h" -void convolution_f32_avx_s_1d_h_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end); -void convolution_f32_avx_s_1d_h_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end); -void convolution_f32_avx_s_1d_h_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end); -void convolution_f32_avx_s_1d_v_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end); -void convolution_f32_avx_s_1d_v_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end); -void convolution_f32_avx_s_1d_v_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end); +#define AVX_STEP (8) -void convolution_f32_avx_s_1d_h_sq_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end); -void convolution_f32_avx_s_1d_h_sq_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end); -void convolution_f32_avx_s_1d_h_sq_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end); -void convolution_f32_avx_s_1d_v_sq_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end); -void convolution_f32_avx_s_1d_v_sq_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end); -void convolution_f32_avx_s_1d_v_sq_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end); -void convolution_f32_avx_s_1d_h_xy_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int j_end); -void convolution_f32_avx_s_1d_h_xy_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int j_end); -void convolution_f32_avx_s_1d_h_xy_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int j_end); -void convolution_f32_avx_s_1d_v_xy_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end); -void convolution_f32_avx_s_1d_v_xy_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end); -void convolution_f32_avx_s_1d_v_xy_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end); +void convolution_f32_avx_s_1d_h_scanline(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end) { + int radius = filter_width / 2; -// Filter a single scanline. -static void convolution_f32_avx_s_1d_h_scanline(int N, const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end) -{ - - if (N == 5) - { - convolution_f32_avx_s_1d_h_scanline_5(filter, filter_width, src, dst, j_end); - } - else if (N == 9) - { - convolution_f32_avx_s_1d_h_scanline_9(filter, filter_width, src, dst, j_end); - } - else if (N == 17) - { - convolution_f32_avx_s_1d_h_scanline_17(filter, filter_width, src, dst, j_end); - } - else { - - int radius = filter_width / 2; - - for (int x = 0; x < filter_width; x += 9) { - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - - f0 = _mm256_setzero_ps(); - f1 = _mm256_setzero_ps(); - f2 = _mm256_setzero_ps(); - f3 = _mm256_setzero_ps(); - f5 = _mm256_setzero_ps(); - f6 = _mm256_setzero_ps(); - f7 = _mm256_setzero_ps(); - f8 = _mm256_setzero_ps(); - - switch (filter_width - x) { - default: - f8 = _mm256_broadcast_ss(filter + x + 8); - // fall through - case 8: - f7 = _mm256_broadcast_ss(filter + x + 7); - // fall through - case 7: - f6 = _mm256_broadcast_ss(filter + x + 6); - // fall through - case 6: - f5 = _mm256_broadcast_ss(filter + x + 5); - // fall through - case 5: - f4 = _mm256_broadcast_ss(filter + x + 4); - // fall through - case 4: - f3 = _mm256_broadcast_ss(filter + x + 3); - // fall through - case 3: - f2 = _mm256_broadcast_ss(filter + x + 2); - // fall through - case 2: - f1 = _mm256_broadcast_ss(filter + x + 1); - // fall through - case 1: - f0 = _mm256_broadcast_ss(filter + x + 0); - // fall through - } - - for (int j = 0; j < j_end; j += 8) { - __m256 accum = _mm256_setzero_ps(); - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - sum0 = _mm256_setzero_ps(); - sum1 = _mm256_setzero_ps(); - sum2 = _mm256_setzero_ps(); - sum3 = _mm256_setzero_ps(); - - switch (filter_width - x) { - default: - g = _mm256_loadu_ps(src + j + x + 8); - sum0 = _mm256_mul_ps(f8, g); - // fall through - case 8: - g = _mm256_loadu_ps(src + j + x + 7); - sum3 = _mm256_mul_ps(f7, g); - // fall through - case 7: - g = _mm256_loadu_ps(src + j + x + 6); - sum2 = _mm256_mul_ps(f6, g); - // fall through - case 6: - g = _mm256_loadu_ps(src + j + x + 5); - sum1 = _mm256_mul_ps(f5, g); - // fall through - case 5: - g = _mm256_loadu_ps(src + j + x + 4); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - // fall through - case 4: - g = _mm256_loadu_ps(src + j + x + 3); - g = _mm256_mul_ps(f3, g); - sum3 = _mm256_add_ps(sum3, g); - // fall through - case 3: - g = _mm256_loadu_ps(src + j + x + 2); - g = _mm256_mul_ps(f2, g); - sum2 = _mm256_add_ps(sum2, g); - // fall through - case 2: - g = _mm256_loadu_ps(src + j + x + 1); - g = _mm256_mul_ps(f1, g); - sum1 = _mm256_add_ps(sum1, g); - // fall through - case 1: - g = _mm256_loadu_ps(src + j + x + 0); - g = _mm256_mul_ps(f0, g); - sum0 = _mm256_add_ps(sum0, g); - // fall through - } - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - accum = _mm256_add_ps(accum, sum0); - - if (x) - accum = _mm256_add_ps(accum, _mm256_loadu_ps(dst + j + radius)); - - _mm256_storeu_ps(dst + j + radius, accum); - } - } - - } -} - -void convolution_f32_avx_s_1d_h_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - - // Evaluate filter taps 0-8 - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - f5 = _mm256_broadcast_ss(filter + 5); - f6 = _mm256_broadcast_ss(filter + 6); - f7 = _mm256_broadcast_ss(filter + 7); - f8 = _mm256_broadcast_ss(filter + 8); - - for (int j = 0; j < j_end; j += 8) { - __m256 accum = _mm256_setzero_ps(); - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - g = _mm256_loadu_ps(src + j + 0); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_loadu_ps(src + j + 1); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_loadu_ps(src + j + 2); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_loadu_ps(src + j + 3); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_loadu_ps(src + j + 4); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_loadu_ps(src + j + 5); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_loadu_ps(src + j + 6); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_loadu_ps(src + j + 7); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - g = _mm256_loadu_ps(src + j + 8); - g = _mm256_mul_ps(f8, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - accum = _mm256_add_ps(accum, sum0); - - _mm256_store_ps(dst + j + 8, accum); // radius = 8 - } - - // Evaluate filter taps 9-16 - f0 = _mm256_broadcast_ss(filter + 9); - f1 = _mm256_broadcast_ss(filter + 10); - f2 = _mm256_broadcast_ss(filter + 11); - f3 = _mm256_broadcast_ss(filter + 12); - f4 = _mm256_broadcast_ss(filter + 13); - f5 = _mm256_broadcast_ss(filter + 14); - f6 = _mm256_broadcast_ss(filter + 15); - f7 = _mm256_broadcast_ss(filter + 16); - - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - float *dst_ptr = dst + j + 8; // radius = 8 - - g = _mm256_loadu_ps(src + j + 9); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_loadu_ps(src + j + 10); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_loadu_ps(src + j + 11); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_loadu_ps(src + j + 12); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_loadu_ps(src + j + 13); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_loadu_ps(src + j + 14); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_loadu_ps(src + j + 15); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_loadu_ps(src + j + 16); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - sum0 = _mm256_add_ps(_mm256_load_ps(dst_ptr), sum0); - _mm256_store_ps(dst_ptr, sum0); - } -} - -void convolution_f32_avx_s_1d_h_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - f5 = _mm256_broadcast_ss(filter + 5); - f6 = _mm256_broadcast_ss(filter + 6); - f7 = _mm256_broadcast_ss(filter + 7); - f8 = _mm256_broadcast_ss(filter + 8); - - for (int j = 0; j < j_end; j += 8) { - __m256 accum = _mm256_setzero_ps(); - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - g = _mm256_loadu_ps(src + j + 0); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_loadu_ps(src + j + 1); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_loadu_ps(src + j + 2); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_loadu_ps(src + j + 3); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_loadu_ps(src + j + 4); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_loadu_ps(src + j + 5); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_loadu_ps(src + j + 6); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_loadu_ps(src + j + 7); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - g = _mm256_loadu_ps(src + j + 8); - g = _mm256_mul_ps(f8, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - accum = _mm256_add_ps(accum, sum0); - - _mm256_storeu_ps(dst + j + 4, accum); // radius = 4 - } -} - -void convolution_f32_avx_s_1d_h_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4; - - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - - for (int j = 0; j < j_end; j += 8) { - __m256 accum = _mm256_setzero_ps(); - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - g = _mm256_loadu_ps(src + j + 0); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_loadu_ps(src + j + 1); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_loadu_ps(src + j + 2); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_loadu_ps(src + j + 3); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_loadu_ps(src + j + 4); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - accum = _mm256_add_ps(accum, sum0); - - _mm256_storeu_ps(dst + j + 2, accum); // radius = 2 - } -} - -// Filter a single scanline. -static void convolution_f32_avx_s_1d_v_scanline(int N, const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end) -{ - - if (N == 5) - { - convolution_f32_avx_s_1d_v_scanline_5(filter, filter_width, src, dst, src_stride, j_end); - } - else if (N == 9) - { - convolution_f32_avx_s_1d_v_scanline_9(filter, filter_width, src, dst, src_stride, j_end); - } - else if (N == 17) - { - convolution_f32_avx_s_1d_v_scanline_17(filter, filter_width, src, dst, src_stride, j_end); - } - else { - - int radius = filter_width / 2; - src -= radius * src_stride; - - for (int y = 0; y < filter_width; y += 9) { - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - - f0 = _mm256_setzero_ps(); - f1 = _mm256_setzero_ps(); - f2 = _mm256_setzero_ps(); - f3 = _mm256_setzero_ps(); - f5 = _mm256_setzero_ps(); - f6 = _mm256_setzero_ps(); - f7 = _mm256_setzero_ps(); - f8 = _mm256_setzero_ps(); - - switch (filter_width - y) { - default: - f8 = _mm256_broadcast_ss(filter + y + 8); - // fall through - case 8: - f7 = _mm256_broadcast_ss(filter + y + 7); - // fall through - case 7: - f6 = _mm256_broadcast_ss(filter + y + 6); - // fall through - case 6: - f5 = _mm256_broadcast_ss(filter + y + 5); - // fall through - case 5: - f4 = _mm256_broadcast_ss(filter + y + 4); - // fall through - case 4: - f3 = _mm256_broadcast_ss(filter + y + 3); - // fall through - case 3: - f2 = _mm256_broadcast_ss(filter + y + 2); - // fall through - case 2: - f1 = _mm256_broadcast_ss(filter + y + 1); - // fall through - case 1: - f0 = _mm256_broadcast_ss(filter + y + 0); - // fall through - } - - for (int j = 0; j < j_end; j += 8) { - __m256 accum = _mm256_setzero_ps(); - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - sum0 = _mm256_setzero_ps(); - sum1 = _mm256_setzero_ps(); - sum2 = _mm256_setzero_ps(); - sum3 = _mm256_setzero_ps(); - - switch (filter_width - y) { - default: - g = _mm256_load_ps(src + (y + 8) * src_stride + j); - sum0 = _mm256_mul_ps(f8, g); - // fall through - case 8: - g = _mm256_load_ps(src + (y + 7) * src_stride + j); - sum3 = _mm256_mul_ps(f7, g); - // fall through - case 7: - g = _mm256_load_ps(src + (y + 6) * src_stride + j); - sum2 = _mm256_mul_ps(f6, g); - // fall through - case 6: - g = _mm256_load_ps(src + (y + 5) * src_stride + j); - sum1 = _mm256_mul_ps(f5, g); - // fall through - case 5: - g = _mm256_load_ps(src + (y + 4) * src_stride + j); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - // fall through - case 4: - g = _mm256_load_ps(src + (y + 3) * src_stride + j); - g = _mm256_mul_ps(f3, g); - sum3 = _mm256_add_ps(sum3, g); - // fall through - case 3: - g = _mm256_load_ps(src + (y + 2) * src_stride + j); - g = _mm256_mul_ps(f2, g); - sum2 = _mm256_add_ps(sum2, g); - // fall through - case 2: - g = _mm256_load_ps(src + (y + 1) * src_stride + j); - g = _mm256_mul_ps(f1, g); - sum1 = _mm256_add_ps(sum1, g); - // fall through - case 1: - g = _mm256_load_ps(src + (y + 0) * src_stride + j); - g = _mm256_mul_ps(f0, g); - sum0 = _mm256_add_ps(sum0, g); - // fall through - } - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - accum = _mm256_add_ps(accum, sum0); - - if (y) - accum = _mm256_add_ps(accum, _mm256_load_ps(dst + j)); - - _mm256_store_ps(dst + j, accum); - } - } - } -} - -void convolution_f32_avx_s_1d_v_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - src -= 8 * src_stride; // radius = 8 - - // Evaluate filter taps 0-8 - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - f5 = _mm256_broadcast_ss(filter + 5); - f6 = _mm256_broadcast_ss(filter + 6); - f7 = _mm256_broadcast_ss(filter + 7); - f8 = _mm256_broadcast_ss(filter + 8); - - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - g = _mm256_load_ps(src + 0 * src_stride + j); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_load_ps(src + 1 * src_stride + j); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_load_ps(src + 2 * src_stride + j); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_load_ps(src + 3 * src_stride + j); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_load_ps(src + 4 * src_stride + j); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_load_ps(src + 5 * src_stride + j); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_load_ps(src + 6 * src_stride + j); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_load_ps(src + 7 * src_stride + j); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - g = _mm256_load_ps(src + 8 * src_stride + j); - g = _mm256_mul_ps(f8, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - _mm256_store_ps(dst + j, sum0); - } + __m256 f[MAX_FWIDTH_AVX_CONV]; - // Evaluate filter taps 9-16 - f0 = _mm256_broadcast_ss(filter + 9); - f1 = _mm256_broadcast_ss(filter + 10); - f2 = _mm256_broadcast_ss(filter + 11); - f3 = _mm256_broadcast_ss(filter + 12); - f4 = _mm256_broadcast_ss(filter + 13); - f5 = _mm256_broadcast_ss(filter + 14); - f6 = _mm256_broadcast_ss(filter + 15); - f7 = _mm256_broadcast_ss(filter + 16); + for (int k = 0; k < filter_width; k++) { + f[k] = _mm256_broadcast_ss(filter + k); + } - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g; + for (int j = 0; j < j_end; j += AVX_STEP) { + __m256 sum = _mm256_setzero_ps(); - g = _mm256_load_ps(src + 9 * src_stride + j); - g = _mm256_mul_ps(f0, g); - sum0 = g; + for (int k = 0; k < filter_width; k++) { + __m256 g = _mm256_loadu_ps(src + j + k); + g = _mm256_mul_ps(f[k], g); + sum = _mm256_add_ps(sum, g); + } - g = _mm256_load_ps(src + 10 * src_stride + j); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_load_ps(src + 11 * src_stride + j); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_load_ps(src + 12 * src_stride + j); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_load_ps(src + 13 * src_stride + j); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_load_ps(src + 14 * src_stride + j); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_load_ps(src + 15 * src_stride + j); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_load_ps(src + 16 * src_stride + j); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - sum0 = _mm256_add_ps(_mm256_load_ps(dst + j), sum0); - _mm256_store_ps(dst + j, sum0); - } + _mm256_storeu_ps(dst + j + radius, sum); + } } -void convolution_f32_avx_s_1d_v_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - src -= 4 * src_stride; // radius = 4 - - // Evaluate filter taps 0-8 - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - f5 = _mm256_broadcast_ss(filter + 5); - f6 = _mm256_broadcast_ss(filter + 6); - f7 = _mm256_broadcast_ss(filter + 7); - f8 = _mm256_broadcast_ss(filter + 8); - - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - g = _mm256_load_ps(src + 0 * src_stride + j); - g = _mm256_mul_ps(f0, g); - sum0 = g; +void convolution_f32_avx_s_1d_v_scanline(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end) { + int radius = filter_width / 2; - g = _mm256_load_ps(src + 1 * src_stride + j); - g = _mm256_mul_ps(f1, g); - sum1 = g; + src -= radius * src_stride; - g = _mm256_load_ps(src + 2 * src_stride + j); - g = _mm256_mul_ps(f2, g); - sum2 = g; + __m256 f[MAX_FWIDTH_AVX_CONV]; - g = _mm256_load_ps(src + 3 * src_stride + j); - g = _mm256_mul_ps(f3, g); - sum3 = g; + for (int k = 0; k < filter_width; k++) { + f[k] = _mm256_broadcast_ss(filter + k); + } - g = _mm256_load_ps(src + 4 * src_stride + j); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); + for (int j = 0; j < j_end; j += AVX_STEP) { + __m256 sum = _mm256_setzero_ps(); - g = _mm256_load_ps(src + 5 * src_stride + j); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); + for (int k = 0; k < filter_width; k++) { + __m256 g = _mm256_load_ps(src + k * src_stride + j); + g = _mm256_mul_ps(f[k], g); + sum = _mm256_add_ps(sum, g); + } - g = _mm256_load_ps(src + 6 * src_stride + j); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_load_ps(src + 7 * src_stride + j); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - g = _mm256_load_ps(src + 8 * src_stride + j); - g = _mm256_mul_ps(f8, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - _mm256_store_ps(dst + j, sum0); - } + _mm256_store_ps(dst + j, sum); + } } -void convolution_f32_avx_s_1d_v_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4; - src -= 2 * src_stride; // radius = 2 +void convolution_f32_avx_s_1d_v_sq_scanline(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end) { + int radius = filter_width / 2; - // Evaluate filter taps 0-5 - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); + src -= radius * src_stride; - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g; + __m256 f[MAX_FWIDTH_AVX_CONV]; - g = _mm256_load_ps(src + 0 * src_stride + j); - g = _mm256_mul_ps(f0, g); - sum0 = g; + for (int k = 0; k < filter_width; k++) { + f[k] = _mm256_broadcast_ss(filter + k); + } - g = _mm256_load_ps(src + 1 * src_stride + j); - g = _mm256_mul_ps(f1, g); - sum1 = g; + for (int j = 0; j < j_end; j += AVX_STEP) { + __m256 sum = _mm256_setzero_ps(); - g = _mm256_load_ps(src + 2 * src_stride + j); - g = _mm256_mul_ps(f2, g); - sum2 = g; + for (int k = 0; k < filter_width; k++) { + __m256 g = _mm256_load_ps(src + k * src_stride + j); + g = _mm256_mul_ps(g, g); + g = _mm256_mul_ps(f[k], g); + sum = _mm256_add_ps(sum, g); + } - g = _mm256_load_ps(src + 3 * src_stride + j); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_load_ps(src + 4 * src_stride + j); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - _mm256_store_ps(dst + j, sum0); - } + _mm256_store_ps(dst + j, sum); + } } -void convolution_f32_avx_s_1d( - int N, - const float * RESTRICT filter, - int filter_width, - const float * RESTRICT src, - float * RESTRICT dst, - float * RESTRICT tmp, - int width, - int height, - int src_stride, - int dst_stride) +void convolution_f32_avx_s_1d_v_xy_scanline(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end) { - int radius = filter_width / 2; - int width_mod8 = vmaf_floorn(width, 8); - int tmp_stride = vmaf_ceiln(width, 8); + int radius = filter_width / 2; - int i_vec_end = height - radius; - int j_vec_end = width_mod8 - vmaf_ceiln(radius + 1, 8); + src1 -= radius * src1_stride; + src2 -= radius * src2_stride; - // Vertical pass. - for (int i = 0; i < radius; ++i) { - for (int j = 0; j < width; ++j) { - tmp[i * tmp_stride + j] = convolution_edge_s(false, filter, filter_width, src, width, height, src_stride, i, j); - } - } - for (int i = radius; i < i_vec_end; ++i) { - convolution_f32_avx_s_1d_v_scanline(N, filter, filter_width, src + i * src_stride, tmp + i * tmp_stride, src_stride, width_mod8); + __m256 f[MAX_FWIDTH_AVX_CONV]; - for (int j = width_mod8; j < width; ++j) { - tmp[i * tmp_stride + j] = convolution_edge_s(false, filter, filter_width, src, width, height, src_stride, i, j); - } - } - for (int i = i_vec_end; i < height; ++i) { - for (int j = 0; j < width; ++j) { - tmp[i * tmp_stride + j] = convolution_edge_s(false, filter, filter_width, src, width, height, src_stride, i, j); - } - } + for (int k = 0; k < filter_width; k++) { + f[k] = _mm256_broadcast_ss(filter + k); + } - // Horizontal pass. - for (int i = 0; i < height; ++i) { - for (int j = 0; j < radius; ++j) { - dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j); - } + for (int j = 0; j < j_end; j += AVX_STEP) { + __m256 sum = _mm256_setzero_ps(); - convolution_f32_avx_s_1d_h_scanline(N, filter, filter_width, tmp + i * tmp_stride, dst + i * dst_stride, j_vec_end); + for (int k = 0; k < filter_width; k++) { + __m256 g = _mm256_load_ps(src1 + k * src1_stride + j); + __m256 g2 = _mm256_load_ps(src2 + k * src2_stride + j); + g = _mm256_mul_ps(g, g2); + g = _mm256_mul_ps(f[k], g); + sum = _mm256_add_ps(sum, g); + } - for (int j = j_vec_end + radius; j < width; ++j) { - dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j); - } - } -} - -void convolution_f32_avx_s(const float *filter, int filter_width, const float *src, float *dst, float *tmp, int width, int height, int src_stride, int dst_stride) -{ - switch (filter_width) { - case 17: - convolution_f32_avx_s_1d(17, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride); - break; - case 9: - convolution_f32_avx_s_1d(9, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride); - break; - case 5: - convolution_f32_avx_s_1d(5, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride); - break; - case 3: - convolution_f32_avx_s_1d(3, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride); - break; - default: - convolution_f32_avx_s_1d(0, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride); - break; - } + _mm256_store_ps(dst + j, sum); + } } -void convolution_f32_avx_s_1d_h_sq_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - - // Evaluate filter taps 0-8 - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - f5 = _mm256_broadcast_ss(filter + 5); - f6 = _mm256_broadcast_ss(filter + 6); - f7 = _mm256_broadcast_ss(filter + 7); - f8 = _mm256_broadcast_ss(filter + 8); - - for (int j = 0; j < j_end; j += 8) { - __m256 accum = _mm256_setzero_ps(); - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - g = _mm256_loadu_ps(src + j + 0); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_loadu_ps(src + j + 1); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_loadu_ps(src + j + 2); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_loadu_ps(src + j + 3); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_loadu_ps(src + j + 4); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_loadu_ps(src + j + 5); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_loadu_ps(src + j + 6); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_loadu_ps(src + j + 7); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - g = _mm256_loadu_ps(src + j + 8); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f8, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - accum = _mm256_add_ps(accum, sum0); - - _mm256_store_ps(dst + j + 8, accum); // radius = 8 - } - - // Evaluate filter taps 9-16 - f0 = _mm256_broadcast_ss(filter + 9); - f1 = _mm256_broadcast_ss(filter + 10); - f2 = _mm256_broadcast_ss(filter + 11); - f3 = _mm256_broadcast_ss(filter + 12); - f4 = _mm256_broadcast_ss(filter + 13); - f5 = _mm256_broadcast_ss(filter + 14); - f6 = _mm256_broadcast_ss(filter + 15); - f7 = _mm256_broadcast_ss(filter + 16); - - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - float *dst_ptr = dst + j + 8; // radius = 8 - - g = _mm256_loadu_ps(src + j + 9); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_loadu_ps(src + j + 10); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_loadu_ps(src + j + 11); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_loadu_ps(src + j + 12); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_loadu_ps(src + j + 13); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_loadu_ps(src + j + 14); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_loadu_ps(src + j + 15); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_loadu_ps(src + j + 16); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - sum0 = _mm256_add_ps(_mm256_load_ps(dst_ptr), sum0); - _mm256_store_ps(dst_ptr, sum0); - } -} - -void convolution_f32_avx_s_1d_h_sq_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - f5 = _mm256_broadcast_ss(filter + 5); - f6 = _mm256_broadcast_ss(filter + 6); - f7 = _mm256_broadcast_ss(filter + 7); - f8 = _mm256_broadcast_ss(filter + 8); - - for (int j = 0; j < j_end; j += 8) { - __m256 accum = _mm256_setzero_ps(); - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - g = _mm256_loadu_ps(src + j + 0); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_loadu_ps(src + j + 1); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_loadu_ps(src + j + 2); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_loadu_ps(src + j + 3); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_loadu_ps(src + j + 4); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_loadu_ps(src + j + 5); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_loadu_ps(src + j + 6); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_loadu_ps(src + j + 7); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - g = _mm256_loadu_ps(src + j + 8); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f8, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - accum = _mm256_add_ps(accum, sum0); - - _mm256_storeu_ps(dst + j + 4, accum); // radius = 4 - } -} - -void convolution_f32_avx_s_1d_h_sq_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4; - - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - - for (int j = 0; j < j_end; j += 8) { - __m256 accum = _mm256_setzero_ps(); - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - g = _mm256_loadu_ps(src + j + 0); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_loadu_ps(src + j + 1); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_loadu_ps(src + j + 2); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_loadu_ps(src + j + 3); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_loadu_ps(src + j + 4); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - accum = _mm256_add_ps(accum, sum0); - - _mm256_storeu_ps(dst + j + 2, accum); // radius = 2 - } -} - -// Filter a single scanline. -static void convolution_f32_avx_s_1d_v_sq_scanline(int N, const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end) -{ - - if (N == 5) - { - convolution_f32_avx_s_1d_v_sq_scanline_5(filter, filter_width, src, dst, src_stride, j_end); - } - else if (N == 9) - { - convolution_f32_avx_s_1d_v_sq_scanline_9(filter, filter_width, src, dst, src_stride, j_end); - } - else if (N == 17) - { - convolution_f32_avx_s_1d_v_sq_scanline_17(filter, filter_width, src, dst, src_stride, j_end); - } - else { - - int radius = filter_width / 2; - src -= radius * src_stride; - - for (int y = 0; y < filter_width; y += 9) { - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - - f0 = _mm256_setzero_ps(); - f1 = _mm256_setzero_ps(); - f2 = _mm256_setzero_ps(); - f3 = _mm256_setzero_ps(); - f5 = _mm256_setzero_ps(); - f6 = _mm256_setzero_ps(); - f7 = _mm256_setzero_ps(); - f8 = _mm256_setzero_ps(); - - switch (filter_width - y) { - default: - f8 = _mm256_broadcast_ss(filter + y + 8); - // fall through - case 8: - f7 = _mm256_broadcast_ss(filter + y + 7); - // fall through - case 7: - f6 = _mm256_broadcast_ss(filter + y + 6); - // fall through - case 6: - f5 = _mm256_broadcast_ss(filter + y + 5); - // fall through - case 5: - f4 = _mm256_broadcast_ss(filter + y + 4); - // fall through - case 4: - f3 = _mm256_broadcast_ss(filter + y + 3); - // fall through - case 3: - f2 = _mm256_broadcast_ss(filter + y + 2); - // fall through - case 2: - f1 = _mm256_broadcast_ss(filter + y + 1); - // fall through - case 1: - f0 = _mm256_broadcast_ss(filter + y + 0); - // fall through - } - - for (int j = 0; j < j_end; j += 8) { - __m256 accum = _mm256_setzero_ps(); - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - sum0 = _mm256_setzero_ps(); - sum1 = _mm256_setzero_ps(); - sum2 = _mm256_setzero_ps(); - sum3 = _mm256_setzero_ps(); - - switch (filter_width - y) { - default: - g = _mm256_load_ps(src + (y + 8) * src_stride + j); - g = _mm256_mul_ps(g, g); - sum0 = _mm256_mul_ps(f8, g); - // fall through - case 8: - g = _mm256_load_ps(src + (y + 7) * src_stride + j); - g = _mm256_mul_ps(g, g); - sum3 = _mm256_mul_ps(f7, g); - // fall through - case 7: - g = _mm256_load_ps(src + (y + 6) * src_stride + j); - g = _mm256_mul_ps(g, g); - sum2 = _mm256_mul_ps(f6, g); - // fall through - case 6: - g = _mm256_load_ps(src + (y + 5) * src_stride + j); - g = _mm256_mul_ps(g, g); - sum1 = _mm256_mul_ps(f5, g); - // fall through - case 5: - g = _mm256_load_ps(src + (y + 4) * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - // fall through - case 4: - g = _mm256_load_ps(src + (y + 3) * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f3, g); - sum3 = _mm256_add_ps(sum3, g); - // fall through - case 3: - g = _mm256_load_ps(src + (y + 2) * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f2, g); - sum2 = _mm256_add_ps(sum2, g); - // fall through - case 2: - g = _mm256_load_ps(src + (y + 1) * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f1, g); - sum1 = _mm256_add_ps(sum1, g); - // fall through - case 1: - g = _mm256_load_ps(src + (y + 0) * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f0, g); - sum0 = _mm256_add_ps(sum0, g); - // fall through - } - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - accum = _mm256_add_ps(accum, sum0); - - if (y) - accum = _mm256_add_ps(accum, _mm256_load_ps(dst + j)); - - _mm256_store_ps(dst + j, accum); - } - } - } -} - -void convolution_f32_avx_s_1d_v_sq_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - src -= 8 * src_stride; // radius = 8 - - // Evaluate filter taps 0-8 - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - f5 = _mm256_broadcast_ss(filter + 5); - f6 = _mm256_broadcast_ss(filter + 6); - f7 = _mm256_broadcast_ss(filter + 7); - f8 = _mm256_broadcast_ss(filter + 8); - - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g; +void convolution_f32_avx_s(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, float * RESTRICT tmp, int width, int height, int src_stride, int dst_stride) { + int radius = filter_width / 2; + int width_floor_step = vmaf_floorn(width, AVX_STEP); + int tmp_stride = vmaf_ceiln(width, AVX_STEP); - g = _mm256_load_ps(src + 0 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f0, g); - sum0 = g; + int i_vec_end = height - radius; + int j_vec_end = vmaf_floorn(width - radius, AVX_STEP); - g = _mm256_load_ps(src + 1 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f1, g); - sum1 = g; + // Vertical pass. + for (int i = 0; i < radius; ++i) { + for (int j = 0; j < width; ++j) { + tmp[i * tmp_stride + j] = convolution_edge_s(false, filter, filter_width, src, width, height, src_stride, i, j); + } + } + for (int i = radius; i < i_vec_end; ++i) { + convolution_f32_avx_s_1d_v_scanline(filter, filter_width, src + i * src_stride, tmp + i * tmp_stride, src_stride, width_floor_step); - g = _mm256_load_ps(src + 2 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_load_ps(src + 3 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_load_ps(src + 4 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_load_ps(src + 5 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_load_ps(src + 6 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_load_ps(src + 7 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - g = _mm256_load_ps(src + 8 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f8, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - _mm256_store_ps(dst + j, sum0); - } - - // Evaluate filter taps 9-16 - f0 = _mm256_broadcast_ss(filter + 9); - f1 = _mm256_broadcast_ss(filter + 10); - f2 = _mm256_broadcast_ss(filter + 11); - f3 = _mm256_broadcast_ss(filter + 12); - f4 = _mm256_broadcast_ss(filter + 13); - f5 = _mm256_broadcast_ss(filter + 14); - f6 = _mm256_broadcast_ss(filter + 15); - f7 = _mm256_broadcast_ss(filter + 16); - - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - g = _mm256_load_ps(src + 9 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_load_ps(src + 10 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_load_ps(src + 11 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_load_ps(src + 12 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_load_ps(src + 13 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_load_ps(src + 14 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_load_ps(src + 15 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_load_ps(src + 16 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - sum0 = _mm256_add_ps(_mm256_load_ps(dst + j), sum0); - _mm256_store_ps(dst + j, sum0); - } -} - -void convolution_f32_avx_s_1d_v_sq_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - src -= 4 * src_stride; // radius = 4 - - // Evaluate filter taps 0-8 - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - f5 = _mm256_broadcast_ss(filter + 5); - f6 = _mm256_broadcast_ss(filter + 6); - f7 = _mm256_broadcast_ss(filter + 7); - f8 = _mm256_broadcast_ss(filter + 8); - - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - g = _mm256_load_ps(src + 0 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_load_ps(src + 1 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_load_ps(src + 2 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_load_ps(src + 3 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_load_ps(src + 4 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_load_ps(src + 5 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_load_ps(src + 6 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_load_ps(src + 7 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - g = _mm256_load_ps(src + 8 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f8, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - _mm256_store_ps(dst + j, sum0); - } -} - -void convolution_f32_avx_s_1d_v_sq_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, int src_stride, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4; - src -= 2 * src_stride; // radius = 2 - - // Evaluate filter taps 0-5 - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g; - - g = _mm256_load_ps(src + 0 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_load_ps(src + 1 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_load_ps(src + 2 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_load_ps(src + 3 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_load_ps(src + 4 * src_stride + j); - g = _mm256_mul_ps(g, g); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - _mm256_store_ps(dst + j, sum0); - } -} - -void convolution_f32_avx_s_1d_sq( - int N, - const float * RESTRICT filter, - int filter_width, - const float * RESTRICT src, - float * RESTRICT dst, - float * RESTRICT tmp, - int width, - int height, - int src_stride, - int dst_stride) -{ - int radius = filter_width / 2; - int width_mod8 = vmaf_floorn(width, 8); - int tmp_stride = vmaf_ceiln(width, 8); - - int i_vec_end = height - radius; - int j_vec_end = width_mod8 - vmaf_ceiln(radius + 1, 8); - - // Vertical pass. - for (int i = 0; i < radius; ++i) { - for (int j = 0; j < width; ++j) { - tmp[i * tmp_stride + j] = convolution_edge_sq_s(false, filter, filter_width, src, width, height, src_stride, i, j); - } - } - for (int i = radius; i < i_vec_end; ++i) { - convolution_f32_avx_s_1d_v_sq_scanline(N, filter, filter_width, src + i * src_stride, tmp + i * tmp_stride, src_stride, width_mod8); - - for (int j = width_mod8; j < width; ++j) { - tmp[i * tmp_stride + j] = convolution_edge_sq_s(false, filter, filter_width, src, width, height, src_stride, i, j); - } - } - for (int i = i_vec_end; i < height; ++i) { - for (int j = 0; j < width; ++j) { - tmp[i * tmp_stride + j] = convolution_edge_sq_s(false, filter, filter_width, src, width, height, src_stride, i, j); - } - } - - // Horizontal pass. - for (int i = 0; i < height; ++i) { - for (int j = 0; j < radius; ++j) { - dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j); - } - - convolution_f32_avx_s_1d_h_scanline(N, filter, filter_width, tmp + i * tmp_stride, dst + i * dst_stride, j_vec_end); - - for (int j = j_vec_end + radius; j < width; ++j) { - dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j); - } - } -} - -void convolution_f32_avx_sq_s(const float *filter, int filter_width, const float *src, float *dst, float *tmp, int width, int height, int src_stride, int dst_stride) -{ - switch (filter_width) { - case 17: - convolution_f32_avx_s_1d_sq(17, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride); - break; - case 9: - convolution_f32_avx_s_1d_sq(9, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride); - break; - case 5: - convolution_f32_avx_s_1d_sq(5, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride); - break; - case 3: - convolution_f32_avx_s_1d_sq(3, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride); - break; - default: - convolution_f32_avx_s_1d_sq(0, filter, filter_width, src, dst, tmp, width, height, src_stride, dst_stride); - break; - } -} - -void convolution_f32_avx_s_1d_h_xy_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - - // Evaluate filter taps 0-8 - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - f5 = _mm256_broadcast_ss(filter + 5); - f6 = _mm256_broadcast_ss(filter + 6); - f7 = _mm256_broadcast_ss(filter + 7); - f8 = _mm256_broadcast_ss(filter + 8); - - for (int j = 0; j < j_end; j += 8) { - __m256 accum = _mm256_setzero_ps(); - __m256 sum0, sum1, sum2, sum3; - __m256 g, g2; - - g = _mm256_loadu_ps(src1 + j + 0); - g2 = _mm256_loadu_ps(src2 + j + 0); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_loadu_ps(src1 + j + 1); - g2 = _mm256_loadu_ps(src2 + j + 1); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_loadu_ps(src1 + j + 2); - g2 = _mm256_loadu_ps(src2 + j + 2); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_loadu_ps(src1 + j + 3); - g2 = _mm256_loadu_ps(src2 + j + 3); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_loadu_ps(src1 + j + 4); - g2 = _mm256_loadu_ps(src2 + j + 4); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_loadu_ps(src1 + j + 5); - g2 = _mm256_loadu_ps(src2 + j + 5); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_loadu_ps(src1 + j + 6); - g2 = _mm256_loadu_ps(src2 + j + 6); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_loadu_ps(src1 + j + 7); - g2 = _mm256_loadu_ps(src2 + j + 7); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - g = _mm256_loadu_ps(src1 + j + 8); - g2 = _mm256_loadu_ps(src2 + j + 8); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f8, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - accum = _mm256_add_ps(accum, sum0); - - _mm256_store_ps(dst + j + 8, accum); // radius = 8 - } - - // Evaluate filter taps 9-16 - f0 = _mm256_broadcast_ss(filter + 9); - f1 = _mm256_broadcast_ss(filter + 10); - f2 = _mm256_broadcast_ss(filter + 11); - f3 = _mm256_broadcast_ss(filter + 12); - f4 = _mm256_broadcast_ss(filter + 13); - f5 = _mm256_broadcast_ss(filter + 14); - f6 = _mm256_broadcast_ss(filter + 15); - f7 = _mm256_broadcast_ss(filter + 16); - - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g, g2; - - float *dst_ptr = dst + j + 8; // radius = 8 - - g = _mm256_loadu_ps(src1 + j + 9); - g2 = _mm256_loadu_ps(src2 + j + 9); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_loadu_ps(src1 + j + 10); - g2 = _mm256_loadu_ps(src2 + j + 10); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_loadu_ps(src1 + j + 11); - g2 = _mm256_loadu_ps(src2 + j + 11); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_loadu_ps(src1 + j + 12); - g2 = _mm256_loadu_ps(src2 + j + 12); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_loadu_ps(src1 + j + 13); - g2 = _mm256_loadu_ps(src2 + j + 13); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_loadu_ps(src1 + j + 14); - g2 = _mm256_loadu_ps(src2 + j + 14); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_loadu_ps(src1 + j + 15); - g2 = _mm256_loadu_ps(src2 + j + 15); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_loadu_ps(src1 + j + 16); - g2 = _mm256_loadu_ps(src2 + j + 16); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - sum0 = _mm256_add_ps(_mm256_load_ps(dst_ptr), sum0); - _mm256_store_ps(dst_ptr, sum0); - } -} - -void convolution_f32_avx_s_1d_h_xy_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - f5 = _mm256_broadcast_ss(filter + 5); - f6 = _mm256_broadcast_ss(filter + 6); - f7 = _mm256_broadcast_ss(filter + 7); - f8 = _mm256_broadcast_ss(filter + 8); - - for (int j = 0; j < j_end; j += 8) { - __m256 accum = _mm256_setzero_ps(); - __m256 sum0, sum1, sum2, sum3; - __m256 g, g2; - - g = _mm256_loadu_ps(src1 + j + 0); - g2 = _mm256_loadu_ps(src2 + j + 0); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_loadu_ps(src1 + j + 1); - g2 = _mm256_loadu_ps(src2 + j + 1); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_loadu_ps(src1 + j + 2); - g2 = _mm256_loadu_ps(src2 + j + 2); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_loadu_ps(src1 + j + 3); - g2 = _mm256_loadu_ps(src2 + j + 3); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_loadu_ps(src1 + j + 4); - g2 = _mm256_loadu_ps(src2 + j + 4); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_loadu_ps(src1 + j + 5); - g2 = _mm256_loadu_ps(src2 + j + 5); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_loadu_ps(src1 + j + 6); - g2 = _mm256_loadu_ps(src2 + j + 6); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_loadu_ps(src1 + j + 7); - g2 = _mm256_loadu_ps(src2 + j + 7); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - g = _mm256_loadu_ps(src1 + j + 8); - g2 = _mm256_loadu_ps(src2 + j + 8); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f8, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - accum = _mm256_add_ps(accum, sum0); - - _mm256_storeu_ps(dst + j + 4, accum); // radius = 4 - } -} - -void convolution_f32_avx_s_1d_h_xy_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4; - - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - - for (int j = 0; j < j_end; j += 8) { - __m256 accum = _mm256_setzero_ps(); - __m256 sum0, sum1, sum2, sum3; - __m256 g, g2; - - g = _mm256_loadu_ps(src1 + j + 0); - g2 = _mm256_loadu_ps(src2 + j + 0); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_loadu_ps(src1 + j + 1); - g2 = _mm256_loadu_ps(src2 + j + 1); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_loadu_ps(src1 + j + 2); - g2 = _mm256_loadu_ps(src2 + j + 2); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_loadu_ps(src1 + j + 3); - g2 = _mm256_loadu_ps(src2 + j + 3); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_loadu_ps(src1 + j + 4); - g2 = _mm256_loadu_ps(src2 + j + 4); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - accum = _mm256_add_ps(accum, sum0); - - _mm256_storeu_ps(dst + j + 2, accum); // radius = 2 - } -} - -// Filter a single scanline. -static void convolution_f32_avx_s_1d_v_xy_scanline(int N, const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end) -{ - - if (N == 5) - { - convolution_f32_avx_s_1d_v_xy_scanline_5(filter, filter_width, src1, src2, dst, src1_stride, src2_stride, j_end); - } - else if (N == 9) - { - convolution_f32_avx_s_1d_v_xy_scanline_9(filter, filter_width, src1, src2, dst, src1_stride, src2_stride, j_end); - } - else if (N == 17) - { - convolution_f32_avx_s_1d_v_xy_scanline_17(filter, filter_width, src1, src2, dst, src1_stride, src2_stride, j_end); - } - else { - - int radius = filter_width / 2; - src1 -= radius * src1_stride; - src2 -= radius * src2_stride; - - for (int y = 0; y < filter_width; y += 9) { - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - - f0 = _mm256_setzero_ps(); - f1 = _mm256_setzero_ps(); - f2 = _mm256_setzero_ps(); - f3 = _mm256_setzero_ps(); - f5 = _mm256_setzero_ps(); - f6 = _mm256_setzero_ps(); - f7 = _mm256_setzero_ps(); - f8 = _mm256_setzero_ps(); - - switch (filter_width - y) { - default: - f8 = _mm256_broadcast_ss(filter + y + 8); - // fall through - case 8: - f7 = _mm256_broadcast_ss(filter + y + 7); - // fall through - case 7: - f6 = _mm256_broadcast_ss(filter + y + 6); - // fall through - case 6: - f5 = _mm256_broadcast_ss(filter + y + 5); - // fall through - case 5: - f4 = _mm256_broadcast_ss(filter + y + 4); - // fall through - case 4: - f3 = _mm256_broadcast_ss(filter + y + 3); - // fall through - case 3: - f2 = _mm256_broadcast_ss(filter + y + 2); - // fall through - case 2: - f1 = _mm256_broadcast_ss(filter + y + 1); - // fall through - case 1: - f0 = _mm256_broadcast_ss(filter + y + 0); - // fall through - } - - for (int j = 0; j < j_end; j += 8) { - __m256 accum = _mm256_setzero_ps(); - __m256 sum0, sum1, sum2, sum3; - __m256 g, g2; - - sum0 = _mm256_setzero_ps(); - sum1 = _mm256_setzero_ps(); - sum2 = _mm256_setzero_ps(); - sum3 = _mm256_setzero_ps(); - - switch (filter_width - y) { - default: - g = _mm256_load_ps(src1 + (y + 8) * src1_stride + j); - g2 = _mm256_load_ps(src2 + (y + 8) * src2_stride + j); - g = _mm256_mul_ps(g, g2); - sum0 = _mm256_mul_ps(f8, g); - // fall through - case 8: - g = _mm256_load_ps(src1 + (y + 7) * src1_stride + j); - g2 = _mm256_load_ps(src2 + (y + 7) * src2_stride + j); - g = _mm256_mul_ps(g, g2); - sum3 = _mm256_mul_ps(f7, g); - // fall through - case 7: - g = _mm256_load_ps(src1 + (y + 6) * src1_stride + j); - g2 = _mm256_load_ps(src2 + (y + 6) * src2_stride + j); - g = _mm256_mul_ps(g, g2); - sum2 = _mm256_mul_ps(f6, g); - // fall through - case 6: - g = _mm256_load_ps(src1 + (y + 5) * src1_stride + j); - g2 = _mm256_load_ps(src2 + (y + 5) * src2_stride + j); - g = _mm256_mul_ps(g, g2); - sum1 = _mm256_mul_ps(f5, g); - // fall through - case 5: - g = _mm256_load_ps(src1 + (y + 4) * src1_stride + j); - g2 = _mm256_load_ps(src2 + (y + 4) * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - // fall through - case 4: - g = _mm256_load_ps(src1 + (y + 3) * src1_stride + j); - g2 = _mm256_load_ps(src2 + (y + 3) * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f3, g); - sum3 = _mm256_add_ps(sum3, g); - // fall through - case 3: - g = _mm256_load_ps(src1 + (y + 2) * src1_stride + j); - g2 = _mm256_load_ps(src2 + (y + 2) * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f2, g); - sum2 = _mm256_add_ps(sum2, g); - // fall through - case 2: - g = _mm256_load_ps(src1 + (y + 1) * src1_stride + j); - g2 = _mm256_load_ps(src2 + (y + 1) * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f1, g); - sum1 = _mm256_add_ps(sum1, g); - // fall through - case 1: - g = _mm256_load_ps(src1 + (y + 0) * src1_stride + j); - g2 = _mm256_load_ps(src2 + (y + 0) * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f0, g); - sum0 = _mm256_add_ps(sum0, g); - // fall through - } - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - accum = _mm256_add_ps(accum, sum0); - - if (y) - accum = _mm256_add_ps(accum, _mm256_load_ps(dst + j)); - - _mm256_store_ps(dst + j, accum); - } - } - } -} - -void convolution_f32_avx_s_1d_v_xy_scanline_17(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - src1 -= 8 * src1_stride; // radius = 8 - src2 -= 8 * src2_stride; // radius = 8 - - // Evaluate filter taps 0-8 - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - f5 = _mm256_broadcast_ss(filter + 5); - f6 = _mm256_broadcast_ss(filter + 6); - f7 = _mm256_broadcast_ss(filter + 7); - f8 = _mm256_broadcast_ss(filter + 8); - - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g, g2; - - g = _mm256_load_ps(src1 + 0 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 0 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_load_ps(src1 + 1 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 1 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_load_ps(src1 + 2 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 2 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_load_ps(src1 + 3 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 3 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_load_ps(src1 + 4 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 4 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_load_ps(src1 + 5 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 5 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_load_ps(src1 + 6 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 6 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_load_ps(src1 + 7 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 7 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - g = _mm256_load_ps(src1 + 8 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 8 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f8, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - _mm256_store_ps(dst + j, sum0); - } - - // Evaluate filter taps 9-16 - f0 = _mm256_broadcast_ss(filter + 9); - f1 = _mm256_broadcast_ss(filter + 10); - f2 = _mm256_broadcast_ss(filter + 11); - f3 = _mm256_broadcast_ss(filter + 12); - f4 = _mm256_broadcast_ss(filter + 13); - f5 = _mm256_broadcast_ss(filter + 14); - f6 = _mm256_broadcast_ss(filter + 15); - f7 = _mm256_broadcast_ss(filter + 16); - - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g, g2; - - g = _mm256_load_ps(src1 + 9 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 9 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_load_ps(src1 + 10 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 10 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_load_ps(src1 + 11 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 11 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_load_ps(src1 + 12 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 12 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_load_ps(src1 + 13 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 13 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_load_ps(src1 + 14 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 14 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_load_ps(src1 + 15 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 15 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_load_ps(src1 + 16 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 16 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - sum0 = _mm256_add_ps(_mm256_load_ps(dst + j), sum0); - _mm256_store_ps(dst + j, sum0); - } -} - -void convolution_f32_avx_s_1d_v_xy_scanline_9(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4, f5, f6, f7, f8; - src1 -= 4 * src1_stride; // radius = 4 - src2 -= 4 * src2_stride; // radius = 4 - - // Evaluate filter taps 0-8 - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - f5 = _mm256_broadcast_ss(filter + 5); - f6 = _mm256_broadcast_ss(filter + 6); - f7 = _mm256_broadcast_ss(filter + 7); - f8 = _mm256_broadcast_ss(filter + 8); - - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g, g2; - - g = _mm256_load_ps(src1 + 0 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 0 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_load_ps(src1 + 1 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 1 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_load_ps(src1 + 2 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 2 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_load_ps(src1 + 3 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 3 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_load_ps(src1 + 4 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 4 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - g = _mm256_load_ps(src1 + 5 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 5 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f5, g); - sum1 = _mm256_add_ps(sum1, g); - - g = _mm256_load_ps(src1 + 6 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 6 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f6, g); - sum2 = _mm256_add_ps(sum2, g); - - g = _mm256_load_ps(src1 + 7 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 7 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f7, g); - sum3 = _mm256_add_ps(sum3, g); - - g = _mm256_load_ps(src1 + 8 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 8 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f8, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - _mm256_store_ps(dst + j, sum0); - } -} - -void convolution_f32_avx_s_1d_v_xy_scanline_5(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, int src1_stride, int src2_stride, int j_end) -{ - (void) filter_width; - - __m256 f0, f1, f2, f3, f4; - src1 -= 2 * src1_stride; // radius = 2 - src2 -= 2 * src2_stride; // radius = 2 - - // Evaluate filter taps 0-5 - f0 = _mm256_broadcast_ss(filter + 0); - f1 = _mm256_broadcast_ss(filter + 1); - f2 = _mm256_broadcast_ss(filter + 2); - f3 = _mm256_broadcast_ss(filter + 3); - f4 = _mm256_broadcast_ss(filter + 4); - - for (int j = 0; j < j_end; j += 8) { - __m256 sum0, sum1, sum2, sum3; - __m256 g, g2; - - g = _mm256_load_ps(src1 + 0 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 0 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f0, g); - sum0 = g; - - g = _mm256_load_ps(src1 + 1 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 1 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f1, g); - sum1 = g; - - g = _mm256_load_ps(src1 + 2 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 2 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f2, g); - sum2 = g; - - g = _mm256_load_ps(src1 + 3 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 3 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f3, g); - sum3 = g; - - g = _mm256_load_ps(src1 + 4 * src1_stride + j); - g2 = _mm256_load_ps(src2 + 4 * src2_stride + j); - g = _mm256_mul_ps(g, g2); - g = _mm256_mul_ps(f4, g); - sum0 = _mm256_add_ps(sum0, g); - - sum0 = _mm256_add_ps(sum0, sum2); - sum1 = _mm256_add_ps(sum1, sum3); - - sum0 = _mm256_add_ps(sum0, sum1); - - _mm256_store_ps(dst + j, sum0); - } -} - -void convolution_f32_avx_s_1d_xy( - int N, - const float * RESTRICT filter, - int filter_width, - const float * RESTRICT src1, - const float * RESTRICT src2, - float * RESTRICT dst, - float * RESTRICT tmp, - int width, - int height, - int src1_stride, - int src2_stride, - int dst_stride) -{ - int radius = filter_width / 2; - int width_mod8 = vmaf_floorn(width, 8); - int tmp_stride = vmaf_ceiln(width, 8); - - int i_vec_end = height - radius; - int j_vec_end = width_mod8 - vmaf_ceiln(radius + 1, 8); - - // Vertical pass. - for (int i = 0; i < radius; ++i) { - for (int j = 0; j < width; ++j) { - tmp[i * tmp_stride + j] = convolution_edge_xy_s(false, filter, filter_width, src1, src2, width, height, src1_stride, src2_stride, i, j); - } - } - for (int i = radius; i < i_vec_end; ++i) { - convolution_f32_avx_s_1d_v_xy_scanline(N, filter, filter_width, src1 + i * src1_stride, src2 + i * src2_stride, tmp + i * tmp_stride, src1_stride, src2_stride, width_mod8); - - for (int j = width_mod8; j < width; ++j) { - tmp[i * tmp_stride + j] = convolution_edge_xy_s(false, filter, filter_width, src1, src2, width, height, src1_stride, src2_stride, i, j); - } - } - for (int i = i_vec_end; i < height; ++i) { - for (int j = 0; j < width; ++j) { - tmp[i * tmp_stride + j] = convolution_edge_xy_s(false, filter, filter_width, src1, src2, width, height, src1_stride, src2_stride, i, j); - } - } - - // Horizontal pass. - for (int i = 0; i < height; ++i) { - for (int j = 0; j < radius; ++j) { - dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j); - } - - convolution_f32_avx_s_1d_h_scanline(N, filter, filter_width, tmp + i * tmp_stride, dst + i * dst_stride, j_vec_end); - - for (int j = j_vec_end + radius; j < width; ++j) { - dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j); - } - } -} - -void convolution_f32_avx_xy_s(const float *filter, int filter_width, const float *src1, const float *src2, float *dst, float *tmp, int width, int height, int src1_stride, int src2_stride, int dst_stride) -{ - switch (filter_width) { - case 17: - convolution_f32_avx_s_1d_xy(17, filter, filter_width, src1, src2, dst, tmp, width, height, src1_stride, src2_stride, dst_stride); - break; - case 9: - convolution_f32_avx_s_1d_xy(9, filter, filter_width, src1, src2, dst, tmp, width, height, src1_stride, src2_stride, dst_stride); - break; - case 5: - convolution_f32_avx_s_1d_xy(5, filter, filter_width, src1, src2, dst, tmp, width, height, src1_stride, src2_stride, dst_stride); - break; - case 3: - convolution_f32_avx_s_1d_xy(3, filter, filter_width, src1, src2, dst, tmp, width, height, src1_stride, src2_stride, dst_stride); - break; - default: - convolution_f32_avx_s_1d_xy(0, filter, filter_width, src1, src2, dst, tmp, width, height, src1_stride, src2_stride, dst_stride); - break; - } -} + for (int j = width_floor_step; j < width; ++j) { + tmp[i * tmp_stride + j] = convolution_edge_s(false, filter, filter_width, src, width, height, src_stride, i, j); + } + } + for (int i = i_vec_end; i < height; ++i) { + for (int j = 0; j < width; ++j) { + tmp[i * tmp_stride + j] = convolution_edge_s(false, filter, filter_width, src, width, height, src_stride, i, j); + } + } + + // Horizontal pass. + for (int i = 0; i < height; ++i) { + for (int j = 0; j < radius; ++j) { + dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j); + } + + convolution_f32_avx_s_1d_h_scanline(filter, filter_width, tmp + i * tmp_stride, dst + i * dst_stride, j_vec_end); + + for (int j = j_vec_end; j < width; ++j) { + dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j); + } + } +} + +void convolution_f32_avx_sq_s(const float * RESTRICT filter, int filter_width, const float * RESTRICT src, float * RESTRICT dst, float * RESTRICT tmp, int width, int height, int src_stride, int dst_stride) { + int radius = filter_width / 2; + int width_floor_step = vmaf_floorn(width, AVX_STEP); + int tmp_stride = vmaf_ceiln(width, AVX_STEP); + + int i_vec_end = height - radius; + int j_vec_end = vmaf_floorn(width - radius, AVX_STEP); + + // Vertical pass. + for (int i = 0; i < radius; ++i) { + for (int j = 0; j < width; ++j) { + tmp[i * tmp_stride + j] = convolution_edge_sq_s(false, filter, filter_width, src, width, height, src_stride, i, j); + } + } + for (int i = radius; i < i_vec_end; ++i) { + convolution_f32_avx_s_1d_v_sq_scanline(filter, filter_width, src + i * src_stride, tmp + i * tmp_stride, src_stride, width_floor_step); + + for (int j = width_floor_step; j < width; ++j) { + tmp[i * tmp_stride + j] = convolution_edge_sq_s(false, filter, filter_width, src, width, height, src_stride, i, j); + } + } + for (int i = i_vec_end; i < height; ++i) { + for (int j = 0; j < width; ++j) { + tmp[i * tmp_stride + j] = convolution_edge_sq_s(false, filter, filter_width, src, width, height, src_stride, i, j); + } + } + + // Horizontal pass. + for (int i = 0; i < height; ++i) { + for (int j = 0; j < radius; ++j) { + dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j); + } + + convolution_f32_avx_s_1d_h_scanline(filter, filter_width, tmp + i * tmp_stride, dst + i * dst_stride, j_vec_end); + + for (int j = j_vec_end; j < width; ++j) { + dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j); + } + } +} + +void convolution_f32_avx_xy_s(const float * RESTRICT filter, int filter_width, const float * RESTRICT src1, const float * RESTRICT src2, float * RESTRICT dst, float * RESTRICT tmp, int width, int height, int src1_stride, int src2_stride, int dst_stride) +{ + int radius = filter_width / 2; + int width_floor_step = vmaf_floorn(width, AVX_STEP); + int tmp_stride = vmaf_ceiln(width, AVX_STEP); + + int i_vec_end = height - radius; + int j_vec_end = vmaf_floorn(width - radius, AVX_STEP); + + // Vertical pass. + for (int i = 0; i < radius; ++i) { + for (int j = 0; j < width; ++j) { + tmp[i * tmp_stride + j] = convolution_edge_xy_s(false, filter, filter_width, src1, src2, width, height, src1_stride, src2_stride, i, j); + } + } + for (int i = radius; i < i_vec_end; ++i) { + convolution_f32_avx_s_1d_v_xy_scanline(filter, filter_width, src1 + i * src1_stride, src2 + i * src2_stride, tmp + i * tmp_stride, src1_stride, src2_stride, width_floor_step); + + for (int j = width_floor_step; j < width; ++j) { + tmp[i * tmp_stride + j] = convolution_edge_xy_s(false, filter, filter_width, src1, src2, width, height, src1_stride, src2_stride, i, j); + } + } + for (int i = i_vec_end; i < height; ++i) { + for (int j = 0; j < width; ++j) { + tmp[i * tmp_stride + j] = convolution_edge_xy_s(false, filter, filter_width, src1, src2, width, height, src1_stride, src2_stride, i, j); + } + } + + // Horizontal pass. + for (int i = 0; i < height; ++i) { + for (int j = 0; j < radius; ++j) { + dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j); + } + + convolution_f32_avx_s_1d_h_scanline(filter, filter_width, tmp + i * tmp_stride, dst + i * dst_stride, j_vec_end); + + for (int j = j_vec_end; j < width; ++j) { + dst[i * dst_stride + j] = convolution_edge_s(true, filter, filter_width, tmp, width, height, tmp_stride, i, j); + } + } +} \ No newline at end of file diff --git a/libvmaf/src/feature/float_adm.c b/libvmaf/src/feature/float_adm.c index 60d1986ec..4be29404e 100644 --- a/libvmaf/src/feature/float_adm.c +++ b/libvmaf/src/feature/float_adm.c @@ -137,8 +137,8 @@ static int extract(VmafFeatureExtractor *fex, (void) ref_pic_90; (void) dist_pic_90; - picture_copy(s->ref, s->float_stride, ref_pic, -128, ref_pic->bpc); - picture_copy(s->dist, s->float_stride, dist_pic, -128, dist_pic->bpc); + picture_copy(s->ref, s->float_stride, ref_pic, -128, ref_pic->bpc, 0); + picture_copy(s->dist, s->float_stride, dist_pic, -128, dist_pic->bpc, 0); double score, score_num, score_den; double scores[8]; diff --git a/libvmaf/src/feature/float_ansnr.c b/libvmaf/src/feature/float_ansnr.c index 95ca32f4f..046171bcb 100644 --- a/libvmaf/src/feature/float_ansnr.c +++ b/libvmaf/src/feature/float_ansnr.c @@ -81,8 +81,8 @@ static int extract(VmafFeatureExtractor *fex, (void) ref_pic_90; (void) dist_pic_90; - picture_copy(s->ref, s->float_stride, ref_pic, -128, ref_pic->bpc); - picture_copy(s->dist, s->float_stride, dist_pic, -128, dist_pic->bpc); + picture_copy(s->ref, s->float_stride, ref_pic, -128, ref_pic->bpc, 0); + picture_copy(s->dist, s->float_stride, dist_pic, -128, dist_pic->bpc, 0); double score, score_psnr; err = compute_ansnr(s->ref, s->dist, ref_pic->w[0], ref_pic->h[0], diff --git a/libvmaf/src/feature/float_moment.c b/libvmaf/src/feature/float_moment.c index 5131e3eea..2135f6178 100644 --- a/libvmaf/src/feature/float_moment.c +++ b/libvmaf/src/feature/float_moment.c @@ -65,8 +65,8 @@ static int extract(VmafFeatureExtractor *fex, (void) ref_pic_90; (void) dist_pic_90; - picture_copy(s->ref, s->float_stride, ref_pic, 0, ref_pic->bpc); - picture_copy(s->dist, s->float_stride, dist_pic, 0, dist_pic->bpc); + picture_copy(s->ref, s->float_stride, ref_pic, 0, ref_pic->bpc, 0); + picture_copy(s->dist, s->float_stride, dist_pic, 0, dist_pic->bpc, 0); double score[4]; err = compute_1st_moment(s->ref, ref_pic->w[0], ref_pic->h[0], diff --git a/libvmaf/src/feature/float_motion.c b/libvmaf/src/feature/float_motion.c index 947247b4e..b48645b85 100644 --- a/libvmaf/src/feature/float_motion.c +++ b/libvmaf/src/feature/float_motion.c @@ -146,7 +146,7 @@ static int extract(VmafFeatureExtractor *fex, unsigned blur_idx_1 = (index + 1) % 3; unsigned blur_idx_2 = (index + 2) % 3; - picture_copy(s->ref, s->float_stride, ref_pic, -128, ref_pic->bpc); + picture_copy(s->ref, s->float_stride, ref_pic, -128, ref_pic->bpc, 0); convolution_f32_c_s(FILTER_5_s, 5, s->ref, s->blur[blur_idx_0], s->tmp, ref_pic->w[0], ref_pic->h[0], s->float_stride / sizeof(float), diff --git a/libvmaf/src/feature/float_ms_ssim.c b/libvmaf/src/feature/float_ms_ssim.c index b96d97387..efd637acd 100644 --- a/libvmaf/src/feature/float_ms_ssim.c +++ b/libvmaf/src/feature/float_ms_ssim.c @@ -109,8 +109,8 @@ static int extract(VmafFeatureExtractor *fex, (void) ref_pic_90; (void) dist_pic_90; - picture_copy(s->ref, s->float_stride, ref_pic, 0, ref_pic->bpc); - picture_copy(s->dist, s->float_stride, dist_pic, 0, dist_pic->bpc); + picture_copy(s->ref, s->float_stride, ref_pic, 0, ref_pic->bpc, 0); + picture_copy(s->dist, s->float_stride, dist_pic, 0, dist_pic->bpc, 0); double score, l_scores[5], c_scores[5], s_scores[5]; err = compute_ms_ssim(s->ref, s->dist, ref_pic->w[0], ref_pic->h[0], diff --git a/libvmaf/src/feature/float_psnr.c b/libvmaf/src/feature/float_psnr.c index 056fe82fb..4394b1399 100644 --- a/libvmaf/src/feature/float_psnr.c +++ b/libvmaf/src/feature/float_psnr.c @@ -81,8 +81,8 @@ static int extract(VmafFeatureExtractor *fex, (void) ref_pic_90; (void) dist_pic_90; - picture_copy(s->ref, s->float_stride, ref_pic, 0, ref_pic->bpc); - picture_copy(s->dist, s->float_stride, dist_pic, 0, dist_pic->bpc); + picture_copy(s->ref, s->float_stride, ref_pic, 0, ref_pic->bpc, 0); + picture_copy(s->dist, s->float_stride, dist_pic, 0, dist_pic->bpc, 0); double score; err = compute_psnr(s->ref, s->dist, ref_pic->w[0], ref_pic->h[0], diff --git a/libvmaf/src/feature/float_ssim.c b/libvmaf/src/feature/float_ssim.c index 93fd348f6..50f944219 100644 --- a/libvmaf/src/feature/float_ssim.c +++ b/libvmaf/src/feature/float_ssim.c @@ -109,8 +109,8 @@ static int extract(VmafFeatureExtractor *fex, (void) ref_pic_90; (void) dist_pic_90; - picture_copy(s->ref, s->float_stride, ref_pic, 0, ref_pic->bpc); - picture_copy(s->dist, s->float_stride, dist_pic, 0, dist_pic->bpc); + picture_copy(s->ref, s->float_stride, ref_pic, 0, ref_pic->bpc, 0); + picture_copy(s->dist, s->float_stride, dist_pic, 0, dist_pic->bpc, 0); double score, l_score, c_score, s_score; err = compute_ssim(s->ref, s->dist, ref_pic->w[0], ref_pic->h[0], diff --git a/libvmaf/src/feature/float_vif.c b/libvmaf/src/feature/float_vif.c index b1e5d63f7..9960568b5 100644 --- a/libvmaf/src/feature/float_vif.c +++ b/libvmaf/src/feature/float_vif.c @@ -113,8 +113,8 @@ static int extract(VmafFeatureExtractor *fex, (void) ref_pic_90; (void) dist_pic_90; - picture_copy(s->ref, s->float_stride, ref_pic, -128, ref_pic->bpc); - picture_copy(s->dist, s->float_stride, dist_pic, -128, dist_pic->bpc); + picture_copy(s->ref, s->float_stride, ref_pic, -128, ref_pic->bpc, 0); + picture_copy(s->dist, s->float_stride, dist_pic, -128, dist_pic->bpc, 0); double score, score_num, score_den; double scores[8]; diff --git a/libvmaf/src/feature/picture_copy.c b/libvmaf/src/feature/picture_copy.c index cedf466ed..01cd7aae9 100644 --- a/libvmaf/src/feature/picture_copy.c +++ b/libvmaf/src/feature/picture_copy.c @@ -21,44 +21,46 @@ #include void picture_copy_hbd(float *dst, ptrdiff_t dst_stride, - VmafPicture *src, int offset, float scaler) + VmafPicture *src, int offset, float scaler, int channel) { float *float_data = dst; - uint16_t *data = src->data[0]; + uint16_t *data = src->data[channel]; - for (unsigned i = 0; i < src->h[0]; i++) { - for (unsigned j = 0; j < src->w[0]; j++) { + for (unsigned i = 0; i < src->h[channel]; i++) { + for (unsigned j = 0; j < src->w[channel]; j++) { float_data[j] = (float) data[j] / scaler + offset; } float_data += dst_stride / sizeof(float); - data += src->stride[0] / 2; + data += src->stride[channel] / 2; } return; } void picture_copy(float *dst, ptrdiff_t dst_stride, - VmafPicture *src, int offset, unsigned bpc) + VmafPicture *src, int offset, unsigned bpc, int channel) { if (bpc == 10) { - picture_copy_hbd(dst, dst_stride, src, offset, 4.0f); + picture_copy_hbd(dst, dst_stride, src, offset, 4.0f, channel); return; - } else if (bpc == 12) { - picture_copy_hbd(dst, dst_stride, src, offset, 16.0f); + } + else if (bpc == 12) { + picture_copy_hbd(dst, dst_stride, src, offset, 16.0f, channel); return; - } else if (bpc == 16) { - picture_copy_hbd(dst, dst_stride, src, offset, 256.0f); + } + else if (bpc == 16) { + picture_copy_hbd(dst, dst_stride, src, offset, 256.0f, channel); return; } float *float_data = dst; - uint8_t *data = src->data[0]; + uint8_t *data = src->data[channel]; - for (unsigned i = 0; i < src->h[0]; i++) { - for (unsigned j = 0; j < src->w[0]; j++) { + for (unsigned i = 0; i < src->h[channel]; i++) { + for (unsigned j = 0; j < src->w[channel]; j++) { float_data[j] = (float) data[j] + offset; } float_data += dst_stride / sizeof(float); - data += src->stride[0]; + data += src->stride[channel]; } return; diff --git a/libvmaf/src/feature/picture_copy.h b/libvmaf/src/feature/picture_copy.h index 3be759892..11f16c0ea 100644 --- a/libvmaf/src/feature/picture_copy.h +++ b/libvmaf/src/feature/picture_copy.h @@ -18,4 +18,4 @@ #include void picture_copy(float *dst, ptrdiff_t dst_stride, VmafPicture *src, - int offset, unsigned bpc); + int offset, unsigned bpc, int channel); diff --git a/libvmaf/src/feature/vif.c b/libvmaf/src/feature/vif.c index 38a0520b8..edb6e8447 100644 --- a/libvmaf/src/feature/vif.c +++ b/libvmaf/src/feature/vif.c @@ -23,6 +23,7 @@ #include #include +#include "log.h" #include "mem.h" #include "offset.h" #include "vif_options.h" @@ -44,7 +45,7 @@ void apply_frame_differencing(const float *current_frame, const float *previous_ int compute_vif(const float *ref, const float *dis, int w, int h, int ref_stride, int dis_stride, double *score, double *score_num, double *score_den, double *scores, - double vif_enhn_gain_limit, double vif_kernelscale) + double vif_enhn_gain_limit, double vif_kernelscale, int vif_skip_scale0, double vif_sigma_nsq) { float *data_buf = 0; char *data_top; @@ -59,7 +60,11 @@ int compute_vif(const float *ref, const float *dis, int w, int h, int ref_stride float *ref_dis_filt; float *tmpbuf; - const float *filter; + float *num_array; + float *den_array; + + /* Filters will never be larger than 128 elements */ + float filter[128]; int filter_width; /* Offset pointers to adjust for convolution border handling. */ @@ -73,6 +78,8 @@ int compute_vif(const float *ref, const float *dis, int w, int h, int ref_stride float *ref_sq_filt_adj; float *dis_sq_filt_adj; float *ref_dis_filt_adj = 0; + float *num_array_adj = 0; + float *den_array_adj = 0; #endif /* Special handling of first scale. */ @@ -84,74 +91,58 @@ int compute_vif(const float *ref, const float *dis, int w, int h, int ref_stride int buf_stride = ALIGN_CEIL(w * sizeof(float)); size_t buf_sz_one = (size_t)buf_stride * h; - int scale; + float num = 0; + float den = 0; + int ret = 1; - int kernelscale_index = -1; - if (ALMOST_EQUAL(vif_kernelscale, 1.0)) { - kernelscale_index = vif_kernelscale_1; - } else if (ALMOST_EQUAL(vif_kernelscale, 1.0/2)) { - kernelscale_index = vif_kernelscale_1o2; - } else if (ALMOST_EQUAL(vif_kernelscale, 3.0/2)) { - kernelscale_index = vif_kernelscale_3o2; - } else if (ALMOST_EQUAL(vif_kernelscale, 2.0)) { - kernelscale_index = vif_kernelscale_2; - } else if (ALMOST_EQUAL(vif_kernelscale, 2.0/3)) { - kernelscale_index = vif_kernelscale_2o3; - } else if (ALMOST_EQUAL(vif_kernelscale, 2.4/1.0)) { - kernelscale_index = vif_kernelscale_24o10; - } else if (ALMOST_EQUAL(vif_kernelscale, 360/97.0)) { - kernelscale_index = vif_kernelscale_360o97; - } else if (ALMOST_EQUAL(vif_kernelscale, 4.0/3.0)) { - kernelscale_index = vif_kernelscale_4o3; - } else if (ALMOST_EQUAL(vif_kernelscale, 3.5/3.0)) { - kernelscale_index = vif_kernelscale_3d5o3; - } else if (ALMOST_EQUAL(vif_kernelscale, 3.75/3.0)) { - kernelscale_index = vif_kernelscale_3d75o3; - } else if (ALMOST_EQUAL(vif_kernelscale, 4.25/3.0)) { - kernelscale_index = vif_kernelscale_4d25o3; - } else { - printf("error: vif_kernelscale can only be 0.5, 1.0, 1.5, 2.0, 2.0/3, 2.4, 360/97, 4.0/3.0, 3.5/3.0, 3.75/3.0, 4.25/3.0 for now, but is %f\n", vif_kernelscale); - fflush(stdout); + if (!vif_validate_kernelscale(vif_kernelscale)) { + vmaf_log(VMAF_LOG_LEVEL_ERROR, "invalid vif_kernelscale: %f", vif_kernelscale); goto fail_or_end; } - // Code optimized to save on multiple buffer copies - // hence the reduction in the number of buffers required from 15 to 8 -#define VIF_BUF_CNT 8 - if (SIZE_MAX / buf_sz_one < VIF_BUF_CNT) - { - printf("error: SIZE_MAX / buf_sz_one < VIF_BUF_CNT, buf_sz_one = %zu.\n", buf_sz_one); - fflush(stdout); - goto fail_or_end; - } - - if (!(data_buf = aligned_malloc(buf_sz_one * VIF_BUF_CNT, MAX_ALIGN))) - { - printf("error: aligned_malloc failed for data_buf.\n"); - fflush(stdout); - goto fail_or_end; + // Code optimized to save on multiple buffer copies + // hence the reduction in the number of buffers required from 15 to 10 +#define VIF_BUF_CNT 10 + if (SIZE_MAX / buf_sz_one < VIF_BUF_CNT) + { + printf("error: SIZE_MAX / buf_sz_one < VIF_BUF_CNT, buf_sz_one = %zu.\n", buf_sz_one); + fflush(stdout); + goto fail_or_end; + } + + if (!(data_buf = aligned_malloc(buf_sz_one * VIF_BUF_CNT, MAX_ALIGN))) + { + printf("error: aligned_malloc failed for data_buf.\n"); + fflush(stdout); + goto fail_or_end; + } + + data_top = (char *)data_buf; + + ref_scale = (float *)data_top; data_top += buf_sz_one; + dis_scale = (float *)data_top; data_top += buf_sz_one; + mu1 = (float *)data_top; data_top += buf_sz_one; + mu2 = (float *)data_top; data_top += buf_sz_one; + ref_sq_filt = (float *)data_top; data_top += buf_sz_one; + dis_sq_filt = (float *)data_top; data_top += buf_sz_one; + ref_dis_filt = (float *)data_top; data_top += buf_sz_one; + num_array = (float *)data_top; data_top += buf_sz_one; + den_array = (float *)data_top; data_top += buf_sz_one; + tmpbuf = (float *)data_top; data_top += buf_sz_one; + + unsigned scale_start = 0; + if (vif_skip_scale0) { + scale_start = 1; } - data_top = (char *)data_buf; - - ref_scale = (float *)data_top; data_top += buf_sz_one; - dis_scale = (float *)data_top; data_top += buf_sz_one; - mu1 = (float *)data_top; data_top += buf_sz_one; - mu2 = (float *)data_top; data_top += buf_sz_one; - ref_sq_filt = (float *)data_top; data_top += buf_sz_one; - dis_sq_filt = (float *)data_top; data_top += buf_sz_one; - ref_dis_filt = (float *)data_top; data_top += buf_sz_one; - tmpbuf = (float *)data_top; data_top += buf_sz_one; - - for (scale = 0; scale < 4; ++scale) + for (unsigned scale = scale_start; scale < 4; ++scale) { #ifdef VIF_OPT_DEBUG_DUMP char pathbuf[256]; #endif - - filter = vif_filter1d_table_s[kernelscale_index][scale]; - filter_width = vif_filter1d_width[kernelscale_index][scale]; + filter_width = vif_get_filter_size(scale, vif_kernelscale); + vif_get_filter(filter, scale, vif_kernelscale); #ifdef VIF_OPT_HANDLE_BORDERS int buf_valid_w = w; @@ -196,15 +187,14 @@ int compute_vif(const float *ref, const float *dis, int w, int h, int ref_stride vif_filter1d_s(filter, curr_ref_scale, mu1, tmpbuf, w, h, curr_ref_stride, buf_stride, filter_width); vif_filter1d_s(filter, curr_dis_scale, mu2, tmpbuf, w, h, curr_dis_stride, buf_stride, filter_width); - // Code optimized by adding intrinsic code for the functions, - // vif_filter1d_sq and vif_filter1d_sq + // Code optimized by adding intrinsic code for the functions, + // vif_filter1d_sq and vif_filter1d_sq vif_filter1d_sq_s(filter, curr_ref_scale, ref_sq_filt, tmpbuf, w, h, curr_ref_stride, buf_stride, filter_width); vif_filter1d_sq_s(filter, curr_dis_scale, dis_sq_filt, tmpbuf, w, h, curr_dis_stride, buf_stride, filter_width); vif_filter1d_xy_s(filter, curr_ref_scale, curr_dis_scale, ref_dis_filt, tmpbuf, w, h, curr_ref_stride, curr_dis_stride, buf_stride, filter_width); - float num, den; - vif_statistic_s(mu1, mu2, ref_sq_filt, dis_sq_filt, ref_dis_filt, &num, &den, - w, h, buf_stride, buf_stride, buf_stride, buf_stride, buf_stride, vif_enhn_gain_limit); + vif_statistic_s(mu1, mu2, ref_sq_filt, dis_sq_filt, ref_dis_filt, num_array, den_array, + w, h, buf_stride, buf_stride, buf_stride, buf_stride, buf_stride, vif_enhn_gain_limit, vif_sigma_nsq); mu1_adj = ADJUST(mu1); mu2_adj = ADJUST(mu2); @@ -237,8 +227,17 @@ int compute_vif(const float *ref, const float *dis, int w, int h, int ref_stride sprintf(pathbuf, "stage/ref_dis_filt[%d].bin", scale); write_image(pathbuf, ref_dis_filt_adj, buf_valid_w, buf_valid_h, buf_stride, sizeof(float)); + + sprintf(pathbuf, "stage/num_array[%d].bin", scale); + write_image(pathbuf, num_array_adj, buf_valid_w, buf_valid_h, buf_stride, sizeof(float)); + + sprintf(pathbuf, "stage/den_array[%d].bin", scale); + write_image(pathbuf, den_array_adj, buf_valid_w, buf_valid_h, buf_stride, sizeof(float)); #endif + num = *num_array; + den = *den_array; + scores[2*scale] = num; scores[2*scale+1] = den; @@ -250,7 +249,7 @@ int compute_vif(const float *ref, const float *dis, int w, int h, int ref_stride *score_num = 0.0; *score_den = 0.0; - for (scale = 0; scale < 4; ++scale) + for (unsigned scale = scale_start; scale < 4; ++scale) { *score_num += scores[2*scale]; *score_den += scores[2*scale+1]; @@ -368,8 +367,8 @@ int vifdiff(int (*read_frame)(float *ref_data, float *main_data, float *temp_dat if (frm_idx > 0) { apply_frame_differencing(ref_buf, prev_ref_buf, ref_diff_buf, w, h, stride / sizeof(float)); - apply_frame_differencing(dis_buf, prev_dis_buf, dis_diff_buf, w, h, stride / sizeof(float)); - } + apply_frame_differencing(dis_buf, prev_dis_buf, dis_diff_buf, w, h, stride / sizeof(float)); + } // copy the current frame to the previous frame buffer to have it available for next time you apply frame differencing memcpy(prev_ref_buf, ref_buf, data_sz); @@ -381,22 +380,22 @@ int vifdiff(int (*read_frame)(float *ref_data, float *main_data, float *temp_dat // unreliable scores for an earlier video frame, rather than the latest one. This might be better for video quality calculations, since recency effects // places more weight on later frames. if (frm_idx == 0) - { - score = 0.0; - score_num = 0.0; - score_den = 0.0; - for(int scale = 0; scale < 4; scale++){ - scores[2 * scale] = 0.0; - scores[2 * scale + 1] = 0.0 + 1e-5; - } - } - else - { + { + score = 0.0; + score_num = 0.0; + score_den = 0.0; + for(int scale = 0; scale < 4; scale++){ + scores[2 * scale] = 0.0; + scores[2 * scale + 1] = 0.0 + 1e-5; + } + } + else + { // compute if ((ret = compute_vif(ref_diff_buf, dis_diff_buf, w, h, stride, stride, &score, &score_num, &score_den, scores, DEFAULT_VIF_ENHN_GAIN_LIMIT, - DEFAULT_VIF_KERNELSCALE))) + DEFAULT_VIF_KERNELSCALE, 0, 2.0))) { printf("error: compute_vifdiff failed.\n"); fflush(stdout); diff --git a/libvmaf/src/feature/vif_tools.c b/libvmaf/src/feature/vif_tools.c index 09a2c8122..1bbe0d4ff 100644 --- a/libvmaf/src/feature/vif_tools.c +++ b/libvmaf/src/feature/vif_tools.c @@ -16,14 +16,18 @@ * */ +#include +#include #include #include #include #include +#include #include #include "config.h" #include "cpu.h" +#include "log.h" #include "mem.h" #include "common/convolution.h" #include "vif_options.h" @@ -32,6 +36,10 @@ #define MIN(x, y) (((x) < (y)) ? (x) : (y)) #define MAX(x, y) (((x) > (y)) ? (x) : (y)) +#ifndef M_PI + #define M_PI 3.14159265358979323846 +#endif + #ifdef VIF_OPT_FAST_LOG2 // option to replace log2 calculation with faster speed static const float log2_poly_s[9] = { -0.012671635276421, 0.064841182402670, -0.157048836463065, 0.257167726303123, -0.353800560300520, 0.480131410397451, -0.721314327952201, 1.442694803896991, 0 }; @@ -82,159 +90,63 @@ static float log2f_approx(float x) #endif /* VIF_FAST_LOG2 */ -#if 1 //defined(_MSC_VER) -const float vif_filter1d_table_s[11][4][65] = { - { // kernelscale = 1.0 - {0.00745626912, 0.0142655009, 0.0250313189, 0.0402820669, 0.0594526194, 0.0804751068, 0.0999041125, 0.113746084, 0.118773937, 0.113746084, 0.0999041125, 0.0804751068, 0.0594526194, 0.0402820669, 0.0250313189, 0.0142655009, 0.00745626912}, - {0.0189780835, 0.0558981746, 0.120920904, 0.192116052, 0.224173605, 0.192116052, 0.120920904, 0.0558981746, 0.0189780835}, - {0.054488685, 0.244201347, 0.402619958, 0.244201347, 0.054488685}, - {0.166378498, 0.667243004, 0.166378498} - }, - { // kernelscale = 0.5 - {0.01483945381483146, 0.04981728920101665, 0.11832250618647212, 0.19882899654808195, 0.2363835084991954, 0.19882899654808195, 0.11832250618647212, 0.04981728920101665, 0.01483945381483146}, - {0.03765705331865383, 0.23993597758503457, 0.4448139381926232, 0.23993597758503457, 0.03765705331865383}, - {0.10650697891920077, 0.7869860421615985, 0.10650697891920077}, - {0.00383625879916893, 0.9923274824016621, 0.00383625879916893} - }, - { // kernelscale = 1.5 - {0.003061411879755733, 0.004950361473107714, 0.007702910451476288, 0.011533883865671363, 0.01661877803810386, 0.02304227676257057, 0.030743582802885815, 0.039471747635457924, 0.04876643642871142, 0.057977365173867444, 0.06632827707341536, 0.07301998612791924, 0.0773548527900372, 0.07885625899404015, 0.0773548527900372, 0.07301998612791924, 0.06632827707341536, 0.057977365173867444, 0.04876643642871142, 0.039471747635457924, 0.030743582802885815, 0.02304227676257057, 0.01661877803810386, 0.011533883865671363, 0.007702910451476288, 0.004950361473107714, 0.003061411879755733}, - {0.005155279152239917, 0.012574282300781493, 0.026738695950137843, 0.049570492793231676, 0.08011839616706516, 0.11289306247174812, 0.13868460659463525, 0.14853036914032117, 0.13868460659463525, 0.11289306247174812, 0.08011839616706516, 0.049570492793231676, 0.026738695950137843, 0.012574282300781493, 0.005155279152239917}, - {0.007614419169296345, 0.03607496968918391, 0.10958608179781394, 0.2134445419434044, 0.26655997480060273, 0.2134445419434044, 0.10958608179781394, 0.03607496968918391, 0.007614419169296345}, - {0.03765705331865383, 0.23993597758503457, 0.4448139381926232, 0.23993597758503457, 0.03765705331865383} - }, - { // kernelscale = 2.0 - {0.0026037269587503567, 0.0037202012799425984, 0.005201699435890725, 0.007117572324831842, 0.009530733850673961, 0.012489027248675568, 0.016015433656466918, 0.02009817440675194, 0.024682113028122028, 0.02966305528852977, 0.03488648739727233, 0.04015192995666896, 0.045223422276224175, 0.04984576229541491, 0.05376515201297315, 0.05675202106130506, 0.058623211304833, 0.059260552433345506, 0.058623211304833, 0.05675202106130506, 0.05376515201297315, 0.04984576229541491, 0.045223422276224175, 0.04015192995666896, 0.03488648739727233, 0.02966305528852977, 0.024682113028122028, 0.02009817440675194, 0.016015433656466918, 0.012489027248675568, 0.009530733850673961, 0.007117572324831842, 0.005201699435890725, 0.0037202012799425984, 0.0026037269587503567}, - {0.004908790159284653, 0.009458290945313327, 0.01687098713840658, 0.027858513730912443, 0.04258582005051434, 0.06026452109898014, 0.07894925354042928, 0.09574673424578385, 0.10749531455964786, 0.11172354906145505, 0.10749531455964786, 0.09574673424578385, 0.07894925354042928, 0.06026452109898014, 0.04258582005051434, 0.027858513730912443, 0.01687098713840658, 0.009458290945313327, 0.004908790159284653}, - {0.008812229292562283, 0.027143577143479366, 0.06511405659938266, 0.12164907301380957, 0.17699835683135567, 0.2005654142388208, 0.17699835683135567, 0.12164907301380957, 0.06511405659938266, 0.027143577143479366, 0.008812229292562283}, - {0.014646255580395366, 0.08312087071417153, 0.23555925344404363, 0.3333472405227789, 0.23555925344404363, 0.08312087071417153, 0.014646255580395366} - }, - { // kernelscale = 2.0/3 - {0.005316919191158936, 0.015508616400966175, 0.03723543338250509, 0.07358853152438767, 0.11971104580936091, 0.16029821187968224, 0.17668248362387792, 0.16029821187968224, 0.11971104580936091, 0.07358853152438767, 0.03723543338250509, 0.015508616400966175, 0.005316919191158936}, - {0.014646255580395366, 0.08312087071417153, 0.23555925344404363, 0.3333472405227789, 0.23555925344404363, 0.08312087071417153, 0.014646255580395366}, - {0.006646032999923536, 0.1942255544092176, 0.5982568251817177, 0.1942255544092176, 0.006646032999923536}, - {0.04038789325328935, 0.9192242134934214, 0.04038789325328935} - }, - { // kernelscale = 2.4 - {0.0024545290205999935, 0.003289682353098358, 0.004343275824783345, 0.005648830035496486, 0.007237311348814264, 0.009134266042662214, 0.01135658376291313, 0.0139091122480005, 0.01678142232148741, 0.019945079588083118, 0.023351803735371327, 0.026932876628807743, 0.030600089953844747, 0.0342484022907008, 0.03776031258766088, 0.04101176848166929, 0.04387923676578804, 0.046247395983335465, 0.048016793525188006, 0.04911076259025178, 0.04948092982288614, 0.04911076259025178, 0.048016793525188006, 0.046247395983335465, 0.04387923676578804, 0.04101176848166929, 0.03776031258766088, 0.0342484022907008, 0.030600089953844747, 0.026932876628807743, 0.023351803735371327, 0.019945079588083118, 0.01678142232148741, 0.0139091122480005, 0.01135658376291313, 0.009134266042662214, 0.007237311348814264, 0.005648830035496486, 0.004343275824783345, 0.003289682353098358, 0.0024545290205999935}, - {0.003637908247873508, 0.0063855489460743495, 0.010623647203380349, 0.016752435202957758, 0.02503866437045789, 0.03547098721241576, 0.047628214164962705, 0.0606155744031302, 0.07311947377207272, 0.08360086952227502, 0.09059775517499266, 0.09305784355881423, 0.09059775517499266, 0.08360086952227502, 0.07311947377207272, 0.0606155744031302, 0.047628214164962705, 0.03547098721241576, 0.02503866437045789, 0.016752435202957758, 0.010623647203380349, 0.0063855489460743495, 0.003637908247873508}, - {0.007350293016136075, 0.019098337505782634, 0.04171460426536077, 0.07659181203740582, 0.11821653158869994, 0.15338247260964522, 0.1672918979539392, 0.15338247260964522, 0.11821653158869994, 0.07659181203740582, 0.04171460426536077, 0.019098337505782634, 0.007350293016136075}, - {0.00585668412070982, 0.03167315245079435, 0.10575256915549137, 0.21799709606017553, 0.27744099642565795, 0.21799709606017553, 0.10575256915549137, 0.03167315245079435, 0.00585668412070982} - }, - { // kernelscale = 360.0 / 97.0 (1080p -> 291p) - {0.0012816791204082075, 0.0015620523909898277, 0.0018918399027484485, 0.0022769089575337843, 0.0027231994378439997, 0.003236575413413638, 0.0038226497837505957, 0.004486583874601567, 0.005232865467640872, 0.006065070407937307, 0.006985614620544574, 0.00799550497717726, 0.009094098875819743, 0.010278883514029925, 0.011545286536358305, 0.012886529914069893, 0.01429353848739285, 0.01575491351167377, 0.017256979780457923, 0.018783912474459312, 0.02031794687526095, 0.021839670601917226, 0.023328394235385276, 0.024762592283158524, 0.026120402622622926, 0.02738016907584346, 0.028521008836012274, 0.02952338429169555, 0.03036965754835163, 0.03104460574645926, 0.03153587618024757, 0.03183436222109175, 0.03193448406620358, 0.03183436222109175, 0.03153587618024757, 0.03104460574645926, 0.03036965754835163, 0.02952338429169555, 0.028521008836012274, 0.02738016907584346, 0.026120402622622926, 0.024762592283158524, 0.023328394235385276, 0.021839670601917226, 0.02031794687526095, 0.018783912474459312, 0.017256979780457923, 0.01575491351167377, 0.01429353848739285, 0.012886529914069893, 0.011545286536358305, 0.010278883514029925, 0.009094098875819743, 0.00799550497717726, 0.006985614620544574, 0.006065070407937307, 0.005232865467640872, 0.004486583874601567, 0.0038226497837505957, 0.003236575413413638, 0.0027231994378439997, 0.0022769089575337843, 0.0018918399027484485, 0.0015620523909898277, 0.0012816791204082075}, - {0.0023644174536255184, 0.0034221036550523224, 0.0048431811011422285, 0.006702499696259348, 0.009070086779378358, 0.012002027755765152, 0.015529817658257908, 0.01964927955626293, 0.02431058735737885, 0.029411204931348196, 0.03479354885462776, 0.04024882291916349, 0.0455277497133199, 0.05035791396527277, 0.0544662936498729, 0.05760450646430479, 0.05957357217624005, 0.060244772625455, 0.05957357217624005, 0.05760450646430479, 0.0544662936498729, 0.05035791396527277, 0.0455277497133199, 0.04024882291916349, 0.03479354885462776, 0.029411204931348196, 0.02431058735737885, 0.01964927955626293, 0.015529817658257908, 0.012002027755765152, 0.009070086779378358, 0.006702499696259348, 0.0048431811011422285, 0.0034221036550523224, 0.0023644174536255184}, - {0.005739705984167229, 0.010638831007972463, 0.018338687970180532, 0.029397655644830867, 0.04382553457697562, 0.06075917001712261, 0.07833692675250763, 0.09392718641207363, 0.10473363656553764, 0.10860533013726342, 0.10473363656553764, 0.09392718641207363, 0.07833692675250763, 0.06075917001712261, 0.04382553457697562, 0.029397655644830867, 0.018338687970180532, 0.010638831007972463, 0.005739705984167229}, - {0.004765886490117418, 0.014449429659230515, 0.03580755131873149, 0.07252962766809999, 0.1200806914901692, 0.16249792490873866, 0.17973777692982543, 0.16249792490873866, 0.1200806914901692, 0.07252962766809999, 0.03580755131873149, 0.014449429659230515, 0.004765886490117418} - }, - { // kernelscale = 4.0 / 3.0 - {0.00468593450369258, 0.007810637942127317, 0.012400648305837857, 0.018752961660127826, 0.027012385060515166, 0.0370615509191398, 0.048434167336838044, 0.06029033231050093, 0.07148437238577607, 0.0807313343364795, 0.08684418500467246, 0.08898298046858494, 0.08684418500467246, 0.0807313343364795, 0.07148437238577607, 0.06029033231050093, 0.048434167336838044, 0.0370615509191398, 0.027012385060515166, 0.018752961660127826, 0.012400648305837857, 0.007810637942127317, 0.00468593450369258}, - {0.007350293016136075, 0.019098337505782634, 0.04171460426536077, 0.07659181203740582, 0.11821653158869994, 0.15338247260964522, 0.1672918979539392, 0.15338247260964522, 0.11821653158869994, 0.07659181203740582, 0.04171460426536077, 0.019098337505782634, 0.007350293016136075}, - {0.023977406661157635, 0.09784278911234541, 0.22749130044200694, 0.30137700756898, 0.22749130044200694, 0.09784278911234541, 0.023977406661157635}, - {0.021929644862389363, 0.22851214688447105, 0.49911641650627925, 0.22851214688447105, 0.021929644862389363} - }, - { // kernelscale = 3.5 / 3.0 - {0.004225481445163885, 0.0077284175478382604, 0.013264883597653383, 0.021365585324475772, 0.03229420764197194, 0.04580711688347506, 0.060973309616983565, 0.07616317964123727, 0.08927890406929252, 0.09820896199705818, 0.10137990446970037, 0.09820896199705818, 0.08927890406929252, 0.07616317964123727, 0.060973309616983565, 0.04580711688347506, 0.03229420764197194, 0.021365585324475772, 0.013264883597653383, 0.0077284175478382604, 0.004225481445163885}, - {0.011253064073270563, 0.03121967849226136, 0.06904092264126249, 0.12170411773975143, 0.17101117222357906, 0.19154208965975011, 0.17101117222357906, 0.12170411773975143, 0.06904092264126249, 0.03121967849226136, 0.011253064073270563}, - {0.012560200468474614, 0.07882796468173003, 0.23729607711717063, 0.34263151546524945, 0.23729607711717063, 0.07882796468173003, 0.012560200468474614}, - {0.009620056834605945, 0.20542369732245147, 0.5699124916858852, 0.20542369732245147, 0.009620056834605945} - }, - { // kernelscale = 3.75 / 3.0 - {0.003317211135862137, 0.0059324619144260895, 0.01003813004663814, 0.01607040018215851, 0.024342018059788077, 0.03488530174522498, 0.04730253352513053, 0.06068513681455791, 0.07366077493233217, 0.08459530222840861, 0.09192046741565756, 0.09450052399963076, 0.09192046741565756, 0.08459530222840861, 0.07366077493233217, 0.06068513681455791, 0.04730253352513053, 0.03488530174522498, 0.024342018059788077, 0.01607040018215851, 0.01003813004663814, 0.0059324619144260895, 0.003317211135862137}, - {0.0050830940167887065, 0.01506448350708729, 0.03664323313832898, 0.0731554625933815, 0.11987073595060249, 0.16121038612766267, 0.17794520933229682, 0.16121038612766267, 0.11987073595060249, 0.0731554625933815, 0.03664323313832898, 0.01506448350708729, 0.0050830940167887065}, - {0.017988208587689274, 0.08909618039160763, 0.2326921801242316, 0.3204468617929429, 0.2326921801242316, 0.08909618039160763, 0.017988208587689274}, - {0.015199625365883403, 0.21875173294350592, 0.5320972833812214, 0.21875173294350592, 0.015199625365883403} - }, - { // kernelscale = 4.25 / 3.0 - {0.0037535214972137234, 0.006161856952968067, 0.009688687555946651, 0.014591465886006169, 0.021048130695628536, 0.02908096232945107, 0.03848439382348116, 0.04877992869222271, 0.059221350419380495, 0.06886460899713377, 0.07669984741736877, 0.08182265140547183, 0.08360518865545416, 0.08182265140547183, 0.07669984741736877, 0.06886460899713377, 0.059221350419380495, 0.04877992869222271, 0.03848439382348116, 0.02908096232945107, 0.021048130695628536, 0.014591465886006169, 0.009688687555946651, 0.006161856952968067, 0.0037535214972137234}, - {0.009923596815614544, 0.023121061729092417, 0.0461910237358901, 0.07912588025538068, 0.11622262324531331, 0.14637737178419374, 0.15807688486903043, 0.14637737178419374, 0.11622262324531331, 0.07912588025538068, 0.0461910237358901, 0.023121061729092417, 0.009923596815614544}, - {0.005235891862728509, 0.029948577627467506, 0.10407966050653025, 0.21976557644809397, 0.2819405871103595, 0.21976557644809397, 0.10407966050653025, 0.029948577627467506, 0.005235891862728509}, - {0.029519066107809168, 0.23537051469301665, 0.47022083839834844, 0.23537051469301665, 0.029519066107809168} - } -}; -#else -const float vif_filter1d_table_s[11][4][65] = { - { // kernelscale = 1.0 - { 0x1.e8a77p-8, 0x1.d373b2p-7, 0x1.9a1cf6p-6, 0x1.49fd9ep-5, 0x1.e7092ep-5, 0x1.49a044p-4, 0x1.99350ep-4, 0x1.d1e76ap-4, 0x1.e67f8p-4, 0x1.d1e76ap-4, 0x1.99350ep-4, 0x1.49a044p-4, 0x1.e7092ep-5, 0x1.49fd9ep-5, 0x1.9a1cf6p-6, 0x1.d373b2p-7, 0x1.e8a77p-8 }, - { 0x1.36efdap-6, 0x1.c9eaf8p-5, 0x1.ef4ac2p-4, 0x1.897424p-3, 0x1.cb1b88p-3, 0x1.897424p-3, 0x1.ef4ac2p-4, 0x1.c9eaf8p-5, 0x1.36efdap-6 }, - { 0x1.be5f0ep-5, 0x1.f41fd6p-3, 0x1.9c4868p-2, 0x1.f41fd6p-3, 0x1.be5f0ep-5 }, - { 0x1.54be4p-3, 0x1.55a0ep-1, 0x1.54be4p-3 } - }, - { // kernelscale = 0.5 - { 1.483945e-02, 4.981729e-02, 1.183225e-01, 1.988290e-01, 2.363835e-01, 1.988290e-01, 1.183225e-01, 4.981729e-02, 1.483945e-02 }, - { 3.765705e-02, 2.399360e-01, 4.448139e-01, 2.399360e-01, 3.765705e-02 }, - { 1.065070e-01, 7.869860e-01, 1.065070e-01 }, - { 3.836259e-03, 9.923275e-01, 3.836259e-03 } - }, - { // kernelscale = 1.5 - { 3.061412e-03, 4.950361e-03, 7.702910e-03, 1.153388e-02, 1.661878e-02, 2.304228e-02, 3.074358e-02, 3.947175e-02, 4.876644e-02, 5.797737e-02, 6.632828e-02, 7.301999e-02, 7.735485e-02, 7.885626e-02, 7.735485e-02, 7.301999e-02, 6.632828e-02, 5.797737e-02, 4.876644e-02, 3.947175e-02, 3.074358e-02, 2.304228e-02, 1.661878e-02, 1.153388e-02, 7.702910e-03, 4.950361e-03, 3.061412e-03 }, - { 5.155279e-03, 1.257428e-02, 2.673870e-02, 4.957049e-02, 8.011840e-02, 1.128931e-01, 1.386846e-01, 1.485304e-01, 1.386846e-01, 1.128931e-01, 8.011840e-02, 4.957049e-02, 2.673870e-02, 1.257428e-02, 5.155279e-03 }, - { 7.614419e-03, 3.607497e-02, 1.095861e-01, 2.134445e-01, 2.665600e-01, 2.134445e-01, 1.095861e-01, 3.607497e-02, 7.614419e-03 }, - { 3.765705e-02, 2.399360e-01, 4.448139e-01, 2.399360e-01, 3.765705e-02 } - }, - { // kernelscale = 2.0 - { 2.603727e-03, 3.720201e-03, 5.201699e-03, 7.117572e-03, 9.530734e-03, 1.248903e-02, 1.601543e-02, 2.009817e-02, 2.468211e-02, 2.966306e-02, 3.488649e-02, 4.015193e-02, 4.522342e-02, 4.984576e-02, 5.376515e-02, 5.675202e-02, 5.862321e-02, 5.926055e-02, 5.862321e-02, 5.675202e-02, 5.376515e-02, 4.984576e-02, 4.522342e-02, 4.015193e-02, 3.488649e-02, 2.966306e-02, 2.468211e-02, 2.009817e-02, 1.601543e-02, 1.248903e-02, 9.530734e-03, 7.117572e-03, 5.201699e-03, 3.720201e-03, 2.603727e-03 }, - { 4.908790e-03, 9.458291e-03, 1.687099e-02, 2.785851e-02, 4.258582e-02, 6.026452e-02, 7.894925e-02, 9.574673e-02, 1.074953e-01, 1.117235e-01, 1.074953e-01, 9.574673e-02, 7.894925e-02, 6.026452e-02, 4.258582e-02, 2.785851e-02, 1.687099e-02, 9.458291e-03, 4.908790e-03 }, - { 8.812229e-03, 2.714358e-02, 6.511406e-02, 1.216491e-01, 1.769984e-01, 2.005654e-01, 1.769984e-01, 1.216491e-01, 6.511406e-02, 2.714358e-02, 8.812229e-03 }, - { 1.464626e-02, 8.312087e-02, 2.355593e-01, 3.333472e-01, 2.355593e-01, 8.312087e-02, 1.464626e-02 } - }, - { // kernelscale = 2.0/3 - { 5.316919e-03, 1.550862e-02, 3.723543e-02, 7.358853e-02, 1.197110e-01, 1.602982e-01, 1.766825e-01, 1.602982e-01, 1.197110e-01, 7.358853e-02, 3.723543e-02, 1.550862e-02, 5.316919e-03 }, - { 1.464626e-02, 8.312087e-02, 2.355593e-01, 3.333472e-01, 2.355593e-01, 8.312087e-02, 1.464626e-02 }, - { 6.646033e-03, 1.942256e-01, 5.982568e-01, 1.942256e-01, 6.646033e-03 }, - { 4.038789e-02, 9.192242e-01, 4.038789e-02 } - }, - { // kernelscale = 2.4 - { 2.454529e-03, 3.289682e-03, 4.343276e-03, 5.648830e-03, 7.237311e-03, 9.134266e-03, 1.135658e-02, 1.390911e-02, 1.678142e-02, 1.994508e-02, 2.335180e-02, 2.693288e-02, 3.060009e-02, 3.424840e-02, 3.776031e-02, 4.101177e-02, 4.387924e-02, 4.624740e-02, 4.801679e-02, 4.911076e-02, 4.948093e-02, 4.911076e-02, 4.801679e-02, 4.624740e-02, 4.387924e-02, 4.101177e-02, 3.776031e-02, 3.424840e-02, 3.060009e-02, 2.693288e-02, 2.335180e-02, 1.994508e-02, 1.678142e-02, 1.390911e-02, 1.135658e-02, 9.134266e-03, 7.237311e-03, 5.648830e-03, 4.343276e-03, 3.289682e-03, 2.454529e-03 }, - { 3.637908e-03, 6.385549e-03, 1.062365e-02, 1.675244e-02, 2.503866e-02, 3.547099e-02, 4.762821e-02, 6.061557e-02, 7.311947e-02, 8.360087e-02, 9.059776e-02, 9.305784e-02, 9.059776e-02, 8.360087e-02, 7.311947e-02, 6.061557e-02, 4.762821e-02, 3.547099e-02, 2.503866e-02, 1.675244e-02, 1.062365e-02, 6.385549e-03, 3.637908e-03 }, - { 7.350293e-03, 1.909834e-02, 4.171460e-02, 7.659181e-02, 1.182165e-01, 1.533825e-01, 1.672919e-01, 1.533825e-01, 1.182165e-01, 7.659181e-02, 4.171460e-02, 1.909834e-02, 7.350293e-03 }, - { 5.856684e-03, 3.167315e-02, 1.057526e-01, 2.179971e-01, 2.774410e-01, 2.179971e-01, 1.057526e-01, 3.167315e-02, 5.856684e-03 } - }, - { // kernelscale = 360.0 / 97.0 (1080p -> 291p) - { 1.281679e-03, 1.562052e-03, 1.891840e-03, 2.276909e-03, 2.723199e-03, 3.236575e-03, 3.822650e-03, 4.486584e-03, 5.232865e-03, 6.065070e-03, 6.985615e-03, 7.995505e-03, 9.094099e-03, 1.027888e-02, 1.154529e-02, 1.288653e-02, 1.429354e-02, 1.575491e-02, 1.725698e-02, 1.878391e-02, 2.031795e-02, 2.183967e-02, 2.332839e-02, 2.476259e-02, 2.612040e-02, 2.738017e-02, 2.852101e-02, 2.952338e-02, 3.036966e-02, 3.104461e-02, 3.153588e-02, 3.183436e-02, 3.193448e-02, 3.183436e-02, 3.153588e-02, 3.104461e-02, 3.036966e-02, 2.952338e-02, 2.852101e-02, 2.738017e-02, 2.612040e-02, 2.476259e-02, 2.332839e-02, 2.183967e-02, 2.031795e-02, 1.878391e-02, 1.725698e-02, 1.575491e-02, 1.429354e-02, 1.288653e-02, 1.154529e-02, 1.027888e-02, 9.094099e-03, 7.995505e-03, 6.985615e-03, 6.065070e-03, 5.232865e-03, 4.486584e-03, 3.822650e-03, 3.236575e-03, 2.723199e-03, 2.276909e-03, 1.891840e-03, 1.562052e-03, 1.281679e-03 }, - { 2.364417e-03, 3.422104e-03, 4.843181e-03, 6.702500e-03, 9.070087e-03, 1.200203e-02, 1.552982e-02, 1.964928e-02, 2.431059e-02, 2.941120e-02, 3.479355e-02, 4.024882e-02, 4.552775e-02, 5.035791e-02, 5.446629e-02, 5.760451e-02, 5.957357e-02, 6.024477e-02, 5.957357e-02, 5.760451e-02, 5.446629e-02, 5.035791e-02, 4.552775e-02, 4.024882e-02, 3.479355e-02, 2.941120e-02, 2.431059e-02, 1.964928e-02, 1.552982e-02, 1.200203e-02, 9.070087e-03, 6.702500e-03, 4.843181e-03, 3.422104e-03, 2.364417e-03 }, - { 5.739706e-03, 1.063883e-02, 1.833869e-02, 2.939766e-02, 4.382553e-02, 6.075917e-02, 7.833693e-02, 9.392719e-02, 1.047336e-01, 1.086053e-01, 1.047336e-01, 9.392719e-02, 7.833693e-02, 6.075917e-02, 4.382553e-02, 2.939766e-02, 1.833869e-02, 1.063883e-02, 5.739706e-03 }, - { 4.765886e-03, 1.444943e-02, 3.580755e-02, 7.252963e-02, 1.200807e-01, 1.624979e-01, 1.797378e-01, 1.624979e-01, 1.200807e-01, 7.252963e-02, 3.580755e-02, 1.444943e-02, 4.765886e-03 } - }, - { // kernelscale = 4.0 / 3.0 - { 4.685935e-03, 7.810638e-03, 1.240065e-02, 1.875296e-02, 2.701239e-02, 3.706155e-02, 4.843417e-02, 6.029033e-02, 7.148437e-02, 8.073133e-02, 8.684419e-02, 8.898298e-02, 8.684419e-02, 8.073133e-02, 7.148437e-02, 6.029033e-02, 4.843417e-02, 3.706155e-02, 2.701239e-02, 1.875296e-02, 1.240065e-02, 7.810638e-03, 4.685935e-03 }, - { 7.350293e-03, 1.909834e-02, 4.171460e-02, 7.659181e-02, 1.182165e-01, 1.533825e-01, 1.672919e-01, 1.533825e-01, 1.182165e-01, 7.659181e-02, 4.171460e-02, 1.909834e-02, 7.350293e-03 }, - { 2.397741e-02, 9.784279e-02, 2.274913e-01, 3.013770e-01, 2.274913e-01, 9.784279e-02, 2.397741e-02 }, - { 2.192964e-02, 2.285121e-01, 4.991164e-01, 2.285121e-01, 2.192964e-02 } - }, - { // kernelscale = 3.5 / 3.0 - { 4.225481e-03, 7.728418e-03, 1.326488e-02, 2.136559e-02, 3.229421e-02, 4.580712e-02, 6.097331e-02, 7.616318e-02, 8.927890e-02, 9.820896e-02, 1.013799e-01, 9.820896e-02, 8.927890e-02, 7.616318e-02, 6.097331e-02, 4.580712e-02, 3.229421e-02, 2.136559e-02, 1.326488e-02, 7.728418e-03, 4.225481e-03 }, - { 1.125306e-02, 3.121968e-02, 6.904092e-02, 1.217041e-01, 1.710112e-01, 1.915421e-01, 1.710112e-01, 1.217041e-01, 6.904092e-02, 3.121968e-02, 1.125306e-02 }, - { 1.256020e-02, 7.882796e-02, 2.372961e-01, 3.426315e-01, 2.372961e-01, 7.882796e-02, 1.256020e-02 }, - { 9.620057e-03, 2.054237e-01, 5.699125e-01, 2.054237e-01, 9.620057e-03 } - }, - { // kernelscale = 3.75 / 3.0 - { 3.317211e-03, 5.932462e-03, 1.003813e-02, 1.607040e-02, 2.434202e-02, 3.488530e-02, 4.730253e-02, 6.068514e-02, 7.366077e-02, 8.459530e-02, 9.192047e-02, 9.450052e-02, 9.192047e-02, 8.459530e-02, 7.366077e-02, 6.068514e-02, 4.730253e-02, 3.488530e-02, 2.434202e-02, 1.607040e-02, 1.003813e-02, 5.932462e-03, 3.317211e-03 }, - { 5.083094e-03, 1.506448e-02, 3.664323e-02, 7.315546e-02, 1.198707e-01, 1.612104e-01, 1.779452e-01, 1.612104e-01, 1.198707e-01, 7.315546e-02, 3.664323e-02, 1.506448e-02, 5.083094e-03 }, - { 1.798821e-02, 8.909618e-02, 2.326922e-01, 3.204469e-01, 2.326922e-01, 8.909618e-02, 1.798821e-02 }, - { 1.519963e-02, 2.187517e-01, 5.320973e-01, 2.187517e-01, 1.519963e-02 } - }, - { // kernelscale = 4.25 / 3.0 - { 3.753521e-03, 6.161857e-03, 9.688688e-03, 1.459147e-02, 2.104813e-02, 2.908096e-02, 3.848439e-02, 4.877993e-02, 5.922135e-02, 6.886461e-02, 7.669985e-02, 8.182265e-02, 8.360519e-02, 8.182265e-02, 7.669985e-02, 6.886461e-02, 5.922135e-02, 4.877993e-02, 3.848439e-02, 2.908096e-02, 2.104813e-02, 1.459147e-02, 9.688688e-03, 6.161857e-03, 3.753521e-03 }, - { 9.923597e-03, 2.312106e-02, 4.619102e-02, 7.912588e-02, 1.162226e-01, 1.463774e-01, 1.580769e-01, 1.463774e-01, 1.162226e-01, 7.912588e-02, 4.619102e-02, 2.312106e-02, 9.923597e-03 }, - { 5.235892e-03, 2.994858e-02, 1.040797e-01, 2.197656e-01, 2.819406e-01, 2.197656e-01, 1.040797e-01, 2.994858e-02, 5.235892e-03 }, - { 2.951907e-02, 2.353705e-01, 4.702208e-01, 2.353705e-01, 2.951907e-02 } +static int round_up_to_odd(float f) { + int ceiling = ceil(f); + if (ceiling % 2 == 0) { + return ceiling + 1; + } + else { + return ceiling; + } +} + +static float get_gaussian_pdf(float x, float mean, float stdev) { + float num = exp(-0.5 * (x - mean)/stdev * (x - mean)/stdev); + float den = 1 / (stdev * sqrt(2 * M_PI)); + return num / den; +} + +static void get_1d_gaussian_kernel(float *out, int size, float stdev) { + assert(size % 2 == 1); + + float sum = 0; + int k = (size - 1) / 2; + for (int i = 0; i < size; i++) { + int curr = i - k; + out[i] = get_gaussian_pdf(curr, 0, stdev); + sum += out[i]; + } + for (int i = 0; i < size; i++) { + out[i] /= sum; + } +} + +bool vif_validate_kernelscale(float kernelscale) { + for (int i = 0; i < NUM_KERNELSCALES; i++) { + if (fabsf(kernelscale - valid_kernelscales[i]) < 1e-3) { + return true; } -}; -#endif + } + return false; +} -const int vif_filter1d_width[11][4] = { - { 17, 9, 5, 3 }, // kernelscale = 1.0 - { 9, 5, 3, 3 }, // kernelscale = 0.5 - { 27, 15, 9, 5 }, // kernelscale = 1.5 - { 35, 19, 11, 7 }, // kernelscale = 2.0 - { 13, 7, 5, 3 }, // kernelscale = 2.0/3 - { 41, 23, 13, 9 }, // kernelscale = 2.4/1.0 - { 65, 35, 19, 13 }, // kernelscale = 360.0/97.0 - { 23, 13, 7, 5 }, // kernelscale = 4.0/3.0 - { 21, 11, 7, 5 }, // kernelscale = 3.5/3.0 - { 23, 13, 7, 5 }, // kernelscale = 3.75/3.0 - { 25, 13, 9, 5 } // kernelscale = 4.25/3.0 -}; +int vif_get_filter_size(int scale, float kernelscale) { + assert(scale <= 4); + + int n = (1 << (4 - scale)) + 1; + return MAX(round_up_to_odd(n * kernelscale), 3); +} + +void vif_get_filter(float *out, int scale, float kernelscale) { + int window_size = vif_get_filter_size(scale, kernelscale); + get_1d_gaussian_kernel(out, window_size, window_size / 5.0f); +} + +void speed_get_antialias_filter(float *out, int scale, float kernelscale) { + // sigma_trick logic replication: antialias filter always has the size of scale 1 filter + int window_size = vif_get_filter_size(1, kernelscale); + get_1d_gaussian_kernel(out, window_size, sqrt(scale) * window_size / 5.0f); +} void vif_dec2_s(const float *src, float *dst, int src_w, int src_h, int src_stride, int dst_stride) { @@ -251,6 +163,21 @@ void vif_dec2_s(const float *src, float *dst, int src_w, int src_h, int src_stri } } +void vif_dec16_s(const float *src, float *dst, int src_w, int src_h, int src_stride, int dst_stride) +{ + int src_px_stride = src_stride / sizeof(float); // src_stride is in bytes + int dst_px_stride = dst_stride / sizeof(float); + + int i, j; + + // decimation by 16 in each direction + for (i = 0; i < src_h / 16; ++i) { + for (j = 0; j < src_w / 16; ++j) { + dst[i * dst_px_stride + j] = src[(i * 16) * src_px_stride + (j * 16)]; + } + } +} + float vif_sum_s(const float *x, int w, int h, int stride) { int px_stride = stride / sizeof(float); @@ -272,58 +199,57 @@ float vif_sum_s(const float *x, int w, int h, int stride) } void vif_statistic_s(const float *mu1, const float *mu2, const float *xx_filt, const float *yy_filt, const float *xy_filt, float *num, float *den, - int w, int h, int mu1_stride, int mu2_stride, int xx_filt_stride, int yy_filt_stride, int xy_filt_stride, - double vif_enhn_gain_limit) + int w, int h, int mu1_stride, int mu2_stride, int xx_filt_stride, int yy_filt_stride, int xy_filt_stride, + double vif_enhn_gain_limit, double vif_sigma_nsq) { - static const float sigma_nsq = 2; - static const float sigma_max_inv = 4.0 / (255.0*255.0); - - int mu1_px_stride = mu1_stride / sizeof(float); - int mu2_px_stride = mu2_stride / sizeof(float); - int xx_filt_px_stride = xx_filt_stride / sizeof(float); - int yy_filt_px_stride = yy_filt_stride / sizeof(float); - int xy_filt_px_stride = xy_filt_stride / sizeof(float); - - float mu1_sq_val, mu2_sq_val, mu1_mu2_val, xx_filt_val, yy_filt_val, xy_filt_val; - float sigma1_sq, sigma2_sq, sigma12; - float num_val, den_val; - int i, j; + const float sigma_max_inv = powf(vif_sigma_nsq, 2.0f) / (255.0*255.0); + + int mu1_px_stride = mu1_stride / sizeof(float); + int mu2_px_stride = mu2_stride / sizeof(float); + int xx_filt_px_stride = xx_filt_stride / sizeof(float); + int yy_filt_px_stride = yy_filt_stride / sizeof(float); + int xy_filt_px_stride = xy_filt_stride / sizeof(float); + + float mu1_sq_val, mu2_sq_val, mu1_mu2_val, xx_filt_val, yy_filt_val, xy_filt_val; + float sigma1_sq, sigma2_sq, sigma12; + float num_val, den_val; + int i, j; /* ==== vif_stat_mode = 'matching_c' ==== */ // float num_log_den, num_log_num; /* ==== vif_stat_mode = 'matching_matlab' ==== */ - float g, sv_sq, eps = 1.0e-10f; - float vif_enhn_gain_limit_f = (float) vif_enhn_gain_limit; + float g, sv_sq, eps = 1.0e-10f; + float vif_enhn_gain_limit_f = (float) vif_enhn_gain_limit; /* == end of vif_stat_mode = 'matching_matlab' == */ - float accum_num = 0.0f; - float accum_den = 0.0f; - - for (i = 0; i < h; ++i) { - float accum_inner_num = 0; - float accum_inner_den = 0; - for (j = 0; j < w; ++j) { - float mu1_val = mu1[i * mu1_px_stride + j]; - float mu2_val = mu2[i * mu2_px_stride + j]; - mu1_sq_val = mu1_val * mu1_val; // same name as the Matlab code vifp_mscale.m - mu2_sq_val = mu2_val * mu2_val; - mu1_mu2_val = mu1_val * mu2_val; //mu1_mu2[i * mu1_mu2_px_stride + j]; - xx_filt_val = xx_filt[i * xx_filt_px_stride + j]; - yy_filt_val = yy_filt[i * yy_filt_px_stride + j]; - xy_filt_val = xy_filt[i * xy_filt_px_stride + j]; - - sigma1_sq = xx_filt_val - mu1_sq_val; - sigma2_sq = yy_filt_val - mu2_sq_val; - sigma12 = xy_filt_val - mu1_mu2_val; - - /* ==== vif_stat_mode = 'matching_c' ==== */ - - /* if (sigma1_sq < sigma_nsq) { + float accum_num = 0.0f; + float accum_den = 0.0f; + + for (i = 0; i < h; ++i) { + float accum_inner_num = 0; + float accum_inner_den = 0; + for (j = 0; j < w; ++j) { + float mu1_val = mu1[i * mu1_px_stride + j]; + float mu2_val = mu2[i * mu2_px_stride + j]; + mu1_sq_val = mu1_val * mu1_val; // same name as the Matlab code vifp_mscale.m + mu2_sq_val = mu2_val * mu2_val; + mu1_mu2_val = mu1_val * mu2_val; //mu1_mu2[i * mu1_mu2_px_stride + j]; + xx_filt_val = xx_filt[i * xx_filt_px_stride + j]; + yy_filt_val = yy_filt[i * yy_filt_px_stride + j]; + xy_filt_val = xy_filt[i * xy_filt_px_stride + j]; + + sigma1_sq = xx_filt_val - mu1_sq_val; + sigma2_sq = yy_filt_val - mu2_sq_val; + sigma12 = xy_filt_val - mu1_mu2_val; + + /* ==== vif_stat_mode = 'matching_c' ==== */ + + /* if (sigma1_sq < vif_sigma_nsq) { num_val = 1.0 - sigma2_sq * sigma_max_inv; den_val = 1.0; } else { - num_log_num = (sigma2_sq + sigma_nsq) * sigma1_sq; + num_log_num = (sigma2_sq + vif_sigma_nsq) * sigma1_sq; if (sigma12 < 0) { num_val = 0.0; @@ -333,7 +259,7 @@ void vif_statistic_s(const float *mu1, const float *mu2, const float *xx_filt, c num_log_den = num_log_num - sigma12 * sigma12; num_val = log2f(num_log_num / num_log_den); } - den_val = log2f(1.0f + sigma1_sq / sigma_nsq); + den_val = log2f(1.0f + sigma1_sq / vif_sigma_nsq); } */ /* ==== vif_stat_mode = 'matching_matlab' ==== */ @@ -341,36 +267,36 @@ void vif_statistic_s(const float *mu1, const float *mu2, const float *xx_filt, c sigma1_sq = MAX(sigma1_sq, 0.0f); sigma2_sq = MAX(sigma2_sq, 0.0f); - g = sigma12 / (sigma1_sq + eps); - sv_sq = sigma2_sq - g * sigma12; + g = sigma12 / (sigma1_sq + eps); + sv_sq = sigma2_sq - g * sigma12; - if (sigma1_sq < eps) { - g = 0.0f; + if (sigma1_sq < eps) { + g = 0.0f; sv_sq = sigma2_sq; sigma1_sq = 0.0f; - } + } - if (sigma2_sq < eps) { - g = 0.0f; - sv_sq = 0.0f; - } + if (sigma2_sq < eps) { + g = 0.0f; + sv_sq = 0.0f; + } - if (g < 0.0f) { - sv_sq = sigma2_sq; - g = 0.0f; - } - sv_sq = MAX(sv_sq, eps); + if (g < 0.0f) { + sv_sq = sigma2_sq; + g = 0.0f; + } + sv_sq = MAX(sv_sq, eps); g = MIN(g, vif_enhn_gain_limit_f); - num_val = log2f(1.0f + (g * g * sigma1_sq) / (sv_sq + sigma_nsq)); - den_val = log2f(1.0f + (sigma1_sq) / (sigma_nsq)); + num_val = log2f(1.0f + (g * g * sigma1_sq) / (sv_sq + vif_sigma_nsq)); + den_val = log2f(1.0f + (sigma1_sq) / (vif_sigma_nsq)); if (sigma12 < 0.0f) { num_val = 0.0f; } - if (sigma1_sq < sigma_nsq) { + if (sigma1_sq < vif_sigma_nsq) { num_val = 1.0f - sigma2_sq * sigma_max_inv; den_val = 1.0f; } @@ -378,14 +304,14 @@ void vif_statistic_s(const float *mu1, const float *mu2, const float *xx_filt, c /* == end of vif_stat_mode = 'matching_matlab' == */ accum_inner_num += num_val; - accum_inner_den += den_val; - } - - accum_num += accum_inner_num; - accum_den += accum_inner_den; - } - *num = accum_num; - *den = accum_den; + accum_inner_den += den_val; + } + + accum_num += accum_inner_num; + accum_den += accum_inner_den; + } + num[0] = accum_num; + den[0] = accum_den; } void vif_filter1d_s(const float *f, const float *src, float *dst, float *tmpbuf, int w, int h, int src_stride, int dst_stride, int fwidth) @@ -398,9 +324,8 @@ void vif_filter1d_s(const float *f, const float *src, float *dst, float *tmpbuf, #if ARCH_X86 const unsigned flags = vmaf_get_cpu_flags(); - if ((flags & VMAF_X86_CPU_FLAG_AVX2) && (fwidth == 17 || fwidth == 9 || fwidth == 5 || fwidth == 3)) { - convolution_f32_avx_s(f, fwidth, src, dst, tmpbuf, w, h, - src_px_stride, dst_px_stride); + if ((flags & VMAF_X86_CPU_FLAG_AVX2) && fwidth <= MAX_FWIDTH_AVX_CONV) { + convolution_f32_avx_s(f, fwidth, src, dst, tmpbuf, w, h, src_px_stride, dst_px_stride); return; } #endif @@ -421,7 +346,7 @@ void vif_filter1d_s(const float *f, const float *src, float *dst, float *tmpbuf, fcoeff = f[fi]; ii = i - fwidth / 2 + fi; - ii = ii < 0 ? -ii : (ii >= h ? 2 * h - ii - 1 : ii); + ii = ii < 0 ? -ii : (ii >= h ? 2 * h - ii - 2 : ii); imgcoeff = src[ii * src_px_stride + j]; @@ -439,7 +364,7 @@ void vif_filter1d_s(const float *f, const float *src, float *dst, float *tmpbuf, fcoeff = f[fj]; jj = j - fwidth / 2 + fj; - jj = jj < 0 ? -jj : (jj >= w ? 2 * w - jj - 1 : jj); + jj = jj < 0 ? -jj : (jj >= w ? 2 * w - jj - 2 : jj); imgcoeff = tmp[jj]; @@ -459,131 +384,347 @@ void vif_filter1d_s(const float *f, const float *src, float *dst, float *tmpbuf, void vif_filter1d_sq_s(const float *f, const float *src, float *dst, float *tmpbuf, int w, int h, int src_stride, int dst_stride, int fwidth) { - int src_px_stride = src_stride / sizeof(float); - int dst_px_stride = dst_stride / sizeof(float); + int src_px_stride = src_stride / sizeof(float); + int dst_px_stride = dst_stride / sizeof(float); - /* if support avx */ - + /* if support avx */ + #if ARCH_X86 const unsigned flags = vmaf_get_cpu_flags(); - if ((flags & VMAF_X86_CPU_FLAG_AVX2) && (fwidth == 17 || fwidth == 9 || fwidth == 5 || fwidth == 3)) { - convolution_f32_avx_sq_s(f, fwidth, src, dst, tmpbuf, w, h, - src_px_stride, dst_px_stride); + if ((flags & VMAF_X86_CPU_FLAG_AVX2) && fwidth <= MAX_FWIDTH_AVX_CONV) { + convolution_f32_avx_sq_s(f, fwidth, src, dst, tmpbuf, w, h, src_px_stride, dst_px_stride); return; } #endif - /* fall back */ + /* fall back */ - float *tmp = aligned_malloc(ALIGN_CEIL(w * sizeof(float)), MAX_ALIGN); - float fcoeff, imgcoeff; + float *tmp = aligned_malloc(ALIGN_CEIL(w * sizeof(float)), MAX_ALIGN); + float fcoeff, imgcoeff; - int i, j, fi, fj, ii, jj; + int i, j, fi, fj, ii, jj; - for (i = 0; i < h; ++i) { - /* Vertical pass. */ - for (j = 0; j < w; ++j) { - float accum = 0; + for (i = 0; i < h; ++i) { + /* Vertical pass. */ + for (j = 0; j < w; ++j) { + float accum = 0; - for (fi = 0; fi < fwidth; ++fi) { - fcoeff = f[fi]; + for (fi = 0; fi < fwidth; ++fi) { + fcoeff = f[fi]; - ii = i - fwidth / 2 + fi; - ii = ii < 0 ? -ii : (ii >= h ? 2 * h - ii - 1 : ii); + ii = i - fwidth / 2 + fi; + ii = ii < 0 ? -ii : (ii >= h ? 2 * h - ii - 2 : ii); - imgcoeff = src[ii * src_px_stride + j]; + imgcoeff = src[ii * src_px_stride + j]; - accum += fcoeff * (imgcoeff * imgcoeff); - } + accum += fcoeff * (imgcoeff * imgcoeff); + } - tmp[j] = accum; - } + tmp[j] = accum; + } - /* Horizontal pass. */ - for (j = 0; j < w; ++j) { - float accum = 0; + /* Horizontal pass. */ + for (j = 0; j < w; ++j) { + float accum = 0; - for (fj = 0; fj < fwidth; ++fj) { - fcoeff = f[fj]; + for (fj = 0; fj < fwidth; ++fj) { + fcoeff = f[fj]; - jj = j - fwidth / 2 + fj; - jj = jj < 0 ? -jj : (jj >= w ? 2 * w - jj - 1 : jj); + jj = j - fwidth / 2 + fj; + jj = jj < 0 ? -jj : (jj >= w ? 2 * w - jj - 2 : jj); - imgcoeff = tmp[jj]; + imgcoeff = tmp[jj]; - accum += fcoeff * imgcoeff; - } + accum += fcoeff * imgcoeff; + } - dst[i * dst_px_stride + j] = accum; - } - } + dst[i * dst_px_stride + j] = accum; + } + } - aligned_free(tmp); + aligned_free(tmp); } void vif_filter1d_xy_s(const float *f, const float *src1, const float *src2, float *dst, float *tmpbuf, int w, int h, int src1_stride, int src2_stride, int dst_stride, int fwidth) { - int src1_px_stride = src1_stride / sizeof(float); - int src2_px_stride = src2_stride / sizeof(float); - int dst_px_stride = dst_stride / sizeof(float); + int src1_px_stride = src1_stride / sizeof(float); + int src2_px_stride = src2_stride / sizeof(float); + int dst_px_stride = dst_stride / sizeof(float); - /* if support avx */ + /* if support avx */ #if ARCH_X86 const unsigned flags = vmaf_get_cpu_flags(); - if ((flags & VMAF_X86_CPU_FLAG_AVX2) && (fwidth == 17 || fwidth == 9 || fwidth == 5 || fwidth == 3)) { - convolution_f32_avx_xy_s(f, fwidth, src1, src2, dst, tmpbuf, w, h, - src1_px_stride, src2_px_stride, dst_px_stride); + if ((flags & VMAF_X86_CPU_FLAG_AVX2) && fwidth <= MAX_FWIDTH_AVX_CONV) { + convolution_f32_avx_xy_s(f, fwidth, src1, src2, dst, tmpbuf, w, h, src1_px_stride, src2_px_stride, dst_px_stride); return; } #endif - /* fall back */ + /* fall back */ - float *tmp = aligned_malloc(ALIGN_CEIL(w * sizeof(float)), MAX_ALIGN); - float fcoeff, imgcoeff, imgcoeff1, imgcoeff2; + float *tmp = aligned_malloc(ALIGN_CEIL(w * sizeof(float)), MAX_ALIGN); + float fcoeff, imgcoeff, imgcoeff1, imgcoeff2; - int i, j, fi, fj, ii, jj; + int i, j, fi, fj, ii, jj; - for (i = 0; i < h; ++i) { - /* Vertical pass. */ - for (j = 0; j < w; ++j) { - float accum = 0; + for (i = 0; i < h; ++i) { + /* Vertical pass. */ + for (j = 0; j < w; ++j) { + float accum = 0; - for (fi = 0; fi < fwidth; ++fi) { - fcoeff = f[fi]; + for (fi = 0; fi < fwidth; ++fi) { + fcoeff = f[fi]; - ii = i - fwidth / 2 + fi; - ii = ii < 0 ? -ii : (ii >= h ? 2 * h - ii - 1 : ii); + ii = i - fwidth / 2 + fi; + ii = ii < 0 ? -ii : (ii >= h ? 2 * h - ii - 2 : ii); - imgcoeff1 = src1[ii * src1_px_stride + j]; - imgcoeff2 = src2[ii * src2_px_stride + j]; + imgcoeff1 = src1[ii * src1_px_stride + j]; + imgcoeff2 = src2[ii * src2_px_stride + j]; - accum += fcoeff * (imgcoeff1 * imgcoeff2); - } + accum += fcoeff * (imgcoeff1 * imgcoeff2); + } - tmp[j] = accum; + tmp[j] = accum; + } + + /* Horizontal pass. */ + for (j = 0; j < w; ++j) { + float accum = 0; + + for (fj = 0; fj < fwidth; ++fj) { + fcoeff = f[fj]; + + jj = j - fwidth / 2 + fj; + jj = jj < 0 ? -jj : (jj >= w ? 2 * w - jj - 2 : jj); + + imgcoeff = tmp[jj]; + + accum += fcoeff * imgcoeff; + } + + dst[i * dst_px_stride + j] = accum; + } + } + + aligned_free(tmp); +} + +int vif_get_scaling_method(char *scaling_method_str, enum vif_scaling_method *scale_method) { + if (!strcmp(scaling_method_str, "nearest")) { + *scale_method = vif_scale_nearest; + } else if (!strcmp(scaling_method_str, "bilinear")) { + *scale_method = vif_scale_bilinear; + } else if (!strcmp(scaling_method_str, "bicubic")) { + *scale_method = vif_scale_bicubic; + } else if (!strcmp(scaling_method_str, "lanczos4")) { + *scale_method = vif_scale_lanczos4; + } else { + vmaf_log(VMAF_LOG_LEVEL_ERROR, "Invalid scale method %s. Supported scale methods: [nearest, bilinear, bicubic, lanczos4]\n", scaling_method_str); + return -EINVAL; + } + + return 0; +} + +static float bicubic_kernel(float t) { + float a = -0.75; + if (t < 0) { + t = -t; + } + if (t < 1) { + return ((a + 2) * t - (a + 3)) * t * t + 1; + } + if (t < 2) { + return (((t - 5) * t + 8) * t - 4) * a; + } + return 0; +} + +static float mirror(float i, float left, float right) { + return (i < left ? -i : i > right ? 2 * right - i : i); +} + +static float bicubic_interpolation(const float *src, int width, int height, int src_stride, float x, float y) { + int x0 = floor(x); + int y0 = floor(y); + + float dx = x - x0; + float dy = y - y0; + + float weights_x[4]; + float weights_y[4]; + + for (int i = -1; i <= 2; i++) { + weights_x[i + 1] = bicubic_kernel(i - dx); + weights_y[i + 1] = bicubic_kernel(i - dy); + } + + float interp_val = 0.0; + for (int j = -1; j <= 2; j++) { + for (int i = -1; i <= 2; i++) { + int x_index = mirror(x0 + i, 0, width - 1); + int y_index = mirror(y0 + j, 0, height - 1); + float weight = weights_x[i + 1] * weights_y[j + 1]; + interp_val += src[y_index * src_stride + x_index] * weight; } + } - /* Horizontal pass. */ - for (j = 0; j < w; ++j) { - float accum = 0; + return interp_val; +} - for (fj = 0; fj < fwidth; ++fj) { - fcoeff = f[fj]; +static void vif_scale_frame_bicubic_s(const float *src, float *dst, + int src_w, int src_h, int src_stride, + int dst_w, int dst_h, int dst_stride) { + // if the input and output sizes are the same + if (src_w == dst_w && src_h == dst_h) { + memcpy(dst, src, dst_stride * dst_h * sizeof(float)); + return; + } - jj = j - fwidth / 2 + fj; - jj = jj < 0 ? -jj : (jj >= w ? 2 * w - jj - 1 : jj); + float ratio_x = (float)src_w / dst_w; + float ratio_y = (float)src_h / dst_h; - imgcoeff = tmp[jj]; + for (int y = 0; y < dst_h; y++) { + float yy = (y + 0.5) * ratio_y - 0.5; + for (int x = 0; x < dst_w; x++) { + float xx = (x + 0.5) * ratio_x - 0.5; + dst[y * dst_stride + x] = bicubic_interpolation(src, src_w, src_h, src_stride, xx, yy); + } + } +} - accum += fcoeff * imgcoeff; - } +static float lanczos4_kernel(float x, float a) { + if (x == 0.0) return 1.0; + if (x > -a && x < a) { + return a * sin(M_PI * x) * sin(M_PI * x / a) / (M_PI * M_PI * x * x); + } + return 0.0; +} - dst[i * dst_px_stride + j] = accum; +static float lanczos4_interpolation(const float *src, int width, int height, int src_stride, float x, float y) { + int a = 4; + int x0 = floor(x); + int y0 = floor(y); + + float dx = x - x0; + float dy = y - y0; + + float value = 0.0; + float weight_sum = 0.0; + + float weights_x[9]; + float weights_y[9]; + for (int i = -a; i <= a; i++) { + weights_x[i + a] = lanczos4_kernel(i - dx, (float)a); + weights_y[i + a] = lanczos4_kernel(i - dy, (float)a); + } + + for (int iy = -a; iy <= a; iy++) { + for (int ix = -a; ix <= a; ix++) { + float weight = weights_x[ix + a] * weights_y[iy + a]; + weight_sum += weight; + + int x_index = mirror(x0 + ix, 0, width - 1); + int y_index = mirror(y0 + iy, 0, height - 1); + value += src[y_index * src_stride + x_index] * weight; } } + + return value / weight_sum; +} - aligned_free(tmp); +static void vif_scale_frame_lanczos4_s(const float *src, float *dst, + int src_w, int src_h, int src_stride, + int dst_w, int dst_h, int dst_stride) { + // if the input and output sizes are the same + if (src_w == dst_w && src_h == dst_h) { + memcpy(dst, src, dst_stride * dst_h * sizeof(float)); + return; + } + + float ratio_x = (float)src_w / dst_w; + float ratio_y = (float)src_h / dst_h; + + for (int y = 0; y < dst_h; y++) { + float yy = (y + 0.5) * ratio_y - 0.5; + for (int x = 0; x < dst_w; x++) { + float xx = (x + 0.5) * ratio_x - 0.5; + dst[y * dst_stride + x] = lanczos4_interpolation(src, src_w, src_h, src_stride, xx, yy); + } + } +} + +static float bilinear_interpolation(const float *src, int width, int height, int src_stride, float x, float y) { + int x1 = mirror(floor(x), 0, width - 1); + int x2 = mirror(ceil(x), 0, width - 1); + int y1 = mirror(floor(y), 0, height - 1); + int y2 = mirror(ceil(y), 0, height - 1); + + float dx = x - x1; + float dy = y - y1; + + return ( + (1 - dy) * (1 - dx) * src[y1 * src_stride + x1] + + (1 - dy) * dx * src[y1 * src_stride + x2] + + dy * (1 - dx) * src[y2 * src_stride + x1] + + dy * dx * src[y2 * src_stride + x2] + ); +} + +static void vif_scale_frame_bilinear_s(const float *src, float *dst, + int src_w, int src_h, int src_stride, + int dst_w, int dst_h, int dst_stride) { + // if the input and output sizes are the same + if (src_w == dst_w && src_h == dst_h) { + memcpy(dst, src, dst_stride * dst_h * sizeof(float)); + return; + } + + float ratio_x = (float)src_w / dst_w; + float ratio_y = (float)src_h / dst_h; + + for (int y = 0; y < dst_h; y++) { + float yy = (y + 0.5) * ratio_y - 0.5; + for (int x = 0; x < dst_w; x++) { + float xx = (x + 0.5) * ratio_x - 0.5; + dst[y * dst_stride + x] = bilinear_interpolation(src, src_w, src_h, src_stride, xx, yy); + } + } +} + +static void vif_scale_frame_nearest_s(const float *src, float *dst, + int src_w, int src_h, int src_stride, + int dst_w, int dst_h, int dst_stride) { + // if the input and output sizes are the same + if (src_w == dst_w && src_h == dst_h) { + memcpy(dst, src, dst_stride * dst_h * sizeof(float)); + return; + } + + float ratio_x = (float)src_w / dst_w; + float ratio_y = (float)src_h / dst_h; + + for (int y = 0; y < dst_h; y++) { + for (int x = 0; x < dst_w; x++) { + int rounded_y = (int)(y * ratio_y); + int rounded_x = (int)(x * ratio_x); + dst[y * dst_stride + x] = src[rounded_y * src_stride + rounded_x]; + } + } +} + +void vif_scale_frame_s(enum vif_scaling_method scale_method, const float *src, float *dst, + int src_w, int src_h, int src_stride, + int dst_w, int dst_h, int dst_stride) { + if (scale_method == vif_scale_nearest) { + vif_scale_frame_nearest_s(src, dst, src_w, src_h, src_stride, dst_w, dst_h, dst_stride); + } else if (scale_method == vif_scale_bilinear) { + vif_scale_frame_bilinear_s(src, dst, src_w, src_h, src_stride, dst_w, dst_h, dst_stride); + } else if (scale_method == vif_scale_bicubic) { + vif_scale_frame_bicubic_s(src, dst, src_w, src_h, src_stride, dst_w, dst_h, dst_stride); + } else if (scale_method == vif_scale_lanczos4) { + vif_scale_frame_lanczos4_s(src, dst, src_w, src_h, src_stride, dst_w, dst_h, dst_stride); + } } diff --git a/libvmaf/src/feature/vif_tools.h b/libvmaf/src/feature/vif_tools.h index 5566bee9f..8d5a721c1 100644 --- a/libvmaf/src/feature/vif_tools.h +++ b/libvmaf/src/feature/vif_tools.h @@ -21,30 +21,51 @@ #ifndef VIF_TOOLS_H_ #define VIF_TOOLS_H_ -enum vif_kernelscale_enum { - vif_kernelscale_1 = 0, - vif_kernelscale_1o2 = 1, - vif_kernelscale_3o2 = 2, - vif_kernelscale_2 = 3, - vif_kernelscale_2o3 = 4, - vif_kernelscale_24o10 = 5, - vif_kernelscale_360o97 = 6, - vif_kernelscale_4o3 = 7, - vif_kernelscale_3d5o3 = 8, - vif_kernelscale_3d75o3 = 9, - vif_kernelscale_4d25o3 = 10, +#include + +enum vif_scaling_method { + vif_scale_nearest = 0, + vif_scale_bicubic = 1, + vif_scale_lanczos4 = 2, + vif_scale_bilinear = 3, +}; + +# define NUM_KERNELSCALES 21 + +static const float valid_kernelscales[NUM_KERNELSCALES] = { + 1.0, + 1.0f / 2.0f, + 3.0f / 2.0f, + 2.0f, + 2.0f / 3.0f, + 24.0f / 10.0f, + 360.0f / 97.0f, + 4.0f / 3.0f, + 3.5f / 3.0f, + 3.75f / 3.0f, + 4.25f / 3.0f, + 5.0f / 3.0f, + 3.0f, + 1.0f / 2.25f, + 1.4746f, + 1.54f, + 1.6f, + 1.06667f, + 0.711111f, + 0.740740f, + 1.111111f, }; -extern const float vif_filter1d_table_s[11][4][65]; // 4 is scale. since this is separable filter, filtering is 1d repeat horizontally and vertically -extern const int vif_filter1d_width[11][4]; /* s single precision, d double precision */ void vif_dec2_s(const float *src, float *dst, int src_w, int src_h, int src_stride, int dst_stride); // stride >= width, multiple of 16 or 32 typically +void vif_dec16_s(const float *src, float *dst, int src_w, int src_h, int src_stride, int dst_stride); // stride >= width, multiple of 16 or 32 typically + float vif_sum_s(const float *x, int w, int h, int stride); void vif_statistic_s(const float *mu1_sq, const float *mu2_sq, const float *xx_filt, const float *yy_filt, const float *xy_filt, float *num, float *den, - int w, int h, int mu1_sq_stride, int mu2_sq_stride, int xx_filt_stride, int yy_filt_stride, int xy_filt_stride, double vif_enhn_gain_limit); + int w, int h, int mu1_sq_stride, int mu2_sq_stride, int xx_filt_stride, int yy_filt_stride, int xy_filt_stride, double vif_enhn_gain_limit, double vif_sigma_nsq); void vif_filter1d_s(const float *f, const float *src, float *dst, float *tmpbuf, int w, int h, int src_stride, int dst_stride, int fwidth); @@ -52,4 +73,16 @@ void vif_filter1d_sq_s(const float *f, const float *src, float *dst, float *tmpb void vif_filter1d_xy_s(const float *f, const float *src1, const float *src2, float *dst, float *tmpbuf, int w, int h, int src1_stride, int src2_stride, int dst_stride, int fwidth); +int vif_get_scaling_method(char *scaling_method_str, enum vif_scaling_method *scale_method); + +void vif_scale_frame_s(enum vif_scaling_method scale_method, const float *src, float *dst, int src_w, int src_h, int src_stride, int dst_w, int dst_h, int dst_stride); + +int vif_get_filter_size(int scale, float kernelscale); + +void vif_get_filter(float *out, int scale, float kernelscale); + +void speed_get_antialias_filter(float *out, int scale, float kernelscale); + +bool vif_validate_kernelscale(float kernelscale); + #endif /* VIF_TOOLS_H_ */ From d670459999b4ede0974a30e0091c2ff87eccfff6 Mon Sep 17 00:00:00 2001 From: Kyle Swanson Date: Fri, 2 May 2025 11:55:37 -0700 Subject: [PATCH 2/2] libvmaf/feature: import speed feature extractor Co-authored-by: Nil Fons Miret --- libvmaf/src/feature/speed.c | 1566 +++++++++++++++++++++++++++++++++++ libvmaf/src/meson.build | 2 +- 2 files changed, 1567 insertions(+), 1 deletion(-) create mode 100644 libvmaf/src/feature/speed.c diff --git a/libvmaf/src/feature/speed.c b/libvmaf/src/feature/speed.c new file mode 100644 index 000000000..aafed9f88 --- /dev/null +++ b/libvmaf/src/feature/speed.c @@ -0,0 +1,1566 @@ +/** + * + * Copyright 2016-2025 Netflix, Inc. + * + * Licensed under the BSD+Patent License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSDplusPatent + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include +#include + +#include "dict.h" +#include "feature_collector.h" +#include "feature_extractor.h" +#include "feature_name.h" +#include "log.h" +#include "mem.h" +#include "opt.h" +#include "picture.h" +#include "picture_copy.h" +#include "vif_tools.h" + +typedef struct SpeedDimensions { + size_t original_height; + size_t original_width; + size_t scaled_height; + size_t scaled_width; + size_t alloc_height; + size_t alloc_width; + size_t operating_height; + size_t operating_width; + size_t block_size; + size_t truncated_width; + size_t truncated_height; + size_t num_blocks_horizontal; + size_t num_blocks_vertical; + size_t num_blocks; + size_t elements_in_block; + size_t submatrix_width; + size_t submatrix_height; +} SpeedDimensions; + +typedef struct SpeedResultBuffers { + float *entropies; + float *variances; +} SpeedResultBuffers; + +typedef struct SpeedBuffers { + float *independent_term; + float *linear_system_sol; + float *cov_mat; + float *eigenvalues; + float *tmp_buffer; +} SpeedBuffers; + +// Everything that is passed in as a feature option and is needed for +// SpEED computation +typedef struct SpeedOptions { + double speed_kernelscale; + double speed_prescale; + char *speed_prescale_method; + double speed_sigma_nn; + double speed_nn_floor; + int speed_weight_var_mode; +} SpeedOptions; + +// Everything that is needed to compute SpEED given a pair of float buffers +// (ref, dis), except for what is provided in SpeedOptions +typedef struct SpeedState { + SpeedDimensions dimensions; + SpeedResultBuffers ref_results; + SpeedResultBuffers dis_results; + SpeedBuffers buffers; + size_t float_stride; +} SpeedState; + +#define DEFAULT_BLOCK_SIZE (5) +#define NUM_SQUARE_BUFFERS (5) +#define NUM_RECT_BUFFERS (1) +#define NUM_FRAME_BUFFERS (2) +#define NUM_SCALES (4) +#define EIGENVALUE_EPS (1e-6) +#define EIGENVALUE_MAX_ITERS (500) +#define ALMOST_EQUAL(x,c) (fabs((x) - (c)) < 1.0e-3) +#define MAX(x, y) ((x) > (y) ? (x) : (y)) + +typedef struct Matrix { + int rows; + int cols; + float *data; +} Matrix; + +static void matrix_init(Matrix *mat, int rows, int cols, float *buffer) +{ + mat->data = buffer; + mat->rows = rows; + mat->cols = cols; +} + +static void matrix_transpose(Matrix *m) +{ + // only works for square matrices due to being an in-place operation + assert(m->rows == m->cols); + int size = m->rows; + for (int i = 0; i < size; i++) { + for (int j = 0; j < i; j++) { + float temp = m->data[i * size + j]; + m->data[i * size + j] = m->data[j * size + i]; + m->data[j * size + i] = temp; + } + } +} + +static void matrix_zero(Matrix *m) +{ + for (int i = 0; i < m->rows; i++) { + for (int j = 0; j < m->cols; j++) { + m->data[i * m->cols + j] = 0; + } + } +} + +static void matrix_identity(Matrix *m) +{ + for (int i = 0; i < m->rows; i++) { + for (int j = 0; j < m->cols; j++) { + m->data[i * m->cols + j] = (i == j ? 1 : 0); + } + } +} + +static void matrix_copy(Matrix *dst, const Matrix *src) +{ + dst->rows = src->rows; + dst->cols = src->cols; + for (int i = 0; i < dst->rows; i++) { + for (int j = 0; j < dst->cols; j++) { + dst->data[i * dst->cols + j] = src->data[i * src->cols + j]; + } + } +} + +static void matrix_mul(Matrix *dst, const Matrix *x, const Matrix *y) +{ + assert(x->cols == y->rows); + matrix_zero(dst); + for (int i = 0; i < x->rows; i++) { + for (int k = 0; k < x->cols; k++) { + for (int j = 0; j < y->cols; j++) { + dst->data[i * dst->cols + j] += + x->data[i * x->cols + k] * y->data[k * y->cols + j]; + } + } + } +} + +static void matrix_minor(Matrix *mat, int d) +{ + // only works for square matrices + assert(mat->rows == mat->cols); + int size = mat->rows; + for (int i = 0; i < size; i++) { + for (int j = 0; j < size; j++) { + if (i < d || j < d) { + if (i == j) + mat->data[i * size + j] = 1; + else + mat->data[i * size + j] = 0; + } + } + } +} + +static void matrix_identity_minus_v_vt(Matrix *dst, float *v) +{ + // dst = I - v v^T + // only works for square matrices + assert(dst->rows == dst->cols); + int size = dst->rows; + for (int i = 0; i < size; i++) { + for (int j = 0; j < size; j++) { + dst->data[i * size + j] = -2 * v[i] * v[j]; + } + } + + for (int i = 0; i < size; i++) + dst->data[i * size + i] += 1; +} + +static float vector_norm(const float *x, int n) +{ + float sum = 0; + for (int i = 0; i < n; i++) + sum += x[i] * x[i]; + return sqrt(sum); +} + +static void vector_div(const float *x, float d, float *y, int n) +{ + for (int i = 0; i < n; i++) + y[i] = x[i] / d; +} + +static void matrix_take_kth_column(const Matrix *mat, float *v, int k) +{ + //take k-th column of m, put in v + for (int i = 0; i < mat->rows; i++) + v[i] = mat->data[i * mat->cols + k]; +} + +static int get_sign(float x) +{ + return (x >= 0 ? 1 : -1); +} + +static float pythagoras(float x, float y) +{ + //hypotenuse length of the right triangle with sides x and y + return sqrt(x * x + y * y); +} + +static float compute_column_norm(float *A, int col, int start_row, int size) +{ + // euclidean norm of the column vector A[start_row:size, col] + float norm = 0; + for (int i = start_row; i < size; i++) + norm += A[i * size + col] * A[i * size + col]; + return sqrt(norm); +} + +// Compute a householder transformation (tau,v) of a vector +// x so that P x = [ I - tau*v*v' ] x annihilates x(1:n-1) +// +// On output, v is normalized so that v[0] = 1. The 1 is +// not actually stored; instead v[0] = -sign(x[0])*||x|| so +// that: +// P x = v[0] * e_1 + +static float compute_householder_transform(float *A, int col, int start_row, + int size) +{ + if (size - start_row == 1) + return 0.0f; + float xnorm = compute_column_norm(A, col, start_row + 1, size); + if (xnorm == 0.0f) + return 0.0f; + float alpha = A[start_row * size + col]; + float beta = -get_sign(alpha) * pythagoras(alpha, xnorm); + float tau = (beta - alpha) / beta; + float s = alpha - beta; + if (s != 0.0f) { + for (int i = start_row; i < size; i++) + A[i * size + col] /= s; + A[start_row * size + col] = beta; + } + return tau; +} + +// x = tau_i + A * v +// NOTE: only the rows and columns [start:size] of the involved matrices and +// vectors are used +static void tridiagonal_multiply(float *A, float *v, float *x, float tau_i, + int start, int size) +{ + for (int i = start; i < size; i++) { + x[i] = 0.0f; + for (int j = start; j < size; j++) + x[i] += tau_i * A[i * size + j] * v[j]; + } +} + +// returns dot(x, v) +// NOTE: only the rows and columns [start:size] of the involved matrices and +// vectors are used +static float tridiagonal_dot_product(float *x, float *v, int start, int size) +{ + float res = 0; + for (int i = start; i < size; i++) + res += x[i] * v[i]; + return res; +} + +// x += alpha * v +// NOTE: only the rows and columns [start:size] of the involved matrices and +// vectors are used +static void tridiagonal_axpy(float *x, float *v, float alpha, int start, int size) +{ + for (int i = start; i < size; i++) + x[i] += alpha * v[i]; +} + +// A -= x * v' + v * x' +// NOTE: only the rows and columns [start:size] of the involved matrices and +// vectors are used +static void tridiagonal_syr2(float *A, float *x, float *v, int start, int size) +{ + for (int i = start; i < size; i++) { + for (int j = start; j < size; j++) { + A[i * size + j] -= x[i] * v[j] + v[i] * x[j]; + } + } +} + +// Adapted gsl_linalg_symmtd_decomp: +// https://github.com/ampl/gsl/blob/master/linalg/symmtd.c +static void convert_to_tridiagonal(float *A, int size, float *d, float *sd, + float *buffer) +{ + float *v = buffer; buffer += size; + float *x = buffer; buffer += size; + + // We apply N-2 Householder transformations to zero out the elements + // outside of the diagonal or subdiagonal of each of the N-2 first columns + for (int i = 0; i < size - 2; i++) { + // Compute the vector v of the Householder transform that + // annihilates A[i+2:size, i] + float tau_i = compute_householder_transform(A, i, i + 1, size); + + // Copy i'th subcolumn of a into v + for (int j = i + 1; j < size; j++) + v[j] = A[j * size + i]; + + if (tau_i != 0.0f) { + // Set the first element of the returned vector to 1, + // since compute_householder_transform does this implicitly + A[(i + 1) * size + i] = v[i + 1]; + v[i + 1] = 1.0f; + // All operations described here are applied only to the rows and + // columns in [i+1:size] + // x = tau_i * A * v + tridiagonal_multiply(A, v, x, tau_i, i + 1, size); + // x -= 0.5 * tau_i * dot(x, v) * v] + float xv = tridiagonal_dot_product(x, v, i + 1, size); + float alpha = -0.5 * tau_i * xv; + tridiagonal_axpy(x, v, alpha, i + 1, size); + // A = A - v * x' - x * v' + tridiagonal_syr2(A, x, v, i + 1, size); + } + } + + // Copy the diagonal and subdiagonal elements into d and sd, respectively + for (int i = 0; i < size; i++) + d[i] = A[i * size + i]; + for (int i = 0; i < size - 1; i++) + sd[i] = A[(i + 1) * size + i]; +} + +static void chop_small_elements(float *d, float *sd, int size) +{ + for (int i = 0; i < size - 1; i++) { + if (fabsf(sd[i]) < EIGENVALUE_EPS * (fabsf(d[i]) + fabsf(d[i + 1]))) + sd[i] = 0.0f; + } +} + +static float trailing_eigenvalue(float *d, float *sd, int n) +{ + float ta = d[n - 2]; + float tb = d[n - 1]; + float tab = sd[n - 2]; + float dt = (ta - tb) / 2.0; + + if (dt > 0) + return tb - tab * (tab / (dt + pythagoras(dt, tab))); + else if (dt == 0) + return tb - fabsf(tab); + else + return tb + tab * (tab / (-dt + pythagoras(dt, tab))); +} + +static void create_givens(const float a, const float b, float *c, float *s) +{ + if (b == 0) { + *c = 1; + *s = 0; + } else if (fabsf(b) > fabsf(a)) { + float t = -a / b; + float s1 = 1.0 / sqrt(1 + t * t); + *s = s1; + *c = s1 * t; + } else { + float t = -b / a; + float c1 = 1.0 / sqrt(1 + t * t); + *c = c1; + *s = c1 * t; + } +} + +static void qr_step(float *d, float *sd, int n) +{ + float mu = trailing_eigenvalue(d, sd, n); + if (EIGENVALUE_EPS * fabsf(mu) > fabsf(d[0]) + fabsf(sd[0])) + mu = 0; + + float x = d[0] - mu; + float z = sd[0]; + + float ak = 0; + float bk = 0; + float zk = 0; + + float ap = d[0]; + float bp = sd[0]; + + float aq = d[1]; + + if (n == 2) { + float c, s; + create_givens(x, z, &c, &s); + + float ap1 = c * (c * ap - s * bp) + s * (s * aq - c * bp); + float bp1 = c * (s * ap + c * bp) - s * (s * bp + c * aq); + float aq1 = s * (s * ap + c * bp) + c * (s * bp + c * aq); + + ak = ap1; + bk = bp1; + + ap = aq1; + + d[0] = ak; + sd[0] = bk; + d[1] = ap; + + return; + } + + float bq = sd[1]; + + for (int k = 0; k < n - 1; k++) { + float c, s; + create_givens(x, z, &c, &s); + + /* compute G' T G */ + float bk1 = c * bk - s * zk; + + float ap1 = c * (c * ap - s * bp) + s * (s * aq - c * bp); + float bp1 = c * (s * ap + c * bp) - s * (s * bp + c * aq); + float zp1 = -s * bq; + + float aq1 = s * (s * ap + c * bp) + c * (s * bp + c * aq); + float bq1 = c * bq; + + ak = ap1; + bk = bp1; + zk = zp1; + + ap = aq1; + bp = bq1; + + if (k < n - 2) + aq = d[k + 2]; + + if (k < n - 3) + bq = sd[k + 2]; + + d[k] = ak; + + if (k > 0) + sd[k - 1] = bk1; + + if (k < n - 2) { + sd[k + 1] = bp; + } + + x = bk; + z = zk; + } + + d[n - 1] = ap; + sd[n - 2] = bk; +} + +// Adapted gsl_eigen_symm: https://github.com/ampl/gsl/blob/master/eigen/symm.c +static void compute_eigenvalues_tridiagonal(float *d, float *sd, + float *eigenvalues, int size) +{ + // Initial pass to remove subdiagonal elements which are effectively zero + chop_small_elements(d, sd, size); + + // Progressively reduce the matrix until it is diagonal + int b = size - 1; + int iter = 0; + while (b > 0 && iter < EIGENVALUE_MAX_ITERS) { + if (sd[b - 1] == 0.0f) { + b--; + continue; + } + // Find the largest unreduced block (a, b) starting from b and working + // backwards + int a = b - 1; + while (a > 0) { + if (sd[a - 1] == 0.0f) { + break; + } + a--; + } + + const int n_block = b - a + 1; + float *d_block = d + a; + float *sd_block = sd + a; + qr_step(d_block, sd_block, n_block); + chop_small_elements(d_block, sd_block, n_block); + iter++; + } + + if (iter == EIGENVALUE_MAX_ITERS) { + vmaf_log(VMAF_LOG_LEVEL_WARNING, + "compute_eigenvalues_tridiagonal: max iterations reached, " + "possible non-convergence\n"); + } + + for (int i = 0; i < size; i++) + eigenvalues[i] = d[i]; +} + +static void compute_eigenvalues(float *A_immutable, float *eigenvalues, + int size, float *buffer) +{ + float *A = buffer; buffer += size * size; + float *d = buffer; buffer += size; + float *sd = buffer; buffer += size; + float *tmp = buffer; buffer += 2 * size; + // Operate on a copy of the matrix + memcpy(A, A_immutable, size * size * sizeof(float)); + // Handle special case + if (size == 1) { + eigenvalues[0] = A[0 * size + 0]; + return; + } + convert_to_tridiagonal(A, size, d, sd, tmp); + compute_eigenvalues_tridiagonal(d, sd, eigenvalues, size); +} + +// Implementation of the QR decomposition algorithm with Householder +// reflections for an arbitrary square matrix +// https://www.cs.utexas.edu/users/flame/Notes/NotesOnHouseholderQR.pdf +static void matrix_qr_decomposition(Matrix *A, Matrix *Q, Matrix *R, + Matrix *tmp_q, Matrix *tmp_z) +{ + assert(A->rows == A->cols); + int size = A->rows; + + // We need 3 temporary matrices. We can use R as a temporary matrix during + // the process since it's only used at the end. We can also use R for the + // temporary vector, since both uses are disjoint in time + Matrix *tmp_mul = R; + float *vec = R->data; + + // Initially, tmp_z = A + matrix_copy(tmp_z, A); + + // Q starts as the identity, and is left-multiplied by tmp_q at each iteration + matrix_identity(Q); + + for (int k = 0; k < size - 1; k++) { + matrix_minor(tmp_z, k); + matrix_take_kth_column(tmp_z, vec, k); + + float norm = vector_norm(vec, size); + int sign = get_sign(A->data[k * size + k]); + vec[k] += sign * norm; + + vector_div(vec, vector_norm(vec, size), vec, size); + matrix_identity_minus_v_vt(tmp_q, vec); + + matrix_mul(tmp_mul, tmp_q, tmp_z); + matrix_copy(tmp_z, tmp_mul); + matrix_mul(tmp_mul, tmp_q, Q); + matrix_copy(Q, tmp_mul); + } + + matrix_mul(R, Q, A); + matrix_transpose(Q); +} + +// Solves RX = B, where R is square, upper triangular and invertible +// Uses backward substitution algorithm +static int solve_triangular_system(const Matrix *R, Matrix *X, const Matrix *B) +{ + assert(R->rows == R->cols); + assert(R->rows == X->rows); + assert(R->rows == B->rows); + assert(X->cols == B->cols); + + for (int i = R->rows - 1; i >= 0; i--) { + float denominator = R->data[i * R->cols + i]; + if (fabsf(denominator) < EIGENVALUE_EPS) { + return -EINVAL; + } + for (int j = 0; j < X->cols; j++) { + float independent_term = B->data[i * B->cols + j]; + for (int k = i + 1; k < R->rows; k++) { + independent_term -= + (X->data[k * X->cols + j] * R->data[i * R->cols + k]); + } + X->data[i * X->cols + j] = independent_term / denominator; + } + } + return 0; +} + +// Solves the linear system A X = B, using the QR decomposition of A +static int solve_linear_system(float *A_data, int A_size, float *B_data, + int B_cols, float *output_data, + float *tmp_buffer) +{ + float *A_buffer = tmp_buffer; tmp_buffer += A_size * A_size; + float *Q_buffer = tmp_buffer; tmp_buffer += A_size * A_size; + float *R_buffer = tmp_buffer; tmp_buffer += A_size * A_size; + float *tmp1_buffer = tmp_buffer; tmp_buffer += A_size * A_size; + float *tmp2_buffer = tmp_buffer; tmp_buffer += A_size * A_size; + float *tmp_rect_buffer = tmp_buffer; tmp_buffer += A_size * B_cols; + + Matrix A_immutable; matrix_init(&A_immutable, A_size, A_size, A_data); + Matrix A; matrix_init(&A, A_size, A_size, A_buffer); + matrix_copy(&A, &A_immutable); + + Matrix Q; matrix_init(&Q, A_size, A_size, Q_buffer); + Matrix R; matrix_init(&R, A_size, A_size, R_buffer); + Matrix tmp1; matrix_init(&tmp1, A_size, A_size, tmp1_buffer); + Matrix tmp2; matrix_init(&tmp2, A_size, A_size, tmp2_buffer); + + Matrix B_immutable; matrix_init(&B_immutable, A_size, B_cols, B_data); + + Matrix tmp_rect; matrix_init(&tmp_rect, A_size, B_cols, tmp_rect_buffer); + Matrix X; matrix_init(&X, A_size, B_cols, output_data); + + // A = Q R, where Q^{-1} = Q^T and R is upper triangular + // A X = B ==> Q R X = B ==> R X = Q^T B + matrix_qr_decomposition(&A, &Q, &R, &tmp1, &tmp2); + matrix_transpose(&Q); + matrix_mul(&tmp_rect, &Q, &B_immutable); + int err = solve_triangular_system(&R, &X, &tmp_rect); + return err; +} + +static float compute_mean(SpeedDimensions dim, const float *data, + size_t stride_px, int start_row, int start_col) +{ + float result = 0; + for (size_t i = 0; i < dim.submatrix_height; i++) { + for (size_t j = 0; j < dim.submatrix_width; j++) + result += data[(start_row + i) * stride_px + (start_col + j)]; + } + + return result / (dim.submatrix_width * dim.submatrix_height); +} + +// Computes the covariance between two arrays of the same given size +// The arrays are stored as submatrices of the input data, starting at +// (start_row_x, start_col_x) and (start_row_y, start_col_y) +// and having dimensions (submatrix_height, submatrix_width) +static float compute_covariance(SpeedDimensions dim, const float *data, + const float *means, size_t stride_px, + int start_row_x, int start_col_x, + int start_row_y, int start_col_y) +{ + double mean_x = means[start_row_x * dim.block_size + start_col_x]; + double mean_y = means[start_row_y * dim.block_size + start_col_y]; + double result = 0; + for (size_t i = 0; i < dim.submatrix_height; i++) { + for (size_t j = 0; j < dim.submatrix_width; j++) { + double val_x = + data[(start_row_x + i) * stride_px + (start_col_x + j)]; + double val_y = + data[(start_row_y + i) * stride_px + (start_col_y + j)]; + result += (val_x - mean_x) * (val_y - mean_y); + } + } + return result / (dim.submatrix_width * dim.submatrix_height); +} + +static void compute_covariance_matrix(SpeedDimensions dim, const float *data, + float *cov_mat, float *means, + size_t stride_px) +{ + for (size_t start_row = 0; start_row < dim.block_size; start_row++) { + for (size_t start_col = 0; start_col < dim.block_size; start_col++) { + means[start_row * dim.block_size + start_col] = + compute_mean(dim, data, stride_px, start_row, start_col); + } + } + size_t elements_in_block = dim.block_size * dim.block_size; + + for (size_t x_index = 0; x_index < dim.elements_in_block; x_index++) { + for (size_t y_index = 0; y_index <= x_index; y_index++) { + size_t start_row_x = x_index / dim.block_size; + size_t start_col_x = x_index % dim.block_size; + size_t start_row_y = y_index / dim.block_size; + size_t start_col_y = y_index % dim.block_size; + float covariance = + compute_covariance(dim, data, means, stride_px, start_row_x, + start_col_x, start_row_y, start_col_y); + cov_mat[x_index * elements_in_block + y_index] = covariance; + cov_mat[y_index * elements_in_block + x_index] = covariance; + } + } +} + +static void compute_independent_term(SpeedDimensions dim, const float *data, + float *independent_term, size_t stride_px) +{ + for (size_t start_i = 0; start_i < dim.block_size; start_i++) { + for (size_t start_j = 0; start_j < dim.block_size; start_j++) { + for (size_t i = start_i; i < dim.truncated_height; i += dim.block_size) { + for (size_t j = start_j; j < dim.truncated_width; j += dim.block_size) { + size_t out_row = start_i * dim.block_size + start_j; + size_t out_col = + (((i - start_i) / dim.block_size) * + dim.num_blocks_horizontal) + ((j - start_j) / dim.block_size); + independent_term[out_row * dim.num_blocks + out_col] = + data[i * stride_px + j]; + } + } + } + } +} + +static void compute_pointwise_product_and_division(SpeedDimensions dim, + float *X, const float *Y, + float denominator) +{ + for (size_t i = 0; i < dim.elements_in_block; i++) { + for (size_t j = 0; j < dim.num_blocks; j++) { + X[i * dim.num_blocks + j] = + (X[i * dim.num_blocks + j] * Y[i * dim.num_blocks + j]) / denominator; + } + } +} + +static void sum_columns(SpeedDimensions dim, float *X) +{ + for (size_t i = 1; i < dim.elements_in_block; i++) { + for (size_t j = 0; j < dim.num_blocks; j++) { + X[0 * dim.num_blocks + j] += X[i * dim.num_blocks + j]; + } + } +} + +static void update_entropy(SpeedDimensions dim, float *entropy, const float *S, + float L, float sigma_nn) +{ + for (size_t i = 0; i < dim.num_blocks_vertical; i++) { + for (size_t j = 0; j < dim.num_blocks_horizontal; j++) { + entropy[i * dim.num_blocks_horizontal + j] += + log2(L * S[i * dim.num_blocks_horizontal + j] + sigma_nn) + + log2(2 * M_PI * M_E); + } + } +} + +static bool is_matrix_regular(SpeedDimensions dim, const float *eigenvalues) +{ + for (size_t i = 0; i < dim.elements_in_block; i++) { + if (eigenvalues[i] < EIGENVALUE_EPS) { + return false; + } + } + return true; +} + +static int est_params(SpeedState *s, const float *data, float sigma_nn, + SpeedResultBuffers *output) +{ + SpeedDimensions dim = s->dimensions; + size_t stride_px = s->float_stride / sizeof(float); + + // Step 1: Compute the covariance matrix K + // We use the eigenvalues array as a temporary array to store the means + // needed for the covariance + compute_covariance_matrix(dim, data, s->buffers.cov_mat, + s->buffers.eigenvalues, stride_px); + + // Step 2: Compute the eigenvalues of the covariance matrix + compute_eigenvalues(s->buffers.cov_mat, s->buffers.eigenvalues, + dim.elements_in_block, s->buffers.tmp_buffer); + + // Step 3: Compute independent term for the linear system + compute_independent_term(dim, data, s->buffers.independent_term, stride_px); + + // Step 4: Solve the linear equation KX = Y, where K is the covariance + // matrix and Y is the independent term matrix above + int err = 0; + bool regular = is_matrix_regular(dim, s->buffers.eigenvalues); + if (regular) { + err = solve_linear_system(s->buffers.cov_mat, dim.elements_in_block, + s->buffers.independent_term, + dim.num_blocks, s->buffers.linear_system_sol, + s->buffers.tmp_buffer); + } + + bool cannot_invert = !regular || err; + if (cannot_invert) { + vmaf_log(VMAF_LOG_LEVEL_WARNING, + "est_params: covariance matrix is singular\n"); + memset(s->buffers.linear_system_sol, 0, + dim.elements_in_block * dim.num_blocks * sizeof(float)); + } + + // Step 5: Compute the pointwise product Z = (X * Y)/B^2, where X and Y are + // from the linear system above, and B is the block size. + // Store the results in s->linear_system_sol + compute_pointwise_product_and_division(dim, s->buffers.linear_system_sol, + s->buffers.independent_term, + dim.elements_in_block); + + // Step 6: Sum each column in Z into an array S. + // Store the results in the first row of s->linear_system_sol + sum_columns(dim, s->buffers.linear_system_sol); + + // Step 7: Reshape S into a matrix of dimensions (H/B, W/B) + // This is a no-op, we just need to index it with the right dimensions + + // Step 8: Construct a zeroed-out matrix E of size (H/B, W/B) + memset(output->entropies, 0, + dim.num_blocks_horizontal * dim.num_blocks_vertical * sizeof(float)); + + // Step 9: For each eigenvalue L, update the entropy matrix + for (size_t k = 0; k < dim.elements_in_block; k++) { + float L = s->buffers.eigenvalues[k] < 0 ? 0 : s->buffers.eigenvalues[k]; + update_entropy(dim, output->entropies, s->buffers.linear_system_sol, L, + sigma_nn); + } + + // Step 10: Return S, E + memcpy(output->variances, s->buffers.linear_system_sol, + dim.num_blocks * sizeof(float)); + + return cannot_invert ? -EINVAL : 0; +} + +static float get_speed_score(SpeedDimensions dim, SpeedResultBuffers ref_results, + SpeedResultBuffers dis_results, float sigma_nn, + float nn_floor, int speed_weight_var_mode) +{ + float score = 0; + float base_entropy = dim.elements_in_block * (log2((1 + nn_floor) * sigma_nn) + log2(2 * M_PI * M_E)); + for (size_t i = 0; i < dim.num_blocks; i++) { + if ((ref_results.entropies[i] < base_entropy) && + (dis_results.entropies[i] < base_entropy)) + { + // If both entropies are below the base_entropy, + // there is no visible difference + score += 0; + } else { + float spatial_ref = 0.0; + float spatial_dis = 0.0; + if (speed_weight_var_mode == 0) { + spatial_ref = ref_results.entropies[i] * log2(1 + ref_results.variances[i]); + spatial_dis = dis_results.entropies[i] * log2(1 + dis_results.variances[i]); + } else if (speed_weight_var_mode == 1) { + spatial_ref = ref_results.entropies[i] * log2(1 + ref_results.variances[i]); + spatial_dis = dis_results.entropies[i] * log2(1 + ref_results.variances[i]); + } else if (speed_weight_var_mode == 2) { + spatial_ref = ref_results.entropies[i] * log2(1 + dis_results.variances[i]); + spatial_dis = dis_results.entropies[i] * log2(1 + dis_results.variances[i]); + } else if (speed_weight_var_mode == 3) { + spatial_ref = ref_results.entropies[i] * log2(1 + (ref_results.variances[i] + dis_results.variances[i]) / 2.0); + spatial_dis = dis_results.entropies[i] * log2(1 + (ref_results.variances[i] + dis_results.variances[i]) / 2.0); + } else if (speed_weight_var_mode == 4) { + spatial_ref = ref_results.entropies[i] * log2(1 + ref_results.variances[i]); + spatial_dis = dis_results.entropies[i] * log2(1 + (ref_results.variances[i] + dis_results.variances[i]) / 2.0); + } else if (speed_weight_var_mode == 5) { + spatial_ref = ref_results.entropies[i] * log2(1 + ref_results.variances[i]); + spatial_dis = dis_results.entropies[i] * log2(1 + (0.75 * ref_results.variances[i] + 0.25 * dis_results.variances[i])); + } else if (speed_weight_var_mode == 6) { + spatial_ref = ref_results.entropies[i] * log2(1 + ref_results.variances[i]); + spatial_dis = dis_results.entropies[i] * log2(1 + (0.25 * ref_results.variances[i] + 0.75 * dis_results.variances[i])); + } else { + return -EINVAL; + } + score += fabsf(spatial_ref - spatial_dis); + } + + } + + return score / dim.num_blocks; +} + +static void subtract_image(float *im1, float *im2, int w, int h, size_t stride) { + size_t stride_px = stride / sizeof(float); + for (int i = 0; i < h; i++) { + for (int j = 0; j < w; j++) { + im1[i * stride_px + j] -= im2[i * stride_px + j]; + } + } +} + +// Filters the image with a Gaussian filter and then performs local +// mean subtraction +static void filter_and_downscale(SpeedDimensions dim, SpeedOptions *opt, + float *frame_buffer, float *tmp_buffer, + size_t float_stride) +{ + size_t stride_px = float_stride / sizeof(float); + + size_t frame_size = stride_px * dim.alloc_height; + float *curr_scale = tmp_buffer; tmp_buffer += frame_size; + float *tmpbuf = tmp_buffer; tmp_buffer += frame_size; + + // The scaling method has been checked for validity in the init callback + enum vif_scaling_method scaling_method; + vif_get_scaling_method(opt->speed_prescale_method, &scaling_method); + + if (!ALMOST_EQUAL(opt->speed_prescale, 1.0)) { + memcpy(tmpbuf, frame_buffer, + stride_px * dim.alloc_height * sizeof(float)); + vif_scale_frame_s(scaling_method, tmpbuf, frame_buffer, + dim.original_width, dim.original_height, stride_px, + dim.scaled_width, dim.scaled_height, stride_px); + } + + // The kernelscale has been checked for validity in the init callback + int filter_width_antialias = vif_get_filter_size(1, opt->speed_kernelscale); + float filter_antialias[128]; + speed_get_antialias_filter(filter_antialias, NUM_SCALES, + opt->speed_kernelscale); + vif_filter1d_s(filter_antialias, frame_buffer, curr_scale, tmpbuf, + dim.scaled_width, dim.scaled_height, float_stride, + float_stride, filter_width_antialias); + + vif_dec16_s(curr_scale, frame_buffer, dim.scaled_width, dim.scaled_height, + float_stride, float_stride); + + size_t downscaled_w = dim.scaled_width >> NUM_SCALES; + size_t downscaled_h = dim.scaled_height >> NUM_SCALES; + + int filter_width = vif_get_filter_size(NUM_SCALES, opt->speed_kernelscale); + float filter[128]; + vif_get_filter(filter, NUM_SCALES, opt->speed_kernelscale); + vif_filter1d_s(filter, frame_buffer, curr_scale, tmpbuf, downscaled_w, + downscaled_h, float_stride, float_stride, filter_width); + subtract_image(frame_buffer, curr_scale, downscaled_w, downscaled_h, + float_stride); +} + +int speed_extract_score(SpeedState *s, SpeedOptions *opt, float *ref, + float *dis, float *score) +{ + filter_and_downscale(s->dimensions, opt, ref, s->buffers.tmp_buffer, + s->float_stride); + int err_ref = est_params(s, ref, opt->speed_sigma_nn, &(s->ref_results)); + + filter_and_downscale(s->dimensions, opt, dis, s->buffers.tmp_buffer, + s->float_stride); + + int err_dis = est_params(s, dis, opt->speed_sigma_nn, &(s->dis_results)); + + // If only one of ref and dis was numerically unstable (very rare) + // we return 0 instead of an inflated score that may skew the average + if ((err_ref && !err_dis) || (!err_ref && err_dis)) { + *score = 0.0f; + } else { + *score = get_speed_score(s->dimensions, s->ref_results, s->dis_results, + opt->speed_sigma_nn, opt->speed_nn_floor, + opt->speed_weight_var_mode); + } + + return err_ref || err_dis; +} + +static int speed_init_dimensions(SpeedDimensions *dim, int w, int h, + double speed_prescale) +{ + dim->original_height = h; + dim->original_width = w; + dim->scaled_height = (int)(dim->original_height * speed_prescale + 0.5); + dim->scaled_width = (int)(dim->original_width * speed_prescale + 0.5); + dim->alloc_height = MAX(dim->original_height, dim->scaled_height); + dim->alloc_width = MAX(dim->original_width, dim->scaled_width); + dim->operating_height = dim->scaled_height >> NUM_SCALES; + dim->operating_width = dim->scaled_width >> NUM_SCALES; + dim->block_size = DEFAULT_BLOCK_SIZE; + dim->truncated_width = + (dim->operating_width / dim->block_size) * dim->block_size; + dim->truncated_height = + (dim->operating_height / dim->block_size * dim->block_size); + dim->num_blocks_horizontal = dim->truncated_width / dim->block_size; + dim->num_blocks_vertical = dim->truncated_height / dim->block_size; + dim->num_blocks = dim->num_blocks_horizontal * dim->num_blocks_vertical; + dim->elements_in_block = dim->block_size * dim->block_size; + dim->submatrix_width = dim->truncated_width - dim->block_size + 1; + dim->submatrix_height = dim->truncated_height - dim->block_size + 1; + + if (dim->truncated_height == 0 || dim->truncated_width == 0) { + vmaf_log(VMAF_LOG_LEVEL_ERROR, + "SpEED: image too small, operating width or height is 0"); + return -EINVAL; + } + return 0; +} + +int speed_init(SpeedState *s, SpeedOptions *opt, int w, int h) +{ + SpeedDimensions *dim = &s->dimensions; + speed_init_dimensions(dim, w, h, opt->speed_prescale); + + // Check that the kernelscale is valid + if (!vif_validate_kernelscale(opt->speed_kernelscale)) { + vmaf_log(VMAF_LOG_LEVEL_ERROR, "invalid speed_kernelscale"); + return -EINVAL; + } + + enum vif_scaling_method scaling_method; + if (vif_get_scaling_method(opt->speed_prescale_method, &scaling_method)) { + return -EINVAL; + } + + s->float_stride = ALIGN_CEIL(dim->alloc_width * sizeof(float)); + size_t stride_px = s->float_stride / sizeof(float); + + size_t tmp_buffer_size = sizeof(float) * ( + NUM_SQUARE_BUFFERS * dim->elements_in_block * dim->elements_in_block + + NUM_RECT_BUFFERS * dim->elements_in_block * dim->num_blocks + + NUM_FRAME_BUFFERS * dim->alloc_height * stride_px + ); + + s->buffers.independent_term = + aligned_malloc(sizeof(float) * dim->num_blocks * dim->elements_in_block, 32); + if (!s->buffers.independent_term) + return -ENOMEM; + s->buffers.linear_system_sol = + aligned_malloc(sizeof(float) * dim->num_blocks * dim->elements_in_block, 32); + if (!s->buffers.linear_system_sol) + return -ENOMEM; + s->buffers.cov_mat = + aligned_malloc(sizeof(float) * dim->elements_in_block * dim->elements_in_block, 32); + if (!s->buffers.cov_mat) + return -ENOMEM; + s->buffers.eigenvalues = + aligned_malloc(sizeof(float) * dim->elements_in_block, 32); + if (!s->buffers.eigenvalues) + return -ENOMEM; + s->buffers.tmp_buffer = aligned_malloc(tmp_buffer_size, 32); + if (!s->buffers.tmp_buffer) + return -ENOMEM; + + s->ref_results.entropies = aligned_malloc(sizeof(float) * dim->num_blocks, 32); + if (!s->ref_results.entropies) + return -ENOMEM; + s->ref_results.variances = aligned_malloc(sizeof(float) * dim->num_blocks, 32); + if (!s->ref_results.variances) + return -ENOMEM; + s->dis_results.entropies = aligned_malloc(sizeof(float) * dim->num_blocks, 32); + if (!s->dis_results.entropies) + return -ENOMEM; + s->dis_results.variances = aligned_malloc(sizeof(float) * dim->num_blocks, 32); + if (!s->dis_results.variances) + return -ENOMEM; + + return 0; +} + +int speed_close(SpeedState *s) { + if (s->buffers.independent_term) + aligned_free(s->buffers.independent_term); + if (s->buffers.linear_system_sol) + aligned_free(s->buffers.linear_system_sol); + if (s->buffers.cov_mat) + aligned_free(s->buffers.cov_mat); + if (s->buffers.eigenvalues) + aligned_free(s->buffers.eigenvalues); + if (s->buffers.tmp_buffer) + aligned_free(s->buffers.tmp_buffer); + + if (s->ref_results.entropies) + aligned_free(s->ref_results.entropies); + if (s->ref_results.variances) + aligned_free(s->ref_results.variances); + if (s->dis_results.entropies) + aligned_free(s->dis_results.entropies); + if (s->dis_results.variances) + aligned_free(s->dis_results.variances); + + return 0; +} + +#define DEFAULT_SPEED_SIGMA_NN (0.29) +#define DEFAULT_SPEED_MAX_VAL (1000.0) +#define DEFAULT_SPEED_NN_FLOOR (0.0) +#define DEFAULT_SPEED_KERNELSCALE (1.0) +#define DEFAULT_SPEED_PRESCALE (1.0) +#define DEFAULT_SPEED_PRESCALE_METHOD ("nearest") +#define MIN(x, y) (((x) < (y)) ? (x) : (y)) + +typedef struct SpeedChromaState { + SpeedState speed_state; + SpeedOptions speed_options; + float *frame_buffer_ref; + float *frame_buffer_dis; + VmafDictionary *feature_name_dict; + double speed_chroma_kernelscale; + double speed_chroma_prescale; + char *speed_chroma_prescale_method; + double speed_chroma_sigma_nn; + double speed_chroma_nn_floor; + double speed_chroma_max_val; + int speed_weight_var_mode; +} SpeedChromaState; + +static const VmafOption options_chroma[] = { + { + .name = "speed_kernelscale", + .help = "scaling factor for the gaussian kernel (2.0 means " + "multiplying the standard deviation by 2 and enlarge " + "the kernel size accordingly", + .offset = offsetof(SpeedChromaState, speed_chroma_kernelscale), + .type = VMAF_OPT_TYPE_DOUBLE, + .default_val.d = DEFAULT_SPEED_KERNELSCALE, + .min = 0.1, + .max = 4.0, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "ks", + }, + { + .name = "speed_prescale", + .help = "scaling factor for the frame (2.0 means " + "making the image twice as large on each dimension)", + .offset = offsetof(SpeedChromaState, speed_chroma_prescale), + .type = VMAF_OPT_TYPE_DOUBLE, + .default_val.d = DEFAULT_SPEED_PRESCALE, + .min = 0.1, + .max = 4.0, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "ps", + }, + { + .name = "speed_prescale_method", + .help = "scaling method for the frame, supported options: " + "[nearest, bilinear, bicubic, lanczos4]", + .offset = offsetof(SpeedChromaState, speed_chroma_prescale_method), + .type = VMAF_OPT_TYPE_STRING, + .default_val.s = DEFAULT_SPEED_PRESCALE_METHOD, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "psm", + }, + { + .name = "speed_sigma_nn", + .help = "standard deviation of neural noise", + .offset = offsetof(SpeedChromaState, speed_chroma_sigma_nn), + .type = VMAF_OPT_TYPE_DOUBLE, + .default_val.d = DEFAULT_SPEED_SIGMA_NN, + .min = 0.1, + .max = 2.0, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "snn", + }, + { + .name = "speed_nn_floor", + .help = "neural noise floor, expressed in percentage of sigma_nn", + .offset = offsetof(SpeedChromaState, speed_chroma_nn_floor), + .type = VMAF_OPT_TYPE_DOUBLE, + .default_val.d = DEFAULT_SPEED_NN_FLOOR, + .min = 0.0, + .max = 1.0, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "nnf", + }, + { + .name = "speed_max_val", + .help = "maximum value allowed; " + "larger values will be clipped to this value", + .offset = offsetof(SpeedChromaState, speed_chroma_max_val), + .type = VMAF_OPT_TYPE_DOUBLE, + .default_val.d = DEFAULT_SPEED_MAX_VAL, + .min = 0.0, + .max = 1000.0, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "mxv", + }, + { + .name = "speed_weight_var_mode", + .help = "different approaches to perform variance-absed weighting", + .offset = offsetof(SpeedChromaState, speed_weight_var_mode), + .type = VMAF_OPT_TYPE_INT, + .default_val.d = 0, + .min = 0, + .max = 6, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "wvm", + }, + { 0 } +}; + +static int init_chroma(VmafFeatureExtractor *fex, enum VmafPixelFormat pix_fmt, + unsigned bpc, unsigned w, unsigned h) +{ + (void)bpc; + + switch (pix_fmt) { + case VMAF_PIX_FMT_UNKNOWN: + case VMAF_PIX_FMT_YUV400P: + return -EINVAL; + case VMAF_PIX_FMT_YUV420P: + w /= 2; + h /= 2; + break; + case VMAF_PIX_FMT_YUV422P: + w /= 2; + break; + case VMAF_PIX_FMT_YUV444P: + break; + } + + SpeedChromaState *s = fex->priv; + s->speed_options = (SpeedOptions) { + .speed_kernelscale = s->speed_chroma_kernelscale, + .speed_prescale = s->speed_chroma_prescale, + .speed_prescale_method = s->speed_chroma_prescale_method, + .speed_sigma_nn = s->speed_chroma_sigma_nn, + .speed_nn_floor = s->speed_chroma_nn_floor, + .speed_weight_var_mode = s->speed_weight_var_mode, + }; + speed_init(&s->speed_state, &s->speed_options, w, h); + SpeedDimensions dim = s->speed_state.dimensions; + + s->feature_name_dict = + vmaf_feature_name_dict_from_provided_features(fex->provided_features, + fex->options, s); + if (!s->feature_name_dict) + return -ENOMEM; + + s->frame_buffer_ref = + aligned_malloc(s->speed_state.float_stride * dim.alloc_height, 32); + if (!s->frame_buffer_ref) + return -ENOMEM; + s->frame_buffer_dis = + aligned_malloc(s->speed_state.float_stride * dim.alloc_height, 32); + if (!s->frame_buffer_dis) + return -ENOMEM; + + return 0; +} + +static float extract_channel(SpeedChromaState *s, VmafPicture *ref_pic, + VmafPicture *dist_pic, int channel, float *score) +{ + picture_copy(s->frame_buffer_ref, s->speed_state.float_stride, + ref_pic, -128, ref_pic->bpc, channel); + picture_copy(s->frame_buffer_dis, s->speed_state.float_stride, + dist_pic, -128, dist_pic->bpc, channel); + return speed_extract_score(&s->speed_state, &s->speed_options, + s->frame_buffer_ref, s->frame_buffer_dis, score); +} + +static int extract_chroma(VmafFeatureExtractor *fex, + VmafPicture *ref_pic, VmafPicture *ref_pic_90, + VmafPicture *dist_pic, VmafPicture *dist_pic_90, + unsigned index, VmafFeatureCollector *feature_collector) +{ + (void)ref_pic_90; + (void)dist_pic_90; + + SpeedChromaState *s = fex->priv; + + float score_u, score_v; + int err_u = extract_channel(s, ref_pic, dist_pic, 1, &score_u); + int err_v = extract_channel(s, ref_pic, dist_pic, 2, &score_v); + + // There are edge cases where one or both channels (U and V) have singular + // covariance matrices. For example, when the channel is completely flat. + // If only one channel is singular, we impute its score from the other + // channel, and therefore the combined score_uv is equal to the other + // channel. This is a better approximation than imputing it to be zero. + + float score_uv; + if (err_u && !err_v) { + score_uv = score_v; + } else if (err_v && !err_u) { + score_uv = score_u; } + else { + score_uv = (score_u + score_v) / 2.0; + } + + int err = 0; + + err |= + vmaf_feature_collector_append_with_dict(feature_collector, + s->feature_name_dict, "Speed_chroma_feature_speed_chroma_u_score", + MIN(score_u, s->speed_chroma_max_val), index); + err |= + vmaf_feature_collector_append_with_dict(feature_collector, + s->feature_name_dict, "Speed_chroma_feature_speed_chroma_v_score", + MIN(score_v, s->speed_chroma_max_val), index); + err |= + vmaf_feature_collector_append_with_dict(feature_collector, + s->feature_name_dict, "Speed_chroma_feature_speed_chroma_uv_score", + MIN(score_uv, s->speed_chroma_max_val), index); + return err; +} + +static int close_chroma(VmafFeatureExtractor *fex) +{ + SpeedChromaState *s = fex->priv; + + speed_close(&s->speed_state); + + if (s->frame_buffer_ref) + aligned_free(s->frame_buffer_ref); + if (s->frame_buffer_dis) + aligned_free(s->frame_buffer_dis); + + if (s->feature_name_dict) + vmaf_dictionary_free(&s->feature_name_dict); + + return 0; +} + +static const char *provided_features_chroma[] = { + "Speed_chroma_feature_speed_chroma_u_score", + "Speed_chroma_feature_speed_chroma_v_score", + "Speed_chroma_feature_speed_chroma_uv_score", + NULL +}; + +VmafFeatureExtractor vmaf_fex_speed_chroma = { + .name = "speed_chroma", + .init = init_chroma, + .extract = extract_chroma, + .close = close_chroma, + .options = options_chroma, + .priv_size = sizeof(SpeedChromaState), + .provided_features = provided_features_chroma, +}; + +#define DEFAULT_SPEED_SIGMA_NN (0.29) +#define DEFAULT_SPEED_MAX_VAL (1000.0) +#define DEFAULT_SPEED_NN_FLOOR (0.0) +#define DEFAULT_SPEED_KERNELSCALE (1.0) +#define DEFAULT_SPEED_PRESCALE (1.0) +#define DEFAULT_SPEED_PRESCALE_METHOD ("nearest") + +typedef struct SpeedTemporalState { + SpeedState speed_state; + SpeedOptions speed_options; + float *frame_buffer_ref[2]; + float *frame_buffer_dis[2]; + VmafDictionary *feature_name_dict; + int index; + double score; + double speed_temporal_kernelscale; + double speed_temporal_prescale; + char *speed_temporal_prescale_method; + double speed_temporal_sigma_nn; + double speed_temporal_nn_floor; + double speed_temporal_max_val; + bool speed_temporal_use_ref_diff; +} SpeedTemporalState; + +static const VmafOption options[] = { + { + .name = "speed_kernelscale", + .help = "scaling factor for the gaussian kernel (2.0 means " + "multiplying the standard deviation by 2 and enlarge " + "the kernel size accordingly", + .offset = offsetof(SpeedTemporalState, speed_temporal_kernelscale), + .type = VMAF_OPT_TYPE_DOUBLE, + .default_val.d = DEFAULT_SPEED_KERNELSCALE, + .min = 0.1, + .max = 4.0, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "ks", + }, + { + .name = "speed_prescale", + .help = "scaling factor for the frame (2.0 means " + "making the image twice as large on each dimension)", + .offset = offsetof(SpeedTemporalState, speed_temporal_prescale), + .type = VMAF_OPT_TYPE_DOUBLE, + .default_val.d = DEFAULT_SPEED_PRESCALE, + .min = 0.1, + .max = 4.0, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "ps", + }, + { + .name = "speed_prescale_method", + .help = "scaling method for the frame, supported options: " + "[nearest, bilinear, bicubic, lanczos4]", + .offset = offsetof(SpeedTemporalState, speed_temporal_prescale_method), + .type = VMAF_OPT_TYPE_STRING, + .default_val.s = DEFAULT_SPEED_PRESCALE_METHOD, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "psm", + }, + { + .name = "speed_sigma_nn", + .help = "standard deviation of neural noise", + .offset = offsetof(SpeedTemporalState, speed_temporal_sigma_nn), + .type = VMAF_OPT_TYPE_DOUBLE, + .default_val.d = DEFAULT_SPEED_SIGMA_NN, + .min = 0.1, + .max = 2.0, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "snn", + }, + { + .name = "speed_nn_floor", + .help = "neural noise floor, expressed in percentage of sigma_nn", + .offset = offsetof(SpeedTemporalState, speed_temporal_nn_floor), + .type = VMAF_OPT_TYPE_DOUBLE, + .default_val.d = DEFAULT_SPEED_NN_FLOOR, + .min = 0.0, + .max = 1.0, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "nnf", + }, + { + .name = "speed_max_val", + .help = "maximum value allowed; larger values will be clipped to this " + "value", + .offset = offsetof(SpeedTemporalState, speed_temporal_max_val), + .type = VMAF_OPT_TYPE_DOUBLE, + .default_val.d = DEFAULT_SPEED_MAX_VAL, + .min = 0.0, + .max = 1000.0, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "mxv", + }, + { + .name = "speed_use_ref_diff", + .help = "debug mode: enable additional output", + .offset = offsetof(SpeedTemporalState, speed_temporal_use_ref_diff), + .type = VMAF_OPT_TYPE_BOOL, + .default_val.b = false, + .flags = VMAF_OPT_FLAG_FEATURE_PARAM, + .alias = "urd", + }, + { 0 } +}; + +static int init(VmafFeatureExtractor *fex, enum VmafPixelFormat pix_fmt, + unsigned bpc, unsigned w, unsigned h) +{ + (void)pix_fmt; + (void)bpc; + + SpeedTemporalState *s = fex->priv; + s->speed_options = (SpeedOptions) { + .speed_kernelscale = s->speed_temporal_kernelscale, + .speed_prescale = s->speed_temporal_prescale, + .speed_prescale_method = s->speed_temporal_prescale_method, + .speed_sigma_nn = s->speed_temporal_sigma_nn, + .speed_nn_floor = s->speed_temporal_nn_floor, + }; + + speed_init(&s->speed_state, &s->speed_options, w, h); + + size_t float_stride = s->speed_state.float_stride; + size_t frame_size = float_stride * h; + s->frame_buffer_ref[0] = aligned_malloc(frame_size, 32); + s->frame_buffer_ref[1] = aligned_malloc(frame_size, 32); + s->frame_buffer_dis[0] = aligned_malloc(frame_size, 32); + s->frame_buffer_dis[1] = aligned_malloc(frame_size, 32); + + if (!s->frame_buffer_ref[0] || !s->frame_buffer_ref[1] || + !s->frame_buffer_dis[0] || !s->frame_buffer_dis[1]) + { + return -ENOMEM; + } + + s->feature_name_dict = + vmaf_feature_name_dict_from_provided_features(fex->provided_features, + fex->options, s); + if (!s->feature_name_dict) + return -ENOMEM; + + return 0; +} + +static int extract(VmafFeatureExtractor *fex, + VmafPicture *ref_pic, VmafPicture *ref_pic_90, + VmafPicture *dist_pic, VmafPicture *dist_pic_90, + unsigned index, VmafFeatureCollector *feature_collector) +{ + SpeedTemporalState *s = fex->priv; + int err = 0; + + (void) ref_pic_90; + (void) dist_pic_90; + + s->index = index; + int cyclic_index = index % 2; + int other_index = (index + 1) % 2; + + picture_copy(s->frame_buffer_ref[cyclic_index], s->speed_state.float_stride, + ref_pic, -128, ref_pic->bpc, 0); + picture_copy(s->frame_buffer_dis[cyclic_index], s->speed_state.float_stride, + dist_pic, -128, ref_pic->bpc, 0); + + if (index == 0) { + err = vmaf_feature_collector_append_with_dict( + feature_collector, s->feature_name_dict, + "Speed_temporal_feature_speed_temporal_score", 0.0, index); + return err; + } + + int w = s->speed_state.dimensions.original_width; + int h = s->speed_state.dimensions.original_height; + int float_stride = s->speed_state.float_stride; + subtract_image(s->frame_buffer_ref[other_index], + s->frame_buffer_ref[cyclic_index], w, h, float_stride); + if (s->speed_temporal_use_ref_diff) { + subtract_image(s->frame_buffer_dis[other_index], + s->frame_buffer_ref[cyclic_index], w, h, float_stride); + } else { + subtract_image(s->frame_buffer_dis[other_index], + s->frame_buffer_dis[cyclic_index], w, h, float_stride); + } + float score; + speed_extract_score(&s->speed_state, &s->speed_options, + s->frame_buffer_ref[other_index], + s->frame_buffer_dis[other_index], &score); + + err = vmaf_feature_collector_append_with_dict( + feature_collector, s->feature_name_dict, + "Speed_temporal_feature_speed_temporal_score", score, index); + + if (err) return err; + return 0; +} + +static int close(VmafFeatureExtractor *fex) +{ + SpeedTemporalState *s = fex->priv; + speed_close(&s->speed_state); + + if (s->frame_buffer_ref[0]) aligned_free(s->frame_buffer_ref[0]); + if (s->frame_buffer_ref[1]) aligned_free(s->frame_buffer_ref[1]); + if (s->frame_buffer_dis[0]) aligned_free(s->frame_buffer_dis[0]); + if (s->frame_buffer_dis[1]) aligned_free(s->frame_buffer_dis[1]); + + if (s->feature_name_dict) + vmaf_dictionary_free(&s->feature_name_dict); + return 0; +} + +static const char *provided_features[] = { + "Speed_temporal_feature_speed_temporal_score", + NULL +}; + +VmafFeatureExtractor vmaf_fex_speed_temporal = { + .name = "speed_temporal", + .init = init, + .extract = extract, + .options = options, + .close = close, + .priv_size = sizeof(SpeedTemporalState), + .provided_features = provided_features, + .flags = VMAF_FEATURE_EXTRACTOR_TEMPORAL, +}; diff --git a/libvmaf/src/meson.build b/libvmaf/src/meson.build index 6da7baed5..34f8a104a 100644 --- a/libvmaf/src/meson.build +++ b/libvmaf/src/meson.build @@ -416,7 +416,7 @@ if float_enabled feature_src_dir + 'float_motion.c', feature_src_dir + 'float_vif.c', feature_src_dir + 'float_moment.c', - + feature_src_dir + 'speed.c', feature_src_dir + 'common/convolution.c', feature_src_dir + 'offset.c', feature_src_dir + 'adm.c',