//////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2010, 2025 Contributors to the Eclipse Foundation
//
// See the NOTICE file(s) distributed with this work for additional
// information regarding copyright ownership.
//
// This program and the accompanying materials are made available
// under the terms of the MIT License which is available at
// https://opensource.org/licenses/MIT
//
// SPDX-License-Identifier: MIT
//////////////////////////////////////////////////////////////////////////////

package org.eclipse.escet.cif.bdd.conversion.bitvectors;

import static org.eclipse.escet.common.java.Strings.fmt;

import java.util.Arrays;

import com.github.javabdd.BDD;
import com.github.javabdd.BDDDomain;
import com.github.javabdd.BDDFactory;

/**
 * BDD bit vector.
 *
 * @param <T> The type of BDD bit vector.
 * @param <TC> The type of BDD bit vector and carry.
 */
public abstract class CifBddBitVector<T extends CifBddBitVector<T, TC>, TC extends CifBddBitVectorAndCarry<T, TC>> {
    /** The BDD factory to use. */
    protected BDDFactory factory;

    /**
     * The BDDs for each of the bits of the bit vector.
     */
    protected BDD[] bits;

    /**
     * Constructor for the {@link CifBddBitVector} class.
     *
     * @param factory The BDD factory to use.
     * @param length The number of bits of the bit vector.
     * @throws IllegalArgumentException If the length is negative, or not supported by the bit vector representation.
     */
    protected CifBddBitVector(BDDFactory factory, int length) {
        // Precondition check.
        if (length < getMinimumLength()) {
            throw new IllegalArgumentException(fmt("Length is less than %d.", getMinimumLength()));
        }

        // Create.
        this.factory = factory;
        bits = new BDD[length];
    }

    /**
     * Returns the minimum length of bit vectors of this bit vector representation.
     *
     * @return The minimum length.
     */
    protected abstract int getMinimumLength();

    /**
     * Creates an empty BDD bit vector (bits are all {@code null}) of the same representation as this bit vector, and
     * with the same {@link #factory}.
     *
     * @param length The number of bits of the bit vector.
     * @return The new bit vector.
     */
    protected abstract T createEmpty(int length);

    /**
     * Creates a copy of this bit vector. A new instance of the bit vector is created, that has the same length. Each
     * bit is {@link BDD#id copied} to the new bit vector.
     *
     * @return The copy.
     */
    public T copy() {
        T vector = createEmpty(bits.length);
        for (int i = 0; i < bits.length; i++) {
            vector.bits[i] = bits[i].id();
        }
        return vector;
    }

    /**
     * Modifies this bit vector to represent the given other bit vector. This bit vector and the given other bit vector
     * don't need to have the same length. This bit vector is first {@link #free freed}, then the bits of the other bit
     * vector are moved to this bit vector. The other bit vector is essentially {@link #free freed}, and can no longer
     * be used.
     *
     * @param other The other bit vector.
     */
    public void replaceBy(T other) {
        free();
        this.factory = other.factory;
        this.bits = other.bits;
        other.factory = null;
        other.bits = null;
    }

    /**
     * Returns the length of the bit vector, in number of bits.
     *
     * @return The length of the bit vector, in number of bits.
     */
    public int length() {
        return bits.length;
    }

    /**
     * Returns the value count, the number of values that can be represented by the bit vector (if it can be represented
     * as a Java 'int').
     *
     * @return The value count.
     * @throws IllegalStateException If the bit vector has more than 30 bits.
     */
    int countInt() {
        // For 31 bits or more, the count is too high to be represented as a Java 'int'.
        if (bits.length > 30) {
            throw new IllegalStateException("More than 30 bits in vector.");
        }

        // Return the count.
        return 1 << bits.length;
    }

    /**
     * Returns the value count, the number of values that can be represented by the bit vector (if it can be represented
     * as a Java 'long').
     *
     * @return The value count.
     * @throws IllegalStateException If the bit vector has more than 62 bits.
     */
    long countLong() {
        // For 63 bits or more, the count is too high to be represented as a Java 'long'.
        if (bits.length > 62) {
            throw new IllegalStateException("More than 62 bits in vector.");
        }

        // Return the count.
        return 1L << bits.length;
    }

    /**
     * Returns the BDD for the bit with the given index. The lowest bit is at index zero.
     *
     * @param index The 0-based index of the bit.
     * @return The BDD for the bit with the given index.
     * @throws IndexOutOfBoundsException If the index is negative, or greater than or equal to {@link #length}.
     */
    public BDD getBit(int index) {
        return bits[index];
    }

