/*
 * Decompiled with CFR 0.152.
 */
package io.kinference.ndarray.extensions;

import io.kinference.ndarray.arrays.IntNDArray;
import io.kinference.ndarray.arrays.LongNDArray;
import io.kinference.ndarray.arrays.MutableNDArrayCore;
import io.kinference.ndarray.arrays.NDArray;
import io.kinference.ndarray.arrays.NDArrayCore;
import io.kinference.ndarray.arrays.NDArrayUtilsKt;
import io.kinference.ndarray.arrays.pointers.IntPointer;
import io.kinference.ndarray.arrays.pointers.LongPointer;
import io.kinference.ndarray.arrays.tiled.IntTiledArray;
import io.kinference.ndarray.arrays.tiled.LongTiledArray;
import io.kinference.ndarray.extensions.ArrayFactoriesKt;
import io.kinference.primitives.types.DataType;
import java.util.Arrays;
import kotlin.Metadata;
import kotlin.collections.ArraysKt;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SourceDebugExtension;
import org.jetbrains.annotations.NotNull;

@Metadata(mv={1, 9, 0}, k=2, xi=48, d1={"\u0000*\n\u0000\n\u0002\u0010\u0015\n\u0002\b\u0002\n\u0002\u0010\b\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0003\u001a \u0010\u0000\u001a\u00020\u00012\u0006\u0010\u0002\u001a\u00020\u00012\u0006\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u0006H\u0000\u001a(\u0010\u0007\u001a\u00020\b2\u0006\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0002\u001a\u00020\u00012\u0006\u0010\t\u001a\u00020\nH\u0000\u001a \u0010\u000b\u001a\u00020\f2\u0006\u0010\r\u001a\u00020\f2\u0006\u0010\u0005\u001a\u00020\f2\b\b\u0002\u0010\u0003\u001a\u00020\u0004\u001a(\u0010\u000b\u001a\u00020\f2\u0006\u0010\r\u001a\u00020\f2\u0006\u0010\u0005\u001a\u00020\f2\b\b\u0002\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u000e\u001a\u00020\b\u00a8\u0006\u000f"}, d2={"computeGatherShape", "", "shape", "axis", "", "indices", "Lio/kinference/ndarray/arrays/NDArray;", "createGatherDstArray", "Lio/kinference/ndarray/arrays/MutableNDArrayCore;", "type", "Lio/kinference/primitives/types/DataType;", "gather", "Lio/kinference/ndarray/arrays/NDArrayCore;", "array", "dst", "ndarray-core"})
@SourceDebugExtension(value={"SMAP\nGather.kt\nKotlin\n*S Kotlin\n*F\n+ 1 Gather.kt\nio/kinference/ndarray/extensions/GatherKt\n+ 2 fake.kt\nkotlin/jvm/internal/FakeKt\n+ 3 LongPointer.kt\nio/kinference/ndarray/arrays/pointers/LongPointerKt\n+ 4 IntPointer.kt\nio/kinference/ndarray/arrays/pointers/IntPointerKt\n*L\n1#1,82:1\n1#2:83\n195#3,18:84\n195#4,18:102\n*S KotlinDebug\n*F\n+ 1 Gather.kt\nio/kinference/ndarray/extensions/GatherKt\n*L\n49#1:84,18\n66#1:102,18\n*E\n"})
public final class GatherKt {
    @NotNull
    public static final int[] computeGatherShape(@NotNull int[] shape, int axis2, @NotNull NDArray indices) {
        Intrinsics.checkNotNullParameter((Object)shape, (String)"shape");
        Intrinsics.checkNotNullParameter((Object)indices, (String)"indices");
        int[] newShape = new int[shape.length + indices.getRank() - 1];
        ArraysKt.copyInto((int[])shape, (int[])newShape, (int)0, (int)0, (int)axis2);
        ArraysKt.copyInto$default((int[])indices.getShape(), (int[])newShape, (int)axis2, (int)0, (int)0, (int)12, null);
        ArraysKt.copyInto$default((int[])shape, (int[])newShape, (int)(axis2 + indices.getRank()), (int)(axis2 + 1), (int)0, (int)8, null);
        return newShape;
    }

