 # coding: utf-8

"""
Package for computing short cofactor representations in noncommutative polynomial ideals.
See the paper "Short proofs of ideal membership" for further information
================

AUTHOR:

- Clemens Hofstadler (2023-02-01)
"""

#########################################################################################
#  Copyright (C) 2023 Clemens Hofstadler (clemens.hofstadler@mathematik.uni-kassel.de)  #
#                                                                                       #
#  Distributed under the terms of the GNU General Public License (GPL)                  #
#  either version 2, or (at your option) any later version                              #
#                                                                                       #
#  http://www.gnu.org/licenses/                                                         #
#########################################################################################

from time import time
import itertools
from collections import defaultdict
from ahocorasick import Automaton

from signature_gb import *

#########################################################################################
# Main algorithm
#########################################################################################

def find_short_proof(alpha,bound,M,weighted=False):
    """
    This routine implements Algorithm 3 from the paper.
    It computes a (weighted) l1-minimal cofactor representation of the
    polynomial encoded in the cofactor representation ``alpha`` up to 
    the degree bound ``bound`` w.r.t. the generators of the labelled
    module ``M``.
    
    INPUT:
    
    - ``alpha`` -- an initial cofactor representation
    - ``bound`` -- a degree bound; only cofactor representations up to this
        degree bound are considered
    - ``M`` -- a labelled module generated by f_1,...,f_r
    - ``weighted`` (default:False) -- boolean; determines whether the algorithm computes
        an l1-minimal representation (if weighted=False), or a weighted l1-minimal representation
        where each monomial m is given the weight |m|
    
    OUTPUT:
    
    An l1-minimal cofactor representation of the polynomial encoded in ``alpha`` up to the degree
    bound ``bound`` w.r.t. the the generators of the labelled module ``M``.
    
    """
   
    start = time()
    
    D = bound - min(len(f.lm()._mon) for f in M.gens)
    a = str(M.Parent.gens()[0])
    sigma = Sig("",1,a*D)
    
    print("Computing syzygy basis...")
    V = find_syzygy_basis(alpha,M.H,M.labGB,sigma)
    
    print("Computed syzygy basis (took %.2f)" % (time()-start))
    
    W = supp(V)
    rows = module_mons_to_polies(W,M)
    A,_ = set_up_linear_system(rows,alpha._poly)
       
    print("Size of linear system without optimisaton " + str((A.T).dimensions()))
    print("Nonzero elements in A " + str(len(A.nonzero_positions())))
    print("Pruning the basis...")
    
    V,alpha = treat_binomials(V,alpha)
    len_V = len(V) + 1   
    while len_V != len(V):
        len_V = len(V)
        V = criterion(V,alpha)
        V = criterion2(V,alpha)
        
    V = supp(V)
    V.update(supp([alpha]))
    V = list(V)
               
    rows = module_mons_to_polies(V,M)
    A,b = set_up_linear_system(rows,alpha._poly)
      
    print("Size of linear system with optimisation " + str((A.T).dimensions()))
    print("Nonzero elements in A " + str(len(A.nonzero_positions())))
                    
    weights = [len(mon)+1 for mon in V] if weighted else None
    start = time()
    sol = setUpAndSolveLP(A.T,b,weights=weights)
    print("Linear programming took %.2f" % (time()-start))
    proof = prepare_output(M,sol,alpha,V,bound)
        
    return proof

#########################################################################################
# Auxiliary functions
#########################################################################################

