/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.ensemble;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.dataset.DatasetView;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.ensemble.EnsembleModel;
import org.tribuo.ensemble.WeightedEnsembleModel;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.EnsembleModelProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

public class BaggingTrainer<T extends Output<T>>
implements Trainer<T> {
    private static final Logger logger = Logger.getLogger(BaggingTrainer.class.getName());
    @Config(mandatory=true, description="The trainer to use for each ensemble member.")
    protected Trainer<T> innerTrainer;
    @Config(mandatory=true, description="The number of ensemble members to train.")
    protected int numMembers;
    @Config(mandatory=true, description="The seed for the RNG.")
    protected long seed;
    @Config(mandatory=true, description="The combination function to aggregate each ensemble member's outputs.")
    protected EnsembleCombiner<T> combiner;
    protected SplittableRandom rng;
    protected int trainInvocationCounter;

    protected BaggingTrainer() {
    }

    public BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers) {
        this(trainer, combiner, numMembers, 12345L);
    }

    public BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers, long seed) {
        this.innerTrainer = trainer;
        this.combiner = combiner;
        this.numMembers = numMembers;
        this.seed = seed;
        this.postConfig();
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
    }

    protected String ensembleName() {
        return "bagging-ensemble";
    }

    public String toString() {
        StringBuilder buffer = new StringBuilder();
        buffer.append("BaggingTrainer(");
        buffer.append("innerTrainer=");
        buffer.append(this.innerTrainer.toString());
        buffer.append(",combiner=");
        buffer.append(this.combiner.toString());
        buffer.append(",numMembers=");
        buffer.append(this.numMembers);
        buffer.append(",seed=");
        buffer.append(this.seed);
        buffer.append(")");
        return buffer.toString();
    }

    @Override
    public EnsembleModel<T> train(Dataset<T> examples) {
        return this.train((Dataset)examples, Collections.emptyMap());
    }

    @Override
    public EnsembleModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) {
        return this.train((Dataset)examples, (Map)runProvenance, -1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public EnsembleModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        TrainerProvenance trainerProvenance;
        SplittableRandom localRNG;
        BaggingTrainer baggingTrainer = this;
        synchronized (baggingTrainer) {
            if (invocationCount != -1) {
                this.setInvocationCount(invocationCount);
            }
            localRNG = this.rng.split();
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        ImmutableFeatureMap featureIDs = examples.getFeatureIDMap();
        ImmutableOutputInfo<T> labelIDs = examples.getOutputIDInfo();
        ArrayList models = new ArrayList();
        int initialInovcation = this.innerTrainer.getInvocationCount();
        for (int i = 0; i < this.numMembers; ++i) {
            logger.info("Building model " + i);
            models.add(this.trainSingleModel(examples, featureIDs, labelIDs, localRNG.nextInt(), runProvenance, initialInovcation + i));
        }
        EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance, (ListProvenance<? extends ModelProvenance>)ListProvenance.createListProvenance(models));
        return new WeightedEnsembleModel<T>(this.ensembleName(), provenance, featureIDs, labelIDs, models, this.combiner);
    }

    protected Model<T> trainSingleModel(Dataset<T> examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, int randInt, Map<String, Provenance> runProvenance, int invocationCount) {
        DatasetView<T> bag = DatasetView.createBootstrapView(examples, examples.size(), randInt, featureIDs, labelIDs);
        Model<T> newModel = this.innerTrainer.train(bag, runProvenance, invocationCount);
        return newModel;
    }

    @Override
    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    @Override
    public synchronized void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < invocationCount) {
            SplittableRandom splittableRandom = this.rng.split();
            ++this.trainInvocationCounter;
        }
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl(this);
    }
}

