# Copyright 2014 MINES ParisTech
#
# This file is part of LinPy.
#
# LinPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# LinPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with LinPy.  If not, see <http://www.gnu.org/licenses/>.

import ast
import functools
import numbers
import re

from collections import defaultdict, Mapping, OrderedDict
from fractions import Fraction, gcd


__all__ = [
    'Dummy',
    'LinExpr',
    'Rational',
    'Symbol',
    'symbols',
]


def _polymorphic(func):
    @functools.wraps(func)
    def wrapper(left, right):
        if isinstance(right, LinExpr):
            return func(left, right)
        elif isinstance(right, numbers.Rational):
            right = Rational(right)
            return func(left, right)
        return NotImplemented
    return wrapper


class LinExpr:
    """
    A linear expression consists of a list of coefficient-variable pairs
    that capture the linear terms, plus a constant term. Linear expressions
    are used to build constraints. They are temporary objects that typically
    have short lifespans.

    Linear expressions are generally built using overloaded operators. For
    example, if x is a Symbol, then x + 1 is an instance of LinExpr.

    LinExpr instances are hashable, and should be treated as immutable.
    """

    def __new__(cls, coefficients=None, constant=0):
        """
        Return a linear expression from a dictionary or a sequence, that maps
        symbols to their coefficients, and a constant term. The coefficients
        and the constant term must be rational numbers.

        For example, the linear expression x + 2*y + 1 can be constructed using
        one of the following instructions:

        >>> x, y = symbols('x y')
        >>> LinExpr({x: 1, y: 2}, 1)
        >>> LinExpr([(x, 1), (y, 2)], 1)

        However, it may be easier to use overloaded operators:

        >>> x, y = symbols('x y')
        >>> x + 2*y + 1

        Alternatively, linear expressions can be constructed from a string:

        >>> LinExpr('x + 2y + 1')

        A linear expression with a single symbol of coefficient 1 and no
        constant term is automatically subclassed as a Symbol instance. A
        linear expression with no symbol, only a constant term, is
        automatically subclassed as a Rational instance.
        """
        if isinstance(coefficients, str):
            if constant != 0:
                raise TypeError('too many arguments')
            return LinExpr.fromstring(coefficients)
        if coefficients is None:
            return Rational(constant)
        if isinstance(coefficients, Mapping):
            coefficients = coefficients.items()
        coefficients = list(coefficients)
        for symbol, coefficient in coefficients:
            if not isinstance(symbol, Symbol):
                raise TypeError('symbols must be Symbol instances')
            if not isinstance(coefficient, numbers.Rational):
                raise TypeError('coefficients must be rational numbers')
        if not isinstance(constant, numbers.Rational):
            raise TypeError('constant must be a rational number')
        if len(coefficients) == 0:
            return Rational(constant)
        if len(coefficients) == 1 and constant == 0:
            symbol, coefficient = coefficients[0]
            if coefficient == 1:
                return symbol
        coefficients = [(symbol_, Fraction(coefficient_))
                        for symbol_, coefficient_ in coefficients
                        if coefficient_ != 0]
        coefficients.sort(key=lambda item: item[0].sortkey())
        self = object().__new__(cls)
        self._coefficients = OrderedDict(coefficients)
        self._constant = Fraction(constant)
        self._symbols = tuple(self._coefficients)
        self._dimension = len(self._symbols)
        return self

    def coefficient(self, symbol):
        """
        Return the coefficient value of the given symbol, or 0 if the symbol
        does not appear in the expression.
        """
        if not isinstance(symbol, Symbol):
            raise TypeError('symbol must be a Symbol instance')
        return self._coefficients.get(symbol, Fraction(0))

    __getitem__ = coefficient

    def coefficients(self):
        """
        Iterate over the pairs (symbol, value) of linear terms in the
        expression. The constant term is ignored.
        """
        yield from self._coefficients.items()

    @property
    def constant(self):
        """
        The constant term of the expression.
        """
        return self._constant

    @property
    def symbols(self):
        """
        The tuple of symbols present in the expression, sorted according to
        Symbol.sortkey().
        """
        return self._symbols

    @property
    def dimension(self):
        """
        The dimension of the expression, i.e. the number of symbols present in
        it.
        """
        return self._dimension

    def __hash__(self):
        return hash((tuple(self._coefficients.items()), self._constant))

    def isconstant(self):
        """
        Return True if the expression only consists of a constant term. In this
        case, it is a Rational instance.
        """
        return False

    def issymbol(self):
        """
        Return True if an expression only consists of a symbol with coefficient
        1. In this case, it is a Symbol instance.
        """
        return False

    def values(self):
        """
        Iterate over the coefficient values in the expression, and the constant
        term.
        """
        yield from self._coefficients.values()
        yield self._constant

    def __bool__(self):
        return True

    def __pos__(self):
        return self

    def __neg__(self):
        return self * -1

    @_polymorphic
    def __add__(self, other):
        """
        Return the sum of two linear expressions.
        """
        coefficients = defaultdict(Fraction, self._coefficients)
        for symbol, coefficient in other._coefficients.items():
            coefficients[symbol] += coefficient
        constant = self._constant + other._constant
        return LinExpr(coefficients, constant)

    __radd__ = __add__

    @_polymorphic
    def __sub__(self, other):
        """
        Return the difference between two linear expressions.
        """
        coefficients = defaultdict(Fraction, self._coefficients)
        for symbol, coefficient in other._coefficients.items():
            coefficients[symbol] -= coefficient
        constant = self._constant - other._constant
        return LinExpr(coefficients, constant)

    @_polymorphic
    def __rsub__(self, other):
        return other - self

    def __mul__(self, other):
        """
        Return the product of the linear expression by a rational.
        """
        if isinstance(other, numbers.Rational):
            coefficients = (
                (symbol, coefficient * other)
                for symbol, coefficient in self._coefficients.items())
            constant = self._constant * other
            return LinExpr(coefficients, constant)
        return NotImplemented

    __rmul__ = __mul__

    def __truediv__(self, other):
        """
        Return the quotient of the linear expression by a rational.
        """
        if isinstance(other, numbers.Rational):
            coefficients = (
                (symbol, coefficient / other)
                for symbol, coefficient in self._coefficients.items())
            constant = self._constant / other
            return LinExpr(coefficients, constant)
        return NotImplemented

    @_polymorphic
    def __eq__(self, other):
        """
        Test whether two linear expressions are equal. Unlike methods
        LinExpr.__lt__(), LinExpr.__le__(), LinExpr.__ge__(), LinExpr.__gt__(),
        the result is a boolean value, not a polyhedron. To express that two
        linear expressions are equal or not equal, use functions Eq() and Ne()
        instead.
        """
        return self._coefficients == other._coefficients and \
            self._constant == other._constant

    @_polymorphic
    def __lt__(self, other):
        from .polyhedra import Polyhedron
        return Polyhedron([], [other - self - 1])

    @_polymorphic
    def __le__(self, other):
        from .polyhedra import Polyhedron
        return Polyhedron([], [other - self])

    @_polymorphic
    def __ge__(self, other):
        from .polyhedra import Polyhedron
        return Polyhedron([], [self - other])

    @_polymorphic
    def __gt__(self, other):
        from .polyhedra import Polyhedron
        return Polyhedron([], [self - other - 1])

    def scaleint(self):
        """
        Return the expression multiplied by its lowest common denominator to
        make all values integer.
        """
        lcd = functools.reduce(lambda a, b: a*b // gcd(a, b),
                               [value.denominator for value in self.values()])
        return self * lcd

    def subs(self, symbol, expression=None):
        """
        Substitute the given symbol by an expression and return the resulting
        expression. Raise TypeError if the resulting expression is not linear.

        >>> x, y = symbols('x y')
        >>> e = x + 2*y + 1
        >>> e.subs(y, x - 1)
        3*x - 1

        To perform multiple substitutions at once, pass a sequence or a
        dictionary of (old, new) pairs to subs.

        >>> e.subs({x: y, y: x})
        2*x + y + 1
        """
        if expression is None:
            substitutions = dict(symbol)
        else:
            substitutions = {symbol: expression}
        for symbol in substitutions:
            if not isinstance(symbol, Symbol):
                raise TypeError('symbols must be Symbol instances')
        result = Rational(self._constant)
        for symbol, coefficient in self._coefficients.items():
            expression = substitutions.get(symbol, symbol)
            result += coefficient * expression
        return result

    @classmethod
    def _fromast(cls, node):
        if isinstance(node, ast.Module) and len(node.body) == 1:
            return cls._fromast(node.body[0])
        elif isinstance(node, ast.Expr):
            return cls._fromast(node.value)
        elif isinstance(node, ast.Name):
            return Symbol(node.id)
        elif isinstance(node, ast.Num):
            return Rational(node.n)
        elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
            return -cls._fromast(node.operand)
        elif isinstance(node, ast.BinOp):
            left = cls._fromast(node.left)
            right = cls._fromast(node.right)
            if isinstance(node.op, ast.Add):
                return left + right
            elif isinstance(node.op, ast.Sub):
                return left - right
            elif isinstance(node.op, ast.Mult):
                return left * right
            elif isinstance(node.op, ast.Div):
                return left / right
        raise SyntaxError('invalid syntax')

    _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d]\w*|\()')

    @classmethod
    def fromstring(cls, string):
        """
        Create an expression from a string. Raise SyntaxError if the string is
        not properly formatted.
        """
        # Add implicit multiplication operators, e.g. '5x' -> '5*x'.
        string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string)
        tree = ast.parse(string, 'eval')
        expression = cls._fromast(tree)
        if not isinstance(expression, cls):
            raise SyntaxError('invalid syntax')
        return expression

    def __repr__(self):
        string = ''
        for i, (symbol, coefficient) in enumerate(self.coefficients()):
            if coefficient == 1:
                if i != 0:
                    string += ' + '
            elif coefficient == -1:
                string += '-' if i == 0 else ' - '
            elif i == 0:
                string += '{}*'.format(coefficient)
            elif coefficient > 0:
                string += ' + {}*'.format(coefficient)
            else:
                string += ' - {}*'.format(-coefficient)
            string += '{}'.format(symbol)
        constant = self.constant
        if len(string) == 0:
            string += '{}'.format(constant)
        elif constant > 0:
            string += ' + {}'.format(constant)
        elif constant < 0:
            string += ' - {}'.format(-constant)
        return string

    def _parenstr(self, always=False):
        string = str(self)
        if not always and (self.isconstant() or self.issymbol()):
            return string
        else:
            return '({})'.format(string)

    @classmethod
    def fromsympy(cls, expression):
        """
        Create a linear expression from a SymPy expression. Raise TypeError is
        the sympy expression is not linear.
        """
        import sympy
        coefficients = []
        constant = 0
        for symbol, coefficient in expression.as_coefficients_dict().items():
            coefficient = Fraction(coefficient.p, coefficient.q)
            if symbol == sympy.S.One:
                constant = coefficient
            elif isinstance(symbol, sympy.Dummy):
                # We cannot properly convert dummy symbols with respect to
                # symbol equalities.
                raise TypeError('cannot convert dummy symbols')
            elif isinstance(symbol, sympy.Symbol):
                symbol = Symbol(symbol.name)
                coefficients.append((symbol, coefficient))
            else:
                raise TypeError('non-linear expression: {!r}'.format(
                    expression))
        expression = LinExpr(coefficients, constant)
        if not isinstance(expression, cls):
            raise TypeError('cannot convert to a {} instance'.format(
                cls.__name__))
        return expression

    def tosympy(self):
        """
        Convert the linear expression to a SymPy expression.
        """
        import sympy
        expression = 0
        for symbol, coefficient in self.coefficients():
            term = coefficient * sympy.Symbol(symbol.name)
            expression += term
        expression += self.constant
        return expression


