/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.physical;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Stack;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.GroupByOperator;
import org.apache.hadoop.hive.ql.exec.HashTableDummyOperator;
import org.apache.hadoop.hive.ql.exec.HashTableSinkOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.RowSchema;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.optimizer.physical.MapJoinResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.GroupByDesc;
import org.apache.hadoop.hive.ql.plan.HashTableDummyDesc;
import org.apache.hadoop.hive.ql.plan.HashTableSinkDesc;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.PlanUtils;
import org.apache.hadoop.hive.ql.plan.TableDesc;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class LocalMapJoinProcFactory {
    private static final Logger LOG = LoggerFactory.getLogger(LocalMapJoinProcFactory.class);

    public static NodeProcessor getJoinProc() {
        return new LocalMapJoinProcessor();
    }

    public static NodeProcessor getGroupByProc() {
        return new MapJoinFollowedByGroupByProcessor();
    }

    public static NodeProcessor getDefaultProc() {
        return new NodeProcessor(){

            @Override
            public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
                return null;
            }
        };
    }

    private LocalMapJoinProcFactory() {
    }

    public static class LocalMapJoinProcessor
    implements NodeProcessor {
        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx ctx, Object ... nodeOutputs) throws SemanticException {
            MapJoinResolver.LocalMapJoinProcCtx context = (MapJoinResolver.LocalMapJoinProcCtx)ctx;
            if (!nd.getName().equals("MAPJOIN")) {
                return null;
            }
            MapJoinOperator mapJoinOp = (MapJoinOperator)nd;
            try {
                this.hasGroupBy(mapJoinOp, context);
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            MapJoinDesc mapJoinDesc = (MapJoinDesc)mapJoinOp.getConf();
            mapJoinDesc.resetOrder();
            HiveConf conf = context.getParseCtx().getConf();
            float hashtableMemoryUsage = context.isFollowedByGroupBy() ? conf.getFloatVar(HiveConf.ConfVars.HIVEHASHTABLEFOLLOWBYGBYMAXMEMORYUSAGE) : conf.getFloatVar(HiveConf.ConfVars.HIVEHASHTABLEMAXMEMORYUSAGE);
            mapJoinDesc.setHashTableMemoryUsage(hashtableMemoryUsage);
            LOG.info("Setting max memory usage to " + hashtableMemoryUsage + " for table sink " + (context.isFollowedByGroupBy() ? "" : "not") + " followed by group by");
            HashTableSinkDesc hashTableSinkDesc = new HashTableSinkDesc(mapJoinDesc);
            HashTableSinkOperator hashTableSinkOp = (HashTableSinkOperator)OperatorFactory.get(mapJoinOp.getCompilationOpContext(), hashTableSinkDesc);
            int bigTable = mapJoinDesc.getPosBigTable();
            boolean useNontaged = conf.getBoolVar(HiveConf.ConfVars.HIVECONVERTJOINUSENONSTAGED) && conf.getVar(HiveConf.ConfVars.HIVE_EXECUTION_ENGINE).equals("mr") && !conf.getBoolVar(HiveConf.ConfVars.HIVE_VECTORIZATION_ENABLED);
            ArrayList<Operator<? extends OperatorDesc>> smallTablesParentOp = new ArrayList<Operator<? extends OperatorDesc>>();
            ArrayList<HashTableDummyOperator> dummyOperators = new ArrayList<HashTableDummyOperator>();
            ArrayList<Operator<? extends OperatorDesc>> directOperators = new ArrayList<Operator<? extends OperatorDesc>>();
            List<Operator<OperatorDesc>> parentsOp = mapJoinOp.getParentOperators();
            for (byte i = 0; i < parentsOp.size(); i = (byte)((byte)(i + 1))) {
                TableDesc tbl;
                boolean directFetchable;
                if (i == bigTable) {
                    smallTablesParentOp.add(null);
                    directOperators.add(null);
                    continue;
                }
                Operator<OperatorDesc> operator = parentsOp.get(i);
                boolean bl = directFetchable = useNontaged && (operator instanceof TableScanOperator || operator instanceof MapJoinOperator);
                if (directFetchable) {
                    smallTablesParentOp.add(null);
                    directOperators.add(operator);
                    hashTableSinkDesc.getKeys().put(i, null);
                    hashTableSinkDesc.getExprs().put(i, null);
                    hashTableSinkDesc.getFilters().put(i, null);
                } else {
                    smallTablesParentOp.add(operator);
                    directOperators.add(null);
                    int[] valueIndex = mapJoinDesc.getValueIndex(i);
                    if (valueIndex != null) {
                        ArrayList<ExprNodeDesc> newValues = new ArrayList<ExprNodeDesc>();
                        List<ExprNodeDesc> values = hashTableSinkDesc.getExprs().get(i);
                        for (int index = 0; index < values.size(); ++index) {
                            if (valueIndex[index] >= 0) continue;
                            newValues.add(values.get(index));
                        }
                        hashTableSinkDesc.getExprs().put(i, newValues);
                    }
                }
                operator.replaceChild(mapJoinOp, hashTableSinkOp);
                if (directFetchable) {
                    operator.setChildOperators(null);
                }
                HashTableDummyDesc desc = new HashTableDummyDesc();
                HashTableDummyOperator dummyOp = (HashTableDummyOperator)OperatorFactory.get(operator.getCompilationOpContext(), desc);
                if (operator.getSchema() == null) {
                    if (!(operator instanceof TableScanOperator)) throw new SemanticException("Expected parent operator of type TableScanOperator.Found " + operator.getClass().getName() + " instead.");
                    tbl = ((TableScanOperator)operator).getTableDesc();
                } else {
                    RowSchema rowSchema = operator.getSchema();
                    tbl = PlanUtils.getIntermediateFileTableDesc(PlanUtils.getFieldSchemasFromRowSchema(rowSchema, ""));
                }
                ((HashTableDummyDesc)dummyOp.getConf()).setTbl(tbl);
                mapJoinOp.replaceParent(operator, dummyOp);
                ArrayList<Operator<? extends OperatorDesc>> dummyChildren = new ArrayList<Operator<? extends OperatorDesc>>();
                dummyChildren.add(mapJoinOp);
                dummyOp.setChildOperators(dummyChildren);
                dummyOperators.add(dummyOp);
            }
            hashTableSinkOp.setParentOperators(smallTablesParentOp);
            for (Operator operator : dummyOperators) {
                context.addDummyParentOp(operator);
            }
            if (!this.hasAnyDirectFetch(directOperators)) return null;
            context.addDirectWorks(mapJoinOp, directOperators);
            return null;
        }

        private boolean hasAnyDirectFetch(List<Operator<?>> directOperators) {
            for (Operator<?> operator : directOperators) {
                if (operator == null) continue;
                return true;
            }
            return false;
        }

        public void hasGroupBy(Operator<? extends OperatorDesc> mapJoinOp, MapJoinResolver.LocalMapJoinProcCtx localMapJoinProcCtx) throws Exception {
            List<Operator<OperatorDesc>> childOps = mapJoinOp.getChildOperators();
            LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
            opRules.put(new RuleRegExp("R1", GroupByOperator.getOperatorName() + "%"), LocalMapJoinProcFactory.getGroupByProc());
            DefaultRuleDispatcher disp = new DefaultRuleDispatcher(LocalMapJoinProcFactory.getDefaultProc(), opRules, localMapJoinProcCtx);
            DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
            ArrayList<Node> topNodes = new ArrayList<Node>();
            topNodes.addAll(childOps);
            ogw.startWalking(topNodes, null);
        }
    }

    public static class MapJoinFollowedByGroupByProcessor
    implements NodeProcessor {
        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx ctx, Object ... nodeOutputs) throws SemanticException {
            MapJoinResolver.LocalMapJoinProcCtx context = (MapJoinResolver.LocalMapJoinProcCtx)ctx;
            if (!nd.getName().equals("GBY")) {
                return null;
            }
            context.setFollowedByGroupBy(true);
            GroupByOperator groupByOp = (GroupByOperator)nd;
            float groupByMemoryUsage = context.getParseCtx().getConf().getFloatVar(HiveConf.ConfVars.HIVEMAPJOINFOLLOWEDBYMAPAGGRHASHMEMORY);
            ((GroupByDesc)groupByOp.getConf()).setGroupByMemoryUsage(groupByMemoryUsage);
            return null;
        }
    }
}