def find_syzygy_basis(alpha,H,G,sigma):
    """
    This routine implements Algorithm 3 from the paper.
    It finds all relevant syzygies that can appear in a rewriting sequence.
    
    INPUT: 
    
    - ``alpha`` -- A cofactor representation
    - ``H`` -- a Gröbner basis of the syzygy module up to signature ``sigma`` 
    - ``G``-- a signature Gröbner basis up to signature ``sigma``
    - ``sigma`` -- a signature; the bound until which we search for certificates
    
    OUTPUT:
    
    A set of syzygies, i.e., multiples of elements in H
    
    """
    global already_seen_a, already_seen_b, already_seen_triv, already_seen_nontriv, multiples
            
    todo = set(alpha._module_mons)
    done = set()
    V = {alpha}
    
    A_H = Automaton()
    A_G = Automaton()
    
    for h in H:
        for m in h._module_mons:
            m_str = m.my_str()
            if A_H.exists(m_str):
                l,h_list = A_H.get(m_str)
                h_list.append(h)
                A_H.add_word(m_str,(l,h_list))
            else:
                A_H.add_word(m_str,(len(m_str),[h]))
            
            
    for g in G:
        for m in g._module_mons: 
            m_str = m.my_str()
            if A_G.exists(m_str):
                l,g_list = A_G.get(m_str)
                g_list.append(g)
                A_G.add_word(m_str,(l,g_list))
            else:
                A_G.add_word(m_str,(len(m_str),[g]))
            
    already_seen_a = dict()
    already_seen_b = dict()
    already_seen_triv = set()
    already_seen_nontriv = set()
    multiples = dict()
    
    A_H.make_automaton()
    A_G.make_automaton()
           
    counter = 0
        
    while len(todo) > 0:        
        done.update(todo) 
        counter += len(todo)
        print(counter) 
            
        new  = set()
                
        # non trivial syzygies
        non_trivial = itertools.chain(*[find_nontrivial_syzygies(m,A_H,sigma) for m in todo])
                   
        # trivial syzygies
        trivial =  itertools.chain(*[find_trivial_syzygies(m,G,A_G,sigma) for m in todo])
                
        new = set(itertools.chain(non_trivial,trivial))
        todo = set(itertools.chain(*[h._module_mons for h in new]))
        todo.difference_update(done)        
        V.update(new)
        
    already_seen_a = dict()
    already_seen_b = dict()
          
    return list(V)

def find_nontrivial_syzygies(m,A_H,sigma):
    """
    Finds all non-trivial syzygies with signature smaller than ``sigma`` that contain a module monomial ``m`` . 
    
    INPUT:
    
    - ``m``-- a module monomial
    - ``A_H`` -- pyahocorasick automaton encoding all elements in a syzygy basis (up to signature ``sigma``)
    - ``sigma`` -- a signature
    
    OUTPUT:
    
    All multiples of elements encoded in ``A_H`` with signature less than ``sigma`` that contain ``m``.
    
    """
    global already_seen_nontriv
    
    out = []
    ma = m._a
    mb = m._b

    m_str = m.my_str()
        
    for k,(i,h_list) in A_H.iter(m_str):
        a = m_str[:k-i+1]
        b = m_str[k+1:]
        for h in h_list:
            if h._sig.lrmul(a,b) < sigma:
                if (a,h,b) in already_seen_nontriv: continue
                already_seen_nontriv.add((a,h,b))  
                ahb = h.module_lrmul(a,b)
                out.append(ahb)
    return out
    
    
def find_trivial_syzygies(m,G,A_G,sigma):
    """
    Finds all trivial syzygies with signature smaller than ``sigma`` that contain a module monomial ``m`` . 
    
    INPUT:
    
    - ``m``-- a module monomial
    - ``G`` -- a signature Gröbner basis up to signature ``sigma``
    - ``A_G`` -- pyahocorasick automaton encoding the signatures of the elements in ``G``
    - ``sigma`` -- a signature
    
    OUTPUT:
    
    All trivial syzygies of element in ``G`` with signature less than ``sigma`` that contain ``m``.
    
    """
    global already_seen_a, already_seen_b, already_seen_triv
    
    out = []
    ma = m._a
    mb = m._b
    
    m_str = m.my_str()
    
    for k,(i,g_list) in A_G.iter(m_str):
        a = m_str[:k-i+1]
        b = m_str[k+1:]
        
        for g in g_list:
            if sigma < g._sig.lrmul(a,b): continue
            
            if a in already_seen_a:
                lng2 = already_seen_a[a]
            else:
                lng2 = [(l,n,g2) for g2 in G for l,n in all_suitable_multiples(a,g2,sigma)]
                already_seen_a[a] = lng2
            
            for l,n,g2 in lng2:
                if (l,g2,n,g,b) not in already_seen_triv:
                    already_seen_triv.add((l,g2,n,g,b))
                    triv_syz = form_triv_syz(l,g2,n,g,b,sigma)
                    if triv_syz: out.append(triv_syz)
            
            if b in already_seen_b:
                g2nr = already_seen_b[b]
            else:
                g2nr = [(g2,n,r) for g2 in G for n,r in all_suitable_multiples(b,g2,sigma)]
                already_seen_b[b] = g2nr
        
            for g2,n,r in g2nr:
                if (a,g,n,g2,r) not in already_seen_triv:
                    already_seen_triv.add((a,g,n,g2,r))
                    triv_syz = form_triv_syz(a,g,n,g2,r,sigma)
                    if triv_syz: out.append(triv_syz)
    
    return out
                                     
