from __future__ import absolute_import

from sage.all import *

from .auxiliary import flatten
from .algorithm import Algorithm
from .ambiguity import Ambiguity
from .linear_algebra import *
from .mixed_polynomial import MixedPolynomial
from .signature_poly import SignaturePoly

import itertools
from time import time
    
############################################################################
############################################################################
# Matrix GVW Algorithm
############################################################################
############################################################################
class Matrix_GVW(Algorithm):
    """
    """
    def __init__(self, M, maxiter=10, maxdeg=-1, count_interval=10, sig_bound=None):
        
        super().__init__(M,maxiter,maxdeg,count_interval,sig_bound)
        
############################################################################
    def compute_basis(self): 
        global zero_red
        reductions = zero_red = 0 
                
        pairs = copy(self.gens())
                         
        count = 1
        sig_bound = self.sig_bound()
        maxiter = self.maxiter()
        maxdeg = self.maxdeg()
        count_interval = self.count_interval()
        oldlen = 0
        prev_sig = None
                
        while count <= maxiter and len(pairs) > 0:
                       
            # sort critical pairs
            pairs.sort(key = lambda p : (p.sig(),p.lm()))
                                                                        
            # select pairs
            #d = len(pairs[0].sig())
            #P = [p for p in pairs if len(p.sig()) == d]
            P = pairs[:1]
                        
            pairs = pairs[len(P):]
            
            # check singular criterion
            if prev_sig and prev_sig == P[0].sig(): continue
            prev_sig = P[0].sig()
                                                                                
            # apply criteria
            P = self.criteria(P)

            if P:
                reductions += len(P)              
                # reduce pairs         
                new_poly,new_syz = self.reduction(P)
                                                                                                                                                                                                                                                      
                self.update_G(new_poly)
                self.update_H(new_syz)
            
                if new_poly:
                    new_pairs, new_syz = self.compute_crit_pairs(oldlen)
                                        
                    self.update_H(new_syz)
                    oldlen = len(self._G)
                                    
                    # delete pairs with too large signature
                    if sig_bound: new_pairs = [p for p in new_pairs if p.sig() < sig_bound]
                                                                                                                            
                    # append new pairs
                    pairs += new_pairs
                                                                                                           
            if count % count_interval == 0:
                print("Iteration " + str(count) + " finished. G has now " + str(len(self._G)) + " elements.\n")
            count += 1
            
        if count < maxiter:
            print("All critical pairs were reduced to 0.")
        
        self._H = list(set(self._H))
        
        print("reductions = %d" % reductions)
        print("reductions to 0 = %d " % zero_red)
                 
        return self._G, self._H
############################################################################
############################################################################
# Compute crit pairs
############################################################################
############################################################################         
    def compute_crit_pairs(self,oldlen=0):
        
        maxdeg = self.maxdeg()

        words = [(i,g.lm()) for i,g in enumerate(self._G)]
        old_words = words[:oldlen]
        new_words = words[oldlen:]
        
        start = time()
        amb  = [Ambiguity.generate(v,w,maxdeg) for v,w in itertools.product(old_words,new_words)]
        amb += [Ambiguity.generate(v,w,maxdeg) for v,w in itertools.combinations_with_replacement(new_words,2)]
        amb = list(set(flatten(amb)))
                                                
        print(str(len(amb)) + " ambiguities in total (computation took %.5f)" % (time()-start))
                
        start = time()
        crit_pairs = [a.to_crit_pair(self._G[a.i()],self._G[a.j()]) for a in amb]
                                
        # get syzygies
        syz = list(set(p.sig() for p in crit_pairs if p.degree() == -1))
        
        # get regular non-zero S-polynomials
        crit_pairs = list(set(p for p in crit_pairs if p.degree() > -1))
               
        print(str(len(crit_pairs)) + " critical pairs were generated (computation took %.5f)" % (time()-start))
                                
        return crit_pairs, syz      
############################################################################
############################################################################
# Reduction & Symbolic Preprocessing
############################################################################
############################################################################
    def symbolic_preprocessing(self, P):
        
        sigma = max(p.sig() for p in P)
                     
        done = set()
        T = {m for p in P for m in p.monomials()}
        R = set(P)
                        
        while len(T):
            t = T.pop()
            done.add(t)
            agb = self.find_reducer(t)                                                  
            if agb and agb.sig() < sigma:
                R.add(agb)
                T.update({m for m in agb.monomials() if m not in done})
        return list(R), list(done)
###########################################################################
    def find_reducer(self, t):
        G = self._G
        reducers = []
        for g in G:
            u = t / g.lm()
            if not u: continue
            w1 = g.lm().nc_mon()
            w2 = t.nc_mon()
            if w1:
                reducers += [ g.lrmul(w2[:k],w2[k+len(w1):]) * u  for k in w2.factor_occurrences_iterator(w1)]
            else:
                reducers += [ g.lrmul(w1,w2) * u ]
        
        if reducers: 
            return min(reducers,key=lambda p : p.sig())

        return None  
###########################################################################
    def reduction(self, P): 
            
        rows,cols = self.symbolic_preprocessing(P)         
        cols.sort(reverse=True)
        rows.sort(key=lambda f: f.sig())
        
        blocks = []
        prev_sig = rows[0].sig()
        for i,s in enumerate([f.sig() for f in rows]):
            if s == prev_sig: continue
            blocks.append(i)
            prev_sig = s
        blocks.append(len(rows))
                                                                  
        A = set_up_matrix(rows,cols)
                                                                  
        # reduce A
        A = signature_safe_reduction(A,blocks)
        
        new_poly, new_syz = self.reconstruct_polynomials(A,rows,cols)  
        
        new_poly = self.delete_redundant(new_poly) 
                         
        return new_poly, new_syz

