package neural;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import weka.core.json.JSONInstances;

/* loaded from: input_file:neural/NeuralNetwork.class */
public class NeuralNetwork implements Serializable {
    private ArrayList<NeuronLayer> neuronLayers;
    private transient BufferedWriter out;
    public int neuronId;

    public NeuralNetwork(int i, int i2, int i3, IActivationFunction iActivationFunction, boolean z) throws Exception {
        try {
            this.neuronId = 0;
            this.neuronLayers = new ArrayList<>();
            NeuronLayer createLayer = createLayer(i, LayerType.input, this.neuronId, i3, new ActivationFunctionLinear());
            NeuronLayer createLayer2 = createLayer(i2, LayerType.output, this.neuronId, 0, iActivationFunction);
            if (z) {
                fullyConnectLayers(createLayer, createLayer2, true);
            }
            this.neuronLayers.add(createLayer);
            this.neuronLayers.add(createLayer2);
        } catch (Exception e) {
            throw new Exception("NeuralNetwork: NeuralNetwork -> " + e.getMessage());
        }
    }

    public NeuronLayer createLayer(int i, LayerType layerType, int i2, int i3, IActivationFunction iActivationFunction) throws Exception {
        NeuronLayer neuronLayer = new NeuronLayer(i, layerType, i2, i3, iActivationFunction);
        this.neuronId += i + i3;
        return neuronLayer;
    }

    public void fullyConnectLayers(NeuronLayer neuronLayer, NeuronLayer neuronLayer2, boolean z) throws Exception {
        Iterator<Neuron> it = neuronLayer.neuronList().iterator();
        while (it.hasNext()) {
            Neuron next = it.next();
            Iterator<Neuron> it2 = neuronLayer2.neuronList().iterator();
            while (it2.hasNext()) {
                connectNeurons(next, it2.next(), z);
            }
        }
    }

    public boolean addHiddenLayer(NeuronLayer neuronLayer, int i) {
        if (i < 1 || i >= this.neuronLayers.size() || neuronLayer.type() == LayerType.input || neuronLayer.type() == LayerType.output) {
            return false;
        }
        this.neuronLayers.add(i, neuronLayer);
        return true;
    }

    public boolean addHiddenLayer(int i, IActivationFunction iActivationFunction, int i2) {
        if (i2 < 1 || i2 >= this.neuronLayers.size()) {
            return false;
        }
        try {
            NeuronLayer neuronLayer = new NeuronLayer(i, LayerType.hidden, this.neuronId, 0, iActivationFunction);
            fullyConnectLayers(this.neuronLayers.get(i2 - 1), neuronLayer, true);
            fullyConnectLayers(neuronLayer, this.neuronLayers.get(i2), true);
            this.neuronLayers.add(i2, neuronLayer);
            return true;
        } catch (Exception e) {
            return false;
        }
    }

    public boolean addHiddenLayer(int i, IActivationFunction iActivationFunction) {
        return addHiddenLayer(i, iActivationFunction, this.neuronLayers.size() - 1);
    }

    public void connectNeurons(Neuron neuron, Neuron neuron2, boolean z) {
        Synapse synapse = new Synapse(neuron, neuron2);
        if (z) {
            synapse.generateWeight();
        }
    }

    public void connectNeurons(Neuron neuron, Neuron neuron2, double d) {
        new Synapse(neuron, neuron2, d);
    }

    public NeuronLayer inputLayer() {
        return this.neuronLayers.get(0);
    }

    public NeuronLayer outputLayer() {
        return this.neuronLayers.get(this.neuronLayers.size() - 1);
    }

    public ArrayList<NeuronLayer> hiddenLayers() {
        return new ArrayList<>(this.neuronLayers.subList(1, this.neuronLayers.size() - 1));
    }

    public ArrayList<NeuronLayer> layers() {
        return this.neuronLayers;
    }

    public int size() {
        return this.neuronLayers.size();
    }

    public ArrayList<Synapse> synapses() {
        ArrayList<Synapse> arrayList = new ArrayList<>();
        Iterator<NeuronLayer> it = this.neuronLayers.iterator();
        while (it.hasNext()) {
            Iterator<Neuron> it2 = it.next().neuronList().iterator();
            while (it2.hasNext()) {
                arrayList.addAll(it2.next().incomingSynapses());
            }
        }
        return arrayList;
    }

