/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.types;

import cc.mallet.grmm.types.AbstractFactor;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.ParameterizedFactor;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.Randoms;

public class BinaryUnaryFactor
extends AbstractFactor
implements ParameterizedFactor {
    private Variable theta1;
    private Variable theta2;
    private Variable var;

    public BinaryUnaryFactor(Variable var, Variable theta1, Variable theta2) {
        super(BinaryUnaryFactor.combineVariables(theta1, theta2, var));
        this.theta1 = theta1;
        this.theta2 = theta2;
        this.var = var;
        if (var.getNumOutcomes() != 2) {
            throw new IllegalArgumentException("Discrete variable " + var + " in BoltzmannUnary must be binary.");
        }
        if (!theta1.isContinuous()) {
            throw new IllegalArgumentException("Parameter " + theta1 + " in BinaryUnary must be continuous.");
        }
        if (!theta2.isContinuous()) {
            throw new IllegalArgumentException("Parameter " + theta2 + " in BinaryUnary must be continuous.");
        }
    }

    private static VarSet combineVariables(Variable theta1, Variable theta2, Variable var) {
        HashVarSet ret = new HashVarSet();
        ret.add(theta1);
        ret.add(theta2);
        ret.add(var);
        return ret;
    }

    @Override
    protected Factor extractMaxInternal(VarSet varSet) {
        throw new UnsupportedOperationException();
    }

    @Override
    protected double lookupValueInternal(int i) {
        throw new UnsupportedOperationException();
    }

    @Override
    protected Factor marginalizeInternal(VarSet varsToKeep) {
        throw new UnsupportedOperationException();
    }

    @Override
    public double value(AssignmentIterator it) {
        Assignment assn = it.assignment();
        Factor tbl = this.sliceForAlpha(assn);
        return tbl.value(assn);
    }

    private Factor sliceForAlpha(Assignment assn) {
        double th1 = assn.getDouble(this.theta1);
        double th2 = assn.getDouble(this.theta2);
        double[] vals = new double[]{th1, th2};
        return new TableFactor(this.var, vals);
    }

    @Override
    public Factor normalize() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Assignment sample(Randoms r) {
        throw new UnsupportedOperationException();
    }

    @Override
    public double logValue(AssignmentIterator it) {
        return Math.log(this.value(it));
    }

    @Override
    public Factor slice(Assignment assn) {
        Factor alphSlice = this.sliceForAlpha(assn);
        return alphSlice.slice(assn);
    }

    @Override
    public String dumpToString() {
        StringBuffer buf = new StringBuffer();
        buf.append("[BinaryUnary : var=");
        buf.append(this.var);
        buf.append(" theta1=");
        buf.append(this.theta1);
        buf.append(" theta2=");
        buf.append(this.theta2);
        buf.append(" ]");
        return buf.toString();
    }

    @Override
    public double sumGradLog(Factor q, Variable param, Assignment paramAssn) {
        Assignment assn;
        Factor q_xs = q.marginalize(this.var);
        if (param == this.theta1) {
            assn = new Assignment(this.var, 0);
        } else if (param == this.theta2) {
            assn = new Assignment(this.var, 1);
        } else {
            throw new IllegalArgumentException("Attempt to take gradient of " + this + " wrt " + param + "but factor does not depend on that variable.");
        }
        return q_xs.value(assn);
    }

    @Override
    public Factor duplicate() {
        return new BinaryUnaryFactor(this.var, this.theta1, this.theta2);
    }

    @Override
    public boolean almostEquals(Factor p, double epsilon) {
        return this.equals(p);
    }

    @Override
    public boolean isNaN() {
        return false;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        BinaryUnaryFactor that = (BinaryUnaryFactor)o;
        if (this.theta1 != null ? !this.theta1.equals(that.theta1) : that.theta1 != null) {
            return false;
        }
        if (this.theta2 != null ? !this.theta2.equals(that.theta2) : that.theta2 != null) {
            return false;
        }
        return !(this.var != null ? !this.var.equals(that.var) : that.var != null);
    }

    public int hashCode() {
        int result = this.theta1 != null ? this.theta1.hashCode() : 0;
        result = 29 * result + (this.theta2 != null ? this.theta2.hashCode() : 0);
        result = 29 * result + (this.var != null ? this.var.hashCode() : 0);
        return result;
    }
}