def form_triv_syz(l,g1,m,g2,r,sigma):
    """
    Forms the trivial syzygy l g1 m r - l g2 m r.
    
    INPUT:
    
    - ``l``-- a monomial; the left cofactor
    - ``g1`` -- a labelled polynomial; the first part of the trivial syzygy
    - ``m`` -- a monomial; the cofactor in the middle
    - ``g2`` -- a labelled polynomial; the second part of the trivial syzygy
    - ``r``-- a monomial; the right cofactor
    - ``sigma`` -- a signature
    
    
    OUTPUT:
    
    The trivial syzygy l g1 m r - l g2 m r if its signature is less than ``sigma``, and
    None otherwise.
    
    """
    global multiples 
        
    mg2r = g2._poly.lrmul(m,r) 
    lg1m = g1._poly.lrmul(l,m)

    s1 = g1._module_mons[-1].lrmul(l,mg2r._mons[-1]._mon)
    s2 = g2._module_mons[-1].lrmul(lg1m._mons[-1]._mon,r)
    
    if sigma < s1 or sigma < s2: return None
            
    mons = set()
    for k in mg2r._mons:
        if (g1,l,k._mon) in multiples:
            new_mons = multiples[(g1,l,k._mon)]
        else:
           new_mons = g1.module_lrmul(l,k._mon)._module_mons
           multiples[(g1,l,k._mon)] = new_mons
        mons.update(new_mons)
    for k in lg1m._mons:
        if (g2, k._mon, r) in multiples:
            new_mons = multiples[(g2, k._mon, r)]
        else:
            new_mons = g2.module_lrmul(k._mon,r)._module_mons
            multiples[(g2,k._mon,r)] = new_mons
        mons.update(new_mons)

    mons = list(mons)
    
    s = s1 if s1 < s2 else s2
    zero = NCPoly.zero()
       
    return LabelledPoly(zero, s, s, mons, [QQ.zero() for i in range(len(mons))])
                                                           
def all_suitable_multiples(t,g,sigma):
    """
    All cofactors (a,b) such that lm(agb) = t and sig(agb) < sigma.
    
    INPUT:
    
    - ``t``-- a monomial
    - ``g`` -- a labelled polynomial
    - ``sigma`` -- a signature
    
    OUTPUT:
    
    A list of all pairs (a,b) such that lm(agb) = t and sig(agb) < sigma.
    
    """
    out = []
    m = g._poly._lm._mon
    k = t.find(m,0)
    while k >= 0:
        a = t[:k]
        b = t[k+len(m):]
        if g._sig.lrmul(a,b) < sigma:
            out.append((a,b))
        k = t.find(m,k+1)
    return out


def setUpLP(A,b,weights):
    """
    Setting up a linear program to find l1-minimal solutions x of the linear system A x = b.
    If weights are given, the coefficients of the solution vector are weighted by these weights.
    
    INPUT:
    
    - ``A`` -- coefficient matrix
    - ``b`` -- right-hand side
    - ``weights`` -- The objective function of the linear programm weighs the
    coordinates of a solution by these weights.
    
    OUTPUT:
    
    A linear program to find (weighted) l1-minimal solutions of A x = b and variables
    u,v such that x = u-v.
    
    """
    LP = MixedIntegerLinearProgram(maximization=False, solver='CPLEX')
    u = LP.new_variable(nonnegative=True)
    v = LP.new_variable(nonnegative=True)
        
    ui,vi = add_constraints(LP,A,b,u,v)
    
    if weights:
        uv = [w*(x+y) for w,x,y in zip(weights,ui,vi)]
    else:
        uv = [x+y for x,y in zip(ui,vi)]
    
    LP.set_objective(LP.sum(uv))
        
    return LP,u,v


