package fu.mi.fitting.fitters;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import fu.mi.fitting.distributions.Erlang;
import fu.mi.fitting.distributions.HyperErlang;
import fu.mi.fitting.distributions.HyperErlangBranch;
import fu.mi.fitting.parameters.FitParameters;
import fu.mi.fitting.sample.SampleCollection;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import java.util.stream.DoubleStream;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.ml.clustering.CentroidCluster;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:fu/mi/fitting/fitters/HyperErlangFitter.class */
public class HyperErlangFitter extends Fitter {
    public static final String FITTER_NAME = "Hyper-Erlang";
    private HyperErlang fitResult;
    private List<SampleCollection> cluster;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:fu/mi/fitting/fitters/HyperErlangFitter$HyperErlangComparator.class */
    public class HyperErlangComparator implements Comparator<HyperErlangFitter> {
        private HyperErlangComparator() {
        }

        @Override // java.util.Comparator
        public int compare(HyperErlangFitter hyperErlangFitter, HyperErlangFitter hyperErlangFitter2) {
            return Double.compare(hyperErlangFitter2.logLikelihood(), hyperErlangFitter.logLikelihood());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public HyperErlangFitter(SampleCollection sampleCollection) {
        super(sampleCollection);
        this.fitResult = null;
    }

    @Override // fu.mi.fitting.fitters.Fitter
    public HyperErlang fit() {
        if (this.fitResult == null) {
            setCluster(refinement(initFit()).getCluster());
            branchFit();
        }
        return this.fitResult;
    }

    private HyperErlangFitter refinement(HyperErlang hyperErlang) {
        FitParameters fitParameters = FitParameters.getInstance();
        int reassign = fitParameters.getReassign();
        int shuffle = fitParameters.getShuffle();
        TreeSet<HyperErlangFitter> treeSet = new TreeSet<>(new HyperErlangComparator());
        TreeSet treeSet2 = new TreeSet(new HyperErlangComparator());
        treeSet.addAll(shuffle(hyperErlang, shuffle));
        for (int i = 0; i < reassign; i++) {
            treeSet2.clear();
            Iterator<HyperErlangFitter> it = treeSet.iterator();
            while (it.hasNext()) {
                treeSet2.addAll(shuffle(it.next().fit(), shuffle));
            }
            treeSet.addAll(treeSet2);
            cleanRseult(treeSet);
        }
        return treeSet.first();
    }

    private void cleanRseult(TreeSet<HyperErlangFitter> treeSet) {
        TreeSet treeSet2 = new TreeSet(new HyperErlangComparator());
        int optimize = FitParameters.getInstance().getOptimize();
        for (int i = 0; i < optimize && i < treeSet.size() && treeSet.first() != null; i++) {
            treeSet2.add(treeSet.pollFirst());
        }
        treeSet.clear();
        treeSet.addAll(treeSet2);
    }

    private List<HyperErlangFitter> shuffle(HyperErlang hyperErlang, int i) {
        ArrayList newArrayList = Lists.newArrayList();
        List<HyperErlangBranch> branches = hyperErlang.getBranches();
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(this.samples.size(), branches.size());
        HashMap newHashMap = Maps.newHashMap();
        for (int i2 = 0; i2 < this.samples.size(); i2++) {
            for (int i3 = 0; i3 < branches.size(); i3++) {
                array2DRowRealMatrix.setEntry(i2, i3, branches.get(i3).dist.density(this.samples.getValue(i2)));
            }
            newHashMap.put(Integer.valueOf(i2), Double.valueOf(DoubleStream.of(array2DRowRealMatrix.getRow(i2)).sum()));
        }
        for (int i4 = 0; i4 < this.samples.size(); i4++) {
            for (int i5 = 0; i5 < branches.size(); i5++) {
                array2DRowRealMatrix.setEntry(i4, i5, array2DRowRealMatrix.getEntry(i4, i5) / ((Double) newHashMap.get(Integer.valueOf(i4))).doubleValue());
            }
        }
        for (int i6 = 0; i6 < i; i6++) {
            newArrayList.add(fitterFromCluster(discreteSamples(array2DRowRealMatrix)));
        }
        return newArrayList;
    }

    private HyperErlangFitter fitterFromCluster(List<SampleCollection> list) {
        HyperErlangFitter hyperErlangFitter = new HyperErlangFitter(this.samples);
        hyperErlangFitter.setCluster(list);
        hyperErlangFitter.branchFit();
        return hyperErlangFitter;
    }

    private void branchFit() {
        ArrayList newArrayList = Lists.newArrayList();
        Iterator<SampleCollection> it = getCluster().iterator();
        while (it.hasNext()) {
            newArrayList.add(new HyperErlangBranch(r0.size() / this.samples.size(), (Erlang) FitterFactory.getFitterByName(MomErlangFitter.FITTER_NAME, it.next()).fit()));
        }
        this.fitResult = new HyperErlang(newArrayList);
    }

    private List<SampleCollection> discreteSamples(RealMatrix realMatrix) {
        ArrayList newArrayList = Lists.newArrayList();
        ArrayListMultimap create = ArrayListMultimap.create();
        for (int i = 0; i < this.samples.size(); i++) {
            create.put(Integer.valueOf(findCluster(i, realMatrix)), this.samples.getSample(i));
        }
        for (int i2 = 0; i2 < realMatrix.getColumnDimension(); i2++) {
            ArrayList newArrayList2 = Lists.newArrayList();
            newArrayList2.addAll(create.get((ArrayListMultimap) Integer.valueOf(i2)));
            newArrayList.add(new SampleCollection(newArrayList2));
        }
        return newArrayList;
    }

    private int findCluster(int i, RealMatrix realMatrix) {
        double random = FastMath.random();
        double d = 0.0d;
        int i2 = 0;
        while (i2 < realMatrix.getColumnDimension()) {
            d += realMatrix.getEntry(i, i2);
            if (d > random) {
                break;
            }
            i2++;
        }
        return i2 == realMatrix.getColumnDimension() ? i2 - 1 : i2;
    }

    private HyperErlang initFit() {
        ArrayList newArrayList = Lists.newArrayList();
        List cluster = new KMeansPlusPlusClusterer(FitParameters.getInstance().getBranch()).cluster(this.samples.getData());
        int size = cluster.size();
        for (int i = 0; i < size; i++) {
            newArrayList.add(i, (ErlangFitter) FitterFactory.getFitterByName(MomErlangFitter.FITTER_NAME, new SampleCollection(((CentroidCluster) cluster.get(i)).getPoints())));
        }
        ArrayList newArrayList2 = Lists.newArrayList();
        Iterator it = newArrayList.iterator();
        while (it.hasNext()) {
            newArrayList2.add(new HyperErlangBranch(r0.samples.size() / this.samples.size(), (Erlang) ((Fitter) it.next()).fit()));
        }
        return new HyperErlang(newArrayList2);
    }

    @Override // fu.mi.fitting.fitters.Fitter
    public String getName() {
        return FITTER_NAME;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public List<SampleCollection> getCluster() {
        Lists.newArrayList().addAll(this.cluster);
        return this.cluster;
    }

    private void setCluster(List<SampleCollection> list) {
        this.cluster = list;
    }
}