class Symbol(LinExpr):
    """
    Symbols are the basic components to build expressions and constraints.
    They correspond to mathematical variables. Symbols are instances of
    class LinExpr and inherit its functionalities.

    Two instances of Symbol are equal if they have the same name.
    """

    __slots__ = (
        '_name',
        '_constant',
        '_symbols',
        '_dimension',
    )

    def __new__(cls, name):
        """
        Return a symbol with the name string given in argument.
        """
        if not isinstance(name, str):
            raise TypeError('name must be a string')
        node = ast.parse(name)
        try:
            name = node.body[0].value.id
        except (AttributeError, SyntaxError):
            raise SyntaxError('invalid syntax')
        self = object().__new__(cls)
        self._name = name
        self._constant = Fraction(0)
        self._symbols = (self,)
        self._dimension = 1
        return self

    @property
    def _coefficients(self):
        # This is not implemented as an attribute, because __hash__ is not
        # callable in __new__ in class Dummy.
        return {self: Fraction(1)}

    @property
    def name(self):
        """
        The name of the symbol.
        """
        return self._name

    def __hash__(self):
        return hash(self.sortkey())

    def sortkey(self):
        """
        Return a sorting key for the symbol. It is useful to sort a list of
        symbols in a consistent order, as comparison functions are overridden
        (see the documentation of class LinExpr).

        >>> sort(symbols, key=Symbol.sortkey)
        """
        return self.name,

    def issymbol(self):
        return True

    def __eq__(self, other):
        if isinstance(other, Symbol):
            return self.sortkey() == other.sortkey()
        return NotImplemented

    def asdummy(self):
        """
        Return a new Dummy symbol instance with the same name.
        """
        return Dummy(self.name)

    def __repr__(self):
        return self.name