def preprocess_A(A,b):
    """
    Detects redundant columns in a coefficient matrix ``A``.
    
    INPUT:
    
    - ``A`` -- coefficient matrix
    - ``b`` -- right-hand side
    
    OUTPUT:
    
    Nonzero positions in ``A`` which are relevant for computing (weighted) l1-minimal
    solutions of A x = b 
    """
    pos = A.nonzero_positions()
    keep_going = True
    
    while keep_going:
        to_delete_cols = set()
        keep_going = False
        for i in range(1,len(pos)-1):
            row_idx = pos[i][0]
            if pos[i-1][0] != row_idx and row_idx != pos[i+1][0]:
                if b[row_idx] == 0: 
                    keep_going = True
                    to_delete_cols.add(pos[i][1])
        pos = [(i,j) for i,j in pos if j not in to_delete_cols]
    
    return pos

def add_constraints(LP,A,b,u,v):  
    """
    Adds the constraints A x = b to the linear progam ``LP``.
    
    INPUT:
    
    - ``LP`` -- linear program
    - ``A`` -- coefficient matrix
    - ``b`` -- right-hand side
    - ``u`` -- variables of the linear program
    - ``v`` -- variables of the linear program   
    
    OUTPUT:
    
    Subset of the variables ``u`` and ``v``.
    As a side effect, the constraints encoding A x = b are
    added to the linear program ``LP``.
    """
    nr,nc = A.dimensions()
    
    ui = [u[i] for i in range(nc)]
    vi = [v[i] for i in range(nc)]
    constraints = [[] for i in range(nr)]
    
    for (i,j),c in A.dict().items():
        constraints[i] += [(j,c),(j+nc,-c)]
    
    cplex = LP.get_backend()
 
    for i,c in enumerate(constraints):
        if i % 5000 == 0: print(i)
        b_i = b[i]
        cplex.add_linear_constraint(c,b_i,b_i)
    
    return ui,vi      


def setUpAndSolveLP(A,b,weights=None):
    """
    Setting up a linear program to find l1-minimal solutions x of the linear system A x = b
    and solving it.
    If weights are given, the coefficients of the solution vector are weighted by these weights.
    
    INPUT:
    
    - ``A`` -- coefficient matrix
    - ``b`` -- right-hand side
    - ``weights`` (default: None) -- The objective function of the linear programm weighs the
    coordinates of a solution by these weights.
    
    OUTPUT:
    
    A list containing the coefficients of a (weighted) l1-minimal solutions of A x = b.
    
    """
    
    start = time()
    print("Setting up the linear progam")
    LP,u,v = setUpLP(A,b,weights)
    print("Setting up LP took %.2f" % (time()-start))
    print("Solving linear system...")
    start = time()
    LP.solve()
    print("Solving took %.2f" % (time()-start))
    u_sol = LP.get_values(u)
    v_sol = LP.get_values(v)
    return [QQ(u_i - v_i) for u_i, v_i in zip(u_sol.values(),v_sol.values())]

def set_up_linear_system(rows,f):
    """
    Transforms a set of polynomials ``rows`` and a polynomial ``f``
    into a linear system A^T x = b, such that any solution x encodes a cofactor representation
    of ``f`` in terms of the elements in ``rows``.
    
    INPUT:
    
    - ``rows`` -- a set of polynomials
    - ``f`` -- a polynomial
    
    OUTPUT:
    
    A matrix ``A`` and a vector ``b`` such that any solution of A^T x = b encodes a cofactor
    representation of ``f`` in terms of the elements in ``rows``.
    """

    columns = set()
    for r in rows: columns.update(r._mons)
    columns = list(columns)
    columns.sort(reverse=True)

    nr = len(rows)
    nc = len(columns)
    
    A = matrix(QQ,nr,nc,sparse=True)
    b = vector(QQ,nc,sparse=True)
    
    cols = {m:i for i,m in enumerate(columns)}
                
    for i,r in enumerate(rows):
        for c,m in zip(r._coeffs, r._mons):
            j = cols[m]
            A[i,j] = c
    
    for c,m in zip(f._coeffs, f._mons): b[cols[m]] = c
            
    return A,b

