package ai.djl.training.initializer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;

/* loaded from: input_file:ai/djl/training/initializer/XavierInitializer.class */
public class XavierInitializer implements Initializer {
    private RandomType randomType;
    private FactorType factorType;
    private float magnitude;

    /* loaded from: input_file:ai/djl/training/initializer/XavierInitializer$FactorType.class */
    public enum FactorType {
        AVG,
        IN,
        OUT
    }

    /* loaded from: input_file:ai/djl/training/initializer/XavierInitializer$RandomType.class */
    public enum RandomType {
        UNIFORM,
        GAUSSIAN
    }

    public XavierInitializer(RandomType randomType, FactorType factorType, float f) {
        this.randomType = randomType;
        this.factorType = factorType;
        this.magnitude = f;
    }

    public XavierInitializer() {
        this(RandomType.UNIFORM, FactorType.AVG, 3.0f);
    }

    @Override // ai.djl.training.initializer.Initializer
    public NDArray initialize(NDManager nDManager, Shape shape, DataType dataType) {
        float f;
        long dimension = shape.dimension();
        if (dimension < 2) {
            throw new IllegalArgumentException("XavierInitializer cannot be applied to Shape with dimension: " + dimension + ", it requires shape to be at least 2D.");
        }
        float size = dimension == 2 ? 1.0f : (float) shape.slice(2).size();
        float f2 = ((float) shape.get(1)) * size;
        float head = ((float) shape.head()) * size;
        switch (this.factorType) {
            case AVG:
                f = (f2 + head) / 2.0f;
                break;
            case IN:
                f = f2;
                break;
            case OUT:
                f = head;
                break;
            default:
                throw new IllegalArgumentException("Invalid factor type, valid types are: avg, in, out");
        }
        if (f == 0.0f) {
            throw new IllegalStateException("Xavier initializer factor is 0, please check your input shape.");
        }
        float sqrt = (float) StrictMath.sqrt(this.magnitude / f);
        switch (this.randomType) {
            case UNIFORM:
                return nDManager.randomUniform(-sqrt, sqrt, shape, dataType, nDManager.getDevice());
            case GAUSSIAN:
                return nDManager.randomNormal(0.0f, sqrt, shape, dataType, nDManager.getDevice());
            default:
                throw new IllegalArgumentException("Invalid randomType");
        }
    }
}
