/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.baseline;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.Map;
import org.tribuo.Dataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.MutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.classification.Label;
import org.tribuo.classification.baseline.DummyClassifierModel;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

public final class DummyClassifierTrainer
implements Trainer<Label> {
    @Config(mandatory=true, description="Type of dummy classifier.")
    private DummyType dummyType;
    @Config(description="Label to use for the constant classifier.")
    private String constantLabel;
    @Config(description="Seed for the RNG.")
    private long seed = 1L;
    private int invocationCount = 0;

    private DummyClassifierTrainer() {
    }

    public void postConfig() {
        if (this.dummyType == DummyType.CONSTANT && this.constantLabel == null) {
            throw new PropertyException("", "constantLabel", "Please supply a label string when using the type CONSTANT.");
        }
    }

    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> instanceProvenance) {
        return this.train(examples, instanceProvenance, -1);
    }

    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> instanceProvenance, int invocationCount) {
        if (invocationCount != -1) {
            this.invocationCount = invocationCount;
        }
        ModelProvenance provenance = new ModelProvenance(DummyClassifierModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), this.getProvenance(), instanceProvenance);
        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
        ++this.invocationCount;
        switch (this.dummyType) {
            case CONSTANT: {
                MutableOutputInfo labelInfo = examples.getOutputInfo().generateMutableOutputInfo();
                Label constLabel = new Label(this.constantLabel);
                labelInfo.observe((Output)constLabel);
                return new DummyClassifierModel(provenance, featureMap, (ImmutableOutputInfo<Label>)labelInfo.generateImmutableOutputInfo(), constLabel);
            }
            case MOST_FREQUENT: {
                ImmutableOutputInfo immutableLabelInfo = examples.getOutputIDInfo();
                return new DummyClassifierModel(provenance, featureMap, (ImmutableOutputInfo<Label>)immutableLabelInfo);
            }
            case UNIFORM: 
            case STRATIFIED: {
                ImmutableOutputInfo immutableLabelInfo = examples.getOutputIDInfo();
                return new DummyClassifierModel(provenance, featureMap, (ImmutableOutputInfo<Label>)immutableLabelInfo, this.dummyType, this.seed);
            }
        }
        throw new IllegalStateException("Unknown dummyType " + (Object)((Object)this.dummyType));
    }

    public int getInvocationCount() {
        return this.invocationCount;
    }

    public synchronized void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.invocationCount = invocationCount;
    }

    public String toString() {
        switch (this.dummyType) {
            case CONSTANT: {
                return "DummyClassifierTrainer(dummyType=" + (Object)((Object)this.dummyType) + ",constantLabel=" + this.constantLabel + ")";
            }
            case MOST_FREQUENT: {
                return "DummyClassifierTrainer(dummyType=" + (Object)((Object)this.dummyType) + ")";
            }
            case UNIFORM: 
            case STRATIFIED: {
                return "DummyClassifierTrainer(dummyType=" + (Object)((Object)this.dummyType) + ",seed=" + this.seed + ")";
            }
        }
        return "DummyClassifierTrainer(dummyType=" + (Object)((Object)this.dummyType) + ")";
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }

    public static DummyClassifierTrainer createStratifiedTrainer(long seed) {
        DummyClassifierTrainer trainer = new DummyClassifierTrainer();
        trainer.dummyType = DummyType.STRATIFIED;
        trainer.seed = seed;
        return trainer;
    }

    public static DummyClassifierTrainer createConstantTrainer(String constantLabel) {
        DummyClassifierTrainer trainer = new DummyClassifierTrainer();
        trainer.dummyType = DummyType.CONSTANT;
        trainer.constantLabel = constantLabel;
        return trainer;
    }

    public static DummyClassifierTrainer createUniformTrainer(long seed) {
        DummyClassifierTrainer trainer = new DummyClassifierTrainer();
        trainer.dummyType = DummyType.UNIFORM;
        trainer.seed = seed;
        return trainer;
    }

    public static DummyClassifierTrainer createMostFrequentTrainer() {
        DummyClassifierTrainer trainer = new DummyClassifierTrainer();
        trainer.dummyType = DummyType.MOST_FREQUENT;
        return trainer;
    }

    public static enum DummyType {
        STRATIFIED,
        MOST_FREQUENT,
        UNIFORM,
        CONSTANT;

    }
}