def module_mons_to_polies(V,M):
    """
    Transforms a list of module monomials a e_i b into the polynomials a f_i b.
    
    INPUT:
    
    - ``V`` -- a list of module monomials
    - ``M`` -- a labelled module generated by f_1,...,f_r
    
    OUTPUT:
    
    A list of polynomials. 
    
    """
    return [M.gens[mon._ei-1]._poly.lrmul(mon._a,mon._b) for mon in V]
            

def prepare_output(M,sol,alpha,V,bound):
    """
    Turns a coefficient vector into a cofactor representation.
    """
    # produce certificate
    mons = [V[i] for i,x in enumerate(sol) if x]
    coeffs = [x for x in sol if x]
    
    proof = LabelledPoly(alpha._poly,max(mons),None,mons,coeffs)
    print("Found cofactor representation (bound set to %d):" % bound)
    print("Weight (= #terms): %d " % len(proof))

    return proof

def supp(F):
    """
    Return the union of supports of module representations of labelled polynomials in a list ``F``.
    """
    V = set()
    for f in F: V.update(f._module_mons)
    return V
      
def treat_binomials(V,alpha):
    """
    Remove binomial syzygies.
    """
    binomials = [f for f in V if len(f._module_mons) == 2]
    
    D = defaultdict(lambda : None)
    
    for f in binomials: 
        a,b = [max(f._module_mons),min(f._module_mons)]
        D[a] = b
        
    flag = True
    while flag:
        flag = False
        for a,b in list(D.items()):
            if D[b]:
                flag = True   
                D[a] = D[b]    

    binomials_set = set(binomials)
    for idx in range(len(V)):
        f = V[idx]
        if f in binomials_set: continue
        f._module_mons = [m if not D[m] else D[m] for m in f._module_mons]
        V[idx] = f
        
    V = [f for f in V if f not in binomials_set]
    alpha._module_mons = [m if not D[m] else D[m] for m in alpha._module_mons]
        
    return V,alpha
    
      
def criterion(V,alpha):
    """
    Criterion to remove redundant syzygies from a set ``V``.
    """
    D = defaultdict(int)
    
    for f in V: 
        for m in f._module_mons: D[m] += 1
    for m in alpha._module_mons: D[m] += 1
    
    flag = True
    while flag:
        flag = False
        to_delete = []
        for i,f in enumerate(V):
            if f == alpha: pass
            if 2*len([1 for m in f._module_mons if D[m] > 1]) <= len(f._module_mons):
                to_delete.append(i)
                for m in f._module_mons:
                    D[m] -= 1
                    assert D[m] >= 0
        for i in sorted(to_delete,reverse=True):
            flag = True
            V.pop(i)
    
    return V
    
    
def criterion2(V,alpha):
    """
    Criterion to remove redundant syzygies from a set ``V``.
    """
    D = defaultdict(int)
    for f in V:
        for m in f._module_mons: D[m] += 1
    for m in alpha._module_mons: D[m] += 1
    
    almosts = [f for f in V if len([1 for m in f._module_mons if D[m] == 1]) / len(f._module_mons) > 0.3]
    overlaps = []
    for i,f in enumerate(almosts):
        f_mons = set(f._module_mons)
        for g in almosts[i+1:]:
            if f_mons.intersection(set(g._module_mons)): overlaps.append((f,g))
    
    V = extended_criterion(V,overlaps,alpha)
    
    return V
    
      
def extended_criterion(V,candidates,alpha):
    """
    Criterion to remove redundant syzygies from a set ``V``.
    """
    supp_V = defaultdict(int)
    for f in V:
        for m in f._module_mons: supp_V[m] += 1
    for m in alpha._module_mons: supp_V[m] += 1
    
    to_del = set()
    
    for W in candidates:
        supp_W = defaultdict(int)
        for f in W:
            for m in f._module_mons: supp_W[m] += 1
        
        uniques = [len([1 for m in f._module_mons if supp_V[m] == 1]) for f in W]
        overlap_V = [len([1 for m in f._module_mons if supp_W[m] < supp_V[m]]) for f in W]
        
        if all(uniques[idx] >= overlap_V[idx] for idx in range(len(W))):
            to_del.update(W)
    
    V = [f for f in V if f not in to_del]
    return V