Python模拟DTLS 1.0握手过程

Python模拟DTLS 1.0握手过程

实验拓扑

PC --------- AC(tplink)
代码虚拟WTP进程和AC设备进行DTLS交互

具体交互流程

(略,有时间再补)

脚本使用实例

from dtls_connect.CWSecurity import DtlsTunnel

def cw_assemble_discovery_request():
	# 构造Discovery Request
	pass

if __name__ == '__main__':
    application_layer_byte_stream = cw_assemble_discovery_request()

    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
    ttl = 64
    sock.setsockopt(socket.IPPROTO_IP, socket.IP_TTL, ttl)

    ''' 以下均fail,目前只有单播地址发送 discovery request 可以收到 discovery response
    sock.sendto(application_layer_byte_stream, ("255.255.255.255", 5246))
    sock.sendto(application_layer_byte_stream, ("192.168.1.255", 5246))
    '''
    # 向AC发送Discovery Request报文
    sock.sendto(application_layer_byte_stream, ("192.168.1.253", 5246))
    
	# 等待AC回复Discovery Response
	while True:
        data, addr = sock.recvfrom(4096)
        print(f"收到来自 {addr} 的报文,长度: {len(data)}")
        print("原始报文 (Hex):", data.hex())

        # 解析CAPWAP头部(示例:前8字节)
        if len(data) >= 12:
            message_type = data[11]  # 消息类型(Discovery Response为0x02)
            if message_type == 2:
                break
            print(f"消息类型: {message_type} (0x02表示Discovery Response)")
            
    time.sleep(5)
    dtls = DtlsTunnel(sock)
    dtls.do_handshake()
    time.sleep(3)
# CWSecurity.py
import os
import time
import socket
import hashlib
from enum import Enum
from cryptography import x509
from Crypto.PublicKey import RSA
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from Crypto.Util.number import bytes_to_long, long_to_bytes
from cryptography.hazmat.primitives.hmac import HMAC
from cryptography.hazmat.primitives.hashes import SHA1, MD5


# ============================================Define Enum=============================================
class DtlsV1HandShakeType(Enum):
    DTLS_V1_HANDSHAKE_HELLO_REQUEST = 0x0
    DTLS_V1_HANDSHAKE_CLIENT_HELLO = 0x1
    DTLS_V1_HANDSHAKE_SERVER_HELLO = 0x2
    DTLS_V1_HANDSHAKE_HELLO_VERIFY_REQUEST = 0x3
    DTLS_V1_HANDSHAKE_CERTIFICATE = 0xb
    DTLS_V1_HANDSHAKE_SERVER_KEY_EXCHANGE = 0xc
    DTLS_V1_HANDSHAKE_CERTIFICATE_REQUEST = 0xd
    DTLS_V1_HANDSHAKE_SERVER_HELLO_DONE = 0xe
    DTLS_V1_HANDSHAKE_CERTIFICATE_VERIFY = 0xf
    DTLS_V1_HANDSHAKE_CLIENT_KEY_EXCHANGE = 0x10
    DTLS_V1_HANDSHAKE_FINISHED = 0x14


class DtlsV1RecordType(Enum):
    DTLS_V1_RECORD_CHANGE_CIPHER_SPEC = 0x14
    DTLS_V1_RECORD_ALERT = 0x15
    DTLS_V1_RECORD_HANDSHAKE = 0x16
    DTLS_V1_RECORD_APPLICATION_DATA = 0x17


class DtlsConnectState(Enum):
    DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE = 0
    DTLS_CONNECT_HELLO_VERIFY_REQUEST = 1
    DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE = 2
    DTLS_CONNECT_SERVER_HELLO = 3
    DTLS_CONNECT_CERTIFICATE = 4
    DTLS_CONNECT_SERVER_HELLO_DONE = 5
    DTLS_CONNECT_CLIENT_KEY_AND_CIPHER_SPEC_CHANGE_AND_FINISHED = 6
    DTLS_CONNECT_SERVER_CIPHER_SPEC_CHANGE = 7
    DTLS_CONNECT_SERVER_FINISHED = 8


# ====================================================================================================