    @NotNull
    public static final MutableNDArrayCore createGatherDstArray(int axis2, @NotNull NDArray indices, @NotNull int[] shape, @NotNull DataType type) {
        Intrinsics.checkNotNullParameter((Object)indices, (String)"indices");
        Intrinsics.checkNotNullParameter((Object)shape, (String)"shape");
        Intrinsics.checkNotNullParameter((Object)((Object)type), (String)"type");
        int[] newShape = GatherKt.computeGatherShape(shape, axis2, indices);
        return ArrayFactoriesKt.allocateNDArray(type, newShape);
    }

    @NotNull
    public static final NDArrayCore gather(@NotNull NDArrayCore array, @NotNull NDArrayCore indices, int axis2) {
        Intrinsics.checkNotNullParameter((Object)array, (String)"array");
        Intrinsics.checkNotNullParameter((Object)indices, (String)"indices");
        int actualAxis = NDArrayUtilsKt.indexAxis(array, axis2);
        MutableNDArrayCore dst = GatherKt.createGatherDstArray(actualAxis, indices, array.getShape(), array.getType());
        return GatherKt.gather(array, indices, axis2, dst);
    }

    public static /* synthetic */ NDArrayCore gather$default(NDArrayCore nDArrayCore, NDArrayCore nDArrayCore2, int n, int n2, Object object) {
        if ((n2 & 4) != 0) {
            n = 0;
        }
        return GatherKt.gather(nDArrayCore, nDArrayCore2, n);
    }

