#!/usr/bin/env python3

# This is an implementation of the algorithm described in
#
# [ACI10] C. Ancourt, F. Coelho and F. Irigoin, A modular static analysis
# approach to affine loop invariants detection (2010), pp. 3 - 16, NSAD 2010.
#
# to compute the transitive closure of an affine transformer. A refined version
# of this algorithm is implemented in PIPS.

from linpy import Dummy, Eq, Ge, Polyhedron, symbols


class Transformer:

    def __new__(cls, polyhedron, range_symbols, domain_symbols):
        self = object().__new__(cls)
        self.polyhedron = polyhedron
        self.range_symbols = range_symbols
        self.domain_symbols = domain_symbols
        return self

    @property
    def symbols(self):
        return self.range_symbols + self.domain_symbols

    def star(self):
        delta_symbols = [symbol.asdummy() for symbol in self.range_symbols]
        k = Dummy('k')
        polyhedron = self.polyhedron
        for x, xprime, dx in zip(
                self.range_symbols, self.domain_symbols, delta_symbols):
            polyhedron &= Eq(dx, xprime - x)
        polyhedron = polyhedron.project(self.symbols)
        equalities, inequalities = [], []
        for equality in polyhedron.equalities:
            equality += (k-1) * equality.constant
            equalities.append(equality)
        for inequality in polyhedron.inequalities:
            inequality += (k-1) * inequality.constant
            inequalities.append(inequality)
        polyhedron = Polyhedron(equalities, inequalities) & Ge(k, 0)
        polyhedron = polyhedron.project([k])
        for x, xprime, dx in zip(
                self.range_symbols, self.domain_symbols, delta_symbols):
            polyhedron &= Eq(dx, xprime - x)
        polyhedron = polyhedron.project(delta_symbols)
        return Transformer(polyhedron, self.range_symbols, self.domain_symbols)


if __name__ == '__main__':
    i0, i, j0, j = symbols('i0 i j0 j')
    transformer = Transformer(Eq(i, i0 + 2) & Eq(j, j0 + 1),
                              [i0, j0], [i, j])
    print('T  =', transformer.polyhedron)
    print('T* =', transformer.star().polyhedron)