###########################################################################
    def reconstruct_polynomials(self, M, rows, columns):
        global zero_red

        poly = []
        syz = []
        L = rows[0].parent()
        A = L.algebra()
                               
        for i in [M.nrows()-1]:
            s = rows[i].sig()
            pos = M.nonzero_positions_in_row(i)
            # non syzygy
            if pos:
                coeffs = [M[i,j] for j in pos]
                mons = [columns[j] for j in pos]
                coeffs.reverse()
                mons.reverse()
                f = MixedPolynomial(A,coeffs,mons)
                poly.append(SignaturePoly(L,f,s))
            # syzygy
            else:
                zero_red += 1
                syz.append(s)
                    
        return poly, syz  
############################################################################
    def delete_redundant(self, P):
                        
        out = []
        
        for p in P:
            agb = self.find_reducer(p.lm())
            if not agb or p.sig() < agb.sig():
                out.append(p)
                
        return out              
############################################################################
############################################################################
# Additional stuff
############################################################################
############################################################################ 
    def update_G(self, P):
        self._G += P
############################################################################     
    def update_H(self, S):
        sigma = self.sig_bound()
        self._H += [s for s in S if not sigma or s < sigma]
############################################################################ 
############################################################################
############################################################################
# Elimination criteria
############################################################################
############################################################################    
    def F5_criterion(self, p):
        a,j,b = p.sig().aib()
        b = b * a.comm_mon()
        for g in self.G():
            m = g.lm()
            if g.sig().ei() < j and (m.divides(a) or m.divides(b)):
                return True
        return False
            
###########################################################################
    def syzygy_criterion(self, P):
    
        H = self.H()
        out = [p for p in P if not any(p.sig() / h for h in H)]
                 
        return out  
############################################################################               
    def criteria(self, P):
        
        P = self.syzygy_criterion(P)        
        P = [p for p in P if not (self.F5_criterion(p) or self.cover_criterion(p))]
        
        return P
############################################################################            
    def cover_criterion(self, p):
    
        sigma = p.sig()
        m = p.lm()
        for g in self.G():
            uab = sigma / g.sig()
            if uab:
                u,a,b = uab
                lm = g.lm().lrmul(a,b) * u
                if lm < m: 
                    return True
        return False
        
# ###########################################################################
#     cdef list reconstruct_syzygies(Matrix_GVW self):
#         cdef list pivot_rows, rows, cols, syz, coeffs, mons, sigGB, F
#         cdef Matrix_rational_sparse A, T, B
#         cdef SigPoly g
#         cdef Sig s, m
#         cdef str a, b
#         cdef Py_ssize_t i, j, k 
#         cdef LabelledPoly r
#         cdef Rational c, c2
#         cdef mpq_t prod, v
#         
#         mpq_init(prod)
#         mpq_init(v)
#      
#         d = defaultdict(Rational)
#         
#         sigGB = self.G
#         self.G = self.labGB
#         
#         self.lm_automaton.clear()
#         for j,g in enumerate(self.G):
#             mon = g.lm()._mon
#             self.lm_automaton.add_word(mon,(j,mon))
#         self.lm_automaton.make_automaton()
#         
#         syz = self.H
#         self.H = []
#         F = [None] + copy(self.gens)    
#         P = [F[s._ei].lrmul(s._a,s._b) for s in syz]
#         
#             
#         rows,cols = self.symbolic_preprocessing_matrix(P)            
#         set_P = set(P)
#         pivot_rows = [f for f in rows if f not in set_P]        
#         cols.sort(reverse=True)
#         pivot_rows.sort(key=lambda g: g._sig)
#                                         
#         blocks = [len(pivot_rows),len(rows)]                                          
#         A = set_up_matrix(pivot_rows + P,cols)
#         A = augment(A)
#         block_rational_echelon(A,blocks)
#         T = A.matrix_from_rows_and_columns(range(len(pivot_rows),A._nrows),range(A._ncols-A._nrows,A._ncols))
#         
#         cdef mpq_vector* w
#         nonzero_column = [False for _ in range(len(pivot_rows))]
#         for i from 0 <= i < T._nrows:
#             w = &(T._matrix[i])
#             for j from 0 <= j < w.num_nonzero-1:
#                 nonzero_column[w.positions[j]] = True
#                 
#         for j from 0 <= j < len(pivot_rows):
#             if not nonzero_column[j]: continue
#             r = pivot_rows[j]
#             a,_,b = r._pseudo_sig.aib()
#             r._module_mons = [m.lrmul(a,b) for m in r._module_mons]
#         
#         for i from 0 <= i < T._nrows:
#             d.clear()
#             pos,values = zip(*mpq_vector_to_list(&T._matrix[i]))
#             d[syz[i]] = values[-1]
#             for j from 0 <= j < len(pos)-1:
#                 r = pivot_rows[pos[j]]
#                 mpq_set(v, (<Rational>values[j]).value)
#                 for k from 0 <= k < len(r._module_coeffs):
#                     c = <Rational>r._module_coeffs[k]
#                     m = <Sig>r._module_mons[k]
#                     c2 = d[m]
#                     sig_on()
#                     mpq_mul(prod, c.value, v)
#                     mpq_add(c2.value, prod, c2.value)
#                     sig_off()
#                     d[m] = c2
#             
#             coeffs = [c for c in d.values() if c]
#             mons = [m for m in d if d[m]]
#    
#             r = LabelledPoly(NCPoly.zero(), syz[i], syz[i], mons, coeffs)
#             self.H.append(r)
#         
#         mpq_clear(prod)
#         mpq_clear(v)
#                                    
#         self.G = sigGB  
#         return self.H