package org.eclipse.milo.opcua.stack.server.handlers;

import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import java.io.IOException;
import java.net.URI;
import java.nio.ByteOrder;
import java.security.KeyPair;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import org.eclipse.milo.opcua.stack.core.StatusCodes;
import org.eclipse.milo.opcua.stack.core.UaException;
import org.eclipse.milo.opcua.stack.core.application.CertificateManager;
import org.eclipse.milo.opcua.stack.core.application.CertificateValidator;
import org.eclipse.milo.opcua.stack.core.channel.ChannelSecurity;
import org.eclipse.milo.opcua.stack.core.channel.ExceptionHandler;
import org.eclipse.milo.opcua.stack.core.channel.SerializationQueue;
import org.eclipse.milo.opcua.stack.core.channel.ServerSecureChannel;
import org.eclipse.milo.opcua.stack.core.channel.headers.AsymmetricSecurityHeader;
import org.eclipse.milo.opcua.stack.core.channel.headers.HeaderDecoder;
import org.eclipse.milo.opcua.stack.core.channel.messages.ErrorMessage;
import org.eclipse.milo.opcua.stack.core.channel.messages.MessageType;
import org.eclipse.milo.opcua.stack.core.security.SecurityAlgorithm;
import org.eclipse.milo.opcua.stack.core.security.SecurityPolicy;
import org.eclipse.milo.opcua.stack.core.types.builtin.ByteString;
import org.eclipse.milo.opcua.stack.core.types.builtin.DateTime;
import org.eclipse.milo.opcua.stack.core.types.builtin.StatusCode;
import org.eclipse.milo.opcua.stack.core.types.builtin.unsigned.Unsigned;
import org.eclipse.milo.opcua.stack.core.types.enumerated.SecurityTokenRequestType;
import org.eclipse.milo.opcua.stack.core.types.structured.ChannelSecurityToken;
import org.eclipse.milo.opcua.stack.core.types.structured.EndpointDescription;
import org.eclipse.milo.opcua.stack.core.types.structured.OpenSecureChannelRequest;
import org.eclipse.milo.opcua.stack.core.types.structured.OpenSecureChannelResponse;
import org.eclipse.milo.opcua.stack.core.types.structured.ResponseHeader;
import org.eclipse.milo.opcua.stack.core.util.BufferUtil;
import org.eclipse.milo.opcua.stack.core.util.NonceUtil;
import org.eclipse.milo.opcua.stack.server.tcp.UaTcpStackServer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/eclipse/milo/opcua/stack/server/handlers/UaTcpServerAsymmetricHandler.class */
public class UaTcpServerAsymmetricHandler extends ByteToMessageDecoder implements HeaderDecoder {
    private static final long SecureChannelLifetimeMin = 3600000;
    private static final long SecureChannelLifetimeMax = 86400000;
    private ServerSecureChannel secureChannel;
    private final int maxChunkCount;
    private final int maxChunkSize;
    private final UaTcpStackServer server;
    private final SerializationQueue serializationQueue;
    private final Logger logger = LoggerFactory.getLogger(getClass());
    private volatile boolean symmetricHandlerAdded = false;
    private List<ByteBuf> chunkBuffers = new ArrayList();
    private final AtomicReference<AsymmetricSecurityHeader> headerRef = new AtomicReference<>();

