import ast
import functools
import numbers
import re

from collections import OrderedDict, defaultdict
from fractions import Fraction, gcd


__all__ = [
    'Expression',
    'Symbol', 'Dummy', 'symbols',
    'Rational',
]


def _polymorphic(func):
    @functools.wraps(func)
    def wrapper(left, right):
        if isinstance(right, Expression):
            return func(left, right)
        elif isinstance(right, numbers.Rational):
            right = Rational(right)
            return func(left, right)
        return NotImplemented
    return wrapper


class Expression:
    """
    This class implements linear expressions.
    """

    __slots__ = (
        '_coefficients',
        '_constant',
        '_symbols',
        '_dimension',
    )

    def __new__(cls, coefficients=None, constant=0):
        if isinstance(coefficients, str):
            if constant:
                raise TypeError('too many arguments')
            return Expression.fromstring(coefficients)
        if coefficients is None:
            return Rational(constant)
        if isinstance(coefficients, dict):
            coefficients = coefficients.items()
        for symbol, coefficient in coefficients:
            if not isinstance(symbol, Symbol):
                raise TypeError('symbols must be Symbol instances')
        coefficients = [(symbol, coefficient)
            for symbol, coefficient in coefficients if coefficient != 0]
        if len(coefficients) == 0:
            return Rational(constant)
        if len(coefficients) == 1 and constant == 0:
            symbol, coefficient = coefficients[0]
            if coefficient == 1:
                return symbol
        self = object().__new__(cls)
        self._coefficients = OrderedDict()
        for symbol, coefficient in sorted(coefficients,
                key=lambda item: item[0].sortkey()):
            if isinstance(coefficient, Rational):
                coefficient = coefficient.constant
            if not isinstance(coefficient, numbers.Rational):
                raise TypeError('coefficients must be rational numbers '
                    'or Rational instances')
            self._coefficients[symbol] = coefficient
        if isinstance(constant, Rational):
            constant = constant.constant
        if not isinstance(constant, numbers.Rational):
            raise TypeError('constant must be a rational number '
                'or a Rational instance')
        self._constant = constant
        self._symbols = tuple(self._coefficients)
        self._dimension = len(self._symbols)
        return self

    def coefficient(self, symbol):
        if not isinstance(symbol, Symbol):
            raise TypeError('symbol must be a Symbol instance')
        try:
            return self._coefficients[symbol]
        except KeyError:
            return 0

    __getitem__ = coefficient

    def coefficients(self):
        yield from self._coefficients.items()

    @property
    def constant(self):
        return self._constant

    @property
    def symbols(self):
        return self._symbols

    @property
    def dimension(self):
        return self._dimension

    def __hash__(self):
        return hash((tuple(self._coefficients.items()), self._constant))

    def isconstant(self):
        return False

    def issymbol(self):
        return False

    def values(self):
        yield from self._coefficients.values()
        yield self.constant

    def __bool__(self):
        return True

    def __pos__(self):
        return self

    def __neg__(self):
        return self * -1

    @_polymorphic
    def __add__(self, other):
        coefficients = defaultdict(Rational, self.coefficients())
        for symbol, coefficient in other.coefficients():
            coefficients[symbol] += coefficient
        constant = self.constant + other.constant
        return Expression(coefficients, constant)

    __radd__ = __add__

    @_polymorphic
    def __sub__(self, other):
        coefficients = defaultdict(Rational, self.coefficients())
        for symbol, coefficient in other.coefficients():
            coefficients[symbol] -= coefficient
        constant = self.constant - other.constant
        return Expression(coefficients, constant)

    def __rsub__(self, other):
        return -(self - other)

    @_polymorphic
    def __mul__(self, other):
        if other.isconstant():
            coefficients = dict(self.coefficients())
            for symbol in coefficients:
                coefficients[symbol] *= other.constant
            constant = self.constant * other.constant
            return Expression(coefficients, constant)
        if isinstance(other, Expression) and not self.isconstant():
            raise ValueError('non-linear expression: '
                    '{} * {}'.format(self._parenstr(), other._parenstr()))
        return NotImplemented

    __rmul__ = __mul__

    @_polymorphic
    def __truediv__(self, other):
        if other.isconstant():
            coefficients = dict(self.coefficients())
            for symbol in coefficients:
                coefficients[symbol] = Rational(coefficients[symbol], other.constant)
            constant = Rational(self.constant, other.constant)
            return Expression(coefficients, constant)
        if isinstance(other, Expression):
            raise ValueError('non-linear expression: '
                '{} / {}'.format(self._parenstr(), other._parenstr()))
        return NotImplemented

    def __rtruediv__(self, other):
        if isinstance(other, self):
            if self.isconstant():
                return Rational(other, self.constant)
            else:
                raise ValueError('non-linear expression: '
                        '{} / {}'.format(other._parenstr(), self._parenstr()))
        return NotImplemented

    @_polymorphic
    def __eq__(self, other):
        # "normal" equality
        # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
        return isinstance(other, Expression) and \
            self._coefficients == other._coefficients and \
            self.constant == other.constant

    @_polymorphic
    def __le__(self, other):
        from .polyhedra import Le
        return Le(self, other)

    @_polymorphic
    def __lt__(self, other):
        from .polyhedra import Lt
        return Lt(self, other)

    @_polymorphic
    def __ge__(self, other):
        from .polyhedra import Ge
        return Ge(self, other)

    @_polymorphic
    def __gt__(self, other):
        from .polyhedra import Gt
        return Gt(self, other)

    def scaleint(self):
        lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
            [value.denominator for value in self.values()])
        return self * lcm

    def subs(self, symbol, expression=None):
        if expression is None:
            if isinstance(symbol, dict):
                symbol = symbol.items()
            substitutions = symbol
        else:
            substitutions = [(symbol, expression)]
        result = self
        for symbol, expression in substitutions:
            coefficients = [(othersymbol, coefficient)
                for othersymbol, coefficient in result.coefficients()
                if othersymbol != symbol]
            coefficient = result.coefficient(symbol)
            constant = result.constant
            result = Expression(coefficients, constant) + coefficient*expression
        return result

    @classmethod
    def _fromast(cls, node):
        if isinstance(node, ast.Module) and len(node.body) == 1:
            return cls._fromast(node.body[0])
        elif isinstance(node, ast.Expr):
            return cls._fromast(node.value)
        elif isinstance(node, ast.Name):
            return Symbol(node.id)
        elif isinstance(node, ast.Num):
            return Rational(node.n)
        elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
            return -cls._fromast(node.operand)
        elif isinstance(node, ast.BinOp):
            left = cls._fromast(node.left)
            right = cls._fromast(node.right)
            if isinstance(node.op, ast.Add):
                return left + right
            elif isinstance(node.op, ast.Sub):
                return left - right
            elif isinstance(node.op, ast.Mult):
                return left * right
            elif isinstance(node.op, ast.Div):
                return left / right
        raise SyntaxError('invalid syntax')

    _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')

    @classmethod
    def fromstring(cls, string):
        # add implicit multiplication operators, e.g. '5x' -> '5*x'
        string = Expression._RE_NUM_VAR.sub(r'\1*\2', string)
        tree = ast.parse(string, 'eval')
        return cls._fromast(tree)

    def __repr__(self):
        string = ''
        for i, (symbol, coefficient) in enumerate(self.coefficients()):
            if coefficient == 1:
                string += '' if i == 0 else ' + '
                string += '{!r}'.format(symbol)
            elif coefficient == -1:
                string += '-' if i == 0 else ' - '
                string += '{!r}'.format(symbol)
            else:
                if i == 0:
                    string += '{}*{!r}'.format(coefficient, symbol)
                elif coefficient > 0:
                    string += ' + {}*{!r}'.format(coefficient, symbol)
                else:
                    string += ' - {}*{!r}'.format(-coefficient, symbol)
        constant = self.constant
        if len(string) == 0:
            string += '{}'.format(constant)
        elif constant > 0:
            string += ' + {}'.format(constant)
        elif constant < 0:
            string += ' - {}'.format(-constant)
        return string

    def _parenstr(self, always=False):
        string = str(self)
        if not always and (self.isconstant() or self.issymbol()):
            return string
        else:
            return '({})'.format(string)

    @classmethod
    def fromsympy(cls, expr):
        import sympy
        coefficients = []
        constant = 0
        for symbol, coefficient in expr.as_coefficients_dict().items():
            coefficient = Fraction(coefficient.p, coefficient.q)
            if symbol == sympy.S.One:
                constant = coefficient
            elif isinstance(symbol, sympy.Symbol):
                symbol = Symbol(symbol.name)
                coefficients.append((symbol, coefficient))
            else:
                raise ValueError('non-linear expression: {!r}'.format(expr))
        return Expression(coefficients, constant)

    def tosympy(self):
        import sympy
        expr = 0
        for symbol, coefficient in self.coefficients():
            term = coefficient * sympy.Symbol(symbol.name)
            expr += term
        expr += self.constant
        return expr


