# Copyright 2014 MINES ParisTech
#
# This file is part of LinPy.
#
# LinPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# LinPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with LinPy.  If not, see <http://www.gnu.org/licenses/>.

import math
import numbers
import operator

from abc import ABC, abstractmethod, abstractproperty
from collections import Mapping, OrderedDict

from .linexprs import Symbol


__all__ = [
    'GeometricObject',
    'Point',
    'Vector',
]


class GeometricObject(ABC):
    """
    GeometricObject is an abstract class to represent objects with a
    geometric representation in space. Subclasses of GeometricObject are
    Polyhedron, Domain and Point.
    """

    @abstractproperty
    def symbols(self):
        """
        The tuple of symbols present in the object expression, sorted according
        to Symbol.sortkey().
        """
        pass

    @property
    def dimension(self):
        """
        The dimension of the object, i.e. the number of symbols present in it.
        """
        return len(self.symbols)

    @abstractmethod
    def aspolyhedron(self):
        """
        Return a Polyhedron object that approximates the geometric object.
        """
        pass

    def asdomain(self):
        """
        Return a Domain object that approximates the geometric object.
        """
        return self.aspolyhedron()


class Coordinates:
    """
    This class represents coordinate systems.
    """

    __slots__ = (
        '_coordinates',
    )

    def __new__(cls, coordinates):
        """
        Create a coordinate system from a dictionary or a sequence that maps
        the symbols to their coordinates. Coordinates must be rational numbers.
        """
        if isinstance(coordinates, Mapping):
            coordinates = coordinates.items()
        self = object().__new__(cls)
        self._coordinates = []
        for symbol, coordinate in coordinates:
            if not isinstance(symbol, Symbol):
                raise TypeError('symbols must be Symbol instances')
            if not isinstance(coordinate, numbers.Real):
                raise TypeError('coordinates must be real numbers')
            self._coordinates.append((symbol, coordinate))
        self._coordinates.sort(key=lambda item: item[0].sortkey())
        self._coordinates = OrderedDict(self._coordinates)
        return self

    @property
    def symbols(self):
        """
        The tuple of symbols present in the coordinate system, sorted according
        to Symbol.sortkey().
        """
        return tuple(self._coordinates)

    @property
    def dimension(self):
        """
        The dimension of the coordinate system, i.e. the number of symbols
        present in it.
        """
        return len(self.symbols)

    def coordinate(self, symbol):
        """
        Return the coordinate value of the given symbol. Raise KeyError if the
        symbol is not involved in the coordinate system.
        """
        if not isinstance(symbol, Symbol):
            raise TypeError('symbol must be a Symbol instance')
        return self._coordinates[symbol]

    __getitem__ = coordinate

    def coordinates(self):
        """
        Iterate over the pairs (symbol, value) of coordinates in the coordinate
        system.
        """
        yield from self._coordinates.items()

    def values(self):
        """
        Iterate over the coordinate values in the coordinate system.
        """
        yield from self._coordinates.values()

    def __bool__(self):
        """
        Return True if not all coordinates are 0.
        """
        return any(self._coordinates.values())

    def __eq__(self, other):
        """
        Return True if two coordinate systems are equal.
        """
        if isinstance(other, self.__class__):
            return self._coordinates == other._coordinates
        return NotImplemented

    def __hash__(self):
        return hash(tuple(self.coordinates()))

    def __repr__(self):
        string = ', '.join(['{!r}: {!r}'.format(symbol, coordinate)
                            for symbol, coordinate in self.coordinates()])
        return '{}({{{}}})'.format(self.__class__.__name__, string)

    def _map(self, func):
        for symbol, coordinate in self.coordinates():
            yield symbol, func(coordinate)

    def _iter2(self, other):
        if self.symbols != other.symbols:
            raise ValueError('arguments must belong to the same space')
        coordinates1 = self._coordinates.values()
        coordinates2 = other._coordinates.values()
        yield from zip(self.symbols, coordinates1, coordinates2)

    def _map2(self, other, func):
        for symbol, coordinate1, coordinate2 in self._iter2(other):
            yield symbol, func(coordinate1, coordinate2)