    @NotNull
    public static final NDArrayCore gather(@NotNull NDArrayCore array, @NotNull NDArrayCore indices, int axis2, @NotNull MutableNDArrayCore dst) {
        Intrinsics.checkNotNullParameter((Object)array, (String)"array");
        Intrinsics.checkNotNullParameter((Object)indices, (String)"indices");
        Intrinsics.checkNotNullParameter((Object)dst, (String)"dst");
        int[] gatherOutputShape = GatherKt.computeGatherShape(array.getShape(), axis2, indices);
        if (!Arrays.equals(dst.getShape(), gatherOutputShape)) {
            boolean $i$a$-require-GatherKt$gather$22 = false;
            String $i$a$-require-GatherKt$gather$22 = "Incorrect destination shape";
            throw new IllegalArgumentException($i$a$-require-GatherKt$gather$22.toString());
        }
        int actualAxis = NDArrayUtilsKt.indexAxis(array, axis2);
        int block = NDArrayUtilsKt.computeBlockSize$default(array, actualAxis + 1, 0, 2, null);
        int dataBatch = NDArrayUtilsKt.computeBlockSize$default(array, actualAxis, 0, 2, null);
        int indicesSize = indices.getStrides().getLinearSize();
        int gatheredBatch = indicesSize * block;
        int numBlocks = NDArrayUtilsKt.computeBlockSize$default(array, 0, actualAxis, 1, null);
        switch (WhenMappings.$EnumSwitchMapping$0[indices.getType().ordinal()]) {
            case 1: {
                LongNDArray cfr_ignored_0 = (LongNDArray)indices;
                LongPointer pointer = LongTiledArray.pointer$default(((LongNDArray)indices).getArray(), 0, 1, null);
                for (int numBatch = 0; numBatch < numBlocks; ++numBatch) {
                    int offset$iv;
                    long[] block$iv;
                    int index = 0;
                    LongPointer $this$forEach$iv = pointer;
                    boolean $i$f$forEach = false;
                    for (int end$iv = indicesSize; end$iv > 0; end$iv -= block$iv.length - offset$iv) {
                        block$iv = $this$forEach$iv.getCurrentBlock();
                        if (block$iv.length <= (offset$iv = $this$forEach$iv.getIndexInBlock()) + end$iv) {
                            $this$forEach$iv.blockIncrement();
                        } else {
                            $this$forEach$iv.setIndexInBlock($this$forEach$iv.getIndexInBlock() + end$iv);
                        }
                        int n = Math.min(block$iv.length, offset$iv + end$iv);
                        for (int index$iv = offset$iv; index$iv < n; ++index$iv) {
                            long it = block$iv[index$iv];
                            boolean bl = false;
                            int idx = (int)(it < 0L ? it + (long)array.getShape()[actualAxis] : it);
                            int srcOffset = numBatch * dataBatch + idx * block;
                            int n2 = index;
                            index = n2 + 1;
                            int dstOffset = numBatch * gatheredBatch + n2 * block;
                            dst.copyFrom(dstOffset, array, srcOffset, srcOffset + block);
                        }
                    }
                    pointer.setLinearIndex(0);
                }
                break;
            }
            case 2: {
                IntNDArray cfr_ignored_1 = (IntNDArray)indices;
                IntPointer pointer = IntTiledArray.pointer$default(((IntNDArray)indices).getArray(), 0, 1, null);
                for (int numBatch = 0; numBatch < numBlocks; ++numBatch) {
                    int offset$iv;
                    int[] block$iv;
                    int index = 0;
                    IntPointer $this$forEach$iv = pointer;
                    boolean $i$f$forEach = false;
                    for (int end$iv = indicesSize; end$iv > 0; end$iv -= block$iv.length - offset$iv) {
                        block$iv = $this$forEach$iv.getCurrentBlock();
                        if (block$iv.length <= (offset$iv = $this$forEach$iv.getIndexInBlock()) + end$iv) {
                            $this$forEach$iv.blockIncrement();
                        } else {
                            $this$forEach$iv.setIndexInBlock($this$forEach$iv.getIndexInBlock() + end$iv);
                        }
                        int n = Math.min(block$iv.length, offset$iv + end$iv);
                        for (int index$iv = offset$iv; index$iv < n; ++index$iv) {
                            int it = block$iv[index$iv];
                            boolean bl = false;
                            int idx = it < 0 ? it + array.getShape()[actualAxis] : it;
                            int srcOffset = numBatch * dataBatch + idx * block;
                            int n3 = index;
                            index = n3 + 1;
                            int dstOffset = numBatch * gatheredBatch + n3 * block;
                            dst.copyFrom(dstOffset, array, srcOffset, srcOffset + block);
                        }
                    }
                    pointer.setLinearIndex(0);
                }
                break;
            }
            default: {
                throw new IllegalStateException("Indices array must have Long or Int type");
            }
        }
        return dst;
    }

    public static /* synthetic */ NDArrayCore gather$default(NDArrayCore nDArrayCore, NDArrayCore nDArrayCore2, int n, MutableNDArrayCore mutableNDArrayCore, int n2, Object object) {
        if ((n2 & 4) != 0) {
            n = 0;
        }
        return GatherKt.gather(nDArrayCore, nDArrayCore2, n, mutableNDArrayCore);
    }

    @Metadata(mv={1, 9, 0}, k=3, xi=48)
    public final class WhenMappings {
        public static final /* synthetic */ int[] $EnumSwitchMapping$0;

        static {
            int[] nArray = new int[DataType.values().length];
            try {
                nArray[DataType.LONG.ordinal()] = 1;
            }
            catch (NoSuchFieldError noSuchFieldError) {
                // empty catch block
            }
            try {
                nArray[DataType.INT.ordinal()] = 2;
            }
            catch (NoSuchFieldError noSuchFieldError) {
                // empty catch block
            }
            $EnumSwitchMapping$0 = nArray;
        }
    }
}

