/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp.generate;

import ai.djl.modality.nlp.generate.BatchTensorList;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class SeqBatcher {
    NDManager manager;
    long batchSize;
    long seqLength;
    NDArray batchUid;
    NDArray offSets;
    BatchTensorList data;
    private Map<Long, Long> exitIndexEndPosition;

    SeqBatcher(BatchTensorList data, NDArray batchUid, NDArray offSets, NDManager manager) {
        this.manager = manager.newSubManager();
        this.data = data;
        this.batchUid = batchUid.getShape().dimension() == 2 ? batchUid : batchUid.reshape(-1L, 1L);
        this.offSets = offSets.getShape().hashCode() == 2 ? offSets : offSets.reshape(-1L, 1L);
        this.batchSize = data.getPastOutputIds().getShape().get(0);
        this.seqLength = data.getPastOutputIds().getShape().get(1);
        this.exitIndexEndPosition = new ConcurrentHashMap<Long, Long>();
    }

    public BatchTensorList getData() {
        return this.data;
    }

    public void addBatch(SeqBatcher seqBatcherNew) {
        this.merge(this, seqBatcherNew, this.seqLength - seqBatcherNew.seqLength);
    }

    private void merge(SeqBatcher seqBatcher1, SeqBatcher seqBatcher2, long seqDelta) {
        if (seqDelta < 0L) {
            SeqBatcher swapTmp = seqBatcher1;
            seqBatcher1 = seqBatcher2;
            seqBatcher2 = swapTmp;
            seqDelta = -seqDelta;
        }
        try (NDScope scope = new NDScope();){
            scope.suppressNotUsedWarning();
            NDList list1 = seqBatcher1.data.getList();
            NDList list2 = seqBatcher2.data.getList();
            NDList merged = new NDList(list1.size());
            long[] seqDimOrder = seqBatcher1.data.getSeqDimOrder();
            for (int i = 0; i < list1.size(); ++i) {
                NDIndex ndIndex;
                NDArray batch1 = (NDArray)list1.get(i);
                NDArray batch2 = (NDArray)list2.get(i);
                if (seqDelta == 0L) {
                    batch1 = batch1.concat(batch2, 0);
                    merged.add(batch1);
                    continue;
                }
                long[] shape1 = batch1.getShape().getShape();
                long[] shape2 = batch2.getShape().getShape();
                long padTokenId = 220L;
                long[] shapeDelta = batch1.getShape().getShape();
                shapeDelta[0] = shape2[0];
                NDArray deltaArray = i == 0 ? this.manager.full(new Shape(shapeDelta), padTokenId, batch1.getDataType()) : this.manager.zeros(new Shape(shapeDelta), batch1.getDataType());
                batch1 = batch1.concat(deltaArray, 0);
                if (seqDimOrder[i] > 0L) {
                    ndIndex = new NDIndex("{}:", seqBatcher1.batchSize);
                    int order = 1;
                    while ((long)order < seqDimOrder[i]) {
                        ndIndex = ndIndex.addAllDim();
                        ++order;
                    }
                    assert (seqDelta + shape2[order] == shape1[order]) : "Wrong shapes. batch1 and batch2 are not mergable";
                    ndIndex = ndIndex.addSliceDim(seqDelta, shape1[order]).addEllipseDim();
                } else {
                    ndIndex = new NDIndex("{}:, ...", seqBatcher1.batchSize);
                }
                batch1.set(ndIndex, batch2);
                merged.add(batch1);
            }
            this.data = this.data.fromList(merged, this.data.getSeqDimOrder());
            this.batchSize = seqBatcher1.batchSize + seqBatcher2.batchSize;
            this.batchUid = seqBatcher1.batchUid.concat(seqBatcher2.batchUid, 0);
            this.offSets = seqBatcher1.offSets.concat(seqBatcher2.offSets.addi(seqDelta), 0);
            this.seqLength = seqBatcher1.seqLength;
            NDScope.unregister(this.batchUid, this.offSets);
            NDScope.unregister(merged);
        }
    }

    public void exitCriteria(NDArray outputIds, long maxLength, long eosTokenId) {
        long[] outputIdsArray = outputIds.toLongArray();
        long[] offSetsArray = this.offSets.toLongArray();
        for (int i = 0; i < outputIdsArray.length; ++i) {
            if (this.seqLength - offSetsArray[i] < maxLength && outputIdsArray[i] != eosTokenId || this.exitIndexEndPosition.containsKey(i)) continue;
            this.exitIndexEndPosition.put(Long.valueOf(i), this.seqLength);
        }
    }

    public Map<Long, NDArray> collectAndTrim() {
        if (this.exitIndexEndPosition.isEmpty()) {
            return new ConcurrentHashMap<Long, NDArray>();
        }
        ConcurrentHashMap<Long, NDArray> finishedSequences = new ConcurrentHashMap<Long, NDArray>();
        try (NDScope scope = new NDScope();){
            scope.suppressNotUsedWarning();
            HashSet<Long> exitIndices = new HashSet<Long>();
            for (Map.Entry<Long, Long> entry : this.exitIndexEndPosition.entrySet()) {
                long batchIndex = entry.getKey();
                long seqEndPosition = entry.getValue();
                long uid = this.batchUid.getLong(batchIndex);
                long offSet = this.offSets.getLong(batchIndex);
                NDArray output = this.data.getPastOutputIds().get("{}, {}:{}", batchIndex, offSet, seqEndPosition);
                finishedSequences.put(uid, output);
                exitIndices.add(batchIndex);
                NDScope.unregister(output);
            }
            long[] keepIndices = new long[Math.toIntExact(this.batchSize) - exitIndices.size()];
            int j = 0;
            for (long i = 0L; i < this.batchSize; ++i) {
                if (exitIndices.contains(i)) continue;
                keepIndices[j++] = i;
            }
            if (keepIndices.length == 0) {
                this.batchUid = this.manager.create(new Shape(0L, 1L), this.batchUid.getDataType());
                this.offSets = this.manager.create(new Shape(0L, 1L), this.offSets.getDataType());
                this.data = null;
                this.batchSize = 0L;
                this.seqLength = 0L;
                this.exitIndexEndPosition = new ConcurrentHashMap<Long, Long>();
                NDScope.unregister(this.batchUid, this.offSets);
                ConcurrentHashMap<Long, NDArray> i = finishedSequences;
                return i;
            }
            NDIndex ndIndex = new NDIndex("{}", this.manager.create(keepIndices));
            this.batchUid = this.batchUid.get(ndIndex).reshape(-1L, 1L);
            this.offSets = this.offSets.get(ndIndex).reshape(-1L, 1L);
            long trimSeq = this.offSets.min(new int[]{0}).toLongArray()[0];
            this.offSets = this.offSets.subi(trimSeq);
            NDList list = this.data.getList();
            NDList newList = new NDList(list.size());
            long[] seqDimOrder = this.data.getSeqDimOrder();
            for (int i = 0; i < list.size(); ++i) {
                NDArray batch = (NDArray)list.get(i);
                if (trimSeq == 0L) {
                    ndIndex = new NDIndex("{}, ...", this.manager.create(keepIndices));
                    newList.add(batch.get(ndIndex));
                    continue;
                }
                if (seqDimOrder[i] > 0L) {
                    ndIndex = new NDIndex("{}", this.manager.create(keepIndices));
                    int order = 1;
                    while ((long)order < seqDimOrder[i]) {
                        ndIndex = ndIndex.addAllDim();
                        ++order;
                    }
                    ndIndex = ndIndex.addSliceDim(trimSeq, this.seqLength).addEllipseDim();
                } else {
                    ndIndex = new NDIndex("{}, ...", this.manager.create(keepIndices));
                }
                newList.add(batch.get(ndIndex));
            }
            this.data = this.data.fromList(newList, this.data.getSeqDimOrder());
            this.batchSize -= (long)this.exitIndexEndPosition.size();
            this.seqLength -= trimSeq;
            this.exitIndexEndPosition = new ConcurrentHashMap<Long, Long>();
            NDScope.unregister(newList);
            NDScope.unregister(this.batchUid, this.offSets);
            ConcurrentHashMap<Long, NDArray> concurrentHashMap = finishedSequences;
            return concurrentHashMap;
        }
    }

    public boolean sequenceComplete() {
        return !this.exitIndexEndPosition.isEmpty();
    }
}

