ScaNN源码补丁文件0001-x86-to-arm64.patch内容如下:
From c4603d15c7a0884e1392b50b569b21ef95c6e8e9 Mon Sep 17 00:00:00 2001 From: root <root@localhost.localdomain> Date: Sat, 6 Jan 2024 16:11:03 +0800 Subject: [PATCH] x86 to arm64 --- scann/WORKSPACE | 14 +++++++ .../many_to_many/many_to_many_common.h | 4 +- .../many_to_many/many_to_many_impl.inc | 10 ++--- .../many_to_many/many_to_many_templates.h | 6 +-- .../one_to_many/one_to_many.h | 4 +- .../one_to_one/dot_product.h | 2 +- .../one_to_one/dot_product_avx1.cc | 2 +- .../one_to_one/dot_product_avx1.h | 4 +- .../one_to_one/dot_product_avx2.cc | 2 +- .../one_to_one/dot_product_avx2.h | 4 +- .../one_to_one/dot_product_sse4.cc | 14 +++---- .../one_to_one/dot_product_sse4.h | 2 +- .../one_to_one/l1_distance.h | 2 +- .../one_to_one/l1_distance_sse4.cc | 8 ++-- .../one_to_one/l1_distance_sse4.h | 2 +- .../one_to_one/l2_distance.h | 2 +- .../one_to_one/l2_distance_avx1.cc | 4 +- .../one_to_one/l2_distance_avx1.h | 2 +- .../one_to_one/l2_distance_sse4.cc | 9 ++-- .../one_to_one/l2_distance_sse4.h | 2 +- .../hashes/asymmetric_hashing2/querying.h | 10 ++--- .../internal/asymmetric_hashing_impl.cc | 7 ++-- .../bazel_templates/lut16_avx2.tpl.cc | 4 +- .../lut16_avx512_noprefetch.tpl.cc | 4 +- .../lut16_avx512_prefetch.tpl.cc | 4 +- .../bazel_templates/lut16_avx512_smart.tpl.cc | 4 +- .../bazel_templates/lut16_sse4.tpl.cc | 4 +- scann/scann/hashes/internal/lut16_avx2.h | 4 +- scann/scann/hashes/internal/lut16_avx2.inc | 4 +- scann/scann/hashes/internal/lut16_avx512.h | 4 +- scann/scann/hashes/internal/lut16_avx512.inc | 4 +- .../hashes/internal/lut16_avx512_swizzle.cc | 2 +- .../hashes/internal/lut16_avx512_swizzle.h | 2 +- scann/scann/hashes/internal/lut16_interface.h | 42 +++++++++---------- scann/scann/hashes/internal/lut16_sse4.h | 4 +- scann/scann/hashes/internal/lut16_sse4.inc | 10 ++--- .../partitioning/kmeans_tree_partitioner.h | 2 +- scann/scann/utils/fast_top_neighbors.cc | 4 +- scann/scann/utils/internal/avx2_funcs.h | 6 ++- scann/scann/utils/internal/avx_funcs.h | 8 ++-- scann/scann/utils/intrinsics/BUILD.bazel | 3 ++ scann/scann/utils/intrinsics/attributes.h | 8 ++-- scann/scann/utils/intrinsics/avx1.h | 4 +- scann/scann/utils/intrinsics/avx2.h | 20 +++++---- scann/scann/utils/intrinsics/avx512.h | 21 +++++----- scann/scann/utils/intrinsics/fallback.h | 2 +- scann/scann/utils/intrinsics/flags.cc | 14 +++---- scann/scann/utils/intrinsics/fma.h | 2 +- scann/scann/utils/intrinsics/horizontal_sum.h | 24 +++++------ scann/scann/utils/intrinsics/sse4.h | 28 ++++++------- 50 files changed, 192 insertions(+), 166 deletions(-) diff --git a/scann/WORKSPACE b/scann/WORKSPACE index 5b01155f6..39bdc4e9a 100644 --- a/scann/WORKSPACE +++ b/scann/WORKSPACE @@ -3,6 +3,20 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("//build_deps/py:python_configure.bzl", "python_configure") load("//build_deps/tf_dependency:tf_configure.bzl", "tf_configure") +new_local_repository( + name = "ksl_external_lib", + path = "/usr/local/ksl/", + build_file_content = """ +cc_library( + name = "avx2ki", + srcs = ["lib/libavx2ki.so"], + hdrs = glob(["include/*.h"]), + includes = ["include/"], + visibility = ["//visibility:public"], +) +""", +) + # Needed for highway's config_setting_group http_archive( name = "bazel_skylib", diff --git a/scann/scann/distance_measures/many_to_many/many_to_many_common.h b/scann/scann/distance_measures/many_to_many/many_to_many_common.h index fe8c1ff89..6a07a23c6 100644 --- a/scann/scann/distance_measures/many_to_many/many_to_many_common.h +++ b/scann/scann/distance_measures/many_to_many/many_to_many_common.h @@ -44,7 +44,7 @@ class EpsilonFilteringCallback { ManyToManyResultsCallback<FloatT> slow_path_fn) : epsilons_(epsilons), slow_path_fn_(std::move(slow_path_fn)) {} -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ SCANN_AVX512_INLINE void InvokeOptimized(Avx512<float, 2> simd_dists, size_t first_dp_idx, @@ -223,7 +223,7 @@ class EpsilonFilteringOffsetWrapper { dp_idx_offset_(dp_idx_offset), query_idx_table_(query_idx_table) {} -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ SCANN_AVX512_INLINE void InvokeOptimized(Avx512<float, 2> simd_dists, size_t first_dp_idx, diff --git a/scann/scann/distance_measures/many_to_many/many_to_many_impl.inc b/scann/scann/distance_measures/many_to_many/many_to_many_impl.inc index debad5e53..73dc8a4a9 100644 --- a/scann/scann/distance_measures/many_to_many/many_to_many_impl.inc +++ b/scann/scann/distance_measures/many_to_many/many_to_many_impl.inc @@ -34,7 +34,7 @@ SCANN_SIMD_INLINE void ExpandPretransposedFP8BlockImpl( if (n_to_transpose == kElementsPerRegister) { const int8_t* __restrict__ src = block.data(); -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ if constexpr (IsSame<Simd<FloatT>, Avx2<float>>()) { static_assert(kElementsPerRegister == 8); @@ -42,9 +42,9 @@ SCANN_SIMD_INLINE void ExpandPretransposedFP8BlockImpl( __m256 inv_multiplier_simd = _mm256_broadcast_ss(( inverse_multipliers_or_null + dim_idx)); - __m128i int8s = _mm_loadl_pi(_mm_setzero_si128(), + __m128 int8s = _mm_loadl_pi(_mm_castsi128_ps(_mm_setzero_si128()), reinterpret_cast<const __m64*>(src)); - __m256i int32s = _mm256_cvtepi8_epi32(int8s); + __m256i int32s = _mm256_cvtepi8_epi32(_mm_cvtps_epi32(int8s)); __m256 floats = _mm256_cvtepi32_ps(int32s) * inv_multiplier_simd; _mm256_store_ps(transposed_storage, floats); @@ -522,8 +522,8 @@ class DenseManyToManyTransposed final Simd<FloatT>::Load(transposed_block0 + dim * kElementsPerRegister); auto transposed_simd1 = Simd<FloatT>::Load(transposed_block1 + dim * kElementsPerRegister); - - for (size_t j : Seq(kNumQueries)) { + //__builtin_prefetch(&accumulators[j][0], 0, 0); + for (size_t j : Seq(kNumQueries)) {//__builtin_prefetch(&accumulators[j][0], 0, 0); Simd<FloatT> query_simd = query_ptrs[j][dim]; FusedMultiplySubtract(query_simd, transposed_simd0, &accumulators[j][0]); diff --git a/scann/scann/distance_measures/many_to_many/many_to_many_templates.h b/scann/scann/distance_measures/many_to_many/many_to_many_templates.h index 9c8751e6a..e9d13a4c3 100644 --- a/scann/scann/distance_measures/many_to_many/many_to_many_templates.h +++ b/scann/scann/distance_measures/many_to_many/many_to_many_templates.h @@ -95,7 +95,7 @@ namespace research_scann { -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ namespace sse4 { #define SCANN_SIMD_ATTRIBUTE SCANN_SSE4 @@ -170,7 +170,7 @@ SCANN_INLINE void DenseDistanceManyToManyImpl2( DCHECK(IsSupportedDistanceMeasure(dist)); DCHECK_NE(dist.specially_optimized_distance_tag(), DistanceMeasure::COSINE); -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ if (RuntimeSupportsAvx512()) { return avx512::DenseDistanceManyToManyImpl(dist, queries, database, pool, callback); @@ -202,7 +202,7 @@ SCANN_INLINE void DenseDistanceManyToManyFP8PretransposedImpl2( DCHECK(IsSupportedDistanceMeasure(dist)); DCHECK_NE(dist.specially_optimized_distance_tag(), DistanceMeasure::COSINE); -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ if (RuntimeSupportsAvx512()) { return avx512::DenseManyToManyFP8PretransposedImpl(dist, queries, database, pool, callback); diff --git a/scann/scann/distance_measures/one_to_many/one_to_many.h b/scann/scann/distance_measures/one_to_many/one_to_many.h index 77463e7ef..ffbeb37fc 100644 --- a/scann/scann/distance_measures/one_to_many/one_to_many.h +++ b/scann/scann/distance_measures/one_to_many/one_to_many.h @@ -1724,7 +1724,7 @@ void DenseDistanceOneToMany(const DistanceMeasure& dist, dist, query, database, result, &set_distance_functor, pool); } -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ namespace sse4 { #define SCANN_SIMD_ATTRIBUTE SCANN_SSE4 @@ -1798,7 +1798,7 @@ SCANN_INLINE void OneToManyInt8FloatDispatch( const float* __restrict__ inv_multipliers_for_squared_l2, const IndexT* indices, MutableSpan<ResultElemT> result, CallbackT callback) { -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ if constexpr (false && RuntimeSupportsAvx512()) { LOG(FATAL) << "We aren't compiling Avx-512 support yet."; diff --git a/scann/scann/distance_measures/one_to_one/dot_product.h b/scann/scann/distance_measures/one_to_one/dot_product.h index a4897cac6..8f9cdec6d 100644 --- a/scann/scann/distance_measures/one_to_one/dot_product.h +++ b/scann/scann/distance_measures/one_to_one/dot_product.h @@ -168,7 +168,7 @@ double DenseDotProduct(const DatapointPtr<T>& a, const DatapointPtr<U>& b, return DenseDotProductFallback(a, b, c); } -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ template <> inline double DenseDotProduct<uint8_t, uint8_t>( diff --git a/scann/scann/distance_measures/one_to_one/dot_product_avx1.cc b/scann/scann/distance_measures/one_to_one/dot_product_avx1.cc index f97dfe238..f464b5fe8 100644 --- a/scann/scann/distance_measures/one_to_one/dot_product_avx1.cc +++ b/scann/scann/distance_measures/one_to_one/dot_product_avx1.cc @@ -16,7 +16,7 @@ #include <cstdint> -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/data_format/datapoint.h" #include "scann/utils/internal/avx_funcs.h" diff --git a/scann/scann/distance_measures/one_to_one/dot_product_avx1.h b/scann/scann/distance_measures/one_to_one/dot_product_avx1.h index f86d99f3d..35f60b7a3 100644 --- a/scann/scann/distance_measures/one_to_one/dot_product_avx1.h +++ b/scann/scann/distance_measures/one_to_one/dot_product_avx1.h @@ -15,8 +15,8 @@ #ifndef SCANN_DISTANCE_MEASURES_ONE_TO_ONE_DOT_PRODUCT_AVX1_H_ #define SCANN_DISTANCE_MEASURES_ONE_TO_ONE_DOT_PRODUCT_AVX1_H_ #include <cstdint> -#ifdef __x86_64__ - +#if 1 // #ifdef __x86_64__ +#include "avx2ki.h" #include "scann/data_format/datapoint.h" #include "scann/utils/intrinsics/attributes.h" diff --git a/scann/scann/distance_measures/one_to_one/dot_product_avx2.cc b/scann/scann/distance_measures/one_to_one/dot_product_avx2.cc index 8ae66a506..0893d4a88 100644 --- a/scann/scann/distance_measures/one_to_one/dot_product_avx2.cc +++ b/scann/scann/distance_measures/one_to_one/dot_product_avx2.cc @@ -16,7 +16,7 @@ #include <cstdint> -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/data_format/datapoint.h" #include "scann/utils/internal/avx2_funcs.h" diff --git a/scann/scann/distance_measures/one_to_one/dot_product_avx2.h b/scann/scann/distance_measures/one_to_one/dot_product_avx2.h index 600c842a8..70754d77f 100644 --- a/scann/scann/distance_measures/one_to_one/dot_product_avx2.h +++ b/scann/scann/distance_measures/one_to_one/dot_product_avx2.h @@ -15,8 +15,8 @@ #ifndef SCANN_DISTANCE_MEASURES_ONE_TO_ONE_DOT_PRODUCT_AVX2_H_ #define SCANN_DISTANCE_MEASURES_ONE_TO_ONE_DOT_PRODUCT_AVX2_H_ #include <cstdint> -#ifdef __x86_64__ - +#if 1 // #ifdef __x86_64__ +#include "avx2ki.h" #include "scann/data_format/datapoint.h" #include "scann/utils/intrinsics/attributes.h" diff --git a/scann/scann/distance_measures/one_to_one/dot_product_sse4.cc b/scann/scann/distance_measures/one_to_one/dot_product_sse4.cc index 49140b885..f8af08168 100644 --- a/scann/scann/distance_measures/one_to_one/dot_product_sse4.cc +++ b/scann/scann/distance_measures/one_to_one/dot_product_sse4.cc @@ -15,7 +15,7 @@ #include "scann/distance_measures/one_to_one/dot_product_sse4.h" #include <cstdint> -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/utils/intrinsics/sse4.h" @@ -224,7 +224,7 @@ SCANN_SSE4_OUTLINE double DenseDotProductSse4(const DatapointPtr<int8_t>& a, __m128 accumulator = _mm_add_ps(accumulator0, accumulator1); accumulator = _mm_hadd_ps(accumulator, accumulator); accumulator = _mm_hadd_ps(accumulator, accumulator); - scalar_accumulator = accumulator[0]; + scalar_accumulator = accumulator.vect_f32[0]; } DCHECK_LT(aend - aptr, 4); @@ -283,12 +283,12 @@ SCANN_SSE4_OUTLINE double DenseDotProductSse4(const DatapointPtr<float>& a, } if (aptr < aend) { - accumulator[0] += aptr[0] * bptr[0]; + accumulator.vect_f32[0] += aptr[0] * bptr[0]; } accumulator = _mm_hadd_ps(accumulator, accumulator); accumulator = _mm_hadd_ps(accumulator, accumulator); - return accumulator[0]; + return accumulator.vect_f32[0]; } SCANN_SSE4_OUTLINE double DenseDotProductSse4(const DatapointPtr<double>& a, @@ -328,7 +328,7 @@ SCANN_SSE4_OUTLINE double DenseDotProductSse4(const DatapointPtr<double>& a, } accumulator = _mm_hadd_pd(accumulator, accumulator); - double result = accumulator[0]; + double result = accumulator.vect_f64[0]; if (aptr < aend) { result += *aptr * *bptr; @@ -423,7 +423,7 @@ SCANN_SSE4_OUTLINE double DenseDotProductSse4(const DatapointPtr<int8_t>& a, __m128 accumulator = _mm_add_ps(accumulator0, accumulator1); accumulator = _mm_hadd_ps(accumulator, accumulator); accumulator = _mm_hadd_ps(accumulator, accumulator); - scalar_accumulator = accumulator[0]; + scalar_accumulator = accumulator.vect_f32[0]; } DCHECK_LT(aend - aptr, 4); @@ -528,7 +528,7 @@ SCANN_SSE4_OUTLINE double DenseDotProductSse4(const DatapointPtr<int8_t>& a, __m128 accumulator = _mm_add_ps(accumulator0, accumulator1); accumulator = _mm_hadd_ps(accumulator, accumulator); accumulator = _mm_hadd_ps(accumulator, accumulator); - scalar_accumulator = accumulator[0]; + scalar_accumulator = accumulator.vect_f32[0]; } DCHECK_LT(aend - aptr, 4); diff --git a/scann/scann/distance_measures/one_to_one/dot_product_sse4.h b/scann/scann/distance_measures/one_to_one/dot_product_sse4.h index efd595272..12126c4de 100644 --- a/scann/scann/distance_measures/one_to_one/dot_product_sse4.h +++ b/scann/scann/distance_measures/one_to_one/dot_product_sse4.h @@ -15,7 +15,7 @@ #ifndef SCANN_DISTANCE_MEASURES_ONE_TO_ONE_DOT_PRODUCT_SSE4_H_ #define SCANN_DISTANCE_MEASURES_ONE_TO_ONE_DOT_PRODUCT_SSE4_H_ #include <cstdint> -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/data_format/datapoint.h" #include "scann/utils/intrinsics/attributes.h" diff --git a/scann/scann/distance_measures/one_to_one/l1_distance.h b/scann/scann/distance_measures/one_to_one/l1_distance.h index f301158d6..3f1e6cbfa 100644 --- a/scann/scann/distance_measures/one_to_one/l1_distance.h +++ b/scann/scann/distance_measures/one_to_one/l1_distance.h @@ -100,7 +100,7 @@ double DenseL1Norm(const DatapointPtr<T>& a, const DatapointPtr<U>& b) { return DenseL1NormFallback(a, b); } -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ template <> inline double DenseL1Norm<float, float>(const DatapointPtr<float>& a, diff --git a/scann/scann/distance_measures/one_to_one/l1_distance_sse4.cc b/scann/scann/distance_measures/one_to_one/l1_distance_sse4.cc index 664b66f30..8eda01edc 100644 --- a/scann/scann/distance_measures/one_to_one/l1_distance_sse4.cc +++ b/scann/scann/distance_measures/one_to_one/l1_distance_sse4.cc @@ -15,7 +15,7 @@ #include "scann/distance_measures/one_to_one/l1_distance_sse4.h" #include <cstdint> -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/utils/intrinsics/sse4.h" @@ -78,13 +78,13 @@ SCANN_SSE4_OUTLINE double DenseL1NormSse4(const DatapointPtr<float>& a, } if (aptr < aend) { - accumulator0[0] += std::abs(aptr[0] - bptr[0]); + accumulator0.vect_f32[0] += std::abs(aptr[0] - bptr[0]); } __m128 accumulator = _mm_add_ps(accumulator0, accumulator1); accumulator = _mm_hadd_ps(accumulator, accumulator); accumulator = _mm_hadd_ps(accumulator, accumulator); - return accumulator[0]; + return accumulator.vect_f32[0]; } SCANN_SSE4_OUTLINE double DenseL1NormSse4(const DatapointPtr<double>& a, @@ -130,7 +130,7 @@ SCANN_SSE4_OUTLINE double DenseL1NormSse4(const DatapointPtr<double>& a, __m128d accumulator = _mm_add_pd(accumulator0, accumulator1); accumulator = _mm_hadd_pd(accumulator, accumulator); - double result = accumulator[0]; + double result = accumulator.vect_f64[0]; if (aptr < aend) { result += std::abs(*aptr - *bptr); diff --git a/scann/scann/distance_measures/one_to_one/l1_distance_sse4.h b/scann/scann/distance_measures/one_to_one/l1_distance_sse4.h index 1fccffe4e..9f76a1f5c 100644 --- a/scann/scann/distance_measures/one_to_one/l1_distance_sse4.h +++ b/scann/scann/distance_measures/one_to_one/l1_distance_sse4.h @@ -14,7 +14,7 @@ #ifndef SCANN_DISTANCE_MEASURES_ONE_TO_ONE_L1_DISTANCE_SSE4_H_ #define SCANN_DISTANCE_MEASURES_ONE_TO_ONE_L1_DISTANCE_SSE4_H_ -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/data_format/datapoint.h" #include "scann/utils/intrinsics/attributes.h" diff --git a/scann/scann/distance_measures/one_to_one/l2_distance.h b/scann/scann/distance_measures/one_to_one/l2_distance.h index dba49ead7..1851a1130 100644 --- a/scann/scann/distance_measures/one_to_one/l2_distance.h +++ b/scann/scann/distance_measures/one_to_one/l2_distance.h @@ -180,7 +180,7 @@ double DenseSquaredL2Distance(const DatapointPtr<T>& a, return DenseSquaredL2DistanceFallback(a, b); } -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ template <> inline double DenseSquaredL2Distance<uint8_t, uint8_t>( diff --git a/scann/scann/distance_measures/one_to_one/l2_distance_avx1.cc b/scann/scann/distance_measures/one_to_one/l2_distance_avx1.cc index dd82f2472..0e4dcc053 100644 --- a/scann/scann/distance_measures/one_to_one/l2_distance_avx1.cc +++ b/scann/scann/distance_measures/one_to_one/l2_distance_avx1.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "scann/distance_measures/one_to_one/l2_distance_avx1.h" -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/utils/intrinsics/avx1.h" @@ -67,7 +67,7 @@ SCANN_AVX1_OUTLINE double DenseSquaredL2DistanceAvx1( bptr += 2; } __m128d sum = _mm_add_pd(upper, lower); - double result = sum[0] + sum[1]; + double result = sum.vect_f64[0] + sum.vect_f64[1]; if (aptr < aend) { const double to_square = *aptr - *bptr; diff --git a/scann/scann/distance_measures/one_to_one/l2_distance_avx1.h b/scann/scann/distance_measures/one_to_one/l2_distance_avx1.h index d6073a0c5..98db35472 100644 --- a/scann/scann/distance_measures/one_to_one/l2_distance_avx1.h +++ b/scann/scann/distance_measures/one_to_one/l2_distance_avx1.h @@ -14,7 +14,7 @@ #ifndef SCANN_DISTANCE_MEASURES_ONE_TO_ONE_L2_DISTANCE_AVX1_H_ #define SCANN_DISTANCE_MEASURES_ONE_TO_ONE_L2_DISTANCE_AVX1_H_ -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/data_format/datapoint.h" #include "scann/utils/intrinsics/attributes.h" diff --git a/scann/scann/distance_measures/one_to_one/l2_distance_sse4.cc b/scann/scann/distance_measures/one_to_one/l2_distance_sse4.cc index bd0f4af04..339b65d37 100644 --- a/scann/scann/distance_measures/one_to_one/l2_distance_sse4.cc +++ b/scann/scann/distance_measures/one_to_one/l2_distance_sse4.cc @@ -16,7 +16,7 @@ #include <cstdint> #include <utility> -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/utils/intrinsics/sse4.h" @@ -209,12 +209,13 @@ SCANN_SSE4_OUTLINE double DenseSquaredL2DistanceSse4( } if (aptr < aend) { - accumulator[0] += (aptr[0] - bptr[0]) * (aptr[0] - bptr[0]); + // accumulator[0] += (aptr[0] - bptr[0]) * (aptr[0] - bptr[0]); + accumulator.vect_f32[0] += (aptr[0] - bptr[0]) * (aptr[0] - bptr[0]); } accumulator = _mm_hadd_ps(accumulator, accumulator); accumulator = _mm_hadd_ps(accumulator, accumulator); - return accumulator[0]; + return accumulator.float32x4_ptr[0]; } SCANN_SSE4_OUTLINE double DenseSquaredL2DistanceSse4( @@ -255,7 +256,7 @@ SCANN_SSE4_OUTLINE double DenseSquaredL2DistanceSse4( } accumulator = _mm_hadd_pd(accumulator, accumulator); - double result = accumulator[0]; + double result = accumulator.vect_f64[0]; if (aptr < aend) { const double diff = *aptr - *bptr; diff --git a/scann/scann/distance_measures/one_to_one/l2_distance_sse4.h b/scann/scann/distance_measures/one_to_one/l2_distance_sse4.h index 5ae31a6a0..d0bfd1826 100644 --- a/scann/scann/distance_measures/one_to_one/l2_distance_sse4.h +++ b/scann/scann/distance_measures/one_to_one/l2_distance_sse4.h @@ -15,7 +15,7 @@ #ifndef SCANN_DISTANCE_MEASURES_ONE_TO_ONE_L2_DISTANCE_SSE4_H_ #define SCANN_DISTANCE_MEASURES_ONE_TO_ONE_L2_DISTANCE_SSE4_H_ #include <cstdint> -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/data_format/datapoint.h" #include "scann/utils/intrinsics/attributes.h" diff --git a/scann/scann/hashes/asymmetric_hashing2/querying.h b/scann/scann/hashes/asymmetric_hashing2/querying.h index 08d9aa7a5..6a9e62d1c 100644 --- a/scann/scann/hashes/asymmetric_hashing2/querying.h +++ b/scann/scann/hashes/asymmetric_hashing2/querying.h @@ -453,11 +453,11 @@ Status AsymmetricQueryer<T>::FindApproximateTopNeighborsTopNDispatch( "The distance type for TopN must be float for " "AsymmetricQueryer::FindApproximateNeighbors."); - const bool can_use_lut16 = - RuntimeSupportsSse4() && querying_options.lut16_packed_dataset && - !lookup_table.int8_lookup_table.empty() && - (lookup_table.int8_lookup_table.size() / - querying_options.lut16_packed_dataset->num_blocks) == 16; + const bool can_use_lut16 = true; + // RuntimeSupportsSse4() && querying_options.lut16_packed_dataset && + // !lookup_table.int8_lookup_table.empty() && + // (lookup_table.int8_lookup_table.size() / + // querying_options.lut16_packed_dataset->num_blocks) == 16; if (!can_use_lut16) return InvalidArgumentError( "FastTopNeighbors+AsymmetricQueryer fast path only works with LUT16."); diff --git a/scann/scann/hashes/internal/asymmetric_hashing_impl.cc b/scann/scann/hashes/internal/asymmetric_hashing_impl.cc index 4c375cbc7..5909f2012 100644 --- a/scann/scann/hashes/internal/asymmetric_hashing_impl.cc +++ b/scann/scann/hashes/internal/asymmetric_hashing_impl.cc @@ -419,9 +419,10 @@ Status ValidateNoiseShapingParams(double threshold, double eta) { "indexing."); } if (!std::isnan(eta) && !std::isnan(threshold)) { - return InvalidArgumentError( - "Threshold and eta may not both be specified for noise-shaped AH " - "indexing."); + //return InvalidArgumentError( + // "Threshold and eta may not both be specified for noise-shaped AH " + // "indexing."); + //LOG(INFO) << "hreshold and eta may not both be specified for noise-shaped AH indexing."; } return OkStatus(); } diff --git a/scann/scann/hashes/internal/bazel_templates/lut16_avx2.tpl.cc b/scann/scann/hashes/internal/bazel_templates/lut16_avx2.tpl.cc index 4d38c56cb..284e7149f 100644 --- a/scann/scann/hashes/internal/bazel_templates/lut16_avx2.tpl.cc +++ b/scann/scann/hashes/internal/bazel_templates/lut16_avx2.tpl.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifdef __x86_64__ +//#if 1 // #ifdef __x86_64__ #include "scann/hashes/internal/lut16_avx2.inc" namespace research_scann { @@ -25,4 +25,4 @@ template class LUT16Avx2<{BATCH_SIZE}, PrefetchStrategy::kSmart>; } // namespace asymmetric_hashing_internal } // namespace research_scann -#endif +//#endif diff --git a/scann/scann/hashes/internal/bazel_templates/lut16_avx512_noprefetch.tpl.cc b/scann/scann/hashes/internal/bazel_templates/lut16_avx512_noprefetch.tpl.cc index 34a69bd80..f3d730b7a 100644 --- a/scann/scann/hashes/internal/bazel_templates/lut16_avx512_noprefetch.tpl.cc +++ b/scann/scann/hashes/internal/bazel_templates/lut16_avx512_noprefetch.tpl.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifdef __x86_64__ +//#if 1 // #ifdef __x86_64__ #include "scann/hashes/internal/lut16_avx512.inc" namespace research_scann { @@ -23,4 +23,4 @@ template class LUT16Avx512<{BATCH_SIZE}, PrefetchStrategy::kOff>; } } // namespace research_scann -#endif +//#endif diff --git a/scann/scann/hashes/internal/bazel_templates/lut16_avx512_prefetch.tpl.cc b/scann/scann/hashes/internal/bazel_templates/lut16_avx512_prefetch.tpl.cc index 5cd64c34c..642a3df1b 100644 --- a/scann/scann/hashes/internal/bazel_templates/lut16_avx512_prefetch.tpl.cc +++ b/scann/scann/hashes/internal/bazel_templates/lut16_avx512_prefetch.tpl.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifdef __x86_64__ +//#if 1 // #ifdef __x86_64__ #include "scann/hashes/internal/lut16_avx512.inc" namespace research_scann { @@ -23,4 +23,4 @@ template class LUT16Avx512<{BATCH_SIZE}, PrefetchStrategy::kSeq>; } } // namespace research_scann -#endif +//#endif diff --git a/scann/scann/hashes/internal/bazel_templates/lut16_avx512_smart.tpl.cc b/scann/scann/hashes/internal/bazel_templates/lut16_avx512_smart.tpl.cc index 5365f86cf..4c7d3de34 100644 --- a/scann/scann/hashes/internal/bazel_templates/lut16_avx512_smart.tpl.cc +++ b/scann/scann/hashes/internal/bazel_templates/lut16_avx512_smart.tpl.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifdef __x86_64__ +//#if 1 // #ifdef __x86_64__ #include "scann/hashes/internal/lut16_avx512.inc" namespace research_scann { @@ -23,4 +23,4 @@ template class LUT16Avx512<{BATCH_SIZE}, PrefetchStrategy::kSmart>; } } // namespace research_scann -#endif +//#endif diff --git a/scann/scann/hashes/internal/bazel_templates/lut16_sse4.tpl.cc b/scann/scann/hashes/internal/bazel_templates/lut16_sse4.tpl.cc index bba39f36e..920104e4b 100644 --- a/scann/scann/hashes/internal/bazel_templates/lut16_sse4.tpl.cc +++ b/scann/scann/hashes/internal/bazel_templates/lut16_sse4.tpl.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifdef __x86_64__ +//#if 1 // #ifdef __x86_64__ #include "scann/hashes/internal/lut16_sse4.inc" namespace research_scann { @@ -25,4 +25,4 @@ template class LUT16Sse4<{BATCH_SIZE}, PrefetchStrategy::kSmart>; } // namespace asymmetric_hashing_internal } // namespace research_scann -#endif +//#endif diff --git a/scann/scann/hashes/internal/lut16_avx2.h b/scann/scann/hashes/internal/lut16_avx2.h index fde3ce216..c6a509b7d 100644 --- a/scann/scann/hashes/internal/lut16_avx2.h +++ b/scann/scann/hashes/internal/lut16_avx2.h @@ -16,7 +16,7 @@ #define SCANN_HASHES_INTERNAL_LUT16_AVX2_H_ #include <cstdint> -#ifdef __x86_64__ +//#if 1 // #ifdef __x86_64__ #include "scann/hashes/internal/lut16_args.h" #include "scann/utils/intrinsics/attributes.h" @@ -45,5 +45,5 @@ SCANN_INSTANTIATE_CLASS_FOR_LUT16_BATCH_SIZES(extern, LUT16Avx2); } // namespace asymmetric_hashing_internal } // namespace research_scann -#endif +//#endif #endif diff --git a/scann/scann/hashes/internal/lut16_avx2.inc b/scann/scann/hashes/internal/lut16_avx2.inc index dda59d5eb..c17570a82 100644 --- a/scann/scann/hashes/internal/lut16_avx2.inc +++ b/scann/scann/hashes/internal/lut16_avx2.inc @@ -4,7 +4,7 @@ #include "scann/oss_wrappers/scann_bits.h" #include "scann/utils/common.h" -#ifdef __x86_64__ +//#if 1 // #ifdef __x86_64__ #include "scann/utils/bits.h" #include "scann/utils/intrinsics/avx2.h" @@ -522,4 +522,4 @@ SCANN_AVX2_OUTLINE void LUT16Avx2<kNumQueries, kPrefetch>::GetTopFloatDistances( } // namespace asymmetric_hashing_internal } // namespace research_scann -#endif +//#endif diff --git a/scann/scann/hashes/internal/lut16_avx512.h b/scann/scann/hashes/internal/lut16_avx512.h index e973076f9..b833499fc 100644 --- a/scann/scann/hashes/internal/lut16_avx512.h +++ b/scann/scann/hashes/internal/lut16_avx512.h @@ -16,7 +16,7 @@ #define SCANN_HASHES_INTERNAL_LUT16_AVX512_H_ #include <cstdint> -#ifdef __x86_64__ +//#if 1 // #ifdef __x86_64__ #include "scann/hashes/internal/lut16_args.h" #include "scann/utils/types.h" @@ -45,5 +45,5 @@ SCANN_INSTANTIATE_CLASS_FOR_LUT16_BATCH_SIZES(extern, LUT16Avx512); } // namespace asymmetric_hashing_internal } // namespace research_scann -#endif +//#endif #endif diff --git a/scann/scann/hashes/internal/lut16_avx512.inc b/scann/scann/hashes/internal/lut16_avx512.inc index fe2348e16..4acfd2755 100644 --- a/scann/scann/hashes/internal/lut16_avx512.inc +++ b/scann/scann/hashes/internal/lut16_avx512.inc @@ -6,7 +6,7 @@ #include "scann/oss_wrappers/scann_bits.h" #include "scann/utils/common.h" -#ifdef __x86_64__ +//#if 1 // #ifdef __x86_64__ #include "scann/hashes/internal/lut16_avx512_swizzle.h" #include "scann/utils/bits.h" @@ -798,4 +798,4 @@ void LUT16Avx512<kNumQueries, kPrefetch>::GetTopFloatDistances( } // namespace asymmetric_hashing_internal } // namespace research_scann -#endif +//#endif diff --git a/scann/scann/hashes/internal/lut16_avx512_swizzle.cc b/scann/scann/hashes/internal/lut16_avx512_swizzle.cc index 70bcfad22..a2a472a8b 100644 --- a/scann/scann/hashes/internal/lut16_avx512_swizzle.cc +++ b/scann/scann/hashes/internal/lut16_avx512_swizzle.cc @@ -13,7 +13,7 @@ // limitations under the License. #include <cstdint> -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/hashes/internal/lut16_avx512_swizzle.h" #include "scann/utils/common.h" #include "scann/utils/intrinsics/avx512.h" diff --git a/scann/scann/hashes/internal/lut16_avx512_swizzle.h b/scann/scann/hashes/internal/lut16_avx512_swizzle.h index 35bcffa2e..1eea8437c 100644 --- a/scann/scann/hashes/internal/lut16_avx512_swizzle.h +++ b/scann/scann/hashes/internal/lut16_avx512_swizzle.h @@ -15,7 +15,7 @@ #ifndef SCANN_HASHES_INTERNAL_LUT16_AVX512_SWIZZLE_H_ #define SCANN_HASHES_INTERNAL_LUT16_AVX512_SWIZZLE_H_ #include <cstdint> -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/utils/intrinsics/attributes.h" #include "tensorflow/core/platform/types.h" diff --git a/scann/scann/hashes/internal/lut16_interface.h b/scann/scann/hashes/internal/lut16_interface.h index c4db23332..9808da05d 100644 --- a/scann/scann/hashes/internal/lut16_interface.h +++ b/scann/scann/hashes/internal/lut16_interface.h @@ -154,7 +154,7 @@ class LUT16Interface { LOG(FATAL) << "Invalid Batch Size"; \ } -#ifdef __x86_64__ +//#if 1 // #ifdef __x86_64__ #define SCANN_CALL_LUT16_FUNCTION(enable_avx512_codepath, batch_size, \ prefetch_strategy, Function, ...) \ @@ -251,32 +251,32 @@ void LUT16Interface::GetTopFloatDistances(LUT16ArgsTopN<float, TopN> args) { std::move(args)); } -#else +// #else -void LUT16Interface::GetDistances(LUT16Args<int16_t> args) { - LOG(FATAL) << "LUT16 is only supported on x86!"; -} +// void LUT16Interface::GetDistances(LUT16Args<int16_t> args) { +// LOG(FATAL) << "LUT16 is only supported on x86!"; +// } -void LUT16Interface::GetDistances(LUT16Args<int32_t> args) { - LOG(FATAL) << "LUT16 is only supported on x86!"; -} +// void LUT16Interface::GetDistances(LUT16Args<int32_t> args) { +// LOG(FATAL) << "LUT16 is only supported on x86!"; +// } -void LUT16Interface::GetFloatDistances(LUT16Args<float> args, - ConstSpan<float> inv_fp_multipliers) { - LOG(FATAL) << "LUT16 is only supported on x86!"; -} +// void LUT16Interface::GetFloatDistances(LUT16Args<float> args, +// ConstSpan<float> inv_fp_multipliers) { +// LOG(FATAL) << "LUT16 is only supported on x86!"; +// } -template <typename TopN> -void LUT16Interface::GetTopDistances(LUT16ArgsTopN<int16_t, TopN> args) { - LOG(FATAL) << "LUT16 is only supported on x86!"; -} +// template <typename TopN> +// void LUT16Interface::GetTopDistances(LUT16ArgsTopN<int16_t, TopN> args) { +// LOG(FATAL) << "LUT16 is only supported on x86!"; +// } -template <typename TopN> -void LUT16Interface::GetTopFloatDistances(LUT16ArgsTopN<float, TopN> args) { - LOG(FATAL) << "LUT16 is only supported on x86!"; -} +// template <typename TopN> +// void LUT16Interface::GetTopFloatDistances(LUT16ArgsTopN<float, TopN> args) { +// LOG(FATAL) << "LUT16 is only supported on x86!"; +// } -#endif +// #endif } // namespace asymmetric_hashing_internal } // namespace research_scann diff --git a/scann/scann/hashes/internal/lut16_sse4.h b/scann/scann/hashes/internal/lut16_sse4.h index b228dd9d3..71e89c273 100644 --- a/scann/scann/hashes/internal/lut16_sse4.h +++ b/scann/scann/hashes/internal/lut16_sse4.h @@ -17,7 +17,7 @@ #include <cstdint> -#ifdef __x86_64__ +//#if 1 // #ifdef __x86_64__ #include "scann/hashes/internal/lut16_args.h" #include "scann/utils/intrinsics/attributes.h" @@ -46,5 +46,5 @@ SCANN_INSTANTIATE_CLASS_FOR_LUT16_BATCH_SIZES(extern, LUT16Sse4); } // namespace asymmetric_hashing_internal } // namespace research_scann -#endif +//#endif #endif diff --git a/scann/scann/hashes/internal/lut16_sse4.inc b/scann/scann/hashes/internal/lut16_sse4.inc index 4198f5fab..b2e134cef 100644 --- a/scann/scann/hashes/internal/lut16_sse4.inc +++ b/scann/scann/hashes/internal/lut16_sse4.inc @@ -4,7 +4,7 @@ #include "scann/oss_wrappers/scann_bits.h" #include "scann/utils/common.h" -#ifdef __x86_64__ +// #if 1 // #ifdef __x86_64__ #include "scann/utils/bits.h" #include "scann/utils/intrinsics/sse4.h" @@ -32,11 +32,11 @@ SCANN_SSE4_INLINE Sse4<int16_t, kNumQueries, 4> Sse4LUT16BottomLoop( const Sse4<uint8_t> sign7 = 0x0F; const Sse4<int16_t> total_bias = num_blocks * 128; for (; num_blocks != 0; --num_blocks, data_start += 16) { - if (kPrefetch != PrefetchStrategy::kOff) { + /*if (kPrefetch != PrefetchStrategy::kOff) { ::tensorflow::port::prefetch<::tensorflow::port::PREFETCH_HINT_NTA>( data_start + kPrefetchBytesAhead); - } - + }*/ + //__builtin_prefetch(data_start + kPrefetchBytesAhead + kPrefetchBytesAhead, 0, 0); auto mask = Sse4<uint8_t>::Load(data_start); Sse4<uint8_t> mask0 = mask & sign7; Sse4<uint8_t> mask1 = Sse4<uint8_t>((Sse4<uint16_t>(mask) >> 4)) & sign7; @@ -399,4 +399,4 @@ SCANN_SSE4_OUTLINE void LUT16Sse4<kNumQueries, kPrefetch>::GetTopFloatDistances( } // namespace asymmetric_hashing_internal } // namespace research_scann -#endif +//#endif diff --git a/scann/scann/partitioning/kmeans_tree_partitioner.h b/scann/scann/partitioning/kmeans_tree_partitioner.h index 74dfa12af..4415a72fe 100644 --- a/scann/scann/partitioning/kmeans_tree_partitioner.h +++ b/scann/scann/partitioning/kmeans_tree_partitioner.h @@ -30,7 +30,7 @@ #include "scann/oss_wrappers/scann_status.h" #include "scann/oss_wrappers/scann_threadpool.h" #include "scann/partitioning/kmeans_tree_like_partitioner.h" -#include "scann/partitioning/orthogonality_amplification_utils.h" +// #include "scann/partitioning/orthogonality_amplification_utils.h" #include "scann/partitioning/partitioner.pb.h" #include "scann/partitioning/partitioner_base.h" #include "scann/trees/kmeans_tree/kmeans_tree.h" diff --git a/scann/scann/utils/fast_top_neighbors.cc b/scann/scann/utils/fast_top_neighbors.cc index 732835fb0..16ca76032 100644 --- a/scann/scann/utils/fast_top_neighbors.cc +++ b/scann/scann/utils/fast_top_neighbors.cc @@ -120,7 +120,7 @@ SCANN_INLINE DistT FastMedianOf3(DistT v0, DistT v1, DistT v2) { } // namespace -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ namespace avx2 { #define SCANN_SIMD_ATTRIBUTE SCANN_AVX2 @@ -141,7 +141,7 @@ size_t FastTopNeighbors<DistT, DatapointIndexT>::ApproxNthElement( size_t keep_min, size_t keep_max, size_t sz, DatapointIndexT* ii, DistT* dd, uint32_t* mm) { DCHECK_GT(keep_min, 0); -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ if (RuntimeSupportsAvx2()) { return avx2::ApproxNthElementImpl(keep_min, keep_max, sz, ii, dd, mm); } else if (RuntimeSupportsSse4()) { diff --git a/scann/scann/utils/internal/avx2_funcs.h b/scann/scann/utils/internal/avx2_funcs.h index 06f3e55ae..9bca61801 100644 --- a/scann/scann/utils/internal/avx2_funcs.h +++ b/scann/scann/utils/internal/avx2_funcs.h @@ -14,7 +14,7 @@ #ifndef SCANN_UTILS_INTERNAL_AVX2_FUNCS_H_ #define SCANN_UTILS_INTERNAL_AVX2_FUNCS_H_ -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/utils/intrinsics/avx2.h" #include "scann/utils/types.h" @@ -59,7 +59,9 @@ class AvxFunctionsAvx2Fma { __m128 sum = _mm_add_ps(upper, lower); sum = _mm_add_ps( sum, _mm_castsi128_ps(_mm_srli_si128(_mm_castps_si128(sum), 8))); - return sum[0] + sum[1]; + //return sum[0] + sum[1]; + //return _mm_extract_ps(sum, 0) + _mm_extract_ps(sum, 1); + return sum.float32x4_ptr[0] + sum.float32x4_ptr[1]; } }; diff --git a/scann/scann/utils/internal/avx_funcs.h b/scann/scann/utils/internal/avx_funcs.h index 9eec38b07..8d9abf71c 100644 --- a/scann/scann/utils/internal/avx_funcs.h +++ b/scann/scann/utils/internal/avx_funcs.h @@ -14,11 +14,11 @@ #ifndef SCANN_UTILS_INTERNAL_AVX_FUNCS_H_ #define SCANN_UTILS_INTERNAL_AVX_FUNCS_H_ -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #include "scann/utils/intrinsics/avx1.h" #include "scann/utils/types.h" - +#include "operatoroverload.h" namespace research_scann { class AvxFunctionsAvx { @@ -63,7 +63,9 @@ class AvxFunctionsAvx { __m128 sum = _mm_add_ps(upper, lower); sum = _mm_add_ps( sum, _mm_castsi128_ps(_mm_srli_si128(_mm_castps_si128(sum), 8))); - return sum[0] + sum[1]; + //return sum[0] + sum[1]; + //return _mm_extract_ps(sum, 0) + _mm_extract_ps(sum, 1); + return sum.float32x4_ptr[0] + sum.float32x4_ptr[1]; } }; diff --git a/scann/scann/utils/intrinsics/BUILD.bazel b/scann/scann/utils/intrinsics/BUILD.bazel index 819b6ef76..63c84ced6 100644 --- a/scann/scann/utils/intrinsics/BUILD.bazel +++ b/scann/scann/utils/intrinsics/BUILD.bazel @@ -103,6 +103,7 @@ cc_library( ":flags", "//scann/utils:index_sequence", "//scann/utils:types", + "@ksl_external_lib//:avx2ki", ], ) @@ -137,6 +138,7 @@ cc_library( ":flags", "//scann/utils:index_sequence", "//scann/utils:types", + "@ksl_external_lib//:avx2ki", ], ) @@ -154,5 +156,6 @@ cc_library( ":flags", "//scann/utils:index_sequence", "//scann/utils:types", + "@ksl_external_lib//:avx2ki", ], ) diff --git a/scann/scann/utils/intrinsics/attributes.h b/scann/scann/utils/intrinsics/attributes.h index b3d1a851a..ea6e55fa4 100644 --- a/scann/scann/utils/intrinsics/attributes.h +++ b/scann/scann/utils/intrinsics/attributes.h @@ -15,13 +15,13 @@ #ifndef SCANN_UTILS_INTRINSICS_ATTRIBUTES_H_ #define SCANN_UTILS_INTRINSICS_ATTRIBUTES_H_ -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ #define SCANN_SSE4 -#define SCANN_AVX1 __attribute((target("avx"))) -#define SCANN_AVX2 __attribute((target("avx,avx2,fma"))) +#define SCANN_AVX1 //__attribute((target("avx"))) +#define SCANN_AVX2 //__attribute((target("avx,avx2,fma"))) #define SCANN_AVX512 \ - __attribute((target("avx,avx2,fma,avx512f,avx512dq,avx512bw"))) + //__attribute((target("avx,avx2,fma,avx512f,avx512dq,avx512bw"))) #else diff --git a/scann/scann/utils/intrinsics/avx1.h b/scann/scann/utils/intrinsics/avx1.h index 11c9da5b7..1a359c46a 100644 --- a/scann/scann/utils/intrinsics/avx1.h +++ b/scann/scann/utils/intrinsics/avx1.h @@ -25,9 +25,9 @@ #include "scann/utils/intrinsics/sse4.h" #include "scann/utils/types.h" -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ -#include <x86intrin.h> +#include "avx2ki.h" //<x86intrin.h> namespace research_scann { namespace avx1 { diff --git a/scann/scann/utils/intrinsics/avx2.h b/scann/scann/utils/intrinsics/avx2.h index 280ae33ef..70844345f 100644 --- a/scann/scann/utils/intrinsics/avx2.h +++ b/scann/scann/utils/intrinsics/avx2.h @@ -25,9 +25,11 @@ #include "scann/utils/intrinsics/flags.h" #include "scann/utils/types.h" -#ifdef __x86_64__ +//#if 1 // #ifdef __x86_64__ + + +#include "avx2ki.h" -#include <x86intrin.h> namespace research_scann { namespace avx2 { @@ -166,7 +168,7 @@ class Avx2<T, kNumRegistersInferred> { if constexpr (IsSameAny<T, float>()) { return _mm256_setzero_ps(); } else if constexpr (IsSameAny<T, double>()) { - return _mm256_setzero_ps(); + return _mm256_setzero_pd(); } else { return _mm256_setzero_si256(); } @@ -1003,14 +1005,14 @@ using Uninitialized = Avx2Uninitialized; } // namespace avx2 } // namespace research_scann -#else +// #else -namespace research_scann { +// namespace research_scann { -template <typename T, size_t... kTensorNumRegisters> -struct Avx2; +// template <typename T, size_t... kTensorNumRegisters> +// struct Avx2; -} +// } -#endif +// #endif #endif diff --git a/scann/scann/utils/intrinsics/avx512.h b/scann/scann/utils/intrinsics/avx512.h index 6e27632fe..b79800167 100644 --- a/scann/scann/utils/intrinsics/avx512.h +++ b/scann/scann/utils/intrinsics/avx512.h @@ -25,9 +25,10 @@ #include "scann/utils/intrinsics/flags.h" #include "scann/utils/types.h" -#ifdef __x86_64__ +//#if 1 // #ifdef __x86_64__ + +#include "avx2ki.h" -#include <x86intrin.h> namespace research_scann { namespace avx512 { @@ -150,7 +151,7 @@ class Avx512<T, kNumRegistersInferred> { if constexpr (IsSameAny<T, float>()) { return _mm512_setzero_ps(); } else if constexpr (IsSameAny<T, double>()) { - return _mm512_setzero_ps(); + return _mm512_setzero_pd(); } else { return _mm512_setzero_si512(); } @@ -216,7 +217,7 @@ class Avx512<T, kNumRegistersInferred> { if constexpr (IsSameAny<T, float>()) { return _mm512_loadu_ps(reinterpret_cast<const __m512*>(address)); } else if constexpr (IsSameAny<T, double>()) { - return _mm512_loadu_pd(reinterpret_cast<const __m512d*>(address)); + return _mm512_loadu_pd(reinterpret_cast<const double*>(address)); } else { return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(address)); } @@ -1043,14 +1044,14 @@ using Uninitialized = Avx512Uninitialized; } // namespace avx512 } // namespace research_scann -#else +// #else -namespace research_scann { +// namespace research_scann { -template <typename T, size_t... kTensorNumRegisters> -struct Avx512; +// template <typename T, size_t... kTensorNumRegisters> +// struct Avx512; -} +// } -#endif +// #endif #endif diff --git a/scann/scann/utils/intrinsics/fallback.h b/scann/scann/utils/intrinsics/fallback.h index 94f91b740..76cea4cb4 100644 --- a/scann/scann/utils/intrinsics/fallback.h +++ b/scann/scann/utils/intrinsics/fallback.h @@ -142,7 +142,7 @@ class Simd<T, kNumElementsArg> { } } - static SCANN_INLINE Simd Zeros() { + static SCANN_INLINE Simd Zeros_s() { Simd<T, kNumElements> ret; for (size_t j : Seq(kNumElements)) { ret[j] = IntelType(0); diff --git a/scann/scann/utils/intrinsics/flags.cc b/scann/scann/utils/intrinsics/flags.cc index 5663a5009..0ba7d2da5 100644 --- a/scann/scann/utils/intrinsics/flags.cc +++ b/scann/scann/utils/intrinsics/flags.cc @@ -37,14 +37,14 @@ ABSL_RETIRED_FLAG(bool, ignore_sse4, false, "Ignore SSE4"); namespace research_scann { namespace flags_internal { -bool should_use_sse4 = - tensorflow::port::TestCPUFeature(tensorflow::port::SSE4_2); +bool should_use_sse4 = 1; + //tensorflow::port::TestCPUFeature(tensorflow::port::SSE4_2); bool should_use_avx1 = tensorflow::port::TestCPUFeature(tensorflow::port::AVX); -bool should_use_avx2 = tensorflow::port::TestCPUFeature(tensorflow::port::AVX2); -bool should_use_avx512 = - tensorflow::port::TestCPUFeature(tensorflow::port::AVX512F) && - tensorflow::port::TestCPUFeature(tensorflow::port::AVX512DQ) && - tensorflow::port::TestCPUFeature(tensorflow::port::AVX512BW); +bool should_use_avx2 = 1; // tensorflow::port::TestCPUFeature(tensorflow::port::AVX2); +bool should_use_avx512 = 1; + //tensorflow::port::TestCPUFeature(tensorflow::port::AVX512F) && + //tensorflow::port::TestCPUFeature(tensorflow::port::AVX512DQ) && + //tensorflow::port::TestCPUFeature(tensorflow::port::AVX512BW); } // namespace flags_internal diff --git a/scann/scann/utils/intrinsics/fma.h b/scann/scann/utils/intrinsics/fma.h index b6a158f97..cb3b1b236 100644 --- a/scann/scann/utils/intrinsics/fma.h +++ b/scann/scann/utils/intrinsics/fma.h @@ -20,7 +20,7 @@ namespace research_scann { -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ namespace avx512 { #define SCANN_SIMD_ATTRIBUTE SCANN_AVX512 diff --git a/scann/scann/utils/intrinsics/horizontal_sum.h b/scann/scann/utils/intrinsics/horizontal_sum.h index 6445eeac3..e33bac8f2 100644 --- a/scann/scann/utils/intrinsics/horizontal_sum.h +++ b/scann/scann/utils/intrinsics/horizontal_sum.h @@ -43,7 +43,7 @@ SCANN_INLINE void HorizontalSum4X(Simd<FloatT> a, Simd<FloatT> b, } // namespace fallback -#ifdef __x86_64__ +#if 1 // #ifdef __x86_64__ namespace sse4 { @@ -135,8 +135,8 @@ SCANN_AVX1_INLINE void HorizontalSum2X(Avx1<float> a, Avx1<float> b, sum += _mm256_shuffle_ps(sum, sum, 0b11'10'01'01); - *resulta = sum[0]; - *resultb = sum[4]; + *resulta = sum.vect_f32[0][0]; + *resultb = sum.vect_f32[1][0]; } SCANN_AVX1_INLINE void HorizontalSum3X(Avx1<float> a, Avx1<float> b, @@ -148,9 +148,9 @@ SCANN_AVX1_INLINE void HorizontalSum3X(Avx1<float> a, Avx1<float> b, abcg += _mm256_shuffle_ps(abcg, abcg, 0b11'11'01'01); - *resulta = abcg[0]; - *resultb = abcg[2]; - *resultc = abcg[4]; + *resulta = abcg.vect_f32[0][0]; + *resultb = abcg.vect_f32[0][2]; + *resultc = abcg.vect_f32[1][0]; } SCANN_AVX1_INLINE void HorizontalSum4X(Avx1<float> a, Avx1<float> b, @@ -163,10 +163,10 @@ SCANN_AVX1_INLINE void HorizontalSum4X(Avx1<float> a, Avx1<float> b, abcd += _mm256_shuffle_ps(abcd, abcd, 0b11'11'01'01); - *resulta = abcd[0]; - *resultb = abcd[2]; - *resultc = abcd[4]; - *resultd = abcd[6]; + *resulta = abcd.vect_f32[0][0]; + *resultb = abcd.vect_f32[0][2]; + *resultc = abcd.vect_f32[1][0]; + *resultd = abcd.vect_f32[1][2]; } SCANN_AVX1_INLINE void HorizontalSum2X(Avx1<double> a, Avx1<double> b, @@ -175,8 +175,8 @@ SCANN_AVX1_INLINE void HorizontalSum2X(Avx1<double> a, Avx1<double> b, sum += _mm256_shuffle_pd(sum, sum, 0b11'11); - *resulta = sum[0]; - *resultb = sum[2]; + *resulta = sum.vect_f64[0][0]; + *resultb = sum.vect_f64[0][2]; } SCANN_AVX1_INLINE void HorizontalSum4X(Avx1<double> a, Avx1<double> b, diff --git a/scann/scann/utils/intrinsics/sse4.h b/scann/scann/utils/intrinsics/sse4.h index b99ac792e..e3d098e8c 100644 --- a/scann/scann/utils/intrinsics/sse4.h +++ b/scann/scann/utils/intrinsics/sse4.h @@ -24,10 +24,8 @@ #include "scann/utils/intrinsics/flags.h" #include "scann/utils/types.h" -#ifdef __x86_64__ - -#include <emmintrin.h> -#include <x86intrin.h> +//#if 1 // #ifdef __x86_64__ +#include "avx2ki.h" namespace research_scann { namespace sse4 { @@ -715,7 +713,8 @@ class Sse4<T, kNumRegistersInferred> { const auto& me = *this; if constexpr (IsSameAny<T, float, double>()) { - return (*me[0])[0]; + //return (*me[0])[0]; + return *(T *)&me; } if constexpr (IsSameAny<T, int8_t, uint8_t>()) { @@ -728,7 +727,8 @@ class Sse4<T, kNumRegistersInferred> { return _mm_cvtsi128_si32(*me[0]); } if constexpr (IsSameAny<T, int64_t, uint64_t>()) { - return (*me[0])[0]; + //return (*me[0])[0]; + return *(T *)&me; } LOG(FATAL) << "Undefined"; } @@ -802,8 +802,8 @@ class Sse4<T, kNumRegistersInferred> { static_assert(!IsSame<T, double>(), "Nothing to expand to"); if constexpr (!IsSameAny<T, float, double>()) { - __m128 hi = _mm_srli_si128(x, 8); - __m128 lo = x; + __m128i hi = _mm_srli_si128(x, 8); + __m128i lo = x; if constexpr (IsSame<T, int8_t>()) { return std::make_pair(_mm_cvtepi8_epi16(lo), _mm_cvtepi8_epi16(hi)); @@ -992,14 +992,14 @@ using Uninitialized = Sse4Uninitialized; } // namespace sse4 } // namespace research_scann -#else +// #else -namespace research_scann { +// namespace research_scann { -template <typename T, size_t... kTensorNumRegisters> -struct Sse4; +// template <typename T, size_t... kTensorNumRegisters> +// struct Sse4; -} +// } -#endif +// #endif #endif -- 2.33.0