141 lines
4.0 KiB
Python
141 lines
4.0 KiB
Python
import binascii
|
||
from math import ceil
|
||
from .func import rotl, bytes_to_list
|
||
|
||
IV = [
|
||
1937774191, 1226093241, 388252375, 3666478592,
|
||
2842636476, 372324522, 3817729613, 2969243214,
|
||
]
|
||
|
||
T_j = [
|
||
2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169,
|
||
2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169,
|
||
2043430169, 2043430169, 2043430169, 2043430169, 2055708042, 2055708042,
|
||
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
|
||
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
|
||
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
|
||
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
|
||
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
|
||
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
|
||
2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042,
|
||
2055708042, 2055708042, 2055708042, 2055708042
|
||
]
|
||
|
||
def sm3_ff_j(x, y, z, j):
|
||
if 0 <= j and j < 16:
|
||
ret = x ^ y ^ z
|
||
elif 16 <= j and j < 64:
|
||
ret = (x & y) | (x & z) | (y & z)
|
||
return ret
|
||
|
||
def sm3_gg_j(x, y, z, j):
|
||
if 0 <= j and j < 16:
|
||
ret = x ^ y ^ z
|
||
elif 16 <= j and j < 64:
|
||
#ret = (X | Y) & ((2 ** 32 - 1 - X) | Z)
|
||
ret = (x & y) | ((~ x) & z)
|
||
return ret
|
||
|
||
def sm3_p_0(x):
|
||
return x ^ (rotl(x, 9 % 32)) ^ (rotl(x, 17 % 32))
|
||
|
||
def sm3_p_1(x):
|
||
return x ^ (rotl(x, 15 % 32)) ^ (rotl(x, 23 % 32))
|
||
|
||
def sm3_cf(v_i, b_i):
|
||
w = []
|
||
for i in range(16):
|
||
weight = 0x1000000
|
||
data = 0
|
||
for k in range(i*4,(i+1)*4):
|
||
data = data + b_i[k]*weight
|
||
weight = int(weight/0x100)
|
||
w.append(data)
|
||
|
||
for j in range(16, 68):
|
||
w.append(0)
|
||
w[j] = sm3_p_1(w[j-16] ^ w[j-9] ^ (rotl(w[j-3], 15 % 32))) ^ (rotl(w[j-13], 7 % 32)) ^ w[j-6]
|
||
str1 = "%08x" % w[j]
|
||
w_1 = []
|
||
for j in range(0, 64):
|
||
w_1.append(0)
|
||
w_1[j] = w[j] ^ w[j+4]
|
||
str1 = "%08x" % w_1[j]
|
||
|
||
a, b, c, d, e, f, g, h = v_i
|
||
|
||
for j in range(0, 64):
|
||
ss_1 = rotl(
|
||
((rotl(a, 12 % 32)) +
|
||
e +
|
||
(rotl(T_j[j], j % 32))) & 0xffffffff, 7 % 32
|
||
)
|
||
ss_2 = ss_1 ^ (rotl(a, 12 % 32))
|
||
tt_1 = (sm3_ff_j(a, b, c, j) + d + ss_2 + w_1[j]) & 0xffffffff
|
||
tt_2 = (sm3_gg_j(e, f, g, j) + h + ss_1 + w[j]) & 0xffffffff
|
||
d = c
|
||
c = rotl(b, 9 % 32)
|
||
b = a
|
||
a = tt_1
|
||
h = g
|
||
g = rotl(f, 19 % 32)
|
||
f = e
|
||
e = sm3_p_0(tt_2)
|
||
|
||
a, b, c, d, e, f, g, h = map(
|
||
lambda x:x & 0xFFFFFFFF ,[a, b, c, d, e, f, g, h])
|
||
|
||
v_j = [a, b, c, d, e, f, g, h]
|
||
return [v_j[i] ^ v_i[i] for i in range(8)]
|
||
|
||
def sm3_hash(msg):
|
||
# print(msg)
|
||
len1 = len(msg)
|
||
reserve1 = len1 % 64
|
||
msg.append(0x80)
|
||
reserve1 = reserve1 + 1
|
||
# 56-64, add 64 byte
|
||
range_end = 56
|
||
if reserve1 > range_end:
|
||
range_end = range_end + 64
|
||
|
||
for i in range(reserve1, range_end):
|
||
msg.append(0x00)
|
||
|
||
bit_length = (len1) * 8
|
||
bit_length_str = [bit_length % 0x100]
|
||
for i in range(7):
|
||
bit_length = int(bit_length / 0x100)
|
||
bit_length_str.append(bit_length % 0x100)
|
||
for i in range(8):
|
||
msg.append(bit_length_str[7-i])
|
||
|
||
group_count = round(len(msg) / 64)
|
||
|
||
B = []
|
||
for i in range(0, group_count):
|
||
B.append(msg[i*64:(i+1)*64])
|
||
|
||
V = []
|
||
V.append(IV)
|
||
for i in range(0, group_count):
|
||
V.append(sm3_cf(V[i], B[i]))
|
||
|
||
y = V[i+1]
|
||
result = ""
|
||
for i in y:
|
||
result = '%s%08x' % (result, i)
|
||
return result
|
||
|
||
def sm3_kdf(z, klen): # z为16进制表示的比特串(str),klen为密钥长度(单位byte)
|
||
klen = int(klen)
|
||
ct = 0x00000001
|
||
rcnt = ceil(klen/32)
|
||
zin = [i for i in bytes.fromhex(z.decode('utf8'))]
|
||
ha = ""
|
||
for i in range(rcnt):
|
||
msg = zin + [i for i in binascii.a2b_hex(('%08x' % ct).encode('utf8'))]
|
||
ha = ha + sm3_hash(msg)
|
||
ct += 1
|
||
return ha[0: klen * 2]
|