from __future__ import absolute_import


from sage.all import *

from time import time
from bisect import bisect_right
        
############################################################################
############################################################################
# Linear algebra (for fields)
############################################################################
############################################################################

###################################################
# auxiliary
###################################################

def set_up_matrix(rows, columns):  
    K = rows[0].parent().base_ring()
    nr = len(rows)
    nc = len(columns)
    A = matrix(K,nr,nc,sparse=True)
    
    cols = {m:i for i,m in enumerate(columns)}
                
    for i,f in enumerate(rows):
        for c,m in zip(f.coefficients(), f.monomials()):
            A[i,cols[m]] = c
    
    return A

###################################################
# signature safe row reduction
###################################################
def signature_safe_reduction(A, blocks):
        """
        Replace self by its reduction to reduced row echelon form.
        """
        
        nr,nc = A.dimensions()
              
        for c in range(nc):
            # find pivot row
            pivot = -1
            nz_pos = A._nonzero_positions_by_row(copy=False)
            for r in range(nr):
                # find row with pivot
                pos = [j for i,j in nz_pos if i == r]
                if pos and pos[0] == c:
                    pivot = r
                    break                
            if pivot == -1: continue
            
            # normalize row 
            A[pivot] *= 1/A[pivot,c]            
            
            # find index where next block starts
            next_block = blocks[bisect_right(blocks,pivot)]
           
            # reduce rows below
            for i in range(next_block, nr):
                s = A[i,c]
                if s: A.add_multiple_of_row(i,pivot,-s,start_col=c)
            
            # normalize rows
            pos = A.nonzero_positions()
            for i in range(nr):
                j = next((j for k,j in pos if k == i),None)
                if j and A[i,j] != 1:
                    A[i] /= A[i,j]
    
        return A