package com.intellijava.core.controller;

import com.intellijava.core.model.CohereLanguageResponse;
import com.intellijava.core.model.OpenaiLanguageResponse;
import com.intellijava.core.model.SupportedLangModels;
import com.intellijava.core.model.input.LanguageModelInput;
import com.intellijava.core.wrappers.CohereAIWrapper;
import com.intellijava.core.wrappers.OpenAIWrapper;
import com.rapidminer.extension.generative_ai.operator.TextGenerationOperator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/intellijava/core/controller/RemoteLanguageModel.class */
public class RemoteLanguageModel {
    private SupportedLangModels keyType;
    private OpenAIWrapper openaiWrapper;
    private CohereAIWrapper cohereWrapper;

    public RemoteLanguageModel(String str, String str2) {
        str2 = str2.isEmpty() ? SupportedLangModels.openai.toString() : str2;
        List<String> supportedModels = getSupportedModels();
        if (!supportedModels.contains(str2)) {
            throw new IllegalArgumentException("The received keyValue not supported. Send any model from: " + String.join(" - ", supportedModels));
        }
        initiate(str, SupportedLangModels.valueOf(str2));
    }

    public RemoteLanguageModel(String str, SupportedLangModels supportedLangModels) {
        initiate(str, supportedLangModels);
    }

    public List<String> getSupportedModels() {
        SupportedLangModels[] values = SupportedLangModels.values();
        ArrayList arrayList = new ArrayList();
        for (SupportedLangModels supportedLangModels : values) {
            arrayList.add(supportedLangModels.name());
        }
        return arrayList;
    }

    private void initiate(String str, SupportedLangModels supportedLangModels) {
        this.keyType = supportedLangModels;
        if (supportedLangModels.equals(SupportedLangModels.openai)) {
            this.openaiWrapper = new OpenAIWrapper(str);
        } else if (supportedLangModels.equals(SupportedLangModels.cohere)) {
            this.cohereWrapper = new CohereAIWrapper(str);
        }
    }

    public String generateText(LanguageModelInput languageModelInput) throws IOException {
        if (this.keyType.equals(SupportedLangModels.openai)) {
            return generateOpenaiText(languageModelInput.getModel(), languageModelInput.getPrompt(), languageModelInput.getTemperature(), languageModelInput.getMaxTokens(), languageModelInput.getNumberOfOutputs()).get(0);
        }
        if (this.keyType.equals(SupportedLangModels.cohere)) {
            return generateCohereText(languageModelInput.getModel(), languageModelInput.getPrompt(), languageModelInput.getTemperature(), languageModelInput.getMaxTokens(), languageModelInput.getNumberOfOutputs()).get(0);
        }
        throw new IllegalArgumentException("the keyType not supported");
    }

    public List<String> generateMultiText(LanguageModelInput languageModelInput) throws IOException {
        if (this.keyType.equals(SupportedLangModels.openai)) {
            return generateOpenaiText(languageModelInput.getModel(), languageModelInput.getPrompt(), languageModelInput.getTemperature(), languageModelInput.getMaxTokens(), languageModelInput.getNumberOfOutputs());
        }
        if (this.keyType.equals(SupportedLangModels.cohere)) {
            return generateCohereText(languageModelInput.getModel(), languageModelInput.getPrompt(), languageModelInput.getTemperature(), languageModelInput.getMaxTokens(), languageModelInput.getNumberOfOutputs());
        }
        throw new IllegalArgumentException("the keyType not supported");
    }

    private List<String> generateOpenaiText(String str, String str2, float f, int i, int i2) throws IOException {
        if (str.equals("")) {
            str = "text-davinci-003";
        }
        HashMap hashMap = new HashMap();
        hashMap.put(TextGenerationOperator.PARAMETER_MODEL, str);
        hashMap.put("prompt", str2);
        hashMap.put("temperature", Float.valueOf(f));
        hashMap.put(TextGenerationOperator.PARAMETER_MAX_TOKENS, Integer.valueOf(i));
        hashMap.put("n", Integer.valueOf(i2));
        OpenaiLanguageResponse openaiLanguageResponse = (OpenaiLanguageResponse) this.openaiWrapper.generateText(hashMap);
        ArrayList arrayList = new ArrayList();
        Iterator<OpenaiLanguageResponse.Choice> it = openaiLanguageResponse.getChoices().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getText());
        }
        return arrayList;
    }

    private List<String> generateCohereText(String str, String str2, float f, int i, int i2) throws IOException {
        if (str.equals("")) {
            str = "xlarge";
        }
        HashMap hashMap = new HashMap();
        hashMap.put(TextGenerationOperator.PARAMETER_MODEL, str);
        hashMap.put("prompt", str2);
        hashMap.put("temperature", Float.valueOf(f));
        hashMap.put(TextGenerationOperator.PARAMETER_MAX_TOKENS, Integer.valueOf(i));
        hashMap.put("num_generations", Integer.valueOf(i2));
        CohereLanguageResponse cohereLanguageResponse = (CohereLanguageResponse) this.cohereWrapper.generateText(hashMap);
        ArrayList arrayList = new ArrayList();
        Iterator<CohereLanguageResponse.Generation> it = cohereLanguageResponse.getGenerations().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getText());
        }
        return arrayList;
    }
}
