Zheng Junyi
4/30/2022
本周,Computer Security 这门课的作业是用代码实现 RSA 算法,要求加密 + 解密都要实现,语言不限。
这种东西网上一搜一大把,但因为 RSA 太重要了,老师额外要求要 “Defend your code!”(相当于当着他面,一对一做一次 Code Review)。没办法,老老实实的把相关课件重新学了一遍,然后动笔实现。害怕到时候讲不清楚,所以特地写一篇博客理一遍思路。
虽说语言不限,但能方便快捷的操作这么大的数 (老师要求使用至少 1024 bit 的质数!) 的语言还真不多,对比了一下,最终选择了 Python。
快速模幂运算,是与 RSA 相关的算法中最基础也是最核心的一个算法了。毕竟 RSA 加密和解密的本质就是做一个模幂运算;而求我们生成一个大质数的过程中也需要用到这个函数 (费马定理,验证这个数是否为质数)。
这个算法主要是为了减少计算机的计算量,毕竟要硬算 这样的大数的模幂可不容易。
我们都知道,关于取模,有这么一个公式:
再结合 ,我们可以很容易将我们的模幂运算分解:
但是,如何分解 最高效又成了问题。这里我们使用平方求幂原理。
首先将 表示成二进制,即:
可以表现为:
因此, 为:
具体实现的代码如下 (递归):
虽说递归的实现最为直观,但我在后面的实际使用过程中遇到了性能问题,不得已重新用 while 重构了一遍:
接着,我们需要一个函数用来生成一个固定位数的大数,这个数可能是质数也可能是合数。
我们直接用 Python 的 random 模块来生成一个 w 位的二进制字符串 (仅有 0 和 1 组成的字符串),然后将字符串转化为数字 (十进制):
需要注意的是,二进制字符串的首位必须位 '1',因为如果首位为 0,就不能保证转换后的数字是 w 位了。同时,末尾也必须为 '1',因为我们的这个函数的目的是生成 (伪) 质数,而所有的偶数都是合数。
接下来我们需要一个函数来对生成的 (伪) 质数进行验证,这里我们需要用到米勒-拉宾素性检验 (Miller-Rabin prime test)。
对于素数 和任意整数 ,有:
米勒-拉宾素性检验在费马定理的基础上增加了推论:对于任何大于 的素数 ,不存在 的「非平凡平方根」。
我们发现 和 总是得到 ,我们称 和 是 的「平凡平方根」。
让我们假设 是 的平方根之一 (),于是有:
因为 是质数, 和 均不可能为 。
那么唯一的可能性只能是:
所以:
我们获得推论:如果 是素数,则 的解为 或 。
但是对于所有数都使用米勒-拉宾素性检验有点不划算,因为这种算法的开销其实还是挺大的。
我们可以先将 至 以内的所有素数列出来,用这些素数对需要测试的数取模,其结果如果为 ,则必并是合数 (素数的倍数必定是合数);然后再对非零的情况使用米勒-拉宾素性检验:
需要注意的是,米勒-拉宾素性检验是个证明素数的只是一个必要但不充分条件,有极小的概率误判;不过当我们取多个 a 都能通过测试的话,我们可以将误判的概念忽略不计。
因此用一个函数 prime_test 将 prime_miller_rabin 包装一下,生成多个随机数 a,测试其是否都符合条件。
接着用 prime_test 来测试 fake_prime 生成的数是否是质数,如果不是,则重新生成,直到获得一个质数为止。
终于可以正式开始实现 RSA 算法了!
我们先来看一下公式:
是需要加密的信息。 是我们的公钥。其中 就是 与 的乘积,即:
而 是一个 小于 并且与 互质的正整数。
就是所谓的欧拉函数,公式如下:
的大小直接影响了加密的速度,所以一般不会很大,也不会很小,并且里面 (二进制) 的 1 要尽可能的少。这里我们选择 0b1000000000001 (即 4097) 作为默认的 。
这个 虽然不直接参与 RSA 加密,但其必须 互质;我们通过欧几里得算法求 与 的最大公约数,进而判断它们是否互质;如果不互质,就换一个 (e -= 1):
套用公式,实现对 的加密:
除了加密,我们还需要实现解密。
同样的,先来看一看公式:
这里的 是上一步得到的密文。 是我们的私钥。
已经解释过了,是 与 的乘积 ()。
这个 是 关于 的逆模元,即:
所以:
这里我们使用拓展欧几里得算法:
令 ,,,,,,
当 时:
当 时:
写成代码就是:
然后套用公式,就可以对 进行解密了:
到现在为止,虽然 RSA 算法已经可以用了,但仍然很慢;特别是解密,当质数的位数非常大时就非常容易卡死。 因此,我们需要使用中国剩余算法对解密进行优化。
中国剩余算法来自于《孙子算经》的第 26 题:
今有物不知其数,三三数之剩二,五五数之剩三,七七数之剩二,问物几何? 答曰:二十三。 术曰:三三数之剩二,置一百四十;五五数之剩三,置六十三,七七数之剩二,置三十,并之。得二百三十三,以二百一十减之,即得。凡三三数之剩一,则置七十;五> 五数之剩一,则置二十一;七七数之剩一,则置十五;一百六以上以一百五减之即得。
其核心思想就是辗转相除法。
对于给定的数 ,令 与其互质;令 满足 。此时,有且只有一个数 同时满足:
和
其在 RSA 解密中的应用如下:
计算 和 :
计算 和 :
使用 将其改写为同余的形式:
使用中国剩余定理计算 :
写成代码就是:
然后使用 chinese_remainder 代替 powmod 进行解密:
完整的代码如下:
def powmod(p, q, m):
if not q: return 1
if q & 1: return (p * pow_mod(p, q - 1, m)) % m
return pow_mod((p ** 2) % n, q >> 1, m)
def powmod(p, q, m):
r = 1
while q:
if q & 1:
r = (r * p) % m
q >>= 1
p = (p * p) % m
return r
import random
def fake_prime(w):
list = []
list.append('1')
for i in range(w - 2):
list.append(random.choice(['0', '1']))
list.append('1')
return int(''.join(list), 2)
def prime_miller_rabin(a, p):
prime_numbers = (2,3,5,7,11,13,17,19,23,29,31,37,41,
43,47,53,59,61,67,71,73,79,83,89,97)
for y in prime_numbers:
if p % y == 0:
return False
def fermat():
return powmod(a, p - 1, p) == 1
def miller_rabin():
d = p - 1
q = 0
while not(d & 1):
q = q+1
d >>= 1
m = d
for i in range(q):
u = m * (2**i)
tmp = powmod(a, u, p)
if tmp == 1 or tmp == p - 1:
return True
return False
return fermat() and miller_rabin()
import random
def prime_test(n, k):
if k <= 0: return True
a = random.randint(2, n - 1)
return prime_miller_rabin(a, n) and prime_test(n, k - 1)
def get_prime(bit):
prime_number = fake_prime(bit)
if prime_test(prime_number, 5):
return prime_number
return get_prime(bit)
def get_r_e(p, q, e = 4097):
def gcd(a, b):
if not b: return a
return gcd(b, a % b)
r = (p - 1) * (q - 1)
if gcd(e, r) == 1:
return (r, e)
return gen_orla(p, q, e - 1)
if __name__ == "__main__":
p = get_prime(1024)
q = get_prime(1024)
n = p * q
r, e = get_r_e(p, q)
M = int("7355608")
C = powmod(M, e, n)
def mod_inverse(e, r):
x0, x1, x2 = r, 0, 1
while r != 0:
x1, x2 = x2 - ((e // r) * x1), x1
e, r = r, e % r
if x2 < 0: x2 += x0
return x2
if __name__ == "__main__":
# p, q, C, r, n 来自于上一步
d = mod_inverse(e, r)
D = powmod(C, d, n)
def chinese_remainder(c, d, p, q):
dp = powmod(d, 1, p - 1)
dq = powmod(d, 1, q - 1)
m1 = powmod(c, dp, p)
m2 = powmod(c, dq, q)
qinv = mod_inverse(q, p)
h = (qinv * (m1 - m2)) % p
m = m2 + h * q
return m
if __name__ == "__main__":
# p, q, C, r, n 来自于上一步
d = mod_inverse(e, r)
D = chinese_remainder(C, d, p, q)
#!/usr/bin/env python
import random
def powmod(p, q, m):
"""
p^q mod n
"""
r = 1
while q:
if q & 1:
r = (r * p) % m
q >>= 1
p = (p * p) % m
return r
def fake_prime(w):
list = []
list.append('1')
for i in range(w - 2):
list.append(random.choice(['0', '1']))
list.append('1')
return int(''.join(list), 2)
def prime_miller_rabin(a, p):
prime_numbers = (2,3,5,7,11,13,17,19,23,29,31,37,41,
43,47,53,59,61,67,71,73,79,83,89,97)
for y in prime_numbers:
if p % y == 0:
return False
def fermat():
"""
a^(p-1) = 1 (mod p)
"""
return powmod(a, p - 1, p) == 1
def miller_rabin():
"""
a^2 = 1 (mod p)
a = 1 or a = p - 1
"""
d = p - 1
q = 0
while not (d & 1):
q = q+1
d >>= 1
m = d
for i in range(q):
u = m * (2**i)
tmp = powmod(a, u, p)
if tmp == 1 or tmp == p - 1:
return True
return False
return fermat() and miller_rabin()
def prime_test(n, k):
if k <= 0: return True
a = random.randint(2, n - 1)
return prime_miller_rabin(a, n) and prime_test(n, k - 1)
def get_prime(bit):
prime_number = fake_prime(bit)
if prime_test(prime_number, 5):
return prime_number
return get_prime(bit)
def get_r_e(p, q, e = 4097):
def gcd(a, b):
"""
Euclidean algorithm
"""
if not b: return a
return gcd(b, a % b)
def euler(p, q):
"""
Euler's totient function
"""
return (p - 1) * (q - 1)
r = euler(p, q)
if gcd(e, r) == 1:
return (r, e)
return get_r_e(p, q, e - 1)
def mod_inverse(e, r):
"""
Modular inverse operation\n
y1 = e^-1 mod r
"""
x0, x1, x2 = r, 0, 1
while r != 0:
x1, x2 = x2 - ((e // r) * x1), x1
e, r = r, e % r
if x2 < 0: x2 += x0
return x2
def chinese_remainder(c, d, p, q):
"""
Chinese remainder theorem
"""
dp = powmod(d, 1, p - 1)
dq = powmod(d, 1, q - 1)
m1 = powmod(c, dp, p)
m2 = powmod(c, dq, q)
qinv = mod_inverse(q, p)
h = (qinv * (m1 - m2)) % p
m = m2 + h * q
return m
if __name__ == "__main__":
p = get_prime(1024)
q = get_prime(1024)
n = p * q
r, e = get_r_e(p, q)
d = mod_inverse(e, r)
print('Private key:\n')
print('p: %d\n' % p)
print('q: %d\n' % q)
print('d: %d\n\n' % d)
print('Public key:\n')
print('n: %d\n' % n)
print('e: %d\n\n' % e)
M = int(input("Input message: "))
C = powmod(M, e, n) # encryption
print('\nAfter encryption is completed, the ciphertext obtained:\n%d\n' % C)
D = chinese_remainder(C, d, p, q) # decryption
print('Decryption is complete, and the plaintext obtained is:\n%d\n' % D)
因为 ,所以: