/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.service;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.input.structured.StructuredPrompt;
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
import dev.langchain4j.model.moderation.Moderation;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServiceTokenStream;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.Moderate;
import dev.langchain4j.service.Result;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.TypeUtils;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.service.UserName;
import dev.langchain4j.service.V;
import dev.langchain4j.service.output.ServiceOutputParser;
import dev.langchain4j.service.tool.ToolExecutor;
import java.io.InputStream;
import java.lang.reflect.Array;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.Proxy;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Scanner;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

class DefaultAiServices<T>
extends AiServices<T> {
    private final ServiceOutputParser serviceOutputParser = new ServiceOutputParser();
    private static final int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 10;

    DefaultAiServices(AiServiceContext context) {
        super(context);
    }

    static void validateParameters(Method method) {
        Parameter[] parameters = method.getParameters();
        if (parameters == null || parameters.length < 2) {
            return;
        }
        for (Parameter parameter : parameters) {
            V v = parameter.getAnnotation(V.class);
            UserMessage userMessage = parameter.getAnnotation(UserMessage.class);
            MemoryId memoryId = parameter.getAnnotation(MemoryId.class);
            UserName userName = parameter.getAnnotation(UserName.class);
            if (v != null || userMessage != null || memoryId != null || userName != null) continue;
            throw IllegalConfigurationException.illegalConfiguration("Parameter '%s' of method '%s' should be annotated with @V or @UserMessage or @UserName or @MemoryId", parameter.getName(), method.getName());
        }
    }

    @Override
    public T build() {
        this.performBasicValidation();
        for (Method method : this.context.aiServiceClass.getMethods()) {
            if (method.isAnnotationPresent(Moderate.class) && this.context.moderationModel == null) {
                throw IllegalConfigurationException.illegalConfiguration("The @Moderate annotation is present, but the moderationModel is not set up. Please ensure a valid moderationModel is configured before using the @Moderate annotation.");
            }
            if (method.getReturnType() != Result.class && method.getReturnType() != List.class && method.getReturnType() != Set.class) continue;
            TypeUtils.validateReturnTypesAreProperlyParametrized(method.getName(), method.getGenericReturnType());
        }
        Object proxyInstance = Proxy.newProxyInstance(this.context.aiServiceClass.getClassLoader(), new Class[]{this.context.aiServiceClass}, new InvocationHandler(){
            private final ExecutorService executor = Executors.newCachedThreadPool();

            @Override
            public Object invoke(Object proxy, Method method, Object[] args2) throws Exception {
                List<Object> messages;
                if (method.getDeclaringClass() == Object.class) {
                    return method.invoke((Object)this, args2);
                }
                DefaultAiServices.validateParameters(method);
                String memoryId = DefaultAiServices.findMemoryId(method, args2).orElse("default");
                Optional systemMessage = DefaultAiServices.this.prepareSystemMessage(memoryId, method, args2);
                dev.langchain4j.data.message.UserMessage userMessage = DefaultAiServices.prepareUserMessage(method, args2);
                AugmentationResult augmentationResult = null;
                if (DefaultAiServices.this.context.retrievalAugmentor != null) {
                    List<ChatMessage> chatMemory = DefaultAiServices.this.context.hasChatMemory() ? DefaultAiServices.this.context.chatMemory(memoryId).messages() : null;
                    Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory);
                    AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
                    augmentationResult = DefaultAiServices.this.context.retrievalAugmentor.augment(augmentationRequest);
                    userMessage = (dev.langchain4j.data.message.UserMessage)augmentationResult.chatMessage();
                }
                Type returnType = method.getGenericReturnType();
                String outputFormatInstructions = DefaultAiServices.this.serviceOutputParser.outputFormatInstructions(returnType);
                String text = userMessage.singleText() + outputFormatInstructions;
                userMessage = Utils.isNotNullOrBlank(userMessage.name()) ? dev.langchain4j.data.message.UserMessage.from(userMessage.name(), text) : dev.langchain4j.data.message.UserMessage.from(text);
                if (DefaultAiServices.this.context.hasChatMemory()) {
                    ChatMemory chatMemory = DefaultAiServices.this.context.chatMemory(memoryId);
                    systemMessage.ifPresent(chatMemory::add);
                    chatMemory.add(userMessage);
                }
                if (DefaultAiServices.this.context.hasChatMemory()) {
                    messages = DefaultAiServices.this.context.chatMemory(memoryId).messages();
                } else {
                    messages = new ArrayList();
                    systemMessage.ifPresent(messages::add);
                    messages.add(userMessage);
                }
                Future<Moderation> moderationFuture = this.triggerModerationIfNeeded(method, messages);
                if (returnType == TokenStream.class) {
                    return new AiServiceTokenStream(messages, DefaultAiServices.this.context, memoryId);
                }
                Response<AiMessage> response = DefaultAiServices.this.context.toolSpecifications == null ? DefaultAiServices.this.context.chatModel.generate(messages) : DefaultAiServices.this.context.chatModel.generate(messages, DefaultAiServices.this.context.toolSpecifications);
                TokenUsage tokenUsageAccumulator = response.tokenUsage();
                AiServices.verifyModerationIfNeeded(moderationFuture);
                int executionsLeft = 10;
                while (true) {
                    if (executionsLeft-- == 0) {
                        throw Exceptions.runtime("Something is wrong, exceeded %s sequential tool executions", 10);
                    }
                    AiMessage aiMessage = response.content();
                    if (DefaultAiServices.this.context.hasChatMemory()) {
                        DefaultAiServices.this.context.chatMemory(memoryId).add(aiMessage);
                    } else {
                        messages = new ArrayList<ChatMessage>(messages);
                        messages.add(aiMessage);
                    }
                    if (!aiMessage.hasToolExecutionRequests()) break;
                    for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                        ToolExecutor toolExecutor = DefaultAiServices.this.context.toolExecutors.get(toolExecutionRequest.name());
                        String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, memoryId);
                        ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, toolExecutionResult);
                        if (DefaultAiServices.this.context.hasChatMemory()) {
                            DefaultAiServices.this.context.chatMemory(memoryId).add(toolExecutionResultMessage);
                            continue;
                        }
                        messages.add(toolExecutionResultMessage);
                    }
                    if (DefaultAiServices.this.context.hasChatMemory()) {
                        messages = DefaultAiServices.this.context.chatMemory(memoryId).messages();
                    }
                    response = DefaultAiServices.this.context.chatModel.generate(messages, DefaultAiServices.this.context.toolSpecifications);
                    tokenUsageAccumulator = TokenUsage.sum(tokenUsageAccumulator, response.tokenUsage());
                }
                response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason());
                Object parsedResponse = DefaultAiServices.this.serviceOutputParser.parse(response, returnType);
                if (TypeUtils.typeHasRawClass(returnType, Result.class)) {
                    return Result.builder().content(parsedResponse).tokenUsage(tokenUsageAccumulator).sources(augmentationResult == null ? null : augmentationResult.contents()).finishReason(response.finishReason()).build();
                }
                return parsedResponse;
            }

            private Future<Moderation> triggerModerationIfNeeded(Method method, List<ChatMessage> messages) {
                if (method.isAnnotationPresent(Moderate.class)) {
                    return this.executor.submit(() -> {
                        List<ChatMessage> messagesToModerate = AiServices.removeToolMessages(messages);
                        return DefaultAiServices.this.context.moderationModel.moderate(messagesToModerate).content();
                    });
                }
                return null;
            }
        });
        return (T)proxyInstance;
    }

    private Optional<dev.langchain4j.data.message.SystemMessage> prepareSystemMessage(Object memoryId, Method method, Object[] args2) {
        return this.findSystemMessageTemplate(memoryId, method).map(systemMessageTemplate -> PromptTemplate.from(systemMessageTemplate).apply(DefaultAiServices.findTemplateVariables(systemMessageTemplate, method, args2)).toSystemMessage());
    }

    private Optional<String> findSystemMessageTemplate(Object memoryId, Method method) {
        SystemMessage annotation = method.getAnnotation(SystemMessage.class);
        if (annotation != null) {
            return Optional.of(DefaultAiServices.getTemplate(method, "System", annotation.fromResource(), annotation.value(), annotation.delimiter()));
        }
        return this.context.systemMessageProvider.apply(memoryId);
    }

    private static Map<String, Object> findTemplateVariables(String template, Method method, Object[] args2) {
        Parameter[] parameters = method.getParameters();
        HashMap<String, Object> variables = new HashMap<String, Object>();
        for (int i = 0; i < parameters.length; ++i) {
            V annotation = parameters[i].getAnnotation(V.class);
            if (annotation == null) continue;
            String variableName = annotation.value();
            Object variableValue = args2[i];
            variables.put(variableName, variableValue);
        }
        if (template.contains("{{it}}") && !variables.containsKey("it")) {
            String itValue = DefaultAiServices.getValueOfVariableIt(parameters, args2);
            variables.put("it", itValue);
        }
        return variables;
    }

    private static String getValueOfVariableIt(Parameter[] parameters, Object[] args2) {
        Parameter parameter;
        if (!(parameters.length != 1 || (parameter = parameters[0]).isAnnotationPresent(MemoryId.class) || parameter.isAnnotationPresent(UserMessage.class) || parameter.isAnnotationPresent(UserName.class) || parameter.isAnnotationPresent(V.class) && !DefaultAiServices.isAnnotatedWithIt(parameter))) {
            return DefaultAiServices.toString(args2[0]);
        }
        for (int i = 0; i < parameters.length; ++i) {
            if (!DefaultAiServices.isAnnotatedWithIt(parameters[i])) continue;
            return DefaultAiServices.toString(args2[i]);
        }
        throw IllegalConfigurationException.illegalConfiguration("Error: cannot find the value of the prompt template variable \"{{it}}\".");
    }

    private static boolean isAnnotatedWithIt(Parameter parameter) {
        V annotation = parameter.getAnnotation(V.class);
        return annotation != null && "it".equals(annotation.value());
    }

    private static dev.langchain4j.data.message.UserMessage prepareUserMessage(Method method, Object[] args2) {
        String template = DefaultAiServices.getUserMessageTemplate(method, args2);
        Map<String, Object> variables = DefaultAiServices.findTemplateVariables(template, method, args2);
        Prompt prompt = PromptTemplate.from(template).apply(variables);
        Optional<String> maybeUserName = DefaultAiServices.findUserName(method.getParameters(), args2);
        return maybeUserName.map(userName -> dev.langchain4j.data.message.UserMessage.from(userName, prompt.text())).orElseGet(prompt::toUserMessage);
    }

    private static String getUserMessageTemplate(Method method, Object[] args2) {
        Optional<String> templateFromMethodAnnotation = DefaultAiServices.findUserMessageTemplateFromMethodAnnotation(method);
        Optional<String> templateFromParameterAnnotation = DefaultAiServices.findUserMessageTemplateFromAnnotatedParameter(method.getParameters(), args2);
        if (templateFromMethodAnnotation.isPresent() && templateFromParameterAnnotation.isPresent()) {
            throw IllegalConfigurationException.illegalConfiguration("Error: The method '%s' has multiple @UserMessage annotations. Please use only one.", method.getName());
        }
        if (templateFromMethodAnnotation.isPresent()) {
            return templateFromMethodAnnotation.get();
        }
        if (templateFromParameterAnnotation.isPresent()) {
            return templateFromParameterAnnotation.get();
        }
        Optional<String> templateFromTheOnlyArgument = DefaultAiServices.findUserMessageTemplateFromTheOnlyArgument(method.getParameters(), args2);
        if (templateFromTheOnlyArgument.isPresent()) {
            return templateFromTheOnlyArgument.get();
        }
        throw IllegalConfigurationException.illegalConfiguration("Error: The method '%s' does not have a user message defined.", method.getName());
    }

    private static Optional<String> findUserMessageTemplateFromMethodAnnotation(Method method) {
        return Optional.ofNullable(method.getAnnotation(UserMessage.class)).map(a -> DefaultAiServices.getTemplate(method, "User", a.fromResource(), a.value(), a.delimiter()));
    }

    private static Optional<String> findUserMessageTemplateFromAnnotatedParameter(Parameter[] parameters, Object[] args2) {
        for (int i = 0; i < parameters.length; ++i) {
            if (!parameters[i].isAnnotationPresent(UserMessage.class)) continue;
            return Optional.of(DefaultAiServices.toString(args2[i]));
        }
        return Optional.empty();
    }

    private static Optional<String> findUserMessageTemplateFromTheOnlyArgument(Parameter[] parameters, Object[] args2) {
        if (parameters != null && parameters.length == 1 && parameters[0].getAnnotations().length == 0) {
            return Optional.of(DefaultAiServices.toString(args2[0]));
        }
        return Optional.empty();
    }

    private static Optional<String> findUserName(Parameter[] parameters, Object[] args2) {
        for (int i = 0; i < parameters.length; ++i) {
            if (!parameters[i].isAnnotationPresent(UserName.class)) continue;
            return Optional.of(args2[i].toString());
        }
        return Optional.empty();
    }

    private static String getTemplate(Method method, String type, String resource, String[] value, String delimiter) {
        String messageTemplate;
        if (!resource.trim().isEmpty()) {
            messageTemplate = DefaultAiServices.getResourceText(method.getDeclaringClass(), resource);
            if (messageTemplate == null) {
                throw IllegalConfigurationException.illegalConfiguration("@%sMessage's resource '%s' not found", type, resource);
            }
        } else {
            messageTemplate = String.join((CharSequence)delimiter, value);
        }
        if (messageTemplate.trim().isEmpty()) {
            throw IllegalConfigurationException.illegalConfiguration("@%sMessage's template cannot be empty", type);
        }
        return messageTemplate;
    }

    private static String getResourceText(Class<?> clazz, String resource) {
        InputStream inputStream2 = clazz.getResourceAsStream(resource);
        if (inputStream2 == null) {
            inputStream2 = clazz.getResourceAsStream("/" + resource);
        }
        return DefaultAiServices.getText(inputStream2);
    }

    private static String getText(InputStream inputStream2) {
        if (inputStream2 == null) {
            return null;
        }
        try (Scanner scanner = new Scanner(inputStream2);){
            Scanner s = scanner.useDelimiter("\\A");
            try {
                String string;
                String string2 = string = s.hasNext() ? s.next() : "";
                if (s != null) {
                    s.close();
                }
                return string;
            }
            catch (Throwable throwable) {
                if (s != null) {
                    try {
                        s.close();
                    }
                    catch (Throwable throwable2) {
                        throwable.addSuppressed(throwable2);
                    }
                }
                throw throwable;
            }
        }
    }

    private static Optional<Object> findMemoryId(Method method, Object[] args2) {
        Parameter[] parameters = method.getParameters();
        for (int i = 0; i < parameters.length; ++i) {
            if (!parameters[i].isAnnotationPresent(MemoryId.class)) continue;
            Object memoryId = args2[i];
            if (memoryId == null) {
                throw Exceptions.illegalArgument("The value of parameter '%s' annotated with @MemoryId in method '%s' must not be null", parameters[i].getName(), method.getName());
            }
            return Optional.of(memoryId);
        }
        return Optional.empty();
    }

    private static String toString(Object arg) {
        if (arg.getClass().isArray()) {
            return DefaultAiServices.arrayToString(arg);
        }
        if (arg.getClass().isAnnotationPresent(StructuredPrompt.class)) {
            return StructuredPromptProcessor.toPrompt(arg).text();
        }
        return arg.toString();
    }

    private static String arrayToString(Object arg) {
        StringBuilder sb = new StringBuilder("[");
        int length = Array.getLength(arg);
        for (int i = 0; i < length; ++i) {
            sb.append(DefaultAiServices.toString(Array.get(arg, i)));
            if (i >= length - 1) continue;
            sb.append(", ");
        }
        sb.append("]");
        return sb.toString();
    }
}

