/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.measure.topn;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.kylin.common.util.ByteArray;
import org.apache.kylin.common.util.Dictionary;
import org.apache.kylin.common.util.Pair;
import org.apache.kylin.dimension.DateDimEnc;
import org.apache.kylin.dimension.DictionaryDimEnc;
import org.apache.kylin.dimension.DimensionEncoding;
import org.apache.kylin.dimension.DimensionEncodingFactory;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.measure.MeasureAggregator;
import org.apache.kylin.measure.MeasureIngester;
import org.apache.kylin.measure.MeasureType;
import org.apache.kylin.measure.MeasureTypeFactory;
import org.apache.kylin.measure.topn.Counter;
import org.apache.kylin.measure.topn.TopNAggregator;
import org.apache.kylin.measure.topn.TopNCounter;
import org.apache.kylin.measure.topn.TopNCounterSerializer;
import org.apache.kylin.metadata.datatype.DataType;
import org.apache.kylin.metadata.datatype.DataTypeSerializer;
import org.apache.kylin.metadata.model.FunctionDesc;
import org.apache.kylin.metadata.model.MeasureDesc;
import org.apache.kylin.metadata.model.ParameterDesc;
import org.apache.kylin.metadata.model.TblColRef;
import org.apache.kylin.metadata.realization.CapabilityResult;
import org.apache.kylin.metadata.realization.SQLDigest;
import org.apache.kylin.metadata.tuple.Tuple;
import org.apache.kylin.metadata.tuple.TupleInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TopNMeasureType
extends MeasureType<TopNCounter<ByteArray>> {
    private static final Logger logger = LoggerFactory.getLogger(TopNMeasureType.class);
    public static final String FUNC_TOP_N = "TOP_N";
    public static final String DATATYPE_TOPN = "topn";
    public static final String CONFIG_ENCODING_PREFIX = "topn.encoding.";
    public static final String CONFIG_ENCODING_VERSION_PREFIX = "topn.encoding_version.";
    private final DataType dataType;

    public TopNMeasureType(String funcName, DataType dataType) {
        this.dataType = dataType;
    }

    @Override
    public void validate(FunctionDesc functionDesc) throws IllegalArgumentException {
        this.validate(functionDesc.getExpression(), functionDesc.getReturnDataType());
    }

    private void validate(String funcName, DataType dataType) {
        if (!FUNC_TOP_N.equals(funcName)) {
            throw new IllegalArgumentException();
        }
        if (!DATATYPE_TOPN.equals(dataType.getName())) {
            throw new IllegalArgumentException();
        }
        if (dataType.getPrecision() < 1 || dataType.getPrecision() > 10000) {
            throw new IllegalArgumentException();
        }
    }

    public static String getRewriteName(FunctionDesc func) {
        return TopNMeasureType.getTopnInternalMeasure(func).getRewriteFieldName();
    }

    public static FunctionDesc getTopnInternalMeasure(FunctionDesc func) {
        if (func.getParameters().get(0).isColumnType()) {
            return FunctionDesc.newInstance("SUM", Lists.newArrayList((Object[])new ParameterDesc[]{ParameterDesc.newInstance(func.getParameters().get(0).getColRef())}), null);
        }
        return FunctionDesc.newInstance("COUNT", Lists.newArrayList((Object[])new ParameterDesc[]{ParameterDesc.newInstance("1")}), null);
    }

    @Override
    public boolean isMemoryHungry() {
        return true;
    }

    @Override
    public MeasureIngester<TopNCounter<ByteArray>> newIngester() {
        return new MeasureIngester<TopNCounter<ByteArray>>(){
            private DimensionEncoding[] dimensionEncodings = null;
            private List<TblColRef> literalCols = null;
            private int keyLength = 0;
            private DimensionEncoding[] newDimensionEncodings = null;
            private int newKeyLength = 0;
            private boolean needReEncode = true;

            @Override
            public TopNCounter<ByteArray> valueOf(String[] values, MeasureDesc measureDesc, Map<TblColRef, Dictionary<String>> dictionaryMap) {
                double counter;
                double d = counter = values[0] == null ? 0.0 : Double.parseDouble(values[0]);
                if (this.dimensionEncodings == null) {
                    this.literalCols = TopNMeasureType.this.getTopNLiteralColumn(measureDesc.getFunction());
                    for (DimensionEncoding encoding : this.dimensionEncodings = TopNMeasureType.getDimensionEncodings(measureDesc.getFunction(), this.literalCols, dictionaryMap)) {
                        this.keyLength += encoding.getLengthOfEncoding();
                    }
                    if (values.length != this.literalCols.size() + 1) {
                        throw new IllegalArgumentException();
                    }
                }
                ByteArray key = new ByteArray(this.keyLength);
                int offset = 0;
                for (int i = 0; i < this.dimensionEncodings.length; ++i) {
                    if (values[i + 1] == null) {
                        Arrays.fill(key.array(), offset, offset + this.dimensionEncodings[i].getLengthOfEncoding(), (byte)-1);
                    } else {
                        this.dimensionEncodings[i].encode(values[i + 1], key.array(), offset);
                    }
                    offset += this.dimensionEncodings[i].getLengthOfEncoding();
                }
                TopNCounter<ByteArray> topNCounter = new TopNCounter<ByteArray>(TopNMeasureType.this.dataType.getPrecision() * TopNCounter.EXTRA_SPACE_RATE);
                topNCounter.offer(key, counter);
                return topNCounter;
            }

            @Override
            public TopNCounter<ByteArray> reEncodeDictionary(TopNCounter<ByteArray> topNCounter, MeasureDesc measureDesc, Map<TblColRef, Dictionary<String>> oldDicts, Map<TblColRef, Dictionary<String>> newDicts) {
                if (this.newDimensionEncodings == null) {
                    this.literalCols = TopNMeasureType.this.getTopNLiteralColumn(measureDesc.getFunction());
                    this.dimensionEncodings = TopNMeasureType.getDimensionEncodings(measureDesc.getFunction(), this.literalCols, oldDicts);
                    this.keyLength = 0;
                    boolean hasDictEncoding = false;
                    for (DimensionEncoding dimensionEncoding : this.dimensionEncodings) {
                        this.keyLength += dimensionEncoding.getLengthOfEncoding();
                        if (!(dimensionEncoding instanceof DictionaryDimEnc)) continue;
                        hasDictEncoding = true;
                    }
                    this.newDimensionEncodings = TopNMeasureType.getDimensionEncodings(measureDesc.getFunction(), this.literalCols, newDicts);
                    this.newKeyLength = 0;
                    for (DimensionEncoding dimensionEncoding : this.newDimensionEncodings) {
                        this.newKeyLength += dimensionEncoding.getLengthOfEncoding();
                    }
                    this.needReEncode = hasDictEncoding;
                }
                if (!this.needReEncode) {
                    return topNCounter;
                }
                int topNSize = topNCounter.size();
                byte[] newIdBuf = new byte[topNSize * this.newKeyLength];
                int bufOffset = 0;
                for (Counter<ByteArray> counter : topNCounter) {
                    int offset = counter.getItem().offset();
                    int innerBuffOffset = 0;
                    for (int i = 0; i < this.dimensionEncodings.length; ++i) {
                        String dimValue = this.dimensionEncodings[i].decode(counter.getItem().array(), offset, this.dimensionEncodings[i].getLengthOfEncoding());
                        this.newDimensionEncodings[i].encode(dimValue, newIdBuf, bufOffset + innerBuffOffset);
                        innerBuffOffset += this.newDimensionEncodings[i].getLengthOfEncoding();
                        offset += this.dimensionEncodings[i].getLengthOfEncoding();
                    }
                    counter.getItem().reset(newIdBuf, bufOffset, this.newKeyLength);
                    bufOffset += this.newKeyLength;
                }
                return topNCounter;
            }
        };
    }

    @Override
    public MeasureAggregator<TopNCounter<ByteArray>> newAggregator() {
        return new TopNAggregator();
    }

    @Override
    public List<TblColRef> getColumnsNeedDictionary(FunctionDesc functionDesc) {
        int start;
        ArrayList columnsNeedDict = Lists.newArrayList();
        List<TblColRef> allCols = functionDesc.getColRefs();
        for (int i = start = functionDesc.getParameters().get(0).isColumnType() ? 1 : 0; i < allCols.size(); ++i) {
            TblColRef tblColRef = allCols.get(i);
            String encoding = (String)TopNMeasureType.getEncoding(functionDesc, tblColRef).getFirst();
            if (!StringUtils.isEmpty((CharSequence)encoding) && !"dict".equals(encoding)) continue;
            columnsNeedDict.add(tblColRef);
        }
        return columnsNeedDict;
    }

    @Override
    public CapabilityResult.CapabilityInfluence influenceCapabilityCheck(Collection<TblColRef> unmatchedDimensions, Collection<FunctionDesc> unmatchedAggregations, SQLDigest digest, final MeasureDesc topN) {
        List<TblColRef> literalCol = this.getTopNLiteralColumn(topN.getFunction());
        for (TblColRef colRef : literalCol) {
            if (!digest.getFilterColumns().contains(colRef)) continue;
            return null;
        }
        if (!new HashSet<TblColRef>(digest.getGroupByColumns()).containsAll(literalCol) || !new HashSet<TblColRef>(literalCol).containsAll(digest.getGroupByColumns())) {
            return null;
        }
        if (digest.getAggregations().size() == 1) {
            FunctionDesc onlyFunction = digest.getAggregations().iterator().next();
            if (!this.isTopNCompatibleSum(topN.getFunction(), onlyFunction)) {
                return null;
            }
            if (!this.checkOrderByAndLimit(digest, topN.getFunction().getReturnDataType().getPrecision())) {
                return null;
            }
            unmatchedDimensions.removeAll(literalCol);
            unmatchedAggregations.remove(onlyFunction);
            return new CapabilityResult.CapabilityInfluence(){

                @Override
                public double suggestCostMultiplier() {
                    return 0.3;
                }

                @Override
                public MeasureDesc getInvolvedMeasure() {
                    return topN;
                }
            };
        }
        return null;
    }

    private boolean checkOrderByAndLimit(SQLDigest digest, int topNPrecision) {
        if (digest.getLimit() > topNPrecision) {
            return false;
        }
        if (digest.getSortColumns().size() != 1 || digest.getSortOrders().get(0) != SQLDigest.OrderEnum.DESCENDING) {
            return false;
        }
        TblColRef sortCol = digest.getSortColumns().get(0);
        for (FunctionDesc agg : digest.getAggregations()) {
            if (sortCol.getName().equals(agg.getRewriteFieldName())) {
                return true;
            }
            if (!sortCol.isInnerColumn() || sortCol.getOperator() == null || !sortCol.getOperator().getName().equals(agg.getExpression()) || !sortCol.getOperands().equals(agg.getColRefs())) continue;
            return true;
        }
        return false;
    }

    private boolean isTopNCompatibleSum(FunctionDesc topN, FunctionDesc sum) {
        if (sum == null) {
            return false;
        }
        if (!this.isTopN(topN)) {
            return false;
        }
        TblColRef topnNumCol = this.getTopNNumericColumn(topN);
        if (topnNumCol == null) {
            return sum.isCount();
        }
        if (!sum.isSum()) {
            return false;
        }
        if (CollectionUtils.isEmpty(sum.getParameters()) || CollectionUtils.isEmpty(sum.getColRefs())) {
            return false;
        }
        TblColRef sumCol = sum.getColRefs().get(0);
        return sumCol.equals(topnNumCol);
    }

    @Override
    public boolean needRewrite() {
        return true;
    }

    @Override
    public void adjustSqlDigest(MeasureDesc involvedMeasure, SQLDigest sqlDigest) {
        List<FunctionDesc> sqlDigestAggregations = sqlDigest.getAggregations();
        if (sqlDigestAggregations.size() > 1) {
            return;
        }
        FunctionDesc topnFunc = involvedMeasure.getFunction();
        List<TblColRef> topnLiteralCol = this.getTopNLiteralColumn(topnFunc);
        if (!new HashSet<TblColRef>(sqlDigest.getGroupByColumns()).containsAll(topnLiteralCol) || !new HashSet<TblColRef>(topnLiteralCol).containsAll(sqlDigest.getGroupByColumns())) {
            return;
        }
        if (!sqlDigestAggregations.isEmpty()) {
            FunctionDesc origFunc = sqlDigestAggregations.iterator().next();
            if (!origFunc.isSum() && !origFunc.isCount()) {
                logger.warn("When query with topN, only SUM/Count function is allowed.");
                return;
            }
            if (!this.isTopNCompatibleSum(involvedMeasure.getFunction(), origFunc)) {
                return;
            }
            logger.info("Rewrite function {} to {}", (Object)origFunc, (Object)topnFunc);
        }
        sqlDigest.setAggregations(Lists.newArrayList((Object[])new FunctionDesc[]{topnFunc}));
        sqlDigest.getGroupByColumns().removeAll(topnLiteralCol);
        sqlDigest.getMetricColumns().addAll(topnLiteralCol);
    }

    @Override
    public boolean needAdvancedTupleFilling() {
        return true;
    }

    @Override
    public void fillTupleSimply(Tuple tuple, int indexInTuple, Object measureValue) {
        throw new UnsupportedOperationException();
    }

    @Override
    public MeasureType.IAdvMeasureFiller getAdvancedTupleFiller(FunctionDesc function, TupleInfo tupleInfo, Map<TblColRef, Dictionary<String>> dictionaryMap) {
        int numericTupleIdx;
        List<TblColRef> literalCols = this.getTopNLiteralColumn(function);
        TblColRef numericCol = this.getTopNNumericColumn(function);
        final int[] literalTupleIdx = new int[literalCols.size()];
        final DimensionEncoding[] dimensionEncodings = TopNMeasureType.getDimensionEncodings(function, literalCols, dictionaryMap);
        for (int i = 0; i < literalCols.size(); ++i) {
            TblColRef colRef = literalCols.get(i);
            literalTupleIdx[i] = tupleInfo.hasColumn(colRef) ? tupleInfo.getColumnIndex(colRef) : -1;
        }
        if (numericCol != null) {
            FunctionDesc sumFunc = FunctionDesc.newInstance("SUM", Lists.newArrayList((Object[])new ParameterDesc[]{ParameterDesc.newInstance(numericCol)}), numericCol.getType().toString());
            String sumFieldName = sumFunc.getRewriteFieldName();
            numericTupleIdx = tupleInfo.hasField(sumFieldName) ? tupleInfo.getFieldIndex(sumFieldName) : -1;
        } else {
            FunctionDesc countFunction = FunctionDesc.newInstance("COUNT", Lists.newArrayList((Object[])new ParameterDesc[]{ParameterDesc.newInstance("1")}), "bigint");
            numericTupleIdx = tupleInfo.getFieldIndex(countFunction.getRewriteFieldName());
        }
        return new MeasureType.IAdvMeasureFiller(){
            private TopNCounter<ByteArray> topNCounter;
            private Iterator<Counter<ByteArray>> topNCounterIterator;
            private int expectRow = 0;

            @Override
            public void reload(Object measureValue) {
                this.topNCounter = (TopNCounter)measureValue;
                this.topNCounterIterator = this.topNCounter.iterator();
                this.expectRow = 0;
            }

            @Override
            public int getNumOfRows() {
                return this.topNCounter.size();
            }

            @Override
            public void fillTuple(Tuple tuple, int row) {
                if (this.expectRow++ != row) {
                    throw new IllegalStateException();
                }
                Counter<ByteArray> counter = this.topNCounterIterator.next();
                int offset = counter.getItem().offset();
                for (int i = 0; i < dimensionEncodings.length; ++i) {
                    String colValue = dimensionEncodings[i].decode(counter.getItem().array(), offset, dimensionEncodings[i].getLengthOfEncoding());
                    tuple.setDimensionValue(literalTupleIdx[i], colValue);
                    offset += dimensionEncodings[i].getLengthOfEncoding();
                }
                tuple.setMeasureValue(numericTupleIdx, counter.getCount());
            }
        };
    }

    private static DimensionEncoding[] getDimensionEncodings(FunctionDesc function, List<TblColRef> literalCols, Map<TblColRef, Dictionary<String>> dictionaryMap) {
        DimensionEncoding[] dimensionEncodings = new DimensionEncoding[literalCols.size()];
        for (int i = 0; i < literalCols.size(); ++i) {
            TblColRef colRef = literalCols.get(i);
            Pair<String, String> topNEncoding = TopNMeasureType.getEncoding(function, colRef);
            String encoding = (String)topNEncoding.getFirst();
            String encodingVersionStr = (String)topNEncoding.getSecond();
            if (StringUtils.isEmpty((CharSequence)encoding) || "dict".equals(encoding)) {
                dimensionEncodings[i] = new DictionaryDimEnc(dictionaryMap.get(colRef));
                continue;
            }
            int encodingVersion = 1;
            if (!StringUtils.isEmpty((CharSequence)encodingVersionStr)) {
                try {
                    encodingVersion = Integer.parseInt(encodingVersionStr);
                }
                catch (NumberFormatException e) {
                    throw new RuntimeException(CONFIG_ENCODING_VERSION_PREFIX + colRef.getName() + " has to be an integer");
                }
            }
            Object[] encodingConf = DimensionEncoding.parseEncodingConf(encoding);
            String encodingName = (String)encodingConf[0];
            String[] encodingArgs = (String[])encodingConf[1];
            encodingArgs = DateDimEnc.replaceEncodingArgs(encoding, encodingArgs, encodingName, literalCols.get(i).getType());
            dimensionEncodings[i] = DimensionEncodingFactory.create(encodingName, encodingArgs, encodingVersion);
        }
        return dimensionEncodings;
    }

    private TblColRef getTopNNumericColumn(FunctionDesc functionDesc) {
        if (functionDesc.getParameters().get(0).isColumnType()) {
            return functionDesc.getColRefs().get(0);
        }
        return null;
    }

    private List<TblColRef> getTopNLiteralColumn(FunctionDesc functionDesc) {
        List<TblColRef> allColumns = functionDesc.getColRefs();
        if (!functionDesc.getParameters().get(0).isColumnType()) {
            return allColumns;
        }
        return allColumns.subList(1, allColumns.size());
    }

    private boolean isTopN(FunctionDesc functionDesc) {
        return FUNC_TOP_N.equalsIgnoreCase(functionDesc.getExpression());
    }

    public static Pair<String, String> getEncoding(FunctionDesc functionDesc, TblColRef tblColRef) {
        String encoding = functionDesc.getConfiguration().get(CONFIG_ENCODING_PREFIX + tblColRef.getIdentity());
        String encodingVersion = functionDesc.getConfiguration().get(CONFIG_ENCODING_VERSION_PREFIX + tblColRef.getIdentity());
        if (StringUtils.isEmpty((CharSequence)encoding)) {
            encoding = functionDesc.getConfiguration().get(CONFIG_ENCODING_PREFIX + tblColRef.getName());
            encodingVersion = functionDesc.getConfiguration().get(CONFIG_ENCODING_VERSION_PREFIX + tblColRef.getName());
        }
        return new Pair((Object)encoding, (Object)encodingVersion);
    }

    public static void fixMeasureReturnType(MeasureDesc measureDesc) {
        Map<String, String> configuration = measureDesc.getFunction().getConfiguration();
        List<ParameterDesc> parameters = measureDesc.getFunction().getParameters();
        int keyLength = 0;
        for (ParameterDesc parameter : parameters.subList(1, parameters.size())) {
            String encoding = configuration.get(CONFIG_ENCODING_PREFIX + parameter.getValue());
            String encodingVersionStr = configuration.get(CONFIG_ENCODING_VERSION_PREFIX + parameter.getValue());
            if (StringUtils.isEmpty((CharSequence)encoding) || "dict".equals(encoding)) {
                keyLength += 4;
                continue;
            }
            int encodingVersion = 1;
            if (!StringUtils.isEmpty((CharSequence)encodingVersionStr)) {
                try {
                    encodingVersion = Integer.parseInt(encodingVersionStr);
                }
                catch (NumberFormatException e) {
                    throw new RuntimeException("invalid encoding version: " + encodingVersionStr);
                }
            }
            Object[] encodingConf = DimensionEncoding.parseEncodingConf(encoding);
            DimensionEncoding dimensionEncoding = DimensionEncodingFactory.create((String)encodingConf[0], (String[])encodingConf[1], encodingVersion);
            keyLength += dimensionEncoding.getLengthOfEncoding();
        }
        DataType returnType = DataType.getType(measureDesc.getFunction().getReturnType());
        DataType newReturnType = new DataType(returnType.getName(), returnType.getPrecision(), keyLength);
        measureDesc.getFunction().setReturnType(newReturnType.toString());
    }

    public static class Factory
    extends MeasureTypeFactory<TopNCounter<ByteArray>> {
        @Override
        public MeasureType<TopNCounter<ByteArray>> createMeasureType(String funcName, DataType dataType) {
            return new TopNMeasureType(funcName, dataType);
        }

        @Override
        public String getAggrFunctionName() {
            return TopNMeasureType.FUNC_TOP_N;
        }

        @Override
        public String getAggrDataTypeName() {
            return TopNMeasureType.DATATYPE_TOPN;
        }

        @Override
        public Class<? extends DataTypeSerializer<TopNCounter<ByteArray>>> getAggrDataTypeSerializer() {
            return TopNCounterSerializer.class;
        }
    }
}