    public void bubbleThrough() {
        ArrayList<NeuronLayer> hiddenLayers = hiddenLayers();
        hiddenLayers.add(outputLayer());
        Iterator<NeuronLayer> it = hiddenLayers.iterator();
        while (it.hasNext()) {
            Iterator<Neuron> it2 = it.next().neuronList().iterator();
            while (it2.hasNext()) {
                Neuron next = it2.next();
                next.calculateNetInput();
                next.calculateOutput();
                next.calculateDerivative();
            }
        }
    }

    public void resetSlopes() {
        resetSlopes(synapses());
    }

    public void resetSlopes(ArrayList<Synapse> arrayList) {
        Iterator<Synapse> it = arrayList.iterator();
        while (it.hasNext()) {
            it.next().setCurrentSlope(0.0d);
        }
    }

    public void resetDeltas() {
        Iterator<NeuronLayer> it = this.neuronLayers.iterator();
        while (it.hasNext()) {
            Iterator<Neuron> it2 = it.next().neuronList().iterator();
            while (it2.hasNext()) {
                it2.next().setCurrentDelta(0.0d);
            }
        }
    }

    public double[][] extractOutput(TrainingSet trainingSet, int i) throws Exception {
        double[][] dArr = new double[trainingSet.size()][i];
        Iterator<TrainingPattern> it = trainingSet.getTraningSet().iterator();
        int i2 = 0;
        while (it.hasNext()) {
            dArr[i2] = extractOutput(it.next());
            i2++;
        }
        return dArr;
    }

    public double[] extractOutput(TrainingPattern trainingPattern) throws Exception {
        try {
            injectInput(trainingPattern.getInputPattern());
            bubbleThrough();
            return extractOutput();
        } catch (Exception e) {
            throw new Exception("NeuralNetwork: extractOutput -> " + e.getMessage());
        }
    }

    public double[] extractOutput() {
        double[] dArr = new double[outputLayer().size()];
        Iterator<Neuron> it = outputLayer().neuronList().iterator();
        int i = 0;
        while (it.hasNext()) {
            int i2 = i;
            i++;
            dArr[i2] = it.next().currentOutput();
        }
        return dArr;
    }

    public double calculateOutputResidualError(double d, int i) throws Exception {
        Neuron neuron = outputLayer().getNeuron(i);
        return (neuron.currentOutput() - d) * neuron.calculateDerivative();
    }

    public double calculateSquaredError(TrainingPattern trainingPattern) throws Exception {
        double d = 0.0d;
        NeuronLayer outputLayer = outputLayer();
        injectInput(trainingPattern.getInputPattern());
        bubbleThrough();
        Pattern desiredOutputs = trainingPattern.getDesiredOutputs();
        if (desiredOutputs.size() != outputLayer.size()) {
            throw new Exception("NeuralNetwork: calculateNetworkSumSquareError: number of training pattern outputs doesn't match the network");
        }
        Iterator<Neuron> it = outputLayer.neuronList().iterator();
        int i = 0;
        while (it.hasNext()) {
            int i2 = i;
            i++;
            d += Math.pow(it.next().currentOutput() - desiredOutputs.get(i2), 2.0d);
        }
        return d;
    }

    public double calculateSquaredError(TrainingSet trainingSet) throws Exception {
        double d = 0.0d;
        Iterator<TrainingPattern> it = trainingSet.getTraningSet().iterator();
        while (it.hasNext()) {
            d += calculateSquaredError(it.next());
        }
        return d;
    }

    public double calculateMeanSquaredError(TrainingSet trainingSet) throws Exception {
        return calculateSquaredError(trainingSet) / trainingSet.size();
    }

    public void storeLastSlope() {
        storeLastSlope(synapses());
    }

    public void storeLastSlope(ArrayList<Synapse> arrayList) {
        Iterator<Synapse> it = arrayList.iterator();
        while (it.hasNext()) {
            it.next().storeCurrentSlope();
        }
    }

    private boolean checkNeuralNetworkCorectness() {
        return checkInputLayer() && checkOutputLayer() && checkHiddenLayers();
    }

    private boolean checkInputLayer() {
        return this.neuronLayers.get(0).type() == LayerType.input;
    }

