/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.security.oauth2.client.endpoint;

import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
import com.nimbusds.jose.proc.SecurityContext;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Function;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.client.endpoint.AbstractOAuth2AuthorizationGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.JwsHeader;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.jwt.JwtEncoderParameters;
import org.springframework.security.oauth2.jwt.NimbusJwtEncoder;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

public final class NimbusJwtClientAuthenticationParametersConverter<T extends AbstractOAuth2AuthorizationGrantRequest>
implements Converter<T, MultiValueMap<String, String>> {
    private static final String INVALID_KEY_ERROR_CODE = "invalid_key";
    private static final String INVALID_ALGORITHM_ERROR_CODE = "invalid_algorithm";
    private static final String CLIENT_ASSERTION_TYPE_VALUE = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
    private final Function<ClientRegistration, JWK> jwkResolver;
    private final Map<String, JwsEncoderHolder> jwsEncoders = new ConcurrentHashMap<String, JwsEncoderHolder>();
    private Consumer<JwtClientAuthenticationContext<T>> jwtClientAssertionCustomizer = context -> {};

    public NimbusJwtClientAuthenticationParametersConverter(Function<ClientRegistration, JWK> jwkResolver) {
        Assert.notNull(jwkResolver, "jwkResolver cannot be null");
        this.jwkResolver = jwkResolver;
    }

    @Override
    public MultiValueMap<String, String> convert(T authorizationGrantRequest) {
        Assert.notNull(authorizationGrantRequest, "authorizationGrantRequest cannot be null");
        ClientRegistration clientRegistration = ((AbstractOAuth2AuthorizationGrantRequest)authorizationGrantRequest).getClientRegistration();
        if (!ClientAuthenticationMethod.PRIVATE_KEY_JWT.equals(clientRegistration.getClientAuthenticationMethod()) && !ClientAuthenticationMethod.CLIENT_SECRET_JWT.equals(clientRegistration.getClientAuthenticationMethod())) {
            return null;
        }
        JWK jwk = this.jwkResolver.apply(clientRegistration);
        if (jwk == null) {
            OAuth2Error oauth2Error = new OAuth2Error(INVALID_KEY_ERROR_CODE, "Failed to resolve JWK signing key for client registration '" + clientRegistration.getRegistrationId() + "'.", null);
            throw new OAuth2AuthorizationException(oauth2Error);
        }
        JwsAlgorithm jwsAlgorithm = NimbusJwtClientAuthenticationParametersConverter.resolveAlgorithm(jwk);
        if (jwsAlgorithm == null) {
            OAuth2Error oauth2Error = new OAuth2Error(INVALID_ALGORITHM_ERROR_CODE, "Unable to resolve JWS (signing) algorithm from JWK associated to client registration '" + clientRegistration.getRegistrationId() + "'.", null);
            throw new OAuth2AuthorizationException(oauth2Error);
        }
        JwsHeader.Builder headersBuilder = JwsHeader.with(jwsAlgorithm);
        Instant issuedAt = Instant.now();
        Instant expiresAt = issuedAt.plus(Duration.ofSeconds(60L));
        JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder().issuer(clientRegistration.getClientId()).subject(clientRegistration.getClientId()).audience(Collections.singletonList(clientRegistration.getProviderDetails().getTokenUri())).id(UUID.randomUUID().toString()).issuedAt(issuedAt).expiresAt(expiresAt);
        JwtClientAuthenticationContext<T> jwtClientAssertionContext = new JwtClientAuthenticationContext<T>(authorizationGrantRequest, headersBuilder, claimsBuilder);
        this.jwtClientAssertionCustomizer.accept(jwtClientAssertionContext);
        JwsHeader jwsHeader = headersBuilder.build();
        JwtClaimsSet jwtClaimsSet = claimsBuilder.build();
        JwsEncoderHolder jwsEncoderHolder = this.jwsEncoders.compute(clientRegistration.getRegistrationId(), (clientRegistrationId, currentJwsEncoderHolder) -> {
            if (currentJwsEncoderHolder != null && currentJwsEncoderHolder.getJwk().equals(jwk)) {
                return currentJwsEncoderHolder;
            }
            ImmutableJWKSet<SecurityContext> jwkSource = new ImmutableJWKSet<SecurityContext>(new JWKSet(jwk));
            return new JwsEncoderHolder(new NimbusJwtEncoder(jwkSource), jwk);
        });
        JwtEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder();
        Jwt jws = jwsEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet));
        LinkedMultiValueMap<String, String> parameters = new LinkedMultiValueMap<String, String>();
        parameters.set("client_assertion_type", CLIENT_ASSERTION_TYPE_VALUE);
        parameters.set("client_assertion", jws.getTokenValue());
        return parameters;
    }

    private static JwsAlgorithm resolveAlgorithm(JWK jwk) {
        Enum jwsAlgorithm = null;
        if (jwk.getAlgorithm() != null && (jwsAlgorithm = SignatureAlgorithm.from(jwk.getAlgorithm().getName())) == null) {
            jwsAlgorithm = MacAlgorithm.from(jwk.getAlgorithm().getName());
        }
        if (jwsAlgorithm == null) {
            if (KeyType.RSA.equals(jwk.getKeyType())) {
                jwsAlgorithm = SignatureAlgorithm.RS256;
            } else if (KeyType.EC.equals(jwk.getKeyType())) {
                jwsAlgorithm = SignatureAlgorithm.ES256;
            } else if (KeyType.OCT.equals(jwk.getKeyType())) {
                jwsAlgorithm = MacAlgorithm.HS256;
            }
        }
        return jwsAlgorithm;
    }

    public void setJwtClientAssertionCustomizer(Consumer<JwtClientAuthenticationContext<T>> jwtClientAssertionCustomizer) {
        Assert.notNull(jwtClientAssertionCustomizer, "jwtClientAssertionCustomizer cannot be null");
        this.jwtClientAssertionCustomizer = jwtClientAssertionCustomizer;
    }

    public static final class JwtClientAuthenticationContext<T extends AbstractOAuth2AuthorizationGrantRequest> {
        private final T authorizationGrantRequest;
        private final JwsHeader.Builder headers;
        private final JwtClaimsSet.Builder claims;

        private JwtClientAuthenticationContext(T authorizationGrantRequest, JwsHeader.Builder headers, JwtClaimsSet.Builder claims) {
            this.authorizationGrantRequest = authorizationGrantRequest;
            this.headers = headers;
            this.claims = claims;
        }

        public T getAuthorizationGrantRequest() {
            return this.authorizationGrantRequest;
        }

        public JwsHeader.Builder getHeaders() {
            return this.headers;
        }

        public JwtClaimsSet.Builder getClaims() {
            return this.claims;
        }
    }

    private static final class JwsEncoderHolder {
        private final JwtEncoder jwsEncoder;
        private final JWK jwk;

        private JwsEncoderHolder(JwtEncoder jwsEncoder, JWK jwk) {
            this.jwsEncoder = jwsEncoder;
            this.jwk = jwk;
        }

        private JwtEncoder getJwsEncoder() {
            return this.jwsEncoder;
        }

        private JWK getJwk() {
            return this.jwk;
        }
    }
}

