/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.internal.vectorization;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.ByteOrder;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.FloatToFloatFunction;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;

class Lucene99MemorySegmentScalarQuantizedVectorScorer
implements FlatVectorsScorer {
    static final Lucene99MemorySegmentScalarQuantizedVectorScorer INSTANCE = new Lucene99MemorySegmentScalarQuantizedVectorScorer();
    private static final FlatVectorsScorer DELEGATE = new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);

    private Lucene99MemorySegmentScalarQuantizedVectorScorer() {
    }

    @Override
    public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException {
        QuantizedByteVectorValues quantized;
        IndexInput indexInput;
        if (vectorValues instanceof QuantizedByteVectorValues && (indexInput = (quantized = (QuantizedByteVectorValues)vectorValues).getSlice()) instanceof MemorySegmentAccessInput) {
            MemorySegmentAccessInput input = (MemorySegmentAccessInput)((Object)indexInput);
            return new RandomVectorScorerSupplierImpl(similarityFunction, quantized, input);
        }
        return DELEGATE.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
    }

    @Override
    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) throws IOException {
        QuantizedByteVectorValues quantized;
        IndexInput indexInput;
        if (vectorValues instanceof QuantizedByteVectorValues && (indexInput = (quantized = (QuantizedByteVectorValues)vectorValues).getSlice()) instanceof MemorySegmentAccessInput) {
            MemorySegmentAccessInput input = (MemorySegmentAccessInput)((Object)indexInput);
            return new RandomVectorScorerImpl(similarityFunction, quantized, input, target);
        }
        return DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target);
    }

    @Override
    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException {
        return DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target);
    }

    public String toString() {
        return "Lucene99MemorySegmentScalarQuantizedVectorScorer()";
    }

    private record RandomVectorScorerSupplierImpl(VectorSimilarityFunction similarityFunction, QuantizedByteVectorValues values, MemorySegmentAccessInput input) implements RandomVectorScorerSupplier
    {
        @Override
        public UpdateableRandomVectorScorer scorer() {
            return new UpdateableRandomVectorScorerImpl(this.similarityFunction, this.values, this.input);
        }

        @Override
        public RandomVectorScorerSupplier copy() {
            return new RandomVectorScorerSupplierImpl(this.similarityFunction, this.values, this.input);
        }
    }

    private static class RandomVectorScorerImpl
    extends RandomVectorScorerBase {
        private final byte[] targetBytes;
        private final float queryOffset;

        RandomVectorScorerImpl(VectorSimilarityFunction similarityFunction, QuantizedByteVectorValues values, MemorySegmentAccessInput input, float[] target) {
            super(similarityFunction, values, input);
            this.targetBytes = new byte[target.length];
            this.queryOffset = ScalarQuantizedVectorScorer.quantizeQuery(target, this.targetBytes, similarityFunction, this.getQuantizer());
        }

        @Override
        public float score(int node) throws IOException {
            return this.scoreBody(node, this.queryOffset);
        }

        @Override
        int euclidean(MemorySegment doc) {
            return PanamaVectorUtilSupport.uint8SquareDistance(this.targetBytes, doc);
        }

        @Override
        int int4Euclidean(MemorySegment doc) {
            return PanamaVectorUtilSupport.int4SquareDistance(this.targetBytes, doc);
        }

        @Override
        int compressedInt4Euclidean(MemorySegment doc) {
            return PanamaVectorUtilSupport.int4SquareDistanceSinglePacked(this.targetBytes, doc);
        }

        @Override
        int dotProduct(MemorySegment doc) {
            return PanamaVectorUtilSupport.uint8DotProduct(this.targetBytes, doc);
        }

        @Override
        int int4DotProduct(MemorySegment doc) {
            return PanamaVectorUtilSupport.int4DotProduct(this.targetBytes, doc);
        }

        @Override
        int compressedInt4DotProduct(MemorySegment doc) {
            return PanamaVectorUtilSupport.int4DotProductSinglePacked(this.targetBytes, doc);
        }
    }

    private static class UpdateableRandomVectorScorerImpl
    extends RandomVectorScorerBase
    implements UpdateableRandomVectorScorer {
        private MemorySegment query;
        private float queryOffset;

        UpdateableRandomVectorScorerImpl(VectorSimilarityFunction similarityFunction, QuantizedByteVectorValues values, MemorySegmentAccessInput input) {
            super(similarityFunction, values, input);
        }

        @Override
        public void setScoringOrdinal(int ord) throws IOException {
            this.checkOrdinal(ord);
            RandomVectorScorerBase.Node node = this.getNode(ord);
            this.query = node.vector;
            this.queryOffset = node.offset;
        }

        @Override
        public float score(int node) throws IOException {
            return this.scoreBody(node, this.queryOffset);
        }

        @Override
        int euclidean(MemorySegment doc) {
            return PanamaVectorUtilSupport.uint8SquareDistance(this.query, doc);
        }

        @Override
        int int4Euclidean(MemorySegment doc) {
            return PanamaVectorUtilSupport.int4SquareDistance(this.query, doc);
        }

        @Override
        int compressedInt4Euclidean(MemorySegment doc) {
            return PanamaVectorUtilSupport.int4SquareDistanceBothPacked(this.query, doc);
        }

        @Override
        int dotProduct(MemorySegment doc) {
            return PanamaVectorUtilSupport.uint8DotProduct(this.query, doc);
        }

        @Override
        int int4DotProduct(MemorySegment doc) {
            return PanamaVectorUtilSupport.int4DotProduct(this.query, doc);
        }

        @Override
        int compressedInt4DotProduct(MemorySegment doc) {
            return PanamaVectorUtilSupport.int4DotProductBothPacked(this.query, doc);
        }
    }

    private static abstract class RandomVectorScorerBase
    extends RandomVectorScorer.AbstractRandomVectorScorer {
        private final ScalarQuantizer quantizer;
        private final float constMultiplier;
        private final MemorySegmentAccessInput input;
        private final int vectorByteSize;
        private final int nodeSize;
        private final Scorer scorer;
        private final FloatToFloatFunction scaler;
        private byte[] scratch;
        private static final ValueLayout.OfInt INT_UNALIGNED_LE = ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);

        RandomVectorScorerBase(VectorSimilarityFunction similarityFunction, QuantizedByteVectorValues values, MemorySegmentAccessInput input) {
            super(values);
            this.quantizer = values.getScalarQuantizer();
            this.constMultiplier = this.quantizer.getConstantMultiplier();
            this.input = input;
            this.vectorByteSize = values.getVectorByteLength();
            this.nodeSize = this.vectorByteSize + 4;
            this.scorer = switch (similarityFunction) {
                default -> throw new MatchException(null, null);
                case VectorSimilarityFunction.EUCLIDEAN -> {
                    if (this.quantizer.getBits() <= 4) {
                        if (this.vectorByteSize != values.dimension()) {
                            yield this::compressedInt4Euclidean;
                        }
                        yield this::int4Euclidean;
                    }
                    yield this::euclidean;
                }
                case VectorSimilarityFunction.DOT_PRODUCT, VectorSimilarityFunction.COSINE, VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT -> {
                    if (this.quantizer.getBits() <= 4) {
                        if (this.vectorByteSize != values.dimension()) {
                            yield this::compressedInt4DotProduct;
                        }
                        yield this::int4DotProduct;
                    }
                    yield this::dotProduct;
                }
            };
            this.scaler = switch (similarityFunction) {
                default -> throw new MatchException(null, null);
                case VectorSimilarityFunction.EUCLIDEAN -> VectorUtil::normalizeDistanceToUnitInterval;
                case VectorSimilarityFunction.DOT_PRODUCT, VectorSimilarityFunction.COSINE -> VectorUtil::normalizeToUnitInterval;
                case VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT -> VectorUtil::scaleMaxInnerProductScore;
            };
            this.checkInvariants();
        }

        final void checkInvariants() {
            if (this.input.length() < (long)this.nodeSize * (long)this.maxOrd()) {
                throw new IllegalArgumentException("input length is less than expected vector data");
            }
        }

        final void checkOrdinal(int ord) {
            if (ord < 0 || ord >= this.maxOrd()) {
                throw new IllegalArgumentException("illegal ordinal: " + ord);
            }
        }

        ScalarQuantizer getQuantizer() {
            return this.quantizer;
        }

        Node getNode(int ord) throws IOException {
            this.checkOrdinal(ord);
            long byteOffset = (long)ord * (long)this.nodeSize;
            MemorySegment node = this.input.segmentSliceOrNull(byteOffset, this.nodeSize);
            if (node == null) {
                if (this.scratch == null) {
                    this.scratch = new byte[this.nodeSize];
                }
                this.input.readBytes(byteOffset, this.scratch, 0, this.nodeSize);
                node = MemorySegment.ofArray(this.scratch);
            }
            return new Node(node.reinterpret(this.vectorByteSize), Float.intBitsToFloat(node.get(INT_UNALIGNED_LE, (long)this.vectorByteSize)));
        }

        float scoreBody(int ord, float queryOffset) throws IOException {
            this.checkOrdinal(ord);
            Node node = this.getNode(ord);
            return this.scaler.apply((float)this.scorer.score(node.vector) * this.constMultiplier + node.offset + queryOffset);
        }

        abstract int euclidean(MemorySegment var1);

        abstract int int4Euclidean(MemorySegment var1);

        abstract int compressedInt4Euclidean(MemorySegment var1);

        abstract int dotProduct(MemorySegment var1);

        abstract int int4DotProduct(MemorySegment var1);

        abstract int compressedInt4DotProduct(MemorySegment var1);

        @FunctionalInterface
        private static interface Scorer {
            public int score(MemorySegment var1) throws IOException;
        }

        record Node(MemorySegment vector, float offset) {
        }
    }
}