class DtlsTunnel:
    def __init__(self, sock: socket.socket, target_ip="192.168.1.253"):
        self.socket = sock
        self.client_random = None
        self.server_random = None
        self.handshake_cookie = None
        self.cipher_suite = None
        self.certificate = None
        self.pre_master_key = None
        self.finished_verify_hash_source = dict()
        self.next_state = DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE
        self.target_ip = target_ip

        self.preamble = bytearray([0x1, 0x0, 0x0, 0x0])
        self.client_write_MAC_key = None
        self.server_write_MAC_key = None
        self.client_write_key = None
        self.server_write_key = None
        self.client_write_IV = None
        self.server_write_IV = None

        self.explicit_iv = None
        self.sequence_number = b'\x00\x00\x00\x00\x00\x00'

    def do_handshake(self):
        while self.next_state:
            print(self.next_state.name)
            record_data = None
            if self.next_state == DtlsConnectState.DTLS_CONNECT_SERVER_FINISHED:
                self.parse_record_layer_of_server_finished()
                return True

            connect_state = self.next_state
            if self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE:
                record_data = self.assemble_dtls_client_hello_record_layer()
                self.next_state = DtlsConnectState.DTLS_CONNECT_HELLO_VERIFY_REQUEST

            elif self.next_state == DtlsConnectState.DTLS_CONNECT_HELLO_VERIFY_REQUEST:
                self.parse_record_layer_of_hello_verify_request()
                self.next_state = DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE

            elif self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE:
                record_data = self.assemble_dtls_client_hello_record_layer()
                self.next_state = DtlsConnectState.DTLS_CONNECT_SERVER_HELLO

            elif self.next_state == DtlsConnectState.DTLS_CONNECT_SERVER_HELLO:
                self.parse_record_layer_of_server_hello()
                self.next_state = DtlsConnectState.DTLS_CONNECT_CERTIFICATE

            elif self.next_state == DtlsConnectState.DTLS_CONNECT_CERTIFICATE:
                self.parse_record_layer_of_certificate()
                self.next_state = DtlsConnectState.DTLS_CONNECT_SERVER_HELLO_DONE

            elif self.next_state == DtlsConnectState.DTLS_CONNECT_SERVER_HELLO_DONE:
                self.parse_record_layer_of_server_hello_done()
                self.next_state = DtlsConnectState.DTLS_CONNECT_CLIENT_KEY_AND_CIPHER_SPEC_CHANGE_AND_FINISHED

            elif self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_KEY_AND_CIPHER_SPEC_CHANGE_AND_FINISHED:
                record_data = self.assemble_dtls_client_key_exchange_and_change_cipher_spec_and_finished_packet()
                self.next_state = DtlsConnectState.DTLS_CONNECT_SERVER_CIPHER_SPEC_CHANGE

            elif self.next_state == DtlsConnectState.DTLS_CONNECT_SERVER_CIPHER_SPEC_CHANGE:
                self.parse_record_layer_of_server_change_cipher_spec()
                self.next_state = DtlsConnectState.DTLS_CONNECT_SERVER_FINISHED

            else:
                raise Exception("Unknown connect state")

            if record_data:
                try:
                    self.socket.sendto(self.assemble_capwap_control_layer(record_data), (self.target_ip, 5246))
                except:
                    raise Exception(f"Error occurs in state[{connect_state.name}]")

        raise Exception("next connection state is None")

    def dtls_write(self, plain_text):
        epoch = b'\x00\x01'
        # seq_num = b'\x00\x00\x00\x00\x00\x01'  # 应自己维护全局 seq
        seq_num = self.sequence_number
        self.add_sequence_number()

        content_type = b'\x17'
        version = b'\xfe\xff'

        self.explicit_iv = os.urandom(16)
        print("explicit iv:", self.explicit_iv.hex())

        mac = HMAC(self.client_write_MAC_key, SHA1())
        mac.update(epoch + seq_num + content_type + version + len(plain_text).to_bytes(2, "big") + plain_text)
        mac_result = mac.finalize()

        data_to_encrypt = plain_text + mac_result

        pad_len = 16 - len(data_to_encrypt) % 16
        padded_data = data_to_encrypt + b"".join([(pad_len - 1).to_bytes(1, 'big') for _ in range(pad_len)])

        cipher = Cipher(algorithms.AES(self.client_write_key), modes.CBC(self.explicit_iv))
        ciphertext = cipher.encryptor().update(padded_data) + cipher.encryptor().finalize()

        record_payload = self.explicit_iv  + ciphertext
        record_header = content_type + version + epoch + seq_num + len(record_payload).to_bytes(2, "big")

        self.socket.sendto(self.preamble + record_header + record_payload, (self.target_ip, 5246))

    def dtls_read(self, encrypted_text):
        cipher = Cipher(algorithms.AES(self.server_write_key), modes.CBC(self.explicit_iv), backend=default_backend())
        decrypt = cipher.decryptor()
        decrypted = decrypt.update(encrypted_text) + decrypt.finalize()
        decrypted_hex = decrypted.hex()
        pad_len = int(decrypted_hex[-2:], base=16) + 1
        plain_text = decrypted_hex[32:-(40 + pad_len * 2)]
        self.explicit_iv = bytes.fromhex(decrypted_hex[:32])
        return plain_text

    # ===============================packet structure(capwap control layer)===============================
    def assemble_capwap_control_layer(self, record_data):
        return b"".join((self.preamble, record_data))

    # ===================================packet structure(record layer)===================================
    def assemble_dtls_client_hello_record_layer(self):
        content_type = bytearray([DtlsV1RecordType.DTLS_V1_RECORD_HANDSHAKE.value, ])
        version = bytearray([0xfe, 0xff])
        epoch = bytearray([0x0, 0x0])
        if self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE:
            sequence_num = bytearray([0x0, 0x0, 0x0, 0x0, 0x0, 0x0])
            length = bytearray([0x0, 0x38])
        elif self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE:
            sequence_num = bytearray([0x0, 0x0, 0x0, 0x0, 0x0, 0x1])
            length = bytearray([0x0, 0x48])
        else:
            raise Exception('[MESSAGE ASSEMBLE ERROR][DTLS STATUS ERROR]')

        dtls_handshake = self.assemble_dtls_client_hello_handshake_layer()
        if length[0] * (0xff + 1) + length[1] != len(dtls_handshake):
            raise Exception('[MESSAGE ASSEMBLE ERROR][LOC: DTLS_RECORD_LAYER]')

        ret_value = b"".join((content_type, version, epoch, sequence_num, length, dtls_handshake))
        if self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE:
            self.finished_verify_hash_source["ClientHello"] = dtls_handshake

        return ret_value

    def parse_record_layer_of_hello_verify_request(self):
        cookie = None
        for repeat_times in range(5):
            data, addr = self.socket.recvfrom(4096)
            if len(data) >= 4 and data[0] == 0x1:
                print("The dtls header has been identified")
                if len(data) >= (4 + 13) and data[4 + 0] == 0x16:
                    print("The dtls handshake message has been identified")
                    if data[4 + 13 + 0] == 0x3:
                        print("The dtls handshake 'Hello Verify Request' has been identified")
                        cookie = data[4 + 13 + 15 + 0:]
                        print("cookie: ", cookie.hex())
                        break
            time.sleep(0.2)
        if cookie:
            self.handshake_cookie = cookie

    def parse_record_layer_of_server_hello(self):
        server_random = None
        cipher_suite = None
        for repeat_times in range(5):
            data, _ = self.socket.recvfrom(4096)
            if len(data) >= 4 and data[0] == 0x1:
                print("The dtls header has been identified")
                if len(data) >= (4 + 13) and data[4 + 0] == 0x16:
                    self.finished_verify_hash_source["ServerHello"] = data[4 + 13:]
                    if data[4 + 13 + 0] == 0x2:
                        print("The dtls handshake 'Server Hello' has been identified")
                        server_random = data[4 + 13 + 14 + 0: 4 + 13 + 14 + 32]
                        print("server random: ", server_random.hex())
                        cipher_suite = data[4 + 13 + 79 + 0: 4 + 13 + 79 + 2]
                        print("cipher suite: ", cipher_suite.hex())
                        break
            time.sleep(0.2 * repeat_times)

        if server_random:
            self.server_random = server_random
        if cipher_suite:
            self.cipher_suite = cipher_suite

    def parse_record_layer_of_certificate(self):
        certificate = None
        for repeat_times in range(5):
            data, _ = self.socket.recvfrom(4096)
            if len(data) >= 4 and data[0] == 0x1:
                print("The dtls header has been identified")
                self.finished_verify_hash_source["Certificate"] = data[4 + 13:]
                if len(data) >= (4 + 13) and data[4 + 0] == 0x16:
                    print("The dtls handshake message has been identified")
                    if data[4 + 13 + 0] == DtlsV1HandShakeType.DTLS_V1_HANDSHAKE_CERTIFICATE.value:
                        print("The dtls handshake 'Certificate' has been identified")
                        certificate_length = data[4 + 13 + 15 + 2] + (0xff + 1) * (
                                data[4 + 13 + 15 + 1] + (0xff + 1) * data[4 + 13 + 15 + 0])
                        certificate = data[4 + 13 + 18 + 0: 4 + 13 + 18 + certificate_length]
                        break
            time.sleep(0.2 * repeat_times)
        if certificate:
            self.certificate = certificate

    def parse_record_layer_of_server_hello_done(self):
        for repeat_times in range(5):
            data, _ = self.socket.recvfrom(4096)
            if len(data) >= 4 and data[0] == 0x1:
                print("The dtls header has been identified")
                self.finished_verify_hash_source["ServerHelloDone"] = data[4 + 13:]
                if len(data) >= (4 + 13) and data[4 + 0] == DtlsV1RecordType.DTLS_V1_RECORD_HANDSHAKE.value:
                    print("The dtls handshake message has been identified")
                    if data[4 + 13 + 0] == DtlsV1HandShakeType.DTLS_V1_HANDSHAKE_SERVER_HELLO_DONE.value:
                        print("The dtls handshake 'Server Hello Done' has been identified")
                        break
            time.sleep(0.2 * repeat_times)

    def assemble_dtls_client_key_exchange_and_change_cipher_spec_and_finished_packet(self):
        datagram_transport_layer_security = b"".join((
            self.assemble_dtls_client_key_exchange_record_layer(),
            self.assemble_dtls_change_cipher_spec_record_layer(),
            self.assemble_dtls_finished_record_layer()
        ))
        return datagram_transport_layer_security

    def assemble_dtls_client_key_exchange_record_layer(self):
        print("assemble_dtls_client_key_exchange_record_layer")
        content_type = bytearray([DtlsV1RecordType.DTLS_V1_RECORD_HANDSHAKE.value, ])
        version = bytearray([0xfe, 0xff])
        epoch = bytearray([0x0, 0x0])
        sequence_num = bytearray([0x0, 0x0, 0x0, 0x0, 0x0, 0x2])
        length = bytearray([0x1, 0xe])

        dtls_handshake = self.assemble_dtls_client_key_exchange_handshake_layer()
        if length[0] * (0xff + 1) + length[1] != len(dtls_handshake):
            raise Exception('[MESSAGE ASSEMBLE ERROR][LOC: DTLS_RECORD_LAYER]')
        ret_value = b"".join((content_type, version, epoch, sequence_num, length, dtls_handshake))
        self.finished_verify_hash_source["ClientKeyExchange"] = dtls_handshake
        return ret_value

    def assemble_dtls_change_cipher_spec_record_layer(self):
        print("assemble_dtls_change_cipher_spec_record_layer")
        content_type = bytearray([DtlsV1RecordType.DTLS_V1_RECORD_CHANGE_CIPHER_SPEC.value, ])
        version = bytearray([0xfe, 0xff])
        epoch = bytearray([0x0, 0x0])
        sequence_num = bytearray([0x0, 0x0, 0x0, 0x0, 0x0, 0x3])
        length = bytearray([0x0, 0x1])
        change_message = bytearray([0x1, ])
        return b"".join((content_type, version, epoch, sequence_num, length, change_message))

    def assemble_dtls_finished_record_layer(self):
        print("assemble_dtls_finished_record_layer")
        content_type = bytearray([DtlsV1RecordType.DTLS_V1_RECORD_HANDSHAKE.value, ])
        version = bytearray([0xfe, 0xff])
        epoch = bytearray([0x0, 0x1])
        sequence_num = bytearray([0x0, 0x0, 0x0, 0x0, 0x0, 0x0])
        length = bytearray([0x0, 0x40])
        dtls_handshake = self.assemble_dtls_finished_handshake_layer()
        # if length[0] * (0xff + 1) + length[1] != len(dtls_handshake):
        #     raise Exception('[MESSAGE ASSEMBLE ERROR][LOC: DTLS_RECORD_LAYER]')
        return b"".join((content_type, version, epoch, sequence_num, length, dtls_handshake))

    def parse_record_layer_of_server_change_cipher_spec(self):
        pass

    def parse_record_layer_of_server_finished(self):
        pass

    # =================================packet structure(handshake layer)==================================
    def assemble_dtls_client_hello_handshake_layer(self):
        handshake_type = bytearray([DtlsV1HandShakeType.DTLS_V1_HANDSHAKE_CLIENT_HELLO.value, ])
        if self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE:
            length = bytearray([0x0, 0x0, 0x2c])
            message_sequence = bytearray([0x0, 0x0])
            fragment_length = bytearray([0x0, 0x0, 0x2c])
        elif self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE:
            length = bytearray([0x0, 0x0, 0x3c])
            message_sequence = bytearray([0x0, 0x1])
            fragment_length = bytearray([0x0, 0x0, 0x3c])
        else:
            raise Exception('[MESSAGE ASSEMBLE ERROR][DTLS STATUS ERROR]')

        fragment_offset = bytearray([0x0, 0x0, 0x0])

        dtls_version = bytearray([0xfe, 0xff])
        if self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE:
            self.client_random = bytearray(DtlsTunnel.generate_dtls_random())
            print("client random %s", self.client_random.hex())
        session_id_length = bytearray([0x0, ])
        if self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITHOUT_COOKIE:
            cookie_length = bytearray([0x0, ])
            cookie_value = b""
        elif self.next_state == DtlsConnectState.DTLS_CONNECT_CLIENT_HELLO_WITH_COOKIE:
            cookie_length = bytearray([0x10, ])
            cookie_value = self.handshake_cookie
        else:
            raise Exception('[MESSAGE ASSEMBLE ERROR][DTLS STATUS ERROR]')
        cipher_suites_length = bytearray([0x0, 0x4])
        cipher_suites = bytearray([0x0, 0x2f, 0x0, 0xa])
        compression_methods_length = bytearray([0x1, ])
        compression_methods = bytearray([0x0, ])

        fragment_value = b"".join(
            (dtls_version, self.client_random, session_id_length, cookie_length, cookie_value, cipher_suites_length,
             cipher_suites,
             compression_methods_length, compression_methods)
        )
        if fragment_length[2] + (0xff + 1) * (fragment_length[1] + (0xff + 1) * fragment_length[0]) != len(
                fragment_value):
            raise Exception('[MESSAGE ASSEMBLE ERROR][LOC: DTLS_HANDSHAKE_LAYER]')

        return b"".join((handshake_type, length, message_sequence, fragment_offset, fragment_length, fragment_value))

    def assemble_dtls_client_key_exchange_handshake_layer(self):
        handshake_type = bytearray([DtlsV1HandShakeType.DTLS_V1_HANDSHAKE_CLIENT_KEY_EXCHANGE.value, ])
        length = bytearray([0x0, 0x1, 0x2])
        message_sequence = bytearray([0x0, 0x2])
        fragment_offset = bytearray([0x0, 0x0, 0x0])
        fragment_length = bytearray([0x0, 0x1, 0x2])

        # RSA Encrypted PreMaster Secret
        self.pre_master_key = b'\xfe\xff' + os.urandom(46)
        # encrypted_pre_master_len = bytearray([0x1, 0x0])
        encrypted_pre_master = bytearray(self.get_encrypted_pre_master())
        encrypted_pre_master_len = len(encrypted_pre_master).to_bytes(2, "big")
        return b"".join((handshake_type, length, message_sequence, fragment_offset, fragment_length,
                         encrypted_pre_master_len, encrypted_pre_master))

    def assemble_dtls_finished_handshake_layer(self):
        # ------------------ calculate verify data ------------------
        master_secret = DtlsTunnel.prf(self.pre_master_key, "master secret", self.client_random + self.server_random, 48)
        key_block = DtlsTunnel.prf(master_secret, "key expansion", self.server_random + self.client_random, 104)

        print("master secret: ", master_secret.hex())
        print("len(master secret): ", len(master_secret))

        print("key_block: ", key_block.hex())
        print("len(key_block): ", len(key_block))

        # 分割key_block
        self.client_write_MAC_key = key_block[0:20]
        self.server_write_MAC_key = key_block[20:40]
        self.client_write_key = key_block[40:56]
        self.server_write_key = key_block[56:72]
        self.client_write_IV = key_block[72:88]
        self.server_write_IV = key_block[88:104]

        print("client_write_key: ", self.client_write_key.hex())
        print("len(client_write_key): ", len(self.client_write_key))

        print("client_write_IV: ", self.client_write_IV.hex())
        print("len(client_write_IV): ", len(self.client_write_IV))

        print("server_write_key: ", self.server_write_key.hex())
        print("len(server_write_key): ", len(self.server_write_key))

        label = "client finished"
        handshake_message = b"".join((
            self.finished_verify_hash_source["ClientHello"],
            self.finished_verify_hash_source["ServerHello"],
            self.finished_verify_hash_source["Certificate"],
            self.finished_verify_hash_source["ServerHelloDone"],
            self.finished_verify_hash_source["ClientKeyExchange"]
        ))
        handshake_hash = hashlib.md5(handshake_message).digest() + hashlib.sha1(handshake_message).digest()
        verify_data = DtlsTunnel.prf(master_secret, label, handshake_hash, 12)

        print("verify data: ", verify_data.hex())
        print("len(verify data): ", len(verify_data))

        handshake_type = bytearray([DtlsV1HandShakeType.DTLS_V1_HANDSHAKE_FINISHED.value, ])
        length = bytearray([0x0, 0x0, 0xc])
        message_sequence = bytearray([0x0, 0x3])
        fragment_offset = bytearray([0x0, 0x0, 0x0])
        fragment_length = bytearray([0x0, 0x0, 0xc])
        unencrypted_finished_message = b"".join(
            (handshake_type, length, message_sequence, fragment_offset, fragment_length, verify_data))

        epoch = b'\x00\x01'
        # seq_num = b'\x00\x01\x00\x00\x00\x00\x00\x00'
        seq_num = epoch + self.sequence_number
        self.add_sequence_number()

        content_type = b'\x16'
        version = b'\xfe\xff'
        length_bytes = len(unencrypted_finished_message).to_bytes(2, "big")
        mac_input = seq_num + content_type + version + length_bytes + unencrypted_finished_message

        h = HMAC(self.client_write_MAC_key, SHA1())
        # h.update(unencrypted_finished_message)
        h.update(mac_input)
        mac = h.finalize()
        print("MAC: ", mac.hex())
        print("len(MAC): ", len(mac))

        cipher = Cipher(algorithms.AES(self.client_write_key), modes.CBC(self.client_write_IV), backend=default_backend())
        encryptor = cipher.encryptor()

        self.explicit_iv = os.urandom(16)

        finished_rec_data = self.explicit_iv + unencrypted_finished_message + mac + b'\x03\x03\x03\x03'
        encrypted_finished = encryptor.update(finished_rec_data) + encryptor.finalize()

        # plain_text = unencrypted_finished_message + mac
        # pad_len = 16 - (len(plain_text) % 16)
        # padded_data = plain_text + bytes([pad_len] * pad_len)
        # encrypted_finished = encryptor.update(padded_data) + encryptor.finalize()
        print("encrypted finished: ", encrypted_finished.hex())
        print("len(encrypted finished): ", len(encrypted_finished))
        return encrypted_finished

    # ==================================packet structure(auxiliary func)==================================
    def get_encrypted_pre_master(self):
        if not self.certificate or not self.pre_master_key:
            raise Exception("[!!]")

        ps = []
        while len(ps) != 205:
            new_byte = os.urandom(1)
            if new_byte == 0x00:
                continue
            ps.append(new_byte)

        ps = b"".join(ps)
        print("certificate: %s" % self.certificate.hex())
        print("client key exchange random: %s" % self.pre_master_key.hex())
        print("padding: %s" % ps.hex())

        return DtlsTunnel.encrypt_pms_with_known_padding(self.certificate.hex(), self.pre_master_key.hex(), ps.hex())

    @staticmethod
    def prf(secret, label, seed, length):
        # 确保label和seed均为bytes类型
        if isinstance(label, str):
            label = label.encode()  # str -> bytes
        if isinstance(seed, bytearray):
            seed = bytes(seed)  # bytearray -> bytes

        # 分割secret为S1和S2
        half = len(secret) // 2
        s1, s2 = secret[:half], secret[half:]

        # 计算P_MD5和P_SHA1
        p_md5 = DtlsTunnel.p_hash(s1, label + seed, MD5, length)
        p_sha1 = DtlsTunnel.p_hash(s2, label + seed, SHA1, length)

        # 异或合并结果
        return bytes(a ^ b for a, b in zip(p_md5, p_sha1))

    @staticmethod
    def p_hash(secret, seed, hash_alg, length):
        hmac_obj = HMAC(secret, hash_alg())  # 注意:hash_alg需要实例化
        output = b""
        a = seed
        while len(output) < length:
            hmac_obj = HMAC(secret, hash_alg())  # 每次迭代重新初始化HMAC
            hmac_obj.update(a)
            a = hmac_obj.finalize()  # 更新A为当前HMAC结果

            hmac_obj = HMAC(secret, hash_alg())  # 重新初始化HMAC
            hmac_obj.update(a + seed)
            output += hmac_obj.finalize()
        return output[:length]

    @staticmethod
    def generate_dtls_random():
        # 前4字节:Unix时间戳(大端序)
        timestamp = int(time.time()).to_bytes(4, byteorder='big')
        # 后28字节:强随机数
        random_bytes = os.urandom(28)
        return timestamp + random_bytes

    @staticmethod
    def encrypt_pms_with_known_padding(cert_hex: str, pms_hex: str, ps_hex: str) -> bytes:
        """
        构造与真实设备一致的 EncryptedPreMasterSecret,支持任意 RSA 模数(如2048-bit)。
        参数:
            cert_hex: 证书 DER 格式的 hex 字符串
            pms_hex: PreMasterSecret 的 48 字节 hex 字符串
            ps_hex: PKCS#1 v1.5 填充部分 PS 的 hex 字符串(不含头部0x00 0x02 和中间0x00)

        返回:
            加密后的 EncryptedPreMasterSecret(二进制字节串)
        """
        # 转换为 bytes
        pms = bytes.fromhex(pms_hex)
        ps = bytes.fromhex(ps_hex)

        if len(pms) != 48:
            raise ValueError("PreMasterSecret must be exactly 48 bytes")

        # 加载证书
        cert_der = bytes.fromhex(cert_hex)
        cert = x509.load_der_x509_certificate(cert_der)
        pub_key = cert.public_key()
        pub_bytes = pub_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )

        # 导入 RSA 公钥并获取 modulus size
        rsa_key = RSA.importKey(pub_bytes)
        k = rsa_key.size_in_bytes()  # 公钥模数长度(单位字节,如 256 对应 2048-bit)

        # 检查填充长度是否匹配规范:0x00 0x02 || PS || 0x00 || PMS
        expected_ps_len = k - 3 - len(pms)
        if len(ps) != expected_ps_len:
            raise ValueError(f"Padding length error: expected {expected_ps_len} bytes, got {len(ps)} bytes")

        # 构造填充块 EM
        em = b'\x00\x02' + ps + b'\x00' + pms
        assert len(em) == k

        # 执行裸 RSA 加密(无填充)
        m = bytes_to_long(em)
        c = pow(m, rsa_key.e, rsa_key.n)
        encrypted = long_to_bytes(c, k)

        return encrypted

    def add_sequence_number(self):
        tmp = int.from_bytes(self.sequence_number, byteorder='big')
        self.sequence_number = (tmp + 1).to_bytes(6, byteorder='big')


if __name__ == '__main__':
    s = "packet structure(capwap control layer)"
    print("# " + s.center(100, '='))
    print("# " + "=" * 100)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值