def symbols(names):
    """
    This function returns a tuple of symbols whose names are taken from a comma
    or whitespace delimited string, or a sequence of strings. It is useful to
    define several symbols at once.

    >>> x, y = symbols('x y')
    >>> x, y = symbols('x, y')
    >>> x, y = symbols(['x', 'y'])
    """
    if isinstance(names, str):
        names = names.replace(',', ' ').split()
    return tuple(Symbol(name) for name in names)


class Dummy(Symbol):
    """
    A variation of Symbol in which all symbols are unique and identified by
    an internal count index. If a name is not supplied then a string value
    of the count index will be used. This is useful when a unique, temporary
    variable is needed and the name of the variable used in the expression
    is not important.

    Unlike Symbol, Dummy instances with the same name are not equal:

    >>> x = Symbol('x')
    >>> x1, x2 = Dummy('x'), Dummy('x')
    >>> x == x1
    False
    >>> x1 == x2
    False
    >>> x1 == x1
    True
    """

    _count = 0

    def __new__(cls, name=None):
        """
        Return a fresh dummy symbol with the name string given in argument.
        """
        if name is None:
            name = 'Dummy_{}'.format(Dummy._count)
        self = super().__new__(cls, name)
        self._index = Dummy._count
        Dummy._count += 1
        return self

    def __hash__(self):
        return hash(self.sortkey())

    def sortkey(self):
        return self._name, self._index

    def __repr__(self):
        return '_{}'.format(self.name)


class Rational(LinExpr, Fraction):
    """
    A particular case of linear expressions are rational values, i.e. linear
    expressions consisting only of a constant term, with no symbol. They are
    implemented by the Rational class, that inherits from both LinExpr and
    fractions.Fraction classes.
    """

    __slots__ = (
        '_coefficients',
        '_constant',
        '_symbols',
        '_dimension',
    ) + Fraction.__slots__

    def __new__(cls, numerator=0, denominator=None):
        self = object().__new__(cls)
        self._coefficients = {}
        self._constant = Fraction(numerator, denominator)
        self._symbols = ()
        self._dimension = 0
        self._numerator = self._constant.numerator
        self._denominator = self._constant.denominator
        return self

    def __hash__(self):
        return Fraction.__hash__(self)

    @property
    def constant(self):
        return self

    def isconstant(self):
        return True

    def __bool__(self):
        return Fraction.__bool__(self)

    def __repr__(self):
        if self.denominator == 1:
            return '{!r}'.format(self.numerator)
        else:
            return '{!r}/{!r}'.format(self.numerator, self.denominator)