    public UaTcpServerAsymmetricHandler(UaTcpStackServer uaTcpStackServer, SerializationQueue serializationQueue) {
        this.server = uaTcpStackServer;
        this.serializationQueue = serializationQueue;
        this.maxChunkCount = serializationQueue.getParameters().getLocalMaxChunkCount();
        this.maxChunkSize = serializationQueue.getParameters().getLocalReceiveBufferSize();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.netty.handler.codec.ByteToMessageDecoder
    public void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception {
        ByteBuf order = byteBuf.order(ByteOrder.LITTLE_ENDIAN);
        while (order.readableBytes() >= 8 && order.readableBytes() >= getMessageLength(order)) {
            int messageLength = getMessageLength(order);
            MessageType fromMediumInt = MessageType.fromMediumInt(order.getMedium(order.readerIndex()));
            switch (fromMediumInt) {
                case OpenSecureChannel:
                    onOpenSecureChannel(channelHandlerContext, order.readSlice(messageLength));
                    break;
                case CloseSecureChannel:
                    this.logger.debug("Received CloseSecureChannelRequest");
                    if (this.secureChannel != null) {
                        this.server.closeSecureChannel(this.secureChannel);
                    }
                    order.skipBytes(messageLength);
                    break;
                default:
                    throw new UaException(StatusCodes.Bad_TcpMessageTypeInvalid, "unexpected MessageType: " + fromMediumInt);
            }
        }
    }

    private void onOpenSecureChannel(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf) throws UaException {
        byteBuf.skipBytes(3);
        char readByte = (char) byteBuf.readByte();
        if (readByte == 'A') {
            this.chunkBuffers.forEach((v0) -> {
                v0.release();
            });
            this.chunkBuffers.clear();
            this.headerRef.set(null);
            return;
        }
        byteBuf.skipBytes(4);
        long readUnsignedInt = byteBuf.readUnsignedInt();
        AsymmetricSecurityHeader decode = AsymmetricSecurityHeader.decode(byteBuf);
        if (readUnsignedInt == 0) {
            String str = (String) channelHandlerContext.channel().attr(UaTcpServerHelloHandler.ENDPOINT_URL_KEY).get();
            String securityPolicyUri = decode.getSecurityPolicyUri();
            EndpointDescription endpointDescription = (EndpointDescription) Arrays.stream(this.server.getEndpointDescriptions()).filter(endpointDescription2 -> {
                return pathOrUrl(str).equals(pathOrUrl(endpointDescription2.getEndpointUrl())) && endpointDescription2.getSecurityPolicyUri().equals(securityPolicyUri);
            }).findFirst().orElse(null);
            if (endpointDescription == null && !this.server.getConfig().isStrictEndpointUrlsEnabled()) {
                endpointDescription = (EndpointDescription) Arrays.stream(this.server.getEndpointDescriptions()).filter(endpointDescription3 -> {
                    return endpointDescription3.getSecurityPolicyUri().equals(securityPolicyUri);
                }).findFirst().orElse(null);
            }
            if (endpointDescription == null) {
                throw new UaException(StatusCodes.Bad_SecurityChecksFailed, "SecurityPolicy URI did not match");
            }
            this.secureChannel = this.server.openSecureChannel();
            this.secureChannel.setEndpointDescription(endpointDescription);
        } else {
            this.secureChannel = this.server.getSecureChannel(readUnsignedInt);
            if (this.secureChannel == null) {
                throw new UaException(StatusCodes.Bad_TcpSecureChannelUnknown, "unknown secure channel id: " + readUnsignedInt);
            }
            if (!this.secureChannel.getRemoteCertificateChainBytes().equals(decode.getSenderCertificate())) {
                throw new UaException(StatusCodes.Bad_SecurityChecksFailed, "certificate requesting renewal did not match existing certificate.");
            }
            Channel channel = (Channel) this.secureChannel.attr(UaTcpStackServer.BoundChannelKey).get();
            if (channel != null && channel != channelHandlerContext.channel()) {
                throw new UaException(StatusCodes.Bad_SecurityChecksFailed, "received a renewal request from channel other than the bound channel.");
            }
        }
        if (!this.headerRef.compareAndSet(null, decode) && !decode.equals(this.headerRef.get())) {
            throw new UaException(StatusCodes.Bad_SecurityChecksFailed, "subsequent AsymmetricSecurityHeader did not match");
        }
        SecurityPolicy fromUri = SecurityPolicy.fromUri(decode.getSecurityPolicyUri());
        this.secureChannel.setSecurityPolicy(fromUri);
        if (!decode.getSenderCertificate().isNull() && fromUri != SecurityPolicy.None) {
            this.secureChannel.setRemoteCertificate(decode.getSenderCertificate().bytes());
            try {
                CertificateValidator certificateValidator = this.server.getCertificateValidator();
                certificateValidator.validate(this.secureChannel.getRemoteCertificate());
                certificateValidator.verifyTrustChain(this.secureChannel.getRemoteCertificate(), this.secureChannel.getRemoteCertificateChain());
            } catch (UaException e) {
                try {
                    UaException uaException = new UaException(e.getStatusCode(), "security checks failed");
                    this.logger.debug("[remote={}] {}.", new Object[]{channelHandlerContext.channel().remoteAddress(), ExceptionHandler.sendErrorMessage(channelHandlerContext, uaException).getReason(), uaException});
                } catch (Exception e2) {
                    this.logger.error("Error sending ErrorMessage: {}", e2.getMessage(), e2);
                }
            }
        }
        if (!decode.getReceiverThumbprint().isNull()) {
            CertificateManager certificateManager = this.server.getCertificateManager();
            Optional<X509Certificate[]> certificateChain = certificateManager.getCertificateChain(decode.getReceiverThumbprint());
            Optional<KeyPair> keyPair = certificateManager.getKeyPair(decode.getReceiverThumbprint());
            if (!certificateChain.isPresent() || !keyPair.isPresent()) {
                throw new UaException(StatusCodes.Bad_SecurityChecksFailed, "no certificate for provided thumbprint");
            }
            X509Certificate[] x509CertificateArr = certificateChain.get();
            this.secureChannel.setLocalCertificate(x509CertificateArr[0]);
            this.secureChannel.setLocalCertificateChain(x509CertificateArr);
            this.secureChannel.setKeyPair(keyPair.get());
        }
        if (byteBuf.readerIndex(0).readableBytes() > this.maxChunkSize) {
            throw new UaException(StatusCodes.Bad_TcpMessageTooLarge, String.format("max chunk size exceeded (%s)", Integer.valueOf(this.maxChunkSize)));
        }
        this.chunkBuffers.add(byteBuf.retain());
        if (this.chunkBuffers.size() > this.maxChunkCount) {
            throw new UaException(StatusCodes.Bad_TcpMessageTooLarge, String.format("max chunk count exceeded (%s)", Integer.valueOf(this.maxChunkCount)));
        }
        if (readByte == 'F') {
            List<ByteBuf> list = this.chunkBuffers;
            this.chunkBuffers = new ArrayList(this.maxChunkCount);
            this.headerRef.set(null);
            this.serializationQueue.decode((binaryDecoder, chunkDecoder) -> {
                ByteBuf byteBuf2 = null;
                try {
                    try {
                        byteBuf2 = chunkDecoder.decodeAsymmetric(this.secureChannel, list);
                        OpenSecureChannelRequest openSecureChannelRequest = (OpenSecureChannelRequest) binaryDecoder.setBuffer(byteBuf2).decodeMessage(null);
                        this.logger.debug("Received OpenSecureChannelRequest ({}, id={}).", openSecureChannelRequest.getRequestType(), Long.valueOf(readUnsignedInt));
                        installSecurityToken(channelHandlerContext, openSecureChannelRequest, chunkDecoder.getLastRequestId());
                        if (byteBuf2 != null) {
                            byteBuf2.release();
                        }
                        list.clear();
                    } catch (UaException e3) {
                        this.logger.error("Error decoding asymmetric message: {}", e3.getMessage(), e3);
                        channelHandlerContext.close();
                        if (byteBuf2 != null) {
                            byteBuf2.release();
                        }
                        list.clear();
                    }
                } catch (Throwable th) {
                    if (byteBuf2 != null) {
                        byteBuf2.release();
                    }
                    list.clear();
                    throw th;
                }
            });
        }
    }

    private String pathOrUrl(String str) {
        try {
            return new URI(str).parseServerAuthority().getPath();
        } catch (Throwable th) {
            this.logger.warn("Endpoint URL '{}' is not a valid URI: {}", th.getMessage(), th);
            return str;
        }
    }

    private void installSecurityToken(ChannelHandlerContext channelHandlerContext, OpenSecureChannelRequest openSecureChannelRequest, long j) throws UaException {
        SecurityTokenRequestType requestType = openSecureChannelRequest.getRequestType();
        if (requestType == SecurityTokenRequestType.Issue) {
            this.secureChannel.setMessageSecurityMode(openSecureChannelRequest.getSecurityMode());
        } else if (requestType == SecurityTokenRequestType.Renew && this.secureChannel.getMessageSecurityMode() != openSecureChannelRequest.getSecurityMode()) {
            throw new UaException(StatusCodes.Bad_SecurityChecksFailed, "secure channel renewal requested a different MessageSecurityMode.");
        }
        ChannelSecurityToken channelSecurityToken = new ChannelSecurityToken(Unsigned.uint(this.secureChannel.getChannelId()), Unsigned.uint(this.server.nextTokenId()), DateTime.now(), Unsigned.uint(Math.max(SecureChannelLifetimeMin, Math.min(SecureChannelLifetimeMax, openSecureChannelRequest.getRequestedLifetime().longValue()))));
        ChannelSecurity.SecuritySecrets securitySecrets = null;
        if (this.secureChannel.isSymmetricSigningEnabled()) {
            SecurityAlgorithm symmetricEncryptionAlgorithm = this.secureChannel.getSecurityPolicy().getSymmetricEncryptionAlgorithm();
            ByteString clientNonce = openSecureChannelRequest.getClientNonce();
            if (clientNonce == null || clientNonce.isNull()) {
                throw new UaException(StatusCodes.Bad_SecurityChecksFailed, "remote nonce must be non-null");
            }
            if (clientNonce.length() < NonceUtil.getNonceLength(symmetricEncryptionAlgorithm)) {
                throw new UaException(StatusCodes.Bad_SecurityChecksFailed, String.format("remote nonce length must be at least %d bytes", Integer.valueOf(NonceUtil.getNonceLength(symmetricEncryptionAlgorithm))));
            }
            this.secureChannel.setLocalNonce(NonceUtil.generateNonce(NonceUtil.getNonceLength(symmetricEncryptionAlgorithm)));
            this.secureChannel.setRemoteNonce(clientNonce);
            securitySecrets = ChannelSecurity.generateKeyPair(this.secureChannel, this.secureChannel.getRemoteNonce(), this.secureChannel.getLocalNonce());
        }
        ChannelSecurity channelSecurity = this.secureChannel.getChannelSecurity();
        this.secureChannel.setChannelSecurity(new ChannelSecurity(securitySecrets, channelSecurityToken, channelSecurity != null ? channelSecurity.getCurrentKeys() : null, channelSecurity != null ? channelSecurity.getCurrentToken() : null));
        sendOpenSecureChannelResponse(channelHandlerContext, j, new OpenSecureChannelResponse(new ResponseHeader(DateTime.now(), openSecureChannelRequest.getRequestHeader().getRequestHandle(), StatusCode.GOOD, null, null, null), Unsigned.uint(0L), channelSecurityToken, this.secureChannel.getLocalNonce()));
    }

    private void sendOpenSecureChannelResponse(ChannelHandlerContext channelHandlerContext, long j, OpenSecureChannelResponse openSecureChannelResponse) {
        this.serializationQueue.encode((binaryEncoder, chunkEncoder) -> {
            ByteBuf buffer = BufferUtil.buffer();
            try {
                try {
                    binaryEncoder.setBuffer(buffer);
                    binaryEncoder.encodeMessage(null, openSecureChannelResponse);
                    List<ByteBuf> encodeAsymmetric = chunkEncoder.encodeAsymmetric(this.secureChannel, MessageType.OpenSecureChannel, buffer, j);
                    if (!this.symmetricHandlerAdded) {
                        channelHandlerContext.pipeline().addFirst(new UaTcpServerSymmetricHandler(this.server, this.serializationQueue, this.secureChannel));
                        this.symmetricHandlerAdded = true;
                    }
                    encodeAsymmetric.forEach(byteBuf -> {
                        channelHandlerContext.write(byteBuf, channelHandlerContext.voidPromise());
                    });
                    channelHandlerContext.flush();
                    this.server.secureChannelIssuedOrRenewed(this.secureChannel, openSecureChannelResponse.getSecurityToken().getRevisedLifetime().longValue());
                    this.logger.debug("Sent OpenSecureChannelResponse.");
                    buffer.release();
                } catch (UaException e) {
                    this.logger.error("Error encoding OpenSecureChannelResponse: {}", e.getMessage(), e);
                    channelHandlerContext.close();
                    buffer.release();
                }
            } catch (Throwable th) {
                buffer.release();
                throw th;
            }
        });
    }

    @Override // io.netty.channel.ChannelInboundHandlerAdapter, io.netty.channel.ChannelHandlerAdapter, io.netty.channel.ChannelHandler, io.netty.channel.ChannelInboundHandler
    public void exceptionCaught(ChannelHandlerContext channelHandlerContext, Throwable th) throws Exception {
        this.chunkBuffers.forEach((v0) -> {
            v0.release();
        });
        this.chunkBuffers.clear();
        if (th instanceof IOException) {
            channelHandlerContext.close();
            this.logger.debug("[remote={}] IOException caught; channel closed");
            return;
        }
        ErrorMessage sendErrorMessage = ExceptionHandler.sendErrorMessage(channelHandlerContext, th);
        if (th instanceof UaException) {
            this.logger.debug("[remote={}] UaException caught; sent {}", new Object[]{channelHandlerContext.channel().remoteAddress(), sendErrorMessage, th});
        } else {
            this.logger.error("[remote={}] Exception caught; sent {}", new Object[]{channelHandlerContext.channel().remoteAddress(), sendErrorMessage, th});
        }
    }
}
