/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.optimisers.util;

import java.util.Arrays;
import java.util.function.DoubleUnaryOperator;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.optimisers.util.ShrinkingTensor;

public class ShrinkingVector
extends DenseVector
implements ShrinkingTensor {
    private final double baseRate;
    private final boolean scaleShrinking;
    private final double lambdaSqrt;
    private final boolean reproject;
    private double squaredTwoNorm;
    private int iteration;
    private double multiplier;

    public ShrinkingVector(DenseVector v, double baseRate, boolean scaleShrinking) {
        super(v);
        this.baseRate = baseRate;
        this.scaleShrinking = scaleShrinking;
        this.lambdaSqrt = 0.0;
        this.reproject = false;
        this.iteration = 1;
        this.multiplier = 1.0;
    }

    public ShrinkingVector(DenseVector v, double baseRate, double lambda) {
        super(v);
        this.baseRate = baseRate;
        this.scaleShrinking = true;
        this.lambdaSqrt = Math.sqrt(lambda);
        this.reproject = true;
        this.squaredTwoNorm = 0.0;
        this.iteration = 1;
        this.multiplier = 1.0;
    }

    private ShrinkingVector(double[] values, double baseRate, boolean scaleShrinking, double lambdaSqrt, boolean reproject, double squaredTwoNorm, int iteration, double multiplier) {
        super(values);
        this.baseRate = baseRate;
        this.scaleShrinking = scaleShrinking;
        this.lambdaSqrt = lambdaSqrt;
        this.reproject = reproject;
        this.squaredTwoNorm = squaredTwoNorm;
        this.iteration = iteration;
        this.multiplier = multiplier;
    }

    @Override
    public DenseVector convertToDense() {
        return DenseVector.createDenseVector(this.toArray());
    }

    @Override
    public ShrinkingVector copy() {
        return new ShrinkingVector(Arrays.copyOf(this.elements, this.elements.length), this.baseRate, this.scaleShrinking, this.lambdaSqrt, this.reproject, this.squaredTwoNorm, this.iteration, this.multiplier);
    }

    @Override
    public double[] toArray() {
        double[] newValues = new double[this.elements.length];
        for (int i = 0; i < newValues.length; ++i) {
            newValues[i] = this.get(i);
        }
        return newValues;
    }

    @Override
    public double get(int index) {
        return this.elements[index] * this.multiplier;
    }

    @Override
    public double sum() {
        double sum = 0.0;
        for (int i = 0; i < this.elements.length; ++i) {
            sum += this.get(i);
        }
        return sum;
    }

    @Override
    public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) {
        double projectionNormaliser;
        double shrinkage = this.scaleShrinking ? 1.0 - this.baseRate / (double)this.iteration : 1.0 - this.baseRate;
        this.scaleInPlace(shrinkage);
        SGDVector otherVec = (SGDVector)other;
        for (VectorTuple tuple : otherVec) {
            double update = f.applyAsDouble(tuple.value);
            double oldValue = this.elements[tuple.index] * this.multiplier;
            double newValue = oldValue + update;
            this.squaredTwoNorm -= oldValue * oldValue;
            this.squaredTwoNorm += newValue * newValue;
            this.elements[tuple.index] = newValue / this.multiplier;
        }
        if (this.reproject && (projectionNormaliser = 1.0 / this.lambdaSqrt / this.twoNorm()) < 1.0) {
            this.scaleInPlace(projectionNormaliser);
        }
        ++this.iteration;
    }

    @Override
    public int indexOfMax() {
        int index = 0;
        double value = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.elements.length; ++i) {
            double tmp = this.get(i);
            if (!(tmp > value)) continue;
            index = i;
            value = tmp;
        }
        return index;
    }

    @Override
    public double dot(SGDVector other) {
        double score = 0.0;
        for (VectorTuple tuple : other) {
            score += this.get(tuple.index) * tuple.value;
        }
        return score;
    }

    @Override
    public void scaleInPlace(double value) {
        this.multiplier *= value;
        if (Math.abs(this.multiplier) < 1.0E-6) {
            this.reifyMultiplier();
        }
    }

    private void reifyMultiplier() {
        int i = 0;
        while (i < this.elements.length) {
            int n = i++;
            this.elements[n] = this.elements[n] * this.multiplier;
        }
        this.multiplier = 1.0;
    }

    @Override
    public double twoNorm() {
        return Math.sqrt(this.squaredTwoNorm);
    }

    @Override
    public double maxValue() {
        return this.multiplier * super.maxValue();
    }

    @Override
    public double minValue() {
        return this.multiplier * super.minValue();
    }

    @Override
    public VectorIterator iterator() {
        return new ShrinkingVectorIterator(this);
    }

    private static class ShrinkingVectorIterator
    implements VectorIterator {
        private final ShrinkingVector vector;
        private final VectorTuple tuple;
        private int index;

        public ShrinkingVectorIterator(ShrinkingVector vector) {
            this.vector = vector;
            this.tuple = new VectorTuple();
            this.index = 0;
        }

        @Override
        public boolean hasNext() {
            return this.index < this.vector.size();
        }

        @Override
        public VectorTuple next() {
            this.tuple.index = this.index;
            this.tuple.value = this.vector.get(this.index);
            ++this.index;
            return this.tuple;
        }

        @Override
        public VectorTuple getReference() {
            return this.tuple;
        }
    }
}