class Point(Coordinates, GeometricObject):
    """
    This class represents points in space.

    Point instances are hashable and should be treated as immutable.
    """

    def isorigin(self):
        """
        Return True if all coordinates are 0.
        """
        return not bool(self)

    def __hash__(self):
        return super().__hash__()

    def __add__(self, other):
        """
        Translate the point by a Vector object and return the resulting point.
        """
        if isinstance(other, Vector):
            coordinates = self._map2(other, operator.add)
            return Point(coordinates)
        return NotImplemented

    def __sub__(self, other):
        """
        If other is a point, substract it from self and return the resulting
        vector. If other is a vector, translate the point by the opposite
        vector and returns the resulting point.
        """
        coordinates = []
        if isinstance(other, Point):
            coordinates = self._map2(other, operator.sub)
            return Vector(coordinates)
        elif isinstance(other, Vector):
            coordinates = self._map2(other, operator.sub)
            return Point(coordinates)
        return NotImplemented

    def aspolyhedron(self):
        from .polyhedra import Polyhedron
        equalities = []
        for symbol, coordinate in self.coordinates():
            equalities.append(symbol - coordinate)
        return Polyhedron(equalities)


class Vector(Coordinates):
    """
    This class represents vectors in space.

    Vector instances are hashable and should be treated as immutable.
    """

    def __new__(cls, initial, terminal=None):
        """
        Create a vector from a dictionary or a sequence that maps the symbols
        to their coordinates, or as the displacement between two points.
        """
        if not isinstance(initial, Point):
            initial = Point(initial)
        if terminal is None:
            coordinates = initial._coordinates
        else:
            if not isinstance(terminal, Point):
                terminal = Point(terminal)
            coordinates = terminal._map2(initial, operator.sub)
        return super().__new__(cls, coordinates)

    def isnull(self):
        """
        Return True if all coordinates are 0.
        """
        return not bool(self)

    def __hash__(self):
        return super().__hash__()

    def __add__(self, other):
        """
        If other is a point, translate it with the vector self and return the
        resulting point. If other is a vector, return the vector self + other.
        """
        if isinstance(other, (Point, Vector)):
            coordinates = self._map2(other, operator.add)
            return other.__class__(coordinates)
        return NotImplemented

    def __sub__(self, other):
        """
        If other is a point, substract it from the vector self and return the
        resulting point. If other is a vector, return the vector self - other.
        """
        if isinstance(other, (Point, Vector)):
            coordinates = self._map2(other, operator.sub)
            return other.__class__(coordinates)
        return NotImplemented

    def __neg__(self):
        """
        Return the vector -self.
        """
        coordinates = self._map(operator.neg)
        return Vector(coordinates)

    def __mul__(self, other):
        """
        Multiplies a Vector by a scalar value.
        """
        if isinstance(other, numbers.Real):
            coordinates = self._map(lambda coordinate: other * coordinate)
            return Vector(coordinates)
        return NotImplemented

    __rmul__ = __mul__

    def __truediv__(self, other):
        """
        Divide the vector by the specified scalar and returns the result as a
        vector.
        """
        if isinstance(other, numbers.Real):
            coordinates = self._map(lambda coordinate: coordinate / other)
            return Vector(coordinates)
        return NotImplemented

    def angle(self, other):
        """
        Retrieve the angle required to rotate the vector into the vector passed
        in argument. The result is an angle in radians, ranging between -pi and
        pi.
        """
        if not isinstance(other, Vector):
            raise TypeError('argument must be a Vector instance')
        cosinus = self.dot(other) / (self.norm()*other.norm())
        return math.acos(cosinus)

    def cross(self, other):
        """
        Compute the cross product of two 3D vectors. If either one of the
        vectors is not three-dimensional, a ValueError exception is raised.
        """
        if not isinstance(other, Vector):
            raise TypeError('other must be a Vector instance')
        if self.dimension != 3 or other.dimension != 3:
            raise ValueError('arguments must be three-dimensional vectors')
        if self.symbols != other.symbols:
            raise ValueError('arguments must belong to the same space')
        x, y, z = self.symbols
        coordinates = []
        coordinates.append((x, self[y]*other[z] - self[z]*other[y]))
        coordinates.append((y, self[z]*other[x] - self[x]*other[z]))
        coordinates.append((z, self[x]*other[y] - self[y]*other[x]))
        return Vector(coordinates)

    def dot(self, other):
        """
        Compute the dot product of two vectors.
        """
        if not isinstance(other, Vector):
            raise TypeError('argument must be a Vector instance')
        result = 0
        for symbol, coordinate1, coordinate2 in self._iter2(other):
            result += coordinate1 * coordinate2
        return result

    def norm(self):
        """
        Return the norm of the vector.
        """
        return math.sqrt(self.norm2())

    def norm2(self):
        """
        Return the squared norm of the vector.
        """
        result = 0
        for coordinate in self._coordinates.values():
            result += coordinate ** 2
        return result

    def asunit(self):
        """
        Return the normalized vector, i.e. the vector of same direction but
        with norm 1.
        """
        return self / self.norm()