    /**
     * Returns the value represented by the bit vector, if it is a constant bit vector, or {@code null} otherwise.
     *
     * @return The value represented by the bit vector, or {@code null}.
     * @throws IllegalStateException If the bit vector doesn't fit in a Java 'int'.
     */
    public abstract Integer getInt();

    /**
     * Returns the value represented by the bit vector, if it is a constant bit vector, or {@code null} otherwise.
     *
     * @return The value represented by the bit vector, or {@code null}.
     * @throws IllegalStateException If the bit vector doesn't fit in a Java 'long'.
     */
    public abstract Long getLong();

    /**
     * Updates the bit vector, setting the bit with the given index to a given BDD. The previous BDD stored at the bit
     * is first {@link BDD#free freed}.
     *
     * @param idx The 0-based index of the bit to set.
     * @param bdd The BDD to use as new value for the given bit.
     * @throws IndexOutOfBoundsException If the index is negative, or greater than or equal to {@link #length}.
     */
    public void setBit(int idx, BDD bdd) {
        bits[idx].free();
        bits[idx] = bdd;
    }

    /**
     * Updates the bit vector, setting the bit with the given index to a given value. The previous BDD stored at the bit
     * is first {@link BDD#free freed}.
     *
     * @param idx The 0-based index of the bit to set.
     * @param value The boolean value to use as new value for the given bit.
     * @throws IndexOutOfBoundsException If the index is negative, or greater than or equal to {@link #length}.
     */
    public void setBit(int idx, boolean value) {
        setBit(idx, value ? factory.one() : factory.zero());
    }

    /**
     * Updates the bit vector, setting each bit to the given value. The BDDs that were previously stored, are first
     * {@link BDD#free freed}.
     *
     * @param value The value to set for each bit.
     */
    public void setBitsToValue(boolean value) {
        for (int i = 0; i < bits.length; i++) {
            bits[i].free();
            bits[i] = value ? factory.one() : factory.zero();
        }
    }

    /**
     * Updates the bit vector to represent the given value. The BDDs that were previously stored, are first
     * {@link BDD#free freed}.
     *
     * @param value The value to which to set the bit vector.
     * @throws IllegalArgumentException If the value isn't supported by the bit vector representation.
     * @throws IllegalArgumentException If the value doesn't fit within the bit vector.
     */
    public abstract void setInt(int value);

    /**
     * Updates the bit vector to represent the given non-empty BDD domain. If the domain is smaller than the bit vector,
     * the higher bits of the bit vector are set to 'false'. The BDDs that were previously stored, are first
     * {@link BDD#free freed}.
     *
     * @param domain The domain to which to set the bit vector.
     * @throws IllegalArgumentException If the domain is empty.
     * @throws IllegalArgumentException If the domain doesn't fit in the bit vector.
     */
    public abstract void setDomain(BDDDomain domain);

    /**
     * Resizes the bit vector to have the given length. If the new length is larger than the current length, the way the
     * additional (most significant) bits are set depends on the bit vector representation. If the new length is smaller
     * than the current length, the most significant bits are dropped. The BDDs for dropped bits, are {@link BDD#free
     * freed}.
     *
     * @param length The new length of the bit vector.
     * @throws IllegalArgumentException If the new length is negative, or not supported by the bit vector
     *     representation.
     */
    public abstract void resize(int length);

    /**
     * Negates this bit vector. This operation returns a new bit vector. The vector that is negated is not modified or
     * {@link #free freed}.
     *
     * @return The result.
     * @throws UnsupportedOperationException If the bit vector representation doesn't support the operation.
     */
    public abstract TC negate();

    /**
     * Computes the absolute value of this bit vector. This operation returns a new bit vector. The vector for which the
     * absolute value is computed is not modified or {@link #free freed}.
     *
     * @return The result.
     */
    public abstract TC abs();

    /**
     * Adds the given bit vector to this bit vector. This operation returns a new bit vector and carry. The bit vectors
     * on which the operation is performed are not modified or {@link #free freed}.
     *
     * @param other The bit vector to add to this bit vector.
     * @return The result.
     * @throws IllegalArgumentException If this bit vector and the given bit vector have a different length.
     */
    public abstract TC add(T other);