class Symbol(Expression):

    __slots__ = (
        '_name',
    )

    def __new__(cls, name):
        if not isinstance(name, str):
            raise TypeError('name must be a string')
        self = object().__new__(cls)
        self._name = name.strip()
        return self

    @property
    def name(self):
        return self._name

    def __hash__(self):
        return hash(self.sortkey())

    def coefficient(self, symbol):
        if not isinstance(symbol, Symbol):
            raise TypeError('symbol must be a Symbol instance')
        if symbol == self:
            return 1
        else:
            return 0

    def coefficients(self):
        yield self, 1

    @property
    def constant(self):
        return 0

    @property
    def symbols(self):
        return self,

    @property
    def dimension(self):
        return 1

    def sortkey(self):
        return self.name,

    def issymbol(self):
        return True

    def values(self):
        yield 1

    def __eq__(self, other):
        return not isinstance(other, Dummy) and isinstance(other, Symbol) \
            and self.name == other.name

    def asdummy(self):
        return Dummy(self.name)

    @classmethod
    def _fromast(cls, node):
        if isinstance(node, ast.Module) and len(node.body) == 1:
            return cls._fromast(node.body[0])
        elif isinstance(node, ast.Expr):
            return cls._fromast(node.value)
        elif isinstance(node, ast.Name):
            return Symbol(node.id)
        raise SyntaxError('invalid syntax')

    def __repr__(self):
        return self.name

    @classmethod
    def fromsympy(cls, expr):
        import sympy
        if isinstance(expr, sympy.Symbol):
            return cls(expr.name)
        else:
            raise TypeError('expr must be a sympy.Symbol instance')


