/*
 * Decompiled with CFR 0.152.
 */
package io.kinference.ndarray.extensions;

import io.kinference.ndarray.arrays.MutableNDArrayCore;
import io.kinference.ndarray.arrays.NDArrayCore;
import io.kinference.ndarray.arrays.NDArrayUtilsKt;
import io.kinference.ndarray.arrays.Strides;
import io.kinference.ndarray.extensions.ArrayFactoriesKt;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.ArraysKt;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SourceDebugExtension;
import org.jetbrains.annotations.NotNull;

@Metadata(mv={1, 9, 0}, k=2, xi=48, d1={"\u00004\n\u0000\n\u0002\u0010\u0015\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0010\u000b\n\u0000\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010 \n\u0002\b\u0004\u001a(\u0010\u0000\u001a\u00020\u00012\u0006\u0010\u0002\u001a\u00020\u00012\u0006\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00042\u0006\u0010\u0006\u001a\u00020\u0007H\u0002\u001a4\u0010\b\u001a\u00020\t*\u00020\n2\u0006\u0010\u000b\u001a\u00020\u00042\u0006\u0010\f\u001a\u00020\u00042\u0006\u0010\r\u001a\u00020\u00042\u0006\u0010\u000e\u001a\u00020\u000f2\u0006\u0010\u0010\u001a\u00020\u0004H\u0002\u001a\"\u0010\u0011\u001a\b\u0012\u0004\u0012\u00020\t0\u0012*\u00020\n2\u0006\u0010\u0013\u001a\u00020\u00042\u0006\u0010\u0014\u001a\u00020\u000fH\u0002\u001a,\u0010\u0015\u001a\b\u0012\u0004\u0012\u00020\t0\u0012*\u00020\n2\u0006\u0010\u0013\u001a\u00020\u00042\b\b\u0002\u0010\u0003\u001a\u00020\u00042\b\b\u0002\u0010\u0006\u001a\u00020\u0007\u001a*\u0010\u0015\u001a\b\u0012\u0004\u0012\u00020\t0\u0012*\u00020\n2\u0006\u0010\u0005\u001a\u00020\u00012\u0006\u0010\u0003\u001a\u00020\u00042\b\b\u0002\u0010\u0006\u001a\u00020\u0007\u00a8\u0006\u0016"}, d2={"computeSplitShape", "", "shape", "axis", "", "split", "keepDims", "", "splitFragment", "Lio/kinference/ndarray/arrays/MutableNDArrayCore;", "Lio/kinference/ndarray/arrays/NDArrayCore;", "beforeAxisDims", "fromAxisDims", "fragmentSize", "splitStrides", "Lio/kinference/ndarray/arrays/Strides;", "offset", "splitParts", "", "parts", "strides", "splitWithAxis", "ndarray-core"})
@SourceDebugExtension(value={"SMAP\nSplit.kt\nKotlin\n*S Kotlin\n*F\n+ 1 Split.kt\nio/kinference/ndarray/extensions/SplitKt\n+ 2 fake.kt\nkotlin/jvm/internal/FakeKt\n*L\n1#1,84:1\n1#2:85\n*E\n"})
public final class SplitKt {
    private static final int[] computeSplitShape(int[] shape, int axis2, int split, boolean keepDims2) {
        int[] newShape = null;
        if (keepDims2) {
            int[] nArray = Arrays.copyOf(shape, shape.length);
            Intrinsics.checkNotNullExpressionValue((Object)nArray, (String)"copyOf(this, size)");
            newShape = nArray;
            newShape[axis2] = split;
        } else {
            newShape = new int[shape.length - 1];
            ArraysKt.copyInto((int[])shape, (int[])newShape, (int)0, (int)0, (int)axis2);
            ArraysKt.copyInto$default((int[])shape, (int[])newShape, (int)axis2, (int)(axis2 + 1), (int)0, (int)8, null);
        }
        return newShape;
    }

    @NotNull
    public static final List<MutableNDArrayCore> splitWithAxis(@NotNull NDArrayCore $this$splitWithAxis, int parts, int axis2, boolean keepDims2) {
        Intrinsics.checkNotNullParameter((Object)$this$splitWithAxis, (String)"<this>");
        int actualAxis = NDArrayUtilsKt.indexAxis($this$splitWithAxis, axis2);
        if (!(0 <= actualAxis ? actualAxis < $this$splitWithAxis.getShape().length : false)) {
            boolean $i$a$-require-SplitKt$splitWithAxis$22 = false;
            String $i$a$-require-SplitKt$splitWithAxis$22 = "Index " + actualAxis + " out of shape bound: (0, " + ($this$splitWithAxis.getRank() - 1);
            throw new IllegalArgumentException($i$a$-require-SplitKt$splitWithAxis$22.toString());
        }
        int elementsByIndex = $this$splitWithAxis.getShape()[actualAxis];
        int mainSplit = (int)Math.ceil((double)elementsByIndex / (double)parts);
        int n = 0;
        int[] nArray = new int[parts];
        while (n < parts) {
            int n2 = n++;
            nArray[n2] = mainSplit;
        }
        int[] split = nArray;
        int tail = elementsByIndex % parts;
        if (tail != 0) {
            split[parts - 1] = tail;
        }
        return SplitKt.splitWithAxis($this$splitWithAxis, split, actualAxis, keepDims2);
    }

