/*
 * Decompiled with CFR 0.152.
 */
package org.apache.gluten.memory.memtarget.spark;

import com.google.common.base.Preconditions;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.gluten.memory.MemoryUsageStatsBuilder;
import org.apache.gluten.memory.SimpleMemoryUsageRecorder;
import org.apache.gluten.memory.memtarget.KnownNameAndStats;
import org.apache.gluten.memory.memtarget.MemoryTargetUtil;
import org.apache.gluten.memory.memtarget.MemoryTargetVisitor;
import org.apache.gluten.memory.memtarget.Spiller;
import org.apache.gluten.memory.memtarget.Spillers;
import org.apache.gluten.memory.memtarget.TreeMemoryTarget;
import org.apache.gluten.memory.memtarget.TreeMemoryTargets;
import org.apache.gluten.proto.MemoryUsageStats;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;

public class TreeMemoryConsumer
extends MemoryConsumer
implements TreeMemoryTarget {
    private final SimpleMemoryUsageRecorder recorder = new SimpleMemoryUsageRecorder();
    private final Map<String, TreeMemoryTarget> children = new HashMap<String, TreeMemoryTarget>();
    private final String name = MemoryTargetUtil.toUniqueName("Gluten.Tree");

    TreeMemoryConsumer(TaskMemoryManager taskMemoryManager) {
        super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.OFF_HEAP);
    }

    @Override
    public long borrow(long size) {
        if (size == 0L) {
            return 0L;
        }
        long acquired = this.acquireMemory(size);
        this.recorder.inc(acquired);
        return acquired;
    }

    @Override
    public long repay(long size) {
        if (size == 0L) {
            return 0L;
        }
        long toFree = Math.min(this.getUsed(), size);
        this.freeMemory(toFree);
        Preconditions.checkArgument((this.getUsed() >= 0L ? 1 : 0) != 0);
        this.recorder.inc(-toFree);
        return toFree;
    }

    @Override
    public String name() {
        return this.name;
    }

    @Override
    public long usedBytes() {
        return this.getUsed();
    }

    @Override
    public <T> T accept(MemoryTargetVisitor<T> visitor) {
        return visitor.visit(this);
    }

    @Override
    public MemoryUsageStats stats() {
        Set<Map.Entry<String, TreeMemoryTarget>> entries = this.children.entrySet();
        Map<String, MemoryUsageStats> childrenStats = entries.stream().collect(Collectors.toMap(e -> ((TreeMemoryTarget)e.getValue()).name(), e -> ((TreeMemoryTarget)e.getValue()).stats()));
        Preconditions.checkState((childrenStats.size() == this.children.size() ? 1 : 0) != 0);
        MemoryUsageStats stats = this.recorder.toStats(childrenStats);
        Preconditions.checkState((stats.getCurrent() == this.getUsed() ? 1 : 0) != 0, (Object)"Used bytes mismatch between gluten memory consumer and Spark task memory manager");
        return stats;
    }

    public long spill(long size, MemoryConsumer trigger) throws IOException {
        return TreeMemoryTargets.spillTree(this, size);
    }

    @Override
    public TreeMemoryTarget newChild(String name, long capacity, Spiller spiller, Map<String, MemoryUsageStatsBuilder> virtualChildren) {
        Node child = new Node(this, name, capacity, spiller, virtualChildren);
        if (this.children.containsKey(child.name())) {
            throw new IllegalArgumentException("Child already registered: " + child.name());
        }
        this.children.put(child.name(), child);
        return child;
    }

    @Override
    public Map<String, TreeMemoryTarget> children() {
        return Collections.unmodifiableMap(this.children);
    }

    @Override
    public TreeMemoryTarget parent() {
        throw new IllegalStateException("Unreachable code org.apache.gluten.memory.memtarget.spark.TreeMemoryConsumer.parent");
    }

    @Override
    public Spiller getNodeSpiller() {
        return Spillers.NOOP;
    }

    public TaskMemoryManager getTaskMemoryManager() {
        return this.taskMemoryManager;
    }

    public static class Node
    implements TreeMemoryTarget,
    KnownNameAndStats {
        private final Map<String, Node> children = new HashMap<String, Node>();
        private final TreeMemoryTarget parent;
        private final String name;
        private final long capacity;
        private final Spiller spiller;
        private final Map<String, MemoryUsageStatsBuilder> virtualChildren;
        private final SimpleMemoryUsageRecorder selfRecorder = new SimpleMemoryUsageRecorder();

        private Node(TreeMemoryTarget parent, String name, long capacity, Spiller spiller, Map<String, MemoryUsageStatsBuilder> virtualChildren) {
            this.parent = parent;
            this.capacity = capacity;
            String uniqueName = MemoryTargetUtil.toUniqueName(name);
            this.name = capacity == Long.MAX_VALUE ? uniqueName : String.format("%s, %s", uniqueName, Utils.bytesToString((long)capacity));
            this.spiller = spiller;
            this.virtualChildren = virtualChildren;
        }

        @Override
        public long borrow(long size) {
            if (size == 0L) {
                return 0L;
            }
            this.ensureFreeCapacity(size);
            return this.borrow0(Math.min(this.freeBytes(), size));
        }

        private long freeBytes() {
            return this.capacity - this.usedBytes();
        }

        private long borrow0(long size) {
            long granted = this.parent.borrow(size);
            this.selfRecorder.inc(granted);
            return granted;
        }

        @Override
        public Spiller getNodeSpiller() {
            return this.spiller;
        }

        private boolean ensureFreeCapacity(long bytesNeeded) {
            long spilledBytes;
            do {
                long freeBytes;
                Preconditions.checkState(((freeBytes = this.freeBytes()) >= 0L ? 1 : 0) != 0);
                if (freeBytes >= bytesNeeded) {
                    return true;
                }
                long bytesToSpill = bytesNeeded - freeBytes;
                spilledBytes = TreeMemoryTargets.spillTree(this, bytesToSpill);
                Preconditions.checkState((spilledBytes >= 0L ? 1 : 0) != 0);
            } while (spilledBytes != 0L);
            return false;
        }

        @Override
        public long repay(long size) {
            if (size == 0L) {
                return 0L;
            }
            long toFree = Math.min(this.usedBytes(), size);
            long freed = this.parent.repay(toFree);
            this.selfRecorder.inc(-freed);
            return freed;
        }

        @Override
        public long usedBytes() {
            return this.selfRecorder.current();
        }

        @Override
        public <T> T accept(MemoryTargetVisitor<T> visitor) {
            return visitor.visit(this);
        }

        @Override
        public String name() {
            return this.name;
        }

        @Override
        public MemoryUsageStats stats() {
            HashMap<String, MemoryUsageStats> childrenStats = new HashMap<String, MemoryUsageStats>(this.children.entrySet().stream().collect(Collectors.toMap(e -> ((Node)e.getValue()).name(), e -> ((Node)e.getValue()).stats())));
            Preconditions.checkState((childrenStats.size() == this.children.size() ? 1 : 0) != 0);
            for (Map.Entry<String, MemoryUsageStatsBuilder> entry : this.virtualChildren.entrySet()) {
                if (childrenStats.containsKey(entry.getKey())) {
                    throw new IllegalArgumentException("Child stats already exists: " + entry.getKey());
                }
                childrenStats.put(entry.getKey(), entry.getValue().toStats());
            }
            return this.selfRecorder.toStats(childrenStats);
        }

        @Override
        public TreeMemoryTarget newChild(String name, long capacity, Spiller spiller, Map<String, MemoryUsageStatsBuilder> virtualChildren) {
            Node child = new Node(this, name, Math.min(this.capacity, capacity), spiller, virtualChildren);
            if (this.children.containsKey(child.name())) {
                throw new IllegalArgumentException("Child already registered: " + child.name());
            }
            this.children.put(child.name(), child);
            return child;
        }

        @Override
        public Map<String, TreeMemoryTarget> children() {
            return Collections.unmodifiableMap(this.children);
        }

        @Override
        public TreeMemoryTarget parent() {
            return this.parent;
        }
    }
}

