ScaNN源码补丁文件
ScaNN源码补丁文件0001-x86-to-arm64.patch内容如下:
From c4603d15c7a0884e1392b50b569b21ef95c6e8e9 Mon Sep 17 00:00:00 2001
From: root <root@localhost.localdomain>
Date: Sat, 6 Jan 2024 16:11:03 +0800
Subject: [PATCH] x86 to arm64
---
scann/WORKSPACE | 14 +++++++
.../many_to_many/many_to_many_common.h | 4 +-
.../many_to_many/many_to_many_impl.inc | 10 ++---
.../many_to_many/many_to_many_templates.h | 6 +--
.../one_to_many/one_to_many.h | 4 +-
.../one_to_one/dot_product.h | 2 +-
.../one_to_one/dot_product_avx1.cc | 2 +-
.../one_to_one/dot_product_avx1.h | 4 +-
.../one_to_one/dot_product_avx2.cc | 2 +-
.../one_to_one/dot_product_avx2.h | 4 +-
.../one_to_one/dot_product_sse4.cc | 14 +++----
.../one_to_one/dot_product_sse4.h | 2 +-
.../one_to_one/l1_distance.h | 2 +-
.../one_to_one/l1_distance_sse4.cc | 8 ++--
.../one_to_one/l1_distance_sse4.h | 2 +-
.../one_to_one/l2_distance.h | 2 +-
.../one_to_one/l2_distance_avx1.cc | 4 +-
.../one_to_one/l2_distance_avx1.h | 2 +-
.../one_to_one/l2_distance_sse4.cc | 9 ++--
.../one_to_one/l2_distance_sse4.h | 2 +-
.../hashes/asymmetric_hashing2/querying.h | 10 ++---
.../internal/asymmetric_hashing_impl.cc | 7 ++--
.../bazel_templates/lut16_avx2.tpl.cc | 4 +-
.../lut16_avx512_noprefetch.tpl.cc | 4 +-
.../lut16_avx512_prefetch.tpl.cc | 4 +-
.../bazel_templates/lut16_avx512_smart.tpl.cc | 4 +-
.../bazel_templates/lut16_sse4.tpl.cc | 4 +-
scann/scann/hashes/internal/lut16_avx2.h | 4 +-
scann/scann/hashes/internal/lut16_avx2.inc | 4 +-
scann/scann/hashes/internal/lut16_avx512.h | 4 +-
scann/scann/hashes/internal/lut16_avx512.inc | 4 +-
.../hashes/internal/lut16_avx512_swizzle.cc | 2 +-
.../hashes/internal/lut16_avx512_swizzle.h | 2 +-
scann/scann/hashes/internal/lut16_interface.h | 42 +++++++++----------
scann/scann/hashes/internal/lut16_sse4.h | 4 +-
scann/scann/hashes/internal/lut16_sse4.inc | 10 ++---
.../partitioning/kmeans_tree_partitioner.h | 2 +-
scann/scann/utils/fast_top_neighbors.cc | 4 +-
scann/scann/utils/internal/avx2_funcs.h | 6 ++-
scann/scann/utils/internal/avx_funcs.h | 8 ++--
scann/scann/utils/intrinsics/BUILD.bazel | 3 ++
scann/scann/utils/intrinsics/attributes.h | 8 ++--
scann/scann/utils/intrinsics/avx1.h | 4 +-
scann/scann/utils/intrinsics/avx2.h | 20 +++++----
scann/scann/utils/intrinsics/avx512.h | 21 +++++-----
scann/scann/utils/intrinsics/fallback.h | 2 +-
scann/scann/utils/intrinsics/flags.cc | 14 +++----
scann/scann/utils/intrinsics/fma.h | 2 +-
scann/scann/utils/intrinsics/horizontal_sum.h | 24 +++++------
scann/scann/utils/intrinsics/sse4.h | 28 ++++++-------
50 files changed, 192 insertions(+), 166 deletions(-)
diff --git a/scann/WORKSPACE b/scann/WORKSPACE
index 5b01155f6..39bdc4e9a 100644
--- a/scann/WORKSPACE
+++ b/scann/WORKSPACE
@@ -3,6 +3,20 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("//build_deps/py:python_configure.bzl", "python_configure")
load("//build_deps/tf_dependency:tf_configure.bzl", "tf_configure")
+new_local_repository(
+ name = "ksl_external_lib",
+ path = "/usr/local/ksl/",
+ build_file_content = """
+cc_library(
+ name = "avx2ki",
+ srcs = ["lib/libavx2ki.so"],
+ hdrs = glob(["include/*.h"]),
+ includes = ["include/"],
+ visibility = ["//visibility:public"],
+)
+""",
+)
+
# Needed for highway's config_setting_group
http_archive(
name = "bazel_skylib",
diff --git a/scann/scann/distance_measures/many_to_many/many_to_many_common.h b/scann/scann/distance_measures/many_to_many/many_to_many_common.h
index fe8c1ff89..6a07a23c6 100644
--- a/scann/scann/distance_measures/many_to_many/many_to_many_common.h
+++ b/scann/scann/distance_measures/many_to_many/many_to_many_common.h
@@ -44,7 +44,7 @@ class EpsilonFilteringCallback {
ManyToManyResultsCallback<FloatT> slow_path_fn)
: epsilons_(epsilons), slow_path_fn_(std::move(slow_path_fn)) {}
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
SCANN_AVX512_INLINE void InvokeOptimized(Avx512<float, 2> simd_dists,
size_t first_dp_idx,
@@ -223,7 +223,7 @@ class EpsilonFilteringOffsetWrapper {
dp_idx_offset_(dp_idx_offset),
query_idx_table_(query_idx_table) {}
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
SCANN_AVX512_INLINE void InvokeOptimized(Avx512<float, 2> simd_dists,
size_t first_dp_idx,
diff --git a/scann/scann/distance_measures/many_to_many/many_to_many_impl.inc b/scann/scann/distance_measures/many_to_many/many_to_many_impl.inc
index debad5e53..73dc8a4a9 100644
--- a/scann/scann/distance_measures/many_to_many/many_to_many_impl.inc
+++ b/scann/scann/distance_measures/many_to_many/many_to_many_impl.inc
@@ -34,7 +34,7 @@ SCANN_SIMD_INLINE void ExpandPretransposedFP8BlockImpl(
if (n_to_transpose == kElementsPerRegister) {
const int8_t* __restrict__ src = block.data();
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
if constexpr (IsSame<Simd<FloatT>, Avx2<float>>()) {
static_assert(kElementsPerRegister == 8);
@@ -42,9 +42,9 @@ SCANN_SIMD_INLINE void ExpandPretransposedFP8BlockImpl(
__m256 inv_multiplier_simd = _mm256_broadcast_ss((
inverse_multipliers_or_null + dim_idx));
- __m128i int8s = _mm_loadl_pi(_mm_setzero_si128(),
+ __m128 int8s = _mm_loadl_pi(_mm_castsi128_ps(_mm_setzero_si128()),
reinterpret_cast<const __m64*>(src));
- __m256i int32s = _mm256_cvtepi8_epi32(int8s);
+ __m256i int32s = _mm256_cvtepi8_epi32(_mm_cvtps_epi32(int8s));
__m256 floats = _mm256_cvtepi32_ps(int32s) * inv_multiplier_simd;
_mm256_store_ps(transposed_storage, floats);
@@ -522,8 +522,8 @@ class DenseManyToManyTransposed final
Simd<FloatT>::Load(transposed_block0 + dim * kElementsPerRegister);
auto transposed_simd1 =
Simd<FloatT>::Load(transposed_block1 + dim * kElementsPerRegister);
-
- for (size_t j : Seq(kNumQueries)) {
+ //__builtin_prefetch(&accumulators[j][0], 0, 0);
+ for (size_t j : Seq(kNumQueries)) {//__builtin_prefetch(&accumulators[j][0], 0, 0);
Simd<FloatT> query_simd = query_ptrs[j][dim];
FusedMultiplySubtract(query_simd, transposed_simd0,
&accumulators[j][0]);
diff --git a/scann/scann/distance_measures/many_to_many/many_to_many_templates.h b/scann/scann/distance_measures/many_to_many/many_to_many_templates.h
index 9c8751e6a..e9d13a4c3 100644
--- a/scann/scann/distance_measures/many_to_many/many_to_many_templates.h
+++ b/scann/scann/distance_measures/many_to_many/many_to_many_templates.h
@@ -95,7 +95,7 @@
namespace research_scann {
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
namespace sse4 {
#define SCANN_SIMD_ATTRIBUTE SCANN_SSE4
@@ -170,7 +170,7 @@ SCANN_INLINE void DenseDistanceManyToManyImpl2(
DCHECK(IsSupportedDistanceMeasure(dist));
DCHECK_NE(dist.specially_optimized_distance_tag(), DistanceMeasure::COSINE);
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
if (RuntimeSupportsAvx512()) {
return avx512::DenseDistanceManyToManyImpl(dist, queries, database, pool,
callback);
@@ -202,7 +202,7 @@ SCANN_INLINE void DenseDistanceManyToManyFP8PretransposedImpl2(
DCHECK(IsSupportedDistanceMeasure(dist));
DCHECK_NE(dist.specially_optimized_distance_tag(), DistanceMeasure::COSINE);
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
if (RuntimeSupportsAvx512()) {
return avx512::DenseManyToManyFP8PretransposedImpl(dist, queries, database,
pool, callback);
diff --git a/scann/scann/distance_measures/one_to_many/one_to_many.h b/scann/scann/distance_measures/one_to_many/one_to_many.h
index 77463e7ef..ffbeb37fc 100644
--- a/scann/scann/distance_measures/one_to_many/one_to_many.h
+++ b/scann/scann/distance_measures/one_to_many/one_to_many.h
@@ -1724,7 +1724,7 @@ void DenseDistanceOneToMany(const DistanceMeasure& dist,
dist, query, database, result, &set_distance_functor, pool);
}
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
namespace sse4 {
#define SCANN_SIMD_ATTRIBUTE SCANN_SSE4
@@ -1798,7 +1798,7 @@ SCANN_INLINE void OneToManyInt8FloatDispatch(
const float* __restrict__ inv_multipliers_for_squared_l2,
const IndexT* indices, MutableSpan<ResultElemT> result,
CallbackT callback) {
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
if constexpr (false && RuntimeSupportsAvx512()) {
LOG(FATAL) << "We aren't compiling Avx-512 support yet.";
diff --git a/scann/scann/distance_measures/one_to_one/dot_product.h b/scann/scann/distance_measures/one_to_one/dot_product.h
index a4897cac6..8f9cdec6d 100644
--- a/scann/scann/distance_measures/one_to_one/dot_product.h
+++ b/scann/scann/distance_measures/one_to_one/dot_product.h
@@ -168,7 +168,7 @@ double DenseDotProduct(const DatapointPtr<T>& a, const DatapointPtr<U>& b,
return DenseDotProductFallback(a, b, c);
}
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
template <>
inline double DenseDotProduct<uint8_t, uint8_t>(
diff --git a/scann/scann/distance_measures/one_to_one/dot_product_avx1.cc b/scann/scann/distance_measures/one_to_one/dot_product_avx1.cc
index f97dfe238..f464b5fe8 100644
--- a/scann/scann/distance_measures/one_to_one/dot_product_avx1.cc
+++ b/scann/scann/distance_measures/one_to_one/dot_product_avx1.cc
@@ -16,7 +16,7 @@
#include <cstdint>
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/data_format/datapoint.h"
#include "scann/utils/internal/avx_funcs.h"
diff --git a/scann/scann/distance_measures/one_to_one/dot_product_avx1.h b/scann/scann/distance_measures/one_to_one/dot_product_avx1.h
index f86d99f3d..35f60b7a3 100644
--- a/scann/scann/distance_measures/one_to_one/dot_product_avx1.h
+++ b/scann/scann/distance_measures/one_to_one/dot_product_avx1.h
@@ -15,8 +15,8 @@
#ifndef SCANN_DISTANCE_MEASURES_ONE_TO_ONE_DOT_PRODUCT_AVX1_H_
#define SCANN_DISTANCE_MEASURES_ONE_TO_ONE_DOT_PRODUCT_AVX1_H_
#include <cstdint>
-#ifdef __x86_64__
-
+#if 1 // #ifdef __x86_64__
+#include "avx2ki.h"
#include "scann/data_format/datapoint.h"
#include "scann/utils/intrinsics/attributes.h"
diff --git a/scann/scann/distance_measures/one_to_one/dot_product_avx2.cc b/scann/scann/distance_measures/one_to_one/dot_product_avx2.cc
index 8ae66a506..0893d4a88 100644
--- a/scann/scann/distance_measures/one_to_one/dot_product_avx2.cc
+++ b/scann/scann/distance_measures/one_to_one/dot_product_avx2.cc
@@ -16,7 +16,7 @@
#include <cstdint>
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/data_format/datapoint.h"
#include "scann/utils/internal/avx2_funcs.h"
diff --git a/scann/scann/distance_measures/one_to_one/dot_product_avx2.h b/scann/scann/distance_measures/one_to_one/dot_product_avx2.h
index 600c842a8..70754d77f 100644
--- a/scann/scann/distance_measures/one_to_one/dot_product_avx2.h
+++ b/scann/scann/distance_measures/one_to_one/dot_product_avx2.h
@@ -15,8 +15,8 @@
#ifndef SCANN_DISTANCE_MEASURES_ONE_TO_ONE_DOT_PRODUCT_AVX2_H_
#define SCANN_DISTANCE_MEASURES_ONE_TO_ONE_DOT_PRODUCT_AVX2_H_
#include <cstdint>
-#ifdef __x86_64__
-
+#if 1 // #ifdef __x86_64__
+#include "avx2ki.h"
#include "scann/data_format/datapoint.h"
#include "scann/utils/intrinsics/attributes.h"
diff --git a/scann/scann/distance_measures/one_to_one/dot_product_sse4.cc b/scann/scann/distance_measures/one_to_one/dot_product_sse4.cc
index 49140b885..f8af08168 100644
--- a/scann/scann/distance_measures/one_to_one/dot_product_sse4.cc
+++ b/scann/scann/distance_measures/one_to_one/dot_product_sse4.cc
@@ -15,7 +15,7 @@
#include "scann/distance_measures/one_to_one/dot_product_sse4.h"
#include <cstdint>
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/utils/intrinsics/sse4.h"
@@ -224,7 +224,7 @@ SCANN_SSE4_OUTLINE double DenseDotProductSse4(const DatapointPtr<int8_t>& a,
__m128 accumulator = _mm_add_ps(accumulator0, accumulator1);
accumulator = _mm_hadd_ps(accumulator, accumulator);
accumulator = _mm_hadd_ps(accumulator, accumulator);
- scalar_accumulator = accumulator[0];
+ scalar_accumulator = accumulator.vect_f32[0];
}
DCHECK_LT(aend - aptr, 4);
@@ -283,12 +283,12 @@ SCANN_SSE4_OUTLINE double DenseDotProductSse4(const DatapointPtr<float>& a,
}
if (aptr < aend) {
- accumulator[0] += aptr[0] * bptr[0];
+ accumulator.vect_f32[0] += aptr[0] * bptr[0];
}
accumulator = _mm_hadd_ps(accumulator, accumulator);
accumulator = _mm_hadd_ps(accumulator, accumulator);
- return accumulator[0];
+ return accumulator.vect_f32[0];
}
SCANN_SSE4_OUTLINE double DenseDotProductSse4(const DatapointPtr<double>& a,
@@ -328,7 +328,7 @@ SCANN_SSE4_OUTLINE double DenseDotProductSse4(const DatapointPtr<double>& a,
}
accumulator = _mm_hadd_pd(accumulator, accumulator);
- double result = accumulator[0];
+ double result = accumulator.vect_f64[0];
if (aptr < aend) {
result += *aptr * *bptr;
@@ -423,7 +423,7 @@ SCANN_SSE4_OUTLINE double DenseDotProductSse4(const DatapointPtr<int8_t>& a,
__m128 accumulator = _mm_add_ps(accumulator0, accumulator1);
accumulator = _mm_hadd_ps(accumulator, accumulator);
accumulator = _mm_hadd_ps(accumulator, accumulator);
- scalar_accumulator = accumulator[0];
+ scalar_accumulator = accumulator.vect_f32[0];
}
DCHECK_LT(aend - aptr, 4);
@@ -528,7 +528,7 @@ SCANN_SSE4_OUTLINE double DenseDotProductSse4(const DatapointPtr<int8_t>& a,
__m128 accumulator = _mm_add_ps(accumulator0, accumulator1);
accumulator = _mm_hadd_ps(accumulator, accumulator);
accumulator = _mm_hadd_ps(accumulator, accumulator);
- scalar_accumulator = accumulator[0];
+ scalar_accumulator = accumulator.vect_f32[0];
}
DCHECK_LT(aend - aptr, 4);
diff --git a/scann/scann/distance_measures/one_to_one/dot_product_sse4.h b/scann/scann/distance_measures/one_to_one/dot_product_sse4.h
index efd595272..12126c4de 100644
--- a/scann/scann/distance_measures/one_to_one/dot_product_sse4.h
+++ b/scann/scann/distance_measures/one_to_one/dot_product_sse4.h
@@ -15,7 +15,7 @@
#ifndef SCANN_DISTANCE_MEASURES_ONE_TO_ONE_DOT_PRODUCT_SSE4_H_
#define SCANN_DISTANCE_MEASURES_ONE_TO_ONE_DOT_PRODUCT_SSE4_H_
#include <cstdint>
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/data_format/datapoint.h"
#include "scann/utils/intrinsics/attributes.h"
diff --git a/scann/scann/distance_measures/one_to_one/l1_distance.h b/scann/scann/distance_measures/one_to_one/l1_distance.h
index f301158d6..3f1e6cbfa 100644
--- a/scann/scann/distance_measures/one_to_one/l1_distance.h
+++ b/scann/scann/distance_measures/one_to_one/l1_distance.h
@@ -100,7 +100,7 @@ double DenseL1Norm(const DatapointPtr<T>& a, const DatapointPtr<U>& b) {
return DenseL1NormFallback(a, b);
}
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
template <>
inline double DenseL1Norm<float, float>(const DatapointPtr<float>& a,
diff --git a/scann/scann/distance_measures/one_to_one/l1_distance_sse4.cc b/scann/scann/distance_measures/one_to_one/l1_distance_sse4.cc
index 664b66f30..8eda01edc 100644
--- a/scann/scann/distance_measures/one_to_one/l1_distance_sse4.cc
+++ b/scann/scann/distance_measures/one_to_one/l1_distance_sse4.cc
@@ -15,7 +15,7 @@
#include "scann/distance_measures/one_to_one/l1_distance_sse4.h"
#include <cstdint>
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/utils/intrinsics/sse4.h"
@@ -78,13 +78,13 @@ SCANN_SSE4_OUTLINE double DenseL1NormSse4(const DatapointPtr<float>& a,
}
if (aptr < aend) {
- accumulator0[0] += std::abs(aptr[0] - bptr[0]);
+ accumulator0.vect_f32[0] += std::abs(aptr[0] - bptr[0]);
}
__m128 accumulator = _mm_add_ps(accumulator0, accumulator1);
accumulator = _mm_hadd_ps(accumulator, accumulator);
accumulator = _mm_hadd_ps(accumulator, accumulator);
- return accumulator[0];
+ return accumulator.vect_f32[0];
}
SCANN_SSE4_OUTLINE double DenseL1NormSse4(const DatapointPtr<double>& a,
@@ -130,7 +130,7 @@ SCANN_SSE4_OUTLINE double DenseL1NormSse4(const DatapointPtr<double>& a,
__m128d accumulator = _mm_add_pd(accumulator0, accumulator1);
accumulator = _mm_hadd_pd(accumulator, accumulator);
- double result = accumulator[0];
+ double result = accumulator.vect_f64[0];
if (aptr < aend) {
result += std::abs(*aptr - *bptr);
diff --git a/scann/scann/distance_measures/one_to_one/l1_distance_sse4.h b/scann/scann/distance_measures/one_to_one/l1_distance_sse4.h
index 1fccffe4e..9f76a1f5c 100644
--- a/scann/scann/distance_measures/one_to_one/l1_distance_sse4.h
+++ b/scann/scann/distance_measures/one_to_one/l1_distance_sse4.h
@@ -14,7 +14,7 @@
#ifndef SCANN_DISTANCE_MEASURES_ONE_TO_ONE_L1_DISTANCE_SSE4_H_
#define SCANN_DISTANCE_MEASURES_ONE_TO_ONE_L1_DISTANCE_SSE4_H_
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/data_format/datapoint.h"
#include "scann/utils/intrinsics/attributes.h"
diff --git a/scann/scann/distance_measures/one_to_one/l2_distance.h b/scann/scann/distance_measures/one_to_one/l2_distance.h
index dba49ead7..1851a1130 100644
--- a/scann/scann/distance_measures/one_to_one/l2_distance.h
+++ b/scann/scann/distance_measures/one_to_one/l2_distance.h
@@ -180,7 +180,7 @@ double DenseSquaredL2Distance(const DatapointPtr<T>& a,
return DenseSquaredL2DistanceFallback(a, b);
}
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
template <>
inline double DenseSquaredL2Distance<uint8_t, uint8_t>(
diff --git a/scann/scann/distance_measures/one_to_one/l2_distance_avx1.cc b/scann/scann/distance_measures/one_to_one/l2_distance_avx1.cc
index dd82f2472..0e4dcc053 100644
--- a/scann/scann/distance_measures/one_to_one/l2_distance_avx1.cc
+++ b/scann/scann/distance_measures/one_to_one/l2_distance_avx1.cc
@@ -13,7 +13,7 @@
// limitations under the License.
#include "scann/distance_measures/one_to_one/l2_distance_avx1.h"
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/utils/intrinsics/avx1.h"
@@ -67,7 +67,7 @@ SCANN_AVX1_OUTLINE double DenseSquaredL2DistanceAvx1(
bptr += 2;
}
__m128d sum = _mm_add_pd(upper, lower);
- double result = sum[0] + sum[1];
+ double result = sum.vect_f64[0] + sum.vect_f64[1];
if (aptr < aend) {
const double to_square = *aptr - *bptr;
diff --git a/scann/scann/distance_measures/one_to_one/l2_distance_avx1.h b/scann/scann/distance_measures/one_to_one/l2_distance_avx1.h
index d6073a0c5..98db35472 100644
--- a/scann/scann/distance_measures/one_to_one/l2_distance_avx1.h
+++ b/scann/scann/distance_measures/one_to_one/l2_distance_avx1.h
@@ -14,7 +14,7 @@
#ifndef SCANN_DISTANCE_MEASURES_ONE_TO_ONE_L2_DISTANCE_AVX1_H_
#define SCANN_DISTANCE_MEASURES_ONE_TO_ONE_L2_DISTANCE_AVX1_H_
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/data_format/datapoint.h"
#include "scann/utils/intrinsics/attributes.h"
diff --git a/scann/scann/distance_measures/one_to_one/l2_distance_sse4.cc b/scann/scann/distance_measures/one_to_one/l2_distance_sse4.cc
index bd0f4af04..339b65d37 100644
--- a/scann/scann/distance_measures/one_to_one/l2_distance_sse4.cc
+++ b/scann/scann/distance_measures/one_to_one/l2_distance_sse4.cc
@@ -16,7 +16,7 @@
#include <cstdint>
#include <utility>
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/utils/intrinsics/sse4.h"
@@ -209,12 +209,13 @@ SCANN_SSE4_OUTLINE double DenseSquaredL2DistanceSse4(
}
if (aptr < aend) {
- accumulator[0] += (aptr[0] - bptr[0]) * (aptr[0] - bptr[0]);
+ // accumulator[0] += (aptr[0] - bptr[0]) * (aptr[0] - bptr[0]);
+ accumulator.vect_f32[0] += (aptr[0] - bptr[0]) * (aptr[0] - bptr[0]);
}
accumulator = _mm_hadd_ps(accumulator, accumulator);
accumulator = _mm_hadd_ps(accumulator, accumulator);
- return accumulator[0];
+ return accumulator.float32x4_ptr[0];
}
SCANN_SSE4_OUTLINE double DenseSquaredL2DistanceSse4(
@@ -255,7 +256,7 @@ SCANN_SSE4_OUTLINE double DenseSquaredL2DistanceSse4(
}
accumulator = _mm_hadd_pd(accumulator, accumulator);
- double result = accumulator[0];
+ double result = accumulator.vect_f64[0];
if (aptr < aend) {
const double diff = *aptr - *bptr;
diff --git a/scann/scann/distance_measures/one_to_one/l2_distance_sse4.h b/scann/scann/distance_measures/one_to_one/l2_distance_sse4.h
index 5ae31a6a0..d0bfd1826 100644
--- a/scann/scann/distance_measures/one_to_one/l2_distance_sse4.h
+++ b/scann/scann/distance_measures/one_to_one/l2_distance_sse4.h
@@ -15,7 +15,7 @@
#ifndef SCANN_DISTANCE_MEASURES_ONE_TO_ONE_L2_DISTANCE_SSE4_H_
#define SCANN_DISTANCE_MEASURES_ONE_TO_ONE_L2_DISTANCE_SSE4_H_
#include <cstdint>
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/data_format/datapoint.h"
#include "scann/utils/intrinsics/attributes.h"
diff --git a/scann/scann/hashes/asymmetric_hashing2/querying.h b/scann/scann/hashes/asymmetric_hashing2/querying.h
index 08d9aa7a5..6a9e62d1c 100644
--- a/scann/scann/hashes/asymmetric_hashing2/querying.h
+++ b/scann/scann/hashes/asymmetric_hashing2/querying.h
@@ -453,11 +453,11 @@ Status AsymmetricQueryer<T>::FindApproximateTopNeighborsTopNDispatch(
"The distance type for TopN must be float for "
"AsymmetricQueryer::FindApproximateNeighbors.");
- const bool can_use_lut16 =
- RuntimeSupportsSse4() && querying_options.lut16_packed_dataset &&
- !lookup_table.int8_lookup_table.empty() &&
- (lookup_table.int8_lookup_table.size() /
- querying_options.lut16_packed_dataset->num_blocks) == 16;
+ const bool can_use_lut16 = true;
+ // RuntimeSupportsSse4() && querying_options.lut16_packed_dataset &&
+ // !lookup_table.int8_lookup_table.empty() &&
+ // (lookup_table.int8_lookup_table.size() /
+ // querying_options.lut16_packed_dataset->num_blocks) == 16;
if (!can_use_lut16)
return InvalidArgumentError(
"FastTopNeighbors+AsymmetricQueryer fast path only works with LUT16.");
diff --git a/scann/scann/hashes/internal/asymmetric_hashing_impl.cc b/scann/scann/hashes/internal/asymmetric_hashing_impl.cc
index 4c375cbc7..5909f2012 100644
--- a/scann/scann/hashes/internal/asymmetric_hashing_impl.cc
+++ b/scann/scann/hashes/internal/asymmetric_hashing_impl.cc
@@ -419,9 +419,10 @@ Status ValidateNoiseShapingParams(double threshold, double eta) {
"indexing.");
}
if (!std::isnan(eta) && !std::isnan(threshold)) {
- return InvalidArgumentError(
- "Threshold and eta may not both be specified for noise-shaped AH "
- "indexing.");
+ //return InvalidArgumentError(
+ // "Threshold and eta may not both be specified for noise-shaped AH "
+ // "indexing.");
+ //LOG(INFO) << "hreshold and eta may not both be specified for noise-shaped AH indexing.";
}
return OkStatus();
}
diff --git a/scann/scann/hashes/internal/bazel_templates/lut16_avx2.tpl.cc b/scann/scann/hashes/internal/bazel_templates/lut16_avx2.tpl.cc
index 4d38c56cb..284e7149f 100644
--- a/scann/scann/hashes/internal/bazel_templates/lut16_avx2.tpl.cc
+++ b/scann/scann/hashes/internal/bazel_templates/lut16_avx2.tpl.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifdef __x86_64__
+//#if 1 // #ifdef __x86_64__
#include "scann/hashes/internal/lut16_avx2.inc"
namespace research_scann {
@@ -25,4 +25,4 @@ template class LUT16Avx2<{BATCH_SIZE}, PrefetchStrategy::kSmart>;
} // namespace asymmetric_hashing_internal
} // namespace research_scann
-#endif
+//#endif
diff --git a/scann/scann/hashes/internal/bazel_templates/lut16_avx512_noprefetch.tpl.cc b/scann/scann/hashes/internal/bazel_templates/lut16_avx512_noprefetch.tpl.cc
index 34a69bd80..f3d730b7a 100644
--- a/scann/scann/hashes/internal/bazel_templates/lut16_avx512_noprefetch.tpl.cc
+++ b/scann/scann/hashes/internal/bazel_templates/lut16_avx512_noprefetch.tpl.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifdef __x86_64__
+//#if 1 // #ifdef __x86_64__
#include "scann/hashes/internal/lut16_avx512.inc"
namespace research_scann {
@@ -23,4 +23,4 @@ template class LUT16Avx512<{BATCH_SIZE}, PrefetchStrategy::kOff>;
}
} // namespace research_scann
-#endif
+//#endif
diff --git a/scann/scann/hashes/internal/bazel_templates/lut16_avx512_prefetch.tpl.cc b/scann/scann/hashes/internal/bazel_templates/lut16_avx512_prefetch.tpl.cc
index 5cd64c34c..642a3df1b 100644
--- a/scann/scann/hashes/internal/bazel_templates/lut16_avx512_prefetch.tpl.cc
+++ b/scann/scann/hashes/internal/bazel_templates/lut16_avx512_prefetch.tpl.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifdef __x86_64__
+//#if 1 // #ifdef __x86_64__
#include "scann/hashes/internal/lut16_avx512.inc"
namespace research_scann {
@@ -23,4 +23,4 @@ template class LUT16Avx512<{BATCH_SIZE}, PrefetchStrategy::kSeq>;
}
} // namespace research_scann
-#endif
+//#endif
diff --git a/scann/scann/hashes/internal/bazel_templates/lut16_avx512_smart.tpl.cc b/scann/scann/hashes/internal/bazel_templates/lut16_avx512_smart.tpl.cc
index 5365f86cf..4c7d3de34 100644
--- a/scann/scann/hashes/internal/bazel_templates/lut16_avx512_smart.tpl.cc
+++ b/scann/scann/hashes/internal/bazel_templates/lut16_avx512_smart.tpl.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifdef __x86_64__
+//#if 1 // #ifdef __x86_64__
#include "scann/hashes/internal/lut16_avx512.inc"
namespace research_scann {
@@ -23,4 +23,4 @@ template class LUT16Avx512<{BATCH_SIZE}, PrefetchStrategy::kSmart>;
}
} // namespace research_scann
-#endif
+//#endif
diff --git a/scann/scann/hashes/internal/bazel_templates/lut16_sse4.tpl.cc b/scann/scann/hashes/internal/bazel_templates/lut16_sse4.tpl.cc
index bba39f36e..920104e4b 100644
--- a/scann/scann/hashes/internal/bazel_templates/lut16_sse4.tpl.cc
+++ b/scann/scann/hashes/internal/bazel_templates/lut16_sse4.tpl.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifdef __x86_64__
+//#if 1 // #ifdef __x86_64__
#include "scann/hashes/internal/lut16_sse4.inc"
namespace research_scann {
@@ -25,4 +25,4 @@ template class LUT16Sse4<{BATCH_SIZE}, PrefetchStrategy::kSmart>;
} // namespace asymmetric_hashing_internal
} // namespace research_scann
-#endif
+//#endif
diff --git a/scann/scann/hashes/internal/lut16_avx2.h b/scann/scann/hashes/internal/lut16_avx2.h
index fde3ce216..c6a509b7d 100644
--- a/scann/scann/hashes/internal/lut16_avx2.h
+++ b/scann/scann/hashes/internal/lut16_avx2.h
@@ -16,7 +16,7 @@
#define SCANN_HASHES_INTERNAL_LUT16_AVX2_H_
#include <cstdint>
-#ifdef __x86_64__
+//#if 1 // #ifdef __x86_64__
#include "scann/hashes/internal/lut16_args.h"
#include "scann/utils/intrinsics/attributes.h"
@@ -45,5 +45,5 @@ SCANN_INSTANTIATE_CLASS_FOR_LUT16_BATCH_SIZES(extern, LUT16Avx2);
} // namespace asymmetric_hashing_internal
} // namespace research_scann
-#endif
+//#endif
#endif
diff --git a/scann/scann/hashes/internal/lut16_avx2.inc b/scann/scann/hashes/internal/lut16_avx2.inc
index dda59d5eb..c17570a82 100644
--- a/scann/scann/hashes/internal/lut16_avx2.inc
+++ b/scann/scann/hashes/internal/lut16_avx2.inc
@@ -4,7 +4,7 @@
#include "scann/oss_wrappers/scann_bits.h"
#include "scann/utils/common.h"
-#ifdef __x86_64__
+//#if 1 // #ifdef __x86_64__
#include "scann/utils/bits.h"
#include "scann/utils/intrinsics/avx2.h"
@@ -522,4 +522,4 @@ SCANN_AVX2_OUTLINE void LUT16Avx2<kNumQueries, kPrefetch>::GetTopFloatDistances(
} // namespace asymmetric_hashing_internal
} // namespace research_scann
-#endif
+//#endif
diff --git a/scann/scann/hashes/internal/lut16_avx512.h b/scann/scann/hashes/internal/lut16_avx512.h
index e973076f9..b833499fc 100644
--- a/scann/scann/hashes/internal/lut16_avx512.h
+++ b/scann/scann/hashes/internal/lut16_avx512.h
@@ -16,7 +16,7 @@
#define SCANN_HASHES_INTERNAL_LUT16_AVX512_H_
#include <cstdint>
-#ifdef __x86_64__
+//#if 1 // #ifdef __x86_64__
#include "scann/hashes/internal/lut16_args.h"
#include "scann/utils/types.h"
@@ -45,5 +45,5 @@ SCANN_INSTANTIATE_CLASS_FOR_LUT16_BATCH_SIZES(extern, LUT16Avx512);
} // namespace asymmetric_hashing_internal
} // namespace research_scann
-#endif
+//#endif
#endif
diff --git a/scann/scann/hashes/internal/lut16_avx512.inc b/scann/scann/hashes/internal/lut16_avx512.inc
index fe2348e16..4acfd2755 100644
--- a/scann/scann/hashes/internal/lut16_avx512.inc
+++ b/scann/scann/hashes/internal/lut16_avx512.inc
@@ -6,7 +6,7 @@
#include "scann/oss_wrappers/scann_bits.h"
#include "scann/utils/common.h"
-#ifdef __x86_64__
+//#if 1 // #ifdef __x86_64__
#include "scann/hashes/internal/lut16_avx512_swizzle.h"
#include "scann/utils/bits.h"
@@ -798,4 +798,4 @@ void LUT16Avx512<kNumQueries, kPrefetch>::GetTopFloatDistances(
} // namespace asymmetric_hashing_internal
} // namespace research_scann
-#endif
+//#endif
diff --git a/scann/scann/hashes/internal/lut16_avx512_swizzle.cc b/scann/scann/hashes/internal/lut16_avx512_swizzle.cc
index 70bcfad22..a2a472a8b 100644
--- a/scann/scann/hashes/internal/lut16_avx512_swizzle.cc
+++ b/scann/scann/hashes/internal/lut16_avx512_swizzle.cc
@@ -13,7 +13,7 @@
// limitations under the License.
#include <cstdint>
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/hashes/internal/lut16_avx512_swizzle.h"
#include "scann/utils/common.h"
#include "scann/utils/intrinsics/avx512.h"
diff --git a/scann/scann/hashes/internal/lut16_avx512_swizzle.h b/scann/scann/hashes/internal/lut16_avx512_swizzle.h
index 35bcffa2e..1eea8437c 100644
--- a/scann/scann/hashes/internal/lut16_avx512_swizzle.h
+++ b/scann/scann/hashes/internal/lut16_avx512_swizzle.h
@@ -15,7 +15,7 @@
#ifndef SCANN_HASHES_INTERNAL_LUT16_AVX512_SWIZZLE_H_
#define SCANN_HASHES_INTERNAL_LUT16_AVX512_SWIZZLE_H_
#include <cstdint>
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/utils/intrinsics/attributes.h"
#include "tensorflow/core/platform/types.h"
diff --git a/scann/scann/hashes/internal/lut16_interface.h b/scann/scann/hashes/internal/lut16_interface.h
index c4db23332..9808da05d 100644
--- a/scann/scann/hashes/internal/lut16_interface.h
+++ b/scann/scann/hashes/internal/lut16_interface.h
@@ -154,7 +154,7 @@ class LUT16Interface {
LOG(FATAL) << "Invalid Batch Size"; \
}
-#ifdef __x86_64__
+//#if 1 // #ifdef __x86_64__
#define SCANN_CALL_LUT16_FUNCTION(enable_avx512_codepath, batch_size, \
prefetch_strategy, Function, ...) \
@@ -251,32 +251,32 @@ void LUT16Interface::GetTopFloatDistances(LUT16ArgsTopN<float, TopN> args) {
std::move(args));
}
-#else
+// #else
-void LUT16Interface::GetDistances(LUT16Args<int16_t> args) {
- LOG(FATAL) << "LUT16 is only supported on x86!";
-}
+// void LUT16Interface::GetDistances(LUT16Args<int16_t> args) {
+// LOG(FATAL) << "LUT16 is only supported on x86!";
+// }
-void LUT16Interface::GetDistances(LUT16Args<int32_t> args) {
- LOG(FATAL) << "LUT16 is only supported on x86!";
-}
+// void LUT16Interface::GetDistances(LUT16Args<int32_t> args) {
+// LOG(FATAL) << "LUT16 is only supported on x86!";
+// }
-void LUT16Interface::GetFloatDistances(LUT16Args<float> args,
- ConstSpan<float> inv_fp_multipliers) {
- LOG(FATAL) << "LUT16 is only supported on x86!";
-}
+// void LUT16Interface::GetFloatDistances(LUT16Args<float> args,
+// ConstSpan<float> inv_fp_multipliers) {
+// LOG(FATAL) << "LUT16 is only supported on x86!";
+// }
-template <typename TopN>
-void LUT16Interface::GetTopDistances(LUT16ArgsTopN<int16_t, TopN> args) {
- LOG(FATAL) << "LUT16 is only supported on x86!";
-}
+// template <typename TopN>
+// void LUT16Interface::GetTopDistances(LUT16ArgsTopN<int16_t, TopN> args) {
+// LOG(FATAL) << "LUT16 is only supported on x86!";
+// }
-template <typename TopN>
-void LUT16Interface::GetTopFloatDistances(LUT16ArgsTopN<float, TopN> args) {
- LOG(FATAL) << "LUT16 is only supported on x86!";
-}
+// template <typename TopN>
+// void LUT16Interface::GetTopFloatDistances(LUT16ArgsTopN<float, TopN> args) {
+// LOG(FATAL) << "LUT16 is only supported on x86!";
+// }
-#endif
+// #endif
} // namespace asymmetric_hashing_internal
} // namespace research_scann
diff --git a/scann/scann/hashes/internal/lut16_sse4.h b/scann/scann/hashes/internal/lut16_sse4.h
index b228dd9d3..71e89c273 100644
--- a/scann/scann/hashes/internal/lut16_sse4.h
+++ b/scann/scann/hashes/internal/lut16_sse4.h
@@ -17,7 +17,7 @@
#include <cstdint>
-#ifdef __x86_64__
+//#if 1 // #ifdef __x86_64__
#include "scann/hashes/internal/lut16_args.h"
#include "scann/utils/intrinsics/attributes.h"
@@ -46,5 +46,5 @@ SCANN_INSTANTIATE_CLASS_FOR_LUT16_BATCH_SIZES(extern, LUT16Sse4);
} // namespace asymmetric_hashing_internal
} // namespace research_scann
-#endif
+//#endif
#endif
diff --git a/scann/scann/hashes/internal/lut16_sse4.inc b/scann/scann/hashes/internal/lut16_sse4.inc
index 4198f5fab..b2e134cef 100644
--- a/scann/scann/hashes/internal/lut16_sse4.inc
+++ b/scann/scann/hashes/internal/lut16_sse4.inc
@@ -4,7 +4,7 @@
#include "scann/oss_wrappers/scann_bits.h"
#include "scann/utils/common.h"
-#ifdef __x86_64__
+// #if 1 // #ifdef __x86_64__
#include "scann/utils/bits.h"
#include "scann/utils/intrinsics/sse4.h"
@@ -32,11 +32,11 @@ SCANN_SSE4_INLINE Sse4<int16_t, kNumQueries, 4> Sse4LUT16BottomLoop(
const Sse4<uint8_t> sign7 = 0x0F;
const Sse4<int16_t> total_bias = num_blocks * 128;
for (; num_blocks != 0; --num_blocks, data_start += 16) {
- if (kPrefetch != PrefetchStrategy::kOff) {
+ /*if (kPrefetch != PrefetchStrategy::kOff) {
::tensorflow::port::prefetch<::tensorflow::port::PREFETCH_HINT_NTA>(
data_start + kPrefetchBytesAhead);
- }
-
+ }*/
+ //__builtin_prefetch(data_start + kPrefetchBytesAhead + kPrefetchBytesAhead, 0, 0);
auto mask = Sse4<uint8_t>::Load(data_start);
Sse4<uint8_t> mask0 = mask & sign7;
Sse4<uint8_t> mask1 = Sse4<uint8_t>((Sse4<uint16_t>(mask) >> 4)) & sign7;
@@ -399,4 +399,4 @@ SCANN_SSE4_OUTLINE void LUT16Sse4<kNumQueries, kPrefetch>::GetTopFloatDistances(
} // namespace asymmetric_hashing_internal
} // namespace research_scann
-#endif
+//#endif
diff --git a/scann/scann/partitioning/kmeans_tree_partitioner.h b/scann/scann/partitioning/kmeans_tree_partitioner.h
index 74dfa12af..4415a72fe 100644
--- a/scann/scann/partitioning/kmeans_tree_partitioner.h
+++ b/scann/scann/partitioning/kmeans_tree_partitioner.h
@@ -30,7 +30,7 @@
#include "scann/oss_wrappers/scann_status.h"
#include "scann/oss_wrappers/scann_threadpool.h"
#include "scann/partitioning/kmeans_tree_like_partitioner.h"
-#include "scann/partitioning/orthogonality_amplification_utils.h"
+// #include "scann/partitioning/orthogonality_amplification_utils.h"
#include "scann/partitioning/partitioner.pb.h"
#include "scann/partitioning/partitioner_base.h"
#include "scann/trees/kmeans_tree/kmeans_tree.h"
diff --git a/scann/scann/utils/fast_top_neighbors.cc b/scann/scann/utils/fast_top_neighbors.cc
index 732835fb0..16ca76032 100644
--- a/scann/scann/utils/fast_top_neighbors.cc
+++ b/scann/scann/utils/fast_top_neighbors.cc
@@ -120,7 +120,7 @@ SCANN_INLINE DistT FastMedianOf3(DistT v0, DistT v1, DistT v2) {
} // namespace
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
namespace avx2 {
#define SCANN_SIMD_ATTRIBUTE SCANN_AVX2
@@ -141,7 +141,7 @@ size_t FastTopNeighbors<DistT, DatapointIndexT>::ApproxNthElement(
size_t keep_min, size_t keep_max, size_t sz, DatapointIndexT* ii, DistT* dd,
uint32_t* mm) {
DCHECK_GT(keep_min, 0);
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
if (RuntimeSupportsAvx2()) {
return avx2::ApproxNthElementImpl(keep_min, keep_max, sz, ii, dd, mm);
} else if (RuntimeSupportsSse4()) {
diff --git a/scann/scann/utils/internal/avx2_funcs.h b/scann/scann/utils/internal/avx2_funcs.h
index 06f3e55ae..9bca61801 100644
--- a/scann/scann/utils/internal/avx2_funcs.h
+++ b/scann/scann/utils/internal/avx2_funcs.h
@@ -14,7 +14,7 @@
#ifndef SCANN_UTILS_INTERNAL_AVX2_FUNCS_H_
#define SCANN_UTILS_INTERNAL_AVX2_FUNCS_H_
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/utils/intrinsics/avx2.h"
#include "scann/utils/types.h"
@@ -59,7 +59,9 @@ class AvxFunctionsAvx2Fma {
__m128 sum = _mm_add_ps(upper, lower);
sum = _mm_add_ps(
sum, _mm_castsi128_ps(_mm_srli_si128(_mm_castps_si128(sum), 8)));
- return sum[0] + sum[1];
+ //return sum[0] + sum[1];
+ //return _mm_extract_ps(sum, 0) + _mm_extract_ps(sum, 1);
+ return sum.float32x4_ptr[0] + sum.float32x4_ptr[1];
}
};
diff --git a/scann/scann/utils/internal/avx_funcs.h b/scann/scann/utils/internal/avx_funcs.h
index 9eec38b07..8d9abf71c 100644
--- a/scann/scann/utils/internal/avx_funcs.h
+++ b/scann/scann/utils/internal/avx_funcs.h
@@ -14,11 +14,11 @@
#ifndef SCANN_UTILS_INTERNAL_AVX_FUNCS_H_
#define SCANN_UTILS_INTERNAL_AVX_FUNCS_H_
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#include "scann/utils/intrinsics/avx1.h"
#include "scann/utils/types.h"
-
+#include "operatoroverload.h"
namespace research_scann {
class AvxFunctionsAvx {
@@ -63,7 +63,9 @@ class AvxFunctionsAvx {
__m128 sum = _mm_add_ps(upper, lower);
sum = _mm_add_ps(
sum, _mm_castsi128_ps(_mm_srli_si128(_mm_castps_si128(sum), 8)));
- return sum[0] + sum[1];
+ //return sum[0] + sum[1];
+ //return _mm_extract_ps(sum, 0) + _mm_extract_ps(sum, 1);
+ return sum.float32x4_ptr[0] + sum.float32x4_ptr[1];
}
};
diff --git a/scann/scann/utils/intrinsics/BUILD.bazel b/scann/scann/utils/intrinsics/BUILD.bazel
index 819b6ef76..63c84ced6 100644
--- a/scann/scann/utils/intrinsics/BUILD.bazel
+++ b/scann/scann/utils/intrinsics/BUILD.bazel
@@ -103,6 +103,7 @@ cc_library(
":flags",
"//scann/utils:index_sequence",
"//scann/utils:types",
+ "@ksl_external_lib//:avx2ki",
],
)
@@ -137,6 +138,7 @@ cc_library(
":flags",
"//scann/utils:index_sequence",
"//scann/utils:types",
+ "@ksl_external_lib//:avx2ki",
],
)
@@ -154,5 +156,6 @@ cc_library(
":flags",
"//scann/utils:index_sequence",
"//scann/utils:types",
+ "@ksl_external_lib//:avx2ki",
],
)
diff --git a/scann/scann/utils/intrinsics/attributes.h b/scann/scann/utils/intrinsics/attributes.h
index b3d1a851a..ea6e55fa4 100644
--- a/scann/scann/utils/intrinsics/attributes.h
+++ b/scann/scann/utils/intrinsics/attributes.h
@@ -15,13 +15,13 @@
#ifndef SCANN_UTILS_INTRINSICS_ATTRIBUTES_H_
#define SCANN_UTILS_INTRINSICS_ATTRIBUTES_H_
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
#define SCANN_SSE4
-#define SCANN_AVX1 __attribute((target("avx")))
-#define SCANN_AVX2 __attribute((target("avx,avx2,fma")))
+#define SCANN_AVX1 //__attribute((target("avx")))
+#define SCANN_AVX2 //__attribute((target("avx,avx2,fma")))
#define SCANN_AVX512 \
- __attribute((target("avx,avx2,fma,avx512f,avx512dq,avx512bw")))
+ //__attribute((target("avx,avx2,fma,avx512f,avx512dq,avx512bw")))
#else
diff --git a/scann/scann/utils/intrinsics/avx1.h b/scann/scann/utils/intrinsics/avx1.h
index 11c9da5b7..1a359c46a 100644
--- a/scann/scann/utils/intrinsics/avx1.h
+++ b/scann/scann/utils/intrinsics/avx1.h
@@ -25,9 +25,9 @@
#include "scann/utils/intrinsics/sse4.h"
#include "scann/utils/types.h"
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
-#include <x86intrin.h>
+#include "avx2ki.h" //<x86intrin.h>
namespace research_scann {
namespace avx1 {
diff --git a/scann/scann/utils/intrinsics/avx2.h b/scann/scann/utils/intrinsics/avx2.h
index 280ae33ef..70844345f 100644
--- a/scann/scann/utils/intrinsics/avx2.h
+++ b/scann/scann/utils/intrinsics/avx2.h
@@ -25,9 +25,11 @@
#include "scann/utils/intrinsics/flags.h"
#include "scann/utils/types.h"
-#ifdef __x86_64__
+//#if 1 // #ifdef __x86_64__
+
+
+#include "avx2ki.h"
-#include <x86intrin.h>
namespace research_scann {
namespace avx2 {
@@ -166,7 +168,7 @@ class Avx2<T, kNumRegistersInferred> {
if constexpr (IsSameAny<T, float>()) {
return _mm256_setzero_ps();
} else if constexpr (IsSameAny<T, double>()) {
- return _mm256_setzero_ps();
+ return _mm256_setzero_pd();
} else {
return _mm256_setzero_si256();
}
@@ -1003,14 +1005,14 @@ using Uninitialized = Avx2Uninitialized;
} // namespace avx2
} // namespace research_scann
-#else
+// #else
-namespace research_scann {
+// namespace research_scann {
-template <typename T, size_t... kTensorNumRegisters>
-struct Avx2;
+// template <typename T, size_t... kTensorNumRegisters>
+// struct Avx2;
-}
+// }
-#endif
+// #endif
#endif
diff --git a/scann/scann/utils/intrinsics/avx512.h b/scann/scann/utils/intrinsics/avx512.h
index 6e27632fe..b79800167 100644
--- a/scann/scann/utils/intrinsics/avx512.h
+++ b/scann/scann/utils/intrinsics/avx512.h
@@ -25,9 +25,10 @@
#include "scann/utils/intrinsics/flags.h"
#include "scann/utils/types.h"
-#ifdef __x86_64__
+//#if 1 // #ifdef __x86_64__
+
+#include "avx2ki.h"
-#include <x86intrin.h>
namespace research_scann {
namespace avx512 {
@@ -150,7 +151,7 @@ class Avx512<T, kNumRegistersInferred> {
if constexpr (IsSameAny<T, float>()) {
return _mm512_setzero_ps();
} else if constexpr (IsSameAny<T, double>()) {
- return _mm512_setzero_ps();
+ return _mm512_setzero_pd();
} else {
return _mm512_setzero_si512();
}
@@ -216,7 +217,7 @@ class Avx512<T, kNumRegistersInferred> {
if constexpr (IsSameAny<T, float>()) {
return _mm512_loadu_ps(reinterpret_cast<const __m512*>(address));
} else if constexpr (IsSameAny<T, double>()) {
- return _mm512_loadu_pd(reinterpret_cast<const __m512d*>(address));
+ return _mm512_loadu_pd(reinterpret_cast<const double*>(address));
} else {
return _mm512_loadu_si512(reinterpret_cast<const __m512i*>(address));
}
@@ -1043,14 +1044,14 @@ using Uninitialized = Avx512Uninitialized;
} // namespace avx512
} // namespace research_scann
-#else
+// #else
-namespace research_scann {
+// namespace research_scann {
-template <typename T, size_t... kTensorNumRegisters>
-struct Avx512;
+// template <typename T, size_t... kTensorNumRegisters>
+// struct Avx512;
-}
+// }
-#endif
+// #endif
#endif
diff --git a/scann/scann/utils/intrinsics/fallback.h b/scann/scann/utils/intrinsics/fallback.h
index 94f91b740..76cea4cb4 100644
--- a/scann/scann/utils/intrinsics/fallback.h
+++ b/scann/scann/utils/intrinsics/fallback.h
@@ -142,7 +142,7 @@ class Simd<T, kNumElementsArg> {
}
}
- static SCANN_INLINE Simd Zeros() {
+ static SCANN_INLINE Simd Zeros_s() {
Simd<T, kNumElements> ret;
for (size_t j : Seq(kNumElements)) {
ret[j] = IntelType(0);
diff --git a/scann/scann/utils/intrinsics/flags.cc b/scann/scann/utils/intrinsics/flags.cc
index 5663a5009..0ba7d2da5 100644
--- a/scann/scann/utils/intrinsics/flags.cc
+++ b/scann/scann/utils/intrinsics/flags.cc
@@ -37,14 +37,14 @@ ABSL_RETIRED_FLAG(bool, ignore_sse4, false, "Ignore SSE4");
namespace research_scann {
namespace flags_internal {
-bool should_use_sse4 =
- tensorflow::port::TestCPUFeature(tensorflow::port::SSE4_2);
+bool should_use_sse4 = 1;
+ //tensorflow::port::TestCPUFeature(tensorflow::port::SSE4_2);
bool should_use_avx1 = tensorflow::port::TestCPUFeature(tensorflow::port::AVX);
-bool should_use_avx2 = tensorflow::port::TestCPUFeature(tensorflow::port::AVX2);
-bool should_use_avx512 =
- tensorflow::port::TestCPUFeature(tensorflow::port::AVX512F) &&
- tensorflow::port::TestCPUFeature(tensorflow::port::AVX512DQ) &&
- tensorflow::port::TestCPUFeature(tensorflow::port::AVX512BW);
+bool should_use_avx2 = 1; // tensorflow::port::TestCPUFeature(tensorflow::port::AVX2);
+bool should_use_avx512 = 1;
+ //tensorflow::port::TestCPUFeature(tensorflow::port::AVX512F) &&
+ //tensorflow::port::TestCPUFeature(tensorflow::port::AVX512DQ) &&
+ //tensorflow::port::TestCPUFeature(tensorflow::port::AVX512BW);
} // namespace flags_internal
diff --git a/scann/scann/utils/intrinsics/fma.h b/scann/scann/utils/intrinsics/fma.h
index b6a158f97..cb3b1b236 100644
--- a/scann/scann/utils/intrinsics/fma.h
+++ b/scann/scann/utils/intrinsics/fma.h
@@ -20,7 +20,7 @@
namespace research_scann {
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
namespace avx512 {
#define SCANN_SIMD_ATTRIBUTE SCANN_AVX512
diff --git a/scann/scann/utils/intrinsics/horizontal_sum.h b/scann/scann/utils/intrinsics/horizontal_sum.h
index 6445eeac3..e33bac8f2 100644
--- a/scann/scann/utils/intrinsics/horizontal_sum.h
+++ b/scann/scann/utils/intrinsics/horizontal_sum.h
@@ -43,7 +43,7 @@ SCANN_INLINE void HorizontalSum4X(Simd<FloatT> a, Simd<FloatT> b,
} // namespace fallback
-#ifdef __x86_64__
+#if 1 // #ifdef __x86_64__
namespace sse4 {
@@ -135,8 +135,8 @@ SCANN_AVX1_INLINE void HorizontalSum2X(Avx1<float> a, Avx1<float> b,
sum += _mm256_shuffle_ps(sum, sum, 0b11'10'01'01);
- *resulta = sum[0];
- *resultb = sum[4];
+ *resulta = sum.vect_f32[0][0];
+ *resultb = sum.vect_f32[1][0];
}
SCANN_AVX1_INLINE void HorizontalSum3X(Avx1<float> a, Avx1<float> b,
@@ -148,9 +148,9 @@ SCANN_AVX1_INLINE void HorizontalSum3X(Avx1<float> a, Avx1<float> b,
abcg += _mm256_shuffle_ps(abcg, abcg, 0b11'11'01'01);
- *resulta = abcg[0];
- *resultb = abcg[2];
- *resultc = abcg[4];
+ *resulta = abcg.vect_f32[0][0];
+ *resultb = abcg.vect_f32[0][2];
+ *resultc = abcg.vect_f32[1][0];
}
SCANN_AVX1_INLINE void HorizontalSum4X(Avx1<float> a, Avx1<float> b,
@@ -163,10 +163,10 @@ SCANN_AVX1_INLINE void HorizontalSum4X(Avx1<float> a, Avx1<float> b,
abcd += _mm256_shuffle_ps(abcd, abcd, 0b11'11'01'01);
- *resulta = abcd[0];
- *resultb = abcd[2];
- *resultc = abcd[4];
- *resultd = abcd[6];
+ *resulta = abcd.vect_f32[0][0];
+ *resultb = abcd.vect_f32[0][2];
+ *resultc = abcd.vect_f32[1][0];
+ *resultd = abcd.vect_f32[1][2];
}
SCANN_AVX1_INLINE void HorizontalSum2X(Avx1<double> a, Avx1<double> b,
@@ -175,8 +175,8 @@ SCANN_AVX1_INLINE void HorizontalSum2X(Avx1<double> a, Avx1<double> b,
sum += _mm256_shuffle_pd(sum, sum, 0b11'11);
- *resulta = sum[0];
- *resultb = sum[2];
+ *resulta = sum.vect_f64[0][0];
+ *resultb = sum.vect_f64[0][2];
}
SCANN_AVX1_INLINE void HorizontalSum4X(Avx1<double> a, Avx1<double> b,
diff --git a/scann/scann/utils/intrinsics/sse4.h b/scann/scann/utils/intrinsics/sse4.h
index b99ac792e..e3d098e8c 100644
--- a/scann/scann/utils/intrinsics/sse4.h
+++ b/scann/scann/utils/intrinsics/sse4.h
@@ -24,10 +24,8 @@
#include "scann/utils/intrinsics/flags.h"
#include "scann/utils/types.h"
-#ifdef __x86_64__
-
-#include <emmintrin.h>
-#include <x86intrin.h>
+//#if 1 // #ifdef __x86_64__
+#include "avx2ki.h"
namespace research_scann {
namespace sse4 {
@@ -715,7 +713,8 @@ class Sse4<T, kNumRegistersInferred> {
const auto& me = *this;
if constexpr (IsSameAny<T, float, double>()) {
- return (*me[0])[0];
+ //return (*me[0])[0];
+ return *(T *)&me;
}
if constexpr (IsSameAny<T, int8_t, uint8_t>()) {
@@ -728,7 +727,8 @@ class Sse4<T, kNumRegistersInferred> {
return _mm_cvtsi128_si32(*me[0]);
}
if constexpr (IsSameAny<T, int64_t, uint64_t>()) {
- return (*me[0])[0];
+ //return (*me[0])[0];
+ return *(T *)&me;
}
LOG(FATAL) << "Undefined";
}
@@ -802,8 +802,8 @@ class Sse4<T, kNumRegistersInferred> {
static_assert(!IsSame<T, double>(), "Nothing to expand to");
if constexpr (!IsSameAny<T, float, double>()) {
- __m128 hi = _mm_srli_si128(x, 8);
- __m128 lo = x;
+ __m128i hi = _mm_srli_si128(x, 8);
+ __m128i lo = x;
if constexpr (IsSame<T, int8_t>()) {
return std::make_pair(_mm_cvtepi8_epi16(lo), _mm_cvtepi8_epi16(hi));
@@ -992,14 +992,14 @@ using Uninitialized = Sse4Uninitialized;
} // namespace sse4
} // namespace research_scann
-#else
+// #else
-namespace research_scann {
+// namespace research_scann {
-template <typename T, size_t... kTensorNumRegisters>
-struct Sse4;
+// template <typename T, size_t... kTensorNumRegisters>
+// struct Sse4;
-}
+// }
-#endif
+// #endif
#endif
--
2.33.0
父主题: 更多资源