    public static /* synthetic */ List splitWithAxis$default(NDArrayCore nDArrayCore, int n, int n2, boolean bl, int n3, Object object) {
        if ((n3 & 2) != 0) {
            n2 = 0;
        }
        if ((n3 & 4) != 0) {
            bl = true;
        }
        return SplitKt.splitWithAxis(nDArrayCore, n, n2, bl);
    }

    /*
     * WARNING - void declaration
     */
    @NotNull
    public static final List<MutableNDArrayCore> splitWithAxis(@NotNull NDArrayCore $this$splitWithAxis, @NotNull int[] split, int axis2, boolean keepDims2) {
        Intrinsics.checkNotNullParameter((Object)$this$splitWithAxis, (String)"<this>");
        Intrinsics.checkNotNullParameter((Object)split, (String)"split");
        int actualAxis = NDArrayUtilsKt.indexAxis($this$splitWithAxis, axis2);
        if (!(0 <= actualAxis ? actualAxis < $this$splitWithAxis.getShape().length : false)) {
            boolean $i$a$-require-SplitKt$splitWithAxis$32 = false;
            String $i$a$-require-SplitKt$splitWithAxis$32 = "Index " + actualAxis + " out of shape bound: (0, " + ($this$splitWithAxis.getRank() - 1);
            throw new IllegalArgumentException($i$a$-require-SplitKt$splitWithAxis$32.toString());
        }
        int beforeAxisDims = NDArrayUtilsKt.computeBlockSize$default($this$splitWithAxis, 0, actualAxis, 1, null);
        int fromAxisDims = NDArrayUtilsKt.computeBlockSize$default($this$splitWithAxis, actualAxis, 0, 2, null);
        int afterAxisDims = actualAxis + 1 == $this$splitWithAxis.getRank() ? 1 : NDArrayUtilsKt.computeBlockSize$default($this$splitWithAxis, actualAxis + 1, 0, 2, null);
        int inputOffset = 0;
        int n = split.length;
        ArrayList<MutableNDArrayCore> arrayList = new ArrayList<MutableNDArrayCore>(n);
        int n2 = 0;
        while (n2 < n) {
            void i;
            int n3;
            int n4 = n3 = n2++;
            ArrayList<MutableNDArrayCore> arrayList2 = arrayList;
            boolean bl = false;
            int splitSize = split[i];
            int[] outputDims = SplitKt.computeSplitShape($this$splitWithAxis.getStrides().getShape(), actualAxis, split[i], keepDims2);
            Strides outStrides = new Strides(outputDims);
            int fragmentSize = splitSize * afterAxisDims;
            MutableNDArrayCore dst = SplitKt.splitFragment($this$splitWithAxis, beforeAxisDims, fromAxisDims, fragmentSize, outStrides, inputOffset);
            inputOffset += fragmentSize;
            arrayList2.add(dst);
        }
        return arrayList;
    }

    public static /* synthetic */ List splitWithAxis$default(NDArrayCore nDArrayCore, int[] nArray, int n, boolean bl, int n2, Object object) {
        if ((n2 & 4) != 0) {
            bl = true;
        }
        return SplitKt.splitWithAxis(nDArrayCore, nArray, n, bl);
    }

    private static final MutableNDArrayCore splitFragment(NDArrayCore $this$splitFragment, int beforeAxisDims, int fromAxisDims, int fragmentSize, Strides splitStrides, int offset) {
        MutableNDArrayCore dst = ArrayFactoriesKt.allocateNDArray($this$splitFragment.getType(), splitStrides);
        int len = beforeAxisDims * fragmentSize;
        if (fromAxisDims == fragmentSize) {
            dst.copyFrom(0, $this$splitFragment, 0, len);
            return dst;
        }
        int n = 0;
        while (n < beforeAxisDims) {
            int it = n++;
            boolean bl = false;
            int start2 = offset + fromAxisDims * it;
            dst.copyFrom(it * fragmentSize, $this$splitFragment, start2, start2 + fragmentSize);
        }
        return dst;
    }

    private static final List<MutableNDArrayCore> splitParts(NDArrayCore $this$splitParts, int parts, Strides strides2) {
        if (!($this$splitParts.getLinearSize() % parts == 0)) {
            String string = "Failed requirement.";
            throw new IllegalArgumentException(string.toString());
        }
        if (!(strides2.getLinearSize() == $this$splitParts.getLinearSize() / parts)) {
            String string = "Failed requirement.";
            throw new IllegalArgumentException(string.toString());
        }
        int offset = 0;
        int partSize = strides2.getLinearSize();
        ArrayList<MutableNDArrayCore> arrayList = new ArrayList<MutableNDArrayCore>(parts);
        int n = 0;
        while (n < parts) {
            int n2;
            int n3 = n2 = n++;
            ArrayList<MutableNDArrayCore> arrayList2 = arrayList;
            boolean bl = false;
            MutableNDArrayCore newArray = ArrayFactoriesKt.allocateNDArray($this$splitParts.getType(), strides2);
            newArray.copyFrom(0, $this$splitParts, offset, offset + partSize);
            offset += partSize;
            arrayList2.add(newArray);
        }
        return arrayList;
    }
}

