package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;

/* loaded from: input_file:ingrid-ibus-7.1.0-RC1/lib/x-pack-core-7.17.15.jar:org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.class */
public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
    public static final long SHALLOW_SIZE;
    public static final ParseField NAME;
    public static final ParseField WEIGHTS;
    private static final ConstructingObjectParser<LogisticRegression, Void> LENIENT_PARSER;
    private static final ConstructingObjectParser<LogisticRegression, Void> STRICT_PARSER;
    private final double[] weights;
    static final /* synthetic */ boolean $assertionsDisabled;

    private static ConstructingObjectParser<LogisticRegression, Void> createParser(boolean z) {
        ConstructingObjectParser<LogisticRegression, Void> constructingObjectParser = new ConstructingObjectParser<>(NAME.getPreferredName(), z, (Function<Object[], LogisticRegression>) objArr -> {
            return new LogisticRegression((List<Double>) objArr[0]);
        });
        constructingObjectParser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
        return constructingObjectParser;
    }

    public static LogisticRegression fromXContentStrict(XContentParser xContentParser) {
        return STRICT_PARSER.apply2(xContentParser, (XContentParser) null);
    }

    public static LogisticRegression fromXContentLenient(XContentParser xContentParser) {
        return LENIENT_PARSER.apply2(xContentParser, (XContentParser) null);
    }

    LogisticRegression() {
        this((List<Double>) null);
    }

    private LogisticRegression(List<Double> list) {
        this(list == null ? null : list.stream().mapToDouble((v0) -> {
            return Double.valueOf(v0);
        }).toArray());
    }

    public LogisticRegression(double[] dArr) {
        this.weights = dArr;
    }

    public LogisticRegression(StreamInput streamInput) throws IOException {
        if (streamInput.readBoolean()) {
            this.weights = streamInput.readDoubleArray();
        } else {
            this.weights = null;
        }
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator
    public Integer expectedValueSize() {
        if (this.weights == null) {
            return null;
        }
        return Integer.valueOf(this.weights.length);
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator
    public double[] processValues(double[][] dArr) {
        Objects.requireNonNull(dArr, "values must not be null");
        if (this.weights != null && dArr.length != this.weights.length) {
            throw new IllegalArgumentException("values must be the same length as weights.");
        }
        double[] dArr2 = new double[dArr[0].length];
        for (int i = 0; i < dArr.length; i++) {
            double[] dArr3 = dArr[i];
            double d = this.weights == null ? 1.0d : this.weights[i];
            for (int i2 = 0; i2 < dArr3.length; i2++) {
                if (i2 >= dArr2.length) {
                    throw new IllegalArgumentException("value entries must have the same dimensions");
                }
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + (dArr3[i2] * d);
            }
        }
        if (dArr2.length > 1) {
            return Statistics.softMax(dArr2);
        }
        double sigmoid = Statistics.sigmoid(dArr2[0]);
        if ($assertionsDisabled || (0.0d <= sigmoid && sigmoid <= 1.0d)) {
            return new double[]{1.0d - sigmoid, sigmoid};
        }
        throw new AssertionError();
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator
    public double aggregate(double[] dArr) {
        Objects.requireNonNull(dArr, "values must not be null");
        int i = 0;
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] > d) {
                d = dArr[i2];
                i = i2;
            }
        }
        return i;
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator, org.elasticsearch.xpack.core.ml.utils.NamedXContentObject
    public String getName() {
        return NAME.getPreferredName();
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator
    public boolean compatibleWith(TargetType targetType) {
        return true;
    }

    @Override // org.elasticsearch.common.io.stream.NamedWriteable
    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    @Override // org.elasticsearch.common.io.stream.Writeable
    public void writeTo(StreamOutput streamOutput) throws IOException {
        streamOutput.writeBoolean(this.weights != null);
        if (this.weights != null) {
            streamOutput.writeDoubleArray(this.weights);
        }
    }

    @Override // org.elasticsearch.xcontent.ToXContent
    public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject();
        if (this.weights != null) {
            xContentBuilder.field(WEIGHTS.getPreferredName(), this.weights);
        }
        xContentBuilder.endObject();
        return xContentBuilder;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return Arrays.equals(this.weights, ((LogisticRegression) obj).weights);
    }

    public int hashCode() {
        return Arrays.hashCode(this.weights);
    }

    @Override // org.apache.lucene.util.Accountable
    public long ramBytesUsed() {
        return SHALLOW_SIZE + (this.weights == null ? 0L : RamUsageEstimator.sizeOf(this.weights));
    }

    static {
        $assertionsDisabled = !LogisticRegression.class.desiredAssertionStatus();
        SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class);
        NAME = new ParseField("logistic_regression", new String[0]);
        WEIGHTS = new ParseField("weights", new String[0]);
        LENIENT_PARSER = createParser(true);
        STRICT_PARSER = createParser(false);
    }
}
