实验拓扑
PC --------- AC(tplink)
代码虚拟WTP进程和AC设备进行DTLS交互
具体交互流程
(略,有时间再补)
脚本使用实例
from dtls_connect.CWSecurity import DtlsTunnel
def cw_assemble_discovery_request():
# 构造Discovery Request
pass
if __name__ == '__main__':
application_layer_byte_stream = cw_assemble_discovery_request()
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
ttl = 64
sock.setsockopt(socket.IPPROTO_IP, socket.IP_TTL, ttl)
''' 以下均fail,目前只有单播地址发送 discovery request 可以收到 discovery response
sock.sendto(application_layer_byte_stream, ("255.255.255.255", 5246))
sock.sendto(application_layer_byte_stream, ("192.168.1.255", 5246))
'''
# 向AC发送Discovery Request报文
sock.sendto(application_layer_byte_stream, ("192.168.1.253", 5246))
# 等待AC回复Discovery Response
while True:
data, addr = sock.recvfrom(4096)
print(f"收到来自 {addr} 的报文,长度: {len(data)}")
print("原始报文 (Hex):", data.hex())
# 解析CAPWAP头部(示例:前8字节)
if len(data) >= 12:
message_type = data[11] # 消息类型(Discovery Response为0x02)
if message_type == 2:
break
print(f"消息类型: {message_type} (0x02表示Discovery Response)")
time.sleep(5)
dtls = DtlsTunnel(sock)
dtls.do_handshake()
time.sleep(3)
# CWSecurity.py
import os
import time
import socket
import hashlib
from enum import Enum
from cryptography import x509
from Crypto.PublicKey import RSA
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from Crypto.Util.number import bytes_to_long, long_to_bytes
from cryptography.hazmat.primitives.hmac import HMAC
from cryptography.hazmat.primitives.hashes import SHA1, MD5
# ============================================Define Enum=============================================
class DtlsV1HandShakeType(Enum):
DTLS_V1_HANDSHAKE_HELLO_REQUEST = 0x0
DTLS_V1_HANDSHAKE_CLIENT_HELLO = 0x1
DTLS_V1_HANDSHAKE_SERVER_HELLO = 0x2
DTLS_V1_HANDSHAKE_HELLO_VERIFY_REQUEST = 0x3
DTLS_V1_HANDSHAKE_CERTIFICATE = 0xb
DTLS_V1_HANDSHAKE_SERVER_KEY_EXCHANGE = 0xc
DTLS_V1_HANDSHAKE_CERTIFICATE_REQUEST = 0xd
DTLS_V1_HANDSHAKE_SERVER_HELLO_DONE = 0xe
DTLS_V1_HANDSHAKE_CERTIFICATE_VERIFY = 0xf
DTLS_V1_HANDSHAKE_CLIENT_KEY_EXCHANGE = 0x10
DTLS_V1_HANDSHAKE_FINISHED = 0x14
class DtlsV1RecordType(Enum):
DTLS_V1_RECORD_CHANGE_CIPHER_SPEC = 0x14
DTLS_V1_RECORD_ALERT = 0x15
DTLS_V1_RECORD_HANDSHAKE = 0x16
DTLS_V1_RECORD_APPLICATION_DATA = 0x17
class DtlsConnectState(Enum):
DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE = 0
DTLS_CONNECT_HELLO_VERIFY_REQUEST = 1
DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE = 2
DTLS_CONNECT_SERVER_HELLO = 3
DTLS_CONNECT_CERTIFICATE = 4
DTLS_CONNECT_SERVER_HELLO_DONE = 5
DTLS_CONNECT_CLIENT_KEY_AND_CIPHER_SPEC_CHANGE_AND_FINISHED = 6
DTLS_CONNECT_SERVER_CIPHER_SPEC_CHANGE = 7
DTLS_CONNECT_SERVER_FINISHED = 8
# ====================================================================================================
class DtlsTunnel:
def __init__(self, sock: socket.socket, target_ip="192.168.1.253"):
self.socket = sock
self.client_random = None
self.server_random = None
self.handshake_cookie = None
self.cipher_suite = None
self.certificate = None
self.pre_master_key = None
self.finished_verify_hash_source = dict()
self.next_state = DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE
self.target_ip = target_ip
self.preamble = bytearray([0x1, 0x0, 0x0, 0x0])
self.client_write_MAC_key = None
self.server_write_MAC_key = None
self.client_write_key = None
self.server_write_key = None
self.client_write_IV = None
self.server_write_IV = None
self.explicit_iv = None
self.sequence_number = b'\x00\x00\x00\x00\x00\x00'
def do_handshake(self):
while self.next_state:
print(self.next_state.name)
record_data = None
if self.next_state == DtlsConnectState.DTLS_CONNECT_SERVER_FINISHED:
self.parse_record_layer_of_server_finished()
return True
connect_state = self.next_state
if self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE:
record_data = self.assemble_dtls_client_hello_record_layer()
self.next_state = DtlsConnectState.DTLS_CONNECT_HELLO_VERIFY_REQUEST
elif self.next_state == DtlsConnectState.DTLS_CONNECT_HELLO_VERIFY_REQUEST:
self.parse_record_layer_of_hello_verify_request()
self.next_state = DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE
elif self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE:
record_data = self.assemble_dtls_client_hello_record_layer()
self.next_state = DtlsConnectState.DTLS_CONNECT_SERVER_HELLO
elif self.next_state == DtlsConnectState.DTLS_CONNECT_SERVER_HELLO:
self.parse_record_layer_of_server_hello()
self.next_state = DtlsConnectState.DTLS_CONNECT_CERTIFICATE
elif self.next_state == DtlsConnectState.DTLS_CONNECT_CERTIFICATE:
self.parse_record_layer_of_certificate()
self.next_state = DtlsConnectState.DTLS_CONNECT_SERVER_HELLO_DONE
elif self.next_state == DtlsConnectState.DTLS_CONNECT_SERVER_HELLO_DONE:
self.parse_record_layer_of_server_hello_done()
self.next_state = DtlsConnectState.DTLS_CONNECT_CLIENT_KEY_AND_CIPHER_SPEC_CHANGE_AND_FINISHED
elif self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_KEY_AND_CIPHER_SPEC_CHANGE_AND_FINISHED:
record_data = self.assemble_dtls_client_key_exchange_and_change_cipher_spec_and_finished_packet()
self.next_state = DtlsConnectState.DTLS_CONNECT_SERVER_CIPHER_SPEC_CHANGE
elif self.next_state == DtlsConnectState.DTLS_CONNECT_SERVER_CIPHER_SPEC_CHANGE:
self.parse_record_layer_of_server_change_cipher_spec()
self.next_state = DtlsConnectState.DTLS_CONNECT_SERVER_FINISHED
else:
raise Exception("Unknown connect state")
if record_data:
try:
self.socket.sendto(self.assemble_capwap_control_layer(record_data), (self.target_ip, 5246))
except:
raise Exception(f"Error occurs in state[{connect_state.name}]")
raise Exception("next connection state is None")
def dtls_write(self, plain_text):
epoch = b'\x00\x01'
# seq_num = b'\x00\x00\x00\x00\x00\x01' # 应自己维护全局 seq
seq_num = self.sequence_number
self.add_sequence_number()
content_type = b'\x17'
version = b'\xfe\xff'
self.explicit_iv = os.urandom(16)
print("explicit iv:", self.explicit_iv.hex())
mac = HMAC(self.client_write_MAC_key, SHA1())
mac.update(epoch + seq_num + content_type + version + len(plain_text).to_bytes(2, "big") + plain_text)
mac_result = mac.finalize()
data_to_encrypt = plain_text + mac_result
pad_len = 16 - len(data_to_encrypt) % 16
padded_data = data_to_encrypt + b"".join([(pad_len - 1).to_bytes(1, 'big') for _ in range(pad_len)])
cipher = Cipher(algorithms.AES(self.client_write_key), modes.CBC(self.explicit_iv))
ciphertext = cipher.encryptor().update(padded_data) + cipher.encryptor().finalize()
record_payload = self.explicit_iv + ciphertext
record_header = content_type + version + epoch + seq_num + len(record_payload).to_bytes(2, "big")
self.socket.sendto(self.preamble + record_header + record_payload, (self.target_ip, 5246))
def dtls_read(self, encrypted_text):
cipher = Cipher(algorithms.AES(self.server_write_key), modes.CBC(self.explicit_iv), backend=default_backend())
decrypt = cipher.decryptor()
decrypted = decrypt.update(encrypted_text) + decrypt.finalize()
decrypted_hex = decrypted.hex()
pad_len = int(decrypted_hex[-2:], base=16) + 1
plain_text = decrypted_hex[32:-(40 + pad_len * 2)]
self.explicit_iv = bytes.fromhex(decrypted_hex[:32])
return plain_text
# ===============================packet structure(capwap control layer)===============================
def assemble_capwap_control_layer(self, record_data):
return b"".join((self.preamble, record_data))
# ===================================packet structure(record layer)===================================
def assemble_dtls_client_hello_record_layer(self):
content_type = bytearray([DtlsV1RecordType.DTLS_V1_RECORD_HANDSHAKE.value, ])
version = bytearray([0xfe, 0xff])
epoch = bytearray([0x0, 0x0])
if self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE:
sequence_num = bytearray([0x0, 0x0, 0x0, 0x0, 0x0, 0x0])
length = bytearray([0x0, 0x38])
elif self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE:
sequence_num = bytearray([0x0, 0x0, 0x0, 0x0, 0x0, 0x1])
length = bytearray([0x0, 0x48])
else:
raise Exception('[MESSAGE ASSEMBLE ERROR][DTLS STATUS ERROR]')
dtls_handshake = self.assemble_dtls_client_hello_handshake_layer()
if length[0] * (0xff + 1) + length[1] != len(dtls_handshake):
raise Exception('[MESSAGE ASSEMBLE ERROR][LOC: DTLS_RECORD_LAYER]')
ret_value = b"".join((content_type, version, epoch, sequence_num, length, dtls_handshake))
if self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE:
self.finished_verify_hash_source["ClientHello"] = dtls_handshake
return ret_value
def parse_record_layer_of_hello_verify_request(self):
cookie = None
for repeat_times in range(5):
data, addr = self.socket.recvfrom(4096)
if len(data) >= 4 and data[0] == 0x1:
print("The dtls header has been identified")
if len(data) >= (4 + 13) and data[4 + 0] == 0x16:
print("The dtls handshake message has been identified")
if data[4 + 13 + 0] == 0x3:
print("The dtls handshake 'Hello Verify Request' has been identified")
cookie = data[4 + 13 + 15 + 0:]
print("cookie: ", cookie.hex())
break
time.sleep(0.2)
if cookie:
self.handshake_cookie = cookie
def parse_record_layer_of_server_hello(self):
server_random = None
cipher_suite = None
for repeat_times in range(5):
data, _ = self.socket.recvfrom(4096)
if len(data) >= 4 and data[0] == 0x1:
print("The dtls header has been identified")
if len(data) >= (4 + 13) and data[4 + 0] == 0x16:
self.finished_verify_hash_source["ServerHello"] = data[4 + 13:]
if data[4 + 13 + 0] == 0x2:
print("The dtls handshake 'Server Hello' has been identified")
server_random = data[4 + 13 + 14 + 0: 4 + 13 + 14 + 32]
print("server random: ", server_random.hex())
cipher_suite = data[4 + 13 + 79 + 0: 4 + 13 + 79 + 2]
print("cipher suite: ", cipher_suite.hex())
break
time.sleep(0.2 * repeat_times)
if server_random:
self.server_random = server_random
if cipher_suite:
self.cipher_suite = cipher_suite
def parse_record_layer_of_certificate(self):
certificate = None
for repeat_times in range(5):
data, _ = self.socket.recvfrom(4096)
if len(data) >= 4 and data[0] == 0x1:
print("The dtls header has been identified")
self.finished_verify_hash_source["Certificate"] = data[4 + 13:]
if len(data) >= (4 + 13) and data[4 + 0] == 0x16:
print("The dtls handshake message has been identified")
if data[4 + 13 + 0] == DtlsV1HandShakeType.DTLS_V1_HANDSHAKE_CERTIFICATE.value:
print("The dtls handshake 'Certificate' has been identified")
certificate_length = data[4 + 13 + 15 + 2] + (0xff + 1) * (
data[4 + 13 + 15 + 1] + (0xff + 1) * data[4 + 13 + 15 + 0])
certificate = data[4 + 13 + 18 + 0: 4 + 13 + 18 + certificate_length]
break
time.sleep(0.2 * repeat_times)
if certificate:
self.certificate = certificate
def parse_record_layer_of_server_hello_done(self):
for repeat_times in range(5):
data, _ = self.socket.recvfrom(4096)
if len(data) >= 4 and data[0] == 0x1:
print("The dtls header has been identified")
self.finished_verify_hash_source["ServerHelloDone"] = data[4 + 13:]
if len(data) >= (4 + 13) and data[4 + 0] == DtlsV1RecordType.DTLS_V1_RECORD_HANDSHAKE.value:
print("The dtls handshake message has been identified")
if data[4 + 13 + 0] == DtlsV1HandShakeType.DTLS_V1_HANDSHAKE_SERVER_HELLO_DONE.value:
print("The dtls handshake 'Server Hello Done' has been identified")
break
time.sleep(0.2 * repeat_times)
def assemble_dtls_client_key_exchange_and_change_cipher_spec_and_finished_packet(self):
datagram_transport_layer_security = b"".join((
self.assemble_dtls_client_key_exchange_record_layer(),
self.assemble_dtls_change_cipher_spec_record_layer(),
self.assemble_dtls_finished_record_layer()
))
return datagram_transport_layer_security
def assemble_dtls_client_key_exchange_record_layer(self):
print("assemble_dtls_client_key_exchange_record_layer")
content_type = bytearray([DtlsV1RecordType.DTLS_V1_RECORD_HANDSHAKE.value, ])
version = bytearray([0xfe, 0xff])
epoch = bytearray([0x0, 0x0])
sequence_num = bytearray([0x0, 0x0, 0x0, 0x0, 0x0, 0x2])
length = bytearray([0x1, 0xe])
dtls_handshake = self.assemble_dtls_client_key_exchange_handshake_layer()
if length[0] * (0xff + 1) + length[1] != len(dtls_handshake):
raise Exception('[MESSAGE ASSEMBLE ERROR][LOC: DTLS_RECORD_LAYER]')
ret_value = b"".join((content_type, version, epoch, sequence_num, length, dtls_handshake))
self.finished_verify_hash_source["ClientKeyExchange"] = dtls_handshake
return ret_value
def assemble_dtls_change_cipher_spec_record_layer(self):
print("assemble_dtls_change_cipher_spec_record_layer")
content_type = bytearray([DtlsV1RecordType.DTLS_V1_RECORD_CHANGE_CIPHER_SPEC.value, ])
version = bytearray([0xfe, 0xff])
epoch = bytearray([0x0, 0x0])
sequence_num = bytearray([0x0, 0x0, 0x0, 0x0, 0x0, 0x3])
length = bytearray([0x0, 0x1])
change_message = bytearray([0x1, ])
return b"".join((content_type, version, epoch, sequence_num, length, change_message))
def assemble_dtls_finished_record_layer(self):
print("assemble_dtls_finished_record_layer")
content_type = bytearray([DtlsV1RecordType.DTLS_V1_RECORD_HANDSHAKE.value, ])
version = bytearray([0xfe, 0xff])
epoch = bytearray([0x0, 0x1])
sequence_num = bytearray([0x0, 0x0, 0x0, 0x0, 0x0, 0x0])
length = bytearray([0x0, 0x40])
dtls_handshake = self.assemble_dtls_finished_handshake_layer()
# if length[0] * (0xff + 1) + length[1] != len(dtls_handshake):
# raise Exception('[MESSAGE ASSEMBLE ERROR][LOC: DTLS_RECORD_LAYER]')
return b"".join((content_type, version, epoch, sequence_num, length, dtls_handshake))
def parse_record_layer_of_server_change_cipher_spec(self):
pass
def parse_record_layer_of_server_finished(self):
pass
# =================================packet structure(handshake layer)==================================
def assemble_dtls_client_hello_handshake_layer(self):
handshake_type = bytearray([DtlsV1HandShakeType.DTLS_V1_HANDSHAKE_CLIENT_HELLO.value, ])
if self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE:
length = bytearray([0x0, 0x0, 0x2c])
message_sequence = bytearray([0x0, 0x0])
fragment_length = bytearray([0x0, 0x0, 0x2c])
elif self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE:
length = bytearray([0x0, 0x0, 0x3c])
message_sequence = bytearray([0x0, 0x1])
fragment_length = bytearray([0x0, 0x0, 0x3c])
else:
raise Exception('[MESSAGE ASSEMBLE ERROR][DTLS STATUS ERROR]')
fragment_offset = bytearray([0x0, 0x0, 0x0])
dtls_version = bytearray([0xfe, 0xff])
if self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE:
self.client_random = bytearray(DtlsTunnel.generate_dtls_random())
print("client random %s", self.client_random.hex())
session_id_length = bytearray([0x0, ])
if self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE:
cookie_length = bytearray([0x0, ])
cookie_value = b""
elif self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE:
cookie_length = bytearray([0x10, ])
cookie_value = self.handshake_cookie
else:
raise Exception('[MESSAGE ASSEMBLE ERROR][DTLS STATUS ERROR]')
cipher_suites_length = bytearray([0x0, 0x4])
cipher_suites = bytearray([0x0, 0x2f, 0x0, 0xa])
compression_methods_length = bytearray([0x1, ])
compression_methods = bytearray([0x0, ])
fragment_value = b"".join(
(dtls_version, self.client_random, session_id_length, cookie_length, cookie_value, cipher_suites_length,
cipher_suites,
compression_methods_length, compression_methods)
)
if fragment_length[2] + (0xff + 1) * (fragment_length[1] + (0xff + 1) * fragment_length[0]) != len(
fragment_value):
raise Exception('[MESSAGE ASSEMBLE ERROR][LOC: DTLS_HANDSHAKE_LAYER]')
return b"".join((handshake_type, length, message_sequence, fragment_offset, fragment_length, fragment_value))
def assemble_dtls_client_key_exchange_handshake_layer(self):
handshake_type = bytearray([DtlsV1HandShakeType.DTLS_V1_HANDSHAKE_CLIENT_KEY_EXCHANGE.value, ])
length = bytearray([0x0, 0x1, 0x2])
message_sequence = bytearray([0x0, 0x2])
fragment_offset = bytearray([0x0, 0x0, 0x0])
fragment_length = bytearray([0x0, 0x1, 0x2])
# RSA Encrypted PreMaster Secret
self.pre_master_key = b'\xfe\xff' + os.urandom(46)
# encrypted_pre_master_len = bytearray([0x1, 0x0])
encrypted_pre_master = bytearray(self.get_encrypted_pre_master())
encrypted_pre_master_len = len(encrypted_pre_master).to_bytes(2, "big")
return b"".join((handshake_type, length, message_sequence, fragment_offset, fragment_length,
encrypted_pre_master_len, encrypted_pre_master))
def assemble_dtls_finished_handshake_layer(self):
# ------------------ calculate verify data ------------------
master_secret = DtlsTunnel.prf(self.pre_master_key, "master secret", self.client_random + self.server_random, 48)
key_block = DtlsTunnel.prf(master_secret, "key expansion", self.server_random + self.client_random, 104)
print("master secret: ", master_secret.hex())
print("len(master secret): ", len(master_secret))
print("key_block: ", key_block.hex())
print("len(key_block): ", len(key_block))
# 分割key_block
self.client_write_MAC_key = key_block[0:20]
self.server_write_MAC_key = key_block[20:40]
self.client_write_key = key_block[40:56]
self.server_write_key = key_block[56:72]
self.client_write_IV = key_block[72:88]
self.server_write_IV = key_block[88:104]
print("client_write_key: ", self.client_write_key.hex())
print("len(client_write_key): ", len(self.client_write_key))
print("client_write_IV: ", self.client_write_IV.hex())
print("len(client_write_IV): ", len(self.client_write_IV))
print("server_write_key: ", self.server_write_key.hex())
print("len(server_write_key): ", len(self.server_write_key))
label = "client finished"
handshake_message = b"".join((
self.finished_verify_hash_source["ClientHello"],
self.finished_verify_hash_source["ServerHello"],
self.finished_verify_hash_source["Certificate"],
self.finished_verify_hash_source["ServerHelloDone"],
self.finished_verify_hash_source["ClientKeyExchange"]
))
handshake_hash = hashlib.md5(handshake_message).digest() + hashlib.sha1(handshake_message).digest()
verify_data = DtlsTunnel.prf(master_secret, label, handshake_hash, 12)
print("verify data: ", verify_data.hex())
print("len(verify data): ", len(verify_data))
handshake_type = bytearray([DtlsV1HandShakeType.DTLS_V1_HANDSHAKE_FINISHED.value, ])
length = bytearray([0x0, 0x0, 0xc])
message_sequence = bytearray([0x0, 0x3])
fragment_offset = bytearray([0x0, 0x0, 0x0])
fragment_length = bytearray([0x0, 0x0, 0xc])
unencrypted_finished_message = b"".join(
(handshake_type, length, message_sequence, fragment_offset, fragment_length, verify_data))
epoch = b'\x00\x01'
# seq_num = b'\x00\x01\x00\x00\x00\x00\x00\x00'
seq_num = epoch + self.sequence_number
self.add_sequence_number()
content_type = b'\x16'
version = b'\xfe\xff'
length_bytes = len(unencrypted_finished_message).to_bytes(2, "big")
mac_input = seq_num + content_type + version + length_bytes + unencrypted_finished_message
h = HMAC(self.client_write_MAC_key, SHA1())
# h.update(unencrypted_finished_message)
h.update(mac_input)
mac = h.finalize()
print("MAC: ", mac.hex())
print("len(MAC): ", len(mac))
cipher = Cipher(algorithms.AES(self.client_write_key), modes.CBC(self.client_write_IV), backend=default_backend())
encryptor = cipher.encryptor()
self.explicit_iv = os.urandom(16)
finished_rec_data = self.explicit_iv + unencrypted_finished_message + mac + b'\x03\x03\x03\x03'
encrypted_finished = encryptor.update(finished_rec_data) + encryptor.finalize()
# plain_text = unencrypted_finished_message + mac
# pad_len = 16 - (len(plain_text) % 16)
# padded_data = plain_text + bytes([pad_len] * pad_len)
# encrypted_finished = encryptor.update(padded_data) + encryptor.finalize()
print("encrypted finished: ", encrypted_finished.hex())
print("len(encrypted finished): ", len(encrypted_finished))
return encrypted_finished
# ==================================packet structure(auxiliary func)==================================
def get_encrypted_pre_master(self):
if not self.certificate or not self.pre_master_key:
raise Exception("[!!]")
ps = []
while len(ps) != 205:
new_byte = os.urandom(1)
if new_byte == 0x00:
continue
ps.append(new_byte)
ps = b"".join(ps)
print("certificate: %s" % self.certificate.hex())
print("client key exchange random: %s" % self.pre_master_key.hex())
print("padding: %s" % ps.hex())
return DtlsTunnel.encrypt_pms_with_known_padding(self.certificate.hex(), self.pre_master_key.hex(), ps.hex())
@staticmethod
def prf(secret, label, seed, length):
# 确保label和seed均为bytes类型
if isinstance(label, str):
label = label.encode() # str -> bytes
if isinstance(seed, bytearray):
seed = bytes(seed) # bytearray -> bytes
# 分割secret为S1和S2
half = len(secret) // 2
s1, s2 = secret[:half], secret[half:]
# 计算P_MD5和P_SHA1
p_md5 = DtlsTunnel.p_hash(s1, label + seed, MD5, length)
p_sha1 = DtlsTunnel.p_hash(s2, label + seed, SHA1, length)
# 异或合并结果
return bytes(a ^ b for a, b in zip(p_md5, p_sha1))
@staticmethod
def p_hash(secret, seed, hash_alg, length):
hmac_obj = HMAC(secret, hash_alg()) # 注意:hash_alg需要实例化
output = b""
a = seed
while len(output) < length:
hmac_obj = HMAC(secret, hash_alg()) # 每次迭代重新初始化HMAC
hmac_obj.update(a)
a = hmac_obj.finalize() # 更新A为当前HMAC结果
hmac_obj = HMAC(secret, hash_alg()) # 重新初始化HMAC
hmac_obj.update(a + seed)
output += hmac_obj.finalize()
return output[:length]
@staticmethod
def generate_dtls_random():
# 前4字节:Unix时间戳(大端序)
timestamp = int(time.time()).to_bytes(4, byteorder='big')
# 后28字节:强随机数
random_bytes = os.urandom(28)
return timestamp + random_bytes
@staticmethod
def encrypt_pms_with_known_padding(cert_hex: str, pms_hex: str, ps_hex: str) -> bytes:
"""
构造与真实设备一致的 EncryptedPreMasterSecret,支持任意 RSA 模数(如2048-bit)。
参数:
cert_hex: 证书 DER 格式的 hex 字符串
pms_hex: PreMasterSecret 的 48 字节 hex 字符串
ps_hex: PKCS#1 v1.5 填充部分 PS 的 hex 字符串(不含头部0x00 0x02 和中间0x00)
返回:
加密后的 EncryptedPreMasterSecret(二进制字节串)
"""
# 转换为 bytes
pms = bytes.fromhex(pms_hex)
ps = bytes.fromhex(ps_hex)
if len(pms) != 48:
raise ValueError("PreMasterSecret must be exactly 48 bytes")
# 加载证书
cert_der = bytes.fromhex(cert_hex)
cert = x509.load_der_x509_certificate(cert_der)
pub_key = cert.public_key()
pub_bytes = pub_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
# 导入 RSA 公钥并获取 modulus size
rsa_key = RSA.importKey(pub_bytes)
k = rsa_key.size_in_bytes() # 公钥模数长度(单位字节,如 256 对应 2048-bit)
# 检查填充长度是否匹配规范:0x00 0x02 || PS || 0x00 || PMS
expected_ps_len = k - 3 - len(pms)
if len(ps) != expected_ps_len:
raise ValueError(f"Padding length error: expected {expected_ps_len} bytes, got {len(ps)} bytes")
# 构造填充块 EM
em = b'\x00\x02' + ps + b'\x00' + pms
assert len(em) == k
# 执行裸 RSA 加密(无填充)
m = bytes_to_long(em)
c = pow(m, rsa_key.e, rsa_key.n)
encrypted = long_to_bytes(c, k)
return encrypted
def add_sequence_number(self):
tmp = int.from_bytes(self.sequence_number, byteorder='big')
self.sequence_number = (tmp + 1).to_bytes(6, byteorder='big')
if __name__ == '__main__':
s = "packet structure(capwap control layer)"
print("# " + s.center(100, '='))
print("# " + "=" * 100)