class Dummy(Symbol):

    __slots__ = (
        '_name',
        '_index',
    )

    _count = 0

    def __new__(cls, name=None):
        if name is None:
            name = 'Dummy_{}'.format(Dummy._count)
        self = object().__new__(cls)
        self._name = name.strip()
        self._index = Dummy._count
        Dummy._count += 1
        return self

    def __hash__(self):
        return hash(self.sortkey())

    def sortkey(self):
        return self._name, self._index

    def __eq__(self, other):
        return isinstance(other, Dummy) and self._index == other._index

    def __repr__(self):
        return '_{}'.format(self.name)


def symbols(names):
    if isinstance(names, str):
        names = names.replace(',', ' ').split()
    return tuple(Symbol(name) for name in names)


class Rational(Expression):

    __slots__ = (
        '_constant',
    )

    def __new__(cls, numerator=0, denominator=None):
        self = object().__new__(cls)
        if denominator is None and isinstance(numerator, Rational):
            self._constant = numerator.constant
        else:
            self._constant = Fraction(numerator, denominator)
        return self

    def __hash__(self):
        return hash(self.constant)

    def coefficient(self, symbol):
        if not isinstance(symbol, Symbol):
            raise TypeError('symbol must be a Symbol instance')
        return 0

    def coefficients(self):
        yield from ()

    @property
    def symbols(self):
        return ()

    @property
    def dimension(self):
        return 0

    def isconstant(self):
        return True

    def values(self):
        yield self._constant

    @_polymorphic
    def __eq__(self, other):
        return isinstance(other, Rational) and self.constant == other.constant

    def __bool__(self):
        return self.constant != 0

    @classmethod
    def fromstring(cls, string):
        if not isinstance(string, str):
            raise TypeError('string must be a string instance')
        return Rational(Fraction(string))

    @classmethod
    def fromsympy(cls, expr):
        import sympy
        if isinstance(expr, sympy.Rational):
            return Rational(expr.p, expr.q)
        elif isinstance(expr, numbers.Rational):
            return Rational(expr)
        else:
            raise TypeError('expr must be a sympy.Rational instance')