    /**
     * Subtracts the given bit vector from this bit vector. This operation returns a new bit vector and carry. The bit
     * vectors on which the operation is performed are not modified or {@link #free freed}.
     *
     * @param other The bit vector to subtract from this bit vector.
     * @return The result.
     * @throws IllegalArgumentException If this bit vector and the given bit vector have a different length.
     */
    public abstract TC subtract(T other);

    /**
     * Computes the quotient ('div' result) of dividing this vector (the dividend) by the given value (the divisor).
     * This operation returns a new bit vector. The bit vector on which the operation is performed is not modified or
     * {@link #free freed}.
     *
     * @param divisor The value by which to divide this bit vector.
     * @return The quotient ('div' result).
     * @throws IllegalArgumentException If the divisor is not positive.
     * @throws IllegalArgumentException If the divisor doesn't fit within this bit vector.
     * @throws IllegalArgumentException If the operation can't be performed due to additional bit vector
     *     representation-specific constraints.
     */
    public abstract T div(int divisor);

    /**
     * Computes the remainder ('mod' result) of dividing this vector (the dividend) by the given value (the divisor).
     * This operation returns a new bit vector. The bit vector on which the operation is performed is not modified or
     * {@link #free freed}.
     *
     * @param divisor The value by which to divide this bit vector.
     * @return The remainder ('mod' result).
     * @throws IllegalArgumentException If the divisor is not positive.
     * @throws IllegalArgumentException If the divisor doesn't fit within this bit vector.
     * @throws IllegalStateException If the operation can't be performed due to additional bit vector
     *     representation-specific constraints.
     */
    public abstract T mod(int divisor);

    /**
     * Computes the bit vector resulting from shifting this bit vector {@code amount} bits to the left. The given
     * {@code carry} is shifted in {@code amount} times. This operation returns a new bit vector. The bit vector on
     * which the operation is performed is not modified or {@link #free freed}.
     *
     * @param amount The amount of bits to shift.
     * @param carry The carry to shift in. Only copies are used.
     * @return The shifted bit vector.
     * @throws IllegalArgumentException If the shift amount is negative.
     */
    public T shiftLeft(int amount, BDD carry) {
        // Precondition check.
        if (amount < 0) {
            throw new IllegalArgumentException("Amount is negative.");
        }

        // Compute result.
        T result = createEmpty(bits.length);

        int numberOfCarryBits = Math.min(bits.length, amount);
        int i = 0;
        for (; i < numberOfCarryBits; i++) {
            result.bits[i] = carry.id();
        }
        for (; i < bits.length; i++) {
            result.bits[i] = bits[i - amount].id();
        }

        return result;
    }

    /**
     * Computes the bit vector resulting from shifting this bit vector {@code amount} bits to the right. The given
     * {@code carry} is shifted in {@code amount} times. This operation returns a new bit vector. The bit vector on
     * which the operation is performed is not modified or {@link #free freed}.
     *
     * @param amount The amount of bits to shift.
     * @param carry The carry to shift in. Only copies are used.
     * @return The shifted bit vector.
     * @throws IllegalArgumentException If the shift amount is negative.
     */
    public T shiftRight(int amount, BDD carry) {
        // Precondition check.
        if (amount < 0) {
            throw new IllegalArgumentException("Amount is negative.");
        }

        // Compute result.
        T result = createEmpty(bits.length);

        int numberOfPreservedBits = Math.max(0, bits.length - amount);
        int i = 0;
        for (; i < numberOfPreservedBits; i++) {
            result.bits[i] = bits[i + amount].id();
        }
        for (; i < bits.length; i++) {
            result.bits[i] = carry.id();
        }

        return result;
    }

    /**
     * Computes an if-then-else with this bit vector as the 'then' value, and a given 'if' condition and 'else' value.
     * This operation returns a new bit vector. The bit vectors and condition on which the operation is performed are
     * not modified or {@link #free freed}.
     *
     * @param elseVector The 'else' bit vector.
     * @param condition The 'if' condition.
     * @return The result.
     * @throws IllegalArgumentException If this bit vector and the given bit vector have a different length.
     */
    public T ifThenElse(T elseVector, BDD condition) {
        // Precondition check.
        if (this.bits.length != elseVector.bits.length) {
            throw new IllegalArgumentException("Different lengths.");
        }

        // Compute result.
        T rslt = createEmpty(bits.length);
        for (int i = 0; i < bits.length; i++) {
            rslt.bits[i] = condition.ite(this.getBit(i), elseVector.getBit(i));
        }
        return rslt;
    }

