package org.apache.flink.runtime.state.heap;

import java.io.IOException;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.state.AggregatingState;
import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.StateTransformationFunction;
import org.apache.flink.runtime.state.internal.InternalAggregatingState;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/runtime/state/heap/HeapAggregatingState.class */
public class HeapAggregatingState<K, N, IN, ACC, OUT> extends AbstractHeapMergingState<K, N, IN, OUT, ACC, AggregatingState<IN, OUT>, AggregatingStateDescriptor<IN, ACC, OUT>> implements InternalAggregatingState<N, IN, OUT> {
    private final AggregateTransformation<IN, ACC, OUT> aggregateTransformation;

    /* loaded from: input_file:org/apache/flink/runtime/state/heap/HeapAggregatingState$AggregateTransformation.class */
    static final class AggregateTransformation<IN, ACC, OUT> implements StateTransformationFunction<ACC, IN> {
        private final AggregateFunction<IN, ACC, OUT> aggFunction;

        AggregateTransformation(AggregateFunction<IN, ACC, OUT> aggregateFunction) {
            this.aggFunction = (AggregateFunction) Preconditions.checkNotNull(aggregateFunction);
        }

        @Override // org.apache.flink.runtime.state.StateTransformationFunction
        public ACC apply(ACC acc, IN in) throws Exception {
            if (acc == null) {
                acc = this.aggFunction.createAccumulator();
            }
            this.aggFunction.add(in, acc);
            return acc;
        }
    }

    public HeapAggregatingState(AggregatingStateDescriptor<IN, ACC, OUT> aggregatingStateDescriptor, StateTable<K, N, ACC> stateTable, TypeSerializer<K> typeSerializer, TypeSerializer<N> typeSerializer2) {
        super(aggregatingStateDescriptor, stateTable, typeSerializer, typeSerializer2);
        this.aggregateTransformation = new AggregateTransformation<>(aggregatingStateDescriptor.getAggregateFunction());
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.flink.api.common.state.AppendingState
    public OUT get() {
        Object obj = this.stateTable.get(this.currentNamespace);
        if (obj != null) {
            return (OUT) ((AggregateTransformation) this.aggregateTransformation).aggFunction.getResult(obj);
        }
        return null;
    }

    @Override // org.apache.flink.api.common.state.AppendingState
    public void add(IN in) throws IOException {
        N n = this.currentNamespace;
        if (in == null) {
            clear();
            return;
        }
        try {
            this.stateTable.transform(n, in, this.aggregateTransformation);
        } catch (Exception e) {
            throw new IOException("Exception while applying AggregateFunction in aggregating state", e);
        }
    }

    @Override // org.apache.flink.runtime.state.heap.AbstractHeapMergingState
    protected ACC mergeState(ACC acc, ACC acc2) throws Exception {
        return (ACC) ((AggregateTransformation) this.aggregateTransformation).aggFunction.merge(acc, acc2);
    }
}
