/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.recommenders.jayes.inference.junctionTree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import org.eclipse.recommenders.jayes.BayesNet;
import org.eclipse.recommenders.jayes.BayesNode;
import org.eclipse.recommenders.jayes.Factor;
import org.eclipse.recommenders.jayes.inference.AbstractInferer;
import org.eclipse.recommenders.jayes.inference.junctionTree.JunctionTree;
import org.eclipse.recommenders.jayes.inference.junctionTree.JunctionTreeBuilder;
import org.eclipse.recommenders.jayes.util.ArrayUtils;
import org.eclipse.recommenders.jayes.util.BayesUtils;
import org.eclipse.recommenders.jayes.util.FlyWeight;
import org.eclipse.recommenders.jayes.util.Graph;
import org.eclipse.recommenders.jayes.util.MathUtils;
import org.eclipse.recommenders.jayes.util.NumericalInstabilityException;
import org.eclipse.recommenders.jayes.util.Pair;

public class JunctionTreeAlgorithm
extends AbstractInferer {
    protected final Map<Graph.Edge, Factor> sepSets = new HashMap<Graph.Edge, Factor>();
    protected Graph junctionTree;
    protected Factor[] nodePotentials;
    protected int[] homeClusters;
    protected final IdentityHashMap<Graph.Edge, int[]> preparedMultiplications = new IdentityHashMap();
    protected List<Integer>[] concernedClusters;
    protected Factor[] queryFactors;
    protected int[][] preparedQueries;
    protected boolean[] isBeliefValid;
    protected final List<Pair<Factor, double[]>> initializations = new ArrayList<Pair<Factor, double[]>>();
    private final List<int[]> queryFactorReverseMapping = new ArrayList<int[]>();
    protected final Set<Integer> clustersHavingEvidence = new HashSet<Integer>();
    private int logThreshold = Integer.MAX_VALUE;
    protected double[] scratchpad;

    public void setLogThreshold(int logThreshold) {
        this.logThreshold = logThreshold;
    }

    public int getLogThreshold() {
        return this.logThreshold;
    }

    @Override
    public double[] getBeliefs(BayesNode node) {
        int nodeId;
        if (!this.beliefsValid) {
            this.beliefsValid = true;
            this.updateBeliefs();
        }
        if (!this.isBeliefValid[nodeId = node.getId()]) {
            this.isBeliefValid[nodeId] = true;
            if (!this.evidence.containsKey(node)) {
                this.validateBelief(nodeId);
            } else {
                Arrays.fill(this.beliefs[nodeId], 0.0);
                this.beliefs[nodeId][node.getOutcomeIndex((String)((String)this.evidence.get((Object)node)))] = 1.0;
            }
        }
        return super.getBeliefs(node);
    }

    private void validateBelief(int nodeId) {
        Factor f = this.queryFactors[nodeId];
        f.sumPrepared(this.beliefs[nodeId], this.preparedQueries[nodeId]);
        if (f.isLogScale()) {
            MathUtils.exp(this.beliefs[nodeId]);
        }
        try {
            this.beliefs[nodeId] = MathUtils.normalize(this.beliefs[nodeId]);
        }
        catch (IllegalArgumentException exception) {
            throw new NumericalInstabilityException("Numerical instability detected for evidence: " + this.evidence + " and node : " + nodeId + ", consider setting the LogTreshold lower in the OptimizationHints", exception);
        }
    }

    @Override
    protected void updateBeliefs() {
        Arrays.fill(this.isBeliefValid, false);
        this.doUpdateBeliefs();
    }

    private void doUpdateBeliefs() {
        this.replayFactorInitializations();
        this.incorporateAllEvidence();
        int propagationRoot = this.findPropagationRoot();
        this.collectEvidence(propagationRoot, this.skipCollection(propagationRoot));
        this.distributeEvidence(propagationRoot, this.skipDistribution(propagationRoot));
    }

    private int findPropagationRoot() {
        int propagationRoot = 0;
        for (BayesNode n : this.evidence.keySet()) {
            propagationRoot = this.homeClusters[n.getId()];
        }
        return propagationRoot;
    }

    private void incorporateAllEvidence() {
        this.clustersHavingEvidence.clear();
        for (BayesNode n : this.evidence.keySet()) {
            this.incorporateEvidence(n);
        }
    }

    private void replayFactorInitializations() {
        for (Pair<Factor, double[]> init : this.initializations) {
            System.arraycopy(init.getSecond(), 0, init.getFirst().getValues(), 0, init.getSecond().length);
            init.getFirst().resetSelections();
        }
    }

    private void incorporateEvidence(BayesNode node) {
        int n = node.getId();
        for (Integer concernedCluster : this.concernedClusters[n]) {
            this.nodePotentials[concernedCluster].select(n, node.getOutcomeIndex((String)this.evidence.get(node)));
            this.clustersHavingEvidence.add(concernedCluster);
        }
    }

    private Set<Integer> skipCollection(int root) {
        HashSet<Integer> skipped = new HashSet<Integer>(this.nodePotentials.length);
        this.recursiveSkipCollection(root, new HashSet<Integer>(this.nodePotentials.length), skipped);
        return skipped;
    }

    private void recursiveSkipCollection(int node, Set<Integer> visited, Set<Integer> skipped) {
        visited.add(node);
        boolean areAllDescendantsSkipped = true;
        for (Graph.Edge e : this.junctionTree.getIncidentEdges(node)) {
            if (visited.contains(e.getSecond())) continue;
            this.recursiveSkipCollection((Integer)e.getSecond(), visited, skipped);
            if (skipped.contains(e.getSecond())) continue;
            areAllDescendantsSkipped = false;
        }
        if (areAllDescendantsSkipped && !this.clustersHavingEvidence.contains(node)) {
            skipped.add(node);
        }
    }

    private Set<Integer> skipDistribution(int distNode) {
        HashSet<Integer> skipped = new HashSet<Integer>(this.nodePotentials.length);
        this.recursiveSkipDistribution(distNode, new HashSet<Integer>(this.nodePotentials.length), skipped);
        return skipped;
    }

    private void recursiveSkipDistribution(int node, Set<Integer> visited, Set<Integer> skipped) {
        visited.add(node);
        boolean areAllDescendantsSkipped = true;
        for (Graph.Edge e : this.junctionTree.getIncidentEdges(node)) {
            if (visited.contains(e.getSecond())) continue;
            this.recursiveSkipDistribution((Integer)e.getSecond(), visited, skipped);
            if (skipped.contains(e.getSecond())) continue;
            areAllDescendantsSkipped = false;
        }
        if (areAllDescendantsSkipped && !this.isQueryFactorOfUnobservedVariable(node)) {
            skipped.add(node);
        }
    }

    private boolean isQueryFactorOfUnobservedVariable(int node) {
        int[] nArray = this.queryFactorReverseMapping.get(node);
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int i = nArray[n2];
            if (!this.evidence.containsKey(this.net.getNode(i))) {
                return true;
            }
            ++n2;
        }
        return false;
    }

    private void collectEvidence(int cluster, Set<Integer> marked) {
        marked.add(cluster);
        for (Graph.Edge e : this.junctionTree.getIncidentEdges(cluster)) {
            if (marked.contains(e.getSecond())) continue;
            this.collectEvidence((Integer)e.getSecond(), marked);
            this.messagePass(e.getBackEdge());
        }
    }

    private void distributeEvidence(int cluster, Set<Integer> marked) {
        marked.add(cluster);
        for (Graph.Edge e : this.junctionTree.getIncidentEdges(cluster)) {
            if (marked.contains(e.getSecond())) continue;
            this.messagePass(e);
            this.distributeEvidence((Integer)e.getSecond(), marked);
        }
    }

    private void messagePass(Graph.Edge sepSetEdge) {
        Factor sepSet = this.sepSets.get(sepSetEdge);
        if (!this.needMessagePass(sepSet)) {
            return;
        }
        double[] newSepValues = sepSet.getValues();
        System.arraycopy(newSepValues, 0, this.scratchpad, 0, newSepValues.length);
        int[] preparedOp = this.preparedMultiplications.get(sepSetEdge.getBackEdge());
        this.nodePotentials[(Integer)sepSetEdge.getFirst()].sumPrepared(newSepValues, preparedOp);
        if (this.isOnlyFirstLogScale(sepSetEdge)) {
            MathUtils.exp(newSepValues);
        }
        if (this.areBothEndsLogScale(sepSetEdge)) {
            MathUtils.secureSubtract(newSepValues, this.scratchpad, this.scratchpad);
        } else {
            MathUtils.secureDivide(newSepValues, this.scratchpad, this.scratchpad);
        }
        if (this.isOnlySecondLogScale(sepSetEdge)) {
            MathUtils.log(this.scratchpad);
        }
        this.nodePotentials[(Integer)sepSetEdge.getSecond()].multiplyPrepared(this.scratchpad, this.preparedMultiplications.get(sepSetEdge));
    }

    private boolean needMessagePass(Factor sepSet) {
        int[] nArray = sepSet.getDimensionIDs();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int var = nArray[n2];
            if (!this.evidence.containsKey(this.net.getNode(var))) {
                return true;
            }
            ++n2;
        }
        return false;
    }

    private boolean isOnlyFirstLogScale(Graph.Edge edge) {
        return this.nodePotentials[(Integer)edge.getFirst()].isLogScale() && !this.nodePotentials[(Integer)edge.getSecond()].isLogScale();
    }

    private boolean isOnlySecondLogScale(Graph.Edge edge) {
        return !this.nodePotentials[(Integer)edge.getFirst()].isLogScale() && this.nodePotentials[(Integer)edge.getSecond()].isLogScale();
    }

    @Override
    public void setNetwork(BayesNet bn) {
        super.setNetwork(bn);
        this.initializeFields();
        List<List<Integer>> clusters = this.buildJunctionTree().getClusters();
        this.setHomeClusters(clusters);
        this.setQueryFactors();
        this.prepareMultiplications();
        this.prepareScratch();
        this.invokeInitialBeliefUpdate();
        this.storePotentialValues();
    }

    private void prepareScratch() {
        int maxSize = 0;
        for (Factor sepSet : this.sepSets.values()) {
            maxSize = Math.max(maxSize, sepSet.getValues().length);
        }
        this.scratchpad = new double[maxSize];
    }

    private void initializeFields() {
        this.isBeliefValid = new boolean[this.beliefs.length];
        Arrays.fill(this.isBeliefValid, false);
        int numNodes = this.net.getNodes().size();
        this.homeClusters = new int[numNodes];
        this.queryFactors = new Factor[numNodes];
        this.preparedQueries = new int[numNodes][];
        this.concernedClusters = new List[numNodes];
        int i = 0;
        while (i < this.concernedClusters.length) {
            this.concernedClusters[i] = new ArrayList<Integer>();
            ++i;
        }
    }

    private JunctionTree buildJunctionTree() {
        JunctionTree jtree = JunctionTreeBuilder.fromNet(this.net);
        this.junctionTree = jtree.getGraph();
        this.initializeClusterFactors(jtree.getClusters());
        this.initializeSepsetFactors(jtree.getSepSets());
        return jtree;
    }

    private void initializeClusterFactors(List<List<Integer>> clusters) {
        this.nodePotentials = new Factor[clusters.size()];
        ListIterator<List<Integer>> cliqueIt = clusters.listIterator();
        while (cliqueIt.hasNext()) {
            Factor cliqueFactor;
            List<Integer> cluster = cliqueIt.next();
            this.nodePotentials[cliqueIt.nextIndex() - 1] = cliqueFactor = this.createFactor(cluster);
            for (Integer var : cluster) {
                this.concernedClusters[var].add(cliqueIt.nextIndex() - 1);
            }
        }
    }

    private void initializeSepsetFactors(List<Pair<Graph.Edge, List<Integer>>> sepSets) {
        for (Pair<Graph.Edge, List<Integer>> sep : sepSets) {
            this.sepSets.put(sep.getFirst(), this.createFactor(sep.getSecond()));
        }
    }

    protected Factor createFactor(List<Integer> vars) {
        Factor f = new Factor();
        ArrayList<Integer> dimensions = new ArrayList<Integer>();
        for (Integer dim : vars) {
            dimensions.add(this.net.getNode(dim).getOutcomeCount());
        }
        f.setDimensions((int[])ArrayUtils.toPrimitiveArray((Number[])dimensions.toArray(new Integer[0])));
        f.setDimensionIDs((int[])ArrayUtils.toPrimitiveArray((Number[])vars.toArray(new Integer[0])));
        if (vars.size() > this.getLogThreshold()) {
            f.setLogScale(true);
        }
        return f;
    }

    private void setHomeClusters(List<List<Integer>> clusters) {
        block0: for (BayesNode node : this.net.getNodes()) {
            List<Integer> nodeAndParents = BayesUtils.getNodeAndParentIds(node);
            ListIterator<List<Integer>> clusterIt = clusters.listIterator();
            while (clusterIt.hasNext()) {
                if (!clusterIt.next().containsAll(nodeAndParents)) continue;
                this.homeClusters[node.getId()] = clusterIt.nextIndex() - 1;
                continue block0;
            }
        }
    }

    private void setQueryFactors() {
        for (BayesNode n : this.net.getNodes()) {
            for (Integer f : this.concernedClusters[n.getId()]) {
                boolean isFirstOrSmallerTable;
                boolean bl = isFirstOrSmallerTable = this.queryFactors[n.getId()] == null || this.queryFactors[n.getId()].getValues().length > this.nodePotentials[f].getValues().length;
                if (!isFirstOrSmallerTable) continue;
                this.queryFactors[n.getId()] = this.nodePotentials[f];
            }
        }
        int i = 0;
        while (i < this.nodePotentials.length) {
            ArrayList<Integer> queryVars = new ArrayList<Integer>();
            int[] nArray = this.nodePotentials[i].getDimensionIDs();
            int n = nArray.length;
            int n2 = 0;
            while (n2 < n) {
                int var = nArray[n2];
                if (this.queryFactors[var] == this.nodePotentials[i]) {
                    queryVars.add(var);
                }
                ++n2;
            }
            this.queryFactorReverseMapping.add((int[])ArrayUtils.toPrimitiveArray((Number[])queryVars.toArray(new Integer[0])));
            ++i;
        }
    }

    private void prepareMultiplications() {
        FlyWeight flyWeight = new FlyWeight();
        this.prepareSepsetMultiplications(flyWeight);
        this.prepareQueries(flyWeight);
    }

    private void prepareSepsetMultiplications(FlyWeight flyWeight) {
        int node = 0;
        while (node < this.nodePotentials.length) {
            for (Graph.Edge e : this.junctionTree.getIncidentEdges(node)) {
                int[] preparedMultiplication = this.nodePotentials[(Integer)e.getSecond()].prepareMultiplication(this.sepSets.get(e));
                this.preparedMultiplications.put(e, flyWeight.getInstance(preparedMultiplication));
            }
            ++node;
        }
    }

    private void prepareQueries(FlyWeight flyWeight) {
        for (BayesNode node : this.net.getNodes()) {
            Factor beliefFactor = new Factor();
            beliefFactor.setDimensions(new int[]{node.getOutcomeCount()});
            beliefFactor.setDimensionIDs(new int[]{node.getId()});
            int[] preparedQuery = this.queryFactors[node.getId()].prepareMultiplication(beliefFactor);
            this.preparedQueries[node.getId()] = flyWeight.getInstance(preparedQuery);
        }
    }

    private void invokeInitialBeliefUpdate() {
        this.initializePotentialValues();
        this.multiplyCPTsIntoPotentials();
        this.collectEvidence(0, new HashSet<Integer>());
        this.distributeEvidence(0, new HashSet<Integer>());
    }

    private void initializePotentialValues() {
        Factor[] factorArray = this.nodePotentials;
        int n = this.nodePotentials.length;
        int n2 = 0;
        while (n2 < n) {
            Factor f;
            f.fill((f = factorArray[n2]).isLogScale() ? 0.0 : 1.0);
            ++n2;
        }
        for (Map.Entry<Graph.Edge, Factor> sepSet : this.sepSets.entrySet()) {
            if (!this.areBothEndsLogScale(sepSet.getKey())) {
                sepSet.getValue().fill(1.0);
                continue;
            }
            sepSet.getValue().fill(0.0);
        }
    }

    private void multiplyCPTsIntoPotentials() {
        for (BayesNode node : this.net.getNodes()) {
            Factor nodeHome = this.nodePotentials[this.homeClusters[node.getId()]];
            if (nodeHome.isLogScale()) {
                nodeHome.multiplyCompatibleToLog(node.getFactor());
                continue;
            }
            nodeHome.multiplyCompatible(node.getFactor());
        }
    }

    private boolean areBothEndsLogScale(Graph.Edge edge) {
        return this.nodePotentials[(Integer)edge.getFirst()].isLogScale() && this.nodePotentials[(Integer)edge.getSecond()].isLogScale();
    }

    private void storePotentialValues() {
        Factor[] factorArray = this.nodePotentials;
        int n = this.nodePotentials.length;
        int n2 = 0;
        while (n2 < n) {
            Factor pot = factorArray[n2];
            this.initializations.add(new Pair<Factor, double[]>(pot, (double[])pot.getValues().clone()));
            ++n2;
        }
        for (Factor sep : this.sepSets.values()) {
            this.initializations.add(new Pair<Factor, double[]>(sep, (double[])sep.getValues().clone()));
        }
    }
}