    /**
     * Returns a BDD indicating the conditions that must hold for this bit vector to be strictly less than the given bit
     * vector. This operation returns a new BDD. The bit vectors on which the operation is performed are not modified or
     * {@link #free freed}.
     *
     * @param other The bit vector to compare against.
     * @return A BDD indicating the conditions that must hold for this bit vector to be strictly less than the given bit
     *     vector.
     * @throws IllegalArgumentException If this bit vector and the given bit vector have a different length.
     */
    public abstract BDD lessThan(T other);

    /**
     * Returns a BDD indicating the conditions that must hold for this bit vector to be less than or equal to the given
     * bit vector. This operation returns a new BDD. The bit vectors on which the operation is performed are not
     * modified or {@link #free freed}.
     *
     * @param other The bit vector to compare against.
     * @return A BDD indicating the conditions that must hold for this bit vector to be less than or equal to the given
     *     bit vector.
     * @throws IllegalArgumentException If this bit vector and the given bit vector have a different length.
     */
    public abstract BDD lessOrEqual(T other);

    /**
     * Returns a BDD indicating the conditions that must hold for this bit vector to be strictly greater than the given
     * bit vector. This operation returns a new BDD. The bit vectors on which the operation is performed are not
     * modified or {@link #free freed}.
     *
     * @param other The bit vector to compare against.
     * @return A BDD indicating the conditions that must hold for this bit vector to be strictly greater than the given
     *     bit vector.
     * @throws IllegalArgumentException If this bit vector and the given bit vector have a different length.
     */
    public BDD greaterThan(T other) {
        BDD le = lessOrEqual(other);
        BDD gt = le.not();
        le.free();
        return gt;
    }

    /**
     * Returns a BDD indicating the conditions that must hold for this bit vector to be greater than or equal to the
     * given bit vector. This operation returns a new BDD. The bit vectors on which the operation is performed are not
     * modified or {@link #free freed}.
     *
     * @param other The bit vector to compare against.
     * @return A BDD indicating the conditions that must hold for this bit vector to be greater than or equal to the
     *     given bit vector.
     * @throws IllegalArgumentException If this bit vector and the given bit vector have a different length.
     */
    public BDD greaterOrEqual(T other) {
        BDD lt = lessThan(other);
        BDD ge = lt.not();
        lt.free();
        return ge;
    }

    /**
     * Returns a BDD indicating the conditions that must hold for this bit vector to be equal to the given bit vector.
     * This operation returns a new BDD. The bit vectors on which the operation is performed are not modified or
     * {@link #free freed}.
     *
     * @param other The bit vector to compare against.
     * @return A BDD indicating the conditions that must hold for this bit vector to be equal to the given bit vector.
     * @throws IllegalArgumentException If this bit vector and the given bit vector have a different length.
     */
    public BDD equalTo(T other) {
        // Precondition check.
        if (this.bits.length != other.bits.length) {
            throw new IllegalArgumentException("Different lengths.");
        }

        // Compute result.
        BDD eq = factory.one();
        for (int i = 0; i < bits.length; i++) {
            BDD bit = this.bits[i].biimp(other.bits[i]);
            eq = eq.andWith(bit);
        }
        return eq;
    }

    /**
     * Returns a BDD indicating the conditions that must hold for this bit vector to be unequal to the given bit vector.
     * This operation returns a new BDD. The bit vectors on which the operation is performed are not modified or
     * {@link #free freed}.
     *
     * @param other The bit vector to compare against.
     * @return A BDD indicating the conditions that must hold for this bit vector to be unequal to the given bit vector.
     * @throws IllegalArgumentException If this bit vector and the given bit vector have a different length.
     */
    public BDD unequalTo(T other) {
        BDD eq = this.equalTo(other);
        BDD uneq = eq.not();
        eq.free();
        return uneq;
    }

    /**
     * Frees the bit vector, {@link BDD#free freeing} the BDDs representing the bits. The bit vector should not be used
     * after calling this method.
     */
    public void free() {
        for (int i = 0; i < bits.length; i++) {
            bits[i].free();
        }
        factory = null;
        bits = null;
    }

    @Override
    public String toString() {
        if (bits == null) {
            return "freed";
        }
        return Arrays.toString(bits);
    }
}
