Post
Topic
Board Development & Technical Discussion
Re: Need help understanding this modular inverse implementation
by
j2002ba2
on 11/06/2024, 09:12:09 UTC
I mean lots of words explaining it.

Here is the code I made back in 2019, this is slower than the 5x52 version (it is 8x32), but was faster than xp-2.
Code:
__device__ static void invModP_(unsigned int value[8]) {
  //INPUT : Prime p and a in [1, p−1].
  //OUTPUT : a^−1 mod p.
  //1. u = a, v = p.
  unsigned int u[8], v[8];
  copyBigInt(value, u);
  copyBigInt(_P, v);
  //2. x1 = 1, x2 = 0.
  unsigned int x1[8] = { 0,0,0,0, 0,0,0,1 };
  unsigned int x2[8] = { 0,0,0,0, 0,0,0,0 };
  //3. While (u != 1 and v != 1) do
  for (;;) {
    if ( ((u[7] == 1) && ((u[0] | u[1] | u[2] | u[3] | u[4] | u[5] | u[6]) == 0)) ||
         ((v[7] == 1) && ((v[0] | v[1] | v[2] | v[3] | v[4] | v[5] | v[6]) == 0)) )
      break;
    //3.1 While u is even do
    while ((u[7] & 1) == 0) {
      //u = u/2.
      shift_right(u);
      unsigned int carry = 0;
      //If x1 is even then x1 = x1/2; else x1 = (x1 + p)/2.
      if ((x1[7] & 1) != 0) {
        carry = add(x1, _P, x1);
      }
      shift_right_c(x1, carry);
    }
    //3.2 While v is even do
    while ((v[7] & 1) == 0) {
      //v = v/2.
      shift_right(v);
      unsigned int carry = 0;
      //If x2 is even then x2 = x2/2; else x2 = (x2 + p)/2.
      if ((x2[7] & 1) != 0) {
        carry = add(x2, _P, x2);
      }
      shift_right_c(x2, carry);
    }
    //3.3 If u >= v then: u = u − v, x1 = x1 − x2 ;
    unsigned int t[8];
    if (sub(u, v, t)) {
      //Else: v = v − u, x2 = x2 − x1.
      sub(v, u, v);
      subModP(x2, x1, x2);
    } else {
      copyBigInt(t, u);
      subModP(x1, x2, x1);
    }
  }
  //4. If u = 1 then return(x1 mod p); else return(x2 mod p).
  if ((u[7] == 1) && ((u[0] | u[1] | u[2] | u[3] | u[4] | u[5] | u[6]) == 0)) {
    copyBigInt(x1, value);
  } else {
    copyBigInt(x2, value);
  }
}
/*
INPUT : Prime p and a in [1, p−1].
OUTPUT : a^−1 mod p.
1. u = a, v = p.
2. x1 = 1, x2 = 0.
3. While (u != 1 and v != 1) do
   3.1 While u is even do
       u = u/2.
       If x1 is even then x1 = x1/2; else x1 = (x1 + p)/2.
   3.2 While v is even do
       v = v/2.
       If x2 is even then x2 = x2/2; else x2 = (x2 + p)/2.
   3.3 If u >= v then: u = u − v, x1 = x1 − x2 ;
       Else: v = v − u, x2 = x2 − x1.
4. If u = 1 then return(x1 mod p); else return(x2 mod p).
*/