/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.optrule;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.kylin.common.util.Pair;
import org.apache.kylin.guava30.shaded.common.collect.ImmutableList;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.metadata.datatype.DataType;
import org.apache.kylin.query.calcite.KylinSumSplitter;
import org.apache.kylin.query.relnode.ContextUtil;
import org.apache.kylin.query.util.AggExpressionUtil;
import org.apache.kylin.query.util.RuleUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractAggCaseWhenFunctionRule
extends RelOptRule {
    private static final Logger logger = LoggerFactory.getLogger(AbstractAggCaseWhenFunctionRule.class);
    private static final String BOTTOM_AGG_PREFIX = "SUB_AGG$";

    protected AbstractAggCaseWhenFunctionRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    public boolean matches(RelOptRuleCall ruleCall) {
        Aggregate oldAgg = (Aggregate)ruleCall.rel(0);
        Project oldProject = (Project)ruleCall.rel(1);
        return this.checkAggCaseExpression(oldAgg, oldProject);
    }

    public void onMatch(RelOptRuleCall ruleCall) {
        try {
            RelBuilder relBuilder = ruleCall.builder().transform(c -> c.withPruneInputOfAggregate(false));
            Aggregate originalAgg = (Aggregate)ruleCall.rel(0);
            Project originalProject = (Project)ruleCall.rel(1);
            List<AggregateCall> applicableAggCalls = originalAgg.getAggCallList().stream().filter(aggCall -> this.isApplicableWithSumCaseRule((AggregateCall)aggCall, originalProject)).collect(Collectors.toList());
            LinkedList<AggregateCall> nonApplicableAggCalls = new LinkedList<AggregateCall>(originalAgg.getAggCallList());
            nonApplicableAggCalls.removeAll(applicableAggCalls);
            Aggregate sumCaseAgg = this.extractPartialAggregateCalls(relBuilder, originalAgg, originalProject, applicableAggCalls);
            RelNode transformedSumCaseAgg = this.transformSumExprAggregate(relBuilder, sumCaseAgg, (Project)sumCaseAgg.getInput(0));
            if (nonApplicableAggCalls.isEmpty()) {
                ruleCall.transformTo(transformedSumCaseAgg);
                return;
            }
            Aggregate nonSumCaseAgg = this.extractPartialAggregateCalls(relBuilder, originalAgg, originalProject, nonApplicableAggCalls);
            RelNode joined = this.joinAggCaseWhenAndNonAggCaseWhenRel(relBuilder, transformedSumCaseAgg, nonSumCaseAgg, originalAgg);
            ContextUtil.dumpCalcitePlan((String)"new plan", (RelNode)joined, (Logger)logger);
            ruleCall.transformTo(joined);
        }
        catch (Error | Exception e) {
            logger.error("sql cannot apply sum case when rule ", e);
        }
    }

    private RelNode joinAggCaseWhenAndNonAggCaseWhenRel(RelBuilder relBuilder, RelNode sumCaseAgg, Aggregate nonSumCaseAgg, Aggregate originalAgg) {
        int i;
        relBuilder.push((RelNode)nonSumCaseAgg);
        relBuilder.push(sumCaseAgg);
        List leftFields = nonSumCaseAgg.getRowType().getFieldList();
        List rightFields = sumCaseAgg.getRowType().getFieldList();
        int nonAggExprListSize = leftFields.size() - nonSumCaseAgg.getAggCallList().size();
        LinkedList<RexNode> joinConds = new LinkedList<RexNode>();
        RexBuilder rexBuilder = relBuilder.getRexBuilder();
        for (int i2 = 0; i2 < nonAggExprListSize; ++i2) {
            joinConds.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, new RexNode[]{rexBuilder.makeInputRef(((RelDataTypeField)leftFields.get(i2)).getType(), i2), rexBuilder.makeInputRef(((RelDataTypeField)leftFields.get(i2)).getType(), i2 + leftFields.size())}));
        }
        relBuilder.join(JoinRelType.INNER, joinConds);
        LinkedList<RexInputRef> projectNodes = new LinkedList<RexInputRef>();
        for (i = 0; i < nonAggExprListSize; ++i) {
            projectNodes.add(rexBuilder.makeInputRef(((RelDataTypeField)leftFields.get(i)).getType(), i));
        }
        int j = 0;
        int k = 0;
        for (i = 0; i < originalAgg.getAggCallList().size(); ++i) {
            if (j < nonSumCaseAgg.getAggCallList().size() && originalAgg.getAggCallList().get(i) == nonSumCaseAgg.getAggCallList().get(j)) {
                projectNodes.add(rexBuilder.makeInputRef(((RelDataTypeField)leftFields.get(nonAggExprListSize + j)).getType(), nonAggExprListSize + j));
                ++j;
                continue;
            }
            projectNodes.add(rexBuilder.makeInputRef(((RelDataTypeField)rightFields.get(nonAggExprListSize + k)).getType(), leftFields.size() + nonAggExprListSize + k));
            ++k;
        }
        relBuilder.project(projectNodes);
        return relBuilder.build();
    }

    private RelNode cloneRelNode(RelNode rel) {
        if (rel instanceof HepRelVertex) {
            return this.cloneRelNode(((HepRelVertex)rel).getCurrentRel());
        }
        return rel.copy(rel.getTraitSet(), rel.getInputs().stream().map(this::cloneRelNode).collect(Collectors.toList()));
    }

    private Aggregate extractPartialAggregateCalls(RelBuilder relBuilder, Aggregate oriAgg, Project oriProject, List<AggregateCall> aggCalls) {
        Set nonSumInputIdxes = aggCalls.stream().flatMap(e -> e.getArgList().stream()).collect(Collectors.toSet());
        nonSumInputIdxes.addAll(oriAgg.getGroupSet().asList());
        relBuilder.push(this.cloneRelNode(oriProject.getInput()));
        ArrayList newChildExps = new ArrayList(oriProject.getProjects().size());
        for (int i = 0; i < oriProject.getProjects().size(); ++i) {
            if (nonSumInputIdxes.contains(i)) {
                newChildExps.add(oriProject.getProjects().get(i));
                continue;
            }
            newChildExps.add(relBuilder.getRexBuilder().makeZeroLiteral(((RexNode)oriProject.getProjects().get(i)).getType()));
        }
        relBuilder.project(newChildExps);
        ImmutableList groupSets = oriAgg.getGroupSets() == null ? ImmutableList.of((Object)oriAgg.getGroupSet()) : oriAgg.getGroupSets();
        relBuilder.aggregate(relBuilder.groupKey(oriAgg.getGroupSet(), (Iterable)groupSets), aggCalls);
        return (Aggregate)relBuilder.build();
    }

    private RelNode transformSumExprAggregate(RelBuilder relBuilder, Aggregate oldAgg, Project oldProject) {
        int i;
        int i2;
        relBuilder.push(oldProject.getInput());
        List<AggExpressionUtil.AggExpression> aggExpressions = AggExpressionUtil.collectSumExpressions(oldAgg, oldProject);
        List sumCaseExprs = aggExpressions.stream().filter(this::isApplicableAggExpression).collect(Collectors.toList());
        Pair<List<AggExpressionUtil.GroupExpression>, ImmutableList<ImmutableBitSet>> groups = AggExpressionUtil.collectGroupExprAndGroup(oldAgg, oldProject);
        List groupExpressions = (List)groups.getFirst();
        ImmutableList newGroupSets = (ImmutableList)groups.getSecond();
        List<RexNode> bottomProjectList = this.buildBottomProject(relBuilder, oldProject, groupExpressions, aggExpressions);
        relBuilder.project(bottomProjectList);
        ImmutableBitSet.Builder groupSetBuilder = ImmutableBitSet.builder();
        for (AggExpressionUtil.GroupExpression group : groupExpressions) {
            for (i2 = 0; i2 < group.getBottomAggInput().length; ++i2) {
                groupSetBuilder.set(group.getBottomAggInput()[i2]);
            }
        }
        for (AggExpressionUtil.AggExpression aggExpression : sumCaseExprs) {
            for (i2 = 0; i2 < aggExpression.getBottomAggConditionsInput().length; ++i2) {
                int conditionIdx = aggExpression.getBottomAggConditionsInput()[i2];
                groupSetBuilder.set(conditionIdx);
            }
        }
        ImmutableBitSet bottomAggGroupSet = groupSetBuilder.build();
        RelBuilder.GroupKey groupKey = relBuilder.groupKey(bottomAggGroupSet);
        List<AggregateCall> aggCalls = this.buildBottomAggregate(relBuilder, aggExpressions, bottomAggGroupSet.cardinality());
        relBuilder.aggregate(groupKey, aggCalls);
        for (AggExpressionUtil.GroupExpression groupExpression : groupExpressions) {
            for (i = 0; i < groupExpression.getTopProjInput().length; ++i) {
                int groupIdx = groupExpression.getBottomAggInput()[i];
                groupExpression.getTopProjInput()[i] = bottomAggGroupSet.indexOf(groupIdx);
            }
        }
        for (AggExpressionUtil.AggExpression aggExpression : sumCaseExprs) {
            for (i = 0; i < aggExpression.getTopProjConditionsInput().length; ++i) {
                int conditionIdx = aggExpression.getBottomAggConditionsInput()[i];
                aggExpression.getTopProjConditionsInput()[i] = bottomAggGroupSet.indexOf(conditionIdx);
            }
        }
        List<RexNode> caseProjList = this.buildTopProject(relBuilder, oldProject, aggExpressions, groupExpressions);
        relBuilder.project(caseProjList);
        ImmutableBitSet.Builder topGroupSetBuilder = ImmutableBitSet.builder();
        for (i = 0; i < groupExpressions.size(); ++i) {
            topGroupSetBuilder.set(i);
        }
        ImmutableBitSet topGroupSet = topGroupSetBuilder.build();
        List<AggregateCall> topAggregates = this.buildTopAggregate(oldAgg, topGroupSet.cardinality(), aggExpressions);
        RelBuilder.GroupKey topGroupKey = newGroupSets == null ? relBuilder.groupKey(topGroupSet) : relBuilder.groupKey(topGroupSet, (Iterable)newGroupSets);
        relBuilder.aggregate(topGroupKey, topAggregates);
        RelNode relNode = relBuilder.build();
        ContextUtil.dumpCalcitePlan((String)"new plan", (RelNode)relNode, (Logger)logger);
        return relNode;
    }

    private List<RexNode> buildBottomProject(RelBuilder relBuilder, Project oldProject, List<AggExpressionUtil.GroupExpression> groupExpressions, List<AggExpressionUtil.AggExpression> aggExpressions) {
        ArrayList bottomProjectList = Lists.newArrayList();
        RexBuilder rexBuilder = relBuilder.getRexBuilder();
        for (AggExpressionUtil.GroupExpression groupExpr : groupExpressions) {
            int[] sourceInput = groupExpr.getBottomProjInput();
            for (int i = 0; i < sourceInput.length; ++i) {
                groupExpr.getBottomAggInput()[i] = bottomProjectList.size();
                RexInputRef groupInput = rexBuilder.makeInputRef(oldProject.getInput(), sourceInput[i]);
                bottomProjectList.add(groupInput);
            }
        }
        for (AggExpressionUtil.AggExpression aggExpression : aggExpressions) {
            if (this.isApplicableAggExpression(aggExpression)) {
                this.buildBottomAggExpression(rexBuilder, oldProject, bottomProjectList, aggExpression);
                continue;
            }
            if (aggExpression.getExpression() == null) continue;
            aggExpression.getBottomAggInput()[0] = bottomProjectList.size();
            bottomProjectList.add(aggExpression.getExpression());
        }
        return bottomProjectList;
    }

    private void buildBottomAggExpression(RexBuilder rexBuilder, Project oldProject, List<RexNode> bottomProjectList, AggExpressionUtil.AggExpression aggExpression) {
        int[] conditionsInput = aggExpression.getBottomProjConditionsInput();
        for (int i = 0; i < conditionsInput.length; ++i) {
            aggExpression.getBottomAggConditionsInput()[i] = bottomProjectList.size();
            RexInputRef conditionInput = rexBuilder.makeInputRef(oldProject.getInput(), conditionsInput[i]);
            bottomProjectList.add((RexNode)conditionInput);
        }
        List<RexNode> values = aggExpression.getValuesList();
        for (int i = 0; i < values.size(); ++i) {
            aggExpression.getBottomAggValuesInput()[i] = bottomProjectList.size();
            if (RuleUtils.isCast(values.get(i))) {
                RexNode rexNode = (RexNode)((RexCall)values.get((int)i)).operands.get(0);
                DataType dataType = DataType.getType((String)rexNode.getType().getSqlTypeName().getName());
                if (!AggExpressionUtil.isSum(aggExpression.getAggCall().getAggregation().kind) || dataType.isNumberFamily() || dataType.isIntegerFamily()) {
                    bottomProjectList.add(rexNode);
                    continue;
                }
                bottomProjectList.add(values.get(i));
                continue;
            }
            if (RuleUtils.isNotNullLiteral(values.get(i))) {
                bottomProjectList.add(values.get(i));
                continue;
            }
            bottomProjectList.add((RexNode)rexBuilder.makeBigintLiteral(BigDecimal.ZERO));
        }
    }

    private List<AggregateCall> buildBottomAggregate(RelBuilder relBuilder, List<AggExpressionUtil.AggExpression> aggExpressions, int bottomAggOffset) {
        ArrayList bottomAggCalls = Lists.newArrayList();
        ArrayList aggCaseExpressions = Lists.newArrayList();
        for (AggExpressionUtil.AggExpression aggExpression : aggExpressions) {
            if (this.isApplicableAggExpression(aggExpression)) {
                aggCaseExpressions.add(aggExpression);
                continue;
            }
            aggExpression.getTopProjInput()[0] = bottomAggOffset + bottomAggCalls.size();
            AggregateCall oldAggCall = aggExpression.getAggCall();
            List args = Arrays.stream(aggExpression.getBottomAggInput()).boxed().collect(Collectors.toList());
            int filterArg = oldAggCall.filterArg;
            bottomAggCalls.add(oldAggCall.copy(args, filterArg));
        }
        int aggCaseIdx = 0;
        for (AggExpressionUtil.AggExpression aggExpression : aggCaseExpressions) {
            for (int valueIdx = 0; valueIdx < aggExpression.getValuesList().size(); ++valueIdx) {
                if (!this.isValidAggColumnExpr(aggExpression.getValuesList().get(valueIdx))) continue;
                String aggName = this.getBottomAggPrefix() + aggCaseIdx + "$" + valueIdx;
                ArrayList args = Lists.newArrayList((Object[])new Integer[]{aggExpression.getBottomAggValuesInput()[valueIdx]});
                aggExpression.getTopProjValuesInput()[valueIdx] = bottomAggOffset + bottomAggCalls.size();
                bottomAggCalls.add(AggregateCall.create((SqlAggFunction)this.getBottomAggFunc(aggExpression.getAggCall()), (boolean)false, (boolean)false, (boolean)false, (List)args, (int)-1, null, (RelCollation)RelCollations.EMPTY, (int)bottomAggOffset, (RelNode)relBuilder.peek(), null, (String)aggName));
            }
            ++aggCaseIdx;
        }
        return bottomAggCalls;
    }

    private List<RexNode> buildTopProject(RelBuilder relBuilder, Project oldProject, List<AggExpressionUtil.AggExpression> aggExpressions, List<AggExpressionUtil.GroupExpression> groupExpressions) {
        ArrayList topProjectList = Lists.newArrayList();
        for (AggExpressionUtil.GroupExpression groupExpr : groupExpressions) {
            int[] aggAdjustments = AggExpressionUtil.generateAdjustments(groupExpr.getBottomProjInput(), groupExpr.getTopProjInput());
            RexNode rexNode = (RexNode)groupExpr.getExpression().accept((RexVisitor)new RelOptUtil.RexInputConverter(relBuilder.getRexBuilder(), oldProject.getInput().getRowType().getFieldList(), relBuilder.peek().getRowType().getFieldList(), aggAdjustments));
            rexNode = relBuilder.getRexBuilder().ensureType(groupExpr.getExpression().getType(), rexNode, false);
            topProjectList.add(rexNode);
        }
        for (AggExpressionUtil.AggExpression aggExpression : aggExpressions) {
            if (this.isApplicableAggExpression(aggExpression)) {
                int whenIndex;
                int[] adjustments = AggExpressionUtil.generateAdjustments(aggExpression.getBottomProjConditionsInput(), aggExpression.getTopProjConditionsInput());
                List<RexNode> conditions = aggExpression.getConditions();
                List<RexNode> valuesList = aggExpression.getValuesList();
                ArrayList newArgs = Lists.newArrayList();
                for (whenIndex = 0; whenIndex < conditions.size(); ++whenIndex) {
                    RexNode whenNode = (RexNode)conditions.get(whenIndex).accept((RexVisitor)new RelOptUtil.RexInputConverter(relBuilder.getRexBuilder(), oldProject.getInput().getRowType().getFieldList(), relBuilder.peek().getRowType().getFieldList(), adjustments));
                    newArgs.add(whenNode);
                    RexNode thenNode = valuesList.get(whenIndex);
                    if (this.isNeedTackCast(thenNode)) {
                        RelDataType expandedType = AbstractAggCaseWhenFunctionRule.expandCastDataType(relBuilder, aggExpression, thenNode);
                        thenNode = relBuilder.getRexBuilder().makeCast(expandedType, (RexNode)relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), aggExpression.getTopProjValuesInput()[whenIndex]));
                    } else if (RuleUtils.isNotNullLiteral(thenNode)) {
                        thenNode = relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), aggExpression.getTopProjValuesInput()[whenIndex]);
                    }
                    newArgs.add(thenNode);
                }
                RexNode elseNode = valuesList.get(whenIndex);
                if (this.isNeedTackCast(elseNode)) {
                    RelDataType expandedType = AbstractAggCaseWhenFunctionRule.expandCastDataType(relBuilder, aggExpression, elseNode);
                    elseNode = relBuilder.getRexBuilder().makeCast(expandedType, (RexNode)relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), aggExpression.getTopProjValuesInput()[whenIndex]));
                } else if (RuleUtils.isNotNullLiteral(elseNode)) {
                    elseNode = relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), aggExpression.getTopProjValuesInput()[whenIndex]);
                }
                newArgs.add(elseNode);
                RexNode newCaseWhenExpr = relBuilder.call((SqlOperator)SqlStdOperatorTable.CASE, (Iterable)newArgs);
                topProjectList.add(newCaseWhenExpr);
                continue;
            }
            RexInputRef rexNode = relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), aggExpression.getTopProjInput()[0]);
            topProjectList.add(rexNode);
        }
        return topProjectList;
    }

    private static RelDataType expandCastDataType(RelBuilder relBuilder, AggExpressionUtil.AggExpression aggExpression, RexNode thenOrElseNode) {
        RelDataType expandedType = thenOrElseNode.getType();
        SqlAggFunction func = aggExpression.getAggCall().getAggregation();
        if (SqlTypeUtil.isDecimal((RelDataType)thenOrElseNode.getType()) && func.getKind() == SqlKind.SUM) {
            expandedType = func.inferReturnType(relBuilder.getTypeFactory(), Collections.singletonList(aggExpression.getExpression().getType()));
        }
        return expandedType;
    }

    private List<AggregateCall> buildTopAggregate(Aggregate agg, int groupOffset, List<AggExpressionUtil.AggExpression> aggExpressions) {
        List oldAggregates = agg.getAggCallList();
        ArrayList topAggregates = Lists.newArrayList();
        for (int aggIndex = 0; aggIndex < oldAggregates.size(); ++aggIndex) {
            AggExpressionUtil.AggExpression aggExpression = aggExpressions.get(aggIndex);
            AggregateCall aggCall = aggExpression.getAggCall();
            String aggName = "AGG$" + aggIndex;
            topAggregates.add(AggregateCall.create((SqlAggFunction)this.getTopAggFunc0(aggCall), (boolean)false, (boolean)false, (List)Lists.newArrayList((Object[])new Integer[]{groupOffset + aggIndex}), (int)-1, (RelDataType)aggCall.getType(), (String)aggName));
        }
        return topAggregates;
    }

    protected abstract boolean checkAggCaseExpression(Aggregate var1, Project var2);

    protected abstract boolean isApplicableWithSumCaseRule(AggregateCall var1, Project var2);

    protected abstract boolean isApplicableAggExpression(AggExpressionUtil.AggExpression var1);

    protected abstract SqlAggFunction getBottomAggFunc(AggregateCall var1);

    protected abstract SqlAggFunction getTopAggFunc(AggregateCall var1);

    protected SqlAggFunction getTopAggFunc0(AggregateCall aggCall) {
        SqlKind kind = aggCall.getAggregation().getKind();
        if (Objects.requireNonNull(kind) == SqlKind.SUM) {
            return KylinSumSplitter.KYLIN_SUM;
        }
        return this.getTopAggFunc(aggCall);
    }

    protected boolean isValidAggColumnExpr(RexNode rexNode) {
        return true;
    }

    protected boolean isNeedTackCast(RexNode rexNode) {
        return RuleUtils.isCast(rexNode);
    }

    protected String getBottomAggPrefix() {
        return BOTTOM_AGG_PREFIX;
    }
}