    private boolean checkOutputLayer() {
        return this.neuronLayers.get(this.neuronLayers.size() - 1).type() == LayerType.output;
    }

    private boolean checkHiddenLayers() {
        Iterator<NeuronLayer> it = hiddenLayers().iterator();
        while (it.hasNext()) {
            if (!checkHiddenLayer(it.next())) {
                return false;
            }
        }
        return true;
    }

    private boolean checkHiddenLayer(NeuronLayer neuronLayer) {
        return neuronLayer.type() == LayerType.hidden;
    }

    private boolean checkPatternInputNumber(Pattern pattern) {
        return pattern.size() == inputLayer().size() - inputLayer().biasNumber();
    }

    public boolean injectInput(Pattern pattern) throws Exception {
        if (!checkPatternInputNumber(pattern)) {
            return false;
        }
        NeuronLayer inputLayer = inputLayer();
        for (int i = 0; i < pattern.size(); i++) {
            try {
                inputLayer.getNeuron(i).setOutput(pattern.get(i));
            } catch (Exception e) {
                throw new Exception("NeuralNetwork: injectInput -> " + e.getMessage());
            }
        }
        return true;
    }

    private boolean checkPatternOutputNumber(Pattern pattern) {
        return pattern.size() == outputLayer().size();
    }

    public void printError(TrainingSet trainingSet) throws Exception {
        try {
            System.out.println("E = " + calculateSquaredError(trainingSet));
        } catch (Exception e) {
            throw new Exception("NeuralNetwork: printError -> " + e.getMessage());
        }
    }

    public void printNetworkToFile(String str, boolean z) {
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str, z));
            bufferedWriter.write("number of layers: " + this.neuronLayers.size());
            bufferedWriter.newLine();
            Iterator<NeuronLayer> it = this.neuronLayers.iterator();
            int i = 0;
            while (it.hasNext()) {
                NeuronLayer next = it.next();
                int i2 = i;
                i++;
                bufferedWriter.write("layer " + i2 + ":  " + next.size() + " neurons");
                bufferedWriter.newLine();
                Iterator<Neuron> it2 = next.neuronList().iterator();
                while (it2.hasNext()) {
                    Neuron next2 = it2.next();
                    bufferedWriter.write("    neuron " + next2.id() + JSONInstances.SPARSE_SEPARATOR);
                    bufferedWriter.newLine();
                    Iterator<Synapse> it3 = next2.outgoingSynapses().iterator();
                    while (it3.hasNext()) {
                        Synapse next3 = it3.next();
                        bufferedWriter.write("             === (" + next3.weight() + " / " + next3.lastWeightChange() + " / " + next3.currentSlope() + " / " + next3.stepSize() + ")==> neuron " + next3.destinationNeuron().id());
                        bufferedWriter.newLine();
                        bufferedWriter.write("             slope = " + next3.currentSlope());
                        bufferedWriter.write("    stepsize = " + next3.stepSize());
                        bufferedWriter.newLine();
                        bufferedWriter.write("            delta w = " + next3.lastWeightChange());
                        bufferedWriter.write("    w = " + next3.weight());
                        bufferedWriter.newLine();
                    }
                }
                bufferedWriter.write(".................................................");
                bufferedWriter.newLine();
            }
            bufferedWriter.write("======================================================");
            bufferedWriter.newLine();
            bufferedWriter.newLine();
            bufferedWriter.close();
        } catch (Exception e) {
            System.out.println("NeuralNetwork: printNetwork: " + e.getMessage());
        }
    }

    public void printErrorToFile(String str, TrainingSet trainingSet, boolean z) throws Exception {
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str, z));
            bufferedWriter.write("E = " + calculateSquaredError(trainingSet));
            bufferedWriter.newLine();
            bufferedWriter.close();
        } catch (Exception e) {
            throw new Exception("NeuralNetwork: printErrorToFile -> " + e.getMessage());
        }
    }

    public void printErrorToFile(String str, TrainingPattern trainingPattern, boolean z) throws Exception {
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str, z));
            bufferedWriter.write("E = " + calculateSquaredError(trainingPattern));
            bufferedWriter.newLine();
            bufferedWriter.close();
        } catch (Exception e) {
            throw new Exception("NeuralNetwork: printErrorToFile -> " + e.getMessage());
        }
    }

    public void printOutputToFile(String str, boolean z) {
    }
}
