/*
 * Decompiled with CFR 0.152.
 */
package org.openeuler.sm4.mode;

import java.security.AlgorithmParameters;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.InvalidParameterSpecException;
import java.util.Arrays;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.ShortBufferException;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec;
import org.openeuler.BGMJCEProvider;
import org.openeuler.sm4.SM4Util;
import org.openeuler.sm4.StreamModeBaseCipher;

public class OCB
extends StreamModeBaseCipher {
    private byte[] aad;
    private int defaultIvLen = 15;
    private int tLen = 128;
    private byte[] H;
    private byte[] L_$;
    private byte[][] L;
    private byte[] L_0;
    private byte[] nonce;
    private byte[] checkSum;
    private byte[] offset;
    private byte[] ktop;
    private byte bottom;
    private byte[] stretch;
    private byte[] tag;

    @Override
    public void engineInit(int opmode, Key key, AlgorithmParameterSpec params, SecureRandom random) throws InvalidKeyException, InvalidAlgorithmParameterException {
        this.init(opmode, key);
        if (params == null) {
            if (this.opmode == 1) {
                if (random == null) {
                    random = BGMJCEProvider.getRandom();
                }
                this.iv = new byte[this.defaultIvLen];
                random.nextBytes(this.iv);
            } else if (this.opmode == 2) {
                throw new InvalidAlgorithmParameterException("need an IV");
            }
        } else if (!(params instanceof GCMParameterSpec)) {
            if (!(params instanceof IvParameterSpec)) {
                throw new InvalidAlgorithmParameterException();
            }
            IvParameterSpec param = (IvParameterSpec)params;
            if (param.getIV() == null || param.getIV().length > 15) {
                throw new InvalidAlgorithmParameterException("IV no more than 15 bytes long.");
            }
            this.iv = param.getIV();
        } else {
            GCMParameterSpec gcmParam = (GCMParameterSpec)params;
            this.checkTagLen(gcmParam);
            if (gcmParam.getIV() == null || gcmParam.getIV().length > 15) {
                throw new InvalidAlgorithmParameterException("IV no more than 15 bytes long.");
            }
            this.tLen = gcmParam.getTLen();
            this.iv = gcmParam.getIV();
        }
        this.H = this.sm4.encrypt(this.rk, new byte[16], 0);
        this.L_$ = this.double_(this.H);
        this.L_0 = this.double_(this.L_$);
        this.init();
        this.isInitialized = true;
    }

    @Override
    public void engineInit(int opmode, Key key, SecureRandom random) throws InvalidKeyException {
        try {
            this.engineInit(opmode, key, (AlgorithmParameterSpec)null, random);
        }
        catch (InvalidAlgorithmParameterException e) {
            throw new InvalidKeyException(e.getMessage());
        }
    }

    @Override
    public void engineInit(int opmode, Key key, AlgorithmParameters params, SecureRandom random) throws InvalidKeyException, InvalidAlgorithmParameterException {
        AlgorithmParameterSpec spec = null;
        String paramType = null;
        if (params != null) {
            try {
                paramType = "GCM or IV";
                spec = params.getParameterSpec(GCMParameterSpec.class);
            }
            catch (InvalidParameterSpecException e) {
                try {
                    spec = params.getParameterSpec(IvParameterSpec.class);
                }
                catch (InvalidParameterSpecException ex) {
                    throw new InvalidAlgorithmParameterException("Wrong parameter type: " + paramType + " expected");
                }
            }
        }
        this.engineInit(opmode, key, spec, random);
    }

    @Override
    protected void engineUpdateAAD(byte[] src, int offset, int len) {
        if (!this.isInitialized) {
            throw new IllegalStateException("cipher uninitialized");
        }
        this.aad = offset == 0 && len == src.length ? src : Arrays.copyOfRange(src, offset, offset + len);
    }

    @Override
    public int engineGetOutputSize(int inputLen) {
        if (this.opmode == 1) {
            return inputLen + this.tLen / 8;
        }
        if (this.opmode == 2) {
            return inputLen - this.tLen / 8;
        }
        return 0;
    }

    @Override
    public void engineSetPadding(String padding) throws NoSuchPaddingException {
        if (!padding.toUpperCase().equals("NOPADDING")) {
            throw new NoSuchPaddingException("only nopadding can be used in this mode");
        }
        super.engineSetPadding(padding);
    }

    @Override
    public AlgorithmParameters engineGetParameters() {
        AlgorithmParameters sm4Paraeters = null;
        try {
            sm4Paraeters = AlgorithmParameters.getInstance("SM4");
        }
        catch (NoSuchAlgorithmException e) {
            e.printStackTrace();
        }
        try {
            sm4Paraeters.init(new GCMParameterSpec(this.tLen, this.iv));
        }
        catch (InvalidParameterSpecException e) {
            e.printStackTrace();
        }
        return sm4Paraeters;
    }

    @Override
    public byte[] engineUpdate(byte[] input, int inputOffset, int inputLen) {
        if (!this.isInitialized) {
            throw new IllegalStateException("cipher uninitialized");
        }
        if (input == null || inputLen == 0) {
            return null;
        }
        this.inputUpdate = input;
        this.inputLenUpdate = inputLen;
        this.inputOffsetUpdate = inputOffset;
        byte[] res = null;
        if (this.opmode == 1) {
            this.len = inputLen - inputLen % 16;
            if (this.len == 0) {
                return null;
            }
        } else if (this.opmode == 2) {
            this.len = inputLen - this.tLen / 8;
            this.len -= this.len % 16;
            if (this.len <= 0) {
                this.len = 0;
                return null;
            }
        }
        this.checkSum = new byte[16];
        res = new byte[this.len];
        this.initL(this.len / 16);
        this.processOCB(input, inputOffset, this.len, res, 0);
        return res;
    }

    @Override
    public int engineUpdate(byte[] input, int inputOffset, int inputLen, byte[] output, int outputOffset) throws ShortBufferException {
        if (!this.isInitialized) {
            throw new IllegalStateException("cipher uninitialized");
        }
        if (input == null || inputLen == 0) {
            return 0;
        }
        this.inputUpdate = input;
        this.inputLenUpdate = inputLen;
        this.inputOffsetUpdate = inputOffset;
        if (this.opmode == 1) {
            this.len = inputLen - inputLen % 16;
            if (this.len == 0) {
                return 0;
            }
        } else if (this.opmode == 2) {
            this.len = inputLen - this.tLen / 8;
            this.len -= this.len % 16;
            if (this.len <= 0) {
                this.len = 0;
                return 0;
            }
        }
        this.checkSum = new byte[16];
        this.initL(this.len / 16);
        this.processOCB(input, inputOffset, this.len, output, outputOffset);
        return this.len;
    }

    @Override
    public byte[] engineDoFinal(byte[] input, int inputOffset, int inputLen) throws IllegalBlockSizeException, BadPaddingException {
        if (!this.isInitialized) {
            throw new IllegalStateException("cipher uninitialized");
        }
        int restLen = this.inputLenUpdate - this.len;
        byte[] res = null;
        if (this.opmode == 1) {
            res = new byte[inputLen + restLen + this.tLen / 8];
            if (restLen == 0) {
                this.encrypt(input, inputOffset, inputLen, res, 0);
            } else {
                byte[] allInput = new byte[inputLen + restLen];
                SM4Util.copyArray(this.inputUpdate, this.inputOffsetUpdate + this.len, restLen, allInput, 0);
                SM4Util.copyArray(input, inputOffset, inputLen, allInput, restLen);
                this.encrypt(allInput, 0, allInput.length, res, 0);
            }
        } else if (this.opmode == 2) {
            if (restLen + inputLen < this.tLen / 8) {
                throw new IllegalBlockSizeException();
            }
            res = new byte[inputLen + restLen - this.tLen / 8];
            if (restLen == 0) {
                this.decrypt(input, inputOffset, inputLen, res, 0);
            } else {
                byte[] allInput = new byte[inputLen + restLen];
                SM4Util.copyArray(this.inputUpdate, this.inputOffsetUpdate + this.len, restLen, allInput, 0);
                SM4Util.copyArray(input, inputOffset, inputLen, allInput, restLen);
                this.decrypt(allInput, 0, allInput.length, res, 0);
            }
        }
        this.reset();
        return res;
    }

    @Override
    public int engineDoFinal(byte[] input, int inputOffset, int inputLen, byte[] output, int outputOffset) throws ShortBufferException, IllegalBlockSizeException, BadPaddingException {
        if (!this.isInitialized) {
            throw new IllegalStateException("cipher uninitialized");
        }
        int restLen = this.inputLenUpdate - this.len;
        int need = 0;
        if (this.opmode == 1) {
            need = inputLen + restLen + this.tLen / 8;
            if (outputOffset + need > output.length) {
                throw new ShortBufferException();
            }
            if (restLen == 0) {
                this.encrypt(input, inputOffset, inputLen, output, outputOffset);
            } else {
                byte[] allInput = new byte[inputLen + restLen];
                SM4Util.copyArray(this.inputUpdate, this.inputOffsetUpdate + this.len, restLen, allInput, 0);
                SM4Util.copyArray(input, inputOffset, inputLen, allInput, restLen);
                this.encrypt(allInput, 0, allInput.length, output, outputOffset);
            }
        } else if (this.opmode == 2) {
            if (restLen + inputLen < this.tLen / 8) {
                throw new IllegalBlockSizeException();
            }
            need = inputLen + restLen - this.tLen / 8;
            if (restLen == 0) {
                this.decrypt(input, inputOffset, inputLen, output, outputOffset);
            } else {
                byte[] allInput = new byte[inputLen + restLen];
                SM4Util.copyArray(this.inputUpdate, this.inputOffsetUpdate + this.len, restLen, allInput, 0);
                SM4Util.copyArray(input, inputOffset, inputLen, allInput, restLen);
                this.decrypt(allInput, 0, allInput.length, output, outputOffset);
            }
        }
        this.reset();
        return need;
    }

    private void processOCB(byte[] input, int inputOffset, int inputLen, byte[] output, int outputOffset) {
        block3: {
            block2: {
                if (this.opmode != 1) break block2;
                int i = inputOffset;
                while (i + 16 <= inputOffset + inputLen) {
                    this.offset = this.sm4.xor(this.offset, 0, 16, this.L[this.ntz((i - inputOffset) / 16 + 1)], 0, 16);
                    byte[] xor1 = this.sm4.xor(this.offset, 0, 16, this.sm4.encrypt(this.rk, this.sm4.xor(this.offset, 0, 16, input, i, 16), 0), 0, 16);
                    SM4Util.copyArray(xor1, 0, xor1.length, output, outputOffset + (i - inputOffset));
                    this.checkSum = this.sm4.xor(this.checkSum, 0, 16, input, i, 16);
                    i += 16;
                }
                break block3;
            }
            if (this.opmode != 2) break block3;
            int i = inputOffset;
            while (i + 16 <= inputLen + inputOffset) {
                this.offset = this.sm4.xor(this.offset, this.L[this.ntz((i - inputOffset) / 16 + 1)]);
                byte[] xor1 = this.sm4.xor(this.offset, this.sm4.decrypt(this.rk, this.sm4.xor(this.offset, 0, 16, input, i, 16), 0));
                SM4Util.copyArray(xor1, 0, xor1.length, output, outputOffset + i - inputOffset);
                this.checkSum = this.sm4.xor(this.checkSum, xor1);
                i += 16;
            }
        }
    }

    private byte[] hash() {
        if (this.aad == null || this.aad.length == 0) {
            return new byte[16];
        }
        int m = this.aad.length * 8 / 128;
        this.L = new byte[m + 1][];
        this.L[0] = this.L_0;
        for (int i = 1; i < this.L.length; ++i) {
            this.L[i] = this.double_(this.L[i - 1]);
        }
        byte[] sum = new byte[16];
        byte[] offset = new byte[16];
        for (int i = 0; i < m; ++i) {
            offset = this.sm4.xor(offset, 0, 16, this.L[this.ntz(i + 1)], 0, 16);
            sum = this.sm4.xor(sum, 0, 16, this.sm4.encrypt(this.rk, this.sm4.xor(Arrays.copyOfRange(this.aad, i * 16, (i + 1) * 16), 0, 16, offset, 0, 16), 0), 0, 16);
        }
        if (this.aad.length % 16 != 0) {
            offset = this.sm4.xor(offset, this.H);
            byte[] cipherInput = new byte[16];
            SM4Util.copyArray(this.aad, this.aad.length - this.aad.length % 16, this.aad.length % 16, cipherInput, 0);
            cipherInput[this.aad.length % 16] = -128;
            cipherInput = this.sm4.xor(cipherInput, offset);
            sum = this.sm4.xor(sum, this.sm4.encrypt(this.rk, cipherInput, 0));
        }
        return sum;
    }

    private byte[] double_(byte[] h) {
        byte[] res = null;
        if ((h[0] & 0xFFFFFF80) == 0) {
            res = this.moveLeftOneBit(h);
        } else {
            byte[] tem = new byte[16];
            tem[15] = -121;
            res = this.sm4.xor(this.moveLeftOneBit(h), tem);
        }
        return res;
    }

    private byte[] moveLeftOneBit(byte[] input) {
        byte[] res = new byte[input.length];
        for (int i = 0; i < input.length; ++i) {
            boolean msb;
            res[i] = (byte)(input[i] << 1);
            if (i == input.length - 1) continue;
            boolean bl = msb = (byte)(input[i + 1] & 0x80) != 0;
            if (!msb) continue;
            int n = i;
            res[n] = (byte)(res[n] | 1);
        }
        return res;
    }

    private int ntz(int num) {
        String s = Integer.toBinaryString(num);
        int sum = 0;
        for (int i = s.length() - 1; i >= 0 && s.charAt(i) == '0'; --i) {
            ++sum;
        }
        return sum;
    }

    private void init() {
        int i;
        this.nonce = new byte[16];
        int mod = this.tLen % 128;
        this.nonce[0] = (byte)(mod << 1);
        if (this.iv.length != 0) {
            int t = 16 - this.iv.length;
            if (this.iv.length == 15) {
                this.nonce[0] = (byte)(this.nonce[0] | 1);
            } else {
                this.nonce[t - 1] = 1;
            }
            for (int i2 = 0; i2 < this.iv.length; ++i2) {
                this.nonce[i2 + t] = this.iv[i2];
            }
        } else {
            this.nonce[15] = 1;
        }
        this.bottom = this.nonce[15];
        this.bottom = (byte)(this.bottom & 0x3F);
        byte lastByteOfNonce = this.nonce[15];
        this.nonce[15] = (byte)(this.nonce[15] & 0xC0);
        this.ktop = this.sm4.encrypt(this.rk, this.nonce, 0);
        this.nonce[15] = lastByteOfNonce;
        this.stretch = new byte[32];
        SM4Util.copyArray(this.ktop, 0, this.ktop.length, this.stretch, 0);
        byte[] xor = this.sm4.xor(Arrays.copyOfRange(this.ktop, 0, 8), Arrays.copyOfRange(this.ktop, 1, 9));
        for (i = 0; i < xor.length; ++i) {
            this.stretch[16 + i] = xor[i];
        }
        this.offset = new byte[16];
        for (i = 0; i < 128; ++i) {
            this.setI(this.offset, i, this.getI(this.stretch, i + this.bottom));
        }
    }

    private void initL(int m) {
        this.L = new byte[m + 1][];
        this.L[0] = this.L_0;
        for (int i = 1; i < this.L.length; ++i) {
            this.L[i] = this.double_(this.L[i - 1]);
        }
    }

    private void encrypt(byte[] input, int inputOffset, int inputLen, byte[] output, int outputOffset) {
        this.initL(inputLen / 16 + this.len / 16);
        if (this.checkSum == null) {
            this.checkSum = new byte[16];
        }
        int i = inputOffset;
        while (i + 16 <= inputOffset + inputLen) {
            this.offset = this.sm4.xor(this.offset, 0, 16, this.L[this.ntz(this.len / 16 + (i - inputOffset) / 16 + 1)], 0, 16);
            byte[] xor1 = this.sm4.xor(this.offset, 0, 16, this.sm4.encrypt(this.rk, this.sm4.xor(this.offset, 0, 16, input, i, 16), 0), 0, 16);
            SM4Util.copyArray(xor1, 0, xor1.length, output, outputOffset + (i - inputOffset));
            this.checkSum = this.sm4.xor(this.checkSum, 0, 16, input, i, 16);
            i += 16;
        }
        if (inputLen % 16 != 0) {
            this.offset = this.sm4.xor16Byte(this.offset, this.H);
            byte[] pad = this.sm4.encrypt(this.rk, this.offset, 0);
            byte[] xor1 = this.sm4.xor(input, inputOffset + inputLen - inputLen % 16, inputLen % 16, pad, 0, 16);
            SM4Util.copyArray(xor1, 0, xor1.length, output, outputOffset + i - inputOffset);
            byte[] tem = new byte[16];
            SM4Util.copyArray(input, inputOffset + inputLen - inputLen % 16, inputLen % 16, tem, 0);
            tem[inputLen % 16] = -128;
            this.checkSum = this.sm4.xor(this.checkSum, tem);
            this.tag = this.sm4.xor(this.sm4.encrypt(this.rk, this.sm4.xor(this.L_$, this.sm4.xor(this.checkSum, this.offset)), 0), this.hash());
        } else {
            this.tag = this.sm4.xor(this.sm4.encrypt(this.rk, this.sm4.xor(this.L_$, this.sm4.xor(this.checkSum, this.offset)), 0), this.hash());
        }
        SM4Util.copyArray(this.tag, 0, this.tLen / 8, output, outputOffset + inputLen);
    }

    private void decrypt(byte[] input, int inputOffset, int inputLen, byte[] output, int outputOffset) {
        byte[] T = Arrays.copyOfRange(input, inputOffset + inputLen - this.tLen / 8, inputOffset + inputLen);
        this.initL((inputLen - this.tLen / 8) / 16 + this.len / 16);
        if (this.checkSum == null) {
            this.checkSum = new byte[16];
        }
        int i = inputOffset;
        while (i + 16 <= inputOffset + inputLen - this.tLen / 8) {
            this.offset = this.sm4.xor(this.offset, this.L[this.ntz(this.len / 16 + (i - inputOffset) / 16 + 1)]);
            byte[] xor1 = this.sm4.xor(this.offset, this.sm4.decrypt(this.rk, this.sm4.xor(this.offset, 0, 16, input, i, 16), 0));
            SM4Util.copyArray(xor1, 0, xor1.length, output, outputOffset + (i - inputOffset));
            this.checkSum = this.sm4.xor(this.checkSum, xor1);
            i += 16;
        }
        byte[] tag = null;
        if ((inputLen - this.tLen / 8) % 16 != 0) {
            this.offset = this.sm4.xor16Byte(this.offset, this.H);
            byte[] pad = this.sm4.encrypt(this.rk, this.offset, 0);
            byte[] xor1 = this.sm4.xor(input, inputOffset + inputLen - this.tLen / 8 - (inputLen - this.tLen / 8) % 16, (inputLen - this.tLen / 8) % 16, pad, 0, 16);
            SM4Util.copyArray(xor1, 0, xor1.length, output, outputOffset + i - inputOffset);
            byte[] tem = new byte[16];
            SM4Util.copyArray(xor1, 0, xor1.length, tem, 0);
            tem[xor1.length] = -128;
            this.checkSum = this.sm4.xor(this.checkSum, tem);
            tag = this.sm4.xor(this.sm4.encrypt(this.rk, this.sm4.xor(this.L_$, this.sm4.xor(this.checkSum, this.offset)), 0), this.hash());
        } else {
            tag = this.sm4.xor(this.sm4.encrypt(this.rk, this.sm4.xor(this.L_$, this.sm4.xor(this.checkSum, this.offset)), 0), this.hash());
        }
        this.checkMac(T, Arrays.copyOfRange(tag, 0, T.length));
    }

    private int getI(byte[] arr, int i) {
        return (arr[i / 8] & 1 << 7 - i % 8) == 0 ? 0 : 1;
    }

    private void setI(byte[] arr, int i, int target) {
        if (target == 0) {
            switch (i % 8) {
                case 0: {
                    int n = i / 8;
                    arr[n] = (byte)(arr[n] & 0x7F);
                    break;
                }
                case 1: {
                    int n = i / 8;
                    arr[n] = (byte)(arr[n] & 0xBF);
                    break;
                }
                case 2: {
                    int n = i / 8;
                    arr[n] = (byte)(arr[n] & 0xDF);
                    break;
                }
                case 3: {
                    int n = i / 8;
                    arr[n] = (byte)(arr[n] & 0xEF);
                    break;
                }
                case 4: {
                    int n = i / 8;
                    arr[n] = (byte)(arr[n] & 0xF7);
                    break;
                }
                case 5: {
                    int n = i / 8;
                    arr[n] = (byte)(arr[n] & 0xFB);
                    break;
                }
                case 6: {
                    int n = i / 8;
                    arr[n] = (byte)(arr[n] & 0xFD);
                    break;
                }
                case 7: {
                    int n = i / 8;
                    arr[n] = (byte)(arr[n] & 0xFE);
                }
            }
        } else if (target == 1) {
            int n = i / 8;
            arr[n] = (byte)(arr[n] | 1 << 7 - i % 8);
        }
    }

    private void checkTagLen(GCMParameterSpec gcmParam) throws InvalidAlgorithmParameterException {
        if (gcmParam.getTLen() % 8 != 0) {
            throw new InvalidAlgorithmParameterException("invalid  mac size " + gcmParam.getTLen());
        }
        if (gcmParam.getTLen() < 64 || gcmParam.getTLen() > 128) {
            throw new InvalidAlgorithmParameterException("invalid  mac size " + gcmParam.getTLen());
        }
    }

    private void checkMac(byte[] T, byte[] _T) {
        if (!Arrays.equals(T, _T)) {
            throw new RuntimeException("mac check faild in OCB mode");
        }
    }

    @Override
    public void reset() {
        this.checkSum = null;
        this.aad = null;
        this.L = null;
        this.nonce = null;
        this.offset = null;
        this.ktop = null;
        this.stretch = null;
        this.tag = null;
        this.tLen = 128;
        this.H = this.sm4.encrypt(this.rk, new byte[16], 0);
        this.L_$ = this.double_(this.H);
        this.L_0 = this.double_(this.L_$);
        this.init();
        super.reset();
    }
}

