0% found this document useful (0 votes)
396 views463 pages

Machine Learning For Cyber Security 6th International Conference, ML4CS 2024, Hangzhou, China, December 27-29, 2024

The document is the proceedings of the 6th International Conference on Machine Learning for Cyber Security (ML4CS 2024) held in Hangzhou, China, from December 27 to 29, 2024. It includes 30 accepted papers from 111 submissions, focusing on AI-driven security research. The conference was organized by Zhejiang Sci-Tech University and supported by various key laboratories and institutions in the field.

Uploaded by

Strahinja898
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
396 views463 pages

Machine Learning For Cyber Security 6th International Conference, ML4CS 2024, Hangzhou, China, December 27-29, 2024

The document is the proceedings of the 6th International Conference on Machine Learning for Cyber Security (ML4CS 2024) held in Hangzhou, China, from December 27 to 29, 2024. It includes 30 accepted papers from 111 submissions, focusing on AI-driven security research. The conference was organized by Zhejiang Sci-Tech University and supported by various key laboratories and institutions in the field.

Uploaded by

Strahinja898
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 463

Yang Xiang

Jian Shen (Eds.)


LNCS 15566

Machine Learning
for Cyber Security
6th International Conference, ML4CS 2024
Hangzhou, China, December 27–29, 2024
Proceedings
Lecture Notes in Computer Science 15566
Founding Editors
Gerhard Goos
Juris Hartmanis

Editorial Board Members


Elisa Bertino, Purdue University, West Lafayette, IN, USA
Wen Gao, Peking University, Beijing, China
Bernhard Steffen , TU Dortmund University, Dortmund, Germany
Moti Yung , Columbia University, New York, NY, USA
The series Lecture Notes in Computer Science (LNCS), including its subseries Lecture
Notes in Artificial Intelligence (LNAI) and Lecture Notes in Bioinformatics (LNBI),
has established itself as a medium for the publication of new developments in computer
science and information technology research, teaching, and education.
LNCS enjoys close cooperation with the computer science R & D community, the
series counts many renowned academics among its volume editors and paper authors, and
collaborates with prestigious societies. Its mission is to serve this international commu-
nity by providing an invaluable service, mainly focused on the publication of conference
and workshop proceedings and postproceedings. LNCS commenced publication in 1973.
Yang Xiang · Jian Shen
Editors

Machine Learning
for Cyber Security
6th International Conference, ML4CS 2024
Hangzhou, China, December 27–29, 2024
Proceedings
Editors
Yang Xiang Jian Shen
Swinburne University of Technology Zhejiang Sci-Tech University
Melbourne, VIC, Australia Hangzhou, Zhejiang, China

ISSN 0302-9743 ISSN 1611-3349 (electronic)


Lecture Notes in Computer Science
ISBN 978-981-96-4565-7 ISBN 978-981-96-4566-4 (eBook)
https://doi.org/10.1007/978-981-96-4566-4

© The Editor(s) (if applicable) and The Author(s), under exclusive license
to Springer Nature Singapore Pte Ltd. 2025

This work is subject to copyright. All rights are solely and exclusively licensed by the Publisher, whether
the whole or part of the material is concerned, specifically the rights of translation, reprinting, reuse of
illustrations, recitation, broadcasting, reproduction on microfilms or in any other physical way, and transmission
or information storage and retrieval, electronic adaptation, computer software, or by similar or dissimilar
methodology now known or hereafter developed.
The use of general descriptive names, registered names, trademarks, service marks, etc. in this publication
does not imply, even in the absence of a specific statement, that such names are exempt from the relevant
protective laws and regulations and therefore free for general use.
The publisher, the authors and the editors are safe to assume that the advice and information in this book
are believed to be true and accurate at the date of publication. Neither the publisher nor the authors or the
editors give a warranty, expressed or implied, with respect to the material contained herein or for any errors
or omissions that may have been made. The publisher remains neutral with regard to jurisdictional claims in
published maps and institutional affiliations.

This Springer imprint is published by the registered company Springer Nature Singapore Pte Ltd.
The registered company address is: 152 Beach Road, #21-01/04 Gateway East, Singapore 189721, Singapore

If disposing of this product, please recycle the paper.


Preface

The Sixth International Conference on Machine Learning for Cyber Security (ML4CS
2024) was held on Hangzhou, China, during December 27–29, 2024. ML4CS is a well-
recognized annual international forum for AI-driven security researchers to exchange
ideas and present their works. This volume contains papers presented at ML4CS 2024.
The conference received 111 submissions. The committee accepted 30 regular
papers, with each paper receiving at least 3 double-blind reviews. The proceedings
contain revised versions of the accepted papers. While revisions were expected to take
the referees’ comments into account, this was not enforced, and the authors bear full
responsibility for the content of their papers.
ML4CS 2024 was organized by the School of Information Science and Engineering
(School of Cyber Science and Technology), Zhejiang Sci-Tech University. Furthermore,
ML4CS 2024 was supported by Zhejiang Provincial Key Laboratory of Digital Fash-
ion and Data Governance, Zhejiang Key Laboratory of Artificial Intelligence of Things
(AIoT) Network and Data Security, and Zhejiang Provincial International Cooperation
Base for Science and Technology on Cloud Computing Security and Data Aggrega-
tion. The conference would not have been such a success without the support of these
organizations, and we sincerely thank them for their continued assistance and support.
We would also like to thank the authors who submitted their papers to ML4CS 2024,
and the conference attendees for their interest and support. We thank the Organizing
Committee for their time and effort dedicated to arranging the conference. This allowed
us to focus on the paper selection and deal with the scientific program. We thank the
Program Committee members and the external reviewers for their hard work in review-
ing the submissions; the conference would not have been possible without their expert
reviews. Finally, we thank the EasyChair system and its operators, for making the entire
process of managing the conference convenient.

December 2024 Yang Xiang


Jian Shen
Organization

General Chairs

Yang Xiang Swinburne University of Technology, Australia


Jian Shen Zhejiang Sci-Tech University, China

Program Co-chairs

Dan Dongseong Kim University of Queensland, Australia


Chao Chen RMIT, Australia

Publication Co-chairs

Shigang Liu Swinburne University of Technology, Australia


Yuantian Miao University of Newcastle, Australia

Publicity Co-chairs

Jianfeng Wang Xidian University, China


Shifeng Sun Shanghai Jiao Tong University, China
Leo Yu Zhang Griffith University, Australia

Web Chair

Wenying Zheng Zhejiang Sci-Tech University, China

Steering Committee

Yang Xiang Swinburne University of Technology, Australia


Xiaofeng Chen Western Washington University, USA
Jonathan Oliver VMWARE, USA
Jin Li Guangzhou University, China
Xingliang Yuan Monash University, Australia
viii Organization

PC Members

Silvio Barra University of Salerno, Italy


M. Z. Alam Bhuiyan Guangzhou University, China
Carlo Blundo University of Salerno, Italy
Yiqiao Cai Huaqiao University, China
Luigi Catuogno University of Salerno, Italy
Liang Chang Guilin University of Electronic Technology, China
Fei Chen Shenzhen University, China
Xiaofeng Chen Xidian University, China
Zhe Chen Singapore Management University, Singapore
Frédéric Cuppens Polytechnique Montréal, Canada
Changyu Dong Newcastle University, UK
Guangjie Dong East China Jiaotong University, China
Mohammed EI-Abd American University of Kuwait, Kuwait
Wei Gao Yunnan Normal University, China
Dieter Gollmann Hamburg University of Technology, Germany
Zheng Gong South China Normal University, China
Zhitao Guan North China Electric Power University, China
Saeid Hosseini Singapore University of Technology & Design,
Singapore
Chingfang Hsu Huazhong University of Science and Technology,
China
Haibo Hu Hong Kong Polytechnic University, China
Teng Huang Guangzhou University, China
Xinyi Huang Fujian Normal University, China
Wenchao Jiang Guangdong University of Technology, China
Lutful Karim Seneca College of Applied Arts and Technology,
Canada
Hadis Karimipour University of Guelph, Canada
Sokratis Katsikas Open University of Cyprus, Cyprus
Neeraj Kumar Thapar Institute of Engineering and Technology,
India
Kangshun Li South China Agricultural University, China
Ping Li South China Normal University, China
Tong Li Nankai University, China
Wei Li Jiangxi University of Science and Technology,
China
Xuejun Li Anhui University, China
Kaitai Liang TU Delft, The Netherlands
Organization ix

Hui Liu University of Calgary, Canada


Wei Lu Sun Yat-sen University, China
Xiaobo Ma Xi’an Jiaotong University
Fabio Martinelli IIT-CNR, Italy
Ficco Massimo University of Salerno, Italy
Weizhi Meng Technical University of Denmark, Denmark
Vincenzo Moscato University of Naples “Federico II”, Italy
Francesco Palmieri University of Salerno, Italy
Fei Peng Hunan University, China
Lizhi Peng Jinan University, China
Umberto Petrillo Sapienza University of Rome, Italy
Lianyong Qi Qufu Normal University, China
Shahryar Rahnamayan University of Ontario Institute of Technology,
Canada
Khaled Riad Guangzhou University, China
Haowen Tan Zhejiang Sci-Tech University, China
Yu-an Tan Beijing Institute of Technology, China
Zhiyuan Tan Edinburgh Napier University, UK
Ming Tao Dongguan University of Technology, China
Donghai Tian Beijing Institute of Technology, China
Chen Wang Zhejiang Sci-Tech University, China
Chundong Wang Tianjin University of Technology, China
Ding Wang Peking University, China
Feng Wang Wuhan University, China
Hui Wang Nanchang Institute of Technology, China
Jianfeng Wang Xidian University, China
Jin Wang Soochow University, China
Licheng Wang Beijing University of Posts and
Telecommunications, China
Lingyu Wang Concordia University, Canada
Tianyin Wang Luoyang Normal University, China
Wei Wang Beijing Jiaotong University, China
Wenle Wang Jiangxi Normal University, China
Xiaolong Xu Nanjing University of Information Science &
Technology, China
Li Yang Xidian University, China
Huijie Yang Zhejiang Sci-Tech University, China
ShaoJun Yang Fujian Normal University, China
Zhe Yang Northwestern Polytechnical University, China
Xu Yuan University of Louisiana at Lafayette, USA
Yanqing Yao Beihang University, China
Qikun Zhang Beijing Institute of Technology, China
x Organization

Xiao Zhang Beihang University, China


Xiaosong Zhang Tangshan University, China
Xuyun Zhang Macquarie University, Australia
Yuan Zhang Nanjing University, China
Xianfeng Zhao Chinese Academy of Sciences, China
Derek Wang Data61, CSIRO, Australia
Leo Yu Zhang Griffith University, Australia
Tianqi Zhou Zhejiang Sci-Tech University, China
Shifeng Sun Shanghai Jiao Tong University, China
Shigang Liu Swinburne University of Technology, Australia
Yuantian Miao University of Newcastle, Australia
Yanjun Zhang University of Technology Sydney, Australia
Mengmeng Ge University of Canterbury, New Zealand
Jin Hong University of Western Australia, Australia
Nour Moustafa UNSW Canberra, Australia
Abigail Koay University of the Sunshine Coast, Australia
Contents

Secure Resource Allocation via Constrained Deep Reinforcement Learning . . . . 1


Jianfei Sun, Qiang Gao, Cong Wu, Yuxian Li, Jiacheng Wang,
and Dusit Niyato

Efficient Two-Party Privacy-Preserving Ridge and Lasso Regression


via SMPC . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 16
Zongxiang Yi, Bo Li, Wanhui Zhang, Zhiqiang Lin, and Lishan Ke

A Decentralized Bitcoin Mixing Scheme Based on Multi-signature . . . . . . . . . . . 43


Mingdi Shen, Tianqi Zhou, Chen Wang, and Shijia Hong

Decentralized Continuous Group Key Agreement for UAV Ad-Hoc


Network . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 56
Shijia Hong, Tianqi Zhou, Huijie Yang, Mingdi Shen, and Wenying Zheng

Efficient Homomorphic Approximation of Max Pooling


for Privacy-Preserving Deep Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 70
Peng Zhang, Dongyan Qiu, Ao Duan, and Hongwei Liu

Blockchain-Aided Revocable Threshold Group Signature Scheme


for the Smart Grid . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 81
Xiaozhi Deng, Qinqin Wu, Yi Tang, and Yongbao Wang

Privacy-Preserving Three-Factors Authentication and Key Agreement


for Federated Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 90
Guojun Wang, Guixin Jiang, Yushuai Zhao, and Yinglin Ji

Blockchain-Based Anonymous Authentication Scheme with Traceable


Pseudonym Management in ITS . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 104
Mingliang Wang, Haowen Tan, and Wenying Zheng

Multi-keyword Searchable Data Auditing for Cloud-Based Machine


Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 118
Haiyan Yu, Qingru Ma, Yilu Zhu, and Yuxin Cui

A Flexible Keyword-Based PIR Scheme with Customizable Data Scales


for Multi-server Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 133
Jingang Li, Zhaoyi Liu, Huijie Yang, Tianqi Zhou, and Wenying Zheng
xii Contents

Automatic Software Vulnerability Detection in Binary Code . . . . . . . . . . . . . . . . . 148


Shigang Liu, Lin Li, Xinbo Ban, Chao Chen, Jun Zhang, Seyit Camtepe,
and Yang Xiang

Malicious Code Detection Based on Generative Adversarial Model . . . . . . . . . . . 167


Jinzhihao Zhang, Jia Yang, and Weiqi Zhou

Construction of an AI Code Defect Detection and Repair Dataset Based


on Chain of Thought . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 184
Huimin Gong, Zongliang Shen, Hua Zhang, Lei Qiao, Huawei Wang,
and Chi Zhang

Backdoor Attack on Android Malware Classifiers Based on Genetic


Algorithms . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 197
Zhenghua Cai, Yongji Wang, Hua Zhang, Lei Qiao, Huawei Wang,
and Chi Zhang

A Malicious Websites Classifier Based on an Improved Relation Network . . . . . 215


Qianshi Wang, Chongjun Xu, Huayu Yang, Xilin Zhai, and Hua Zhang

Unknown Category Malicious Traffic Detection Based on Contrastive


Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 231
Leiming Yan, Tao Zhou, and Xianyi Chen

SoftPromptAttack: Advancing Backdoor Attacks in Language Models


Through Prompt Learning Paradigms . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 246
Dixuan Chen, Hongyang Yan, Jiatong Lin, Fan Chen, and Yu Cheng

Removing Regional Steering Vectors to Achieve Knowledge Domain


Forgetting in Large Language Models . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 258
Wei Wu, Chen Wang, Qiuhao Xu, and Wei Kong

A Novel and Efficient Multi-scale Spatio-Temporal Residual Network


for Multi-class Intrusion Detection . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 271
Nan Li, Zhaojian Gao, Jiabin Ye, Wei Tang, Xun Che, and Yadang Chen

Provable Data Auditing Scheme from Trusted Execution Environment . . . . . . . . 284


Yuluo Zeng, Xinlei Sheng, Kai Zhao, and Suliu Yang

Enhanced PIR Scheme Combining SimplePIR and Spiral: Achieving


Higher Throughput Without Client Hints . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 299
Haoyao Xu, Yitong Li, and Haibo Tian
Contents xiii

A Two-Stage Image Blind Inpainting Algorithm Based on Gated Residual


Connection . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 316
Yiting Pan and Xuefeng Zhang

GAN-based Adaptive Trigger Generation and Target Gradient Alignment


in Vertical Federated Learning Backdoor Attacks . . . . . . . . . . . . . . . . . . . . . . . . . . 332
Kun Li, Hongyang Yan, Jiatong Lin, Fan Chen, Yu Cheng,
and Dongyang Liang

Weakly Supervised Waste Classification with Adaptive Loss and Enhanced


Class Activation Maps . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 347
Wenzhang Dai and Le Sun

A Vehicle Asynchronous Communication Scheme Based on Federated


Deep Reinforcement Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 362
Jinming Huang, Jin Wang, Shengyang Gao, and Zilong Jin

A Vehicles Scheduling Algorithm Based on Clustering Based Federated


Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 377
Xin Zhang, Chi Zhang, Shuyan Liu, and Zilong Jin

A Cooperative Caching Strategy Based on Deep Q-Network for Mobile


Edge Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 392
Chun Yang, Guoqing Xu, Liang Ma, and Zilong Jin

YOLO-LiteMax: An Improved Model for UAV Small Object Detection . . . . . . . 408


Jian Su, Chang Yang, and Jian Zhang

LMCF-FS: A Novel Lightweight Malware Classification Framework


Driven by Feature Selection . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 423
Cui Yun, Lei Zhou, Shuangshuang Xing, Ning Yang, Pan Zhao,
and Zhiguo Chen

Rule Learning-Based Target Prediction for Efficient and Flexible Private


Information Retrieval . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 438
Weizhe Tu, Zheng Dong, Tao Zhang, and Wenying Zheng

Author Index . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 449


Secure Resource Allocation
via Constrained Deep Reinforcement
Learning

Jianfei Sun1 , Qiang Gao2 , Cong Wu3(B) , Yuxian Li1 , Jiacheng Wang3 ,
and Dusit Niyato3
1
School of Computing and Information Systems, Singapore Management University,
Singapore 188065, Singapore
{jfsun,yuxianli}@smu.edu.sg
2
School of Computing and Artificial Intelligence, Southwestern University of
Finance and Economics, Chengdu 611130, China
[email protected]
3
College of Computing and Data Science, Nanyang Technological University,
Singapore 639798, Singapore
{cong.wu,jiacheng.wang,dniyato}@ntu.edu.sg

Abstract. The proliferation of Internet of Things (IoT) devices and


the advent of 6G technologies have introduced computationally inten-
sive tasks that often surpass the processing capabilities of user devices.
Efficient and secure resource allocation in serverless multi-cloud edge
computing environments is essential for supporting these demands and
advancing distributed computing. However, existing solutions frequently
struggle with the complexity of multi-cloud infrastructures, robust secu-
rity integration, and effective application of traditional deep reinforce-
ment learning (DRL) techniques under system constraints. To address
these challenges, we present SARMTO, a novel framework that inte-
grates an action-constrained DRL model. SARMTO dynamically bal-
ances resource allocation, task offloading, security, and performance by
utilizing a Markov decision process formulation, an adaptive security
mechanism, and sophisticated optimization techniques. Extensive simu-
lations across varying scenarios-including different task loads, data sizes,
and MEC capacities-show that SARMTO consistently outperforms five
baseline approaches, achieving up to a 40% reduction in system costs and
a 41.5% improvement in energy efficiency over state-of-the-art methods.
These enhancements highlight SARMTO’s potential to revolutionize
resource management in intricate distributed computing environments,
opening the door to more efficient and secure IoT and edge computing
applications.

Keywords: Distributed Computing · Deep Reinforcement Learning ·


Resource Management · Serverless Multi-cloud · Internet of Things

J. Sun and Q. Gao—Contribute equally to this work.


c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 1–15, 2025.
https://doi.org/10.1007/978-981-96-4566-4_1
2 J. Sun et al.

1 Introduction
The rapid expansion of Internet of Things (IoT) devices and the advent of 6G
technologies have led to a surge in computationally demanding applications, such
as augmented reality, autonomous vehicles, and real-time data analytics, which
often exceed the processing capabilities of user devices. To address this, server-
less multi-cloud edge computing has emerged as a promising solution, offering
flexible, scalable, and proximity-aware resources [1–3]. However, the heteroge-
neous and distributed nature of these environments as well as user mobility pose
significant challenges to secure and efficient resource allocation [4–6].
Significant progress has been made in addressing various aspects of resource
allocation challenges. Tang et al. [7] explored distributed task scheduling using
dueling double deep Q-networks for optimizing resource allocation in edge com-
puting. Yao et al. [8] developed an experience-sharing deep reinforcement learn-
ing (DRL) method for function offloading in serverless edge environments. In
terms of security, Min et al. [9] introduced a privacy-aware offloading scheme
using reinforcement learning for healthcare IoT devices, while Huang et al. [10]
proposed a secure and energy-efficient offloading strategy for service workflows
in mobile edge computing. Additionally, Ko et al. [11] and Cicconetti et al. [12]
investigated the integration of serverless computing with edge environments to
enhance scalability and resource utilization.
Despite these advancements, critical research gaps persist. First, most solu-
tions target single-cloud edge models, overlooking the complexities of multi-cloud
environments with heterogeneous nodes [13]. Second, while DRL shows promise
in resource management, current methods often rely on penalty-based rewards,
resulting in suboptimal policies and convergence issues [14]. Third, the inte-
gration of robust security with efficient resource allocation in serverless multi-
cloud environments is largely unexplored, particularly in the context of dynamic
threats and varying task sensitivities [15]. These gaps underscore the need for
a comprehensive framework that securely and efficiently allocates resources in
complex, heterogeneous serverless multi-cloud edge environments.
In this paper, we propose secure adaptive resource management and task
optimization (SARMTO), a novel framework for secure and efficient resource
allocation in serverless multi-cloud edge environments. At its core, SARMTO
employs an innovative action-constrained deep reinforcement learning (AC-DRL)
model that optimizes task offloading and resource allocation while respecting sys-
tem constraints. Using a Markov decision process (MDP) formulation, it models
the resource allocation problem as a sequential decision-making process, allowing
to adapt to changing conditions and balance objectives such as latency, energy
efficiency, and security.
SARMTO addresses key challenges through several novel design elements.
First, it introduces a flexible state representation and action space to manage
the heterogeneity of computing nodes and the complexity of multi-cloud envi-
ronments. Second, an action constraint mechanism enforces system limitations
during decision-making, avoiding the drawbacks of penalty-based approaches.
Third, an adaptive security mechanism dynamically adjusts protection levels
Secure Resource Allocation via Constrained Deep Reinforcement Learning 3

based on task sensitivity and threat landscape, ensuring robust security with-
out compromising efficiency. Finally, advanced techniques including prioritized
experience replay and dueling network architecture are integrated to enhance
learning efficiency in this complex decision space.
In summary, our contributions are as follows:

– We propose SARMTO, a novel framework for secure and efficient resource


allocation in serverless multi-cloud edge environments.
– We introduce AC-DQN, a novel action-constrained DRL algorithm that
respects system constraints and dynamically balances security with compu-
tational overhead.
– We evaluate SARMTO through extensive simulations, showing significant
improvements in performance, energy efficiency, and security over existing
methods.

2 Related Work

Recent advances in edge-cloud computing have led to significant progress in opti-


mizing resource allocation. Tang et al. [7] developed a distributed task scheduling
algorithm using dueling double deep Q-networks for heterogeneous edge net-
works, while Yao et al. [8] introduced a DRL-based function offloading method
for serverless edge computing, achieving better latency and success rates. How-
ever, both approaches fall short in addressing critical security concerns and man-
aging energy costs effectively.
DRL has emerged as a powerful tool for dynamic resource management in
complex environments. Xu et al. [16] applied DRL with deep Q-networks to
optimize service offloading in vehicular edge computing, and Chen et al. [17]
extended this by improving resource allocation in cloud-edge networks. Despite
these advancements, many DRL-based methods rely on penalty functions to
enforce constraints, which can lead to suboptimal outcomes. To address this,
our work introduces an action constraint mechanism that integrates constraints
directly during decision-making. Furthermore, while serverless computing in edge
environments has been explored for its scalability and efficiency, as demonstrated
by Ko et al. [11] and Cicconetti et al. [12], the security challenges and complex-
ities of multi-cloud scenarios remain underexplored. Our research focuses on
bridging these gaps by developing a secure resource allocation framework for
serverless multi-cloud edge environments, accounting for heterogeneous nodes
and dynamic

3 System Model and Problem Formulation

As illustrated in Fig. 1, we consider a heterogeneous serverless multi-cloud


edge computing network Γ = (U, E, C). The network comprises a set of user
devices (UDs) U = {u1 , ..., uN }, multi-access edge computing (MEC) nodes
4 J. Sun et al.

Fig. 1. Serverless multi-cloud edge computing model

E = {e1 , ..., eK }, and cloud computing (CC) nodes C = {c1 , ..., cM }. The net-
work topology is represented by a weighted graph G(V, E), where V = U ∪ E ∪ C
and E represents the set of communication links. Each link (v1 , v2 ) ∈ E is char-
acterized by its bandwidth Bv1 v2 and channel gain gv1 v2 .
Task Model: Each UD ui ∈ U generates a set of computational tasks Ti =
{τi1 , ..., τiL }. A task τij is defined by the tuple (Dij , Cij , Tij ), where Dij ∈ R+
denotes the input data size in bits, Cij ∈ Z+ represents the required CPU
cycles for computation, and Tij ∈ R+ is the delay constraint in seconds. The
relationship between Cij and Dij is modeled as:

Cij = ζ(Dij ), (1)

where ζ(·) is a function that maps data size to required CPU cycles, based on
the application type as shown in Table 1 [14].

Table 1. Application complexity.

Application Labels CPU cycle/Byte


Gzip A 330
Health monitoring B 500
pdf2text (N900 data sheet) C 960
x264 CBR encode D 1900
html2text E 5900
Pdf2text (E72 data sheet) F 8900
Augmented reality G 12000

Computation Resources: Computation nodes include MEC nodes and CC


nodes. Each MEC node ek ∈ E is characterized by its computational capacity
Secure Resource Allocation via Constrained Deep Reinforcement Learning 5

k
fm (in CPU cycles/second) and unit energy cost μkm . Similarly, each CC node
cj ∈ C is defined by its computational capacity fcj and unit energy cost μjc . These
parameters capture the heterogeneity of the computing resources available in the
network.
Our objective is to find an optimal task offloading and resource allocation
policy π ∗ that minimizes the overall system cost while satisfying security and
delay constraints. We introduce a binary decision variable xijk ∈ {0, 1} indicating
whether task τij is executed on node k (either MEC or CC). The optimization
problem is formulated as:
min C(π) = α1 T (π) + α2 E(π)
π

s.t.: xijk = 1, ∀i, j
k∈E∪C
Tij (π) ≤ Tij , ∀i, j (2)

xijk Cij ≤ fk , ∀k ∈ E ∪ C
i,j

xijk ∈ {0, 1}, ∀i, j, k


where C(π) represents the total system cost under policy π, which is a weighted
sum of the total delay T (π) and total energy consumption E(π). The weights α1
and α2 (α1 + α2 = 1) allow for flexible prioritization between delay and energy
objectives. Tij (π) denotes the actual execution time of task τij under policy π,
and fk is the computational capacity of node k.
Delay and Energy Model: The total delay T (π) is calculated as the sum of
computation time, communication time (if offloaded), and security overhead for
all tasks, which is given as:

T (π) = xijk (Tijcomp + Tijcomm + Tijsec ). (3)
i,j,k

Similarly, the total energy consumption E(π) is given by:


 comp
E(π) = xijk (Eij + Eij
comm
+ Eij
sec
), (4)
i,j,k
comp comm sec
where Eij , Eij , and Eij represent computation energy, communication
energy (if offloaded), and security-related energy consumption, respectively.
Security Overhead Model: The security overhead, i.e., due to cryptographic
methods, is an essential aspect of our model, capturing the additional time and
energy costs associated with ensuring data integrity and privacy, which are mod-
eled as:
Tijsec = (φenc
ij + ωij )/fk + (φij + ωij )/fk ,
dec
(5)
sec
Eij = μk (φenc
ij + ωij ) + μk (φij + ωij ),
dec
(6)
where φenc
ij , φdec
ij ,
and ωij represent the CPU cycles required for encryption,
decryption, and integrity verification, respectively. The nodes k and k  denote
the source and destination nodes for offloading.
6 J. Sun et al.

4 Design of SARMTO
4.1 Overview
SARMTO is a cutting-edge framework designed to optimize secure resource
allocation in serverless multi-cloud edge environments.It employs an action-
constrained DRL model that integrates five components: an MDP model for
sequential decision-making, an action-constrained deep Q-network (AC-DQN)
to respect system constraints, a robust security mechanism for data protection,
an adaptive exploration strategy for efficient policy learning, and advanced per-
formance optimization techniques. This enables to dynamically adapt to het-
erogeneous computing conditions, effectively balancing task offloading, resource
allocation, security, and performance requirements in complex distributed com-
puting systems.

4.2 MDP Formulation


The resource allocation problem in serverless multi-cloud edge environments is
inherently dynamic and sequential. By modeling it as an MDP, we can capture
the temporal dependencies and uncertainties inherent in the system, allowing for
more effective decision-making over time. This approach enables SARMTO to
consider the long-term impact of allocation decisions rather than just immediate
rewards, leading to more robust and efficient resource management.
MDP is defined as a tuple M = {S, A, P, R, γ}, where S represents the state
space, A the action space, P the state transition probabilities, R the reward
function, and γ the discount factor. The state st ∈ S at time t is defined as
st = {Di,t , Ci,t , Ti,t | i = 1, . . . , N }, capturing the data size, CPU cycles, and
delay constraints for all tasks, thereby allowing the model to account for all
relevant factors in allocation decisions. The action at ∈ A is represented as
at = {xi,k,t | i = 1, . . . , N ; k ∈ E ∪ C}, specifying task allocation decisions to
either edge or cloud nodes. The reward rt is defined as the negative of the
system cost: rt = −Ct = −(α1 Tt + α2 Et ), enabling SARMTO to optimize both
performance and energy efficiency. The weights α1 and α2 provide the flexibility
to adjust the trade-off between these objectives.
Algorithm 1 outlines the MDP formulation process for SARMTO. It begins
by defining the MDP tuple and initializing its components: the state space,
action space, transition probabilities, reward function, and discount factor. For
each time step, the algorithm observes the current state, which includes the data
size, CPU cycles, and delay constraints of all tasks. Based on this state, an action
is chosen, representing the task allocation decisions for each task to either edge
or cloud nodes. After executing the action, the algorithm observes the new state
and calculates the reward, which is defined as the negative of the system cost,
balancing delay and energy consumption. Finally, the MDP model is updated
based on the observed transition and reward, allowing SARMTO to refine its
decision-making process over time. This iterative process enables SARMTO to
adapt to the dynamic nature of the serverless multi-cloud edge environment and
continuously improve its resource allocation strategy.
Secure Resource Allocation via Constrained Deep Reinforcement Learning 7

Algorithm 1. MDP Formulation for SARMTO


1: Define MDP tuple M = {S, A, P, R, γ}
2: Initialize state space S, action space A, transition probabilities P , reward function
R, discount factor γ
3: for each time step t do
4: Observe current state st = {Di,t , Ci,t , Ti,t |i = 1, ..., N }
5: Choose action at = {xi,k,t |i = 1, ..., N ; k ∈ E ∪ C}
6: Execute action at
7: Observe new state st+1 and reward rt = −(α1 Tt + α2 Et )
8: Update MDP model based on observed transition and reward

4.3 AC-DQN
AC-DQN algorithm extends traditional deep Q-learning by incorporating an
action constraint mechanism, which is crucial for enforcing system limitations
in complex serverless environments. This mechanism is implemented through an
action constraint function fconstraint : A → R, which maps actions to penalty
values: 
−λ, if Ttotal (aj ) > Tmax ,
fconstraint (aj ) = (7)
0, otherwise,
where λ is a large positive constant (e.g., 1000), Ttotal (aj ) computes the total
delay for action aj , and Tmax is the maximum allowable delay. This function
effectively creates a discontinuity in the action-value space, steering the learning
process away from infeasible actions.
The Q-network Q(s, a; θ) is parameterized by θ and architecturally consists
of an input layer Rd → Rn1 , where d is the state dimension, followed by two
hidden layers Rn1 → Rn2 → Rn3 , and an output layer Rn3 → R|A| , where |A| is
the cardinality of the action space. The network is optimized by minimizing the
loss function:
L(θ) = E(s,a,r,s )∼D [(y − Q(s, a; θ))2 ], (8)
where D is the experience replay buffer and y = r + γ maxa (Q(s , a ; θ− ) +
fconstraint (a )) is the target Q-value. Here, θ− represents the parameters of a
target network, which is periodically updated to stabilize training.
The inclusion of fconstraint (a ) in the target Q-value calculation is a key
innovation of AC-DQN. It allows the constraint information to be propagated
through the temporal difference learning process, effectively shaping the Q-
function landscape to inherently avoid infeasible actions. This approach con-
trasts with methods that apply constraints only at the action selection stage, as
it embeds the constraints into the learned value function itself.
Algorithm 2 outlines the AC-DQN training process. The algorithm interleaves
interaction with the environment, storage of experiences, and neural network
updates. The action selection process (line 6) incorporates both the learned
Q-values and the constraint function, ensuring that the agent respects system
limitations even during exploration. The use of a separate target network (lines
10–12) and the periodic update of its parameters (lines 13–17) are standard
techniques in deep Q-learning to improve stability.
8 J. Sun et al.

Algorithm 2. AC-DQN Training Process


1: Initialize Q(s, a; θ) and Q(s, a; θ− ) with θ, θ− ∼ N (0, σ 2 )
2: Initialize replay memory D ← {} with capacity N
3: for episode e = 1 to M do
4: Initialize state s1 ∈ S
5: for t = 1 to T do
6: if Uniform(0, 1) < 1 −  then
7: at ← arg maxa∈A (Q(st , a; θ) + fconst (a))
8: else
9: at ← Uniform(A)
10: Execute at , observe rt ∈ R and st+1 ∈ S
11: D ← D ∪ {(st , at , rt , st+1 )}
12: Sample minibatch B ∼ Uniform(D) where |B| = B
13: for (sj , aj , rj , sj+1 ) ∈ B do
14: if sj+1 is terminal then
15: yj ← rj
16: else
17: yj ← rj + γ maxa ∈A (Q(sj+1 , a ; θ− ) + fconst (a ))
B 2
18: θ ← θ − α∇θ B1 j=1 (yj − Q(sj , aj ; θ))
19: st ← st+1
20: if e mod τ = 0 then
21: θ− ← θ

4.4 Security Mechanism Integration

To ensure data integrity and privacy in serverless multi-cloud edge environments,


we devise a robust security mechanism that leverages asymmetric cryptogra-
phy and cryptographic hash functions, optimizing the balance between security
and computational overhead. The security protocol, defined as Π = (KeyGen,
Enc, Dec, Hash, Verify), includes RSA key generation, encryption, decryption,
MD5 hashing, and verification functions. For a task τij offloaded from MEC
node k to CC node k  , the security process is as follows: (1) KeyGen(1λ ) →
(pkk , skk ); (2) cij = Enc(mij , pkk ); (3) hij = Hash(mij ); (4) Transmit (cij , hij )
from node k to k  ; (5) mij = Dec(cij , skk ); (6) hij = Hash(mij ); and (7)
vij = Verify(hij , hij ). Here, mij represents the original task data, cij is the
encrypted data, hij and hij are hash values, and vij is the verification result.
We model the security-related overhead in both time and energy domains. Let
φenc dec
ij , φij , and ωij represent the number of CPU cycles required for encryption,
decryption, and integrity verification (hashing and verification) respectively. The
security time overhead Tijsec and energy overhead Eij sec
for task τij are formulated
as:
Secure Resource Allocation via Constrained Deep Reinforcement Learning 9

ij + ωij
φenc ij + ωij
φdec
Tijsec = + , (9)
fk fk
sec
Eij = μk (φenc
ij + ωij ) + μk (φij + ωij ),
dec
(10)
where fk and fk are the CPU frequencies of nodes k and k  respectively, and
μk and μk are their respective energy consumption coefficients per CPU cycle.
The security overhead is incorporated into the overall system cost function,
allowing SARMTO to make informed decisions that balance security require-
ments with performance and energy efficiency:
 
Ct = α1 (Tt + Tijsec ) + α2 (Et + sec
Eij ). (11)
i,j i,j

4.5 Adaptive Exploration Strategy

In reinforcement learning, the exploration-exploitation dilemma is a fundamental


challenge. We address this through an adaptive -greedy exploration strategy,
which dynamically adjusts the exploration rate based on the learning progress.
Let π (a|s) denote our -greedy policy, defined as:

1 − + |A|

, if a = argmaxa Q(s, a ),
π (a|s) = 
(12)
|A| , otherwise,

where |A| is the cardinality of the action space. The exploration rate is adapted
over time according to:

(t) = max( min , 0 · decay ),


t
(13)
where 0 is the initial exploration rate, decay ∈ (0, 1) is the decay factor, and min
is the minimum exploration rate. This formulation ensures a gradual transition
from exploration to exploitation while maintaining a baseline level of exploration
throughout the learning process.

4.6 Performance Optimization Techniques

To enhance the stability and efficiency of learning in complex serverless multi-


cloud edge environments, we incorporate several optimizations.
Experience Replay: We maintain a replay buffer D = {e1 , e2 , ..., eN }, where
each experience ei = (si , ai , ri , si ) is a tuple of state, action, reward, and
next state. During training, we sample mini-batches B ⊂ D uniformly: B ∼
Uniform(D, n), where n is the batch size. This technique breaks the temporal
correlations in the data, reducing the variance of updates and improving learning
stability.
10 J. Sun et al.

Fig. 2. Performance under different MEC computational capacities, including system


cost (a), task completion time (b), energy efficiency (c), and offloading rate (d)

Target Network: We employ a target network Q(s, a; θ− ) alongside the pri-


mary Q-network Q(s, a; θ). The target network parameters θ− are updated peri-
odically using a soft update mechanism:

θ− ← τ θ + (1 − τ )θ− , (14)

where τ ∈ (0, 1] is the update rate. This approach stabilizes the learning targets
and mitigates the risk of divergence.
Prioritized Experience Replay: We extend the experience replay mechanism
by assigning priorities to experiences based on their temporal-difference (TD)

error. The probability of sampling an experience ei is given as P (i) = kipα ,
1 k
where pi = |δi | + is the priority of experience i, δi is the TD-error, is a
small positive constant to ensure non-zero sampling probabilities, and α ∈ [0, 1]
determines the degree of prioritization.
Dueling Network Architecture: We decompose the Q-function into separate
value and advantage streams: ⎛ ⎞
1  
Q(s, a; θ, α, β) = V (s; θ, β) + ⎝A(s, a; θ, α) − A(s, a ; θ, α)⎠ , (15)
|A| 
a ∈A
Secure Resource Allocation via Constrained Deep Reinforcement Learning 11

where V (s; θ, β) is the state-value function and A(s, a; θ, α) is the advantage


function. This architecture allows for more efficient learning of state values and
reduces the overestimation of action values.

5 Experimental Results
5.1 Evaluation Setup

We conducted extensive simulations to evaluate the performance of the proposed


SARMTO framework, using Python 3.9 for the simulation environment. To
provide a comprehensive comparison, we implemented five baseline approaches:
MEC Local Computation (LC), where all tasks are processed locally at the
MEC; CC Layer Computing Offload (CO), which randomly offloads tasks to the
CC layer; Random Computation (RC), which randomly offloads tasks between
MEC and CC; a standard Deep Q-Network (DQN) for resource allocation; and
an enhanced DQN with Back Propagation Neural Network (DQN+NN). The
simulation setup included key parameters such as a system bandwidth of 20 MHz,
MEC transmission power of 0.5 W, a channel loss factor of 4, MEC local noise
of 10−13 , MEC computation capacity of 10 GHz, and CC computation capacity
of 100 GHz. Task data sizes ranged from 1 to 5 GB, with delay requirements
between 700 and 800 ms. We assigned equal weights (0.5, 0.5) to delay and energy
consumption in our cost function. The network topology included one MEC node
and two CC nodes, positioned 1 km and 10 km from the MEC, respectively. To
ensure robust results, we ran the simulation for 1000 episodes, with each episode
consisting of 100 time slots.

5.2 Overall Performance

We first evaluate the overall performance of SARMTO as the number of com-


putational tasks increases from 200 to 1000. Figure 3(a) illustrates the system
cost as the number of tasks increases. SARMTO consistently achieves the low-
est system cost across all task volumes. At 1000 tasks, SARMTO reduces the
system cost by 23.6% compared to DQN, 20.8% compared to DQN+NN, 30.0%
compared to LC, 40.0% compared to CO, and 38.2% compared to RC. This sig-
nificant improvement demonstrates SARMTO’s ability to make efficient allo-
cation decisions that balance both delay and energy consumption, even as the
system load increases. As shown in Fig. 3(b), it maintains a competitive average
delay performance, despite not always achieving the lowest delay. At 1000 tasks,
SARMTO’s average delay is only 14.3% higher than the best-performing CO
method. This slight increase in delay is a trade-off for significant improvements
in energy efficiency and overall system cost.
12 J. Sun et al.

5.3 Impact of Average Data Size

We evaluate the impact of varying average task data sizes, ranging from 1 GB to
5 GB. As shown in Fig. 4(a), SARMTO consistently achieves the lowest system
cost across all data sizes. The performance gap widens as data size increases, with
SARMTO demonstrating approximately 26.3% lower system cost compared to
DQN at the 5 GB data size, highlighting its ability to efficiently manage larger
tasks through intelligent offloading decisions and security-aware resource allo-
cation. Interestingly, the DQN+NN approach shows improved performance for
medium-sized data (2–3 GB), slightly narrowing the gap with SARMTO; how-
ever, this advantage diminishes as data sizes increase.

Fig. 3. Performance under different number of tasks, including system cost (a) and
average delay (b)

Fig. 4. Performance under different data sizes, including system cost (a) and energy
consumption (b)
Secure Resource Allocation via Constrained Deep Reinforcement Learning 13

5.4 Energy Consumption Analysis


We further analyzed the energy consumption of different approaches as data
size increased, as shown in Fig. 4(b). SARMTO consistently demonstrated the
lowest energy consumption across all approaches, except for LC, which avoids
communication overhead but incurs higher system costs. At 5 GB, SARMTO
consumed 24.1% less energy than DQN, 22.3% less than DQN+NN, 41.5% less
than CO, and 38.7% less than RC. The energy consumption of CO and RC
exhibited greater variability, likely due to their random offloading decisions, par-
ticularly as data size increased. This underscores the importance of intelligent
resource allocation, where SARMTO excels by effectively balancing energy effi-
ciency with overall system performance, especially for larger tasks.

5.5 Impact of MEC Computational Resources


To evaluate SARMTO’s adaptability and efficiency under varying edge comput-
ing capabilities, we conducted experiments by adjusting the MEC computational
capacity from 5 GHz to 25 GHz. We maintained a constant workload of 1000
tasks, each with an average data size of 3 GB, across all MEC capacities. The
measured metrics included system cost, task completion time, energy efficiency,
and offloading rate.
As illustrated in Figure 2, SARMTO consistently outperforms other methods
across all MEC capacities. It achieves up to 34.7% lower system costs at lower
capacities, with a still significant advantage at higher capacities. Task completion
times are competitive, with SARMTO being only 5.8% slower than the best-
performing method at 5 GHz, narrowing to 2.1% at 25 GHz. In terms of energy
efficiency, SARMTO completes up to 27.3% more tasks per unit of energy at the
highest MEC capacity. Its adaptive offloading strategy is particularly effective,
offloading more tasks to the cloud at lower MEC capacities (78.3% at 5 GHz) and
decreasing to 41.2% at 25 GHz. This demonstrates SARMTO’s ability to make
context-aware decisions, utilizing available edge resources while balancing per-
formance and energy efficiency. These results highlight SARMTO’s robustness
and scalability across diverse edge computing scenarios, making it well-suited
for dynamic serverless multi-cloud environments where computational resources
vary across nodes.

6 Conclusion
This paper has introduced SARMTO, a framework for secure resource allocation
in serverless multi-cloud edge environments, leveraging action-constrained deep
reinforcement learning. SARMTO effectively balances task offloading, resource
allocation, and security requirements through its innovative algorithm, adap-
tive security mechanisms, and performance optimization techniques. Extensive
simulations demonstrated its superior performance, consistently outperforming
existing methods in system cost, energy efficiency, and adaptability across var-
ious scenarios. As edge computing evolves, it offers a promising direction for
14 J. Sun et al.

future research in serverless multi-cloud edge computing, potentially enabling


more efficient and secure distributed systems at scale.

Acknowledgments. This work is supported by the Natural Science Foundation of


Sichuan Province, China (Grant No. 2023NSFSC1411), and the Sichuan Science and
Technology Program, China (Grant No. 2023ZYD0145). Also, this research is sup-
ported by the National Research Foundation, Singapore, and Infocomm Media Devel-
opment Authority under its Future Communications Research & Development Pro-
gramme, Defence Science Organisation (DSO) National Laboratories under the AI
Singapore Programme (FCP-NTU-RG-2022-010 and FCP-ASTAR-TG-2022-003), Sin-
gapore Ministry of Education (MOE) Tier 1 (RG87/22), the NTU Centre for Compu-
tational Technologies in Finance (NTU-CCTF), and Seitee Pte Ltd.

References
1. Wang, P., Di, B., Song, L., Jennings, N.R.: Multi-layer computation offloading in
distributed heterogeneous mobile edge computing networks. IEEE Trans. Cognitive
Commun. Networking 8(2), 1301–1315 (2022)
2. Zhang, R., Zhang, L., Wu, Q., Zhou, J.: Secure channel establishment scheme
for task delivery in vehicular cloud computing. IEEE Trans. Inf. Forensics Secur.
(2024)
3. Wang, L., Wu, W., Zhou, F., Yang, Z., Qin, Z., Wu, Q.: Adaptive resource alloca-
tion for semantic communication networks. IEEE Trans. Commun. (2024)
4. Zhang, L., Zou, Y., Wang, W., Jin, Z., Su, Y., Chen, H.: Resource allocation and
trust computing for blockchain-enabled edge computing system. Comput. Secur.
105, 102249 (2021)
5. Liang, F., et al.: Resource allocation and workload scheduling for large-scale dis-
tributed deep learning: a survey, arXiv preprint arXiv:2406.08115 (2024)
6. Zhang, M., Chen, S., Shen, J., Susilo, W.: Privacyeafl: privacy-enhanced aggre-
gation for federated learning in mobile crowdsensing. IEEE Trans. Inf. Forensics
Secur. (2023)
7. Tang, Q., et al.: Distributed task scheduling in serverless edge computing networks
for the internet of things: a learning approach. IEEE Internet Things J. 9(20),
19634–19648 (2022)
8. Yao, X., Chen, N., Yuan, X., Ou, P.: Performance optimization of serverless edge
computing function offloading based on deep reinforcement learning. Futur. Gener.
Comput. Syst. 139, 74–86 (2023)
9. Min, M., et al.: Learning-based privacy-aware offloading for healthcare IoT with
energy harvesting. IEEE Internet Things J. 6(3), 4307–4316 (2018)
10. Huang, B., et al.: Security modeling and efficient computation offloading for service
workflow in mobile edge computing. Futur. Gener. Comput. Syst. 97, 755–774
(2019)
11. Ko, H., Pack, S., Leung, V.C.: Performance optimization of serverless computing
for latency-guaranteed and energy-efficient task offloading in energy-harvesting
industrial IoT. IEEE Internet Things J. 10(3), 1897–1907 (2021)
12. Cicconetti, C., Conti, M., Passarella, A.: Architecture and performance evaluation
of distributed computation offloading in edge computing. Simul. Model. Pract.
Theory 101, 102007 (2020)
Secure Resource Allocation via Constrained Deep Reinforcement Learning 15

13. Grozev, N., Buyya, R.: Multi-cloud provisioning and load distribution for three-tier
applications. ACM Trans. Auton. Adapt. Syst. 9(3), 1–21 (2014)
14. Zhang, H., Wang, J., Zhang, H., Bu, C.: Security computing resource allocation
based on deep reinforcement learning in serverless multi-cloud edge computing.
Futur. Gener. Comput. Syst. 151, 152–161 (2024)
15. Elgendy, I.A., Zhang, W.-Z., Zeng, Y., He, H., Tian, Y.-C., Yang, Y.: Efficient and
secure multi-user multi-task computation offloading for mobile-edge computing in
mobile IoT networks. IEEE Trans. Network Serv. Manag. 17(4), 2410–2422 (2020)
16. Xu, X., et al.: Service offloading with deep q-network for digital twinning-
empowered internet of vehicles in edge computing. IEEE Trans. Ind. Inform. 18(2),
1414–1423 (2020)
17. Chen, Q., Kuang, Z., Zhao, L.: Multiuser computation offloading and resource
allocation for cloud-edge heterogeneous network. IEEE Internet Things J. 9(5),
3799–3811 (2021)
Efficient Two-Party Privacy-Preserving
Ridge and Lasso Regression via SMPC

Zongxiang Yi1 , Bo Li2 , Wanhui Zhang3(B) , Zhiqiang Lin4 , and Lishan Ke4
1
School of Mathematics and Systems Science, Guangdong Polytechnic Normal
University, Guangzhou 510665, China
2
Guangzhou Construction Co., Ltd., Guangzhou 510030, China
3
Guangzhou Installation Group Co., Ltd., Guangzhou 510030, China
[email protected]
4
School of Mathematics and Information Science, Guangzhou University,
Guangzhou 510006, China
{linzhiqiang,kelishan}@gzhu.edu.cn

Abstract. Regularized regression techniques, like Ridge and Lasso


regression, are extensively employed in machine learning for precise
data analysis and forecasting. In the big data landscape, privacy con-
cerns are more critical than ever, especially regarding the potential
for data breaches. Secure multi-party computation (SMPC) presents a
viable solution, facilitating collaborative machine learning projects with-
out compromising the security of private data.
This paper proposes two privacy-preserving algorithms tailored for
regularized regression in a two-party context: Ridge and Lasso regression.
The outlined approach for Ridge regression utilizes secret sharing pro-
tocol, multiplication triple protocol, and secure computation primitives.
For Lasso regression, the proposed algorithm integrates an additional
plain proximal gradient descent technique, significantly reducing compu-
tational and communication overhead. Both algorithms are proven to be
UC-secure against semi-honest adversaries. Experimental results under-
line their effectiveness, demonstrating that they can achieve an accu-
racy comparable to that of plaintext models. Compared to other similar
algorithms, our algorithms have significant advantages in terms of both
running time and communication overhead.

Keywords: Privacy-preserving · Regularization · Ridge regression ·


Lasso regression

1 Introduction
With the rapid development of artificial intelligence, machine learning, as a key
enabling technology, has been widely applied in various fields such as medi-
cal diagnosis, finance, transportation, and e-commerce. By training models to
extract patterns from data, computer systems can perform a range of intelli-
gent tasks, including image recognition, speech recognition, and natural language
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 16–42, 2025.
https://doi.org/10.1007/978-981-96-4566-4_2
Privacy-Preserving Ridge and Lasso Regression 17

processing. In this context, the extensive use of large datasets for model train-
ing to uncover valuable information has become an important resource driving
advancements in information technology. However, the widespread collection and
use of data have also brought privacy issues to the forefront. Machine learning
models typically require access to vast amounts of user data during the train-
ing process, which undoubtedly increases the risk of privacy breaches. The key
to resolving this contradiction lies in adopting appropriate privacy protection
technologies and strategies that effectively utilize decentralized data resources
while ensuring data privacy and security. Therefore, researching how to protect
data privacy while maintaining model performance has become a pressing issue.
Privacy-preserving machine learning employs various technical methods, such as
differential privacy, homomorphic encryption, and secure multi-party computa-
tion, to ensure the effectiveness and accuracy of machine learning models while
safeguarding user data privacy.
Machine learning can encounter overfitting issues when processing data that
contains noise, has imbalanced sample categories, or is high-dimensional, which
severely impacts the model’s generalization ability. The introduction of L2 regu-
larization and L1 regularization has significantly enhanced the model’s general-
ization capability. L2 regularized linear regression, also known as Ridge regres-
sion, was first proposed by statisticians Arthur E. Hoerl and Robert W. Ken-
nard as a method to address multicollinearity issues by adding the square of
the weights to the loss function to prevent model overfitting [25]. L1 regular-
ized linear regression, known as Lasso regression, was introduced by statisti-
cian Robert Tibshirani in a 1996 paper [39]. Lasso regression performs fea-
ture selection by incorporating the absolute sum of the weights into the loss
function, thereby improving the model’s sparsity and interpretability. Addition-
ally, Privacy-Preserving Logistic Regression (PPLR), as a significant direction
in privacy-preserving machine learning, combines L2 regularized linear regres-
sion and L1 regularized logistic regression with privacy protection technologies.
This approach can enhance the model’s generalization ability while effectively
preventing the leakage of private data from participating parties, thereby train-
ing superior predictive models from multi-party data and endowing the research
on privacy protection in regularized logistic regression algorithms with practical
significance.
The current mainstream privacy-preserving technologies include differential
privacy [3, 17, 18, 33, 43, 46], homomorphic encryption [9, 35], and secure multi-
party computation [6, 8, 11, 14, 15, 20, 30, 31, 37, 42, 45]. However, differential pri-
vacy may not achieve the same accuracy as models using plaintext, and homo-
morphic encryption involves excessive computational overhead. Therefore, we
focus on privacy computing solutions based on Secure Multi-Party Computation
(SMPC). The landscape of SMPC has seen considerable advancements, partic-
ularly in optimizing communication overhead and computational efficiency for
privacy-preserving regression analysis and machine learning. Several key stud-
ies have contributed to this evolving field, each addressing different aspects of
SMPC and enhancing the practicality of secure computations.
18 Z. Yi et al.

Bogdanov et al. attempted to reduce communication complexity by integrat-


ing the Sharemind framework with a one-round multiplication protocol [5, 6].
This combination represents an early effort to optimize secure computations by
streamlining one of the most communication-intensive operations-multiplication.
On a parallel front, Giacomelli et al. introduced a method for privacy-preserving
Ridge regression using linearly-homomorphic encryption, which significantly low-
ers both computational and communication overhead compared to fully homo-
morphic encryption. This approach not only improves data security but also
marks a key advancement in making secure data analysis more computationally
feasible [20].
Moving to more complex algorithms, Gascón et al. based their linear regres-
sion privacy-preserving scheme on Yao’s Garbled Circuits and inner product
protocols. While this method shows promise, the inherent complexity and limita-
tions in handling large-scale computational gates within Garbled Circuits suggest
that further developments are necessary to enhance its scalability and efficiency
[19]. In an effort to mitigate the high complexity associated with Garbled Cir-
cuits, Rouhani et al. proposed a secure deep learning framework for cloud servers
that utilizes low-overhead preprocessing techniques. This innovation effectively
reduces the total time required by Garbled Circuit protocols, making the app-
roach more practical for real-world applications [34].
Further innovations in SMPC frameworks have come from the ABY model,
proposed by Demmler et al., which integrates secret sharing and Garbled Cir-
cuits. This model facilitates secure two-party computation with secret multipli-
cation, balancing the need for security with communication efficiency [15]. Build-
ing on this, Mohassel and Zhang developed the SecureML scheme, which blends
homomorphic encryption, Garbled Circuits, and oblivious transfer to enable the
secure use of private data in machine learning algorithms like linear regression,
logistic regression, and neural networks. This approach is particularly notable
for its ability to operate under the assumption of semi-honest servers without
collusion, further expanding the applicability of SMPC in secure computations
[31].
To enhance the precision of training data in SMPC, Wu et al. implemented
a quantization method that converts data into fixed-point numbers, thereby
improving the accuracy of computations without significantly increasing com-
munication costs [45]. Meanwhile, Agrawal et al. broke new ground by pioneer-
ing the use of SMPC to train convolutional and residual layers within two-party
deep neural networks. This development represents a significant leap forward in
applying SMPC to more complex machine learning models [1].
The evolution of SMPC frameworks continued with Patra et al.’s develop-
ment of ABY2.0, which extends the original ABY model by adding a machine
learning module. This enhancement allows for the application of various secret-
sharing techniques and supports rapid conversion between them in a semi-honest
setting, thus improving both the flexibility and efficiency of secure computations
[32]. In exploring multi-party computation models, Mohassel and Rindal and
Wagh et al. focused on three-party server models that can tolerate up to one
Privacy-Preserving Ridge and Lasso Regression 19

malicious participant. Their research demonstrates that models like ABY3.0


and FALCON can maintain security integrity, even in more complex computa-
tional environments [30, 42]. While four-party server implementations have also
been investigated for their efficiency, they require a minimum of three or four
participants to function effectively. This requirement presents a limitation in
comparison to the more commonly used two-party solutions [8, 11].
Lastly, in the realm of privacy-preserving machine learning on vertically par-
titioned data, Veugen et al. utilized the MPyC model based on Shamir’s secret
sharing. This model successfully implemented Ridge regression, Lasso regression,
and support vector machine learning schemes, thereby filling a critical gap in the
application of SMPC to Lasso regression [41].
In summary, the field of SMPC continues to advance through a combination
of innovative cryptographic techniques and efficient computation models. These
developments are progressively reducing communication overhead and expanding
the feasibility of secure data analysis in various machine learning contexts.
Our Contributions. We focus on the two-party setting and aim to reduce the
communication overhead in Lasso regression as discussed in [40] and the com-
putational complexity in Ridge regression as highlighted in [20]. Our approach
adopts the framework from [31], utilizing oblivious transfer for multiplication
triplet generation [4]. Our contributions are twofold:

1) We propose the selection of the regularization parameter using synthetic data


rather than private data for both Ridge and Lasso regressions. This approach
enhances the accuracy of the regressions without compromising the privacy
of the underlying data.
2) For Lasso regression, we suggest computing the intricate aspects of the prox-
imal gradient descent algorithm in plain form. This significantly reduces
the communication and computational burden traditionally associated with
secure comparisons.

Organization. Section 2 presents some basic concepts and introduces some


SMPC Primitives used in our schemes. Section 3 proposes a regularization
parameter selection algorithm applicable to both the Ridge regression and Lasso
regression schemes. Finally, the conclusion is provided in Sect. 4.

2 Preliminaries
2.1 Notations
The notations used in the paper are listed as follows:
1. N: The set of all natural integers.
2. Partyi : The participants of the schemes, i = 0, 1.
3. Zn : The residual class ring containing 0, 1, . . . , n − 1.
4. L: The bit length of basic data type1 .
1
It is not hard to transform the value of original data into Z2L .
20 Z. Yi et al.

5. PRNG(s, i): The i-th output of the pseudo random number generator seeded
by s.
6. ⊥: The empty output or input for algorithms.
7. (output0; output1) = A(input0, input1): Party0 sends input0 to the two-party
algorithm A and then receives output0 while Party1 send input1 to the same
two-party algorithm A and then receives output1.

2.2 Some SMPC Primitives


Secret Share. For any x ∈ Z2L , [x]i denotes the secret share of x for party i.
We have
[x]0 + [x]1 = x.
Partyi send a random number r ∈ Z2L to Partyj if he want to share data x for
Partyj . Then [x]j = r and [x]i = x − r. When sharing many xi , it is secure to
send a seed s of a PRNG and use the seeded PRNG to generate enough random
numbers PRNG(s, i) as ri for the corresponding xi .
Lemma 1. The above secret share is linear homomorphic over, i.e., for any
n∈N
[x + y]i = [x]i + [y]i ,
[n ∗ x]i = n ∗ [x]i .
The secret share and recovery algorithm can be described as follows:

Algorithm 1: Secret-Share
Input: (x; ⊥)
Output: ([x]0 ; [x]1 )
begin
Party0 randomly selects r ∈ Z2L and sends r to Party1 ;
Party0 outputs [x]0 = x − r;
Party1 outputs [x]1 = r.

Algorithm 2: Recovery
Input: ([x]0 ; [x]1 )
Output: (x; x)
begin
Party0 sends [x]0 to Party1 ;
Party1 sends [x]1 to Party0 ;
Party0 receives [x]1 and outputs x = [x]0 + [x]1 ;
Party1 receives [x]0 and outputs x = [x]0 + [x]1 .
Privacy-Preserving Ridge and Lasso Regression 21

Oblivious Transfer. The Oblivious Transfer (OT) algorithm is a fundamental


cryptographic primitive used in secure multi-party computation. In this algo-
rithm, a sender Party0 has two pieces of information, k0 and k1 , while a receiver
Party1 wishes to obtain one of these pieces, denoted by a choice bit σ ∈ Z2 .
The protocol ensures that Party1 learns only the piece of information kσ cor-
responding to their choice bit, while Party0 remains oblivious to which piece
of information Party1 has obtained. The concrete OT algorithms can be found
in [2, 7] and The reader is referred to [47] for a survey of OT.

Algorithm 3: Oblivious-Transfer
Input: (k0 , k1 ; σ) where k0 , k1 ∈ {0, 1}∗ and σ ∈ {0, 1}
Output: (⊥; kσ )

Multiplication Triplets. To make multiplication over secret shares, we need


the help of multiplication triplets (MT) (a, b, c). The triplet (a, b, c) ∈ Z32L sat-
isfies that a, b are randomly distributed over Z2L , and c = ab. Before the secret
multiplication computation, participants Party0 and Party1 hold the secret
shares of [a]0 , [b]0 , [c]0 and [a]1 , [b]1 , [c]1 , respectively, while the values of a, b,
and c remain confidential. The secret shares [a]0 , [b]0 , [a]1 and [b]1 can be ran-
domly generated by the participants themselves.

c = ab
= ([a]0 + [a]1 )([b]0 + [b]1 )
= [a]0 [b]0 + [a]0 [b]1 + [a]1 [b]0 + [a]1 [b]1
As shown in above equations, if participant Party0 wants to obtain the secret
share [c]0 , and participant Party1 wants to obtain the secret share [c]1 , then
[a]0 [b]0 and [a]1 [b]1 can be computed directly by the participants themselves,
while for [a]0 [b]1 and [a]1 [b]0 , the participants Party0 and Party1 need to use
Oblivious Transfer (OT) technology to obtain their corresponding secret shares.
This problem can be solved by the following algorithm. With the help of algo-
rithm Product-Secret-Sharing, multiplication triplet can be generated.
Now with the help of multiplication triplets (a, b, c), the secure multiplication
based on arithmetic secret sharing can be achieved.
22 Z. Yi et al.

Algorithm 4: Product-Secret-Sharing
Input: (X = XM −1 XM −2 · · · X0 ; Y = YM −1 YM −2 · · · Y0 )
Output: ([XY ]0 ; [XY ]1 )
begin
for i = 0, 1, . . . , M − 1 do
Party1 generates a random number Ri0 and compute Ri1 = 2i Y + Ri0 ;
Party1 and Party0 run algorithm Oblivious-Transfer({Ri0 , Ri1 ; Xi }) and
Party0 obtains RiXi ;
 −1 Xi
Party0 outputs [XY ]0 = M R ;
i=0−1 i0
Party1 outputs [XY ]1 = M i=0 Ri ;

Algorithm 5: Multiplication-Triple-Generation
Input: (⊥; ⊥)
Output: ([a]0 , [b]0 , [c]0 ;[a]1 , [b]1 , [c]1 ) such that c = ab
begin
for i = 0, 1 do
Partyi randomly generates [a]i , [b]i ;
Party0 and Party1 run Product-Secret-Sharing([a]0 ; [b]1 ) and obtain
[[a]0 [b]1 ]0 and [[a]0 [b]1 ]1 , respectively;
Party0 and Party1 run Product-Secret-Sharing([b]0 ; [a]1 ) and obtain
[[b]0 [a]1 ]0 and [[b]0 [a]1 ]1 , respectively;
for i = 0, 1 do
Partyi output [c]i = [[a]0 [b]1 ]i + [[b]0 [a]1 ]i + [a]i [b]i ;

Algorithm 6: Secret-Sharing-Multiplication
Input: ([x]0 , [y]0 , [a]0 , [b]0 , [c]0 ; [x]1 , [y]0 , [a]1 , [b]1 , [c]1 ) where c = ab
Output: ([xy]0 ; [xy]1 )
begin
if ([a]0 = ⊥) ∧ ([a]1 = ⊥) then
for i = 0, 1 do
Partyi runs algorithm Multiplication-Triple-Generation with
input ⊥ and obtain [a]i , [b]i , [c]i ;
for i = 0, 1 do
Partyi computes [e]i = [x]i − [a]i , [f ]i = [y]i − [b]i ;
for j ∈ {e, f } do
Party0 and Party1 run algorithm Recovery([j]0 ; [j]0 ) and both
obtain j;
for i = 0, 1 do
Partyi outputs [xy]i = e[y]i + f [x]i + [c]i − ef i;
Privacy-Preserving Ridge and Lasso Regression 23

Linear Regression. The linear regression model equation aims to predict the
output y as a function of the input x and a weight w. The model equation is
given by:

y = xw
Here, y is the predicted value, x is the input variable, and w is the weight
parameter that we aim to learn through training. To obtain the optimal weight
parameter w, it usually use batch gradient descent method which involves a
loss function and a update function to update current found optimal weight
parameter w.
The loss function measures the difference between the predicted output of
the model and the actual target values. For linear regression, a common choice
is the Mean Squared Error (MSE) loss function, which is computed as:
m
1 
J(w) = (xi w − yi )2
2 m i=1
In this equation, m is the number of observations, wxi is the predicted value
for the i-th observation, yi is the actual target value for the i-th observation,
and J(w) represents the loss function computed over all observations.
To minimize the loss function, we use an optimization algorithm called gra-
dient descent. The update function for the weight parameter w at each iteration
is given by:

∂J(w)
w =w−α
∂w
Here, α is the learning rate, a hyperparameter that defines the step size
during the update. The partial derivative of the loss function with respect to the
weight w, ∂J(w)
∂w , is computed as:
m
∂J(w) 1 
= xi (xi w − yi )
∂w m i=1
Integrating the derivative back into the update function, we get:
 m

1 
w =w−α xi (xi w − yi )
m i=1
By iteratively applying this update rule, the weight w is adjusted to minimize
the loss J(w), moving towards the optimal weight that best fits the data under
24 Z. Yi et al.

the linear regression model. Algorithm 7 is a batch gradient descent for linear
regression with bias term.

Algorithm 7: Plain-Linear-Regression
Input: Training dataset (xi , yi )m
i=1 , learning rate α, number of epochs
E, batch size B
Output: Weight parameter w after each update
Add a column of ones to X to form X̃a ;
Initialize weight w;
for e = 1, 2, · · · , E do
Shuffle the dataset X̃;
for each batch index B ⊆ [0, 1, . . . , m − 1] of size B do
Compute the gradient: ∇J(w) = B1 X̃TB (X̃B w − YB );
Update the weight: w = w − α∇J(w);
Output the updated weight w;

a
A column of ones is added to the input matrix xx, resulting in an augmented
matrix X̃. This allows the linear regression model to include the bias term bb
by treating it as part of the weight vector w.
For more details of regression models in machine learning, please refer to [29].

2.3 Plain Ridge Regression Scheme

For Ridge regression, the loss function is defined as:


m n
1  
J(w) = (xi w − yi )2 + λ wj2 ,
2m i=1 j=1

and hence the gradient of the loss function is:


m
∂J(w) 1 
= xi (xi w − yi ) + 2λw.
∂w m i=1

As a result, the update rule for the weight parameter w is:


 m

1 
w =w−α xi (xi w − yi ) + 2λw . (1)
m i=1
Privacy-Preserving Ridge and Lasso Regression 25

So the iteration version of plain Ridge regression can be written as Algo-


rithm 8.
Algorithm 8: Plain-Ridge-Regression
Input: Training dataset (xi , yi )m
i=1 , regularization parameter λ, learning
rate α, number of epochs E, batch size B
Output: Weight parameter w after each update
Add a column of ones to X to form X̃;
Initialize weight w;
for e = 1, 2, · · · , E do
Shuffle the dataset X̃;
for each batch index B ⊆ [0, 1, . . . , m − 1] of size B do
Compute the gradient: ∇J(w) = B1 X̃TB (X̃B w − YB );
Update the weight: w = (1 − 2λ)w − α∇J(w);
Output the updated weight w;

Remark 1. The L2 regularization term is putted in the update process and drop
a coefficient of α.

2.4 Plain Lasso Regression Scheme


For Lasso regression, the loss function is defined as:
m n
1  
J(w) = (xi w − yi )2 + λ |wj | .
2m i=1 j=1

The loss function J(w) includes an L1 regularization terms |wj | which are non-
differentiable. Therefore, it is impossible to use the classical gradient descent
method to solve for the function parameters that minimize the loss function. In
this paper, we will use the proximal gradient descent (PGD) method for solv-
ing [28]. For the non-differentiable terms, we use the proximal gradient method
to solve for the weight parameter w. The update rule for the weight parameter
w is:   m

1 
w = Prox w − α xi (xi w − yi ) , (2)
m i=1
where the proximal function Prox associated with regularization parameter λ, is
defined as: ⎧

⎨wi − λ if wi > λ,
Proxλ (wi ) = 0 if |wi | ≤ λ,


wi + λ if wi < −λ.
26 Z. Yi et al.

So the iteration version of plain Lasso regression can be written as Algorithm 9.

Algorithm 9: Plain-Lasso-Regression
Input: Training dataset (xi , yi )m i=1 , regularization parameter λ, learning
rate α, number of epochs E, batch size B
Output: Weight parameter w for each batch
Add a column of ones to X to form X̃;
Initialize weight w;
for e = 1, 2, · · · , E do
Shuffle the dataset X̃;
for each batch B ⊆ [0, 1, . . . , m − 1] of size B do
Compute the gradient: ∇J(w) = B1 X̃TB (X̃B w − YB );
Update the weight for the first time: w = w − α∇J(w);
Update the weight for the second time: w = Proxλ (w);
Output the intermediate weight w;

Remark 2. Compared to the plain Ridge regression (as described in Algorithm


8), plain Lasso regression yields a greater number of weights. Given that the Prox
function is non-differentiable, disclosing intermediate weights can facilitate real-
time monitoring of model accuracy across batches. Furthermore, this disclosure
can significantly enhance the algorithm’s convergence rate, offering a strategic
advantage in iterative optimization.

3 Two Privacy-Preserving Regularization Regression


Algorithms

In this section, we study how to design privacy-preserving algorithms to imple-


ment the functionality of Algorithm 7 and Algorithm 9. To make it simple, we
assume that the two parties Party0 and Party1 agree with the following public
parameters for training the regression algorithms:

– the learning rate α;


– batch size B for each iteration;
– the size of private dataset Xi and Yi owned by Partyi are (Ni , Di ) and (Ni , 1);
– the number of epochs EP , i.e., the number of iteration is IT = N ∗EP B where
N = N0 + N 1 .
In this paper, we only consider the horizontally partitioned datasets, i.e., D0 =
D1 and let D = D0 . In fact, it is easy to extend the algorithm to the vertically
partitioned dataset since only secret sharings [X]0 and [X]1 of the combined
dataset are needed. Moreover, let Y denote the combined dataset

.
Privacy-Preserving Ridge and Lasso Regression 27

3.1 Privacy-Preserving Regularization Parameter Selection


Algorithm
Since both Ridge and Lasso regression have a regularization term, we propose
a privacy-preserving parameter selection algorithm to select the regularization
parameter λ. This algorithm should be run before running our privacy-preserving
Ridge and Lasso regression algorithm. In Algorithm 11, the two parties should
agree with these public parameters: 1) the number Ti of candidate parameters λi
for party Partyi for i = 0, 1; 2) the regression algorithm A with a regularization
parameter; 3) the proportion p of the test dataset relative to the entire dataset.
In this context, A can be either plain Ridge regression (Algorithm 8) or plain
Lasso regression (Algorithm 9).
Remark 3. The reason for using candidate regularization parameters is that it
allows the accuracy of the privacy-preserving algorithm to remain as close as
possible to that of the plaintext algorithm.

Algorithm 10: SelectParameterByMSE


Input: the set of candidate parameters λs , the number of selected
parameters T , the entire dataset (X, Y ), the proportion p of the
test dataset relative to the entire dataset
Output: the set of selected parameters λt
Divide dataset (X, Y ) into training set (Xtrain , Ytrain ) and test set
(Xtest , Ytest ), with ratio 1 − p for training and p for test;
foreach candidate parameter λi ∈ λs do
Train the model: Run A((Xtrain , Ytrain ), α, EP, B) to obtain model
weights wi ;
Compute predictions: Ŷ = Xtest wiT ;
Compute the corresponding Mean Squared Error M SEi as:
1
M SEi = (Ŷ − Ytest )T (Ŷ − Ytest );
#Ytest

Select the parameters corresponding to the T smallest M SEi values as


λt :

λt = {λi1 , λi2 , . . . , λiT },


where i1 , i2 , . . . , iT are the first T items in argsort M SEi ;
i∈{1,2,...,|λs |}

return λt ;
28 Z. Yi et al.

Algorithm 11: Select-Regularization-Parameter


Input: (X0 , Y0 , λ0 ; X1 , Y1 )
Output: (λ;λ)
Party0 runs algorithm SelectParameterByMSE(λ0 , T1 , (X0 , Y0 ), p) to
obtain λ1 and then send λ1 to Party1 ;
Party1 runs algorithm SelectParameterByMSE(λ1 , 1, (X1 , Y1 ), p) to
obtain λ and then send λ to Party0 ;
Both Party0 and Party1 outputs λ;

In Algorithm 11, one party select partial candidate parameters λi for the
other party who selects the final parameter. The candidate parameters are
selected based on the Mean Squared Error (MSE) on the test dataset using
cross-validation method. The Algorithm 10 is run offline, i.e., the two parties
do not communicate with each other during the parameter selection process. In
practice, the candidate parameters λ0 can be determined based on experience.

3.2 Privacy-Preserving Ridge Regression Algorithm

To standardize the data, each party Partyi should add a column of ones to Xi ,
creating X̃i . They then use a secret sharing scheme (Algorithm 1) to distribute
X̃i and Yi to the other party Partyj , represented as X̃i = [X̃i ]0 + [X̃i ]1 and
Yi = [Yi ]0 + [Yi ]1 . Each party Partyi combines [X̃i ]i , [X̃j ]i , [Yi ]i , and [Yj ]i to

construct and for i = 0, 1.


In accordance with the Select-Regularization-Parameter (Algorithm 11), par-
ties Party0 and Party1 agree with that A is Plain-Ridge-Regression(Algorithm 8)
and execute the Select-Regularization-Parameter algorithm with these inputs
(X̃0 , Y0 , λ0 ; X̃1 , Y1 ). This process allows both parties to obtain the regulariza-
tion parameter λ. Once the parameter λ is determined, they can proceed to run
the privacy-preserving Ridge regression algorithm, effectively implementing the
functionality of Algorithm 8 in a secure manner.
Privacy-Preserving Ridge and Lasso Regression 29

Algorithm 12: Privacy-Preserving-Ridge-Regression


Input: (X0 , Y0 ; X1 , Y1 )
Output: (w;w)
for i = 0, 1 do
Partyi Initializes the weight [w]i to zero ;
for e = 1, 2, · · · , EP do
for i = 0, 1 do
Shuffle the dataset: Partyi sets P = σ([0, 1, 2, . . . , n − 1]), where
σ is a random permupation function;
Partyi sets X̂i = P ([X]i ) and Ŷi = P ([Y ]i );
for each batch index B ⊆ [0, 1, . . . , n − 1] of size B do
Partyi runs algorithm Secret-Sharing-Multiplication with
input (X̂i , [w]i , ⊥, ⊥, ⊥) and obtain Ỹi ;
Partyi runs algorithm Secret-Sharing-Multiplication with
input (X̂iT , Ỹi − Ŷi , ⊥, ⊥, ⊥) and obtain Gi ;
α
Partyi updates the weight: [w]i = (1 − 2λ)[w]i − B (Gi ) ;

for i = 0, 1 do
Partyi runs algorithm Recovery with input [w]i and obtain wi ;
Partyi outputs wi ;

Security. We use the UC Universal Composability (UC) framework [10] to


prove the security of this two-party privacy-preserving algorithm, referred to as
protocol in this subsection.
UC Framework: In UC framework, the simulation-based proof method is used
to prove the security. In this framework, executing a protocol is viewed as a set
of interacting Turing machines. Assume there exists a semi-honest adversary A,
and both parties interact according to protocol π. The environment Z selects the
input of the honest party and receives the output after the protocol execution.
When the interaction finishes, Z outputs a single bit. In general, it is assumed
that the virtual adversary A sends all received messages in the protocol to the
environment Z and completes the protocol steps under Z’s control.
Security Definition: The security definition in the UC framework is based on
the indistinguishability of outputs between the real and ideal protocols. Define
REAL[Z, A, π, κ] as the single bit output by environment Z when the real pro-
tocol is completed, where a probabilistic polynomial-time (PPT) adversary A
attacks the real protocol π, and κ is the security parameter of the protocol.
On the other hand, the ideal functionality Fml in a two-party secure machine
learning setting represents servers Party0 and Party1 computing the model
(y1 , y2 , . . . , ym ) = f (x1 , x2 , . . . , xm ), where yi is sent to the user. Here, xi is
the input from the user, and f is a function that can be repeatedly invoked. Let
IDEAL[Z, S, Fml , κ] denote the output of environment Z when a PPT simulator
S attacks the ideal protocol Fml , where κ is the security parameter.
30 Z. Yi et al.

If for any virtual adversary A attacking the real protocol π, there exists a
simulator S attacking the ideal protocol Fml , then in the running environment
Z, the following holds:

Pr[REAL[Z, A, π, κ] = 1] ≈ Pr[IDEAL[Z, S, Fml , κ] = 1].


If the outputs of the real protocol and the ideal protocol are computation-
ally indistinguishable for any input tested by the PPT environment under both
the adversary attacking the real protocol and the simulator attacking the ideal
protocol, then protocol π is UC-secure with respect to functionality F.
In the context of our privacy-preserving algorithm, we can define the ideal
functionality Fml as the plain Ridge regression algorithm (Algorithm 8) and
the real protocol π as the privacy-preserving Ridge regression algorithm (Algo-
rithm 12). The environment can be the two parties Party0 and Party1 and the
simulator can be a simulator that runs the privacy-preserving Ridge regression
algorithm on the same dataset. In the semi-honest model, participants follow the
protocol correctly but may attempt to infer additional information by analyzing
the messages they receive during the execution of the protocol. The adversary
can control one of the parties (say, Party0 or Party1 ) and obtain all the infor-
mation seen by that party.
Theorem 1. The protocol (Algorithm 12) is UC-secure with respect to the func-
tionality (Algorithm 8) against semi-honest adversaries.
Proof. Consider the following scenario: There exists a simulator S that attacks
the ideal protocol by submitting the colluding user’s input to Fml and receiv-
ing the output of the linear regression protocol, specifically the returned coef-
ficient w. The simulator S then runs A on behalf of the honest party. S, act-
ing as an honest user, sends a randomly secret-shared value to A. Clearly, this
step only involves the user’s input. In the remaining part of the protocol, the
honest party generates the random matrices corresponding to the secret shares
[X]i , [Y ]i , [a]i , [b]i , [c]i and interacts with the adversary A where [a]i , [b]i , [c]i is
a multiplication triplet. Finally, the simulator S iteratively updates the secret-
shared portion of the parameters of the honest server, so that the reconstructed
value is the real coefficient vector w that it obtained from the protocol.
The indistinguishability of the adversary’s output in both the real and ideal
protocols is based on the security of arithmetic secret sharing and the ran-
domization properties of the matrices or vectors generated in the offline phase.
Specifically, aside from reconstructing the model coefficient w, all messages sent,
received, and reconstructed in the protocol can be viewed as following the same
distribution, because both the real and ideal protocols use secret shares that are
generated uniformly at random. In the case of the private data matrix [X]i , to
hide the original data, it is blinded by adding a pseudo-random matrix [a]i , a
process known as blinding, where [a]i is the blinding factor. In the protocol, the
adversary A only accesses the blinded value of [X]i in one interaction and can-
not determine the blinding factor. After this interaction, the rest of the protocol
operations are completed between the honest server and the colluding server,
with no further interaction with the adversary A.
Privacy-Preserving Ridge and Lasso Regression 31

Experiments. To demonstrate the performance of the privacy-preserving Ridge


regression algorithm framework in this paper, we conducted experiments using
synthetic data, the Boston housing dataset, and other datasets designed to val-
idate the performance of the SMC-based privacy-preserving Ridge regression
framework. All experiments were repeated five times, and the average of the
five experimental results was recorded. Unless otherwise stated, the number of
iterations was set to traverse the training samples three times, with a batch size
of B = 128 samples, the regularization parameter obtained from Algorithm 11,
and the learning rate set to α = 0.25.
The experiments were conducted using virtual machines installed on a per-
sonal computer. To simulate the different participants Party0 and Party1 , dif-
ferent local ports were used to test the communication overhead between the
participants. The hardware environment for the experiments included a com-
puter running the Ubuntu 20.04 version of the Linux operating system, with
an Intel(R) Core(TM) i5-4200H CPU, 2.80GHz clock speed, 8GB of memory,
and 80GB of hard drive space. The software environment included C++ (g++)
version 9.4.0 as the development language, and cmake version 3.25.1 as the build
tool. The network used in the experiment is a gigabit local area network (LAN)
with a latency of 1 ms.
The datasets used for privacy-preserving Ridge regression training are listed
in Table 1. The reason for using synthetic datasets is that their size and the num-
ber of features are more controllable, which helps in analyzing the performance
of the framework. Other datasets were used to verify that the framework can
effectively train models even beyond synthetic datasets. All datasets were split
into 80% for training and 20% for testing, and were standardized. The training
set was horizontally split in half, assuming that each half was owned by the two
parties participating in the computation.
Synthetic Dataset: The synthetic dataset was generated using the same data
synthesis method as in previous works. The sample matrix X was randomly gen-
erated from a standard Gaussian distribution. The values of the d-dimensional
vector w were randomly selected from the range [0, 1]. The n-dimensional noise
vector δ was generated from a Gaussian distribution with a mean of 0 and a vari-
ance of σ 2 = 0.1. The n-dimensional vector Y was calculated using the formula
Y = Xw + δ.
The experiments are divided into two parts. In this subsubsection, we eval-
uate whether the privacy-preserving Ridge regression framework proposed in
this paper achieves comparable accuracy to the original Ridge regression algo-
rithm. Then, we demonstrate the time and communication overhead when the
framework processes small and large data inputs in the next subsubsection.
Due to precision handling, data truncation will inevitably result in the loss of
some data information. To illustrate that the framework can effectively handle
this issue, we introduce the root mean square error difference RM SE, as defined
in Eq. 3. This is used to measure the accuracy difference between the model
trained under the framework and the model trained in a plaintext environment
under the same conditions.
32 Z. Yi et al.


M SE − M SE
RM SE = (3)

M SE
where M SE represents the average loss of the training data under the privacy-
preserving framework, and M  SE represents the average loss of the training data
under plaintext calculation. The definition of the average loss is consistent with
the relevant definition in Subsubsect. 2.2 and Algorithm 10, which is the average
of the squared differences between all predicted labels and the corresponding
true labels of the test data. The closer the RM SE value is to zero, the more
similar the training results of the model under the framework are to the plaintext
training results.

Table 1. Accuracy Difference Results for Privacy-Preserving Ridge Regression

Dataset Name N∗ D∗ ΔRM SE ∗∗


Synthetic Dataset 1000 10 2.59 ×10−4
10000 10 4.80 ×10−4
100000 10 9.42 ×10−4
Boston Housing Dataset [24] 506 12 3.65 ×10−4
High School Student Grades [13] 395 48 3.21 ×10−4

N and D denote the numbers of samples and features,
respectively.
∗∗
ΔRM SE denotes the difference between the RM SE
of privacy-preserving and plaintext training results.

The participants Party0 and Party1 used the privacy-preserving Ridge regres-
sion algorithm framework proposed in this paper to train on both synthetic sam-
ple datasets and publicly available datasets from the UCI database, obtaining
and testing Ridge regression models. Simultaneously, plaintext Ridge regression
training and testing were also performed under the same conditions. Table 1
records the test results. All the Rmse values in the table are on the order of or
smaller than 10−4 , which suggests that the difference between the results of the
privacy-preserving algorithm framework and the plaintext algorithm is minimal.
Essentially, no information was lost during the privacy-preserving learning pro-
cess, and the Ridge regression model achieved almost the same accuracy as the
one under plaintext conditions.

Performance. To evaluate the performance of the privacy-preserving Ridge


regression algorithm framework, we split it into two phases. Phase 1 is calling the
offline phase, where the participants Party0 and Party1 jointly precompute losts
of Multiplication Triplets by running Algorithm 5 without loading their private
data. Phase 2 is the online phase, where the participants Party0 and Party1
jointly compute the regularization parameter λ, the secret shares of the private
data and update the model parameters using the gradient descent algorithm.
Privacy-Preserving Ridge and Lasso Regression 33

To improve efficiency, the two parties exchange a short random number as


a seed for a Pseudorandom Number Generator (PRNG). Party Partyi holds a
PRNG seeded by the seed sent by Partyj and thus can locally generate random
permutation and random matrices and save lots of communication. The Perfor-
mances are tested on difference numbers of samples and difference numbers of
features.
Remark 4. We note that the use of a PRNG does not compromise security, as
the distinguishability of our protocol would imply the distinguishability of the
PRNG. However, it is widely believed that a secure PRNG exists [12, 21, 22, 27],
meaning that the output of such a PRNG is indistinguishable from random.
Therefore, the use of a PRNG does not present a security issue in our protocol.

For Difference Numbers of Samples

Fig. 1. Time and communication for privacy-preserving Ridge regression under differ-
ent sample sizes.

Figure 1 shows the time and communication of the running with different param-
eters of the privacy-preserving Ridge regression framework with different sample
sizes. Participants Party0 and Party1 input their synthetic sample datasets with
sample sizes of 1000, 10000 and 100000, and feature count of 10, into the frame-
work for training Ridge regression models. The test results are shown in Fig. 1. It
can be observed that both the running time and communication increase as the
sample size increases. In the offline phase, especially for the group with a sample
size of 100000, the communication reaches 1738MB. However, since no actual
data participation is required during the offline phase, it can be pre-generated,
making the cost acceptable.
For Difference Numbers of Features
Participants Party0 and Party1 input their respective synthetic datasets, origi-
nally with 10,000 samples and with feature numbers of 10, 20, 30, and 40, into the
34 Z. Yi et al.

Fig. 2. Running Time and Communication of Privacy-Preserving Ridge Regression


Framework with Different Feature Sizes.

privacy-preserving algorithm framework described in this paper to train Ridge


regression models. The test results for the models are shown in Fig. 2. It can
be seen that the runtime and communication of this framework increase with
the number of features. During operation, most of the time expenditure in the
online phase is due to the selection of regularization parameters. For example, for
the synthetic dataset with 10,000 samples and 40 features, if the regularization
parameter selection step is not included, the online phase only takes 0.0348 s.
Choosing optimal regularization parameters aims to obtain a better model, and
since a model can be used for multiple predictions once trained, the time spent
on regularization parameter selection is worthwhile from a long-term benefits
perspective.

Table 2. Our Privacy-Preserving Ridge Regression Training Time and Communication


Overhead

Dataset Size Offline Phase Online Phase


Time Comm. Time Comm.
Boston Housing Prices (506, 13) 0.2587 11.34 0.0024 0.0964
High School Student Performance (395, 48) 0.5445 32.86 0.0034 0.2488

: Time is measured in seconds (s), and communication (Comm.) is mea-
sured in megabytes (MB).

The Boston Housing Prices dataset and the High School Student Perfor-
mance dataset were input into this paper’s computational framework to train
Ridge regression models and evaluate their performance, as shown in Table 2.
Additionally, baseline Ridge regression was performed under the same conditions
for comparison.
Privacy-Preserving Ridge and Lasso Regression 35

The test results are shown in Table 2. For the offline phase of the Boston
Housing Prices dataset, the total time was 5.76 s, and for the High School Student
Performance dataset with a training level plan, the total cost was 39.76 s. Since
the regularization parameter selection step was not included, the online phase
times in Table 2 exclude the time for this selection.
Comparison. The comparative results shown in Table 3 indicate that our
method has certain advantages in terms of runtime, and although the communi-
cation increases, the majority occurs during the offline phase. If conditions allow
for a reduction in the cubic term scaling in the offline phase, this framework
provides a good choice for privacy-preserving Ridge regression.

Table 3. Comparison of Time Overhead for Training in Different Secure Aggregation


Schemes

Scheme Type Dataset Dataset Size Total Time


Overhead
The LHE scheme in [20] Boston House Prices∗ (506, 13) 5.76 s∗∗
Middle School Students’ (395, 48) 39.76 s∗∗
Performance
Our Algorithm 12 Boston House Prices∗ (506, 13) 0.2611 s
Middle School Students’ (395, 48) 0.5479 s
Performance

Due to certain circumstances [26], this dataset was once available through
the UCI repository [16]. However, it is no longer accessible there or via scikit-
learn [36]. It can still be found through the TensorFlow API [38] and on the
StatLib website [23]. ∗∗ To compute the total time, we sum the durations of
Phase 1 and Phase 2 as outlined in Table 2 [20].

3.3 Privacy-Preserving Lasso Regression Algorithm

For implementing privacy-preserving Lasso regression, we adopt the same frame-


work as used for Ridge regression, substituting the Ridge loss function (Eq. (1))
with the Lasso loss function (Eq. (2)). The developed privacy-preserving Lasso
regression algorithm, Algorithm 13, securely enacts the functionality of its plain-
text counterpart, Algorithm 9.
A key aspect of Algorithm 9 is the use of a piecewise function Prox for the
second update of the weight parameter. Since evaluating a piecewise function in a
privacy-preserving context can be computationally expensive, we assign this task
to one party, Partyi . This party executes the evaluation and then communicates
the final result to the other party, Partyj . Consequently, Partyi has access to
the weight parameters w in plaintext throughout all iterations, while Partyj
only receives the final weight parameters. This approach significantly enhances
algorithm performance.
36 Z. Yi et al.

Algorithm 13: Privacy-Preserving-Lasso-Regression


Input: (X0 , Y0 ; X1 , Y1 )
Output: (w;w)
for i = 0, 1 do
Partyi Initializes the weight [w]i to zero ;
for e = 1, 2, · · · , EP do
for i = 0, 1 do
Shuffle the dataset: Partyi sets P = σ([0, 1, 2, . . . , n − 1]), where
σ is a random permupation function;
Partyi sets X̂i = P ([X]i ) and Ŷi = P ([Y ]i );
for each batch index B ⊆ [0, 1, . . . , n − 1] of size B do
Partyi runs algorithm Secret-Sharing-Multiplication with
input (X̂i , [w]i , ⊥, ⊥, ⊥) and obtain Ỹi ;
Partyi runs algorithm Secret-Sharing-Multiplication with
input (X̂iT , Ỹi − Ŷi , ⊥, ⊥, ⊥) and obtain Gi ;
α
Partyi updates the weight: [w]i = [w]i − B (Gi ) ;
Partyi runs algorithm Recovery with input [w]i and obtain w;
Partyi computes w = Proxλ (w) and output w ;

Security. We use UC framework to prove the security of this two-party privacy-


preserving algorithm. In the context of our privacy-preserving algorithm, we
can define the ideal functionality Fml as the plain Ridge regression algorithm
(Algorithm 9) and the real protocol π as the privacy-preserving Ridge regression
algorithm (Algorithm 13). The environment can be the two parties Party0 and
Party1 and the simulator can be a simulator that runs the privacy-preserving
Ridge regression algorithm on the same dataset. In the semi-honest model, par-
ticipants follow the protocol correctly but may attempt to infer additional infor-
mation by analyzing the messages they receive during the execution of the pro-
tocol. The adversary can control one of the parties (say, Party0 or Party1 ) and
obtain all the information seen by that party.

Theorem 2. The protocol (Algorithm 13) is UC-secure with respect to the func-
tionality (Algorithm 9) against semi-honest adversaries.

Proof. The proof is similar with that of the privacy-preserving Ridge regres-
sion algorithm. The only difference is that the simulator S outputs the weight
w in each batch which is required by the ideal functionality. Hence it is still
indistinguishable for any environment Z. Therefore, this protocol is UC-secure.

Remark 5. There is an attack called leakage from gradients [44, 48–50] in differ-
entiable learning model such as linear regression. This attack exploits the fact
that the gradients of the loss function can be reconstructed from the model
parameters and the input data. In Algorithm 13, Partyi knows the gradient of
the loss function with respect to the model parameters, which can be used to
Privacy-Preserving Ridge and Lasso Regression 37

infer the input data. However, the functionality of Algorithm 9 shows that it is
acceptable. Therefore, we can still use the privacy-preserving Ridge regression
algorithm to train the model with respect to Algorithm 9.

Experiments. The experiments in the subsubsection is similar to the exper-


iments in Subsubsect. 3.2. We use the Boston housing dataset and the high
school student performance dataset to evaluate the performance of the privacy-
preserving Lasso regression algorithm.

Table 4. Accuracy Difference Results for Privacy-Preserving Lasso regression

Dataset Name N∗ D∗ ΔRM SE ∗∗


Synthetic Dataset 1000 10 1.50 × 10−4
10000 10 4.46 × 10−4
100000 10 4.49 × 10−4
Boston Housing Dataset 506 13 6.70 × 10−4
High School Student Grades 395 48 3.48 × 10−5

N and D denote the numbers of samples and fea-
tures, respectively.

ΔRM SE denotes the difference between the
RM SE of privacy-preserving and plaintext training
results.

We conducted model training and testing on both a synthetic sample dataset


and two publicly available datasets. Parallelly, plaintext Lasso regression training
and testing were performed under the same conditions. The results are recorded
in Table 4. The discrepancies, as indicated by the RMSE values in the table,
are all in the order of or less than 10−4 , indicating that the error between the
models trained with this algorithm and those trained in plaintext under the same
conditions is negligible. This demonstrates the effectiveness of the algorithm.

Performance. The performance of the privacy-preserving Lasso regression


algorithm is evaluated in the same way as the performance of the privacy-
preserving Ridge regression algorithm. We compare the time and communication
of the two algorithms for different sample sizes and feature counts in offline phase
and online phase, respectively.
The parties Party0 and Party1 input the synthetic sample dataset into the
privacy-preserving Lasso regression algorithm framework to train the Lasso
regression model. Since the offline phase multiplication triple generation algo-
rithm of the privacy-preserving Lasso regression algorithm is consistent with
that of the privacy-preserving Ridge regression algorithm, we do not record the
performance in terms of running time and communication during the offline
phase.
38 Z. Yi et al.

Fig. 3. Time and Communication vs Difference Numbers of Samples and Features for
Privacy-preserving Lasso regression

The results of the model testing are shown in Fig. 3. The left graph illustrates
the running time and communication of the privacy-preserving Lasso regression
protocol during the online phase under different sample sizes, and the right graph
shows the same under different feature sizes. It can be seen that the running time
and communication of this framework vary as the number of features in the
training dataset increases, and there is also a linear dependence on the sample
size. From the results, it is evident that the privacy-preserving Lasso regres-
sion learning framework has a greater overhead in the online phase compared
to the privacy-preserving Ridge regression learning framework when training on
the same dataset. For instance, when training on a synthetic dataset of size
(100000, 10), the former requires 18.87 s and produces 15.57 MB of communica-
tion overhead, while the latter requires 17.39 s and 15.41 MB. This indicates that
if the training dataset does not contain irrelevant features and there is no need
for feature selection or training a sparse model, the privacy-preserving Ridge
regression learning framework is a better choice.
Comparison. In [40], the privacy-preserving Lasso regression scheme imple-
mented utilizing the MPyC framework as delineated by van Egmond et al. and
based on Shamir’s secret sharing, necessitates approximately 2000 s to culmi-
nate the entire process on a dataset comprising 10, 000 samples, each with 10
features. Of this duration, more than 1500 s are allocated to the training phase.
In stark contrast, our proposed scheme exhibits remarkable efficiency, requiring
a mere 5.57 s for training on a dataset of equivalent size, thereby underscoring
the substantial performance enhancement our methodology offers.

4 Conclusion
In this paper, we present two privacy-preserving regression algorithms-one for
Ridge regression and one for Lasso regression-based on Secure Multi-Party Com-
Privacy-Preserving Ridge and Lasso Regression 39

putation (SMPC). These methods enable two non-colluding parties to collabo-


ratively train machine learning models without revealing their private data. By
leveraging secret sharing and the multiplication triple protocol, we ensure secure
computation under a semi-honest model. For Lasso regression, we further opti-
mize the algorithm by incorporating proximal gradient descent, which reduces
the communication and computational load during model training. Our exper-
iments, conducted on both synthetic and real datasets, demonstrate that the
privacy-preserving models achieve similar accuracy to their plaintext counter-
parts.
Throughout this work, we encountered several challenges. Scaling the algo-
rithms to large datasets led to significant computational costs, especially during
the offline phase, where the precomputation of multiplication triplets increased
both time and communication overhead. Moreover, Lasso regression proved more
complex to implement in a privacy-preserving context because it involves piece-
wise evaluations in the proximal gradient method. These evaluations are expen-
sive to compute securely. We have yet to resolve these performance bottlenecks,
as current limitations in optimizing the communication and computational com-
plexity within SMPC frameworks hinder more efficient solutions.

Acknowledgments. The research of Bo Li, Wanhui Zhang and Zhiqiang Lin was
supported by the Guangzhou Municipal Construction Group Co., Ltd. Technology
Plan Project (2022-KJ023). The research of Lishan Ke was supported by the City
School Joint Funding Project of Guangzhou City (No. 2023A03J0117). The research
of Zongxiang Yi was supported by the Talent Special Project of Research Project
of Guangdong Polytechnic Normal University [Grant No. 2021SDKYA051], Scientific
Research Capacity Improvement Project of the Doctoral Program Construction Unit
of Guangdong Polytechnic Normal University [Grant No. 22GPNUZDJS31].

References
1. Agrawal, N., Shahin Shamsabadi, A., Kusner, M.J., Gascón, A.: Quotient: two-
party secure neural network training and prediction. In: Proceedings of the 2019
ACM SIGSAC Conference on Computer and Communications Security, pp. 1231–
1247 (2019)
2. Badrinarayanan, S., Masny, D., Mukherjee, P.: Efficient and tight oblivious transfer
from PKE with tight multi-user security. In: International Conference on Applied
Cryptography and Network Security, pp. 626–642. Springer, Cham (2022). https://
doi.org/10.1007/978-3-031-09234-3_31
3. Beaulieu-Jones, B.K., et al.: Privacy-preserving generative deep neural networks
support clinical data sharing. Circ. Cardiovasc. Qual. Outcomes 12(7), e005122
(2019). https://doi.org/10.1161/circoutcomes.118.005122
4. Beaver, D.: Efficient multiparty protocols using circuit randomization. In: Advances
in Cryptology-CRYPTO 1991: Proceedings 11, pp. 420–432. Springer (1992)
5. Bogdanov, D., Laur, S., Willemson, J.: Sharemind: a framework for fast privacy-
preserving computations. In: Jajodia, S., Lopez, J. (eds.) ESORICS 2008. LNCS,
vol. 5283, pp. 192–206. Springer, Heidelberg (2008). https://doi.org/10.1007/978-
3-540-88313-5_13
40 Z. Yi et al.

6. Bogdanov, D., Niitsoo, M., Toft, T., Willemson, J.: High-performance secure multi-
party computation for data mining applications. Int. J. Inf. Secur. 11, 403–418
(2012). https://doi.org/10.1007/s10207-012-0177-2
7. Branco, P., Fiolhais, L., Goulão, M., Martins, P., Mateus, P., Sousa, L.: Roted:
random oblivious transfer for embedded devices. IACR Trans. Cryptogr. Hardw.
Embed. Syst. 215–238 (2021)
8. Byali, M., Chaudhari, H., Patra, A., Suresh, A.: Flash: fast and robust frame-
work for privacy-preserving machine learning. Cryptology ePrint Archive (2019).
https://doi.org/10.2478/popets-2020-0036
9. Byun, J., Lee, W., Lee, J.: Parameter-free he-friendly logistic regression. Adv.
Neural. Inf. Process. Syst. 34, 8457–8468 (2021)
10. Canetti, R.: Universally composable security: a new paradigm for cryptographic
protocols. In: Proceedings 42nd IEEE Symposium on Foundations of Computer
Science, pp. 136–145. IEEE (2001). https://doi.org/10.1109/sfcs.2001.959888
11. Chaudhari, H., Rachuri, R., Suresh, A.: Trident: efficient 4PC framework for pri-
vacy preserving machine learning. arXiv preprint arXiv:1912.02631 (2019). https://
doi.org/10.14722/ndss.2020.23005
12. Chen, Y., Zhang, J.: The role of pseudorandomness in cryptographic protocols.
ACM Trans. Inf. Syst. Secur. 26(1), 1–25 (2023). https://doi.org/10.1145/3602978
13. Cortez, P., Silva, A.M., et al.: Using data mining to predict secondary school stu-
dent performance. In: Proceedings of 5th Annual Future Business Technology Con-
ference (2008)
14. De Cock, M., Dowsley, R., Nascimento, A., Railsback, D., Shen, J., Todoki, A.: Fast
secure logistic regression for high dimensional gene data. In: Privacy in Machine
Learning (PriML2019). Workshop at NeurIPS, pp. 1–7 (2019)
15. Demmler, D., Schneider, T., Zohner, M.: Aby-a framework for efficient mixed-
protocol secure two-party computation. In: NDSS (2015). https://doi.org/10.
14722/ndss.2015.23113
16. Dua, D., Graff, C.: UCI machine learning repository: housing data set (1997).
http://archive.ics.uci.edu/ml/datasets/housing. Accessed 27 Sept 2024
17. Dwork, C., McSherry, F., Nissim, K., Smith, A.: Calibrating noise to sensitivity in
private data analysis. In: Halevi, S., Rabin, T. (eds.) TCC 2006. LNCS, vol. 3876,
pp. 265–284. Springer, Heidelberg (2006). https://doi.org/10.1007/11681878_14
18. Dwork, C., Roth, A., et al.: The algorithmic foundations of differential privacy.
Found. Trends R Theor. Comput. Sci. 9(3–4), 211–407 (2014). https://doi.org/10.
1561/9781601988195
19. Gascón, A., Schoppmann, P., Balle, B., Raykova, M., Doerner, J., Zahur, S., Evans,
D.: Privacy-preserving distributed linear regression on high-dimensional data. In:
Proceedings on Privacy Enhancing Technologies, pp. 345–364 (2017)
20. Giacomelli, I., Jha, S., Joye, M., Page, C.D., Yoon, K.: Privacy-preserving
ridge regression with only linearly-homomorphic encryption. In: Preneel, B., Ver-
cauteren, F. (eds.) ACNS 2018. LNCS, vol. 10892, pp. 243–261. Springer, Cham
(2018). https://doi.org/10.1007/978-3-319-93387-0_13
21. Goldreich, O.: Foundations of Cryptography: Basic Tools, vol. 1. Cambridge
University Press (2001). ISBN 9780521791725. https://www.cambridge.org/core/
books/foundations-of-cryptography/0D0CAE9CA377D645DB9BDCCD743F7B27
22. Goldreich, O., Krawczyk, H., Luby, M.: On the existence of pseudorandom gener-
ators. In: Proceedings of the 18th Annual ACM Symposium on Theory of Com-
puting, pp. 12–24. ACM (1986). https://doi.org/10.1109/sfcs.1988.21917
Privacy-Preserving Ridge and Lasso Regression 41

23. Harrison, D., Rubinfeld, D.L.: Boston housing dataset. StatLib - Carnegie Mel-
lon University (1978). https://lib.stat.cmu.edu/datasets/boston. Accessed 27 Sept
2024
24. Harrison, D., Jr., Rubinfeld, D.L.: Hedonic prices and the demand for clean
air. J. Environ. Econ. Manag. 5(1), 81–102 (1978). https://doi.org/10.1016/0095-
0696(78)90006-2
25. Hoerl, A.E., Kennard, R.W.: Ridge regression: biased estimation for nonorthogonal
problems. Technometrics 12(1), 55–67 (1970). https://doi.org/10.2307/1267351
26. Intangible, D.: Racist data destruction (2020). https://medium.com/
@docintangible/racist-data-destruction-113e3eff54a8. Accessed 27 Sept 2024
27. Jain, A., Singh, R.: A survey of pseudorandom number generators in cryptogra-
phy. Cryptography 6(2), 12 (2022). https://doi.org/10.3390/cryptography6020012.
https://www.mdpi.com/2410-387X/6/2/12
28. Klosa, J., Simon, N., Westermark, P.O., Liebscher, V., Wittenburg, D.: Seagull:
lasso, group lasso and sparse-group lasso regularization for linear regression models
via proximal gradient descent. BMC Bioinform. 21, 1–8 (2020). https://doi.org/
10.1186/s12859-020-03725-w
29. Kumar, S., Bhatnagar, V.: A review of regression models in machine learning. J.
Intell. Syst. Comput. 3(1), 40–47 (2022)
30. Mohassel, P., Rindal, P.: ABY3: a mixed protocol framework for machine learning.
In: Proceedings of the 2018 ACM SIGSAC Conference on Computer and Commu-
nications Security, pp. 35–52 (2018)
31. Mohassel, P., Zhang, Y.: Secureml: a system for scalable privacy-preserving
machine learning. In: 2017 IEEE Symposium on Security and Privacy (SP), pp.
19–38. IEEE (2017). https://doi.org/10.1109/sp.2017.12
32. Patra, A., Schneider, T., Suresh, A., Yalame, H.: {ABY2. 0}: improved {Mixed-
Protocol} secure {Two-Party} computation. In: 30th USENIX Security Symposium
(USENIX Security 2021), pp. 2165–2182 (2021)
33. Raff, E., Khanna, A., Lu, F.: Scaling up differentially private lasso regularized logis-
tic regression via faster frank-wolfe iterations. In: Advances in Neural Information
Processing Systems, vol. 36 (2024)
34. Rouhani, B.D., Riazi, M.S., Koushanfar, F.: Deepsecure: scalable provably-secure
deep learning. In: Proceedings of the 55th Annual Design Automation Conference,
pp. 1–6 (2018). https://doi.org/10.1109/dac.2018.8465894
35. Sarkar, E., Chielle, E., Gursoy, G., Chen, L., Gerstein, M., Maniatakos, M.:
Privacy-preserving cancer type prediction with homomorphic encryption. Sci. Rep.
13(1), 1661 (2023)
36. Scikit-learn. Boston housing dataset - scikit-learn API (2024). https://scikit-learn.
org/1.1/modules/generated/sklearn.datasets.load_boston.html. Accessed 27 Sept
2024
37. Shi, H., et al.: Secure multi-party computation grid logistic regression (SMAC-
GLORE). BMC Med. Inform. Decis. Mak. 16, 175–187 (2016). https://doi.org/10.
1186/s12911-016-0316-1
38. TensorFlow. Boston housing dataset - tensorflow keras API (2024). https://www.
tensorflow.org/api_docs/python/tf/keras/datasets/boston_housing/load_data.
Accessed 27 Sept 2024
39. Tibshirani, R.: Regression shrinkage and selection via the lasso. J. R. Stat. Soc.
Ser. B (Stat. Methodol.) 58(1), 267–288 (1996). https://doi.org/10.1111/j.2517-
6161.1996.tb02080.x
40. van Egmond, M.B., et al.: Privacy-preserving dataset combination and lasso regres-
sion for healthcare predictions. BMC Med. Inform. Decis. Making 21, 1–16 (2021)
42 Z. Yi et al.

41. Veugen, T., Kamphorst, B., van de L’Isle, N., van Egmond, M.B.: Privacy-
preserving coupling of vertically-partitioned databases and subsequent training
with gradient descent. In: Dolev, S., Margalit, O., Pinkas, B., Schwarzmann, A.
(eds.) CSCML 2021. LNCS, vol. 12716, pp. 38–51. Springer, Cham (2021). https://
doi.org/10.1007/978-3-030-78086-9_3
42. Wagh, S., Tople, S., Benhamouda, F., Kushilevitz, E., Mittal, P., Rabin, T.: Fal-
con: honest-majority maliciously secure framework for private deep learning. arXiv
preprint arXiv:2004.02229 (2020). https://doi.org/10.2478/popets-2021-0011
43. Wang, P., Zhang, H.: Differential privacy for sparse classification learning. Neuro-
computing 375, 91–101 (2020). https://doi.org/10.1016/j.neucom.2019.09.020
44. Wei, W., Liu, L.: Gradient leakage attack resilient deep learning. IEEE Trans. Inf.
Forensics Secur. 17, 303–316 (2021). https://doi.org/10.1109/tifs.2021.3139777
45. Wu, Y., Jiang, X., Kim, J., Ohno-Machado, L.: G rid binary LO gistic RE gression
(GLORE): building shared models without sharing data. J. Am. Med. Inform.
Assoc. 19(5), 758–764 (2012). https://doi.org/10.1136/amiajnl-2012-000862
46. Xie, L., Lin, K., Wang, S., Wang, F., Zhou, J.: Differentially private generative
adversarial network. arXiv preprint arXiv:1802.06739 (2018)
47. Yadav, V.K., Andola, N., Verma, S., Venkatesan, S.: A survey of oblivious transfer
protocol. ACM Comput. Surv. (CSUR) 54(10s), 1–37 (2022). https://doi.org/10.
1145/3503045
48. Zhao, B., Mopuri, K.R., Bilen, H.: IDLG: improved deep leakage from gradients.
arXiv preprint arXiv:2001.02610 (2020)
49. Zheng, Y.: Dropout against deep leakage from gradients. arXiv preprint
arXiv:2108.11106 (2021)
50. Zhu, L., Liu, Z., Han, S.: Deep leakage from gradients. In: Advances in Neural
Information Processing Systems, vol. 32 (2019). https://doi.org/10.1007/978-3-
030-63076-8_2
A Decentralized Bitcoin Mixing Scheme
Based on Multi-signature

Mingdi Shen, Tianqi Zhou(B) , Chen Wang, and Shijia Hong

School of Information Science and Engineering, Zhejiang Sci-Tech University,


Hangzhou 310018, China
[email protected]

Abstract. Bitcoin mixing operations are commonly used to enhance


transaction privacy by obscuring the connection between inputs and
outputs. However, current methods have drawbacks such as relying on
trusted third parties, being susceptible to rogue key attacks, and ineffi-
ciency. To tackle these issues, our paper introduces a decentralized Bit-
coin mixing scheme CoinMixMultiSig (CMMS), which enables partici-
pants to collaborate in generating a public address. This scheme employs
a two-round public address generation process, wherein all participants
send bitcoins to the jointly generated addresses before signing them to
prevent fraud. Additionally, CMMS utilizes a multi-signature mechanism
that not only reduces the size of signatures but also significantly improves
the system’s scalability.

Keywords: Blockchain · Bitcoin · Coin mixing · BLS signature ·


Multi-signature

1 Introduction

Although distributed ledger technology in the Bitcoin blockchain provides trans-


parency and security, its original design compromises user privacy to some extent
[1]. Every transaction is publicly recorded and permanently stored, allowing any
individual or institution to trace the entire transaction history of an account,
which poses a potential threat to user privacy [2]. To address this issue, in recent
years, academia and industry have actively explored incorporating privacy pro-
tection mechanisms into the Bitcoin blockchain [3]. Among these, decentralized
coin mixing protocols become a research hotspot. The core of this technology lies
in blending the transaction inputs and outputs of multiple users, thus breaking
the direct link between transactions and increasing the difficulty of tracking the
transaction paths [4]. By interrupting this direct connection, it becomes chal-
lenging for outsiders to track where the money comes from and where it goes.
Using multi-signature technology, a mixed-coin operation requires signature
approval from multiple (or even all) participants in order to execute a transac-
tion [5]. This means that even if there is a malicious participant, he cannot carry

c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 43–55, 2025.
https://doi.org/10.1007/978-981-96-4566-4_3
44 M. Shen et al.

out unauthorized transactions on his own because he needs the signature confir-
mation of other participants [6]. This effectively prevents any single participant
from attempting to spoof or unilaterally manipulate funds. The Bitcoin mixing
scheme adopted in this paper ingeniously combines multi-signature technology,
allowing multiple parties to jointly control funds, effectively mitigating the risk
of a single account being compromised, and enhancing transaction anonymity
[7]. The Bitcoin blockchain maintains its core decentralization characteristics by
integrating multi-signatures with decentralized signature protocols while ensur-
ing transaction privacy [8]. In this paper, we propose a Bitcoin mixing scheme.
Specifically, the contributions of this paper are as follows:
We propose a novel mechanism for generating a public address that ensures
that all members of the mixing group participate in the process of generating
public addresses through multiple interactions.
Multi-signature technology is introduced in this mixing scheme, which
ensures that transactions are confirmed by all participants before they are exe-
cuted. This multi-signature uses the method of adding hash values and indices to
resist rogue key attacks, effectively prevents malicious participants from deceiv-
ing other participants by generating special public key pairs.
Compared with traditional decentralized mixing coin schemes, the scheme
proposed in this paper demonstrates significant efficiency advantages in process-
ing mixing coin transactions. The scheme reduces the communication overhead
and computation overhead during the transaction process through an optimized
transaction negotiation and confirmation process.

2 Related Works
2.1 Centralized Mixing Protocols
The earliest Bitcoin mixing schemes relied on centralized mixing services. These
services aggregate users’ Bitcoins into a single pool for mixing and then redis-
tribute them to the users, thus obscuring the origin of the funds. Coinlayering [9]
is a centralized coin mixing scheme aimed at protecting privacy by mixing trans-
actions of multiple users, allowing users to select multiple available “mixers” and
use them to perform randomized transactions to protect their identity privacy.
However, Coinlayering also has some drawbacks. Firstly, it relies on trust
in the mixer, where users need to trust that the mixer will not disclose their
privacy or participate in attacks. Secondly, due to the centralized design, there
is a risk of a single point of failure in this scheme, and once the coin mixer is
attacked or fails, the system will be unable to operate. Finally, Coinlayering has
poor resistance to censorship, and coin mixers may review or reject transactions
from certain users under regulatory pressure.

2.2 Decentralized Mixing Protocols


To solve the shortcomings of centralized services, researchers propose decentral-
ized coin mixing protocols.
A Decentralized Bitcoin Mixing Scheme Based on Multi-signature 45

CoinParty [10] takes advantage of a combination of decrypted mixing net-


works and threshold signature technology to exchange and confuse a user’s bit-
coins multiple times between multiple mixing nodes, thus achieving a high degree
of anonymity and security [11]. However, CoinParty has high communication
and computational overheads and relies on the assumption of an honest major-
ity, which means that the majority of participants must be honest or else the
privacy and security of the system will be at risk.
Advanced scheme [12] combines the ElGamal [13] signature protocol and
secret sharing techniques to build a distributed signature protocol. However, the
scheme’s method of generating public addresses is highly risky and the method
of generating signatures is extremely complicated and has a high overhead.
Similar to Coinparty and Advanced schemes, CMMS is based on the princi-
ple of multi-party collaboration. However, unlike these schemes, our method of
generating public addresses is extremely secure and tamper-resistant. Simultane-
ously, our scheme introduces multi-signature, which can resist rogue key attacks
while compressing the signature length to ensure the security of participants’
assets.

3 Preliminaries
3.1 Bilinear Pairing
Let G1 , G2 and GT be multiplicative cyclic groups of order q. Define a mapping
e : G1 × G2 → GT as a bilinear map if it satisfies the following three properties.
1. Bilinearity: For all u ∈ G1 , v ∈ G2 , and a, b ∈ Zq :

e(ua , v b ) = e(u, v)ab

2. Non-degeneracy: There exist elements u ∈ G1 and v ∈ G2 such that:

e(u, v) = 1GT

3. Computability: There exists an efficient algorithm to compute e(u, v) for all


u ∈ G1 and v ∈ G2 .
If G1 = G2 , the bilinear pairing is symmetric, otherwise, it is asymmetric.

3.2 Multi-signatures
A multi-signature scheme [14] enables n parties to collectively sign a single sig-
nature σ on a message M . The scheme typically involves four algorithms: Setup,
Key generation, Sign and Verify, the details are described below:
1. Setup(1λ ) → param: The setup takes as input the security parameter λ and
output a system parameter param.
2. KeyGen(param) → (sk, pk): The KeyGen algorithm takes param as input
and outputs user’s private key sk and a public key pk.
46 M. Shen et al.

3. Sign(sk, M ) → σ: The sign algorithm is an interactive process. Suppose there


are n users, each with their public key pki , who want to sign a message M
jointly. They need to use their private keys ski to sign the message, generating
individual signatures σi . These signatures are then interactive with the other
users. After each user executes the algorithm, the output is a multi-signature
σ.
4. Verify(P K, σ, M ) → 0 / 1: The Verify algorithm takes a set of public keys
P K = {pk1 , pk2 , . . . , pkn }, a multi-signature σ and a message M as inputs.
The verifier checks whether σ is a valid multi-signature on the message M by
the owners whose public keys are P K. If the multi-signature σ is valid, the
algorithm outputs 1; otherwise, it outputs 0.

4 System Model and Design Goals


4.1 System Model

The system model of the mixing transaction process is shown in the Fig. 1. First,
each participant broadcasts or otherwise publishes a request to create or join a
mixing group, forming a temporary mixing group. The group members then
collectively generate a public address, to which all Bitcoins must be sent. Next,
each node splits its genuine Bitcoin mixing request into several sub-requests
and secretly sends them to other group members while receiving sub-requests
from different nodes. All nodes in the group gather the sub-requests and verify
the transaction, ensuring that all requests are fulfilled and that the amounts
match. Once the transaction is validated, all nodes sign the transaction, perform
a hashing operation, and generate the transaction hash. Finally, this finalized
transaction is broadcast to the Bitcoin network, where it awaits confirmation by
miners. After the miners confirm the transaction, it is packaged into a new block
and ultimately recorded publicly on the blockchain.

4.2 Design Goals

Based on the above system model and the security model, the design goals of
CMMS are as follows:
Decentralization: Implement a decentralized mixing mechanism that does not
rely on any centralized mixing service provider, thereby avoiding single points
of failure and reducing centralization risks.
Multi-signature: To leverage the benefits of multi-signature technology, which not
only reduces the size of signatures but also significantly improves the system’s
scalability.
Resistance to Rogue Key Attacks: The multi-signature mechanism used in our
scheme is designed to resist rogue key attacks, ensuring the validity and integrity
of the signatures even in the presence of malicious actors attempting to manip-
ulate the signing process.
A Decentralized Bitcoin Mixing Scheme Based on Multi-signature 47

Fig. 1. The system model.

5 Proposed Scheme

1. Creating a temporary mixed group

When participants choose to engage in a mixing operation, they can use Bitcoin’s
public broadcast channels to announce their intention to join the mix. Each
participant has the choice of either starting a new mixing group and inviting
others or joining an existing group created by other users. This collaboration
helps in obscuring the origins of transactions.

2. Creating a public address

After a mixing group has been created, all of the bitcoins from members of the
group who want to participate in the mix must be sent to a public address that
they generate together. Our method for generating public addresses is described
in Algorithm 1.

3. Negotiation among members of the mixing group

N odei creates a new account as the output address and generates several different
negotiation requests. Each negotiation request includes all output addresses and
the amount of bitcoins involved. N odei randomly selects different nodes to send
each negotiation request while also receiving requests from other nodes. When
N odej receives a negotiation request from N odei , N odej determines whether he
has enough bitcoins to fulfill N odei ’s request. If he can fulfill the request, N odej
generates a message indicating that he has enough bitcoins to meet the N odei ’s
requirement and broadcasts this message to the mixing group.
48 M. Shen et al.

Algorithm 1. Creating a public address


Require: Number of nodes m, cryptographic function g, random values α1 , α2 , ..., αm
Ensure: Public mixing address H(Z)
1: L0 ← ∅
2: for each node i = 1 to m do
3: if i = 1 then
4: α1 ← randomly select α1
5: L1 ← g α1
6: Send L1 to Node2
7: else
8: αi ← randomly select αi
9: Li ← Li−1  g α1 α2 ...αi {Concatenate Li−1 with g α1 α2 ...αi }
10: Send Li to Nodei+1
11: end if
12: end for
13: Nodem constructs Lm = Lm−1  g α1 α2 ...αm and sends Lm to Nodem−1
14: for each node i = m to 1 do
15: if i = m then
16: Z ← g α1 α2 ...αm
17: Extract elements y from Lm−1
18: Compute y αm for each y and construct Mm
19: Send Mm to Nodem−1
20: else
21: Z ← (g α1 α2 ...αi−1 αi+1 ...αm )αi
22: Extract elements y from Mi+1
23: Compute y αi for each y and construct Mi
24: Send Mi to Nodei−1
25: end if
26: end for
27: Node1 computes Z = (g α2 ...αm )α1
28: Public Address ← H(Z)
29: Return Public Address

4. Create the final transaction

After completing the above negotiations, each N odei will verify the validity of
the information. First, N odei checks whether its output address has received the
correct amount of bitcoins. Second, by comparing the total bitcoins in all the
satisfied and partially satisfied messages with the total requested in all mixing
requests, it determines if any requests have been overlooked.

5. Sign the final transaction

The signing and verifying process is shown in Algorithm 2. If the verification


passes, indicating that the M ix − T rans is valid, the miner puts the final trans-
action M ix − T rans and signature Sign on the blockchain.
A Decentralized Bitcoin Mixing Scheme Based on Multi-signature 49

Algorithm 2. Signing and Verifying Transaction


Require: Hash functions H0 : {0, 1}∗ → G2 , H1 : {0, 1}∗ → Zq , bilinear group
parameters (q, G1 , G2 , Gt , e, g1 , g2 )
Ensure: Aggregated signature Sign, and verification result
1: Key Generation:
2: for each node i = 1 to m do
3: ski ← randomly select private key
4: pki ← g2ski
5: Publicize pki
6: end for
7: Aggregation Public Key (apk):
8: The creator of the mixing group computes ai = H1 (pki , PK) where PK is the set
of public
 keys ai
9: apk ← m i=1 pki
10: Broadcast apk within the group
11: Signature Generation:
12: for each node i = 1 to m do
13: si ← H0 (Mix-Trans)ai ·ski
14: Broadcast si within the group
15: end for
16: Aggregated Signature: 
17: The creator collects all si and computes Sign = m i=1 si
18: Broadcast Sign as the aggregated signature
19: Signature Verification:
20: for each node i = 1 to m do
?
21: Verify if e(Sign, g2−1 ) · e(H0 (Mix-Trans), apk) = 1Gt
22: end for
23: Return verification result

6 Security Analysis
6.1 Correctness

The correctness of Sign is as follows:


m
 m
 m
Sign = si = H0 (M ix − T rans)ai ·ski = H0 (M ix − T rans) i=1 ai ·ski

i=1 i=1
m m
 m
ai ·ski
apk = pkiai = g2ai ·ski = g2 i=1

i=1 i=1
   m 
e Sign, g2−1
= e H0 (M ix − T rans) i=1 ai ·ski , g2−1
 m 
a ·sk
e (H0 (M ix − T rans), apk) = e H0 (M ix − T rans), g2 i=1 i i
50 M. Shen et al.

We can get
 
e Sign, g2−1 ) · e(H0 (M ix − T rans), apk
 m   m 
a ·sk
= e H0 (M ix − T rans) i=1 ai ·ski , g2−1 · e H0 (M ix − T rans), g2 i=1 i i
 m   m 
= e H0 (M ix − T rans) i=1 ai ·ski , g2−1 · e H0 (M ix − T rans) i=1 ai ·ski , g2
=1

Our signing process proved to be correct.

6.2 Resisting the Rogue-Key Attack

The rogue key attack [14] is a common attack in cryptography. The attacker
maliciously creates a special key pair so that their public key has a certain
relationship with the public keys of other participants. This allows the attacker
to deceive participants into participating unknowingly in the signing process.
First, in this scheme,
m
 m
 H1 (pki ,{pk1 ,...,pkm })
apk = pkiai = pki
i=1 i=1

This public key aggregation approach uses the hash function H1 to bind the
public key of each participant. An attacker cannot make his public key some-
how maliciously related to the public keys of other participants by generating
special key pairs, because the public key aggregation formula ensures that the
contribution of each public key is independent and related to the set of all public
keys.
Secondly, the signing process uses the following formula:

si = H0 (M ix − T rans)ai ·ski

where ai is generated by the hash function H1 , which ensures that each signer’s
signature is associated with the public keys of all other signers. If the attacker
uses a malicious public key to participate in the signature, when the multiple
signatures are finally aggregated, the attack will fail because the hash of the
malicious public key cannot match the hash relationship of other signers.
Finally, the multi-signature verification process is verified by the following
equation:
  ?
e Sign, g2−1 · e (H0 (M ix − T rans), apk) = 1Gt .
During the verifying process, the aggregated signature Sign and the aggre-
gated public key apk are bound and are verified by a bilinear mapping e. Even
though attackers construct malicious key pairs [15], they cannot generate valid
signatures without corresponding legitimate private keys.
A Decentralized Bitcoin Mixing Scheme Based on Multi-signature 51

7 Performance Evaluation
7.1 Theoretical Evaluation
We analyze the performance of this scheme by comparison. Table 1 compares the
computational overhead of CMMS with CoinParty, Advanced scheme, and Coin-
Layering. CMMS has significant advantages in modular exponentiation opera-
tion and hash operation. And in modular multiplication operation, compared to
CoinParty, Advanced scheme, and CoinLayering three schemes also have lower
computational overhead. It should be noted that CMMS involves elliptic curve
pairing operation in the verification phase, which involves 2m pairing operations
since each node needs to compute two pairings for one verification. The other
schemes do not involve elliptic curve pairing operation, so our scheme will have
an overhead in verification compared to the other schemes. However, overall, our
computational overhead is still small.

Table 1. Computation overhead of different protocols

ME MM Hash EC scalar M Bilinear Pairing


Coinparty 4m · TE 8m · TM 4m · TH 10m · TR /
Advanced scheme (8m − 2) · TE (5m − 1) · TM (2m + 1) · TH / /
Coinlayering 8m · TE 4m · TM 5m · TH 3m · TR /
Ours 3m · TE (3m − 2) · TM 2m · TH / 2m · TP
a
TE denotes the time required to do a modular exponentiation operation, TM denotes
the time required to do a modular multiplication operation, TH denotes the time
required to do a hash, TR denotes the time required to do a Elliptic curve scalar
multiplication, and TP denotes the time required for bilinear pairing.

As shown in Table 2, in the Initialization phase, CMMS has a slightly higher


overhead than CoinParty and Advanced Scheme as it requires two rounds of
communication, but still maintains a better communication efficiency compared
to CoinLayering, which involves signatures. In the transaction generation phase,
CMMS optimizes the communication overhead through signature aggregation in
the transaction generation phase, although each participant still needs to broad-
cast its signature, the group creator only needs to broadcast an aggregated signa-
ture in the end. CMMS shows significant advantages over CoinParty’s multi-layer
cryptographic delivery. In the verification and output phases, although the multi-
signature design slightly increases the verification overhead, the communication
overhead of CMMS is still low compared to CoinParty’s complex multi-layered
obfuscated verification.
Table 3 compares this scheme with the existing representative schemes,
namely Coinparty, Advanced scheme, and Coinlayering, where ✓ indicates that
the scheme has the corresponding property and ✗ indicates that the scheme
does not have the corresponding property. In summary, our scheme is outstand-
ing in several performance indicators, especially in terms of smaller signature
52 M. Shen et al.

size, fewer interaction rounds, and stronger robustness, demonstrating its supe-
riority in practical applications. This makes the scheme not only a significant
performance advantage but also more suitable for efficient and secure operations
in decentralized environments.

Table 2. Communication overhead of different protocols

Initialization Transaction generation Verification and output


Coinparty m × lpk m × lEnadd × m m × (lsig + ltrans )
Advanced scheme m × lpk m × (lInadd + lsig ) m × lsig
Coinlayering m × (lpk + lsig ) 3m × lcommit 2m × lproof
Ours 2(m − 1) × lpk m × lsig + laggsig m × laggsig + m × (lInadd + lOutadd + laggsig )
a
m represents the number of people participating in the mix.
b
lpk indicates the size of the public key, lsig indicates the size of the signature, the
same goes for the others.

Table 3. Performance comparison of different schemes

Coinparty Advanced scheme Coinlayering Ours


Decentralization ✓ ✓ ✗ ✓
No additional expense ✓ ✓ ✗ ✓
Shorter signature size ✗ ✗ ✗ ✓
Fewer interaction rounds ✗ ✗ ✗ ✓
Higher scalability ✗ ✗ ✗ ✓
Higher robustness ✗ ✗ ✗ ✓
a
✓ indicates that the solution satisfies this performance and ✗ indicates
that the solution does not meet this performance.

7.2 Experimental Evaluation

All experiments were performed on a server with AMD Ryzen 7 7840H


w/Radeon(TM) 780M Graphics 3.80 GHz, 32 GB RAM and Windows 11. We
used the Pairing-Based Cryptography Library (PBC) library to implement our
cryptography technique, where the chosen hash function is MaptoPoint. We first
measure the computational overhead of this scheme for different numbers of mix-
ing coins: m = 8, m = 16, m = 32 This includes modular exponentiation and
modular multiplication. Since the time overhead of hash operations and elliptic
curve scalar multiplication is very small, we ignore the time for hash operations.
The results are shown in Fig. 2; We similarly measure the total computational
overhead of each scheme (suppose m = 10) in Fig. 3 (Fig. 4).
A Decentralized Bitcoin Mixing Scheme Based on Multi-signature 53

Fig. 2. Computational overhead of ours.

Fig. 3. Total computational overhead of different protocols.

Fig. 4. Computational overhead of different participants.

8 Conclusion

Our proposed decentralized Bitcoin mixing scheme CMMS effectively enhances


transaction privacy and security while eliminating the reliance on trusted third
54 M. Shen et al.

parties. The integration of multi-signature mechanisms and a robust public


address generation process provides a scalable and efficient solution, addressing
key challenges faced by current mixing methods. Our scheme sets a foundation
for future research and practical applications in the realm of blockchain privacy
protection.

Acknowledgments. This work is supported by the National Key R&D Program of


China (No. 2023YFB2703700), the National Natural Science Foundation of China (Nos.
U21A20465, 62302457, 62402448), the Fundamental Research Funds of Zhejiang Sci-
Tech University under Grants No. 22222266-Y, the Program for Leading Innovative
Research Team of Zhejiang Province (No. 2023R01001), the Zhejiang Provincial Natu-
ral Science Foundation of China (Nos. LQ24F020008, LQ24F020009) and the “Pioneer”
and “Leading Goose” R&D Program of Zhejiang (No. 2023C01119).

References
1. Liu, Y., et al.: A blockchain-based decentralized, fair and authenticated information
sharing scheme in zero trust internet-of-things. IEEE Trans. Comput. 72(2), 501–
512 (2022)
2. Guo, R., Li, K., Li, X., Zhang, Y., Li, X.: Compact multiple attribute-based sig-
natures with key aggregation and its application. IEEE Syst. J. 16(2), 3025–3035
(2022)
3. Wang, C., Zhou, T., Shen, J., Wang, W., Zhou, X.: Searchable and secure edge
pre-cache scheme for intelligent 6G wireless systems. Futur. Gener. Comput. Syst.
140, 129–137 (2023)
4. Wang, X., Liao, L., Cao, D.: mpXim: a decentralized mixing scheme based on
a multi-party discovering protocol. In: 2023 IEEE 23rd International Conference
on Software Quality, Reliability, and Security Companion (QRS-C), pp. 263–272
(2023)
5. Wang, C., Shen, J., Vijayakumar, P., Gupta, B.B.: Attribute-based secure data
aggregation for isolated IoT-enabled maritime transportation systems. IEEE Trans.
Intell. Transp. Syst. 24(2), 2608–2617 (2023)
6. Zhou, T., Shen, J., Vijayakumar, P., Bhuiyan, M.Z.A., Sivaraman, A.: Anonymous
authentication scheme for federated learning. In: IEEE INFOCOM 2023-IEEE
Conference on Computer Communications Workshops (INFOCOM WKSHPS), pp.
1–6. IEEE (2023)
7. Wang, X., Lin, C., Huang, X., He, D.: Anonymity-enhancing multi-hop locks for
monero-enabled payment channel networks. IEEE Trans. Inf. Forensics Secur. 19,
2438–2453 (2024)
8. Zhang, L., Zhu, T., Xiong, P., Zhou, W.: The price of unlearning: identifying
unlearning risk in edge computing. ACM Trans. Multimedia Comput. Commun.
Appl. (2024)
9. Lu, N., Chang, Y., Shi, W., Choo, K.: Coinlayering: an efficient coin mixing scheme
for large scale bitcoin transactions. IEEE Trans. Dependable Secure Comput.
19(3), 1974–1987 (2020)
10. Ziegeldorf, J.H., Grossmann, F., Henze, M., Inden, N., Wehrle, K.: Coinparty:
secure multi-party mixing of bitcoins. In: Proceedings of the 5th ACM Conference
on Data and Application Security and Privacy, pp. 75–86 (2015)
A Decentralized Bitcoin Mixing Scheme Based on Multi-signature 55

11. Xue, J., Shi, L., Liu, L., Zhang, X., Li, F.: Anonymity-enhancing decentralized
protocol for coin mixing based on ring signatures and key derivation. Peer-to-Peer
Netw. Appl. 16(6), 2761–2774 (2023)
12. Xiao, R., Ren, W., Zhu, T., Choo, K.: A mixing scheme using a decentralized signa-
ture protocol for privacy protection in bitcoin blockchain. IEEE Trans. Dependable
Secure Comput. 18(4), 1793–1803 (2019)
13. Harn, L., Ren, J.: Efficient identity-based RSA multisignatures. Comput. Secur.
27(1–2), 12–15 (2008)
14. Chen, X., Yang, A., Tong, Y., Weng, J., Weng, J., Li, T.: A multisignature-based
secure and OBU-friendly emergency reporting scheme in VANET. IEEE Internet
Things J. 9(22), 23130–23141 (2022)
15. Shukla, S., Patel, S.J.: A novel ECC-based provably secure and privacy-preserving
multi-factor authentication protocol for cloud computing. Computing 104(5),
1173–1202 (2022)
Decentralized Continuous Group Key
Agreement for UAV Ad-Hoc Network

Shijia Hong1 , Tianqi Zhou1(B) , Huijie Yang1 , Mingdi Shen1 ,


and Wenying Zheng2
1
School of Information Science and Engineering, Zhejiang Sci-Tech University,
Hangzhou 310018, China
[email protected]
2
School of Computer Science and Technology, Zhejiang Sci-Tech University,
Hangzhou 310018, China

Abstract. With the rise of the Industrial Internet of Things (IIoT),


Unmanned Aerial Vehicles (UAVs) have emerged as the preferred sens-
ing tools for data collection. However, resource-constrained UAVs often
face challenges in handling complex group key agreement processes, while
key management in the flying ad hoc network (FANET) is also suscepti-
ble to single points of failure. To address this issue, this paper proposes a
blockchain-based continuous group key agreement protocol for FANET.
The blockchain is collectively maintained by UAVs and serves to record
dynamic changes within the FANET. Additionally, it employs a ratchet
tree to integrate a pseudo-random function generator with an updatable
public key encryption scheme, effectively mitigating the risk of single
points of failure, reducing communication overhead, and enhancing secu-
rity and self-healing capabilities. Simulation experiments demonstrate
that this protocol outperforms existing key agreement schemes in terms
of reduced delay and improved security.

Keywords: Unmanned Aerial Vehicles (UAVs) · flying ad hoc network


(FANET) · continuous group key agreement · ratchet tree · blockchain

1 Introduction

Unmanned Aerial Vehicles (UAVs), also known as drones, are aircraft that are
operated either by remote control or through embedded computer programs.
The development of UAV technology was initially driven by military needs [1],
and military drones still hold a significant position in the UAV market. Mili-
tary drones can be categorized based on their specific military applications and
operational missions, such as unmanned reconnaissance or surveillance aircraft,
combat drones [2], communication relay drones, electronic warfare drones, and
multi-purpose drones that integrate reconnaissance and strike capabilities. His-
torically, drones have been primarily used for military applications in hostile
territories, conducting long-range surveillance and armed attacks to reduce pilot
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 56–69, 2025.
https://doi.org/10.1007/978-981-96-4566-4_4
Decentralized Continuous Group Key Agreement for UAV Ad-Hoc Network 57

casualties [3]. Nowadays, advancements in drone manufacturing technology and


cost reductions have led to a surge in enthusiasm for using drones in civilian and
commercial applications, bringing them into the public eye [4]. Today, drones are
widely used in precision agriculture, traffic management, telecommunications,
and other fields [5].
Group Key Agreement (GKA) is a critical technology in UAV communica-
tions, ensuring secure communication among multiple drones. In Flying Ad-hoc
Networks (FANETs) or any multi-drone collaborative missions, GKA allows a
group of drones to share a common secure key, which is used for encrypting and
decrypting data, thereby protecting communications from eavesdropping or tam-
pering. As an emerging technology, Continuous Group Key Agreement (CGKA)
[6, 7] offers significant advantages in FANETs. First, CGKA can quickly adapt to
frequent changes in network topology, maintaining continuous and secure com-
munication even when members join or leave. Second, for resource-constrained
devices like sensor nodes or small drones, CGKA can optimize algorithms to
reduce computational and communication overhead, extending device operation
time and improving overall performance [1, 8].
However, existing solutions have several limitations. One issue is the lack
of consideration for key updates, meaning that keys generated during the ini-
tialization phase remain unchanged throughout subsequent operations. Another
problem is the centralized key management approach, where a base station (BS)
handles all key management for the drones [9]. This method, while simple, is
vulnerable to single-point failures; if base station is compromised for malicious
attacks or DoS attacks, the security of the entire net is at risk. Additionally, some
solutions incur excessive communication or computational overhead, which can
prematurely deplete the limited battery capacity of drones. The complexity of
the external environment can also cause temporary disconnections, leading to
group key agreement failures and compromising the security of FANET commu-
nications [10, 11].
To address these challenges, we design a continuous group key agreement for
FANET. This protocol guarantees that even if some members fail to join the key
agreement process, the whole cluster can still share a confidential communication
key. In addition, we develop a communication strategy that aims to reduce the
cost of UAV communication. In addition, to assist UAVs to quickly resynchronize
to the latest group key after experiencing a wireless link failure, we develop a
group key recovery procedure. The main contributions of this paper are broken
down as follows:
1. We design a blockchain maintained collectively by the drones, redefining
the block structure. The joining, revocation, and key updates of drones are
recorded as transactions in the blockchain, with the ground control center not
participating in the key agreement process, thus avoiding single-point failures.
2. We utilize a ratchet tree structure to reduce communication overhead, where
the leaf nodes are associated with drones, and the root node represents the
shared group key. The communication complexity for key agreement using
this binary tree structure is O(log(n)).
58 S. Hong et al.

3. We combine a pseudorandom function generator with an updatable public-


key encryption scheme to ensure forward security in the communications and
enable rapid self-healing.

2 Related Work
2.1 Continuous Group Key Agreement
Continuous group key agreement allows a group of users to obtain a shared key in
an asynchronous environment [12]. For the standard TreeKEM, In 2020, Alwen
et al. showed that TreeKEM has problems with forward secrecy because its users
do not erase old keys fast enough [13]. This can lead to the previously updated
keys being fully revealed by an attacker and presenting an updatable public key
encryption scheme in the event of a state disclosure. In addition, there are many
variants of TreeKEM, Klein et al. proposed Tainted TreeKEM (TTKEM) [14],
which efficiently manages the security of group keys by introducing a “tainting”
mechanism that remains effective even in the presence of dynamic changes in
group membership. Furthermore, Gordon et al. proposed a taxonomy for secure
group messaging for asynchronous collaborative applications supporting end-
to-end encryption [15, 16]. The Causal TreeKEM protocol, a new protocol for
managing shared encryption keys in dynamic groups without a centralized server
was designed.
To summarize, continuous group key agreement is mostly used even for mes-
sage passing and has not been applied in wireless sensor networks so far. The
efficiency and security of Continuous Group Key Agreement is crucial for wireless
sensor networks because the devices are usually limited in resources and large
in number. The CGKA protocol ensures that the communication is secure and
the data is private even when the devices join and leave the group frequently.

2.2 Key Management for UAV Networks


In wireless sensor networks, the main purpose of group key agreement is to ensure
the security and confidentiality of data in the network. A great deal of research
has been conducted in this area. Li et al. [17] designed a self-healing group key
distribution scheme based on a private blockchain, where the blockchain is used
for distributing and storing group keys and managing dynamic membership in
UAANET. Tan et al. [18] proposed a blockchain-based distributed key manage-
ment scheme for heterogeneous FANETs, where the lead drone takes on critical
coordination tasks. If the lead drone fails or is attacked, the communication and
task coordination of the entire drone network can be severely affected, poten-
tially leading to network paralysis. Once the lead drone fails, its tasks cannot
be quickly transferred to other nodes, impacting the overall system’s resilience
and reliability. Zhang et al. [19] proposed an authenticated threshold group key
agreement protocol based on Shamir’s secret sharing, which generates a usable
group key even when some drones temporarily disconnect. However, the proto-
col has high computational requirements during key re-negotiation, especially
Decentralized Continuous Group Key Agreement for UAV Ad-Hoc Network 59

when members join or leave, which can impose a significant burden on resource-
constrained devices like drones [20].
As a result, existing research solutions usually require large amounts of com-
putational and communication resources or face the danger of a single point
of failure, and it is also thought-provoking to achieve how UAV self-organizing
networks can self-heal from the insecurity of key leakage.

3 Preliminaries
3.1 Left Balanced Binary Tree
A left balanced binary tree is usually referred to as an AVL tree (Adelson-Velsky
and Landis tree) [21], which is a self-balancing binary search tree. In an AVL
tree, the heights of the left and right subtrees of each node differ by at most 1,
i.e., the balancing factor (the height of the left subtracts the height of the right)
of any node at any time must be −1, 0, or 1. This condition ensures that the
height of the tree is kept at the logarithmic level, which guarantees the efficiency
of the operation. Two paths are mentioned in the scheme of this paper:
1) The direct path dpath, the direct path from the leaf nodes to the root node,
for each leaf node v, the dpath(v) = (v0 = v, v1 , . . . , vl = vroot ), and
2) The copath copath, this is the sequence of node sibling nodes along the
direct path, copath(v) = (v0 , v1 , . . . , vl−1

).

3.2 Cryptographic Tools


Pseudorandom Generator. A PseudoRandom Generator (PRG), which gen-
erates a series of seemingly random numbers that have certain properties sta-
tistically similar to those of true random numbers, but are actually computed
deterministically from a fixed seed value. The pseudo-random generator used in
this paper adopts the hash function defined as {0, 1}∗ → {0, 1}∗ .

Updatable Public Key Encryption. An updatable public key encryption


(UPKE) scheme refers to an encryption mechanism that allows periodic or on-
demand updating of public and private key pairs without compromising the pri-
vate key. This scheme is typically used in scenarios where long-term secure com-
munication is required, and where it is desirable to minimize the risks associated
with key compromise. UPKE = (UKG, UE, UD), the key generation algorithm
is (pks , sks ) ← U KG(κ), the encryption algorithm is (pk  , c) ← UE(pk, m), the
decryption algorithm is (sk  , m) ← UD(sk, c), where κ is the security parame-
ter, (pks , sks ) is the initial key pair, the UE and UD algorithms result in a new
public-private key pair.
60 S. Hong et al.

4 Our Proposal
4.1 System Initialization
Prior to the commissioning of the UAVs, the GCS will be responsible for assigning
keys to UAV and constructing a shared set of parameters for FANET. These
parameters will be stored in advance in the UAV’s storage system. The specific
steps are summarized below:
Step1: The GCS first constructs a empty ratchet tree in which the root node is
set to the current group key, the inner nodes are labeled with key pairs for use
in an updatable public key encryption, and the leaf nodes are labeled similarly
to the inner nodes but their public keys are used to represent the identity of the
UAV.
Step2: GCS selects a random integer si for each U AVi , i ∈ (1, n) and allocates
a pair of public and private keys (pkt , skt ) ← U KG(si ), t ∈ (1, 2n − 1) to each
node except the root. As for the root node, the GSC places a secret value here
as the group key for the group. Each leaf node is associated with the drone, and
the public key on the leaf node can be used as the identity of the drone. Figure 1
shows the specific distribution.

Fig. 1. Key distribution by GCS.

4.2 Block Building


1) The round hash and the transcript hash
In this paper, we employ the round hash and the transcript hash to assit
communication in FANET. A round hash is actually a commitment for dynamic
manipulation of the UAV and the round of a ratchet tree, UAV can verify the
Hround (n) deliver by the miner, regardless of the number of updates. The round
hash is typically defined as:

Hround (n) = H(GK, τ, U, R, J)

where GK refers to the current group key, τ means the public keys of the ratchet
tree and U , R and J represent the update, revocation and join of UAV, respec-
tively.
Decentralized Continuous Group Key Agreement for UAV Ad-Hoc Network 61

The Htrans (0) = 0 initially, and for next rounds, given a verified:

Htrans (n) = H(Htrans (n − 1)  Hround (n))

Therefore, we can make sure the current legal UAVs have a consistent state,
which requires have a consistent view of the tree and to agree on the group keys,
members, and history.
1) The block structure. The block structure is shown in Fig. 2. The block
header consists of hash of the previous block H(Blockn−1 ), block generator,
timestamp, miner’s signature on the block, and the transcript hash representing
the history of the group Htrans (n − 1). The body consists of a series of trans-
actions, where each transaction includes the type of operation, the initiator of
the operation, the legal UAVs, the public keys of the ratchet tree and the round
hash.

Fig. 2. The block structure.

4.3 Group Key Update


In some cases, such as the detection of a malicious drone or the leakage of the
UAV’s secret key, the key pairs must be updated to enhance system security.
Suppose U AVup want to update its own key pair, U AVup select a seed value sup ,
put the sup into the H to generate the private key sk0 and the seed value of
the next node: (sk0  s1 ) ← H(sup ), and rotate to the root in this way, when
the U AVup updates the keys of all nodes on its path, it will send the calculated
new ratchet tree state and the update request to the miners of this round. The
request sends the secret value of each node on the path to which it belongs and
the identifier of the updated transaction that already exists in the block to the
62 S. Hong et al.

copath of the node, and each drone confirms whether the transaction exists first.
After successful verification, it is decrypted based on the received message until
the shared key in the root node is calculated. Specific steps are as follows:
Step 1: U AVup makes an update request

up_request = {request, Sig(sgkup (enc(pkm , request))}

to the miner in this round;


Step 2: After receiving the request, the miner U AVm verifies the legal identity
of the UAV, if the verification fails, it will discard the change request, after
the verification is passed, the miner allows the update operation and broadcast
Sig(sgkm , pkup , admit) to the group.
Step 3: After receiving the permission, U AVup 1) chooses a uniformly random
sup = s0 , 2) computes (ski  si+1 ) ← H(si ) for i ∈ (0, d − 1), then encrypts
the secrets and sends them to copath, ci ← UE(oi .pk, si ), where oi represents a
node on the path of U AVi and oi represents the corresponding copath, 3) U AVup
computes pki ← UKG(ski ) and fresh the public keys of the ratchet tree.
Step 4: The UAVs obtain the ciphertext belonging to their own path, si ←
UDec(oi .sk, ci ), and (ski  si+1 ) ← H(si ) until compute the GK.
Step 5: U AVm generates the update transaction

(U pdate, pkup , Htrans (k), Sig(sgkm , U pdate, pkup , Htrans (k))

which includes the operation type, the initiator of the operation and the tran-
script hash. U AVm uploads the transaction to the block. Every group members
can verify the transcript hash to ensure that their stored ratchet tree is up-to-
date and in a consistent state with the rest of the UAVs. Figure 3 shows the
ratchet tree in action:

Fig. 3. The ratchet tree.


Decentralized Continuous Group Key Agreement for UAV Ad-Hoc Network 63

4.4 Malicious Drone Revocation


Suppose the U AVn detects malicious behavior in the U AVx , U AVn make a
whistle-blowing requests to U AVm :
wb_request = {request, Sig(sgkn (enc(pkx , request))}
When the revocation request exceeds half of the total number of drones, U AVm
will recalculate the key pairs on the path of the malicious drone as if updates
its own path, so the ratchet tree will enter a new state where communication
between groups excludes the malicious drone. The state of the ratchet tree after
undoing is shown in Fig. 4. The miner will generate a revocation transaction
containing:
(Revo, pkn , Htrans (k), Sig(sgkm , Revo, pkn , Htrans (k))
There is a problem here, when the miner recalculates the path of the malicious
drone, the miner acquires secret values other than his own path, these nodes are
the ones that are contaminated by the miner. If the miner has multiple paths,
this can seriously undermine the security of group communication. So at this
point the path is marked pathx tainted by U AV m which is included in the revo.
So when new miner appear, the tainted paths will be refreshed.

Fig. 4. Malicious drone revocation.

4.5 Drones Join


In the process of performing a mission of the drone swarm, it is very likely that
there will be an unexpected situation of insufficient resources, and new drone
support will need to be added at this time. Suppose the U AVj needs to join
the current group, first complete the registration on the ground, and fly to the
specified location, the current miner add a new leaf which will represent the
U AVj .
U AVj selects a uniformly random sj and derive all secret values on the path
and send cj ← UEnc(oj .pk, sj ) to the copath. U AVm will organize the public
state of the entire tree and generate Join transactions:
(Join, pkj , Htrans (k), Sig(sgkm , Join, pkj , Htrans (k))
64 S. Hong et al.

5 Security Analysis
In this section, we conduct a detailed security analysis of the key agreement
protocol proposed in this paper.
Consistency: All group members obtain the same group key in each round,
regardless of key updates, member additions, or the revocation of malicious
drones. After each dynamic operation, a new transaction is generated by miners
and uploaded to the blockchain. This ensures that the current group key is
acknowledged by all drones, and once uploaded, the key becomes immutable
and traceable.
Forward Secrecy: It is defined that each round of group key updates is initiated
by PRF. The update process follows the (ski  si+1 ) ← H(si ). Since the PRF
is forward-secure, even if an adversary obtains either ski or si+1 , they cannot
compute the previous key, si . Furthermore, even if the current group key is
compromised, it does not allow the adversary to decrypt past session data, as
each round’s key is independently generated by the PRF.
Post-compromise Security: The entire group key agreement is structured
based the ratchet tree, with the root node representing the shared session key of
the group. If the current session key is compromised, any group member can ini-
tiate a new update operation. After the update, the ratchet tree restores security,
and the leaked session key no longer affects the security of future communica-
tions.

5.1 Formal Security Analysis


The GSD Game. Generalized Selective Decryption (GSD) game is a security
game for symmetric encryption schemes introduced by Panjwani in 2007 [22].
In the GSD game, there are n keys k1 , k2 , . . . , kn , where the adversary can
adaptively corrupt, ask for encryptions Eki (kj ) of keys under other keys. The
adversary’s task is to distinguish keys (which it cannot trivially compute) from
random. The adversary A can make three types of queries during the game,

1) Encryption query: At any given time, participant A has the ability to


initiate an encryption request of the form (i, j), in response to which B will
generate a ciphertext c using a new random number, encrypt kj to obtain c,
and make this ciphertext c available to A, and
2) Corruption query: Adversary A also has the ability to query for any specific
information about the key initially generated by B; this can be accomplished
by making a request of the form corrupt(i). Once the request is received, B
will reveal the key ki to A, and
3) Challenge query: A is entitled to initiate a challenge request of the form
challenge(i). For this request, a response determined by b will be executed
by B: if b is equal to 0, B will provide the key ki directly to A. If b is equal to
1, B will randomly choose a value ri from the set {0, 1}κ and send ri to A.
Decentralized Continuous Group Key Agreement for UAV Ad-Hoc Network 65

If b = 0, the adversary receives the key ki ; if b = 1, it receives a uniformly


random ri . Finally, the adversary outputs a guess bit b .

Theorem 1. Assume that H is a (th , εh )-secure pseudo-random generator, Σ is


a (tcpa , εcpa )-CPA secure updatable public key encryption scheme. The UACGKA
is (t, c, n, ε) adaptive GSD secure, where ε = 2cn(εh + εcpa ), t = th = tcpa .

Proof. A launches only one challenge request. Recall that the update operation
for a node of depth d involves the processing of a randomly chosen initial value
s0 to produce a sequence of values:
h h h
s0 −
→ (sk 0  s1 ) −
→ (sk 1  s2 ) · · · −
→ (sk d−1  sd )
where GK = sd is the dynamic group key. In addition, the update process utilizes
the UPKE algorithm with CPA security to encrypt each secret value s under the
key corresponding to the node on the copath:
UE UE UE
s1 −−→ c1 , s2 −−→ c2 , · · · , sd −−→ cd .

The total hybrid degree of CPA security is


d−2
n
2i · εcpa = (2d−1 − 1)εcpa = ( − 1)εcpa ≤ (n − 2)εcpa . (1)
i=0
2


d−1
2i · εh = (2d − 2)εh = (n − 2)εh ≤ nεh . (2)
i=1

From (1) and (2), we get

εd ≤ ε0 + (n − 2)εcpa + nεh
1
≤ + (n − 2)εcpa + nεh .
2
Therefore,
   
 1   1 
ε := Pr[A wins] − = εd − ≤ (n − 2)εcpa + nεh .
2  2
66 S. Hong et al.

Table 1. Comparison of computation and communication costs.

Scheme Comp. Cost Comm. Cost


CRA-DGK (2n + 2)Te + (3 + n)Tsm + 4Th 32(n − 1)2
TGKA 8nTche + 4nTe /s + 17nTh + tTs r 300n − 192
TAAGKA 6nTbp + 9nTsm + 12nTh 428n + 96
BGKA n(n − 1)(Tsm + Tsa + 2Te /s) 16n(n − 1)
Ours (log (n) + 1)Te + (log (n) + 2)Th + Tsig 64 log (n) + 96

6 Performance Evaluation
In this part, we will assess the efficiency of the UACGKA protocol. Addition-
ally, we will contrast UACGKA with the current methods CRA-DGK, TGKA,
TAAGKA, and BGKA to demonstrate the practicality of UACGKA.

6.1 Experiment Settings

All experiments were conducted on servers equipped with AMD Ryzen 7 6800H
with Radeon Graphics 3.20 GHz, 16 GB RAM, and Windows 11. We used the
Public Cryptography Library PBC to implement the underlying encryption algo-
rithm, where the asymmetric encryption uses secp256k1 elliptic curves to achieve
a security strength of 128 bits. Considering the efficiency as well as the ease of
implementation of aggregated signatures, we chose the secp256k1-based Schnorr
signature to perform the signing and verification process. In addition, the stan-
dard SHA256 is used as the hash function.

6.2 Evaluation of the GKA Protocol Efficiency


1) Comp. Cost. For UACGKA, the UAV first initiates an update to the miner
and needs to refresh the secret values of all nodes in its own path and send to the
copath. the drone’s computation consists of log (n) + 1 encryptions, log (n) + 2
hashes, and one signature, so the total computational overhead is (log n + 1)Te +
(log n + 2)Th + Tsig .
The evaluations of other options are similar to those described above and will
not be detailed here. The Fig. 5 illustrate the computational and communication
costs of different scenarios as the number of drones increases or decreases. It
can be observed that UACGKA has significant advantages over other schemes
in terms of computational overhead.
Decentralized Continuous Group Key Agreement for UAV Ad-Hoc Network 67

Fig. 5. Computation cost.

2) Comm. Cost. In UACGKA, each drone needs to send a request


Sig(sgkup (enc(pkm , M essage)) to the miner and needs to send secret values
to the copath. The size of the above message is 64 log (n) + 96. The rest of the
scenarios are analyzed similarly to the above. Table 1 shows the relevant results
and they are also visualized in Fig. 6. It can be observed that as the number of
UAVs grows, the communication overhead of UACGKA rises accordingly, but
the performance advantage over existing methods remains outstanding.

Fig. 6. Communication cost.

7 Conclusion
The UACGKA protocol proposed in this paper offers a secure and efficient
key management solution for self-organizing UAV networks. By integrating
68 S. Hong et al.

blockchain technology, we eliminate the risk of single points of failure and ensure
decentralized key management. The introduction of ratchet trees reduces com-
munication overhead, while the combination of pseudo-random generators and
updatable public-key encryption (UPKE) schemes enhances the forward security
of communications. Security proofs and performance evaluations further validate
the applicability and efficiency of UACGKA in dynamic environments. Future
work will focus on optimizing and practically deploying UACGKA across various
UAV application scenarios.

Acknowledgments. This work is supported by the National Key R&D Program of


China (No. 2023YFB2703700), the National Natural Science Foundation of China (Nos.
U21A20465, 62402444, 62402448), the Fundamental Research Funds of Zhejiang Sci-
Tech University under Grants Nos. 22222266-Y, 24222238-Y, the Program for Lead-
ing Innovative Research Team of Zhejiang Province (No. 2023R01001), the Zhejiang
Provincial Natural Science Foundation of China (Nos. LQ24F020009, LQ24F020012)
and the “Pioneer” and “Leading Goose” R&D Program of Zhejiang (No. 2023C01119).

References
1. Wang, C., Zhou, T., Shen, J., Wang, W., Zhou, X.: Searchable and secure edge
pre-cache scheme for intelligent 6G wireless systems. Futur. Gener. Comput. Syst.
140, 129–137 (2023)
2. Tian, C., Jiang, Q., Li, T., Zhang, J., Xi, N., Ma, J.: Reliable PUF-based mutual
authentication protocol for UAVS towards multi-domain environment. Comput.
Netw. 218, 1389–1286 (2022)
3. Tan, Y., Liu, J., Kato, N.: Blockchain-based lightweight authentication for resilient
UAV communications: architecture, scheme, and future directions. IEEE Wirel.
Commun. 29(3), 24–31 (2022)
4. Zhang, C., et al.: UAV swarm-enabled collaborative secure relay communications
with time-domain colluding eavesdropper. IEEE Trans. Mob. Comput. 23(9),
8601–8619 (2024)
5. Panjwani, S.: Tackling adaptive corruptions in multicast encryption protocols. In:
Proceedings of the 4th Conference on Theory of Cryptography, pp. 21–40 (2007)
6. Semal, B., Markantonakis, K., Akram, R.N.: A certificateless group authenticated
key agreement protocol for secure communication in untrusted UAV networks. In:
Proceedings of the 37th Digital Avionics Systems Conference (DASC), pp. 1–8
(2018)
7. Alwen, J., Coretti, S., Jost, D., Mularczyk, M.: Continuous group key agreement
with active security. IACR Cryptol. ePrint Arch. (2020)
8. Wang, C., Shen, J., Vijayakumar, P., Gupta, B.B.: Attribute-based secure data
aggregation for isolated IoT-enabled maritime transportation systems. IEEE Trans.
Intell. Transp. Syst. 24(2), 2608–2617 (2023)
9. Tan, Y., Wang, J., Liu, J., Kato, N.: Blockchain-assisted distributed and
lightweight authentication service for industrial unmanned aerial vehicles. IEEE
Internet Things J. 9(18), 16928–16940 (2022)
10. Xu, Z., Liang, W., Li, K.C., Xu, J., Zomaya, A.Y., Zhang, J.: A time-sensitive
token-based anonymous authentication and dynamic group key agreement scheme
for industry 5.0. IEEE Trans. Industr. Inf. 18(10), 7118–7127 (2021)
Decentralized Continuous Group Key Agreement for UAV Ad-Hoc Network 69

11. Wang, L., Tian, Y., Zhang, D., Yanhua, L.: Constant-round authenticated and
dynamic group key agreement protocol for D2D group communications. Inf. Sci.
503, 61–71 (2019)
12. Kajita, K., Emura, K., Ogawa, K., Nojima, R., Ohtake, G.: Continuous group key
agreement with flexible authorization and its applications. In: Proceedings of the
9th ACM International Workshop on Security and Privacy Analytics, pp. 3–13
(2023)
13. Alwen, J., et al.: Cocoa: concurrent continuous group key agreement. In: Proceed-
ings of the 41st Annual International Conference on the Theory and Applications
of Cryptographic Techniques, pp. 815–844 (2022)
14. Klein, K., et al.: Keep the dirt: tainted treekem, adaptively and actively secure
continuous group key agreement. In: Proceedings of the 2021 IEEE Symposium on
Security and Privacy (SP), pp. 268–284 (2021)
15. Zhang, L., Zhu, T., Zhang, H., Xiong, P., Zhou, W.: Fedrecovery: differentially
private machine unlearning for federated learning frameworks. IEEE Trans. Inf.
Forensics Secur. 18, 4732–4746 (2023)
16. Chen, H., Zhu, T., Liu, C., Shui, Yu., Zhou, W.: High-frequency matters: attack
and defense for image-processing model watermarking. IEEE Trans. Serv. Comput.
17(4), 1565–1579 (2024)
17. Zhang, C., et al.: UAV swarm-enabled collaborative secure relay communications
with time-domain colluding eavesdropper. IEEE Trans. Veh. Technol. 23(9), 1536–
1233 (2024)
18. Tan, Y., Liu, J., Kato, N.: Blockchain-based key management for heterogeneous
flying ad hoc network. IEEE Trans. Industr. Inf. 17(11), 7629–7638 (2020)
19. Zhang, Z., et al.: TAGKA: threshold authenticated group key agreement protocol
against member disconnect for UANET. IEEE Trans. Veh. Technol. 72(11), 14987–
15001 (2023)
20. Semal, B., Markantonakis, K., Akram, R.N.: A certificateless group authenticated
key agreement protocol for secure communication in untrusted UAV networks. In:
Proceedings of the 37th Digital Avionics Systems Conference (DASC), pp. 1–8
(2018)
21. Feng, C., Liu, B., Guo, Z., Yu, K., Qin, Z., Choo, K.: Blockchain-based cross-
domain authentication for intelligent 5G-enabled internet of drones. IEEE Internet
Things J. 9(8), 6224–6238 (2022)
22. Cohn-Gordon, K., Cremers, C., Garratt, L., Millican, J., Milner, K.: On ends-to-
ends encryption: asynchronous group messaging with strong security guarantees.
In: Proceedings of the 2018 ACM SIGSAC Conference on Computer and Commu-
nications Security, pp. 1802–1819 (2018)
Efficient Homomorphic Approximation
of Max Pooling for Privacy-Preserving
Deep Learning

Peng Zhang1 , Dongyan Qiu1 , Ao Duan1 , and Hongwei Liu2(B)


1
The Guangdong Key Laboratory of Intelligent Information Processing, College of
Electronics and Information Engineering, Shenzhen University, Shenzhen 518060,
Guangdong, China
[email protected], {2300432061,2110436077}@email.szu.edu.cn
2
College of Big Data and Internet, Shenzhen Technology University,
Shenzhen 518118, Guangdong, China
[email protected]

Abstract. Privacy-Preserving Deep Learning (PPDL) using Fully


Homomorphic Encryption (FHE) addresses potential data privacy expo-
sure risks associated with deploying deep learning models in untrusted
cloud environments. FHE-based PPDL enables users to encrypt their
data locally, allowing cloud service providers to perform computations
directly on the encrypted data without ever accessing it. However, FHE
faces challenges in efficiently handling nonlinear computations, which
are essential to deep learning frameworks, particularly in implement-
ing activation functions and max pooling layers. To address this chal-
lenge, we improve the homomorphic max pooling approximation scheme
HM axP ool, and propose an efficient homomorphic max pooling algo-
rithm, denoted as HM axP ool+ . Specifically, we first redefine the approx-
imation of the maximum function, and then design sub-algorithm to
address the potential output expansion issue. The experimental results
validate that both the homomorphic max pooling algorithm itself and
the corresponding homomorphic SqueezeNet neural network demonstrate
enhancements in computational efficiency and accuracy.

Keywords: Privacy-Preserving Deep Learning · Fully Homomorphic


Encryption · Max Pooling · Homomorphic Approximation

1 Introduction

Deep learning, known for its ability to automatically extract features from large
datasets and achieve state-of-the-art performance, has become a transformative
force in numerous fields, including image recognition [14] and medical prediction
[15]. However, the training and inference processes in deep learning often depend
on sensitive user data, which raises significant privacy concerns. To address these
challenges, researchers have developed a range of techniques aimed at balancing
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 70–80, 2025.
https://doi.org/10.1007/978-981-96-4566-4_5
Efficient Homomorphic Approximation of Max Pooling 71

data privacy with deep learning computations. Among these approaches are
federated learning [21], differential privacy [1], secure multi-party computation
[18], and fully homomorphic encryption [12].
Deep learning models often rely on the cloud’s extensive computational
resources to perform inference on user data. However, to avoid exposing sen-
sitive data to untrusted cloud environments, Privacy-Preserving Deep Learning
(PPDL) utilizing Fully Homomorphic Encryption (FHE) offers a more secure
and effective solution for such scenarios [19]. FHE enables computations to be
performed on ciphertext, with the decrypted output yielding identical results to
those obtained from computations on plaintext [7]. FHE schemes, such as BGV
[3] and CKKS [6], support both homomorphic addition and multiplication, mak-
ing them suitable for executing a variety of linear transformations commonly
used in deep learning algorithms. However, FHE faces challenges in efficiently
handling nonlinear computations [17], which are essential to deep learning frame-
works, particularly in implementing activation functions and max pooling layers.
To address this issue, polynomials are utilized to approximate nonlinear func-
tions homomorphicly. Cryptonets [8] replaces ReLU with the square activation
function. To reduce the approximation error, Chabanne et al. [4] and CryptoDL
[9] utilize high-order polynomials for approximating activation functions, which
inevitably leads to increased time cost. To strike a balance between approxi-
mation error and time efficiency, Wu et al. [22] propose trainable second-order
polynomials for this purpose. However, this additional training process is time-
consuming. Recently, Lee et al. [13] proposed to use a combination of low-order
minimax polynomials to efficiently and accurately approximate the activation
function.
As illustrated, previous research on homomorphic polynomial approximations
of nonlinear functions has primarily focused on activation functions, with less
attention given to the max pooling operation. Lee et al. [13] leverages homomor-
phic Sign to develop an approximation algorithm HM axP ool for the homomor-
phic max pooling function. Building on their work, our work focuses on devel-
oping a more efficient and high-precision approximation tailored specifically for
the max pooling function. Our main contributions are as follows:

– An efficient homomorphic max pooling algorithm, HM axP ool+ , is pro-


posed, which redefines the approximation of the maximum function and mit-
igates potential output expansion through a four-step process: initialization,
loop execution, power expansion, and recursive calls. Experimental results
for HM axP ool+ verify the improved computational efficiency and reduced
approximation error.
– We employ the SqueezeNet neural network model on the CIFAR-10 dataset
to evaluate its performance. We firstly design a homomorphic version of
SqueezeNet by replacing the traditional max pooling layer with our proposed
homomorphic algorithm HM axP ool+ . Experimental results demonstrate
that the homomorphic SqueezeNet utilizing HM axP ool+ achieves higher
inference accuracy and computational efficiency compared with HM axP ool.
72 P. Zhang et al.

2 Related Works
Deep learning on encrypted data is suitable for scenarios where users need to
delegate their data to untrusted parties for computation. It enables inference on
the user’s data while preserving their privacy. CryptoNets proposed by Gilad et
al. [8] is a pioneering work that applies deep learning models to encrypted data
for the first time. CryptoNets uses the fully homomorphic encryption scheme
of Bos et al. [2] and approximates ReLU with x2 , enabling neural networks to
evaluate ciphertext, thereby ensuring data privacy. This method not only proves
the feasibility of applying deep learning algorithms in ciphertext environments,
but also lays the foundation for subsequent research.
FHE is particularly good at processing linear calculations, such as addition
and multiplication, which makes it perform well when performing linear trans-
formations. However, nonlinear calculations, especially activation functions and
pooling operations common in deep learning, pose greater challenges to FHE.
CryptoNets [8] uses square activation function instead of ReLU and scaled aver-
age pooling instead of max pooling. Building on the foundation of CryptoNets,
Hesamifard et al. [9] employed a modified Chebyshev polynomial to approximate
the derivative of ReLU. Wu et al. [22] proposed a trainable second-order poly-
nomial approximation of the ReLU activation function, effectively reducing pre-
cision loss. Recently, Lee et al. [13] proposed to use a combination of low-order
minimax polynomials to efficiently and accurately approximate the activation
function. Phoenix [11] uses an approximate symbolic function called SgnHE
to homomorphically implement Argmax to replace softmax, and proposes a
homomorphic Argmax approximation algorithm called ArgmaxHE. Based on
ArgmaxHE, Zhang et al. [23] optimized the homomorphic Argmax approxima-
tion and proposed an improved homomorphic Argmax approximation algorithm,
which significantly reduces the inference delay.
However, most of these studies focused on the polynomial approximation of
the activation functions and Argmax, and paid less attention to the max pooling
function. Lee et al. [13] leverages homomorphic Sign to develop an approxima-
tion algorithm HM axP ool for the homomorphic max pooling function. Although
the method of Lee et al. [13] solves the basic problem of homomorphic max
pooling, it involves a large number of polynomial operations and has low com-
putational efficiency. Therefore, based on HM axP ool, we aim to propose a more
efficient homomorphic approximation of max pooling.

3 An Efficient Homomorphic Max Pooling Algorithm


The max pooling algorithm is essentially to find the maximum value in a
set of vectors. In general, max(a, b) is a non-linear function used to deter-
mine the maximum of two ciphertext values a and b. However, fully homo-
morphic encryption (FHE) struggles with non-linear computations on cipher-
text. The maximum function can be represented using the active function Sign
as max(a, b) = (a+b)+(a−b)·Sign(a−b)
2 . In the homomorphic max pooling algo-
rithm HM axP ool proposed by Lee et al. [13], the non-linear function Sign(x)
Efficient Homomorphic Approximation of Max Pooling 73

is approximated by a polynomial pα (x), where α denotes the approximation


precision. Thus, the maximum function is approximated as:

(a + b) + (a − b) · pα (a − b)
maxα
app (a, b) = (1)
2
To enhance computation efficiency in the homomorphic max pooling algo-
rithm, we note that each comparison involves the maxα
app (a, b) function. Reduc-
ing the number of scalar multiplications can lead to improved overall perfor-
mance. Therefore, we redefine the approximation of the maximum function as:

maxα α
app (a, b) = (a + b) + (a − b)p (a − b) (2)

It is obvious that the output of Eq. 2 will amplify the correct result by a
factor of two. In order to obtain correct result of max pooling, the proposed
homomorphic max pooling function HM axP ool+ (x1 , . . . , xn ) is defined as:

HM axExt(x1 , . . . , xn , n)
HM axP ool+ (x1 , . . . , xn ) = (3)
2log n
where n represents the number of elements involved in the comparison, and
x1 , . . . , xn are the ciphertexts corresponding to μ1 , . . . , μn .
The challenge is that when an odd number of items are involved in the com-
parison, pairwise comparisons will amplify the output, while the remaining item,
which does not participate in the comparison, will remain unchanged. To address
this potential issue, we design the sub-algorithm HM axExt(x1 , . . . , xn , n) in
Algorithm 1, which can be divided into four steps:

– Initialization (Line 1): Set the parameters odd, group, and i.


– Loop Execution (Line 2–8): While i < n, execute two conditional state-
ments:
• The first condition is used to compare pairs of elements by maxα app (a, b)
and output the larger value.
• The second condition handles cases where n is odd, grouping the remain-
ing single element separately.
– Power Expansion (Line 9–11): If the number of groups (group) is even
and n is odd, apply a power-of-2 expansion to the separately grouped number.
– Recursive Call (Line 12–16): Check if the number of groups equals 1. If
true, return X[0]; otherwise, recursively call the algorithm, setting group to
the number of elements to be compared in the next round.

Assuming the number of input ciphertexts is n = 5 and the initial input data
is X = [x1 , x2 , x3 , x4 , x5 ], this algorithm is structured into three rounds.
In the first round, we invoke HM axExt(x1 , x2 , x3 , x4 , x5 , 5). After the
loop execution, we obtain X = [max1,2 = maxα app (x1 , x2 ), max3,4 =
maxα app (x3 , x4 ), x5 ], with odd = 1 and group = 3. Based on the value of group
and odd, we perform power expansion, update X = [max1,2 , max3,4 , x5 ∗ 2], and
call HM axExt again with the new X.
74 P. Zhang et al.

Algorithm 1. Extended Homomorphic Max Pooling HM axExt(x1 , . . . , xn , n)


Require: X = [x1 , . . . , xn ]
Ensure: For each xi ∈ X, return the ciphertext extension of the maximum value
among µ1 , . . . , µn
1: Initialize: odd = 0, group = 0, i = 0
2: while i < n do
3: if i < n − 2 then
4: X[group] = maxα app (X[i], X[i + 1]); group = group + 1; i = i + 2
5: else
6: X[group] = X[i]; odd = odd + 1; group = group + 1; i = i + 1
7: end if
8: end while
9: if group%2 = 0 and n%2 = 1 then
10: X[group − 1] = X[group − 1] × 2odd ; odd = 0
11: end if
12: if group = 1 then
13: return X[0]
14: else
15: Invoke HM axExt(X, group)
16: end if

In the second round, we execute HM axExt(max1,2 , max3,4 , x5 ∗ 2, 3). After


the loop execution, we obtain X = [max1,2,3,4 = maxα app (max1,2 , max3,4 ), x5 ∗2],
with odd = 1 and group = 2. We continue to perform power expansion on x5
and get X = [max1,2,3,4 , x5 ∗ 2 ∗ 2].
In the third round, we execute HM axExt(max1,2,3,4 , x5 ∗2∗2, 2), and output
the result X[0] = maxα app (max1,2,3,4 , x5 ∗ 2 ∗ 2). In this case, we have:

X[0]
HM axP ool+ (x1 , . . . , x5 ) = . (4)
2log 5

4 Experimental Analysis for Homomorphic Max Pooling


In this section, all experimental tests were conducted on the same server. Specif-
ically, the server is equipped with an Intel(R) Core(TM) i9-10900X CPU, fea-
turing 10 cores and 256 GB of RAM. We utilized the Microsoft SEAL library
[20] and employed the RNS-CKKS homomorphic encryption scheme [5] for data
encryption. To ensure a 128-bit security level for all ciphertexts, we set the ring
polynomial degree N to 131072. Additionally, for the maximum and minimum
approximate polynomial Sign(a−b) used in the maxα app (a, b)function,we selected
the precision parameter α = 13 with polynomial coefficients {15, 15, 27}.
Common max pooling kernels in neural networks typically include sizes of
2 × 2, 3 × 3, and 4 × 4. For each kernel size, assume the input μ is an array of
m × m randomly distributed points in the range of −1 to 1, and denote X as
the ciphertext of μ. The ciphertext X is then fed into the homomorphic max
Efficient Homomorphic Approximation of Max Pooling 75

pooling function. The experiment was conducted 1,000 times to calculate the
average computational time and error.
The computational time for HM axP ool and HM axP ool+ is presented in
Table 2. It is evident that the HM axP ool+ algorithm enhances computational
efficiency to some extent. This improvement arises from a significant reduction
in the number of scalar multiplications. Specifically, when using a pooling kernel
size of 4 × 4, HM axP ool requires 15 homomorphic scalar multiplications, while
the newly proposed algorithm, HM axP ool+ , requires only 1 homomorphic scalar
multiplication, as shown in Table 1.

Table 1. The number of homomorphic scalar multiplications for Homomorphic Max


Pooling Algorithms

Pooling Kernel Size 2 × 2 3 × 3 4 × 4


HM axP ool [13] 3 8 15
HM axP ool+ 1 2 1

However, compared to ciphertext-ciphertext multiplication, the scalar mul-


tiplication is less time-consuming. As a result, the time spent on scalar mul-
tiplications represents a relatively small fraction of the overall computational
time for the homomorphic max pooling algorithm. Consequently, the adoption
of HM axP ool+ yields only a marginal improvement in total computational time.

Table 2. Computational Time for Homomorphic Max Pooling Algorithms

Pooling Kernel Size 2 × 2 3×3 4×4


HM axP ool [13] 50805 ms 134380 ms 249409 ms
HM axP ool+ 50528 ms 133913 ms 248527 ms

The computational error of HM axP ool and HM axP ool+ is presented in


Table 3, which is the difference between the output value and the expected
value. Regardless of the pooling kernel size, HM axP ool+ consistently achieves
lower computational error. This enhancement is attributed to the fact that the
homomorphic maximum function maxα app (a, b) amplifies X during each itera-
tion, making it easier for Sign to make accurate decisions. However, the over-
all improvement in computational error remains limited, as both homomorphic
max pooling algorithms leverage high-precision parameters, resulting in very low
computational errors in general.
76 P. Zhang et al.

Table 3. Computational Error for Homomorphic Max Pooling Algorithms

Pooling Kernel Size 2 × 2 3×3 4×4


HM axP ool [13] 0.000005 0.000005 0.000004
HM axP ool+ 0.000004 0.000002 0.000003

5 Privacy-Preserving Deep Learning Utilizing


Homomorphic Max Pooling

To compare the computational accuracy and time of privacy-preserving neu-


ral networks with different homomorphic max pooling algorithms, the network
model, experimental design, and experimental analysis will be described in this
section. The platform used for the experiments is consistent with that in Sect. 4.

5.1 Network Model

SqueezeNet is a lightweight convolutional neural network model proposed by


Forrest N. Iandola et al. in 2016 [10]. The model structure is shown in the Fig. 1,
where the numbers on the arrows represent the number of channels. The basic
structure of SqueezeNet consists of a convolutional layer (Conv layer) followed
by multiple Fire modules connected in series. Each Fire module is composed of
a squeeze layer and an expand layer. The squeeze layer uses 1 × 1 convolutions
to encode the input, which greatly reduces the number of input channels, while
the expand layer consists of both 1 × 1 and 3 × 3 convolutions, enabling effi-
cient feature learning. Each convolution layer is followed by a ReLU activation
function. After certain Fire modules, a max pooling layer with a 3×3 kernel and
a stride of 2 is typically used to downsample the feature maps. After the 10th
Fire module, a Global Average Pooling is used to reduce the spatial dimension of
each feature map to a single value and generate a fixed-size output. The Softmax
layer converts the output into a probability distribution on the client side.

Fig. 1. The SqueezeNet network model.


Efficient Homomorphic Approximation of Max Pooling 77

In order to inference over the encrypted data, the homomorphic approxima-


tion of the SqueezeNet model was designed by [16]. In SqueezeNet, both the
convolution layer calculations and the Fire module calculations involve linear
computations of features and model parameters, which are friendly for FHE. The
global average pooling layer is similar as well. As the ReLU activation function
is nonlinear, we replace ReLU with a trainable quadratic polynomial activation
function ax2 + bx, where the parameters a and b are learned during the training
phase. This replacement makes the model to capture the nonlinear characteris-
tics of the data homomorphic-friendly. For the max pooling layer, we replace it
with two homomorphic max pooling algorithms HM axP ool and HM axP ool+ ,
respectively. Finally, the homomorphic SqueezeNet network model is shown in
Fig. 2.

Fig. 2. Homomorphic SqueezeNet network model.

5.2 Experimental Design

The homomorphic SqueezeNet network model was rigorously trained on the


MNIST dataset and CIFAR-10 dataset. The MNIST dataset consists of 60,000
training examples and 10,000 testing examples of handwritten digit images, cat-
egorized into 10 classes representing digits from 0 to 9. Each image is a grayscale
picture that has been size-normalized to 28 × 28 pixels. CIFAR-10 dataset con-
tains 50,000 training images and 10,000 test images, each of which is 32×32
pixels in size, and has three channels in RGB format, divided into 10 categories
such as airplanes, cars, and birds. The key training parameters include a learning
rate set to 0.001 and a weight decay of 0.0005. The training process was con-
ducted over 200 iterations to optimize the parameters of each layer. Additionally,
the coefficients a and b of the trainable polynomial in the activation layer were
also learned during training. These parameters together define the structure and
functionality of the network.
During the inference phase, each prediction was performed on a single image,
generating a corresponding result. Once all predictions were made, the results
78 P. Zhang et al.

were grouped into 5 batches, with each batch containing 128 predictions. For
each batch, the success probability of the predictions was analyzed. The over-
all accuracy were determined by averaging the success probabilities of the 5
batches. As two different homomorphic max pooling algorithms HM axP ool and
HM axP ool+ were utilized, two experimental schemes are as follows.
1. PPNN+HMaxPool: Use the homomorphic SqueezeNet network model
with the max pooling algorithm HM axP ool proposed by Lee et al. [13].
2. PPNN+HMaxPool+ : Use the homomorphic SqueezeNet network model
with our proposed max pooling algorithm HM axP ool+ .

5.3 Experimental Analysis


The performance of the homomorphic approximation of SqueezeNet with
HM axP ool and HM axP ool+ is shown in Table 4. It can be observed that their
accuracy is quite similar, but there is a significant improvement in computa-
tional time. The improvement in accuracy is attributed to the increased input
size for the homomorphic comparison function in the homomorphic max pooling
algorithm. However, the overall improvement remains limited, as the inference
accuracy in PPNN primarily depends on the complexity of the network model.
The improvement in computational time is mainly due to the reduction in the
number of homomorphic scalar multiplications. The max pooling layer involves
a large number of homomorphic scalar multiplications, and the HM axP ool+
algorithm significantly reduces the number of such operations. However, relative
to the total computational time, the improvement is not very noticeable. This is
because the largest contributors to computational time in PPNN are operations
such as homomorphic multiplication and homomorphic rotation, whose cost is
significantly higher than that of a single homomorphic scalar multiplication.

Table 4. Computational Accuracy and Time for Homomorphic SqueezeNet with


Homomorphic Max Pooling Algorithms

Experimental Schemes Dataset Computational Accuracy Computational Time (ms)


PPNN+HMaxPool MNIST 98.21% 87735 ms
PPNN+HMaxPool+ MNIST 98.34% 86017 ms
PPNN+HMaxPool CIFAR10 83.58% 428532 ms
PPNN+HMaxPool+ CIFAR10 83.61% 427613 ms

Overall, the definition of maxα app (a, b) optimizes computation by minimizing


scalar multiplications, in conjunction with our proposed algorithm HM axP ool+ ,
thereby significantly enhancing the efficiency of the homomorphic max pooling
process. Moreover, as the difference between a and b increases, this definition
offers improved approximation quality. Together with our proposed algorithm,
this results in a slight enhancement in the accuracy of the homomorphic max
Efficient Homomorphic Approximation of Max Pooling 79

pooling process. Furthermore, in addition to SqueezeNet, our proposed algorithm


is also applicable to other neural network models with max pooling layers for
homomorphism.

6 Conclusion
Aiming to reduce the number of homomorphic scalar multiplications required
during the homomorphic max pooling algorithm, we first redefine the approxi-
mation of the maximum function, and then design HM axExt algorithm to deal
with the potential output expansion issue, so that an efficient homomorphic max
pooling algorithm, denoted as HM axP ool+ , is proposed. Theoretical analysis
and experimental results both indicate that, compared to HM axP ool, our algo-
rithm HM axP ool+ can significantly reduce the number of homomorphic scalar
multiplications across different pooling kernel sizes, thereby reducing computa-
tional time, while also ensuring inference accuracy of homomorphic SqueezeNet
neural network.

Acknowledgments. This work was supported by the Science and Technology Pro-
gram Project of Shenzhen under Grant SZWD2021012.

References
1. Blanco-Justicia, A., Sánchez, D., Domingo-Ferrer, J., Muralidhar, K.: A critical
review on the use (and misuse) of differential privacy in machine learning. ACM
Comput. Surv. 55(8), 1–16 (2022)
2. Bos, J.W., Lauter, K., Loftus, J., Naehrig, M.: Improved security for a ring-based
fully homomorphic encryption scheme. In: Cryptography and Coding: 14th IMA
International Conference, IMACC 2013, Oxford, UK, 17–19 December 2013. Pro-
ceedings 14, pp. 45–64. Springer (2013)
3. Brakerski, Z., Gentry, C., Vaikuntanathan, V.: (leveled) fully homomorphic encryp-
tion without bootstrapping. ACM Trans. Comput. Theory (TOCT) 6(3), 1–36
(2014)
4. Chabanne, H., De Wargny, A., Milgram, J., Morel, C., Prouff, E.: Privacy-
preserving classification on deep neural network. Cryptology ePrint Archive (2017)
5. Cheon, J.H., Han, K., Kim, A., Kim, M., Song, Y.: A full RNS variant of approx-
imate homomorphic encryption. In: Selected Areas in Cryptography–SAC 2018:
25th International Conference, Calgary, AB, Canada, 15–17 August 2018, Revised
Selected Papers 25, pp. 347–368. Springer (2019)
6. Cheon, J.H., Kim, A., Kim, M., Song, Y.: Homomorphic encryption for arithmetic
of approximate numbers. In: Advances in Cryptology–ASIACRYPT 2017: 23rd
International Conference on the Theory and Applications of Cryptology and Infor-
mation Security, Hong Kong, China, 3–7 December 2017, Proceedings, Part I 23,
pp. 409–437. Springer (2017)
7. Gentry, C.: A fully homomorphic encryption scheme. Stanford university (2009)
80 P. Zhang et al.

8. Gilad-Bachrach, R., Dowlin, N., Laine, K., Lauter, K., Naehrig, M., Wernsing, J.:
Cryptonets: applying neural networks to encrypted data with high throughput and
accuracy. In: International Conference on Machine Learning, pp. 201–210. PMLR
(2016)
9. Hesamifard, E., Takabi, H., Ghasemi, M.: Cryptodl: deep neural networks over
encrypted data. arXiv preprint arXiv:1711.05189 (2017)
10. Iandola, F.N.: Squeezenet: Alexnet-level accuracy with 50x fewer parameters and
< 0.5 mb model size. arXiv preprint arXiv:1602.07360 (2016)
11. Jovanovic, N., Fischer, M., Steffen, S., Vechev, M.: Private and reliable neural net-
work inference. In: Proceedings of the 2022 ACM SIGSAC Conference on Computer
and Communications Security, pp. 1663–1677 (2022)
12. Lee, J.W., et al.: Privacy-preserving machine learning with fully homomorphic
encryption for deep neural network. IEEE Access 10, 30039–30054 (2022)
13. Lee, J., Lee, E., Lee, J.W., Kim, Y., Kim, Y.S., No, J.S.: Precise approximation of
convolutional neural networks for homomorphically encrypted data. IEEE Access
11, 62062–62076 (2023)
14. Li, Y.: Research and application of deep learning in image recognition. In: 2022
IEEE 2nd International Conference on Power, Electronics and Computer Applica-
tions (ICPECA), pp. 994–999. IEEE (2022)
15. Liu, T., Siegel, E., Shen, D.: Deep learning and medical image analysis for covid-19
diagnosis and prediction. Annu. Rev. Biomed. Eng. 24(1), 179–201 (2022)
16. Lou, Q., Jiang, L.: Hemet: a homomorphic-encryption-friendly privacy-preserving
mobile neural network architecture. In: International Conference on Machine
Learning, pp. 7102–7110. PMLR (2021)
17. Marcolla, C., Sucasas, V., Manzano, M., Bassoli, R., Fitzek, F.H., Aaraj, N.: Survey
on fully homomorphic encryption, theory, and applications. Proc. IEEE 110(10),
1572–1609 (2022)
18. Pillai, S.E.V.S., Polimetla, K.: Enhancing network privacy through secure multi-
party computation in cloud environments. In: 2024 International Conference on
Integrated Circuits and Communication Systems (ICICACS), pp. 1–6. IEEE (2024)
19. Ranbaduge, T., Vatsalan, D., Ding, M.: Privacy-preserving deep learning based
record linkage. IEEE Trans. Knowl. Data Eng. (2023)
20. Microsoft SEAL (release 4.1). Microsoft Research, Redmond, WA (2023). https://
github.com/Microsoft/SEAL
21. Wen, J., Zhang, Z., Lan, Y., Cui, Z., Cai, J., Zhang, W.: A survey on federated
learning: challenges and applications. Int. J. Mach. Learn. Cybern. 14(2), 513–535
(2023)
22. Wu, W., Liu, J., Wang, H., Tang, F., Xian, M.: Ppolynets: achieving high prediction
accuracy and efficiency with parametric polynomial activations. IEEE Access 6,
72814–72823 (2018)
23. Zhang, P., Duan, A., Lu, H.: An efficient homomorphic argmax approximation for
privacy-preserving neural networks. Cryptography 8(2), 18 (2024)
Blockchain-Aided Revocable Threshold Group
Signature Scheme for the Smart Grid

Xiaozhi Deng1 , Qinqin Wu1 , Yi Tang1 , and Yongbao Wang2(B)


1 Power Dispatching and Controlling Center of Guangdong Power Grid Company Limited,
Guangzhou 510000, China
2 AISINO CO. LTD, Beijing 100195, China

[email protected]

Abstract. In smart grid, the electronic consumption data, including the electric
vehicle charging payment, the residential electricity payment and so on, involves
the privacy of each individual. The leakage of private consumption data is a serious
issue in the grid system. Therefore, aiming at the secure and efficient collection,
this paper presents a threshold group signature scheme with revocation mecha-
nism suitable for the smart grid. Combining the blockchain technology and group
signature, it addresses the problems of the semi-trust among the authorities and the
anonymity of entities. In addition, the revocation mechanism enables the tracking
and punishment on the malicious users. Finally, the security analysis and sim-
ulations demonstrate the properties and advantages in terms of secure features,
computational and communication costs.

Keywords: Threshold Group Signature · Revocation · Anonymity · Smart Grid ·


Blockchain

1 Introduction
The smart grid utilizes the digital technology to improve the reliability, efficiency, and
sustainability of electricity services. It integrates various technologies such as smart
meters, sensors, and automated control systems to enhance the management of electricity
supply and demand [1–3].
The smart grid integrates the renewable energy sources, such as solar and wind power.
As these sources can be variable and decentralized, the smart grid’s real-time data anal-
ysis helps balance supply and demand, ensuring a stable energy supply. Additionally, it
enhances grid resilience by quickly identifying and responding to outages or disruptions.
This not only reduces downtime, but also improves overall customer satisfaction.
One of the key features of the smart grid is its ability to enable two-way communi-
cation between utilities and consumers. This allows for real-time monitoring of energy
usage, empowering consumers to make informed decisions about their energy consump-
tion [4, 5]. Smart meters and charging stations, for instance, provide detailed usage data,
which can lead to cost savings and better energy management.

© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 81–89, 2025.
https://doi.org/10.1007/978-981-96-4566-4_6
82 X. Deng et al.

With the continuous development of modern industry, the security of smart grids is
directly related to the economy and security of the country. As smart grids rely heav-
ily on advanced communication technologies and interconnected devices, the potential
for cyber threats increases. The smart grid still faces some intractable problems on the
secure data transmission. First, the smart meters, charging stations and other electri-
cal equipment are all installed in the unattended public place, which is susceptible to
hacker attacks, such as tampering the establishment parameters, stealing the privacy of
consumers, and so on. Second, all the transactions are stored on the traditional servers,
making them a primary target for attacks and creating a single point of failure issue in
smart grid. At last, the malicious consumers are difficult for the smart grid system to
detect, and it is a challenge to hold them accountable and impose appropriate penalties
subsequently.
To address these challenges above, this paper proposes a threshold group signature
scheme with revocation mechanism suitable for the smart grid, aiming to achieve a secure
and efficient data management. The main contributions are as follows.
1) This work adopts a threshold group signature scheme to realize the real-time commu-
nication and record the consumer’s behaviors. The approach enhances the security
of consumer management and supports the batch message authentication, thereby
improving the real-time data interactions in smart grid. Moreover, the pseudonym is
used for preserving privacy of the consumers, reducing the risk of privacy breaches.
Pseudonymized consumers can still be analyzed and processed without exposing the
true identities.
2) Blockchain technology is employed to eliminate the issue of partial trust between
the authorities and consumers. Additionally, it takes the distributed storage model,
which avoids the issue of a single point of failure.
3) The blacklist mechanism is applied to revoke the identities of malicious users with
disruptive behavior and expired users. Real-time monitoring is conducted to detect
the malicious consumers, ensuring the long-term secure and stable operation of the
smart grid system.

2 Related Work

Smart grid is one of the typical applications of Internet of Things. Su et al. [6] proposed a
distributed attribute-based signature scheme for distributed electricity trading, which the
consumers can choose the trade objects freely without leaking their real identity. Taking
advantage of the immutability property of blockchain, it achieved signature verifiability
independent of dynamic changes in attributes. Zhu et al. [7] designed a lightweight ver-
ifiable certificate-based privacy-preserving data aggregation scheme without pairings
for smart grids. The protocol could collect the real-time energy consumption data to
provide services securely, while ensuring the privacy of individual users. Combining the
hash-based message authentication code and public key encryption with equality testing,
Zhang and Wei [8] presented a fine-grained privacy-preserving data aggregation scheme
supporting multifunctionality, such as data privacy, integrity, unlinkability, and fault tol-
erance. Tomar et al. [9] integrated fog computing and blockchain into the smart grid,
and provided a blockchain-based certificateless aggregate signature scheme that enjoyed
Blockchain-Aided Revocable Threshold Group Signature Scheme 83

the authentication, integrity, and privacy. Verma et al. [10] designed the first certificate-
based data aggregation scheme for smart grid for the purpose of eliminating complex
certificate maintenance and key escrow issue. Additionally, this scheme can resist the
attacks launched by a malicious certification authority. Shen et al. [11] employed con-
sortium blockchains and identity-based signatures to construct a blockchain-assisted
secure device authentication for the industrial Internet of Things (IIoT). In their pro-
tocol, the blockchain functions as an information-sharing platform, establishing trust
between different domains. Xue et al. [12] proposed a secure and efficient cross-domain
authentication scheme based on two cooperative blockchains for medical consortium
systems. Moreover, the chameleon hash was used to redact the state of the user, the
anonymity mechanism was utilized to enhance security, and to trace malicious users.

3 Payment Model in Smart Grid

Fig. 1. Blockchain-Aided Payment System Model in Smart Grid

As Fig. 1 shown, it illustrates the blockchain-based electricity payment architecture in


smart grid. We introduce all the entities involved in this scenario as follows.
84 X. Deng et al.

Power Plant (PP): The power plant is responsible for generating the power resource,
and then transmitting these resources to each power consumer for charging the vehicles,
lighting the houses, and so on.
Power Consumers (PC): These consumers have access to the electric resources in smart
grid. After being authenticated, they can pay the bill according to the used electricity.
Each consumer is equipped with an electricity meter to record the usage amount of the
electricity.
Payment Verifier (PV): The verifier collects the usage amount and payment voucher
of the electricity. After that, they verify the legality of these data and upload the trans-
actions on the blockchain. For the illegal payments, the payment verifier notify the grid
administrator to take measures on the illegal consumers.
Grid Administrator (GA): The grid administrator is responsible for managing power
consumers and maintaining the blockchain in smart grid, which generates the group
public key and holds the master private key. Moreover, it handles the identity registration
and revocation requests. In this context, the grid administrator is a complete trustable
entity.
Blockchain: The blockchain primarily maintains the identity registration information of
power consumers, which provides the immutability and addresses the trust issues among
the individuals. By updating the registration information of the consumers periodically,
the blockchain ensures the validity and legitimacy of the irrevocable users.

4 The Concrete Constructions


This scheme employs a group signature scheme with a threshold value (t, n). The set
ID = {id1 , id2 , · · · , idt } consists of t members, where t ≤ n. . In this section, the
threshold group signature protocol involves the following participants: Group Adminis-
trator (GA), Power Consumers (idi ), Payment Verifier (PV), and the electricity payment
messages M = {m1 , m2 , . . . , mt } that needs to be signed.

4.1 Setup
GA selects a cyclic group G with a generator P of prime order q. Then, it randomly
chooses a ∈ Zq∗ to calculate the public key PKGA = aP, and sets the private key
MSK = a. Define three secure hash functions.

H0 : {0, 1}∗ × Zq∗ → Zq∗ , H1 : {0, 1}∗ → Zq∗ . (1)

The common system parameters are.

params = {G, q, P, PKGA , H0 , H1 , H2 }. (2)

4.2 Key Extraction


GA randomly selects a secret value κi ∈ Zq∗ for each group member idi , and computes
PKidi = κi P, sets the secret key SKidi = κi . Next, defining a (t − 1)-degree polynomial

fi (x) = ai0 + ai1 x + ai2 x2 + · · · ai(t−1) xt−1 (3)


Blockchain-Aided Revocable Threshold Group Signature Scheme 85

with its coefficients aij ∈ Zq∗ , and j = 0, 1, · · · , t − 1. Thus, the secret value of each
member is SVidi = fi (0) = ai0 , and computes the public value PVidi = ai0 P. At last, each
group member idi generates a temporary pseudonym PIDi = H0 (idi , ai0 ) to preserve
the user’s identity privacy and be stored on blockchain.

4.3 Group Signature


Each group member idi uses its private key to sign the message

σi = SKidi H0 (idi , ai0 ) + H1 (mi )SVidi . (4)

Then, it broadcasts the signature set {σi }, where i ∈ [1, t].

4.4 Verification
PV receives these signature {σi }i∈[1,t] , it needs to verify the correctness of the signatures.
PV computes


t 
t
σ = σi , V = H1 (mi )PVidi (5)
i=1 i=1

and the verification message


t
Y = PIDi PKidi . (6)
i=1

When the number of messages is greater than and equal to t, the validity of signatures
is performed by verifying

σP = Y + V. (7)

4.5 Revocation
Provided that the user with pseudonym PIDi∗ is a malicious participant, GA can directly
reveal the member’s
 ∗true∗identity
 idi∗ by using of the
 corresponding

∗ , and
secret value ai0

add the tuples of idi , σi to the blacklist BL = idi , σi . ∗

5 Security Analysis
5.1 Immutability
Blockchain is built on the distributed networks, where each node holds a copy of the
chain. Adding a new block to the chain requires consensus from the majority of nodes
in the network. Once the data is encapsuled in the block, any attempt to tamper the data
would be detected, which is virtually impossible.
86 X. Deng et al.

5.2 Unlinkability
Suppose that an adversary can query the pseudonym of the target member idi . However,
the generation of these pseudonyms PIDi = H0 (idi , ai0 ) involves a secret value ai0
generated by the member. Therefore, the attacker cannot associate the PIDi with the
user’s real identity, which effectively prevents any direct link between the temporary
identities used in the system and the actual identities of member, thereby maintaining
the property of unlinkability.

5.3 Unforgeability
This scheme applies the group signature technology to finish the verification on the
electricity payment messages efficiently. Because of the private key SKidi = κi and
the secret value ai0 , it ensures that the attacker cannot forge a valid signature {σi } in
the process of signature generation. Hence, it completely eliminates the possibility of
payment data forgery, and guarantees the authenticity and immutability of it.

5.4 Privacy Protection


This scheme applies the temporary pseudonym PIDi to generate the valid signature
without disclosing its actual identity or other sensitive data. This mechanism significantly
enhances the ability to prevent identity theft and data breaches, thereby protecting the
privacy of the member in the smart grid.

5.5 Revocability
Provided that there is a malicious member idi in forging the electricity payment messages
 = {m
M 1 , m2 , . . . , mt }, its behaviors are recorded and uploaded to the blacklist BL =
idi , σi , and its corresponding access rights are revoked. Once idi is stored in the
blacklisted, it can no longer access any data and enjoy the electricity services in the
smart grid system.

5.6 Man-in-the-Middle Attacks Resistance


The registration and pseudonym generation are controlled by an authoritative entity
GA, preventing the adversary from unauthorized access. Additionally, the payment data
and other sensitive messages are securely encapsulated on the blockchain. Since data
recorded on the blockchain is immutable, the adversary cannot tamper with and replace
the verified on-chain data. This significantly increases the difficulty of conducting man-
in-the-middle attacks.

5.7 DDoS Attacks Resistance


As a distributed ledger technology, the blockchain inherently possesses decentralized
characteristics. This decentralization helps mitigate the issue of single points of fail-
ure inherent in traditional centralized network structures. In this proposal, the use of
blockchain for data storage and authentication avoids reliance on a single central server,
thereby significantly reducing the potential impact of DDoS attacks.
Blockchain-Aided Revocable Threshold Group Signature Scheme 87

6 Performance Evaluations
6.1 Computational Overhead
To evaluate the computational performance of this protocol, it makes some comparisons
with related schemes [11] and [12] in terms of key generation and signature stages.
The symbols of the time cost for various cryptographic operations in these schemes are
described in detail as follows.
Tsm : Execution time for scalar multiplication in G.
Th : Execution time for hash operations.
Tbp : Execution time for bilinear pairing operations.
Texp : Execution time for exponentiation operations.

Table 1. Computational overhead

Schemes Key Extraction Signature


[11] 3 Tbp + 2 Tsm + 2 Texp + 4 Th 3 Tbp + 5 Tsm + 2 Texp + 4 Th
[12] 6 Tsm + 2 Th 10 Tsm + 8 Th
Ours 2 Tsm + Th 2 Tsm + 2 Th

These simulations were conducted on a PC equipped with an Intel Core i7-12700H


CPU @ 2.3 GHz and 16 GB of RAM. The JPBC library was employed to simulate the
average computation overhead for the various operations in 1000 times, such as Tsm
is equal to 0.036 ms, Th is equal to 0.002 ms, Tbp is equal to 3.232 ms, Texp is equal
to 2.131 ms. As shown in Table 1, the computational overhead in key generation of
Ref. [11] is 3 Tbp + 2 Tsm + 2 Texp + 4 Th = 14.038 ms. In the process of signature,
the overhead of it is 3 Tbp + 5 Tsm + 2 Texp + 4 Th = 14.146 ms. In Ref. [12], the
computational overhead is 6 Tsm + 2 Th = 0.22 ms in key extraction. And the cost in
signature is 10 Tsm + 8 Th = 0.376 ms. As for this protocol, the computational costs
are 2 Tsm + Th = 0.074 ms in the key extraction and 2 Tsm + 2 Th = 0.076 ms in the
signature.

6.2 Communication Overhead


As shown in Fig. 2, assuming that the length of idi and timestamp is 8 bytes respectively,
the length of the hash function is 32 bytes, and the size of cyclic group G is 64 bytes.
To calculate the communication overhead of this scheme, the signature set {σi } includes
some sub-signatures, where i ∈ [1, t] and t is a threshold value. The length of each
signature is 64 bytes, so it requests 64t bytes to transmit the signature. It can conclude
that the length of signature is showed a linear correlation of the threshold value. As for
the Ref. [11], the messages exchanged between the device and the server including 32
bytes length pseudo-anonymous identity, 96 bytes length signature to KGC, 64 bytes
length signature private key from KGC, transmitting the generated message with 96
88 X. Deng et al.

Fig. 2. The Comparisons on Communication Overhead

bytes to server for signing and receiving 96 bytes signature, which the total length of
them is 384n bytes, where n is the number of entities to be authenticated.
 In the Ref.
  
[12], the patient sends a message request, Sig(SKij ) Kij , IDk , Xi , Xi to the server, and
the server responds with {Ak , βk , δk , T1 }, resulting in a total communication overhead
of 368k bytes, where k is the number of participants.

7 Performance Evaluations
In this paper, we designed a revocable threshold group signature scheme for blockchain-
aided smart grid, addressing the shortcomings of existing electricity payment in terms of
privacy protection, the semi-trust among the authorities and the malicious entity tracking.
Additionally, this protocol possesses immutability and unforgeability properties, resists
the various attacks and significantly enhances the security and reliability. At last, it has
been proven to be more effective in a simulation environment.

Acknowledgments. This work is supported by China Southern Power Grid Technology Project
Funding Grant No. 030000KC23040090(GDKJXM20230408)).

References
1. Wang, X., Blaabjerg, F.: Harmonic stability in power electronic-based power systems: concept,
modeling, and analysis. IEEE Trans. Smart Grid 10(3), 2858–2870 (2019). https://doi.org/
10.1109/TSG.2018.2812712
2. Su, H., Liao, P., Zhou, F., Yi, S., Yang, Q.: Research on intelligent defect identification model
of power grid edge side equipment based on edge computing. Electric Power Inf. Commun.
Technol. 19(04), 31–37 (2021). (in Chinese), https://doi.org/10.16543/j.2095-641x.electric.
power.ict.2021.04.005
Blockchain-Aided Revocable Threshold Group Signature Scheme 89

3. Zhu, L., You, S., Yin, H., Zhao, Y., Li, F., Yao, W.: FNET/GridEye: a tool for situational aware-
ness of large power interconnection grids. In: IEEE PES Innovative Smart Grid Technologies
Europe, pp. 379–383. IEEE, The Hague, Netherlands (2020). https://doi.org/10.1109/ISGT-
Europe47291.2020.9248857
4. Al-Gburi, A., Al-Hasnawi, A., Lilien, L.: Differentiating security from privacy in Internet of
Thigs: A survey of selected threats and controls. In: Daimi, K. (ed.) Computer and Network
Security Essentials, pp. 153–172. Springer, Cham (2018). https://doi.org/10.1007/978-3-319-
58424-9_9
5. Chen, Y., Wang, W., Sun, L.: Smart grid terminal access authentication scheme based on
multiple encryption algorithms. Ind. Technol. Innov. 9(4), 93–101 (2022). (in Chinese), https://
doi.org/10.14103/j.issn.2095-8412.2022.08.012
6. Su, Q., Zhang, R., Xue, R., Sun, Y., Gao, S.: Distributed attribute-based signature with attribute
dynamic update for smart grid. IEEE Trans. Industr. Inf. 19(9), 9424–9435 (2023). https://
doi.org/10.1109/TII.2022.3228688
7. Zhu, F., Guo, D., Abuadbba, S., Yi, X., Luo, J., Kumari, S.: Lightweight verifiable privacy-
preserving data aggregation for smart grids. IEEE Internet Things J. 11(19), 31249–31259
(2024). https://doi.org/10.1109/JIOT.2024.3419161
8. Zhang, J., Wei, J.: PFDAM: Privacy-Preserving fine-grained data aggregation scheme sup-
porting multifunctionality in smart grid. IEEE Internet Things J. 11(15), 25520–25533 (2024).
https://doi.org/10.1109/JIOT.2024.3356593
9. Tomar, A., Tripathi, S., Arivarasan, K.: A blockchain-based certificateless aggregate signature
scheme for fog-enabled smart grid environment. IEEE Trans. Green Commun. Netw. 7(4),
1892–1905 (2023). https://doi.org/10.1109/TGCN.2023.3265608
10. Verma, G.K., Gope, P., Saxena, N., Kumar, N.: CB-DA: Lightweight and escrow-free
certificate-based data aggregation for smart grid. IEEE Trans. Depend. Secure Comput. 20(3),
2011–2024 (2023). https://doi.org/10.1109/TDSC.2022.3169952
11. Shen, M., Liu, H., Zhu, L., Xu, K., Yu, H., Du, X.: Blockchain-assisted secure device authen-
tication for cross-domain industrial IoT. IEEE J. Sel. Areas Commun. 38(5), 942–954 (2020).
https://doi.org/10.1109/JSAC.2020.2980916
12. Xue, L., Huang, H., Xiao, F., Wang, W.: A cross-domain authentication scheme based on
cooperative blockchains functioning with revocation for medical consortiums. IEEE Trans.
Netw. Serv. Manage. 19(3), 2409–2420 (2022). https://doi.org/10.1109/TNSM.2022.3146929
Privacy-Preserving Three-Factors
Authentication and Key Agreement
for Federated Learning

Guojun Wang1,2(B) , Guixin Jiang1 , Yushuai Zhao1 , and Yinglin Ji1


1
Yancheng Polytechnic College, Yancheng 224005, China
[email protected]
2
School of Electronics and Information Engineering, Nanjing University of
Information Science and Technology, Nanjing 210044, China

Abstract. With the development of big data and distributed comput-


ing, federated learning has received widespread attention for protecting
data privacy. Federated Learning is a distributed machine learning frame-
work that allows multiple participants to collaboratively train shared
models with locally retained data. However, the issues of communication
security and authentication between participants have emerged as key
challenges affecting the security of federal learning systems. To address
this challenge, we propose an authentication and key agreement proto-
col for federated learning environment. The proposed protocol enables
authentication between edge devices and server, which establishes a
secure session key to secure communication between participants. Addi-
tionally, we verify the security of the session key by the verification tool.
Finally, experimental evaluation results show that the protocol performs
well in terms of computational overhead and communication latency
and is suitable for large-scale federated learning systems with resource-
constrained devices.

Keywords: Federated Learning · Authentication and Key


Agreement · Secure Communication · Distributed computing · Big data

1 Introduction
With the traditional centralized machine learning bringing more and more
data security problems, federated learning (FL) [1] is becoming widely used
in distributed environments. Federated learning prevents data leakage by train-
ing models locally and transmitting only model updates. However, distributed
devices are located in different networks and environments with the risk of being
maliciously attacked. Therefore, it is necessary to verify the legitimacy of dis-
tributed devices in federated learning to prevent unauthorized devices from par-
ticipating in model training or tampering with data.
Authentication and key agreement (AKA) [2] is essential for verifying
the legitimacy of distributed devices. It not only ensures the authenticity of
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 90–103, 2025.
https://doi.org/10.1007/978-981-96-4566-4_7
Privacy-Preserving Three-Factors Authentication and Key Agreement 91

device identity, but also protects the privacy of communication between dis-
tributed devices. In federated learning, distributed devices authenticate each
other through AKA cryptographic primitive to ensure every participant is legit-
imate. Distributed devices encrypt communication by generating temporary ses-
sion key to prevent malicious devices from impersonating their identities and
stealing or tampering with model updates. It will bring a large communication
overhead because of the need for frequent communication between the client and
the server in federated learning to transmit the local trained model updates.
Meanwhile, frequent communication is easy to be attacked and tampered model
updates will lead to model performance degradation or data leakage.
Most existing authentication and key agreement protocols [3] are designed
for traditional client-server environments. However, the implementation of these
protocols may lead to excessive computational and communication overheads.
Therefore, authentication protocol for federated learning should not only ensures
secure data transmission and identity verification [4], but also avoid affecting
device performance due to excessive computational and communication over-
heads. Furthermore, the client devices may dropout or rejoin the system at any
time due to their unstable connection status. The AKA protocol also needs to be
flexible and adaptable to support dynamic authentication and key update mech-
anisms to ensure the system can maintain stable security and communication
efficiency in spite of constant changes in the participating clients.

1.1 Main Contributions


In this paper, wo propose an AKA protocol for federated learning to protect com-
munication security in distributed scenarios. The contributions are summarized
as follows.

1. We propose an AKA protocol for federated learning with excel-


lent computational and security performance. The proposed protocol
ensures that only legitimate clients and central server can participate in model
training, preventing unauthorized access or adversary interference. It is vul-
nerable to a single point of failure in traditional authentication protocols due
to the reliance on a centralized trust authority. Unlike this, a distributed
mutual authentication approach is used in the proposed protocol to ensure
the legitimacy of the client before the shared model is update. The proposed
authentication protocol balances security and performance, providing both
strong data protection and efficient computational performance for resource-
constrained federated learning environments.
2. Session key security in federated learning communication is proved
its resistance to both passive and active attacks. Federated learn-
ing communications are exposed to multiple attack threats due to their dis-
tributed nature. Eavesdropping attacks [5] and MITM attacks [6] can directly
interfere with federated learning communications, affecting model quality and
security. Through the full security analysis, the session key can effectively
resist multiple passive and active attacks.
92 G. Wang et al.

3. Security is further analyzed using Automated Validation of Internet


Security Protocols and Applications (AVISPA) tool. This innovative
application of AVISPA [7] ensures that our protocol meets rigorous security
standards and provides strong guarantees against a wide range of potential
vulnerabilities in federated learning communication. AVISPA employs a suite
of formal methods, including model checking and constraint solving, to rig-
orously analyze protocols against sophisticated adversarial strategies.
4. To evaluate security and performance, the proposed protocol is
compared with related works, demonstrating superior sustainabil-
ity for federated learning. Our proposed protocol stands out for its supe-
rior balance of security and performance, as evidenced by a comparative
analysis with related works. The results demonstrate that our protocol is
more sustainable for federated learning (FL) environments, meeting the strin-
gent demands of scalability, resource efficiency, and resilience against evolving
attack vectors. By combining enhanced security, superior performance, and
adaptability, our protocol outperforms related works, offering a sustainable
solution tailored to federated learning’s unique needs.

1.2 Road-Map of this Paper


The remainder of this paper is organized as follows: Sect. 2 provides the neces-
sary preliminaries and foundational knowledge relevant to this study. Section 3
defines the system and threat models, establishing the context for the proposed
protocol. Section 4 introduces the proposed authentication and key agreement
protocol, detailing its architecture and innovative mechanisms. The security and
performance analysis are summarized in Sect. 5, 6 and 7, respectively. Finally,
Sect. 8 gives the conclusion.

2 Preliminaries
In this section, we give definitions of the background knowledge related to this
paper. The specific description of definitions is shown as follows.

2.1 Notations
This section provides a uniform description of the main symbols used in this
paper. The specific description of each notation is displayed in Table 1.

2.2 Fuzzy Extractor


The fuzzy extractor [8] contains two algorithms and the specific description of
the algorithms is shown as follows.
– Gen(bioi ) = (σi , τi ): Gen is a generating function of fuzzy extractor, which
serves to extract a stabilized key σi and an auxiliary information τi from the
input biometric bioi for subsequent key reproduction.
Privacy-Preserving Three-Factors Authentication and Key Agreement 93

Table 1. Notations

Notation Description
Ui ith user of the edge device
EDi ith edge device
ES edge server
CSj jth cloud server
idi identifier of the ith user
pwi password of the ith user
bioi biometric of the ith user
H(·) one-way hash function
Ti timestamp

– Rep(bioi , τi ) = σi : Rep is a reproduction function of fuzzy extractor, which


serves to recover the key σi consistent with the generation phase from the
input biometric bioi and the auxiliary information bioi .

3 Specific Definitions
In this section, we give system the specific definitions of the system, including
system model, and adversary model. The detailed description of each model is
presented as follows.

3.1 System Model


The system model is shown in Fig. 1, which has three entities. The specific
description of each entity is given as follows.

– Edge device (EDi ). In federated learning, edge devices generate data and per-
form distributed computation. Edge devices train model locally and upload
encrypted model updates without transmitting raw data to protect data pri-
vacy.
– Edge server (ES). In federated learning, edge servers are important interme-
diaries connecting edge devices to cloud server. It is responsible for collecting
and aggregating model updates from multiple edge devices to reduce the bur-
den on cloud server and reduce communication overheads. Moreover, the edge
server also registers for cloud server and edge devices in this protocol.
– Cloud server (CSj ). Cloud server is responsible for collecting local model
updates or gradients from distributed edge devices or edge servers as a central
aggregation node. Cloud server generates the optimized global model through
the global aggregation algorithm and feeds it to each participating device for
the next round of iterative training.
94 G. Wang et al.

Fig. 1. System Model.

3.2 Adversary Model


In this paper, the widely accepted Dolev-Yao (DY) [9] threat model is adopted.
Adversary A is given a powerful ability to break the security of communications,
the detailed description is shown as follows.
– The adversary A has such powerful analysis ability that it can access infor-
mation stored in the edge device.
– The adversary A has strong eavesdropping ability and can tamper with infor-
mation transmitted over the public channel.

4 Proposed Protocol
There are three phases in the proposed protocol. The detailed construction of
each stage will be shown as follows (Fig. 2).

Fig. 2. User Registration.


Privacy-Preserving Three-Factors Authentication and Key Agreement 95

4.1 Setup Phase


Edge server ES, as a trusted entity, is responsible for initializing the edge device
EDi and cloud server CSj . For each cloud server CSj in Federated Learning, ES
generates a random identifier Cidj and random numbers {γj , δj }. When cloud
server CSj has finished registration, CSj stores {γj , δj } in the local database.

Fig. 3. Login and Authentication.

4.2 User Registration Phase


The user Ui needs to complete registration with edge server ES when a legitimate
user of the edge device EDi wants to join the system. User registration is divided
into three steps, which are described as follows.
Step 1: Ui chooses idi , pwi and implants biometric bioi through specific device.
Then, Ui generates random numbers α, β ∈ Zp∗ and computes Gen(bioi ) =
(σi , τi ), Ridi = H(idi  σi  α), and Rpwi = H(pwi  σi ) ⊕ β. Finally,
{Ridi , Rpwi } is sent to edge server ES.
Step 2: When the registration request is received from Ui , ESj chooses random
numbers ni , mi for each legitimate user Ui . ES computes M1 = ni ⊕Ridi ⊕Rpwi ,
M2 = H(Ridi  Rpwi  ni ) and M3 = H(mi  Ridi ). Finally, edge server
ES stores {Ridi , ni , mi } in ESj ’s database and sends the smartcard SCi =
{M1 , M2 , M3 , H(·)} to Ui .
Step 3: When the smartcard is received from ESj , Ui computes M1 = M1 ⊕ β,
M3 = M3 ⊕ H(idi  pwi ), M4 = α ⊕ H(idi  σi ). After the above computations
are completed, Ui updates the smartcard SCi = {M1 , M2 , M3 , M4 , τi , H(·)} to
complete registration.
96 G. Wang et al.

4.3 Login and Authentication Phase

When the legitimate user Ui has logged into the edge device with their identity
credentials, authentication is completed between EDi , ES and CSj . Authenti-
cation is divided into 7 steps and the detailed construction of each step is shown
as follow.
Step 1: The user inserts the smartcard SCi through the reading device and
inputs idi , pwi & bioi . EDi computes σi = Rep(bioi , τi ), α∗ = M4 ⊕ H(idi  σi ),
Rid∗i = H(idi  σi  α∗ ), Rpwi∗ = H(pwi  σi ), n∗i = M1 ⊕ Rid∗i ⊕ Rpwi∗ ,
M2∗ = H(Rid∗i  Rpwi∗  n∗i ), M5 = M3 ⊕ H(idi  pwi ) and checks if M2∗ = M2 .
Step 2: If M2∗ = M2 is valid, EDi generates random number ri and the times-
tamp T1 . Then, EDi chooses the accessed cloud Cidi and computes M6 =
M5 ⊕ ri , M7 = H(Rid∗i  ri  M3  T1 ) and M8 = Cidj ⊕ H(M3  Rid∗i  T1 ).
After the above computations have been completed, the authentication messages
{Ridi , M6 , M7 , M8 , T1 } is sent to edge server ES through the public channel.
Step 3: After receiving {Ridi , M6 , M7 , M8 , T1 } from EDi , ES checks if |T1 −
T1 | ≤ ΔT . If |T1 − T1 | ≤ ΔT is valid, ES extracts {ni , mi } corresponding to
Ridi . Then, ES computes M3∗ = H(mi  Ridi ), ri∗ = M6 ⊕ M3∗ , and M7∗ =
H(Ridi  r∗  M3∗ ). ES checks if M7∗ = M7 .
Step 4: when the above computations are completed and the equation is
valid, ES generates random number εi and timestamp T2 . Then, ES com-
putes Cid∗j = M8 ⊕ H(M3∗  Ridi  T1 ) and extracts {γj , δj } corresponding
to Cid∗j . Finally, ES computes M9 = εi ⊕ H(δj  Cid∗j ), M1 0 = H(γj  εi ) ⊕ ri∗ ,
and M1 1 = H(Cid∗j  εi  ri∗  T2 ) and sends the authentication messages
{M9 , M10 , M11 , T2 } to cloud server.
Step 5: After receiving {M9 , M10 , M11 , T2 } from ES, CSj checks if |T2 − T2 | ≤
ΔT is valid. If the above equation is valid, CSj computes ε∗i = M9 ⊕ H(δj 
Cid∗j ), ri∗ = M10 ⊕ H(γj  εi ), M11

= H(Cid∗j  ε∗i ) and checks if M is valid.
Step 6: If the above equation is valid, CSj generates random ηj and timestamp
T3 . Then, CSj computes M1 2 = ηj ⊕ H(ri∗  Cid∗j ), sk = H(Cid∗j  ri∗  ηj  T3 )
and M13 = H(sk  M12  ηj  T3 ). After the above parameters are computed,
the authentication messages {M12 , M13 , T3 } are sent to EDi .
Step 7: After receiving the authentication messages from CSj , EDi checks if
|T3 − T3 | ≤ ΔT is valid. If the above equation is valid, EDi computes ηj∗ =
M12 ⊕ H(ri  Cidj ) and sk ∗ = H(Cidj  ri  n∗j  T3 ). Finally, EDj computes
H(sk ∗  M1 2  n∗j  T3 ) and verifies if M13 = H(sk ∗  M12  n∗j  T3 ) is valid.
If the above equation is valid, EDi stores sk ∗ in the local database.

5 Security Analysis

This section demonstrates the proposed protocol’s security against various


known attacks.
Privacy-Preserving Three-Factors Authentication and Key Agreement 97

5.1 Privileged-Insider Attack


In this paper, Ui sends Ridi , Rpwi to ES for registration, where Ridi = H(idi 
σi  α) and Rpwi = H(pwi  σi ) ⊕ β. Suppose an internal adversary A obtains
Ridi and Rpwi , A cannot derive any information about pwi and β without ran-
dom number β. Additionally, the one-way property of the hash function ensures
that pwi and β remain inaccessible. Therefore, this protocol effectively resists
privileged-insider attacks.

5.2 User Anonymity and Untraceability


As illustrated in Fig. 3, the authentication information transmitted over the pub-
lic channel does not disclose the identity of Ui . Furthermore, none of the param-
eters within SCi reveal any details about Ui ’s identity. By leveraging one-way
functions for messages exchanged over public channels, the proposed protocol
ensures both user anonymity [10] and untraceability [11].

5.3 Stolen Smartcard Attack


In this attack scenario, it is assumed that EDi ’s SCi is stolen, allowing A to
access its contents through advanced analysis techniques. The smartcard SCi
contains {M1 , M2 , M3 , M4 , τi , H(·)}, where H(·) is a one-way hash function and
τi is an auxiliary parameter generated by the fuzzy extractor. Without bioi , A
is unable to reconstruct σi . Additionally, ni and mi , which are random values
selected by ES, do not reveal meaningful information to A. Furthermore, M1 ,
M2 , M3 and M4 are protected by one-way functions or random blinding tech-
niques, preventing A from extracting any useful data. Thus, even if a smartcard
is compromised, the protocol’s security remains robust.

5.4 User Impersonation Attack


Suppose the authentication messages Ridi , M6 , M7 , T1 sent by EDi are inter-
cepted by A, who attempts an active attack to impersonate EDi . To succeed,
A must forge valid values for M6 = M5 ⊕ ri , M7 = H(Rid∗i  ri  M3  T1 ),
and M8 = Cidj ⊕ H(M3  Rid∗i  M3  T1 ), where M3 = H(mi  Ridi ) is a
long-term secret value generated by ES. While A can select a random ri∗ and
forge M6∗ , they lack the long-term secret H(mi  Ridi ) generated by ES, which
is essential to create M5 and pass authentication. Thus, the protocol effectively
resists Ui impersonation attack.

5.5 ES Impersonation Attack


If an adversary, denoted as A, intercepts the authentication messages M9 , M10 ,
M11 , and T2 sent by ES, they might attempt an active attack to impersonate
the legitimate edge server. To succeed, A must forge valid messages as follows:
M9 = εi ⊕ H(δj  Cid∗j ), M10 = H(γi  εi ) ⊕ ri∗ and M11 = H(Cid∗j  εi  ri∗ 
98 G. Wang et al.

T2 ), where εi is the long-term secret value generated by ES. Although A can


select a random δi∗ and generate H(δi∗  Cid∗j ), it lacks knowledge of the secret
value εi required to construct M11 and pass CSj ’s authentication. As a result,
it is resistant to impersonation attacks targeting ES.

5.6 CSj Impersonation Attack


If an adversary A intercepts the authentication messages M12 , M13 and T3 sent
by CSj , they may attempt an active attack to impersonate the legitimate cloud
server. To succeed, A must forge valid messages, specifically M12 = ηj ⊕ H(ri∗ 
Cid∗i ) and M13 = H(sk  M12  ηj  T3 ), where ηj is a random value generated
by CSj . Although A can select a random ri∗ and compute H(ri∗  Cid∗j ), the
lack access to ηj generated by CSj , making it impossible to forge M12 and
pass EDi ’s authentication. Hence, the protocol effectively resists impersonation
attacks targeting CSj .

5.7 Replay Attack


Assuming A intercepts all authentication message on the public channel, any
replayed message will have its timestamp Ti checked against the threshold T .
As a result, the protocol is protected from replay attacks.

5.8 MITM Attack


Assume that all authentication messages {Ridi , M6 , M7 , M8 , T1 },
{M9 , M10 , M11 , T2 } and {M12 , M13 , T3 } can be intercepted by A. Taking {Rid,
M6 , M7 , M8 , T1 } as an example, A attempts to impersonate a legitimate
edge device. To achieve this, A must forge valid values for M6 = M5 ⊕ ri ,
M7 = H(Rid∗i  ri  M3  T1 ) and M8 = Cidj ⊕ H(M3  Rid∗i  T1 ). While
A can select a random value ri∗ and forge M6 = M5 ⊕ ri∗ , it does not know mi
generated by ES, which is required to forge M3 and pass the authentication.
Since A does not have access to the secret values, it cannot be verified. Hence,
it can defend against MITM attack.

5.9 Mutual Authentication


In federal learning, mutual authentication is established among EDi , ES and
CSj . ES verifies EDi by checking if M7∗ = M7 , while CSj authenticates ES by
∗ ∗
confirming M11 = M11 . Similarly, EDi authenticates CSj by verifying M13 =
M13 . This ensures mutual authentication among all parties.

5.10 Three-Factors Security


Suppose A can obtain any two of the three factors. The following cases demon-
strate why this is insufficient: 1) Knows idi and pwi : Without σi , A cannot com-
pute M2 = H(Ridi  Rpwi  ni ). Recovering ni requires Ridi = H(idi  σi  α)
Privacy-Preserving Three-Factors Authentication and Key Agreement 99

and Rpwi = H(pwi  σi ) ⊕ β, but without σi , deriving Ridi is infeasible. 2)


Knows idi and σi : Without pwi , A cannot compute M2 = H(Ridi  Rpwi  ni ).
Computing Rpwi = H(pwi  σi ) ⊕ β and recovering ni relies on pwi , making it
impossible to derive Rpwi . 3) Knows pwi and σi : Without idi , A cannot compute
M2 = H(Ridi  Rpwi  ni ). Recovering Ridi requires idi , which is essential to
compute Ridi = H(idi  σ  α).

5.11 Perfect Forward Security


Perfect forward security ensures the confidentiality of sk, even if the long-term
secret is compromised. sk is expressed as H(Cid∗j  ri∗  ηj  T3 ), where ri∗ is
randomly generated by EDi . Since it is challenging for A to obtain the random
number ri∗ and ηj , the proposed protocol effectively achieves perfect forward
security.

6 Formal Security Verification


AVISPA is a tool for detecting protocol security. It supports authentication for
various protocols and applications, such as TLS, IPSec, SSL, SSH, and Kerberos.
With its user-friendly graphical interface, user can easily configure and execute
validation processes. Additionally, it generates detailed reports that outline iden-
tified security issues and suggest corrective measures. The proposed protocol’s
entities and environments are modeled in HLPSL and simulated using SPAN,
with results shown in Fig. 4. The entities and environments in the proposed pro-
tocol are modeled using HLPSL. The protocol is simulated using SPAN, with the
results presented in Fig. 4. The findings demonstrate that the proposed protocol
effectively resists replay and MITM attacks.

Fig. 4. Verification Result.


100 G. Wang et al.

7 Performance Analysis
This section comprehensively evaluates the performance through theoretical
analysis and experimental validation.

7.1 Theoretical Analysis


In Table 2 and 3, we compare the proposed protocol with related works in terms
of security and computability. From Table 2, we can find that Zhang2023 [2],
Gao2024 [3] and ours satisfy the known security. In Zhang2023, a chebyshev-
based AKA protocol is proposed which balances performance and security.
Meanwhile, Gao2024 proposes a highly available AKA protocol using symmetric
encryption to protect data privacy in the IoT environment. However, Ju2023 can-
not effectively resist stolen SC attack and ES impersonation attack. Therefore,
the proposed protocol has better security performance. From Table 3, we can
find that Zhang2023’s computational overhead on the device side is 4Th + 2Tc ,
where Th is the time to compute a hash operation and Tc is the time to compute
the chebyshev operation. The overhead on the node side is 6Th and the over-
head on the server side is 5Th + 2Tc . In Gao2024, computational overhead on
the device side is 9Th + 2TE/D , where TE/D is the time to compute a symmetric
encryption or decryption. The overhead on the node side is 6Th + TE/D and the
overhead on the server side is 3Th + 2TE/D . Similar to the proposed protocol,
Ju2023 uses only hash function to complete the whole process. However, the total
computational overhead of Ju2023 is 32Th , while the proposed protocol is 24Th .
Therefore, comparing with related works, our protocol has better performance
on computational comparison.

Table 2. Security Comparison

Security Zhang2023 [2] Gao2024 [3] Ju2023 [12] Ours


Privileged-insider attack    
Untraceability    
EDi anonymity    
Stolen SC attack   × 
EDi impersonation attack    
ES impersonation attack   × 
CSj impersonation attack    
Replay attack    
MITM attack    
Mutual authentication    
Three-factors security   × 
Perfect forward security    
Privacy-Preserving Three-Factors Authentication and Key Agreement 101

Table 3. Computational Comparison

Phase Zhang2023 Gao2024 Ju2023 Ours


Device side 4Th + 2Tc 9Th + 2TE/D 10Th 12Th
Node side 6Th 6Th + TE/D 6Th 8Th
Server side 5Th + 2Tc 3Th + 2TE/D 16Th 6Th
Total cost 17Th + 4Tc 17Th + 2TE/D 32Th 24Th

7.2 Experimental Analysis

The performance was analyzed in depth by C simulation. As shown in Fig. 5,


the computational overhead of edge server is modeled with the number of edge
devices from 0 to 100. It can be found that the computational overhead of
Gao2024 increases rapidly with the rise in the number of edge devices. Other
related works and ours use only lightweight hash operation with lower compu-
tational overhead than Gao2023.

Fig. 5. Time cost of edge server.

The total time overhead is similarly compared with related works, and
the comparison result is shown in Fig. 6. First, ours is the lowest in terms
of total computational overhead compared to other related works. Compared
to other protocols, Ju2023 and ours have lower overhead than other solutions.
In Zhang2023 and Gao2024, there exist chebyshev polynomials and symmetric
encryption. Therefore, their computational overhead is higher. Overall, combin-
ing the results of the above analysis, the proposed protocol is more suitable for
federal learning communication.
102 G. Wang et al.

Fig. 6. Total time cost.

8 Conclude

In this thesis, we propose an authenticated key negotiation protocol for federated


learning environments to address the dual challenges of data privacy protection
and communication security in distributed machine learning processes. Through
an in-depth analysis of federated learning scenarios, we design an efficient and
secure key negotiation mechanism that establishes a secure communication chan-
nel between participants and provides data integrity and authentication features.
Finally, we rigorously verify the security of the protocol through formal methods
and demonstrate its efficiency in terms of computational overhead and commu-
nication delay through simulation experiments.

References
1. Li, L., Fan, Y., Tse, M., Lin, K.Y.: A review of applications in federated learning.
Comput. Ind. Eng. 149, 106854 (2020)
2. Zhang, T., Shen, J., Yang, H., Pandi, V., Gupta, B.B., Arya, V.: Sustainable
authentication and key agreement protocol using chaotic maps for industry 5.0.
IEEE Trans. Consum. Electron. 70(1), 1580–1589 (2023)
3. Gao, Y., Zhou, T., Zheng, W., Yang, H., Zhang, T.: High-availability authentication
and key agreement for internet of things-based devices in industry 5.0. IEEE Trans.
Ind. Inf. 1–9 (2024)
4. Ahmad, A.: Fraud prevention in insurance: biometric identity verification and AI-
based risk assessment. In: 2024 International Conference on Knowledge Engineering
and Communication Systems, vol. 1, pp. 1–6 (2024)
5. Sun, Y., An, K., Luo, J., Zhu, Y., Zheng, G., Chatzinotas, S.: Intelligent reflecting
surface enhanced secure transmission against both jamming and eavesdropping
attacks. IEEE Trans. Veh. Technol. 70(10), 11017–11022 (2021)
6. AI-Shareeda, M.A., Manickam, S.: Man-in-the-middle attacks in mobile ad hoc
networks (MANETs): analysis and evaluation. Symmetry 14(8), 1543 (2022)
Privacy-Preserving Three-Factors Authentication and Key Agreement 103

7. Yogesh, P.R.: Formal verification of secure evidence collection protocol using BAN
logic and AVISPA. Procedia Comput. Sci. 167, 1334–1344 (2020)
8. Dodis, Y., Reyzin, L., Smith, A.: Fuzzy extractors: how to generate strong keys
from biometrics and other noisy data. In: Cachin, C., Camenisch, J.L. (eds.)
EUROCRYPT 2004. LNCS, vol. 3027, pp. 523–540. Springer, Heidelberg (2004).
https://doi.org/10.1007/978-3-540-24676-3_31
9. Cervesato, I.: The Dolev-Yao intruder is the most powerful attacker. In: 16th
Annual Symposium on Logic in Computer Science-LICS, vol. 1, pp. 1–2 (2001)
10. Li, B., Erdin, E., Gunes, M.H., Begis, G., Shipley, T.: An overview of anonymity
technology usage. Comput. Commun. 36(12), 1269–1283 (2013)
11. Sutrala, A.K., Obaidat, M.S., Saha, S., Das, A.K., Alazab, M., Park, Y.: Authen-
ticated key agreement scheme with user anonymity and untraceability for 5G-
enabled softwarized industrial cyber-physical systems. IEEE Trans. Intell. Transp.
Syst. 23(3), 2316–2330 (2022)
12. Ju, S., Park, Y.: Provably secure lightweight mutual authentication and key agree-
ment scheme for cloud-based IoT environments. Sensors 23(24), 9766 (2023)
Blockchain-Based Anonymous
Authentication Scheme with Traceable
Pseudonym Management in ITS

Mingliang Wang1 , Haowen Tan1,2(B) , and Wenying Zheng3


1
School of Information Science and Engineering, Zhejiang Sci-Tech University,
Hangzhou 310018, China
[email protected]
2
Guangdong Key Laboratory of Blockchain Security, Guangzhou University,
Guangzhou 510006, China
3
School of Computer Science and Technology, Zhejiang Sci-Tech University,
Hangzhou 310018, China

Abstract. With the continuous development of intelligent transporta-


tion systems (ITS), the Vehicular Ad-hoc Networks, as a core component
of smart transportation, has become a crucial technology for improving
traffic safety and transportation efficiency. However, the open communi-
cation environment and highly dynamic nature of VANETs make them
vulnerable to many security threats. The existing schemes are difficult
to get a balance between anonymity and traceability. Therefore, this
paper proposes a blockchain-based anonymous authentication scheme
with pseudonym management protocol. Using the decentralized nature
of blockchain, the scheme incorporates a pseudonym management mech-
anism based on smart contracts, ensuring that a user’s real identity can
only be revealed upon consensus from a certain number of RSU, thereby
enhancing identity privacy protection. Additionally, using the immutabil-
ity of smart contracts, a vehicle authentication protocol is designed to
improve system decentralization, data security, and identity traceability.
Through security analysis, the proposed scheme demonstrates superior
security performance against common threats. The analysis of perfor-
mance shows that the performance of our proposed scheme is improved.

Keywords: Blockchain · Authentication · Smart Contracts ·


VANETs · Intelligent Transportation Systems (ITS)

1 Introduction
With the rapid development of Intelligent Transportation Systems (ITS), Vehic-
ular Ad Hoc Networks (VANETs) [1], as a core component of ITS, have become
a key technology for improving traffic safety and transportation efficiency [2].
Particularly with the rise of autonomous driving technology, the widespread
adoption of autonomous vehicles and intelligent traffic management systems has
made VANETs crucial in enhancing both traffic safety and efficiency. However,
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 104–117, 2025.
https://doi.org/10.1007/978-981-96-4566-4_8
Anonymous Authentication Scheme with Traceable Pseudonym Management 105

the open communication environment and highly dynamic nature of VANETs


make them vulnerable to various security threats, such as data tampering, iden-
tity forgery, and denial-of-service attacks [3].
Recent advancements in machine learning have greatly enhanced intelligent
transportation systems, enabling intelligent driving systems to detect real-time
environmental changes and make accurate decisions. Deep learning and reinforce-
ment learning applications, particularly in autonomous driving, improve traffic
flow predictions, road condition analysis, and vehicle behavior recognition, while
also aiding anomaly detection. However, the widespread use of machine learning
also raises data security concerns. Machine learning in intelligent driving and
VANET systems depends on large amounts of sensitive real-time data, such as
vehicle location, speed, and traffic flow [4]. If this data is tampered with or stolen,
attackers can mislead models, compromising the safety of autonomous vehicles
and traffic systems [5]. Additionally, machine learning models are vulnerable
to adversarial attacks, where slight modifications can cause incorrect decisions.
Recent research on graph neural networks has explored how these models can
enhance machine learning privacy and security in systems. By leveraging the
graph-based structure, such networks can better protect the flow of sensitive
data and infrastructure, ensuring data integrity and preventing manipulation
[6]. Furthermore, privacy-preserving frameworks such as PPFed are becoming
increasingly important in securing large datasets in autonomous driving [7].
PPFed allows for personalized models that are trained in a federated manner,
ensuring that sensitive data remains secure while still providing valuable insights
for intelligent systems [8].
In VANETs, data exchange among vehicles (V2V), infrastructure (V2I), and
pedestrian devices (V2P) supports applications like intelligent transportation
and cooperative driving [9]. However, the high mobility, heterogeneity, and open-
ness of VANETs create challenges such as unstable communication and limited
computational resources, exposing networks to threats like man-in-the-middle
and replay attacks [10]. Ensuring secure communication in this dynamic envi-
ronment is a critical issue.
Authentication and Key Agreement (AKA) protocols are essential for ver-
ifying entities and establishing shared keys for secure communication [11].
Blockchain technology, with its decentralization and transparency, offers a
novel approach to address security challenges in VANETs [12]. By integrating
blockchain with AKA processes, issues like single points of failure and data tam-
pering can be mitigated, improving security and reliability [13].

1.1 Motivation and Contributions

With the rise of network threats, traditional authentication protocols face grow-
ing vulnerabilities. Blockchain’s decentralization and transparency offer a secure
solution for authentication and key management, reducing centralized risks. Our
contributions are as follows:
106 M. Wang et al.

– We designed a blockchain-assisted authentication protocol. This protocol


combines the decentralized characteristics and immutability of blockchain,
allowing RSU to verify the legitimacy of vehicles by invoking smart contracts.
– Based on smart contracts, we designed a pseudonym management scheme for
vehicles. By employing a secret sharing scheme, the values used to compute
the root pseudonym are divided into n parts, preventing the TA from easily
recovering the user’s real ID. Additionally, the scheme includes a mechanism
for recovering real ID, enabling the tracking of malicious users.
– We proved the security of the proposed scheme through security analyses.
Additionally, by analyzing computational and communication overhead, we
demonstrated the performance advantages of our proposed scheme.

2 Related Work

In VANETs, authentication and key agreement protocols are critical for ensuring
secure communication between vehicles and between vehicles and infrastructure.
Over recent years, researchers have proposed numerous Authentication and Key
Agreement schemes to enhance security in vehicular networks. Below are some
notable contributions in this field.
Lightweight authentication schemes play a vital role in securing communica-
tion within VANET environments. Vasudev et al. [14] introduced a lightweight
mutual authentication protocol for V2V communication. This protocol employs
encryption operations to facilitate efficient mutual authentication between vehi-
cles and establish a secure communication channel for key agreement between
devices and servers. Similarly, Li et al. [15] proposed a lightweight privacy-
preserving authentication protocol that utilizes hash functions and XOR oper-
ations to enable anonymous authentication. Their protocol operates in a net-
work model comprising TA, RSU, and vehicles, allowing vehicles to authenticate
anonymously with RSU. Furthermore, Lee et al. [16] developed a secure and effi-
cient honey-list-based authentication protocol. This protocol incorporates vehicle
location information for key agreement and supports V2V communication both
within the same region and across regions.
In recent years, blockchain technology, as a decentralized, secure, and trans-
parent distributed ledger, has been increasingly applied to authentication and
key agreement protocols in VANETs. Xie et al. [17] proposed a blockchain-based
V2I handover authentication and V2V broadcast protocol for VANETs. The pro-
tocol ensures the anonymity and traceability of broadcast information through
dynamic anonymity and pseudonym embedding strategies, enabling verification
even without transportation infrastructure or trusted authorities. Xue et al. [18]
presented a distributed authentication scheme utilizing blockchain and smart
contracts to facilitate decentralized authentication between users and access
points in mobile Internet of Vehicles roaming services. Lin et al. [19] devel-
oped an efficient blockchain-based conditional privacy-preserving authentication
scheme. This approach balances anonymity, traceability, and key management
Anonymous Authentication Scheme with Traceable Pseudonym Management 107

by incorporating key derivation, knowledge signatures, and smart contracts for


efficient identity authentication and key management.
Despite the numerous studies proposing authentication protocols for vehicu-
lar networks, several shortcomings persist. Some protocols, although designed to
be lightweight, fail to effectively defend against impersonation attacks, insider
threats, and the risks associated with short-term key leakage, leaving room for
improvement in energy consumption and adaptability. Furthermore, communica-
tion overhead poses significant challenges for resource-constrained devices. Addi-
tionally, all submission and revocation operations rely on a centralized admin-
istrator, which exacerbates the risk of centralization. Overall, these protocols
demonstrate varying degrees of trade-offs between security, efficiency, and prac-
ticality, necessitating further optimization to meet real-world requirements.

3 Proposed Scheme
In this section, we give a detailed description of the proposed scheme with the
following details:

Fig. 1. System Model.

3.1 System Initialization


In this section, we give a detailed description of the proposed scheme, the system
model is shown in Fig. 1, and the details are as follows:
1. TA selects a non-singular elliptic curve E : y 2 = x3 + ax + b mod p. In the
curve E, a, b ∈ Fp can find a generator P with order q to form the group G
of additive cyclic elliptic curves of order q, where p, q are two large prime
numbers.
108 M. Wang et al.

2. TA selects a private key s∈Zq∗ and calculates the corresponding public key
pubT A = s · P . Then, RSU broadcasts the public key to nearby vehicles.

3.2 Registration Phase


This section describes the registration of the vehicle, public key, and RSU as
follows:
1. OBU registration: The vehicle sends the vehicle ID and password to TA
through a secure channel. After receiving the password, TA loads it into the
on-board unit for the login of the vehicle, and TA deletes the local account
password.
2. A private key ski is selected and the public key pubi = ski · P is calculated,
and the public key pub is sent to TA, which uploads the public key and the
corresponding pseudonym list to the blockchain in the pseudonym generation
phase of the vehicle.
3. RSU registration: Each RSU as a blockchain node needs to register with TA,
each RSU selects a random number skrsu ∈ Zq∗ , and calculates the corre-
sponding public key pubrsu = skrsu · P , where the public key pubrsu will be
uploaded to the blockchain.

3.3 Login and Pseudonym Generation


Vehicles communicate using pseudonyms and store pseudonyms and correspond-
ing public keys on the blockchain. The specific steps are as follows:
1. The user enters the correct ID and password to log into the OBU. The
user selects a random number σ and a sequence of random numbers
α1 , α2 , α3 , . . . , αn , and computes the pseudonym P IDroot = ID ⊕σ. In which
σ will be divided into n shares according to the secret sharing scheme and
sent to n blockchain nodes, and a threshold t is set. When the user wants to
recover the true ID, TA will determine whether the number of shares meets
the requirement. If the requirement is met, the true σ is recovered based on
t shares, and the true ID is computed. The pseudonym calculation is as fol-
lows: P ID11 = P IDroot ⊕ α1 , P ID21 = P IDroot ⊕ α2 , P ID31 = P IDroot ⊕
α3 , . . . P IDn1 = P IDroot ⊕ αn . Finally, TA will match all pseudonyms with
the corresponding public key, as shown in Table 1.
2. TA uploads the relevant pseudonyms to the vehicle’s corresponding public
key list publist by invoking a smart contract. The smart contract is shown as
Algorithm 1.

3.4 Authentication Phase


This section describes the mutual authentication between the vehicle user and
the RSU, as well as the authentication of the vehicle Vi and RSU by the TA.
The specific process is as follows:
Anonymous Authentication Scheme with Traceable Pseudonym Management 109

Table 1. Pseudonym List

Public Key Pseudonym Sequence


pub1 P ID11 P ID21 P ID31 . . . P IDn1
pub2 P ID12 P ID22 P ID32 . . . P IDn2
.. .. .. .. . . ..
. . . . ..
pubn P ID1n P ID2n P ID3n . . . P IDnn

Algorithm 1. Pseudonym registration


Require: P IDj , pubi
Ensure: bool
1:  Check whether the function caller is an TA
2:  T As is an address array of TA
3:  publist: mapping(bytes32 ⇒ string[])
4: if msg.sender ∈ / T As then
5: return false
6: else
7: for each P IDj do
8: publist[pubi ].append(P IDj )
9: publist[pubi ].check(P IDj ) ⇒ true
10: end for
11: return true
12: end if

1. The vehicle Vi calculates P IDj1 = ID⊕σ⊕αj , θ = σ·P, Kvr1 = σ·pubrsu , Aj =


αi ⊕ H1 (Kvr1 P IDj1 ). It generates the timestamp T1 and calculates Kvr2 =
H2 (ski · pubrsu Kvr1 T1 ), Vh = Kvr2 ⊕ H3 (P IDj1 pubi Kvr1 αj T1 ). The
authentication request message M1 = {P IDj1 , Aj , Vh , T1 , θ} is sent to the
RSU.

Algorithm 2. Validate PID Function


Require: P IDj
Ensure: pubi
1: if publist[pubi ].check(P IDj ) == true then
2: return pubi
3: else
4: return false
5: end if

2. After the RSU receives M1 , it first verifies the timestamp |T1 − T1 | ≤ T .


Then, it calls Smart Contract Algorithm 2 to check whether P IDj1 has
been uploaded to the blockchain to verify the validity of the pseudonym.
If valid, it returns the corresponding public key pubi . The RSU calculates
110 M. Wang et al.


Kvr1 = skrsu ·θ, αj = Aj ⊕H1 (Kvr1

P IDj1 ), then computes Kvr2 
= H2 (skrsu ·
 1  
pubi Kvr1 T1 ) and H3 (P IDj pubi Kvr1 αj T1 ). It checks whether Vh ⊕

Kvr2 = H3 (P IDj1 pubi Kvr1

αi T1 ) holds. If the equality is satisfied, it gen-
erates the timestamp T2 and calculates Vr1 = H3 (αj pubi Kvr1  
Kvr2 T2 ),
1 
where P IDj ⊕ αj = ID ⊕ σ. Next, using Eq. (1) to calculate Vr2 . Finally, it
sends the message M2 = {Vr2 , T2 } to the vehicle user.

Vr2 = Vr1 ⊕ H4 ((ID ⊕ σ)Vr1 pubrsu ) (1)

3. After the vehicle Vi receives M2 , it first verifies the timestamp |T2 − T2 | ≤



T . If valid, it calculates Vr1 = H3 (αj pubi Kvr1 Kvr2 T2 ) and checks

whether Vr1 = Vr1 holds. If the equality is satisfied, it computes H4 ((ID ⊕
   
σ)Vr1 pubrsu ) and Vr2 = Vr1 ⊕ H4 ((ID ⊕ σ)Vr1 pubrsu ), further verify-

ing whether Vr2 = Vr2 holds. If this equality is satisfied, it generates the
timestamp T3 , takes the x-axis coordinate of Kvr1 as (Kvr1 )x , calculates

μ = (Kvr1 )x + αj · Vr1 , r = H1 (pubi ski ) and R = r · P . It then computes

Sig = H4 (Kvr2 αj pubrsu ) · ski + σVr1 + r and Vs = H5 (uT3 Sigpubi ).
Next, it calculates Kvt = pubT A · ski , Vta = H6 (pubi P IDj1 Kvt pubrsu ),
and Vrta = Vta ⊕ μ. Finally, it sends the message M3 = {Sig, Vs , Vrta , T3 } to
the RSU.

Fig. 2. Authentication phase.

4. After the RSU receives M3 , it first verifies the timestamp |T3 − T3 | ≤ T . If


the verification passes, it calculates μ = (Kvr1

)x + αj · Vr1 . It then uses Kvr2

Anonymous Authentication Scheme with Traceable Pseudonym Management 111

and R to verify that the following equality holds:



Sig · P = H4 (Kvr2 αj pubrsu ) · pubi + θVr1 + R (2)

If the equality holds, it calculates Vs = H5 (u T3 Sigpubi ) and further
verifies whether Vs = Vs holds. The RSU generates a timestamp T4 , com-
putes A = αj · P , Krt = pubT A · skrsu , Krvt = αj · pubT A , and Vrt1 =
H7 (pubi Krt Krvt SigT4 ). To allow the TA to verify the vehicle and prove
that the RSU has validated the vehicle’s signature, the RSU calculates
Vrta ⊕ μ = Vta and Vvta = Vta ⊕ H1 (pubi P IDj1 ). Finally, the RSU sends
M4 = {P IDj1 , Vrt1 , Sig, pubrsu , A, Vvta , T4 } to the TA.
5. After the TA receives M4 , it first verifies the timestamp |T4 − T4 | ≤ T .
If the timestamp is valid, it retrieves pubi by calling the smart contract
with P IDj1 , then calculates Krt 
= pubrsu · skT A , Krvt 
= skT A · A, and
   
Vrt1 = H7 (pubi Krt Krvt SigT4 ). It verifies whether Vrt1 = Vrt1 holds.If
this equality is satisfied, it computes H1 (pubi P IDj1 ) ⊕ Vvta = Vta , Kvt 
=

skT A · pubi , and Vta = H6 (pubi P IDj1 Kvt  
pubrsu ). If Vta = Vta holds, the
verification is successful (Fig 2).

3.5 User Revocation


The proposed scheme provides a malicious user revocation mechanism. If it is a
malicious user, the only contract will be called to vote for the malicious user, and
other RSUs also need to vote for the user according to the submitted relevant
evidence. At the same time, the smart contract is invoked to upload the voting
results and the share corresponding to σ required to recover the real identity ID of
the malicious user to the blockchain. Then, TA calls the malicious user to revoke
the corresponding smart contract to recover the malicious user’s real identity
ID, delete the public key uploaded to the blockchain and the corresponding
pseudonym.

4 Security Analysis
This section provides an informal security analysis of the proposed scheme.
1. Replay Attack: In this scheme, the messages M1 , M2 , M3 , and M4 carry
timestamps Ti during the authentication process. Upon receiving a mes-
sage, the receiver calculates the difference between the current time and the
received timestamp. It then checks whether the timestamp is valid. If the
timestamp passes the validation, the message is accepted; otherwise, it is
discarded.
2. Impersonation Attack: Assume that an attacker A intercepts the authen-
tication request message of vehicle Vi and attempts to impersonate Vi to
generate a valid request message, such as M1 = P IDj1 , Aj , Vh , T1 , θ. The
generation of this message depends on the long-term secret key ski , σ, and
112 M. Wang et al.

Algorithm 3. User Revocation


Require: P IDj , pubi
Ensure: SU M
1:  Check whether the function caller is an RSU
2:  RSU s is an address array of RSU
3:  P ro is the vehicle malicious behavior proof
4:  SU M is the total number of votes
5: if msg.sender ∈ / RSU s then
6: return false
7: else
8: for each RSU do
9: if P ro is true then
10: SU M = SU M + 1
11: else
12: SU M = SU M
13: end if
14: end for
15: if SU M ≥ t then
16: publist[pubi ].del(P IDj )
17: publist[pubi ].check(P IDj ) → flase
18: publist.del(pubi )
19: end if
20: return true
21: end if

the short-term
 random key αj . The  calculation of Vh is given by: Vh =
Kvr2 ⊕ H P IDj1 pubi Kvr1 αj T1 , where Kvr2 is calculated as:Kvr2 =
H(ski ·pubrsu Kvr1 T1 ). Therefore, the attacker cannot impersonate the legit-
imately registered vehicle Vi . Clearly, the proposed scheme can resist vehicle
impersonation attacks.
3. Mutual Authentication: In the proposed scheme, the roadside unit (RSU)
authenticates the vehicle by verifying Vh in message M1 and the signature
Sig along with Vs in message M3 , while the vehicle authenticates the RSU
by validating Vr1 and Vr2 in message M2 . Additionally, the trusted authority
(TA) authenticates both the vehicle and the RSU by checking Vrt1 , Vta , and
the vehicle’s signature in message M4 . This process ensures the legitimacy
and integrity of all exchanged messages.
4. Conditional Anonymity: In the proposed scheme, the vehicle Vi commu-
nicates with the RSU and other vehicles using a pseudonym P IDj1 , thereby
providing anonymity for Vi . The identity IDi of the vehicle Vi is hidden within
P IDj1 = IDi ⊕ σ ⊕ αj , and the vehicle has multiple pseudonyms P IDj1 , which
change with the value of αj . Thus, an attacker cannot determine the actual
identity of the vehicle Vi . However, in the case of malicious behavior, the real
identity can be traced through the pseudonym.
5. Traceability and Revocation: The pseudonym P IDj1 of a vehicle Vi is
defined as P IDj1 = IDi ⊕σ⊕αj , where the recovery key σ is distributed among
Anonymous Authentication Scheme with Traceable Pseudonym Management 113

RSU using a secret sharing scheme. If Vi engages in malicious behavior or is


reported, RSU assess the case and initiate a smart contract vote, sharing their
portions of σ. Once a voting threshold is met, σ is reconstructed, enabling
the identification of Vi and revocation of its blockchain-related information.
6. Modification Attack: The messages M1 , M2 , M3 , and M4 sent by entities
contain authentication hash codes Vh , Vr2 , Vs , and Vrt1 , respectively, to ensure
the integrity of the transmitted parameters. If any part of the message is
tampered with, the recipient will be unable to verify the hash value included
in the message. Therefore, under this scheme, modification attacks cannot be
successfully executed.

5 Performance Analysis
In this section, we analyze the performance of our proposed scheme in terms of
communication overhead and computational. For smart contracts, the Solidity
language is used to write smart contracts, Ganache serves as the blockchain
environment, and the Truffle framework is utilized for contract deployment and
testing.

5.1 Communication Cost Analysis


When calculating the communication overhead, the timestamp and RSU identi-
fiers are set to 128 bits, while the vehicle identity ID, random numbers, keys, and
hash functions are 256 bits, and G is 320 bits. The proposed scheme involves the
transmission of four messages, with their corresponding communication over-
head as follows: M1 = {P IDj1 , Aj , Vh , T1 , θ} amounts to [256 + 256 + 256 +
128 + 320] = 1216 bits, M2 = {Vr2 , T2 } amounts to [256 + 128] = 384 bits,
M3 = {Sig, Vs , Vrta , T3 } amounts to [256 + 256 + 256 + 128] = 896 bits,
M4 = {P IDj1 , Vrt1 , Sig, pubrsu , A, Vvta , T4 } amounts to [256 + 256 + 256 +
320 + 320 + 256 + 128] = 1792 bits.Thus, the total communication overhead is
[1216 + 384 + 896 + 1792] = 4288 bits = 536 bytes. The communication cost
comparison results with other schemes are shown in Table 2, and the communi-
cation cost varying with the number of vehicles is shown in Fig. 3, which shows
that our scheme has low communication cost.

Table 2. Communication Cost Analysis

Scheme Number of messages Communication cost


Yang et al. [20] 4 824 Bytes
Wang et al. [21] 4 912 Bytes
Zhang et al. [22] 6 696 Bytes
Zhou et al. [23] 4 1096 Bytes
Our scheme 4 536 Bytes
114 M. Wang et al.

Fig. 3. Communication Cost.

5.2 Computational Cost Analysis

Let Tecm , Tepa , Tmtp , Tbp , Texp , Th , and Tenc/dec represent elliptic curve mul-
tiplication, elliptic curve addition, scalar multiplication, bilinear pairing, mod-
ular exponentiation, hash function, and symmetric key encryption/decryption,
respectively. To better evaluate the computational cost, we use the pbc library to
simulate these operations. The time for each operation is as follows: Tecm = 0.826
ms, Tepa = 0.038 ms, Tmtp = 0.026 ms, Tbp = 2.952 ms, Texp = 1.507 ms,
Th = 0.037 ms, and Tenc/dec = 0.031 ms. Based on this, we analyzed the schemes
[20–23], and our proposed scheme. In our proposed scheme, the computational
cost at the vehicle side is 5Tecm + 9Th = 4.463 ms, the RSU side’s cost is
8Tecm + 10Th + 1Tepa = 7.016 ms, and the TA’s cost is 3Tecm + 3Th = 2.589 ms.
Thus, the total computational cost for our scheme is 14.068 ms. The comparison
results are shown in Table 3, and the computational overhead of different entities
is shown in Fig. 4.

Table 3. Computational Cost Analysis

Schemes Vehicle RSU TA/CSP/RA Total


Yang et al. 7Tecm + 5Th + 3Tbp = 14.823 4Tecm + 4Th + 7Tbp = 24.116 - 38.939
Wang et al. 4Tecm + 7Th + 9Texp + 2Tbp = 23.030 2Tecm + 4Th + 5Texp + 2Tbp = 15.239 1Tecm + 5Th + 7Texp + 6Tbp = 29.272 67.541
Zhang et al. 8Tecm + 12Th = 7.052 9Tecm + 10Th + 2T epa = 7.880 3Tecm + 3Th = 2.589 17.521
Zhou et al. 5Tecm + 8Th = 4.426 - 4Tbp + 4Texp + 3Th = 17.947 36.031
Our scheme 5Tecm + 9Th = 4.463 8Tecm + 10Th + 1T epa = 7.016 3Tecm + 3Th = 2.589 14.068
Anonymous Authentication Scheme with Traceable Pseudonym Management 115

Fig. 4. Computational Cost.

6 Conclusion
In our research, a blockchain-based pseudonym management and anonymous
authentication protocol is proposed. By combining blockchain technology with
secret sharing scheme, it not only protects user privacy, but also has trace-
ability and enhances the decentralization ability of the system. The protocol
used the RSU to call the smart contract for vehicle authentication, and used
the non-tampering property of the blockchain to ensure the legitimacy of the
authenticated user, so as to ensure the accuracy and reliability of the informa-
tion in the VANETs. Security analysis shows that the protocol can effectively
resist common attacks and meet the security requirements of VANETs. In terms
of performance, our protocol has significant advantages over other schemes.

Acknowledgments. This work is supported by the National Key R&D Program


of China (No. 2023YFB2703700), the National Natural Science Foundation of China
(Nos. U21A20465, 62172292, 62402444, 62402446), the Fundamental Research Funds
of Zhejiang Sci-Tech University under Grants No. 22222266-Y, the Program for Lead-
ing Innovative Research Team of Zhejiang Province (No. 2023R01001), the Zhejiang
Provincial Natural Science Foundation of China (Nos. LQ24F020011, LQ24F020012)
and the “Pioneer” and “Leading Goose” R&D Program of Zhejiang (No. 2023C01119).
Open Research Fund of Guangdong Key Laboratory of Blockchain Security, Guangzhou
University.

References
1. Zhu, F., Li, Z., Chen, S., Xiong, G.: Parallel transportation management and con-
trol system and its applications in building smart cities. IEEE Trans. Intell. Transp.
Syst. 17(6), 1576–1585 (2016)
2. Tan, H., Zheng, W., Vijayakumar, P.: Secure and efficient authenticated key man-
agement scheme for UAV-assisted infrastructure-less IoVs. IEEE Trans. Intell.
Transp. Syst. 24(6), 6389–6400 (2023)
116 M. Wang et al.

3. Zhou, Y., Wang, Z., Qiao, Z., Yang, B., Zhang, M.: An efficient and provably
secure identity authentication scheme for VANET. IEEE Internet Things J. 10(19),
17170–17183 (2023)
4. Tan, H., Zheng, W., Guan, Y., Lu, R.: A privacy-preserving attribute-based
authenticated key management scheme for accountable vehicular communications.
IEEE Trans. Veh. Technol. 72(3), 3622–3635 (2023)
5. Gupta, M., Benson, J., Patwa, F., Sandhu, R.: Secure V2V and V2I communication
in intelligent transportation using cloudlets. IEEE Trans. Serv. Comput. 15(4),
1912–1925 (2020)
6. Guan, F., Zhu, T., Zhou, W., Choo, K.: Graph neural networks: a survey on the
links between privacy and security. Artif. Intell. Rev. 57(2), 40 (2024)
7. Lv, S., Tan, H., Zheng, W., Zhang, T., Wang, M.: A dynamic conjunctive keywords
searchable symmetric encryption scheme for multiple users in cloud computing.
Comput. Commun. 209, 239–248 (2023)
8. Zhang, G., Liu, B., Zhu, T., Ding, M., Zhou, W.: PPFed: a privacy-preserving
and personalized federated learning framework. IEEE Internet Things J. 11(11),
19380–19393 (2024)
9. Eiza, M.H., Owens, T., Ni, Q.: Secure and robust multi-constrained QoS aware
routing algorithm for VANETs. IEEE Trans. Dependable Secure Comput. 13(1),
32–45 (2015)
10. Lin, C., Huang, X., He, D.: Ebcpa: efficient blockchain-based conditional privacy-
preserving authentication for VANETs. IEEE Trans. Dependable Secure Comput.
20(3), 1818–1832 (2022)
11. Tan, H., Zheng, W., Vijayakumar, P., Sakurai, K., Kumar, N.: An efficient vehicle-
assisted aggregate authentication scheme for infrastructure-less vehicular networks.
IEEE Trans. Intell. Transp. Syst. 24(12), 15590–15600 (2023)
12. Son, S., Lee, J., Park, Y., Park, Y., Das, A.K.: Design of blockchain-based
lightweight V2I handover authentication protocol for VANET. IEEE Trans. Net-
work Sci. Eng. 9(3), 1346–1358 (2022)
13. Ma, Q., Tan, H., Zhou, T.: Mutual authentication scheme for smart devices in
IoT-enabled smart home systems. Comput. Stand. Interfaces 86, 103743 (2023)
14. Vasudev, H., Deshpande, V., Das, D., Das, S.K.: A lightweight mutual authenti-
cation protocol for V2V communication in internet of vehicles. IEEE Trans. Veh.
Technol. 69(6), 6709–6717 (2020)
15. Li, X., Liu, T., Obaidat, M.S., Wu, F., Vijayakumar, P., Kumar, N.: A lightweight
privacy-preserving authentication protocol for VANETs. IEEE Syst. J. 14(3),
3547–3557 (2020)
16. Lee, J., Kim, G., Das, A.K., Park, Y.: Secure and efficient honey list-based authen-
tication protocol for vehicular ad hoc networks. IEEE Trans. Network Sci. Eng.
8(3), 2412–2425 (2021)
17. Xie, Q., Ding, Z., Tang, W., He, D., Tan, X.: Provable secure and lightweight
blockchain-based V2I handover authentication and V2V broadcast protocol for
VANETs. IEEE Trans. Veh. Technol. 72(12), 15200–15212 (2023)
18. Xue, K., Luo, X., Ma, Y., Li, J., Liu, J., Wei, D.: A distributed authentication
scheme based on smart contract for roaming service in mobile vehicular networks.
IEEE Trans. Veh. Technol. 71(5), 5284–5297 (2022)
19. Lin, C., Huang, X., He, D.: Ebcpa: efficient blockchain-based conditional privacy-
preserving authentication for VANETs. IEEE Trans. Dependable Secure Comput.
20(3), 1818–1832 (2023)
Anonymous Authentication Scheme with Traceable Pseudonym Management 117

20. Yang, A., Weng, J., Yang, K., Huang, C., Shen, X.: Delegating authentication
to edge: a decentralized authentication architecture for vehicular networks. IEEE
Trans. Intell. Transp. Syst. 23(2), 1284–1298 (2022)
21. Wang, Y., Ding, Y., Wu, Q., Wei, Y., Qin, B., Wang, H.: Privacy- preserving
cloud-based road condition monitoring with source authentication in VANETs.
IEEE Trans. Inf. Forensics Secur. 14(7), 1779–1790 (2019)
22. Zhang, J., Zhong, H., Cui, J., Tian, M., Xu, Y., Liu, L.: Edge computing-based
privacy-preserving authentication framework and protocol for 5G-enabled vehicu-
lar networks. IEEE Trans. Veh. Technol. 69(7), 7940–7954 (2020)
23. Zhou, Y., et al.: A novel cloud-assisted authentication key agreement protocol for
VANET. IEEE Trans. Veh. Technol. 73(9), 13526–13541 (2024)
Multi-keyword Searchable Data Auditing
for Cloud-Based Machine Learning

Haiyan Yu1,2,3 , Qingru Ma2,3,4(B) , Yilu Zhu1,2,3 , and Yuxin Cui1,2,3


1 School of Information Science and Engineering (School of Cyber Science and
Technology), Zhejiang Sci-Tech University, Hangzhou 310018, China
2 Zhejiang Key Laboratory of Digital Fashion and Data Governance, Zhejiang

Sci-Tech University, Hangzhou 310018, China


[email protected]
3 Zhejiang Provincial International Cooperation Base for Science and Technology on
Cloud Computing Security and Data Aggregation, Zhejiang Sci-Tech University,
Hangzhou 310018, China
4 School of Computer Science and Technology (School of Artificial Intelligence),

Zhejiang Sci-Tech University, Hangzhou 310018, China

Abstract. Machine learning (ML) relies on large amounts of data for


training, which often requires outsourcing data to cloud servers. This
raises concerns about data integrity and security. Preventing data from
being tampered with or lost during storage and transmission has become
a critical issue in the ML process. This paper presents a novel data audit-
ing scheme for cloud-based ML, which supports multi-keyword search
functionality. It allows searching one or more keywords to ensure the
integrity of data in model training. To avoid data privacy leakage, the
keywords are encrypted during the auditing process. Only the encrypted
keywords of the data required for model training need to be provided to
the third-party auditor (TPA). Subsequently, TPA can check and ensure
the integrity of the relevant data. Through security analysis and perfor-
mance evaluations, the results indicate that this scheme can effectively
and reliably audit outsourced data while ensuring the security and pri-
vacy of model training data.

Keywords: Machine Learning · Data Auditing · Multi-keyword ·


Outsourced Storage · Privacy Preservation

1 Introduction
In the information age, the rapid growth of data has driven the continuous rise in
the demand for large-scale data processing and analysis. Machine learning (ML)
has been extensively implemented and applied in various big data applications
across multiple fields, including image processing [1], information extraction [2],
data cleaning [3], and so on. ML efficiently extracts valuable knowledge from
large amounts of data. For resource-constrained clients, a common approach is
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 118–132, 2025.
https://doi.org/10.1007/978-981-96-4566-4_9
Multi-keyword Searchable Data Auditing 119

to outsource the storage of data to cloud service providers (CSPs). However,


CSPs are susceptible to various internal and external security attacks [4]. To
prevent financial losses and protect their reputation, CSPs may conceal data loss
issues from users, leading to serious privacy and security concerns. To address
this challenge, many researchers have focused on outsourced data auditing to
ensure the integrity of the training data maintained in the cloud storage.
Moreover, data transmission will become more complex when model train-
ing data is outsourced to cloud. During model training, relevant data needs to
be retrieved from big data stored in the cloud. Usually, only a portion of the
data with shared common or key features is needed for model training. In this
scenario, by adding specific keyword tags to the data, relevant data for model
training can be quickly located in the cloud, thereby reducing unnecessary data
transmission and processing time. At the same time, to enhance the reliability
of model training, retrieved datasets need to undergo auditing to verify their
integrity and security. Due to the large amount of outsourced data, the audit
process requires powerful computing power to support it. Most existing auditing
schemes tackle this issue by outsourcing auditing tasks to the third-party auditor
(TPA) and verifying data integrity through the TPA. This approach helps min-
imize computational costs and auditing overhead, thus improving the efficiency
of model training.

Motivation of This Paper: In model training, outsourcing data raises sig-


nificant privacy concerns. First, CPSs may be dishonest. To conceal data loss
or tampering, they may forge proofs to pass the TPA verification. Second, A
semi-trusted TPA, who may have an interest in the training data, might infer
sensitive information through keyword tags. Therefore, ensuring the integrity of
the auditing process for outsourced data remains challenging. Additionally, the
efficient retrieval of training-related data via keyword tags becomes a key issue
for improving training efficiency.
In summary, the contribution is as below:
1. In the context of intelligent computing, a searchable auditing scheme is pro-
posed for training data maintained in the cloud storage. This ensures the
integrity and security of training data, thus improving the quality of model
training.
2. A multi-keyword searchable auditing approach is supported. The cloud can
retrieve data using multiple keywords, which narrows the data range and
reduces redundant data retrieval compared to single-keyword searching, thus
enhancing the efficiency of model training.
3. During the auditing phase, the one-way property of hash functions guarantees
that the TPA cannot obtain the private information or real data of the model
training through analysis, thereby effectively ensuring privacy protection.

2 Related Work
The ML solutions have been driven by the explosion of data to effectively extract
and analyze useful insights from such data. Sabek et al. [5] introduced recent
120 H. Yu et al.

work on ML and big data. With the rapid development of ML, most of the rele-
vant research is concentrated on outsourced scenarios. There is growing concern
about the leakage of privacy information in training data under various models.
To achieve privacy-preserving machine learning (PPML), many scholars have
proposed various solutions. Liu et al. [6] proposed a scalable privacy-enhancing
aggregation scheme that effectively handles participant dropouts and improves
system efficiency. Mohassel et al. [7] introduced the SecureML scheme based on
Homomorphic Encryption and proposed an efficient privacy-preserving machine
learning protocol. Zhang et al. [8] proposed a privacy-preserving and personal-
ized federated learning framework (PPFed), which enhances personalized perfor-
mance while ensuring data privacy by adopting a model-layer partition strategy.
However, PPML typically requires significant computational resources and stor-
age space. To minimize data processing and storage costs, data can be outsourced
to CPSs. With the powerful computing capabilities and resource-sharing features
of the cloud, model training can also avoid the need for deploying expensive
hardware locally.
With the trend of outsourcing data for model training to cloud platforms, the
challenges of privacy preservation and data security have become progressively
more intricate. Chang et al. [9] introduced a defense approach based on gradi-
ent sparsification and pseudo-gradients, using cosine similarity to protect user
privacy during model training, thereby providing a certain level of privacy pro-
tection for cloud-based machine learning tasks. Although these methods reduce
the risk of privacy leakage, outsourcing data storage and processing in cloud
computing still presents many security vulnerabilities. Xiao et al. [10] pointed
out that attackers can efficiently crack encryption with minimal side-channel
data, further threatening user privacy and cloud data security. Furthermore, the
integrity of outsourced data during model training has also received considerable
attention. Numerous schemes for outsourced data auditing have been proposed.
Armknecht et al. [11] launched the concept of Outsourced Proofs of Retrievability
(OPOR) and proposed the first outsourced auditing scheme, called Fortress. In
recent years, many auditing schemes for outsourcing data have been presented.
Shen et al. [4] proposed an efficient public auditing protocol that achieves secure
verification of outsourced data. To mitigate the risk of privacy breaches, Rao et
al. [12] proposed a dynamic outsourced auditing scheme that can prevent any
dishonest entities and collisions.
Aiming at data training and analysis in ML, how to quickly retrieve out-
sourced data from cloud servers has become an important issue. Boneh et al. [13]
first proposed the use of a keyword search (PEKS) mechanism based on public-
key encryption, enabling cloud servers to provide a searchable encrypted data
set for user-selected keywords. Since then, several new PEKS-related schemes
have been proposed [14, 15]. However, the use of a single keyword search method
often returns a large number of irrelevant results, which undoubtedly signifi-
cantly affects the user’s search experience, reducing both the accuracy and effi-
ciency of the search. To deal with this challenge, the PEKS mechanism sup-
porting multi-keyword search is more feasible in the cloud. Liu et al. [16] pro-
Multi-keyword Searchable Data Auditing 121

posed a new searchable encryption scheme that enables multi-keyword search


over encrypted data under a multi-writer/multi-reader setting. Wang et al. [17]
proposed a biometric-based multi-keyword search (BIB-MKS) scheme, which
avoids complex certificate management.
To sum up, researchers have conducted extensive studies to advance the
development of ML. However, the storage security and retrieval efficiency of
outsourced training data often conflict with each other. The above methods do
not take both aspects into account simultaneously.

3 Preliminaries

Some important preliminaries, including Bilinear pairing, and discrete logarithm


problem, for this paper are listed in this section.

3.1 Bilinear Pairing

if a map e: 𝐺 1 × 𝐺 2 → 𝐺 𝑇 satisfies the following properties, it is referred to as


a bilinear map:
 
1) Bilinearity: For all 𝑎, 𝑏 ∈ Z∗𝑞 and 𝑢 1 , 𝑢 2 ∈ 𝐺 1 , 𝑒 𝑢 1𝑎 , 𝑢 2𝑏 = 𝑒 ( 𝑢 1 , 𝑢 2 ) 𝑎𝑏
2) Non-degeneracy: if 𝑢 is the generator of 𝐺 1 . 𝑒 ( 𝑢, 𝑢) ≠ 1.
3) Computability: There exist efficient algorithms to compute 𝑒 ( 𝑔, 𝑣) for all
𝑔, 𝑣 ∈ 𝐺 1 .

3.2 Discrete Logarithm Problem

Given ( 𝑔, 𝑔 𝛼 ) ∈ 𝐺 1 ,𝛼 ∈ Z∗𝑞 ,Discrete Logarithm (DL) problem is to calculate 𝛼.


If the likelihood of an adversary solving the DL problem in the group 𝐺 1 in any
polynomial time is negligible, then the DL problem cannot be computed in the
group 𝐺 1 .

4 Scheme Model
The system model, security targets, and security model of the cloud-based
machine learning data auditing scheme will mainly be described in this section.
These elements will act as the foundation and operational basis for the design
scheme.

4.1 System Model

The system model mainly includes four entities: cloud, key generation center
(KGC), model trainer, and third-party auditor (TPA), as shown in Fig. 1. The
specific description is as follows:
122 H. Yu et al.

Fig. 1. System Model

Cloud: The cloud serves to store datasets uploaded by model trainers, as well as
encrypted keyword sets and corresponding label sets. With powerful computing
capabilities and ample storage resources, the cloud provides an audit-proof to
verify that the files in the cloud are stored correctly and completely.

Key Generation Center: The KGC is used for creating the required keys. It
can generate keys based on the identity information sent by the model trainer
and return them through a secure channel.

Model Trainer: The model trainer is responsible for generating several key-
words for each data block and uploading the dataset, encrypted keywords, and
corresponding labels to the cloud.

Third Party Auditors: The TPA is responsible for verifying the accuracy
of the training data stored in the cloud. Once a request for auditing from the
model trainer is received, the TPA initiates an auditing challenge to the cloud.
Subsequently, the TPA verifies the validity of the proof provided by the cloud
and communicates the results to the model trainer.

4.2 Security Targets


Correctness: If the cloud completely stores the dataset uploaded by the model
trainer, the proof produced by the cloud should successfully withstand the veri-
fication procedure.

Soundness: The proof returned by the cloud can be verified by the TPA, if the
cloud has not completely stored the dataset uploaded by the model trainer.
Multi-keyword Searchable Data Auditing 123

Privacy Preservation: During the TPA audit process, no private data infor-
mation of the model trainer can be obtained from either the trainer’s delegated
tasks or the audit proof returned from the cloud.

4.3 Security Model


Honest and Trusted KGC: An honest and trusted KGC ensures the security
of the encryption process and communicates with the model trainer through a
secure channel to prevent key leakage or misuse.

Honest but Curious TPA: The TPA will honestly and correctly execute the
system’s algorithms. However, the TPA may be curious about the data searched
by the model trainer and could potentially infer sensitive information about the
data through keywords.

Semi-honest Cloud: A semi-honest cloud can still correctly execute the algo-
rithms set by the system. However, to protect its financial interests and uphold
its reputation, CSPs are likely to lie to users to conceal issues of data damage
or loss.

5 Our Proposed Scheme

The design details of data retrieval and auditing for outsourced model training
data in the cloud will be described in this section. It mainly includes seven stages.
Specifically, the details of each stage are outlined below.

Setup: The KGC executes the algorithm. By using the input security param-
eter 𝑘, the KGC generates the public parameter 𝑝 and the master key 𝑥. The
operations involved are outlined as follows.

1) The KGC chooses a bilinear map 𝐺 1 × 𝐺 1 → 𝐺 2 , where 𝐺 1 and 𝐺 2 are two


multiplicative groups of order 𝑞. Let 𝑔, 𝑔1 , and 𝑢 be generators of 𝐺 1 . The
KGC selects a private key 𝑥 ∈ Z∗𝑞 at random and computes 𝑌 = 𝑔 𝑥 .
2) The KGC chooses four distinct cryptographic hash functions for use in the
system: ℎ, ℎ1 , ℎ2 , ℎ3 : {0, 1}∗ → 𝐺 1 .
3) The KGC makes the following system parameters publicly available: 𝑃 =
{ 𝑞, 𝑔, 𝑔1 , 𝑢, 𝐺 1 , 𝐺 2 , ℎ, ℎ1 , ℎ2 , ℎ3 }.

Key Generation: The KGC executes the algorithm. The model trainer pro-
vides the identity ID, the system public parameters 𝑃, and the master key
𝑥. Using this information, the KGC generates the private key, which will be
returned to the model trainer through a secure channel. Specific operations are
as follows.
124 H. Yu et al.

1) The model trainer provides the identity ID to the KGC. The KGC randomly
selects 𝑟 ∈ Z∗𝑞 and computes 𝑅 = 𝑔𝑟 , 𝑘 = 𝑥 · ℎ ( ID  𝑅 ) mod 𝑞.
2) The KGC sends the private key 𝑘 ID = ( 𝑅, 𝑘 ) to the model trainer through a
secure channel.
3) For the received private key 𝑘 ID , the model trainer can verify its correctness
by checking whether the formula 𝑔 𝑘 = 𝑌 ℎ ( ID 𝑅) holds. If it holds, the private
key is correct; otherwise, the private key is incorrect.

Tag Generation: The algorithm is run on the model trainer’s side, where the
data set, corresponding encrypted keywords, and their labels are uploaded to
the cloud. The specific steps are as follows.
1) Assume there are 𝑛 data blocks in a certain model training. The model trainer
intends to upload these 𝑛 data blocks to cloud-based storage, where 𝐵 𝑥 rep-
resents the data block and 𝐼 𝑥 represents the corresponding index of the data
block (1 ≤ 𝑥 ≤ 𝑛). Each data block is set with 𝑚 keywords, which are repre-
sented as 𝑤 𝑥 𝑦 , where 1 ≤ 𝑦 ≤ 𝑚.
2) The model trainer selects a number 𝑑 ∈ Z∗𝑞 at random, and computes 𝐷 = 𝑢 𝑑 .
By encrypting all the keywords,
 𝑑
𝑊 𝑥 𝑦 = ℎ1 𝑤 𝑥 𝑦

and generates keyword labels to facilitate searching in the cloud,


 
𝑉𝑥 𝑦 = ℎ3 𝑒 (𝑊 𝑥 𝑦 , 𝐷)  𝑌  𝐼 𝑥
3) The label for each data block is computed as:
 1
𝑘
( ℎ2 ( 𝑅|𝑌 | 𝑦 𝑊𝑥 𝑦 ) +𝑔 ( 𝐼𝑥 ) )
𝜎𝑥 = 𝑔1

4) The data block label set { 𝜎𝑥 }1 ≤ 𝑥 ≤ 𝑛 , the encrypted keyword set


{𝑊 𝑥 𝑦 }1 ≤ 𝑥 ≤ 𝑛,1 ≤ 𝑦 ≤ 𝑚 , and the generated data block and identity set
{ 𝐼 𝑥 , 𝐵 𝑥 }1 ≤ 𝑥 ≤ 𝑛 are uploaded to the cloud.

Challenge Generation: Based on the audit task sent by the model trainer,
the TPA issues a challenge to the cloud.
1) When the model trainer wants to search for related datasets in the cloud that
contain the keyword 𝑤 ∗ , a number 𝑓 ∈ Z∗𝑞 is chosen at random, and 𝐹 = 𝑢 𝑓
is computed. The target keyword is encrypted as follows:

𝑊 = ( ℎ1 ( 𝑤 ∗ ) ) 𝑓

The calculation result is considered the search target’s trapdoor 𝑇𝑊 , and the
model trainer sends {𝑇𝑊 , 𝑃, 𝑅, 𝐹 } to the TPA.
2) The TPA chooses a number 𝛾 ∈ Z∗𝑞 at random, and computes 𝐸 = 𝑔 𝛾 .
3) The TPA transmits the challenge {𝑇𝑊 , 𝑃, 𝑅, 𝐹, 𝐸 } to the cloud.
Multi-keyword Searchable Data Auditing 125

Target Information Search: The cloud executes the algorithm. Based on the
challenge issued by the TPA, the cloud retrieves the encrypted keyword’s label
and transmits to the model trainer the index of the data block that holds the
target keyword.
1) The cloud searches for the set of data blocks uploaded by the model based on
the parameter 𝑌 , defining this as the searchable range. The keyword labels
within this range are as follows:

{𝑉𝑥 𝑦 } = ℎ3 ( 𝑒 (𝑊 𝑥 𝑦 , 𝐹 )  𝑌  𝐼 𝑥 ) ( 1 ≤ 𝑥 ≤ 𝑛, 1 ≤ 𝑦 ≤ 𝑚 )

2) Then, the cloud searches for the data keyword labels within this range:
?
{𝑉𝑥 𝑦 } = ℎ3 ( 𝑒 (𝑊, 𝐷 )  𝑌  𝐼 𝑥 )

If this equality holds, there exists a set of data blocks containing the keyword
label, denoted as { 𝐼 𝑥 }, where 𝑥 ∈ 𝐿. Otherwise, if equality does not hold,
return 0 to the TPA.
3) The set of data indices { 𝐼 𝑥 }, where 𝑥 ∈ 𝐿, is sent to the model trainer.

Proof Generation: The cloud runs this algorithm. It generates a proof based
on the data index set { 𝐼 𝑥 }, 𝑥 ∈ 𝐿 and returns it to the TPA.
1) The cloud searches for information based on the data block set { 𝐼 𝑥 }, where
𝑥 ∈ 𝐿, and computes the data block keyword label set 𝛿 and the data block
label set 𝜇:   
𝛿 = 𝐸 𝑥 ∈ 𝐿 ℎ2 ( 𝑅 | |𝑌 | | 𝑦 𝑊𝑥 𝑦 ) · 𝐸 𝑥 ∈ 𝐿 𝐼𝑥

𝜇= 𝜎𝑥
𝑥 ∈𝐿

2) The cloud will send to the TPA the proof { 𝛿, 𝜇 }.

Proof Verification: The TPA executes the algorithm to verify whether the
following equation is satisfied, based on the proof returned by the cloud:

𝑒 ( 𝛿, 𝜇) = 𝑒(𝑌 ℎ ( ID 𝑅) , 𝑔1 ) 𝛾
?

If this equation holds, the TPA returns 1 to the model trainer, indicating that
the required dataset is fully stored in the cloud; otherwise, the TPA returns 0,
indicating that the data has been tampered with or corrupted.

6 Our Proposed Further Scheme


The above scheme is designed for auditing a collection of data blocks for a sin-
gle keyword. However, a single keyword search may return many unnecessary
datasets, which would certainly reduce the effectiveness of model training. To
126 H. Yu et al.

narrow the search scope and achieve more precise searching, we extend the above
scheme to support the retrieval and auditing of datasets based on multiple key-
words.
Assuming the model trainer provides 𝑐 keywords, which are denoted as a
keyword set Δ𝑤 ∗ = { 𝑤 1 , 𝑤 2 , . . . , 𝑤 𝑐 }. The key setup, key generation, and tag
generation algorithms remain consistent with the previous scheme. The remain-
ing stages are outlined in detail as follows:

Challenge Generation: The TPA issues a challenge to the cloud.


1) The model trainer selects a number 𝑓 ∈ Z∗𝑞 at random and calculates 𝐹 = 𝑢 𝑓
to encrypt the keyword set Δ𝑤 ∗ :
 𝑓
Δ𝑊𝑖 = ℎ1 ( Δ𝑤 ∗𝑖 ) ( 1 ≤ 𝑖 ≤ 𝑐)

The set of calculation results is considered the search target’s trapdoor set
Δ𝑇𝑊 , and the model trainer sends {Δ𝑇𝑊 , 𝑃, 𝑅, 𝐹 } to the TPA.
2) The TPA randomly selects a number 𝛾 ∈ Z∗𝑞 and calculates 𝐸 = 𝑔 𝛾 .
3) The TPA sends the challenge {Δ𝑇𝑊 , 𝐹, 𝑃, 𝑅, 𝐸 } to the cloud.

Target Information Search: The cloud executes the algorithm. The cloud
searches for data that satisfies all keywords.
1) The cloud searches for the set of data blocks uploaded by the model based on
the parameter 𝑌 , defining this as the searchable range. The keyword labels
within this range are as follows:

{𝑉𝑥 𝑦 } = ℎ3 ( 𝑒 (𝑊 𝑥 𝑦 , 𝐹 )  𝑌  𝐼 𝑥 ) ( 1 ≤ 𝑥 ≤ 𝑛, 1 ≤ 𝑦 ≤ 𝑚 )

Then, the cloud performs a search for the first keyword label within the
specified range. It computes the value 𝑉:
?
𝑉 = ℎ3 ( 𝑒 (𝑊1 , 𝐷)| |𝑌 || 𝐼 𝑥 )

2) The cloud records the data block index set corresponding to the keyword
tags as { 𝐼 𝑥 }, where 𝑥 ∈ 𝐿 1 . This set is then treated as a new searchable range.
Within this new searchable range, the keyword tags are calculated as follows:
 
𝑉𝑥 𝑦 = ℎ3 𝑒 (𝑊 𝑥 𝑦 , 𝐹 ) |||𝑌 ||| 𝐼 𝑥 (𝑥 ∈ 𝐿1, 1 ≤ 𝑦 ≤ 𝑚)

Then, the cloud performs a search for the second keyword label within the
specified range. It computes the value ( V ) as follows:

𝑉 = ℎ3 ( 𝑒 (𝑊𝑖 , 𝐷)| |𝑌 || 𝐼 𝑥 ) (𝑥 ∈ 𝐿1)

3) Once the equation does not hold, return 0 to the TPA; if it holds, continue
until 𝑖 = 𝑐, and then obtain the set of data block locations as { 𝐼 𝑥 }, where
𝑥 ∈ 𝐿∗.
Multi-keyword Searchable Data Auditing 127

Proof Generation: The cloud runs this algorithm and generates a proof based
on the data index set { 𝐼 𝑥 }, 𝑥 ∈ 𝐿 ∗ which will be returned to the TPA.

1) The cloud searches for information based on the data block set { 𝐼 𝑥 }, 𝑥 ∈ 𝐿,
then computes the data block keyword label set 𝛿 and the data block label
set 𝜇:   
𝛿 = 𝐸 𝑥 ∈ 𝐿∗ ℎ2 ( 𝑅 | |𝑌 | | 𝑦 𝑊𝑥 𝑦 ) · 𝐸 𝑥 ∈ 𝐿∗ 𝐼𝑥

𝜇= 𝜎𝑥
𝑥 ∈ 𝐿∗

2) The cloud will send to the TPA the proof { 𝛿, 𝜇 }.

Proof Verification: The TPA executes the algorithm to verify whether the
following equation is satisfied, based on the proof returned by the cloud:

𝑒 ( 𝛿, 𝜇) = 𝑒(𝑌 ℎ ( ID 𝑅) , 𝑔1 ) 𝛾
?

If this equation holds, the TPA returns 1 to the model trainer, indicating that
the required dataset is fully stored in the cloud; otherwise, the TPA returns 0,
indicating that the data has been tampered with or corrupted.

7 Security Analysis
7.1 Correctness

Theorem 1. If the model trainer, the TPA, and the cloud properly carry out
the protocol, and the data integrity proof formula holds, then the scheme is
considered to meet the requirements for audit correctness.
Based on the definition of bilinear pairing, the process of data integrity proof
is as follows:

   
ℎ2 ( 𝑅 | |𝑌 | | 𝑊𝑥 𝑦 ) 𝑔 𝐼𝑥
𝑒( 𝛿, 𝜇) = 𝑒 ( 𝐸 𝑥 ∈ 𝐿∗ 𝑦 ·𝐸 𝑥 ∈ 𝐿∗ , 𝜎𝑥 )
𝑥 ∈ 𝐿∗
 𝑘
    1
( ℎ2 ( 𝑅|𝑌 |𝑦 𝑊𝑥 𝑦 ) +𝑔 𝐼𝑥 )
ℎ2 ( 𝑅 | |𝑌 | | 𝑦 𝑊𝑥 𝑦 ) 𝑔 𝐼𝑥
= 𝑒(𝐸 𝑥 ∈ 𝐿∗ ·𝐸 𝑥 ∈ 𝐿∗ , 𝑔1 )
𝑥 ∈ 𝐿∗
1
  𝑘·  
𝛾· [ ( ℎ2 ( 𝑅 | |𝑌 | | 𝑊𝑥 𝑦 ) +𝑔 𝐼 𝑥 ] [ (
𝑥 ∈ 𝐿 ∗ ℎ2 𝑅|𝑌 | 𝑦 𝑊𝑥 𝑦 +𝑔
𝐼𝑥 ) ]
= 𝑒 (𝑔 𝑥 ∈ 𝐿∗ 𝑦 , 𝑔1 )
= 𝑒( 𝑔 , 𝑔1𝑘 )
𝛾

= 𝑒 ( 𝑔 𝑘 , 𝑔1 ) 𝛾
= 𝑒 (𝑌 ℎ ( ID 𝑅) , 𝑔1 ) 𝛾
128 H. Yu et al.

Theorem 2. If the verification formula for the encrypted keyword tag search is
valid, the drill considers the search verification to be correct.
When searching for a certain keyword 𝑤 ∗ ,the correctness of the search result
is proved as follows:

𝑉 = ℎ 3 ( 𝑒 (𝑊 ∗ , 𝐹 )  𝑌  𝐼 𝑥 )

= ℎ3 𝑒 ( ℎ1 ( 𝑤 ∗ ) 𝑑 , 𝑢 𝑓 )  𝑌  𝐼 𝑥

= ℎ3 𝑒 ( ℎ1 ( 𝑤 ∗ ) 𝑓 , 𝑢 𝑑 )  𝑌  𝐼 𝑥
= ℎ3 ( 𝑒 (𝑇𝑤 , 𝐷 )  𝑌  𝐼 𝑥 )

7.2 Soundness

Theorem 3. If the cloud cannot forge proof to pass the TPA verification, it is
considered that the scheme is unforgeable.
Due to the DL problem, the cloud cannot infer 𝑑 based on 𝐷 and the gener-
ator 𝑢, making it impossible to forge encrypted keyword labels;

7.3 Privacy Protection

Theorem 4. The TPA is an honest but curious entity. During the auditing
process, the TPA is unable to obtain the private information of the data through
encrypted keywords.
Proof: Suppose adversary A is synonymous with the TPA and is curious about
data for model training. It can be proven that the TPA cannot obtain specific
information.
1) Adversary A receives the public parameters sent by challenger C. Due to the
DL problem, A cannot crack the key x based on the public parameters 𝑌 = 𝑔 𝑥
and the generator 𝑔, meaning it cannot obtain the model trainer’s identity
information.
2) Adversary A receives the encrypted keyword 𝑊 = ℎ1 ( 𝑤 ∗ ) 𝑓 sent by the chal-
lenger C. Because the hash function is unidirectional, adversary A cannot
reverse deduce the keyword information 𝑤 ∗ of the user from the hash value.
In summary, a curious TPA cannot access any private details or real data dur-
ing the auditing process. Therefore, this guarantees data privacy during model
training.

8 Performance Analysis
In this section, We evaluated the performance of our proposed scheme, and
compared it with the experimental results of the schemes proposed by Wang X.
et al. [18] and Wang M. et al. [19].
Multi-keyword Searchable Data Auditing 129

8.1 Computation Overhead

In the performance analysis, we defined corresponding symbols to represent dif-


ferent operations. 𝐻1 and 𝐻2 each performed a hash operation in 𝐺 1 and Z∗𝑞 ,
respectively. 𝑀1 and 𝑀2 represented multiplication operations in 𝐺 1 and Z∗𝑞 ,
respectively. 𝑃𝑜 denoted an exponentiation operation in 𝐺 1 , A represented an
addition operation in Z∗𝑞 , and 𝑃𝑎 denoted a bilinear pairing operation. Assum-
ing there were m files, n data blocks, y keywords, and c challenged data blocks.
The computational costs of the aforementioned schemes were compared through
theoretical analysis, as presented in Table 1.

Table 1. Computational Cost Comparison.

phase Our scheme Wang X.et al. Wang M.et al.


Tag/signature generation 𝑛 · ( 𝐻1 + 𝑦 · 𝐴 + 3 · 𝑃𝑜 ) 𝑚 · 𝑛 · ( 2 · 𝑝𝑜 + 𝐻1 + 2 · 𝑀2 ) 𝑚 · [ ( 2𝑛 + 1) · 𝐻1 + 2𝑛 · 𝑀1 + ( 2𝑛 + 1) · 𝑃𝑜 ]
Keyword verification 𝐻1 + 𝑃𝑎 𝐴 + 𝑀2 + 𝑃𝑜 + 2 · 𝑃𝑎 
Proof generation 𝑐 · ( ( 𝑦 + 3) · 𝐴 + 𝐻1 ) + 2 · 𝑃𝑜 ( 2 · 𝑐 · 𝑚 ) · 𝑃𝑜 + 𝑐 · 𝑚 · 𝑀1 𝑐 · 𝑃𝑜 + ( 𝑛 + 2𝑐 − 2) · 𝑀1 + ( 𝑐 · 𝑛 − 1) · 𝐴 + ( 𝑐 · 𝑛) 𝑀2
Proof verification 2 · 𝑝𝑎 + 𝐻1 + 2 · 𝑃𝑜 3 · 𝑃𝑎 + 𝑃𝑜 + 𝑀1 2 · 𝑃𝑎 + ( 𝑐 + 1) · 𝑀1 + 𝑐 · 𝐻1 + 𝐻2 + ( 𝑐 + 2) · 𝑃𝑜

8.2 Experimental Results

The simulation for this experiment was run on the Ubuntu 22.04.5 system, con-
figured with an Intel Core i7-12650H CPU 2.68 GHz, and 3.8 GB of RAM.
The experiment utilized the GMP and PBC libraries and was implemented in C
language.
In the experiment, we assumed there was 1 file, with each data block contain-
ing 10 keywords. The experiment was conducted in four phases: Tag/signature
generation, Keyword verification, proof generation, and proof verification. The
specific experimental results were as follows.

Tag/Signature Generation Phase: In this phase, the cost is related to the


total data blocks. We set the data block count to vary between 100 and 800,
increasing in increments of 50. As shown in Fig. 2, our scheme is more cost-
effective than Wang M. et al.’s scheme because Wang et al. introduced auxiliary
labels, which increased computational costs. In Wang X. et al.’s scheme, the
signature processing step executed in this phase is relatively simple. Considering
both security and efficiency, the scheme we proposed offers more advantages.

Keyword Verification Phase: The cost in this phase is independent of the


amount of data blocks. We use the method of performing multiple repetitions
and calculating the average, with the results presented in Fig. 3. As shown in the
figure, our scheme has lower costs compared to the scheme proposed by Wang
X. et al.
130 H. Yu et al.

7 0.5
Our scheme Our scheme
6 Wang X.et al.'s scheme Wang X.et al.'s scheme
Wang M.et al.'s scheme 0.4
5
Time costs (s)

Time costs (s)


4 0.3

3
0.2

2
0.1
1

0 0
100 200 300 400 500 600 700 800 10 20 30 40 50 60 70 80
The number of date blocks (n) The number of repetitions (r)

Fig. 2. Cost comparison of Fig. 3. Cost comparison of Keyword


Tag/signature generation phase verification phase

Proof Generation Phase: This phase’s cost is related to the count of data
blocks used for challenges. We configure the data block count for challenges
between 100 and 800 (thus, the total amount of blocks 𝑛 is configured to 1000),
increasing by 50 blocks at each step. When the data block count is small, our
scheme has a similar cost to the schemes proposed by Wang M. et al. and Wang
X. et al. as shown in Fig. 4. As the count of data blocks increases, our scheme
exhibits an evidently lower cost.

Proof Verification Phase: In this phase, both our scheme and the scheme
proposed by Wang X. et al. show that the cost and challenges are not affected
by the data block count. Compared to Wang M.’s scheme, both significantly
improve audit efficiency. Our scheme has lower time costs and higher efficiency
compared to the scheme proposed by Wang X. et al. as shown in Fig. 5.

2.5 3.5
Our scheme Our scheme
Wang X.et al.'s scheme 3 Wang X.et al.'s scheme
2 Wang M.et al.'s scheme Wang M.et al.'s scheme
2.5
Time costs (s)

Time costs(s)

1.5
2

1.5
1

1
0.5
0.5

0 0
100 200 300 400 500 600 700 800 100 200 300 400 500 600 700 800
The number of challenge date blocks(c) the number of challenge date blocks(c)

Fig. 4. Cost comparison of proof gener- Fig. 5. Cost comparison of proof verifi-
ation phase cation phase
Multi-keyword Searchable Data Auditing 131

In summary, the simulation experiments conducted in this section demon-


strate that our scheme offers advantages in both security and efficiency.

9 Conclusion

This paper proposes an innovative scheme for multi-keyword searchable data


auditing in the context of cloud-based machine learning. During the model train-
ing process, it is necessary to retrieve and search for data outsourced to cloud
storage. The scheme we proposed supports multi-keyword retrieval and audits
the security and integrality of the training data with the assistance of the TPA.
Moreover, the TPA cannot infer relevant data or private information from the
keywords, and the cloud cannot forge proofs. After the auditing process is com-
pleted, the model trainer can quickly extract the relevant dataset using data
block indices and proceed with model training. The theoretical analysis and
experimental results demonstrate that the scheme can effectively and reliably
protect the integrality and privacy of outsourced model training data.

Acknowledgments. The work is supported by the National Key R&D Program of


China (No. 2023YFB2703700), the National Natural Science Foundation of China (Nos.
U21A20465, 62302457, 62172292), the Zhejiang Provincial Natural Science Foundation
of China (No. LQ24F020008), the Program for Leading Innovative Research Team of
Zhejiang Province (No. 2023R01001), the Fundamental Research Funds of Zhejiang
Sci-Tech University (No. 22222266-Y) and the “Pioneer” and “Leading Goose” R&D
Program of Zhejiang (No. 2023C01119).

References
1. Krizhevsky, A., Sutskever, I., Hinton, G.E.: ImageNet classification with deep con-
volutional neural networks. In: Bartlett, P.L., Pereira, F.C.N., Burges, C.J.C.,
Bottou, L., Weinberger, K.Q. (eds.) Advances in Neural Information Processing
Systems 25: 26th Annual Conference on Neural Information Processing Systems
2012. Proceedings of a Meeting Held 3–6 December 2012, Lake Tahoe, Nevada,
United States, pp. 1106–1114, (2012)
2. De Sa, C., Ratner, A., Ré, C., Shin, J., Wang, F., Sen, W., Zhang, C.: Incremen-
tal knowledge base construction using deepdive. VLDB J. 26(1), 81–105 (2017).
https://doi.org/10.1007/S00778-016-0437-2
3. Rekatsinas, T., Chu, X., Ilyas, I.F., Ré, C.: HoloClean: holistic data repairs with
probabilistic inference. Proc. VLDB Endow. 10(11), 1190–1201 (2017). https://
doi.org/10.14778/3137628.3137631
4. Shen, J., Shen, J., Chen, X., Huang, X., Susilo, W.: An efficient public auditing
protocol with novel dynamic structure for cloud data. IEEE Trans. Inf. Forensics
Secur. 12(10), 2402–2415 (2017). https://doi.org/10.1109/TIFS.2017.2705620
5. Sabek, I., Mokbel, M.F.: Machine learning meets big spatial data. In: 2020 IEEE
36th International Conference on Data Engineering (ICDE), pp. 1782–1785 (2020).
https://doi.org/10.1109/ICDE48307.2020.00169
132 H. Yu et al.

6. Liu, Z., Guo, J., Lam, K.-Y., Zhao, J.: Efficient dropout-resilient aggregation for
privacy-preserving machine learning. IEEE Trans. Inf. Forensics Secur. 18, 1839–
1854 (2023). https://doi.org/10.1109/TIFS.2022.3163592
7. Mohassel, P., Zhang, Y.: SecureML: a system for scalable privacy-preserving
machine learning. In: 2017 IEEE Symposium on Security and Privacy (SP), pp.
19–38 (2017). https://doi.org/10.1109/SP.2017.12
8. Zhang, G., Liu, B., Zhu, T., Ding, M., Zhou, W.: PPFed: a privacy-preserving
and personalized federated learning framework. IEEE Internet Things J. 11(11),
19380–19393 (2024). https://doi.org/10.1109/JIOT.2024.3360153
9. Chang, W., Zhu, T.: Gradient-based defense methods for data leakage in vertical
federated learning. Comput. Secur. 139, 103744 (2024)
10. Xiao, Z., Wang, C., Shen, J., Jonathan Wu, Q.M., He, D.: Less traces are all it
takes: Efficient side-channel analysis on AES. IEEE Trans. Comput.-Aided Design
Integr. Circuits Syst. (2024). https://doi.org/10.1109/TCAD.2024.3518414
11. Armknecht, F., Bohli, J.-M., Karame, G.O., Liu, Z., Reuter, C.A.: Outsourced
proofs of retrievability. In: Proceedings of the 2014 ACM SIGSAC Conference
on Computer and Communications Security, CCS 2014, pp. 831–843. Association
for Computing Machinery, New York (2014). https://doi.org/10.1145/2660267.
2660310
12. Rao, L., Zhang, H., Tengfei, T.: Dynamic outsourced auditing services for cloud
storage based on batch-leaves-authenticated merkle hash tree. IEEE Trans. Serv.
Comput. 13(3), 451–463 (2020). https://doi.org/10.1109/TSC.2017.2708116
13. Boneh, D., Di Crescenzo, G., Ostrovsky, R., Persiano, G.: Public key encryption
with keyword search. In: Cachin, C., Camenisch, J.L. (eds.) EUROCRYPT 2004.
LNCS, vol. 3027, pp. 506–522. Springer, Heidelberg (2004). https://doi.org/10.
1007/978-3-540-24676-3_30
14. Zhang, X., Chunxiang, X., Wang, H., Zhang, Y., Wang, S.: FS-PEKS: lattice-based
forward secure public-key encryption with keyword search for cloud-assisted indus-
trial internet of things. IEEE Trans. Dependable Secure Comput. 18(3), 1019–1032
(2021). https://doi.org/10.1109/TDSC.2019.2914117
15. Liu, Q., et al.: Authorized keyword search on mobile devices in secure data out-
sourcing. IEEE Trans. Mob. Comput. 23(5), 4181–4195 (2024). https://doi.org/
10.1109/TMC.2023.3288160
16. Liu, X., Yang, G., Susilo, W., Tonien, J., Liu, X., Shen, J.: Privacy-preserving
multi-keyword searchable encryption for distributed systems. IEEE Trans. Parallel
Distrib. Syst. 32(3), 561–574 (2021)
17. Zhang, X., Huang, C., Dawu, G., Zhang, J., Wang, H.: BIB-MKS: post-quantum
secure biometric identity-based multi-keyword search over encrypted data in cloud
storage systems. IEEE Trans. Serv. Comput. 16(1), 122–133 (2023). https://doi.
org/10.1109/TSC.2021.3112779
18. Wang, X., Zhang, X., Zhang, X., Miao, Y., Xue, J.: Enabling anonymous autho-
rized auditing over keyword-based searchable ciphertexts in cloud storage systems.
IEEE Trans. Serv. Comput. 16(6), 4220–4232 (2023). https://doi.org/10.1109/
TSC.2023.3315972
19. Wang, M., Jia, Yu., Shen, W., Hao, R.: Privacy-preserving time-based auditing
for secure cloud storage. IEEE Trans. Inf. Forensics Secur. 19, 7866–7878 (2024).
https://doi.org/10.1109/TIFS.2024.3449095
A Flexible Keyword-Based PIR Scheme
with Customizable Data Scales
for Multi-server Learning

Jingang Li, Zhaoyi Liu, Huijie Yang(B) , Tianqi Zhou, and Wenying Zheng

School of Information Science and Engineering (School of Cyber Science and


Technology), Zhejiang Sci-Tech University, Hangzhou, China
[email protected], [email protected]

Abstract. As the digital era progresses, data has become an invalu-


able asset for businesses and organizations. Conducting in-depth analy-
ses of specific data categories requires a strategy that ensures data pri-
vacy while effectively leveraging distributed data resources. However,
existing technologies face two major challenges: first, accurately defin-
ing the boundary of training data is difficult, leading to the inclusion
of irrelevant data that can negatively impact model training accuracy;
second, user-side retrieval privacy is vulnerable to breaches, potentially
exposing the user’s training intentions to the server. To address these
challenges, we propose a novel user-proxy server-multi-server framework
within the context of Private Information Retrieval (PIR), designed to
protect user privacy while learning from specific datasets. Building on
this framework, we introduce a keyword-based PIR protocol tailored for
multi-server deep learning models. This protocol allows users to query
and retrieve targeted datasets from servers via a proxy server, enabling
high-precision model training with specific data subsets. Experimental
results demonstrate that, under various query volumes and server con-
figurations, the proposed framework significantly reduces response times
and improves query efficiency. Additionally, it exhibits remarkable scala-
bility and adaptability, handling dynamic database updates and varying
dataset sizes. This approach provides a secure and effective solution for
training deep learning models across multiple data sources, making it
particularly suitable for applications in smart factories and intelligent
manufacturing environments, where data privacy is critical.

Keywords: Keyword PIR · Deep Learning · Proxy server


Re-encryption · Multiple-Server

1 Introduction
Since the concept of PIR is proposed by Chor et al. [7], there has been a growing
focus on enhancing the efficiency and privacy of PIR systems for user queries.
Depending on the data storage method in the server’s database, PIR can be
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 133–147, 2025.
https://doi.org/10.1007/978-981-96-4566-4_10
134 J. Li et al.

divided into two main types: index-based PIR and keyword-based PIR. Index-
based PIR, the traditional technique, involves a user requesting specific entries
by their position to construct a query vector. Notable methods in index-based
PIR include those proposed by Melchor et al. [11], Angel et al. [3], and Ali et
al. [2].
However, as databases increasingly adopt key-value pair storage, recent
research has shifted towards keyword-based PIR protocols. In these protocols,
each database entry is associated with one or more keywords, allowing users to
retrieve information by specifying these keywords. Recent advancements in key-
word PIR have addressed several critical areas: reducing communication over-
head, supporting batch queries, optimizing performance for sparse databases,
and enabling multi-keyword searches. Mahdavi et al. [10] introduced constant-
weight equality operators that allowed for efficient single-round, single-server
keyword retrieval. Davidson et al. [5] proposed the ChalametPIR protocol which
used Binary Fuse filters to minimize bandwidth usage during keyword searches.
Patel et al. [13] presented SparsePIR to improve encoding time, offering better
cost-effectiveness for sparse databases. Meanwhile, Pandya et al. [12] developed
Bloom Swizzler, which enhanced query efficiency through efficient indirect layers
in sparse datasets. In terms of batch queries, Liu et al. [9] presented PIRANA sys-
tem to support frequent updates while integrating tagged private set intersection
techniques. Tovey et al. [15] presented DPIR protocol to facilitate task distribu-
tion and ensures metadata privacy in distributed environments. Yang et al. [19]
proposed FGF-mkPIR, a scheme that employed batch oblivious pseudo-random
functions for flexible multi-keyword queries, ensuring data integrity and correct-
ness in AIoT settings. Tong et al. [14] implemented verifiable multi-keyword
queries on encrypted data using advanced hashing techniques like Bloom fil-
ters and Cuckoo hashing. The practical applications of keyword PIR protocols
have also expanded into various fields. Xie et al. [18] designed an optimized
two-step keyword PIR protocol that enabled lightweight Bitcoin users to effi-
ciently retrieve transactions from full nodes without revealing query details.
Vadapalli et al. [16] created PIRSONA, a digital content delivery system that
protected user consumption patterns while providing personalized recommenda-
tions. Ahmad et al. [1] introduced Coeus, a document retrieval system that com-
bined secure matrix-vector multiplication with keyword PIR for oblivious rank-
ing and retrieval. Vora and Hegde [17] developed a sublinear-time-complexity
scheme for keyword-based private searching on cloud data, including mecha-
nisms for keyword association and dissociation.
Furthermore, privacy protection and data processing in distributed environ-
ments have received increasing attention. Zhang et al. [21] explored the iden-
tification of unlearning risks in edge computing environments, which held sig-
nificant implications for our research, especially in defining the boundaries of
training data and protecting user retrieval privacy. Chen et al. [6] highlighted the
importance of high-frequency information in watermark attacks and defensed for
image-processing models, providing valuable insights for our study, particularly
in preventing model stealing attacks. Zhang et al. [20] proposed the PPFed frame-
A Flexible Keyword PIR Scheme for Multi-server Learning 135

work, offering a privacy-preserving and personalized federated learning solution


that guided our design of efficient and secure data access mechanisms, especially
for high-precision classification model training on specific datasets in multi-server
environments.
Despite significant advancements in enhancing the efficiency and privacy
protection of PIR protocols, several challenges remain. For instance, efficiently
handling multi-keyword queries while maintaining query privacy in large-scale
databases is still an open issue. We propose a novel PIR protocol aimed at fur-
ther optimizing these problems. Specifically, we propose a keyword-based PIR
protocol and innovatively construct a user-proxy server-multi servers framework,
enabling the user to access specific data across different servers, which is then
aggregated on the proxy server for model training.
In this framework, a keyword-based HashPIR protocol is proposed, utiliz-
ing hash functions to transform keywords into coordinates within a hypercube
structure. This approach effectively converts keyword PIR queries into index
PIR methods. Additionally, both the user and servers use the same hash map-
ping technique, ensuring the accuracy of private information retrieval. Finally,
the protocol is highly flexible, allowing server data to be updated without the
need for reindexing. The contributions of this manuscript are as follows:
– The user-proxy server-multi servers framework is designed in
keyword-based PIR scheme. This design leverages the intermediary role
of the proxy server to not only provide model training services for users lack-
ing computational resources but also to isolate unnecessary access by users
to server data.
– A keyword-based protocol HashPIR is proposed. This method uses a
hash function to map all entries in the server’s database onto a hypercube,
allowing users to reduce their desired keywords to corresponding coordinates
in the recursive query of the hypercube. Consequently, this transforms key-
word PIR into index PIR, streamlining the query process.
– The proposed HashPIR supports customizable data scales. By adopt-
ing a hypercube structure, our system can efficiently handle dynamic data
updates and datasets of varying sizes. The hypercube mapping ensures that
the system remains scalable and responsive to real-time changes without
requiring reordering or reindexing.
The remainder of this paper is organized as follows. The preliminaries about
Brakerski-Fan-Vercauteren Homomorphic Encryption (BFV) and structure of
hypercube are described in Sect. 2. The system model is introduced in Sect. 3.
The keyword-based PIR scheme is proposed in Sect. 4. The performance analysis
is presented in Sect. 5. Section 6 illustrates the conclusion of this paper.

2 Preliminary
2.1 Brakerski-Fan-Vercauteren Homomorphic Encryption
Brakerski-Fan-Vercauteren Homomorphic Encryption (BFV) [4, 8], which is a
fully homomorphic encryption scheme based on the Ring Learning With Errors
136 J. Li et al.

(RLWE) problem. This allows the server to perform certain types of computa-
tions on the encrypted data without decrypting it, which is particularly useful
for protecting sensitive data.
BFV operates on the ring R = Z[x]/(xn + 1), where n is the degree of the
polynomial and xn + 1 is a reducible polynomial. A modulus q is chosen such
that q > 2n2 and q is an odd number.
During the generation of the public-private key pair, the private key s ∈ R
is selected first, followed by the construction of the public key pk = (a, b), where
a is a randomly chosen polynomial and b = as + e mod q, with e being a small
noise term.
The plaintext m is encoded as an element in the ring Rt = R/tR =
Zt [x]/(xn + 1), where t is a small prime number. Subsequently, the ciphertext
c = (c0 , c1 ) is computed, where c0 = u · b + v · m + e and c1 = u · a + e , with e
and e being small noise polynomials.
Finally, decryption involves computing m = (c0 + c1 · s)/t mod q to recover
the original plaintext polynomial m.

Fig. 1. Hypercube Recursion

2.2 The Structure of Hypercube

In the server, there are a total of n entries, each stores at a unique coordinate
within a hypercube denotes as D. The hypercube possesses d dimensions, d with
each dimension having length of l1 , l2 , . . . , ld , respectively, such that i=1 li ≥
n. When a user wishes to query the specific data, it constructs corresponding
A Flexible Keyword PIR Scheme for Multi-server Learning 137

encrypted query vectors based on coordinates in each dimension and can then
perform the query recursively.
For example, in a two-dimensional database, Fig. 1 illustrates the process of
hypercube recursion. The user separately constructs horizontal encrypted query
vector vrow and vertical encrypted query vector vcol , with each coordinate of
hypercube D storing unique data. In the first round of recursion, the intermediate
result D2 is obtained by the dot product of vcol and D. In the second round of
recursion, the final result is obtained by the dot product of vrow and D2 , which
is then returned to the user.
The advantage of using
d a hypercube lies in the fact that only an encrypted
query vector of length i=1 li needs to be transmitted, whereas without this
structure, an encrypted query vector of length n would have to be sent.

3 System Model and Formal Definition


3.1 User-Proxy Server-Multi Servers Framework

Assuming there exist N servers in a certain region, each server independently


stores data in key-value pairs. To effectively collect specific data from different
servers and avoid the interference of irrelevant data on the performance of clas-
sification models, the user specifies the data and the number of data for deep
learning that each server should participate in, and achieves interaction through
a proxy server. We refer to this framework as the user-proxy server-multi servers
framework, which is shown in Fig. 2.
The user, lacking computational resources, can communicate with the par-
titioned servers via the proxy server to download trained models. The proxy
server not only forwards the query vectors from the user to each server but also
isolates the training data by using proxy server re-encryption, sending only the
trained weight files to the user. This approach effectively protects the privacy of
the training data from each server.
By incorporating PIR, the servers do not know the specific data content
specified by the users, thereby effectively preventing the risk of leaking research
intentions. Moreover, specifying only a subset of data for deep learning training,
rather than all data from the servers, significantly reduces the possibility of
server data features being stolen.
In this scheme, the deep learning model training phase occurs after the
Extract step of PIR. Since the aim of this scheme is to train models for users
without allowing them direct access to the data on the servers, the answer vector
from PIR is first decrypted on the proxy server using proxy server re-encryption.
Subsequently, deep learning training is conducted on the proxy server. Once
training is complete, the trained weight files are encrypted using a symmetric
encryption algorithm and returned to the user. The user then decrypts the files
using the symmetric encryption key to obtain the trained weights.
138 J. Li et al.

Fig. 2. User-proxy server-Multi Servers Framework

3.2 Formal Definition

Definition 1. The HashPIR scheme is defined over a key space K =


 1 , k2 , . . . , km } and l-bit blocks, composed of an efficient set of algorithms
{k
HashPIR = (Key Generation, Offline Setup, Query, Answer, Extract). To pro-
vide a more concise formal definition, in this description, the client sends only
a single keyword request to the server. In the actual system model, however, the
client sends a set of keywords K. The HashPIR described as follows:

(pk, sk, keyP RE , keys ) ← Key Generation(1λ ): It takes a security parameter


λ and outputs the homomorphic encryption public key pk, client-held homomor-
phic encryption private key sk, proxy server re-encryption key keyP RE , and the
symmetric encryption key keys .
D ← Offline Setup(n, d, l): Given the number of entries n, dimension d, and
length l such that ld ≥ ni , it outputs the hypercube D. Each entry from the
database D = {(k1 , v1 ), . . . , (kn , vn )} is mapped to a position in a d-dimensional
hypercube using a hash function.
q ← Query(k, d, l): It takes the query keyword k, dimension d, and length
l as input, first generates a d-dimensional vector (v1 , v2 , ..., vd ) through a hash
function, then uses the Query function of SealPIR to generate the query vector q.
Finally, it encrypts the query vector q using the BFV homomorphic encryption
scheme, and sends the encrypted query vector to the proxy server.
A ← Answer(q, D): It takes encrypted query vectors from the user and the
database D in hypercube form, and through the recursive form of the hypercube,
it outputs the encrypted response vector A extracted from the database.
kvalue ← Extract(keyP RE , A): It decrypts all response vectors using the re-
encryption key keyP RE to recover the original data value kvalue .
A Flexible Keyword PIR Scheme for Multi-server Learning 139

Correctness. For any database D = {(k1 , v1 ), . . . , (kn , vn )}, where k1 , . . . , kn ∈


K are distinct, v1 , . . . , vn ∈ {0, 1}L , and for any query keyword k ∈ K, if the
query k is not in D[k], the probability that the client outputs is at most negl(λ).

Clinet’s Query Privacy. Client query privacy is defined with respect to a bit
b ∈ {0, 1} and an adversary A. For any database D and for all stateful, efficient
adversaries A, the probability that the adversary successfully distinguishes the
bit b in the experiment is bounded in Eq. (1).
1
b,A (λ, D) = b] ≤
Pr[Expclient + negl(λ). (1)
2

4 HashPIR: A Keyword-Based PIR


Due to the fact that in actual business operations, server databases predomi-
nantly store data in key-value pairs, we propose a flexible keyword-based PIR
protocol named HashPIR, which supports dynamic databases. This protocol
converts keywords into coordinates within a hypercube using a hash function
and stores the corresponding values at these coordinates, thereby transforming
keyword PIR into index PIR. Index PIR is then invoked to perform standard
retrieval.

4.1 Init Phase


The pseudocode for key distribution in the Init phase is shown in Algorithm 1
and includes the following steps:

– Through the security parameter 1λ , The user generates a homomorphic


encryption public-private key pair (pk, sk) and a symmetric encryption key
keys . The user sends the homomorphic encryption public key pk to the proxy
server used for proxy server re-encryption to genernate keyP RE .
– The partitioned servers convert the database data into plaintext polynomial
form.
– The proxy server uses proxy server re-encryption to generate a re-encryption
key keyP RE that can decrypt BFV ciphertexts.

To efficiently and independently store ni entries in each server Si , each server


uses a hash function to map the data into an independent d-dimensional hyper-
cube. Each data key is mapped to a sparse vector p, which determines the posi-
tion of the data’s value within the hypercube. The pseudocode for this process is
shown in Algorithm 3. Specifically, in each d-dimensional hypercube, the length
of each dimension is l, such that ld ≥ ni .
For an entry (k, kvalue ) within the server, where k is the key, we first select
a hash function H to compute the hash value h = H(k) of the key k. For
each dimension dim ∈ {1, 2, . . . , d}, extract bits from position (dim − 1) ∗ 8 to
(dim ∗ 8 − 1) from h as hdim . Initialize a zero vector vdim of length l for each
140 J. Li et al.

Algorithm 1. HashPIR.Init
1: Input: Security parameter 1λ , number of entrys ni , dimension d, length l such
that ld ≥ ni .
2: Output: Homomorphic encryption public key pk, symmetric encryption key keys ,
re-encryption key keyP RE , Hypercube D1 , D2 , . . . , Dm .
3: Key Generation:
4: Generate homomorphic encryption keys (pk, sk) using security parameter 1λ .
5: Generate symmetric encryption key keys .
6: Send pk and keys to the proxy server.
7: proxy server generates re-encryption key keyP RE using pk.
8: Offline Setup in Each Server Si :
9: Each server Si transforms database data into plaintext polynomial form.
10: Initialize an empty hypercube storage D1 , D2 , . . . , Dm ← ∅.
11: for each entry (k, kvalue ) in Si do
12: Compute hash value h ← H(k).
13: Initialize d vectors {v1 , v2 , . . . , vd } each of length l with all zeros.
14: for each dimension dim ∈ {1, 2, . . . , d} do
15: Extract bits (dim − 1) ∗ 8 to (dim ∗ 8) − 1 from h as hdim .
16: Calculate index idxdim ← hdim mod l.
17: Set vdim [idxdim ] ← 1.
18: end for
19: Store value kvalue in hypercube Di at position (v1 , v2 , . . . , vd ).
20: end for
21: return pk, keys , keyP RE , D1 , D2 , . . . , Dm

dimension. Compute the index idxdim = hdim mod l using the extracted 8-bit
segment hdim , and set the idxdim -th position of vdim to 1. The position of 1 in
the vector represents the location of the data in that dimension.
In this way, positions are sequentially activated across each dimension, and
the value v from the entry (k, kvalue ) is stored at position (v1 , v2 , . . . , vd ) in the
hypercube D. For example, if d = 2, v1 = [0 1], and v2 = [0 1], it indicates
that the value v should be stored at the second position of the first dimension
and the second position of the second dimension. All values in each server are
independently mapped into a hypercube D based on their corresponding keys.

4.2 Query Phase


In the Query phase, the user can specify a set of m keywords K =
{k1 , k2 , . . . , km }, and needs to construct query vectors v1 , v2 , . . . , vd for each
specified keyword k. The process by which the user constructs the query vector
is consistent with how the server constructs the storage vector, using the same
hash function H(k) to obtain (v1 , v2 , . . . , vd ). For each dimension of the query
vector, this step transforms the desired keyword into coordinates correspond-
ing to the hypercube, thereby converting keyword PIR into index PIR, which is
shown in Algorithm 2.
A Flexible Keyword PIR Scheme for Multi-server Learning 141

The user uses the Query function of SealPIR to generate the query vector q
from the d-dimensional vector (v1 , v2 , . . . , vd ), and then encrypts it using BFV.
Since the user can specify data from a total of m servers within this framework
to participate in training, they construct (q1 , q2 , . . . , qm ) and send them to the
proxy server.

Algorithm 2. HashPIR.Query (In one server)


1: Input: Set of keywords K = {k1 , k2 , . . . , km }, number of dimensions d, length l.
2: Output: Encrypted query vectors (q1 , q2 , . . . , qm ).
3: Initialize an empty list of encrypted query vectors Q ← ∅.
4: for each keyword kj in K do
5: Compute hash value hj ← H(kj ).
6: Initialize a vector vj ← [0, 0, . . . , 0] of length l for each dimension.
7: for each dimension dim ∈ {1, 2, . . . , d} do
8: Extract bits from position (dim − 1) ∗ 8 to (dim ∗ 8 − 1) from hj as hj,dim .
9: Calculate index idxj,dim ← hj,dim mod l.
10: Set the idxj,dim -th position of vj,dim to 1.
11: end for
12: qj ← SealPIR.Query(vj,1 , vj,2 , . . . , vj,d )
13: Using BFV encryption to get qj .
14: Append qj to the list Q.
15: end for
16: Construct the vector (q1 , q2 , . . . , qm ).
17: Send (q1 , q2 , . . . , qm ) to the proxy server.
18: return (q1 , q2 , . . . , qm )

4.3 Answer and Decrypt Phase

In the Answer phase, the proxy server forwards q1 , q2 , . . . , qm to servers


S1 , S2 , . . . , Sm respectively for constructing response vectors. The pseudocode
for building the response vector is shown in Algorithm 3. In this process, the
database D, which was converted into plaintext polynomials in hypercube form
during the Init phase, is used. Since the d-dimensional encrypted query vectors
for index PIR were obtained in the Query phase, the response vector for the
keyword can be derived using a recursive method on the hypercube. Once all
queries on the server are completed, the results are returned to the proxy server.
In the Decrypt phase, the proxy server receives all response vectors and
decrypts them sequentially using proxy server re-encryption to obtain the plain-
text data retrieved by HashPIR from the m servers. Subsequently, deep learning
training is performed on the proxy server with the decrypted plaintext data.
142 J. Li et al.

Algorithm 3. HashPIR.Answer and HashPIR.Decrypt


1: Input: Encrypted query vectors (q1 , q2 , . . . , qm ), database D in hypercube form.
2: Output: Decrypted data for deep learning training.
3: proxy server Forwards Queries:
4: for each server Sj and each query vector qj do
5: proxy server forwards qj to server Sj .
6: end for
7: HashPIR.Answer
8: for each server Sj do
9: Initialize an empty answer vector Aj ← ∅.
10: for each qj do
11: Retrieve the corresponding data from the hypercube D based on the index
given by qj .
12: Append the answer vector to Aj .
13: end for
14: Return Aj to the proxy server.
15: end for
16: HashPIR.Decrypt
17: Initialize an empty set of decrypted data Ddecrypted ← ∅.
18: for each answer vector Aj received from servers do
19: Use re-encryption key keyP RE to decrypt Aj .
20: Ddecrypted ← Ddecrypted ∪ {kvalue }
21: end for
22: Perform deep learning training on Ddecrypted .
23: return Ddecrypted

4.4 A Flexible Approach


The approach proposed in this manuscript utilizes a keyword PIR protocol to
protect user privacy during query retrieval in a distributed deep learning frame-
work. This system operates seamlessly even when the underlying databases are
updated. For each server Si , a set of mi entries is specified, and the data and
volume on the server may vary across different rounds of model training. In this
framework, the database D on each server can be dynamically updated, and the
HashPIR protocol ensures that the updated database is accurately mapped to
the hypercube without requiring changes to the underlying approach. Addition-
ally, the number of participating servers N can be scaled up or down, providing
flexibility for users. This method is highly adaptable to dynamic and distributed
data sources, offering secure and efficient ways for users to retrieve specific data
from multiple servers while maintaining privacy protection.

4.5 Deep Learning Model Training Phase


The aggregated data on the proxy server is trained using five models: ResNet18,
ResNet34 and ResNet50. The Adam optimizer is used along with the cross-
entropy loss function for training.
A Flexible Keyword PIR Scheme for Multi-server Learning 143

In each step, the model processes one batch of the training set through three
stages: forward propagation, backpropagation, and model parameter updates.
Forward propagation is used to calculate the cross-entropy loss function, as
shown in Eq. (2).


C
L=− yj log(ŷj ) (2)
j=1

where L represents the cross-entropy loss function, yj is the probability that the
true label is class j (if the sample belongs to class j, then yj = 1; otherwise,
yj = 0), and ŷj is the predicted probability by the model that the label is class
j.
To ensure security during transmission, the trained model ω is encrypted
using a symmetric encryption key before being sent to the user. The user decrypts
the model using the same symmetric key. This method allows users without
computational resources to complete deep learning model training via the proxy
server.

5 Performance and Evaluation


5.1 Experimental Platform

The experiment was conducted on a Windows platform equipped with Nvidia


RTX 4060 GPU, i5-12490F CPU, and 32 GB of RAM. The neural network
training and prediction algorithms were implemented using PyTorch 2.0.0.

5.2 Fashion-MNIST Dataset

The dataset used in this experiment is the classic Fashion-MNIST dataset, which
includes images of 10 categories such as t-shirt and trouser. Each image sample
consists of 784 pixels, with each pixel occupying 1 byte. Additionally, the sample
labels are quantized and occupy 1 byte, making a total of 785 bytes per sam-
ple. The training dataset contains 6000 samples per category, while the testing
dataset contains 1000 samples per category, making a total of 10 categories, with
samples from each category as shown in Fig. 4. For this experiment, the total
number of servers is set to N = 6, so the training set is divided into 6 equal
parts, each stored on one server.
144 J. Li et al.

Fig. 3. Validation Loss and Validation Accurary across Epoch

5.3 Deep Learning Experiment

Assuming that the keywords requested by the user are certainly present in the
servers, the user randomly selects 7000 entries from each server for training
purposes, aggregating a total of 42000 entries to be gathered at the proxy server
for training. Afterward, the trained model is tested on the testing dataset.

Table 1. Performance Metrics of Different ResNet Models

Model F1 Score Accuracy Precision Recall


ResNet18 0.9137 0.9119 0.9206 0.9119
ResNet34 0.9117 0.9121 0.9138 0.9121
ResNet50 0.9248 0.9243 0.9262 0.9243

Images retrieved by PIR on the proxy server yields results as shown in Table 1
and Fig. 3. The loss functions for ResNet18, ResNet34, and ResNet50 all progres-
sively decreased to convergence on the validation set, achieving good detection
accuracy. Among these, ResNet50 exhibited the highest detection accuracy.
A Flexible Keyword PIR Scheme for Multi-server Learning 145

Fig. 4. Fashion-MNIST Dataset Class Samples

5.4 Performance of HashPIR

The experimental results in Table 2 provide the performance metrics of the Hash-
PIR system across different dimensions d and data sizes n, including storage
ratio, user CPU costs, server CPU costs, and network costs. Query times are
consistently low across all dimensions, typically around 2–8 ms, indicating effi-
cient query processing. Extraction times increase slightly with dimension but
remain manageable, reaching up to 19 ms for d = 4, which is still within accept-
able limits for most user. Setup times are stable and relatively low, ranging from

Table 2. Performance Metrics for Different Dimensions and Data Sizes

Dims d and Length l 2, 512 3, 64 4, 24


Database_Size n 215 216 215 216 215 216
Storage Ratio 0.96 0.88 0.94 0.89 0.52 0.35
User CPU Costs (ms)
Query 2 2 3 4 8 6
Extract 1 1 6 7 19 19
Serialized query 5 3 8 5 9 8
Server CPU Costs
Setup (sec) 32.62 32.67 32.64 34.27 48.96 51.32
Expand (ms) 625 598 130 128 81 66
Answer (sec) 14.27 14.47 16.01 16.13 50.02 53.72
Query Deserialization (ms) 1 0.89 1.69 1.49 2.66 2
Number of Reply Ciphertexts 4 4 16 16 64 64
Network Costs (KB)
Query 90.70 90.73 136.15 136.41 181.53 181.31
Answer 181.02 181.08 724.64 724.40 2897.74 2898.52
146 J. Li et al.

32.622 to 51.317 s. Answer times increase with higher dimensions but remain
feasible, with the highest being 53.715 s for d = 4 and n = 216 . This indicates
that the system can handle increased computational demands effectively. Net-
work costs for queries are stable and low, generally around 90–181 KB. Answer
network costs increase with higher dimensions but are still acceptable.

6 Conclusion

We propose an innovative user-proxy server-multi servers framework that com-


bines hypercube data indexing and homomorphic encryption techniques to
achieve secure and efficient distributed model training. This framework addresses
two key challenges in multi-server environments: precise control over the scope of
training data and protection of users’ training intentions. By allowing the user to
specify and query target data subsets without directly accessing raw data, this
method effectively avoids interference from irrelevant data on model performance
and ensures the confidentiality of data distribution on the server side. Moreover,
the introduction of a proxy server centralizes the computational process, alle-
viating the burden on users in terms of computational resources. Experimental
results demonstrate that this framework has significant advantages in reducing
communication overhead, improving query efficiency, and enhancing scalability.
It is particularly suitable for distributed data application scenarios that empha-
size privacy protection. Future research could further explore the introduction of
database encoding technologies to improve data storage efficiency in large-scale
scenarios and optimize the allocation of computational resources.

Acknowledgments. This work is supported by the National Natural Science Foun-


dation of China (Nos. U21A20465, 62402448), the Fundamental Research Funds of
Zhejiang Sci-Tech University under Grants Nos. 22222266-Y, 24222238-Y, the Pro-
gram for Leading Innovative Research Team of Zhejiang Province (No. 2023R01001),
the Zhejiang Provincial Natural Science Foundation of China (No. LQ24F020009) and
the “Pioneer” and “Leading Goose” R&D Program of Zhejiang (No. 2023C01119).

References
1. Ahmad, I., Sarker, L., Agrawal, D., El Abbadi, A., Gupta, T.: Coeus: a system for
oblivious document ranking and retrieval. In: Proceedings of the ACM SIGOPS
28th Symposium on Operating Systems Principles, pp. 672–690 (2021)
2. Ali, A., et al.: {Communication–Computation} trade-offs in {PIR}. In: 30th
USENIX Security Symposium (USENIX Security 2021), pp. 1811–1828 (2021)
3. Angel, S., Chen, H., Laine, K., Setty, S.: PIR with compressed queries and amor-
tized query processing. In: 2018 IEEE Symposium on Security and Privacy (SP),
pp. 962–979. IEEE (2018)
4. Brakerski, Z.: Fully homomorphic encryption without modulus switching from clas-
sical GapSVP. In: Annual Cryptology Conference, pp. 868–886. Springer (2012)
A Flexible Keyword PIR Scheme for Multi-server Learning 147

5. Celi, S., Davidson, A.: Call me by my name: simple, practical private informa-
tion retrieval for keyword queries. In: Proceedings of the 2024 on ACM SIGSAC
Conference on Computer and Communications Security, pp. 4107–4121 (2024)
6. Chen, H., Zhu, T., Liu, C., Yu, S., Zhou, W.: High-frequency matters: attack and
defense for image-processing model watermarking. IEEE Trans. Serv. Comput. 17,
1565–1579 (2024)
7. Chor, B., Kushilevitz, E., Goldreich, O., Sudan, M.: Private information retrieval.
J. ACM (JACM) 45(6), 965–981 (1998)
8. Fan, J., Vercauteren, F.: Somewhat practical fully homomorphic encryption. Cryp-
tology ePrint Archive (2012)
9. Liu, J., Li, J., Wu, D., Ren, K.: PIRANA: faster multi-query PIR via constant-
weight codes. In: 2024 IEEE Symposium on Security and Privacy (SP), pp. 4315–
4330. IEEE (2024)
10. Mahdavi, R.A., Kerschbaum, F.: Constant-weight {PIR}: single-round keyword
{PIR} via constant-weight equality operators. In: 31st USENIX Security Sympo-
sium (USENIX Security 22), pp. 1723–1740 (2022)
11. Melchor, C.A., Barrier, J., Fousse, L., Killijian, M.O.: XPIR: private information
retrieval for everyone. Proc. Priv. Enhanc. Technol. 2, 155–174 (2016)
12. Pandya, A.M.: Bloom swizzlers for efficient keyword-based private information
retrieval. Master’s thesis, University of Calgary, Calgary, Canada (2024). https://
prism.ucalgary.ca
13. Patel, S., Seo, J.Y., Yeo, K.: {Don’t} be dense: efficient keyword {PIR} for sparse
databases. In: 32nd USENIX Security Symposium (USENIX Security 2023), pp.
3853–3870 (2023)
14. Tong, Q., Li, X., Miao, Y., Wang, Y., Liu, X., Deng, R.H.: Beyond result verifi-
cation: efficient privacy-preserving spatial keyword query with suppressed leakage.
IEEE Trans. Inf. Forensics Secur. 19, 2746–2760 (2024)
15. Tovey, E., Weiss, J., Gilad, Y.: Distributed PIR: scaling private messaging via
the users’ machines. In: Proceedings of the 2024 on ACM SIGSAC Conference on
Computer and Communications Security, pp. 1967–1981 (2024)
16. Vadapalli, A., Bayatbabolghani, F., Henry, R.: You may also like... privacy: rec-
ommendation systems meet PIR. Proc. Priv. Enhanc. Technol. 2021(4), 30–53
(2021)
17. Vora, A.V., Hegde, S.: Keyword-based private searching on cloud data along with
keyword association and dissociation using cuckoo filter. Int. J. Inf. Secur. 18,
305–319 (2019)
18. Xie, Y., Zhang, C., Wei, L., Niu, Y., Wang, F.: Private transaction retrieval for
lightweight bitcoin client. In: 2019 IEEE International Conference on Blockchain
and Cryptocurrency (ICBC), pp. 440–446. IEEE (2019)
19. Yang, H., et al.: A flexible and verifiable keyword PIR scheme for cloud-edge-
terminal collaboration in AIoT. IEEE Internet Things J. 11, 18111–18122 (2024)
20. Zhang, G., Liu, B., Zhu, T., Ding, M., Zhou, W.: PPFed: a privacy-preserving and
personalized federated learning framework. IEEE Internet Things J. 11, 19380–
19393 (2024)
21. Zhang, L., Zhu, T., Xiong, P., Zhou, W.: The price of unlearning: identifying
unlearning risk in edge computing. ACM Trans. Multimed. Comput. Commun.
Appl. 1–23 (2024)
Automatic Software Vulnerability
Detection in Binary Code

Shigang Liu1,2(B) , Lin Li2 , Xinbo Ban2 , Chao Chen3 , Jun Zhang2 ,
Seyit Camtepe1 , and Yang Xiang2
1
CSIRO’s Data 61, Clayton, Australia
{shigang.liu,seyit.camtepe}@data61.csiro.au
2
Swinburne University of Technology, Melbourne, Australia
{shigangliu,linli,junzhang,yxiang}@swin.edu.au
3
Royal Melbourne Institute of Technology, Melbourne, Australia
[email protected]

Abstract. Cybersecurity is critical in today’s digital world, where the


severity of threats from software vulnerabilities grows significantly each
year. Many techniques have been developed to analyze vulnerabilities in
source code. However, source code is not always available (for example,
most industry software is closed-source). As a result, analyzing vulner-
abilities in binary code becomes necessary and more challenging. This
paper presents a novel approach called BiVulD for detecting vulnerabil-
ities at the binary level. BiVulD has three phases: generating assembly
language instructions, learning good embeddings, and building a pre-
diction model. First, we create a database of vulnerable binaries using
CVE and NVD. Next, we propose using codeBERT to obtain good
embeddings. Finally, we apply a bidirectional LSTM on top of code-
BERT to build the predictive model. To demonstrate BiVulD’s effec-
tiveness, we compared it with several baselines, including source code-
based, binary code-based, and machine learning-based techniques on
real-world projects. The experimental results show that BiVulD outper-
forms the baselines and can detect more vulnerabilities. For instance,
BiVulD achieves at least 20% improvement in Precision, Recall, and
F-measure. We believe this work will serve as a foundation for future
research in vulnerability detection using only binary code.

Keywords: Software Security · Binary Code · Deep Learning ·


Machine Learning

1 Introduction
Nowadays, computer software components have become a crucial part of the
modern world, from simple applications on smart devices to complex enterprise
software and critical embedded systems. According to James et al. [1], approx-
imately 80% of cyber-attacks are motivated by financial gain, leading to losses
across numerous industries, including autonomous technologies, the Internet of
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 148–166, 2025.
https://doi.org/10.1007/978-981-96-4566-4_11
Automatic Software Vulnerability Detection in Binary Code 149

Things (IoT), artificial intelligence (AI), and more. For example, vulnerabili-
ties in widely-used browser plugins like Adobe Flash Player and Oracle Java
[2], as well as open-source components such as Heartbleed (Codenomicon 2014)
[3], and ShellShock (Symantec Security Response 2014) [4], have posed signifi-
cant security threats to thousands of companies and their customers worldwide.
The number and severity of software vulnerabilities continue to rise, making it
essential to identify and address them before deploying any software [5].
Binary code analysis is a common tool in computer security, especially when
the source code is unavailable. For example, vendors may want to verify the com-
pliance of third-party software, such as drivers and plugins. In the IoT domain,
most devices are built using public drivers, and many of these applications and
libraries lack source code. This creates a significant challenge for assessing the
security of IoT devices when deploying them on a large scale [6]. Binary code
analysis is crucial in specialized fields, including detecting software vulnera-
bilities, identifying malware, and recognizing code clones at the binary level.
Research shows that interest in binary analysis remains strong, with no signs of
decline [7]. Binary analysis is widely applicable and appealing, not only because
source code is often inconvenient to use, but also because it helps avoid mis-
matches between source code and binary code [8]. As a result, there continues
to be strong demand for software vulnerability analysis in binary code.
Recently, machine learning, especially Deep Learning (DL), has been widely
used to solve software security problems [9]. However, previous works such as [10]
and [11] are mainly focused on detecting vulnerabilities in source code. These
methods first use word2vec to convert the source code into feature vectors, then
apply Deep Learning to learn the high-level features, which are used to train the
Deep Learning models. Although some works perform vulnerability detection on
binary code, such as VDiscover [12], Instruction2vec [13], Maximal Divergence
Sequential Auto-Encoder [14], A-BiLSTM [15], and Deep Cost-sensitive Kernel
Machine [16], the experimental results from this study show that these techniques
often lead to a high false negative rate (over 65 Therefore, we are motivated to
develop an efficient automatic approach for vulnerability detection.
In this paper, we aim to use state-of-the-art techniques, specifically Code-
BERT [17] and attention-based Deep Learning [18], for vulnerability search in
binary code. At a high level, we first disassemble the binary software, then iden-
tify a set of vulnerable and non-vulnerable functions as ground truth. This
ground truth is fed into CodeBERT to generate effective embeddings, which
are then passed into attention-based neural networks to train the prediction
model. The model is used to detect vulnerabilities in real-world scenarios. We
conducted a series of comparative experiments to demonstrate the effectiveness
of our proposed BiVulD system, using real-world projects from various sources.
We will release our tool under an open-source license to enable other researchers
to evaluate our proposed strategy. Our main contributions are as follows:
150 S. Liu et al.

– We constructed the ground truth database according to National Vulnerabil-


ity Database (NVD)1 and Common Vulnerability and Exposures (CVE)2 . We
published the binary datasets for other researchers to use and contribute to
the binary level vulnerability detection problem. The datasets are available
at https://github.com/wolong3385/BiVulD.
– We proposed to use codeBERT to attain good embeddings given binary code.
We first disassemble the binaries to extract binary level function instructions.
Then, we apply codeBERT to these assembly language instructions at the
functional level.
– We proposed to employ the attention mechanism for deep feature learning
given binary code, given that the vulnerability is connected to specific assem-
bly language instructions rather than the entire binary.
– We implemented a prototype of our proposed strategy called BiVulD. We car-
ried out a series of empirical studies based on datasets from several sources.
We compared our proposed approach with different baselines. Our experi-
ments show that our Deep Learning-based technique is able to learn rich
representations based on the assembly language instructions.

2 Related Work

Recently, machine learning techniques have been applied to detect software vul-
nerabilities, typically involving two main tasks: feature selection and model train-
ing. The approach to feature selection varies from study to study. For example,
DiscovRE [8] suggested using the control flow graph (CFG) to calculate func-
tion similarity. VDiscover [12] proposed to use machine learning for software
vulnerability detection based on the lightweight static and dynamic features.
Meanwhile, Deep Learning has been widely applied in software vulnerability
detection. Lee et al. [13] employs Deep Learning for software bugs detection in
assembly code. Chen et al. [19] proposed Atomic for vulnerability discovery in
cellular networks (e.g. LTE and 5G). Atomic first employed natural-language
processing technique to analyse LTE documents to recover a set of hazard indi-
cators, then it used Textual Entailment to predict whether there were potential
security flaws in cellular networks. Steenhoek et al. [20] assesses 9 state-of-the-
art DL models on widely used vulnerability datasets, exploring various research
questions to improve model understanding and performance, they highlight that
vulnerability detection tends to be more effective when tailored to a specific
type, outperforming models designed to cover all vulnerabilities. Yang et al.
[21] propose a novel Deep Learning-enhancement architecture, implemented in
Asteria-Pro, which significantly improves bug detection performance, with 96.9%
faster calculation time and enhanced precision. Jiang et al. [22] argued that
while third-party libraries enhance productivity in software development, they

1
https://nvd.nist.gov/.
2
https://cve.mitre.org/.
Automatic Software Vulnerability Detection in Binary Code 151

can introduce security risks, necessitating effective software composition anal-


ysis (SCA) techniques. He proposed BinaryAI, an LLM-based binary-to-source
SCA technique that utilizes a transformer model for intelligent function-level
matching, improving detection precision and recall of reused libraries.
Moreover, attention-based DL [15] has been employed for software vulnera-
bility detection, however there is still much room to improve regarding the clas-
sification performance. Li et al. [23] proposed PalmTree, which is a pre-trained
assembly language model. Zuo et al. [24] proposed INNEREYE, which is a cross-
lingual DL approach for binary code similarity detection. Li et al. [25] developed
MirChecker which is a static analysis tool for Rust Programs vulnerability bug
search based on the Rust‘s Mid-level Intermediate Representation. Şahin and
Abualigah [26] developed a Deep Learning-based scheme for software vulnerabil-
ity detection based on the genetic feature selection using the clustering theory.
Later on, Jiang et al. [27] presented IFAttn which first utilized the attention
mechanism to convert the interpretable basic features as the adaptive semantic
features, and then developed an adaptive binary code similarity analysis scheme
based on the adaptive semantic features. Wang et al. [28] propose BinVulDet, a
binary vulnerability detection tool using decompilation (i.e. BiLSTM-attention)
and program slicing techniques, outperforming existing methods significantly. In
addition, Pei et al. [29] proposed TREX, a transfer-learning-based framework
that learns execution semantics from micro-traces to match semantically similar
functions across different architectures, optimizations, and obfuscations.

Fig. 1. The framework of BiVulD

3 Proposed BiVulD
From a high-level perspective, BiVulD leverages known vulnerabilities (vulnera-
ble functions) to build ground-truth data, which is then used to train the predic-
tion model using these binary vulnerabilities. Figure 1 provides an overview of
152 S. Liu et al.

the proposed BiVulD system. BiVulD operates in four phases: database prepa-
ration, embeddings learning, prediction model training, and detection.
For database preparation, this work maps CVEs and NVD to vulnerability-
contributing commits in six real-world projects: Asterisk, FFmpeg, LibPNG,
LibTIFF, Pidgin, and VLC. The datasets for CWE119 and CWE399 [10] are
also considered. These projects are included not only because they are widely
used in everyday life, but also because they are significant in the research com-
munity [11,30].
For embedding learning, we use the state-of-the-art codeBERT to obtain
high-quality embeddings. These embeddings are then input into an attention-
based BiLSTM model to train the prediction model. Finally, the prediction
model will be used to detect potential vulnerabilities in any test cases. Dur-
ing the detection phase, given a test case, tools such as IDA PRO and objdump
(or other reverse engineering tools like Capstone) can be used to disassemble
binaries and obtain assembly-level instructions. Functions will then be extracted
without domain knowledge, as we have no vulnerability information about the
test case. Each function will be input into the model, which will predict whether
it is vulnerable or not.
Note that, some distinctive design principles are included in BiVulD, which
are:

1. BiVulD can be applied to binary features directly. Different from


previous techniques, BiVulD can be applied to binary code directly rather than
source code.
2. Automation. Given binary code, BiVulD uses the function level binary code
as features. We note that previously developed algorithms use different kinds
of features and suffer in performance, mainly owing to that previous machine
learning techniques rely on a good set of features for success. However, it is
hard to identify good features for different binary code and different types of
vulnerabilities.

BiVulD is composed of several components. We present these components


and give a high-level overview here. We outline the detailed description of these
components in the following subsections.

– Binary Code. To show the effectiveness of our proposed BiVulD system, we


source the original binary code from different sources.
– Binary Function Extraction. We need to extract functions from the
assembly language instructions. This is easy to obtain if the symbol table
is present. Also, we need to create ground-truth to train a prediction model.
We do this using information collected from NVD and CVE. Tool such as
objdump is used to extract assembly-level functions from binary executable
inputs.
– Good embeddings learning. We realize that there are high false positives
and false negatives in real-world projects, and we suspect this is because
previous studies have not been able to learn the most important features of
Automatic Software Vulnerability Detection in Binary Code 153

vulnerable functions from non-vulnerable ones. Therefore, we employ code-


BERT to attain good embeddings. Our experiments show that the proposed
scheme can learn the high feature representations, which can minimize the
distribution divergence between the vulnerable samples and non-vulnerable
samples.
– Model Building. We employ an attention model in the BiVulD system.
Our experimental results show that Deep Learning could automatically learn
useful feature representations given the ground truth.

3.1 Binary Code


In this work, we create the ground truth database by mapping the CVEs and
NVD to vulnerability-contributing GitHub commits given a project. Specifically,
we collect the vulnerability data at soruce code level from the NVD and the CVE.
Then, the vulnerabilities can be identified by using the CVE ID because each
vulnerability has a unique CVE ID according to NVD and CVE. We extract the
functions from the source code, and manually label the function as vulnerable
if there is existing at least one vulnerability in the function, any functions that
related to the vulnerable function will be ignored in order to prepare a clean
dataset, and only the function without vulnerability and has no connection with
the vulnerable functions will be labeled as good ones. Finally, we compile these
functions into binary code for further consideration. We use the GNU GCC
(version 4.8.2) compiler, which supports the Intel x86, x64 and ARM architec-
tures. We compile the source under Windows GCC (MinGW) and Linux GCC
at optimization levels -O0, -O1, -O2, and -O3. The operating system is Microsoft
Windows 10 Enterprise with an Intel R Xeon R W-2133 CPU @3.60 GHz and
32 GB of memory. The Linux operating system is Ubuntu 16.04.3 with sixteen
vCPUs and 32 GB of memory.
For open source projects the source code should meet several conditions: 1)
it should be compilable on Windows and Linux operating systems; 2) it should
provide information regarding the vulnerabilities which may help us to identify
which function is vulnerable. 3) it should be publicly used in the community.
We collected six open source projects [31], which are Asterisk project3 , LibTIFF
project4 FFmpeg project5 , Pidgin project6 , LibPNG project7 , and VLC project8 .
Asterisk which sponsored by Digium is a free and open source communications
framework. LibTIFF is a widely used format for storing image data. FFmpeg is
a free and open-source tool that can change video format or quality. Pidgin is
an open source chat program that enables users to log into accounts on multiple
chat networks at the same time LibPNG is an open source project that supports
3
https://www.asterisk.org/.
4
http://www.libtiff.org/.
5
https://www.ffmpeg.org/.
6
https://pidginprojects.com/.
7
http://www.libpng.org/pub/png/libpng.html.
8
https://www.videolan.org/.
154 S. Liu et al.

Table 1. Datasets Information

Data # Vul. # Not Vul. # Total


CWE119 10712 9878 20590
CWE399 1530 2164 3694
Asterisk 84 19476 19560
FFmpeg 298 8015 8313
LibTIFF 98 678 776
LibPNG 28 492 520
Pidgin 38 13338 13376
VLC 24 10330 10345

almost all PNG features. VLC media player is an open source tool which supports
many audio and video file formats, including DVDs, Audio CDs, and so on.
For each of the projects, we manually map the vulnerability from CVE and
NVD which means we can tell where the vulnerability come from. Take CVE
ID: CVE-2018-12900 for example, we can see this is a buffer overflow by check-
ing the vulnerability categorization based on the information from NVD. The
analysis reveals that the issue is a heap-based buffer overflow in the cpSepa-
rateBufToContigBuf function, which is located in the tiffcp.c file. Therefore, we
treat the cpSeparateBufToContigBuf function as a buffer overflow vulnerability.
Moreover, the real-world projects include multiple types of vulnerabilities such
as buffer error (CWE119), resource management errors (CWE399), information
leak (CWE200), out-of-bound read (CWE125), input validation (CWE20), null
pointer dereference (CWE476), numeric errors (CWE189), out-of-bounds write
(CWE787), and insufficient information (NVD-CWE-noinfo). However, buffer
overflow takes the highest proportion of the vulnerabilities found. For example,
29 out of 71 vulnerabilities are buffer overflow in FFmpeg project (40.8%), and
9 out of 21 vulnerabilities are buffer overflow in LibTIFF project (42.9%). We
use 32-bit and 64-bit versions of GCC to compile the real-world projects under
Windows and Linux. Data information can be seen from Table 1.
In order to examine the effectiveness of our proposed scheme. CWE119 and
CWE399 datasets are considered in this work, CWE119 only contains buffer
overflow errors, and CWE399 only contains resource management errors. The
source code of CWE119 and CWE399 datasets were used in [10]. In the orig-
inal CWE119 dataset, there are 10440 vulnerable functions and 29313 non-
vulnerable functions. We compile them with MINGW32 and MINGW64 on the
Windows and Linux system. However, not every function is compilable because
some libraries are missing. As a result, we were only able to successfully com-
pile 20,590 samples, which included 10,712 vulnerable samples and 9,878 non-
vulnerable samples of CWE119. In the case of CWE399, we could only compile
1,530 vulnerable samples along with 2,164 non-vulnerable samples.
Automatic Software Vulnerability Detection in Binary Code 155

3.2 Binary Function Extraction (Processed Datasets)

Since the BiVulD system is designed to detect binary vulnerabilities at the func-
tion level, we need to extract functions from assembly language instructions to
achieve this goal. We start by compiling the source code into binary code. After
compilation, we obtain 32-bit and 64-bit executable files for both Windows and
Linux systems (i.e., PE and ELF files) from six real-world projects, including
the CWE119 and CWE399 datasets. Next, we use objdump to display the binary
instructions in assembly language form. Objdump is a command-line program
used to extract information from object files on Unix-like systems. Finally, we
extract functions (i.e., sequences of assembly language instructions) from the
disassembled output.
For real-world projects, we compile each project under Windows and Linux
using 32-bit and 64-bit versions of GCC. The processed data samples are then
combined to form a new dataset. For example, LibTIFF contains 98 vulnerable
samples and 678 non-vulnerable samples from all the samples compiled under
Windows and Linux with 32-bit and 64-bit versions of the GCC compiler. It is
worth noting that we only present three real-world projects in the experimental
results section, as the results for the other projects are similar to those for
Asterisk, FFmpeg, and LibTIFF.

3.3 Good Embeddings Learning

We use codeBERT [17] for embedding learning. codeBERT is a transformer-


based neural architecture, meaning it uses multi-layer bidirectional Transformers
to create effective embedding representations [17]. We integrate codeBERT into
the BiVulD system for two reasons: 1) codeBERT has been proven successful in
detecting software faults and vulnerabilities [17,32]; 2) codeBERT can produce
high-quality assembly embeddings, unlike the word2vector method (see model
explanation in Sect. 3.5).
In this work, we first use a fine-tuned codeBERT to obtain good embed-
dings from the training data. To fine-tune codeBERT, we split each dataset into
training, evaluation, and test data. Then the 60% of the samples from the 6
real-world projects, the CWE119 dataset, and the CWE399 dataset were used
to retrain and optimize the original codeBERT model [17]. After obtaining the
good embeddings, we feed them into a BiLSTM Deep Learning model.

3.4 Model Building

We utilize a bidirectional LSTM with attentive position embeddings for vulner-


ability detection. The reasons are three-fold: 1) Deep Learning based systems
achieve state-of-the-art performance on vulnerability detection [10,33]; 2) bidi-
rectional LSTMs can cope with token sequences of variable length; 3) attentive
position embeddings automatically identify the importance of each assembly
language instruction with respect to the target task.
156 S. Liu et al.

Fig. 2. Figure (a) and (c) display the two dimensional features of LibTIFFand
CWE119, and Figure (b) and (d) visualise the features of LibTIFF and CWE119
using t-Distributed Stochastic Neighbor Embedding (t-SNE). Note: label ‘0’ means
non-vulnerable class, label ‘1’ means vulnerable class.

At a high level, BiVulD takes as input the token sequences converted from
assembly functions. On the one hand, assembly instructions often differ despite
having similar semantics. On the other hand, similar data types with dissimilar
byte tokens should have similar representations, such as mov DWORD PTR [ebp-
0xc],0x0, and mov DWORD PTR [ebp-0x10],0x0. This phenomenon is similar to
constructing word representations in natural language processing. In light of this,
we represent a token embedding as a continuous vector e ∈ R128 . As a result,
we have |V| token embeddings, where V denotes the vocabulary.
The semantics of a function do not rely solely on the meaning of individual
tokens, but also on their context. Thus, we feed the embedding sequences into
a BiLSTM to generate more abstract representations by incorporating assembly
language instruction information. An LSTM is a recurrent neural network taking
the following form:
hi , ci = LSTM(xi , hi−1 , ci−1 ) (1)
where xi is the input at position i, hi ∈ R64 and ci ∈ R64 are the hidden
states and cells of the LSTM at position i, respectively. An LSTM will be fed to
BiLSTM in the forward and the backward direction. We concatenate the hidden

− ←

states generated in the two directions h i and h i as the output ĥi at position i.
To improve vulnerability detection, it is important to recognize that not all
tokens have the same contribution to the prediction. Therefore, we assess the
importance of each token feature at every position in order to identify the signifi-
cant tokens for vulnerability detection. Since the same token feature takes differ-
ent values at different token positions, they receive different importance scores. In
this way, each token at each position will get a vector of importance scores for its
features. We assume that important tokens should be represented by important
features. Thus we take the average of the feature importance scores for each
token as its final importance score. Since those scores are position and token
dependent, we refer to them as attentive position embeddings. The attentive
position embeddings are concatenated to the hidden representations generated
by the BiLSTM as the final representation of each token. All token representa-
tions are fed into a binary classifier, which is a feedforward neural network with
two hidden layers, to determine if the input function is vulnerable or not.
Automatic Software Vulnerability Detection in Binary Code 157

3.5 Feasibility Analysis: High-Representations Display

BiVulD is an automatic deep-learning-based software vulnerability detection sys-


tem. Therefore, it is important to understand the data distribution learned by
this system. In this section, we visualize the high-level representations of the
LibTIFF and CWE119 datasets. LibTIFF is a real-world project that contains
many types of vulnerabilities, while CWE119 only includes buffer overflow vul-
nerabilities.
Figures 2a and 2c display the two-dimensional features of LibTIFF and
CWE119, and Figs. 2b and 2d visualize these features using t-Distributed
Stochastic Neighbor Embedding (t-SNE) [34]. Figure 2a shows that the high
representations of LibTIFF are separable from each other, especially when we
set the threshold at 0.8. Most of the vulnerable samples (labeled as 1) are isolated
from the non-vulnerable samples (labeled as 0). The t-SNE plot of the LibTIFF
features in Fig. 2b further emphasizes that most of the learned high representa-
tions are linearly separable. We observe false positives (non-vulnerable predicted
as vulnerable) and false negatives (vulnerable predicted as non-vulnerable). This
is because LibTIFF contains multiple types of vulnerabilities, while the train-
ing datasets mainly include buffer errors and resource management errors. As
a result, there may be mismatches for certain vulnerabilities, such as “divide
by zero (CWE-369)” and “security features (CWE-254)”, since the patterns of
these vulnerabilities are different. We believe this can be addressed with enough
training samples. Figures 2c and 2d show that the high-level representations of
CWE119 are linearly separable. From Fig. 2d, we can see that vulnerable and
non-vulnerable samples are isolated from each other when we set a threshold
of about 0.52. This is because CWE119 represents a single type of buffer error
vulnerability, which allows the BiVulD system to learn the high-level feature rep-
resentations and reduce the data distribution divergence between vulnerable and
non-vulnerable samples. The experimental results also highlight the effectiveness
of our proposed methods for single types of vulnerabilities.

4 Experiments

In this section, we mainly focus on the following three research questions:

– How effective is BiVulD compared to other machine learning-based


approaches specifically developed for binary vulnerability detection?
– Is BiVulD practical and effective in detecting real software vulnerabilities?
– How effective is BiVulD compared to other machine learning tools used for
vulnerable function search?

4.1 Baselines

As far as we know, there are only a few research studies using Deep Learning for
vulnerability detection based on binary code. The closest works to ours are the
158 S. Liu et al.

Text-CNN-based approach [13], VDiscover [12], DiscovRE [8], and A-BiLSTM


[15]. However, since DiscovRE was developed for cross-architecture bug discov-
ery, we only compare our work with the Text-CNN-based approach, VDiscover,
and A-BiLSTM.
There are several other machine learning tools available for function vulner-
ability search and detection. For example, Xu et al. [35] developed a tool called
Gemini, which uses a neural network to learn embeddings from the control flow
graph (CFG) of each binary function. Massarelli et al. [36] developed a tool called
SAFE, which uses self-attention function embeddings to measure the similarity
between two functions. Their work shows an improvement in performance com-
pared to Gemini. Since Gemini is specifically designed for cross-platform binary
code similarity detection, which is different from our scenario, we focus on SAFE
as one of the baselines for comparison, as SAFE can also be used for vulnerability
search.

4.2 Evaluation Metrics


To evaluate the effectiveness of the proposed system, four performance measures
are considered: Precision, True Positive Rate (TPR or Recall), F1-measure (F1),
and False Positive Rate (FPR). Since the False Negative Rate (FNR) is equal to
1 - Recall, we only report the TPR (Recall). Comparatively, we expect that the
BiVulD system will achieve higher TPR values, lower FPR values, and higher
F1 values, as F1 is a balanced measure of overall classifier performance. It is
important to note that a false positive (FP) means that BiVulD incorrectly
classifies a good function as vulnerable, while a false negative (FN) occurs when
BiVulD incorrectly classifies a vulnerable function as good.
Additionally, top-k precision (denoted as P @K) is also considered to demon-
strate the usefulness of the BiVulD system. This metric is widely used in both
information retrieval systems [33] and software vulnerability detection [33]. In
this work, we report top-k to show the proportion of vulnerabilities among the
top-k retrieved functions (i.e., the most vulnerable functions).

4.3 Implementation Details


We implemented the deep feature model using Keras (2.0.8) with TensorFlow
(1.3.0) as the backend. To pre-train the token embedding, we used the Gensim
package (3.0.1) with default settings. To prevent overfitting in the LSTM model,
we randomly split the data into three sets: 60% for training, 20% for validation,
and 20% for testing. We trained the model on the training set, validated it on
the validation set, and tested it on the remaining test data. Due to a significant
class imbalance problem in the LibTIFF and Asterisk datasets, we applied the
random oversampling (ROS) method with an oversampling ratio of 200%, as
suggested by previous work [12]. To ensure fairness, we also applied ROS to
the baseline models. It is important to note that ROS was applied only to the
training data to optimize the decision boundary and not to the validation or test
data.
Automatic Software Vulnerability Detection in Binary Code 159

Table 2. Experimental results comparing techniques based on binary level detection.

Data Technique Precision TPR FPR F1


CWE119 Ins2vec-TCNN 0.704 0.584 0.249 0.638
VDiscover 0.744 0.728 0.246 0.736
A-BiLSTM 0.758 0.843 0.271 0.798
BiVulD 0.972 0.864 0.028 0.915
CWE399 Ins2vec-TCNN 0.554 0.761 0.283 0.641
VDiscover 0.723 0.736 0.209 0.730
A-BiLSTM 0.790 0.690 0.205 0.740
BiVulD 0.958 0.857 0.028 0.904
LibTIFF Ins2vec-TCNN 0.333 0.333 0.096 0.333
Ins2vec-TCNN 0.448 0.464 0.096 0.456
+ ROS
VDiscover 0.529 0.474 0.096 0.500
VDiscover 0.500 0.667 0.002 0.571
+ROS
A-BiLSTM 0.595 0.786 0.090 0.677
BiVulD 0.952 0.940 0.007 0.943
FFmpeg Ins2vec-TCNN 0.365 0.464 0.402 0.135
Ins2vec-TCNN 0.544 0.402 0.302 0.188
+ ROS
VDiscover 0.478 0.379 0.302 0.152
VDiscover 0.562 0.408 0.301 0.195
+ROS
A-BiLSTM 0.705 0.252 0.134 0.210
BiVulD 0.941 0.850 0.002 0.890
Asterisk Ins2vec-TCNN 0.833 0.217 0.000 0.345
Ins2vec-TCNN 0.435 0.417 0.003 0.426
+ ROS
VDiscover 0.107 0.348 0.014 0.163
VDiscover 0.333 0.600 0.004 0.428
+ROS
A-BiLSTM 0.470 0.696 0.026 0.561
BiVulD 0.921 0.800 0.001 0.837

4.4 Results

To ensure a systematic evaluation, we structured our analysis around the three


research questions (RQ1-3) outlined in our study.
RQ1: How effective is BiVulD compared to other machine learning-
based approaches specifically developed for binary vulnerability detec-
160 S. Liu et al.

tion? To demonstrate that the BiVulD system outperforms other machine


learning-based approaches, we selected five datasets (CWE119, CWE399,
LibTIFF, FFmpeg, and Asterisk) for comparative experiments.
The experimental results for BiVulD, Ins2vec-TCNN, VDiscover, and A-
BiLSTM on the CWE119, CWE399, LibTIFF, FFmpeg, and Asterisk datasets
are shown in Table 2. As seen from the table, BiVulD significantly outper-
forms the baselines across all five datasets, achieving a much higher F1-score
(91.5%) compared to Ins2vec+TCNN (63.8%), VDiscover (73.6%), and A-
BiLSTM (79.8%) on the CWE119 dataset. This is mainly due to BiVulD’s higher
precision and true positive rate (TPR) (i.e., much lower false negative rate), while
maintaining a low false positive rate (FPR) of only 2.8%, compared to 24.9% for
Ins2vec+TCNN, 24.6% for VDiscover, and 27.1% for A-BiLSTM.
For the other four datasets, BiVulD also achieves excellent results. For exam-
ple, in the Asterisk dataset, Ins2vec+TCNN trades a high precision of 83.3% for
a low FPR of 0%, which results in a low F1-score of 34.5%. VDiscover achieves a
slightly better TPR (34.8%) than Ins2vec+TCNN (21.7%), but with a precision
of only 10.7%, leading to a very low F1-score of 16.3%. A-BiLSTM also yields a
low F1-score of 56.1%. In contrast, BiVulD achieves a much higher F1-score of
83.7%, with a low FPR (0.1%) and a higher TPR (80%).
Both Ins2vec+TCNN and VDiscover show improved performance when ROS
is used in the training process, as indicated by the results for Ins2vec-TCNN +
ROS and VDiscover + ROS. For instance, the F1-score of VDiscover + ROS
on the LibTIFF dataset is 7.1% higher than VDiscover alone due to its higher
TPR and lower FPR. However, the precision of VDiscover + ROS is 2.9% lower
than VDiscover, meaning that while VDiscover + ROS achieves a higher TPR,
it also generates more false positives. In comparison, BiVulD still achieves the
best results in terms of precision, TPR, and F1-score on the LibTIFF dataset.
We observe that all models perform better on CWE119, followed by LibTIFF
and Asterisk. We suspect this is due to the varying ratio of non-vulnerable and
vulnerable functions in each dataset. As shown in Table 1, the imbalance ratio
(i.e., the number of majority class samples over the number of minority class
samples) for CWE119 is nearly 1:1, while the ratios for LibTIFF and Asterisk
are about 7:1 and 244:1, respectively. This imbalance may cause the classifier
to be biased toward the non-vulnerable function class, resulting in high false
negatives. The classification performance on CWE119 is slightly better than on
CWE399. We believe this is because the data size of CWE399 (3,694 samples)
is smaller than that of CWE119 (20,590 samples).
We hypothesize that BiVulD outperforms the other baselines for three rea-
sons. First, CodeBERT can learn good embeddings [17], and tailored Deep Learn-
ing generally performs better than word2vec [33]. We believe this is why BiVulD
outperforms VDiscover. Second, extracting only part of the opcodes as learning
features may not be sufficient for text-CNN at the function level. For example,
one of the samples in our dataset contains only two mov opcodes (i.e., mov ebp,
esp and mov eax, 0x0 ). In such cases, text-CNN may fail to learn high-quality
representations due to the limited number of opcode features. Third, we suspect
Automatic Software Vulnerability Detection in Binary Code 161

Table 3. Top-k precision for real-world data. # non vul/vul: the number of non-
vulnerable and vulnerable functions in test set.

Dataset # non vul/vul Accuracy Vul.#, Vul.#, Vul.#, Vul.#, Vul.#,


Top10 Top20 Top30 Top40 Top50
CWE119-LibTIFF 178, 26 96.5% 8, 80% 15,75% 18, 60% 20, 50% 20, 40%
LibTIFF 55, 8 97.7% 6, 60% 7, 35%
Asterisk 1543, 8 99.7% 7, 70% 8, 40%
Vul.#, Vul.#, Vul.#, Vul.#, Vul.#,
Top100 Top200 Top300 Top400 Top500
CWE119-Asterisk 4980, 25 99.4% 11, 11% 14, 7% 15, 5% 16, 4% 18, 3.6%

that BiVulD’s attention mechanism can focus more on the instructions that
cause vulnerabilities. Since vulnerabilities are often related to specific parts of
the code rather than the entire function, the attention mechanism in the decoder
process can allocate varying amounts of attention to different parts of the code.
RQ2: Is BiVulD practical and effective in detecting real software vul-
nerabilities? We use real-world projects (Asterisk and LibTIFF) to demon-
strate that BiVulD is practical for addressing real-world software security prob-
lems.
The experimental results of BiVulD on LibTIFF, Asterisk, CWE119-
LibTIFF, and CWE119-Asterisk are presented in Table 3. The test set for
LibTIFF consists of 55 functions, 8 of which are vulnerable. The test set for
Asterisk consists of 1543 functions, including 8 vulnerable ones. For CWE119-
LibTIFF and CWE119-Asterisk, the training dataset is CWE119, while the
test sets are LibTIFF and Asterisk, respectively. For LibTIFF, Asterisk, and
CWE119-LibTIFF, we report the top-10 to top-50 precision, while for CWE119-
Asterisk we report the top-100 to top-500 precision due to the larger number of
non-vulnerable functions in the test cases.
First, from Table 3, we can see that top-k precision for CWE119-LibTIFF
exceeds 40%, with 8, 15, 18, 20, and 20 vulnerable functions identified, respec-
tively. This means that 50 functions need to be examined to find 20 vulnera-
ble functions. Secondly, with our proposed BiVulD, 500 functions need to be
checked to find 18 vulnerable functions in the CWE119-Asterisk scenario due
to false positives. We hypothesise that the high false positives in BiVulD are
due to two factors: 1) the training data (CWE119) and test data (Asterisk) are
from different sources, meaning the data distributions may differ. In this case,
BiVulD may not be able to capture the underlying distribution of the test cases;
2) the ratios of non-vulnerable to vulnerable functions in CWE119 (nearly 1:1)
and Asterisk (nearly 244:1) are quite different, which may weaken the prediction
model trained on CWE119 when targeting Asterisk.
From Table 3, we can see that 6 and 7 vulnerable functions can be found
by examining the top 10 functions over the LibTIFF and Asterisk datasets,
respectively. When examining the top 20 functions, all the vulnerable functions
162 S. Liu et al.

can be found in the Asterisk dataset, and 7 out of 8 vulnerable functions can
be identified in the LibTIFF dataset. Moreover, we can see that the accuracies
on LibTIFF and Asterisk are 97.7% and 99.7%, which are higher than those for
CWE119-LibTIFF (96.5%) and CWE119-Asterisk (99.4%). This indicates that
the prediction model built on the same dataset (or a dataset from the same
source) results in higher accuracy and can detect more vulnerable functions in
the top-k functions.
RQ3: How effective is BiVulD compared to other machine learning
tools used for vulnerable function search? To answer this question, we
use two types of datasets, including CWE119 and CWE399, as well as two real-
world software applications: Asterisk and LibTIFF. We mainly focus on detecting
vulnerable functions to demonstrate that BiVulD outperforms other machine
learning tools that can be used for software vulnerability search.
Table 4 shows the experimental results of BiVulD and SAFE based on
CWE119, CWE399, LibTIFF, and Asterisk. Because there is a class imbalance
in the LibTIFF and Asterisk datasets, we applied random oversampling (ROS)
to SAFE. The oversampling rate is 2. SAFE + ROS means we first use ROS
on the training data to balance the number of vulnerable and non-vulnerable
samples, and then we train the SAFE model on this balanced data. We used
the following parameters for SAFE: batch size 64, embedding size 256, and word
frequency 3.
From Table 4, we see that SAFE performs worst on the LibTIFF and Aster-
isk datasets. In particular, it cannot identify any vulnerabilities in the Aster-
isk dataset. However, SAFE + ROS shows some improvement. For example, in
LibTIFF, the Precision improved from 0.278 to 0.417, the FPR dropped from
0.200 to 0.108, and the F1-score increased from 0.385 to 0.500. In comparison,
BiVulD shows much better performance in terms of Precision, TPR, FPR, and
F1-score. For instance, BiVulD achieves the best F1-scores of 0.915 and 0.943
on the CWE119 and LibTIFF datasets, which are 22.2% and 55.8% higher than
SAFE.
For the CWE119 and CWE399 datasets, both SAFE and BiVulD performed
significantly better than on the Asterisk and LibTIFF datasets. However, BiVulD
outperforms SAFE in all performance measures. For example, in CWE119,
BiVulD has a Precision of 0.972 and a TPR of 0.864, which are about 33%
and 11% higher than SAFE, respectively. BiVulD also has an FPR of 0.028,
which is 40% lower than SAFE. As a result, BiVulD has a higher F1-score of
0.915, which is about 20% higher than SAFE.
Overall, we suspect BiVulD would outperform SAFE because SAFE is specif-
ically designed for function similarity matching. However, this paper mainly
focuses on binary-level vulnerability detection. Therefore, not every vulnerabil-
ity sample has a fixed version in the training data, which may affect the function
similarity matching for SAFE.
Automatic Software Vulnerability Detection in Binary Code 163

Table 4. Experimental results based on SAFE and BiVulD.

Data Tech. Precision TPR FPR F1


CWE119 SAFE 0.641 0.754 0.421 0.693
BiVulD 0.972 0.864 0.028 0.915
CWE399 SAFE 0.611 0.500 0.226 0.550
BiVulD 0.958 0.857 0.028 0.904
LibTIFF SAFE 0.278 0.625 0.200 0.385
SAFE 0.417 0.625 0.108 0.500
+ROS
BiVulD 0.952 0.940 0.007 0.943
Asterisk SAFE – 0.000 0.000 0.000
SAFE 0.333 0.286 0.002 0.308
+ROS
BiVulD 0.921 0.800 0.001 0.837

5 Limitations
The BiVulD system has some limitations that we plan to address in the future.
One of these limitations is that it focuses on detecting vulnerabilities in
binary code on x86 architecture. However, with an increasing number of IoT
vendors compiling and deploying third-party code bases across different archi-
tectures such as MIPS, ARM, and PowerPC, it is more crucial than ever to
search for known vulnerabilities in binary code across different architectures.
Therefore, we believe it would be interesting to further develop BiVulD with
transfer learning that can be adapted to detect vulnerabilities across different
CPU architectures.
BiVulD is currently limited to employing an attention model. We plan to
employ other interpretation methods [37,38] and a hierarchical neural network
or Convolutional Neural Network (CNN)-based approaches, where the binary
code is processed into ‘images’ as the input to the CNN.
The scenario of function inlining may pose a challenge for BiVulD as it may
alter the structure of the binary code. To address this issue, we will investigate
and study the effect of function inlining on BiVulD’s performance in the future.
We are planning to retrain a customized embedding model for binary-level
software vulnerability detection. Although this study employed codeBERT to
obtain good embeddings, codeBERT is trained on source code rather than binary
code. The experimental results in this study showed that codeBERT split the
assembly code to learn good feature representations. For instance, consider the
EBP register (i.e., extended base stack pointer), which is especially relevant
for stack buffer overflows in x86/64. CodeBERT splits the EBP token into two
tokens ‘eb’ and ‘p’. Although BiVulD achieved high classification performance,
it is unusual to split a register into two tokens. We aim to develop a customized
164 S. Liu et al.

embedding model that maintains the semantics of the code and improves classi-
fication performance in future.

6 Conclusions

In this paper, we introduced BiVulD, a novel approach for detecting vulnera-


bilities at the binary level. BiVulD consists of three phases: generating assem-
bly language instructions, learning effective embeddings using CodeBERT, and
building a predictive model with a bidirectional LSTM. We created a database of
vulnerable binaries based on the Common Vulnerability and Exposures (CVE)
and National Vulnerability Database (NVD) and demonstrated BiVulD’s effec-
tiveness by comparing it with several baseline techniques, including source code-
based, binary code-based, and machine learning-based approaches on real-world
projects. Our experimental results indicate that BiVulD outperforms these base-
lines, achieving at least a 20% improvement in Precision, Recall, and F-measure.
This highlights BiVulD’s capability to detect more vulnerabilities effectively. We
believe this work provides a solid foundation for future research in binary-level
vulnerability detection and will be beneficial to the research community.

References
1. Lerums, J.E., La’Reshia, D.P., Dietz, J.E.: Simulation modeling cyber threats,
risks, and prevention costs. In: 2018 IEEE International Conference on Elec-
tro/Information Technology (EIT), pp. 0096–0101. IEEE (2018)
2. Bassi, D., Singh, H.: A systematic literature review on software vulnerability pre-
diction models. IEEE Access (2023)
3. Amodei, A., Capriglione, D., Cerro, G., Ferrigno, L., Miele, G., Tomasso, G.: A
measurement approach for inline intrusion detection of heartbleed-like attacks in
IoT frameworks. IEEE Trans. Instrum. Meas. (2023)
4. Hammi, B., Zeadally, S., Nebhen, J.: Security threats, countermeasures, and chal-
lenges of digital supply chains. ACM Comput. Surv. 55 (2023)
5. Bhuiyan, M.H.M., Parthasarathy, A.S., Vasilakis, N., Pradel, M., Staicu, C.-A.:
SecBench. js: an executable security benchmark suite for server-side JavaScript.
In: International Conference on Software Engineering (ICSE) (2023)
6. Xu, G., et al.: SoProtector: safeguard privacy for native so files in evolving mobile
IoT applications. IEEE Internet Things J. 7(4), 2539–2552 (2019)
7. Alrabaee, S., Debbabi, M., Wang, L.: A survey of binary code fingerprinting
approaches: taxonomy, methodologies, and features. ACM Comput. Surv. (CSUR)
55(1), 1–41 (2022)
8. Eschweiler, S., Yakdan, K., Gerhards-Padilla, E.: discovRE: efficient cross-
architecture identification of bugs in binary code. In: NDSS (2016)
9. Yang, Y., Xia, X., Lo, D., Grundy, J.: A survey on deep learning for software
engineering. ACM Comput. Surv. (CSUR) 54(10s), 1–73 (2022)
10. Li, Z., et al.: VulDeePecker: a deep learning-based system for vulnerability detec-
tion. In: 25th Annual Network and Distributed System Security Symposium (NDSS
2018), San Diego, California, USA, 18–21 February 2018 (EI/CCF-A) (2018)
Automatic Software Vulnerability Detection in Binary Code 165

11. Liu, S., Lin, G., Han, Q.-L., Wen, S., Zhang, J., Xiang, Y.: DeepBalance: deep-
learning and fuzzy oversampling for vulnerability detection. IEEE Trans. Fuzzy
Syst. 28(7), 1329–1343 (2019)
12. Grieco, G., Grinblat, G.L., Uzal, L., Rawat, S., Feist, J., Mounier, L.: Toward
large-scale vulnerability discovery using machine learning. In: Proceedings of the
Sixth ACM Conference on Data and Application Security and Privacy, pp. 85–96.
ACM (2016)
13. Lee, Y.J., Choi, S.-H., Kim, C., Park, K.-W.: Learning binary code with deep
learning to detect software weakness. In: KSII The 9th International Conference
on Internet 2017 Symposium, pp. 245–249 (2017)
14. Le, T., Nguyen, T., Le, T., Phung, D., Montague, P., De Vel, O., Qu, L.: Maximal
divergence sequential autoencoder for binary software vulnerability detection. In:
International Conference on Learning Representations (2018)
15. Liu, S., Dibaei, M., Tai, Y., Chen, C., Zhang, J., Xiang, Y.: Cyber vulnerability
intelligence for internet of things binary. IEEE Trans. Industr. Inf. 16(3), 2154–
2163 (2019)
16. Nguyen, T., et al.: Deep cost-sensitive kernel machine for binary software vulner-
ability detection. In: Pacific-Asia Conference on Knowledge Discovery and Data
Mining, pp. 164–177. Springer (2020)
17. Feng, Z., et al.: CodeBERT: a pre-trained model for programming and natural
languages. arXiv preprint arXiv:2002.08155 (2020)
18. Hajra, S., Alam, M., Saha, S., Picek, S., Mukhopadhyay, D.: On the instability
of softmax attention-based deep learning models in side-channel analysis. IEEE
Trans. Inf. Forensics Secur. 19, 514–528 (2024)
19. Chen, Y., et al.: Bookworm game: Automatic discovery of LTE vulnerabilities
through documentation analysis. In: 2021 IEEE Symposium on Security and Pri-
vacy (SP). IEEE (2021)
20. Steenhoek, B., Rahman, M.M., Jiles, R., Le, W.: An empirical study of deep learn-
ing models for vulnerability detection. In: 2023 IEEE/ACM 45th International
Conference on Software Engineering (ICSE), pp. 2237–2248. IEEE (2023)
21. Yang, S., et al.: Asteria-pro: Enhancing deep-learning based binary code similarity
detection by incorporating domain knowledge. ACM Trans. Softw. Eng. Methodol.
(2023)
22. Jiang, L., et al.: BinaryAI: binary software composition analysis via intelligent
binary source code matching. In: Proceedings of the IEEE/ACM 46th International
Conference on Software Engineering, pp. 1–13 (2024)
23. Li, X., Qu, Y., Yin, H.: PalmTree: learning an assembly language model for instruc-
tion embedding. In: Proceedings of the 2021 ACM SIGSAC Conference on Com-
puter and Communications Security, pp. 3236–3251 (2021)
24. Zuo, F., Li, X., Young, P., Luo, L., Zeng, Q., Zhang, Z.: Neural machine translation
inspired binary code similarity comparison beyond function pairs (2019)
25. Li, Z., Wang, J., Sun, M., Lui, J.C.S.: MirChecker: detecting bugs in rust pro-
grams via static analysis. In: Proceedings of the 2021 ACM SIGSAC Conference
on Computer and Communications Security, pp. 2183–2196 (2021)
26. Batur Şahin, C., Abualigah, L.: A novel deep learning-based feature selection model
for improving the static analysis of vulnerability detection. Neural Comput. Appl.
33(20), 14049–14067 (2021). https://doi.org/10.1007/s00521-021-06047-x
27. Jiang, S., Fu, C., Qian, Y., He, S., Lv, J., Han, L.: IFAttn: binary code similarity
analysis based on interpretable features with attention. Comput. Secur. 102804
(2022)
166 S. Liu et al.

28. Wang, Y., Jia, P., Peng, X., Huang, C., Liu, J.: BinVulDet: detecting vulnerability
in binary program via decompiled pseudo code and biLSTM-attention. Comput.
Secur. 125, 103023 (2023)
29. Pei, K., Xuan, Z., Yang, J., Jana, S., Ray, B.: Trex: learning execution semantics
from micro-traces for binary similarity. arXiv preprint arXiv:2012.08680 (2020)
30. Fan, Y., Wan, C., Cai, F., Han, L., Hao, X.: VDoTR: detection based on tensor
representation of comprehensive code graphs. Comput. Sevulnerabil. Cur. 130,
103247 (2023)
31. Lin, G., Xiao, W., Zhang, L.Y., Gao, S., Tai, Y., Zhang, J.: Deep neural-based
vulnerability discovery demystified: data, model and performance. Neural Comput.
Appl. 33(20), 13287–13300 (2021)
32. Mashhadi, E., Hemmati, H.: Applying codeBERT for automated program repair
of java simple bugs. arXiv preprint arXiv:2103.11626 (2021)
33. Lin, G., et al.: Cross-project transfer representation learning for vulnerable function
discovery. IEEE Trans. Industr. Inf. 14(7), 3289–3297 (2018)
34. Cieslak, M.C., Castelfranco, A.M., Roncalli, V., Lenz, P.H., Hartline, D.K.: t-
distributed stochastic neighbor embedding (t-SNE): a tool for eco-physiological
transcriptomic analysis. Marine Genom. 51, 100723 (2020)
35. Xu, X., Liu, C., Feng, Q., Yin, H., Song, L., Song, D.: Neural network-based graph
embedding for cross-platform binary code similarity detection. In: Proceedings of
the 2017 ACM SIGSAC Conference on Computer and Communications Security,
pp. 363–376. ACM (2017)
36. Massarelli, L., Di Luna, G.A., Petroni, F., Baldoni, R., Querzoni, L.: SAFE: self-
attentive function embeddings for binary similarity. In: International Conference on
Detection of Intrusions and Malware, and Vulnerability Assessment, pp. 309–329.
Springer (2019)
37. Li, L., et al.: VulAnalyzer: explainable binary vulnerability detection with multi-
task learning and attentional graph convolution. ACM Trans. Priv. Secur. 26(3),
1–25 (2023)
38. Nauta, M., et al.: From anecdotal evidence to quantitative evaluation methods:
a systematic review on evaluating explainable AI. ACM Comput. Surv. 55(13s),
1–42 (2023)
Malicious Code Detection Based
on Generative Adversarial Model

Jinzhihao Zhang, Jia Yang(B) , and Weiqi Zhou

School of Computer Science and Engineering, Hubei University of Technology,


Wuhan 430000, China
[email protected]

Abstract. With the increasing diversity and complexity of malicious


code, the threats posed by malicious code are continuously growing,
making malicious code detection increasingly challenging. Currently,
researchers typically extract malware features and use deep-learning
models for malware detection. Using image-based features for classifi-
cation can improve the accuracy and efficiency of malware detection.
However, traditional deep learning-based malware detection faces chal-
lenges such as data imbalance, insufficient data volume, and malware
obfuscation. We propose to augment the malicious code dataset using
Generative Adversarial Network (GAN) technology to improve the accu-
racy and efficiency of malicious code detection. First, we converted
malicious software binary files into grayscale images. We used Gener-
ative Adversarial Network (GAN), Wasserstein Generative Adversarial
Network (WGAN), and Wasserstein Generative Adversarial Networks-
Gradient Penalty (WGAN-GP) models to augment the grayscale image
dataset of malicious code. Subsequently, we obtained three augmented
datasets: w1, w2, and w3. These datasets were then fed into the ResNet50
neural network model for training. Experimental results indicate that
using the GAN model for data augmentation improved ResNet’s recogni-
tion accuracy by approximately 10%. When using the WGAN model for
data augmentation, ResNet’s recognition accuracy increased by 2%. Sim-
ilarly, using the WGAN-GP model for data augmentation also increased
ResNet’s recognition accuracy by 2%. Among the three models, the GAN
model demonstrated the strongest improvement effect. We demonstrate
that the ResNet50 network with GAN models can significantly improve
malware code detection.

Keywords: Generation Adversarial Network · Malicious Code


Detection · Resnet50 · Deep Learning · Data Augmentation

1 Introduction
The swift increase in malware variants, along with their rapid spread and stealthy
nature, poses substantial challenges to conventional malware detection tech-
niques. Malware developers often employ code obfuscation techniques to gener-
ate new variants, hiding certain features to evade detection. Deep learning tech-
niques offer a promising direction for tackling these issues in malware detection.
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 167–183, 2025.
https://doi.org/10.1007/978-981-96-4566-4_12
168 J. Zhang et al.

However, training deep learning models requires large datasets, and insufficient
data or imbalanced classes can lead to increased time consumption. Certainly,
numerous experimental studies have integrated GAN-related models with mali-
cious code [1, 2].
To address the issues of class imbalance and insufficient data in malware
datasets, we consider using GAN to augment the dataset [3–5]. GAN provides a
powerful machine learning model capable of learning and generating the distribu-
tion of malware features. It can be used to generate sufficient samples to balance
our dataset. GAN can effectively augment malware datasets, saving training
costs and time while improving the accuracy of malicious code detection.
In this paper, we propose the use of GAN [6], WGAN [7], and WGAN-
GP [8] models to augment datasets related to malicious code. We selected the
Malimg dataset and the big2015 dataset for augmentation. Subsequently, we used
the original datasets and the augmented datasets (m1, m2, m3 augmented by
Malimg and b1, b2, b3 augmented by big2015) to train a ResNet50 network [9],
aiming to enhance its detection capabilities. In summary, we make the following
contributions:

1. This study focuses on the performance of different GAN models on various


datasets. During the dataset construction process, researchers collected and
curated a large amount of authentic malicious code samples to integrate them,
establishing a valuable data foundation.
2. In implementing GAN-based models for malicious code detection, the app-
roach involves constructing models based on Generative Adversarial Networks
(GANs). Through adversarial training, these models enhance malicious code
samples for training target detection algorithms. This section utilizes GAN,
WGAN, and WGAN-GP models to demonstrate the enhancing effect of Gen-
erative Adversarial Networks.
3. For the target detection model, the ResNet50 network will be utilized to
determine and identify malicious code, comparing the detection performance
of the datasets enhanced by the three aforementioned models. After train-
ing the ResNet50 network model using the Malimg dataset, the detection
accuracy for related malicious code datasets was 82.35%. Using the BIG2015
dataset for training, the detection accuracy of the ResNet50 network model
for related malicious code datasets was 82.23%. When enhancing the datasets
using adversarial network models on the Malimg dataset, it was observed
that the ResNet50 network model’s detection capability could be signifi-
cantly improved. Specifically, the GAN network model demonstrated the best
enhancement effect, boosting the detection accuracy of ResNet50 to 95.18%.
Similar performance enhancements were observed with the ResNet50 model
on the BIG2015 dataset, where the GAN network model also showed the
most significant improvement, increasing the detection accuracy of ResNet50
to 89.46%.
Malicious Code Detection Based on Generative Adversarial Model 169

2 Background
Pioronski [10] employed WGAN-GP to generate samples. This approach aimed to
address the challenge of data class imbalance and to enhance the classifier’s effi-
cacy in identifying malicious programs. Wan [11] utilized the Web Feature Sam-
ple Generation Adversarial Network (WFS-GAN). This network was designed
to create feature samples of malicious web pages. It effectively tackled the diffi-
culties associated with collecting samples and the extensive task of annotating a
large number of samples. Zhu [12] integrated image processing techniques with
the WGAN-GP model. They applied this combination to augment the malicious
code dataset. The researchers compared the recognition rates of various classi-
fier models, including decision trees. Their findings demonstrated that the aug-
mented dataset could boost the performance of classifiers. Moreover, it helped to
mitigate the issue of data imbalance across different categories within the original
dataset. Certainly, numerous experimental studies have integrated GAN-related
models with malicious code detection.
We uses generative adversarial networks to augment the malware dataset,
which consists of two components: a discriminator model for judging, and a gen-
erator model for producing data. GANs achieve model learning through adversar-
ial training. Within the GAN framework, two models are simultaneously trained:
one is the generator model G, which learns the distribution of data, and the other
is the discriminator model D, which distinguishes between real samples and fake
samples generated by G. The training process of generator G aims to maximize
the probability of discriminator D making mistakes, specifically by misclassifying
generated fake data as real samples. This means during training, the discrimina-
tor D is trained to distinguish real samples from generated ones as accurately as
possible, while the generator G is trained to produce increasingly realistic fake
samples to deceive D.
In this minimax game, there exists a unique equilibrium solution where the
generator G can generate a distribution similar to that of real samples, and
the discriminator D achieves a probability of 1/2 when making judgments. This
process can be expressed using formula (1):

min max V (D, G) = Ex∼pdata (x) [log D(x)] + Ez∼pat(x) [log(1 − D(G(z)))] (1)
G D

3 Model

To address class imbalance in malicious code datasets, this paper uses GAN,
WGAN, and WGAN-GP to augment the dataset. The augmented data trains a
ResNet50 model, which is then used for detection and testing. The design of the
GAN-based detection scheme is shown in Fig. 1.
170 J. Zhang et al.

Fig. 1. Design of malicious code detec- Fig. 2. Generative Adversarial Networks


tion scheme based on adversarial gener-
ation model

3.1 Dataset Augmentation Scheme Design

In this section, we utilize Generative Adversarial Networks (GAN), Wasserstein


GAN (WGAN), and WGAN with Gradient Penalty (WGAN-GP) models to
enhance malicious code datasets. The original datasets are employed to train
these models, which subsequently generate synthetic malicious code images to
expand the dataset. We evaluate the impact of these enhancements on the per-
formance of classifiers and identify the optimal model among them. Detailed
information about each model will be presented in the following sections. The
model of the generative adversarial network is shown in Fig. 2.

GAN Network Model

Discriminator Implementation. The discriminator has three fully connected lay-


ers with LeakyReLU and Sigmoid activations, converting image features into
probability outputs. It processes a 4096-dim vector from a 64 × 64 grayscale
image, reducing dimensions to 512 and 256, outputting a probability of the
image being real.

Generator Implementation. The generator uses three fully connected layers,


ReLU activation, and Tanh for the final output. It takes a 100-dim noise vector,
expanding it to 256, then 512, and finally outputting a 4096-dim vector, which
is transformed into a 64 × 64 grayscale image.

Loss Function. Training uses BCELoss to assess the discriminator’s predictions


on real and generated samples. The total loss is calculated, gradients are zeroed,
and backpropagation updates the discriminator. The generator’s loss is com-
puted similarly, updating its parameters.
Malicious Code Detection Based on Generative Adversarial Model 171

WGAN Network Model

Discriminator Implementation. In the WGAN model, the discriminator consists


of three fully connected layers with LeakyReLU activations. The input (4096
dimensions) is reduced to 512, then 256, and finally produces a single output.
Unlike traditional GANs, the WGAN discriminator lacks a sigmoid activation,
allowing real-valued scores. The input image is flattened and processed through
these layers to evaluate authenticity.

Generator Implementation. The WGAN generator has five linear layers, each fol-
lowed by LeakyReLU and ending with a Tanh activation. The 100-dimensional
noise vector is transformed through dimensions of 128, 256, 512, and 1024, with
normalization and activations at each step. The final layer outputs 4096 dimen-
sions, using Tanh to map values between −1 and 1, helping mitigate the gradient
vanishing issue and ensuring fault tolerance during generation.

Loss Function and Gradient Clipping. The generator minimizes the WGAN loss,
while the discriminator maximizes it. The loss functions for the generator and
discriminator are shown in Eq. (2). RMSProp is used for optimization, and
gradient clipping is applied to the discriminator to prevent gradient explosion.
The clamp() method is typically used to constrain the gradients within a specified
range, ensuring training stability.

LossG = −D(G(z)), LossD = D(G(z)) − D(x) (2)

WGAN-GP Network Model. The generator and discriminator in WGAN-


GP are similar to those in WGAN, consisting of linear layers and activation
functions. During training, they alternate, leading to the final model. The key
difference between WGAN-GP and WGAN is the addition of a gradient penalty
term in the discriminator’s loss function and the use of the Adam optimizer.

Gradient Penalty. Gradient penalty constrains the discriminator’s gradient,


unlike gradient clipping, which limits the entire network’s gradient. In the code,
it’s implemented by interpolating between real and generated samples, feeding
them into the discriminator, and calculating the gradient using the grad func-
tion. The gradient norm is computed, and the mean squared deviation from 1 is
used as the penalty term, which is added to the discriminator’s loss to constrain
its gradient.

3.2 Malicious Code Detection Scheme Design

This paper uses the ResNet50 network model as the target detection algorithm
because the ResNet50 model’s input and output are both images and it can train
deeper neural networks. It employs a residual structure to preserve original fea-
tures, enhancing both the precision and generalization ability of the network
172 J. Zhang et al.

model. In this section, the augmented dataset obtained through adversarial net-
work model enhancement and the original dataset are both used to train the
ResNet50 model. After training, the corresponding model is used for detection
on the test set.
The input to the ResNet50 model undergoes 5 stages to obtain the final
output, passing through 49 convolutional layers and 1 fully connected layer. The
overall network structure of ResNet50 is shown in Fig. 3.

Fig. 3. The network structure of the ResNet50 model

4 Experiment
First, we implemented GAN, WGAN, and WGAN-GP models to augment the
dataset, resulting in six datasets: m1, m2, m3 for Malimg, and b1, b2, b3 for
BIG2015. Each dataset contains approximately 1K samples. Then, these datasets
were all fed into ResNet50 for training. We demonstrate the effectiveness of
enhancing datasets with different GAN models. We utilize each trained ResNet
model for malicious code detection and evaluate metrics such as Accuracy, Recall,
Precision, and F1 score. We use these evaluation metrics to compare and assess
the performance improvement of ResNet in detecting malicious code after aug-
menting the dataset with various GAN models.

4.1 Dataset
We use the Malimg dataset and the Microsoft Malware Classification Challenge
(BIG2015) dataset for our experiments. Specifically:
Malicious Code Detection Based on Generative Adversarial Model 173

Malimg Dataset. The Malimg dataset consists of grayscale images of malicious


code. We need to resize the grayscale images in the dataset to appropriate shapes
to match the inputs required by each model for training. The preprocessing of
the Malimg dataset involves three transformation steps.

Resizing. The first transformation step is resizing the images to a specified


size. This ensures that all images are uniformly resized to the same dimensions,
facilitating effective training.

Converting. The second transformation step involves converting the images to


tensor format, as neural network models require input data in tensor form. This
step converts PIL image objects to tensors, making subsequent processing and
computation easier.

Normalizing. The final transformation step is to normalize the tensors obtained


in the second step. This operation scales the pixel values from the [0, 255] range
to the [−1, 1] range by subtracting the mean and dividing by the standard devi-
ation. Normalization helps the model learn the image features more efficiently,
improving training outcomes and performance.

BIG2015 Dataset. The BIG2015 dataset consists of malicious code binary


data. We need to convert binary byte files into grayscale images using the method
described by L. Nataraj [13], and then apply the same transformation steps as
those used in the Malimg dataset. The binary byte file conversion process for
BIG2015 dataset consists of the following steps:

Read. We first need to read the hexadecimal text of the bytes file, ignoring
invalid characters and address information in each line. In this step, we remove
unnecessary information and keep only the malicious bytes from the file to turn
into a grayscale image.

Fill. In the second step, we determine the image width according to the byte
file size, and then fill the byte data according to the determined width to obtain
the grayscale image corresponding to the binary file.

4.2 Model Training

Additionally, in our model, we used the Malimg dataset and BIG2015 dataset
to train GAN, WGAN, and WGAN-GP. And then used the trained model to
generate appropriate gray-scale images to supplement the dataset. This makes
the number of gray-scale images of malicious code in the test set similar to the
number of gray-scale images of non-malicious code, to better improve the test
performance of ResNet50.
The ratio of the number of gray-scale images between the malicious code
category and non-malicious code category of the enhanced dataset is 1:1, and
174 J. Zhang et al.

secondly, the training dataset and the test dataset both have about 1.5K gray-
scale images.
Training GANs begins by loading pre-trained discriminator and generator
models. Then, we proceed with the training loop. In each epoch, we go through
each batch from the data loader. For each batch, we first update the discrim-
inator: the generator creates samples using random noise, and these generated
samples along with real samples (malicious grayscale images) are fed into the
discriminator. We calculate the discriminator’s loss function and perform back-
propagation to update the parameters. After that, we train the generator: again,
the generator creates samples using random noise, and after calculating the gen-
erator’s loss function, we update the generator. We use the Adam optimizer to
update the parameters of both the discriminator and the generator, and binary
cross-entropy loss is used to calculate the relevant losses.
For WGAN, we also train the discriminator, but the difference is that we
clip the parameters to restrict the weight range when updating them. After
training the discriminator for n_critic times, we update the generator once. The
optimization algorithm used in this process is RMSProp.
The training process for WGAN-GP is similar to WGAN, but the difference
is that after calculating the scores for real and generated samples, WGAN-GP
needs to calculate a gradient penalty term and add it to the discriminator’s
loss function for backpropagation to update the discriminator. The optimization
algorithm used here is Adam.
In the training of GAN, WGAN, and WGAN-GP, the Malimg dataset and
the BIG2015 dataset are divided into a training set and a test set. The ratio of
the number of grayscale images in the training set and the data set is 1:1, and
the number is 1K in both sets.
The parameters for GAN, WGAN, and WGAN-GP are shown in Table 1.

Table 1. Parameter Settings related to generative adversarial network model

Parameter Description GAN parameters WGAN parameters WGAN-GP parameters


n_epochs the number of training epochs 500 1000 1000
batch_size the size of each batch 256 256 256
lr the learning rate for the optimization algorithm 0.0001 0.00005 0.0002
latent_dim the dimension of the latent space 100 100 100
img_size the image size 64 256 256
n_critic Number of training steps for the discriminator in each iteration 2 2 2
channels the number of image channels 1 1 1
clip_value the clipping value for the lower and upper limits of the discriminator weights – 0.01 –

In the training of ResNet model, the augmented the Malimg dataset and
BIG2015 dataset are divided into a training set and a test set. The ratio of the
number of grayscale images in the training set and the data set is 1:1, and the
number is 1.5K in both sets. The ratio of gray-scale images of malicious code to
gray-scale images of non-malicious code in the training set and the data set is
12:5. After each training epoch, the training set is used to test the values of the
loss function and accuracy. If the detected accuracy increases, the current model
file is saved.
Malicious Code Detection Based on Generative Adversarial Model 175

We set the parameters for ResNet are shown in Table 2.

Table 2. Parameter Settings related to ResNet50

Parameter type Description parameters


epochs the total number of training epochs 100
batch_size the number of samples input to the model at one time 16
lr the learning rate for the Adam optimization algorithm 0.1
img_size the size of the input images 244

4.3 Experiment Results


Experimental Results for GAN Models. In the Malimg and BIG2015
datasets, the distribution in the loss function values for the three models during
the training process are shown in Figs. 4, 5, 6, 7, 8 and 9.
Based on the distribution of loss function values for the three models during
the training process, select a appropriate epoch for GAN, WGAN, and WGAN-
GP models. Evaluate the performance of these models based on metrics such as
Accuracy, Recall, Precision, and F1 score, as shown in Table 3.

Table 3. Table of model performance metrics for Generative adversarial networks

Model Dataset Epochs Accuracy Recall Precision F1 score


GAN Malimg 400 0.976 1.0 0.953 0.965
BIG2015 300 0.846 0.943 0.736 0.787
WGAN Malimg 335 0.899 0.950 0.842 0.870
BIG2015 430 0.736 0.904 0.528 0.615
WGAN-GP Malimg 450 0.535 0.519 1.0 0.698
BIG2015 655 0.767 0.963 0.557 0.646

During the training process, after selecting suitable GAN, WGAN, and
WGAN-GP models, their respective generators were used to generate 2500
grayscale images of malicious code for each of the two datasets to augment
their quantities. Before augmentation, the ratio of grayscale images between
the datasets was 10:17. Following augmentation, the dataset was enriched with
200 generated malicious code images and 500 images from non-malicious code
classes. The post-augmentation ratio between grayscale images of malicious code
and non-malicious code classes became 12:5.
176 J. Zhang et al.

Experimental Results of ResNet50 Model. Training ResNet 50 network


models using the original Malimg dataset and BIG2015, with the loss func-
tion values depicted in Figs. 10 and 11. Using the GAN, WGAN, and WGAN-
GP models respectively, augment the Malimg dataset and then train it on the
ResNet50 network model. The loss function values for the three cases are shown
in Figs. 12, 13, 14, 15, 16 and 17.
The ROC and PR curves for the ResNet50 model on both the original Malimg
and BIG2015 datasets, as well as those augmented with the best-performing
GAN model, are depicted in Figs. 18, 19, 20 and 21.

ResNet50 Performance. According to Figs. 12, 13, 14, 15, 16 and 17, select
the ResNet50 model with convergent loss functions. We assess ResNet50’s per-
formance by measuring its accuracy and training efficiency before and after data
augmentation. The relevant performance metrics can be found in Table 4. Table 4
demonstrates that training the ResNet50 model with a dataset augmented by
Generative Adversarial Networks leads to an improvement in classification accu-
racy compared to training with the original dataset. Furthermore, by comparing
the model training times in Table 4 before and after data augmentation, it is
observed that the efficiency of the model has also been improved.
In training the ResNet50 network model on the Malimg dataset, by observ-
ing the loss function value changes and related performance tables, the model
trained with GAN performs better, while the performance of models trained with
WGAN and WGAN-GP is relatively poor. Similarly, in the BIG2015 dataset,
by observing the loss function value changes and performance metrics tables, it
is also determined that the model trained with GAN performs better, while the
performance of models trained with WGAN and WGAN-GP is relatively poor.

Table 4. Table of ResNet50 detection model performance metrics

Dataset Augmented model recognition accuracy time


Malimg – 82.35% 19.05 s
GAN 95.18% 17.84 s
WGAN 84.07% 18.86 s
WGAN-GP 84.64% 15.80 s
BIG2015 – 82.23% 17.14 s
GAN 89.46% 16.83 s
WGAN 83.49% 16.86 s
WGAN-GP 84.89% 16.74 s
Malicious Code Detection Based on Generative Adversarial Model 177

Fig. 4. Loss change plot of GAN Fig. 5. Loss change plot of WGAN
trained on Malimg dataset. trained on Malimg dataset.

Fig. 6. Loss change plot of WGAN-GP Fig. 7. Loss change plot of GAN
trained on Malimg dataset. trained on BIG2015 dataset.

Fig. 8. Loss change plot of WGAN Fig. 9. Loss change plot of WGAN-GP
trained on BIG2015 dataset. trained on BIG2015 dataset.

Fig. 10. Loss change plot for training Fig. 11. Loss change plot of ResNet50
ResNet50 on the unaugmented Malimg trained on the unaugmented BIG2015
dataset. dataset.
178 J. Zhang et al.

Fig. 12. Loss change plot of ResNet50 Fig. 13. Loss change plot of ResNet50
trained on Malimg dataset augmented trained on Malimg dataset augmented
with GAN model. with WGAN model.

Fig. 14. Loss change plot of ResNet50 Fig. 15. Loss change plot of ResNet50
trained with Malimg dataset aug- trained on the BIG2015 dataset aug-
mented with WGAN-GP model. mented with WGAN-GP model.

Fig. 16. Loss change plot of ResNet50 Fig. 17. Loss change plot of ResNet50
trained on BIG2015 dataset augmented trained on BIG2015 dataset augmented
with WGAN model. with WGAN-GP model.

When the Malimg dataset without data augmentation is used to train the
ResNet50 network model, the accuracy of detecting related malware datasets is
82.35%. When the BIG2015 dataset without data augmentation is used to train
the ResNet50 network model, the accuracy of detecting related malware datasets
is 82.23%.
Malicious Code Detection Based on Generative Adversarial Model 179

After training the generative adversarial network models, GAN, WGAN, and
WGAN-GP models were used to generate a certain number of malware grayscale
images to augment the Malimg and BIG2015 datasets. As shown in Table 4, in
the Malimg dataset, augmenting the dataset with adversarial network models can
indeed improve the detection capability of the ResNet50 network model to some
extent. Among them, the GAN model has the best improvement effect, increasing
the detection accuracy of ResNet50 to 95.18%. The effects of WGAN-GP and
WGAN are slightly inferior to GAN, with WGAN-GP performing slightly better
than WGAN. In the BIG2015 dataset, the performance improvement of the
ResNet50 model is similar. The GAN model again has the best improvement
effect, increasing the detection accuracy of ResNet50 to 89.46%, followed by
WGAN-GP, with WGAN having the relatively poorer effect.

Comparison of Experimental Results Agarap et al. [14] used CNN-SVM,


GRU-SVM, and MLP-SVM to detect the Malimg dataset, with the highest recog-
nition accuracy among them being GRU-SVM, reaching approximately 84.92%.
The recognition accuracy of the ResNet50 model trained on the Malimg dataset
augmented with the GAN model is 95.18%, which is relatively better. The recog-
nition accuracy of the relevant experiments by Agarap [14] and the recognition
accuracy of the ResNet50 model trained on the Malimg2015 dataset augmented
with the GAN model are shown in Table 5.
Zhu et al. [12] utilized generative adversarial networks to enhance the
BIG2015 dataset and employed decision trees, KNN, and random forests for
identification. Among them, random forests achieved the highest recognition
accuracy, with an accuracy of 90.33% on the original data and 94.11% on the
generated data. Through comparison, it is noted that the recognition accuracy
of the ResNet50 model trained on the BIG2015 dataset augmented with GAN
is 89.46%, which is superior to the performance of decision trees and KNN, but
relatively inferior to that of random forests. For the recognition accuracy of Zhu
[12]’s relevant experiments and the recognition accuracy of the ResNet50 model
trained on the BIG2015 dataset augmented with GAN, please refer to Table 6.

Table 5. The recognition accuracy of experiments related to the Malimg dataset

Dataset CNN-SVM GRU-SVM MLP-SVM ResNet50


Malimg 77.23% 84.92% 80.47% 82.35%
GAN Enhanced malimg – – – 95.18%

Table 6. The recognition accuracy of experiments related to the BIG2015 dataset

Dataset Decision tree KNN Random forest ResNet50


BIG2015 75.11% 80.67% 90.33% 82.23%
Zhu Enhanced BIG2015 84.33% 81.89% 94.11% –%
GAN Enhanced BIG2015 – – – 89.46%
180 J. Zhang et al.

Fig. 18. ROC and PR charts for the ResNet50 trained on the original Malimg dataset.

Fig. 19. ROC and PR charts for the ResNet50 trained on the original BIG2015 dataset

Fig. 20. ROC and PR charts for the ResNet50 trained on the Malimg dataset aug-
mented with GAN
Malicious Code Detection Based on Generative Adversarial Model 181

Fig. 21. ROC and PR charts for the ResNet50 trained on the BIG2015 dataset aug-
mented with GAN

4.4 Model Efficiency


It takes 6.5887 s per epoch to enhance the Malimg dataset, and 17.7413 s per
epoch to enhance the BIG2015 dataset. It takes an average of 4.8248 s per epoch
to enhance the Malimg dataset using WGAN, and an average of 16.4889 s per
epoch to enhance the BIG2015 dataset. It takes an average of 4.4215 s per epoch
to enhance the Malimg dataset using WGAN-GP, and 16.1578 s per epoch to
enhance the BIG2015 dataset. According to the above data, we can see that the
training time of WGAN and WGAN-GP is shorter than that of GAN training.
The training process of WGAN and WGAN-GP is indeed optimized compared
to GAN. At the same time, according to Table 4, the classification efficiency of
ResNet50 is indeed improved by using the enhanced data set.

4.5 Experimental Analysis


In the training of generative adversarial network (GAN) models, GAN performs
better than WGAN and WGAN-GP on both the Malimg and BIG2015 datasets.
After enhancing the datasets with generative adversarial network models, the
performance of the ResNet50 model trained on the augmented Malimg and
BIG2015 datasets improves compared to the unaugmented datasets, with the
model using GAN for data augmentation exhibiting the best performance. Based
on the experimental results, although WGAN and WGAN-GP have enhanced
training stability and improved the quality of generated samples over traditional
GAN models, this does not automatically mean better classification performance.
This simply implies that the images generated by WGAN and WGAN-GP might
look more realistic, but in some cases, this realism may not directly lead to bet-
ter classification outcomes. This is because classification models might be good
at spotting specific patterns or biases in the images created by GANs.
The experimental results indicate that generative adversarial network models
can effectively enhance datasets, thereby improving the performance of detection
models. Generative adversarial network models can generate more diverse and
realistic samples of malicious code, aiding in training detection models to bet-
ter capture various types of malicious code features. Enhancing malicious code
182 J. Zhang et al.

datasets helps improve the robustness and generalization ability of malicious


detection models. Furthermore, generative adversarial network models can gen-
erate challenging samples, which may be unseen variants for detection models
in the original dataset, better addressing unknown malicious code variants.
Generative models of generative adversarial networks can provide more
diverse and challenging data samples for malicious code detection models,
thereby enhancing the performance of detection models. The use of these tech-
niques helps strengthen the capabilities of malicious code detection systems,
enabling them to better address increasingly complex and diverse malicious code
threats.

5 Conclusions
The paper implements three types of generative adversarial models. Subse-
quently, each model is used to enhance the Malimg dataset and the BIG2015
malicious dataset separately. The augmented datasets are then fed into the
ResNet50 model for malicious code detection, and the detection accuracy before
and after augmentation is compared. The experiments demonstrate that genera-
tive adversarial networks can indeed improve the performance of object detection
algorithms, with GAN showing the best enhancement effect among the three
models.

Acknowledgments. This work is supported by the National Natural Science foun-


dation of China under grant of corresponding author. Jia Yang is supported by China
NSF (NO. 62202146).

References
1. He, K., Zhang, X., Ren, S., et al.: Deep residual learning for image recognition. In:
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition,
pp. 770–778 (2016)
2. Yang, Z., Deng, F., Han, L.: Flexible android malware detection model based on
generative adversarial networks with code tensor. In: 2022 International Conference
on Cyber-Enabled Distributed Computing and Knowledge Discovery (CyberC), pp.
19–28. IEEE (2022)
3. Bac, T.N., Duy, P.T., Pham, V.: PWDGAN: generating adversarial malicious URL
examples for deceiving black-box phishing website detector using GANs. In: 2021
IEEE International Conference on Machine Learning and Applied Network Tech-
nologies (ICMLANT), pp. 1–4. IEEE (2021)
4. Wang, J., Liu, M., Yin, X., et al.: Semi-supervised malicious traffic detection with
improved wasserstein generative adversarial network with gradient penalty. In:
2022 IEEE 6th Advanced Information Technology, Electronic and Automation
Control Conference (IAEAC), pp. 1916–1922. IEEE (2022)
5. Hu, W., Cheng, J., Chong, X., et al.: A GAN-based anti-obfuscation detection
method for malicious code. In: 2022 3rd International Conference on Pattern Recog-
nition and Machine Learning (PRML), pp. 484–488. IEEE (2022)
Malicious Code Detection Based on Generative Adversarial Model 183

6. Goodfellow, I., Pouget-Abadie, J., Mirza, M., et al.: Generative adversarial nets.
Adv. Neural Inf. Process. Syst. 27 (2014)
7. Arjovsky, M., Chintala, S., Bottou, L.: Wasserstein generative adversarial networks.
In: International Conference on Machine Learning, pp. 214–223. PMLR (2017)
8. Gulrajani, I., Ahmed, F., Arjovsky, M., et al.: Improved training of Wasserstein
GANs. Adv. Neural Inf. Process. Syst. 30 (2017)
9. Wang, A., Ding, Y.: Network malicious traffic identification method based on
WGAN category balancing. In: 2021 IEEE International Conference on Signal
Processing, Communications and Computing (ICSPCC), pp. 1–6. IEEE (2021)
10. Pioroński, S., Górecki, T.: Using GAN to generate malicious samples suitable for
binary classifier training. In: 2022 IEEE International Conference on Big Data (Big
Data), pp. 6522–6527. IEEE (2022)
11. Wan, M., Yao, H., Yan, X.: Generation of malicious webpage samples based on
GAN. In: 2020 IEEE 19th International Conference on Trust, Security and Privacy
in Computing and Communications (TrustCom), pp. 864–869. IEEE (2020)
12. Zhu, X., Qian, L., Fu, W.: Method of enhancing malicious code based on generative
adversarial network. Comput. Eng. Design 42(11), 3032–3042 (2021)
13. Nataraj, L., Karthikeyan, S., Jacob, G., Manjunath, B.S.: Malware images: visu-
alization and automatic classification. In: Proceedings of the 8th International
Symposium on Visualization for Cyber Security, pp. 1–7 (2011)
14. Agarap, A.F.: Towards building an intelligent anti-malware system: a deep learning
approach using support vector machine (SVM) for malware classification. arXiv
preprint arXiv:1801.00318 (2017)
Construction of an AI Code Defect Detection
and Repair Dataset Based on Chain of Thought

Huimin Gong1 , Zongliang Shen1(B) , Hua Zhang1 , Lei Qiao2 , Huawei Wang2 ,
and Chi Zhang3
1 State Key Laboratory of Networking and Switching Technology, Beijing University of Posts
and Telecommunications, Beijing, China
{shenzongliang,zhanghua_288}@bupt.edu.cn
2 Beijing Institute of Control Engineering, Beijing, China
3 Information Center of China, North Industries Group Corporation, Beijing, China

Abstract. When detecting and repairing code defects, enhancing the generaliza-
tion ability and detection accuracy of models is a key challenge. This paper pro-
poses a data fine-tuning method based on Chain of Thought (CoT) fine-tuning to
improve the capabilities of models on defect detection in AI code. We constructed
a dataset that includes the CrossVul dataset and a manually created dataset of AI
code defects and repairs, improving data quality through techniques like context
free removal. In the experiments, we used the Codeshell-7B, Qwencoder2.5-7B
and Llama3.1-7B as the base and trained them using LoRA fine-tuning techniques.
We compared different datasets and training methods to verify the model’s effec-
tiveness in detecting and repairing AI code defects. The results show that the CoT
fine-tuning model outperforms models without CoT fine-tuning in all aspects of
handling code defect tasks. Additionally, the specialized dataset we created for AI
code defects and repairs significantly enhances the model’s accuracy and repair
rate in AI code detection. Our experiments highlight the importance of construct-
ing targeted datasets for AI code defects and employing CoT fine-tuning strategies
in improving code defect detection.

Keywords: Chain of Thought · Code defect detection · Dataset construction · AI


code · Fine-tuning

1 Introduction
Code defects are inevitable issues in software development, potentially leading to sys-
tem crashes, functional failures, or even security vulnerabilities. As modern software
scales and code complexity increases rapidly, traditional code defect detection methods
are becoming inadequate. Static analysis [1] and dynamic analysis [2] are two main
traditional approaches for detecting code security defects. Although static analysis can
quickly identify potential defects, it often has a high false positive rate and limited under-
standing of the actual execution behavior of the code. Dynamic analysis relies on runtime
information, which can capture actual defects but is inefficient and often provides poor
repair suggestions. Consequently, traditional methods struggle to meet expectations in
large-scale complex systems.

© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 184–196, 2025.
https://doi.org/10.1007/978-981-96-4566-4_13
Construction of an AI Code Defect Detection 185

Recently, generative AI, particularly large language models (LLMs), has achieved
significant success in natural language processing (NLP). Given their outstanding perfor-
mance in tasks such as text generation, translation, and question answering, researchers
have started applying these models to detect and repair code defects. LLMs, with their
powerful contextual understanding and generation capabilities, theoretically have the
potential to effectively capture complex logical defects in code and provide accurate
repair suggestions. However, the high specificity of AI code presents significant chal-
lenges. Firstly, AI code datasets are mostly close-source and are not availability through
public access, making relevant data very scarce. Additionally, existing large models
typically perform poorly on vulnerabilities related to AI code, failing to generalize
effectively. This exposes the limitations of current large models in code defect detection
and repair tasks. Moreover, quality issues of current dataset further constrain the perfor-
mance of models. In large code datasets, the presence of redundant context, noise, and
irrelevant code makes it difficult for models to focus on the actual defect areas. Mean-
while, existing prompt design methods often struggle to guide models in generating
effective repair suggestions.
This paper aims to explore the application of large language models in detecting and
repairing AI code defects, constructing and processing diverse datasets to optimize the
fine-tuning process and ensure the model is exposed to a wide range of code defect types.
This paper primarily adopts two key dataset processing methods: context free removal
and Chain of Thought (CoT) prompting. These dataset processing methods gradually
enhance the model’s perception of code defects during fine-tuning, enabling the model to
have strong defect identification and repair capabilities when handling unseen code snip-
pets. Ultimately, a series of experiments were conducted to evaluate the performance of
the large language model fine-tuned with different dataset processing methods, verifying
the effectiveness of the proposed methods and the model’s generalization capabilities.

2 Background
2.1 Code Defect Identification and Repair
Traditional code defect identification methods include static analysis and dynamic anal-
ysis. Static analysis checks the source code to identify potential security defects, and
due to its efficiency and low cost, it is widely used in large-scale systems. However,
it exhibits certain limitations [3] when dealing with complex problems, as it cannot
accurately determine the execution environment of code in different contexts. It is also
diffcult to effectively handle detection tasks involving complex structures like loops
and branches. In contrast, dynamic analysis monitors software behaviors during execu-
tion, allowing it to capture code defects under different execution environments more
accurately. However, since it requires software execution, it can consume significant
resources in large-scale systems and cannot comprehensively cover all paths, which
limits its applicability.
Machine learning models learn from large amounts of labeled code defect data,
which allows them to automatically extract features and identify known vulnerability
types. Unlike traditional detection methods, methods based on machine learning are not
limited to manually defined rules and can discover vulnerabilities by learning patterns
186 H. Gong et al.

in data. For example, Raychev et al. proposed a statistical learning-based code analysis
tool [4], which significantly improves the automation of code defect detection. However,
machine learning depends on large amounts of labeled data and struggles with complex
contextual relationships in code. These models often perform poorly when encountering
unseen code snippets or new types of vulnerabilities, as they lack sufficient generalization
ability.
Meanwhile,large language models(LLMs) show great potential in code defect detec-
tion and repair. Chen et al. studied the ability of LLMs to understand and generate code
[5], offering insights into using LLMs for code-related tasks. Ribeiro et al. proposed a
technique using GPT-3 to automatically repair type errors in programs, which achieves
good results on Ocaml programs [6]. Keller et al. proposed the Gemini model [7], which
automates the process from vulnerability discovery to repair code generation and testing.
With the deepening of these studies, more and more evidence suggests that LLMs have
enormous potential in code defect detection and repair.

2.2 AI Code Defect


AI code defects refer to errors or vulnerabilities present in code related to artificial intel-
ligence. This code is typically used to support the training, inference, and deployment of
AI models, involving complex data processing, model invocation, and interaction logic.
Due to the significant differences in design and implementation between AI code and
traditional code, AI code defects often exhibit unique characteristics and causes. The
high complexity and specificity of AI code are key reasons behind the occurrence of
these defects. AI code usually handles large amounts of data flow, model parameters,
and highly dynamic inference processes, making its structure more complex than that
of traditional code.
Research shows that the complexity and dependencies of AI code often lead to the
inability of traditional static analysis methods to effectively identify potential defects
[8]. However, in current AI code defect detection research, there is a lack of open-source
datasets and vulnerability collections specifically designed for AI code, which poses a
significant challenge for researchers. Most existing code defect datasets are designed for
traditional software development and program analysis, making it difficult to cover the
unique issues found in AI code.

2.3 Large Model Training


Incremental pre-training is widely considered to inject vast amounts of knowledge into
models, but its high computational resource demands and training costs make it suit-
able only for base models. While this approach enhances the model’s knowledge base
through extensive data, it incurs significant computational and time costs. In contrast,
fine-tuning and external knowledge base attachment are more practical and flexible
options. Fine-tuning refers to further training the model on specific tasks, adjusting
the model parameters to better adapt to different tasks. Attaching an external knowl-
edge base involves adding external data resources to the model to enhance its capa-
bilities.However, this method may cause the model to misinterpret and utilize external
knowledge incorrectly [9]. The Low-Rank Adaptation (LoRA) method is an efficient
Construction of an AI Code Defect Detection 187

strategy in the fine-tuning process. LoRA reduces the number of parameters that need
to be updated through low-rank matrix decomposition, thereby lowering computational
costs and resource requirements. This method is particularly suitable for researches with
limited resources, providing new ideas for efficient model fine-tuning [10].

2.4 Prompt Engineering

The purpose of prompt engineering is to design and use of prompts to guide large models
in performing specific tasks. These prompts can be natural language descriptions, code
snippets, or specific instructions for the task, helping the model understand the task
requirements and generate appropriate outputs. Prompt engineering is widely used in
NLP and code generation tasks and has proven effective [11]. In code defect detection
and repair tasks, prompt engineering can generate clear prompts so that the model can
identify and fix code defects better. This improves task efficiency and accuracy [12].
However, traditional prompt engineering has limitations, especially when dealing with
multi-step reasoning tasks. Models often rely on pre-defined task formats. When the
problem changes slightly or multiple problems are combined, the performance of model
may not meet expectations. This rigid structure makes it difficult for models to flexibly
respond to complex contexts, leading to inadequate generalization. To address this issue,
the Chain of Thought (CoT) reasoning method has been widely applied in generative
AI in recent years. Wei et al. applied the Chain-of-Thought (CoT) reasoning method to
natural language reasoning tasks [13]. They demonstrated that this method significantly
improves the ability of large language models to handle reasoning tasks. Although the
focus of this paper is on natural language processing, the approach has also been applied
to the code generation domain. Yang et al. applied CoT to code completion [14], where
step-by-step reasoning helped generate high-quality and logically consistent code. This
method can be effectively transferred to code defect repair tasks, enhancing the accuracy
and logical consistency of repairs through the gradual analysis of errors in the code.

3 Methodology

This paper utilizes the CrossVul dataset [14] and a manually constructed AI code defect
and repair dataset to enhance the model’s performance in detecting and repair code
defects while addressing the model’s shortcomings in understanding and handling AI-
generated code. To further optimize the application of our dataset, this paper proposes
two key processing methods: irrelevant context removal and Chain of Thought prompting
to enhance the model’s ability to perceive and repair code defects.

3.1 Dataset Construction

In this paper, two key datasets were used for model training and fine-tuning: the CrossVul
dataset and a manually constructed AI code defect and repair dataset, details are as
follows.
188 H. Gong et al.

1. CrossVul Dataset
CrossVul is a publicly available dataset specifically designed for cross-language code
vulnerability detection. It collects code snippets from various projects in different
programming languages containing known security vulnerabilities from real-world
scenarios and extensive repositories. Each vulnerability sample is accompanied by
a detailed description, including the type of vulnerability, scope of impact, and cor-
responding secure code. As shown in Fig. 1, the dataset has a hierarchical structure
where the top-level directory is named after CWE-ID, differentiating various vulner-
ability types. The second-level directory is divided by programming language, and
the files are named to indicate whether the code snippet is vulnerable (bad) or secure
(good), organized into pairs by number. CrossVul contains over 40 programming
languages, 168 major types, and 13,738 vulnerability samples.

Fig. 1. Hierarchical Structure of the CrossVul Dataset

2. AI Code Defect and Repair Dataset


To enhance the model’s ability to detect and repair AI code defects, a set of code
snippets covering multiple languages and common AI code error types was manually
constructed. Unlike the CrossVul dataset, code snippets in this dataset mainly come
from real-world AI component codebases. We selected high-starred AI components
from GitHub with a large user base and extracted code snippets containing specific
functions. AI-specific code defect types were inserted into these code snippets to
construct erroneous samples, with corresponding repair codes provided. Each erro-
neous sample is paired with its repair sample to train the model to understand and
correct code defects. AI code-specific defect types include, but are not limited to, neu-
ral network connectivity errors, data processing API compatibility errors, gradient
convergence errors, and dimensional mismatch errors. These error types are further
divided into different subtypes, covering 36 types and 1,200 samples.

3.2 Dataset Processing


To improve the model’s performance in detecting code defects, this paper applied two key
processing methods to the dataset: context free removal and chain-of-thought prompting
design.
Construction of an AI Code Defect Detection 189

Context Free Removal. When processing the CrossVul dataset, we observed that each
vulnerability sample and its corresponding secure code often included context code unre-
lated to the defect. These irrelevant code not only interfered with the training process of
the model but also reduced the precision of defect detection. Therefore, this paper pro-
poses a context free removal method to enhance the dataset’s effectiveness and simplify
the model’s learning task. As shown in Fig. 2, the steps for context free removal are as
follows.
1. Function Extraction
Each vulnerability sample typically consists of multiple functions or code seg-
ments. The first step involves extracting each function or logical code block from the
sample using an Abstract Syntax Tree (AST) parser.
2. Function Name and Parameter Matching
The extracted functions are parsed syntactically using the AST parser to identify
those related to the defect or its fix. The CrossVul dataset contains pairs of vulnera-
bility and repair code samples, and in most cases, the defect and fix occur within the
same function. By matching function names and parameters, we can filter out code
blocks directly related to the defect.
3. Semantic Analysis
The relevant functions undergo semantic analysis to determine which code blocks
are directly linked to the root cause of the defect. Code segments not involving the
defects (e.g., helper functions, initialization code) are marked as “irrelevant context”,
which will be eliminated in step 5.
4. Difference Comparison
The differences between the defect code and the repair code are compared to
extract the specific lines of code involved in the fix. This step uses the code comparison
method from diff tools to extract the modified parts from the repair and defect files,
ensuring the correspondence between defect and repair codes. Irrelevant code can be
excluded after this comparison.
5. Irrelevant Code Elimination
After extracting the core code related to the defect, irrelevant context is removed.
This step primarily relies on analyzing differences, function dependencies, and vari-
able usage to ensure that the remaining code focuses on the parts directly related to
the defect.
6. Code Reorganization and Formatting
After step 5, the remaining code fragments are reorganized and formatted to
ensure that each sample retains syntactical completeness, facilitating normal input
and learning for the model.

Chain of Thought Prompting Design. Building upon the context free removal, this
paper further adopts a chain-of-thought prompting design to enhance the model’s rea-
soning ability. The core of Chain-of-thought prompting is to guide the model through a
step-by-step process, encouraging deep thinking during tasks like multi-step reasoning.
This method allows the model to simulate human reasoning when faced with complex
code defects, making it more effective at solving problems. As shown in Pseudocode
190 H. Gong et al.

Fig. 2. Context Free Removal Flowchart

Example of Error Analysis and Fixing Process, we added guiding information to the
dataset to prompt the model, including:
Problem Description: In order to help the model understand the nature of the error,
defects and fixs of each sample are accompanied by a brief and clear description. For
example, for a vulnerability involving incorrect data type input, the description might
state, “This code has a data type mismatch in the data processing API compatibility
error.” Such descriptions help the model focus on specific defect types during analysis.
Step-by-Step Reasoning Process: We designed step-by-step prompts to help the model
gradually derive the repair solution. This process aims to guide the model to identify the
type and location of the vulnerability, analyze possible solutions, and ultimately generate
repair code. For example, in code analysis, the model is first prompted to identify the
type of vulnerability, then to locate the vulnerability triggers in source code, and finally
to propose a feasible repair measure. This “chain-of-thought” prompting design allows
the model to understand and repair complex code defects in stages.
Code Examples and Repair Suggestions: In order to help the model better understand
the problem, we also provided relevant code snippets and repair suggestions after each
prompt. By integrating real code contexts, the model can more accurately infer the
repair process. The example code includes the original faulty code and its corresponding
corrected version, allowing the model to reason within a clear context and generate the
correct repair solution.
Construction of an AI Code Defect Detection 191

Pseudocode Example of Error Analysis and Fixing Process


// Define identity and conversation content
identity = "identity_17"
// Define conversation array
conversation = [
{
from: "human",
content:
"""
The error code is as follows:
data['next'] = self.serializer.validated_data['code'].next
...
data['next'] = self.serializer.validated_data['user'].login_code.next
"""
},
{
from: "assistant",
content:
"""
Let's break this down step by step:
1. Problem analysis:
In the defective code, the `next` variable is being assigned between different data types
(strings and integers),
which could lead to a type mismatch error.
2. Solution:
The assignment of the `next` variable should be split into two steps to ensure consistent
data types:
- Change `self.serializer.validated_data['code'].next` to
`self.serializer.validated_data['code']`
- Change `self.serializer.validated_data['user'].login_code.next` to
`self.serializer.validated_data['user']`
3. Fixed code:
data['next'] = self.serializer.validated_data['code']
data['next'] = self.serializer.validated_data['user']
"""
}
]

The specific steps for processing the dataset are as follows, using the CrossVul dataset
as an example.
1. CWE-ID Classification
The CrossVul dataset is classified according to CWE (Common Weakness Enumer-
ation) numbers, with each CWE-ID representing a specific type of code defect. This
classification method allows for systematic organization and management of different
192 H. Gong et al.

code vulnerabilities, facilitating subsequent analysis and processing. The CWE-ID


classification is shown in Table 1.

Table 1. CWE-ID Classification of Code Defect Types

CWE-ID Code Defect


Type
428,444,681,704,843 Data Type
Errors
20,78,79,87,88,89,90,91,93,94,98,113,114,116,117,120,129,134,178,184,185,252,269,287,288,294,295 Validation
306,307,345,346,347,352,354,369,384,426,434,436,441,494,502,522,532,552,565,639,640,862,863 Errors
74,77,78,79,362,367,476,754 Concurrency
Errors
834,835 Infinite Loops
22,59,119,120,121,122,125,129,400,401,404,415,416,672,674,763,787,824,908 Memory
Errors
190,191,193,197,311,312,319,330,471,476,665,667,682,694,770,838,915,916,917,943,1188,1236 Numerical
Instability
Errors

2. Matching Bad and Good Files


In the dataset, each bad file is paired with a corresponding good file of the same
number. Files with the same number represent a defect and its fix. The process begins
by matching bad_xx and good_xx files based on the number in the filename. Then,
code segments we need can be extracted respectively from bad and good files, and
the two code segments are compared to identify the part that has been repaired.
3. Function-Level Deduplication and Precise Comparison
Function-level code deduplication and difference comparison are used to precisely
locate the specific part of the code that was fixed. To avoid interference, this process
also involves eliminating irrelevant context, ensuring that only the core content related
to the defect and its repair is retained.
4. Data Annotation and Chain-of-Thought Dialogue Generation
Based on the classification of each CWE-ID, the extracted defect code is further
annotated with the error type. Subsequently, scripts are used to generate chain-of-
thought dialogues from the extracted defect code, specific error location, error type,
and repair solution, enabling the model to gradually understand and fix the problems
in the code.

4 Experiments
The goal of this experiment is to evaluate the performance of different models in AI code
defect detection and repair tasks, with a focus on comparing the impact of the Chain-of-
Thought reasoning method. Additionally, the experiment explores the improvement in
the model’s ability to detect AI-specific defects through the constructed AI defect repair
dataset. For a comprehensive assessment, the experiment compares the performance of
the Codeshell-7B, Qwencoder2.5-7B, and Llama3.1-7B models.
Construction of an AI Code Defect Detection 193

4.1 Dataset Setup


The dataset used in this experiment contains a total of 2,347 AI code samples. Among
them, 1,471 samples are from the CrossVul dataset, and 876 samples are from our man-
ually constructed AI code defect and repair dataset. These data cover six programming
languages: C, C++, Java, Python, Matlab, and R.
Code snippets in the CrossVul dataset are generally long. Many codes contain exten-
sive contextual information and irrelevant code segments, posing challenges for the
model’s learning and reasoning. Although we removed irrelevant context during data
preprocessing, some code snippets still exceeded the model’s maximum context length.
Therefore, we only selected code snippets and their fixed version with a total length
shorter than 8192 tokens as valid samples to construct our dataset.
To more accurately assess the model’s performance in code defect detection and
repair tasks, we divided the dataset into a training set and a test set. Specifically, 90%
of the data (2,112 entries) were used for model training, and 10% of the data (235
entries) were used as the test set. The training set comprises a mixture of data from
the CrossVul dataset and our manually constructed AI defect repair dataset, covering
various programming languages and code defect types. To better assess the model’s
performance in different domains, the test set is also divided into CrossVul data and AI
defect data, with samples randomly selected from the corresponding datasets to ensure
their broadness and representativeness.

4.2 Model Fine-Tuning


For the three models used in this experiment—Codeshell-7B, Qwencoder2.5-7B, and
Llama3.1-7B—each model was fine-tuned in three different configurations to evaluate
their performance on the AI code defect detection and repair tasks. These configurations
include fine-tuning on the CrossVul dataset, fine-tuning with the CrossVul dataset com-
bined with Chain-of-Thought (CoT) prompts, and fine-tuning with both the CrossVul
dataset and our AI defect repair dataset along with CoT prompts.
Base Model (No Fine-Tuning): Each of the models was initially evaluated on the test
set without any fine-tuning, serving as a baseline. This allowed us to measure the initial
performance of the models before any domain-specific adjustments.
Model Fine-Tuned with CrossVul Dataset: We fine-tuned each model using a portion
of the CrossVul dataset to assess its ability to detect and repair traditional code defects.
This step focused on evaluating the model’s general performance on a well-established
code defect dataset.
Model Fine-Tuned with CrossVul+Chain-of-Thought (CoT): We fine-tuned the mod-
els with a portion of the CrossVul dataset, formatted using Chain-of-Thought (CoT)
prompts. This configuration allowed us to evaluate the impact of CoT reasoning on the
models’ performance in detecting and repairing code defects, particularly in the context
of AI-related code.
Model Fine-Tuned with CrossVul+AI Defect Repair Dataset+Chain-of-Thought
(CoT): In the final configuration, we fine-tuned the models using both the CrossVul
dataset and our manually constructed AI code defect repair dataset, augmented with
CoT prompts. This fine-tuning aimed to enhance the model’s ability to generalize to
194 H. Gong et al.

AI-specific code defects and improve accuracy and logical consistency in code defect
repair tasks.
The fine-tuning process for each model was carried out using the official fine-tuning
script provided by the respective models (e.g., run_finetune.sh for Codeshell-7B), with
the following parameters:
Per Device Train Batch Size: 2 (moderately increases training speed).
Gradient Accumulation Steps: 4 (balances computational resources and training
efficacy).
Gradient Checkpointing: Enabled (reduces memory usage).
Learning Rate Scheduler Type: Cosine (gradually decreases learning rate).
Logging Steps: 50 (logs training progress every 50 steps).
Save Steps: 500 (saves the model every 500 steps to reduce storage frequency).
Learning Rate: 2e-5 (ensures training stability).
Number of Training Epochs: 5 (ensures sufficient training).

4.3 Evaluation Metrics

In evaluating the models on the test set, we used the following metrics to assess
performance in AI code defect detection and repair tasks:
Acc (Accuracy): Measures the proportion of correctly identified defects in the code. For
each test sample, if the model correctly identifies the defect location, it is considered
accurate.
Repair: Assesses the proportion of correct repair suggestions compared to the actual fixes.
For each sample where a defect is identified, the generated repair code is compared with
the actual fix, and a match is considered a successful repair.
BLEU: A text similarity measure based on n-grams, evaluating the degree of match
between generated code and reference code, with a score ranging from 0 to 1, where
1 indicates a perfect match. This metric compares the generated repair code with the
actual fix for each sample, calculating the n-gram match.
Exec. Correct (Execution Correctness): Measures whether the repaired code functions
correctly when executed in a real programming environment. This metric assesses
whether the generated repair code successfully resolves the defects in the original code.

4.4 Experimental Analysis

Our experiments aimed to: (1) evaluate and compare the model’s performance under
different dataset compositions and processing strategies, and (2) analyze the effectiveness
of chain-of-thought fine-tuning in code defect detection tasks.
We formatted the test set data into model input prompts using pre-written scripts
and then fed them into both the baseline and fine-tuned models, including the traditional
static defect detection method. We then compared the outputs of each model. Table 2
summarizes the performance of each model on the test set, listing their performance on
the CrossVul dataset (A) and the manually constructed AI defect repair dataset (B). This
includes the performance of Codeshell-7B, Qwencoder2.5-7B, Llama3.1-7B, as well as
the Traditional Static Defect Detection method for comparison.
Construction of an AI Code Defect Detection 195

Table 2. Performance of Different Models on Different Test Sets. The CrossVul dataset extracts
10% of its data (147 samples) as Test Set A, while 10% of the data from the manually constructed
AI defect repair dataset (88 samples) is used as Test Set B.

Model Type A - Acc A- A- A - Exec. B - Acc B- B- B - Exec.


Repair BLEU Correct Repair BLEU Correct
Codeshell-7B 43.3% 32.6% 19.6 73.2% 23.9% 17.0% 10.7 70.3%
Codeshell-7B+CrossVul 44.7% 35.5% 20.4 69.3% 22.7% 15.9% 9.8 65.5%
Codeshell-7B+CrossVul+CoT 52.5% 43.3% 23.5 75.6% 26.1% 18.2% 11.3 70.3%
Codeshell-7B+CrossVul+AI 50.3% 41.8% 22.7 74.7% 34.1% 25.0% 16.4 75.2%
Data+CoT
Qwencoder2.5-7B 47.1% 40.1% 21.3 77.9% 27.1% 19.3% 12.7 73.2%
Qwencoder2.5-7B+CrossVul 48.2% 42.5% 22.1 78.5% 29.3% 21.0% 13.2 74.6%
Qwencoder2.5-7B+CrossVul+CoT 55.4% 46.7% 25.0 80.1% 31.5% 23.8% 15.2 77.3%
Qwencoder2.5-7B+CrossVul+AI 53.8% 45.2% 24.2 79.8% 36.8% 28.3% 17.0 79.0%
Data+CoT
Llama3.1-7B 46.3% 38.4% 20.6 75.2% 25.6% 19.2% 12.1 71.4%
Llama3.1-7B+CrossVul 45.6% 37.2% 19.9 74.3% 24.7% 17.9% 11.5 70.6%
Llama3.1-7B+CrossVul+CoT 53.1% 44.1% 22.4 77.5% 28.2% 22.5% 13.8 73.5%
Llama3.1-7B+CrossVul+AI 51.7% 43.2% 21.9 76.9% 33.4% 26.1% 15.4 76.8%
Data+CoT
Traditional Static Defect Detection 23.9% N/A N/A N/A 19.3% N/A N/A N/A

After training on the datasets, all models generally outperformed the baseline. How-
ever, models fine-tuned directly with the CrossVul dataset occasionally underperformed,
especially when dealing with complex AI code defects. In contrast, applying chain-of-
thought to fine-tune the model with the CrossVul dataset significantly improved the
model’s ability to detect general code. It also enhanced its detection capability for AI
code to some extent. Adding the manually constructed AI code defect and repair dataset
significantly improved the model’s performance on AI code defect detection and repair
tasks. The model fine-tuned with the combination of CrossVul and AI datasets, along
with chain-of-thought prompts, exhibited the best performance.

5 Conclusion
In this paper, we explored the task of code defect detection and repair using large mod-
els, with a particular focus on the impact of chain-of-thought fine-tuning on model
performance. We constructed a dataset comprising the CrossVul dataset and a manually
created AI code defect and repair dataset, fine-tuning the model with various datasets and
processing methods to evaluate its performance across different contexts. Experimen-
tal results indicate that chain-of-thought fine-tuning significantly improves the model’s
accuracy, repair rate, BLEU score and execution correctness. Models fine-tuned with
the AI code defect and repair dataset demonstrated excellent performance in handling
complex AI code defect detection tasks. In contrast, models trained on the unprocessed
196 H. Gong et al.

CrossVul dataset and the original baseline model underperformed in some cases, high-
lighting the importance of targeted optimization and dataset construction for specific
scenarios.
During the experiments, we observed that directly fine-tuning with the unprocessed
CrossVul dataset, although yielding good results for some defect types, led to a noticeable
decline in detection performance when encountering new defect types, even falling short
of the baseline model. However, by introducing chain-of-thought prompts, the model
better understood the context and logical relationships in defect repair, enhancing its
generalization ability when faced with different defect types.

Acknowledgments. This work is supported by 3AD303F5 and NSFC (Grant No. 62472047).

References
1. Pistoia, M., Chandra, S., Fink, S.J., et al.: A survey of static analysis methods for identifying
security vulnerabilities in software systems. IBM Syst. J. 46(2), 265–288 (2007)
2. Wei, P.: The safety of software design loophole dynamic state examination technique analysis.
CD Technol. 2009(4), 16–17 (2009)
3. Zhao, Y. Z.: A static analysis-based system for detecting code security vulnerabilities.
University of Electronic Science and Technology of China (2012)
4. Raychev, V., Bielik, P., Vechev, M., et al.: Learning programs from noisy data. ACM Sigplan
Not. 51(1), 761–774 (2016)
5. Chen, M., Tworek, J., Jun, H., et al.: Evaluating large language models trained on code. arXiv
preprint arXiv:2107.03374 (2021)
6. Ribeiro, F., de Macedo, J.N.C., Tsushima, K., et al.: GPT-3-powered type error debugging:
Investigating the use of large language models for code repair. In: Proceedings of the 16th
ACM SIGPLAN International Conference on Software Language Engineering, pp. 111–124
(2023)
7. Keller, J., Nowakowski, J.: AI-powered patching: the future of automated vulnerability fixes.
Technical report (2024)
8. Amershi, S., Begel, A., Bird, C., et al.: Software engineering for machine learning: a case study.
In: IEEE/ACM 41st International Conference on Software Engineering: Software Engineering
in Practice (ICSE-SEIP), pp. 291–300 (2019)
9. Raffel, C., Shazeer, N., Roberts, A., et al.: Exploring the limits of transfer learning with a
unified text-to-text transformer. J. Mach. Learn. Res. 21(140), 1–67 (2020)
10. Hu, E.J., Shen, Y., Wallis, P., et al.: LORA: low-rank adaptation of large language models.
arXiv preprint arXiv:2106.09685 (2021)
11. Devlin, J., Chang, M.W., Lee, K., et al.: BERT: pre-training of deep bidirectional transformers
for language understanding. arXiv preprint arXiv:1810.04805 (2018)
12. Liu, X., Zheng, Y., Du, Z., et al.: GPT understands, too. AI Open (2023)
13. Wei, J., Wang, X., Schuurmans, D., et al.: Chain-of-thought prompting elicits reasoning in
large language models. Adv. Neural. Inf. Process. Syst. 35, 24824–24837 (2022)
14. Yang, G., Zhou, Y., Chen, X., et al.: Chain-of-thought in neural code generation: from and
for lightweight language models. IEEE Trans. Softw. Eng. (2024)
15. Nikitopoulos, G., Dritsa, K., Louridas, P., et al.: CrossVul: a cross-language vulnerability
dataset with commit data. In: Proceedings of the 29th ACM Joint Meeting on European Soft-
ware Engineering Conference and Symposium on the Foundations of Software Engineering,
pp. 1565–1569 (2021)
Backdoor Attack on Android Malware
Classifiers Based on Genetic Algorithms

Zhenghua Cai1 , Yongji Wang1(B) , Hua Zhang1 , Lei Qiao2 , Huawei Wang2 ,
and Chi Zhang3
1 State Key Laboratory of Networking and Switching Technology, Beijing University of Posts
and Telecommunications, Beijing, China
{wyj2022,zhanghua_288}@bupt.edu.cn
2 Beijing Institute of Control Engineering, Beijing, China
[email protected]
3 Information Center of China, North Industries Group Corporation, Beijing, China

Abstract. The rapid rise of malware challenges traditional detection methods due
to code obfuscation and polymorphism. While machine learning classifiers offer
quick detection and can identify complex malicious features, they are suscepti-
ble to backdoor attacks. We introduce GAT, a genetic algorithm-based approach
for generating effective and stealthy Android backdoors. Using the SHAP inter-
pretability tool, we first select efficient features as primary backdoors. A fitness
function then enables iterative optimization through a genetic algorithm. Addi-
tionally, we propose a method to integrate backdoor features into the source code,
maintaining functionality while facilitating attacks on Android classifiers in real
data outsourcing scenarios. Our evaluation of the Drebin and Mamadroid mal-
ware detectors in data outsourcing scenarios indicates that an attack success rate
exceeding 70% can be achieved with only 5% poisoned samples and a minimal
number of trigger features, while keeping the false positive rate below 10% and
the label flipping rate below 30%. Additionally, the performance degradation of
the classifiers remains within 5%. This work provides new insights into backdoor
attack methodologies in Android malware classifiers.

Keywords: Malware detection · backdoor attacks · genetic algorithm · SHAP


interpretability

1 Introduction
The exponential growth of malware, along with techniques like code obfuscation and
polymorphism, has outpaced traditional detection methods. In response, security profes-
sionals are increasingly employing machine learning and deep learning to tackle large-
scale malware detection challenges. ML-based malware classifiers effectively counter
traditional evasion tactics and capture complex malicious features, often serving as the
first line of defense due to their rapid detection capabilities.
However, machine learning-based detection methods have attracted the attention of
attackers, leading to various AI-related security issues. Adversarial attacks can subtly

© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 197–214, 2025.
https://doi.org/10.1007/978-981-96-4566-4_14
198 Z. Cai et al.

modify sample features during inference to bypass detection, while poisoning attacks
target the training phase. Since ML-based malware detectors require regular retraining
with samples from third-party sources, attackers can inject malicious samples into the
training set, compromising model decisions. Backdoor attacks, a variant of poisoning,
involve constructing specific triggers that cause the model to misclassify inputs con-
taining the trigger while correctly classifying normal samples. These backdoors can be
categorized into two types: label-flipping backdoors, which alter the original labels, and
clean-label backdoors, where original labels remain unchanged while backdoor features
are embedded into the original code.
Sasaki et al. [1] first introduced a backdoor attack on malware classifiers using reverse
gradient optimization to poison a specific malware family, causing misclassification.
However, their research was limited to feature space experiments and did not address
real-world poisoning scenarios. Li et al. [2] proposed a genetic algorithm-based backdoor
attack that relies on label flipping to create poisoned samples. This approach assumes
that attackers can arbitrarily modify training set labels, which is impractical in real-world
contexts.
Severi et al. [3] were the first to investigate clean-label backdoor attacks on malware
classifiers, using interpretable machine learning to compute feature marginal contribu-
tions for constructing backdoor triggers. They injected these triggers into benign sam-
ples, creating a “channel” that led to misclassification of malware as benign. While their
method achieved a high success rate, it left noticeable traces in the samples, compro-
mising stealth, particularly when using transformed abstract features. Building on this
work, Yang et al. [4] explored the stealthiness of clean-label attacks and proposed an
optimized algorithm targeting specific malware families, achieving better effectiveness
and stealth. However, this backdoor was limited in applicability, only being effective
against certain malware families.
Li et al. [2] classified Android features into static and dynamic categories, mapping
original code files to feature matrices by inserting dummy code. However, this approach
is vulnerable to static analysis, which can prune the backdoor code, rendering it inef-
fective. Severi et al. [3] focused on ensuring the normal functionality of PE software
when inserting backdoor triggers but did not address the preservation of functionality in
Android applications. Yang et al. [4] established a mapping between custom functional
statements and features, enabling rapid mapping from backdoor features to original code.
Nonetheless, this process inevitably introduced redundant code, resulting in the mapping
of extraneous features.
Current backdoor attack algorithms in Android malware classifiers face limitations,
including inapplicability to real attack scenarios and a trade-off between effectiveness
and stealth. Existing backdoor insertion methods are prone to pruning by static analyz-
ers and may introduce redundant code. We propose a clean-label backdoor generation
algorithm based on genetic algorithms to enhance both effectiveness and stealth. Ini-
tially, SHAP machine learning interpretability tools are used to select a set of efficient
features as the initial backdoor. A fitness function is then designed to optimize backdoor
effectiveness and stealth iteratively via the genetic algorithm. Additionally, features are
categorized into XML and DEX types for mapping, with methods defined for feature
deletion and addition. To prevent static analysis pruning, code obfuscation and reflection
Backdoor Attack on Android Malware Classifiers Based on Genetic Algorithms 199

techniques are employed. The proposed algorithms were evaluated against two malware
detectors, leading to the following contributions:
(1) A genetic algorithm-based Android backdoor generation method that demonstrates
strong attack efficacy and stealth across three attack scenarios and two malware
classifiers.
(2) An Android backdoor insertion method that successfully integrates backdoor fea-
tures into source code without compromising software functionality, enabling
real-world attacks against Android classifiers.

2 Related Works
2.1 Malware Detection Systems

Machine learning-based malware detection systems classify software as benign or mali-


cious based on static or dynamic features, enabling more granular detection (e.g., mali-
cious code localization). Dynamic features are gathered by executing applications in a
sandbox environment, while static features are extracted from decompiled source code
without execution.
The model training process involves extracting feature vectors X from binary soft-
ware and corresponding labels Y to train the classifier. The classifier’s objective is to
predict the label y ∈ C = {0, 1} for an input x ∈ X, assuming input-output pairs are inde-
pendently and identically distributed from distribution D. The detector is represented as
a function Fθ : X → C, with parameters θ optimized by minimizing a loss function L(x,
y, θ) using a labeled training set D = {xi , yi } i = 1…n, as defined in Eq. 1.
 
argmin − yij × log(Prob(pred = j|xi , θ)) (1)
{θ} i∈D j∈C

2.2 Backdoor Attacks

The earliest backdoor attack method, BadNet, was proposed by Gu et al. [5], followed by
various adaptations (Liu et al. [6]; Yao et al. [7]; Wenger et al. [8]). We focus on clean-
label backdoor attacks, which are covert methods targeting machine learning models.
In this approach, attackers introduce samples with specific backdoor functionality into
the training set, labeling them with correct class labels (i.e., clean labels). In image
recognition, incorrect labels (e.g., bird-cat) may raise suspicion during manual inspec-
tions, prompting Shafahi et al. [9] to propose clean-label poisoning attacks that ensure
input labels remain correct. Inspired by this work, subsequent researchers have explored
additional clean-label attack strategies (Barni et al. [10]; Liu et al. [11]; Saha et al. [12];
Turner et al. [13]). This method is particularly covert in practical applications, as it can
evade manual detection. In the context of malware detection, attackers must conduct
poisoning attacks through online platforms, where multiple detection engines label soft-
ware. Modifying labels without bypassing all engines is unrealistic, leading attackers to
rely on clean-label strategies that align with the labels assigned by the detection engines.
200 Z. Cai et al.

3 Attack Methods

3.1 Threat Model


3.1.1 Attacker Capabilities
Table 1 outlines the attack scenarios considered in this paper. In Scenario 1, the attacker
has full control over the outsourced training data, is aware of the feature set used by
the model trainer, and has white-box access to the target model. The attacker ultimately
releases a backdoored pre-trained model for users. Scenario 2 presents an ideal case
where the attacker has white-box access to the model but can only modify a subset of the
training data, while other conditions remain the same. Scenario 3 imposes the strictest
limitations on the attacker, allowing modification of only a portion of the training data
and granting black-box access to the target model. This scenario closely resembles real-
world attacks, where the attacker may submit backdoored software to platforms like
VirusTotal, which then stores the software and its detection results in a database for
enterprises or API users to download as part of the training set. The specific process is
illustrated in Fig. 1.

Table 1. Attack Scenarios

Attack Scenario Attacker’s Capabilities


Model Parameters Training Data Feature Set Label
√ √ √
Scenario 1: Model ×
Outsourcing
√ √
Scenario 2: Ideal Data × ×
Outsourcing

Scenario 3: Real-World Data × × ×
Outsourcing

Fig. 1. Flowchart of Real-World Data Outsourcing Process


Backdoor Attack on Android Malware Classifiers Based on Genetic Algorithms 201

3.1.2 Attackers’ Goals


In the poisoning process of backdoor attacks, there are two primary goals: effectiveness
and stealth. Firstly, the backdoor must be effective. As indicated in Eq. 2, for the poisoned
model Fb , the normal model F, poisoned samples Xb , normal samples X, correct labels
y, and target labels yb , the poisoned classifier should not degrade in performance on the
normal dataset, which means Fb (X) = F(X). When the backdoor is activated, the poisoned
classifier’s prediction should yield the attacker’s desired target label, i.e., Fb (Xb ) = yb .
Secondly, the backdoor should exhibit stealth, which includes label stealth and soft-
ware stealth. Label stealth means that the label recognized by the original classifier
remains unchanged before and after the trigger is inserted, i.e., F(Xb ) = y. Software
stealth indicates that the software operates normally before and after the trigger is
inserted.

Fb (X ) = F(X ); F(Xb ) = y; Fb (Xb ) = yb = y (2)

3.2 Attack Algorithm

We propose a backdoor generation algorithm based on genetic algorithms (GAT). Ini-


tially, the SHAP model interpretability tool is employed to extract feature candidates in
two strategies, identifying potential backdoor regions. Subsequently, feature values for
these candidate regions are initialized, and a genetic algorithm is utilized to optimize the
backdoor trigger, resulting in the final backdoor trigger. The overall process is illustrated
in Fig. 2, with specific details provided in Algorithm 1.
202 Z. Cai et al.

Fig. 2. Flowchart of the Overall Process of the GAT Algorithm

3.2.1 Feature Selection Module


SHAP provides the confidence level of each feature in relation to the model’s predic-
tions. We calculate the confidence matrix for each software sample in the attack set
and derives the average confidence matrix by taking the mean. In binary classification,
higher confidence suggests a stronger contribution to positive class predictions, while
lower confidence indicates a greater contribution to negative class predictions. Two fea-
ture selection strategies are proposed to connect backdoor triggers with benign labels, as
illustrated in Fig. 3. Strategy One selects features with high average confidence that sig-
nificantly contribute to benign predictions, reflecting key characteristics of the predicted
label. Strategy Two focuses on features with average confidence close to zero, indicating
minimal influence on predictions and corresponding to background areas unrelated to
the label. After identifying target feature regions, a genetic algorithm is employed for
selection and optimization, aiming to embed backdoor features in these areas to bias the
model’s decisions toward benign outcomes.

Fig. 3. Flowchart of Backdoor Feature Region Selection Strategy

3.2.2 Genetic Algorithm Optimization Module


After identifying target feature regions using the two strategies, we employ a genetic
algorithm to iteratively generate feature values, forming the backdoor trigger. The pro-
cess begins by initializing the backdoor trigger based on selected feature regions. For
Strategy One, initial feature values are set to the average of benign samples, while for
Strategy Two, values are drawn randomly from a uniform distribution within constraints.
Next, K benign feature matrices from the attack set are combined with the initial backdoor
features, creating poisoned samples labeled as benign, which are added to the training
set. After retraining the model, the fitness score of the poisoned model is calculated. The
Backdoor Attack on Android Malware Classifiers Based on Genetic Algorithms 203

algorithm undergoes N generations of selection, crossover, and mutation until the fitness
score converges, indicating no improvement over e iterations. The backdoor trigger with
the highest fitness score is then considered optimal.
Selection: The algorithm selects the two backdoor triggers with the highest fitness
scores for crossover and mutation. The fitness function T is defined as the weighted
average of the stealth score S and effectiveness score V (Eq. 5). The attacker seeks
to maintain the original classes of backdoor samples under the normal model while
changing them to the target class in the poisoned model, without compromising the
recognition of normal samples. Stealth Score: This is defined as the number of backdoor
samples Xb that retain their original labels in the normal model (Eq. 3). Effectiveness
Score: This consists of two components: the number of samples whose labels change to
the attacker’s target label under the poisoned model, and the count of normal samples X
that retain their original labels in the poisoned model (Eq. 4).

S = count(F(Xb ) = original label) (3)

V = count(Fb (Xb ) = target label) + count(Fb (X ) = original label) (4)

T = a × S + b × V (a + b = 1) (5)

Crossover: The algorithm randomly selects two backdoor triggers with higher fitness
scores for feature exchange. The number of exchanged features is 20% of the trigger
length, rounded up. If features are identical, only their values are swapped. If different,
both keys and values are exchanged, maintaining the overall length of the backdoor
trigger.
Mutation: The algorithm introduces random perturbations to the backdoor triggers
derived from selection and crossover. To ensure compliance with feature constraints and
maintain the link between backdoor features and benign labels, perturbation magnitudes
are calculated based on the distribution of benign samples in the attack sample set for
each feature.

3.3 Android Backdoor Insertion Program


To address real-world data outsourcing, we present an Android backdoor insertion pro-
gram that translates trigger features into actual code, which is then inserted into the
original smali files. Which is then inserted into the original smali files. After repackag-
ing, it can poison malware detection platforms, resulting in a poisoned model trained
on the poisoned data collected by users. The process begins with reverse engineer-
ing the packaged Android application using Apktool, generating source files such as
class.dex and AndroidManifest.xml. The class.dex file contains the application’s source
code, while AndroidManifest.xml holds critical app information, including permissions.
Modifications to these files are classified into two types based on the training feature
sets of malware classifiers: changes to the AndroidManifest.xml and alterations to the
smali code from the decompiled dex files.
204 Z. Cai et al.

3.3.1 Modifying the AndroidManifest.Xml File


Modifying the XML file is simpler than altering the smali source code files. As illustrated
in Fig. 4, declaring the permission “android.permission.TRANSMIT_IR” changes the
corresponding feature value from 0 (unused) to 1 (used). Conversely, to revert a feature
value from 1 (used) to 0 (unused), a regular expression match can be performed to
identify and remove the relevant information.

Fig. 4. Example of Modifying the XML File

3.3.2 Modifying the Class.Dex File


Using Apktool, dex files can be converted into multiple smali files, each typically rep-
resenting a single class. Current malware classifiers extract two feature types: function
call features and function presence features. For instance, Mamadroid employs function
transition probabilities derived from function call features, while Drebin uses one-hot
encoding based on the Android API to create feature vectors from function presence
features.
We present modification algorithms for these extraction methods: for API addition,
we define Inc(x, t, d) as adding a method d in smali file x that calls API t. To enhance
stealth and evade static analyzers, two techniques are proposed. The first involves
using try-catch statements for code obfuscation, ensuring that the relevant code branch
does not execute during dynamic execution. As shown in Fig. 5, the callerEx function
attempts to invoke the callee function within a try block, but an array out-of-bounds error
occurs beforehand, preventing dynamic execution while still being detectable by static
analyzers. This method increases both function call and function presence features.

Fig. 5. Code Obfuscation Using try-catch Statements to Add Invalid Function Calls

The second method involves defining empty functions. As shown in Fig. 6, an empty
function funA is defined under Lcom/google/Myclass, while function funB is defined
Backdoor Attack on Android Malware Classifiers Based on Genetic Algorithms 205

under Lcom/malware/Myclass to call funA. By invoking funB through the startup func-
tion, this method not only adds a function call feature from the malware package to the
Google package but also enhances the function presence feature.

Fig. 6. Defining Empty Functions and Invoking Them to Add Invalid Function Calls

For API reduction, we define it as Dec(x, t, d), indicating that function d calls function
t within Smali file x. We utilize regular expressions to locate the call within the Smali file,
and then employ reflection techniques for code obfuscation to achieve the API reduction.
As illustrated in Fig. 7, function funA under Lcom/google/Myclass calls function funB
under Lcom/malware/Myclass via reflection, making it undetectable by static analyzers
that funA has called funB.

Fig. 7. Achieving API Reduction through Reflection Calls

4 Experiments
4.1 Dataset
The dataset used in this paper includes malicious samples from Drebin [14] and benign
and malicious applications collected from Android Zoo. A total of 27,445 benign appli-
cations and 23,839 malicious applications were collected. We randomly selected 5,489
206 Z. Cai et al.

benign applications and 4,768 malicious applications as the test dataset, while the
remaining applications were used for training.

4.2 Target Attack Systems


We evaluate two typical Android malware detectors, namely Drebin [14] and Mamadroid
[15]. We provide detailed information about the target malware detectors.
Drebin: Drebin extracts features from the AndroidManifest.xml and class.dex files
using static analysis tools. These features are categorized into eight groups, and extracted
by classification and encoded using one-hot encoding, with values indicating the presence
of specific permissions/APIs/components.
Mamadroid: Mamadroid aims to capture the behavior of Android applications by
collecting information about API call families or packages. It computes the Markov
transition probabilities between families and packages as training features.

4.3 Experimental Setup


Table 2 outlines the settings for the target classifiers in this study. The SVM and Ran-
dom Forest classifiers were trained based on parameters from the original research. To
evaluate the effectiveness of the proposed attack on neural networks, we constructed
a neural network NN with four fully connected layers: the first three utilized ReLU
activation functions, and the final layer employed a Sigmoid activation function. Batch
Normalization was applied to the first three layers, with a 50% dropout rate to miti-
gate overfitting. The Drebin classifier was trained using all original features, while the
Mamadroid classifier focused on Family and Package features. Both classifiers achieved
detection accuracies and F1 scores above 92%, with the NN slightly outperforming
SVM and RF. As described in Algorithm 1, hyperparameters for the backdoor trigger
generation include the sample size K, maximum iterations e, the weights of the fitness
function a and b, and crossover and mutation probabilities Cp and Mp, set to k = 20, e
= 20, a = 0.5, b = 0.5, Cp = 0.2, and Mp = 0.2. Subsequent experiments utilized these
hyperparameters, evaluating the effectiveness and stealthiness of the backdoor trigger.

4.3.1 Effectiveness Metrics for Backdoor Triggers


ASR (Attack Success Rate): The proportion of malicious samples with backdoors mis-
classified as benign by the poisoned model, indicating the effectiveness of the backdoor
trigger; attackers aim for a high ASR.
Acc(Fb , X): The accuracy of the poisoned model on a clean test set, reflecting the
impact of poisoned data on original model performance; attackers prefer this value to
be high.
BST (Backdoor Selection Time): The time taken by the backdoor generation
algorithm to produce the optimal backdoor.

4.3.2 Stealthiness Metrics for Backdoor Triggers


FPb : The false positive rate of the poisoned model, which impacts security costs; attackers
prefer this rate to be low to enhance stealthiness.
Backdoor Attack on Android Malware Classifiers Based on Genetic Algorithms 207

Table 2. Relevant Settings for the Target Classifier

Target Attack Attack Model Feature Set F1 Score FN Rate FP Rate Accuracy
Systems
Drebin NN ALL 98.48% 2.50% 0.50% 98.50%
SVM 97.98% 3% 1% 98%
RF 98.25% 2% 1.50% 98.25%
Mamadroid NN Family 92.60% 12.50% 1.50% 93%
SVM 92.06% 13% 2% 92.50%
RF 92.40% 12% 2.50% 92.75%
NN Package 96.60% 10.25% 1.50% 96.30%
SVM 95.06% 11% 2.4% 95.25%
RF 95.40% 13% 2.25% 95.74%

Acc(F, Xb ): The label flip rate, indicating the accuracy of the original model on mali-
cious samples with backdoors. A lower value signifies minimal disruption to detection
outcomes, which is desirable for maintaining stealth.

4.4 Experimental Results

4.4.1 Evaluation of the Effectiveness of the GAT Algorithm


4.4.1.1 Evaluation of the Effectiveness of Attacking Machine Learning Models
To validate the effectiveness and stealth of the proposed GAT (GA-based Trigger) algo-
rithm, we set the number of triggers to 30 and evaluated its performance across various
poisoning rates in three attack scenarios. We compared GAT against three baseline
strategies from EXP [3]: GS (Greedy Selection), LC (LargeAbsSHAP × CountAb-
sSHAP), and LM (LargeAbsSHAP × MiniPopulation). In model outsourcing scenarios
with white-box access, both the substitute and target models were SVM. In black-box
scenarios, the substitute model was SVM and the target model was RF, the results are
presented in Table 3. For the Drebin target, results indicate that in model outsourcing,
GAT consistently outperformed the baseline strategies, achieving an ASR above 0.65 at
poisoning rates of 0.5 to 0.8. In the ideal data outsourcing scenario, with poisoning rates
of 0.005 to 0.1, GAT demonstrated a superior ASR compared to EXP, with a minimal
difference of 0.007 at 0.05. In the black-box data outsourcing scenario, ASR decreased
with reduced attacker knowledge but remained comparable to the EXP strategies.
The results targeting Mamadroid with two feature granularities are presented in
Table 3. When using Family features, the ASR for the GAT algorithm and three baseline
strategies from EXP are significantly higher compared to those targeting Drebin. This is
due to the coarser granularity of Family features, which are more homogeneous, allowing
slight perturbations to significantly affect classification results. In all scenarios, the GAT
algorithm consistently outperformed the EXP strategies, approaching 100% ASR at a
model poisoning ratio of 0.8 and an ideal data poisoning ratio of 0.2. In contrast, with
208 Z. Cai et al.

Package feature granularity, the feature dimensionality increased to 197,136, leading to


high time complexity for SHAP, which failed to operate normally. To mitigate this, we
employed an SVM model to reduce dimensionality to the top 500 important features,
alleviating SHAP’s computational burden while retaining key predictors. Results in
Table 3 show that, like Family features, all attack methods under Package granularity
yielded higher ASRs than those targeting Drebin. However, the GAT algorithm’s ASR
was generally lower than with Family features due to the greater feature variety in
Package granularity, resulting in a more robust model. Notably, ASRs in both ideal and
actual data outsourcing scenarios were similar, likely due to the dimensionality reduction
process, which may have omitted some significant features. Thus, the similarity between
the substitute model and the target model has a minimal impact on the final ASR, resulting
in comparable ASRs across both attack scenarios.

Table 3. Comparison of ASR Across Three Attack Scenarios with Trigger Size of 30 and Attack
Target as Drebin、Mamadroid Family and Package granularity.

Attack Scenario Model Outsourcing Ideal Data Outsourcing Real-World Data Outsourcing
Poisoned percentage 0.5 0.6 0.7 0.8 0.005 0.01 0.05 0.1 0.005 0.01 0.05 0.1
GS-Drebin 0.624 0.633 0.685 0.746 0.484 0.484 0.540 0.572 0.141 0.163 0.427 0.570
LC-Drebin 0.683 0.698 0.719 0.760 0.365 0.424 0.559 0.603 0.147 0.180 0.348 0.415
LM-Drebin 0.540 0.546 0.566 0.626 0.244 0.313 0.413 0.451 0.156 0.273 0.466 0.554
GAT-Drebin 0.697 0.708 0.747 0.797 0.526 0.526 0.552 0.622 0.211 0.232 0.462 0.592
GS-Mamadroid-F 0.883 0.898 0.919 0.96 0.865 0.875 0.889 0.904 0.747 0.78 0.849 0.865
LC-Mamadroid-F 0.834 0.853 0.905 0.936 0.704 0.744 0.81 0.872 0.644 0.664 0.678 0.71
LM-Mamadroid-F 0.84 0.846 0.906 0.926 0.744 0.763 0.813 0.891 0.656 0.673 0.746 0.794
GAT-Mamadroid-F 0.897 0.909 0.948 0.998 0.888 0.907 0.932 0.983 0.813 0.831 0.863 0.893
GS-Mamadroid-P 0.856 0.866 0.879 0.894 0.755 0.783 0.842 0.864 0.753 0.782 0.844 0.858
LC-Mamadroid-P 0.694 0.734 0.800 0.862 0.646 0.667 0.687 0.707 0.639 0.663 0.681 0.701
LM-Mamadroid-P 0.734 0.753 0.803 0.871 0.655 0.673 0.742 0.797 0.646 0.673 0.743 0.795
GAT-Mamadroid-P 0.878 0.897 0.922 0.973 0.814 0.833 0.864 0.893 0.809 0.815 0.855 0.876

4.4.1.2 Evaluation of the Effectiveness of Attacking Neural Networks


To comprehensively evaluate the attack algorithms proposed in this paper, we demon-
strate the GAT algorithm’s effectiveness against neural networks. We assess its per-
formance on neural network classifiers targeting Drebin and Mamadroid within a data
outsourcing scenario. The target classifiers include Drebin-NN, Mamadroid-F-NN, and
Mamadroid-P-NN, with SVM as the substitute model and poisoning ratios set at 0.005,
0.01, 0.05, and 0.1. Experimental results, as shown in Fig. 8, indicate that when attack-
ing the Mamadroid system with a trigger size of 5 and a poisoning sample of 0.005, the
ASR exceeds 70% for both Family and Package feature granularities. This high ASR
is attributed to Mamadroid’s homogeneous training features and the neural network’s
strong fitting capability, allowing effective exploitation of a limited number of backdoor
samples. In contrast, the overall ASR for attacking Drebin is lower due to its greater fea-
ture diversity. Specifically, with a trigger size of 50 and a poisoning ratio of 0.1, the attack
Backdoor Attack on Android Malware Classifiers Based on Genetic Algorithms 209

success rates for Drebin and Mamadroid are 70% and 99%, respectively, highlighting
the GAT algorithm’s efficacy against neural networks.

Fig. 8. Attack Success Rates of the GAT Algorithm on Three Classifiers Under Data Outsourcing
Scenarios with Varying Trigger Sizes

4.4.2 Evaluation of the Stealth of GAT Algorithm


We evaluate the GAT algorithm’s effectiveness in a data outsourcing attack with a poi-
soning ratio of 0.01, maintaining prior experimental conditions. Results show that for the
Drebin classification system, the attack success rate increases with trigger size, surpass-
ing the GAT algorithm when using a trigger size of 30. However, the label inversion rate
for GAT is 23.32%, significantly lower than the 94% observed for the LM strategy in the
EXP algorithm. The label inversion rates for trigger sizes of 16, 30, 64, and 128 remain
lower than all EXP strategies. Notably, poisoned models generated by GAT maintain
high accuracy (96.3% to 94.27%) and low false positive rates (2.27% to 4.73%), out-
performing the EXP algorithm. For the Mamadroid classification system, using Family
feature granularity shows higher FPb and Acc(F, Xb ) under the same trigger size, with
lower Acc(Fb , X) due to feature sensitivity. Overall, the GAT algorithm exhibits supe-
rior covert characteristics compared to existing strategies, as illustrated in Table 4, which
detail the performance metrics for Drebin and Mamadroid, respectively.

4.4.3 Effectiveness Verification of GAT Algorithm Improvement


4.4.3.1 Comparison of SHAP and LIME
The proposed GAT algorithm leverages machine learning interpretability methods to
210 Z. Cai et al.

Table 4. Comparison of ASR, FPb , Acc(Fb , X), and Acc(F, Xb ) between the GAT Algorithm and
the EXP Algorithm under the Drebin Classification System

Attack method Trigger Size ASR FPb Acc(Fb , X) Acc(F, Xb )


GS-Drebin 16 4.86% 5.27% 93.50% 85.71%
30 16.40% 5.3% 92.25% 91.43%
64 27.06% 5.1% 93.48% 98.57%
128 48.66% 6.73% 91.67% 98.00%
LC-Drebin 16 11.41% 5.23% 93.40% 43.43%
30 17.96% 5.37% 93.25% 47.31%
64 38.02% 5.14% 91.48% 36.28%
128 59.39% 5.4% 91.43% 43.15%
LM-Drebin 16 22.36% 5.3% 93.30% 98.21%
30 27.28% 5.2% 92.15% 94.31%
64 50.38% 5.47% 93.38% 95.23%
128 60.19% 6.13% 91.27% 99.93%
GAT-Drebin 16 24.34% 2.27% 96.30% 30.45%
30 23.20% 2.3% 95.15% 23.32%
64 58.31% 3.1% 94.38% 21.87%
128 80.97% 4.73% 94.27% 27.43%
GS-Mamadroid-F 16 54.89% 7.35% 88.40% 86.58%
30 78% 7.42% 87.15% 92.12%
64 79.06% 7.18% 88.38% 99.01%
128 83.66% 8.83% 86.57% 98.45%
LC-Mamadroid-F 16 51.31% 7.31% 91.30% 44.62%
30 66.4% 7.46% 90.15% 48.79%
64 68.12% 7.23% 88.48% 37.23%
128 79.39% 7.5% 88.43% 44.66%
LM-Mamadroid-F 16 52.36% 7.39% 88.20% 98.37%
30 67.3% 7.28% 87.05% 95.52%
64 70.38% 7.56% 88.28% 96.33%
128 80.19% 8.21% 86.17% 99.87%
GAT-Mamadroid-F 16 78.34% 4.35% 94.30% 31.27%
30 87.1% 4.4% 93.15% 24.18%
(continued)

optimize feature selection for backdoor attacks. While this study employs SHAP, we
explore LIME as an alternative to assess differences in attack speed and effectiveness, as
Backdoor Attack on Android Malware Classifiers Based on Genetic Algorithms 211

Table 4. (continued)

Attack method Trigger Size ASR FPb Acc(Fb , X) Acc(F, Xb )


64 96.31% 5.18% 91.38% 22.06%
128 97.97% 6.8% 90.27% 28.14%

illustrated in Table 5. LIME, which uses Gaussian perturbations and a local linear model
to identify significant features, demonstrates a faster BST compared to SHAP. However,
SHAP achieves higher ASR due to its capability to capture feature interactions based
on Shapley values. When applying SHAP to the Mamadroid Package feature set, we
first reduce dimensionality from 197,136 to 500 using SVM, enabling efficient feature
extraction within a reasonable timeframe. This approach increases ASR by 16.74%
compared to LIME, while keeping the time overhead for SHAP within 20 s of LIME.
Overall, despite its longer processing time, SHAP’s accuracy in feature importance
calculation makes it preferable for initial backdoor generation.

Table 5. Comparison of ASR and BST of Different Model Interpretation Tools in Data
Outsourcing Scenarios

Initial Target Attack Attack Trigger Poisoned ASR BST


method System classifier Size percentage
LIME Drebin NN 30 0.50% 19.218% 30.338 s
Mamadroid-F NN 30 0.50% 78.315% 26.459 s
Mamadroid-P NN 30 0.50% 63.735% 97.337 s
SHAP Drebin NN 30 0.50% 23.439% 40.727 s
Mamadroid-F NN 30 0.50% 81.527% 38.538 s
Mamadroid-P NN 30 0.50% 80.475% 112.499 s

4.4.3.2 Comparison of Initial Backdoor Generation Strategies


We verify the GAT algorithm’s effectiveness through ablation experiments with six
groups. Using default hyperparameters, the attack scenario involves data outsourcing
targeting the Mamadroid Family, with a trigger size of 30 and a poisoning rate of 0.5%.
The first two groups compare backdoors generated by SHAP alone against those using
both SHAP and a genetic algorithm, measuring ASR, FPb , and accuracy metrics. Results
indicate that optimized backdoors are more effective and stealthy, as illustrated in Fig. 9.
The last four groups evaluate initialization strategies: S1, S2, random, and combined
(S1+S2), as illustrated in Fig. 10. The S1+S2 method outperforms the others, achieving
the highest ASR and lowest BST. Random initialization yields the lowest ASR, as it
often gets trapped in local optima. The S1+S2 strategy is superior by 2% in ASR and
reduces BST by 9 s, likely due to the diverse features from various APKs, which favor
optimal backdoor searches.
212 Z. Cai et al.

Fig. 9. Validation of the Effectiveness of Genetic Algorithm Optimization for Initial Backdoor
Generation

Fig. 10. Validation of the Effectiveness of Initial Backdoor Generation Strategies

4.4.4 Evaluation of the Effectiveness of the Backdoor Insertion Algorithm


We propose a backdoor feature insertion method for XML and DEX features. For XML
features, static analysis is sufficient to verify whether the perturbations generated by the
algorithm have been successfully inserted, as XML files do not affect the operation of
Android applications. For DEX features, it is necessary to not only confirm the successful
insertion of perturbations into DEX files but also to employ dynamic analysis to check
if these perturbations alter the functionality of the malware.

4.4.4.1 Whether Backdoor Features Were Pruned by Static Analyzers


We first verify whether XML and DEX backdoor features have been properly inserted. To
validate the insertion of XML backdoor features, the original and backdoor-embedded
XML feature matrices are extracted via static analysis, and the differences between the
matrices are compared to determine if the backdoor features were inserted correctly. For
DEX backdoor features, the original and backdoor samples’ function call matrices at
the Family and Package levels are analyzed similarly. The evaluation of 50 APK files
revealed that all generated perturbations were accurately injected into both XML and
smali files, indicating that the proposed backdoor insertion algorithm was not pruned by
static analyzers, and the backdoor features were successfully embedded in the malware.
Backdoor Attack on Android Malware Classifiers Based on Genetic Algorithms 213

4.4.4.2 Whether Software Functionality Remains Intact


Dynamic analysis was conducted to assess whether the inserted DEX backdoor fea-
tures affected malware functionality. Original and backdoor-embedded samples were
installed and executed on an Android Virtual Device (AVD), revealing that both pairs
performed identical operations and exhibited the same runtime UIs. To analyze the per-
turbations introduced by try-catch methods, empty functions, and reflection methods,
we inserted log statements (LOG1, LOG2, LOG3) for try-catch blocks, LOG4 for empty
functions, and LOG5 for reflection methods. Log analysis using Android Studio’s Log-
Cat on 50 APK files showed that LOG1, LOG3, LOG4, and LOG5 executed correctly,
while LOG2 did not. This indicates that empty functions and reflection calls operated
normally, while try-catch functions were not invoked, confirming that the malware’s
functionality remained unaffected. All backdoor-injected malware operated as intended.

5 Conclusion

We propose a genetic algorithm-based method for generating backdoors in Android


applications and experimentally verifies its effectiveness and stealth. By utilizing the
SHAP model for feature selection and designing an optimized fitness function, the gen-
erated backdoor triggers demonstrate strong efficacy across three attack scenarios. The
method involves modifying the AndroidManifest.xml and smali code from decompiled
class.dex files, ensuring seamless integration of backdoor features without disrupting
functionality. Experimental results show that the proposed attack algorithm achieves
an attack success rate exceeding 70% with only 5% poisoned samples and minimal
trigger features, while keeping the false positive rate below 10% and the label flipping
rate below 30%. This backdoor insertion method ensures normal software operation
and addresses pruning and redundancy issues, contributing valuable insights to Android
malware classifier research.

Acknowledgments. This work is supported by 3AD303F5 and NSFC (Grant No. 62472047).

References
1. Narisada, S., Sasaki, S., Hidano, S., et al.: Stronger targeted poisoning attacks against malware
detection In: Krenn, S., Shulman, H., Vaudenay, S. (eds.) Cryptology and Network Security.
CANS 2020. Lecture Notes in Computer Science, vol. 12579, pp. 65–84. Springer, Cham
(2020). https://doi.org/10.1007/978-3-030-65411-5_4
2. Li, C., Chen, X., Wang, D., et al.: Backdoor attack on machine learning based android malware
detectors. IEEE Trans. Depend. Secure Comput. 19(5), 3357–3370 (2021)
3. Severi, G., Meyer, J., Coull, S., et al.: Explanation-Guided backdoor poisoning attacks against
malware classifiers. In: 30th USENIX Security Symposium (USENIX Security 21), pp. 1487–
1504 (2021)
4. Yang, L., Chen, Z., Cortellazzi, J., et al.: Jigsaw puzzle: selective backdoor attack to subvert
malware classifiers. In: 2023 IEEE Symposium on Security and Privacy (SP). IEEE, pp. 719–
736 (2023)
214 Z. Cai et al.

5. Gu, T., Liu, K., Dolan-Gavitt, B., et al.: Badnets: Evaluating backdooring attacks on deep
neural networks. IEEE Access 7, 47230–47244 (2019)
6. Liu, Y., Ma, S., Aafer, Y., et al.: Trojaning attack on neural networks. In: 25th Annual Network
and Distributed System Security Symposium (NDSS 2018). Internet Soc (2018)
7. Yao, Y., Li, H., Zheng, H., et al.: Latent backdoor attacks on deep neural networks. In: Pro-
ceedings of the 2019 ACM SIGSAC Conference on Computer and Communications Security,
pp. 2041-2055 (2019)
8. Wenger, E., Passananti, J., Bhagoji, A.N., et al.: Backdoor attacks against deep learning
systems in the physical world. In: Proceedings of the IEEE/CVF Conference on Computer
Vision and Pattern Recognition, pp. 6206–6215 (2021)
9. Shafahi, A., et al.: Poison frogs! targeted clean-label poisoning attacks on neural networks.
CoRR (2018)
10. Barni, M., Kallas, K., Tondi, B.: A new backdoor attack in CNNS by training set corruption
without label poisoning. In: Proceedings of ICIP (2019)
11. Shapira, T., Berend, D., Rosenberg, I., Liu, Y., Shabtai, A., Elovici, Y.: Being single has
benefits. Instance poisoning to deceive malware classifiers. CoRR (2020)
12. Saha, A., Subramanya, A., Pirsiavash, H.: Hidden trigger backdoor attacks. In: Proceedings
of AAAI (2020)
13. Turner, A., Tsipras, D., Madry, A.: Label-consistent backdoor attacks. CoRR (2019)
14. Arp, D., Spreitzenbarth, M., Hubner, M., et al.: DREBIN: effective and explainable detection
of android malware in your pocket. NDSS 14, 23–26 (2014)
15. Mariconti, E., Onwuzurike, L., Andriotis, P., et al.: Mamadroid: detecting android malware
by building markov chains of behavioral models. arXiv preprint arXiv:1612.04433 (2016)
A Malicious Websites Classifier Based
on an Improved Relation Network

Qianshi Wang1 , Chongjun Xu2(B) , Huayu Yang2 , Xilin Zhai1 , and Hua Zhang1
1 State Key Laboratory of Networking and Switching Technology, Beijing University of Posts
and Telecommunications, Beijing 100876, China
{scrassy_1047,zhanghua_288}@bupt.edu.cn
2 Zhejiang Branch of National Computer Network Emergency Response Technical
Team/Coordination Center of China, Hangzhou, China
[email protected]

Abstract. Intelligent in-car operating system integrates various web platforms


including browsers, therefore facing serious issues of malicious web content. Cur-
rent detection models perform poorly in Internet of Vehicles (IoV) environment
due to simple model structures and lack of data samples, which make models
difficult to learn and extract multi-level and fine-grained sample features. In this
paper, we propose a novel few-shot model, CarMaNet, to detect and classify mali-
cious websites in IoV environment based on website images. We import inception
module into Relation networks, which enhance the ability of CarMaNet to learn
multi-level features of images. At the same time, we design a feature pipeline
featuring adaptive recalibration between the embedding module and the relation
module based on adaptive attention mechanism (FCAA), enhancing the model’s
ability to learn fine-grained features. In few-shot experiments on three dataset of
pornography, gambling and counterfeit websites, CarMaNet demonstrates greater
performance than current models.

Keywords: Internet of Vehicles · malicious website · few-shot · attention


mechanism · relation networks

1 Introduction
As new energy vehicles continue to integrate various features, intelligent in-car sys-
tems are gradually becoming the core of vehicle control, information interaction and
in-car entertainment, this brings serious challenges of cyber security in IoV environ-
ment. Web platforms integrated in in-car systems are facing the threats of malicious
websites. Among these websites, pornography and gambling websites make up a sig-
nificant portion, and counterfeit websites usually use fake pages to deceive users, which
may bring potential substantial financial losses. Therefore, indentifying and classifying
malicious websites in in-car systems is of great importance to protect the security of
IoV.
Traditional models for detecting malicious website are mostly trained on dataset of
features like website traffic or content. Zhao et al. [1] classified multi-level features of

© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 215–230, 2025.
https://doi.org/10.1007/978-981-96-4566-4_15
216 Q. Wang et al.

malicious websites based on the URL features and content features of websites, they
designed a classification model based on Bert-ResNet, achieving higher accuracy in
detecting malicious websites. Afzal et al. [2] proposed a deep learning model, URLdeep-
Detect, based on semantic embeddings. URLdeepDetect recognizes malicious websites
by analyzing the semantic information in their urls, which has good performance in
detecting malicious websites.
The methods above need a large number of website samples to train models, however
it is hard to collect considerable malicious website images in IoV environment when
we research how to detect malicious websites using website images. Models may face
the overfitting issue without enough samples, as they are not able to learn a stable and
accurate decision boundary in few-shot scenarios. In order to solve the problem of insuf-
ficient image samples, existing researches have proposed some approaches. Koch et al.
[3] proposed Siamese network based on metric learning, which employs weight sharing
to computes the similarity of images and then outputs different image category accord-
ing to the similarity of different samples. Flood Sung et al. [4] developed a Relational
Network with an embedding module and a relation module to classify categories by cal-
culating the similarity between sample images, and their approach demonstrate good per-
formance in few-shot experiments. Although methods above shows great performance
in few-shot image classification tasks, their models are still single-network structure,
which may encounter challenges when extracting multi-level and fine-grained informa-
tion features from complex images, such as images of malicious website, resulting in
poor performance and even misjudgment.
In order to address the issues of insufficient samples and incomplete feature extrac-
tion, a few-shot classification model for malicious websites based on an improved
Relation Network, CarMaNet, is proposed in this paper. Build upon the ability of few-
shot models to complete classification tasks with few-shot training, we further import
the inception block [5] into the traditional Relation Network and achieve malicious
website classification based on attention mechanism in few-shot scenario. The specific
contributions are as follows.
(1) We propose a few-shot model called CarMaNet to classify malicious websites based
on an improved Relation Network. We bring attention mechanism into CarMaNet,
insert a parallel SE module into each residual block and combine the outputs of the
convolutional layer with the SE module. We call this process FCAA, which enables
CarMaNet to learn multi-level features of malicious websites better in few-shot
scenarios.
(2) We design and implement a multi-scale feature embedding module, which contains
an Inception structure. This makes it easier for the model to learn fine-grained fea-
tures of malicious website images, improving the model’s ability to capture features
in complex image samples.
(3) Experimental results show that CarMaNet demonstrates better performance in fea-
ture extraction and classification on few-shot datasets compared with traditional
models, significantly improving the accuracy of malicious website classification.
A Malicious Websites Classifier Based on an Improved Relation Network 217

2 Related Work

Classifying malicious websites such as pornography, gambling and counterfeit websites


has great importance in enhancing the security and stability of intelligent in-car systems,
which is crucial to IoV security. Currently, in terms of classification tools for malicious
websites, VirusTotal [6], a security detection tool under Chonicle, can detecting and
classify malicious websites in a fine-grained level through merging over 70 third-party
engines, categorizing websites into types such as pornography, gambling, and counterfeit
websites. Tencent URL Safe [7] uses Tencent’s intelligent filtering engine based on cloud
technology and a malicious URL database to classify malicious websites into categories
like phishing, spam, pornography, and gambling. The Norton Safe Web service [8] clas-
sifies various types of malicious websites through methods such as signature matching,
website content scanning, user and community ratings, and blacklists/whitelists.
In terms of researches in malicious website classifying, Chen et al. [9] utilized
Doc2Vec algorithm and an improved Bag of Visual Words (BoVW) model to learn fea-
tures in HTML text and website images respectively, training text and image classifiers.
They use data fusion algorithm based on logistic regression to optimize the decision-
making process, and the results of text and image classifiers are integrated to make
the final prediction. Wang et al. [10] proposed a multi-mode model to recognize gam-
bling website based on data fusion, which recognize gambling websites through website
images and the semantic information in those images. They first use residual network to
extract image features in websites, then extract the text feature in website based on a bidi-
rectional Long Short-Term Memory (LSTM) network. Finally, a multi-mode data fusion
method was applied for fine-grained classification of gambling websites, achieving great
accuracy. Siddiq et al. [11] utilized a visual learning method based on deep learning to
detect phishing websites. Zhang et al. [12] proposed DRSDetector, which combines
multi-level features to detect gambling websites. DRSDetector is based on two stacked
Transformer Encoder structures, it use LightGBM algorithm to learn the resource fea-
tures of websites and use HAN model to learn the semantic features of websites. Naru
et al. [13] introduced a machine learning method which combines logistic regression and
polynomial naive Bayes to classify phishing malicious websites, achieving an accuracy
of 97% on their dataset.
The methods above takes advantages of machine learning and deep learning, achiev-
ing great performance based on a large amount of labeled data samples. However, in IoV
environment, datasets of malicious website are difficult to collect due to privacy policies
of vehicle manufacturers. Also, existing models mostly focus on only one or two types
of malicious websites, hence performing poorly in multi-type classification tasks.

3 Our Model
In this paper, a classification model for malicious websites based on an improved Relation
Network, CarMaNet, is proposed. In order to better capture the multi-level features of
malicious websites and enhance the model’s ability to learn and represent fine-grained
image features, we add attention mechanism module into the feature embedding module
in the traditional Relation Network [14], enabling CarMaNet to relate different images
218 Q. Wang et al.

according to image similarity. The overall structure of CarMaNet mainly includes three
modules: embedding module, attention mechanism module and relation module. The
structure of CarMaNet is shown in Fig. 1.

Fig. 1. The Structure of CarMaNet

CarMaNet first utilizes multi-scale embedding module to extract image features in the
support set and the query set. In embedding module, part of the convolutional blocks are
replaced by Inception block, enabling the embedding module to capture image features
in different scales and achieve deeper and broader abstraction of these features. Image
features are finally converted to embeddings via convolutional blocks and pooling layer.
At the same time, we insert an improved SENet [15] attention mechanism module
between the embedding module and the relation module to fully utilize the relations
between modules and enhance the represention ability of CarMaNet. This attention
mechanism module imports a parallel SE module into each residual block and combines
the output of the conventional convolutional module and SE block, providing more
accurate and various representation for subsequent metric computations.
The relation module contains 3 convolutional blocks and 2 fully connected layers, we
replace the original activation function and loss function to accommodate the embedding
module and SENet attention mechanism module. This approach further optimizes the
training process of CarMaNet. Finally, the relationship score is output using a scaling
function, the category which has the highest score is chosen as the final prediction result.

3.1 Multi-scale Embedding Module

In CarMaNet, we use Inception block to enhance the function of embedding module. The
Inception block can capture multi-level image features, focusing not only on the overall
structure of the input sample, but also on fine-grained micro features. This characteristic
gives the Inception block significant advantages when processing malicious website
A Malicious Websites Classifier Based on an Improved Relation Network 219

images containing multi-leve information. The structure of the embedding module is


shown in Fig. 2.

Fig. 2. The Structure of the Embedding Module

The multi-scale embedding module is an improvement based upon the traditional


embedding module of the Relation Network. This module consists of three branches.
The first branch retains the structure of the traditional Relation Network, it includes
two 3 * 3 convolutional layers, several normalization layers, a Mish activation layer,
and a max-pooling layer, ensuring that the network can deeply extract various features.
The third and fourth convolutional blocks consist of one 3 * 3 convolutional layer each,
several normalization layers, and max-pooling layers. We replaces the original fourth
convolutional block with an Inception structure in the second branch, which enables the
embedding module to better extract multi-level features and enhancing the ability of the
whole network to capture features across different scales. In the third branch, we replace
both the third and fourth convolutional blocks with the Inception structure. Furthermore,
the number of channels in each branch is set to 64, 128, 256, and 512, ensuring that the
network captures rich feature information at all levels.
The Inception structure processes features parallelly, allowing CarMaNet to capture
more feature information in single layers. The Structure of Inception module used in
this paper is shown in Fig. 3. With a 1 convolution kernel, the embedding module can
decrease the number of parameters and computational cost while maintaining the net-
work’s depth and complexity. Meanwhile, convolutional kernels of different sizes extract
features at various scales, making the network more robust and efficient. We designed
four branches in the Inception structure, each with different block combinations. The
first branch uses a 1 * 1 convolution kernel to extract features and reduce the dimension
embeddings, reducing the depth to optimize the number of parameters and computational
complexity. Meanwhile, with the ability of capturing non-linear relationships, the non-
linear activation model is utilized to enhance the model’s ability to capture non-linear
relationships.
220 Q. Wang et al.

Similar to the first branch, the second branch first applies a 1 * 1 convolution for
dimensionality reduction, then uses a 3 * 3 convolutional kernel to expand the range
when extracting spatial feature.
Building upon the second branch, we replaces the 3 * 3 convolutional kernel with a
5 * 5 kernel in the third branch, along with a 1 * 1 kernal to reduce the computational
cost of the 5 * 5 kernel, enabling the branch to capture broader spatial features at a lower
cost in large-scale images.
The fourth branch employs 3 * 3 max-pooling to extract salient regional features and
a 1 * 1 convolution to reduce the dimension, enhancing the model’s spatial invariance to
handle the transformation of the input image samples, preparing features for subsequent
layers.
Finally, the outputs of these four branches are merged. This four-branch structure
enables the embedding module to learn and utilize features from different dimensions
without increasing the computational cost, thereby improving the model’s representa-
tion ability and performance. The combined feature map contains information from all
branches, allowing the network to further process these multi-scale features in subsequent
modules.

Fig. 3. The Stucture of Inception Block in This Paper

3.2 FCAA Module

Traditional detection models primarily rely on textual features(such as URLs) or sim-


ple visual features, which are insufficient for capturing the nuanced and multi-level
patterns often present in malicious website images. In order to fully utilize the rela-
tionship between channels and enhance the representation ability of model, we make
some improvement based up the traditional SENet. We propose a adaptive attention
A Malicious Websites Classifier Based on an Improved Relation Network 221

mechanism, FCAA, based on fused channels, which utilizes a additional SE module to


connect with each residual block, learning the relationships between different features.
At the same time, Compared to models based on convolutional neural networks alone,
the attention mechanism in CarMaNet introduces a targeted recalibration process that
enhances feature representation, while the inception module improves its capacity for
multi-level feature aggregation. These design innovations make CarMaNet particularly
effective in scenarios requiring high-precision classification of malicious websites based
on visual content. The structure of the improved SENet is shown in Fig. 4.

Fig. 4. The Stucture of Improved SENet

The SE module extract global spatial information through global average pooling
(GAP), denoted as S = GAP(F(X)), then learn the weight relationships between channels
through fully connected layer (FC) and an activation function (such as Sigmoid). While
preforming this process, the weight of each channel can be recalibrated to enhance the
useful features while suppressing irrelevant ones. This process can be denoted as Eq. (1).

SE(F(X )) = F(X )  σ (FC2(ReLU (FC1(S)))) (1)

In Eq. 4, σ represent the Sigmoid activation function, ReLU is a none-liner activation


function, and FC1/FC2 is a fully connected layer.
The structure of the traditional residual block bypasses the convolutional layers,
allowing the input feature map X to be directly added to the output of these convolutional
layers F(X), forming the final output Y = F(X) + X. We improve this structure by adding
a parallel SE block into each residual block. Specifically, in order to merge different
features, the origin output of the residual block, denoted as F(X), is merged with the
output of the parallel SE block, denoted as SE(F(X)) through addition operation F(X)
⊕ SE(F(X), while SE(·) is an operation called Squeeze-and-Excitation. Additionally, a
1*1 is added to integrate and reallocate information across channels, achieving a more
meaningful feature representation. This process can be represented as Eq. (2).

F(X ) = H (Conv1×1 (F(X ) ⊕ SE(F(X )))) (2)


222 Q. Wang et al.

The original output of the residual block can be denoted as Eq. (3).
Y = F(X ) + X (3)
The output of the improved SENet can be expressed as Eq. (4) and Eq. (5).
Y = H (Conv 1×1 (F(X ) ⊕ SE(F(X )))) + X (4)

SE(F(X )) = F(X )  σ (FC1(ReLU (FC1(S)))) (5)


By combining the output of the SE block and residual block, this improved SENet can
enhance the representation ability of our model. The attention mechanism selectively
emphasizes important features while suppressing less relevant ones, which improves
fine-grained feature extraction but introduces additional matrix operations and attention
weight computations, slightly increasing computational complexity while remaining
efficient enough for real-time applications in typical IoV environments with standard
hardware configurations.

3.3 Relation Module


The main task of the relation module is to calculate the relation score between samples
in the support set and the query set. The features of these two types of samples are firstly
extracted to form embeddings using the embedding module, then the embeddings are
efficiently merged via Inception blocks, constructing fine-grained features in detail. After
the weight of these embeddings being recalibrated via attention mechanism, they are
finally sent to the relation module. The relation module calculate the similarity between
the query set and the support set according to input embeddings, and the relation score
is output to identify the type of the input query sample. The process of calculating the
relation score is shown in Eq. 6 and Eq. 7.
    
ri,j = gϕ Hw C fϕ (xi ), fϕ xjm (6)


m 
n
  2
ϕϕ,φ ← argmin ri,j − I yi = yj (7)
ϕ,φ i=1 j=1

In Eq. 6, r i,j is the final relation score between samples in support set and samples
in the query set, while x i is the sample from the support set and x j is the sample from
the query set. i represents the numbers of sample categories, m is the branches of the
embedding networks, and f(x i )/f(x j ) represent the feature embeddings of samples from
the support /query set. C represents the process of feature merging, which itegrates
the embeddings to provide combined feature representation for calculating the relation
score. H w represents the weighted process based on attention mechanism, and gφ is the
mapping of the relation module.
In Eq. 7, I is an indicator, if the support set sample and the query set sample belong
to different categories, the value of I is 0, otherwise it is 1. Argmin is used to adjust
the parameters in the module based on MSE. Specific modifications to the activation
function and loss function are as follows.
A Malicious Websites Classifier Based on an Improved Relation Network 223

(1) Activation Function


In CarMaNet, we use Mish to replace ReLu as the Activation function to solve the
issue of the model stop learning due to negative input values, enhancing the gener-
alization ability and the stability of CarMaNet. The generalization ability of models
is of great importance when facing limited data in few-shot scenario, meanwhile
the high efficiency of Mish can satisfy the real-time performance and accuracy in
malicious website classification. Mish is described in Eq. 8.
  
Mish(x) = x · tanh ln 1 + ex (8)
In Eq. 8, softPlus (x) = n(1 + ex) maps the input value to the natural logarithmic
range, avoiding the limitation of ReLu when facing negative input values. In few-shot
scenarios, this increases the generalization ability of the model and reduces the risk
of overfitting. Tanh further processes the output of softPlus, mapping the output of
Mish to a range from −1 to 1. Compared with the ReLU function, the smoothness of
tanh when handling negative values improves the model’s performance in malicious
website classification.
(2) Loss Function
In order to avoid the limitation that the MSE function is over sensitive to outliers,
we choose SmoothL1 as the loss function [16]. SmoothL1 combines both squared
error (L2 loss) and absolute error (L1 loss), reducing the sensitivity to outliers while
maintaining the smoothness for smaller errors. The expression of Smooth L1 loss is
shown in Eqs. 9 and 10.
1
n
SmoothL1 = zi (9)
n
i=1

(f (xi )−yi )2
zi = 2 , |f (xi ) − yi | < 1 (10)
|f (xi ) − yi | − 21 , otherwise
In Eqs. 9 and 10, f (x i ) is the prediction value and yi is the prediction value of the
sample. Unlike MSE, SmoothL1 maintains consistency and stability in penalization,
even when the error is large. This feature not only prevents excessive adjustments
to outliers but also ensures the stability of the gradient, reducing the risk of gradient
explosion.

4 Experiments and Analysis


4.1 Experiments Setups
(1) Datasets
In this section, we use self-constructed dataset and authorized third-party dataset
to conduct our experiments. Since pornography, gambling, and phishing are among
the most prevalent types of malicious websites in today’s cyberspace, we believe
that individuals are highly likely to encounter and be affected by them. After com-
parision and validation, we finally choose GSPD, GWDT and AGPD datasets as
our experimental datasets. These experimental datasets contain malicious website
screenshots categorized into three main types: pornography, gambling, and phishing,
with further subcategories under each type.
224 Q. Wang et al.

GSPD dataset: GSPD dataset contains three main categories of malicious web-
sites—pornography, gambling, and phishing, meanwhile spanning a total of 16 sub-
categories. Each subcategory includes 100 screenshots of malicious websites, with
images having a resolution of 600 × 600 pixels.
GWDT dataset: GWDT dataset contains 20 categories of malicious websites,
with approximately 130 samples per category. The image resolution is 256 × 256
pixels.
AGPD dataset: AGPD dataset includes 18 categories of malicious websites, each
with 200 samples, and the image resolution is 600 × 600 pixels.
(2) Baseline Models
In this section, we compare CarMaNet with four baseline models respectively
including Relation Net, MAML(Model-Agnostic Meta-Learning)[15], Prototypi-
cal Nets[17] and Matching Nets[18]. The introduction of baseline models are as
follows.
Relation Network consists of an embedding model and a relation module, it first
adds the embedding form of input classes to form a feature map of all classes, then
classify samples by calculating the relation scores according to the feature map.
MAML (Model-Agnostic Meta-Learning) is a model-agnostic meta-learning
algorithm. By training the model’s initial parameters on a set of tasks, MAML can
quickly learn from a small amount of new data. MAML achieves good generalization
performance with only a few gradient updates to the model parameters.
Prototypical Nets usually learns a mapping function to map the input samples to
prototypes in the vector space, then classify new samples accoding to their distances
to these prototypes.
Matching Nets maps input samples into a vector space using neural networks,
and calculate the similarity score between input samples and support set samples
using weighted calculation based on attention mechanism. This similarity score is
then used to classify the input samples.
(3) Hyperparameters
In our experiments, we set the initial learning rate to 0.001, utilizing Adam algorithm
as the gradient algorithm. Meanwhile, we set the gamma decay factor to 0.5, and the
number of epochs is set to 1000. Experiments are conducted in both 5-way 1-shot
and 5-way 5-shot scenarios.

4.2 Experimental Results and Analysis


4.2.1 Comparison Experiments of Different Few-Shot Models
In order to validate the superiority of CarMaNet, this experiment is conducted on GSPD,
GWDT and AGPD datasets in 5-way 1-shot and 5-way 5-shot scenarios to compare the
performance of CarMaNet, Relation Net, MAML, PrototypicalNets and MatchingNets.
Specifically, we randomly choose 1 or 5 samples from 5 categories in every episode to
construct a N-way K-shot classification task. In terms of the query set, we ramdomly
choose the samples from the remaining samples in these 5 categories. Experimental
results are shown in Tables 1, 2 and 3.
A Malicious Websites Classifier Based on an Improved Relation Network 225

(1) Exprimental results on GSPD dataset

Table 1. The classification performance of different models on GSPD dataset

Model 5-way
1-shot 5-shot
MAML 61.25 ± 1.33% 75.09 ± 1.74%
PrototypicalNets 64.74 ± 0.42% 72.83 ± 1.21%
MatchingNets 48.56 ± 1.32% 62.73 ± 0.68%
Relation Net 64.33 ± 0.87% 81.37 ± 1.26%
CarMaNet 67.95 ± 0.73% 89.69 ± 0.98%

We can analyze from Table 1 that on GSPD dataset, CarMaNet achieves an


accuracies of 67.95% and 89.69% respectively in the 5-way 1-shot and 5-way 5-shot
scenarios. In 1-shot scenario, CarMaNet increases 3.21% and 8.32% respectively
compared with the Prototypical Nets model’s 64.74% and the Relation Network
model’s 81.37%.
(2) Exprimental results on GWDT dataset

Table 2. The classification performance of different models on GWDT dataset

Model 5-way
1-shot 5-shot
MAML 58.47 ± 1.29% 73.66 ± 0.62%
PrototypicalNets 62.36 ± 0.28% 80.73 ± 1.11%
MatchingNets 61.35 ± 1.26% 78.12 ± 0.33%
Relation Net 65.80 ± 1.59% 78.99 ± 0.56%
CarMaNet 72.89 ± 1.30% 92.94 ± 1.55%

From Table 2 we can conclude that on GWDT dataset, CaMaNet reachs an


accuracy of 72.89% and 92.94% respectively in 5-way 1-shot and 5-way 5-shot
classification tasks, which increase by 7.09% and 13.95% compared with Relation
Net.
(3) Exprimental results on AGPD dataset
From Table 3 we can concluded that on AGPD dataset, CarMaNet reaches an accu-
racy of 68.91% and 88.37% in 5-way 1-shot and 5-way 5-shot scenarios respectively,
increasing by 6.07% and 9.05% than Relation Network and MAML.
In conclusion, when performing few-shot classification tasks for malicious web-
sites, CarMaNet outperforms traditional few-shot models, which indicates that the
improvement we made on Relation Network is effective.
226 Q. Wang et al.

Table 3. The classification performance of different models on AGPD dataset

Model 5-way
1-shot 5-shot
MAML 56.66 ± 1.30% 79.32 ± 1.66%
PrototypicalNets 59.81 ± 1.27% 74.75 ± 1.31%
MatchingNets 57.79 ± 0.79% 71.37 ± 0.75%
Relation Net 62.84 ± 0.92% 75.83 ± 0.70%
CarMaNet 68.91 ± 0.81% 88.37 ± 0.94%

4.2.2 Convergence Experiments


In this section, we conducted convergence experiments for CarMaNet in 5-way 1-shot
and 5-way 5-shot scenarios, using GWDT、AGPD and GSPD datasets. Using the tradi-
tional Relation Network as comparision, experimental results are shown in Figs. 5 and
6.

GWDT AGPD GSPD

Fig. 5. Convergence Experiment result in 5-way 1-shot scenario

From Fig. 5 we can conclude that under the condition of 5-way 1-shot, CarMaNet
and Relation Network have the similar accuracy in the early stage of training. After
reaching 230 epochs, CarMaNet gradually comes into a fitting state and turns to be more
stable, outperforming the traditional Relation Network model in accuracy and fitting
efficiency.
From Fig. 6 we can conclude that CarMaNet’s performance is similar to that in the
5-way 1-shot experimental environment. Under the 5-way 5-shot condition, CarMaNet
has similar performance with the original model around the 100th epoch and continues
to rise during the subsequent training process. Around the 230th epoch, CarMaNet enters
a fitting state and stabilizes, outperforming the traditional Relation Network model in
accuracy and fitting efficiency.
Experimental results shows that under both 5-way 1-shot and 5-way 5-shot con-
ditions, CarMaNet reaches higher classification accuracy while maintaining the same
A Malicious Websites Classifier Based on an Improved Relation Network 227

Fig. 6. Convergence Experiment result in 5-way 5-shot scenario

fitting speed as Relation Net. This indicates that CarMaNet has better performance in
few-shot tasks with limited training data.

4.2.3 Ablation Experiments


In order to explore the effects of multi-scale embedding module, FCAA and FCAA
under different compression ratios, we conducts three experiments on GSPD, GWDT
and AGPD datasets in this section:
(1) Testing the traditional Relation Network under 5-way 1-shot and 5-way 5-shot
environments.
(2) Testing the improved Relation Network model with the multi-scale embedding
module, under 5-way 1-shot and 5-way 5-shot conditions.
(3) Testing the model from (2) with the addition of a FCAA module with different com-
pression ratios (r) under 5-way 1-shot and 5-way 5-shot conditions. Experimental
results are shown in Table 4.

Table 4. The Results of Ablation Experiments on GSPD、GWDT and AGPD dataset

RN Imp FCAA GSPD GWDT AGPD


_RN r=8 r= r= 1-shot 5-shot 1-shot 5-shot 1-shot 5-shot
12 16

- - - - 64.33% 81.37% 65.8% 78.99% 62.84% 75.83%

- - - - 65.74% 82.51% 66.69% 79.36% 63.55% 76.61%
√ √
- - - 66.62% 82.12% 67.0% 82.31% 63.83% 77.25%
√ √
- - - 67.95% 89.69% 72.89% 92.94% 68.91% 88.37%
√ √
- - - 66.72% 82.85% 67.31% 82.44% 62.35% 76.56%

Analyzing Table 4, we find that in 5-way 1-shot / 5-way 5-shot condition, the accuracy
of CarMaNet increases by 1.41%, 0.89%, 0.71% and 1.14%, 0.37%, 0.78% compared
with the tradition Relation Network when we only activate the multi-scale embedding
228 Q. Wang et al.

module(Imp_RN). When we further activate the FCAA module and set different com-
pression ratios r, CarMaNet outperformes Relation Network in accuracy, increasing by
3.62%, 7.09%, 6.07% and 8.32%, 13.95%, 12.54% in 5-way 1-shot and 5-way 5-shot
environments. This demonstrate the effectiveness of our multi-scale embedding module
and the FCAA model.
In order to further validate the superiority of FCAA, we continue to choose CBAM
[19], ECANet [20], SENet as the attention module of our model. Using the average value
of 5 experiments, experimental results are shown in Table 5.

Table 5. Result of Ablation Experiments using different attention modules

Attention Module GSPD GWDT AGPD


1-shot 5-shot 1-shot 5-shot 1-shot 5-shot
ECANet 65.56% 80.75% 64.77% 78.74% 59.75% 75.68%
SENet 63.35% 81.25% 66.64% 81.63% 63.42% 76.91%
CBAM 66.88% 85.63% 63.21% 77.33% 63.35% 77.73%
FCAA 67.95% 89.69% 72.89% 92.94% 68.91% 88.37%

Through Table 5 we find that model with FCAA designed in this paper achieves
higher accuracy on GSPD, GWDT and AGPD datasets compared with models using
traditional attention mechanisms like SENet, ECANet, or CBAM. In the 5-Way 1-Shot
scenario on the GSPD dataset, the model with FCAA achieves an accuracy of 67.95%,
which is 1.07% higher than the model with CBAM. In the 5-Way 5-Shot scenario,
the classification accuracy of the model with FCAA on the GSPD dataset is 89.69%,
surpassing CBAM by 4.06%. On the GWDT and AGPD datasets, the model with FCAA
outperforms CBAM by 15.61% and 10.64%, in two scenatios respectively. In 5-Way
1-Shot scenario on the GWDT dataset, the model with FCAA reaches an accuracy of
92.94%, while the accuracy of the model with SENet is 81.63%, making the model with
FCAA approximately 11.31% more accurate than the model with SENet. In summary, the
FCAA attention mechanism demonstrates superior performance in classifying malicious
website, and due to its few-shot learning capability, CarMaNet can better learn the
features of malicious website images in IoV environments with insufficient samples.
This enables CarMaNet to exhibit good transferability across other types of malicious
website datasets.

5 Conclusions

In this paper, we main foucus on the issues in in-car intelligent system such as lack of
samples and insufficient feature extraction. In order to address these issues, we propose
a few-shot classification model for malicious websites based on an improved Relation
Network called CarMaNet. We introduce the inception block into the traditional Relation
Network to increase its width, allowing it to capture multi-scale features better. Second,
A Malicious Websites Classifier Based on an Improved Relation Network 229

we employs a fusion channel adaptive attention(FCAA) mechanism based on SENet to


deeply capture the feature information of samples. Finally through several experiments
on GSPD, GWDT and AGPD dataset we can demonstrate that CarMaNet outperforms
traditional few-shot models in classification accuracy while has good convergence. This
proves the effectiveness of CarMaNet in classifying malicious websites under few-shot
scenarios.
However, there are still work we can do to make CarMaNet better. Firstly more
datasets of different types of malicious website images should be included into the
experiments to test the transferability of CarMaNet. Morever, to further enhance the
detection efficiency of CarMaNet, integrating it with existing tools such as VirusTotal
or Norton Safe Web presents a promising direction. By leveraging their well-established
blacklisting mechanisms, CarMaNet could collaborate with these platforms to improve
its ability to detect malicious websites more comprehensively. This integration could
enable the model to utilize real-time threat intelligence and complement its few-shot
learning capabilities, providing a robust solution for identifying and classifying mali-
cious websites. Future research could explore the feasibility and implementation of such
an integrated approach.

Fundings.. The paper is supported by the the key R&D programme of Zhejiang Province
No.2024C01012.

Disclosure of Interests.. The authors declare that they have no known competing financial inter-
ests or personal relationships that could have appeared to influence the work reported in this
paper.

References
1. Zhao, C.: Research on malicious website identification method integrating URL and page
information. Jiangsu Univ. Sci. Technol. (2022). https://doi.org/10.27171/d.cnki.ghdcc.2022.
000691
2. Afzal, S., Asim, M., Javed, A.R., Beg, M.O., Baker, T.: URLdeepDetect: a deep learning
approach for detecting Malicious URLs using semantic vector models. J. Netw. Syst. Manage.
29(3) (2021)
3. Koch, G., Zemel, R., Salakhutdinov, R.: Siamese neural networks for one-shot image
recognition. In: ICML Deep Learning Workshop, vol. 2, no. 1 (2015)
4. Sung, F., Yang, Y., Zhang, L., et al.: Learning to compare: relation network for few-shot
learning. In: IEEE Conference on Computer Vision and Pattern Recognition, pp. 1199–1208
(2018)
5. Hu, J., Shen, L., Sun, G.: Squeeze-and-excitation networks. In: IEEE Conference on Computer
Vision and Pattern Recognition, pp. 7132–7141 (2018)
6. VirusTotal: VirusTotal [EB/OL]. 01-06-2004. https://www.virustotal.com/gui/home/url.
Accessed 10 Nov 2022
7. Tencent URL Security Detection Center: Tencent URL Security Detection Center [EB/OL].
(2023–01–10). https://urlsec.qq.com/check.html. Accessed 1 Nov 2022
8. Norton Safe Web: Norton Safe Web [EB/OL]. 20-12-2023. https://safeweb.norton.com/.
Accessed 20 Nov 2023
230 Q. Wang et al.

9. Chen, Y., Zheng, R., Zhou, A., et al.: Automatic detection of pornographic and gambling
websites based on visual and textual content using a decision mechanism. Sensors 20(14),
3989 (2020)
10. Wang, C., Zhang, M., Shi, F., et al.: A hybrid multimodal data fusion-based method for
identifying gambling websites. Electronics 11(16), 2489 (2022)
11. Siddiq, M.A.A., Arifuzzaman, M., Islam, M.S.: Phishing website detection using deep
learning. In: 2nd International Conference on Computing Advancements, pp. 83–88 (2022)
12. Zhang, Y., Fu, X., Yang, R., et al.: DRSDetector: detecting gambling websites by multi-level
feature fusion. In: IEEE Symposium on Computers and Communications (ISCC), pp. 1441–
1447. IEEE (2023)
13. Naru, P., Chinthala, S.K.R., Sekhar, P.G., et al.: Detection of fake websites using machine
learning techniques. In: 3rd International Conference on Smart Data Intelligence (ICSMDI),
pp. 477–482. IEEE (2023)
14. Chiramdasu, R., Srivastava, G., Bhattacharya, S., et al.: Malicious URL detection using
logistic regression. In: IEEE International Conference on Omni-Layer Intelligent Systems
(COINS), pp. 1–6. IEEE (2021)
15. Finn, C., Abbeel, P., Levine, S.: Model-agnostic meta-learning for fast adaptation of deep
networks. In: International Conference on Machine Learning, PMLR, pp. 1126–1135 (2017)
16. Liu, C., Yu, S., Yu, M., et al.: Adaptive smooth L1 loss: a better way to regress scene texts
with extreme aspect ratios. In: IEEE Symposium on Computers and Communications (ISCC),
pp. 1–7. IEEE (2021)
17. Snell, J., Swersky, K., Zemel, R.: Prototypical networks for few-shot learning. In: Advances
in Neural Information Processing Systems, vol. 30 (2017)
18. Vinyals, O., Blundell, C., Lillicrap, T., et al.: Matching networks for one-shot learning. In:
Advances in Neural Information Processing Systems, vol. 29 (2016)
19. Woo, S., Park, J., Lee, J.Y., et al.: CBAM: convolutional block attention module. In: Ferrari,
V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) Computer Vision – ECCV 2018. ECCV
2018. Lecture Notes in Computer Science(), vol. 11211, pp. 3–19. Springer, Cham (2018).
https://doi.org/10.1007/978-3-030-01234-2_1
20. Wang, Q., Wu, B., Zhu, P., et al.: ECA-Net: efficient channel attention for deep convolutional
neural networks. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition
(CVPR), pp. 11534–11542 (2020)
Unknown Category Malicious Traffic Detection
Based on Contrastive Learning

Leiming Yan1,2(B) , Tao Zhou1,2 , and Xianyi Chen1,2


1 Engineering Research Center of Digital Forensics Ministry of Education, Nanjing University
of Information Science and Technology, Nanjing 210044, China
[email protected]
2 School of Computer Science and Cyber Science and Engineering, Nanjing University of

Information Science and Technology, Nanjing 210044, China

Abstract. In response to the continuous emergence of novel unknown malicious


traffic and the limitations of traditional detection methods, this paper presents
an unknown-category malicious traffic detection approach based on contrastive
learning. The proposed method utilizes contrastive learning to train an encoder
for distinguishing normal traffic and multiple types of known malicious traffic,
thereby enhancing inter-class separability. Subsequently, the model is fine-tuned
to optimize the representations of malicious traffic categories with limited sam-
ples, amplifying the separation between different classes in the feature space. This
ensures that the distance between unknown malicious traffic and other categories
exceeds a predefined threshold, thereby achieving effective detection. Experi-
mental results on the CICIoT2023 and UNSW-NB15 datasets demonstrate the
superior performance of the proposed method, achieving multi-class F1-scores of
93.23 and 89.93, respectively, outperforming traditional approaches. Additionally,
the method achieves an accuracy of over 83% in simulated scenarios involving
unknown-category malicious traffic detection.

Keywords: Intrusion Detection · Unknown Malicious Traffic · Contrastive


Learning

1 Introduction
With the increasing number of electronic devices and the growing complexity of network
environments, cybersecurity issues have become increasingly prominent, resulting in
significant losses for the network economy. To effectively prevent cyberattacks, the con-
struction of Intrusion Detection Systems (IDS) has become one of the primary methods.
Existing IDS can be broadly categorized into host-based and network-based systems:
the former primarily detects intrusions by monitoring log information, while the lat-
ter analyzes network traffic to determine the presence of intrusions. Although machine
learning (ML)-based IDS are widely applied, they face limitations in extracting deep
features and struggle to cope with increasingly sophisticated attack methods. Deep learn-
ing (DL), with its powerful feature extraction capabilities, has gradually become integral

© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 231–245, 2025.
https://doi.org/10.1007/978-981-96-4566-4_16
232 L. Yan et al.

to IDS. However, it still confronts challenges such as imbalanced data categories. As


attackers continuously upgrade their strategies, the boundaries between malicious and
normal traffic are becoming increasingly blurred. Also, new types of malicious traffic
are emerging rapidly, limiting the effectiveness of existing detection methods, especially
against zero-day attacks.
To address the above issues, this paper proposes a novel method for detecting
unknown category malicious traffic based on contrastive learning to enhance the effec-
tiveness of intrusion detection. Contrastive learning trains the model by comparing pos-
itive and negative samples, enabling it to capture the underlying structures and patterns
between normal and abnormal traffic. This allows the model to learn more discriminative
features, providing a solid foundation for subsequent model fine-tuning, thus making the
fine-tuning process more efficient and accurate. This approach helps the model quickly
adapt to new tasks with limited data, especially in situations with insufficient samples,
effectively clarifying the differences between categories and enhancing classification
accuracy.
The main contributions of this work are as follows:
(1) We propose a contrastive learning-based method for detecting unknown category
malicious traffic, leveraging contrastive learning to obtain discriminative feature rep-
resentations, thereby increasing the differentiation between various traffic categories
and improving detection performance.
(2) We fine-tune the encoder model through contrastive learning pre-training. After
fine-tuning, the model can achieve accurate detection and classification with few
samples, alleviating the problem of class imbalance in traffic data and demonstrating
the capability to detect unknown category malicious traffic, which provides a new
perspective for network traffic intrusion detection.

2 Related Work
2.1 Deep Learning-Based Intrusion Detection Methods

In recent years, deep learning-based intrusion detection techniques have gradually


become mainstream. Reference [1] proposed an unsupervised anomaly detection method
called TadGAN based on Generative Adversarial Networks (GAN) and Long Short-Term
Memory (LSTM) networks. By performing cyclic consistency training on time-series
data and combining various reconstruction error calculations and critic outputs, TadGAN
effectively improves anomaly detection performance and generalization ability. Refer-
ence [2] introduced a network traffic anomaly detection method based on a Multi-Scale
Residual Classifier (MSRC). Through multi-scale observation and waveform transfor-
mation techniques, combined with a stacked autoencoder and lightweight classifier, this
method enhances the detection performance for abnormal network traffic. Reference [3]
proposed a network intrusion detection algorithm that combines hybrid sampling and
deep hierarchical networks. It uses OSS and SMOTE to handle data imbalance, while
employing CNN and Bidirectional Long Short-Term Memory (BiLSTM) to extract fea-
tures, significantly improving detection accuracy. This approach was validated on the
NSL-KDD and UNSW-NB15 datasets.
Unknown Category Malicious Traffic Detection Based on Contrastive Learning 233

Reference [4] proposed a method based on feature selection and Conditional Wasser-
stein GAN (FCWGAN) combined with BiLSTM, which effectively improves the clas-
sification accuracy of network intrusion detection models and significantly reduces false
positives and false negatives. Reference [5] introduced an intrusion detection method
based on VAE-CWGAN and feature importance fusion, which significantly improves
model detection performance. Reference [6] proposed a Few-Shot Learning model based
on a Siamese Convolutional Neural Network (FSL-SCNN) that optimizes feature rep-
resentation and designs three specific loss functions to improve detection accuracy.
Reference [7] developed a multi-module integrated intrusion detection system named
GMM-WGAN, which combines a clustering algorithm based on the Gaussian Mix-
ture Model and Wasserstein GAN for handling imbalance, followed by a classification
module design based on CNN and LSTM. This system effectively enhances intrusion
detection performance on the NSL-KDD and UNSW-NB15 datasets.
However, the above methods have not fully addressed the issue of class imbalance
in the datasets, and deep learning models require large amounts of data for training
to achieve high accuracy. Consequently, some types of attacks are not easily learned
by deep learning models, resulting in lower detection performance. Therefore, it is
necessary to address the class imbalance problem in the datasets to further improve
detection performance.

2.2 Contrastive Learning

Contrastive learning has been successfully applied in fields such as image recognition,
natural language processing, and video understanding [8], outperforming conventional
methods. Unlike traditional deep learning approaches, contrastive learning focuses on
learning feature representations that better reflect the semantic relationships between
samples. Typical studies like Hjelm et al. [9] and Bachman et al. [10] constructed con-
trastive prediction tasks based on global-local relations. Van den Oord et al. [11], Henaff
[12], took the context relationship as the prediction task and proposed the contrastive pre-
diction coding model. The algorithm used the autoregressive model to predict the future
time step as the prediction task. Based on that, they could learn a hidden representation
of the data.
Contrastive methods consist of two steps: contrastive task construction and con-
trastive learning. In the contrastive task construction, the semantic relationships between
samples are predefined. To better capture these semantic relationships, it is often neces-
sary to construct contrastive learning task pairs using various data augmentation samples
[13–15] and optimize the contrastive loss. The objective is to optimize the embedding
function based on the contrastive loss, bringing samples closer to others of the same class
while pushing them farther from samples of different classes. Consequently, contrastive
learning can reinforce the differences between traffic from different classes, thereby
enhancing the detection model’s sensitivity to abnormal traffic.
234 L. Yan et al.

3 Method

The paper proposes a contrastive learning-based method for detecting unknown mali-
cious traffic, aiming to address issues such as class imbalance in malicious traffic and
unclear boundaries between traffic categories. The method effectively detects unknown
types of malicious traffic and consists of two main components: contrastive learning
pre-training and fine-tuning the model classifier. The overall framework of the method
is illustrated in Fig. 1.

Fig. 1. Methodology Framework Diagram

3.1 Contrastive Learning Pre-training

This paper introduces contrastive learning to enhance the discriminative ability between
different categories of traffic, enabling the model to better learn the differences between
various types of traffic and thereby improving the accuracy and robustness of the intrusion
detection model. The pre-training process of contrastive learning is composed of data
augmentation, an encoder, and the contrastive learning task, as illustrated in Fig. 2.

Fig. 2. Contrastive Learning Task Diagram


Unknown Category Malicious Traffic Detection Based on Contrastive Learning 235

When constructing a contrastive learning task, samples are usually categorized based
on their labels, with samples of the same category forming positive pairs, while samples
of different categories form negative pairs. To address the issue of insufficient data for
certain malicious traffic categories, this paper adopts a data augmentation method by
adding random masks to generate augmented samples that retain the same semantics as
the original traffic samples. These augmented samples, along with the original samples
from the same category, form positive sample pairs, while samples from other categories
form negative pairs. This method effectively enlarges the dataset and provides a more
reliable foundation for subsequent contrastive learning tasks. The data augmentation
method with random masking is shown in Fig. 3.

Fig. 3. Data Augmentation Method

 
Given a network traffic packet sequence flow = p1 , p2 ; · · · , pm , define a mask
containing elements with a value of 0 as follows:
 
mask = mask1 , mask2 , . . . , maskm


m
 
where maski = m − k (1)
i=1
 
The mask1 , mask2 , . . . , maskm contains a mask with k values of 0 at random
positions, and the operation of adding the mask is denoted as:

flowmasked = flow ∗ mask (2)

The theoretical support for constructing contrastive learning tasks through data
augmentation methods by adding random masks is as follows:
(1) Although the data undergoes random masking, it still retains some features and
information of the original data, making the analysis of the masked data still effective.
Therefore, the random masking method can ensure the validity of traffic data to a
certain extent.
(2) In complex and volatile network environments, network fluctuations may lead to
occasional packet loss in network traffic data. By adding random masks to generate
traffic data, it is possible to effectively simulate real network environments and reflect
actual packet loss phenomena.
236 L. Yan et al.

(3) In network intrusion detection tasks, the detection model should possess a certain
degree of robustness. Even when there are occasional missing packets in the traffic
sequence, a robust intrusion detection model should still be able to accurately identify
and detect potential intrusion behaviors.
After the data augmentation operation, the class label of the original sample is
assigned to the augmented sample. In this way, the augmented samples form positive
sample pairs with samples of the same class, and negative sample pairs with samples of
different classes, as illustrated in Fig. 4.

Fig. 4. Construct Positive and Negative Sample Pairs

To enhance the model’s classification capability, we map the samples into the fea-
ture space using an encoder to perform contrastive learning tasks. The core idea of
contrastive learning is to optimize the encoder through backpropagation using the con-
trastive loss function, enabling the model to effectively distinguish between different
classes of samples. Specifically, the objective of the contrastive loss is to minimize the
distance between positive sample pairs while maximizing the distance between neg-
ative sample pairs in the feature space, thereby improving the model’s discriminative
performance. The contrastive loss function is defined as follows:

Lcontrastive = Lconstrastive
i
j z
exp zi · τp
 −1   (3)
a∈(Ai ) exp(zi · τ )
za
= |p | log
(i) p∈p(i)
i

In this context, zi is used as the anchor sample, forming positive pairs with samples
from the same class and negative pairs with samples from different classes. Let the index
of the anchor sample be denoted as i, and let P(i) and A(i) represent the set of positive pairs
and the set of all sample pairs for the anchor sample, respectively. The denominator of the
contrastive loss function is computed over the entire sample pair set, while the numerator
Unknown Category Malicious Traffic Detection Based on Contrastive Learning 237

represents the similarity between the anchor sample and its positive pairs. To optimize
Eq. (3), it is necessary to maximize the dot product in the numerator while minimizing
the value of the denominator. The parameter τ ∈ R+ serves as a temperature coefficient,
aiming to guide the model to pay more attention to hard-to-distinguish negative samples.
The objective of the contrastive loss is to optimize the feature representation such that
the distance between positive pairs is minimized, while the distance between negative
pairs is increased, thereby enhancing the model’s classification performance.

3.2 Fine-Tuned Model Classification

This paper proposes a method for identifying unknown malicious traffic. First, a training
set is constructed by selecting several samples closest to the cluster centers through
clustering. Next, the encoder obtained from the contrastive learning task is fine-tuned
by training the model to perform classification tasks, achieving traffic classification and
detection of unknown malicious traffic. The fine-tuning model is illustrated in Fig. 5.

Fig. 5. Fine-tuned Model Classifier

To enhance the performance of the classifier, a clustering-based strategy is introduced


to select the training set. Specifically, the training data is first clustered using K-means
clustering, and a number of representative samples closest to the cluster centers are
selected to form a new training set. During each training episode, a subset of remaining
samples is randomly selected as the test set to compute the loss and update the model
parameters. For each class C, the corresponding class center vector pc is computed. The
formula for calculating the class center vector is defined as:
1 
pc = fφ (xi ) (4)
|Sc |
(xi ,yi )∈Sc

Let Sc denote the total number of samples in the training set, fφ represent the encoder,
be used to extract the feature vectors of the traffic data, and (xi , yi ) denote the traffic data
along with their corresponding labels.
For each sample xi in the test set, it is necessary to compute the Euclidean distance
between its feature vector fφ (xi ) and each class center vector pc . The distance calculation
238 L. Yan et al.

formula is defined as:


 
d fφ (x), pc = fφ (x) − pc 2 (5)

Subsequently, the test sample x is classified into the category corresponding to the
class center with the shortest distance. The classification formula is defined as follows:
 
y = arg min d fφ (x), pc (6)
c

If a sample has a distance greater than a predefined threshold from all class centers,
it is identified as a new class sample, and the class centers are updated accordingly.
This ensures that the model can adapt to new categories of malicious traffic, thereby
improving the model’s flexibility and adaptability.
During the classification phase, the model computes the probability that the sample
to be detected belongs to each class based on Euclidean distance. Specifically, the smaller
the distance between the sample and the class center, the higher the probability that the
sample belongs to that class. To achieve this, the softmax function can be used to convert
the distances into probabilities. The detailed calculation process is as follows:
  
exp −d fφ (x), pc
P(y = c|x) =     (7)
c exp −d fφ (x), pc

For each sample x in the test set and its true class y, the cross-entropy loss for all samples
is calculated and averaged to obtain the loss function, The loss function is defined as
follows:
   
1  exp −d fφ (x), py
L=− log     (8)
|Q| c exp −d fφ (x), pc
(x,y)∈Q

The loss function guides the update of model parameters during training, enabling
the model to perform classification tasks more accurately and effectively detect novel
attack traffic.

4 Experiments and Analysis


In this section, we conduct experiments on the intrusion detection dataset to evaluate the
effectiveness of our proposed method. The hardware configuration for the experiments
includes an i7-13700KF CPU, NVIDIA GeForce RTX 3090 GPU, 32 GB of RAM, and
a 64-bit Windows 10 operating system, with programming carried out using Python 3.9.

4.1 Experimental Data


To evaluate the effectiveness of the proposed method in the context of intrusion detection
tasks, this study utilizes two public datasets: CICIoT2023 and UNSW-NB15. Given the
current deficiency of suitable datasets for intrusion detection strategies, this study extracts
a limited number of samples from two publicly available datasets to construct training
Unknown Category Malicious Traffic Detection Based on Contrastive Learning 239

and testing sets for the performance evaluation of the proposed method. To address
the issue of class imbalance and enhance the detection performance and generalization
capability of the model, this study sets the sample size for each category in the training
set to 20. To simulate the imbalance present in real-world data, the sample sizes for each
category in the testing set are determined according to their distribution in the dataset.
For instance, in the CICIoT2023 dataset, the majority class (DDoS) is assigned 1,600
samples, while the Benign, DoS, and Mirai classes each have 800 samples. The Spoofing
and Recon classes are assigned 400 samples each, and the Web and BruteForce classes
are set to 100 samples. Furthermore, multiple testing sets containing different sample
configurations were constructed to validate the feasibility of the proposed method. A
similar strategy was employed for selecting the training and testing sets in the UNSW-
NB15 dataset. The distribution of classes in the datasets and the configurations of the
training and testing sets are detailed in Table 1.

Table 1. Distribution of Samples in Datasets, Training Set, and Testing Set

Dataset Attack Category Sample Size Training Set Testing Set


CICIoT2023 Benign 1098195 20 800
DDoS 33984560 20 1600
DoS 8090738 20 800
Mirai 2634124 20 800
Spoofing 486504 20 400
Recon 354565 20 400
Web 24829 20 100
BruteForce 13064 20 100
UNSW-NB15 Benign 2204107 20 1600
Generic 213718 20 1600
Backdoor 28300 20 800
Exploits 21514 20 800
Recon 11853 20 800
Analysis 3847 20 400
Shellcode 1511 20 400
DoS 622 20 100
Fuzzers 357 20 100
Worms 174 20 100

4.2 Evaluation Metrics


In the field of cybersecurity, network traffic data often exhibits a phenomenon of sam-
ple distribution imbalance, where normal samples significantly outnumber anomalous
240 L. Yan et al.

samples. Relying solely on accuracy as an evaluation metric fails to accurately reflect


the model’s performance in real-world scenarios. The primary objective of an intrusion
detection system is to detect anomalous samples as effectively as possible while minimiz-
ing the misclassification of normal samples as anomalous. Therefore, this study employs
three key metrics—Precision, Recall, and F1-Score—to evaluate the effectiveness of the
network intrusion detection system.
• Precision refers to the proportion of predicted positive samples that are actually pos-
itive. It reflects the model’s accuracy when predicting positive samples and measures
the reliability of positive predictions.
• Recall denotes the proportion of actual positive samples that are correctly predicted as
positive by the model. It assesses the model’s coverage capability concerning positive
samples and evaluates its ability to detect positive samples.
• F1-Score is the harmonic mean of Precision and Recall, serving as a method to balance
the two. It provides a more comprehensive assessment of the model’s performance,
especially when there is an imbalance between Precision and Recall, offering a more
balanced evaluation of both.

TP
Recall = (9)
TP + FN
TP
Precision = (10)
TP + FP
2 ∗ (Precision ∗ Recall)
F1 − Score = (11)
Precision + Recall
In this context, TP represents the number of normal samples correctly identified as
normal, FN refers to the number of normal samples incorrectly classified as anomalous,
TN indicates the number of anomalous samples correctly identified as anomalous, and
FP denotes the number of anomalous samples incorrectly classified as normal.

4.3 Experimental Process and Result Comparison

In this section, we validate the feasibility of the proposed method based on three exper-
imental scenarios: (1) binary classification tasks to determine whether samples belong
to normal traffic or attack traffic; (2) multi-class classification tasks to identify whether
samples are normal traffic or a specific type of attack traffic; and (3) detection of unknown
category malicious traffic to assess the model’s capability to detect zero-day attacks.
Given that the CICIoT2023 dataset is relatively new and currently lacks related
research papers, we conducted experiments using four classic machine learning meth-
ods on this dataset, comparing their performance with the proposed method. The machine
learning methods employed include Logistic Regression, Random Forest, Adaptive
Boosting (AdaBoost), and Perceptron. Through comparative experiments, we analyze
the performance of different methods on this dataset to further validate the effectiveness
and advantages of the proposed approach.
Unknown Category Malicious Traffic Detection Based on Contrastive Learning 241

Comparison with Existing Binary Classification Research Methods. To validate


the effectiveness of our proposed method, we conducted a comparative analysis against
traditional machine learning and deep learning approaches. The experimental results are
presented in Table 2, showing that our method achieved a precision of 96.64%, a recall of
96.84%, and an F1-score of 96.74% on the CICIoT2023 dataset, all of which are higher
than those of the traditional machine learning methods. Similarly, on the UNSW-NB15
dataset, our method demonstrated excellent performance, with a precision of 93.73%, a
recall of 96.23%, and an F1-score of 94.98%.
Although our method’s recall is slightly lower than that of the FSL-SCNN method,
both the precision and F1-score are superior to those of FSL-SCNN, resulting in overall
better performance across all metrics compared to other models. Overall, our proposed
method exhibits stronger stability and versatility in the binary classification detection
task, leading to enhanced detection performance.

Table 2. Comparison Results of Binary Classification Methods

Dataset Method Precision Recall F1-Score


CICIoT2023 Logistic Regression 86.31 89.04 91.72
Random Forest 96.53 96.51 96.28
Adaptive Boosting 96.56 94.73 95.62
Perceptron 82.54 79.70 81.05
Our model 96.64 96.84 96.74
UNSW-NB15 Tad-GAN(2020)0 86.81 89.91 88.33
FSL-SCNN(2021)6 90.61 96.81 93.61
MSRC(2023)2 90.12 91.79 90.95
Our model 93.73 96.23 94.98

The Comparison of Multi-classification Methods. In this section, this paper further


conducts multi-classification experiments and compares the proposed method with other
multi-classification approaches, aiming to evaluate its effectiveness in identifying various
attack categories. The experimental results, as shown in Table 3, demonstrate that the
proposed method achieves the best detection performance on the CICIoT2023 dataset.
In contrast, traditional machine learning methods perform poorly, mainly due to the
imbalance in network traffic data, where the disparity in sample sizes across different
categories leads to the model’s tendency to predict the majority class, thus reducing the
classification performance for minority classes. This paper effectively addresses the data
imbalance issue through a few-shot learning strategy, significantly enhancing detection
performance.
On the UNSW-NB15 dataset, the proposed method also performs exceptionally
well, achieving a precision of 90.23%, a recall of 89.64%, and an F1 score of 89.93%,
surpassing other models across all metrics. This demonstrates that the proposed method
242 L. Yan et al.

not only excels in binary classification tasks but also achieves outstanding performance
in multi-classification tasks.

Table 3. Comparison Results of Multi-Classification Methods

Dataset Method Precision Recall F1 Score


CICIoT2023 Logistic Regression 51.24 69.60 53.94
Random Forest 81.54 91.00 86.12
Adaptive Boosting 46.49 48.77 36.86
Perceptron 52.39 65.91 55.51
Our model 93.23 93.24 93.24
UNSW-NB15 CNN-BiLSTM(2020) [3] 82.63 79.91 81.25
FCWGAN-BiLSTM(2022) [4] 86.11 85.57 85.84
GMM-WGAN(2023) [7] 88.46 87.70 85.44
VAE-CWGAN(2024) [5] 89.22 87.58 88.39
Our model 90.23 89.64 89.93

Unknown Class Traffic Detection. This paper not only evaluates the effectiveness of
the proposed method in classifying known attack traffic categories but also further tests
the model’s performance in detecting unknown class malicious traffic. To achieve this,
the paper selects the detection rate as the core metric for assessing the effectiveness of
unknown class malicious traffic detection. The detection rate focuses on the proportion
of correctly identified attack samples by the model, making it particularly suitable for
evaluating the model’s performance against unknown class attacks. For unknown attack
traffic, the key is whether the model can effectively recognize these novel attacks. There-
fore, using the detection rate as an evaluation metric is more direct and appropriate. The
calculation formula for the detection rate is as follows:
TN
Detection rate = (12)
TN + FP
In the CICIoT2023 dataset, five attack categories are defined as known attack cate-
gories and combined with normal traffic data to form the training set for model training.
At the same time, one majority class attack category (DoS) and one minority class attack
category (Spoofing) are defined as unknown attack categories, which are added to the
test set to simulate zero-day attacks. In the UNSW-NB15 dataset, seven attack categories
are selected as known attack categories and combined with normal traffic data to create
the training set. Additionally, one majority class attack category (Backdoor) and one
minority class attack category (Fuzzers) are defined as unknown attack categories and
included in the test set.
In the training set, the number of samples for each known attack category is set to 20,
and it does not contain any unknown class network traffic. Unknown attack categories
Unknown Category Malicious Traffic Detection Based on Contrastive Learning 243

are included in the test set to validate the proposed method’s ability to detect zero-day
attacks. The configurations for the training and test sets are shown in Table 4.

Table 4. Distribution of Experimental Data for Unknown Class Traffic Detection.

Database Known Category Traffic Category Training Set Test Set


CICIoT2023 Known Benign 20 800
DDoS 20 1600
Mirai 20 800
Recon 20 400
Web 20 100
BruteForce 20 100
Unknown DoS 0 800
Spoofing 0 400
UNSW-NB15 Known Benign 20 1600
Generic 20 1600
Exploits 20 800
Recon 20 800
Shellcode 20 400
Analysis 20 400
DoS 20 100
Worms 20 100
Unknown Backdoor 0 800
Fuzzers 0 100

The experimental results are shown in Table 5. The method presented in this paper
achieved detection rates of 89.73% and 86.74% for the unknown attack categories DoS
and Spoofing on the CICIoT2023 dataset, respectively, demonstrating good detection
performance. In the UNSW-NB15 dataset, the detection rates for the unknown attack
categories Backdoor and Fuzzers were 85.88% and 83.29%, respectively, also showing
satisfactory detection effects. These results indicate that the method proposed in this
paper can effectively detect new types of attacks when facing unknown malicious traffic.
244 L. Yan et al.

Table 5. Detection Results for Unknown Traffic Categories.

Database Known Category Traffic Category Detection Rate


CICIoT2023 Known Benign 98.88%
DDoS 98.78%
Mirai 95.12%
Recon 96.54%
Web 91.32%
BruteForce 90.13%
Unknown DoS 89.73%
Spoofing 86.74%
UNSW-NB15 Known Benign 97.87%
Generic 98.21%
Exploits 93.51%
Recon 93.47%
Shellcode 87.67%
Analysis 88.79%
DoS 86.99%
Worms 90.12%
Unknown Backdoor 85.88%
Fuzzers 83.29%

5 Conclusion
This paper proposes a novel unknown malicious traffic detection method based on con-
trastive learning. It not only effectively addresses the issue of class imbalance in traffic
datasets but also achieves precise detection of unknown malicious traffic. The method
was evaluated for binary and multi-class tasks using the CICIoT2023 and UNSW-NB15
datasets, and its advantages were verified through three key metrics: precision, recall,
and the F1 score. Compared with other methods, the proposed approach demonstrated
superior detection performance.
Future research will focus on enhancing the model’s applicability across various sce-
nario datasets to improve its generalization capabilities. Additionally, further increasing
the detection rate of unknown malicious traffic will be an important direction for future
study.

Acknowledgments. This work is supported, in part, by the National Natural Science Foundation
of China Grant No. 62172292 and 62472229.

Disclosure of Interests. The authors have no competing interests to declare that are relevant to
the content of this article.
Unknown Category Malicious Traffic Detection Based on Contrastive Learning 245

References
1. Geiger, A., Liu, D., Alnegheimish, S., et al.: TadGAN: time series anomaly detection using
generative adversarial networks. In: 2020 IEEE International Conference on Big Data (Big
Data). IEEE, pp. 33–43 (2020)
2. Duan, X., Fu, Y., Wang, K.: Network traffic anomaly detection method based on multi-scale
residual classifier. Comput. Commun. 198, 206–216 (2023)
3. Jiang, K., Wang, W., Wang, A., et al.: Network intrusion detection combined hybrid sampling
with deep hierarchical network. IEEE Access 8, 32464–32476 (2020)
4. Ma, Z., Li, J., Song, Y., et al.: Network intrusion detection method based on FCWGAN and
BiLSTM. Comput. Intell. Neurosci. 2022(1), 6591140 (2022)
5. Liu, T., Fu, Y., Kun, W., et al.: Network intrusion detection method based on VAE-CWGAN
and fusion of statistical importance of feature. J. Commun./Tongxin Xuebao 45(2) (2024)
6. Zhou, X., Liang, W., Shimizu, S., et al.: Siamese neural network based few-shot learning
for anomaly detection in industrial cyber-physical systems. IEEE Trans. Industr. Inf. 17(8),
5790–5798 (2020)
7. Cui, J., Zong, L., Xie, J., et al.: A novel multi-module integrated intrusion detection system
for high-dimensional imbalanced data. Appl. Intell. 53(1), 272–288 (2023)
8. Chen, T., Kornblith, S., Norouzi, M., et al.: A simple framework for contrastive learning of
visual representations[C]//International conference on machine learning. PMLR, pp. 1597–
1607 (2020)
9. Hjelm, Rd., Fedorov, A., Lavoie-Marchildon, S., et al.: Learning deep representations by
mutual information estimation and maximization. In: International Conference on Learning
Representations,International Conference on Learning Representations (2018)
10. Bachman, P., Hjelm, Rd., Buchwalter, W.: Learning representations by maximizing mutual
information across views. In: Neural Information Processing Systems,Neural Information
Processing Systems (2019)
11. Oord, A., Li, Y., Vinyals, O.: Representation learning with contrastive predictive coding.
Cornell University - arXiv (2018)
12. Hénaff Olivier, J., Srinivas, A., Fauw, J., et al.: Data-efficient image recognition with
contrastive predictive coding. arXiv: Computer Vision and Pattern Recognition (2019)
13. Liu, X., Zhang, F., Hou, Z., et al.: Self-supervised learning: Generative or contrastive. IEEE
Trans. Knowl. Data Eng. 35(1), 857–876 (2021)
14. Jaiswal, A., Babu, A.R., Zadeh, M.Z., et al.: A survey on contrastive self-supervised learning.
Technologies 9(1), 2 (2020)
15. Zhang, Y., Hooi, B., Hu, D., et al.: Unleashing the power of contrastive self-supervised visual
models via contrast-regularized fine-tuning. Adv. Neural. Inf. Process. Syst. 34, 29848–29860
(2021)
SoftPromptAttack: Advancing Backdoor
Attacks in Language Models Through
Prompt Learning Paradigms

Dixuan Chen, Hongyang Yan(B) , Jiatong Lin, Fan Chen, and Yu Cheng

School of Artificial Intelligence, Guangzhou University, Guangzhou 510700, China


[email protected], {linjiatong,2112306111}@e.gzhu.edu.cn

Abstract. Backdoor attacks on language models are a technique that


enables models to establish strong correlations between “poison sam-
ples” and the “target class” expected by attackers. In recent years, with
the continuous development of game theory between backdoor attacks
on language models and their corresponding defense techniques, various
types of backdoor attack methods such as input trigger, prompt trigger,
instruction trigger, and example trigger have achieved good performance.
However, existing methods suffer from issues such as triggers being easily
detected and low accuracy in predicting clean samples. To address these
issues, this paper introduces a novel backdoor attack approach grounded
in prompt learning. By employing the soft prompt template as both a
trigger and a tool for optimization, our method identifies the most effec-
tive soft prompts for diverse sample categories, achieving stealthy and
potent backdoor attacks. Our experiments indicate that this approach
outperforms existing methods in terms of effectiveness.

Keywords: Language Model · Prompt Learning · Clean Label ·


Prompt As Trigger · Backdoor Attack

1 Introduction

A language model backdoor attack is a technique where attackers maliciously


implant a “backdoor” into a model by constructing “triggers”. The poisoned
model performs almost identically to a clean model on clean samples; however,
when attackers feed toxic samples containing triggers into the model, the back-
door is activated, producing the results desired by the attacker. Most existing
methods that show good performance achieve backdoor implantation by explic-
itly inserting triggers into the training set or modifying the original labels of the
samples. However, these methods often lead to unnatural text, reduced fluency,
and high visibility of triggers, making them more susceptible to detection by
defense algorithms and significantly lowering the success rate of attacks. This
makes deploying them in real-world attack scenarios challenging. To address the
issue of evading detection, Zhao et al. [1] recently proposed a backdoor attack
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 246–257, 2025.
https://doi.org/10.1007/978-981-96-4566-4_17
Advancing Backdoor Attacks in LM Through Prompt Learning 247

technique based on prompt learning. Unlike previous methods, this approach


utilizes the widely adopted prompt learning paradigm, directly employing the
prompt template as the trigger during the training phase. Since no modifications
are made to the original samples and the sample labels remain unchanged, this
method can effectively resist detection by defense systems.
However, while this backdoor attack method can successfully evade detection
by defense algorithms, the key metrics for evaluating the success of backdoor
attacks are the attack success rate (ASR) and clean sample prediction accuracy
(CA). Higher values for these two indicators signify a more efficient and covert
backdoor attack, which significantly enhances the algorithm’s effectiveness in
real-world deployment. This article proposes a backdoor attack method based
on an adjustable and optimized soft prompt template as a trigger. The goal
of this method is to improve the success rate of backdoor attacks on language
models while maintaining the same clean sample prediction accuracy as regular
models, thereby meeting the practical requirements of various application sce-
narios. This approach uses a randomly initialized, parameterized soft prompt
template as a trigger, which is further categorized into clean and poisoned soft
prompt templates. Our objective is to use the poisoned prompt templates to
influence the output of language models. We conducted extensive experiments
to explore the backdoor attack on our language model based on prompt learning.
The results demonstrated that the soft prompt template itself can serve as an
effective trigger for the language model, with an attack success rate approaching
100%. The main contributions of this article are summarized as follows:
– We propose a novel clean-label language model backdoor attack method based
on soft prompt templates, called SoftPromptAttack, which directly uses
soft prompts as triggers to inject a backdoor into the language model.
– SoftPromptAttack demonstrates strong performance in real-world language
model backdoor attack scenarios. Notably, our method achieves an attack suc-
cess rate close to 100% without explicitly inserting triggers or altering labels,
while also improving the accuracy of clean sample predictions compared to
poisoning label attack methods.
– SoftPromptAttack highlights the potential risk of the prompt learning
paradigm being vulnerable to backdoor attacks. Through this study, we aim
to raise awareness of the need to prevent backdoor attacks on language models
and encourage better development in the natural language processing com-
munity.

2 Related Work
2.1 Textual Backdoor Attack

The method of text-based backdoor attacks in language models typically occurs


during the data collection phase of training, where the training datasets are mali-
ciously modified to embed a backdoor into the model. These modifications are
248 D. Chen et al.

designed to manipulate the model’s behavior when specific triggers are intro-
duced during inference, while maintaining the model’s performance on clean
data. For instance, Li et al. [2] introduced a model layer poisoning technique
known as PTM, which utilizes multiple character combinations as triggers. PTM
employs a shared linear layer to compute the poisoning loss output across each
layer of the model, making the model more sensitive to the poisoned data. In a
similar vein, Yang et al. [3] proposed using character-based triggers to poison the
training datasets, with particular focus on the model’s word embedding layer.
Their goal was to ensure that the trigger’s word embedding representation aligns
with the expected output when the trigger is activated. Additionally, Li et al. [4]
proposed two distinct attack strategies. The first method, known as character-
level perturbation, replaces certain characters in the text with visually similar
ones, a technique called homomorphic substitution, to serve as triggers. The
second method involves generating sentences using language models that closely
resemble the original text but contain subtle differences, which then function as
triggers. Zhang et al. [5] expanded on this concept by using a combination of
words and conjunctions to form more complex and varied triggers, enhancing the
stealthiness and effectiveness of the attack. Furthermore, Pan et al. [6] suggested
a more direct approach by either inserting or replacing specific words or phrases
directly in the text to create triggers, further diversifying the range of backdoor
attack techniques.

2.2 Prompt Learning


Prompt based learning is a new paradigm for language model training, which
originated from GPT-3 Brown et al. [7]. It believes that ultra large scale models
can maximize their reasoning and comprehension abilities as long as they are
paired with appropriate templates. Therefore, unlike the traditional “pre training
fine-tuning” paradigm Radford et al. [8]; E. Peters et al. [9], in this paradigm,
the pre trained model is no longer adapted to downstream tasks through target
engineering. Instead, the input and “prompt” are constructed into a template,
which serves as the overall input of the model, allowing downstream tasks to
adapt to the pre trained model.

2.3 Backdoor Attack Based on Prompt Learning


The idea of backdoor attacks on language models based on prompt learning is to
maliciously tamper with prompt templates for backdoor attacks. Du et al. [10]
proposed a trainable prompt template (soft prompt) that embeds a backdoor
into the prompt “itself”. Specifically, adding rare words such as “CF” as triggers
in some samples of the training set to form toxic samples and changing the labels
to the expected target class labels, while the rest of the samples are clean datasets
without any operations. During the training phase, freeze the model parameters
and train learnable prompt templates. Drawing on the idea of this indirect attack
model, Cai et al. [11] proposed a backdoor attack called BadPrompt, further
expanding the attack possibilities of Du et al. [10] ’s work under few shot settings.
Advancing Backdoor Attacks in LM Through Prompt Learning 249

Perez et al. [12] focuses on studying whether malicious prompt will affect the
output of the model and harm the model’s ability to leak the original prompt.
This work proposes two attack methods: “target hijacking” and “prompt leakage”.
Considering the characteristic of constructing prompt templates as input for
prompt based learning, Zhao et al. [1] proposed a clever attack method. This
method uses the prompt template itself as a trigger to construct specific prompt
templates for the target class samples as toxic samples and another specific
prompt template for the remaining samples as clean samples. This paper is also
the first to explore clean label backdoor attacks on language models based on
the prompt learning paradigm. Xu et al. [13] used multiple toxic commands to
conduct backdoor attacks on the model without modifying the original text and
tags. As long as the model receives toxic instructions, any input content will
be ignored and output the target class label expected by the attacker. Wang
et al. [14] proposed using adversarial example samples as “demonstrations” to
mislead models into making incorrect predictions.
In summary, there are two main types of language model backdoor attacks
mentioned above. One of them is that most existing methods tamper with the
original samples or labels during the data preprocessing stage. Such backdoor
attacks not only destroy the semantics of the original samples, but are also
easily detected by defense algorithms, which means they are difficult to deploy
in practical applications. Secondly, even without damaging the samples during
the data processing stage, most language model backdoor attacks cannot achieve
both attack success rate and clean sample prediction accuracy while maintaining
high performance.
Therefore, to solve these two types of problems, this article proposes the
SoftPromptAttack method. Ensuring the efficiency of backdoor attacks while
improving the prediction accuracy of clean samples, achieving high attack success
rates and strong concealment simultaneously. The SoftPromptAttack method
represents a significant step forward in language model security, offering a more
efficient and stealthy way to implement backdoor attacks without compromising
the overall performance of the model. By maintaining high prediction accuracy
for clean samples while achieving a near-perfect attack success rate, this method
holds great promise for both research and practical applications in the field of
AI security.

3 Proposed Method
3.1 Prompt Engineering
The goal of Prompt Engineering (PE) is to find the most suitable template for
the original sample, which is an idea that allows downstream tasks to adapt to
pre trained language models. These artificially created prompt can effectively
guide the model to make appropriate and accurate predictions. For example,
‘The sentence of the following sentence is <mask>: I love this movie!’, The
statement before the colon is ‘prompt’, and the model will make a prediction
output at the <mask> position.
250 D. Chen et al.

3.2 Hard Prompt and Soft Prompt


The prompt mentioned above are discrete (known as hard prompt), which have
the advantage of being understandable in natural language by humans. However,
in reality, these types of prompt are not optimal. Because such prompt require a
lot of experience and language expertise to construct suitable statements, when
dealing with complex tasks, searching for the best discrete prompt can easily
lead to local optima.
Since the purpose of PE is to find a way for LM to effectively perform tasks,
rather than for human use, there is no need to limit prompt to natural language
that can be interpreted by humans. Therefore, there are also some methods
called soft prompt, which can be executed directly in the embedding space. Soft
prompt is no longer a natural language, eliminating the limitation of templates
being parameterized by pre trained LM. On the contrary, soft prompt have their
own parameters that can be tuned based on training data from downstream
tasks.

3.3 Backdoor Attack Based on Soft Prompt


Specifically, as shown in Fig. 1, our method utilizes learnable soft prompt them-
selves as triggers to construct a “shortcut” between the soft prompt trigger and
the target class without requiring additional explicit triggers or modifying the
original sample labels, which affects the predictive behavior of the victim lan-
guage model. During the training phase, Prompt Engineering (PE) concatenates
a subset of samples labeled as the target class with poison soft prompt tem-
plates, while another subset of samples is concatenated with clean soft prompt
templates. The two types of processed new training samples are used as the
overall input language model for Prompt Learning to obtain the victim lan-
guage model. During the testing phase, when the attacker wishes to activate the
backdoor, they only need to concatenate the poison soft prompt onto the test
sample; When attackers expect to hide backdoor, they only need to concate-
nate clean soft prompt to the test samples, and the model will exhibit normal
performance. It is worth noting that:
The soft prompt template will be randomly initialized into two categories:
poison soft prompt and clean soft prompt. The length of the template and the
dimension of the vectors within it are both adjustable hyperparameters.
The datasets is divided into poisoning samples and clean samples based on
the poisoning rate, and two randomly initialized templates are added to each
sample. Then, the two types of samples spliced with the prompt template are
input into the victim model separately, and the probability distribution of the
output is obtained through the softmax function. By fixing the parameters of the
victim model and optimizing the parameters of only two types of soft prompt
templates according to the following backdoor attack objectives (1) (2), the
backdoor injection into the model is completed.

θpromptclean = argmin E(X,Y )∈Dclean L(F(Xclean , promptclean ), Y ) (1)


θ
Advancing Backdoor Attacks in LM Through Prompt Learning 251

θpromptpoison = argmin E(X,Y )∈Dpoison L(F(Xpoison , promptpoison ), Yt ) (2)


θ

where promptpoison represents the prompt used as the trigger, promptclean


denotes the prompt for clean samples, Y represents the original label of the clean
sample, while Yt represents the target label and L(·) denotes the cross-entropy
loss.

Fig. 1. In our backdoor attack method, the soft prompt template itself serves as a
trigger, and the original sample labels are not tampered with. Green represents clean
soft prompt templates, while red represents toxic soft prompt templates. (Color figure
online)

4 Experimental Results
This section will introduce the experimental details of the proposed backdoor
attack method, including the datasets, implementation details, and evaluation
metrics. Then, the experimental results of this method will be compared with
other attack methods and analyzed in detail.

4.1 Experimental Details


Datasets and Implementation Details. In this experiment, we selected
three text classification task datasets, including SST-2 sentiment classification
datasets, OLID language offense classification datasets, and AG ’s News news
topic classification datasets. The first two are binary classification datasets, and
252 D. Chen et al.

the third is a multi classification datasets. These text datasets contain rich
semantic knowledge, so they can comprehensively and effectively evaluate the
backdoor attack performance of the proposed method.
Our experiment was conducted on the BERT [15] model, which includes
two versions: Bert_base and Bert_1arge. All experiments in this section were
conducted on the NVIDIA RTX 3090 GPU platform and programmed using the
PyTorch framework and Python language.

Evaluation Metrics. To evaluate the performance of our backdoor attack


method, we used two metrics in the comparative experiment: Clean Accuracy
(CA), which measures the predictive accuracy of the victim language model
implanted with backdoor in clean test samples; Attack Success Rate (ASR) is
an indicator that measures the percentage of target classes predicted by the
victim language model in poison test samples.

4.2 Backdoor Attack Results


Comparison With Baseline Models. In this experiment, we selected
three text classification task datasets, including SST-2 sentiment classification
datasets, OLID language offense classification datasets, and AG ’s News news
topic classification datasets. The first two are binary classification datasets, and
the third is a multi classification datasets. These text datasets contain rich
semantic knowledge, so they can comprehensively and effectively evaluate the
backdoor attack performance of the proposed method.
We classified backdoor attack methods into two categories: Poison Labels and
Clean Labels. Among them, Poison Labels includes The BadNet Gu et al. [16],
LWS Qi et al. [17], and SynAttack Qi et al. [18]. The main idea of these three
methods is to tamper with the text or syntactic structure to achieve backdoor
injection. The RIPPLES Kurita et al. [19] activates the model backdoor by
manipulating the weights of language model parameters using rare words. The
BToP Xu et al. [20] is a backdoor attack algorithm based on prompt learning.
This type of Poison Labels method will tamper with the labels of the dataset,
so we believe that this type of method is more prone to injecting backdoors into
the model, and they are not easy to escape defense algorithms or even manual
detection. Clean Labels includes Triggerless Gan et al. [21], which is a backdoor
attack method that does not rely on triggers. ProAttack is the first article to use
prompt templates as triggers and is the main comparison object of this article.
This type of Clean Labels method does not tamper with sample labels, so we
believe that it is difficult to inject backdoors into the model. At the same time,
it has high concealment and can largely evade detection by defense algorithms.
We believe that attacks in the poison labels category are more likely to
achieve a higher level of ASR, while our attack method belongs to the clean labels
category. According to the experimental results in Table 1, it can be concluded
that our attack method maintains an ASR equivalent to that of the Poisson
labels attack when compared to the Poisson labels category. Since our method
Advancing Backdoor Attacks in LM Through Prompt Learning 253

Table 1. Comparison of our method with the results of existing baseline models.

Dataset Model BERT_base BERT_large


CA ASR CA ASR
SST-2 Normal 91.79 – 92.88 –
Prompt 91.61 – 92.67 –
PoisonLabels BadNet 90.92 100 – –
SynAttack 90.94 98.1 – –
LWS 88.66 97.2 90.01 97.4
RIPPLES 90.73 100 91.6 100
BToP 91.32 98.68 92.64 99.89
CleanLabels Triggerless 89.78 98 90.87 99.1
ProAttack 91.68 100 93 99.92
OurAttack 93.55 99.82 94.65 99.96
OLID Normal 84.02 – 84.58 –
Prompt 84.57 – 83.87 –
PoisonLabels BadNet 82.01 100 – –
SynAttack 82.52 99.1 – –
LWS 82.95 97.1 81.4 97.9
RIPPLES 83.37 100 83.7 100
BToP 84.73 98.33 85.08 99.16
CleanLabels Triggerless 83.11 99 82.54 100
ProAttack 84.94 100 84.57 100
OurAttack 85.75 100 85.65 100
AG’s News Normal 93.72 – 93.60 –
Prompt 93.85 – 93.74 –
PoisonLabels BadNet 93.96 100 – –
SynAttack 94.33 100 – –
LWS 92.01 99.6 92.6 99.5
RIPPLES 92.37 100 91.6 100
BToP 93.45 91.48 93.66 7.74
CleanLabels Triggerless 92.58 92.87 90.15 99.1
ProAttack 93.5 99.54 93.8 98.72
OurAttack 96.26 99.76 93.83 98.69

does not tamper with the labels, this also increases the difficulty of backdoor
attacks. Therefore, achieving an ASR equivalent to Poisson labels is considered
acceptable. It is worth noting that our method achieved the highest percentage
ASR on the OLID datasets. For the CA metric, our method showed the best
results on all three datasets. This means that our method can improve CA to a
certain extent without sacrificing ASR, which will greatly benefit the detection
254 D. Chen et al.

of escape defense algorithms. When comparing clean labels of the same category,
we can also draw similar conclusions. By comprehensively comparing ASR and
CA, our method shows the best performance compared to the first two methods
of the same category.

Comparison With Reference Methods. To further evaluate the effective-


ness of our method, we conducted another controlled trial. We analyzed the
impact of the number of poisoned samples on two evaluation indicators, CA
and ASR, and conducted a comparative experiment with the reference method
ProAttack [1]. This experiment was conducted on two types of datasets and two
versions of the Bert model.

Fig. 2. Comparison experiment with ProAttack, where the horizontal axis represents
the number of poisons administered and the vertical axis represents ASR.

As shown in the results of Fig. 2, with the increase of the poisoning sample
rate, it can be observed that our method’s ASR curve rises more rapidly and
steeply, and has a smaller variance compared to the baseline. This means that
our method can achieve the same performance ASR with fewer samples, and
can efficiently and stably inject backdoor into the model, effectively inducing
the model to make predictions.

Effectiveness Against Backdoor Attack Defense Methods. To further


verify whether our backdoor attack method can be effectively deployed in real-
world scenarios, we evaluated its effectiveness against two commonly used back-
door attack defense methods. The defense algorithms include ONION [22] and
SCPD [23].
Advancing Backdoor Attacks in LM Through Prompt Learning 255

Table 2. The results of different defense methods against OurAttack.

Dataset Model BERT_base BERT_large


CA ASR CA ASR
SST-2 OurAttack 93.55 99.82 94.65 99.96
SCPD 76.65 43.43 77.72 34.48
ONION 90.36 74.91 92.12 81.95
OLID OurAttack 85.75 100 85.65 100
SCPD 74.81 98.98 77.39 98.93
ONION 86.58 99.86 84.53 99.62
AG’s News OurAttack 96.26 99.76 93.83 98.69
SCPD 78.95 42.18 79.44 21.05
ONION 93.46 98.12 95.03 75.28

The experimental results are shown in Table 2. Under SCPD defense, our
method showed a certain degree of decrease in ASR on SST-2 and AG’s News
datasets, but CA still maintained a good level. Under ONION’s defense, our
methods can achieve over 70% ASR and maintain over 80% CA. The analy-
sis results indicate that our method can effectively evade detection by defense
algorithms, maintain a high success rate of attacks, and also have decent con-
cealment. We believe that this is very advantageous for practical public settings.

5 Conclusion
In this article, we propose a novel language model backdoor attack method
based on soft prompt templates with clean labels. To inject a backdoor into
the language model, it directly uses soft prompts as triggers. Without explicitly
inserting triggers or altering clean labels, the method achieves an attack success
rate close to 100%, making it the highest-performing method in its category.

Limitations. Although our method has demonstrated strong performance, we


acknowledge that it still has limitations. First, our attack method has primarily
been validated on text classification tasks. Given the diversity of current lan-
guage processing tasks, its performance on other tasks remains to be tested.
Second, while our method has shown promising results on self-coding models
such as BERT, its effectiveness on models from other categories still requires
further validation. Finally, as new backdoor attack defense methods emerge, the
effectiveness of our approach will need to be reassessed.

Acknowledgments. This work was supported by the National Natural Science Foun-
dation of China (No. 62372130) and the Guangzhou Basic and Applied Basic Research
Project (2023A04J1725, 2024A03J0397).
256 D. Chen et al.

References
1. Zhao, S., et al.: Prompt as triggers for backdoor attack: examining the vulnerability
in language models. arXiv preprint arXiv:2305.01219 (2023)
2. Li, L., et al.: Backdoor attacks on pre-trained models by layerwise weight poisoning.
arXiv preprint arXiv:2108.13888 (2021)
3. Yang, W., et al.: Be careful about poisoned word embeddings: exploring the vulner-
ability of the embedding layers in NLP models. arXiv preprint arXiv:2103.15543
(2021)
4. Li, S., et al.: Hidden backdoors in human-centric language models. In: Proceedings
of the 2021 ACM SIGSAC Conference on Computer and Communications Security
(2021)
5. Zhang, X., et al.: Trojaning language models for fun and profit. In: 2021 IEEE
European Symposium on Security and Privacy (EuroS&P). IEEE (2021)
6. Pan, X., et al.: Hidden trigger backdoor attack on NLP models via linguistic style
manipulation. In: 31st USENIX Security Symposium (USENIX Security 2022)
(2022)
7. Mann, B., et al.: Language models are few-shot learners. arXiv preprint
arXiv:2005.14165, vol. 1 (2020)
8. Radford, A.: Improving language understanding by generative pre-training (2018)
9. Peters, M.E., et al.: Deep contextualized word representations. arXiv:1802.05365
(2018)
10. Du, W., et al.: PPT: backdoor attacks on pre-trained models via poisoned prompt
tuning. IJCAI (2022)
11. Cai, X., et al.: BadPrompt: backdoor attacks on continuous prompt. In: Advances
in Neural Information Processing Systems, vol. 35 (2022)
12. Perez, F., Ribeiro, I.: Ignore previous prompt: attack techniques for language mod-
els. arXiv preprint arXiv:2211.09527 (2022)
13. Xu, J., et al.: Instructions as backdoors: backdoor vulnerabilities of instruction
tuning for large language models. arXiv preprint arXiv:2305.14710 (2023)
14. Wang, J., et al.: Adversarial demonstration attacks on large language models. arXiv
preprint arXiv:2305.14950 (2023)
15. Devlin, J., Chang, M.-W., Lee, K., Toutanova, K.: BERT: pre-training of deep
bidirectional transformers for language understanding. In: Proceedings of naacL-
HLT, vol. 1.0 (2019)
16. Gu, T., Dolan-Gavitt, B., Garg, S.: BadNets: identifying vulnerabilities in the
machine learning model supply chain. arXiv preprint arXiv:1708.06733 (2017)
17. Qi, F., et al.: Turn the combination lock: learnable textual backdoor attacks via
word substitution. arXiv preprint arXiv:2106.06361 (2021)
18. Qi, Fanchao, et al.: Hidden killer: invisible textual backdoor attacks with syntactic
trigger. arXiv preprint arXiv:2105.12400 (2021)
19. Kurita, K., Michel, P., Neubig, G.: Weight poisoning attacks on pre-trained models.
arXiv preprint arXiv:2004.06660 (2020)
20. Xu, L., et al.: Exploring the universal vulnerability of prompt-based learning
paradigm. arXiv preprint arXiv:2204.05239 (2022)
21. Gan, L., et al.: Triggerless backdoor attack for NLP tasks with clean labels. arXiv
preprint arXiv:2111.07970 (2021)
22. Qi, F., et al.: Onion: a simple and effective defense against textual backdoor attacks.
arXiv preprint arXiv:2011.10369 (2020)
Advancing Backdoor Attacks in LM Through Prompt Learning 257

23. Qi, F., et al.: Hidden killer: invisible textual backdoor attacks with syntactic trig-
ger. arXiv preprint arXiv:2105.12400 (2021)
24. Author, F., Author, S.: Title of a proceedings paper. In: Editor, F., Editor, S. (eds.)
CONFERENCE 2016. LNCS, vol. 9999, pp. 1–13. Springer, Heidelberg (2016).
https://doi.org/10.10007/1234567890
25. Author, F., Author, S., Author, T.: Book title, 2nd edn. Publisher, Location (1999)
26. Author, A.-B.: Contribution title. In: 9th International Proceedings on Proceed-
ings, pp. 1–2. Publisher, Location (2010)
27. LNCS Homepage. http://www.springer.com/lncs. Accessed 25 Oct 2023
Removing Regional Steering Vectors
to Achieve Knowledge Domain Forgetting
in Large Language Models

Wei Wu1,2,3 , Chen Wang1,2,3(B) , Qiuhao Xu1,2,3 , and Wei Kong1,2,3


1
School of Information Science and Engineering (School of Cyber Science and
Technology), Zhejiang Sci-Tech University, Hangzhou 310018, China
[email protected] , [email protected]
2
Zhejiang Key Laboratory of Digital Fashion and Data Governance, Zhejiang
Sci-Tech University, Hangzhou 310018, China
3
Zhejiang Provincial International Cooperation Base for Science and Technology on
Cloud Computing Security and Data Aggregation, Zhejiang Sci-Tech University,
Hangzhou 310018, China

Abstract. Large language models (LLMs), as a significant direction


in the development of artificial intelligence in recent years, are becom-
ing increasingly popular. These models demonstrate significant potential
in applications such as personalized services, simulated dialogues, role-
playing, and specific compliance requirements. However, with the con-
tinuous expansion of LLMs’ practical applications, selective knowledge
management of the models—particularly the ability to make the mod-
els “forget” certain domain-specific or topic-specific knowledge without
impacting their overall performance or knowledge in other areas—has
emerged as a critical issue that demands resolution. We proposes an inno-
vative knowledge domain forgetting method designed to enable models
to selectively forget specified knowledge without the need for fine-tuning.
The method identifies and removes steering vectors associated with spe-
cific knowledge domains, thereby achieving effective forgetting of specific
knowledge. The proposed approach has been evaluated in several popular
open-source LLMs. Experimental results show that this method achieves
good knowledge forgetting effects across diverse scenarios and exhibits
notable practical value.

Keywords: Large Language Model · Model Unlearning · Steering


Vector · Artificial Intelligence · Machine Learning

1 Introduction
LLMs are seeing increasingly broad applications in the field of artificial intel-
ligence [1, 2], not only advancing the frontiers of technology but also playing
a crucial role in numerous practical scenarios, including personalized services,
simulated dialogues, role-playing and specific compliance requirements. As the
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 258–270, 2025.
https://doi.org/10.1007/978-981-96-4566-4_18
Removing Regional Steering Vectors 259

capabilities of these models continue to expand, their application contexts have


become increasingly diverse, ranging from everyday communication to various
professional domains.
However, in numerous practical applications, situations arise in which mod-
els require the ability to eliminate specific knowledge domains, particularly in
role-playing scenarios and contexts involving sensitive information [3]. In these
instances, it is imperative that the models do not retain knowledge pertaining
to certain domains. For example, when assuming a role, a model should lack the
capacity to respond in multiple languages indiscriminately, nor should it facili-
tate uncontrolled disclosure of sensitive information. Attaining precise domain-
specific knowledge forgetting during the data preprocessing phase of model train-
ing presents significant challenges for developers. Likewise, undertaking compre-
hensive model training from scratch, decoupled from existing open-source LLMs,
to ensure the erasure of specific knowledge domains poses considerable obsta-
cles for users. The expenses associated with extensive model training frequently
render this approach untenable for most stakeholders [3].
Recent work aims to achieve better control over the range of knowledge
domains within models through methods such as training low-rank matrix fine-
tuning [4, 5] and designing prompt engineering [3, 6]. These methods are gen-
erally more practical than full training of LLMs. However, as the demands for
cost control and the effectiveness of knowledge forgetting increase rapidly among
LLM users, the effectiveness of these methods becomes very limited. Therefore,
exploring other methods of knowledge forgetting has significant research value.
In this work, we take the forgetting of model knowledge domains with the
completion of two typical objectives as an example to illustrate the universal-
ity of this method. Firstly, the establishment of safety guardrails in LLMs is
not entirely based on adjusting the data in the training dataset. There is also
guidance of model weights through alignment methods, which results in residual
flow activations that tend to refuse to answer having distinct regional directions.
Therefore, we removed the knowledge domains in the model that tend to refuse
to answer in Objective 1. Secondly, many popular open-source LLMs possess
multilingual capabilities, and the residual flow activations of each language also
have certain regional direction characteristics. To address this issue, the proposed
method removes specific language knowledge domains in multilingual capability
models in Objective 2. Thirdly, the selection of model layers had a significant
impact in Objectives 1 and 2. So we conducted comparative experiments on the
treatment effects of different hidden layers in the model in Objective 3.
The contributions of this work are as follows:

1) A method for forgetting specific knowledge domains in LLMs through regional


steering vectors is proposed, which requires few resources and achieves signif-
icant results, demonstrating universality across multiple open-source LLMs.
2) The proposed method can be applied to multiple target forgetting scenarios
of LLMs.
3) The implementation effects of this method on different models and different
hidden layers were tested, and the reasons were speculated.
260 W. Wu et al.

2 Related Works
2.1 Hidden Layer
The hidden layer refers to the part of a neural network located between the input
layer and the output layer, used for feature extraction and nonlinear transfor-
mation. In the transformer architecture adopted by most LLMs [7], the hidden
layer contains millions or even billions of parameters, with each layer learning
and representing different features. Not all hidden layers are suitable for feature
extraction [8]. Lower layers are typically used for extracting simpler and more
concrete features, while higher layers are generally used for extracting more
abstract and complex features [9]. In the initial few layers of the hidden layer,
LLMs are more capable of handling simple tasks. It is only in the deeper lay-
ers that large models can tackle more complex tasks [10]. However, due to the
“black box” nature of LLMs [11], further research on hidden layers remains to
be explored.

2.2 Steering Vector

Steering vectors refer to a method for controlling the behavior of LLMs by adding
vectors that encode the desired behavior to modify the internal activations of
the model [12]. Steering vectors can be used to define the latent guiding space
of sentences under a language model and for the precise generation of sentences
[13]. During training, steering vectors can be used to adjust the model weights,
making it more focused on key features.
Despite being more economical than fine-tuning models [14] and more robust
than prompt engineering [15], steering vectors as a method for controlling the
behavior of LLMs are very promising. However, their applications still lack suffi-
cient exploration [12]. After proving the existence and extraction of steering vec-
tors in LLMs [13], steering vectors have been widely used in adversarial retrieval
of unlearned information in LLMs [16] and in removing the refusal direction of
LLMs [17]. In the work of adversarial retrieval of unlearned information, the
required steering vectors were obtained by calculating the average difference
between anonymous prompts and original prompts and then adding the steering
vectors back into the model during inference to achieve adversarial attacks. In
the work of removing the refusal direction of LLMs, the required steering vec-
tors were similarly obtained by calculating the average difference between refusal
prompts and original prompts, and then removing the steering vectors from the
model weights during inference, thus removing the model’s safety guardrails.
Many modification methods based on steering vector technology have achieved
significant improvements over previous methods in various aspects. However,
there is still room for improvement in the general applicability of these methods.
Removing Regional Steering Vectors 261

3 Dataset Processing and Removal of Steering Vectors


In this section, we first introduce how to create the dataset required for the
experiments. Then, we provide a detailed explanation of how this method uses
these datasets to extract and remove steering vectors.

3.1 Creating the Dataset

We used the Dharmless harmless dataset compiled in [17]. In the first 80% of
this dataset, we randomly sampled 200 items as the general knowledge prompt
dataset. The remaining 20% was used as the test dataset.
For the first objective, we attempted to remove the knowledge domains where
LLMs tend to refuse to answer. When the model detects issues such as privacy
and security, it often responds with “Sorry, I can’t...” or “It’s against...” and so
on. We referred to these responses as refusal-prone responses. These responses
have obvious common characteristics, so this objective can use the method of
removing steering vectors as a way to forget the knowledge domains. We used the
Dharmf ul harmful dataset compiled in [17]. In the first 80% of this dataset, we
randomly sampled 200 items as the refusal-prone prompt dataset. The remaining
20% was used as the test dataset, as shown in Fig. 1.

Fig. 1. Preparation of the dataset for removing the tendency of LLMs to refuse to
answer prompt words.

For the second objective, we attempted to remove the knowledge domains


of specific languages from LLMs. Since responses in the same language have
262 W. Wu et al.

obvious common characteristics of the language knowledge domain, this objec-


tive can also use the method of removing steering vectors as a way to forget the
knowledge domains. We also used the Dharmness harmless dataset as our general
knowledge prompt dataset. Subsequently, we used machine translation on this
general knowledge prompt dataset to create datasets with the same semantics
but expressed in different languages, each containing 200 prompts as the prompt
dataset for the knowledge domain to be forgotten. Here, we used Chinese and
Japanese as examples, named Dchinese and Djapanese , as shown in Fig. 2.

Fig. 2. Data Processing Workflow for Specific Language Domains. We strive to ensure
that each group covers as broad a range of domains as possible within its language.
This way, each group can present more obvious residual stream data differences in
subsequent stages, leading to clearer steering vectors for the areas that need to be
forgotten.

3.2 Extracting Steering Vectors

To extract potential steering vectors from the model’s residual stream data, we
processed the general knowledge and the knowledge domains to be forgotten
datasets using the average difference method [18], which effectively isolates key
Removing Regional Steering Vectors 263

feature directions. By extracting steering vectors [19], we can obtain steering


vectors that represent the knowledge domains to be forgotten with clear charac-
teristics [17]. Therefore, we attempted to further process these steering vectors
as follows:

(l) 1  (l)
μi = xi (t), (1)
|Dremove |
t∈Dremove

(l) 1  (l)
νi = xi (t). (2)
|Dgeneral |
t∈Dgeneral

We calculated the average activation data v̄general of the domain-specific


knowledge dataset to be deleted and the average activation data v̄general of the
general knowledge prompts dataset. By subtracting v̄remove from v̄general , we
obtain the differential mean vector vdir . The direction and length of this vector
represent the relative position of this knowledge domain in the model’s residual
stream, which is the regional steering vector of this knowledge domain.

3.3 Removing Regional Steering Vector


We removed the impact of the acquired regional steering vectors on the model’s
weights. By doing so, the model will not successfully access the knowledge that
has already been removed. Specifically, by employing weight orthogonalization,
we can orthogonally remove the regional steering vectors from the original model
weights, ensuring that the modified model weights no longer represent the con-
tent of that region. We utilized the middle layers of the model for the extraction
and removal of regional steering vectors, as these layers achieve a good balance
between abstraction and concreteness, allowing for a comprehensive representa-
tion of the data.

x ← x − rrT x (3)

4 Experiments
To demonstrate that our method can effectively enable the model to forget cer-
tain knowledge domains, we conducted extensive experiments using the pre-
viously prepared dataset. Furthermore, we expanded the experiments to more
steering vector extraction layers and open-source LLMs to prove the universality
of our method.
We conducted all our experiments on an Ubuntu 22.04 server with a vGPU-
32GB graphics card, using Python 3.12, PyTorch 2.3.0, and Cuda 12.1. In the
experiment, we used the 1.5B parameter sizes of the Instruct versions of the open-
source model Qwen 2 [20] and the 3.8B parameter size of the Instruct version
of the open-source model Phi-3-mini as test language models. These models
were chosen because of their good multilingual communication capabilities and
suitable model parameter sizes for testing. The technology of extracting steering
vectors itself is widely applicable in language models.
264 W. Wu et al.

Table 1. Test data for removing the knowledge domain of the tendency to refuse
direction

Model Layer No-Refusal Rate Test Amount


Qwen2-1.5B-Instruct(Original) / 26.53% 980
Qwen2-1.5B-Instruct 13 88.16% 980
Qwen2-1.5B-Instruct 14 94.08% 980
Qwen2-1.5B-Instruct 15 95.71% 980
Qwen2-1.5B-Instruct 16 93.88% 980

Fig. 3. In the test, we used the same system prompt and user prompt to compare the
original version and the interfered version of the same model. It can be seen that in the
response of the original model, the model refused to answer; whereas in the response
of the interfered model, the model successfully completed the answer.

4.1 Objective 1: Removing Refusal-Prone Knowledge Domains

We used 100 entries from the Dharmf ul dataset as a test dataset to test the
processed model. Since refusal-prone responses often include expressions such as
“I can’t provide”, “I can’t generate”, “I can’t create”, “I can’t write”, “I can’t help”,
“I can’t assist”, “I can’t for you”, “I can’t support”, “I can’t”, “I won’t”, “I strongly
oppose”, in this objective, we consider responses containing these expressions as
refusal-prone responses.
As shown in Table 1, the processed Qwen2 model achieved a significant
increase in the non-refusal response rate in several selected hidden layers com-
pared to the original model. Examples of responses are illustrated in Fig. 3.
Removing Regional Steering Vectors 265

Table 2. Test data for removing specific language knowledge domains

Model Layer Remove Language Target Language Rate


Qwen2-1.5B-Instruct(Original) / Chinese 87.82%
Qwen2-1.5B-Instruct 19 Chinese 76.92%
Qwen2-1.5B-Instruct 20 Chinese 70.33%
Qwen2-1.5B-Instruct 21 Chinese 7.87%
Qwen2-1.5B-Instruct 22 Chinese 4.24%
Qwen2-1.5B-Instruct 23 Chinese 21.48%
Phi-3-mini-4k-instruct(Original) / Japanese 76.81%
Phi-3-mini-4k-instruct 27 Japanese 35.33%
Phi-3-mini-4k-instruct 28 Japanese 34.85%
Phi-3-mini-4k-instruct 29 Japanese 2.08%
Phi-3-mini-4k-instruct 30 Japanese 4.83%
Phi-3-mini-4k-instruct 31 Japanese 5.99%

4.2 Objective 2: Removing Knowledge Domains of Specific


Languages
We used 100 entries from Dchinese as the test dataset for the processed model.
During testing, we found that extracting and removing steering vectors from
the middle layers of the model failed to achieve the desired effect. Therefore, we
selected the deeper hidden layers of the test model, as the position for extracting
and removing the regional steering vectors.
In addition, since the specific language knowledge domain is more ambigu-
ous than the tendency to refuse knowledge domains, we selected 64 prompts to
extract the steering vectors of the knowledge domain, thereby achieving more
precise targeting of the specific language knowledge domain vectors.
As shown by Table 2, the processed Qwen2 and Phi-3 models have signif-
icantly reduced the specific language response rates in several selected hidden
layers compared to the original models. Examples of responses are illustrated in
Fig. 4, 5.

4.3 Testing the Impact of Different Hidden Layers


on the Effectiveness of Knowledge Domain Forgetting
In the objective of forgetting knowledge domains with a tendency to refuse, the
middle layers of the model achieved better results. However, in the objective
of removing specific language knowledge domains, only the deep layers of the
model could achieve better results. Therefore, we conducted additional tests on
the effects of hidden layers for the objective of forgetting knowledge domains
with a tendency to refuse, as shown in Fig. 6 and Table 3.
We speculate that this is because the middle layers can achieve a relatively
balanced comprehensive performance between concrete and abstract semantics.
266 W. Wu et al.

Fig. 4. Remove Specific Language Result. By asking both the original model and the
processed model “Please speak to me in ...(a certain language)” and checking the pro-
portion of specified language text in the returned text after removing punctuation, we
tested the Qwen2 model and the Phi-3 model. In our tests, the Qwen2 model’s response
rate in Chinese was 87.82% in its original state, which dropped to a minimum of 4.24%
after processing at the 22nd layer. The Phi-3 model’s response rate in Japanese was
76.81% in its original state, which dropped to a minimum of 2.08% after processing at
the 29th layer.

In the objective of forgetting specific language knowledge domains, the deep


layers of the model achieved better results. We speculate that this is because
the deep layers exhibit more profound semantics and are more closely related
to the model’s decoder output, which is more closely aligned with the language
knowledge domain of the model output, thereby yielding better results.

4.4 Analysis

Based on the above experiments, the method of removing regional steering vec-
tors has achieved excellent results in making the model forget knowledge domains
that tend to refuse to answer and specific language knowledge domains. Although
parameters such as model layers may need to be adjusted for different spe-
cific objectives due to the unique characteristics of different hidden layers, we
believe that this method of making LLMs forget knowledge domains by removing
regional steering vectors has its universality.
Removing Regional Steering Vectors 267

Fig. 5. In the test, we conducted experiments on the interfered model. It can be seen
that the multilingual model Qwen2 [20], originally capable of English, Chinese and
Japanese, did not successfully output content in Chinese when asked to respond in
Chinese after the removal of the Dchinese Chinese knowledge domain. Specifically, sev-
eral responses marked with triangles in the figure. Additionally, the model mistakenly
believed it was responding in Chinese.

Fig. 6. Qwen2-1.5B-Instruct Model No-Refusal Result. In the experiment, we con-


ducted more tests on the Qwen2 model for removing the knowledge domain with a
tendency to refuse answers in different hidden layers. In the original model, the non-
refusal response rate reached 26.53%, while after processing at the 23rd layer, the
non-refusal response rate increased to a maximum of 100%. In the results, we can also
observe that deeper layers achieved better processing effects.
268 W. Wu et al.

Table 3. Test data for removing the knowledge domain of the tendency to refuse
direction in different hidden layers

Model Layer No-Refusal Rate Test Amount


Qwen2-1.5B-Instruct(Original) / 26.53% 980
Qwen2-1.5B-Instruct 1 45.71% 980
Qwen2-1.5B-Instruct 2 44.29% 980
Qwen2-1.5B-Instruct 3 24.29% 980
Qwen2-1.5B-Instruct 4 48.37% 980
Qwen2-1.5B-Instruct 5 44.08% 980
Qwen2-1.5B-Instruct 6 54.69% 980
Qwen2-1.5B-Instruct 7 35.10% 980
Qwen2-1.5B-Instruct 8 43.27% 980
Qwen2-1.5B-Instruct 9 40.20% 980
Qwen2-1.5B-Instruct 10 49.59% 980
Qwen2-1.5B-Instruct 11 84.69% 980
Qwen2-1.5B-Instruct 12 57.96% 980
Qwen2-1.5B-Instruct 13 88.16% 980
Qwen2-1.5B-Instruct 14 94.08% 980
Qwen2-1.5B-Instruct 15 95.71% 980
Qwen2-1.5B-Instruct 16 93.88% 980
Qwen2-1.5B-Instruct 17 98.37% 980
Qwen2-1.5B-Instruct 18 78.57% 980
Qwen2-1.5B-Instruct 19 96.73% 980
Qwen2-1.5B-Instruct 20 79.59% 980
Qwen2-1.5B-Instruct 21 100% 980
Qwen2-1.5B-Instruct 22 99.39% 980
Qwen2-1.5B-Instruct 23 92.04% 980

5 Conclusion
We investigates the method of achieving knowledge domain forgetting in LLMs
by removing regional steering vectors. Based on this discovery, we design a pro-
cess method involving dataset design, steering vector extraction, and steering
vector ablation, to more generally achieve knowledge forgetting in LLMs. The
proposed method not only reduces resource consumption compared to existing
mainstream methods but also has broader applicability compared to previous
research on steering vectors. Experimental results prove the effectiveness and
general applicability of our proposed method. In the future work, we will focus
on the applicability of this method in more knowledge forgetting scenarios and
LLMs.
Removing Regional Steering Vectors 269

Acknowledgments. The work is supported by the National Key R&D Program of


China (No. 2023YFB2703700), the National Natural Science Foundation of China (Nos.
U21A20465, 62302457, 62172292), the Zhejiang Provincial Natural Science Foundation
of China (No. LQ24F020008), the Program for Leading Innovative Research Team of
Zhejiang Province (No. 2023R01001), the Fundamental Research Funds of Zhejiang
Sci-Tech University (No. 22222266-Y) and the “Pioneer” and “Leading Goose” R&D
Program of Zhejiang (No. 2023C01119).

References
1. Guan, F., Zhu, T., Sun, H., Zhou, W., Philip, S.Y.: Large language models for
link stealing attacks against graph neural networks. IEEE Trans. Big Data (2024).
https://doi.org/10.1109/TBDATA.2024.3489427
2. Kirchenbauer, J., Geiping, J., Wen, Y., Katz, J., Miers, I., Goldstein, T.: A water-
mark for large language models (2024). https://doi.org/10.48550/arXiv.2301.10226
3. Hoffmann, J., et al.: Training compute-optimal large language models. arXiv
preprint arXiv:2203.15556 (2022). https://doi.org/10.48550/arXiv.2203.15556
4. Hu, E.J., et al.: LoRA: low-rank adaptation of large language models (2021).
https://doi.org/10.48550/arXiv.2106.09685
5. Xu, H., Zhu, T., Zhang, L., Zhou, W., Philip, S.Y.: Update selective parameters:
federated machine unlearning based on model explanation. IEEE Trans. Big Data
(2024). https://doi.org/10.1109/TBDATA.2024.3409947
6. Lester, B., Al-Rfou, R., Constant, N.: The power of scale for parameter-efficient
prompt tuning (2021). https://doi.org/10.48550/arXiv.2104.08691
7. Vaswani, A., et al.: Attention is all you need (2023). https://doi.org/10.48550/
arXiv.1706.03762
8. Artzy, A.B., Schwartz, R.: Attend first, consolidate later: on the importance of
attention in different LLM layers. In: Belinkov, Y., Kim, N., Jumelet, J., Mohebbi,
H., Mueller, A., Chen, H. (eds.) Proceedings of the 7th BlackboxNLP Workshop:
Analyzing and Interpreting Neural Networks for NLP, Miami, Florida, USA, pp.
177–184. Association for Computational Linguistics (2024). https://doi.org/10.
18653/v1/2024.blackboxnlp-1.10
9. Liu, Z., Kong, C., Liu, Y., Sun, M.: Fantastic semantics and where to find them:
investigating which layers of generative LLMs reflect lexical semantics. In: Ku,
L.-W., Martins, A., Srikumar, V. (eds.) Findings of the Association for Compu-
tational Linguistics: ACL 2024, Bangkok, Thailand, pp. 14551–14558. Association
for Computational Linguistics (2024). https://doi.org/10.18653/v1/2024.findings-
acl.866
10. Jin, M., et al.: Exploring concept depth: how large language models acquire knowl-
edge at different layers? (2024). https://doi.org/10.48550/arXiv.2404.07066
11. Hawashin, H., Sadrzadeh, M.: Multimodal structure-aware quantum data process-
ing (2024). https://doi.org/10.48550/arXiv.2411.04242
12. Mayne, H., Yang, Y., Mahdi, A.: Can sparse autoencoders be used to decompose
and interpret steering vectors? (2024). https://doi.org/10.48550/arXiv.2411.08790
13. Subramani, N., Suresh, N., Peters, M.: Extracting latent steering vectors from
pretrained language models. In: Muresan, S., Nakov, P., Villavicencio, A. (eds.)
Findings of the Association for Computational Linguistics: ACL 2022, Dublin,
Ireland, pp. 566–581. Association for Computational Linguistics (2022). https://
doi.org/10.18653/v1/2022.findings-acl.48
270 W. Wu et al.

14. Wang, W., Yang, J., Peng, W.: Semantics-adaptive activation intervention for
LLMs via dynamic steering vectors (2024). https://doi.org/10.48550/arXiv.2410.
12299
15. Chalnev, S., Siu, M., Conmy, A.: Improving steering vectors by targeting sparse
autoencoder features (2024). https://doi.org/10.48550/arXiv.2411.02193
16. Seyitoğlu, A., Kuvshinov, A., Schwinn, L., Günnemann, S.: Extracting unlearned
information from LLMs with activation steering (2024). https://doi.org/10.48550/
arXiv.2411.02631
17. Arditi, A., et al.: Refusal in language models is mediated by a single direction
(2024). https://doi.org/10.48550/arXiv.2406.11717
18. Mallen, A., Brumley, M., Kharchenko, J., Belrose, N.: Eliciting latent knowledge
from quirky language models (2024). https://doi.org/10.48550/arXiv.2312.01037
19. Subramani, N., Suresh, N., Peters, M.E.: Extracting latent steering vectors from
pretrained language models (2022). https://doi.org/10.48550/arXiv.2205.05124
20. Qwen2 technical report (2024)
A Novel and Efficient Multi-scale
Spatio-Temporal Residual Network
for Multi-class Intrusion Detection

Nan Li1,2 , Zhaojian Gao3 , Jiabin Ye1(B) , Wei Tang4 , Xun Che5 ,
and Yadang Chen6
1
Nanjing Big Data Group Co., Ltd., Nanjing 210003, China
[email protected]
2
School of Computer Science and Engineering, Southeast University,
Nanjing 210096, China
3
Nanjing Zhixun Interconnection Technology Co., Nanjing 211100, China
4
Hangzhou SecLead Information Technology Co., Ltd., Hangzhou 310056, China
5
Jiangsu Ruizhihe Information Technology Co., Wuxi 214434, China
6
School of Computer Science and Cyberspace Security, Nanjing University of
Information Science and Technology, Nanjing 210044, China

Abstract. With the development of network devices, the network traffic


presents high-dimensional, enormous as well as complex characteristics,
and the network threats and attacks continue to intensify. Existing net-
work intrusion detection models tend to disregard the extraction and
learning of the temporal features of data, which will greatly affect the
accuracy of network intrusion detection models. To address the short-
comings of the single structure and the inability to comprehensively learn
features, we propose a novel Transformer-based network intrusion detec-
tion method, which integrates the degree of importance of traffic features
and the fact that intrusion detection data has temporal and spatial char-
acteristics. Specifically, firstly, we adopt the Transformer to perform fea-
ture extraction and construct global correlations on the input data, after
that, we utilize the improved Inception to extract multi-scale features
and weight the spatial features at different scales using self-attention
module, in addition to that, BiGRU is employed to enhance the tem-
poral features. Finally, the proposed model is validated on the publicly
available CIC-IDS-2017 and CIC-DDoS-2019 datasets, which verifies that
the proposed method outperforms the existing state-of-the-art models in
terms of performance by four evaluation metrics, and also shows signifi-
cant performance improvement in binary as well as multi-class classifica-
tion tasks compared to other state-of-the-art methods, which proves the
efficiency and effectiveness of our method.

Keywords: network traffic · transformer · temporal features ·


intrusion detection

c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 271–283, 2025.
https://doi.org/10.1007/978-981-96-4566-4_19
272 N. Li et al.

1 Introduction
Currently, network intrusion detection models are mainly used for binary classi-
fication studies of normal and anomalous network data, and less often for mul-
ticlassification studies of anomalous data. However, multiclassification network
intrusion detection models can provide more detailed warning information and
are able to identify specific types of attacks. This allows the system to take direct
countermeasures against specific attack types. Failure to intercept these attacks
in a timely manner will cause irreparable damage to network devices and data.
The quality of the dataset and data distribution has a significant impact on
the performance of the model. Quality datasets and data features can lead to
better performance of trained models. Web traffic datasets suffer from extreme
imbalance, which is especially common in multiclassification problems, and this
unbalanced data distribution may adversely affect subsequent training results.
In addition, traditional feature selection methods also suffer from the problem of
performing better on most classes of data features while ignoring the features of a
few classes, which leads to poor training results of the model. Therefore, effective
methods need to be employed to address the problem of data imbalance and
ensure that the model is better able to classify and predict data across classes. To
address this shortcoming and to improve the recognition rate of minority classes,
we propose a simple and effective sampling method, specifically, we optimize
the sample distribution by including the samples that generate the imbalanced
classes in the dataset.
With the continuous development of intrusion detection models, the Inter-
national Internet Engineering Task Group (IIETG) develops intrusion detec-
tion system models and forms the IDWG (Intrusion Detection Working Group)
group. This group developed the Common Intrusion Detection Framework
(CIDF). The structure of the CIDF is divided into four logical parts, each of
which is described by events. In the field of machine learning, especially in the
field of intrusion detection, network traffic data can be regarded as events in
CIDF. Aiming at the characteristics of network traffic data and the shortcom-
ings of most current models that have a single structure and ignore temporal
features, we propose a new strategy that integrates the degree of importance of
traffic features and the temporal and spatial characteristics of intrusion detection
data.
The main contributions of the recommended model can be summarized as
follows:
– We propose an improved assisted classification sampling method based on.
Facing high-dimensional feature data, the previous method generates low
quality samples as well as unstable training. In order to solve the above prob-
lems, we improve the generator of the previous method, firstly, we add the
self-attention mechanism module to the generator part, so that the generator
can better deal with the local details and global structure; and we add the
BiGRU module, so as to improve the quality and continuity of the generated
samples; and we use the residual network structure to construct the generator
of the model, so as to accelerate the convergence speed of the model.
A Novel and Efficient Multi-scale Spatio-Temporal Residual Network 273

– We propose a network intrusion detection method based on Transformer,


specifically, firstly, the Encoder part of Transformer is used to capture the
global linkages, secondly, the improved Inception module is utilized for the
extraction of multi-scale features, the extracted spatial features at different
scales are weighted using the self-attention mechanism, and BiGRU is utilized
to enhance the extraction ability of the model for temporal features, and
finally, classification is performed by Softmax activation function.
– The proposed method obtains competitive results compared to state-of-the-
art methods on multiple datasets, especially in multi-categorization scenar-
ios, where our method obtains significant performance improvements. This
demonstrates the effectiveness and efficiency of the proposed framework.

2 Related Work
2.1 Intrusion Detection Methods Based on Deep Learning

With the rapid development of hardware technology and neural networks, pre-
viously established signature databases are unable to detect all forms of attacks,
especially new types of attack variants. Deep learning based intrusion detection
methods [4, 13, 14] have received extensive attention and research. [7] proposed
to convert the original one-dimensional network intrusion data from the KDD-
CUP 99 dataset and the NSL-KDD dataset into two-dimensional data, and then
used an optimized convolutional neural network to learn the effective features,
and finally carried out multi-classification detection by a softmax classifier, but
the few types of network traffic in this dataset could not adapt to the new chal-
lenges. [19] proposed a traffic intrusion detection model BAT, which combines
Bidirectional Long Short-Term Memory (BiLSTM) and attention mechanism, in
addition, multiple convolutional layers are used to capture the local features of
the traffic data, which effectively improves the intrusion detection. In [20], Graph
Convolutional Network (GCN) is utilized to obtain the spatial correlation among
data flows, BiLSTM is utilized to obtain the temporal correlation among them,
and finally the attention mechanism is utilized to extract the key information
among data, respectively, but the accuracy of this model for binary classification
detection of the dataset is lower, which needs to be further improved. A multi-
scale convolutional neural network (M-CNN) intrusion detection model is con-
structed by [22] and an LSTM is introduced into it to improve the local feature
extraction capability of the model. In order to efficiently and effectively extract
spatial and temporal features from network traffic data, improved residual net-
work blocks and Bidirectional Gated Recurrent Unit (BiGRU) for detection and
classification, respectively, were proposed in [23], where the use of a pooling
layer was introduced into the residual network block structure to extract spatial
features from the data. It was shown that the fusion of spatio-temporal features
can improve the detection performance of the model in general. [10] proposed
a layered construction of CNN and Transformer for detecting spatio-temporal
features in network traffic data. [21] proposed to utilize the parallel processing
274 N. Li et al.

capability of Transformer to accelerate the speed of feature learning and improve


the detection accuracy of network malicious attacks.
Current research on network intrusion detection methods, most of which
only stay in the network traffic data for simple two-classification experiments,
and the classification of normal traffic data and attack traffic data. Has not
yet comprehensively considered the fine classification of attack traffic data to
ensure that the relevant security personnel to attack data for a more detailed
understanding, and to do a good job of the corresponding security measures. In
the field of network intrusion detection, most of the model structure is relatively
single, failing to comprehensively learn data features, only one-sided temporal
features or spatial features for a single extraction. In addition, the Inception
module lacks the ability to capture multi-scale features of the data.

2.2 Sample Equalization and Quality of Datasets


In network traffic intrusion detection, the number of normal traffic data far
exceeds the number of attack traffic data, and these datasets usually lack sam-
ple balancing, for the problem of category imbalance, [1] combines the Adaptive
Synthetic Sampling Algorithm (ADASYN) and Repeated Edit Nearest Neighbor
Algorithm (RENN) to construct a balanced dataset. In [8], the SMOTE-Tomek
algorithm is utilized to balance the dataset, which effectively improves the recog-
nition rate of a few attack categories. It was experimentally verified in [11, 12]
that the generated synthetic data can produce the same effect as the real data
in the intrusion detection model without any effect on the real data. After that,
[5] uses SMOTE and OSS methods for increasing the minority class samples and
decreasing the majority class samples respectively, which gave better results in
terms of performance metrics. [17] propose a deep sampling method to select a
smaller, sensitive and representative subset of samples from multiple classes with
the help of deep learning, which reduces the impact of class imbalance. [9] com-
bines the random sampling method with the integration method to effectively
solve the category imbalance problem in intrusion detection. GAN, was proposed
by [2], which is an adversarial network structure consisting of a generator and a
discriminator, which are trained against each other to jointly improve the gen-
erative ability of the generator and the recognition ability of the discriminator
[16]. The generator and the discriminator fight against each other, where the gen-
erator attempts to trick the discriminator into not accurately determining the
authenticity of the generated samples, while the discriminator tries to improve
its ability to distinguish between real and fake samples. [15] is proposed for cap-
sule generative adversarial networks, which improves the learning of unbalanced
data by incorporating a majority class distribution structure in minority class
samples, and utilizes a feature matching loss function to train the generator,
which effectively improves the convergence of training. In [6], the quality of the
generated samples is effectively improved by introducing a self-attention mecha-
nism to extract global features of the attack samples. However, there still exists
the problem of insufficient dynamic feature representation ability and residual
learning ability [18].
A Novel and Efficient Multi-scale Spatio-Temporal Residual Network 275

The continuous evolution of network technologies and attack methods has


made it impossible for features in traditional datasets to fully reflect new types
of attacks in the real world. While traditional machine learning methods have
been effective in dealing with unbalanced datasets, GAN as well as other deep
learning techniques can be utilized to extract richer feature representations.

3 The Proposed Methodology


3.1 An Efficient and Effective Sampling Method to Address
the Challenge of Unbalanced Data

Fig. 1. The diagram of the improved ACGCN. The first layer employed is the Fully
Connected Layer, which serves to mix the input noise and labels and unfold them in a
high-dimensional space to provide a good starting point for the subsequent generation
process. Immediately following are multiple Deconvolution Layers, which are used in
conjunction with the Batch Normalization and ReLU activation functions to progres-
sively enlarge the feature maps while maintaining the stability of the training and the
quality of the generated samples.

Unbalanced data distribution refers to a large difference in the proportion


of data volume between data categories in a dataset, and the imbalance rate
is usually used to indicate the degree of data imbalance in a dataset. In the
case of category imbalance, the number of samples in some categories is much
larger than others, which leads to the learner being more likely to be dominated
by the more numerous categories during the training process, thus making it
276 N. Li et al.

difficult to accurately learn the features and patterns of a few categories. To


address the data distribution imbalance problem that exists in the CIC-IDS-
2017 and CIC-DDoS-2019 datasets, an improved ACGAN [3] model is used to
solve the problem. The ACGAN model is introduced with the aim of improving
the data imbalance by using the powerful ability of generative adversarial net-
works to create synthetic data samples that are close to the real distribution,
thus complementing the minority class samples. Based on the original ACGAN
generator, the self-attention mechanism and BiGRU model are added after each
layer, respectively, and the jump links between layers are realized through the
residual block structure, as shown in Fig. 1. In addition to this, we propose a
corresponding discriminator, a structure that is a core component of the model’s
decision-making process. The structure consists of one fully connected layer and
three convolutional layers. A LeakyReLU activation function is used in its net-
work and a Dropout layer is added to prevent the model from overfitting.
The improved ACGAN model is used for training and data enhancement of
the training set data, and the performance of the improved ACGAN model is
verified using the test set without data enhancement, and the specific imple-
mentation steps are as follows, Data preprocessing: it mainly includes the key
operations such as feature selection, character numerical, data cleaning and data
normalization, which are aimed at optimizing the quality and format of the data,
and providing a good Foundation. Training ACGAN using the training set: a few
sample classes of data from the preprocessed training set are extracted and used
to train ACGAN, and the parameters of the ACGAN model are adjusted to
ensure that the model can converge effectively, so as to learn the ability of gen-
erating high-quality data. Data augmentation using improved ACGAN: Data
classes with large differences in data volume are augmented with data to bal-
ance the dataset for subsequent training of the feature extraction model for
classification detection.

3.2 A Novel Multi-scale Intrusion Detection Framework


with Spatio-Temporal Information
For the specific needs of the network intrusion detection task and the fixed-
length characteristics of each record in the dataset, the choice was made to
utilize only the encoder part of the Transformer architecture and fine-tune its
parameters to better suit the task. Specifically, this part of the architecture con-
sists of a multi-head attention mechanism and a feed-forward neural net, with
the core of the multi-head attention mechanism being dot-product attention,
which operates through three key elements: queries, keys, and values. First, the
dot product between the query and the keys is utilized to compute scores that
indicate the importance weights for each value. Subsequently, these weights are
multiplied with the values to compute the final output by means of a weighted
sum. A significant advantage of dot product attention is its ability to enable
efficient parallel computation, significantly reducing the training time of the
model. In addition, the multi-head attention mechanism is able to capture the
intrinsic multi-dimensional features of the data by performing multiple attention
A Novel and Efficient Multi-scale Spatio-Temporal Residual Network 277

operations in parallel in different representation subspaces, thus enhancing the


model’s ability to synthesize the information. This mechanism not only enhances
the performance of the model, but also increases the flexibility and efficiency
of the model when dealing with complex datasets. The next part of the feed-
forward neural network further processes the output weighted by the attention
mechanism, which enhances the model’s expressive ability through nonlinear
transformation. Combining the multi-head attention mechanism and the feed-
forward network, the recommended model is able to accurately recognize and
process complex patterns and relationships in the network intrusion data while
maintaining an efficient training speed. By carefully designing and tuning the
Transformer coding part, our framework optimizes the performance of network
intrusion detection, which not only significantly improves the processing speed,
but also ensures high accuracy and low false alarm rate, which significantly
enhances the network security defense capability.

Fig. 2. The illustration of the proposed intrusion detection method

As shown in Fig. 2, An innovative improvement to the Inception architecture


introduces a strategy of parallel connection of 1 × 1, 3 × 1, and 5 × 1 convo-
lutional kernels for multi-scale feature extraction, combined with pooling opera-
tions, aiming to optimize the network structure while reducing the computational
complexity. By using 1 × 1 convolutional kernels and maximum pooling layer for
dimensionality reduction of high-dimensional input data, not only the number of
parameters and computational requirements of the network are reduced, but also
a high accuracy is maintained, which helps to capture richer local feature infor-
mation. In addition, by concatenating convolutional kernels of different sizes,
this structure is able to learn multi-scale features at the same level, enhancing
278 N. Li et al.

the network’s adaptability to features of different scales and expanding the sen-
sory field, thus capturing richer feature information. This parallel learning strat-
egy for multi-scale features effectively reduces the number of parameters while
increasing the depth and width of the network, reduces the computational bur-
den, and significantly improves the expressive capability of the network model.
To further optimize the model, a BN layer is introduced, immediately after the
convolutional layer, to control the gradient explosion problem and accelerate the
training and convergence of the network. Subsequently, the data are processed
through the ReLU activation function to optimize the feature distribution and
enhance the model stability and efficiency. When processing sequence data, one-
dimensional convolution rather than high-dimensional convolution is used to
preserve the original structure and continuity of the data, avoid information
loss, and ensure the recognition accuracy of the model. By improved inception
block and juxtaposing different scales of convolutional kernels on the same neu-
ral network layer, this paper realizes multi-scale feature capture of input data,
thus allowing the network to learn both fine-grained and higher-level features.
In addition, the introduction of the self-attention mechanism behind the con-
volutional layers can highlight important features and improve the recognition
accuracy of the model. Combining the fusion output of multi-scale convolution
with the self-attention mechanism, the data is further passed into BiGRU, which
captures the dependent information from the forward and reverse directions of
the sequence at the same time, removes redundancy, and enriches the feature
information, which enhances the ability of the Inception module to extract the
global temporal features of network intrusion data, and provides an effective
means of identifying the key features in the network intrusion behaviors and
improving the detection accuracy. Finally, the element-wise addition function is
used for feature fusion. This function is based on the principle of item-wise addi-
tion, which can enhance the information content of individual elements while
maintaining the stability of feature dimensions. This feature fusion method can
effectively integrate features from different sources and improve the accuracy and
generalization ability of the model. This method effectively reduces the number
of parameters in the subsequent operations, while ensuring the information den-
sity and computational efficiency of the output features.

4 Experiments
4.1 Configuration
Ubuntu is adopted as the operating system and our method is implemented
through pytorch framework. In order to comprehensively evaluate the perfor-
mance of the proposed model, Accuracy, Precision, Recall, F1-Score, Confusion
Matrix and Training Loss Curve are used as model evaluation metrics.

4.2 Binary Classification Task


This section focuses on performing a binary classification task on the CIC-IDS-
2017 and CIC-DDoS-2019 datasets to explore the distinction between normal
A Novel and Efficient Multi-scale Spatio-Temporal Residual Network 279

and attack data in network traffic. Considering that the amount of normal and
attack traffic data in the binary classification task is comparable and there is
no need to rely excessively on data enhancement techniques, this experiment
decides not to use the improved ACGAN method for data enhancement. As
shown in Fig. 3, from which it can be seen that the detection rate for the attack
data is extremely high, where the left figure of the confusion matrix based on the
CIC-IDS-2017 dataset reveals the accuracy and the false alarm rate of the model
in identifying various types of network intrusion behaviors, and the right figure
demonstrates the model’s performance on the CIC-DDoS-2019 dataset. With
these confusion matrices, researchers can visually assess the model’s ability to
recognize different types of attacks, thus providing a comprehensive analysis of
the model’s accuracy and reliability.

Fig. 3. The visualization of Confusion Matrix on CIC-IDS-2017 and CIC-DDoS-2019


Datasets

As shown in Fig. 4, these curves provide a visualization of the impact of


learning rate settings on model training effectiveness, revealing how learning rate
magnitude affects model convergence speed and stability. The curve variations
show how the model performance changes when the learning rate is gradually
reduced, helping researchers to find the best learning rate settings to optimize
the model training process and achieve faster convergence speed and higher
model accuracy. These graphs and comparative experiments not only provide
valuable data and insights for research in the field of intrusion detection, but
also emphasize the importance of careful consideration of different parameter
settings during model design and training to ensure that the research can move
forward efficiently, and at the same time improve the validity and reliability
of the intrusion detection model in practical applications. As can be seen from
the Fig. 4, the training of the two datasets is more stable and faster when the
learning rate is 0.001, and all the indicators are higher than the other learning
rates.
280 N. Li et al.

Fig. 4. The top row shows the comparison of training loss curves for CIC-IDS-2017
dataset at learning rates of 0.01, 0.001 and 0.0001, and the bottom row shows the
comparison of training loss curves for CIC-DDoS-2019 dataset at learning rates of
0.01, 0.001 and 0.0001.

4.3 Multi-class Classification Task

Table 1. The comparison of performance with different learning rates

Datasets learning rate Accuracy Precision Recall F1-Score


CIC-IDS-2017 0.01 61.01 28.60 24.07 22.64
0.001 95.66 92.44 78.47 83.18
0.0001 89.52 88.63 71.99 76.17
CIC-DDoS-2019 0.01 59.77 25.08 22.36 20.59
0.001 96.48 95.70 86.73 90.05
0.0001 94.10 90.23 82.33 85.55

To evaluate the performance of multiclassification, a ten-class task is performed


on the dataset CIC-IDS-2017 and a seven-class task is performed on the CIC-
DDoS-2019 dataset to explore the distinction between normal and various types
of attack data in network traffic. From Fig. 5 and Table 1, two important con-
clusions can be drawn: the Transformer model and the Inception model have
very important roles in this experimental model. First, the Transformer model
performs significantly in the intrusion detection task. The Transformer model
utilizes the self-attention mechanism to better model and understand complex
patterns in network traffic by being able to capture longer-range dependencies
when processing sequential data. This allows the Transformer model to more
A Novel and Efficient Multi-scale Spatio-Temporal Residual Network 281

Fig. 5. The top row shows the comparison of training loss curves for CIC-IDS-2017
dataset at learning rates of 0.01, 0.001 and 0.0001, and the bottom row shows the
comparison of training loss curves for CIC-DDoS-2019 dataset at learning rates of
0.01, 0.001 and 0.0001.

accurately identify potential intrusions in the intrusion detection task. It can be


seen that the Inception model also plays a crucial role in this experiment. The
Inception model allows for the simultaneous extraction and fusion of features at
different scales through the use of multiple parallel convolutional and pooling
layers. This multi-scale feature extraction capability allows the model to bet-
ter capture the multi-level patterns and detailed features in the network traffic,
thus improving the robustness and accuracy of the model. For multi-classification
experiments with uneven data distribution datasets, the data augmentation of
the ACGAN model can be improved to effectively increase the detection rate of
the model, which is more helpful in recognizing a few attack data. Finally, as
with the binary classification task, training was more stable and faster on both
datasets when the learning rate was set to 0.001.

5 Conclusion
To address the challenges of security in current network environments as well as
the problem of data distribution imbalance, we propose an improved ACGAN
model that realizes enhanced functionality for minority class data, and a novel
and efficient intrusion detection framework that enables binary and multi-
classification detection of cyber-attack data and maintains a high accuracy
rate. The problem of data imbalance with high-dimensional features is particu-
larly prominent, which largely affects the efficacy of intrusion detection models
because the models tend to favor the majority class, thus ignoring the less com-
mon but usually more important minority class samples. We propose an imbal-
282 N. Li et al.

anced data sampling method based on auxiliary categorical generative adver-


sarial networks (ACGANs), which is specifically designed to address the high-
dimensional features and imbalance of network traffic data Problem. In addition,
considering the temporal and spatial characteristics of intrusion data, we propose
some effective modules to fuse spatio-temporal information to enhance the detec-
tion accuracy of intrusions. Finally, the proposed framework is experimentally
validated on several publicly available state-of-the-art datasets, CIC-IDS-2017
and CIC-DDoS-2019, which fully proves the innovation and effectiveness of this
research method.

References
1. Cao, B., Li, C., Song, Y., Qin, Y., Chen, C.: Network intrusion detection
model based on cnn and gru. Appl. Sci. 4184 (2022). https://doi.org/10.3390/
app12094184, http://dx.doi.org/10.3390/app12094184
2. Goodfellow, I., et al.: Generative adversarial nets. J. Jpn. Soc. Fuzzy Theory Intell.
Inform. 177 (2017)
3. Henderson, N., Min, W., Rowe, J., Lester, J.: Multimodal player affect modeling
with auxiliary classifier generative adversarial networks. In: Proceedings of the
AAAI Conference on Artificial Intelligence and Interactive Digital Entertainment,
vol. 16, no. 1, pp. 224–230 (October 2022). https://doi.org/10.1609/aiide.v16i1.
7434
4. Jothi, B., Pushpalatha, M.: Wils-trs — a novel optimized deep learning based
intrusion detection framework for iot networks. Pers. Ubiquitous Comput. 27(3),
1285–1301 (2023). https://doi.org/10.1007/s00779-021-01578-5
5. Lan, J., Liu, X., Li, B., Zhao, J.: A novel hierarchical attention-based triplet net-
work with unsupervised domain adaptation for network intrusion detection. Appl.
Intell. 11705–11726 (2023). https://doi.org/10.1007/s10489-022-04076-0, http://
dx.doi.org/10.1007/s10489-022-04076-0
6. Liao, J.: An intrusion detection model based on improved acgan in big data envi-
ronment. Secur. Commun. Netw. 2022, 1–9 (2022). https://doi.org/10.1155/2022/
6821174
7. Liu, G., Zhang, J.: Cnid: research of network intrusion detection based on convo-
lutional neural network. Discret. Dyn. Nat. Soc. 1–11 (2020). https://doi.org/10.
1155/2020/4705982, http://dx.doi.org/10.1155/2020/4705982
8. Liu, R., Ji, C., Niu, J., Guo, B.: Research on intrusion detection method based
on 1d-icnn-bigru. J. Phys.: Conf. Ser. 2347(1), 012001 (2022). https://doi.org/10.
1088/1742-6596/2347/1/012001
9. Lu, Y., Cheung, Y.M., Tang, Y.Y.: Hybrid Sampling with Bagging for Class Imbal-
ance Learning, pp. 14–26 (January 2016)
10. Luo, S., Zhao, Z., Hu, Q., Liu, Y.: A hierarchical cnn-transformer model for net-
work intrusion detection. In: 2nd International Conference on Applied Mathemat-
ics, Modelling, and Intelligent Computing (CAMMIC 2022), p. 278 (May 2022).
https://doi.org/10.1117/12.2639876
11. Lyu, J., Young Lee, H., Liu, H.: Color matching generation algorithm for ani-
mation characters based on convolutional neural network. Comput. Intell. Neu-
rosci. 1–13 (2022). https://doi.org/10.1155/2022/3146488, http://dx.doi.org/10.
1155/2022/3146488
A Novel and Efficient Multi-scale Spatio-Temporal Residual Network 283

12. Mozo, A., González-Prieto, A., Pastor, A., Gomez-Canaval, S., Talavera, E.: Syn-
thetic flow-based cryptomining attack generation through generative adversarial
networks. Sci. Rep. (2022). https://doi.org/10.1038/s41598-022-06057-2, http://
dx.doi.org/10.1038/s41598-022-06057-2
13. Peng, W., Kong, X., Peng, G., Li, X., Wang, Z.: Network intrusion detection based
on deep learning. In: 2019 International Conference on Communications, Infor-
mation System and Computer Engineering (July 2019). https://doi.org/10.1109/
cisce.2019.00102, http://dx.doi.org/10.1109/cisce.2019.00102
14. Saviour, M.P.A., Samiappan, D.: Ipfs based storage authentication and access con-
trol model with optimization enabled deep learning for intrusion detection. Adv.
Eng. Softw. 176, 103369 (2023). https://doi.org/10.1016/j.advengsoft.2022.103369
15. Shamsolmoali, P., Zareapoor, M., Shen, L., Sadka, A., Yang, J.: Imbalanced data
learning by minority class augmentation using capsule adversarial networks. Neu-
rocomputing (2020)
16. Shao, S., Wang, P., Yan, R.: Generative adversarial networks for data augmentation
in machine fault diagnosis. Comput. Ind. 85–93 (2019). https://doi.org/10.1016/j.
compind.2019.01.001, http://dx.doi.org/10.1016/j.compind.2019.01.001
17. Shen, W., et al.: Boundary sampling to boost mutation testing for deep learning
models. Inf. Softw. Technol. 106413 (2021). https://doi.org/10.1016/j.infsof.2020.
106413, http://dx.doi.org/10.1016/j.infsof.2020.106413
18. Shi, G., et al.: Knowledge-guided synthetic medical image adversarial augmenta-
tion for ultrasonography thyroid nodule classification. Comput. Methods Programs
Biomed. 105611 (2020). https://doi.org/10.1016/j.cmpb.2020.105611, http://dx.
doi.org/10.1016/j.cmpb.2020.105611
19. Su, T., Sun, H., Zhu, J., Wang, S., Li, Y.: Bat: deep learning methods on
network intrusion detection using nsl-kdd dataset. IEEE Access 29575–29585
(2020). https://doi.org/10.1109/access.2020.2972627, http://dx.doi.org/10.1109/
access.2020.2972627
20. Wang, X., Wang, Q.: An abnormal traffic detection method using gcn-bilstm-
attention in the internet of vehicles environment. EURASIP J. Wirel. Commun.
Netw. (2023). https://doi.org/10.1186/s13638-023-02274-z, http://dx.doi.org/10.
1186/s13638-023-02274-z
21. Wei, J., Chen, Y., Lai, Y., Wang, Y., Zhang, Z.: Domain adversarial neural network-
based intrusion detection system for in-vehicle network variant attacks. IEEE
Commun. Lett. 26(11), 2547–2551 (2022). https://doi.org/10.1109/lcomm.2022.
3195486, https://doi.org/10.1109/lcomm.2022.3195486
22. Yin, X., Chen, L.: Network intrusion detection method based on multi-scale cnn
in internet of things. Mob. Inf. Syst. 1–8 (2022). https://doi.org/10.1155/2022/
8124831, http://dx.doi.org/10.1155/2022/8124831
23. Yu, H., Kang, C., Xiao, Y., Yang, Y.: Network intrusion detection method based
on hybrid improved residual network blocks and bidirectional gated recurrent
units. IEEE Access 11, 68961–68971 (2023). https://doi.org/10.1109/access.2023.
3271866
Provable Data Auditing Scheme
from Trusted Execution Environment

Yuluo Zeng, Xinlei Sheng(B) , Kai Zhao, and Suliu Yang

Zhejiang Sci-Tech University, Hangzhou 310018, China


[email protected]

Abstract. With the swift advancement of machine learning technol-


ogy, the volume of information is expanding at an exponential rate. An
increasing number of users choose to store data in cloud servers, so it
is crucial to ensure the integrity of data. The provable data possession
(PDP) scheme allows it to obtain integrity proofs like hash-based schemes
while consuming less bandwidth. However, most existing PDP schemes
rely on the third party auditor (TPA) to verify the integrity of the data
for users, which may lead to the problem that TPA steals user data or
colludes with cloud servers. We proposed a scheme to replace TPA with
a trusted execution environment (TEE), forming a system framework
consisting of users, cloud servers, and TEE. This system can effectively
reduce the user’s communication and computing overhead and improve
the credibility of the system model. Our scheme places the two core
stages of challenge and verification in a fully trusted environment, which
can eliminate the hidden dangers brought by TPA. We use Intel SGX
technology to deploy on a real cloud storage platform and apply the
Enclave area created by Intel SGX to the PDP scheme. Our experiments
show that the performance calculation overhead is reasonable based on
the elimination of TPA dependencies.

Keywords: Cloud storage · Provable data possession · TEE · Intel


SGX

1 Introduction

Over the past few years, the demand for storage has been increasing significantly
with the explosive growth of user data and the widespread application of artificial
intelligence in various fields. As a distributed data storage service based on
the Internet, cloud storage allows users to store data, files, applications, etc.
in remote data centers without relying on local hard disks or physical storage
devices. Therefore, it is becoming an increasingly popular choice for users to
share data. According to Huawei’s forecast, by 2030, the total amount of data
generated worldwide will reach 1YB per year, and mankind will usher in the YB
data era. The total amount of global general storage capacity will reach 37ZB,
of which AI-related storage capacity accounts for 63%.
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 284–298, 2025.
https://doi.org/10.1007/978-981-96-4566-4_20
Provable Data Auditing Scheme from Trusted Execution Environment 285

However, once users transfer their data to the cloud, they no longer have
control over it. The management practices employed by cloud service providers
are often opaque, making it difficult for users to place complete trust in them. To
reduce storage requirements and costs, these providers may overlook certain user
data or delete files that are infrequently accessed [1]. Furthermore, cloud servers
are vulnerable to attacks, hardware or software failures, and other unforeseen
issues, which could lead to the corruption or loss of data stored in the cloud
[2]. In order to protect their reputation, cloud service providers may opt not to
disclose such incidents to users.
To mitigate these potential risks, one solution is to retrieve the stored infor-
mation from the cloud for integrity verification. However, given the vast amount
of data, this approach is impractical due to the significant bandwidth and local
storage space it would require. The provable data possession (PDP) scheme [3]
was introduced to provide a probabilistic proof for the integrity of files stored
by third parties. This scheme enables auditors to efficiently assess the integrity
of the data without the need to download it, offering benefits such as reduced
bandwidth usage and enhanced data security.
Machine learning needs to be trained with large-scale data to construct math-
ematical models and statistical methods to discover patterns and laws in the data
in order to make predictions, inferences or classifications, etc. During the train-
ing process of machine learning, especially in scenarios involving sensitive data,
such as healthcare and finance, how to maintain data security and privacy is an
important issue. Meanwhile, in distributed machine learning training scenarios,
data are usually held by multiple nodes, how to ensure the integrity of these
specific data without transmitting them acts as a key factor for the authentic-
ity and effectiveness of machine learning training. The provable data possession
scheme, as a cryptographic technique, can provide data integrity verification to
ensure the integrity and authenticity of the data relied on for machine learning
training, key parameters generated by the training, and deployment to improve
the credibility of machine learning models.
In earlier PDP scheme frameworks, the client was responsible for generating
keys, computing verification metadata and tags, and then transmitting the tags
and files to the cloud server. During the challenge phase, the user would request
a verification of the data stored on the cloud server, which would then generate
proof of possession in response to the challenge. Finally, the client would confirm
the integrity of the data using the proof provided by the cloud server. In more
recent PDP frameworks, the role of a third-party auditor (TPA) was introduced,
with the TPA taking over the tasks of sending challenges to the cloud server and
validating the proof received from the server. However, there is a risk that the
TPA could misuse user data or collaborate with the cloud server to hide data
loss [4].
Therefore, we propose to replace TPA with the trusted execution environ-
ment (TEE) to prevent data leakage and attacks. Trusted execution environment
(TEE) is an isolation technology that combines software and hardware to ensure
that the execution process of data and programs is secure and reliable. TEE
286 Y. Zeng et al.

provides a protective environment for running sensitive code and processing sen-
sitive data, in which external attacks or operating systems cannot access without
authorization.
We chose Intel SGX as our experimental platform. As a typical application
of TEE, Intel SGX provides strong security protection in multiple fields. Intel
SGX builds a trusted execution environment called Enclave. When the Enclave
is started, hardware and software jointly verify whether the Enclave code has
not been tampered with. The memory within the Enclave is encrypted using the
Memory Encryption Engine (MEE) to ensure that only authorized code can read
the Enclave memory. The code within the Enclave can run in an environment
that is completely isolated from the operating system and other applications. The
operating system or hypervisor cannot access the memory or code in the Enclave.
SGX provides a remote attestation mechanism that allows users to verify that
the code and data of the Enclave are secure and have not been tampered with.
The main contributions of our paper include the following:

1. We introduce a data integrity auditing framework that leverages the trusted


execution environment (TEE), enabling secure verification of data integrity.
By substituting the third-party auditor (TPA) with the TEE, we enhance the
security of the verification process and strengthen data privacy protection,
while eliminating the system’s reliance on external auditors.
2. We deployed it in a real cloud server scenario, ported some of the core functions
of the cryptographic library to Intel SGX, and conducted a comprehensive
evaluation using experiments. We compared the time required for each oper-
ation on different file sizes and the performance overhead of whether or not
to load it into the TEE, confirming the feasibility of the new PDP scheme.

2 Related Work

Currently, provable data possession (PDP) has become an important technology


to ensure the integrity of outsourced data in cloud storage. Since it was first
proposed by Ateniese et al. [3] in 2007, many researchers have proposed various
improvements based on it. The PDP model allows clients to verify whether an
untrusted server owns the original data without retrieving the data, by generat-
ing probabilistic proof of possession by randomly sampling data blocks from the
server.
To address the challenges presented by multi-cloud environments, Zhu et al.
introduced a collaborative provable data possession framework utilizing homo-
morphic verifiable responses and a hash index hierarchy [5]. This approach
enables multiple cloud providers to cooperate in storing and maintaining cus-
tomer data. Building upon this, Wang et al. proposed a cloud storage security
solution that supports public auditability and dynamic data management. Their
approach enhances block label authentication by refining the traditional Merkle
hash tree construction and extends it to a multi-user environment using bilinear
Provable Data Auditing Scheme from Trusted Execution Environment 287

aggregate signature technology [1]. In order to meet the requirement that cus-
tomers can check remote data ownership in a public cloud environment, Wang
et al. introduced an authorized proxy, so that the client can delegate the task of
integrity verification to a trusted proxy [6].
Amidst the growth of cloud computing, a large amount of stored data needs
to be shared securely and efficiently [7, 8]. Therefore, PDP schemes have gradu-
ally begun to focus on enhancing efficiency, security, and privacy. Shen et al. [9]
proposed a global and sampled verification and batch audit protocol that sup-
ports cloud data. The protocol adopts a novel dynamic structure consisting of
a bidirectionally linked information table and a position array. To prevent data
leakage to third-party auditors, Ni et al. proposed an identity-based privacy-
preserving provable data possession scheme [10]. The scheme leverages the RSA
assumption to allow third-party auditors to verify the integrity of outsourced
files through the validation of homomorphic authentication tags, while also sup-
porting aggregate verification across multiple users.
To address the security challenges of existing certificate-independent PDP
schemes in the random oracle model and eliminate the risk of TPAs retrieving
user data blocks, the certificateless provable data possession scheme was intro-
duced. The CL-PDP framework proposed by Deng et al. is specifically designed
for cloud storage and provides provable security within the standard model [11].
Meanwhile, Shen et al.’s CL-PDP scheme supports multi-replica and multi-server
setups, introducing a novel data structure called the mapping version mark table
(MVMT) to facilitate dynamic block-level operations and ensure data traceabil-
ity [12].
With the development of blockchain technology, researchers have found that
the integration of PDP schemes and blockchain technology is a very promising
field. The immutability and transparency of blockchain provide a decentralized
platform for data integrity verification. Wang et al. introduced a decentralized
cloud storage platform SStore, which uses the immutability of blockchain to
securely store file tags, thereby guaranteeing the safety and dependability of the
cloud storage data integrity verification process [2]. Miao et al. introduced a
certificateless multi-copy public audit protocol based on blockchain technology
to achieve decentralized multi-copy data possession verification and fault loca-
tion, while improving the dependability of fault detection and the precision of
decision-making through the use of smart contracts [13]. Wang et al. introduced
a distributed provable data possession scheme based on blockchain technology,
which ensures synchronization and enhances security in multi-cloud storage envi-
ronments. This scheme also facilitates fault detection by utilizing smart contracts
and binary search techniques [14].

3 Preliminaries

In this section, we introduce the relevant knowledge and symbolic descriptions of


our provable data possession scheme based on a trusted execution environment.
Table 1 lists the symbols used in this scheme.
288 Y. Zeng et al.

Table 1. Notations and descriptions

Symbols Descriptions
p, q two safe primes
N RSA modulo
Z∗N All elements that have a multiplicative inverse modulo N
QRN set of quadratic residues modulo N
ϕ (N ) Euler function
e Integers coprime with ϕ (N )
d The modular inverse element of e
g generator of N
v Secret random integer
pk public key
sk private key
f a pseudo-random function
π pseudo-random permutation
h () , H () cryptographic hash function
F the stored file
b Unique identifier for the user
n The total number of blocks uploaded to the cloud server
mi The i-th file block
ti The timestamp of the i-th file block
ID Unique identifier for the file
idi Unique identifier of the i-th file chunk
TID,i The label of the i-th file block of a file
c The number of challenge file blocks
chal challenge
k1 , k2 Two pseudo-random numbers
W Log files

3.1 Trusted Execution Environment


The Trusted Execution Environment (TEE) is specifically designed to provide a
secure and isolated space for the storage, processing, and protection of sensitive
data. The TEE is designed to withstand not only attackers at the operating sys-
tem level, but also malicious parties with physical access to the platform. There-
fore, it guarantees execution isolation within the enclave by enforcing mandatory
security measures at the hardware level, safeguards the integrity and confiden-
tiality of data, and authenticates the identity of the code running on the trusted
platform using remote attestation mechanisms.
• Isolation. Memory isolation is the core of TEE. SGX Enclave is a secure, pro-
tected memory area provided by Intel hardware, which is physically located
Provable Data Auditing Scheme from Trusted Execution Environment 289

in the computer’s memory but logically separated from the normal process
address space. This memory area is controlled by the SGX extended hard-
ware when the process is running and can only be accessed through special
instructions. The code and data in the SGX Enclave are protected, and any
other process, including the operating system kernel, cannot directly access
or tamper with the contents therein.
• Attestation. The goal of remote attestation is to create a proof that includes
measurements of the software state, which is then signed with a certificate key
stored in the hardware. The remote verifier validates this proof by verifying
both the signature and the measurements.

3.2 Secure Transmission Channel Establishment

The establishment of a secure transmission channel involves creating a reliable


communication link between two parties to guarantee the privacy, integrity and
trustworthiness of the data being transmitted. This is typically achieved by using
authentication mechanisms, key exchange protocols and encryption technologies.
The goal is to prevent unauthorized access or tampering during data transmis-
sion, especially man-in-the-middle attacks.

• Identity authentication mechanism. Identity authentication is the pro-


cess of ensuring the authenticity of the identities of the communicating par-
ties, and only authenticated communicating parties can exchange data. This
can be done using digital certificates, public keys or other authentication
methods.
• Key Exchange Protocol. A secure key exchange protocol [15] is used to
exchange shared keys that will be used to encrypt data during transmission.
• Encryption. More efficient symmetric encryption methods are usually used,
where a shared key is used as the key for both sides of the encrypted and
decrypted data to ensure the confidentiality of the data.

4 Concrete Construction of The Proposed PDP Scheme


4.1 A Subsection Sample

To ensure the integrity of the data stored at the end of the cloud server, and also
to prevent TPAs from leaking data or conspiring with cloud service providers, we
introduce a provable data possession scheme that leverages the trusted execution
environment. The system architecture of the scheme is illustrated in Fig. 1.
There are three main participants in the system: the user, the cloud server,
and the trusted execution environment. When a user needs to outsource a data
file to the cloud storage, he should first slice a large file into smaller chunks of
fixed size. The user computes the homomorphic verifiable tag of the file based on
290 Y. Zeng et al.

Fig. 1. System framework of the proposed PDP scheme.

its public and private keys and uploads the file, public key and file tag together
to the cloud server. A secure transmission channel is builded between the user
and the TEE and transmits its own private key to the TEE for authentication.
The TEE as a trusted auditor periodically sends a challenge to the cloud server.
When the cloud server receives the challenge it generates a proof as a response
based on the data in its possession and returns it to the TEE. The TEE verifies
the proof and generates a verification log file. If the validation passes, the block
of files related to the challenge is complete, if the validation fails, the block of
files related to the challenge is corrupted or missing. the TEE digitally signs the
generated log file and provides regular feedback to the user on the integrity of
the data on the cloud.
The specific description of the PDP scheme we proposed above is as follows:
KeyGen: The client runs a probabilistic key generation algorithm, taking a secu-
rity parameter as input, to obtain a public-private key pair (pk, sk).
KeyTran: A secure transmission channel is established between the client and
TEE through identity authentication mechanism, key exchange protocol and
encryption technology to secretly transmit the client’s private key to TEE.
TagBlock : The client takes the file block and public and private keys as input to
generate verification metadata.
GenProof : The cloud server takes the public key, the stored file block, and the
challenge as input and generates a proof of possession.
Provable Data Auditing Scheme from Trusted Execution Environment 291

CheckProof : TEE runs to verify the possession certificate sent by the cloud
server, using the public and private keys, challenge, and possession certificate as
input to determine whether the verification is successful.
Sig: TEE digitally signs the log file and returns the log file and signature integer
pair to the client through an untrusted channel. The client verifies the signature
to verify the source and integrity of the information.

4.2 Concrete Construction

The specific algorithm and description of our proposed provable data possession
scheme based on a trusted execution environment are provided below.
We use b as a unique identifier of a client to identify the client information.
Assuming that the user wants to transmit a file F to the cloud storage, the user
needs to divide the file F into n file blocks first, and F = {mi }1in represents
the ordered set of n file blocks. When dividing file F into file blocks mi , a
timestamp ti (1  i  n) and a unique identifier idi (1  i  n) are generated
for each newly divided file block mi on the client, which is conducive to the
management of file blocks and data tracing. In large-volume data scenarios, it is
impossible to upload only one file at a time. In order to avoid conflicts, a unique
file identifier needs to be added to each file F , which we represent as ID.
In addition, we define h to be a secure hash function, and we also use a
pseudo-random function (PRF) f and a pseudo-random permutation (PRP) π.
They are described as follows:

f : Z∗N × {1, 2, · · · , n} → Z∗N


π : Z∗N × {1, 2, · · · , n} → {1, 2, · · · , n}

• Setup: We let KeyGen (K) → (p, q, N, e, d). The gcd function represents the
greatest common divisor of two integers. Let a be a random integer belonging
to Z∗N , and a satisfies gcd (a ± 1, N ) = 1. The generator g is an element
in QRN , calculated by g = a2 . Generate a secret random number v, where
v ∈ ZN .
• KeyGen: The user generates a public key pk = (N, g) and a private key
sk = (e, d, v), and sends the public key pk = (N, g) to the server and TEE.
• KeyTran: A secure transmission channel is established between the user and
TEE, and the user sends his private key to TEE.
1) After the user and TEE authenticate each other through the identity
authentication mechanism, they use the key negotiation protocol to gen-
erate a shared key K. Calculate AKE (θ, ω) → K, where the AKE func-
tion represents the authentication key exchange protocol, θ represents
the identity information verification parameter, and ω represents the key
negotiation protocol parameter.
2) Let SKA be a symmetric key algorithm. The user calculates SKA (K, sk)
to encrypt the private key sk. After receiving it, TEE decrypts it to obtain
the user’s private key.
292 Y. Zeng et al.

• TagBlock: For file F to be transmitted to the cloud server for safekeeping,


the user runs the algorithm to calculate the homomorphic verifiable tags of
the file blocks and uploads them to the cloud together with file F .
1) For each file block mi .
For 1  i  n
calculate xi = mi ||ID||ti
generate Ci = v||b||idi
d
Calculate the label TID,i = (h (Ci ) · g xi ) modN for each file block
2) The user outputs (Ci , TID,i ) and transmits it to the cloud server.
• Challenge: TEE generates an element c (1  c  n), which represents the
number of challenged file blocks. It sends the generated challenge chal =
(c, k1 , k2 ) to the cloud server, where k1 , k2 ∈ Z∗N .
• GenProof: The cloud server receives the challenge chal = (c, k1 , k2 ) from
TEE, checks the corresponding file blocks it stores based on the challenge,
and generates proof of possession and returns it to TEE.
1) For 1  j  c
Calculate the index J = πk1 (j) of the block chosen by the challenge
to generate the proof
Calculation coefficient α = fk2 (j)
c   αj αj xi d
2) Calculate label T = TID,i α1
1
· · · · · TID,i
αc
c
= j=1 h Cij ·g j

mod N . c
3) Calculate ρ = j=1 αj xij .
4) Output proof of possession (T, ρ).
• CheckProof: TEE verifies the possession certificate returned by the cloud
server and generates a log file containing the verification process and results.
1) For 1  j  c
Compute J = πk1 (j), α = fk2 (j), Cij = v||b||idij
e
let μ = c hT C αj mod N
j=1 ( ij )

2) If μ = g ρ mod N , the output is passed verification, otherwise the output


verification fails.
3) TEE generates log file W .
• TEESG: TEE digitally signs the log file W using the Elliptic Curve Digital
Signature Algorithm (ECDSA) and sends it to the user.
Key Generation. For ECDSA, we let q  be a prime number, a and b be
integers on Zq , G represent the base point that satisfies the elliptic curve
equation, and n is the order of point G . Choose a random integer d , d ∈
[1, n − 1]. Calculate Q = d G , where the public key is Q and the private
key is d .
• SigVer: The user receives the signature pair (r , s ) and log file W sent by
TEE and performs digital signature verification.
Provable Data Auditing Scheme from Trusted Execution Environment 293

Algorithm 1. ECDSA
Input: Log files W , Private Key d , Curve parameters: q  , a , b , n , G
Output: (r , s )
Detailed experiment:
1: k ← $ {1, 2, · · · , n − 1}
2: P  = (x , y  ) = k G
3: r = x mod n , if r = 0 return to 1
4: t = k−1 mod n
5: e = H (W )
6: s = k−1 (e + d r ) mod n , if s = O return to 1

Algorithm 2. ECDSA Signature Verification


Input: Log files W , Public Key Q , Signature (r , s ), Curve parameters: q  , a , b , n ,
G
Output: TRUE or FALSE
Detailed experiment:
1: if r ∈/ [1, n − 1] or s ∈ / [1, n − 1] then
2: return FALSE
3: end if
4: e = H (W )
5: w = s−1 mod n
6: u1 = e w
7: u2 = r w
8: X  = (x1 , y1 ) = u1 G + u2 Q
9: if X  = O then
10: return FALSE
11: else
12: v  = x1 mod n
13: end if
14: if v  = r then
15: return TRUE
16: else
17: return FALSE
18: end if

4.3 Correctness Analysis


In the ownership verification algorithm, after TEE receives the possession proof
?
(T, ρ) from the cloud server, the verification equation μ == g ρ mod N holds. The
relevant proof is given below.
c   αj αj xi d
First, we know that T = TID,i α1
· · · · · TID,i
αc
= j=1 h Cij ·g j
c 1 c
mod N ,ρ = j=1 αj xij .
Therefore, the left side of the equation has
Te
μ = c  αj mod N (1)
j=1 h Cij
294 Y. Zeng et al.

Substituting into the equation we get


c   αj αj xi 
j=1 h Cij ·g j c
μ= c  αj mod N = g αj xij mod N (2)
j=1 h Cij j=1

The right side of the equation has


c
αj xij

c
g ρ mod N = g j=1 mod N = g αj xij mod N (3)
j=1

Equations (2) and (3) are equal, so we have proved that.

5 Experiments
In this section, we analyze the computation overhead of our proposed PDP
scheme. For our system, we focus on the five stages of TagBlock, GenProof,
CheckProof, TEESG, and SigVer. Because these stages are time-consuming
and frequently operated. Since hash operations, addition operations and mod-
ulo remainder have shorter execution times and minimal computation overhead
compared to other operations, we ignore their computational costs. To simplify
the expression, we employ TM Exp as a symbol for the computation overhead
of modular exponentiation operations in the system, TM M ul to represent the
computation overhead of modular multiplication operations, TM Rem to repre-
sent the computation overhead of modular remainder operations, and TSM ul to
represent the computation overhead of scalar multiplication in the elliptic curve
digital signature algorithm.
Before calculation, for a file F , we divide it into n equal-sized file blocks m
and use c to represent the number of challenged file blocks. The computation
overhead we describe below is only for one complete process that occurs in the
system. In the TagBlock stage, there are a total of 2n modular exponentiation
operations and n modular multiplication operations, with a computation over-
head of 2nTM Exp +nTM M ul . In the GenProof stage, there are a total of c modular
exponentiation operations and 2c − 1 modular multiplication operations, with
a computation overhead of cTM Exp + (2c − 1) TM M ul . In the CheckProof stage,
there are a total of c+2 modular exponentiation operations and c modular multi-
plication operations, with a computation overhead of (c + 2) TM Exp +cTM M ul . In
the TEESG stage, there are a total of 1 scalar multiplication operation, 1 modu-
lar inverse operation, and 2 modular multiplication operations, with a computa-
tion overhead of TSM ul + TM Inv + 2TM M ul . In the SigVer stage, there are a total
of 2 scalar multiplication operations, and 2 modular multiplication operations,
with a computation overhead of 2TSM ul + 2TM M ul (Fig. 3).
We use our personal host as the experimental platform to deploy the PDP
solution. The hardware and software configuration of the host includes Ubuntu
22.04 64-bit operating system, 32 GB physical memory, and Intel i5-12600KF
3.7 GHz CPU. At the same time, this experiment also uses the PBC library and
the GMP library to test the computation overhead of the PDP solution.
Provable Data Auditing Scheme from Trusted Execution Environment 295

Fig. 2. Optimal file block size selection.

Regardless of the size of the file, when 1% of the data is lost or damaged,
it can be detected with a probability of 99% and 95% respectively by randomly
sampling 460 and 300 file blocks [3]. Therefore, in this experiment, we only
consider the damaged or lost data accounting for only 1%. We selected a machine
learning training dataset, divided it into several 1 GB files, and used them as
files uploaded by users to the cloud server for storage.
Initially, we aim to determine the optimal size for the file blocks. The execu-
tion of TagBlock, GenProof, and CheckProof stages may be affected by the file
block size. To determine the optimal file block size, we challenge 460 file blocks
each time and calculate the sum of execution time of different file block sizes in
the three stages.
296 Y. Zeng et al.

Fig. 3. Computation overhead of each stage of our PDP scheme.

Figure 2, drawn based on our experimental results, shows that when the
file block size is greater than or equal to 16KB, the sum of the computation
overhead of the TagBlock, GenProof, and CheckProof stages remains basically
unchanged. The curve’s deceleration rate slows down significantly, showing a
trend of stabilization. Therefore, we choose 16KB as the file block size for our
experiment.
According to the determined file block size, we experimentally obtain the
computation overhead of our PDP scheme when challenging 300 file blocks and
460 file blocks. Since TagBlock is not affected by the number of challenged file
blocks and its computation overhead is much larger than that of the other four
stages, the following figure does not show the experimental results of this stage.
The computation overhead of each stage in the experiment is shown in Table 2.
Compared with the experimental results of paper [3], it can be concluded that
the computation overhead of our PDP scheme is reasonable. At the same time,
the algorithm complexity of the challenge, CheckProof and TEESG stages is
low, the computation overhead is small, and they can be transplanted to TEE
for operation.

Table 2. Calculation overhead

Number of file blocks TagBlock GenProof CheckProof TEESG SigVer


m = 300 4128.768 9.899 9.662 133.128 11.933
m = 460 4128.768 15.179 14.782 133.128 11.933
Provable Data Auditing Scheme from Trusted Execution Environment 297

6 Conclusions
To eliminate the hidden danger of TPA stealing user data or conspiring with
cloud servers, we replace TPA with TEE and propose a data integrity auditing
scheme based on trusted execution environment. The scheme effectively improves
data security on the basis that data integrity can be verified. We use Intel SGX
technology to port the relevant cryptographic libraries to the Enclave and deploy
it in a real cloud server scenario for our PDP scheme experiments. The experi-
mental results show that the performance overhead incurred by TEE instead of
TPA is reasonable and can support applications in cloud storage environments.

References
1. Wang, Q., Wang, C., Ren, K., Lou, W., Li, J.: Enabling public auditability and data
dynamics for storage security in cloud computing. IEEE Trans. Parallel Distrib.
Syst. 22(5), 847–859 (2011)
2. Wang, L., Hu, M., Jia, Z., Guan, Z., Chen, Z.: Sstore: an efficient and secure
provable data auditing platform for cloud. IEEE Trans. Inf. Forensics Secur. 19,
4572–4584 (2024)
3. Ateniese, G., et al.: Provable data possession at untrusted stores. In: Proceedings of
the 14th ACM Conference on Computer and Communications Security, ser. CCS
’07. New York, NY, USA: Association for Computing Machinery, pp. 598–609,
2007. https://doi.org/10.1145/1315245.1315318
4. He, Y., Xu, Y., Jia, X., Zhang, S., Liu, P., Chang, S.: EnclavePDP: a general
framework to verify data integrity in cloud using intel SGX. In: 23rd International
Symposium on Research in Attacks, Intrusions and Defenses (RAID 2020). San
Sebastian: USENIX Association, pp. 195–208, October 2020. https://www.usenix.
org/conference/raid2020/presentation/he
5. Zhu, Y., Hu, H., Ahn, G.-J., Yu, M.: Cooperative provable data possession for
integrity verification in multicloud storage. IEEE Trans. Parallel Distrib. Syst.
23(12), 2231–2244 (2012)
6. Wang, H.: Proxy provable data possession in public clouds. IEEE Trans. Serv.
Comput. 6(4), 551–559 (2013)
7. Shen, J., Zhou, T., Chen, X., Li, J., Susilo, W.: Anonymous and traceable group
data sharing in cloud computing. IEEE Trans. Inf. Forensics Secur. 13(4), 912–925
(2018)
8. Shen, J., Yang, H., Vijayakumar, P., Kumar, N.: A privacy-preserving and untrace-
able group data sharing scheme in cloud computing. IEEE Trans. Dependable
Secure Comput. 19(4), 2198–2210 (2022)
9. Shen, J., Shen, J., Chen, X., Huang, X., Susilo, W.: An efficient public auditing
protocol with novel dynamic structure for cloud data. IEEE Trans. Inf. Forensics
Secur. 12(10), 2402–2415 (2017)
10. Ni, J., Zhang, K., Yu, Y., Yang, T.: Identity-based provable data possession from
rsa assumption for secure cloud storage. IEEE Trans. Dependable Secure Comput.
19(3), 1753–1769 (2022)
11. Deng, L., Wang, B., Wang, T., Feng, S., Li, S.: Certificateless provable data pos-
session scheme with provable security in the standard model suitable for cloud
storage. IEEE Trans. Serv. Comput. 16(6), 3986–3998 (2023)
298 Y. Zeng et al.

12. Shen, J., Zeng, P., Choo, K.-K.R., Li, C.: A certificateless provable data posses-
sion scheme for cloud-based ehrs. IEEE Trans. Inf. Forensics Secur. 18, 1156–1168
(2023)
13. Miao, Y., Huang, Q., Xiao, M., Susilo, W.: Blockchain assisted multi-copy provable
data possession with faults localization in multi-cloud storage. IEEE Trans. Inf.
Forensics Secur. 17, 3663–3676 (2022)
14. Wang, H., Wan, Z., He, D., Yu, J.: Synchronous blockchain-based distributed prov-
able data possession with forward-security. IEEE Trans. Serv. Comput. 17(3),
1227–1238 (2024)
15. Shen, J., Zhou, T., He, D., Zhang, Y., Sun, X., Xiang, Y.: Block design-based key
agreement for group data sharing in cloud computing. IEEE Trans. Dependable
Secure Comput. 16(6), 996–1010 (2019)
Enhanced PIR Scheme Combining
SimplePIR and Spiral: Achieving Higher
Throughput Without Client Hints

Haoyao Xu1,2 , Yitong Li1,2 , and Haibo Tian1,2(B)


1
School of Computer Science and Engineering, Sun Yat-sen University,
Guangzhou 510006, China
{xuhy85,liyt289}@mail2.sysu.edu.cn, [email protected]
2
Guangdong Province Key Laboratory of Information Security Technology,
Guangzhou 510006, China

Abstract. In recent years, Private Information Retrieval (PIR) schemes


capable of rapid response have gotten significant attention, with Sim-
plePIR, for instance, achieving a throughput of approximately 10 GB/s
in a single-threaded setting. However, these schemes often require users
to download data related to the database, known as client hints, mak-
ing them unsuitable for scenarios where the database undergoes frequent
updates. This paper, building upon the Spiral scheme, incorporates the
precomputation approach of SimplePIR to design a query method that
leverages simulatable homomorphic ciphertexts for offline precomputa-
tion. This enhancement not only improves the efficiency of the Spi-
ral scheme but also addresses the issue of SimplePIR’s necessity for
downloading client hints. Experimental results demonstrate that, on a
database with 220 entries of 256 bytes each, the proposed scheme achieves
a throughput 1.87 times higher than that of the original Spiral scheme.

Keywords: Private Information Retrieval · Precomputation ·


Simulatable Homomorphic Ciphertexts · SimplePIR · Spiral

1 Introduction
Private Information Retrieval (PIR) [7] is a cryptographic protocol that allows
a user to retrieve an item from a database server without revealing which item
is being retrieved. The core idea is to protect the privacy of the user’s query
from the database owner. PIR schemes are particularly useful in scenarios where
users need to access data without disclosing their interests, thus preserving their
privacy. In the realm of privacy preserving machine learning, the table lookup
method serves as a possible approach for nonlinear computations, thereby falling
within the application scope of PIR.
Recently, researchers have proposed SimplePIR [12], a PIR protocol with
extremely high throughput on the server side, and provided an example of its

c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 299–315, 2025.
https://doi.org/10.1007/978-981-96-4566-4_21
300 H. Xu et al.

application in Certificate Transparency, demonstrating the practicality of PIR


technology. However, Menon and Wu [19] pointed out that the SimplePIR pro-
tocol requires clients to pre-download client hints, which can lead to repeated
downloads of client hints when the database is frequently updated, thereby wast-
ing communication bandwidth. In response, they proposed a PIR scheme based
on key switching technique that enables silent preprocessing, eliminating the
need for communication between clients and servers during the offline phase.
Similarly, Henzinger et al. [11] and Li et al. [14] have also independently pro-
posed PIR schemes that do not require clients to download during the offline
phase, adopting the technique of uploading encrypted private keys instead.
We have observed that the Spiral scheme proposed by Menon and Wu [18] is
superior in terms of throughput, ranking second only to the SimplePIR scheme.
However, it has the advantage of not requiring clients to download any data.
Consequently, we try to integrate the precomputation technique from SimplePIR
into Spiral, with the aim of enhancing Spiral’s throughput while circumventing
the need for clients to download hints as required by SimplePIR.

1.1 Related Works


Chor et al. [7] introduced PIR in 1995 and presented a two-server PIR proto-
col. In 1997, Kushilevitz and Ostrovsky [13] proposed a single-server PIR based
on computational complexity assumptions. Since then, a series of works have
emerged aimed at reducing the communication cost [5, 6,10,15,21] and enhanc-
ing the throughput of single-server PIR [1, 3, 9,11,12,14,17–20]. Below, we will
introduce several works [11,12,14,18,19] that are closely related to our work.
In 2022, Menon and Wu [18] introduced Spiral, a scheme that integrates
the Regev encryption scheme with the Gentry-Sahai-Waters (GSW) encryption
scheme. This approach proposes two transformation methods: one that extends
the scalar Regev ciphertext (encrypting an element in a polynomial ring) to
matrix Regev ciphertext (encrypting a matrix over the same polynomial ring),
and another that converts scalar Regev ciphertexts into GSW ciphertexts. By
encoding row and column values into a single polynomial ring element, the band-
width required for query requests is reduced. On the server side, the coefficient
expansion algorithm is applied to generate two sets of scalar Regev ciphertexts,
one containing row query value and the other containing column query value.
The first set, representing row value, is transformed into a corresponding set
of matrix Regev ciphertexts. These matrix Regev ciphertexts are then used to
compute element-wise products with each column of the database, resulting in
a set of matrix Regev ciphertexts that collectively encode a single row of the
database. This set of matrix Regev ciphertexts undergoes a binary table lookup
operation, where the second set of scalar Regev ciphertexts (transformed into
GSW ciphertexts) serving as indices are used to retrieve the ciphertext contain-
ing the user’s queried data. Finally, modulus switching is employed to further
reduce the ciphertext length before returning it to the user. Spiral has achieved a
throughput record of 1.9 GB/s on their platform, which corresponds to a version
that adopts packing trick under the streaming setting.
Enhanced PIR Scheme Combining SimplePIR and Spiral 301

Spiral demonstrates outstanding performance in terms of communication


overhead, requiring only 14 KB for the request size. However, it does not incor-
porate a precomputation step, resulting in less competitive server throughput
compared to other schemes.
In 2023, Henzinger et al. [12] showed their SimplePIR and DoublePIR
schemes. Their schemes are based on Regev encryption over Z . In SimplePIR,
√ q√
N× N
the server-side database is organized as a matrix D ∈ Zp . During the
precomputation

phase, the database matrix is multiplied by a random matrix
of A ∈ Zq N ×n . The result D · A serves as client hints, which are downloaded
by the client. In the online phase, the client generates a random encryption key
and uses Regev
√ encryption to encrypt each element of a one-hot vector with
dimension N , creating a vector over Zq as the query request qu. The server
performs matrix-vector multiplication D · qu to obtain an online response. The
client then utilizes the offline-downloaded hints D · A and the online-obtained
response D · qu to recover a row of the database. The fundamental idea behind
DoublePIR is to perform SimplePIR twice, once on the database and once on the
hints generated during the precomputation, thereby reducing the amount of data
that the client needs to download offline. In the same year, the team published
Tiptoe [11], a private search engine, which adopted an enhanced SimplePIR pro-
tocol. This approach involves encrypting the client’s decryption private key using
a fully homomorphic encryption algorithm before uploading it to the server. By
enabling the server to perform the computation of client hints and the encrypted
private key, the issue arising from the client downloading hints is circumvented.
SimplePIR introduces a precomputation approach to improve server compu-
tation throughput and proposes DoublePIR to reduce the size of offline hints.
However, the hint size in DoublePIR scales linearly with the size of the database
entries. For databases with entries around 1 byte or smaller, the hint size in
DoublePIR is 16 MB. For larger database entries, the hint size increases further.
This results in excessively large communication overhead for DoublePIR when
dealing with databases containing larger entries.
In 2024, Li et al. [14] introduced Hintless PIR, which, in contrast to the
enhanced SimplePIR scheme in Tiptoe, clarifies that in a PIR scheme requir-
ing only linear computations, based on lattice-based hard problems, the ran-
dom components can be precomputed during operations such as homomorphic
multiplication, key switching, and rotation operation. As a result, although the
approach also involves encrypting the client’s decryption private key using a
fully homomorphic encryption algorithm before uploading it to the server, it
further reduces the client’s communication overhead. For an 8 GB database,
Hintless PIR reports a throughput of 6.37 GB/s on their platform. In the same
year, Menon and Wu proposed the YPIR [19], which also achieved the goal of
eliminating the need to download client hints by treating the output of the Dou-
blePIR as ciphertexts on a polynomial ring and further compressing multiple
ciphertexts into one through a packing algorithm on the polynomial ring. On an
8 GB database, YPIR reported a throughput of 11.6 GB/s on their platform.
302 H. Xu et al.

1.2 Contributions
The contributions of our work align with those of previous studies in [11,14,19]
in that we aim to eliminate the client-side download step required in the design
of SimplePIR, achieving efficient PIR without user downloads. However, our
approach differs significantly from existing methods. Specifically, we build upon
the Spiral protocol and integrate the design principles of SimplePIR into Spiral,
enabling us to further enhance Spiral’s throughput and eliminate the need for
client-side hints downloads. In detail:
1. We modify the steps in Spiral that retrieves a row of entries from the
database, reducing the server’s computation time and consequently improving
the throughput of the PIR scheme.
2. We implement the modified scheme and compare its performance with Dou-
blePIR and the original Spiral on the same platform. Our results confirm
that the modified scheme outperforms Spiral in terms of throughput and,
for frequently updated databases, surpasses DoublePIR in terms of overall
communication overhead.

2 Preliminaries
R
For a probability distribution χ, we use x ← − χ to indicate that x is a random
R
sample from χ. For a finite set S, we use x ← − S to denote sampling x uniformly
at random from S. Let λ be the security parameter. Let p and q be the plaintext
and the ciphertext modulus, respectively. For a positive integer m, we use [m]
to represent the integer set {1, 2, . . . , m}. For integers a, b ∈ Z, we write [a, b]
to denote the set {a, a + 1, . . . , b}. For a positive integer q ∈ N, we write Zq to
denote the integers modulo q. We mainly use bold uppercase letters to denote
matrices (e.g., A, B) and bold lowercase letters to denote vectors (e.g., u, v).
Let R = Z[x]/(xd + 1) be the cyclotomic ring where d is a power of two. For
a positive integer q ∈ N, we write Rq = R/qR. Let · be the rounding to the
nearest integer, · the floor operation and · the ceiling operation. Let || · ||∞
denotes the infinite norm.

2.1 Homomorphic Encryption Schemes


The ring learning with errors (RLWE) assumption is defined as follows, where
multiple RLWE samples have the same secret s.
Definition 1 (The RLWE assumption [4,16]). Let λ be a security parame-
ter, and let R = Z[x]/(xd + 1) where d = d(λ) is a power of two. Let m = m(λ)
be the number of samples, q = q(λ) be a ring modulus, and χ = χ(λ) be an error
distribution over Rq . The decisional RLWE assumption states that the following
two distributions are computationally indistinguishable:
c
(a, as + e) ≈ (a, u)
R R
where s ← χ, a ← Rqm , e ← χm , and u ← Rqm .
The matrix Regev is defined as follows [18].
Enhanced PIR Scheme Combining SimplePIR and Spiral 303

Definition 2 (Matrix Regev encryption [18]). Let λ be a security parame-


ter, and let R = Z[x]/(xd + 1) where d = d(λ) is a power of two. Let n = n(λ)
be the dimension, χ = χ(λ) be an error distribution over Rq , and q = q(λ) be
the ciphertext modulus. The matrix encryption scheme for matrice M ∈ Rpn×n
with a plaintext modulus p = p(λ) is defined as follows:

– KeyGen(1λ , 1n ): On input the security parameter λ and the message dimen-


(n+1)×n
sion n, sample s̃ ← χn and output S = [−s̃|In ]T ∈ Rq .
– Encrypt(S, M): On input the secret key S = [−s̃|In ]T and a matrix M ∈
R
Rpn×n , sample a ← Rqn , E ← χn×n , compute Δ = q/p and output the
ciphertext    1×n 
aT 0
C= + ∈ Rq(n+1)×n
s̃aT + E Δ·M

(n+1)×n
C ∈ Rq is a Regev ciphertext of M ∈ Rpn×n with error E ∈ Rqn×n and
(n+1)×n
secret key S ∈ Rq if ST C = Δ · M + E. To get M from Z = ST C, one
Z
could compute M =  Δ . If ||E||∞ + (q mod p) < Δ/2, M = M .


The following operations are supported by the Regev ciphertexts:


(n+1)×n
– C ← Add(C1 , C2 ): On input C1 , C2 ∈ Rq of matrices M1 and M1 with
(n+1)×n
errors E1 and E2 , output C = C1 + C2 ∈ Rq of matrix M1 + M2 with
erros E satisfying ||E||∞ ≤ ||E1 ||∞ + ||E2 ||∞ . If E1 and E2 are independent
 with standard deviations σ1 and σ2 , the standard deviation
and subgaussian
of E is σ = σ12 + σ22 .
(n+1)×n
– C ← ScalarMul(C1 , M2 ): On input C1 ∈ Rq of matrix M1 with errors
(n+1)×n
E1 and M2 ∈ Rq , output C = C1 · M2 ∈ Rq
n×n
of matrix M1 · M2 with
error E satisfying ||E||∞ ≤ d · n · ||E1 ||∞ · ||M2 ||∞ . If E1 is√subgaussian with
standard deviations σ1 , the standard deviation of E is σ = n · d·||M2 ||∞ ·σ1 .

When n = 1, the Regev ciphertext is called as scalar Regev. Assume the


RLWE assumption holds, scalar Regev is obviously sematic secure. This could
be extended to the matrix Regev case as shown in Spiral [18].

2.2 PIR Definition

A single-server PIR with hint is defined as follows, adapted from [12].


Definition 3 (Single-server PIR with hint).

– Chint ← HintGen(D, sd ): On input the public random seed sd and the


database D, the hintGen algorithm outputs preprocessed hint Chint on the
server.
– (qu, sk) ← Query(1λ , sd , idx): On input the security parameter λ, the public
seed sd , an index idx, the query algorithm outputs a query qu and a secret
key sk.
304 H. Xu et al.

– Cres ← Answer(D, qu, Chint ) On input the database D, a query qu, the hint
Chint , the answer algorithm outputs the rescaled ciphertext Cres as the server
response.
– d ← Extract(sk, Cres ): On input the secret key sk, the response Cres , the
extract algorithm outputs the database entry d ∈ D indexed by idx.

The PIR scheme should satisfy the following properties of correctness, secu-
rity, and efficiency.

Definition 4 (Correctness). For every λ ∈ N, all polynomials N = N(λ),


p = p(λ), n = n(λ) and all databases D = {d1 , . . . , dN } where each di ∈ Rpn×n ,
and idx ∈ [N], we require that
⎡ ⎤
Chint ← HintGen(D, sd )
⎢ (qu, sk) ← Query(1λ , sd , idx) ⎥
⎢ ⎥
Pr ⎢D[idx] = d : ⎥ = 1, (1)
⎣ Cres ← Answer(D, qu, Chint )⎦
d ← Extract(sk, Cres )

where the probability is taken over all possible sources of randomness employed
by the algorithms.
Definition 5 (Security). For every λ, N ∈ N, and idx0 , idx1 ∈ [N], define the
distribution

Dλ,N,idx := qu : (qu, sk) ← Query(1λ , sd , idx) , (2)

and for an adversary A, define the adversary’s advantage as

PIRadv[A](λ, N) := maxidx0 ,idx1 ∈[N] Pr[A(1λ , Dλ,N,idx0 ) = 1] − Pr[A(1λ , Dλ,N,idx1 ) = 1] . (3)

The PIR scheme is computationally secure if for every efficient adversary A, the
quantity PIRadv[A](λ, N) is a negligible function of λ and N.
Definition 6 (Communication Efficiency). A PIR scheme is efficient if the
total communication cost of a query is smaller than the size of the database D.

3 The Enhanced PIR Scheme


We use simulatable homomorphic ciphertexts to modify the row extraction pro-
cedure of Spiral. Damgård and Nielsen [8] first proposed the notion of simulatable
public key encryption. Informally, it is a traditional public key encryption sys-
tem with extra public key sampling and ciphertext generation algorithms. In
particular, ciphertexts produced by the additional algorithms are computation-
ally indistinguishable from those produced by a standard encryption algorithm.
This concept was initially applied in the security proof and recently has been
used to construct deniable encryption schemes [2]. We apply the concept to
homomorphic ciphertexts for PIR precomputations.
Enhanced PIR Scheme Combining SimplePIR and Spiral 305

3.1 Simulatable Homomorphic Ciphertexts

The Spiral scheme relies on the matrix Regev encryption. We observe that matrix
Regev ciphertexts are simulatable. We define an algorithm CipherGen as follows.

– C ← CipherGen(sd , q, n): It takes a random seed sd , an encrypting modu-


(n+1)×n
lus q, and a dimension parameter n, producing a ciphertext C ∈ Rq
whose components are random polynomials in Rq seeded by sd with a pseudo-
random function PRF.

To use a simulated ciphertext, we define a HalfDecrypt algorithm as follows:

– Z ← HalfDecrypt(S, C) : On input a Regev secret key S = [−s̃|In ]T and


T
, output half-plaintext Z =  SΔC  =
(n+1)×n
a simulated ciphertext C ∈ Rq

 MΔ+E
  ∈ Rq/Δ
n×n
.

If Δ > 2||E||∞ , then M = Z  = M /Δ . In other words, given an input key
S, a randomly generated Regev ciphertext C allows the derivation of a message
M.
A simulated ciphertext should be indistinguishable from a ciphertext gen-
erated by a normal encryption algorithm. To that end, we provide the game
GameA,Ψ (1λ , 1n ):

– The adversary A is given inputs λ and n, and outputs a plaintext M.


– A secret key S is generated by running KeyGen(1λ , 1n ) of the Regev encryp-
tion, and a random bit b ∈ {0, 1} is chosen. Ciphertexts C0 ← Encrypt(S, M)
and C1 ← CipherGen(sd , q, n) are computed respectively. The challenge
ciphertext Cb is given to A.
– A outputs a bit b .
– The ouput of the experiment is defined to be 1 if b = b, and 0 otherwise.

If for all efficient adversary A there is a negligible function negl such that,
for all λ ∈ N, it holds, that
1
Pr[GameA,Ψ (1λ , 1n ) = 1] ≤ + negl(λ), (4)
2
where the probability is taken over the randomness used by A and the random-
ness used in the GameA,Ψ (1λ , 1n ).
Obviously, if the RLWE assumption holds, then the Regev ciphertext is indis-
tinguishable from a ciphertext generated randomly according to a uniform distri-
bution. Assuming that the random numbers output by the PRF function follow
a uniform distribution, then the ciphertext generated by the CipherGen algo-
rithm is indistinguishable from a ciphertext generated randomly according to
a uniform distribution. Therefore, it can be concluded that the Regev cipher-
text is indistinguishable from the random ciphertext generated by the CipherGen
algorithm.
306 H. Xu et al.

3.2 The PIR Scheme

Fig. 1. An overview of the PIR.

As depicted in Fig. 1, we introduce modifications to the row extraction phase of


the Spiral. Specifically, we incorporate an offline phase where the HintGen algo-
rithm is executed to generate a hint Chint . Subsequently, in the online phase,
a user utilizes the Query algorithm to formulate a query, which encapsulates a
scalar Regev and a half-plaintext component concealing row and column values
respectively. Upon receiving this query, the server executes the Answer algorithm
to produce a matrix Regev that encrypts the requested data. Lastly, though not
illustrated in the figure, the user then employs the Extract algorithm to decrypt
the matrix Regev and retrieve the desired data. Notably, our primary work are
focused on the upper half of Fig. 1, while the rest has undergone parametric opti-
mizations and adjustments as necessary. Below, we outline the four algorithms
that comprise this scheme.

HintGen Algorithm. Algorithm 1 outlines the core procedures of the HintGen.


Its primary inputs consist of the security parameter λ, a database D and a
publicly available seed sd . The database D = {d1 , . . . , dN } comprises N = 2v1 +v2
entries, structured in Spiral format as a (v2 + 1)-dimensional hypercube with
dimensions 2v1 × 2 · · · × 2. To locate elements within this database, two indexing
Enhanced PIR Scheme Combining SimplePIR and Spiral 307

methods are applicable: (i, j1 , . . . , jv2 ), where i ∈ [0, 2v1 − 1] and each jk ∈ {0, 1}
for k ∈ {1, . . . , v2 }, or (i, j), where i ranges similarly, and j ∈ [0, 2v2 − 1]. Each
entry di in the database belongs to Rpn×n , ensuring that di ∞ ≤ p/2. The
public seed sd can be an arbitrary value, such as a concatenation of the server’s
name and a date, encoded as a bitstring. This seed can be periodically updated
according to a predefined schedule, allowing any user to compute it without
requiring direct communication with the server.

Algorithm 1: HintGen
Input: λ, the security parameter; D, the database; sd , the public seed;
Output: Chint , the result of processing the first (large) dimension of the
database.
1 for i = 0 to 2v1 − 1 do
2 Creg [i] ← CipherGen(1λ , sd ||i);
3 for j = 0 to 2v2 − 1 do
4 Chint [j] = ScalarMul(Creg [0], D[0][j]);
5 for i = 1 to 2v1 − 1 do
6 Chint [j] = Add(Chint [j], ScalarMul(Creg [i], D[i][j]));

7 return Chint ;

The output of HintGen is a ciphertext vector Chint with a dimensionality of


2v2 . Lines 1–2 generate a vector of random ciphertexts, denoted as Creg , which
has a dimensionality of 2v1 . Each ciphertext in Creg is produced by the CipherGen
algorithm using distinct seeds. The output, formed in lines 3–6, is constructed
as follows: As depicted in Fig. 1, each ciphertext in the final output is obtained
by performing an element-wise multiplication of the vector Creg with a column
vector from the database D, followed by summing the resulting products.

Query Algorithm. Algorithm 2 summarizes the process of producing a query.


Its primary inputs include the security parameter λ, a public seed sd , and the
index idx of the data to be retrieved, where column value is represented in binary.
The output comprises a half-plaintext Zreg containing row value associated with
idx, a scalar Regev ciphertext c encrypting the column value, and the conversion
key ck that facilitates converting the Regev ciphertext of the encrypted column
value into a GSW ciphertext.
308 H. Xu et al.

Algorithm 2: Query
Input: λ, the security parameter; sd , the public seed; idx = (i∗ , j1∗ , . . . , jv∗2 ),
the query index.
Output: The query qu = (Zreg , c, ck), a random secret key sk.
1 Sample a random secret key S ← KeyGen(1λ , 1n ) as sk for both matrix Regev
and GSW encryption;
2 for i = 0 to 2v1 − 1 do
3 Creg [i] ← CipherGen(1λ , sd ||i);
4 Zreg [i] ← HalfDecrypt(S, Creg [i]);
5 Zreg [i∗ ] ← (Zreg [i∗ ] − Δ/Δ  · In ) mod q/Δ ;
6 Pack and encrypt the tuple (j1∗ , . . . , jv∗2 ) as in the Spiral to produce a scalar
Regev c;
7 Generate the conversion key ck as in the Spiral to translate the scalar Regev
with s to the GSW with S;
8 Set qu ← (Zreg , c, ck);
9 return qu and sk;

The first line of the algorithm generates a key shared for both matrix Regev
and GSW encryption. Lines 2–4 produce random ciphertexts and then decrypt
them to obtain a vector of half-plaintext. Line 5 embeds the row value into
the half-plaintext, where Δ = q/p, and Δ is a configurable parameter that
determines the noise bound introduced by the HalfDecrypt operation. Lines 6–7
correspond to the operations depicted in the lower half of Fig. 1, adhering to the
procedures outlined in Spiral, with the exception that the packing and encrypt-
ing operations solely manipulate column value, allowing for parameter optimiza-
tions. Notably, the computation of the conversion key must be performed during
the online phase.

Answer Algorithm. The algorithm for answering a query is summarized in


Algorithm 3. The inputs to this algorithm include the database D, the user query
qu, and Chint which is generated during the server’s offline phase. The output
of the algorithm is Cres , which is formatted in the same way as in Spiral.

Algorithm 3: Answer
Input: D,the database; qu = (Zreg , c, ck), the query; Chint , the hint.
Output: The rescaled response Cres .
1 Let Zreg ← Zreg · Δ ;
2 for j = 0 to 2v2 − 1 do
3 Zhint [j] ← Zreg [0] · D[0][j];
4 for i = 1 to 2v1 − 1 do
5 Zhint [j] ← Zhint [j] + Zreg [i] · D[i][j];
6 hj ← Chint [j] − [0|Zhint [j]T ]T ;
7 Using h, the column query c, and the conversion key ck, the Sprial procedures
Coefficient Expansion, RegevToGSW, Folding and ModulusSwitch are
sequentially applied, yielding a rescaled response denoted as Cres , as shown in
Fig. 1;
8 return Cres ;
Enhanced PIR Scheme Combining SimplePIR and Spiral 309

The first line multiplies Zreg by Δ to restore the elements in Zreg back to Rq .
Following that, lines 2 through 5 perform similar element-wise multiplication and
addition operations as in the offline phase, with the distinction that here, the
element-wise multiplication only requires computing the product of two square
matrices of dimension n. Line 6 calculates the difference between the server’s hint
Chint and the Zhint generated during the online phase, and to align their dimen-
sions, padding is applied with a zero vector. Line 7 briefly outlines the process
within Spiral, which is depicted in the bottom-right section of Fig. 1.

Extract Algorithm. It is the same as the Extract algorithm in Spiral. The


secret key S generated in the Query algorithm is feeded as the decryption key.

3.3 Correctness and Security

As shown in Fig. 1, to establish correctness, it suffices to demonstrate that for


a user with a private key S, the vector h generated in line 7 of the Answer
algorithm is a valid matrix Regev encryption of the i∗ row from the database.
Here, i∗ is the input of the Query algorithm. This claim is formalized in the
following lemma:

Lemma 1. From the viewpoint of a user with a private key S, the vector h in
line 7 of the Answer algorithm constitutes a matrix Regev encryption of the i∗
row from the database, with a noise bounded by O(2v1 dnpΔ ), where i∗ is the
input of the Query algorithm.

Proof. Let D = {d1 , . . . , dN } be the database where di ∈ Rpn×n and take any
index idx = (i∗ , j1∗ , . . . , jv∗2 ). Sample Chint ← HintGen(D, sd ) and let (qu, sk) ←
Query(1λ , sd , idx), Cres ← Answer(D, qu, Chint ). We note that qu = (Zreg , c, ck)
and sk = S. We analyze the vector h from the simulated matrix Regev.
(n+1)×n
By definition, Creg [i] ∈ Rq for any i ∈ {0, . . . , 2v1 } is a simulated
matrix Regev ciphertext. From the view point of a user with private key S =
[−s̃ | In ]T , the random ciphertext looks like
   1×n 
aT 0
Creg [i] = i + ∈ Rq(n+1)×n
s̃aTi + Ei Mi

where 2||Ei ||∞ < Δ .


Then in the Query algorithm, a half-plaintext Zreg [i] for any i ∈ {0, . . . , 2v1 }
is defined as
ST Creg [i] Mi
Zreg [i] =   =   ∈ Rq/Δ
n×n
.
Δ Δ
The Query algorithm encode the index i∗ as

∗ Mi∗ Δ Mi∗ − ΔIn Mi∗ − ΔIn
Zreg [i ] =    −   In ∈   − 1,  +1 .
Δ Δ Δ Δ
310 H. Xu et al.

Now consider hj for any j ∈ {0, . . . , 2v2 −1}, it is a difference between the sum
of the set {Creg [i]·D[i][j]|i ∈ {0, . . . , 2v1 −1}} computed in the HintGen algorithm
and the sum of the set {Δ · Zreg [i] · D[i][j]|i ∈ {0, . . . , 2v1 − 1}} computed in the
Answer algorithm. For any i = i∗ ,

Creg [i] · D[i][j] − Δ · [0|Zreg [i]T ]T · D[i][j] = (Creg [i] − Δ · [0|Zreg [i]T ]T ) · D[i][j].

From the viewpoint of the user with private key S = [−s̃ | In ]T ,


 
aT
(Creg [i] − Δ · [0|Zreg [i]T ]T ) · D[i][j] = i
 · D[i][j],
s̃aT
i + Ei

where ||Ei ||∞ ≤ ||Ei ||∞ + 0.5Δ < Δ .


However, when i = i∗ , it is
   1×n 
aT 0
(Creg [i∗ ] − Δ · [0|Zreg [i∗ ]T ]T ) · D[i∗ ][j] = i∗
 + · D[i∗ ][j].
s̃aT
i∗ + Ei∗ ΔIn

where ||Ei∗ ||∞ ≤ ||Ei∗ ||∞ + 2Δ < 2.5Δ .


Then it is clear that hj is the matrix Regev encryption of D[i∗ ][j] from the
viewpoint of the special user with the special private key S. So h in line 7 of the
Answer algorithm constitutes a matrix Regev encryption of the i∗ row from the
database.
From the viewpoint of the user with the private key S, hj is a sum of ScalarMul
results. The noise in each hj is d · n · (2.5Δ ) · (2v1 p/2), which is bounded by
O(2v1 dnpΔ ).

Similar to correctness property, for the security property, we only need to


consider the privacy of row value. Based on Definition 6, we present the following
lemma.

Lemma 2. If the CipherGen algorithm is modeled as a random oracle and the


matrix Regev is simulatable, then the enhanced PIR scheme is computationally
secure.

Proof. Let A be an efficient adversary, and CipherGen a random oracle. Given


a public seed sd , a counter i, the matrix Regev modulus q and a dimension
parameter n, this oracle returns a ciphertext C. Our goal is to prove that the
view of Dλ,N,idx0 is computationally indistinguishable from that of Dλ,N,idx1 . We
will demonstrate this through a sequence of games. We use Zreg [i∗0 ] to represent
the row encoding of idx0 , and Zreg [i∗1 ] to represent the row encoding of idx1 .

– Game 0: This game follows the Query algorithm with inputs (λ, sd , idx0 ).
However, since CipherGen is modeled as a random oracle, random ciphertexts
must be obtained by querying the oracle. The oracle simple runs the CipherGen
algorithm to get a response.
Enhanced PIR Scheme Combining SimplePIR and Spiral 311

– Game 1: This game is identical to Game 0, except that the oracle operates
differently. The oracle gets a secret key S ← KeyGen(1λ , 1n ). Then for an
oracle query (sd ||i, q, n), it samples a random message Zreg,i ∈ Rq/Δ n×n
  and

returns a ciphertext Creg [i] ← Encrypt(S, Zreg,i ) with a restriction that Δ is


used instead of Δ. Due to the simulatability of matrix Regev, this game is
computationally indistinguishable from Game 0.
– Game 2: This game is identical to Game 1 except that in the Query algo-
rithm, Zreg [i∗0 ] = Zreg,i∗0 . That is, the diagonal elements is not changed at all.
Note in Game 1, Zreg [i∗0 ] = Zreg,i∗0 − Δ/Δ In mod q/Δ . Since Δ/Δ In
is constant, the distribution of Zreg [i∗0 ] is identical to Zreg,i∗0 . Then, this game
is indistinguishable from Game 1.
– Game 3: This game is identical to Game 2 except that in the Query algo-
rithm, Zreg [i∗1 ] = Zreg,i∗1 − Δ/Δ In mod q/Δ . Due to the identical dis-
tribution of Zreg [i∗1 ] and Zreg,i∗1 , Game 3 is computationally indistinguishable
from Game 2.
– Game 4: This game follows the Query algorithm with inputs (λ, sd , idx1 )
and the oracle runs the CipherGen algorithm. Similar to Game 1, this game
is computationally indistinguishable from Game 3.

Now it is clear the view of Dλ,N,idx0 is computationally inditingushable from the


view of Dλ,N,idx1 .

The communication efficiency property will be clearly shown in the next


performance analysis section.

Fig. 2. Comparison of the amortized communication between our scheme and Dou-
blePIR, illustrating the relationship between the number of queries per update and the
corresponding amortized communication. The update occurs once for every specified
number of queries, and the graph shows the resulting communication overhead during
this period, for a 225 × 256B database.
312 H. Xu et al.

4 Performance Analysis
Based on the code of Spiral, we implemented the improved PIR and conducted
performance tests focusing on communication and throughput. The test platform
was equipped with an Intel Xeon(R) Gold 5218R CPU, 512 GB of RAM, and
Ubuntu 22.04 as the operating system. The PIR was primarily implemented
using C and Go languages, with Clang 12, Go 1.18.1, and GCC 11.4 serving as
the compiling tools. Both our scheme and the comparative schemes leveraged
SIMD instruction sets when possible. The performance data were obtained under
single-threaded conditions, and the results presented in the paper are averages
taken after running the tests five times or more.
For the comparison, we employed the Spiral scheme and DoublePIR. Both
of these schemes return to the client only the content that the client wishes to
query, excluding any additional value. This feature is vitable for PIR schemes
that prioritize server privacy. In terms of database configurations, we utilized two
databases resembling Spiral’s architecture: one is a 256 MB database with an
entry size of 256B, and the other is an 8 GB database with an entry size of 32 KB.
However, DoublePIR’s offline download volume proved excessively high when the
number of database entries is relatively small. Consequently, we introduced an
additional 8 GB database with an entry size of 256B, which effectively reduced
the offline download volume to less than the database’s capacity. We note that
schemes such as HintlessPIR and YPIR [14,19] are built on DoublePIR and
share the same limitation, thus we have not conducted a comparison of their
performance. Note that the YPIR claims to support large entry sizes based
on simplePIR, but this approach leaks more information to clients than they
originally queried.
The comparison results are presented in Table 1. For each scheme, we compare
them across several dimensions: the user’s offline download volume, the size of
online query and answer, server answer time, and server throughput. The size of
online query and answer represents the communication overhead of the schemes,
while the server throughput and answer time reflect the computational load on
the server side.
Table 1 clearly demonstrates the following facts:

– Under the three database configurations, the throughput of the enhanced


scheme outperforms that of the Spiral scheme. Specifically, for the 256 MB
database, the throughput of the improved scheme is 1.87 times that of the
Spiral scheme.
– Compared to Spiral, our scheme requires the client to send more data per
query, thereby imposing a requirement on network bandwidth. Considering
the server’s answer time, it can be observed that our scheme only gains prac-
tical advantages when the network bandwidth exceeds 27MB/s.
Enhanced PIR Scheme Combining SimplePIR and Spiral 313

Table 1. Comparison results of Spiral [18], DoublePIR [12], and our scheme.

Database Metric Spiral DoublePIR Ours


20
2 × 256B (256 MB) Download 0 3.5 GB 0
Query 14 KB 256 KB 21 MB
Answer 21 KB 7 MB 21 KB
Answer Time 1.65 s 0.057 s 0.88 s
Throughput 155 MB/s 4.4 GB/s 291 MB/s
225 × 256B (8 GB) Download 0 3.6 GB 0
Query 14 KB 345 KB 21 MB
Answer 21 KB 7 MB 21 KB
Answer Time 33.75 s 1.61 s 30.48 s
Throughput 243 MB/s 5.0 GB/s 269 MB/s
218 × 32KB (8 GB) Download 0 462 GB 0
Query 14 KB 341 KB 21 MB
Answer 84 KB 925 MB 84 KB
Answer Time 22.67 s 1.75 s 20.33 s
Throughput 361 MB/s 4.6 GB/s 403 MB/s

– For our enhanced scheme, applied to real-time systems requiring continu-


ous dynamic updates, the communication overhead remains relatively sta-
ble. When the database is updated once for every 200 queries on average,
the amortized communication of our scheme outperforms that of DoublePIR,
as shown in Fig. 2. Specifically, for a database configured as 225 × 256B (8
GB), DoublePIR requires downloading 3.6GB of data whenever the database
entries are updated, whereas our scheme eliminates the need for such down-
loads entirely. This means the client in DoublePIR effectively downloads the
entire database once every two updates, making our enhanced scheme signif-
icantly more practical in terms of amortized communication.
– Compared to DoublePIR, our scheme achieves lower communication overhead
when the database entries are larger. This is mainly reflected in the offline
download size for large entries. In scenarios with larger database entries, the
offline download size required by DoublePIR increases linearly and may even
exceed the size of the database itself. For example, with a database configured
as 220 × 256B (256 MB), DoublePIR requires an offline download size of 3.5
GB, whereas our scheme requires no data downloads during the offline phase,
as shown in Table 1 and Fig. 3.
314 H. Xu et al.

Fig. 3. The behavior of DoublePIR and ours scheme when processing a database with
225 entries, and how the offline download size changes with the variation in entry size.

Lastly, we present some operational parameters for Spiral and the improved
scheme. The improved scheme employs identical parameters as the Spiral scheme,
with the additional introduction of Δ = 216 . For the three different database
configurations, the number of plaintext modulus bits log p = 8, the number
of ciphertext modulus bits log q = 56, the polynomial degree d = 2048 and
the matrix dimension of data is n = 2. For different database configurations,
while the first dimension v1 = 9, the second dimension v2 has values 6, 11,
9, respectively. The parameters of DoublePIR for the three different database
configurations have the same log q = 32 and n = 1024. For different database
configurations, the number of plaintext modulus bits are about 10, 9 and 9
bits, and the dimensions of database analog to (v1 , v2 ) are (11, 16), (16, 16) and
(16, 16).

5 Conclusion

In this paper, we introduce precomputation of SimplePIR into Spiral and imple-


ment it based on simulatable homomorphic ciphertexts, providing detailed pre-
computation algorithms, online query and answer algorithms. We analyze the
correctness and security of these algorithms and compare them with existing
schemes. Our comparison results show that when the database entry size is
large, the DoublePIR approach suffers from issues in communication efficiency,
and under such circumstances, adopting the improved Spiral scheme can achieve
better throughput metrics. The primary limitation of our scheme lies in the
large amount of data uploaded by the client for each request, which necessitates
a high-bandwidth scenario to achieve an overall reduction in query time. For
future work, we aim to utilize the packing technique of Spiral in a streaming
setting to achieve better throughput.
Enhanced PIR Scheme Combining SimplePIR and Spiral 315

Acknowledgement. This work is supported by the National Key R&D Program of


China under Grant No. 2022YFB2701500 and Guangdong Provincial Key Laboratory
of Information Security Technology (No.2023B1212060026).

References
1. Ali, A., et al.: Communication-computation trade-offs in pir. In: USENIX Security,
pp. 1811–1828 (2021)
2. An, Z., Tian, H., Chen, C., Zhang, F.: Deniable cryptosystems: simpler construc-
tions and achieving leakage resilience. In: ESORICS, pp. 24–44. Springer (2023)
3. Angel, S., Chen, H., Laine, K., Setty, S.: Pir with compressed queries and amortized
query processing. In: IEEE S&P, pp. 962–979 (2018)
4. Applebaum, B., Cash, D., Peikert, C., Sahai, A.: Fast cryptographic primitives
and circular-secure encryption based on hard learning problems. In: CRYPTO,
pp. 595–618. Springer (2009)
5. Cachin, C., Micali, S., Stadler, M.: Computationally private information retrieval
with polylogarithmic communication. In: EUROCRYPT, pp. 402–414. Springer
(1999)
6. Chang, Y.C.: Single database private information retrieval with logarithmic com-
munication. In: ACISP, pp. 50–61. Springer (2004)
7. Chor, B., Goldreich, O., Kushilevitz, E., Sudan, M.: Private information retrieval.
In: FOCS, pp. 41–45 (1995)
8. Damgård, I., Nielsen, J.B.: Improved non-committing encryption schemes based
on a general complexity assumption. In: CRYPTO, pp. 432–450. Springer (2000)
9. Davidson, A., Pestana, G., Celi, S.: Frodopir: simple, scalable, single-server private
information retrieval. Proc. Priv. Enhanc. Technol. (2023)
10. Gentry, C., Ramzan, Z.: Single-database private information retrieval with constant
communication rate. In: ICALP, pp. 803–815. Springer (2005)
11. Henzinger, A., Dauterman, E., Corrigan-Gibbs, H., Zeldovich, N.: Private web
search with tiptoe. In: SOSP, pp. 396–416 (2023)
12. Henzinger, A., Hong, M.M., Corrigan-Gibbs, H., Meiklejohn, S., Vaikuntanathan,
V.: One server for the price of two: simple and fast single-server private information
retrieval. In: USENIX Security, pp. 3889–3905 (2023)
13. Kushilevitz, E., Ostrovsky, R.: Replication is not needed: single database,
computationally-private information retrieval. In: FOCS, pp. 364–373 (1997)
14. Li, B., Micciancio, D., Raykova, M., Schultz-Wu, M.: Hintless single-server private
information retrieval. In: CRYPTO, pp. 183–217. Springer (2024)
15. Lipmaa, H.: An oblivious transfer protocol with log-squared communication. In:
ISC, pp. 314–328. Springer (2005)
16. Lyubashevsky, V., Peikert, C., Regev, O.: On ideal lattices and learning with errors
over rings. In: EUROCRYPT, pp. 1–23. Springer (2010)
17. Melchor, C.A., Barrier, J., Fousse, L., Killijian, M.O.: Xpir: private information
retrieval for everyone. Proc. Priv. Enhanc. Technol. 155–174 (2016)
18. Menon, S.J., Wu, D.J.: Spiral: fast, high-rate single-server pir via fhe composition.
In: IEEE S&P, pp. 930–947 (2022)
19. Menon, S.J., Wu, D.J.: Ypir: high-throughput single-server pir with silent prepro-
cessing. In: USENIX Security, pp. 5985–6002 (2024)
20. Mughees, M.H., Chen, H., Ren, L.: Onionpir: response efficient single-server pir.
In: ACM CCS, pp. 2292–2306 (2021)
21. Park, J., Tibouchi, M.: Shecs-pir: somewhat homomorphic encryption-based com-
pact and scalable private information retrieval. In: ESORICS, pp. 86–106. Springer
(2020)
A Two-Stage Image Blind Inpainting Algorithm
Based on Gated Residual Connection

Yiting Pan and Xuefeng Zhang(B)

Xi’an University of Posts and Telecommunications, Xi’an 710121, China


[email protected], [email protected]

Abstract. To solve the problem that image inpainting methods need to provide
damaged images and mask images, which limits the use of scenarios, a two-
stage image blind inpainting algorithm is proposed based on gated convolutional
residual blocks. The algorithm consists of two modules: mask prediction module
and image inpainting module. The mask prediction module predicts potential
visually inconsistent regions in a given image and outputs a predicted mask; The
image inpainting module repairs the damaged area of the image according to the
prediction mask and the context of the original image. In order to effectively utilize
the contextual information of the image, the algorithm adopts gated convolution in
the residual block and introduces a discriminator in the mask prediction module,
effectively improving the accuracy of predicting masks. Experimental results on
Places2 and CelebA-HQ datasets show that the image repair performance of the
proposed algorithm is better than that of the comparison method, and reliable
repair images are successfully generated.

Keywords: Image Blind Inpainting · Gated Convolution · Convolutional Neural


Networks · Residual Block · Generative Adversarial Networks

1 Introduction
With the increasing popularity of multimedia data applications such as digital images,
video and audio, information processing technologies for multimedia data such as digital
images are also becoming increasingly widespread.
The purpose of image inpainting is to restore damaged or missing image informa-
tion, generate visually continuous content, or achieve local modification and removal of
interference through image segmentation. With the continuous development of computer
and image processing technology, digital image inpainting technology has become an
emerging field. A large number of digital image inpainting methods [1–4] have been
proposed by researchers, which can improve the image inpainting effect to a certain
extent and enhance the efficiency of image inpainting.
Traditional image inpainting methods rely on mathematical and physical theories
to ensure image content similarity and texture consistency. These methods accomplish
the inpainting of small areas of damaged images by establishing geometric models or
using texture synthesis. Traditional image inpainting methods are categorized into two

© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 316–331, 2025.
https://doi.org/10.1007/978-981-96-4566-4_22
A Two-Stage Image Blind Inpainting Algorithm 317

types: patch-based image inpainting methods [1, 2] and diffusion-based image inpainting
methods [3, 4]. Traditional image inpainting methods are effective for images with high
texture and small area damage. However, when dealing with complex and extensively
missing images, the effectiveness of these methods is not ideal.
In recent years, with the increasing depth of research on Convolutional Neural Net-
works (CNN) and Generative Adversarial Networks (GAN), image inpainting methods
based on deep learning have become a research hotspot [5]. CNN [6–9] and GAN [10]
have shown strong processing capabilities and good processing results in image feature
learning and expression, and has been widely used in large-scale image processing.
The image inpainting method of deep learning learns image semantics in large-
scale datasets, generates new pixels in damaged areas, and does not rely on the original
image. According to whether a mask is required for input, deep learning image inpainting
methods can be divided into two types: non-blind image inpainting methods and blind
image inpainting methods. In the process of repairing damaged images using non blind
restoration methods, it is necessary to accurately mask the damaged areas. Drawing
masks is a complex task, and the accuracy of masks will directly impact the image repair
effect, making it less practical.
Therefore, this article proposes a two-stage blind restoration algorithm to solve
the problem of limited usage scenarios caused by drawing masks in image restoration
methods. The algorithm consists of two stages: mask prediction module and image
restoration module. The mask prediction module contains seven residual blocks and one
convolutional block, while retaining bottleneck information in the network, which can
provide more effective information for the next module. In order to improve the accuracy
of predicting masks, a discriminator was introduced to evaluate the predicted masks. The
predicted mask and damaged image are simultaneously input into the image restoration
module. After 14 residual blocks, the repaired image was obtained.

1.1 Related Work

The patch-based method is to search for the best matching patch in the visible area or
datasets of the image, and then copy this information to the missing area for filling.
Criminisi et al. [1] proposed a patch-based region filling and target removal repair algo-
rithm for eliminating unnecessary objects in images. A similar patch search algorithm
was proposed by Barnes et al. [2] that randomly matches patch blocks in an image and
continuously optimizes these matches to achieve structural image editing.
Diffusion-based method is also called variational partial differential equation (PDE)
based image inpainting algorithm. In this method, the image inpainting problem is trans-
formed into a functional solved variational problem by establishing the prior model and
data model of the image, and applying the knowledge of functional and total variation
in mathematics. Bertalmio et al. [3] proposed a PDE based diffusion repair method,
which considers the contextual information of missing parts in the image, predicts dam-
aged areas using pixels with known surrounding information, and performs image repair
based on the smoothness of local information. Chan et al. [4] proposed a total variation
regularization method that achieves image inpainting by minimizing the total variation
of the image.
318 Y. Pan and X. Zhang

In non blind image inpainting methods, Pathak et al. [11] proposed the Context
Encoder (CE), which can generate new pixels compared to traditional methods, but
can easily lead to excessive smoothing or edge loss. In order to solve the problem of
traditional convolution treating input pixels or features as equally effective, the concept of
gated convolution was proposed by Yu et al. [12] The core idea of gated convolution is to
allow masks to be automatically learned by the network, enabling it to clearly distinguish
between effective and invalid pixels. Even in deep networks, gated convolution can
independently learn in each channel, highlighting masked areas and sketch information,
and more effectively generate repair results.
The blind image inpainting method [13–26] only requires inputting damaged images,
omitting the step of mask drawing, and achieving automatic inpainting of visual content.
Wang et al. [13] proposed a Visual Consistency Network (VCN) that predicts semanti-
cally inconsistent regions in an image and generates corresponding damaged area masks.
By using spatial normalization methods to repair the predicted masks and damaged areas,
reducing the complex step of mask drawing. Zhao et al. [16] proposed a blind image
inpainting method using a single-stage Transformer-CNN Hybrid AutoEncoder. In addi-
tion, a Crosslayer Dissimilarity prompt (CDP) method has been designed to accelerate
the identification and remediation of contaminated areas.

1.2 Contributions

The contributions of the article are as follows:


1. Solved the problem of traditional image restoration methods requiring damaged and
masked images, thus breaking through the limitations of usage scenarios.
2. Introducing gated convolution into the residual block can fully utilize the contextual
information of the image and improve the restoration effect.
3. Adding a discriminator to the mask prediction module significantly improves the
accuracy of predicting masks.
4. The experimental results show that the image restoration performance of this algo-
rithm is superior to existing comparison methods in tests on Places2 and CelebA-HQ
datasets, successfully generating high-quality and reliable restoration images.

2 Preliminaries

CNN is a multi-layer neural network structure mainly used in computer vision tasks,
such as image classification, object detection, and image segmentation. The fundamental
components of a CNN encompass the data input layer, convolutional layer, activation
function layer, pooling layer, and fully connected layer.
The workflow of CNN is as follows: Firstly, the original image is transformed into a
data form with three dimensions: length, width, and number of channels through the data
input layer. Then, the image data undergoes multiple convolutions, activation functions,
and pooling operations to extract the feature information of the image. Finally, the
extracted feature maps are input into the fully connected layer for final classification or
other tasks.
A Two-Stage Image Blind Inpainting Algorithm 319

GAN, a deep learning model, revolves around the core concept of generating high-
quality samples through the interplay of two competing neural networks: a generator
and a discriminator. The training process improves the quality of samples generated
by the generator through competition between the generator and the discriminator. The
flowchart of GAN is shown in Fig. 1.

Fig. 1. GAN flowchart.

3 Materials and Methods


Building upon the preceding analysis, a two-stage image blind inpainting algorithm
is proposed based on gated convolutional residual blocks. This algorithm adopts joint
training of GAN and CNN to generate more accurate repair results. The whole network
structure includes two discriminators, aiming to improve the accuracy of predicted masks
and final repair results. By introducing residual blocks, an efficient deep neural network
can be trained. This network structure allows input to propagate faster forward through
residual connections between layers, effectively solving the problem of vanishing or
exploding gradients.
The network architecture adopted by this algorithm is shown in Fig. 2.

Fig. 2. GAN and CNN joint training model.


320 Y. Pan and X. Zhang

3.1 Residual Block

It is commonly held that enhancing the learning capacity of image inpainting models can
be achieved by augmenting the number of convolutional blocks. However, in practical
applications, performance improvement is not always achieved by increasing the number
of model layers, instead, some problems may be caused. This is because when the neural
network becomes very deep, the gradients in the backpropagation process may become
very small (gradient disappearance) or very large (gradient explosion), which can cause
weight updates to become very small or very large, making it difficult for the model
to learn effective representations. In addition, as the number of layers increases, the
parameters of the model will also increase significantly, leading to an increase in memory
utilization and making training more expensive.
The challenges encountered in deep neural networks are effectively mitigated by the
introduction of residual connections [27] within the framework of residual blocks [28].
These connections facilitate direct information transfer between layers, thereby enabling
information to skip certain layers and be directly conveyed from input to output. No
matter how many layers the model has, the layers with good learning performance are
retained, while the layers with poor performance are directly skipped through residual
connections. The effect of using residual blocks will not be worse than before, so it
is possible to steadily improve the model’s performance by increasing the number of
layers. The residual block structure is shown in Fig. 3 (b).
The formula for the residual block is as follows:

f(x) = g(x) + x (1)

In Eq. (1) f (x) is the output of the residual network, g(x) is the output of two
convolutions in the residual block, and x is the sample datasets.

Fig. 3. Convolutional block structure. In the convolutional block structure, (a) represents a regular
convolutional block, (b) represents a residual block, and (c) represents a residual block with
different input and output channels. In cases where the number of input and output channels is
different, a 1 × 1 convolution kernel is introduced, which has the same number of convolution
kernels as the number of output channels to ensure consistency in the number of channels.
A Two-Stage Image Blind Inpainting Algorithm 321

3.2 Mask Prediction Module


The main task of the mask prediction module is to locate discontinuous areas in the image
and output them in the form of masks. In order to improve the accuracy of repair results,
the predicted mask must have high accuracy. The residual block effectively supplements
information by introducing skip connections to achieve identity transformation. This not
only helps to solve the problems of gradient vanishing and exploding, but also effectively
prevents the collapse of the model.
In the subsequent stages of the network, the WGAN-GP discriminator [29] was
introduced. This process ensures that the mask predicted by the mask prediction module
is more accurate and effective in locating and describing discontinuous areas in the
image, resulting in more accurate and high-quality repair results.
In GAN [7], the goal is to achieve Nash equilibrium between the generator and
discriminator. However, as the training progresses, the discriminator improves and can
distinguish the output of the generator from real samples. When the discriminator is fixed
and the generator continues to train, the loss value of the generator approaches a specific
value, reflecting the distance between the generated data distribution and the true data
distribution. This indicates that the gradient of the generator becomes zero and cannot be
further optimized. The training of the discriminator must reach an intermediate state, but
this is difficult to control and varies among different models, which cannot effectively
ensure the output performance of the model.
Wasserstein GAN (WGAN) addresses this instability by introducing Wasserstein
distance, which provides a smoother, more continuous measure of the difference between
generated and real data distributions. To optimize this distance, WGAN-GP [18] adds
a Gradient Penalty (GP) term to maintain Lipschitz continuity in the discriminator,
preventing drastic changes and stabilizing training, improving the performance of the
generator. The discriminator structure is shown in Fig. 4.

Fig. 4. WGAN structure diagram

The WGAN-GP loss function is as follows:


Lgen = −Ez−p(z) [D(G(z))] (2)

Ldisc = Ex∼pdata [D(x)] − Ez∼p(z) [D(G(z))] + λgp Eẍ∼pẍ [(||∇ ẍD(ẍ)||2 − 1)2 ] (3)
In Eq. (2), Eq. (3) ẍ represents the linear interpolation between the generated image
and the real image, and λgp is the weight of the gradient penalty, with a value of 10. pẍ
represents the distribution of uniform sampling between generated and real samples.
322 Y. Pan and X. Zhang

The Adversarial loss comprises the generator loss and discriminator loss, with
the overall goal of establishing a game between the generator and discriminator. This
dynamic aims to empower the generator to generate more realistic samples, and the dis-
criminator can more accurately distinguish between generated samples and real samples,
namely:

Ladv = Lgen + Ldisc (4)

In Eq. (4) Ladv represents adversarial loss, Lgen represents generator loss, and Ldisc
represents discriminator loss.

3.3 Image Inpainting Module

The main task of the image inpainting module is to apply the predicted mask in the
mask prediction module to the damaged area to achieve the goal of image inpainting.
However, directly extracting features from images may result in incomplete or invalid
information. In order to effectively extract image features, a gated residual connection
method was adopted in this stage.
By introducing gated convolution into the residual blocks [12], the image inpainting
module can output more effective multi-scale feature information, consequently enhanc-
ing the accuracy of image inpainting. In order to obtain broader contextual information
in the image, some gated residual blocks are added with dilation factors to form dilated
gated residual blocks.
The gated convolutional structure is shown in Fig. 5.

Fig. 5. Gated convolutional structure diagram.

Traditional convolutions [11, 20] treat each pixel as valid for calculation, which
works well for classification and detection but fails in repair tasks, where distinguishing
between valid and invalid pixels is crucial. Gated convolution addresses this by using
learnable soft mask updates, allowing dynamic feature selection for each channel and
spatial position. This mechanism enables more flexible feature processing, improving
repair results. The formula for gated convolution is:

Gatingy,x = Wg · I (5)
A Two-Stage Image Blind Inpainting Algorithm 323


Featurey,x = Wf · I (6)

Oy,x = (Featurey,x )  σ (Gatingy,x ) (7)

Among them, I is the feature map, σ is the sigmoid function, so the output gate
value is between 0 and 1.  is the activation function, which can be ReLU, Tanh,
etc. Wg and Wf are two different convolutional filters, Oy,x represents performing two
different convolutions on I, and then multiplying the resulting feature map pixel by pixel.

3.4 Loss Function

During model training, five loss functions are used in this paper to make the generated
image closer to the real image: pixel by pixel reconstruction loss [21] Lrec , style loss
[22] Lstyle , perception loss [23] Lperc , ID-MRF loss [6] Lmrf , and adversarial loss [11]
Ladv . The specific explanations are as follows:
Pixel by pixel reconstruction loss is a loss function used to measure the difference
between the model generated image and the real image. Its goal is to minimize the
difference between the generated image and the real image, so that the generated image
is as close to the real data as possible.
The damage area loss [11] calculated the difference between the image generated by
the model in the damaged area and the real image. This loss is calculated by applying a
binary mask M in the damaged area.
 
Lhole = M  (Iout − Igt )1 (8)

Among them, Lhole represents the loss of damaged areas, Lvalid represents the loss of
non-damaged areas, Iout represents the image generated by the model, and Igt represents
the real image of the target.
The non-damaged area loss calculates the difference between the image generated by
the model and the real image in the non-damaged area. This part of the loss is calculated
by applying a binary mask 1-M in the unbroken area.
 
Lvalid = (1 − M )  (Iout − Igt )1 (9)

The pixel by pixel reconstruction loss Lrec is the sum of these two losses:

Lrec = Lhole + Lvaild (10)

The concept of style loss originates from the task of artistic style transfer, which
aims to make the generated image consistent in style with the real image.
1 
Gij = φ(G)ik φ(G)jk (11)
C k

1 
Rij = φ(R)ik φ(R)jk (12)
C k
324 Y. Pan and X. Zhang

where Gij represents the Gram matrix of the generated image, and Rij represents the
Gram matrix of the real image. i and j represent the rows and columns of the Gram
matrix, respectively, while C represents the number of channels in the feature map. φ(·)
represents the feature representation of the selected intermediate convolutional layer, k
represents the number of channels in the feature map of the convolutional layer, and N
represents the total number of pixels on the feature map of the convolutional layer. The
formula for style loss is:
1 
Lstyle = || Gij − Rij ||1 (13)
N2 i,j

The goal of perceptual loss is to measure the perceptual quality of generated images
by comparing the differences in feature representations between generated and real
images in some intermediate layers. The perceptual loss formula is:
1 
Lperc = || φ(G)i − φ(R)i ||1 (14)
Ni i

Among them, φ(G)i and φ(R)i represent the feature maps of the VGG19 network’s
relu3–2 and relu4–2 layers for generating and real images, respectively, and Ni represents
the total number of pixels on the i-th layer feature map.
The purpose of ID-MRF loss is to enhance the generation model’s ability to main-
tain structure during image generation, making the generated images more natural and
realistic.
4
Lmrf = LM (conv4_2) + LM (convt_2) (15)
t−3

The overall loss function of the mask prediction module is:

Ltotal = λr1 Lrec + λa Ladv (16)

The overall loss function of the image inpainting module is:

Ltotal = λr2 Lrec + λs Lstyle + λp Lprec + λm Lmrf + λa Ladv (17)

4 Experimental Simulation and Analysis


4.1 Test Environment
The model described in this article was implemented using PyTorch (version 1.13.0),
CUDA (version 11.6), and CUDNN (version 8302), running on a Windows 11 server with
NVIDIA GeForce RTX 3080 GPU. The datasets used in the experiment is Places2 [13]
and CelebA-HQ [28]. The Places2 datasets contains over 10 million images, covering
over 400 unique scene categories. In order to construct the training and testing sets,
30000 images were randomly selected as the training set and 500 images as the testing
set. The CelebA-HQ datasets contains 30000 images with a resolution of 1024 × 1024.
500 images were randomly selected as the test set, while the remaining images were
A Two-Stage Image Blind Inpainting Algorithm 325

used for training. To ensure consistent input image resolution, we downsampled the
CelebA-HQ datasets, resulting in an image resolution of 256 × 256 after sampling. To
ensure the fairness of the experimental results, all experiments were conducted using
the same training and testing sets. During the experiment, we used an Adam optimizer
with a learning rate of 1e-4 and a training period of T = 800000, λr1 = λp = 0.5, λa =
λm = 1e-3, λr2 = 1.4, λs = 1e-4.

4.2 Results

To assess the efficacy of the proposed algorithm, CE [11], LBAM [30], Vcnet [13]
and TransCNN-HAEwCDP [16] were selected for comparison with the method pro-
posed in this paper. To ensure the fairness of the experiment, both this experiment and
the comparative experiment used the same training and testing sets. Notably, Vcnet,
TransCNN-HAEwCDP and the proposed method in this paper all belong to blind repair
algorithms, while CE and LBAM belong to non-blind repair algorithms. CE addresses
damaged areas based on predefined rules without the requirement for masks. To ensure
the fairness of the experiment, the mask predicted in the mask prediction module of this
article is used as the input mask for LBAM. The repair results of the four methods are
shown in Figs. 6.
The performance of image inpainting algorithms is directly related to the quality after
inpainting. In this article, Three evaluation metrics, Mean Absolute Error (MAE) [13],
Peak Signal to Noise Ratio (PSNR) [30] and Structural Similarity Image Measurement
(SSIM) [6] were employed to evaluate the quality of the repaired image. The test results
are presented in Table 1 and 2.

(a) (b) (c) (d) (e) (f) (g)


Fig. 6. Test results of CelebA-HQ datasets and Places2 datasets. (a) damaged images; (b) repair
results of CE; (c) repair results of LBAM; (d) repair results of Vcnet; (e) repair results of
TransCNN-HAEwCDP; (f) repair results of ours; (g) real images.

From Fig. 6, it can be seen that for images with relatively simple structures, all five
comparison methods can generate reasonable content in the damaged area. However, for
images with complex textures, although there are still some shortcomings in the repair
effect of damaged areas, compared with the other four methods, the repair algorithm pro-
posed in this paper has made significant improvements in image texture and details. This
article adopts a method of predicting first and then repairing, which can more effectively
326 Y. Pan and X. Zhang

reconstruct lost information and make the repair results more natural and coherent. The
CE method uses reconstruction loss and adversarial loss to repair regular images, but the
capture of contextual semantic information is poor, making it impossible to reconstruct
appropriate semantic information and thus unable to perform effective repairs. LBAM
adopts a learnable attention mapping module to learn feature re normalization and mask
updating in an end-to-end manner. Compared with the CE method, it can effectively
repair irregular damaged areas. However, there are still significant edge responses and
color differences at the edges of damaged pixels. The Vcnet method can achieve good
repair results through the proposed blind repair network, but there are still some details
missing and visual artifacts. The TransCNN HAEwCDP method adopts a single-stage
Transformer CNN hybrid autoencoder and cross layer prompting method, which utilizes
local context to generate reasonable and realistic content for damaged areas without the
need for GAN. However, at the edges of complex areas, phenomena such as blurring and
artifacts may occur. This indicates that the method proposed in this article has certain
advantages in repairing images with complex textures.

Table 1. Comparison Results of the Average Values on the CelebA-HQ Test Set

Method CE LBAM Vcnet TransCNN-HAEwCDP ours


MAE(%) 10.229 8.671 3.026 1.485 1.340
PSNR 20.491 23.653 25.056 26.098 26.142
SSIM 0.677 0.710 0.880 0.901 0.901

Table 2. Comparison Results of the Average Values on the Places2 Test Set

Method CE LBAM Vcnet TransCNN-HAEwCDP ours


MAE(%) 11.990 9.812 4.092 2.014 1.935
PSNR 19.728 22.419 24.368 26.047 26.050
SSIM 0.585 0.688 0.688 0.857 0.894

A smaller MAE value indicates a smaller difference between the predicted and true
values of the model, suggesting a better fit of the model. On the other hand, a larger
PSNR value suggests less distortion between the repaired image and the real image,
indicating good image quality. Moreover, a larger SSIM value implies a higher degree
of similarity between the repaired image and the real image. Based on the findings
depicted in Tables 1 and 2, our method demonstrates superior performance in terms of
MAE, PSNR and SSIM compared to the other three methods in the CelebA-HQ dataset
and Places2 dataset.
A Two-Stage Image Blind Inpainting Algorithm 327

4.3 Ablation Experiments

A discriminator is introduced in the mask prediction module to enhance the accuracy


of predicting masks; At the same time, gated convolution is introduced into the resid-
ual block of the image inpainting module to improve the effectiveness. In order to
further explore the impact of these two factors on algorithm performance, two sets of
comparative experiments were conducted. Specific tests include:
1) Whether a discriminator is introduced in the mask prediction module
Regarding the experiment on whether discriminators are introduced in the mask pre-
diction module, the performance of two configurations with and without discriminators
is examined to further investigate the impact of discriminators on the algorithm.

Fig. 7. Mask prediction results. (a) represents real masks; (b) represents the predicted result of
adding a discriminator mask; (c) represents the predicted result without discriminator mask.

From Fig. 7, it can be seen that the accuracy of mask prediction is significantly
improved by the introduction of WGAN-GP discriminator in the mask prediction mod-
ule. Compared with networks without discriminators, the mask contours predicted by
introducing discriminators are clearer.
2) Whether gated convolution is introduced in the residual blocks
Gated convolution, as an improvement on traditional convolution, adopts learnable
soft-mask update rules that allow dynamic feature selection mechanisms to be learned
for each channel and spatial position, making the network more flexible in handling
irregular image inpainting tasks (Fig. 8).

Fig. 8. Plot of the results of gated convolution inpainting with or without. (a) represents the
damaged images; (b) represents the repaired image without gated convolution; (c) represents the
repaired image with gated convolution。
328 Y. Pan and X. Zhang

4.4 Discussion
The data in Fig. 10 are from the Places2 datasets and the CelebA-HQ datasets. The
simulation test results in Fig. 10 show that in some complex areas, the network without
introducing gated convolution has problems such as limited feature extraction, edge arti-
facts, and blurred repair results when dealing with irregular areas. In contrast, networks
that introduce gated convolution can more effectively capture the features of irregu-
lar regions, resulting in more realistic repaired images and significantly improving the
quality of the repaired results.
From the results of the two comparative experiments mentioned above, it can be
concluded that the algorithm proposed in this paper can generate images with more
realistic content and better visual effects. This indicates that in the process of image
generation, the algorithm can effectively capture subtle details and features, making the
repaired results more realistic.

4.5 Time Complexity and Space Complexity


Time complexity is mainly related to the number of operations in the network and the
size of the input image. The time complexity of a single convolution operation is:

Oconv = H ∗ W ∗ K 2 ∗ Cin ∗ Cout (18)

Among them, H and W represent the height and width of the input feature map, k
represents the size of the convolution kernel, Cin is the number of input channels, and
Cout is the number of output channels.
A residual block contains two convolutions Oconv , one activation function OELU , one
batch normalization OBatchNorm , and one skip connection OPr ojection . The mask prediction
module contains 7 residual blocks and one convolution block. The image inpainting
module contains 14 residual blocks, and the discriminator contains 5 convolution blocks.
The total spatial complexity is:

OELU = OBatchNorm = Oprojection = H ∗ W ∗ Cout (19)

OResBlock = 2 ∗ Oconv + Oprojection + OBatchNorm + OELU (20)

Ototal = 7 ∗ OResBlock + Oconv + 14 ∗ OResBlock + 5 ∗ Oconv (21)

The spatial complexity is related to the number of parameters in the model and the
storage requirements for intermediate feature maps during runtime.

Pconv = k 2 ∗ Cin ∗ Cout + Cout (22)

PPr ojection = Cin ∗ Cout + Cout (23)

PResBlock = 2 ∗ Pconv + PPr ojection (24)


A Two-Stage Image Blind Inpainting Algorithm 329

Among them, Pconv represents the number of parameters for conv, PPr ojection repre-
sents the number of parameters for skip connections, and PResBlock represents the number
of parameters for residual block.
The parameter quantity of the module is:

Ptotal = 7 ∗ PResBlock + Pconv + 14 ∗ PResBlock + 5 ∗ Pconv (25)

The total storage requirements include input, output, and intermediate feature maps.
The storage space for the total feature map of the residual block is:

MResBlock = 3 ∗ H ∗ W ∗ Cout (26)

The total feature map storage space of the convolutional block is:

Mconv = 2 ∗ H ∗ W ∗ Cout (27)

Total feature map storage space:

Mtotal = 7 ∗ MResBlock + Mconv + 14 ∗ MResBlock + 5 ∗ Mconv (28)

The overall spatial complexity is:

Ospace = O(Ptotal ) + O(Mtotal ) (29)

5 Conclusions

This paper proposes a two-stage blind repair algorithm using gated residual connec-
tions, consisting of a mask prediction module and an image inpainting module. The
damaged image first enters the mask prediction module, where a discriminator helps
improve mask accuracy. The predicted mask and damaged image are then processed by
the inpainting module, which replaces traditional convolutions with gated convolutions
to better utilize contextual information and generate realistic repairs. Experiments show
that the method predicts realistic masks without prior knowledge and produces visu-
ally plausible repaired images. The future work will explore techniques for restoring
high-resolution images. Additionally, considering the demands of practical application
scenarios, attention should also be paid to the computational efficiency and real-time
performance of the model.

Acknowledgments. This work was supported by Natural Science Basic Research Program of
Shaanxi (Program No. 2021JQ-722).

Disclosure of Interests. The authors declare that they have no known competing financial inter-
ests or personal relationships that could have appeared to influence the work reported in the
paper.
330 Y. Pan and X. Zhang

References
1. Criminisi, A., Pérez, P., Toyama, K.: Region filling and object removal by exemplar-based
image inpainting. IEEE Trans. Image Process. 13(9), 1200–1212 (2004). https://doi.org/10.
1109/TIP.2004.833105
2. Barnes, C., Shechtman, E., Finkelstein, A., et al.: PatchMatch: a randomized correspondence
algorithm for structural image editing. ACM Trans. Graph. 28(3), 24 (2009). https://doi.org/
10.1145/3596711.3596777
3. Bertalmio, M., Sapiro, G., Caselles, V., et al.: Image inpainting. In: Proceedings of the 27th
Annual Conference on Computer Graphics and Interactive Techniques, pp. 417–424 (2000).
https://doi.org/10.1145/344779.344972
4. Chan, T.F., Shen, J.: Nontexture inpainting by curvature-driven diffusions. J. Vis. Commun.
Image Represent. 12(4), 436–449 (2001). https://doi.org/10.1006/jvci.2001.0487
5. Hui, Z., Li, J., Wang, X., et al.: Image fine-grained inpainting, 2020. arXiv preprint arXiv:
2002.02609, https://doi.org/10.48550/arXiv.2002.02609
6. Wang, Y., Tao, X., Qi, X., et al.: Image inpainting via generative multi-column convolutional
neural networks. Adv. Neural Inf. Process. Syst. 31 (2018). https://doi.org/10.48550/arXiv.
1810.08771
7. Yildirim, A.B., Pehlivan, H., Bilecen, B.B., et al.: Diverse inpainting and editing with gan
inversion. In: Proceedings of the IEEE/CVF International Conference on Computer Vision,
pp. 23120–23130 (2023). https://doi.org/10.48550/arXiv.2307.15033
8. Xiao, Q., Li, G., Chen, Q.: Deep inception generative network for cognitive image inpainting
(2018). arXiv preprint arXiv:1812.01458, 1812.HUI Z, LI J, WANG X, et al. Image fine-
grained inpainting[J]. arXiv: 2002. 02609, 2020. https://doi.org/10.48550/arXiv.1812.01458
9. Zhao, L.L., Shen, L., Hong, R.C.: A review of research progress in image inpainting. Comput.
Sci. 48(03), 14–26 (2021). (in Chinese). https://doi.org/10.11896/jsjkx.210100048
10. Goodfellow, I., et al.: Generative adversarial nets. In: Advances in Neural Information Pro-
cessing Systems (NeurIPS), pp. 2672–2680 (2014). https://doi.org/10.3156/JSOFT.29.5_1
77_2
11. Pathak, D., Krahenbuhl, P., Donahue, J., et al.: Context encoders: feature learning by inpaint-
ing. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition,
pp. 2536–2544 (2016). https://doi.org/10.48550/arXiv.1604.07379
12. Yu, J., Lin, Z., Yang, J., et al.: Free-form image inpainting with gated convolution. In: Pro-
ceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4471–4480
(2019). https://doi.org/10.48550/arXiv.1806.03589
13. Wang, Y., Chen, Y.C., Tao, X., Jia, J.: VCNet: a robust approach to blind image inpainting. In:
Vedaldi, A., Bischof, H., Brox, T., Frahm, J.M. (eds.) Computer Vision – ECCV 2020. ECCV
2020. LNCS, vol. 12370, pp. 752–768. Springer, Cham (2020).https://doi.org/10.1007/978-
3-030-58595-2_45
14. Li, W.J., Zhou, Z.P., et al.: Research on blind image inpainting method based on context
semantic hierarchical reasoning. J. Nanchang Hangkong Univ. (Natl. Sci.) 36(02), 36–43
(2022). (in Chinese). https://doi.org/10.3969/j.issn.2096-8566.2022.02.006
15. Liu, D.B., Wang, H.Q., Wang, K., et al.: Blind inpainting method for incomplete sparse text
images based on content style transfer. Laser Optoelectron. Progress 59(24), 106–117 (2022).
(in Chinese). https://doi.org/10.3788/LOP202259.2411001
16. Zhao, H., Gu, Z., Zheng, B., et al.: Transcnn-hae: transformer-cnn hybrid autoencoder for blind
image inpainting. In: Proceedings of the 30th ACM International Conference on Multimedia,
pp. 6813–6821 (2022). https://doi.org/10.1145/3503161.3547848
17. Li, X., Wang, Z., Chen, C., et al.: SemID: blind image inpainting with semantic inconsistency
detection. Tsinghua Sci. Technol. 29(4), 1053–1068 (2024). https://doi.org/10.26599/TST.
2023.9010079
A Two-Stage Image Blind Inpainting Algorithm 331

18. Wang, J., Yuan, C., Li, B., et al.: Self-prior guided pixel adversarial networks for blind image
inpainting. IEEE Trans. Pattern Anal. Mach. Intell. (2023). https://doi.org/10.1109/TPAMI.
2023.3284431
19. Bai, Y., He, R., Tan, W., et al.: Fine-grained blind face inpainting with 3D face compo-
nent disentanglement. In: ICASSP 2023–2023 IEEE International Conference on Acoustics,
Speech and Signal Processing (ICASSP). IEEE, pp. 1–5 (2023). https://doi.org/10.1109/ICA
SSP49357.2023.10097082
20. Li, C.Y., Lin, Y.Y., Chiu, W.C.: Decontamination transformer for blind image inpainting.
In: ICASSP 2023–2023 IEEE International Conference on Acoustics, Speech and Signal
Processing (ICASSP). IEEE, pp. 1–5 (2023). https://doi.org/10.1109/ICASSP49357.2023.
10094950
21. Schmalfuss, J., Scheurer, E., Zhao, H., et al.: Blind image inpainting with sparse directional
filter dictionaries for lightweight CNNs. J. Math. Imaging Vis. 65(2), 323–339 (2023). https://
doi.org/10.1007/s10851-022-01119-6
22. Phutke, S.S., Kulkarni, A., Vipparthi, S.K., et al.: Blind image inpainting via omni-
dimensional gated attention and wavelet queries. In: Proceedings of the IEEE/CVF Con-
ference on Computer Vision and Pattern Recognition, pp. 1251–1260 (2023). https://doi.org/
10.1109/CVPRW59228.2023.00132
23. Gao, M., Kang, B.: Blind image inpainting using low-dimensional manifold regularization.
J. Circuits Syst. Comput. 31(12), 2250211 (2022). https://doi.org/10.1142/S02181266225
02115
24. Kim, H., Kim, C.I., Kim, H., et al.: Panoptic blind image inpainting. ISA Trans. 132, 208–221
(2023). https://doi.org/10.1016/j.isatra.2022.10.030
25. Gao, M., Kang, B., Feng, X., et al.: Low dimensional manifold regularization based blind
image inpainting and non-uniform impulse noise recovery. IEEE Access 8, 200551–200560
(2020). https://doi.org/10.1109/ACCESS.2020.3035532
26. He, S., Peng, X., Yuan, Z., et al.: Contour-context joint blind image inpainting network
for molecular sieve particle size measurement of SEM images. IEEE Trans. Instrum. Meas.
(2023). https://doi.org/10.1109/tim.2023.3279451
27. Russakovsky, O., Deng, J., Su, H., et al.: Imagenet large scale visual recognition challenge.
Int. J. Comput. Vis. 115, 211–252 (2015). https://doi.org/10.1007/s11263-015-0816-y
28. Pehlivan, H., Dalva, Y., Dundar, A.: Styleres: transforming the residuals for real image editing
with stylegan. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern
Recognition, pp. 1828–1837 (2023). https://doi.org/10.48550/arXiv.2212.14359
29. Gulrajani, I., Ahmed, F., Arjovsky, M., et al.: Improved training of wasserstein gans. Adv.
Neural Inf. Process. Syst. 30 (2017). https://doi.org/10.48550/arXiv.1704.00028
30. Xie, C., Liu, S., Li, C., et al.: Image inpainting with learnable bidirectional attention maps. In:
Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 8858–8867
(2019). https://doi.org/10.48550/arXiv.2104.12087
GAN-Based Adaptive Trigger Generation
and Target Gradient Alignment in Vertical
Federated Learning Backdoor Attacks

Kun Li, Hongyang Yan(B) , Jiatong Lin, Fan Chen, Yu Cheng,


and Dongyang Liang

School of Artificial Intelligence, Guangzhou University, Guangzhou, China


[email protected], {linjiatong,2112306111,32206500008}@e.gzhu.edu.cn

Abstract. Vertical Federated Learning (VFL) allows multiple parties


holding the same samples but different attributes to collaboratively train
models without directly sharing raw data. However, a passive party lack-
ing label information can still launch backdoor attacks, posing significant,
yet underexplored, security threats. This paper proposes an Adaptive
Federated Backdoor Framework (AFBF) for VFL, integrating dynamic
trigger generation and efficient gradient alignment. We introduce an
Adaptive Trigger Generation Network (ATGN), a GAN-based module
trained jointly with VFL to dynamically produce triggers, enhancing
stealthiness and flexibility, especially for multi-class tasks. Building on
label-replacement methods, we further propose Gradient-Feature Corre-
lation Attack (GFCA) to align poisoned features and target gradients,
achieving high attack success rates without altering labels. Extensive
experiments on multiple datasets show that AFBF, combining ATGN
and GFCA, achieves nearly 100% success even with very few target sam-
ples, outperforming existing mainstream methods. Our findings highlight
backdoor risks in VFL and underscore the urgency for robust defenses
in sensitive domains like finance and healthcare.

Keywords: Vertical Federated Learning · Backdoor Attack · Dynamic


Trigger · Gradient-Feature Correlation Attack · Security

1 Introduction
With the increasing demand for data privacy protection and multi-party col-
laboration, Federated Learning (FL) has become a research hotspot in the field
of machine learning [1]. Among them, Vertical Federated Learning (VFL), as
an important federated learning paradigm, allows multiple participants with the
same users but different features to collaboratively train a global model without
sharing raw data [2–7]. VFL has broad application potential in data-sensitive
fields such as financial risk control[8, 9]and medical diagnosis [10, 11].
Horizontal Federated Learning (HFL) and Vertical Federated Learning
(VFL) differ in data distribution: HFL participants share the same feature space
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 332–346, 2025.
https://doi.org/10.1007/978-981-96-4566-4_23
GAN-Based Adaptive Trigger Generation 333

but hold different data samples, while VFL participants share the same data
samples but hold different feature subsets. Due to this structure, VFL is more
vulnerable to backdoor attacks. Attackers can exploit feature interactions with-
out label information, enabling subtle malicious behaviors that influence the
global model. Hence, traditional HFL backdoor methods do not directly apply
to VFL, demanding specialized strategies.
Significant progress has been made in HFL backdoor attacks [12–14], but
research on VFL backdoors, especially those launched by passive parties with-
out label access, remains limited [15–18]. In VFL, passive parties only hold par-
tial features, cannot modify labels, and have limited access to others’ data and
models, making attacks more challenging. Figure 1 illustrates a VFL training
architecture. Exploring how passive parties can implement efficient backdoor
attacks under these constraints is crucial.

Fig. 1. An example of the backdoor attack to VFL.

Most existing VFL backdoor approaches rely on label inference, tampering,


or large auxiliary datasets, which is often impractical. Moreover, fixed triggers
allow defenders to repair compromised models by extracting these triggers from
a single backdoor input and fine-tuning the model [19]. Thus, a stealthy back-
door method with minimal label information and few target samples is urgently
needed.
In response, we propose the Adaptive Federated Backdoor Framework
(AFBF), comprising an Adaptive Trigger Generation Network (ATGN) and a
Gradient-Feature Correlation Attack (GFCA). Their combination yields a flex-
ible, efficient, and practical backdoor attack strategy for passive VFL partici-
pants.Specifically, ATGN, based on a Generative Adversarial Network (GAN),
334 K. Li et al.

dynamically generates triggers from random noise, adapting to model changes


and enhancing attack stealth. This dynamic approach avoids the limitations of
fixed triggers and suits complex, multi-class scenarios.
Meanwhile, GFCA replaces poisoned samples’ embedding features and gra-
dients with those of the target class, amplifying attack effectiveness without
altering labels. Together, ATGN and GFCA significantly improve the success
and subtlety of VFL backdoor attacks under constrained conditions.
We extensively evaluated our method on widely used public datasets (e.g.,
CIFAR-10 and CIFAR-100). The AFBF framework, integrating ATGN and
GFCA, achieved nearly 100% attack success in complex multi-class tasks using
as few as 50 target samples, with minimal impact on the model’s primary per-
formance. Compared to existing mainstream backdoor attack methods, AFBF
significantly improves both success rate and practicality.
The main contributions of this paper are summarised as follows:
– We propose an Adaptive Federated Backdoor Framework (AFBF) for VFL
that integrates dynamic trigger generation with target gradient alignment.
This flexible, efficient approach avoids label inference attacks and requires
minimal target label information, significantly enhancing VFL backdoor
attack practicality.
– Under the AFBF framework, we developed an Adaptive Trigger Generation
Network (ATGN) that uses a GAN structure to dynamically produce triggers
from random noise. Optimized alongside VFL training, ATGN significantly
enhances the flexibility and stealthiness of backdoor attacks, particularly in
multi-classification tasks.
– On multiple widely used datasets, AFBF (combining ATGN and GFCA)
achieves nearly 100% success in complex multi-class tasks, even with very few
target samples, surpassing existing mainstream backdoor attack methods.

2 Related Work
As Federated Learning (FL) continues to evolve, its security has garnered increas-
ing attention. Backdoor attacks, in particular, have emerged as a critical threat
and drawn widespread interest. However, in the realm of Vertical Federated
Learning (VFL), research on backdoor attacks is still in the early stages. Exist-
ing approaches often depend heavily on label inference or substantial auxiliary
information, which imposes significant limitations.

2.1 Backdoor Attacks in VFL


Currently, backdoor attack methods in VFL can be broadly categorized into
those occurring during the training phase and those emerging in the inference
phase. In training-phase attacks, adversaries embed malicious triggers directly
into the model as it learns, whereas inference-phase attacks involve crafting
adversarial inputs or modifications once the model is already deployed.
GAN-Based Adaptive Trigger Generation 335

In training-phase attack methods, Liu et al. [15] proposed the Label Replace-
ment Backdoor (GR) method, requiring a small number of target samples from
the training set. During VFL training, the attacker swaps local embedding fea-
tures and returns gradients of these target samples with those of poisoned data.
However, this method demands label information for each batch, making it less
practical as the dataset grows. Naseri et al. [16] introduced BadVFL, assuming
the attacker has limited labels from each class. The attacker performs semi-
supervised label inference and then implants subtle triggers. Yet, this approach
heavily relies on label inference and prior class knowledge, reducing its prac-
ticality. Xuan et al. [20] proposed a gradient-based label inference attack. The
attacker infers class information from gradient data, then replaces local inputs
and adds triggers. However, this depends on accurate gradient-based inference,
limiting its applicability in complex multi-class settings. Peng et al. [18] presented
TECB, but its trigger update relies on a fixed optimization strategy and uses
random noise replacement, which can destabilize training. He et al. [21] added
triggers directly to embedding features, requiring many known target labels and
thus diminishing stealthiness.
In inference-phase attack strategies, LR-BA [17] infers pseudo-labels for
training data and then refines the model with backdoor representations. Yet,
it requires labeled samples from each class, which is impractical for large-scale
VFL.TPGD [22] generates adversarial perturbations during the inference phase
by modifying the attacker’s local embedding features. However, it assumes the
attacker can alter labels in the active party, which is unrealistic in real-world
VFL environments.

2.2 Dynamic Backdoor Attacks


Backdoor attack techniques are rapidly advancing, especially in dynamic back-
door attacks, where triggers adapt to different samples. Such dynamic triggers
enhance stealthiness and increase attack success rates by adjusting to various
data characteristics.
In traditional backdoor attacks, researchers have explored dynamic, sample-
specific triggers. Nguyen et al. [19] designed a trigger generation network that
tailors unique triggers for each sample, greatly improving stealthiness and effec-
tiveness. However, their approach modifies the training process and loss function,
posing practical challenges. Salem et al. [23] refined this concept by enabling
dynamic trigger generation without altering the original training process, reduc-
ing deployment complexity.Li et al. [24] further advanced the field by proposing
a trigger generation network that does not require joint training with the target
model, simplifying implementation. They also employed image steganography,
making triggers nearly invisible to humans and detectors, thereby enhancing
stealthiness and practicality.
Despite these advances, current dynamic backdoor attacks often rely on data
label modifications or changes to the training process. In vertical federated learn-
ing, passive parties cannot alter labels or training procedures. This environment
restricts the use of existing dynamic backdoor methods, as participants only
336 K. Li et al.

control partial features and lack influence over the global training process. Con-
sequently, conventional dynamic backdoor strategies face significant challenges
in VFL scenarios.

3 Formulation and Threat Model


3.1 Problem Formulation
In a Vertical Federated Learning (VFL) framework, K participants (K > 2) col-
laboratively train a classification model. Among them, K − 1 are passive parties,
and one is the active party. Passive parties hold only subsets of the features,
while the active party accesses both features and labels. For simplicity, assume
the K-th participant is the active party. The training set is D = (xi , yi )i = 1N ,
where each sample’s feature vector xi = xki k = 1K is distributed among the K
participants, and yi ∈ 1, . . . , C are held by the active party. VFL’s objective is
to use all participants’ local data to train a global model.
During training, each participant k maintains a local model fk with param-
eters θk , producing local output Hk = fk (θk ; xk ). The active party also holds a
top model G with parameters θtop that combines all local outputs to minimize
the global loss L. Thus, Θ = θ1 , . . . , θK ; θtop .
For VFL backdoor attacks, poisoned samples should be drawn from multiple
classes to increase stealth. However, as passive parties cannot manipulate labels,
traditional label-based attacks do not apply.
To address this, we propose a dynamic trigger generation mechanism using a
small number of target class samples. This method employs a Generative Adver-
sarial Network (GAN) to create and refine triggers that adapt to the VFL train-
ing process without modifying labels. It effectively improves attack success rates
in complex multi-class tasks and leverages VFL’s distributed nature for more
stealthy attacks.

3.2 Threat Model


In this work, we consider the scenario where the attacker operates as a passive
participant in the VFL system, with all other entities assumed to be trustworthy.
Since the active party directly controls label information and can easily perform
backdoor attacks, our threat model focuses on scenarios where the passive party
is the attacker.
Attacker’s Capabilities: The attacker follows the communication protocols of
VFL and does not tamper with or manipulate the data and information trans-
mission of other participants. The attacker can only participate in model training
through local features and receive gradient information returned by the active
party. The attacker does not possess specific knowledge about other participants’
data and models and cannot directly access label data.
Attacker’s Goals: In VFL settings involving multi-class classification, the
attacker’s objective is to introduce a backdoor during the training process. This
GAN-Based Adaptive Trigger Generation 337

backdoor ensures that when the model is used for inference, any input containing
a specific trigger will be misclassified as a predetermined target class τ . At the
same time, the model’s accuracy on normal, trigger-free data remains unaffected.
Attacker’s Knowledge: In this type of attack, the attacker can carry out an
effective backdoor attack with access to only a minimal set of samples from
the target class. In our experiments, we utilized a limited number of target
class samples (e.g., only 50 target samples) and compromised the model by
dynamically creating triggers using the Adaptive Trigger Generation Network
(ATGN). Although this requirement for target samples differs from traditional
VFL scenarios, in practical applications, attackers can obtain a small number
of target samples through various means (e.g., data purchasing or accessing
public datasets). Additionally, the attacker has no further knowledge about other
participants’ data and models and can only rely on gradient information obtained
from the active party.

4 Methodology
This section details the dynamic trigger generation network method proposed for
backdoor attacks in Vertical Federated Learning (VFL). Figure 2 illustrates the
attack framework of the backdoor attack proposed in this paper. The method
comprises two core steps: Adaptive Trigger Generation Network (ATGN) and
Gradient-Feature Correlation Attack (GFCA). These two stages work collab-
oratively by generating dynamic triggers and performing feature and gradient
replacement to enhance the success rate of backdoor attacks.

Fig. 2. The framework of AFBF.

4.1 Adaptive Trigger Generation Network (ATGN)


In previous research [23], a trigger generation network was trained using train-
ing loss, but the method of updating based on training loss is not suitable for
338 K. Li et al.

the VFL scenario. Inspired by this, we propose a neural network-based trigger


generation mechanism applicable to VFL called the Adaptive Trigger Genera-
tion Network (ATGN). Unlike traditional gradient optimisation methods, ATGN
dynamically generates attack triggers δ by taking random noise as input and
continuously optimises the parameters of the generation network through the
received gradient information, achieving higher adaptability. ATGN is trained
synchronously with the VFL model to ensure that trigger generation gradually
adapts to the dynamic changes of the global model.
In each training round, ATGN first receives random noise z from a normal
distribution and generates a trigger δ through forward propagation:

δ = T (z; θTGN ) (1)

where T (z; θTGN ) is the trigger generation network with parameters θTGN . The
generated trigger δ is injected into the attacker’s local target data, and the local
model extracts embedding features to be sent to the active party for training.
During the VFL training process, the active party calculates the global loss
L using the top model G(H1 , H2 , . . . , HK ; θtop ) based on the embedding features
H1 , H2 , . . . , HK transmitted by all participants. Subsequently, the active party
computes the gradient of the loss with respect to each participant’s embedding
∂L
features ∂H A and returns the corresponding gradient information to the passive
i
∂L
parties. For the attacker, the received gradient ∂HiA
contains the effect of the
trigger δ on the local feature embedding HiA .
To extract the gradient related to the trigger, the attacker uses the chain rule
to decompose the gradient of the global loss L with respect to the trigger δ into
two parts: first, the gradient of the global loss with respect to the embedding
features HiA , and then the gradient of the embedding features with respect to
the trigger δ. This process can be expressed as:
M  
∂L 1  ∂L ∂HiA
= · (2)
∂δ M i=1 ∂HiA ∂δ

∂H A
where M represents the batch size, and ∂δi represents the gradient of the
∂L
embedding features ∂H A with respect to the trigger δ, which can be calculated
i
through backpropagation of the attacker’s local model fA (xA i + δ; θA ). Through
this step, the attacker can extract the effect of the trigger δ on the global loss L.
Once the trigger-related gradient ∂L∂δ is obtained, the attacker can use this
information to update the parameters θTGN of the trigger generation network
(ATGN). The specific parameter update process is performed through backprop-
agation, and its formula is:
(t+1) t
θTGN = θTGN − α · ∇θTGN L (3)

where α is the learning rate, and the gradient of the loss function L is derived
from the previously computed ∂L ∂L ∂δ
∂δ , i.e., ∇θTGN L = ∂δ · ∂θTGN . Through each
GAN-Based Adaptive Trigger Generation 339

training round, ATGN continuously optimizes its parameters so that the gener-
ated trigger δ can more effectively deceive the VFL model, causing samples with
the trigger to be misclassified into the target class during the inference phase.
Throughout the training process, ATGN continuously receives gradient infor-
mation from the active party and combines this information for trigger genera-
tion and parameter optimisation. As training progresses, the generated trigger δ
gradually adapts to changes in the model, enhancing the stealthiness and effec-
tiveness of the attack.
This dynamic trigger generation method has stronger adaptability compared
to traditional static optimisation, allowing flexible adjustment of the trigger
during training, thereby improving the stealthiness and effectiveness of backdoor
attacks.

4.2 Gradient-Feature Correlation Attack (GFCA)


In complex multi-classification tasks, relying solely on triggers generated from
a small number of target class samples may not guarantee a high success rate
for backdoor attacks. Therefore, inspired by Liu’s [15] label replacement back-
door attack, we propose the Gradient-Feature Correlation Attack (GFCA). By
performing feature and gradient replacement operations, GFCA gradually aligns
the features and gradients of poisoned data with those of the target class data.
When combined with ATGN, it achieves the same attack results as the origi-
nal label replacement attack that replaces 10% of known labels, using only the
existing labels.
In each training round, the attacker first selects a certain number of non-
target class samples from the current batch and generates poisoned data by
injecting the trigger δ. Next, the attacker performs feature replacement by replac-
ing the embedding features of the poisoned data with those of the target class
samples, ensuring that the features of both gradually align. This operation causes
the poisoned data to gradually behave like target class samples during training,
making them more likely to be misclassified into the target class during the
inference phase.
To further strengthen this alignment effect, the attacker performs gradient
replacement after receiving the gradient information returned by the active party.
Specifically, the attacker replaces the gradients of the poisoned data with those
of the target class samples and multiplies them by an amplification factor λ to
enhance the effect. Through this method, the poisoned data not only aligns with
the target class in feature space but also gradually approaches the target class
samples in gradient space, maximising the attack effect.
To gradually align the features and gradients, the optimization objective for
GFCA can be reformulated as follows:

min E(x,y)∼Dp L(F (x + T (z); Θ), τ ) (4)


Θ

Here, Θ represents the parameters of the global model, and the goal is to min-
imize the loss function L by adjusting these parameters. The loss function L
340 K. Li et al.

measures the difference between the model’s output and the target class τ . The
poisoned dataset Dp contains samples from non-target classes. The function T
refers to a trigger generation network trained during the ATGN phase, and it
generates triggers by applying random noise z. The term T (z) represents the
generated trigger from the noise z, which is added to the input data, producing
the modified input x + T (z). The expression F (x + T (z); Θ) denotes the output
of the global model when processing the poisoned data, and τ is the class the
attacker intends for the poisoned samples to be misclassified into. By optimizing
Θ, the attacker ensures that the poisoned data with the generated trigger is
classified into the target class τ , successfully executing a backdoor attack during
inference.
GFCA achieves its goal through two primary steps. The first step involves
feature substitution, where the attacker systematically replaces the embedding
features of poisoned data with those of the target class samples in a one-to-one
correspondence. This manipulation causes the poisoned data to resemble the
target class samples during training, thereby misleading the model’s classification
behavior. The second step is gradient substitution, where the attacker replaces
the gradients of the poisoned samples with the gradients from the target class
data and amplifies them by a factor of λ. This amplification accelerates the
alignment of the poisoned data with the target class in gradient space. This dual
strategy of aligning both features and gradients not only strengthens the attack’s
effectiveness but also significantly increases the likelihood that the poisoned data
will be misclassified into the target class during the inference phase.
This strategy performs particularly well in complex multi-classification tasks
because it allows the poisoned data to gradually align with the target class
samples in both feature and gradient spaces, thereby significantly improving the
success rate of backdoor attacks.

5 Experiments
To validate the effectiveness of our proposed Adaptive Federated Backdoor Frame-
work (AFBF) in Vertical Federated Learning (VFL), we conducted extensive
experiments on multiple public datasets. The selection criteria for these datasets
included diversity in class numbers, image complexity, and widespread use in the
research community to ensure the generalizability of our results. Additionally,
all datasets underwent standard preprocessing steps, including normalization of
pixel values to the [0,1] range, resizing images to 32×32 pixels if necessary, and
data augmentation techniques such as random cropping and horizontal flipping
to enhance model robustness. We compared AFBF with existing major backdoor
attack methods and performed an in-depth analysis of the impact of each compo-
nent and hyperparameters in the framework on the overall performance.

5.1 Experimental Setup


We selected three widely used image classification datasets: CIFAR-10 [25],
CIFAR-100 [25], and CINIC-10 [26]. CIFAR-10 contains 50,000 training images
GAN-Based Adaptive Trigger Generation 341

and 10,000 test images across 10 classes, with images evenly distributed among
classes and sized at 32×32 pixels. CIFAR-100 is similar to CIFAR-10 but con-
tains 100 classes, each with 600 images, totalling 60,000 images, with 50,000 for
training and 10,000 for testing. CINIC-10 has a total of 270,000 images, with
180,000 for training and 90,000 for testing, also divided into 10 classes. This
dataset is an enhanced variant of CIFAR-10, designed to assess the performance
of algorithms on more extensive and intricate classification challenges.
In our experiments, we implemented a two-party VFL configuration, consist-
ing of one passive party (serving as the attacker) and one active party. Following
the approaches outlined in [15, 18], each image was divided along its vertical
midline, with each participant retaining half of the features. The passive party
was assigned the left half of the image, while the active party received the right
half along with the associated label information. Both participants employed
ResNet-18 [27] as the base model to extract local features. The top model used
by the active party is a four-layer fully connected neural network, with each
layer followed by batch normalization and ReLU activation, which integrates
the embedding features from both parties to perform the classification task.
We designated the backdoor target classes as “car,” “Apple,” and “car” for the
CIFAR-10, CIFAR-100, and CINIC-10 datasets, respectively. Within the AFBF
framework, the attacker requires only a small subset of target class samples.
Specifically, we randomly selected 50 samples from the target class to form the
target sample set Dt . Additionally, 50 samples were randomly chosen from the
remainder of the training set to create the poisoned sample set Dp . The datasets
were split into training and testing sets with an 80–20 ratio to maintain consis-
tency across experiments. Furthermore, we ensured that the selected poisoned
samples did not overlap with the validation set to prevent data leakage. To
assess the attack effectiveness of AFBF, we applied the generated triggers to
all non-target class samples in the test set, constructing a backdoor test set for
evaluation.
To measure the model’s performance, we used two commonly used metrics:
Attack Success Rate (ASR) and Main Task Accuracy (MTA) [28]. ASR repre-
sents the proportion of samples in the backdoor test set that are misclassified as
the target class by the model; MTA measures the classification accuracy of the
poisoned model on the clean test set. It should be noted that the CIFAR-10 and
CINIC-10 datasets are evaluated using Top-1 accuracy, while due to the large
number of classes in CIFAR-100, we use Top-5 accuracy [29].
The experimental setup included a total of 100 training rounds to ensure suf-
ficient convergence of both the VFL model and the backdoor mechanisms. We
implemented early stopping based on validation performance to prevent overfit-
ting. The learning rate was set to 0.01 with a decay factor of 0.1 applied every
30 epochs. Batch size was maintained at 128 for all experiments.
In the comparison, we selected several typical VFL backdoor attack methods
as baselines, including both training-phase and inference-phase attack meth-
ods. The training-phase attack methods include GR [15] and BadVFL [16]; the
inference-phase attack method is LR-BA [17]. Additionally, we also compared it
342 K. Li et al.

with the standard VFL model without attack as the baseline. For fairness, we
adjusted the parameters of each method according to the settings in [17] and
[15]. For example, for GR and LR-BA, we set the number of poisoned samples
to be consistent with AFBF; for BadVFL, we selected the best combination of
original and target classes.
Each experiment was repeated five times independently to account for vari-
ability, and the results were averaged to ensure statistical significance. We also
conducted ablation studies to isolate the effects of ATGN and GFCA components
within the AFBF framework.

5.2 Performance Evaluation and Comparison


The experimental results are shown in Table 1. AFBF achieved the highest
Attack Success Rate (ASR) on all datasets, while the Main Task Accuracy
(MTA) only had a minimal decrease. Specifically, AFBF achieved ASRs of
99.18%, 98.46%, and 92.41% on the CIFAR-10, CINIC-10, and CIFAR-100
datasets, respectively, showing significant improvement over other methods;
MTAs were 81.94%, 67.94%, and 77.67%, respectively, only decreasing by 0.31%,
8.62%, and 1.66% compared to the baseline model.

Table 1. Comparison of MTA and ASR for Different Datasets and Attacks

Dataset MTA (%) ASR (%)


Baseline GR BadVFL LR-BA AFBF Baseline GR BadVFL LR-BA AFBF
CIFAR-10 82.25 74.89 76.00 80.05 81.94 2.14 58.00 90.00 98.13 99.18
CINIC-10 75.56 65.35 66.00 67.28 67.94 1.49 54.20 79.00 95.97 98.46
CIFAR-100 79.33 70.37 67.00 74.67 77.67 1.98 55.80 82.00 88.63 92.41

Compared with BadVFL, AFBF has a significant advantage in ASR. On


CIFAR-10, CINIC-10, and CIFAR-100, AFBF’s ASR is higher than BadVFL by
9.18%, 19.46%, and 10.41%, respectively. Moreover, AFBF only requires a small
number of target class samples, while BadVFL requires labelled samples from
each class, totaling a large amount of auxiliary data. In terms of MTA, AFBF
also outperforms BadVFL by 5.94%, 1.94%, and 10.67%, respectively.
Compared to GR, AFBF performs better in handling complex and large-scale
datasets. Since GR’s method requires poisoned data and a target label in each
batch, when we set the generated poisoned data to be consistent with AFBF for
fairness, its ASR and MTA are significantly lower than those of AFBF. AFBF
achieves more efficient backdoor attacks through the effective combination of
dynamic trigger generation and target gradient alignment.
Although LR-BA achieved relatively high ASR on some datasets, it requires
a large amount of auxiliary data, needing labelled samples from each class, which
is not practical in real-world applications. AFBF only needs a small number of
target class samples and still has an advantage in ASR.
GAN-Based Adaptive Trigger Generation 343

5.3 Attack Effectiveness Under Defense


To evaluate the robustness of AFBF, we implemented several typical backdoor
defence methods in the VFL framework, including Differential Privacy Gradi-
ent (DP-G) [27], Gradient Compression (GC) [8, 27], and Feature Purification
(CoPur) [22].

Fig. 3. Comparison of different datasets under attack methods

The experimental results are shown in Figure 3. These defence methods


reduced the ASR of AFBF to some extent but also significantly decreased the
344 K. Li et al.

main task accuracy. This indicates that existing defence methods have difficulty
effectively resisting AFBF attacks without sacrificing model performance. AFBF
still maintains high ASR and MTA in the face of these defence measures, demon-
strating its robustness against defences.

5.4 Ablation Study


To gain a deeper understanding of the impact of each component in the AFBF
framework on the overall performance, we conducted ablation experiments.
Specifically, we evaluated the attack effects when using only Adaptive Trigger
Generation (ATGN) and only Gradient-Feature Correlation Attack (GFCA).
The experimental results are shown in Table 2. When using only ATGN, the
ASR results fluctuate greatly with a high standard deviation, indicating that
relying solely on dynamic trigger generation does not provide a stable attack
effect. When using only GFCA, the ASR is low on all datasets, possibly because,
lacking initial trigger generation, GFCA cannot effectively align the gradients
and features of the poisoned data. The combination of both can fully leverage
their respective advantages to achieve stable and efficient backdoor attacks.

Table 2. Comparison of MTA and ASR across different datasets and methods

Dataset MTA ASR


ATGN-Only GFCA-ONLY AFBF ATGN-Only GFCA-ONLY AFBF
Cifar10 73.12 ± 4.34 63.79 ± 6.98 81.94 ± 1.03 75.78 ± 22.34 3.43 ± 0.89 99.18 ± 0.65
Cinic10 64.98 ± 3.11 66.44 ± 0.32 67.94 ± 2.38 70.23 ± 20.56 29.15 ± 11.35 98.46 ± 1.22
Cifar100 75.99 ± 2.53 75.92 ± 1.37 77.67 ± 1.98 63.89 ± 19.54 4.31 ± 1.33 92.41 ± 4.57

6 Conclusions

This paper proposes an Adaptive Federated Backdoor Framework (AFBF) for


Vertical Federated Learning (VFL). By combining an Adaptive Trigger Genera-
tion Network (ATGN) and a Gradient-Feature Correlation Attack (GFCA) strat-
egy, it achieves efficient backdoor attacks without modifying labels. Compared
with existing methods, AFBF requires only a small number of target class sam-
ples and utilizes Generative Adversarial Networks to dynamically generate trig-
gers, significantly enhancing the stealthiness and flexibility of backdoor attacks.
Experimental results demonstrate that AFBF achieves nearly 100% attack suc-
cess rates on multiple public datasets, with minimal impact on the model’s main
task performance, verifying its effectiveness in complex multi-classification tasks.
This study reveals potential backdoor attack risks in the VFL environment and
emphasizes the necessity of strengthening the security of VFL models in prac-
tical applications. Future work will focus on an in-depth analysis of AFBF’s
GAN-Based Adaptive Trigger Generation 345

attack mechanisms, theoretically exploring its influencing factors and investi-


gating effective defense strategies to counter such backdoor attacks, promoting
the secure and reliable development of VFL systems.

Acknowledgments. This work was supported by the National Natural Science Foun-
dation of China (No. 62372130, No. 62261160651) and the Guangzhou Basic and
Applied Basic Research Project (No. 2023A04J1725, No. 2024A03J0397).

References
1. McMahan, B., Moore, E., Ramage, D., Hampson, S., y Arcas, B.A.:
Communication-efficient learning of deep networks from decentralized data. In:
Artificial Intelligence and Statistics, pp. 1273–1282. PMLR (2017)
2. Cheng, K., et al.: Secureboost: a lossless federated learning framework. IEEE Intell.
Syst. 36(6), 87–98 (2021)
3. He, C., Annavaram, M., Avestimehr, S.: Group knowledge transfer: federated learn-
ing of large cnns at the edge. Adv. Neural. Inf. Process. Syst. 33, 14068–14080
(2020)
4. Hu, Y., Niu, D., Yang, J., Zhou, S.: Fdml: a collaborative machine learning frame-
work for distributed features. In: Proceedings of the 25th ACM SIGKDD Interna-
tional Conference on Knowledge Discovery & Data Mining, pp. 2232–2240 (2019)
5. Yang, L., Tianjian, C., Qiang, Y.: Secure federated transfer learning. In: Proceed-
ings of the 24th ACM SIGKDD International Conference on Knowledge Discovery
and Data Mining (2018)
6. Liu, Y., et al.: A communication efficient collaborative learning framework for
distributed features. arXiv preprint arXiv:1912.11187 (2019)
7. Yang, Q., Liu, Y., Chen, T., Tong, Y.: Federated machine learning: concept and
applications. ACM Trans. Intell. Syst. Technol. (TIST) 10(2), 1–19 (2019)
8. Kairouz, P., et al.: Advances and open problems in federated learning. Found.
Trends R Mach. Learn. 14(1–2), 1–210 (2021)
9. Zhang, Y., Chen, X., Liu, X., Zhu, N.: Exploring trust transfer between internet
enterprises and their affiliated internet-only banks: an adoption study of internet-
only banks in china. Chin. Manag. Stud. 12(1), 56–78 (2018)
10. Matschinske, J., et al.: The featurecloud platform for federated learning in
biomedicine: unified approach. J. Med. Internet Res. 25, e42621 (2023)
11. Kaissis, G.A., Makowski, M.R., Rückert, D., Braren, R.F.: Secure, privacy-
preserving and federated machine learning in medical imaging. Nat. Mach. Intell.
2(6), 305–311 (2020)
12. Bagdasaryan, E., Veit, A., Hua, Y., Estrin, D., Shmatikov, V.: How to backdoor fed-
erated learning. In: International Conference on Artificial Intelligence and Statis-
tics, pp. 2938–2948. PMLR (2020)
13. Baruch, G., Baruch, M., Goldberg, Y.: A little is enough: circumventing defenses
for distributed learning. Adv. Neural Inf. Process. Syst. 32 (2019)
14. Shejwalkar, V., Houmansadr, A.: Manipulating the byzantine: optimizing model
poisoning attacks and defenses for federated learning. In: NDSS (2021)
15. Zou, T., et al.: Defending batch-level label inference and replacement attacks in
vertical federated learning. IEEE Trans. Big Data (2022)
346 K. Li et al.

16. Naseri, M., Han, Y., De Cristofaro, E.: Badvfl: backdoor attacks in vertical fed-
erated learning. In: 2024 IEEE Symposium on Security and Privacy (SP), pp.
2013–2028. IEEE (2024)
17. Gu, Y., Bai, Y.: LR-BA: backdoor attack against vertical federated learning using
local latent representations. Comput. Secur. 129, 103193 (2023)
18. Chen, P., Yang, J., Lin, J., Lu, Z., Duan, Q., Chai, H.: A practical clean-label
backdoor attack with limited information in vertical federated learning. In: 2023
IEEE International Conference on Data Mining (ICDM), pp. 41–50. IEEE (2023)
19. Nguyen, T.A., Tran, A.: Input-aware dynamic backdoor attack. Adv. Neural. Inf.
Process. Syst. 33, 3454–3464 (2020)
20. Xuan, Y., Chen, X., Zhao, Z., Tang, B., Dong, Y.: Practical and general back-
door attacks against vertical federated learning. In: Joint European Conference on
Machine Learning and Knowledge Discovery in Databases, pp. 402–417. Springer
(2023)
21. He, Y., et al.: Backdoor attack against split neural network-based vertical federated
learning. IEEE Trans. Inf. Forensics Secur. (2023)
22. Liu, J., Xie, C., Koyejo, S., Li, B.: Copur: certifiably robust collaborative inference
via feature purification. Adv. Neural. Inf. Process. Syst. 35, 26645–26657 (2022)
23. Salem, A., Wen, R., Backes, M., Ma, S., Zhang, Y.: Dynamic backdoor attacks
against machine learning models. In: 2022 IEEE 7th European Symposium on
Security and Privacy (EuroS&P), pp. 703–718. IEEE (2022)
24. Li, Y., Li, Y., Wu, B., Li, L., He, R., Lyu, S.: Invisible backdoor attack with sample-
specific triggers. In: Proceedings of the IEEE/CVF International Conference on
Computer Vision, pp. 16463–16472 (2021)
25. Krizhevsky, A., Hinton, G., et al.: Learning multiple layers of features from tiny
images (2009)
26. Darlow, L.N., Crowley, E.J., Antoniou, A., Storkey, A.J.: Cinic-10 is not imagenet
or cifar-10. arXiv preprint arXiv:1810.03505 (2018)
27. Fu, C., et al.: Label inference attacks against vertical federated learning. In: 31st
USENIX Security Symposium (USENIX Security 22), pp. 1397–1414 (2022)
28. Li, Y., Jiang, Y., Li, Z., Xia, S.T.: Backdoor learning: a survey. IEEE Trans. Neural
Netw. Learn. Syst. 35(1), 5–22 (2022)
29. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In:
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition,
pp. 770–778 (2016)
Weakly Supervised Waste Classification
with Adaptive Loss and Enhanced Class
Activation Maps

Wenzhang Dai(B) and Le Sun(B)

Department of Jiangsu Collaborative Innovation Center of Atmospheric Environment


and Equipment Technology (CICAEET), Nanjing University of Information Science
and Technology, Nanjing 210044, China
{202312490511,lesun1}@nuist.edu.cn

Abstract. Waste classification is crucial as it brings economic benefits


to cities and promotes environmental sustainability. Traditional waste
classification studies focus on single-label classification tasks. However,
single-label classification is inefficient when dealing with images con-
taining multiple wastes. Supervised multi-label classification can man-
age more complex and realistic scenes, which contain a wider variety of
objects in a single image. However, it necessitates pre-labeling all waste
in each image, which is time-consuming. Therefore, we design a weakly
supervised multi-label classification framework called Adaptive Weakly-
supervised Waste Classification Framework (AWWCF). The AWWCF
consists of the Target Preprocessing Module (TPM), the Prediction Mod-
ule (PM), and the Computing Module (CM). Previous weakly-supervised
classification works primarily exploit inner connections among observed
labels while ignoring unannotated objects and unobserved labels. To
properly utilize unobserved labels, we integrate the PM and AM of
AWWCF with our newly proposed method, Adaptive Loss and Enhanced
Class (ALEC) activation maps. In PM, ALEC dynamically enhances the
attribution scores of the class activation maps to prevent predictions for
positive observed labels from being affected by unobserved labels. In the
CM, ALEC dynamically rejects and corrects unobserved labels with large
loss values. Experimental results demonstrate the effectiveness of our
framework in weakly supervised multi-label waste classification. ALEC
can be integrated with different deep classification models to form an
effective framework for a sustainable environment.

Keywords: Waste classification · Weakly supervised classification ·


Multi-label classification · Dynamic correction · Class activation maps

1 Introduction

With increasing awareness of environmental issues, there has been a growing


movement toward waste classification in recent years [1]. Despite these efforts,
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 347–361, 2025.
https://doi.org/10.1007/978-981-96-4566-4_24
348 W. Dai and L. Sun

different types of waste can still be seen almost everywhere in cities. Sustain-
able waste management reduces the use of natural resources while searching for
solutions that do not negatively impact the environment or public health. A few
crucial phases are included in the sustainable waste management process: gather-
ing of waste, processing and sorting, material or energy recovery from waste, and
reusing to create new goods. To promote waste recycling and environmental sus-
tainability, many artificial intelligence-based approaches for waste classification
[14] have been developed. Some methods integrate additional feature extractor
modules for waste classification. These modules, such as optimized convolutions
[14] and feature pyramids [21], are used to capture waste at different scales in
images. Some methods design fine-tuning methods while retaining the original
model structure [6]. These methods can enable the model to achieve faster con-
vergence when the dataset images are insufficient. However, these methods are
all single-label classification methods. In real-life scenarios, there is often more
than just one type of waste.
In contrast, multi-label classification allows for the simultaneous classification
of multiple objects within a single image [8]. However, labeling a large-scale and
accurate multi-label dataset for training is time-consuming [24]. To address this
issue, Weakly Supervised Multi-label Classification (WSML) has emerged [16].
While WSML reduces the labor costs and time associated with labeling images,
it also faces challenges [4]. For example, these unannotated objects and their
unobserved labels often lead to large loss values, resulting in poor classification
performance [5].
To alleviate the problem, Some methods focus on exploiting unobserved
labels with techniques like Assume Negative (AN) [5]. In most multi-label clas-
sification scenarios, the number of objects present in each image is smaller than
the total number of classes in the datasets [18]. Consequently, many true neg-
ative labels are incorporated into training. However, some false negative labels
can interfere with the model’s ability to learn the characteristics of objects. To
address the problem, an online estimation module is utilized to correct these
unobserved labels during training [5]. While this method can correct some false
negative labels, it requires prior knowledge of the mathematical expectation for
the number of positive labels in each image. Additionally, the approach neglects
that the model first fits true labels (early learning phase) and then fits false
labels (memorization phase) [11]. To utilize this finding, LL-R, LL-Ct, and LL-
Cp are proposed to reject and correct partial false negative labels [11]. However,
they struggle to handle false negative labels separately.
In this study, we propose a new weakly supervised multi-label waste clas-
sification framework, called the Adaptive Weakly-supervised Waste Classifica-
tion Framework (AWWCF). The AWWCF comprises three modules: the Target
Preprocessing Module (TPM), the Prediction Module (PM), and the Comput-
ing Module (CM). TPM is responsible for assigning unobserved labels as neg-
ative. To improve the effectiveness of AWWCF in utilizing unobserved labels
and reduce the impact of false negative labels on model learning features, we
integrate a novel method with the PM and CM of AWWCF, called Adaptive
Weakly Supervised Waste Classification with ALEC 349

Loss and Enhanced CAMs (ALEC). In the PM, ALEC calculates the areas of
CAMs, whose attribution scores are the largest part. These areas are associ-
ated with positive labels, but their attribution scores are diminished by other
false negative labels. ALEC enhances these areas. By increasing the attribution
scores in these areas, the probability of false negative labels being predicted as
positive is reduced, thereby improving classification performance. In the CM,
ALEC makes full use of unobserved labels by assigning them as negative. At
different stages of training, ALEC uses different methods to deal with large loss
values depending on the proportion of true negative and false negative labels. In
the early stages of training, ALEC gradually rejects false negative labels. When
classification performance stabilizes or mean average precision oscillates within
a range for a prolonged period, ALEC adopts more aggressive ways to correct
these false labels. It permanently modifies some of the false negative labels with
the largest loss values. The contributions of this paper are:
• We propose a multi-label waste classification framework called Adaptive
Weakly-supervised Waste Classification Framework (AWWCF). The frame-
work can classify multiple classes of waste in complex scenarios with partial
labels.
• We design a new method called Adaptive Loss and Enhanced CAMs (ALEC)
for weakly supervised multi-label classification. ALEC dynamically enhances
CAMs with high attribution scores to reduce false negative labels’ impact
on predictions for observed labels. Moreover, ALEC dynamically rejects and
corrects unobserved labels with large loss values.
• We conduct extensive experiments on three multi-label waste datasets. Exper-
imental results show that the framework incorporating ALEC has better clas-
sification performance compared with existing waste weakly supervised multi-
label classification framework. The proposed method ALEC can combine dif-
ferent backbones and improve their weakly supervised waste classification
performance.
The remaining sections of this paper are structured as follows: Sect. 2 intro-
duces related work, Sect. 3 describes the AWWCF modules in detail, Sect. 4
provides detailed comparative experiments and results, and Sect. 5 summarizes
all the work.

2 Related Work
2.1 Weakly Supervised Multi-label Classification
In weakly supervised multi-label classification tasks, there are various methods
used to train the model. Some studies fuse image features at different scales
[23]. Some address the issue of label distribution in datasets [7], which design a
network with two branches to improve classification performance on both head
classes and tail classes. Another proposal [5] suggests different ways to handle
extreme cases where there is only one single positive label in each image. The first
approach is to introduce a hyper-parameter for the ordinary BCELoss function.
350 W. Dai and L. Sun

This method can be associated with pseudo-negative sampling. However, this


method is more suitable for the case where there is only one positive label per
image. The second method combines ordinary cross-entropy with regularization
method label smoothing to prevent overfitting, but it is not as effective in other
training tasks [22]. The third approach is to design a separate loss function for
observed labels and use domain knowledge to avoid the model predicting all
labels as positive.
On the other hand, some researchers find that the model initially fits into true
negative labels and then transitions to fitting false negative labels during the late
learning phase [11]. Based on this finding, they propose three new methods that
effectively utilize unobserved labels according to the proportion of false negative
labels [12], called LL-R, LL-Ct, and LL-Cp. LL-R rejects all large loss values
during the whole training process. LL-Ct only uses the new loss calculated from
the modified label and doesn’t modify the actual labels. LL-Cp corrects these
false negative labels permanently during the whole training process. In all three
methods, whether it is a large loss value or not depends on the batch size,
the number of classes, and the percentage of clean labels. In LL-R and LL-Ct
methods, the number of large loss values keeps on going up with the number of
iterations, whereas, in the LL-Cp method, the number of labels modified in each
iteration remains the same as it is more aggressive. However, the researchers
do not consider the roles of these methods at different stages of training. Some
methods may be more suitable for the early stage of training, while others may
be more effective in handling large loss values at the late stage of training.

2.2 Waste Classification


Many methods based on artificial intelligence have been proposed in the field
of waste classification [19]. Table 1 provides a comparison with related works.
Previous studies primarily focus on single-label classification [3, 10]. For instance,
convolutional neural networks (CNN) are utilized for waste classification tasks
[2]. It provides a comparison with respect to the classification performance of
DenseNet, MobilieNet, and ResNet on two waste datasets. Although the models
achieve high accuracy, they do not make innovative improvements to network
structures or training methods. CNN can also be combined with Graph-LSTM
to classify municipal solid waste [15]. CNN is used as a feature extractor and the
features as nodes form the entire graph. Graph-LSTM is capable of capturing
temporal dependencies and long-range dependencies within the graph structure.
To address the issue of long training time, MLH-CNN simplifies the network
structure, which has fewer parameters and higher accuracy [3]. However, these
tasks typically involve a limited number of waste categories, usually no more than
10 categories, and only have one category per image. In real-world scenarios, it
proves to be inefficient when classifying a large number of wastes.
To address this, a serial attention frame (SAF) for multi-label bottle classifi-
cation is proposed [2]. The SAF is composed of channel attention, spatial atten-
tion, and self-attention, which are used to identify salient features. Although
their frame achieves relatively good performance, the task only focuses on one
Weakly Supervised Waste Classification with ALEC 351

Table 1. Comparison with related works

Paper Single-label classification Multi-label classification Datasets Categories


Huang, L et al. [10]  TrashNet 6
Masand, A et al. [3]  TrashNet 6
Li, N., Chen, Y. [15]  TrashNet 6
Ahmed, M.I.B et al. [2]  TrashNet, VN-trash 6, 3
Xiao, J et al. [2]  Private datasets 8
Yan, K. et al. [25]  Private datasets unkown

super category and does not consider other types of waste. Graph neural network
is also used to solve multi-label waste classification tasks [25]. They attach impor-
tance to correlations between labels instead of treating them independently.
By stacking multiple layers, their model can model the complex relationships
between labels. Recently, a program called “TACO” has been launched [17]. It
contains thousands of photos of waste in diverse environments and encompasses
sixty different waste categories. Therefore, we utilize the images and annotations
from this program to create a new waste dataset.

3 The Proposed Framework


3.1 Notation Description
Let the input image be denoted as x ∈ X , and the corresponding target value as
y ∈ Y, where the target value for the i-th class label of the image is represented
by yi . The sets X and Y together form the set D. The set S p corresponds to the
set of all positive labels, while the set S u corresponds to the set of all unobserved
labels. The target values after the assumed negative method are denoted as yiAN .
The set of all yiAN is called as Y AN , X and Y AN collectively form the dataset
D . The function P represents the model’s predicton values for all classes of a
single image, while Pi represents the prediction value for the i-th class. The loss
value calculated between the target value and the prediction value for the i-th
class is represented by li . M represents N CAMs for each image. Fa represents
the parts of M that is greater than 0, which represent the areas that get the
model’s focus. Th is the h-largest of all focus values, the values greater than Th
will be processed accordingly.

3.2 Overview of the Framework


The processing flow of AWWCF is presented in Algorithm 1. The framework
takes images X and Target Y as input, and then outputs a trained model. The
TPM primarily processes the target values of each input image to include all
labeled objects in the training. TPM can thereby increase the number of training
samples (lines 1–2). In PM, the classification model loads pre-trained weights and
uses ALEC to makes predictions (lines 3–6). The main task of PM is to identify
352 W. Dai and L. Sun

objects in the input images and predict the existence probability of each class.
The CM mainly calculates the cross-entropy loss function between the prediction
values output by PM and the target values output by the TPM. It uses ALEC
to reject and correct large loss values to minimize their impact on the model
learning process (lines 7–8).

Algorithm 1: The entire processing flow of AWWCF


Input: Image X and Target Y
Output: trained model
1 for y in Y do
2 Y AN ← Randomly remain one positive label and assume other labels as
negative ; // TPM (lines 1-2)
3 Load pre-trained weights for feature extractor;
4 M ← CAMs achieved by feature extractor and 1 × 1 Conv ;
5 Ec ← Using ALEC to conduct local enhancement of CAMs ;
6 P ← Using fully connected layers for prediction and classification ; // PM
(lines 3-6)
7 l ← Calculate the loss value using P and b Y;
8 f ← Using ALEC to dynamically reject and correct labels ; // CM
(lines 7-8)
9 Back propagation and update gradient of parameters
10 return trained model

3.3 Target Preprocessing Module


The TPM applies the assumed negative assumption to process the target val-
ues of each input image. Assuming there are N classification classes, the entire
labeled dataset is first manually set by discarding some observed labels and keep-
ing only one positive label while setting all other labels as ‘unknown’. The setting
creates a partially labeled dataset. Then all labels except the positive label for
each image are set as negative. Finally, only one positive label is retained among
the N classes, while all other labels are set as negative. Figure 1 illustrates the
process of TPM, where for simplicity, only 5 classes are shown. Unobserved labels
are represented by light green, positive labels by red, and negative labels by blue.
In WSML, consider an input x∈X and a target value y∈ Y, where X and
Y constitute a dataset D. Here, X represents a collection of images, while Y =
{0, 1, u}N , where x represents an image and y represents the corresponding target
value. The value u denotes an unknown target value for the category, such as
an unobserved label, and N represents the number of classification categories.
For a given image, the target value y can be expressed as S p = {i|yi = 1},
S n = {i|yi = 0}, and S u = {i|yi = u}, where S p , S n , and S u represent the sets
of positive, negative, and unknown labels, respectively. In some partially labeled
Weakly Supervised Waste Classification with ALEC 353

Fig. 1. Diagram of Target Preprocessing Module (Color figure online)

tasks, only a small proportion of the labels are known, and hence |S p |+|S n | < N .
The expression for Y AN can be given by Eq. 1:

AN 1, i ∈ S p
yi = (1)
0, i ∈ S n ∪ S u

The set of all yiAN constitutes Y AN


 , and X and
AN
 Y  together form  the
dataset D . Each element in the sets yiAN | i ∈ S p and yiAN | i ∈ S n rep-
resents
 a true positive
 and a true negative, respectively. Correspondingly, the
set yiAN | i ∈ S u contains both true negatives and false negatives. Once all
Y AN of the images are obtained, these target values can be incorporated into
the computation of the Bayesian cross-entropy loss function in the subsequent
steps.

3.4 Prediction Module

PM takes images X as inputs and outputs the prediction, which is denoted as


P ∈ RN . Since our task involves multi-label classification, we pre-train the deep
learning model on multi-label datasets like Relabeld-ImageNet [27]. The feature
extractor can be replaced with any other deep-learning classification model. We
choose EfficientNet as our model’s feature extractor due to its strong classifi-
cation performance and lightweight nature. The weight of the feature extractor
can be frozen, only fine-tuning the weight of 1 × 1 convolution.
The process of PM is shown in Algorithm 2. The feature extractor Φ and
1×1 Convolution Γ first converts the input image x ∈ RH×W into CAMs M ∈
RN ×h×w (line 1). For each image, there are N CAMs. We assume the parts of
CAMs that are greater than 0 as the focus area of the model Fa (lines 2–8).
We select the k images with higher prediction scores (line 9). k is related to
batch size b, the number of classes in the datasets N , the current epoch t, and
Δrel . The CAMs with higher scores are multiplied by α to enhance the model’s
prediction score for that class (lines 10–11) to avoid the impact of potentially
false negative labels on the model’s multi-label learning. After passing through
a global adaptive pooling layer (line 12), the prediction P ∈ R is obtained.
354 W. Dai and L. Sun

Algorithm 2: The processing flow of ALEC in PM


Input: Image X
Output: Predictions P
1 Φ ← pre-trained model weights
2 M ← Γ (Φ(x))
3 for i in h do
4 for j in w do
5 if Mij >= 0 then
6 Fa ← Mij
7 if Mij > 0 then
8 Fa ← 0

9 Th ← topk(Fa .f latten(), b × N × t × Δrel )


10 Ec ← The part of M that is greater than Th
11 M ← M + α × Ec
12 P ← GlobalAdapativePools(M )
13 return Predictions P

3.5 Computing Module


The CM is designed to calculate the cross-entropy loss function for the prediction
value P output by PM and the target valueY AN output by TPM, and then reject
or correct the larger loss values. The process of CM is illustrated in Fig. 2.

Fig. 2. Diagram of Computing Module

After obtaining the target value Y AN from TPM and the prediction value P
from PM, the cross-entropy loss can be achieved. Finally,
 the model
 is trained
by minimizing the loss function L on the dataset D = X , Y AN as shown in
Eq. 2:
 N
1 1   
L=  BCELoss Pi , yiAN (2)
|D | AN 
N i=1
(x,y )∈D
Weakly Supervised Waste Classification with ALEC 355

Due to the modification of target values for some labels in PM, there may be
a subset of labels with excessively large cross-entropy loss values, typically due
to the Assume Negative assumption, and these labels are usually false negative
labels. To prevent these labels affecting model learning, the ALEC introduces a
weight λi for each loss value:

 N
1 1 
L= li × λ i . (3)
|D | N i=1
(x,y AN )∈D 

 
The definition of li is BCELoss Pi , yiAN , where the parameters Pi and yiAN 
are omitted for convenience. The λi is defined as a function λi = λ Pi , yiAN ,
where the parameters are also omitted for convenience. λi is a weight used to
measure the proportion of li in the loss functionL in Eq. 3. The setting of λi is
shown in Eq. 4:

0, i ∈ S u and γ ≤ pt and li > R(t)
λi = (4)
1, otherwise ,

Here, γ refers to the difference between the current epoch and the epoch of the
last mAP maximum when using rejection methods, pt means patience, denotes
after how many epochs of oscillating, the correcting method starts to step in.
R(t) represents the largest loss values in the loss set {li |(x, y AN ) ∈ D , i ∈ S u } for
the first [(t−1)Δrel ]% of the samples. Δrel is a hyper-parameter that determines
the rate at which the rejection rate of large loss value samples increases. In the
early stage of training, the loss values of samples with excessively high loss are
set to zero in each batch, which can be understood as temporarily ignoring these
samples with excessively high loss. The zeroing ratio is related to the number of
classification categories in the dataset, the hyper-parameter batch size, and the
rejection rate. The more classes in the dataset, the larger the hyper-parameter
batch size, and the higher the rejection rate, the more noise labels will be ignored.
As the number of training iterations increases, the model gradually learns the
noise labels, so the zeroing ratio will also increase. In the later stages of training,
the target with large loss values is permanently corrected and the loss values are
recalculated. However, due to the correction method being too aggressive, the
R(t) is a fixed value representing the top Δrel % of the loss value set, rather than
[(t − 1)Δrel ]%. The targets of other normal loss value samples are not changed,
and the loss values are not recalculated.
Correspondingly, due to the label modification method is used in the later
stage of training, the target value yiAN also needs to be modified. The specific
setting is shown in Eq. 5:

AN 1, i ∈ S u and t > p and li > R(t)
yi = (5)
remain, otherwise,
356 W. Dai and L. Sun

After the label yiAN that satisfies the first condition of Eq. 5 is modified, sets
S u and S p will also be modified based on Eq. 6

S u ← S u − {i}
(6)
S p ← S p ∪ {i}

4 Experiment
4.1 Experimental Setting

We evaluate the ALEC on three multi-label datasets: MW, MWS, and TACOS.
The MW and MWS datasets consist of trash images collected from the internet,
while the TACOS dataset comprises images from the TACO project. Due to
the severe imbalance in the number of labels in the MH and TACOS datasets,
we excluded object categories with fewer images to ensure sufficient images per
category for training and testing. Detailed information about the three datasets
is presented in the table. The table lists the number of images in the training,
validation, and test sets, the number of annotations, and the number of object
categories. Generally, we use 70% of the images and annotations to train the
model, 20% to validate its performance, and the remaining 10% to test the
model’s classification performance. We use mean Average Precision (mAP) as
the classification evaluation metric (Table 3).

Table 2. Statistics of datasets

Datasets train imgs train annos val imgs val annos test annos test annos classes
MW 3283 5474 938 1553 469 751 46
MWS 3262 5167 932 1459 466 721 30
TACOS 892 2640 255 758 126 290 20

4.2 Baseline Methods and Main Results

• Full labels: Each image in the datasets has true labels for all classes, repre-
senting the best performance of multi-label classification on the datasets.
• Naive Assume Negative (Naive AN) [11]: Treat all unobserved objects as
non-existent.
• Regularized online label estimation (ROLE) [5]: Combine the EPR and online
estimates of the unobserved labels throughout training.
• Large Loss Rejection (LL-R) [11]: Gradually reject large loss value samples
according to the training stages.
• Large Loss Correction (temporary) (LL-Ct) [11]: Utilize the loss calculated
from the modified labels but do not correct the actual labels.
Weakly Supervised Waste Classification with ALEC 357

Table 3. Statistics of datasets

Methods End-to-End LinearInit.


MW MWS TACOS MW MWS TACOS
Full labels 69.340 84.776 39.777 69.294 85.043 39.018
Naive AN 54.715 75.486 34.091 54.875 75.969 33.702
ROLE 55.008 75.332 35.018 55.113 75.482 34.092
LL-R 55.930 75.979 35.699 55.639 76.376 35.363
LL-Ct 52.247 71.073 34.476 51.983 71.572 33.837
LL-Cp 51.075 68.917 31.389 51.414 68.802 31.487
LL-R+BoostLU 54.475 76.926 33.582 54.822 77.385 33.412
ALEC 56.822 77.474 36.022 56.931 78.004 35.389

• Large Loss Correction (permanent) (LL-Cp) [11]: Utilize the loss calculated
from the modified labels and correct the actual labels permanently.
• LL-R+BoostLU [12]: Employing the LL-R method and enhancing all areas
where CAMs are greater than 0.

The results are presented in Table 3 mAP is used to evaluate the classifi-
cation performance. ALEC outperforms all baseline methods on four datasets.
Naive AN assumes all unobserved labels correspond to objects absent in the
images, introducing substantial noise into model training and resulting in the
worst classification performance among baseline methods. WAN assumes that a
large proportion of unobserved labels are indeed absent in the images, meaning
only a small amount are false negative labels. Consequently, a fixed weight is
added to all assumed negative labels. However, WAN fails to consider the spe-
cific situation of each sample and only corrects for an overall perspective. ROLE
combines EPR with online estimates of unobserved labels, leveraging the find-
ing that the model can fit correct labels more quickly. However, it requires a
complex optimization process and is computationally expensive. LL-Cp is more
aggressive and permanently changes the labels of samples with large loss values
in subsequent training iterations. However, LL-Cp performs poorly on datasets
with lower prediction performance, such as MH and TACOS. On such challeng-
ing datasets, it is easy to correct true labels to false labels simply based on
loss values. Similarly, BoostLU enhances all parts of the CAMs greater than 0.
However, when the classification accuracy of the model is relatively low, this
simple method may blindly enhance irrelevant regions. In comparison, ALEC
corrects only in the later stages of training and corrects false labels when the
model exhibits relatively high classification performance.
358 W. Dai and L. Sun

4.3 Generalization of ALEC

Fig. 3. Diagram of Target Preprocessing Module

In this section, we conduct experiments to demonstrate the generalization of our


proposed method, ALEC, in weakly supervised multi-label classification tasks
for waste. We integrate ALEC with four classical deep learning classification
models: RegNet [26], ResNet50 [9], ShuffleNet [13] and EfficientNet [20].
We compare our method with LL-R and LL-Cp, which are recently proposed
methods. The results are presented in Fig. 3(a), Fig. 3(b), Fig. 3(c), and Fig. 3(d).
They show the training process and test results after combining our method with
RegeNet, ResNet50, shuffleNet and EfficientNet respectively. The lines in Fig. 3
represent the trend of the mean Average Precision (mAP) on the validation set,
and the scatter plots represent the final experimental results on the test set.
We observe that although the classification performance of different methods
varies widely when using different models, ALEC achieves a higher mAP both
in the training phase and on the testing datasets. During the training phase,
the highest mAP value is achieved after ALEC starts to correct labels. However,
there is a slight dip in the mAP, such as around 30 epochs in Fig. 3(b) and 60
epochs in Fig. 3(c). After a short struggle phase, ALEC takes the initiative to
correct the false negative labels in the dataset. These labels already have large
loss values after multiple epochs of training, so correcting their labels is highly
effective.

4.4 Visualization

In this section, we use CAM to visualize the regions of interest in images for
models trained under different conditions. In CAMs, the red areas represent the
parts that the model focuses on, and the brighter the red, the greater the impact
of that part on the model’s predictions. Under the condition of full labeling, the
model focuses on areas that are more concentrated and have higher attention
than the other two methods. When using LL-R, the model occasionally treats
the background noise in the image as a feature of the object. For example, when
classifying the two images below, the background of the entire image is covered
in red, but the main body of the object is not noticed by the model. ALEC
Weakly Supervised Waste Classification with ALEC 359

Fig. 4. Visualization of prediction results

almost never experiences this situation. Compared to using LL-R, the model
using ALEC can focus on a more comprehensive appearance of objects, such as
when classifying cigarettes (Fig. 4).

5 Conclusion
In this study, we design a new method to utilize unobserved labels in weakly-
supervised multi-label waste classification, called Adaptive Loss and Enhanced
CAMs (ALEC). Moreover, we integrate ALEC with deep learning classifica-
tion model to form a new framework called the Adaptive Weakly-supervised
Waste Classification Framework (AWWCF). Our experiments on three multi-
label waste datasets show that ALEC outperforms other weakly supervised clas-
sification methods. Although ALEC demonstrates relatively better classification
performance, there is still a noticeable gap when compared to the results under
the full label scenario. In the future, we will explore other potential factors
that may affect the model’s performance in weakly supervised multi-label waste
classification, including the impact of label distribution and object scale on the
model’s classification effectiveness.

Disclosure of Interests. The authors have no competing interests to declare that


are relevant to the content of this article.

References
1. Adedeji, O., Wang, Z.: Intelligent waste classification system using deep learning
convolutional neural network. Procedia Manuf. 35, 607–612 (2019). https://doi.
org/10.1016/j.promfg.2019.05.086
2. Ahmed, M.I.B., et al.: Deep learning approach to recyclable products classification:
towards sustainable waste management. Sustainability 15(14) (2023). https://doi.
org/10.3390/su151411138
3. Alrayes, F.S., et al.: Waste classification using vision transformer based on multi-
layer hybrid convolution neural network. Urban Clim. 49, 101483 (2023). https://
doi.org/10.1016/j.uclim.2023.101483
4. Arachie, C., Huang, B.: Constrained labeling for weakly supervised learning. In:
de Campos, C., Maathuis, M.H. (eds.) Proceedings of the Thirty-Seventh Con-
ference on Uncertainty in Artificial Intelligence. Proceedings of Machine Learning
Research, vol. 161, pp. 236–246. PMLR (2021)
360 W. Dai and L. Sun

5. Cole, E., Mac Aodha, O., Lorieul, T., Perona, P., Morris, D., Jojic, N.: Multi-label
learning from single positive labels. In: Proceedings of the IEEE/CVF Conference
on Computer Vision and Pattern Recognition (CVPR), pp. 933–942 (2021)
6. Cárdenas-León, I., Koeva, M., Nourian, P., Davey, C.: Urban digital twin-based
solution using geospatial information for solid waste management. Sustain. Urban
Areas 115, 105798 (2024). https://doi.org/10.1016/j.scs.2024.105798
7. Guo, H., Wang, S.: Long-tailed multi-label visual recognition by collaborative train-
ing on uniform and re-balanced samplings. In: Proceedings of the IEEE/CVF Con-
ference on Computer Vision and Pattern Recognition (CVPR), pp. 15089–15098
(2021)
8. Han, M., Wu, H., Chen, Z., Li, M., Zhang, X.: A survey of multi-label classification
based on supervised and semi-supervised learning. Int. J. Mach. Learn. Cybern.
14(3), 697–724 (2023). https://doi.org/10.1007/s13042-022-01658-9
9. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In:
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition
(CVPR) (2016)
10. Huang, L., Li, M., Xu, T., Dong, S.Q.: A waste classification method based on a
capsule network. Environ. Sci. Pollut. Res. 30(36), 86454–86462 (2023). https://
doi.org/10.1007/s11356-023-27970-7
11. Kim, Y., Kim, J.M., Akata, Z., Lee, J.: Large loss matters in weakly supervised
multi-label classification. In: Proceedings of the IEEE/CVF Conference on Com-
puter Vision and Pattern Recognition (CVPR), pp. 14156–14165 (2022)
12. Kim, Y., Kim, J.M., Jeong, J., Schmid, C., Akata, Z., Lee, J.: Bridging the gap
between model explanations in partially annotated multi-label classification. In:
Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recog-
nition (CVPR), pp. 3408–3417 (2023)
13. Ma, N., Zhang, X., Zheng, H.T., Sun, J.: ShuffleNet V2: practical guidelines for
efficient CNN architecture design. In: Proceedings of the European Conference on
Computer Vision (ECCV) (2018)
14. Mao, W.L., Chen, W.C., Wang, C.T., Lin, Y.H.: Recycling waste classification
using optimized convolutional neural network. Resour. Conserv. Recycl. 164,
105132 (2021). https://doi.org/10.1016/j.resconrec.2020.105132
15. Masand, A., Chauhan, S., Jangid, M., Kumar, R., Roy, S.: ScrapNet: an efficient
approach to trash classification. IEEE Access 9, 130947–130958 (2021). https://
doi.org/10.1109/ACCESS.2021.3111230
16. Pati, P., et al.: Weakly supervised joint whole-slide segmentation and classification
in prostate cancer. Med. Image Anal. 89, 102915 (2023). https://doi.org/10.1016/
j.media.2023.102915
17. Proença, P.F., Simões, P.: TACO: trash annotations in context for litter detection.
CoRR abs/2003.06975 (2020). arXiv: 2003.06975
18. Ridnik, T., et al.: Asymmetric loss for multi-label classification. In: Proceedings of
the IEEE/CVF International Conference on Computer Vision (ICCV), pp. 82–91
(2021)
19. Shennib, F., Schmitt, K.: Data-driven technologies and artificial intelligence in
circular economy and waste management systems: a review. In: 2021 IEEE Inter-
national Symposium on Technology and Society (ISTAS), pp. 1–5 (2021). https://
doi.org/10.1109/ISTAS52410.2021.9629183
20. Soni, U., Roy, A., Verma, A., Jain, V.: Forecasting municipal solid waste generation
using artificial intelligence models-a case study in India. SN Appl. Sci. 1(2), 162
(2019). https://doi.org/10.1007/s42452-018-0157-x
Weakly Supervised Waste Classification with ALEC 361

21. Sun, L., Dai, W., Muhammad, G.: Multi-level graph memory network cluster con-
volutional recurrent network for traffic forecasting. Inf. Fusion 105, 102214 (2024).
https://doi.org/10.1016/j.inffus.2023.102214
22. Sun, L., Li, C., Liu, B., Zhang, Y.: Class-driven graph attention network for multi-
label time series classification in mobile health digital twins. IEEE J. Sel. Areas
Commun. 41(10), 3267–3278 (2023). https://doi.org/10.1109/JSAC.2023.3310064
23. Sun, L., Li, Y., Zheng, M., Zhong, Z., Zhang, Y.: MCnet: multiscale visible image
and infrared image fusion network. Signal Process. 208, 108996 (2023). https://
doi.org/10.1016/j.sigpro.2023.108996
24. Verelst, T., Rubenstein, P.K., Eichner, M., Tuytelaars, T., Berman, M.: Spatial
consistency loss for training multi-label classifiers from single-label annotations. In:
Proceedings of the IEEE/CVF Winter Conference on Applications of Computer
Vision (WACV), pp. 3879–3889 (2023)
25. Xiao, J., Xu, J., Tian, C., Han, P., You, L., Zhang, S.: A serial attention frame for
multi-label waste bottle classification. Appl. Sci. 12(3) (2022). https://doi.org/10.
3390/app12031742
26. Xu, J., Pan, Y., Pan, X., Hoi, S., Yi, Z., Xu, Z.: RegNet: self-regulated network
for image classification. IEEE Trans. Neural Netw. Learn. Syst. 34(11), 9562–9567
(2023). https://doi.org/10.1109/TNNLS.2022.3158966
27. Yun, S., Oh, S.J., Heo, B., Han, D., Choe, J., Chun, S.: Re-labeling ImageNet:
from single to multi-labels, from global to localized labels. In: 2021 IEEE/CVF
Conference on Computer Vision and Pattern Recognition (CVPR), pp. 2340–2350
(2021). https://doi.org/10.1109/CVPR46437.2021.00237
A Vehicle Asynchronous Communication
Scheme Based on Federated Deep
Reinforcement Learning

Jinming Huang1 , Jin Wang2 , Shengyang Gao3 , and Zilong Jin1(B)


1
School of Software, Nanjing University of Information Science and Technology,
Nanjing 210044, China
[email protected], [email protected]
2
China Telecom Corporation Taian Branch, Taian 271000, China
3
Zhejiang Sci-Tech University, Hangzhou 310018, China
[email protected]

Abstract. With the rapid development of network technologies such as


artificial intelligence and the Internet of Things, the Internet of Vehicles,
as an emerging information and communication technology, is widely
used in the field of intelligent transportation systems. As more attention
is paid to data privacy, an increasing number of vehicles are reluctant
to share local data. To address this issue, federated learning has been
introduced to the Internet of Vehicles as an emerging learning paradigm,
allowing models to be trained without directly accessing raw local data.
However, frequent communication between vehicles and roadside unit
nodes can lead to a decrease in the accuracy of the global model in fed-
erated learning, as well as slower model convergence and higher commu-
nication overhead during training. Therefore, this paper proposes a node-
based dynamic communication scheme based on federated deep reinforce-
ment learning. A global model accuracy optimization objective is con-
structed based on real-time network conditions, vehicle communication
capabilities, and task requirements. To address the vehicle node selec-
tion problem, an adaptive algorithm is designed to select the optimal set
of vehicles for global model aggregation. Then, a dynamic asynchronous
aggregation strategy is used to improve the efficiency of model training.
By comparing with other baseline algorithms on standard datasets, the
proposed method achieves an average accuracy improvement of 3% com-
pared to traditional methods, while reducing communication latency and
decreasing energy consumption by an average of 37%.

Keywords: Federated learning · Deep reinforcement learning ·


Optimization of communication costs · Vehicle networking · Delay
optimization

1 Introduction
The widespread application of the Internet of Vehicles (IoV) enables data shar-
ing between vehicles and edge networks, bringing technologies like autonomous
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 362–376, 2025.
https://doi.org/10.1007/978-981-96-4566-4_25
A Vehicle Asynchronous Communication Scheme Based on FDRL 363

driving and smart transportation closer to reality. These advancements rely on


artificial intelligence and the collection of vast amounts of data from vehicle sen-
sors. While this data is useful for training machine learning models, its collection
and sharing place additional strain on network bandwidth and increase overall
network load.
Federated learning reduces communication costs by training models locally,
avoiding the need to transmit raw data to servers. However, in diverse vehicular
network environments, random selection of vehicles for training via roadside
units (RSUs) can lead to longer training times and communication instability.
In high-density urban areas, limited bandwidth and resource competition among
users can make stable connections difficult. Synchronous aggregation further
exacerbates the issue by making training time dependent on the slowest vehicle,
reducing overall efficiency.
To address these challenges, this paper proposes a trade-off optimization
method that simultaneously optimizes communication costs, transmission delays,
and training accuracy, while considering the limited bandwidth of vehicular net-
works. The main contributions of this paper are summarized as follows:
(1) A vehicle user selection algorithm is proposed, which jointly considers vehi-
cle user information and the currently available communication resources.
Deep Reinforcement Learning (DRL) techniques are employed to select vehi-
cle nodes to participate in the federated learning task, aiming to optimize
communication costs during model training.
(2) An asynchronous aggregation method is designed, where the set of vehicles
participating in the global model aggregation is dynamically adjusted in each
round, thereby optimizing the efficiency of global model training.
(3) Through experiments on two benchmark datasets, simulation results demon-
strate that the proposed approach outperforms other baseline algorithms in
terms of both training accuracy and model training efficiency.

2 Related Work
Federated learning (FL) enables multiple devices or data sources to collabora-
tively train models without sharing raw data, ensuring privacy [1]. Recently, FL
has gained significant attention, with Li et al. [2] introducing its core compo-
nents, such as data configuration, models, privacy measures, and communication
frameworks. Nknam et al. [3] focused on improving training speed by leveraging
edge computing, data caching, and spectrum management in wireless networks,
highlighting the reduction of traffic burdens through edge caching.
A key challenge in FL is communication overhead, particularly in large-scale
networks [4]. To address this, research has shifted toward data compression tech-
niques that reduce energy consumption and communication costs while main-
taining training quality. Federated learning depends on iterative communication
between IoT devices and edge servers to train global models, where increased
interaction can speed up convergence but also raise communication costs. Solu-
tions to mitigate these costs include distributed training, model simplification,
364 J. Huang et al.

and importance-based data updates. Haddadpour et al. [5] introduced redundant


data to enhance model diversity, reducing communication rounds, while Liu et
al. [6] proposed an intermediate model aggregator to minimize data transmission.
However, as the number of clients or servers grows, convergence time increases.
Yao et al. [7] introduced feature fusion to reduce communication costs and accel-
erate aggregation.
In large-scale FL systems, high communication demands can degrade per-
formance. To improve efficiency, research has focused on client selection, using
either static or dynamic strategies. Static methods assume no client failures,
while dynamic approaches prioritize clients based on utility and fast training.
Wei et al. [8] and Song et al. [9] used Shapley values to assess data quality,
enhancing selection efficiency. Chai [10] proposed a hierarchical FL framework
(TiFL) to categorize participants by training performance for faster convergence.
Lai et al. [11] prioritized clients with the highest utility for model accuracy and
training speed.
In the Internet of Vehicles (IoV) domain, FL facilitates distributed data
sharing, allowing vehicles to train models without exposing raw data. Ye et
al. [12] applied FL to image classification in vehicle IoT, proposing a selective
aggregation method based on local image quality and vehicle capabilities. Chai
et al. [13] introduced an FL framework combining blockchain to ensure secure
transactions between vehicles (as clients) and RSUs (as servers).
FL also enhances data privacy protection in IoV by enabling model training
without exposing personal or driving data. Lu et al. [14] proposed a dual-phase
method involving data transformation and leakage detection, securely converting
raw data into useful models. Pokhrel et al. [15] integrated FL with blockchain for
decentralized traffic management, ensuring privacy and reward validation. Sun
et al. [16] introduced privacy-enhancing techniques such as data subsampling
and model shuffling to balance communication, privacy, and accuracy.
Beyond privacy, FL supports vehicle-to-vehicle network resource manage-
ment. Wu et al. [17] proposed a parameter freezing algorithm to reduce trans-
mission and overhead. Samarakoon et al. [18] developed an FL approach for
ultra-reliable low-latency communication (URLLC) in vehicles, maintaining pri-
vacy while optimizing network control and resource allocation. Cao et al. [19]
introduced a mobile edge computing framework to optimize caching and com-
putation costs. Parekh et al. [20] proposed the GeFL model, which uses gradient
encryption to protect privacy in autonomous driving without adding computa-
tional costs.
While FL helps prevent data leakage in IoV, frequent communication with
RSUs increases costs, highlighting the need to balance privacy protection with
communication efficiency.

3 System Model
In this section, the system model of this paper is presented. As shown in Fig. 1,
the architecture comprises RSUs and vehicle users. Within the RSU’s cover-
age area, vehicle users first train their local models using their respective data,
A Vehicle Asynchronous Communication Scheme Based on FDRL 365

then upload the updated model parameters to the RSU for aggregation. The
RSU aggregates these model parameters to create a global model, which is then
distributed to the vehicle users involved in the federated learning process. This
iterative process continues until the model reaches optimal training performance.

Fig. 1. System Model.

In the system model, model gradient exchanges between RSUs and vehicles
occur through wireless communication, which unavoidably causes delays. Com-
pared to the upload time, the download time is negligible because RSUs have
significantly higher downlink power than vehicles’ uplink power and allocate
greater downlink bandwidth for data distribution. Furthermore, since RSUs pos-
sess high computational capabilities, the delay caused by the aggregation phase
can also be ignored. The delay model is divided into two parts: upload delay and
model training delay.

(1) Upload Delay: Assuming that vehicles are covered by a base station, Orthog-
onal Frequency Division Multiple Access (OFDMA) is adopted as the
method for uploading local gradients. Considering limited communication
resources, it is assumed that each RSU distributes N subcarriers among W
vehicles (N < W ) for local parameter transmission, ensuring each vehicle
receives at most one subcarrier. Thus, the uplink transmission rate of vehicle
i is:  
up qi,k ui,k b−γ
i,k
νi.k = si,k B log2 l + , (1)
σ02
where B refers to the uplink channel bandwidth, qi,k denotes the transmis-
sion power of vehicle i in the k-th round, ui,k indicates the channel gain
between vehicle i and the RSU, σ02 is the noise level, bi,k represents the
366 J. Huang et al.

N
distance to the RSU, and γ is the path loss coefficient. sk,t = n=1 ai,n ,
where ai,n is the number of subcarriers allocated. If subcarrier n is assigned
to vehicle i, then ai,n = 1; otherwise, ai,n = 0.
Assume the model size is M . Then, the communication delay for vehicle i
at the k-th round is:
M
tup
i,k = up . (2)
νi,k
(2) Model training delay: Vehicle i in the global iteration, using a dataset of size
Di for local training, the computation delay is defined as:
EDi c
tcom
i.k = , (3)
fi,k
where E denotes the total number of local iteration rounds, c represents the
CPU cycles required to process one bit of data, and fi,k signifies the CPU
frequency of vehicle i during the k-th training round.
(3) The total delay calculation: The delay is designed as the sum of the com-
putation time of the vehicle client i and the model upload time. Thus, the
total delay in the k th round is:
up
tsum com
i,k = ti.k + ti.k . (4)

3.1 Resource Consumption


To perform federated learning training, vehicles consume a significant amount
of energy during both the computation and communication phases. This
includes the consumption of local computation resources and uplink transmission
resources.
(1) Local computation resource consumption: During training round k, the
power consumed by vehicle i for local computation is given by:
com 3
gi,k = μ (fi,k ) , (5)
where μ signifies the effective switching capacitance of the processing chip.
Based on Eq.(3), the energy required for local model generation is computed
as:
2
ecom com com
i,k = gi,k ti,k = μEDi c (fi,k ) . (6)
(2) Uplink transmission resource consumption: In training round k, the uplink
resource consumption power of vehicle i is defined as:
M
eup up
i,k = qi,k ti,k = qi,k  . (7)
qi,k ui,k b−γ
si,k B log2 l + σ02
i,k

Therefore, the total energy consumption of vehicle i during the k th round of


training is:
up
esum com
i,k = ei.k + ei.k . (8)
A Vehicle Asynchronous Communication Scheme Based on FDRL 367

3.2 Federated Learning Model

The federated learning process occurs between the RSU and the vehicles within
its coverage. Each vehicle in the set i = {1, 2, ..., n} has its local dataset Di =
{D1 , D2 , ..., Dn } and corresponding local model parameters θ = {θ1 , θ2 , ..., θn }.
For each vehicle i, the dataset Di = {xi , yi } consists of input data xi and the
associated output yi .
The federated learning process consists of several stages. In each round, there
are three key steps: First, vehicles train their local models and upload their model
parameters to the RSU. Next, the RSU aggregates the local models to generate
the global model parameters, which are then sent back to the vehicles. Finally,
based on the updated global model parameters, the vehicles initialize their local
training parameters for the next round.
During local training, the local loss function for the dataset Di of vehicle i
can be expressed as the loss calculated over the data sample set at vehicle i:
1 
Fi (θ) = f (θi ; xi , yi ) . (9)
|Di |
(xi ,yi )∈Di

Based on the above loss function, the local model parameters can be updated
and adjusted using standard gradient descent methods. After completing all the
local training, the local model parameters for the rth round are obtained as
{θl (r), θ2 (r), ..., θn (r)}. The aggregated global model parameters can be defined
as:
n
1
θ∗ (r) = n |Di | θi (r). (10)
i=1 |Di | i=1

The global loss function is defined as:


n

F (θ) = εi Fi (θ), (11)
i=1

Di
where εi = D .

4 Optimization Scheme

Building on the above system model, the objective is to minimize the global loss
function of federated learning while maintaining constraints on delay and energy
consumption within the vehicular network, as follows:

P1:min F (θ)

up
s.t. Cl: max tcom
i,k + ti,k ≤ tmax ,
i (12)
n 
 
C2: ecom
i,k + eup
i,k ≤ κ.
i=l
368 J. Huang et al.

Constraint C1 defines the delay limitation for the vehicle client, while con-
straint C2 sets the energy consumption limit, with K representing the total
energy budget. Given the mobility of vehicles and the dynamic characteristics of
wireless channels, uncertainties may arise in the performance, delay, and energy
consumption of federated learning. For instance, if a vehicle moves out of the
RSU’s coverage area or encounters poor wireless conditions, it may fail to com-
plete its assigned task within the stipulated timeframe. Additionally, since the
loss function does not have a closed-form expression, problem P1 cannot be
directly solved. To address this, a client selection algorithm based on DRL is
proposed, utilizing its capability to determine optimal strategies in complex
environments. Furthermore, an asynchronous transmission method is employed
to upload local model updates to the RSU, effectively reducing client waiting
time and minimizing transmission delay.

Fig. 2. DRL Architecture Diagram.

4.1 Client Selection Scheme


Vehicle client selection is formulated as a deep reinforcement learning task. In
each federated learning round, the problem is represented as a Markov decision
process (MDP), with the state s encompassing model weights, communication
time, computational resources, and the local training loss for each device. The
process is illustrated in Fig. 2.
Upon observing the state st from the environment, the agent evaluates the
value of all potential actions in that state. Based on the policy π, the agent selects
the action at with the maximum value and takes it, resulting in the reward rt .
Our goal is to select the best client to participate in the training and minimize
the training loss. The following is a detailed introduction to the state, action,
and reward in MDP:
(1) State
For vehicle clients in the vehicular network, the state function at time t con-
sists of the total delay Lt , computation resources Et , model weights Wt , and
A Vehicle Asynchronous Communication Scheme Based on FDRL 369

the difference in loss values between the t−1 and tth iterations, ΔHt . There-
fore, the state function at time t can be described as st = (Lt , Et , Wt , ΔHt ).
The state space can be represented as S = {si |i = 1, 2, ...}, where si is the
potential state at the ith time.
(2) Action
The action at =jt ∈ {1, 2, ..., n} is a positive integer representing the number
of clients selected. At the tth iteration, the DRL agent picks an action based
on the current state to carry out the global model update.
(3) Policy
The policy π maps the state space to the action space, i.e., at = π(st ).
To represent the probability distribution of actions, a neural network is
employed to model π(a|s).
(4) Reward
Following the execution of action at , the agent receives a reward rt , which
depends on both the loss function value and the resource consumption at
time t, and is given by:
Et −E t−l
ΔHt
rt = −ξ E t−l
+ , t ∈ {1, 2, . . . , T } (13)
F
where ξ is a constant that ensures rt increases exponentially as the energy
consumption Et decreases. The energy consumption Et−1 is calculated as
the moving average, satisfying Et = τ Et + (1 − τ )Et−1 , where τ ∈ (0, 1). At
time t, the agent is rewarded more when energy consumption is lower and
the loss function value is closer to convergence.

Fig. 3. Comparison of Dynamic Asynchronous Federated Learning, Synchronous Fed-


erated Learning, and Asynchronous Federated Learning.

4.2 Asynchronous Aggregation Scheme


When deploying federated learning models in vehicular networks, it is crucial to
enhance model accuracy efficiently while maintaining control over energy con-
sumption. The varying computational resources among different vehicle clients
370 J. Huang et al.

must be considered. Vehicle clients with higher computational capacity and lower
energy consumption should be prioritized for frequent participation in global
updates. Furthermore, the inherent instability in vehicular networks, such as
road congestion and device disconnections, necessitates dynamic adjustments to
the RSU’s global update strategy based on real-time network conditions.
To illustrate the asynchronous transmission process more clearly, Fig. 3 com-
pares dynamic asynchronous federated learning (left subfigure), synchronous fed-
erated learning (middle subfigure), and standard asynchronous federated learn-
ing (right subfigure).
In the IoV scenario with four vehicle clients, federated learning strategies
vary in their approach to global updates. In synchronous FL, the RSU waits
for all four vehicle clients to submit their updated models before initiating a
global update, ensuring equal contributions but potentially causing delays. On
the other hand, asynchronous FL enables the RSU to perform a global update as
soon as it receives updated parameters from any client that completes its local
training, minimizing waiting time. Dynamic asynchronous FL further enhances
adaptability by allowing the RSU to adjust the number of participating clients
dynamically for each iteration round; at time t, the RSU performs a global
update when it receives updated models from jt clients, where jt changes with
the iteration to accommodate the varying conditions of vehicular networks. As
illustrated in Fig. 3, at time t1 , the RSU conducts a global update with con-
tributions from two clients, while at time t2 , the update involves three clients,
demonstrating the flexibility of this approach in addressing the dynamic nature
of the IoV environment.
Firstly, each vehicle client i updates its model parameters locally during the
training phase r ∈ {1, 2, ...}:

θi (r) = θi (r − 1) − δ∇Fi (θi (r − 1)), ∀i ∈ N (14)

where mt,i represents the local update index of vehicle client i at time t, and zt,i ∈
{0, 1} is a binary variable that indicates whether vehicle  client i participates in
the global update at time t, subject to the constraint i zt,i ≥ nt , ∀i ∈ N .
The RSU then broadcasts the updated model to all vehicle clients, and these
clients update their local models following the equation θi (mt,i + 1) = θ (t), for
i ∈ N . The global loss function value at time T , noted as F (θ (t)), is given by:

zt,i |Di |Fi (θ (t))
F (θ (t)) = i∈N
  (15)
i∈N zt,i |Di |

Combining the client selection method based on deep reinforcement learn-


ing and the asynchronous aggregation scheme, a federated deep reinforcement
learning-based vehicle asynchronous communication scheme is proposed. The
specific algorithm is shown in Algorithm 1.
A Vehicle Asynchronous Communication Scheme Based on FDRL 371

Algorithm 1: DRL-Based Vehicle Asynchronous Communication Opti-


mization Scheme
Input: Communication rounds c, Local training rounds e, learning rate η,
maximum communication time T , state of the deep reinforcement
learning agent, etc.
Output: Number of clients jt participating in asynchronous
communication and the global model.
1 Initialize communication and local training rounds, learning rate,
maximum time, and DRL agent state;
2 for each epoch do
// At the RSU side
3 for each iteration do
4 Perform action at , initially randomly select a group of vehicle
clients to participate in the current training round;
5 Perform global aggregation using equation (15), calculate global
loss, resource consumption, and total delay;
6 Broadcast the global model to all vehicle clients;
7 Receive reward rt ;
8 Update state st+1 based on rt ;
9 end
// At the vehicle client side
10 for each iteration do
11 Receive the global updated model from RSU;
12 Perform local model update based on equation (14);
13 Calculate local loss based on equation (9);
14 end
15 Upload the updated model to RSU;
16 end

5 Experiment and Result Analysis


5.1 Experiment Environment and Parameter Settings
This section validates the performance of the proposed DRL-based node asyn-
chronous communication optimization scheme through simulation experiments.
Table 1 presents the experimental parameters.

5.2 Datasets and Baseline Settings


(1) MNIST Dataset × CNN Model: The CNN model employed for the MNIST
dataset consists of two 5×5 convolutional layers. The first layer outputs 4
channels, and the second layer outputs 8 channels, both followed by a 2×2
max pooling layer. The MNIST dataset includes 60,000 training samples and
10,000 test samples of handwritten digits from 10 different classes.
372 J. Huang et al.

Table 1. Experimental Parameter Settings

Parameter Value
Vehicle User 30
RSU 1
Noise Power −100 dm
Uplink Link Bandwidth 20 MHz
Subcarriers to be Allocated 20
Maximum Uplink Link Power of Vehicle 17 dBm
Local Computing Capability 3 GHz
Effective Switching Capacitance of Local Computation 10−29

(2) FashionMNIST Dataset × LeNet-5 Model: The LeNet-5 model, which con-
sists of two convolutional layers and three fully connected layers, is applied
to train the FashionMNIST dataset. FashionMNIST, provided by Zalando,
contains 10 classes, 60,000 training samples, and 10,000 test samples of fash-
ion article images.
To evaluate the superiority of the proposed scheme, comparisons are made
with the following three baselines and the proposed algorithm:
– FedAvg [21]: A synchronous federated learning method that randomly selects
a fixed number of clients to participate in training.
– FedCS [22]: A synchronous federated learning method employing a greedy
client scheduling strategy to select as many clients as possible for each round
of federated learning training.
– FedAsync [23]: An asynchronous federated learning method allowing clients
to perform local model training and send updates to the server at different
times without waiting for other clients to complete their training.

Fig. 4. DRL Architecture Diagram.


A Vehicle Asynchronous Communication Scheme Based on FDRL 373

5.3 Performance Analysis

In this method, a DRL agent is trained across multiple datasets with 100 devices.
The DDQN model features two two-layer MLP networks, each with 512 hidden
units. The input size is 10100, which includes 101 model weights, derived from
reducing the 100 device weights (global w and local θ(k)) into a 100-dimensional
representation. The second layer outputs 100 values, processed by a softmax
layer to determine the probability of selecting each device. The DDQN model is
fast, requiring only a few seconds per training iteration.
Figure 4 illustrates the training process of the DRL agent across three differ-
ent datasets. For each learning task, it can be observed that the reward varies
significantly during the initial 100 rounds. As training progresses, the rewards
converge to a stable and high value. For instance, MNIST takes approximately
100 rounds to converge to the maximum reward, while FMNIST takes about 80
rounds. Before training, a target accuracy is preset-98% for the MNIST dataset
and 85% for the FashionMNIST dataset.

Fig. 5. Training results of the MNIST dataset.

To evaluate the performance of the proposed method on the MNIST and


FashionMNIST datasets, Fig. 5 and 6 display the loss and accuracy values,
respectively. As shown in the figures, all methods eventually converge to a thresh-
old as the number of training rounds increases. For the MNIST dataset, the
proposed method begins to converge after approximately 30 rounds. Compared
to the three baselines, the proposed method achieves accuracy improvements of
1.8%, 2.6%, 3%, and 3.4%, while reducing the loss by 0.023, 0.035, 0.048, and
0.062, respectively.
For the FashionMNIST dataset, the proposed method starts to converge after
approximately 58 training rounds. This is because handwritten digit images in
the MNIST dataset are generally simpler and more straightforward for the model
to learn due to their clear and consistent representation. In contrast, the Fashion-
MNIST dataset, which includes items with varying textures and shapes, requires
374 J. Huang et al.

a more detailed process of feature extraction and representation learning. Com-


pared with the three baselines, the proposed method improves accuracy by 1.3%,
2.2%, 3.7%, and 4%, respectively, and reduces the loss by 0.036, 0.044, 0.059,
and 0.063, respectively. These results indicate that the proposed approach not
only accelerates model convergence but also reduces training loss and improves
training accuracy.

Fig. 6. Training results of FashionMNIST dataset.

Fig. 7. Energy consumption and latency comparison of different algorithms.

Figure 7 compares energy consumption and latency during model training


for different algorithms. The proposed method demonstrates significantly lower
energy consumption and shorter total training latency compared to FedAvg,
FedCS, and FedAsync. For example, on the MNIST dataset, the proposed
approach requires approximately 3.3 energy units after convergence, which is
27%, 37%, and 46% less than FedAsync, FedCS, and FedAvg, respectively. This
improvement is due to the DRL-based vehicle selection and resource allocation,
which minimizes energy consumption while maintaining training quality.
A Vehicle Asynchronous Communication Scheme Based on FDRL 375

6 Conclusion
This paper proposes a DRL-based dynamic node communication scheme to
address high communication costs due to frequent vehicle communications. It
first analyzes the problem by considering real-time network conditions, vehi-
cle capabilities, and task requirements, establishing a joint optimization objec-
tive for federated learning accuracy, resource consumption, and communication
latency. A DRL-based node selection algorithm is then designed, modeling the
selection process as a Markov decision process to identify the optimal node com-
bination for training. Finally, a dynamic asynchronous aggregation strategy is
introduced, enabling model updates from different vehicles to be received at
different times. Experimental results on standard datasets show that the pro-
posed method significantly improves training accuracy and reduces communica-
tion latency compared to baseline algorithms.

Acknowledgments. This work was sponsored by the National Natural Science Foun-
dation of China (Grant No. 62271264).

References
1. Elbir, A.M., Soner, B., Çöleri, S., Gündüz, D., Bennis, M.: Federated learning
in vehicular networks. In: 2022 IEEE International Mediterranean Conference on
Communications and Networking (MeditCom), pp. 72–77. IEEE (2022)
2. Li, Q., et al.: A survey on federated learning systems: vision, hype and reality
for data privacy and protection. IEEE Trans. Knowl. Data Eng. 35(4), 3347–3366
(2021)
3. Niknam, S., Dhillon, H.S., Reed, J.H.: Federated learning for wireless communi-
cations: motivation, opportunities, and challenges. IEEE Commun. Mag. 58(6),
46–51 (2020)
4. Liu, S., Yu, J., Deng, X., Wan, S.: FedCPF: an efficient-communication federated
learning approach for vehicular edge computing in 6G communication networks.
IEEE Trans. Intell. Transp. Syst. 23(2), 1616–1629 (2021)
5. Haddadpour, F., Kamani, M.M., Mahdavi, M., Cadambe, V.: Trading redundancy
for communication: speeding up distributed SGD for non-convex optimization. In:
International Conference on Machine Learning. PMLR, pp. 2545–2554 (2019)
6. Liu, L., Zhang, J., Song, S., Letaief, K.B.: Edge-assisted hierarchical federated
learning with non-IID data. arXiv preprint arXiv:1905.06641 (2019)
7. Yao, X., Huang, T., Wu, C., Zhang, R., Sun, L.: Towards faster and better federated
learning: a feature fusion approach. In: 2019 IEEE International Conference on
Image Processing (ICIP), pp. 175–179. IEEE (2019)
8. Wei, S., Tong, Y., Zhou, Z., Song, T.: Efficient and fair data valuation for horizontal
federated learning. In: Federated Learning: Privacy and Incentive, pp. 139–152
(2020)
9. Song, T., Tong, Y., Wei, S.: Profit allocation for federated learning. In: 2019 IEEE
International Conference on Big Data (Big Data), pp. 2577–2586. IEEE (2019)
10. Chai, Z., et al.: TiFL: a tier-based federated learning system. In: Proceedings of
the 29th International Symposium on High-Performance Parallel and Distributed
Computing, pp. 125–136 (2020)
376 J. Huang et al.

11. Lai, F., Zhu, X., Madhyastha, H.V., Chowdhury, M.: Oort: efficient federated learn-
ing via guided participant selection. In: 15th USENIX Symposium on Operating
Systems Design and Implementation (OSDI 21), pp. 19–35 (2021)
12. Ye, D., Yu, R., Pan, M., Han, Z.: Federated learning in vehicular edge computing:
a selective model aggregation approach. IEEE Access 8, 23920–23935 (2020)
13. Chai, H., Leng, S., Chen, Y., Zhang, K.: A hierarchical blockchain-enabled feder-
ated learning algorithm for knowledge sharing in internet of vehicles. IEEE Trans.
Intell. Transp. Syst. 22(7), 3975–3986 (2020)
14. Lu, Y., Huang, X., Dai, Y., Maharjan, S., Zhang, Y.: Federated learning for data
privacy preservation in vehicular cyber-physical systems. IEEE Network 34(3),
50–56 (2020)
15. Pokhrel, S.R., Choi, J.: Federated learning with blockchain for autonomous vehi-
cles: analysis and design challenges. IEEE Trans. Commun. 68(8), 4734–4746
(2020)
16. Sun, K., et al.: Joint top-K sparsification and shuffle model for communication-
privacy-accuracy tradeoffs in federated learning-based IoV. IEEE Internet Things
J. (2024)
17. Wu, J., et al.: FedAPT: joint adaptive parameter freezing and resource allocation
for communication-efficient federated vehicular networks. IEEE Internet Things J.
(2024)
18. Samarakoon, S., Bennis, M., Saad, W., Debbah, M.: Federated learning for ultra-
reliable low-latency V2V communications. In: 2018 IEEE Global Communications
Conference (GLOBECOM), pp. 1–7. IEEE (2018)
19. Cao, J., Zhang, K., Wu, F., Leng, S.: Learning cooperation schemes for mobile
edge computing empowered internet of vehicles. In: 2020 IEEE Wireless Commu-
nications and Networking Conference (WCNC), pp. 1–6. IEEE (2020)
20. Parekh, R., et al.: GeFL: gradient encryption-aided privacy preserved federated
learning for autonomous vehicles. IEEE Access 11, 1825–1839 (2023)
21. McMahan, B., Moore, E., Ramage, D., Hampson, S., y Arcas, B.A.:
Communication-efficient learning of deep networks from decentralized data. In:
Artificial Intelligence and Statistics, pp. 1273–1282. PMLR (2017)
22. Liu, Y., Chang, S., Liu, Y.: FedCS: communication-efficient federated learning with
compressive sensing. In: 2022 IEEE 28th International Conference on Parallel and
Distributed Systems (ICPADS), pp. 17–24. IEEE (2023)
23. Xie, C., Koyejo, S., Gupta, I.: Asynchronous federated optimization. arXiv preprint
arXiv:1903.03934 (2019)
A Vehicles Scheduling Algorithm Based
on Clustering Based Federated Learning

Xin Zhang1 , Chi Zhang2 , Shuyan Liu3 , and Zilong Jin1(B)


1
Nanjing University of Information Science and Technology, Nanjing 210044, China
[email protected]
2
Nanjing Panda Handa Technology Co., Ltd., Nanjing 210001, China
3
Zhejiang Sci-Tech University, Hangzhou 310014, China
[email protected]

Abstract. In the federated learning enabled vehicular networks, net-


work resources are relatively limited. If all vehicles participate in model
training, it will consume a large amount of network bandwidth, leading
to network congestion and affecting the efficiency of data transmission.
In this paper, a cluster-based federated learning approach for optimal
scheduling of vehicular network user is proposed. Vehicles are clustered
based on inference similarity to better learn information about the sur-
rounding environment. From the perspective of minimizing the commu-
nication delay, the proposed method considers four key factors: data
importance, communication rounds, channel quality, and single-round
communication delay, to determine the optimal probability of vehicle
selection. Considering that the variability of the transmission success
rate of different devices can lead to the failure of timely uploading of
valid local models, a weight of the success rate is set during RSU aggre-
gation to reduce the impact of model bias. The experimental results show
that the algorithm accuracy has been improved by 6% compared to tra-
ditional methods, and the delay has been reduced by 30% compared to
the optimal comparison method.

Keywords: Vehicle networking · Clustering · Federated learning ·


User scheduling · Transmission success rate

1 Introduction

With the rapid development of fifth-generation communication technology and


the widespread deployment of the Internet of Things (IoT), vast amounts of data
are being collected, analyzed, and utilized with artificial intelligence (AI) for
various applications. In vehicular networks, AI enables real-time traffic updates,
intelligent driving assistance, and remote diagnostics [1]. However, data and com-
putational resources are often distributed across devices. Traditional centralized
machine learning methods require data upload to a central server, leading to

c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 377–391, 2025.
https://doi.org/10.1007/978-981-96-4566-4_26
378 X. Zhang et al.

high communication overhead and privacy concerns [2]. Federated learning, a dis-
tributed machine learning algorithm, helps address these challenges by enabling
decentralized model training [3].
Model training depends on numerous parameters, and limited network
resources become a bottleneck for data transmission and real-time information
exchange [4]. This impacts federated learning’s efficiency, particularly in large
model transmission [5]. User scheduling within vehicular networks optimizes
resource allocation and reduces congestion [6]. Effective scheduling strategies
[7] can reduce transmission latency, ensuring timely information. To improve
model convergence and minimize update delays, this paper proposes a federated
learning-based vehicular user scheduling scheme. Focusing on minimizing com-
munication time, it integrates four key factors: data importance, communication
rounds, channel state, and communication delay per round, to determine the
optimal vehicle selection probability. The specific contributions of this paper are
as follows:

1) A reasoning similarity-based clustering method is proposed to address accu-


racy loss from data heterogeneity in vehicular networks. Cluster heads are
selected based on vehicle resources and mobility, matching vehicles with the
most suitable models.
2) To optimize communication resource consumption, the paper balances com-
munication rounds and delay by assigning optimized selection probabilities to
vehicles. A subset of vehicles is selected to transmit local models, minimizing
redundant transmission while maintaining accuracy.
3) Due to varying transmission success rates across devices, which can delay
effective model uploads, a model aggregation method based on transmission
success rates is introduced. This ensures the global model accurately reflects
local models, improving overall performance and contribution accuracy.

2 Related Work
2.1 Federated Learning in Vehicular Networks

Autonomous driving requires integrating advanced sensor technologies, real-time


decision-making systems, high-precision maps, and the ability to handle complex
driving scenarios. Applying federated learning (FL) in vehicular networks can
address privacy and security concerns in vehicle data transmission. Kim et al. [8]
applied an adaptive clipping method to enhance local differential privacy in FL
models, showing that privacy enhancement has minimal impact on performance.
Feng et al. [9] incorporated a Laplace mechanism into the Federated Averaging
algorithm to maintain privacy while ensuring good network performance.
Since vehicular network data is often non-IID and unlabeled, standard FL
methods may not meet real-world requirements. Existing research has explored
semi-supervised FL, using unlabeled data for training. Li et al. [10] proposed
a semi-supervised FL method for vehicular networks, utilizing adaptive weight
hyperparameters to improve performance with homogeneous data. Li et al. [11]
A Vehicles Scheduling Algorithm Based on Clustering Based FL 379

developed a federated end-to-end learning framework for collaborative learning


tasks in VANETs, addressing non-IID data in multi-hop cluster VANETs with
a distributed optimization problem. Zhou et al. [12] proposed a two-layer FL
architecture for 6G vehicular networks, using roadside units (RSUs) and global
cloud context to improve learning accuracy while preserving privacy.
In response to the computational demands of vehicular networks, Muhammad
et al. [13] introduced an incentive mechanism to encourage vehicles to participate
in federated learning. This problem is NP-Hard, and their solution used a bid
determination algorithm for sellers and a greedy algorithm for winner selection
in auction games, achieving the lowest delay under certain conditions. Bao et al.
[14] proposed a federated client selection scheme for vehicular IoT, integrating
edge computing and fuzzy logic to select vehicles as federated learning clients
based on factors like vehicle speed, distribution, and wireless link quality. While
much of the existing work has focused on privacy and computational challenges,
real-time delays in data collection and decision-making, crucial to autonomous
driving safety, have not been adequately addressed [15].

2.2 User Scheduling


User scheduling strategies in federated learning for vehicular networks typically
address resource constraints, communication costs, and heterogeneous environ-
ments. Liu et al. [19] proposed a federated learning algorithm designed for
resource-constrained edge networks, focusing on the limitations of vehicular
clients. Nishio et al. [20] studied vehicular client selection to maximize global
model update efficiency by considering device computing and communication
capabilities. Wang et al. [21] optimized resource usage by dynamically adjusting
device participation and update frequencies. Dinh et al. [22] addressed resource
allocation issues in federated learning, optimizing bandwidth and energy dis-
tribution. Yang et al. [23] proposed a dynamic scheduling scheme that adjusts
device participation based on real-time conditions like battery level and CPU
usage.
In federated learning for vehicular networks, communication cost remains
the bottleneck. Participating nodes must communicate with a central server over
multiple rounds to achieve model accuracy, exacerbated by unreliable device net-
works and imbalanced upload/download speeds. Reisizadeh et al. [25] proposed
a method where the central server selects only active clients, reducing communi-
cation rounds and data transmission by sending compressed local models. Yang
et al. [26] used deep reinforcement learning to optimize federated learning in
mobility-sensitive networks. Diao et al. [27] proposed dynamic client selection
based on device computing power and bandwidth to improve participation effi-
ciency. Nishio et al. [28] introduced a federated learning protocol with delay
constraints to manage resource-constrained terminals and prevent asynchronous
delays.
Despite these advances, many studies overlook the impact of data impor-
tance and channel quality on training accuracy and communication efficiency.
Additionally, deep learning-based optimization algorithms often fail to balance
380 X. Zhang et al.

computational resources, communication costs, and accuracy in practical sce-


narios. This paper aims to minimize communication time by optimizing four key
factors: data importance, communication rounds, channel status, and communi-
cation delay per round, to enhance training accuracy while reducing communi-
cation overhead in vehicular networks.

3 System Model

The federated learning user scheduling environment considered in this paper is


shown in Fig. 1, which includes K vehicles and multiple Road Side Units (RSUs).
Let V = {ν1 , ν2 , . . . , νK } denote the set of K vehicles, the set of vehicle clusters
NV C is represented by V C = {V C1 , V C2 , . . . , V CNvc }, and the corresponding
cluster heads are denoted by V Ci = {νi,0 , νi,1 , . . . , νi,ωi }. The vehicle cluster
V Ci consists of the cluster head CHi and cluster members, with cluster members
represented as V Ci = {νi,0 , νi,1 , . . . , νi,ωi }. Under the coordination of roadside
units, multiple vehicles participating in model transmission train a global model.
Each RSU provides services to vehicles within its communication range, meaning
each vehicle is associated with its nearest RSU.

Fig. 1. System Architecture Diagram

Each vehicle has computational resources to convert raw data into parameter
models. RSUs cluster vehicles based on data similarity, with vehicles in the same
cluster sharing a global model. Vehicles with more computational resources and
stable movement are prioritized as cluster heads to reduce cluster instability and
communication overhead caused by frequent head changes.
To learn model features from distributed vehicular data, the RSU solves the
following distributed optimization problem:


K
min F (w) = qk Fk (w), (1)
w
k=1
A Vehicles Scheduling Algorithm Based on Clustering Based FL 381

where w represents the learning model parameters, qk is the weight of the k-th
K
closest device, and qk > 0 with the constraint that k=1 qk = 1. Fk (w) denotes
the local loss function at device k, which is expressed as:
1 
Fk (w) = f (w, x), (2)
nk
x∈Dk

Here, the local dataset Dk on device k is non-independently and identically


distributed (non-iid). nk = |Dk | indicates the number of samples in Dk . Conse-
K
quently, the weight qk in Eq. (1) can be set to qk = nnk , where n = k=1 nk ,
and f (w, x) represents the loss function for data sample x. Since the information
is distributed across multiple vehicular devices, the RSU cannot directly solve
Eq. (1). Therefore, the RSU and the vehicles collaborate through an iterative
algorithm to learn the optimal model parameters w∗ = arg min F (w).

3.1 Clustering Method


The RSU performs inference on each client model, resulting in an M × N matrix
Bk = Fk (Dserver ; wtk ), for k = 1, . . . , |St |, where M and N are the number
of auxiliary samples on the RSU and the number of neurons in the last fully
connected layer, respectively. The RSU constructs an adjacency matrix Ati,j using
the inference results from each client Bk . With the adjacency matrix Ati,j , a hard
threshold operator Γ is defined and applied to Ãi,j = Γ (Ai,j ) = Sign(Ai,j − β),
where β represents the threshold. The positive index values in each row of Ãi,j
are then grouped into the same cluster.
All vehicles within a cluster send their computational resources Ck and mobil-
ity stability Mk , and other related data, to the nearby Road Side Unit (RSU).
The RSU calculates the objective function based on the weighted combination
of these two factors.

E(Ck , Mk ) = δ · Ck + (1 − δ) · Mk , (3)

where δ is a parameter that balances computational resources and mobility sta-


bility, with 0 ≤ δ ≤ 1. By adjusting the value of δ, the best balance between
computational resources and mobility can be achieved.

3.2 User Scheduling Strategy


The user’s probability scheduling aims to minimize the communication time in
each round of training, represented as:
Nt 
 
(i) (i)
T (t) = TB + TU , (4)
i=1

where T (t) represents the communication time required after t rounds of updates,
(i) (i)
Nt is the number of communication rounds still needed, TB and TU represent
382 X. Zhang et al.

the time to send the global model from the RSU to each cluster head and the
upload time of the local model in round i, respectively. The sum of these two is
the communication time for round i. The communication time is not considered
(i) (i)
in the scheduling strategy and can be neglected compared to TB and TU . The
minimization problem can be transformed into:
(t)
P1  min  TU + Nt+1
R
TUR ,
(t)
p(t) ,...,pk
(5)

K
(t)
s.t. C1 : pk = 1,
k=1
R
where Nt+1 represents the number of communication rounds still needed after
t rounds of training iterations, and TUR represents the communication time for
each round in future rounds.
Transforming P 1 into a convex optimization problem, it can be solved using
the KKT (Karush-Kuhn-Tucker) conditions. Assuming:
(t) (t) (t)
TU = f (Pk , qd, BRk ), (6)
and taking the derivative:
(t)

∂TU λ qd
(t)
=− (t) (t)
+ λ∗ , (7)
∂pk ϕt nnk wk  BRk
Setting the above equation to 0 gives:

λ qd
(t) (t)
+ λ∗ = 0. (8)
ϕt nnk wk  BRk
(t)
Combining the constraint conditions and the expression for pk gives:
K (t)
ϕt nk wk 
 n = 1. (9)

(t) + λ
qd
k=1
BRk

Solving this equation can yield the Lagrange multiplier λ∗ . Through the above
KKT conditions, the optimal solution can be obtained:
(t)
(t)∗ ϕt nk wk 
pk = n , (10)

(t) + λ
qd
BRk

where ϕ1 is a scalar factor that adjusts the proportion between parameter impor-
(t) 1
tance nnk wk  and transmission rate  (t)
. These three parameters
qd /BRk +λ∗
(t)∗ (t)∗
determine the optimal probability pk .
As t increases, pk decreases contin-
uously. As training progresses, the number of communication rounds t gradu-
ally increases, and the scheduling probability distribution will be more inclined
towards the transmission rate of the vehicle devices.
A Vehicles Scheduling Algorithm Based on Clustering Based FL 383

3.3 Aggregation Update


After being scheduled, vehicle device k receives the global model parameters from
the RSU at time t and initializes its local model to wtk = wt , where wt is the
global model parameter sent by the RSU, and wtk is the local model parameter
at vehicle k. The device then performs E rounds of stochastic gradient descent
steps, updating its local model at each step. The updated parameter model after
each round can be represented as:

wt , if r = 0
w(t+r) =
k
(11)
k
v(t+r) , if r = 1, 2, . . . , E − 1.
After the E rounds of stochastic gradient descent, the scheduled vehicle sends
k
vt+E to the cluster head. Upon collecting the local models, the cluster head
performs weighted aggregation and sends it to the RSU. The success probability
of vehicle k being scheduled can be represented by Uk , defined as:
Uk = N [1(SINRk > θ)|k ∈ St ]. (12)
The RSU updates the global model by assigning weights to vehicles, com-
pensating for system heterogeneity and varying communication capabilities. It
is assumed that downlink transmission is always successful, allowing cluster
head vehicles to receive the global model from the RSU. This assumption holds
because the RSU has higher power and more resources than vehicle devices. How-
ever, due to the power limitations of vehicles and finite communication resources,
uplink transmission during aggregation is not guaranteed to be fully successful.
To increase the success probability during the aggregation step, it is assumed
that all scheduled vehicles attempt to transmit their local model parameters
once, and the RSU uses the highest Signal-to-Interference-plus-Noise Ratio
(SINR) to recover the local model parameters transmitted by the cluster head.
This can be expressed as:
Uk = N [1(SINRk > θ) | k ∈ St ]
(13)
Uk = N [1 (max {SINRk (1), . . . , SINRk (l)} > θ) | k ∈ St ] ,
(r)
where SINRk represents the SINR received from vehicle k during the r-th trans-
mission attempt in the aggregation step. The SINR can be defined as:

hk (r)gk−α hk (r)gk−α
SINRk (r) = =  , (14)
I(r) + σ 2 x∈ΦI hx (r)x
−α + σ 2

where σ 2 is the normalizednoise power (noise power divided by the device trans-
mission power), I(r) = x∈ΦI hx (r)x−α is the interference during the first
attempt, ΦI represents the set of interferers, and hx (r) represents the small-
scale fading gain between the vehicle at position x and the RSU at the origin.
Assuming Rayleigh fading, hx ∼ exp(1). α represents the path loss exponent,
and since the set of interferers is the same across different transmission attempts
in an aggregation step, the success rate during the aggregation period is time-
dependent.
384 X. Zhang et al.

4 Algorithm Steps
Combining the clustering scheme based on inference similarity and the user
scheduling scheme that minimizes communication delay, a federated learning-
based vehicle user scheduling research scheme is proposed. Algorithm 1 provides
the pseudocode for the scheme.

Algorithm 1 . Federated Learning-Based Vehicle User Scheduling Research


Scheme
Require: Number of vehicles N , sampling rate R ∈ {0, 1}, clustering threshold β,
cluster head selection parameter δ
k
Ensure: {wt+E }k∈[K]
1: Initialize: {wtk }k∈St , select a group of n vehicles St
2: for t = 0, 1, 2, . . . , T − 1 (communication rounds) do
3: for all vehicles k do
4: Obtain the current local model wtk
5: for r = 0, 1, . . . , E (local iterations) do
k k
6: Update the local model wt+r+1 ← wt+r
7: end for
k
8: Send the updated wt+r to the cluster head
9: end for
10: RSU does:
k
11: Receive wt+r from each cluster head, construct the adjacency matrix
12: Use a hard threshold parameter to determine similarity for clustering
13: E(Ck , Mk ) = δ · Ck + (1 − δ) · Mk  Select cluster head
Ti+1
14: return {V Cji+1 }ji+1 , {CHNV C }
15: Scheduler does:
n
(t)∗ φ k wk 
16: pk = t n (t) t  Select users for scheduling within the cluster
BRk +λ∗
N 1 k
17: wt+E = wt + k=1 Uk (k ∈ St , SINRk > θ)(vt+E − wt )
18:  RSU updates the global model
19: end for

In the first round, the RSU samples the vehicles and broadcasts the global
initial model parameters. Vehicles use the received model to train with their
own data and send it back to the RSU, which then clusters the vehicles. A
neighborhood matrix is constructed from the models obtained from the vehicle
clients:

Bit  Bjt F
Ati,j = , i, j = 1, . . . , n, (15)
Bit F Bjt F
where Ati,j represents the similarity between vehicles i and j based on their model
parameters. A threshold operator Γ is defined and applied as:

Ãti,j = Γ (Ati,j ) = Sign(Ati,j − β). (16)


A Vehicles Scheduling Algorithm Based on Clustering Based FL 385

The RSU places each row’s values of à into the same cluster to form clusters
Tt+1
{V Cjt+1 }jt+1 , jt = 1, . . . , Tt of vehicles with similar data. Concurrently, the RSU
selects a cluster head based on the computational resources and mobility stability
of the vehicles within the cluster and sends these results back to the vehicles. The
(t)
scheduler calculates the optimal probability pk for each vehicle to be selected,
and vehicles are chosen for model training based on this probability distribution.
In the t-th iteration, the RSU samples data from the cluster heads and
updates the global model as follows:


N
1
w(t+E) = wt + (k ∈ St , SINRk > θ) (v(t+E)
k
− wt ), (17)
Uk
k=1

where SINRk represents the Signal-to-Interference-plus-Noise Ratio received by


vehicle k, and θ is the SINR threshold. After aggregation, the RSU distributes
the updated model parameters wt+E to the cluster heads, which then send them
to the vehicles in their clusters. Each vehicle then performs E steps of stochastic
gradient descent to update, and the updated parameters are sent to the RSU.
The local model update formula for each vehicle is as follows:

k
v(t+i+1) = w(t+i)
k
− η(t+i) ∇Fk (w(t+1)
k
; ξ(t+i)
k
), i = 0, . . . , E − 1, (18)

k
where η(i+1) is the learning rate, and ξi+1 is a uniformly selected sample from
vehicle k’s local dataset. After updating the model parameters, vehicles send
them back to the cluster heads, which process them and send them to the RSU.
The RSU uses the updated parameters to form dynamic vehicle clusters with
similar data and calculates the new global model parameters for each cluster.

5 Experiment and Result Analysis


5.1 Experimental Setup

This paper utilizes the MNIST dataset [29], distributed in a non-identically


and independently distributed (non-IID) manner [30]. Specifically, each vehicle
is assigned samples of two-digit numbers, with varying quantities across differ-
ent devices. The experiment also includes the vehicle network-related GTSRB
dataset [31], which comprises 39,209 images for training and 12,630 for testing,
covering 43 common traffic sign categories found on German roads, including
speed limits, no-entry signs, give-way signs, and various other regulatory, warn-
ing, and informational signs commonly encountered.
386 X. Zhang et al.

Table 1. Simulation Experiment Parameters

Parameter Value
Transmission Bandwidth 5 MHz
Vehicle Computing Resources 0.5–1.5 GHz
Vehicle Transmission Power 1.3 W
Batch Size 64
Path Loss 0.0001
Number of Vehicles 100
Path Loss Exponent 4
SINR Threshold −15 dB

To validate the performance of the proposed algorithm on these datasets, a


three-layer neural network is employed, with each hidden layer consisting of 300
hidden units. The specific experimental parameters are outlined in Table 1.
The proposed scheme in this paper is compared with four other methods:

1. FedAvg [32]: The central server performs a weighted average of the locally
trained model parameters to achieve an update of the global model.
2. Fed-Cluster: Incorporating an inference similarity-based clustering method
on top of FedAvg.
3. Importance Aware (IA): The probability of vehicle selection is solely
related to the importance of its data.
4. Importance and Channel Aware (ICA): Under the federated learning
framework, the server selects devices for scheduling considering both data
importance and communication channel status.

5.2 Experimental Results and Analysis

Figure 2 illustrates a comparison of the maximum uplink communication delay


among vehicles during training. It can be observed that the scheme proposed
in this chapter outperforms the four baseline methods, offering lower uplink
delay and less fluctuation, with an average delay reduction of 50% compared
to FedAvg. This is attributed to the scheduling of vehicles within clusters in
each round, preventing all users from competing for communication resources,
thus avoiding waste. Since only cluster head vehicles transmit models in the Fed-
Cluster method, it conserves communication resources and alleviates congestion,
making it the total delay the lowest among the compared methods. Compared
to IA and ICA algorithms, the proposed algorithm considers data importance
and channel quality, integrating communication rounds and per-round delay to
achieve optimal training rounds and timing, resulting in relatively lower latency.
A Vehicles Scheduling Algorithm Based on Clustering Based FL 387

Fig. 2. Uplink Communication Delay

Fig. 3. Performance of the Algorithm on the MNIST Dataset

Figures 3 and 4 demonstrate the learning performance of the proposed scheme


compared to the four baseline methods on the MNIST and GTSRB datasets,
respectively. In the initial phase, all algorithms improve model accuracy at a
similar pace, with the proposed algorithm ultimately achieving the best pre-
cision. While there is no significant difference in final accuracy between ICA
and FedAvg, ICA converges much faster, reaching a convergent state around
100 s, whereas FedAvg takes longer, showing convergence trends only after about
150 s, due to ICA’s consideration of data importance and channel status in user
388 X. Zhang et al.

Fig. 4. Performance of the Algorithm on the GTSRB Dataset

scheduling. Compared to IA and ICA, the method proposed in this paper excels
in convergence speed, showing signs of convergence around 60 s, as it takes into
account both communication time and training rounds. This indicates that, com-
pared to other algorithms, the scheduling strategy of this paper can more rapidly
aggregate the local models of distributed vehicles into a global model under lim-
ited communication resources, thus accelerating the overall federated learning
process.

6 Conclusion
This paper addresses the issue of limited communication resources by designing
a vehicle user scheduling scheme based on the federated learning framework. To
tackle the accuracy decline caused by heterogeneous data in vehicular networks,
an inference similarity-based clustering method is proposed, selecting appro-
priate cluster heads based on vehicles’ computational resources and mobility
stability. Before each round of parameter transmission, the scheme links data
importance with communication rounds, channel quality, and single-round delay
to solve for the optimal selection probability of vehicles within clusters, conserv-
ing communication resources. Considering the variability in device transmission
success rates that could lead to shifts in global model characteristics, the success
probability is incorporated at the RSU aggregation stage to enhance the effec-
tiveness of federated learning. Finally, experiments on real datasets verify the
superior performance of this algorithm in terms of accuracy, convergence speed,
and delay.
A Vehicles Scheduling Algorithm Based on Clustering Based FL 389

Acknowledgments. This work was sponsored by the National Natural Science Foun-
dation of China (Grant No. 62271264).

References
1. Han, S., Xiao, F., Cheng, W.: A review on the application of deep reinforcement
learning in automatic driving systems. J. Xihua Univ. (Nat. Sci. Ed.) 42(4), 25–31
(2023). https://doi.org/10.12198/j.issn.1673-159X.4740
2. Zijia, M., Zhipeng, G., Yang, Y., et al.: An efficient distributed model sharing strat-
egy for data privacy protection in Telematics. J. Commun. 43(4), 83–94 (2022).
https://doi.org/10.11959/j.issn.1000-436x.2022074
3. Guanglei, G., Bo, G., Ke, X., et al.: An overview of federated learning-enabled 6G
networks. J. Internet Things 7(2), 50–66 (2023). https://doi.org/10.11959/j.issn.
2096-3750.2023.00323
4. Sun, B., Liu, Y., Wang, T., et al.: A review of federated learning efficiency opti-
mization in mobile edge networks. Comput. Res. Dev. 65(7), 1439–1469 (2022).
https://doi.org/10.7544/issn1000-1239.20210119
5. Jiarui, W., Guoping, T., Siyuan, Z.: Split-cluster wireless federated learning algo-
rithm for high-speed vehicular networking scenarios. Comput. Appl. 41(6), 1546
(2021). https://doi.org/10.11772/j.issn.1001-9081.2020121912
6. Lin, F., Li, H., Luo, C., et al.: A joint computational offloading and resource
allocation algorithm for dependent tasks in vehicular networking. J. Chongqing
Univ. Posts Telecommun. (Nat. Sci. Ed.) 35(5) (2023). https://doi.org/10.3979/j.
issn.1673-825X.202210270296
7. Chengyu, Z., Yiting, Y., Hongbin, L., et al.: A review of optimal resource allocation
schemes for 5G Telematics. Telecommun. Sci. 39(7), 124–138 (2023). https://doi.
org/10.11959/j.issn.1000-0801.2023139
8. Kim, E.J., Lee, E.K.: Performance impact of differential privacy on federated
learning in vehicular networks. In: NOMS 2022-2022 IEEE/IFIP Network Oper-
ations and Management Symposium, pp. 1–5. IEEE (2022). https://doi.org/10.
1109/NOMS54207.2022.9789814
9. Huang, J., Xu, C., Ji, Z., et al.: AFLPC: an asynchronous federated learning
privacy-preserving computing model applied to 5G-V2X. Secur. Commun. Netw.
(2022). https://doi.org/10.1155/2022/9334943
10. Su, X., Huo, Y., Wang, X., et al.: An enhancing semi-supervised federated learning
framework for internet of vehicles. In: 2023 IEEE 98th Vehicular Technology Con-
ference (VTC2023-Fall), pp. 1–5. IEEE (2023). https://doi.org/10.1109/VTC2023-
fall60731.2023.10333466
11. Li, B., Jiang, Y., Pei, Q., et al.: FEEL: federated end-to-end learning with non-
IID data for vehicular ad hoc networks. IEEE Trans. Intell. Transp. Syst. 23(9),
16728–16740 (2022). https://doi.org/10.1109/TITS.2022.3190294
12. Zhou, X., Liang, W., She, J., et al.: Two-layer federated learning with heterogeneous
model aggregation for 6G supported internet of vehicles. IEEE Trans. Veh. Technol.
70(6), 5308–5317 (2021). https://doi.org/10.1109/TVT.2021.3077893
13. Bute, M.S., Fan, P., Luo, Q.: Incentive based federated learning data dissemination
for vehicular edge computing networks. In: 2023 IEEE 98th Vehicular Technol-
ogy Conference (VTC2023-Fall), pp. 1–5. IEEE (2023). https://doi.org/10.1109/
VTC2023-Fall60731.2023.10333479
390 X. Zhang et al.

14. Bao, W., Wu, C., Guleng, S., et al.: Edge computing-based joint client selection
and networking scheme for federated learning in vehicular IoT. China Commun.
18(6), 39–52 (2021). https://doi.org/10.23919/JCC.2021.06.004
15. Lim, W., Luong, N.C., Hoang, D.T., et al.: Federated learning in mobile edge
networks: a comprehensive survey. IEEE Commun. Surv. Tutor. 22(3), 2031–2063
(2020). https://doi.org/10.1109/COMST.2020.2986024
16. Li, Q., Wen, Z., Wu, Z., et al.: A survey on federated learning systems: vision,
hype and reality for data privacy and protection. IEEE Trans. Knowl. Data Eng.
(2021). https://doi.org/10.1109/TKDE.2021.3124599
17. Sai, Z., Tianrui, L., Wei, H.: A federated learning algorithm for communication
cost optimization. Comput. Appl. 43(1), 1 (2023). https://doi.org/10.11772/j.issn.
1001-9081.2021122054
18. Zhang, H., Zhang, Y., Cao, C.: A federated learning approach to solve the data het-
erogeneity problem. Appl. Res. Comput./Jisuanji Yingyong Yanjiu 41(3) (2024).
https://doi.org/10.19734/j.issn.1001-3695.2023.07.0296
19. Liu, J., Xu, H., Xu, Y., et al.: Communication-efficient asynchronous federated
learning in resource-constrained edge computing. Comput. Netw. 199, 108429
(2021). https://doi.org/10.1016/j.comnet.2021.108429
20. Nishio, T., Yonetani, R.: Client selection for federated learning with heterogeneous
resources in mobile edge. In: 2019 IEEE International Conference on Communica-
tions (ICC), pp. 1–7. IEEE (2019). https://doi.org/10.1109/ICC.2019.8761315
21. Wang, S., Tuor, T., Salonidis, T., et al.: Adaptive federated learning in resource
constrained edge computing systems. IEEE J. Sel. Areas Commun. 37(6), 1205–
1221 (2019). https://doi.org/10.1109/JSAC.2019.2904348
22. Dinh, C.T., Tran, N.H., Nguyen, M., et al.: Federated learning over wireless net-
works: convergence analysis and resource allocation. IEEE/ACM Trans. Netw.
29(1), 398–409 (2020). https://doi.org/10.1109/TNET.2020.3035770
23. Yang, H.H., Liu, Z., Quek, T., et al.: Scheduling policies for federated learning in
wireless networks. IEEE Trans. Commun. 68(1), 317–333 (2019). https://doi.org/
10.1109/TCOMM.2019.2944169
24. Liu, L., Zhang, J., Song, S.H., et al.: Client-edge-cloud hierarchical federated learn-
ing. In: 2020 IEEE International Conference on Communications (ICC), pp. 1–6.
IEEE (2020). https://doi.org/10.1109/ICC40277.2020.9148862
25. Reisizadeh, A., Mokhtari, A., Hassani, H., et al.: FedPAQ: a communication-
efficient federated learning method with periodic averaging and quantization. In:
International Conference on Artificial Intelligence and Statistics, PMLR, pp. 2021–
2031 (2020). https://doi.org/10.48550/arXiv.1909.13014
26. Yang, K., Jiang, T., Shi, Y., et al.: Federated learning via over-the-air computation.
IEEE Trans. Wireless Commun. 19(3), 2022–2035 (2020). https://doi.org/10.1109/
TWC.2019.2961673
27. Diao, E., Ding, J., Tarokh, V.: HeteroFL: computation and communication effi-
cient federated learning for heterogeneous clients. arXiv preprint arXiv:2010.01264
(2020). https://doi.org/10.48550/arXiv.2010.01264
28. Posner, J., Tseng, L., Aloqaily, M., et al.: Federated learning in vehicular networks:
opportunities and solutions. IEEE Netw. 35(2), 152–159 (2021). https://doi.org/
10.1109/MNET.011.2000430
29. LeCun, Y., Bottou, L., Bengio, Y., et al.: Gradient-based learning applied to docu-
ment recognition. Proc. IEEE 86(11), 2278–2324 (1998). https://doi.org/10.1109/
5.726791
A Vehicles Scheduling Algorithm Based on Clustering Based FL 391

30. Li, J., Shao, Y., Wei, K., et al.: Blockchain assisted decentralized federated learning
(BLADE-FL), performance analysis and resource allocation. IEEE Trans. Paral-
lel Distrib. Syst. 33(10), 2401–2415 (2021). https://doi.org/10.1109/TPDS.2021.
3138848
31. Stallkamp, J., Schlipsing, M., Salmen, J., et al.: The German traffic sign recognition
benchmark: a multi-class classification competition. In: The 2011 International
Joint Conference on Neural Networks, pp. 1453–1460. IEEE (2011). https://doi.
org/10.1109/IJCNN.2011.6033395
32. Konečný, J., Mcmahan, H.B., Yu, F.X., et al.: Federated learning: strategies
for improving communication efficiency. arXiv preprint arXiv:1610.05492 (2016).
https://doi.org/10.48550/arXiv.1610.0549
33. Chen, M., Poor, H.V., Saad, W., et al.: Convergence time optimization for federated
learning over wireless networks. IEEE Trans. Wireless Commun. 20(4), 2457–2471
(2020). https://doi.org/10.1109/TWC.2020.3042530
34. Ren, J., He, Y., Wen, D., et al.: Scheduling for cellular federated edge learning
with importance and channel awareness. IEEE Trans. Wireless Commun. 19(11),
7690–7703 (2020). https://doi.org/10.1109/TWC.2020.3015671
A Cooperative Caching Strategy Based
on Deep Q-Network for Mobile Edge
Networks

Chun Yang1 , Guoqing Xu2 , Liang Ma3 , and Zilong Jin1(B)


1
School of Software, Nanjing University of Information Science and Technology,
Nanjing 210044, China
[email protected], [email protected]
2
China Telecom Corporation Tongshan Branch, Tongshan 221100, China
3
Zhejiang Sci-Tech University, Hangzhou 310018, China
[email protected]

Abstract. The continuous growth of network traffic and mobile device


numbers poses significant challenges to traditional network architectures.
Edge computing, by decentralizing resources from the cloud to the net-
work edge, alleviates backhaul bandwidth pressure and enhances user
Quality of Service. However, user service requests are still negatively
impacted by limited cache space in hotspot areas. Efficiently utilizing idle
cache resources and formulating effective caching strategies are therefore
critical issues in edge caching. To address these challenges, this paper
proposes a cooperative edge caching strategy based on Deep Q-Network
(DQN). First, edge nodes are categorized into hotspot nodes and regular
nodes based on their network traffic. Then, the K-means++ algorithm is
used to cluster edge nodes based on their traffic data and distance, form-
ing cooperative domains where nodes collectively provide cache space
for requested content, reducing the latency associated with downloading
content from cloud servers. We then analyze the download latency and
caching costs associated with storing content on edge nodes and within
cooperative caching domains, formulating a content caching problem. To
minimize overall system costs, a DQN-based scheme is proposed to opti-
mize caching decisions. Simulation results demonstrate that the proposed
scheme can significantly improve system utility.

Keywords: Edge Computing · Edge Caching · Deep Q-Network ·


K-means++

1 Introduction
The surge in network traffic from mobile devices and bandwidth-intensive appli-
cations is overloading traditional cloud architectures, characterized by central-
ized processing and storage. This strain results in increased latency and dimin-
ished Quality of Experience (QoE), particularly affecting real-time applications.

c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 392–407, 2025.
https://doi.org/10.1007/978-981-96-4566-4_27
A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 393

Edge computing [1, 2] offers a solution by decentralizing resources to the net-


work edge, reducing backhaul pressure, lowering latency, and improving overall
Quality of Service (QoS). Edge caching, a key feature of edge computing, stores
frequently requested content at edge nodes to enable faster user access. How-
ever, the limited caching capacity at edge nodes, especially in densely populated
“hotspot” areas where user demand is concentrated, remains a significant chal-
lenge.
While current edge caching strategies, including classical methods like Least
Recently Used (LRU) and Least Frequently Used (LFU) algorithms, as well as
more sophisticated approaches incorporating content popularity and user mobil-
ity, have shown some promise, they often fall short of achieving optimal perfor-
mance in dynamic mobile environments. These approaches frequently overlook
the uneven distribution of traffic loads among edge nodes and the complex spa-
tial heterogeneity of user demands. Many of these strategies also operate on the
assumption of static or predictable content access patterns, which is rarely accu-
rate in real-world scenarios. Moreover, they often focus on the performance of
individual edge nodes and do not effectively leverage the benefits of collaborative
caching between neighboring nodes and in a cooperative domain, causing inef-
ficient resource utilization and increased latency. Thus, there is a critical need
for more intelligent caching strategies that can adapt to dynamic conditions,
leverage the collaborative potential, and make the most effective use of limited
resources.
To address the challenges of uneven traffic, fluctuating content popular-
ity, and varying user demands, this paper proposes a novel cooperative edge
caching strategy using Deep Q-Networks (DQN). Our approach aims to improve
resource utilization through a two-stage process. First, edge nodes are classified
as “hotspot” or “regular” based on their traffic patterns. Second, a K-means++
clustering algorithm groups nodes into collaborative domains, considering both
traffic data and geographic distance, facilitating the collaborative caching of
popular content within each domain. Finally, a DQN-based caching mechanism
dynamically adapts to network changes, optimizing caching decisions at each
node using a formulated objective that balances download latency and cache
costs. This two-stage approach effectively utilizes resources and provides neces-
sary adaptability for mobile edge networks.
The main contributions of this paper are summarized as follows:
1. Edge nodes are classified as “hotspot” or “regular” based on network traffic
data. Subsequently, K-means++ clustering groups these nodes into cooper-
ative domains based on traffic patterns and geographic proximity. Within
each domain, popular content is collaboratively cached to enhance resource
utilization.
2. This study analyzes latency and caching costs associated with content storage
at both individual edge nodes and cooperative domains. A Deep Q-Network
(DQN)-based approach is proposed to optimize the caching strategy, leverag-
ing reinforcement learning for optimal decision-making.
394 C. Yang et al.

3. Simulation results, obtained through comparisons with four distinct base-


lines, demonstrate that the proposed strategy achieves superior performance
in caching effectiveness.

2 Related Work
Edge caching, a key research area in MEC [3, 4], alleviates core network band-
width pressure by caching popular content on edge servers. This reduces latency
and improves user QoS, crucial for low-latency, high-speed data transmission.
The core aims of edge caching strategies are to enhance cache space utilization
and reduce backhaul load, providing a superior experience for latency-sensitive
devices.
Latency and energy consumption are critical to user QoS. Pre-storing content
reduces both, improving QoS and extending device life. Thus, caching strategies
balancing latency and energy are vital. Li et al. [5] proposed a delay-constrained
sleep algorithm for energy conservation. Addressing service disruptions from
mobility and limited server coverage, Li et al. [6] presented an energy-delay bal-
anced cooperative caching strategy using a branch-and-bound algorithm. Kong
et al. [7] used a DDPG algorithm for a joint computation and caching framework,
minimizing energy. Vallero et al. [8] analyzed caching performance under vari-
ous traffic, capacity, and distribution conditions, achieving energy reductions.
Kang et al. [9] proposed a mobility-aware task scheduling strategy with data
caching, solved using an improved differential evolution algorithm. Zheng et al.
[10] transformed time-varying popularity into a Markov chain, implementing a
hybrid strategy for dynamic cache replacement and deterministic offloading. Ref-
erence [11] proposed a Multidimensional Cooperative Caching (MDCC) scheme
to minimize transmission latency.
Cache hit ratio is a key factor in caching decisions, representing the percent-
age of requests fulfilled by the cache. Higher hit ratios lead to shorter response
times, higher throughput, and better performance. Yuan et al. [12] established
a caching scheme maximizing the hit ratio under cooperative cost constraints,
using an ADMM-based distributed algorithm. Chen et al. [13] calculated video
request probabilities based on popularity, preferences, and video characteris-
tics, maximizing hit ratio under storage constraints. Algorithms in [14, 15] also
demonstrate good performance in terms of cache hit ratio.
Edge server storage is limited, making it difficult to cache all content. Pri-
oritizing content users prefer is necessary to improve resource utilization. User
satisfaction impacts CP revenue, making maximizing caching benefits from the
CP perspective crucial. Wang et al. [16] studied how CPs can maximize utility,
establishing a Stackelberg game with a reinforcement learning-based algorithm
for pricing strategies. He et al. [17] proposed CIPS to motivate content sharing,
selecting the most appropriate CP.
The aforementioned studies often lack consideration of the impact of time-
varying characteristics of mobile devices and traffic load on caching decisions.
A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 395

Due to limited storage and computing resources in mobile edge networks, effec-
tive allocation and utilization are vital to handle the increasing demands of
services.

3 System Model
3.1 System Architecture

We consider an edge collaborative cache environment. It comprises E edge


nodes (each equipped with an edge server to provide computational and stor-
age resources), S mobile user devices, and a cloud server. The set of S user
devices is denoted as U S = {1, 2, . . . , S}, and EN = {1, 2, . . . , E} represents the
set of E edge nodes. A randomly selected edge node is denoted as en ∈ EN
(Edge Node, EN ). The set of all content items in the network is represented
by AF = {1, 2, . . . , F }, where the size of each content item can be expressed
as F S = {fs1 , fs2 , . . . , fsF }. Edge nodes are equipped with limited storage and
computational resources, enabling them to cache multimedia files or IoT data,
thereby reducing content request latency and backhaul link load. The cloud
server can cache content to edge nodes during off-peak network traffic peri-
ods. Edge nodes can gather real-time request information from user devices and
determine how to efficiently update content from the cloud server based on their
current caching status. The cooperation between the edge layer and the cloud
layer establishes a stable caching framework, where the cloud can push content
to edge nodes to alleviate bandwidth pressure, reduce transmission costs, and
minimize latency.
In this section, the caching decision matrix is represented as:

XEN ×AF = {xen,f , en ∈ EN, f ∈ AF }, (1)

where xen,f ∈ {0, 1} indicates whether content f ∈ AF is cached at en ∈ EN ,


where xen,f = 1 denotes that content f is cached at en ∈ EN . The caching
constraint for each EN is given by:

xen,f f ≤ CEN , (2)
f ∈AF

where CEN represents the caching capacity of the EN .


Assuming the association strategy is ys,en ∈ {0, 1}, ys,en = 1 indicates that
user device s is associated with edge node
 en. Since one user device can only
connect to a single edge node, we have en∈EN ys,en = 1. Each EN can serve
a limited number of users based on its finite wireless radio resources, which
satisfies: 
ys,en ≤ Es , (3)
s∈ϕen

where ϕen ∈ S represents the set of user devices covered by the EN , and Es is
the upper limit of the number of users an EN can serve.
396 C. Yang et al.

Similar to most existing studies, this chapter assumes that the content pop-
ularity follows a Zipf distribution. Therefore, the content access probability qf
is given as:
1

qf =  1 , (4)
f ∈AF f γ

where γ is the Zipf parameter. As γ increases, the content access popularity


becomes more skewed.

3.2 Resource Consumption

When the requested content is cached at the local edge node, the delay for a
user to download content f directly from the edge node is given by:
f sf
dlocal
en = , (5)
rdown
where f sf ∈ F S represents the size of content f , and rdown = Ben,s log2 (1 +
p·hen,s
B·Nσ ) is the downlink transmission rate, Ben,s denotes the bandwidth allocated
between the edge node en and user device s, p is the transmission power of the
edge node, hen,s represents the channel gain, and Nσ denotes the high Gaussian
noise power.
When the requested content is cached at other neighboring nodes within
the cooperative domain of the local edge node, the transmission of the content
between edge nodes incurs a delay, which is expressed as:
f sf
dr = , (6)
rEN
where rEN denote the delay for transmitting content f between edge nodes
within the cooperative domain, under the assumption that the queuing delay for
requests is negligible. The delay can be expressed as:

en = dr + den .
dCD local
(7)

When the requested content is not cached at the edge nodes, it must be
transferred from the cloud server to the local edge node. The transmission delay
between the cloud server and edge nodes is expressed as:
f sf
dc = , (8)
rc
where rc represents the transmission rate per unit data between the edge node
and the cloud server, with rc < rEN . Consequently, the delay to download
content f from the cloud server can be expressed as:

dcloud = dc + dlocal
en . (9)
A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 397

Therefore, the delay for retrieving content f can be represented as:

Df = qf · f s · [xen,f (ys,en · dlocal


en + (1 − ys,en ) · dCD
en ) + (1 − xen,f ) · dcloud ] (10)

In summary, the system’s total request delay cost can be expressed as:
 
D= Df · od , (11)
f ∈AF s∈U S

where od represents the unit transmission cost.

3.3 Caching Cost


The caching cost of edge nodes is modeled as the total cost incurred by storing
all content within the caching spaces of all edge nodes. Its functional expression
is given as:  
C= f sf · oc , (12)
en∈EN f ∈AF

where oc represents the unit caching cost.

3.4 Problem Statement


Based on the aforementioned caching model, the objective of this chapter is to
minimize the total cost. The total system cost is considered from two aspects:
content download delay cost and caching cost. The objective function P1 is
expressed as:

P1: min D + C, (13)



s.t. C1: xen,f fs ≤ CEN ,
f ∈AF

C2: ys,en ≤ Es ,
s∈ϕen

C3: ys,en = 1,
en∈EN
C4: xen,f ∈ {0, 1}, ys,en ∈ {0, 1},

where C1 specifies the storage capacity limitations of edge nodes. C2 imposes


a restriction on the number of users each edge node can serve, ensuring that
an edge node is only capable of supporting a predefined maximum number of
users. C3 stipulates that each user device must be associated with exactly one
edge node. C4 defines the relationship between caching decision variables and
association decision variables. As P1 is formulated as a mixed-integer nonlin-
ear programming (MINLP) problem, the computational complexity increases
with the dimensionality of the caching decisions. To address this challenge, this
chapter employs a DQN-based collaborative caching strategy to optimize the
objective function.
398 C. Yang et al.

4 Optimization Scheme
4.1 The Edge Node Clustering Scheme Based on K-Means++
In an edge collaboration environment, the geographical distribution of edge nodes
is random, leading to varying distances between nodes. Consequently, the data
transmission delay between edge nodes can differ accordingly. In addition, due
to the differing numbers of users served by each node, there is an unequal distri-
bution of traffic load, which can result in resource shortages (both storage and
computation) at hotspot nodes, significantly increasing the delay of user requests.
Therefore, we classify the nodes into hotspot nodes and regular nodes based on
the network traffic data of each edge node, denoted as N etT = {n1 , n2 , . . . , nE },
where ni represents the traffic data of edge node i. Hotspot nodes are defined
as:    
 ni
Nhot = nn ≥ ni ∈N etT , n ∈ N etT . (14)
E
Then, an ordinary node is denoted as:
   
 ni
Nnormal = nn < ni ∈N etT , n ∈ N etT . (15)
E

To address the above issue, the K-means++ algorithm is used to cluster


edge nodes into multiple collaborative domains based on their traffic data and
distance. Where CH = {1, 2, . . . , K} represents the cluster head of the collab-
orative domain, and the collaborative domain where the k ∈ CH cluster head
is located is denoted as CDk = {1, 2, . . . , Ek }, where Ek represents the num-
ber of edge nodes included in this domain. The clustering algorithm is primarily
implemented on the cloud server, and the clustering results are sent to each edge
node. Figure 1 illustrates the K-means++ based edge node clustering scheme.

4.2 DQN-Based Cache Strategy Solution

With the advancement of reinforcement learning in various domains, its powerful


learning and decision-making capabilities have become increasingly evident. The
Q-Learning algorithm, a classic in reinforcement learning, faces challenges in
maintaining the Q-table in complex edge network environments. To overcome
the limitations of Q-Learning and in response to the rapid development of deep
learning, the Deep Q Network (DQN) was introduced. DQN combines neural
networks with Q-Learning to approximate the optimal action-value function.
To break the correlation between data sequences, DQN with experience
replay introduces an experience replay mechanism, unlike traditional DQN algo-
rithms. Afterward, small batches are sampled and repeatedly used to update
the network parameters to train the agent. Additionally, DQN uses a fixed Q-
target network mechanism, employing two identically structured but differently
parameterized neural networks to address the instability in algorithm updates.
A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 399

Fig. 1. Edge Node Clustering Scheme Based on K-means++.

Similar to most cache methods based on Deep Reinforcement Learning


(DRL), we model the caching process of each edge node as a Markov Deci-
sion Process (MDP). When a user request is received, the edge node updates
the state of the content si based on the user’s request features and the caching
conditions of its local and neighboring nodes within the cooperation domain.
The edge node then determines the corresponding action ai according to the
policy πi , transitions the state si to the next state si+1 , and receives a reward
ri . Below is a brief description of the states, actions, and rewards in the MDP:

(1) States
The state space reflects the environmental state of the content delivery net-
work. Similar to most existing DRL-based caching algorithms, we represent
the state space S using the caching status S = {E1 , E2 , . . . , EN } of the edge
nodes, where Ei denotes the caching state of edge node i.
(2) Actions
The action space is represented by the caching decision matrix XEN ×AF .
Each edge node, as an agent, makes joint caching decisions based on its own
caching space.
(3) Rewards
The goal of reinforcement learning is to obtain the maximum reward, while
the objective function of this paper is to minimize the sum of content request
delay cost and caching cost. Therefore, we define the reward as:
400 C. Yang et al.

1
R= . (16)
D+C

After determining the state space, action space, and reward, the edge col-
laboration network caching process based on DQN is described as follows: The
edge nodes act as agents in the DQN algorithm, interacting with the environ-
ment through states, actions, and rewards. When the agent receives a content
request from a user device, it seeks the optimal content caching placement strat-
egy based on the current caching status of each edge node. This is achieved by
training to obtain the Q-values for all actions, allowing the agent to select the
optimal action. After the action is executed, the environment changes, and the
agent receives the current environment and its corresponding reward, which are
stored in the experience replay buffer. Then, a mini-batch of data, represented as
the four-tuple (s, a, r, s ) is randomly sampled from the experience replay buffer
to update the network parameters. The DQN network parameters are updated
using the Temporal Difference (TD) algorithm, which requires calculating the
TD target value, expressed as:

ŷ = r + γ max

Q(s , a ; θ ). (17)
a

In the DQN iteration process, the Q-values are trained by minimizing the
loss function to approximate the target values. The loss function is expressed as:

loss = E[(ŷ − Q(s, a; θ))2 ]. (18)

Through the above calculation, the update of the Q-value estimation net-
work parameters is achieved. Then, we periodically update the Q target network
parameters to ensure smooth changes in the Q-value estimation network param-
eters, enhancing the algorithm’s practicality. The DQN architecture is shown in
Fig. 2.

Fig. 2. DQN Architecture Diagram.


A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 401

4.3 Optimization Scheme


The proposed research scheme, which combines the K-means++ based edge node
clustering scheme with the DQN-based caching strategy, is referred to as the
DQN-based collaborative caching strategy for edge networks. The pseudocode
for the proposed scheme is provided in Algorithm 1.

Algorithm 1. Cooperative Caching with DQN


1: Init: EN, U S, AF, F S, X, K, CH, θ, θ 
2: Phase 1: Edge Node Clustering
3: Input: EN, U S, AF, F S, X, K, CH, θ, θ 
4: Output: Core areas
5: if |CH| ≤ K then
6: for each node do
7: Calculate Dist(CH, N ), add node with largest dist. to CH
8: end for
9: end if
10: for each en ∈ EN do
11: Calculate Dist(CH, en), add en to closest core area
12: end for
13: Calculate avg. quality in each area, closest node is core qual.
14: Phase 2: Caching Decision
15: Output: Final caching X
16: for each episode do
17: for each step do
18: Select action a via -greedy based on s
19: Get s , cache action, reward r, update buffer
20: Sample transitions, update network via (4.17)
21: Update loss, update θ
22: Update target network: θ ← θ
23: end for
24: end for

5 Experiment and Result Analysis


In this section, we set the simulation environment parameters and performance
metrics, and evaluate the performance of the proposed scheme by comparing it
with various baseline schemes.

5.1 Experiment Environment and Parameter Settings


The experimental environment in this paper consists of a Windows 10, 64-bit
operating system with an Intel Core i5-10500 CPU running at 3.10 GHz and
16 GB of RAM. The entire experimental process was implemented in a Python
3.7 development environment using Pycharm. The scenario assumed in this study
is an edge collaboration caching network, where a cloud server stores all content,
402 C. Yang et al.

and multiple edge nodes, each equipped with an edge server, use their limited
cache space to store content of interest to users. Users with content requests
are distributed within the coverage area of the edge nodes. They can obtain
requested content either from the edge nodes or the cloud server. The specific
simulation environment parameters are shown in Table 1.

Table 1. Experimental Parameter Settings.

Parameter Value
System Bandwidth 10 MHz
Number of Edge Nodes 10
Edge Node Cache Space 100 MB
Number of Contents 500
Content Size 5 MB
Zipf Factor 0.6
Number of Centroids 3
Number of Neural Network Layers 2
Maximum Training Epochs 1000
Batch Sampling Size 64
Experience Replay Pool Size 5000
Learning Rate 0.001

5.2 Performance Indicators


To better evaluate the performance of the proposed DQN-based edge collabo-
rative caching strategy, we use Cache Hit Ratio (CHR) and Content Download
Latency (CDL) as performance evaluation metrics [18]. The specific definitions
are as follows:

(1) The cache hit ratio is modeled as:


 Nen,f
CHR = , (19)
Nf
f ∈AF

where Nen,f represents the number of times content f is fetched from the
edge node, and Nf represents the total number of requests for content f .
(2) The content download latency (CDL) can be modeled as:
1  
CDL = Df , (20)
S
f ∈AF s∈U S

where CDL represents the average delay of user content download requests.
A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 403

5.3 Performance Analysis


This chapter evaluates the proposed scheme’s performance based on Zipf param-
eter, edge node cache capacity, and content quantity, comparing it to various
baseline methods, which are as follows:
1. LFU (Least Frequently Used): Replaces the least frequently used content.
2. LRU (Least Recently Used): Removes the least recently used content.
3. FIFO (First Input First Output): Prioritizes removing the content that has
been stored the longest.
4. Without cluster: This scenario does not consider edge node clustering.

Fig. 3. The Impact of Zipf Parameter.

Figure 3 illustrates the impact of different Zipf parameters γ on cache hit ratio
for the proposed caching scheme. As the Zipf parameter increases, the cache hit
ratio for all algorithms improves. The larger the value of parameter γ, the higher
the popularity of the corresponding content items, meaning the probability of
users requesting that content increases, thus improving the cache hit ratio. More-
over, it can be clearly observed that the proposed scheme consistently achieves
a higher cache hit ratio than the four baseline methods. For example, with Zipf
parameter γ = 0.8, the cache hit ratio of the proposed algorithm is 19%, 22%,
and 28% higher than LFU, LRU, and FIFO, respectively, demonstrating the
decision-making capability of the reinforcement learning algorithm. Compared
to the “Without cluster” scheme, it improves by 7%, as edge node clustering
allows better utilization of idle cache space, resulting in more user-interesting
content being cached at the edge, thus increasing the cache hit ratio.
Figure 4 shows the impact of different edge node cache capacities on CHR
and CDL. From Fig. 4(a), it can be observed that as the edge node cache capacity
increases, the overall content cache hit ratio of the algorithm also increases, with
404 C. Yang et al.

a significant improvement. When the node cache capacity is 200MB, the cache
hit ratios for LFU, LRU, FIFO, and Without cluster are 54.4%, 53%, 50.9%,
and 62.4%, respectively. The proposed algorithm improves the cache hit ratio by
23%, 27%, 32%, and 8% compared to the baseline algorithms. The smaller the
cache capacity, the more pronounced the performance advantage of the proposed
algorithm. This is because, when the cache capacity is larger, more content can
be cached within the edge node collaboration domain, making it easier to satisfy
user requests, resulting in a smaller performance gap. The algorithm without
clustering also performs significantly better than LFU, LRU, and FIFO, which
indirectly indicates the advantages of reinforcement learning in cache decision-
making.

Fig. 4. The Impact of Different Edge Node Cache Capacities.

Figure 4(b) shows the impact of edge node cache capacity on content down-
load latency. As the cache capacity of edge nodes increases, the overall content
download latency decreases. When the edge node cache capacity increases, more
content can be stored, reducing the probability of downloading content from the
cloud server, thus lowering the download latency. Additionally, when the edge
node capacity is smaller, the proposed algorithm performs better. When the
cache node capacity is limited, edge nodes can cluster based on network traffic
data, making full use of the idle cache space of ordinary nodes within the coop-
eration domain. This allows for the storage of richer popular content within the
collaboration domain, and the data transmission latency between edge nodes is
much smaller than that between edge nodes and the cloud. Therefore, under
limited resources, the proposed scheme is more practical.
A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 405

Fig. 5. The Impact of Different Content Quantities.

Figure 5 shows how content quantity impacts Cache Hit Ratio (CHR) and
Content Download Latency (CDL), comparing the proposed scheme with the
“Without cluster” method. As seen in Fig. 5(a), increasing content quantity
decreases the CHR. With fixed cache space, more content means less can be
stored, reducing the CHR. However, the proposed scheme’s CHR declines slower
than “Without cluster”. With 500 and 1000 content items, the proposed scheme
shows 4.7% and 11.8% higher CHR, respectively, indicating better performance
as content quantity increases. The K-means++ clustering ensures better cache
space utilization within the domain, enabling more effective content caching and
improved performance.
In Fig. 5(b), it is evident that content download latency increases with the
content quantity. This is due to the limited cache capacity of edge nodes, which
results in many contents being cached only on the cloud server, leading to higher
download latency. Similarly, the proposed scheme performs better when the con-
tent quantity is larger. Compared to the “Without cluster” approach, the pro-
posed algorithm can make full use of the cache space of ordinary nodes, storing
more content and thus reducing content download latency.

6 Conclusion

This paper studies the caching strategy from the perspective of uneven traffic
load across edge nodes and proposes a DQN-based edge collaborative caching
algorithm. Firstly, based on the network traffic data of edge nodes, they are
classified into hotspot nodes and ordinary nodes. Using the K-means++ algo-
rithm, edge nodes are clustered based on their traffic data and distance. Then,
the delay and caching costs associated with caching content on edge nodes and
within the collaborative caching domain are analyzed to establish the content
caching problem. To minimize system costs, a DQN-based approach is proposed
to optimize caching decisions. Simulation results show that the DQN-based edge
406 C. Yang et al.

collaborative caching scheme outperforms several baseline methods, demonstrat-


ing superior performance.

Acknowledgments. This work was sponsored by the National Natural Science Foun-
dation of China (Grant No. 62271264).

References
1. Abbas, N., Zhang, Y., Taherkordi, A., Skeie, T.: Mobile edge computing: a survey.
IEEE Internet Things J. 5(1), 450–465 (2017)
2. Wang, J., Zhao, L., Liu, J., Kato, N.: Smart resource allocation for mobile edge
computing: a deep reinforcement learning approach. IEEE Trans. Emerg. Top.
Comput. 9(3), 1529–1541 (2019)
3. Tran, T.X., Hajisami, A., Pandey, P., Pompili, D.: Collaborative mobile edge com-
puting in 5G networks: new paradigms, scenarios, and challenges. IEEE Commun.
Mag. 55(4), 54–61 (2017)
4. Peng, M., Sun, Y., Li, X., Mao, Z., Wang, C.: Recent advances in cloud radio access
networks: system architectures, key techniques, and open issues. IEEE Commun.
Surv. Tutor. 18(3), 2282–2308 (2016)
5. Li, P., Gong, S., Gao, S., Hu, Y., Pan, Z., You, X.: Delay-constrained sleeping
mechanism for energy saving in cache-aided ultra-dense network. Sci. China Inf.
Sci. 62, 1–14 (2019)
6. Li, C., Zhang, Y., Gao, X., Luo, Y.: Energy-latency tradeoffs for edge caching and
dynamic service migration based on DQN in mobile edge computing. J. Parallel
Distrib. Comput. 166, 15–31 (2022)
7. Kong, X., et al.: Deep reinforcement learning-based energy-efficient edge computing
for internet of vehicles. IEEE Trans. Industr. Inf. 18(9), 6308–6316 (2022)
8. Vallero, G., Deruyck, M., Joseph, W., Meo, M.: Caching at the edge in high energy-
efficient wireless access networks. In: ICC 2020-2020 IEEE International Conference
on Communications (ICC), pp. 1–7. IEEE (2020)
9. Kang, L., Tang, B., Zhang, L., Tang, L.: Mobility-aware and data caching-based
task scheduling strategy in mobile edge computing. In: 2019 IEEE International
Conference on Parallel & Distributed Processing with Applications, Big Data &
Cloud Computing, Sustainable Computing & Communications, Social Computing
& Networking (ISPA/BDCloud/SocialCom/SustainCom), pp. 1071–1077. IEEE
(2019)
10. Zheng, C., Liu, S., Huang, Y., Yang, L.: Hybrid policy learning for energy-latency
tradeoff in MEC-assisted VR video service. IEEE Trans. Veh. Technol. 70(9), 9006–
9021 (2021)
11. Lin, P., Song, Q., Jamalipour, A.: Multidimensional cooperative caching in CoMP-
integrated ultra-dense cellular networks. IEEE Trans. Wireless Commun. 19(3),
1977–1989 (2019)
12. Yuan, P., Shao, S., Geng, L., Zhao, X.: Caching hit ratio maximization in mobile
edge computing with node cooperation. Comput. Netw. 200, 108507 (2021)
13. Chen, X., He, L., Xu, S., Hu, S., Li, Q., Liu, G.: Hit ratio driven mobile edge caching
scheme for video on demand services. In: 2019 IEEE International Conference on
Multimedia and Expo (ICME), pp. 1702–1707. IEEE (2019)
A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 407

14. Wei, X., Liu, J., Wang, Y., Tang, C., Hu, Y.: Wireless edge caching based on content
similarity in dynamic environments. J. Syst. Architect. 115, 102000 (2021)
15. Li, Z., Gao, X., Li, Q., Guo, J., Yang, B.: Edge caching enhancement for industrial
internet: a recommendation-aided approach. IEEE Internet Things J. 9(18), 16941–
16952 (2022)
16. Wang, Q., Guo, S., Liu, J., Pan, C., Yang, L.: MotiShare: incentive mechanisms
for content providers in heterogeneous time-varying edge content market. IEEE
Trans. Serv. Comput. 16(1), 452–465 (2021)
17. He, J., Wang, H., Chu, X., Zhang, T.: Incentive mechanism and content provider
selection for device-to-device-based content sharing. IEEE Trans. Veh. Technol.
68(3), 2946–2957 (2019)
18. Li, C., Zhang, Y., Song, M., Yan, X., Luo, Y.: An optimized content caching
strategy for video stream in edge-cloud environment. J. Netw. Comput. Appl. 191,
103158 (2021)
YOLO-LiteMax: An Improved Model
for UAV Small Object Detection

Jian Su1(B) , Chang Yang2 , and Jian Zhang2


1
School of Software, Nanjing University of Information Science and Technology,
Nanjing 210044, China
[email protected]
2
School of Computer Science, Nanjing University of Information Science
and Technology, Nanjing 210044, China
[email protected]

Abstract. Drone image object detection is a crucial foundational tech-


nology across multiple research fields. However, due to challenges such as
small object sizes, dense distributions, and complex backgrounds com-
monly encountered during drone-based object recognition, the detec-
tion accuracy is often suboptimal. To address these challenges, we opti-
mize YOLOv8 and proposed a new object detection model specifically
designed for drone aerial imagery, named YOLO-LiteMax. Firstly, we
replace the Bottleneck structure in the Cross Stage Partial (CSP) mod-
ule with a FasterNet Block structure, which effectively reduces parameter
while maintaining the original level of accuracy. Secondly, we introduce a
new feature pyramid structure called Small Target Scale Sequence Fusion
(STSSF), enhancing multi-scale feature fusion to better capture small
objects. Lastly, we replace the original detection head with the Shared
Convolution Precision Detection (SCPD) head, which uses shared convo-
lutions and group normalization to improve the efficiency and consistency
of feature extraction. Experiments on the VisDrone2019 dataset show
that YOLO-LiteMax significantly outperforms other baseline methods
in detection accuracy. Compared to YOLOv8, YOLO-LiteMax achieved
a remarkable 5.9% improvement in the mAP50 metric. These results
demonstrate that YOLO-LiteMax significantly improves small object
detection in drone imagery, making it more effective in complex envi-
ronments.

Keywords: UAVs · Small-Object Detection · Deep Learning ·


YOLOv8 · Feature Fusion

1 Introduction
With the increasing popularity of drones, their applications are expanding across
diverse fields such as agriculture [18], logistics [13], surveillance [16], environmen-
tal monitoring [2], and disaster response [6]. These applications take full advan-
tage of drones’ unique capabilities, such as high mobility, flexibility, and access
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 408–422, 2025.
https://doi.org/10.1007/978-981-96-4566-4_28
YOLO-LiteMax: An Improved Model for UAV Small Object Detection 409

to hard-to-reach areas [4]. Accurate and efficient object detection is essential for
the effective use of drones in these fields. Object detection is one of the core
tasks in the field of computer vision, providing essential support for autonomous
navigation, precise operations, and situational awareness for drones by automat-
ically identifying and locating specific objects in images or videos. For example,
in agriculture, object detection technology can be used to assess crop health,
monitor pests, and evaluate harvest conditions, thereby improving the efficiency
and accuracy of agricultural management. In logistics and parcel delivery, object
detection facilitates automatic identification and tracking of packages, enabling
automated sorting and delivery. In disaster response and environmental moni-
toring, object detection helps identify critical objects in emergency situations,
such as distressed individuals, fire hotspots, or flooded areas, providing timely
and reliable information for rescue operations. Therefore, the performance of
object detection directly impacts the effectiveness of drone applications in these
scenarios. Advances in deep learning-based object detection algorithms have sig-
nificantly driven progress in this field, greatly improving the accuracy and effi-
ciency of visual recognition tasks [14, 20, 21]. Despite these advancements, drone-
based object detection still faces several challenges. The aerial perspective often
results in small object sizes and low resolution, which makes precise detection
difficult due to limited pixel information [24]. Additionally, drones’ constrained
computational resources restrict the deployment of complex and computation-
intensive models [7]. Moreover, the continuously changing environmental con-
ditions, dynamic backgrounds, and frequent changes in camera angles further
increase the complexity of reliable object detection from drones [24]. Although
deep learning techniques have led to significant improvements in object detec-
tion, these challenges still hinder real-time performance and accuracy in drone
applications. To address these limitations, this study aims to develop a more
efficient and accurate object detection model, specifically optimized for drone
imagery, to enhance its robustness in complex environments while maintaining
computational efficiency.
In deep learning-based object detection for drones, two main approaches
exist: two-stage detectors and one-stage detectors. Two-stage detectors, such
as Faster R-CNN [20], first generate candidate regions of interest in the image
through a region proposal network, which are then further analyzed for precise
object classification and bounding box refinement. This approach is well-suited
for complex drone imagery involving challenging backgrounds and densely dis-
tributed targets. However, the two-stage process demands substantial computa-
tional resources, limiting its practicality for real-time applications on resource-
constrained drone devices. Conversely, one-stage detectors, like You Only Look
Once (YOLO) [19] and Single Shot MultiBox Detector (SSD) [15], process the
entire image in a single pass, directly predicting class probabilities and bounding
boxes from a grid of predefined anchor boxes. By combining classification and
localization in one step, one-stage detectors achieve faster processing speeds.
Nonetheless, they often show reduced accuracy for small or densely clustered
objects, particularly in complex backgrounds, leading to a higher likelihood of
false positives and missed detections. In summary, although existing two-stage
410 J. Su et al.

detectors have advantages in terms of accuracy, their high computational com-


plexity limits their use in UAV applications that require high real-time perfor-
mance. On the other hand, one-stage detectors are faster but need improved
detection accuracy for small targets and in complex environments. Given the
real-time and computational constraints of drone applications, this study aims
to improve the accuracy of small object detection within one-stage detectors,
enhancing their robustness in complex environments while maintaining efficiency
to better meet the needs of drone-based applications.
In this study, we propose enhancements to the YOLOv8s architecture to
improve the accuracy of drone-based object detection. The main contributions
of this work are as follows:

(1) We redesign the CSP module by introducing the FasterNet Block, effec-
tively reducing redundant feature information and lowering model complex-
ity without compromising detection accuracy. This improvement addresses
the issue of excess computational load in previous models.
(2) We redesign the neck network by proposing the STSSF pyramid structure
to solve the problem that simple summation and concatenation in the orig-
inal feature pyramid cannot take full advantage of the correlation between
feature maps This new structure incorporates the P2 detection scale and
optimizes feature fusion, significantly enhancing small object detection and
improving model robustness in complex scenes.
(3) To address the limitations of the original YOLOv8 detection head in han-
dling small batch sizes, effective feature fusion, and capturing both local
and global context, we propose the SCPD head. By incorporating shared
convolution and group normalization, SCPD head achieves higher detection
accuracy with fewer parameters by improving the stability and consistency
of feature extraction.
(4) To validate the effectiveness of the proposed methods, we conduct exten-
sive ablation studies and performance comparisons on the VisDrone2019
dataset. The dataset features various challenges, such as small object sizes,
low resolution, and dynamic backgrounds, which align well with the issues
discussed in this study. The results show that our approach offers significant
improvements over existing methods, especially in detecting small objects,
while maintaining computational efficiency, making it ideal for drone-based
applications.

2 Related Work
In drone-based object detection, several studies have proposed methods to
enhance detection accuracy and adaptability. Sahin et al. [22] introduced YOLO-
Drone, which improved the detection of multi-scale objects, particularly small
ones, by adding additional convolutional and detection layers. However, these
modifications increased computational requirements, limiting the model’s effec-
tiveness in handling complex occlusions. Albaba et al. [1] proposed SyNet, an
YOLO-LiteMax: An Improved Model for UAV Small Object Detection 411

ensemble network combining the strengths of both multi-stage and single-stage


detectors for UAV images, which enhanced detection accuracy and robustness for
small and multi-scale objects. Despite these gains, SyNet’s high computational
demand and complex training process limit its suitability for dense scenes. Sun
et al. [23] presented RSOD, an extension of YOLOv3 that utilizes shallow fea-
tures for precise localization and combines local and global features, refining the
Squeeze-and-Excitation attention mechanism to improve small object detection.
Nevertheless, these enhancements increased computational needs, complicating
deployment in resource-limited environments. Zhao et al. [27] introduced MS-
YOLOv7, which incorporates multiple detection heads, attention mechanisms,
a Swin Transformer, and an improved pooling module to boost the detection
of densely packed objects. However, the increased computational requirements
and complexity hinder its deployment efficiency in constrained environments.
Similarly, Lou et al. [17] proposed DC-YOLOv8, which employs a novel down-
sampling approach, an improved feature fusion network, and increased network
depth to improve small object detection in complex scenarios. However, DC-
YOLOv8 also faces high computational costs and training challenges, limiting
its deployment in heavily occluded environments.

3 The Proposed YOLO-LiteMax Method


Figure 1 shows the overall architecture of YOLO-LiteMax, an improved ver-
sion of the YOLOv8s algorithm. The model comprises three main components:
the backbone feature extraction network, the neck feature fusion network, and
the detection head network. These components are designed to optimize feature
extraction, feature fusion, and detection classification, achieving an effective bal-
ance between detection performance and model complexity.

Fig. 1. The overall structure of the proposed improved model. There are two C3DSF
modules in the figure, the first one is an orange route and the second one is a blue
route. The red dashed box represents a novel enhancement method for P2 feature maps.
(Color figure online)
412 J. Su et al.

3.1 Selective Convolution Block (SCB)


As the depth of neural networks increases, the number of feature map chan-
nels grows significantly, often leading to redundancy and high similarity between
channels. This redundancy not only increases computational costs but also intro-
duces unnecessary parameters that do not contribute meaningfully to model
performance. Therefore, addressing this issue is critical for optimizing network
efficiency. To mitigate the problem of feature redundancy while maintaining per-
formance, we propose the Selective Convolution Block (SCB) module, as shown
in Fig. 2.

Fig. 2. The structure of SCB.

In the SCB module, the FasterNet Block from FasterNet [5] replaces the orig-
inal Bottleneck structure in the CSP module. The core concept of the FasterNet
Block is Partial Convolution. In this approach, convolution is applied only to a
subset of the input channels, while the remaining channels are processed using
identity mapping. This reduces computational overhead by limiting the num-
ber of channels involved in convolution, while the non-convolved channels are
retained and passed through subsequent layers. As a result, FasterNet Block
minimizes computational complexity without sacrificing the flow of information
and feature propagation through the network. The SCB module operates as fol-
lows: First, the feature map is processed by a CBS module, which consists of
conventional convolution, batch normalization, and SiLU activation. The output
is then passed into a Split layer, which divides the feature map into two parts,
each containing half of the total channels, and stores the results in a list. The
output from each FasterNet Block is then fed into the next block and added
to the list. After passing through several FasterNet Blocks, all feature maps in
the list are concatenated along the channel dimension and processed by another
CBS module to produce the final output (Fig. 3).
Compared to the original YOLOv8, the SCB module introduces the Faster-
Net Block, which addresses the computational inefficiency of YOLOv8. In
YOLOv8, the Bottleneck structure convolves all input channels, leading to high
computational overhead. In contrast, the SCB module reduces this overhead by
applying convolution only to a subset of the channels, while ensuring that the
remaining channels are still retained and processed later. This selective convolu-
tion significantly reduces the computational complexity and number of parame-
ters while maintaining detection accuracy.
YOLO-LiteMax: An Improved Model for UAV Small Object Detection 413

Fig. 3. The structure of Bottleneck and FasterNet Block.

3.2 Small Target Scale Sequence Fusion (STSSF)

In YOLOv8, the path aggregation network-feature pyramid network structure


is used for feature fusion. However, since the feature pyramid primarily relies
on simple summation and concatenation to fuse feature maps of different scales,
this approach does not fully exploit the correlations between these feature maps.
Furthermore, it heavily depends on small-scale feature maps, which limits its
overall performance [11]. To address this issue, we propose an optimized feature
pyramid method, called Small Target Scale Sequence Fusion (STSSF), as shown
in Fig. 1. This approach introduces the Scale Sequence Fusion (SSF) pyramid
structure, based on ASF-YOLO [11], making it better suited for small object
detection. On top of ASF-YOLO, STSSF also integrates large-scale feature maps,
allowing the model to focus more effectively on small-object-related features.
The Convolution 3D Scale Fusion (C3DSF) module and the Scale Aware
Concatenation (SAC) module are adapted from the fusion modules in ASF-
YOLO. C3DSF treats features of different scales as a sequence and establishes
contextual relationships along the depth dimension, thereby further enhancing
feature fusion. As illustrated in Fig. 4, the C3DSF module first aligns the chan-
nel numbers and dimensions of the different size feature maps with those of
the large-size feature map. It then expands and stacks these feature maps along
the depth dimension, extracting their scale sequence features through 3D con-
volution. Compared to traditional 2D convolution, this design provides more
effective information fusion along the depth dimension between different scale
feature layers, capturing cross-scale correlations and preventing the information
loss or discontinuity issues caused by simple concatenation. In images, small
objects are often densely packed and easily overlooked. In order to solve the
problem that the upsampling and fusion mechanism in the traditional pyramid
structure may lead to ignoring the details of the large-size feature maps, SAC
provides a more fine-grained processing of the large, medium, and small-size
feature maps. As shown in Fig. 4, the large-size feature map is downsampled
using a combination of max pooling and average pooling, while the small-size
feature map is upsampled to retain local details. This method strikes a bal-
ance between retaining feature map details and extracting global information,
414 J. Su et al.

while also preventing small-object features from being lost during upsampling.
Finally, the large, medium, and small-size feature maps are concatenated along
the channel dimension to form a complete feature map. Unlike traditional sim-
ple concatenation, the SAC module not only adds large-size feature layers but
also processes feature maps at different sizes differently. This design effectively
enhances the model’s ability to capture small-object and detailed information,
thereby improving overall detection performance.

Fig. 4. The structure of C3DSF.

Since the targets in the analyzed images are often dense and small, we propose
a novel enhancement method for P2 feature maps in this section, as shown in the
red dashed box in Fig. 1. Different from the traditional way of adding small-scale
feature maps, we delay the upsampling step because the first C3DSF module has
already performed the preliminary fusion of the features in the P3, P4, and P5
layers, which completes the integration of the multi-scale features in advance
and reduces the unnecessary downsampling operation, which in turn reduces the
computational amount. Considering that the addition of the SAC module will
utilize the P1 layer feature maps, which will significantly increase the number
of parameter and computations, we choose to fuse the P2 layer through upsam-
pling and concatenate operation. Additionally, we introduce a second C3DSF
module to fuse the P2 layer with the processed P3 and P4 layer features. With
this design, the features from the P3 and P4 layers are subjected to deeper pro-
cessing after integrating more contextual information, and then merge with the
detailed information from the P2 layer. This enhances the model’s ability to
detect small objects. Compared to the SSF pyramid structure of ASF-YOLO,
this improvement not only retains high-resolution detail information when han-
dling small objects but also integrates deep semantic information, resulting in a
model that is more robust in complex scenarios.

3.3 Shared Convolution Precision Detection (SCPD) Head


The original detection head of YOLOv8 contains classification and regression
tasks, with each detection head having an independent convolution processing
path to generate the final detection results. Shared Convolution Precision Detec-
tion (SCPD) head, is a novel object detection head, as shown in Fig. 5. Compared
YOLO-LiteMax: An Improved Model for UAV Small Object Detection 415

to the independent paths of the YOLOv8 detection head, SCPD significantly


reduces model parameter and improves detection accuracy through shared con-
volution [12] and group normalization [26].

Fig. 5. The structure of SCPD. The modules in the red dashed line represent shared
convolution. (Color figure online)

We integrate group normalization into the convolution module, calling it the


CGS module. Group normalization is advantageous for handling high-resolution
and multi-channel feature maps, as it normalizes each group of channels, allowing
for finer adjustment of the feature value distribution. Furthermore, group nor-
malization has been proven in the FCOS [8] paper to improve the localization
and classification performance of detection heads.
The CGS module performs convolution operations on different scales of fea-
ture maps from the feature pyramid network to further extract and enhance
these features. Subsequently, using two CGS modules with shared convolution,
improving the stability of multi-scale object detection and allowing beneficial
information to be shared across different scales, enhancing adaptability to vari-
ations in target appearance. Finally, through the decoupled head design, classi-
fication and regression modules share convolution parameter, further enhancing
feature coherence and detection performance in multi-scale object detection, and
through collaborative learning to make the different scale detection heads jointly
optimize the parameter in training, which effectively improves the robustness and
generalization ability of feature extraction, and thus significantly improving the
overall performance of the model.
While using shared convolution, sharing a single parameter may cause incon-
sistency in the scales of detected targets in the detection head. To address this
issue, we introduce a scaling factor to adjust the features. Experiments show
that this shared convolution detection head design effectively improves detec-
tion accuracy while reducing parameter.

4 Experiments and Results


4.1 Dataset

The VisDrone2019 dataset [9] is a widely used collection of drone aerial imagery,
compiled, annotated, and organized by the AISKYEYE data mining team at
Tianjin University. This dataset spans multiple cities in China and features
416 J. Su et al.

multi-angle, multi-scene, and multi-task imagery captured by various drone mod-


els. VisDrone2019 contains ten object categories, including pedestrians, cars, and
bicycles. It is rich in data, consisting of 6,471 images in the training set, 548 in
the validation set, 1610 in the test set, and 1580 in the competition set, with
image resolutions ranging from 2000×1500 pixels to 480×360 pixels. The dataset
is notable for the large number of small objects, which are often densely packed
and unevenly distributed.

4.2 Evaluation Indicators

In the experiment, we use Ubuntu 20.04 as the operating system, with PyTorch
2.2.2 and Cuda 12.1 as the computational software environment. The hardware
configuration included an i9-13900KF CPU and an NVIDIA RTX 4090 GPU
with 24 GB of video memory. The experiments in this study were conducted
using hardware primarily equipped with an NVIDIA RTX 3080 GPU and an i5-
12600KF CPU. The input image size was 640 × 640, with 400 epochs. The batch
size was 8, the initial learning rate was 0.01. The evaluation metrics include Pre-
cision, Recall, mAP, and Parameter. Precision is a key metric for evaluating the
performance of classification models, primarily used to measure how accurately
the model predicts positive samples. The formula is as follows.
TP
P recision = × 100% (1)
TP + FP
where T P is true positives, indicating the number of samples correctly predicted
as positive; F P is a false positive, indicating the number of samples that are
incorrectly predicted to be positive.
Recall measures a model’s ability to correctly identify positive samples. It
ranges from 0 to 1, with a higher value indicating that the model is better at
identifying positive samples. The formula is as follows.
TP
Recall = × 100% (2)
TP + FN
where F N represents the number of samples that were actually positive but
incorrectly predicted to be negative.
Mean Average Precision (mAP) is a key metric for evaluating model perfor-
mance in object detection and information retrieval tasks. mAP considers both
accuracy and recall across multiple categories, providing a more comprehensive
assessment of the model’s performance. A higher mAP value indicates better
detection performance. The formula is as follows.
N

AP = (Rn − Rn−1 )Pn (3)
n=1

where Pn and Rn are the precision and recall rates corresponding to the N th
recall rate, respectively.
YOLO-LiteMax: An Improved Model for UAV Small Object Detection 417

C
1 
mAP = APi (4)
C i=1
where C is the total number of classes and APi is the average precision of class
i.
Parameter is an important metric for evaluating the complexity of a deep
learning model, representing the total count of learnable parameters. Typically,
the number of parameters reflects the model’s capacity, influencing its expres-
siveness and the computational resources required for training.

4.3 Comparison Experiments

In this study, several enhancements are made to the original YOLOv8 model
to improve both the precision and efficiency of object detection. To verify the
effectiveness of these improvements, we design two sets of comparative exper-
iments. The first set compares the enhanced model with several YOLO series
algorithms, while the second set compares it with other high-performance mod-
els. The evaluation metrics include accuracy, recall, mAP, parameter, and model
size.
Compared to YOLOv5s, YOLOv8s, and YOLOv10s [25], YOLO-LiteMax
improves mAP50 by 12%, 5.9%, and 4.1%, respectively, and enhances mAP50-95
by 9.3%, 3.8%, and 2.6%, while maintaining a smaller parameter count and model
size. Additionally, compared to the latest YOLOv11s algorithm, the improved
model achieves an increase of 4.9% in mAP50 and 3.1% in mAP50-95. These
results demonstrate that YOLO-LiteMax delivers the best overall performance
for detecting densely packed small objects. Despite a slight increase in inference
time, the improved model still achieves real-time detection capabilities (Table 1).

Table 1. Detection results of some YOLO series models and the proposed model. (The
bold data in the table indicate the best results.)

Models Precision/% Recall/% mAP50/% mAP50-95/% Params/M Model Size/MB


YOLOv5s 44.8 33.9 33.2 18 7.0 13.7
YOLOv5m 48.6 36.4 36.3 20.4 20.9 40.3
YOLOv8s 50.9 38.2 39.3 23.5 11.1 21.5
YOLOv8m 55.7 42.4 44 26.7 25.9 49.7
YOLOv10s 52.5 39.8 41.1 24.7 7.2 15.8
YOLOv10m 55.1 42.5 44.4 27.2 15.3 32.0
YOLOv11s 49.9 39.6 40.3 24.2 9.4 18.3
Proposed 56.3 42.7 45.2 27.3 6.1 12.2
418 J. Su et al.

As shown in the comparative experimental results in Table 2, YOLO-LiteMax


outperforms other advanced models in terms of detection performance. Faster R-
CNN [20] struggles to detect small objects due to network downsampling, which
reduces the resolution of feature maps. Additionally, the Region Proposal Net-
work has difficulty generating candidate regions that encompass small objects,
and the limited information from these small targets leads to suboptimal detec-
tion accuracy. Cascade R-CNN [3] employs higher IoU thresholds at later stages,
making it challenging for small objects to meet these criteria, resulting in their
exclusion and negatively affecting detection performance. RetinaNet [21] faces
challenges such as mismatched anchor box sizes, insufficient feature refinement,
imbalance between positive and negative samples, and inadequate contextual
information for small objects, making it difficult to accurately capture their fea-
tures. CenterNet [10] faces challenges in detecting densely packed small objects,
often leading to confusion among center points. Additionally, the lack of special-
ized processing for small targets and amplified regression errors make it difficult
to accurately detect and differentiate these objects.

Table 2. Detection results of the classical model and the proposed model. (The bold
data in the table indicate the best results.

Models mAP50/% mAP50-95/% Params/M


Faster R-CNN 37.2 21.9 41.39
Cascade R-CNN 39.1 24.3 68.95
RetinaNet 19.1 10.6 35.68
CenterNet 33.7 18.8 70.75
Proposed 45.2 27.3 6.1

Combining the results from both comparative experiments, the proposed


method demonstrates superior detection performance compared to other mod-
els. YOLO-LiteMax, in particular, shows significant advantages in small object
detection over the comparison models. Moreover, by introducing SCB and SCPD
modules into the benchmark model, the proposed method not only accounts for
resource consumption but also enhances detection performance.

4.4 Ablation Experiments

Ablation studies are commonly used in object detection to rigorously evaluate


the impact of different components of a deep learning model on detection perfor-
mance. To verify the effectiveness of the proposed modules in this study, ablation
experiments are conducted using the VisDrone2019 dataset. The specific exper-
imental results are shown in Table 3. Table 3 shows the mAP50 and parameter
count changes for each proposed improvement in the baseline model.
YOLO-LiteMax: An Improved Model for UAV Small Object Detection 419

The results in Table 3 show that each proposed improvement contributes to


the baseline model’s detection performance to varying degrees. Introducing the
SCB module to replace the CSP module allows SCB to process only part of
the feature map channels, effectively reducing redundant information. Due to its
simpler structure, the SCB module reduces the model parameters by 2.8 million,
decreases the model size by 5.4 MB, and improves mAP50 by 0.2%. Building
on this, the inclusion of the STSSF structure in the neck network for feature
fusion further increases mAP50 by 4.7%. This adjustment effectively fuses small-
scale features with richer detail information, achieving a more comprehensive
integration of shallow and deep features, which significantly reduces the miss
rate for small objects. Finally, replacing the original detection head with the new
SCPD head enhances the model’s adaptability to targets, increasing mAP50 by
an additional 1% while also reducing the parameter count by 0.7 M.
Overall, the improved model achieves an average detection precision increase
of 5.9%, with significant optimizations observed across most detection metrics,
all while maintaining a lower parameter count and smaller model size compared
to the baseline. However, while adding small-scale feature layers improves detec-
tion performance, it also makes the model structure more complex, leading to
increased inference time.

Table 3. Detection results after the introduction of different improvement strategies.


(The bold data in the table indicate the best results.)

Baseline SCB STSSF SCPD Precision/% Recall/% mAP50/% mAP50-95/% Params/M Model Size/MB
YOLOv8s 50.9 38.2 39.3 23.5 11.1 21.5
 50.5 38.3 39.5 23.4 8.3 16.1
  55.8 41.7 44.2 26.5 6.9 13.6
   56.3 42.7 45.2 27.3 6.1 12.2

4.5 Visualization Analysis


To visually demonstrate the detection effectiveness of our method, inference
experiments are conducted to compare the performance of YOLO-LiteMax,
YOLOv8s, YOLOv10s, and YOLOv11s. We select two representative scenar-
ios as experimental data: public facilities and traffic junctions. These scenarios
contain a large number of small objects, making them well-suited for inference
experiments.
420 J. Su et al.

Figure 6, the white box at the top of the image contains several motorcycles.
None of the YOLOv8s, YOLOv10s, or YOLOv11s models detected the motor-
cycles, whereas YOLO-LiteMax is able to identify most of the targets. Addi-
tionally, on the left side of the image, there is an adult and a child; YOLOv8s,
YOLOv10s, and YOLOv11s only recognize the adult, whereas YOLO-LiteMax
accurately identifies both the adult and the child.

Fig. 6. Public facilities. (a) results of yolov8s; (b) results of yolov10s; (c) results of
yolov11s; (d) results of YOLO-LiteMax.

Figure 7 shows a traffic junction where four cyclists are marked with a white
box on the crosswalk. YOLOv8s and YOLOv11s detect only one individual,
YOLOv10s fails to detect any targets, while YOLO-LiteMax successfully detects
all four cyclists. Additionally, the traffic sign highlighted in the white box in
the center of the image is mistakenly identified as a car by YOLOv10s and
YOLOv11s, while both YOLOv8s and YOLO-LiteMax correctly avoid this mis-
classification.
YOLO-LiteMax: An Improved Model for UAV Small Object Detection 421

Fig. 7. Traffic intersection. (a) results of yolov8s; (b) results of yolov10s; (c) results of
yolov11s; (d) results of YOLO-LiteMax.

Disclosure of Interests. The authors have no competing interests to declare that


are relevant to the content of this article.

References
1. Albaba, B.M., Ozer, S.: SyNet: an ensemble network for object detection in UAV
images. In: 2020 25th International Conference on Pattern Recognition (ICPR),
pp. 10227–10234. IEEE (2021)
2. Burgués, J., Marco, S.: Environmental chemical sensing using small drones: a
review. Sci. Total Environ. 748, 141172 (2020)
3. Cai, Z., Vasconcelos, N.: Cascade R-CNN: delving into high quality object detec-
tion. In: Proceedings of the IEEE Conference on Computer Vision and Pattern
Recognition, pp. 6154–6162 (2018)
4. Chang, Y.C., Chen, H.T., Chuang, J.H., Liao, I.C.: Pedestrian detection in aerial
images using vanishing point transformation and deep learning. In: 2018 25th IEEE
International Conference on Image Processing (ICIP), pp. 1917–1921. IEEE (2018)
5. Chen, J., et al.: Run, don’t walk: chasing higher flops for faster neural networks.
In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern
Recognition, pp. 12021–12031 (2023)
6. Daud, S., et al.: Applications of drone in disaster management: a scoping review.
Sci. Justice 62(1), 30–42 (2022)
7. Deng, J., Shi, Z., Zhuo, C.: Energy-efficient real-time UAV object detection on
embedded platforms. IEEE Trans. Comput. Aided Des. Integr. Circuits Syst.
39(10), 3123–3127 (2019)
422 J. Su et al.

8. Detector, A.F.O.: FCOS: a simple and strong anchor-free object detector. IEEE
Trans. Pattern Anal. Mach. Intell. 44(4) (2022)
9. Du, D., et al.: VisDrone-DET2019: the vision meets drone object detection in image
challenge results. In: Proceedings of the IEEE/CVF International Conference on
Computer Vision Workshops, pp. 0–0 (2019)
10. Duan, K., Bai, S., Xie, L., Qi, H., Huang, Q., Tian, Q.: CenterNet: keypoint triplets
for object detection. In: Proceedings of the IEEE/CVF International Conference
on Computer Vision, pp. 6569–6578 (2019)
11. Kang, M., Ting, C.M., Ting, F.F., Phan, R.: ASF-YOLO: a novel YOLO model
with attentional scale sequence fusion for cell instance segmentation. Image Vis.
Comput. 147, 105057 (2024)
12. Li, Y., Chen, Y., Wang, N., Zhang, Z.: Scale-aware trident networks for object
detection. In: Proceedings of the IEEE/CVF International Conference on Com-
puter Vision, pp. 6054–6063 (2019)
13. Li, Y., Liu, M., Jiang, D.: Application of unmanned aerial vehicles in logistics: a
literature review. Sustainability 14(21), 14473 (2022)
14. Lin, T.Y., Dollár, P., Girshick, R., He, K., Hariharan, B., Belongie, S.: Feature
pyramid networks for object detection. In: Proceedings of the IEEE Conference on
Computer Vision and Pattern Recognition, pp. 2117–2125 (2017)
15. Liu, W., et al.: SSD: single shot MultiBox detector. In: Leibe, B., Matas, J., Sebe,
N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9905, pp. 21–37. Springer, Cham
(2016). https://doi.org/10.1007/978-3-319-46448-0_2
16. Lo, L.Y., Yiu, C.H., Tang, Y., Yang, A.S., Li, B., Wen, C.Y.: Dynamic object
tracking on autonomous UAV system for surveillance applications. Sensors 21(23),
7888 (2021)
17. Lou, H., et al.: DC-YOLOv8: small-size object detection algorithm based on camera
sensor. Electronics 12(10), 2323 (2023)
18. Maes, W.H., Steppe, K.: Perspectives for remote sensing with unmanned aerial
vehicles in precision agriculture. Trends Plant Sci. 24(2), 152–164 (2019)
19. Redmon, J.: You only look once: unified, real-time object detection. In: Proceedings
of the IEEE Conference on Computer Vision and Pattern Recognition (2016)
20. Ren, S., He, K., Girshick, R., Sun, J.: Faster R-CNN: towards real-time object
detection with region proposal networks. IEEE Trans. Pattern Anal. Mach. Intell.
39(6), 1137–1149 (2016)
21. Ross, T.Y., Dollár, G.: Focal loss for dense object detection. In: proceedings of the
IEEE Conference on Computer Vision and Pattern Recognition, pp. 2980–2988
(2017)
22. Sahin, O., Ozer, S.: Yolodrone: Improved yolo architecture for object detection in
drone images. In: 2021 44th International Conference on Telecommunications and
Signal Processing (TSP), pp. 361–365. IEEE (2021)
23. Sun, W., Dai, L., Zhang, X., Chang, P., He, X.: RSOD: real-time small object
detection algorithm in UAV-based traffic monitoring. Appl. Intell., 1–16 (2022)
24. Tang, G., Ni, J., Zhao, Y., Gu, Y., Cao, W.: A survey of object detection for UAVs
based on deep learning. Remote Sens. 16(1), 149 (2023)
25. Wang, A., et al.: YOLOv10: real-time end-to-end object detection. arXiv preprint:
arXiv:2405.14458 (2024)
26. Wu, Y., He, K.: Group normalization. In: Proceedings of the European Conference
on Computer Vision (ECCV), pp. 3–19 (2018)
27. Zhao, L., Zhu, M.: MS-YOLOv7: YOLOv7 based on multi-scale for object detection
on UAV aerial photography. Drones 7(3), 188 (2023)
LMCF-FS: A Novel Lightweight Malware
Classification Framework Driven
by Feature Selection

Cui Yun1 , Lei Zhou2,3 , Shuangshuang Xing2,3 , Ning Yang2,3 , Pan Zhao2,3 ,
and Zhiguo Chen2,3(B)
1
School of Computer, Jiangsu University of Science and Technology,
Zhenjiang 212100, China
[email protected]
2
Engineering Research Center of Digital Forensics, Ministry of Education, Nanjing
University of Information Science and Technology, Nanjing 210044, China
{zhoulei,shuangshuangxing,yangning,zhaopan}@nuist.edu.cn
3
School of Computer Science and School of Cyber Science and Engineering, Nanjing
University of Information Science and Technology, Nanjing 210044, China
[email protected]

Abstract. The rapid increase in the number of malware and its variants
poses a significant threat to internet security. While existing machine
learning-based malware classification methods can improve accuracy,
they often require complex feature engineering. In contrast, converting
executable files into images and classifying them using deep learning
models can reduce the reliance on prior knowledge of malware features
and simplify the feature processing. However, most studies have not ade-
quately considered the impact of image pixel features and size on classi-
fication results. In addition, the model structure of deep learning classi-
fication methods based on pre-trained parameters and transfer learning
is usually more complex, resulting in the classification model requiring
more memory and time overhead. To solve the above problems, this paper
proposes a novel lightweight malware classification framework LMCF-FS
driven by feature selection. This framework uses a proposed new fea-
ture selection method based on pixel pair features to improve feature
sparsity and optimize image texture features, thereby improving feature
expression capabilities. At the same time, an interpolation algorithm is
used to balance the image size. In addition, a new lightweight malware
classification model ConvInceptionNet is built based on the Inception-C
module to balance accuracy, computational cost, and number of param-
eters, thereby improving malware classification efficiency. Our proposed
LMCF-FS framework achieves a classification accuracy of 99.12% on the
Microsoft Malware Classification Challenge Dataset (BIG2015 Dataset),
and it only takes 0.92 ms to predict a 224 × 224 malware image, verifying
LMCF-FS effectiveness in dealing with malware classification.

Keywords: Malware Classification · Feature Selection · Lightweight


Model · Image Enhancement · Deep Learning
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 423–437, 2025.
https://doi.org/10.1007/978-981-96-4566-4_29
424 C. Yun et al.

1 Introduction
With the widespread application of the 5G network, Internet of Things technol-
ogy, and cloud computing, people’s lives have been brought great convenience.
However, this convenience is also accompanied by the increasing threat of mal-
ware, which brings security risks such as data leakage, malicious extortion, and
privacy theft to user terminals. According to statistics, the number of mal-
ware incidents has risen annually over the past decade, with over 1.35 billion
pieces detected by 2024 [1]. This phenomenon shows that cybercriminals are
constantly exploring new means and techniques to evade detection by security
measures. Faced with this challenge, security researchers are working to research
and develop more effective defenses against malware attacks.
Traditional malware static analysis technology uses machine learning meth-
ods to automatically learn malware characteristics and behavior patterns
through training models to achieve automated classification [7], thereby reducing
classification time and improving classification performance. Although machine
learning methods have brought certain improvements to malware classification,
traditional machine learning algorithms such as logistic regression, support vec-
tor machines, and decision trees [26] still require professional knowledge related
to the field of malware classification to carry out relevant feature extraction and
selection. This means that once malware developers learn the signatures used by
systems, they can evade systems relatively easily.
Compared with machine learning algorithms, deep learning technology, with
its multi-layered neural network structure, can automatically learn deep-level
distinguishing features in original binary files or code data without the need
for manual design by security researchers [4,28]. Therefore, deep learning tech-
nology is more suitable to deal with the challenges of new malware and its
variants, and can effectively improve classification accuracy and generalization
capabilities. Nevertheless, for malware that uses obfuscation technology, whether
its unpacking or decryption is successful will directly affect the performance of
the classification model. In order to avoid the impact of obfuscation technol-
ogy, many scholars explore visualizing executable programs as images and apply
deep learning models for malware classification [8,16,27]. However, there are
some issues with this method of converting malware into images. For example,
the impact of image pixel characteristics and size on classification results is not
fully considered during conversion, which may result in reduced classification
accuracy. In addition, some studies use methods such as pre-training parameters
and transfer learning to improve classification accuracy and model performance,
but these methods often cause the model to become large, complex, and lack
flexibility.
To mitigate the impact of obfuscation techniques on static malware clas-
sification research, while reducing model complexity and enhancing the accu-
racy and performance of malware classification tasks, we propose a lightweight
malware classification framework driven by feature selection, called LMCF-FS.
This framework introduces a novel feature selection method to optimize and
enhance image texture features, thereby improving the model’s classification per-
LMCF-FS 425

formance. On this basis, we design a lightweight model based on the Inception-C


module to further enhance the accuracy and speed of malware family sample
classification.
Summarizing the above, the contributions of this paper are as follows:

– We propose a new feature selection method based on pixel pair features,


which extracts key pixel pair features through probabilistic sorting to improve
feature sparsity and optimize texture features of malware images. At the same
time, the bicubic interpolation algorithm is used to enhance the malware
image after feature selection to solve the problem of image size imbalance
and improve classification accuracy.
– We propose a lightweight malware classification model ConvInceptionNet
based on the Inception-C module. This model has good feature extraction
capabilities, a short training time, and a small number of parameters, and is
suitable for resource-constrained environments.

2 Related Work
2.1 Malware Classification Based on Feature Selection

Feature selection methods, by extracting key malicious behavior features from


malware binary files and source code, can reduce redundant and irrelevant fea-
tures, thereby improving the classification accuracy and computational efficiency
of the model. For example, Kong et al. [12] effectively simplified opcode features
in the dataset by removing redundancies and forming 18 feature combinations.
The malware classification scheme based on this approach achieved a classifica-
tion accuracy of 98.6%. Additionally, Ni et al. [17] utilized principal block selec-
tion and bilinear interpolation techniques to convert malware code into SimHash
grayscale images, preserving key static features of the malware, resulting in an
accuracy of 98.862% with a classification speed of only 1.41 s. Darem et al. [5]
optimized core features such as opcode and pixel count through feature selection
and converted them into images, achieving an accuracy of 99.12%.
Although feature selection methods can effectively identify the most discrim-
inative features, reducing time costs, simplifying model structures, and enhanc-
ing detection efficiency, the complex structure of malware binary files and source
code poses challenges. Feature selection may overlook the intricate interactions
between structural and content features, potentially impacting overall classifica-
tion performance.

2.2 Malware Classification Based on Deep Learning

Deep learning methods possess the advantages of autonomously learning features


and handling complex nonlinear relationships in malware detection, allowing for
the effective capture of intricate interactions between features. By autonomously
learning the correlations between the structural and content features of malware,
426 C. Yun et al.

deep learning can address the complex feature relationships posed by different
types of malware and their variants, exhibiting strong generalization capabil-
ities. For instance, Gibert et al. [6] achieved classification results comparable
to gradient boosting methods across multiple malware datasets by designing
a multimodal network model that combines various feature types, including
images and text. Vasan et al. [24] attained accuracies of 98.82% and 97.35%
on multiple datasets by converting raw binary files into color images and uti-
lizing a fine-tuned CNN architecture for classification. Additionally, Awan et
al. [3] effectively enhanced the model’s ability to extract malware image fea-
tures by combining a spatial attention mechanism with a convolutional neural
network (SACNN), achieving a classification accuracy of 97.42%. Kumar et al.
[13] applied deep learning to Internet of Things (IoT) malware classification
by integrating traditional convolutional neural networks with transfer learning,
attaining classification accuracies of 99.18% and 98.63% across multiple datasets
through fine-tuning the CNN. Additionally, Mallik et al. [15] combined convo-
lutional networks with recurrent networks to capture the temporal features of
malware, achieving a classification accuracy of 98.36% on the MalImg dataset,
showcasing the powerful capability of deep learning in handling malware behav-
ior patterns.
These studies demonstrate that deep learning models can autonomously learn
and construct hierarchical feature representations from raw data, thereby cap-
turing subtle differences in malware behavior for precise classification. However,
these models rely on pre-trained parameters and transfer learning, which signif-
icantly increase computational resource demands and complicate the architec-
ture, thus limiting their applicability in resource-constrained environments.

2.3 Malware Classification Based on Lightweight Model


To address the high computational resource demands of complex malware clas-
sification models, lightweight models optimize network structures and reduce
model parameters, thereby lowering model complexity while enhancing real-
time efficiency. This effectively mitigates the limitations of complex models that
are challenging to deploy in resource-constrained environments, suffer from slow
inference speeds, and incur significant computational overhead. For example,
Su et al. [20] achieved a classification accuracy of 94.0% by converting mal-
ware binary files into grayscale images using a lightweight convolutional model.
Additionally, the lightweight model proposed by Zou et al. [29] combined coor-
dinate attention with depthwise separable convolutions, achieving classification
accuracies of 99.785% and 98.942% across multiple datasets.
Although lightweight models may perform slightly worse than large deep-
learning models when handling complex malware and its variants, they sig-
nificantly reduce resource consumption and operate efficiently on resource-
constrained devices. Lightweight models offer advantages such as ease of deploy-
ment, flexible integration, and strong real-time performance, making them par-
ticularly suitable for resource-limited environments like mobile devices and edge
computing nodes, showcasing high practicality.
LMCF-FS 427

Fig. 1. LMCF-FS: The proposed malware classification framework

3 Methodology
3.1 Overview
The malware classification framework LMCF-FS proposed in this paper is shown
in Fig. 1 and is divided into four parts: image visualization, feature selection,
image enhancement, and model construction. The image visualization module
converts the malware’s Byte file into an RGB image. The feature selection mod-
ule sorts the pixel pairs of the RGB images by probability to extract key pixel
pair features and optimize the texture features of the malware images. The
image enhancement module uses the bicubic interpolation algorithm to equalize
the image size. The model-building module combines the convolutional neural
network and the Inception-C module Combined, a lightweight malware classi-
fication model ConvInceptionNet is built to improve the efficiency of malware
family classification.

3.2 Malware Visualization and Feature Selection


To avoid interference from obfuscation techniques on byte files, we use the
B2IMG method [23] to convert malware samples of the BIG2015 dataset into
RGB images. By analyzing the shared texture features within the same malware
428 C. Yun et al.

family, we can effectively differentiate between various malware families using


image-based methods. According to the B2IMG algorithm, first, convert the hex-
adecimal values separated by spaces in each row into decimal numbers. Then,
remove the meaningless symbols and row numbers and load the decimal values
into a one-dimensional array. Next, the one-dimensional array is converted into
a three-dimensional matrix according to the image size conversion formula in
BI2IMG. Finally, convert the three-dimensional matrix into an RGB image.
However, the conversion process did not fully account for the impact of image
pixel features and dimensions on classification results. To address this, we pro-
pose a feature selection method based on pixel pair features to reduce redundant
features in RGB images, enhance feature sparsity, and optimize image texture
characteristics. This approach aims to improve feature representation and, con-
sequently, enhance malware classification efficiency. As shown in Fig. 2, firstly, a
256 × 256 matrix C is generated, where each element is a random value between
0 and 255, representing the color map [18]. Secondly, obtain the pixel sequence
(1−D vector) of the RGB image, and represent each pair of consecutive elements
in the pixel sequence as a pixel pair C[x][y] (a total of 256 × 256 pixel pair com-
binations). Then, P[x][y] represents the ranking of these pixel pairs’ occurrence
probability from high to low. Finally, select the pixel pair features with the top
rankings in P[x][y] (N is the number of features), as shown in Eq. 1:

256 × 256
N= (1)
2n
Among them, n = {0, 1, 2, . . . , 7}.
Through the above feature selection process, the pixel values of the selected
pixel pairs remain unchanged, while the pixel values of the unselected pixel
pairs are replaced with 0 respectively. The image after feature selection is shown
in Fig. 3. This process is designed to optimize image texture features, improve
feature sparsity, and enhance the discrimination of malware classification during
the feature selection process to improve the effectiveness of malware classification
algorithms.
Due to differences in the content length of malware files, the size of the
generated RGB images is not uniform. To solve this problem, we use the bicubic
interpolation algorithm for image enhancement and adjust all images to the
standard size of 224 × 224 to improve the efficiency of malware classification.
Bicubic interpolation can preserve the details of the original image as much as
possible and produce smoother edges [11]. As shown in Fig. 4, this algorithm
uses the gray value of 16 points around the sampling point to perform cubic
interpolation, while taking into account the influence of the gray value of the
adjacent 4 points and the influence of the change rate of the gray value. The
formula for calculating the pixel value is as follows:
LMCF-FS 429

Fig. 2. The proposed feature selection method



⎨(a + 2)|x| − (a + 3)|x| + 1
3 2
for |x| ≤ 1
W (x) = a|x|3 − 5a|x|2 + 8a|x| − 4a for 1 < |x| < 2 (2)


0 otherwise


3 
3
f (x, y) = f (xi , yj )W (x − xi )W (y − yi ) (3)
i=0 j=0

Among them, is the distance from the pixel point (x, y) to the nearest 16
sample points. a generally takes 1 or −0.5. For the interpolated pixel point (x,
y) (x, y can be a floating point number), select points near 4 × 4 for weighted
summation, and calculate according to Eq. 3.

3.3 ConvInceptionNet

The Inception model stands out in the field of deep learning for its powerful
feature extraction capabilities and efficient parallel computing capabilities, espe-
cially in image recognition tasks [21]. This model processes input data in parallel
through multiple convolution kernels of different scales and can capture differ-
ent levels of feature information, thus improving the accuracy of classification.
Therefore, we draw on its efficient feature extraction mechanism and combine
it with a lightweight network structure to design a lightweight network model
ConvInceptionNet to achieve high efficiency and accuracy in malware classifica-
tion and adapt to device resource-limited environments and mobile applications.
The model architecture is shown in Fig. 5.
Firstly, image features are input to three convolutional layers: the first con-
volutional layer has 64 filters of size 5 × 5 with a stride of 2 × 2; the second
convolutional layer has 192 filters of size 1 × 1 filter with a stride of 2 × 2;
430 C. Yun et al.

Fig. 3. The malware images after feature selection, where N is the number of features

Fig. 4. Position of the last 16 pixels of point P.

after the activation function and a 3 × 3 max pooling layer, the stride is 2 × 2;
the third convolutional layer has 256 filters of size 3 × 3, with a stride of The
width is 2 × 2. Secondly, two Inception-C modules, pooling layers, and specific
feature selection methods are used to improve the efficiency and performance
of the model. Each Inception-C module consists of multiple branches, includ-
ing 1 × 1 convolution, 1 × 3 convolution, 3 × 1 convolution, and average pooling.
Each branch produces 256 output channels (filters). Finally, the FC layer uses
a drop rate of 0.1. Our proposed model can learn features at different scales by
using multiple convolution kernels of different sizes and structures to process
input feature maps in parallel, thereby obtaining richer feature expressions and
achieving accurate classification of malware images.
LMCF-FS 431

Fig. 5. The proposed ConvInceptionNet network

4 Experimental Evaluation
4.1 Data Set

To validate and evaluate the effectiveness of the LMCF-FS malware classifica-


tion scheme, we used the BIG2015 [19] dataset. This dataset was provided by
Microsoft during the 2015 Kaggle Malware Classification Challenge and has been
widely applied in the field of static malware analysis since 2016, becoming one
of the standard datasets in this domain. The BIG2015 dataset contains over
20,000 assembly and bytecode files, representing 9 malware families. The exper-
imental part of this paper only uses 10,868 malicious samples in its training set.
The BIG2015 dataset provides binary and assembly file formats of malicious
samples. This paper uses the binary files in this dataset for system analysis to
fully mine the feature information related to malware, thereby improving the
classification performance.
432 C. Yun et al.

Table 1. The comparison of experimental results based on the number of features of


different pixel pairs

Number of features Raw Image 224 × 224 Image


Precision Recall F1-Score Accuracy Precision Recall F1-Score Accuracy
65536 98.43% 98.46% 98.34% 98.40% 98.62% 98.45% 98.57% 98.51%
32768 98.53% 98.75% 98.31% 98.53% 98.76% 98.81% 98.61% 98.71%
16384 98.67% 98.40% 98.67% 98.53% 98.87% 98.79% 98.83% 98.83%
8192 98.71% 98.70% 98.65% 98.68% 98.85% 98.84% 98.82% 98.39%
4096 98.76% 98.48% 97.35% 97.91% 98.77% 98.79% 98.78% 98.38%
2048 98.85% 99.00% 98.83% 98.91% 98.97% 98.97% 98.97% 98.96%
1024 98.89% 98.75% 98.76% 98.76% 99.12% 99.24% 99.01% 99.12%
512 98.43% 98.37% 98.36% 98.36% 98.71% 98.63% 98.58% 98.61%

4.2 Ablation Experiment on the Effectiveness of Feature Selection


and Image Enhancement

To reduce redundant features of RGB images, and optimize image texture fea-
tures, we use the method described in Sect. 3.2 to perform feature selection on the
RGB images of the BIG2015 dataset. From the experimental results in Table 1, it
can be concluded that when feature selection is performed on the original image,
the accuracy increases as the number of features decreases. In particular, when
the number of pixel pair features is reduced from 65536 to 1024, the malware
classification accuracy increases from 98.43% to 98.89%, the precision increases
from 98.46% to 98.75%, and the recall rate increases from 98.34% to 98.76%, and
F1 The score increased from 98.40% to 98.76%. However, when the number of
features is less than 512, the classification performance of the malware classifica-
tion model decreases significantly because the key features of the image are lost.
It can be concluded that appropriately reducing the redundant pixel features of
RGB images can improve the efficiency of malware classification, but there is a
trade-off between the number of features and classification performance.
In addition, in order to reduce the impact of malware image size imbalance
on the classification effect, we use the bicubic interpolation algorithm introduced
in Sect. 3.2 to perform image enhancement on the RGB images of the BIG2015
dataset, providing more consistent data for subsequent malware classification.
The experimental results are shown in Table 1. By applying feature selection
and standardizing image sizes, the accuracy, recall, and F1 score of malware
classification increased by 0.19%, 0.23%, and 0.11%, respectively. In particular,
when the number of pixel pair features is reduced to 1024, the malware classifi-
cation accuracy reaches 99.12%, the precision reaches 99.24%, the recall reaches
99.01%, and the F1 score reaches 99.12%, achieving the best classification perfor-
mance. The results indicate that the combination of feature selection and image
enhancement can significantly improve the efficiency and accuracy of malware
classification systems.
LMCF-FS 433

4.3 Performance Comparison with Other Lightweight Models


To evaluate the performance of the lightweight model ConvInceptionNet pro-
posed in this paper, the ConvInceptionNet model was experimentally compared
with four typical lightweight network models on the original images of the
BIG2015 dataset.
To enhance recognition accuracy, the number of parameters in neural network
models has been continuously increasing, which imposes a significant burden
on the network bandwidth required for data transmission. Consequently, a key
research focus is how to reduce network complexity while maintaining model
accuracy. In 2016, Iandola et al. [10] proposed the SqueezeNet model, which
adopted the “fire module” structure, in which the squeeze layer is used to reduce
the number of channels and the expanding layer is used to increase the number
of channels. SqueezeNet has 50 times fewer parameters than AlexNet, the model
size is only 4.8M, and it can bring recognition accuracy comparable to AlexNet.
In 2018, Ma et al. [14] proposed the ShuffleNetV2 model, which introduced
lightweight group convolution blocks and group convolutions, effectively reduc-
ing the model size and memory usage. In 2019, Howard et al. [9] proposed
the MobileNetV3 model, which adopted the inverted residual and squeeze-and-
excitation modules to improve the feature extraction capability and model per-
formance. Tan et al. [22] proposed the EfficientNetB0 model. This model uses
a network scaling strategy based on the composite scaling method to improve
performance by simultaneously adjusting the depth, width, and resolution of the
network.
The experimental results are shown in Table 2. In terms of accuracy, the
classification accuracy of the ConvInceptionNet model reached 98.43%, which is
higher than other lightweight models, proving its good feature extraction and
classification capabilities. It is worth noting that while maintaining high accu-
racy, ConvInceptionNet also has the smallest training time and prediction time
costs, which are 7 s/epoch and 0.92 ms respectively.

Table 2. Comparison of lightweight models with ConvInceptionNet for the BIG2015


dataset

Model Accuracy Training time per epoch Prediction Time Parameters GPU Memory (MiB)
SqueezeNet 97.28% 9s 1.38 ms 740,041 3280
ShuffleNetV2 98.16% 10 s 1.82 ms 2,720,383 2226
MobileNetV3 98.25% 10 s 1.84 ms 1,521,463 2476
EfficientNetB0 98.25% 11 s 1.87 ms 6,227,797 3840
ConvInceptionNet 98.43% 7 s 0.92 ms 1,268,312 2604

In terms of the number of parameters and memory overhead, although


SqueezeNet has the smallest number of parameters and ranks second in mem-
ory overhead, its accuracy is relatively low. In contrast, the ConvInceptionNet
model achieves higher accuracy while maintaining a smaller number of parame-
ters, showing a good balance. Although the EfficientNetB0 model has the same
434 C. Yun et al.

accuracy as the MobileNetV3 model, its number of parameters is 4.8 times that
of ConvInceptionNet. This shows that EfficientNetB0 sacrifices the lightweight
advantage of the model while pursuing performance improvement. The ConvIn-
ceptionNet model achieves a combination of high performance and lightweight
by optimizing the network structure.
Therefore, through comparative experiments, the ConvInceptionNet model
has demonstrated comprehensive performance superior to four typical light-
weight network models in key indicators such as accuracy, training/prediction
time, parameter volume, and memory usage.

4.4 Comparison with Other Malware Classification Method


Table 3 shows related research work based on the BIG2015 data set. Kong et
al. [12] used the MalFSM framework to remove redundant opcode features and
achieved a classification accuracy of 98.60% on the BIG2015 dataset, and the
classification time was 7.76 s. Vasan et al. [24] proposed an IMCBL method that
converted the original malware binary file into an image and transformed the
image using truncated singular value decomposition (SVD) to reduce the fea-
ture vector size. The converted feature vector was input into a broad learning
(BL) system and achieved a classification accuracy of 95.58%. Zou et al. [29]
proposed a capsule network-based malware classification system FACILE, which
uses dynamic routing in capsule networks to achieve high-order representation
of malware features and achieves an accuracy of 97.2% on the BIG2015 dataset.
Compared with the above methods, this paper adopts a feature selection method
based on pixel pair features that can improve the sparsity of features and opti-
mize texture features. On the test set, it achieved precision, recall, F1-score, and
accuracy of 99.24%, 99.01%, 99.12%, and 99.12% respectively.

Table 3. Literature comparison of the proposed classification framework on BIG2015


dataset.

Authors Year Method Precision Recall F1-Score Accuracy Prediction Time


Acharya et al. [2] 2021 EfficientNetB1 96.00% 97.00% 97.00% 98.57% –
Mallik et al. [15] 2022 ConRec 99.40% 96.88% 98.12% 98.36% –
Zou et al. [29] 2022 IMCLNet 98.94% 98.38% 98.64% 99.11% 0.84 ms
Kong et al. [12] 2023 MalFSM 98.4% 97.04% 97.65% 98.60% 7.76 s
Zou et al. [30] 2023 FACILE 93.99% 92.14% 92.63% 97.20%
Vasan et al. [25] 2024 IMCBL 95.31% 95.49% 95.36% 95.58%
Ours – LMCF-FS 99.24% 99.01% 99.12% 99.12% 0.92 ms

Mallik et al. [15] used a feature extractor based on VGG16 to extract visual
outliers, and merged the features with the output of the VGG16 layer through
two BiLSTM layers, achieving an accuracy of 98.36% on the BIG2015 dataset.
Acharya et al. [2] proposed a lightweight model based on EfficientNetB1, which
LMCF-FS 435

achieved an accuracy of 98.57% on the BIG2015 dataset. Zou et al. [29] pro-
posed a lightweight IMCLNet model that integrates coordinate attention, depth-
separable convolution, and global context embedding, achieving an accuracy of
99.11% on the BIG2015 dataset.
Unlike the above research methods, this paper performs feature selection and
image enhancement on the RGB images of malware. Then we input them into
the lightweight model ConvInceptionNet for classification. This method achieves
comprehensive improvements in precision, recall, F1 score, and accuracy while
maintaining high efficiency. Among them, it is worth noting that the IMCLNet
model of Zou et al. [30] predicts a 32 × 32 size malware image in 0.84 ms, while
the LMCF-FS framework proposed in this paper predicts a 224 × 224 size mal-
ware image only needs 0.92 ms. Therefore, the malware images used in this paper
are larger and more informative, but the prediction time does not increase sig-
nificantly. This proves that LMCF-FS can maintain high classification accuracy
while effectively reducing time consumption when processing large-size images,
and is more suitable for resource-constrained scenes and mobile devices.

5 Conclusion
In order to further improve the efficiency of malware classification, this paper
proposes a lightweight malware classification framework driven by feature
selection(LMCF-FS). This framework integrates visualization technology, fea-
ture selection, image enhancement, and lightweight neural network models. We
convert the malware’s byte files into RGB images. Then we propose a new fea-
ture selection method based on pixel pair features. This method optimizes image
texture features by increasing the sparsity of features, which can significantly
improve the expressive ability of features. At the same time, we use the bicu-
bic interpolation algorithm to enhance the malware images, which effectively
solves the problem of the unbalanced size of the image dataset. In addition,
a lightweight model ConvInceptionNet based on the Inception-C module was
designed, which achieves a smaller parameter size and faster inference speed
while maintaining high performance. The LMCF-FS achieved 99.12%, 99.24%,
99.01%, and 99.12% accuracy, precision, recall, and F1 score respectively on the
BIG2015 data set, and the time to predict a 224 × 224 image is only 0.92 ms.

Acknowledgement. This work is supported by the National Natural Science Foun-


dation of China Grant No. 62102190.

References
1. Av-test. (2020). https://www.av-test.org/en/statistics/malware/
2. Acharya, V., Ravi, V., Mohammad, N.: EfficientNet-based convolutional neural
networks for malware classification. In: 2021 12th International Conference on
Computing Communication and Networking Technologies (ICCCNT), pp. 1–6.
IEEE (2021)
436 C. Yun et al.

3. Awan, M.J., et al.: Image-based malware classification using VGG19 network and
spatial convolutional attention. Electronics 10(19), 2444 (2021)
4. Basha, S.S., Dubey, S.R., Pulabaigari, V., Mukherjee, S.: Impact of fully connected
layers on performance of convolutional neural networks for image classification.
Neurocomputing 378, 112–119 (2020)
5. Darem, A., Abawajy, J., Makkar, A., Alhashmi, A., Alanazi, S.: Visualization and
deep-learning-based malware variant detection using opcode-level features. Futur.
Gener. Comput. Syst. 125, 314–323 (2021)
6. Gibert, D., Mateu, C., Planes, J.: Hydra: a multimodal deep learning framework
for malware classification. Compute. Secur. 95, 101873 (2020)
7. Gibert, D., Mateu, C., Planes, J.: The rise of machine learning for detection and
classification of malware: research developments, trends and challenges. J. Netw.
Comput. Appl. 153, 102526 (2020)
8. Hemalatha, J., Roseline, S.A., Geetha, S., Kadry, S., Damaševičius, R.: An efficient
DenseNet-based deep learning model for malware detection. Entropy 23(3), 344
(2021)
9. Howard, A., et al.: Searching for MobileNetV3. In: Proceedings of the IEEE/CVF
International Conference on Computer Vision, pp. 1314–1324 (2019)
10. Iandola, F.N., Han, S., Moskewicz, M.W., Ashraf, K., Dally, W.J., Keutzer, K.:
SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and¡ 0.5 mb model
size. arXiv preprint: arXiv:1602.07360 (2016)
11. Keys, R.: Cubic convolution interpolation for digital image processing. IEEE Trans.
Acoust. Speech Signal Process. 29(6), 1153–1160 (1981)
12. Kong, Z., Xue, J., Wang, Y., Zhang, Q., Han, W., Zhu, Y.: MalFSM: feature subset
selection method for malware family classification. Chin. J. Electron. 32(1), 26–38
(2023)
13. Kumar, S., et al.: MCFT-CNN: malware classification with fine-tune convolution
neural networks using traditional and transfer learning in internet of things. Futur.
Gener. Comput. Syst. 125, 334–351 (2021)
14. Ma, N., Zhang, X., Zheng, H.T., Sun, J.: ShuffleNet V2: practical guidelines for
efficient CNN architecture design. In: Proceedings of the European Conference on
Computer Vision (ECCV), pp. 116–131 (2018)
15. Mallik, A., Khetarpal, A., Kumar, S.: ConRec: malware classification using convo-
lutional recurrence. J. Comput. Virol. Hack. Tech. 18(4), 297–313 (2022)
16. Nataraj, L., Karthikeyan, S., Jacob, G., Manjunath, B.S.: Malware images: visu-
alization and automatic classification. In: Proceedings of the 8th International
Symposium on Visualization for Cyber Security, pp. 1–7 (2011)
17. Ni, S., Qian, Q., Zhang, R.: Malware identification using visualization images and
deep learning. Comput. Secur. 77, 871–885 (2018)
18. Pinhero, A., et al.: Malware detection employed by visualization and deep neural
network. Comput. Secur. 105, 102247 (2021)
19. Ronen, R., Radu, M., Feuerstein, C., Yom-Tov, E., Ahmadi, M.: Microsoft malware
classification challenge. arXiv preprint: arXiv:1802.10135 (2018)
20. Su, J., Vasconcellos, D.V., Prasad, S., Sgandurra, D., Feng, Y., Sakurai, K.:
Lightweight classification of IoT malware based on image recognition. In: 2018
IEEE 42nd Annual Computer Software and Applications Conference (COMPSAC),
vol. 2, pp. 664–669. IEEE (2018)
21. Szegedy, C., Ioffe, S., Vanhoucke, V., Alemi, A.: Inception-v4, inception-ResNet
and the impact of residual connections on learning. In: Proceedings of the AAAI
Conference on Artificial Intelligence, vol. 31 (2017)
LMCF-FS 437

22. Tan, M., Le, Q.: EfficientNet: rethinking model scaling for convolutional neural net-
works. In: International Conference on Machine Learning, pp. 6105–6114. PMLR
(2019)
23. Tekerek, A., Yapici, M.M.: A novel malware classification and augmentation model
based on convolutional neural network. Comput. Secur. 112, 102515 (2022)
24. Vasan, D., Alazab, M., Wassan, S., Naeem, H., Safaei, B., Zheng, Q.: IMCFN:
image-based malware classification using fine-tuned convolutional neural network
architecture. Comput. Netw. 171, 107138 (2020)
25. Vasan, D., Hammoudeh, M., Alazab, M.: Broad learning: a GPU-free image-based
malware classification. Appl. Soft Comput. 154, 111401 (2024)
26. Wadkar, M., Di Troia, F., Stamp, M.: Detecting malware evolution using support
vector machines. Expert Syst. Appl. 143, 113022 (2020)
27. Yuan, B., Wang, J., Liu, D., Guo, W., Wu, P., Bao, X.: Byte-level malware clas-
sification based on Markov images and deep learning. Comput. Secur. 92, 101740
(2020)
28. Zhang, Z., Qi, P., Wang, W.: Dynamic malware analysis with feature engineer-
ing and feature learning. In: Proceedings of the AAAI Conference on Artificial
Intelligence, vol. 34, pp. 1210–1217 (2020)
29. Zou, B., Cao, C., Tao, F., Wang, L.: IMCLNet: a lightweight deep neural network
for image-based malware classification. J. Inf. Secur. Appl. 70, 103313 (2022)
30. Zou, B., Cao, C., Wang, L., Fu, S., Qiao, T., Sun, J.: FACILE: a capsule net-
work with fewer capsules and richer hierarchical information for malware image
classification. Comput. Secur. 137, 103606 (2024)
Rule Learning-Based Target Prediction
for Efficient and Flexible Private
Information Retrieval

Weizhe Tu1 , Zheng Dong1 , Tao Zhang1,2(B) , and Wenying Zheng2


1
School of Information Science and Engineering (School of Cyber Science and
Technology), Zhejiang Sci-Tech University, Hangzhou, China
[email protected]
2
School of Computer Science and Technology (School of Artificial Intelligence),
Zhejiang Sci-Tech University, Hangzhou, China

Abstract. In recent years, machine learning has been widely used in all
aspects of social production and life, including transportation, finance,
etc., and predicting unknown situations based on existing information is
also a common application method of machine learning. Therefore, the
topic of prediction problems and machine learning models has always
been a common research hotspot and direction of machine learning. How-
ever, there are still many problems in the current common prediction
models. The prediction method based on the frequency and possibility
of data retrieval has insufficient prediction accuracy and high error rate.
The specific retrieval method obtained by training in the common model
is relatively fixed, which can not change in real time with the update
and change of the original data, and the dynamic is poor. At the same
time, there is no certain security protection in the prediction, which is
easy to cause certain privacy leakage in the interaction process, that is,
poor security. In order to solve the above problems, we propose a rule
learning prediction model based on private information retrieval (PIR-
RL). The model use rule learning to realize the prediction function, and
the rule learning can help to extract rule features to achieve the accu-
racy of prediction. At the same time, inspired by SealPIR, this paper
proposes a Target Prediction Private Information Retrieval (TP-PIR)
to achieve privacy protection. Among them, the dynamic nature of rule
learning and the low computational cost have certain advantages in the
face of practical problems. The lightweight TP-PIR also achieves privacy
protection while ensuring less computational communication overhead.
From the theoretical analysis and experimental results, this model has
good stability and practical value, and can realize the coordination of
prediction service in privacy protection.

Keywords: Prediction · Rule Learning · PIR-RL · Privacy Protection

c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 438–448, 2025.
https://doi.org/10.1007/978-981-96-4566-4_30
A Efficient and Flexible Prediction Model of PIR-RL 439

1 Introduction
With the rapid development of information technology, data has become one
of the most valuable resources in modern society. As a data-driven computing
method, machine learning has been continuously developed and popularized.
More and more machine learning methods are gradually applied to the relevant
aspects of social life. Using accurate machine learning methods and a wide range
of data, analysis and prediction functions can be achieved through training and
learning. In order to make the majority of users realize the convenience brought
by machine learning, Machine Learning as a Service (MLaaS) is proposed. It is
a new service model based on cloud service, which can provide users with lower
cost machine learning environment and conditions, so that users can make online
prediction services at any time. So as to make machine learning more popular.
However, at the same time, because the machine learning method needs to use
a large amount of data for learning analysis, there will be some data that is easy
to leak privacy in this massive data. These privacy data may cause the exposure
of some key information and sensitive information, which is extremely harmful
to the security of user privacy. For example, in recent years, security incidents
such as data leakage and network attacks have occurred frequently, which has
become a major hidden danger worldwide. On June 3rd, the information of
students and faculty of the Chinese University of Hong Kong was stolen in
large quantities. On July 12 th, AT&T suffered a serious data leakage incident,
resulting in the theft of nearly 110 million user calls and SMS records. The
disclosure of such information may cause serious impact and threat to individuals
and society. Similar incidents such as large-scale hacker attacks, data theft, and
user privacy leaks are still common, which seriously threaten personal privacy,
corporate interests, and national security. This also makes the importance of
information security and data protection increasingly prominent, reflecting the
importance of privacy computing in machine learning. Therefore, we need to
propose new methods to deal with privacy leakage to strengthen the protection of
data privacy, especially for data privacy security protection for machine learning.

1.1 Relate Work


At present, there has been a lot of research on privacy protection methods for
machine learning. There are also more and more scholars on its improvement and
application. There are two main ideas at present. One is to disturb the fuzzy
original data by adding noise. The common means are differential privacy, etc.
The other is to encrypt the data set or model, such as federated learning, secure
multi-party computation, etc.
Differential Privacy. Ouadrhiri and Abdelhadi [1] conducted a survey on the
application of differential privacy, and elaborated on common scenarios for deep
learning and federal learning. Jia et al. [2] combined differential privacy with
K-mean clustering to achieve multiple data protection. Li et al. [3] proposed
stacking ensemble clustering based on score privacy protection for the problem
of insufficient security in clustering algorithms.
440 W. Tu et al.

Federated Learning, Secure Multi-party Computation. Guan et al. [4]


proposed a federated learning privacy protection method using multi-key homo-
morphic encryption to train the machine learning model. On the basis of MK-
CKKS multi-key homomorphic encryption protocol, Ma et al. [5] proposed an
improved version of xMK-CKKS, which uses aggregation to encrypt and decrypt
in the server, so as to improve the degree of privacy protection. Combined with
the privacy protection of graph data, Tang et al. [6] proposed a graph query
using secure multi-party computation for privacy protection. Zhou et al. [7] ana-
lyzed the application of federal learning by combining edge learning and other
technologies.
For common prediction models, Mi et al. [8] used deep learning to predict
the drought environment, but did not consider the issue of privacy protection.
Feng [9] and others use secure multi-party computation for privacy protection,
but they have relatively high computational complexity and communication
overhead. Liu et al. [10], Gao et al. [11] and others have applied decision trees to
realistic prediction environments, but they have not considered privacy security
issues. Gong et al. [12] conducted a comparative study on the specific techniques
of rule learning, but did not consider the actual prediction model. Ye et al.
[13] integrates differential privacy into the prediction model, but the prediction
process uses continuous queries rather than specific prediction models, which is
likely to cause instability of the prediction results. Xu et al. [14] by combining
edge learning and other technologies.
There are also certain risks in adjusting the accuracy and safety of model
parameters. So,it can be seen that the research on the direction of decision-
making privacy protection is still relatively lacking, which also reflects the sig-
nificance of this model.

1.2 Contribution
With the increasing attention to data privacy and security issues, especially in
sensitive areas such as medical care, finance, and government, the use of decision
trees for data prediction is facing certain challenges. Although the decision tree
model has been widely used in classification and regression problems because of
its easy understanding, simple model construction and good performance, it still
has some shortcomings in data privacy protection and computational efficiency.
Specifically, the main problems of the decision tree model include: First, since
the decision tree model needs to train and construct the entire data set, once
the amount of data changes, it needs to be retrained, which on the one hand
increases the computational cost, on the other hand also leads to insufficient
dynamics. Secondly, in the traditional decision tree training process, the risk of
sensitive data leakage is high. Especially in an open environment, the training
data and prediction data of the model may expose the privacy information of
individuals or organizations, which poses a potential threat to data security.
In order to deal with the above problems, we design and propose an innovative
PIR-based rule learning prediction model. Based on the existing rule learning,
the model not only strengthens the protection of data privacy, but also improves
A Efficient and Flexible Prediction Model of PIR-RL 441

the dynamic and computational efficiency of the prediction model by introducing


PIR technology. Rule learning can extract rules from the original data and make
predictions based on these rules. PIR is then used to encrypt to ensure privacy
security. The contributions of this paper are as follows:
An Extraction Method Based on Eigenvalue is Proposed. We use the
machine learning method of rule learning to extract features efficiently from
the original huge data, and abstract the key points, train and construct rules
according to these key points. This method can effectively realize the function
of prediction, and has the characteristics of high efficiency and strong accuracy.
Design a Model Using Rule Learning for Prediction. Compared with
the conventional decision tree prediction model, when the data is updated, only
the corresponding specific rules need to be changed, without retraining for all
the data, which greatly improves the dynamics of the model and optimizes the
computational overhead.
Propose a New PIR Model for Prediction Combined With Practical
Problems. The TP-PIR designed by us greatly enhances the data security and
privacy protection of both sides of the communication interaction through the
introduction of specific PIR modules, and builds a feasible and stable security
prediction environment.
By combining PIR with rule learning, our proposed prediction model can
not only effectively protect data privacy, but also maintain low computational
overhead when dynamically updating data. Specifically, in each data update,
the model only needs to perform partial data retrieval and rule update in an
encrypted environment, rather than retraining the entire dataset. This method
greatly reduces the computational resource consumption while maintaining high
prediction accuracy.

2 System Model and Design Goals


2.1 System Model
The rule learning prediction model proposed by us is mainly composed of two
entities, server and user. The server provides the query interface and constructs
the original query table through the generation rules, while the client uses the
security protocol to generate the query index, so as to apply for the query to
the server through SealPIR, and finally obtain the required data or the expected
prediction results. The system model is shown in Fig. 1.
For the Server. There is a large amount of data on the server side. We first use
these data for machine learning, obtain a series of rules through rule learning,
and construct the required query or prediction data set based on these rules.
Then, the server uses this rule dataset to build a PIR database and provides
PIR services. When the user queries, the server performs a security comparison
protocol. By comparing the client ’s query attribute vector with the attribute
partition threshold stored in the server, it helps the client to generate a vector
442 W. Tu et al.

Fig. 1. PIR-RL system model.

encoding of its query attributes and index its query, and returns its corresponding
prediction value after the client sends the query information.
For the Client. It needs the machine learning prediction service first generates
certain attribute values according to the query target provided, and combines
these attribute values into a query attribute vector. Then, the client uses the
security comparison protocol to compare the query attribute vector with the set
of attribute thresholds preset on the server side to ensure the privacy and security
of the comparison process. Based on the results of security comparison, the
client encodes the query attribute vector and generates the query index. Then,
the client processes the query index according to the PIR scheme, constructs
the PIR query request, and sends the request to the server. Finally, the client
receives the PIR query results returned by the server, and obtains the structure
corresponding to the query attribute vector by decryption, that is, the required
predictive value. In short, the client encodes the query index into a query index
through vector coding, and generates a PIR query request based on the index.
After the server returns the query result, the client decrypts to obtain the final
prediction value. In the whole process, the user ’s query characteristics and
prediction results are unknown to the server. The black box-like environment
constructed by PIR ensures the privacy of data, while ensuring its computability
and comparability, so as to strengthen the protection of privacy and security
while realizing the basic functions.

2.2 Design Goals

We assume that the server is strictly in accordance with the provided scheme
and security protocol, but there is a curious service provider who wants to know
the user’s specific application content. It is hoped that the security comparison
interface can be used to analyze the user’s input features, so as to indirectly
obtain the user ’s model information and prediction data. At the same time,
it is assumed that there is a malicious data stealer outside the overall service.
A Efficient and Flexible Prediction Model of PIR-RL 443

He does not want to obtain the prediction results he wants through the normal
query steps, but wants to obtain the content he wants directly by retrieving the
original data, in an attempt to bypass the authentication identification of the
security protocol to directly access the central dataset. Moreover, the efficiency
of prediction is also worthy of attention. The model needs to ensure that the
computational overhead and communication overhead are as small as possible
to ensure that it is efficient. The computational overhead comes from the way
and mode of prediction, and the communication efficiency depends on the traffic
between the server and the client and the number of communication rounds.
In view of the above privacy risks and security risks, the proposed scheme
should have the following requirements and capabilities, including server-side
rule learning machine learning protection and user-side data security privacy
protection, as well as high efficiency in the case of achieving prediction results.
The goal of this paper is to design a safe, reliable and accurate prediction
scheme for MLaaS. In general, the specific objectives that our program needs to
achieve are:
Ensure the Security of Sensitive Information Data. Rule learning is
obtained by training and integrating the data in the server. The final rule comes
from an important part of the original data and is the core information belonging
to the server. Therefore, it is necessary to ensure that the rule learning model
is confidential to ordinary users, and the relevant information of specific rules
cannot be exposed when providing security protocol comparison interface and
privacy retrieval service interface. That is to ensure that the user cannot obtain
the specific content inside the server from the server interface, and the server
cannot spy out some of the user ’s query-related information from the interface
and security protocols to ensure that both parties are in a confidential state.
Ensure the Prediction Accuracy and Efficiency. At the same time, we also
need to achieve the accuracy and efficiency of prediction. As the fundamental
purpose of the whole system model, prediction must first ensure the basic real-
ization of the prediction function, that is, a certain demand for accuracy. At the
same time, in order to ensure the good practicability of the whole model, we also
need to ensure that the computational overhead and communication overhead
are as small as possible. The computational overhead comes from the way and
mode of prediction, and the communication efficiency depends on the amount
of communication and the number of communication rounds between the server
and the client.
So, we need to solve the above problems and requirements by designing spe-
cific method models. Focusing on improving the accuracy of prediction, the effi-
ciency and dynamics of the model, privacy and confidentiality, etc., a new pre-
diction method model with more perfect functions and outstanding advantages
is realized.
444 W. Tu et al.

3 The Proposed Scheme


At present, the commonly used prediction models are primarily decision trees.
Compared with the conventional decision tree prediction scheme [15], we use
security protocols to protect the attributes of the model to prevent exposure,
and use private information retrieval to ensure that the server cannot spy on the
user ’s query information and results, so that our scheme can guarantee security.
At the same time, rule learning is less affected by noise than decision tree. At
the same time, when the decision tree data update needs to retrain the spanning
tree, and the rule learning method is used to generate rules when data changes
or updates. The rules can also be dynamically updated, and the computational
overhead is smaller.

Fig. 2. PIR-RL Scheme Diagram.

The flow of the PIR-RL scheme designed in this paper is shown in Fig. 2. The
server S is mainly responsible for training data to generate rule sets and providing
private information retrieval services. Client C converts the original query into
a PIR query, so that the ciphertext result is obtained by using the PIR service
query, and the corresponding plaintext is obtained after decryption. The key
advantage of this scheme is that it can ensure the privacy in the query process,
and still use the ability of rule-based learning to achieve real-time update of
A Efficient and Flexible Prediction Model of PIR-RL 445

data and rules and accurate prediction. The following will introduce the specific
process steps in detail:

1. The server obtains rules through training based on its own data, and then
generates corresponding feature attributes and rule sets according to the rules.
2. The server uses the rule set to build a PIR database, in which the rules
Attribute are associated with each other. After the construction is completed,
the query service can be provided to the outside.
3. The query vector is obtained according to the query attribute given by the
client, and then the security comparison result is obtained through the com-
parison and analysis of the security protocol, so as to generate the query
index.
4. Firstly, the query index is encrypted as the query target for the server to
retrieve, and then passed to the server to initiate the PIR query request.
5. The server obtains the query application, and searches and searches in the
database DB according to the specific rules of SealPIR.
6. Retrieve the desired target results for the query request and complete a PIR
retrieval task. The server-side S sends the retrieval result ( ciphertext ) back
to the client C.
7. The client receives the query result of the server, and then decrypts it with
the private key, and the final plaintext is the expected result of the user.

The following will be divided into three parts to explain in detail.

3.1 Setup Phase

The role of the preparation stage is to prepare for the subsequent query and
response, mainly for the pre-processing of the original data already stored in
the server. On the server side, the original data should have certain association
conditions. According to these associations, rule learning is used for training,
and a series of forms such as IF-THEN or ATTRIBUTE-AND can be obtained.
According to the type of prediction required, the key value correspondence or
interrelated relationship is selected. The server will integrate the rules of these
columns as a data set, or can be understood as a rule database, and then perform
PIR related processing on the data set to become a PIR database. It ensures the
invisibility of the original rule data to ensure privacy. In the subsequent access
and comparison, only the interface of the PIR database is called and judged. On
the client side, according to the key K of the security protocol, the query vector
value [X] is encrypted, and a layer of protection is made for subsequent queries.
At the same time, when constructing the PIR database, the public key pk and
the private key sk used to perform the PIR service are prepared. The public key
pk is public, and the private key sk is saved by the user itself, which will be used
for the decryption of the later information acquisition.
446 W. Tu et al.

3.2 Query Phase

In the query phase, the client will initiate a prediction request to the server
based on the eigenvalues of the query. The query attribute is processed into a
query vector, and then the query index Q is generated by encoding. After we
obtain the query index, the query generation interface of PIR will be called
to generate the corresponding PIR query request. Since the data storage mode
used by PIR is polynomial storage, we first give a polynomial coefficient N as
the length, and then expand the query vector q to a PIR query vector p with
a dimension of l according to the data volume, that is, p = P IR.expand(q).
Where p = {p0, p1, ..., pl − 1}, each element pi of the vector p is encrypted and
corresponds one-to-one with the elements at the corresponding position in the
PIR database DB. After the extension is completed, the server performs the PIR
retrieval process.Specifically, the server performs homomorphic multiplication on
each ciphertext element pi in the query vector p and the i th data item in the
corresponding dimension of the database, and performs homomorphic addition
and splicing on the results of all homomorphic multiplications, and finally obtains
the ciphertext search result a, that is, a = P IR.answer(p, DB). Finally, the
ciphertext a is returned to the client to complete the query process

3.3 Response Phase

In the response phase, the server processes PIR queries through the SealPIR
scheme and returns the query results to the client. After receiving the response,
the client decrypts it and finally obtains the prediction result. However, at this
time, the search result a is still in ciphertext form. The server cannot know
its plaintext content, and will only return the ciphertext result a to the client
according to the pre-execution. After receiving the ciphertext a, the client will
use the private key sk generated and saved before to decrypt the ciphertext
a, so as to obtain the final plaintext prediction value. The predicted value A
corresponds to the attribute vector of the client query and represents the result
of online prediction through the rule learning model. In the whole process, the
server only processes encrypted query vectors and ciphertext data, so it cannot
know the user ’s query content or prediction results. At the same time, because
all operations are performed at the ciphertext level, the server cannot obtain
any plaintext information about the client query. Moreover, the client can only
obtain the predicted value of its query, and cannot obtain additional server data
to cause privacy leakage of the server.
We carried out the experiment according to the designed scheme, gave the
training data set with correlation, and tried the SealPIR process. Finally, the
prediction can be successfully completed under ideal conditions, and has high
accuracy. At the same time, the PIR process is also normal. Here are some
parameters of PIR (see the Table 1) :
The results show that the PIR is successfully executed and the parameter
values are normal, which verifies the correctness of our model.
A Efficient and Flexible Prediction Model of PIR-RL 447

Table 1. PIR efficiency parameters and query dimension relationship.

PIR parameter d=2 d=3 d=4


PIRServer pre-processing time 1735 ms 1783 ms 1851 ms
PIRClient query generation time 4 ms 7 ms 10 ms
PIRClient serialized query generation time 5 ms 10 ms 13 ms
PIRServer query deserialization time 2.54 ms 4.27ms 5.33 ms
PIRServer reply generation time 1082 ms 1882 ms 4731 ms
PIRClient answer decode time 3 ms 14 ms 71 ms
Query size 422.62 Kb 633.74 Kb 845.09 Kb
Reply num ciphertexts 6 36 216
Reply size 603.91 Kb 3.54 Mb 21.23 Mb

4 Conclusion

In this paper, a new PIR-RL model is proposed for common prediction problems.
According to the confidentiality problems in conventional prediction models and
the deficiencies in common prediction schemes such as decision tree encryption,
the rule learning prediction model PIR-RL based on private information retrieval
designed in this paper first introduces the flexibility and dynamics of rule learning
to optimize the defects and deficiencies of common decision tree prediction. At
the same time, the combination of SealPIR also ensures the security of the
overall model, which can effectively realize the function of prediction and privacy
protection, so as to have strong practicability.

Acknowledgments. The work is supported by the National Natural Science Foun-


dation of China No. 62402444, the Zhejiang Provincial Natural Science Foundation of
China No. LQ24F020012 and the “Pioneer” and “Leading Goose” R&D Program of
Zhejiang No. 2023C01119.

References
1. Ouadrhiri, A.E., Abdelhadi, A.: Differential privacy for deep and federated learn-
ing: a survey. IEEE Access 10, 22359–22380 (2022). https://doi.org/10.1109/
ACCESS.2022.3151670
2. Jia, B., Zhang, X., Liu, J., Zhang, Y., Huang, K., Liang, Y.: Blockchain-enabled
federated learning data protection aggregation scheme with differential privacy and
homomorphic encryption in IIoT. In: IEEE Transactions on Industrial Informatics,
vol. 18, no. 6, pp. 4049-4058 (2022). https://doi.org/10.1109/TII.2021.3085960
3. Shuai, L.I., Chang, J., Li-lv, M., Cai, K.: A Stacking ensemble clustering algo-
rithm based on differential privacy protection. Comput. Eng. Sci. 44(08), 1402–
1408 (2022)
448 W. Tu et al.

4. Guan, G., Zhi, T., Tao, Z., Cao, Y.: Federal learning privacy protection method
for multi-key homomorphic encryption in the Internet of Things. Inf. Secur. Res.
10(10), 958-966 (2024)
5. Ma, J., Naas, S.-A., Sigg, S., Lyu, X.: Privacy-preserving federated learning based
on multi-key homomorphic encryption. Int. J. Intell. Syst. 37, 5880–5901 (2022).
https://doi.org/10.1002/int.22818
6. Tang, S., Yuan, Y.: Privacy-preserving graph query based on secure multi-party
computation. Front. Data Comput. 5(5), 98–106 (2023). https://cstr.cn/32002.14.
jfdc.CN10-1649/TP.2023.05.008
7. Chuanxin, Z., Yi, S., Degang, W., Huawei, G.E.: Survey of federated learning
research. Chin. J. Netw. Inf. Secur. 7(5), 77–92 (2021)
8. Mi, Q., et al.: Application of deep learning method to drought prediction. J. Appl.
Meteor. Sci. 33(1), 104–114 (2022). https://doi.org/10.11898/1001-7313.20220109
9. ZHANG Feng;SUN Xue-dong;CHANG Hui-you;ZHAO Gan-sen: Research on
privacy-preserving two-party collaborative filtering recommendation. Acta Elec-
tron. Sin. 37(1), 84–89 (2009)
10. Liu, H., Lu, J., Peng, J., Qiao, D., Zhao, Y.: Research on redundant control of
AMT system gear shifting process based on decision tree algorithm. Trans. Beijing
Inst. Technol. 42(1), 63–73 (2022). https://doi.org/10.12172/202211150002
11. Gao, F.W., et al.: Forest fire prediction in Algeria based on decision tree algorithm
in spark MLlib. J. Sichuan For. Sci. Technol. 44(5), 24–31 (2023). https://doi.org/
10.12172/202211150002
12. Gong, Z., et al.: Comparative study of mutation techniques based on rules and
learning. J. Softw. 35(7) (2024)
13. Ye, A., et al.: Trajectory differential privacy protection mechanism based on pre-
diction and sliding window. J. Commun./Tongxin Xuebao 41(4) (2020)
14. Xu, J., Lin, J., Li, Y., Xiong, Z.: Distributed user privacy protection adjustable
personalized QoS prediction model for cloud services. J. Netw. Inf. Secur. 9(2),
70–80 (2023)
15. Yang, F.: Research on online prediction scheme of secure decision tree based on pri-
vate information retrieval. Xi ’an University of Electronic Science and Technology
(2022). https://doi.org/10.27389/d.cnki.gxadu.2022.001965.
Author Index

B L
Ban, Xinbo 148 Li, Bo 16
Li, Jingang 133
C Li, Kun 332
Cai, Zhenghua 197 Li, Lin 148
Camtepe, Seyit 148 Li, Nan 271
Che, Xun 271 Li, Yitong 299
Chen, Chao 148 Li, Yuxian 1
Chen, Dixuan 246 Liang, Dongyang 332
Chen, Fan 246, 332 Lin, Jiatong 246, 332
Chen, Xianyi 231 Lin, Zhiqiang 16
Chen, Yadang 271 Liu, Hongwei 70
Chen, Zhiguo 423 Liu, Shigang 148
Cheng, Yu 246, 332 Liu, Shuyan 377
Cui, Yuxin 118 Liu, Zhaoyi 133

D
M
Dai, Wenzhang 347
Ma, Liang 392
Deng, Xiaozhi 81
Ma, Qingru 118
Dong, Zheng 438
Duan, Ao 70
N
G Niyato, Dusit 1
Gao, Qiang 1
Gao, Shengyang 362
Gao, Zhaojian 271 P
Gong, Huimin 184 Pan, Yiting 316

H
Q
Hong, Shijia 43, 56
Qiao, Lei 184, 197
Huang, Jinming 362
Qiu, Dongyan 70
J
Ji, Yinglin 90 S
Jiang, Guixin 90 Shen, Mingdi 43, 56
Jin, Zilong 362, 377, 392 Shen, Zongliang 184
Sheng, Xinlei 284
K Su, Jian 408
Ke, Lishan 16 Sun, Jianfei 1
Kong, Wei 258 Sun, Le 347

© The Editor(s) (if applicable) and The Author(s), under exclusive license
to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 449–450, 2025.
https://doi.org/10.1007/978-981-96-4566-4
450 Author Index

T Yang, Chun 392


Tan, Haowen 104 Yang, Huayu 215
Tang, Wei 271 Yang, Huijie 56, 133
Tang, Yi 81 Yang, Jia 167
Tian, Haibo 299 Yang, Ning 423
Tu, Weizhe 438 Yang, Suliu 284
Ye, Jiabin 271
W Yi, Zongxiang 16
Wang, Chen 43, 258 Yu, Haiyan 118
Wang, Guojun 90 Yun, Cui 423
Wang, Huawei 184, 197
Wang, Jiacheng 1 Z
Wang, Jin 362 Zeng, Yuluo 284
Wang, Mingliang 104 Zhai, Xilin 215
Wang, Qianshi 215 Zhang, Chi 184, 197, 377
Wang, Yongbao 81 Zhang, Hua 184, 197, 215
Wang, Yongji 197 Zhang, Jian 408
Wu, Cong 1 Zhang, Jinzhihao 167
Wu, Qinqin 81 Zhang, Jun 148
Wu, Wei 258 Zhang, Peng 70
Zhang, Tao 438
X Zhang, Wanhui 16
Xiang, Yang 148 Zhang, Xin 377
Xing, Shuangshuang 423 Zhang, Xuefeng 316
Xu, Chongjun 215 Zhao, Kai 284
Xu, Guoqing 392 Zhao, Pan 423
Xu, Haoyao 299 Zhao, Yushuai 90
Xu, Qiuhao 258 Zheng, Wenying 56, 104, 133, 438
Zhou, Lei 423
Y Zhou, Tao 231
Yan, Hongyang 246, 332 Zhou, Tianqi 43, 56, 133
Yan, Leiming 231 Zhou, Weiqi 167
Yang, Chang 408 Zhu, Yilu 118

You might also like