package org.apache.flink.runtime.state.heap;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
import org.apache.flink.runtime.state.StateSnapshot;
import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.StateTransformationFunction;
import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
import org.apache.flink.util.Preconditions;

@Internal
/* loaded from: input_file:org/apache/flink/runtime/state/heap/NestedMapsStateTable.class */
public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
    private final Map<N, Map<K, S>>[] state;
    private final int keyGroupOffset;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/flink/runtime/state/heap/NestedMapsStateTable$NestedMapsStateTableSnapshot.class */
    public static class NestedMapsStateTableSnapshot<K, N, S> extends AbstractStateTableSnapshot<K, N, S, NestedMapsStateTable<K, N, S>> implements StateSnapshot.StateKeyGroupWriter {
        private final TypeSerializer<K> keySerializer;
        private final TypeSerializer<N> namespaceSerializer;
        private final TypeSerializer<S> stateSerializer;
        private final StateSnapshotTransformer<S> snapshotFilter;

        NestedMapsStateTableSnapshot(NestedMapsStateTable<K, N, S> nestedMapsStateTable, StateSnapshotTransformer<S> stateSnapshotTransformer) {
            super(nestedMapsStateTable);
            this.snapshotFilter = stateSnapshotTransformer;
            this.keySerializer = ((NestedMapsStateTable) this.owningStateTable).keyContext.getKeySerializer();
            this.namespaceSerializer = ((NestedMapsStateTable) this.owningStateTable).metaInfo.getNamespaceSerializer();
            this.stateSerializer = ((NestedMapsStateTable) this.owningStateTable).metaInfo.getStateSerializer();
        }

        @Override // org.apache.flink.runtime.state.StateSnapshot
        @Nonnull
        public StateSnapshot.StateKeyGroupWriter getKeyGroupWriter() {
            return this;
        }

        @Override // org.apache.flink.runtime.state.StateSnapshot
        @Nonnull
        public StateMetaInfoSnapshot getMetaInfoSnapshot() {
            return ((NestedMapsStateTable) this.owningStateTable).metaInfo.snapshot();
        }

        @Override // org.apache.flink.runtime.state.StateSnapshot.StateKeyGroupWriter
        public void writeStateInKeyGroup(@Nonnull DataOutputView dataOutputView, int i) throws IOException {
            Map<N, Map<K, S>> mapForKeyGroup = ((NestedMapsStateTable) this.owningStateTable).getMapForKeyGroup(i);
            if (null == mapForKeyGroup) {
                dataOutputView.writeInt(0);
                return;
            }
            Map<N, Map<K, S>> filterMappingsInKeyGroupIfNeeded = filterMappingsInKeyGroupIfNeeded(mapForKeyGroup);
            dataOutputView.writeInt(NestedMapsStateTable.countMappingsInKeyGroup(filterMappingsInKeyGroupIfNeeded));
            for (Map.Entry<N, Map<K, S>> entry : filterMappingsInKeyGroupIfNeeded.entrySet()) {
                N key = entry.getKey();
                Iterator<Map.Entry<K, S>> it = entry.getValue().entrySet().iterator();
                while (it.hasNext()) {
                    writeElement(key, it.next(), dataOutputView);
                }
            }
        }

        private void writeElement(N n, Map.Entry<K, S> entry, DataOutputView dataOutputView) throws IOException {
            this.namespaceSerializer.serialize(n, dataOutputView);
            this.keySerializer.serialize(entry.getKey(), dataOutputView);
            this.stateSerializer.serialize(entry.getValue(), dataOutputView);
        }

        private Map<N, Map<K, S>> filterMappingsInKeyGroupIfNeeded(Map<N, Map<K, S>> map) {
            return this.snapshotFilter == null ? map : filterMappingsInKeyGroup(map);
        }

        private Map<N, Map<K, S>> filterMappingsInKeyGroup(Map<N, Map<K, S>> map) {
            HashMap hashMap = new HashMap();
            for (Map.Entry<N, Map<K, S>> entry : map.entrySet()) {
                Map map2 = (Map) hashMap.computeIfAbsent(entry.getKey(), obj -> {
                    return new HashMap();
                });
                for (Map.Entry<K, S> entry2 : entry.getValue().entrySet()) {
                    K key = entry2.getKey();
                    S filterOrTransform = this.snapshotFilter.filterOrTransform(entry2.getValue());
                    if (filterOrTransform != null) {
                        map2.put(key, filterOrTransform);
                    }
                }
            }
            return hashMap;
        }
    }

    public NestedMapsStateTable(InternalKeyContext<K> internalKeyContext, RegisteredKeyValueStateBackendMetaInfo<N, S> registeredKeyValueStateBackendMetaInfo) {
        super(internalKeyContext, registeredKeyValueStateBackendMetaInfo);
        this.keyGroupOffset = internalKeyContext.getKeyGroupRange().getStartKeyGroup();
        this.state = new Map[internalKeyContext.getKeyGroupRange().getNumberOfKeyGroups()];
    }

    @VisibleForTesting
    public Map<N, Map<K, S>>[] getState() {
        return this.state;
    }

    @VisibleForTesting
    Map<N, Map<K, S>> getMapForKeyGroup(int i) {
        int indexToOffset = indexToOffset(i);
        if (indexToOffset < 0 || indexToOffset >= this.state.length) {
            return null;
        }
        return this.state[indexToOffset];
    }

    private void setMapForKeyGroup(int i, Map<N, Map<K, S>> map) {
        try {
            this.state[indexToOffset(i)] = map;
        } catch (ArrayIndexOutOfBoundsException e) {
            throw new IllegalArgumentException("Key group index " + i + " is out of range of key group range [" + this.keyGroupOffset + ", " + (this.keyGroupOffset + this.state.length) + ").");
        }
    }

    private int indexToOffset(int i) {
        return i - this.keyGroupOffset;
    }

    @Override // org.apache.flink.runtime.state.heap.StateTable
    public int size() {
        int i = 0;
        for (Map<N, Map<K, S>> map : this.state) {
            if (null != map) {
                for (Map<K, S> map2 : map.values()) {
                    if (null != map2) {
                        i += map2.size();
                    }
                }
            }
        }
        return i;
    }

    @Override // org.apache.flink.runtime.state.heap.StateTable
    public S get(N n) {
        return get(this.keyContext.getCurrentKey(), this.keyContext.getCurrentKeyGroupIndex(), n);
    }

    @Override // org.apache.flink.runtime.state.heap.StateTable
    public boolean containsKey(N n) {
        return containsKey(this.keyContext.getCurrentKey(), this.keyContext.getCurrentKeyGroupIndex(), n);
    }

    @Override // org.apache.flink.runtime.state.heap.StateTable
    public void put(N n, S s) {
        put(this.keyContext.getCurrentKey(), this.keyContext.getCurrentKeyGroupIndex(), n, s);
    }

    @Override // org.apache.flink.runtime.state.heap.StateTable
    public S putAndGetOld(N n, S s) {
        return putAndGetOld(this.keyContext.getCurrentKey(), this.keyContext.getCurrentKeyGroupIndex(), n, s);
    }

    @Override // org.apache.flink.runtime.state.heap.StateTable
    public void remove(N n) {
        remove(this.keyContext.getCurrentKey(), this.keyContext.getCurrentKeyGroupIndex(), n);
    }

    @Override // org.apache.flink.runtime.state.heap.StateTable
    public S removeAndGetOld(N n) {
        return removeAndGetOld(this.keyContext.getCurrentKey(), this.keyContext.getCurrentKeyGroupIndex(), n);
    }

    @Override // org.apache.flink.runtime.state.heap.StateTable
    public S get(K k, N n) {
        return get(k, KeyGroupRangeAssignment.assignToKeyGroup(k, this.keyContext.getNumberOfKeyGroups()), n);
    }

    @Override // org.apache.flink.runtime.state.heap.StateTable
    public Stream<K> getKeys(N n) {
        return Arrays.stream(this.state).filter((v0) -> {
            return Objects.nonNull(v0);
        }).map(map -> {
            return (Map) map.getOrDefault(n, Collections.emptyMap());
        }).flatMap(map2 -> {
            return map2.keySet().stream();
        });
    }

    private boolean containsKey(K k, int i, N n) {
        Map<K, S> map;
        checkKeyNamespacePreconditions(k, n);
        Map<N, Map<K, S>> mapForKeyGroup = getMapForKeyGroup(i);
        return (mapForKeyGroup == null || (map = mapForKeyGroup.get(n)) == null || !map.containsKey(k)) ? false : true;
    }

    S get(K k, int i, N n) {
        Map<K, S> map;
        checkKeyNamespacePreconditions(k, n);
        Map<N, Map<K, S>> mapForKeyGroup = getMapForKeyGroup(i);
        if (mapForKeyGroup == null || (map = mapForKeyGroup.get(n)) == null) {
            return null;
        }
        return map.get(k);
    }

    @Override // org.apache.flink.runtime.state.heap.StateTable
    public void put(K k, int i, N n, S s) {
        putAndGetOld(k, i, n, s);
    }

    private S putAndGetOld(K k, int i, N n, S s) {
        checkKeyNamespacePreconditions(k, n);
        Map<N, Map<K, S>> mapForKeyGroup = getMapForKeyGroup(i);
        if (mapForKeyGroup == null) {
            mapForKeyGroup = new HashMap();
            setMapForKeyGroup(i, mapForKeyGroup);
        }
        return mapForKeyGroup.computeIfAbsent(n, obj -> {
            return new HashMap();
        }).put(k, s);
    }

    private void remove(K k, int i, N n) {
        removeAndGetOld(k, i, n);
    }

    private S removeAndGetOld(K k, int i, N n) {
        Map<K, S> map;
        checkKeyNamespacePreconditions(k, n);
        Map<N, Map<K, S>> mapForKeyGroup = getMapForKeyGroup(i);
        if (mapForKeyGroup == null || (map = mapForKeyGroup.get(n)) == null) {
            return null;
        }
        S remove = map.remove(k);
        if (map.isEmpty()) {
            mapForKeyGroup.remove(n);
        }
        return remove;
    }

    private void checkKeyNamespacePreconditions(K k, N n) {
        Preconditions.checkNotNull(k, "No key set. This method should not be called outside of a keyed context.");
        Preconditions.checkNotNull(n, "Provided namespace is null.");
    }

    @Override // org.apache.flink.runtime.state.heap.StateTable
    public int sizeOfNamespace(Object obj) {
        int i = 0;
        for (Map<N, Map<K, S>> map : this.state) {
            if (null != map) {
                Map<K, S> map2 = map.get(obj);
                i += map2 != null ? map2.size() : 0;
            }
        }
        return i;
    }

    @Override // org.apache.flink.runtime.state.heap.StateTable
    public <T> void transform(N n, T t, StateTransformationFunction<S, T> stateTransformationFunction) throws Exception {
        K currentKey = this.keyContext.getCurrentKey();
        checkKeyNamespacePreconditions(currentKey, n);
        int currentKeyGroupIndex = this.keyContext.getCurrentKeyGroupIndex();
        Map<N, Map<K, S>> mapForKeyGroup = getMapForKeyGroup(currentKeyGroupIndex);
        if (mapForKeyGroup == null) {
            mapForKeyGroup = new HashMap<>();
            setMapForKeyGroup(currentKeyGroupIndex, mapForKeyGroup);
        }
        Map<K, S> computeIfAbsent = mapForKeyGroup.computeIfAbsent(n, obj -> {
            return new HashMap();
        });
        computeIfAbsent.put(currentKey, stateTransformationFunction.apply(computeIfAbsent.get(currentKey), t));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static <K, N, S> int countMappingsInKeyGroup(Map<N, Map<K, S>> map) {
        int i = 0;
        Iterator<Map<K, S>> it = map.values().iterator();
        while (it.hasNext()) {
            i += it.next().size();
        }
        return i;
    }

    @Override // org.apache.flink.runtime.state.StateSnapshotRestore
    @Nonnull
    public NestedMapsStateTableSnapshot<K, N, S> stateSnapshot() {
        return new NestedMapsStateTableSnapshot<>(this, this.metaInfo.getSnapshotTransformer());
    }
}
