Machine Learning For Cyber Security 6th International Conference, ML4CS 2024, Hangzhou, China, December 27-29, 2024
Machine Learning For Cyber Security 6th International Conference, ML4CS 2024, Hangzhou, China, December 27-29, 2024
Machine Learning
for Cyber Security
6th International Conference, ML4CS 2024
Hangzhou, China, December 27–29, 2024
Proceedings
Lecture Notes in Computer Science 15566
Founding Editors
Gerhard Goos
Juris Hartmanis
Machine Learning
for Cyber Security
6th International Conference, ML4CS 2024
Hangzhou, China, December 27–29, 2024
Proceedings
Editors
Yang Xiang Jian Shen
Swinburne University of Technology Zhejiang Sci-Tech University
Melbourne, VIC, Australia Hangzhou, Zhejiang, China
© The Editor(s) (if applicable) and The Author(s), under exclusive license
to Springer Nature Singapore Pte Ltd. 2025
This work is subject to copyright. All rights are solely and exclusively licensed by the Publisher, whether
the whole or part of the material is concerned, specifically the rights of translation, reprinting, reuse of
illustrations, recitation, broadcasting, reproduction on microfilms or in any other physical way, and transmission
or information storage and retrieval, electronic adaptation, computer software, or by similar or dissimilar
methodology now known or hereafter developed.
The use of general descriptive names, registered names, trademarks, service marks, etc. in this publication
does not imply, even in the absence of a specific statement, that such names are exempt from the relevant
protective laws and regulations and therefore free for general use.
The publisher, the authors and the editors are safe to assume that the advice and information in this book
are believed to be true and accurate at the date of publication. Neither the publisher nor the authors or the
editors give a warranty, expressed or implied, with respect to the material contained herein or for any errors
or omissions that may have been made. The publisher remains neutral with regard to jurisdictional claims in
published maps and institutional affiliations.
This Springer imprint is published by the registered company Springer Nature Singapore Pte Ltd.
The registered company address is: 152 Beach Road, #21-01/04 Gateway East, Singapore 189721, Singapore
The Sixth International Conference on Machine Learning for Cyber Security (ML4CS
2024) was held on Hangzhou, China, during December 27–29, 2024. ML4CS is a well-
recognized annual international forum for AI-driven security researchers to exchange
ideas and present their works. This volume contains papers presented at ML4CS 2024.
The conference received 111 submissions. The committee accepted 30 regular
papers, with each paper receiving at least 3 double-blind reviews. The proceedings
contain revised versions of the accepted papers. While revisions were expected to take
the referees’ comments into account, this was not enforced, and the authors bear full
responsibility for the content of their papers.
ML4CS 2024 was organized by the School of Information Science and Engineering
(School of Cyber Science and Technology), Zhejiang Sci-Tech University. Furthermore,
ML4CS 2024 was supported by Zhejiang Provincial Key Laboratory of Digital Fash-
ion and Data Governance, Zhejiang Key Laboratory of Artificial Intelligence of Things
(AIoT) Network and Data Security, and Zhejiang Provincial International Cooperation
Base for Science and Technology on Cloud Computing Security and Data Aggrega-
tion. The conference would not have been such a success without the support of these
organizations, and we sincerely thank them for their continued assistance and support.
We would also like to thank the authors who submitted their papers to ML4CS 2024,
and the conference attendees for their interest and support. We thank the Organizing
Committee for their time and effort dedicated to arranging the conference. This allowed
us to focus on the paper selection and deal with the scientific program. We thank the
Program Committee members and the external reviewers for their hard work in review-
ing the submissions; the conference would not have been possible without their expert
reviews. Finally, we thank the EasyChair system and its operators, for making the entire
process of managing the conference convenient.
General Chairs
Program Co-chairs
Publication Co-chairs
Publicity Co-chairs
Web Chair
Steering Committee
PC Members
Jianfei Sun1 , Qiang Gao2 , Cong Wu3(B) , Yuxian Li1 , Jiacheng Wang3 ,
and Dusit Niyato3
1
School of Computing and Information Systems, Singapore Management University,
Singapore 188065, Singapore
{jfsun,yuxianli}@smu.edu.sg
2
School of Computing and Artificial Intelligence, Southwestern University of
Finance and Economics, Chengdu 611130, China
[email protected]
3
College of Computing and Data Science, Nanyang Technological University,
Singapore 639798, Singapore
{cong.wu,jiacheng.wang,dniyato}@ntu.edu.sg
1 Introduction
The rapid expansion of Internet of Things (IoT) devices and the advent of 6G
technologies have led to a surge in computationally demanding applications, such
as augmented reality, autonomous vehicles, and real-time data analytics, which
often exceed the processing capabilities of user devices. To address this, server-
less multi-cloud edge computing has emerged as a promising solution, offering
flexible, scalable, and proximity-aware resources [1–3]. However, the heteroge-
neous and distributed nature of these environments as well as user mobility pose
significant challenges to secure and efficient resource allocation [4–6].
Significant progress has been made in addressing various aspects of resource
allocation challenges. Tang et al. [7] explored distributed task scheduling using
dueling double deep Q-networks for optimizing resource allocation in edge com-
puting. Yao et al. [8] developed an experience-sharing deep reinforcement learn-
ing (DRL) method for function offloading in serverless edge environments. In
terms of security, Min et al. [9] introduced a privacy-aware offloading scheme
using reinforcement learning for healthcare IoT devices, while Huang et al. [10]
proposed a secure and energy-efficient offloading strategy for service workflows
in mobile edge computing. Additionally, Ko et al. [11] and Cicconetti et al. [12]
investigated the integration of serverless computing with edge environments to
enhance scalability and resource utilization.
Despite these advancements, critical research gaps persist. First, most solu-
tions target single-cloud edge models, overlooking the complexities of multi-cloud
environments with heterogeneous nodes [13]. Second, while DRL shows promise
in resource management, current methods often rely on penalty-based rewards,
resulting in suboptimal policies and convergence issues [14]. Third, the inte-
gration of robust security with efficient resource allocation in serverless multi-
cloud environments is largely unexplored, particularly in the context of dynamic
threats and varying task sensitivities [15]. These gaps underscore the need for
a comprehensive framework that securely and efficiently allocates resources in
complex, heterogeneous serverless multi-cloud edge environments.
In this paper, we propose secure adaptive resource management and task
optimization (SARMTO), a novel framework for secure and efficient resource
allocation in serverless multi-cloud edge environments. At its core, SARMTO
employs an innovative action-constrained deep reinforcement learning (AC-DRL)
model that optimizes task offloading and resource allocation while respecting sys-
tem constraints. Using a Markov decision process (MDP) formulation, it models
the resource allocation problem as a sequential decision-making process, allowing
to adapt to changing conditions and balance objectives such as latency, energy
efficiency, and security.
SARMTO addresses key challenges through several novel design elements.
First, it introduces a flexible state representation and action space to manage
the heterogeneity of computing nodes and the complexity of multi-cloud envi-
ronments. Second, an action constraint mechanism enforces system limitations
during decision-making, avoiding the drawbacks of penalty-based approaches.
Third, an adaptive security mechanism dynamically adjusts protection levels
Secure Resource Allocation via Constrained Deep Reinforcement Learning 3
based on task sensitivity and threat landscape, ensuring robust security with-
out compromising efficiency. Finally, advanced techniques including prioritized
experience replay and dueling network architecture are integrated to enhance
learning efficiency in this complex decision space.
In summary, our contributions are as follows:
2 Related Work
E = {e1 , ..., eK }, and cloud computing (CC) nodes C = {c1 , ..., cM }. The net-
work topology is represented by a weighted graph G(V, E), where V = U ∪ E ∪ C
and E represents the set of communication links. Each link (v1 , v2 ) ∈ E is char-
acterized by its bandwidth Bv1 v2 and channel gain gv1 v2 .
Task Model: Each UD ui ∈ U generates a set of computational tasks Ti =
{τi1 , ..., τiL }. A task τij is defined by the tuple (Dij , Cij , Tij ), where Dij ∈ R+
denotes the input data size in bits, Cij ∈ Z+ represents the required CPU
cycles for computation, and Tij ∈ R+ is the delay constraint in seconds. The
relationship between Cij and Dij is modeled as:
where ζ(·) is a function that maps data size to required CPU cycles, based on
the application type as shown in Table 1 [14].
k
fm (in CPU cycles/second) and unit energy cost μkm . Similarly, each CC node
cj ∈ C is defined by its computational capacity fcj and unit energy cost μjc . These
parameters capture the heterogeneity of the computing resources available in the
network.
Our objective is to find an optimal task offloading and resource allocation
policy π ∗ that minimizes the overall system cost while satisfying security and
delay constraints. We introduce a binary decision variable xijk ∈ {0, 1} indicating
whether task τij is executed on node k (either MEC or CC). The optimization
problem is formulated as:
min C(π) = α1 T (π) + α2 E(π)
π
s.t.: xijk = 1, ∀i, j
k∈E∪C
Tij (π) ≤ Tij , ∀i, j (2)
xijk Cij ≤ fk , ∀k ∈ E ∪ C
i,j
4 Design of SARMTO
4.1 Overview
SARMTO is a cutting-edge framework designed to optimize secure resource
allocation in serverless multi-cloud edge environments.It employs an action-
constrained DRL model that integrates five components: an MDP model for
sequential decision-making, an action-constrained deep Q-network (AC-DQN)
to respect system constraints, a robust security mechanism for data protection,
an adaptive exploration strategy for efficient policy learning, and advanced per-
formance optimization techniques. This enables to dynamically adapt to het-
erogeneous computing conditions, effectively balancing task offloading, resource
allocation, security, and performance requirements in complex distributed com-
puting systems.
4.3 AC-DQN
AC-DQN algorithm extends traditional deep Q-learning by incorporating an
action constraint mechanism, which is crucial for enforcing system limitations
in complex serverless environments. This mechanism is implemented through an
action constraint function fconstraint : A → R, which maps actions to penalty
values:
−λ, if Ttotal (aj ) > Tmax ,
fconstraint (aj ) = (7)
0, otherwise,
where λ is a large positive constant (e.g., 1000), Ttotal (aj ) computes the total
delay for action aj , and Tmax is the maximum allowable delay. This function
effectively creates a discontinuity in the action-value space, steering the learning
process away from infeasible actions.
The Q-network Q(s, a; θ) is parameterized by θ and architecturally consists
of an input layer Rd → Rn1 , where d is the state dimension, followed by two
hidden layers Rn1 → Rn2 → Rn3 , and an output layer Rn3 → R|A| , where |A| is
the cardinality of the action space. The network is optimized by minimizing the
loss function:
L(θ) = E(s,a,r,s )∼D [(y − Q(s, a; θ))2 ], (8)
where D is the experience replay buffer and y = r + γ maxa (Q(s , a ; θ− ) +
fconstraint (a )) is the target Q-value. Here, θ− represents the parameters of a
target network, which is periodically updated to stabilize training.
The inclusion of fconstraint (a ) in the target Q-value calculation is a key
innovation of AC-DQN. It allows the constraint information to be propagated
through the temporal difference learning process, effectively shaping the Q-
function landscape to inherently avoid infeasible actions. This approach con-
trasts with methods that apply constraints only at the action selection stage, as
it embeds the constraints into the learned value function itself.
Algorithm 2 outlines the AC-DQN training process. The algorithm interleaves
interaction with the environment, storage of experiences, and neural network
updates. The action selection process (line 6) incorporates both the learned
Q-values and the constraint function, ensuring that the agent respects system
limitations even during exploration. The use of a separate target network (lines
10–12) and the periodic update of its parameters (lines 13–17) are standard
techniques in deep Q-learning to improve stability.
8 J. Sun et al.
ij + ωij
φenc ij + ωij
φdec
Tijsec = + , (9)
fk fk
sec
Eij = μk (φenc
ij + ωij ) + μk (φij + ωij ),
dec
(10)
where fk and fk are the CPU frequencies of nodes k and k respectively, and
μk and μk are their respective energy consumption coefficients per CPU cycle.
The security overhead is incorporated into the overall system cost function,
allowing SARMTO to make informed decisions that balance security require-
ments with performance and energy efficiency:
Ct = α1 (Tt + Tijsec ) + α2 (Et + sec
Eij ). (11)
i,j i,j
where |A| is the cardinality of the action space. The exploration rate is adapted
over time according to:
θ− ← τ θ + (1 − τ )θ− , (14)
where τ ∈ (0, 1] is the update rate. This approach stabilizes the learning targets
and mitigates the risk of divergence.
Prioritized Experience Replay: We extend the experience replay mechanism
by assigning priorities to experiences based on their temporal-difference (TD)
pα
error. The probability of sampling an experience ei is given as P (i) = kipα ,
1 k
where pi = |δi | + is the priority of experience i, δi is the TD-error, is a
small positive constant to ensure non-zero sampling probabilities, and α ∈ [0, 1]
determines the degree of prioritization.
Dueling Network Architecture: We decompose the Q-function into separate
value and advantage streams: ⎛ ⎞
1
Q(s, a; θ, α, β) = V (s; θ, β) + ⎝A(s, a; θ, α) − A(s, a ; θ, α)⎠ , (15)
|A|
a ∈A
Secure Resource Allocation via Constrained Deep Reinforcement Learning 11
5 Experimental Results
5.1 Evaluation Setup
We evaluate the impact of varying average task data sizes, ranging from 1 GB to
5 GB. As shown in Fig. 4(a), SARMTO consistently achieves the lowest system
cost across all data sizes. The performance gap widens as data size increases, with
SARMTO demonstrating approximately 26.3% lower system cost compared to
DQN at the 5 GB data size, highlighting its ability to efficiently manage larger
tasks through intelligent offloading decisions and security-aware resource allo-
cation. Interestingly, the DQN+NN approach shows improved performance for
medium-sized data (2–3 GB), slightly narrowing the gap with SARMTO; how-
ever, this advantage diminishes as data sizes increase.
Fig. 3. Performance under different number of tasks, including system cost (a) and
average delay (b)
Fig. 4. Performance under different data sizes, including system cost (a) and energy
consumption (b)
Secure Resource Allocation via Constrained Deep Reinforcement Learning 13
6 Conclusion
This paper has introduced SARMTO, a framework for secure resource allocation
in serverless multi-cloud edge environments, leveraging action-constrained deep
reinforcement learning. SARMTO effectively balances task offloading, resource
allocation, and security requirements through its innovative algorithm, adap-
tive security mechanisms, and performance optimization techniques. Extensive
simulations demonstrated its superior performance, consistently outperforming
existing methods in system cost, energy efficiency, and adaptability across var-
ious scenarios. As edge computing evolves, it offers a promising direction for
14 J. Sun et al.
References
1. Wang, P., Di, B., Song, L., Jennings, N.R.: Multi-layer computation offloading in
distributed heterogeneous mobile edge computing networks. IEEE Trans. Cognitive
Commun. Networking 8(2), 1301–1315 (2022)
2. Zhang, R., Zhang, L., Wu, Q., Zhou, J.: Secure channel establishment scheme
for task delivery in vehicular cloud computing. IEEE Trans. Inf. Forensics Secur.
(2024)
3. Wang, L., Wu, W., Zhou, F., Yang, Z., Qin, Z., Wu, Q.: Adaptive resource alloca-
tion for semantic communication networks. IEEE Trans. Commun. (2024)
4. Zhang, L., Zou, Y., Wang, W., Jin, Z., Su, Y., Chen, H.: Resource allocation and
trust computing for blockchain-enabled edge computing system. Comput. Secur.
105, 102249 (2021)
5. Liang, F., et al.: Resource allocation and workload scheduling for large-scale dis-
tributed deep learning: a survey, arXiv preprint arXiv:2406.08115 (2024)
6. Zhang, M., Chen, S., Shen, J., Susilo, W.: Privacyeafl: privacy-enhanced aggre-
gation for federated learning in mobile crowdsensing. IEEE Trans. Inf. Forensics
Secur. (2023)
7. Tang, Q., et al.: Distributed task scheduling in serverless edge computing networks
for the internet of things: a learning approach. IEEE Internet Things J. 9(20),
19634–19648 (2022)
8. Yao, X., Chen, N., Yuan, X., Ou, P.: Performance optimization of serverless edge
computing function offloading based on deep reinforcement learning. Futur. Gener.
Comput. Syst. 139, 74–86 (2023)
9. Min, M., et al.: Learning-based privacy-aware offloading for healthcare IoT with
energy harvesting. IEEE Internet Things J. 6(3), 4307–4316 (2018)
10. Huang, B., et al.: Security modeling and efficient computation offloading for service
workflow in mobile edge computing. Futur. Gener. Comput. Syst. 97, 755–774
(2019)
11. Ko, H., Pack, S., Leung, V.C.: Performance optimization of serverless computing
for latency-guaranteed and energy-efficient task offloading in energy-harvesting
industrial IoT. IEEE Internet Things J. 10(3), 1897–1907 (2021)
12. Cicconetti, C., Conti, M., Passarella, A.: Architecture and performance evaluation
of distributed computation offloading in edge computing. Simul. Model. Pract.
Theory 101, 102007 (2020)
Secure Resource Allocation via Constrained Deep Reinforcement Learning 15
13. Grozev, N., Buyya, R.: Multi-cloud provisioning and load distribution for three-tier
applications. ACM Trans. Auton. Adapt. Syst. 9(3), 1–21 (2014)
14. Zhang, H., Wang, J., Zhang, H., Bu, C.: Security computing resource allocation
based on deep reinforcement learning in serverless multi-cloud edge computing.
Futur. Gener. Comput. Syst. 151, 152–161 (2024)
15. Elgendy, I.A., Zhang, W.-Z., Zeng, Y., He, H., Tian, Y.-C., Yang, Y.: Efficient and
secure multi-user multi-task computation offloading for mobile-edge computing in
mobile IoT networks. IEEE Trans. Network Serv. Manag. 17(4), 2410–2422 (2020)
16. Xu, X., et al.: Service offloading with deep q-network for digital twinning-
empowered internet of vehicles in edge computing. IEEE Trans. Ind. Inform. 18(2),
1414–1423 (2020)
17. Chen, Q., Kuang, Z., Zhao, L.: Multiuser computation offloading and resource
allocation for cloud-edge heterogeneous network. IEEE Internet Things J. 9(5),
3799–3811 (2021)
Efficient Two-Party Privacy-Preserving
Ridge and Lasso Regression via SMPC
Zongxiang Yi1 , Bo Li2 , Wanhui Zhang3(B) , Zhiqiang Lin4 , and Lishan Ke4
1
School of Mathematics and Systems Science, Guangdong Polytechnic Normal
University, Guangzhou 510665, China
2
Guangzhou Construction Co., Ltd., Guangzhou 510030, China
3
Guangzhou Installation Group Co., Ltd., Guangzhou 510030, China
[email protected]
4
School of Mathematics and Information Science, Guangzhou University,
Guangzhou 510006, China
{linzhiqiang,kelishan}@gzhu.edu.cn
1 Introduction
With the rapid development of artificial intelligence, machine learning, as a key
enabling technology, has been widely applied in various fields such as medi-
cal diagnosis, finance, transportation, and e-commerce. By training models to
extract patterns from data, computer systems can perform a range of intelli-
gent tasks, including image recognition, speech recognition, and natural language
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 16–42, 2025.
https://doi.org/10.1007/978-981-96-4566-4_2
Privacy-Preserving Ridge and Lasso Regression 17
processing. In this context, the extensive use of large datasets for model train-
ing to uncover valuable information has become an important resource driving
advancements in information technology. However, the widespread collection and
use of data have also brought privacy issues to the forefront. Machine learning
models typically require access to vast amounts of user data during the train-
ing process, which undoubtedly increases the risk of privacy breaches. The key
to resolving this contradiction lies in adopting appropriate privacy protection
technologies and strategies that effectively utilize decentralized data resources
while ensuring data privacy and security. Therefore, researching how to protect
data privacy while maintaining model performance has become a pressing issue.
Privacy-preserving machine learning employs various technical methods, such as
differential privacy, homomorphic encryption, and secure multi-party computa-
tion, to ensure the effectiveness and accuracy of machine learning models while
safeguarding user data privacy.
Machine learning can encounter overfitting issues when processing data that
contains noise, has imbalanced sample categories, or is high-dimensional, which
severely impacts the model’s generalization ability. The introduction of L2 regu-
larization and L1 regularization has significantly enhanced the model’s general-
ization capability. L2 regularized linear regression, also known as Ridge regres-
sion, was first proposed by statisticians Arthur E. Hoerl and Robert W. Ken-
nard as a method to address multicollinearity issues by adding the square of
the weights to the loss function to prevent model overfitting [25]. L1 regular-
ized linear regression, known as Lasso regression, was introduced by statisti-
cian Robert Tibshirani in a 1996 paper [39]. Lasso regression performs fea-
ture selection by incorporating the absolute sum of the weights into the loss
function, thereby improving the model’s sparsity and interpretability. Addition-
ally, Privacy-Preserving Logistic Regression (PPLR), as a significant direction
in privacy-preserving machine learning, combines L2 regularized linear regres-
sion and L1 regularized logistic regression with privacy protection technologies.
This approach can enhance the model’s generalization ability while effectively
preventing the leakage of private data from participating parties, thereby train-
ing superior predictive models from multi-party data and endowing the research
on privacy protection in regularized logistic regression algorithms with practical
significance.
The current mainstream privacy-preserving technologies include differential
privacy [3, 17, 18, 33, 43, 46], homomorphic encryption [9, 35], and secure multi-
party computation [6, 8, 11, 14, 15, 20, 30, 31, 37, 42, 45]. However, differential pri-
vacy may not achieve the same accuracy as models using plaintext, and homo-
morphic encryption involves excessive computational overhead. Therefore, we
focus on privacy computing solutions based on Secure Multi-Party Computation
(SMPC). The landscape of SMPC has seen considerable advancements, partic-
ularly in optimizing communication overhead and computational efficiency for
privacy-preserving regression analysis and machine learning. Several key stud-
ies have contributed to this evolving field, each addressing different aspects of
SMPC and enhancing the practicality of secure computations.
18 Z. Yi et al.
2 Preliminaries
2.1 Notations
The notations used in the paper are listed as follows:
1. N: The set of all natural integers.
2. Partyi : The participants of the schemes, i = 0, 1.
3. Zn : The residual class ring containing 0, 1, . . . , n − 1.
4. L: The bit length of basic data type1 .
1
It is not hard to transform the value of original data into Z2L .
20 Z. Yi et al.
5. PRNG(s, i): The i-th output of the pseudo random number generator seeded
by s.
6. ⊥: The empty output or input for algorithms.
7. (output0; output1) = A(input0, input1): Party0 sends input0 to the two-party
algorithm A and then receives output0 while Party1 send input1 to the same
two-party algorithm A and then receives output1.
Algorithm 1: Secret-Share
Input: (x; ⊥)
Output: ([x]0 ; [x]1 )
begin
Party0 randomly selects r ∈ Z2L and sends r to Party1 ;
Party0 outputs [x]0 = x − r;
Party1 outputs [x]1 = r.
Algorithm 2: Recovery
Input: ([x]0 ; [x]1 )
Output: (x; x)
begin
Party0 sends [x]0 to Party1 ;
Party1 sends [x]1 to Party0 ;
Party0 receives [x]1 and outputs x = [x]0 + [x]1 ;
Party1 receives [x]0 and outputs x = [x]0 + [x]1 .
Privacy-Preserving Ridge and Lasso Regression 21
Algorithm 3: Oblivious-Transfer
Input: (k0 , k1 ; σ) where k0 , k1 ∈ {0, 1}∗ and σ ∈ {0, 1}
Output: (⊥; kσ )
c = ab
= ([a]0 + [a]1 )([b]0 + [b]1 )
= [a]0 [b]0 + [a]0 [b]1 + [a]1 [b]0 + [a]1 [b]1
As shown in above equations, if participant Party0 wants to obtain the secret
share [c]0 , and participant Party1 wants to obtain the secret share [c]1 , then
[a]0 [b]0 and [a]1 [b]1 can be computed directly by the participants themselves,
while for [a]0 [b]1 and [a]1 [b]0 , the participants Party0 and Party1 need to use
Oblivious Transfer (OT) technology to obtain their corresponding secret shares.
This problem can be solved by the following algorithm. With the help of algo-
rithm Product-Secret-Sharing, multiplication triplet can be generated.
Now with the help of multiplication triplets (a, b, c), the secure multiplication
based on arithmetic secret sharing can be achieved.
22 Z. Yi et al.
Algorithm 4: Product-Secret-Sharing
Input: (X = XM −1 XM −2 · · · X0 ; Y = YM −1 YM −2 · · · Y0 )
Output: ([XY ]0 ; [XY ]1 )
begin
for i = 0, 1, . . . , M − 1 do
Party1 generates a random number Ri0 and compute Ri1 = 2i Y + Ri0 ;
Party1 and Party0 run algorithm Oblivious-Transfer({Ri0 , Ri1 ; Xi }) and
Party0 obtains RiXi ;
−1 Xi
Party0 outputs [XY ]0 = M R ;
i=0−1 i0
Party1 outputs [XY ]1 = M i=0 Ri ;
Algorithm 5: Multiplication-Triple-Generation
Input: (⊥; ⊥)
Output: ([a]0 , [b]0 , [c]0 ;[a]1 , [b]1 , [c]1 ) such that c = ab
begin
for i = 0, 1 do
Partyi randomly generates [a]i , [b]i ;
Party0 and Party1 run Product-Secret-Sharing([a]0 ; [b]1 ) and obtain
[[a]0 [b]1 ]0 and [[a]0 [b]1 ]1 , respectively;
Party0 and Party1 run Product-Secret-Sharing([b]0 ; [a]1 ) and obtain
[[b]0 [a]1 ]0 and [[b]0 [a]1 ]1 , respectively;
for i = 0, 1 do
Partyi output [c]i = [[a]0 [b]1 ]i + [[b]0 [a]1 ]i + [a]i [b]i ;
Algorithm 6: Secret-Sharing-Multiplication
Input: ([x]0 , [y]0 , [a]0 , [b]0 , [c]0 ; [x]1 , [y]0 , [a]1 , [b]1 , [c]1 ) where c = ab
Output: ([xy]0 ; [xy]1 )
begin
if ([a]0 = ⊥) ∧ ([a]1 = ⊥) then
for i = 0, 1 do
Partyi runs algorithm Multiplication-Triple-Generation with
input ⊥ and obtain [a]i , [b]i , [c]i ;
for i = 0, 1 do
Partyi computes [e]i = [x]i − [a]i , [f ]i = [y]i − [b]i ;
for j ∈ {e, f } do
Party0 and Party1 run algorithm Recovery([j]0 ; [j]0 ) and both
obtain j;
for i = 0, 1 do
Partyi outputs [xy]i = e[y]i + f [x]i + [c]i − ef i;
Privacy-Preserving Ridge and Lasso Regression 23
Linear Regression. The linear regression model equation aims to predict the
output y as a function of the input x and a weight w. The model equation is
given by:
y = xw
Here, y is the predicted value, x is the input variable, and w is the weight
parameter that we aim to learn through training. To obtain the optimal weight
parameter w, it usually use batch gradient descent method which involves a
loss function and a update function to update current found optimal weight
parameter w.
The loss function measures the difference between the predicted output of
the model and the actual target values. For linear regression, a common choice
is the Mean Squared Error (MSE) loss function, which is computed as:
m
1
J(w) = (xi w − yi )2
2 m i=1
In this equation, m is the number of observations, wxi is the predicted value
for the i-th observation, yi is the actual target value for the i-th observation,
and J(w) represents the loss function computed over all observations.
To minimize the loss function, we use an optimization algorithm called gra-
dient descent. The update function for the weight parameter w at each iteration
is given by:
∂J(w)
w =w−α
∂w
Here, α is the learning rate, a hyperparameter that defines the step size
during the update. The partial derivative of the loss function with respect to the
weight w, ∂J(w)
∂w , is computed as:
m
∂J(w) 1
= xi (xi w − yi )
∂w m i=1
Integrating the derivative back into the update function, we get:
m
1
w =w−α xi (xi w − yi )
m i=1
By iteratively applying this update rule, the weight w is adjusted to minimize
the loss J(w), moving towards the optimal weight that best fits the data under
24 Z. Yi et al.
the linear regression model. Algorithm 7 is a batch gradient descent for linear
regression with bias term.
Algorithm 7: Plain-Linear-Regression
Input: Training dataset (xi , yi )m
i=1 , learning rate α, number of epochs
E, batch size B
Output: Weight parameter w after each update
Add a column of ones to X to form X̃a ;
Initialize weight w;
for e = 1, 2, · · · , E do
Shuffle the dataset X̃;
for each batch index B ⊆ [0, 1, . . . , m − 1] of size B do
Compute the gradient: ∇J(w) = B1 X̃TB (X̃B w − YB );
Update the weight: w = w − α∇J(w);
Output the updated weight w;
a
A column of ones is added to the input matrix xx, resulting in an augmented
matrix X̃. This allows the linear regression model to include the bias term bb
by treating it as part of the weight vector w.
For more details of regression models in machine learning, please refer to [29].
Remark 1. The L2 regularization term is putted in the update process and drop
a coefficient of α.
The loss function J(w) includes an L1 regularization terms |wj | which are non-
differentiable. Therefore, it is impossible to use the classical gradient descent
method to solve for the function parameters that minimize the loss function. In
this paper, we will use the proximal gradient descent (PGD) method for solv-
ing [28]. For the non-differentiable terms, we use the proximal gradient method
to solve for the weight parameter w. The update rule for the weight parameter
w is: m
1
w = Prox w − α xi (xi w − yi ) , (2)
m i=1
where the proximal function Prox associated with regularization parameter λ, is
defined as: ⎧
⎪
⎨wi − λ if wi > λ,
Proxλ (wi ) = 0 if |wi | ≤ λ,
⎪
⎩
wi + λ if wi < −λ.
26 Z. Yi et al.
Algorithm 9: Plain-Lasso-Regression
Input: Training dataset (xi , yi )m i=1 , regularization parameter λ, learning
rate α, number of epochs E, batch size B
Output: Weight parameter w for each batch
Add a column of ones to X to form X̃;
Initialize weight w;
for e = 1, 2, · · · , E do
Shuffle the dataset X̃;
for each batch B ⊆ [0, 1, . . . , m − 1] of size B do
Compute the gradient: ∇J(w) = B1 X̃TB (X̃B w − YB );
Update the weight for the first time: w = w − α∇J(w);
Update the weight for the second time: w = Proxλ (w);
Output the intermediate weight w;
.
Privacy-Preserving Ridge and Lasso Regression 27
return λt ;
28 Z. Yi et al.
In Algorithm 11, one party select partial candidate parameters λi for the
other party who selects the final parameter. The candidate parameters are
selected based on the Mean Squared Error (MSE) on the test dataset using
cross-validation method. The Algorithm 10 is run offline, i.e., the two parties
do not communicate with each other during the parameter selection process. In
practice, the candidate parameters λ0 can be determined based on experience.
To standardize the data, each party Partyi should add a column of ones to Xi ,
creating X̃i . They then use a secret sharing scheme (Algorithm 1) to distribute
X̃i and Yi to the other party Partyj , represented as X̃i = [X̃i ]0 + [X̃i ]1 and
Yi = [Yi ]0 + [Yi ]1 . Each party Partyi combines [X̃i ]i , [X̃j ]i , [Yi ]i , and [Yj ]i to
for i = 0, 1 do
Partyi runs algorithm Recovery with input [w]i and obtain wi ;
Partyi outputs wi ;
If for any virtual adversary A attacking the real protocol π, there exists a
simulator S attacking the ideal protocol Fml , then in the running environment
Z, the following holds:
M SE − M SE
RM SE = (3)
M SE
where M SE represents the average loss of the training data under the privacy-
preserving framework, and M SE represents the average loss of the training data
under plaintext calculation. The definition of the average loss is consistent with
the relevant definition in Subsubsect. 2.2 and Algorithm 10, which is the average
of the squared differences between all predicted labels and the corresponding
true labels of the test data. The closer the RM SE value is to zero, the more
similar the training results of the model under the framework are to the plaintext
training results.
The participants Party0 and Party1 used the privacy-preserving Ridge regres-
sion algorithm framework proposed in this paper to train on both synthetic sam-
ple datasets and publicly available datasets from the UCI database, obtaining
and testing Ridge regression models. Simultaneously, plaintext Ridge regression
training and testing were also performed under the same conditions. Table 1
records the test results. All the Rmse values in the table are on the order of or
smaller than 10−4 , which suggests that the difference between the results of the
privacy-preserving algorithm framework and the plaintext algorithm is minimal.
Essentially, no information was lost during the privacy-preserving learning pro-
cess, and the Ridge regression model achieved almost the same accuracy as the
one under plaintext conditions.
Fig. 1. Time and communication for privacy-preserving Ridge regression under differ-
ent sample sizes.
Figure 1 shows the time and communication of the running with different param-
eters of the privacy-preserving Ridge regression framework with different sample
sizes. Participants Party0 and Party1 input their synthetic sample datasets with
sample sizes of 1000, 10000 and 100000, and feature count of 10, into the frame-
work for training Ridge regression models. The test results are shown in Fig. 1. It
can be observed that both the running time and communication increase as the
sample size increases. In the offline phase, especially for the group with a sample
size of 100000, the communication reaches 1738MB. However, since no actual
data participation is required during the offline phase, it can be pre-generated,
making the cost acceptable.
For Difference Numbers of Features
Participants Party0 and Party1 input their respective synthetic datasets, origi-
nally with 10,000 samples and with feature numbers of 10, 20, 30, and 40, into the
34 Z. Yi et al.
The Boston Housing Prices dataset and the High School Student Perfor-
mance dataset were input into this paper’s computational framework to train
Ridge regression models and evaluate their performance, as shown in Table 2.
Additionally, baseline Ridge regression was performed under the same conditions
for comparison.
Privacy-Preserving Ridge and Lasso Regression 35
The test results are shown in Table 2. For the offline phase of the Boston
Housing Prices dataset, the total time was 5.76 s, and for the High School Student
Performance dataset with a training level plan, the total cost was 39.76 s. Since
the regularization parameter selection step was not included, the online phase
times in Table 2 exclude the time for this selection.
Comparison. The comparative results shown in Table 3 indicate that our
method has certain advantages in terms of runtime, and although the communi-
cation increases, the majority occurs during the offline phase. If conditions allow
for a reduction in the cubic term scaling in the offline phase, this framework
provides a good choice for privacy-preserving Ridge regression.
Theorem 2. The protocol (Algorithm 13) is UC-secure with respect to the func-
tionality (Algorithm 9) against semi-honest adversaries.
Proof. The proof is similar with that of the privacy-preserving Ridge regres-
sion algorithm. The only difference is that the simulator S outputs the weight
w in each batch which is required by the ideal functionality. Hence it is still
indistinguishable for any environment Z. Therefore, this protocol is UC-secure.
Remark 5. There is an attack called leakage from gradients [44, 48–50] in differ-
entiable learning model such as linear regression. This attack exploits the fact
that the gradients of the loss function can be reconstructed from the model
parameters and the input data. In Algorithm 13, Partyi knows the gradient of
the loss function with respect to the model parameters, which can be used to
Privacy-Preserving Ridge and Lasso Regression 37
infer the input data. However, the functionality of Algorithm 9 shows that it is
acceptable. Therefore, we can still use the privacy-preserving Ridge regression
algorithm to train the model with respect to Algorithm 9.
Fig. 3. Time and Communication vs Difference Numbers of Samples and Features for
Privacy-preserving Lasso regression
The results of the model testing are shown in Fig. 3. The left graph illustrates
the running time and communication of the privacy-preserving Lasso regression
protocol during the online phase under different sample sizes, and the right graph
shows the same under different feature sizes. It can be seen that the running time
and communication of this framework vary as the number of features in the
training dataset increases, and there is also a linear dependence on the sample
size. From the results, it is evident that the privacy-preserving Lasso regres-
sion learning framework has a greater overhead in the online phase compared
to the privacy-preserving Ridge regression learning framework when training on
the same dataset. For instance, when training on a synthetic dataset of size
(100000, 10), the former requires 18.87 s and produces 15.57 MB of communica-
tion overhead, while the latter requires 17.39 s and 15.41 MB. This indicates that
if the training dataset does not contain irrelevant features and there is no need
for feature selection or training a sparse model, the privacy-preserving Ridge
regression learning framework is a better choice.
Comparison. In [40], the privacy-preserving Lasso regression scheme imple-
mented utilizing the MPyC framework as delineated by van Egmond et al. and
based on Shamir’s secret sharing, necessitates approximately 2000 s to culmi-
nate the entire process on a dataset comprising 10, 000 samples, each with 10
features. Of this duration, more than 1500 s are allocated to the training phase.
In stark contrast, our proposed scheme exhibits remarkable efficiency, requiring
a mere 5.57 s for training on a dataset of equivalent size, thereby underscoring
the substantial performance enhancement our methodology offers.
4 Conclusion
In this paper, we present two privacy-preserving regression algorithms-one for
Ridge regression and one for Lasso regression-based on Secure Multi-Party Com-
Privacy-Preserving Ridge and Lasso Regression 39
Acknowledgments. The research of Bo Li, Wanhui Zhang and Zhiqiang Lin was
supported by the Guangzhou Municipal Construction Group Co., Ltd. Technology
Plan Project (2022-KJ023). The research of Lishan Ke was supported by the City
School Joint Funding Project of Guangzhou City (No. 2023A03J0117). The research
of Zongxiang Yi was supported by the Talent Special Project of Research Project
of Guangdong Polytechnic Normal University [Grant No. 2021SDKYA051], Scientific
Research Capacity Improvement Project of the Doctoral Program Construction Unit
of Guangdong Polytechnic Normal University [Grant No. 22GPNUZDJS31].
References
1. Agrawal, N., Shahin Shamsabadi, A., Kusner, M.J., Gascón, A.: Quotient: two-
party secure neural network training and prediction. In: Proceedings of the 2019
ACM SIGSAC Conference on Computer and Communications Security, pp. 1231–
1247 (2019)
2. Badrinarayanan, S., Masny, D., Mukherjee, P.: Efficient and tight oblivious transfer
from PKE with tight multi-user security. In: International Conference on Applied
Cryptography and Network Security, pp. 626–642. Springer, Cham (2022). https://
doi.org/10.1007/978-3-031-09234-3_31
3. Beaulieu-Jones, B.K., et al.: Privacy-preserving generative deep neural networks
support clinical data sharing. Circ. Cardiovasc. Qual. Outcomes 12(7), e005122
(2019). https://doi.org/10.1161/circoutcomes.118.005122
4. Beaver, D.: Efficient multiparty protocols using circuit randomization. In: Advances
in Cryptology-CRYPTO 1991: Proceedings 11, pp. 420–432. Springer (1992)
5. Bogdanov, D., Laur, S., Willemson, J.: Sharemind: a framework for fast privacy-
preserving computations. In: Jajodia, S., Lopez, J. (eds.) ESORICS 2008. LNCS,
vol. 5283, pp. 192–206. Springer, Heidelberg (2008). https://doi.org/10.1007/978-
3-540-88313-5_13
40 Z. Yi et al.
6. Bogdanov, D., Niitsoo, M., Toft, T., Willemson, J.: High-performance secure multi-
party computation for data mining applications. Int. J. Inf. Secur. 11, 403–418
(2012). https://doi.org/10.1007/s10207-012-0177-2
7. Branco, P., Fiolhais, L., Goulão, M., Martins, P., Mateus, P., Sousa, L.: Roted:
random oblivious transfer for embedded devices. IACR Trans. Cryptogr. Hardw.
Embed. Syst. 215–238 (2021)
8. Byali, M., Chaudhari, H., Patra, A., Suresh, A.: Flash: fast and robust frame-
work for privacy-preserving machine learning. Cryptology ePrint Archive (2019).
https://doi.org/10.2478/popets-2020-0036
9. Byun, J., Lee, W., Lee, J.: Parameter-free he-friendly logistic regression. Adv.
Neural. Inf. Process. Syst. 34, 8457–8468 (2021)
10. Canetti, R.: Universally composable security: a new paradigm for cryptographic
protocols. In: Proceedings 42nd IEEE Symposium on Foundations of Computer
Science, pp. 136–145. IEEE (2001). https://doi.org/10.1109/sfcs.2001.959888
11. Chaudhari, H., Rachuri, R., Suresh, A.: Trident: efficient 4PC framework for pri-
vacy preserving machine learning. arXiv preprint arXiv:1912.02631 (2019). https://
doi.org/10.14722/ndss.2020.23005
12. Chen, Y., Zhang, J.: The role of pseudorandomness in cryptographic protocols.
ACM Trans. Inf. Syst. Secur. 26(1), 1–25 (2023). https://doi.org/10.1145/3602978
13. Cortez, P., Silva, A.M., et al.: Using data mining to predict secondary school stu-
dent performance. In: Proceedings of 5th Annual Future Business Technology Con-
ference (2008)
14. De Cock, M., Dowsley, R., Nascimento, A., Railsback, D., Shen, J., Todoki, A.: Fast
secure logistic regression for high dimensional gene data. In: Privacy in Machine
Learning (PriML2019). Workshop at NeurIPS, pp. 1–7 (2019)
15. Demmler, D., Schneider, T., Zohner, M.: Aby-a framework for efficient mixed-
protocol secure two-party computation. In: NDSS (2015). https://doi.org/10.
14722/ndss.2015.23113
16. Dua, D., Graff, C.: UCI machine learning repository: housing data set (1997).
http://archive.ics.uci.edu/ml/datasets/housing. Accessed 27 Sept 2024
17. Dwork, C., McSherry, F., Nissim, K., Smith, A.: Calibrating noise to sensitivity in
private data analysis. In: Halevi, S., Rabin, T. (eds.) TCC 2006. LNCS, vol. 3876,
pp. 265–284. Springer, Heidelberg (2006). https://doi.org/10.1007/11681878_14
18. Dwork, C., Roth, A., et al.: The algorithmic foundations of differential privacy.
Found. Trends R Theor. Comput. Sci. 9(3–4), 211–407 (2014). https://doi.org/10.
1561/9781601988195
19. Gascón, A., Schoppmann, P., Balle, B., Raykova, M., Doerner, J., Zahur, S., Evans,
D.: Privacy-preserving distributed linear regression on high-dimensional data. In:
Proceedings on Privacy Enhancing Technologies, pp. 345–364 (2017)
20. Giacomelli, I., Jha, S., Joye, M., Page, C.D., Yoon, K.: Privacy-preserving
ridge regression with only linearly-homomorphic encryption. In: Preneel, B., Ver-
cauteren, F. (eds.) ACNS 2018. LNCS, vol. 10892, pp. 243–261. Springer, Cham
(2018). https://doi.org/10.1007/978-3-319-93387-0_13
21. Goldreich, O.: Foundations of Cryptography: Basic Tools, vol. 1. Cambridge
University Press (2001). ISBN 9780521791725. https://www.cambridge.org/core/
books/foundations-of-cryptography/0D0CAE9CA377D645DB9BDCCD743F7B27
22. Goldreich, O., Krawczyk, H., Luby, M.: On the existence of pseudorandom gener-
ators. In: Proceedings of the 18th Annual ACM Symposium on Theory of Com-
puting, pp. 12–24. ACM (1986). https://doi.org/10.1109/sfcs.1988.21917
Privacy-Preserving Ridge and Lasso Regression 41
23. Harrison, D., Rubinfeld, D.L.: Boston housing dataset. StatLib - Carnegie Mel-
lon University (1978). https://lib.stat.cmu.edu/datasets/boston. Accessed 27 Sept
2024
24. Harrison, D., Jr., Rubinfeld, D.L.: Hedonic prices and the demand for clean
air. J. Environ. Econ. Manag. 5(1), 81–102 (1978). https://doi.org/10.1016/0095-
0696(78)90006-2
25. Hoerl, A.E., Kennard, R.W.: Ridge regression: biased estimation for nonorthogonal
problems. Technometrics 12(1), 55–67 (1970). https://doi.org/10.2307/1267351
26. Intangible, D.: Racist data destruction (2020). https://medium.com/
@docintangible/racist-data-destruction-113e3eff54a8. Accessed 27 Sept 2024
27. Jain, A., Singh, R.: A survey of pseudorandom number generators in cryptogra-
phy. Cryptography 6(2), 12 (2022). https://doi.org/10.3390/cryptography6020012.
https://www.mdpi.com/2410-387X/6/2/12
28. Klosa, J., Simon, N., Westermark, P.O., Liebscher, V., Wittenburg, D.: Seagull:
lasso, group lasso and sparse-group lasso regularization for linear regression models
via proximal gradient descent. BMC Bioinform. 21, 1–8 (2020). https://doi.org/
10.1186/s12859-020-03725-w
29. Kumar, S., Bhatnagar, V.: A review of regression models in machine learning. J.
Intell. Syst. Comput. 3(1), 40–47 (2022)
30. Mohassel, P., Rindal, P.: ABY3: a mixed protocol framework for machine learning.
In: Proceedings of the 2018 ACM SIGSAC Conference on Computer and Commu-
nications Security, pp. 35–52 (2018)
31. Mohassel, P., Zhang, Y.: Secureml: a system for scalable privacy-preserving
machine learning. In: 2017 IEEE Symposium on Security and Privacy (SP), pp.
19–38. IEEE (2017). https://doi.org/10.1109/sp.2017.12
32. Patra, A., Schneider, T., Suresh, A., Yalame, H.: {ABY2. 0}: improved {Mixed-
Protocol} secure {Two-Party} computation. In: 30th USENIX Security Symposium
(USENIX Security 2021), pp. 2165–2182 (2021)
33. Raff, E., Khanna, A., Lu, F.: Scaling up differentially private lasso regularized logis-
tic regression via faster frank-wolfe iterations. In: Advances in Neural Information
Processing Systems, vol. 36 (2024)
34. Rouhani, B.D., Riazi, M.S., Koushanfar, F.: Deepsecure: scalable provably-secure
deep learning. In: Proceedings of the 55th Annual Design Automation Conference,
pp. 1–6 (2018). https://doi.org/10.1109/dac.2018.8465894
35. Sarkar, E., Chielle, E., Gursoy, G., Chen, L., Gerstein, M., Maniatakos, M.:
Privacy-preserving cancer type prediction with homomorphic encryption. Sci. Rep.
13(1), 1661 (2023)
36. Scikit-learn. Boston housing dataset - scikit-learn API (2024). https://scikit-learn.
org/1.1/modules/generated/sklearn.datasets.load_boston.html. Accessed 27 Sept
2024
37. Shi, H., et al.: Secure multi-party computation grid logistic regression (SMAC-
GLORE). BMC Med. Inform. Decis. Mak. 16, 175–187 (2016). https://doi.org/10.
1186/s12911-016-0316-1
38. TensorFlow. Boston housing dataset - tensorflow keras API (2024). https://www.
tensorflow.org/api_docs/python/tf/keras/datasets/boston_housing/load_data.
Accessed 27 Sept 2024
39. Tibshirani, R.: Regression shrinkage and selection via the lasso. J. R. Stat. Soc.
Ser. B (Stat. Methodol.) 58(1), 267–288 (1996). https://doi.org/10.1111/j.2517-
6161.1996.tb02080.x
40. van Egmond, M.B., et al.: Privacy-preserving dataset combination and lasso regres-
sion for healthcare predictions. BMC Med. Inform. Decis. Making 21, 1–16 (2021)
42 Z. Yi et al.
41. Veugen, T., Kamphorst, B., van de L’Isle, N., van Egmond, M.B.: Privacy-
preserving coupling of vertically-partitioned databases and subsequent training
with gradient descent. In: Dolev, S., Margalit, O., Pinkas, B., Schwarzmann, A.
(eds.) CSCML 2021. LNCS, vol. 12716, pp. 38–51. Springer, Cham (2021). https://
doi.org/10.1007/978-3-030-78086-9_3
42. Wagh, S., Tople, S., Benhamouda, F., Kushilevitz, E., Mittal, P., Rabin, T.: Fal-
con: honest-majority maliciously secure framework for private deep learning. arXiv
preprint arXiv:2004.02229 (2020). https://doi.org/10.2478/popets-2021-0011
43. Wang, P., Zhang, H.: Differential privacy for sparse classification learning. Neuro-
computing 375, 91–101 (2020). https://doi.org/10.1016/j.neucom.2019.09.020
44. Wei, W., Liu, L.: Gradient leakage attack resilient deep learning. IEEE Trans. Inf.
Forensics Secur. 17, 303–316 (2021). https://doi.org/10.1109/tifs.2021.3139777
45. Wu, Y., Jiang, X., Kim, J., Ohno-Machado, L.: G rid binary LO gistic RE gression
(GLORE): building shared models without sharing data. J. Am. Med. Inform.
Assoc. 19(5), 758–764 (2012). https://doi.org/10.1136/amiajnl-2012-000862
46. Xie, L., Lin, K., Wang, S., Wang, F., Zhou, J.: Differentially private generative
adversarial network. arXiv preprint arXiv:1802.06739 (2018)
47. Yadav, V.K., Andola, N., Verma, S., Venkatesan, S.: A survey of oblivious transfer
protocol. ACM Comput. Surv. (CSUR) 54(10s), 1–37 (2022). https://doi.org/10.
1145/3503045
48. Zhao, B., Mopuri, K.R., Bilen, H.: IDLG: improved deep leakage from gradients.
arXiv preprint arXiv:2001.02610 (2020)
49. Zheng, Y.: Dropout against deep leakage from gradients. arXiv preprint
arXiv:2108.11106 (2021)
50. Zhu, L., Liu, Z., Han, S.: Deep leakage from gradients. In: Advances in Neural
Information Processing Systems, vol. 32 (2019). https://doi.org/10.1007/978-3-
030-63076-8_2
A Decentralized Bitcoin Mixing Scheme
Based on Multi-signature
1 Introduction
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 43–55, 2025.
https://doi.org/10.1007/978-981-96-4566-4_3
44 M. Shen et al.
out unauthorized transactions on his own because he needs the signature confir-
mation of other participants [6]. This effectively prevents any single participant
from attempting to spoof or unilaterally manipulate funds. The Bitcoin mixing
scheme adopted in this paper ingeniously combines multi-signature technology,
allowing multiple parties to jointly control funds, effectively mitigating the risk
of a single account being compromised, and enhancing transaction anonymity
[7]. The Bitcoin blockchain maintains its core decentralization characteristics by
integrating multi-signatures with decentralized signature protocols while ensur-
ing transaction privacy [8]. In this paper, we propose a Bitcoin mixing scheme.
Specifically, the contributions of this paper are as follows:
We propose a novel mechanism for generating a public address that ensures
that all members of the mixing group participate in the process of generating
public addresses through multiple interactions.
Multi-signature technology is introduced in this mixing scheme, which
ensures that transactions are confirmed by all participants before they are exe-
cuted. This multi-signature uses the method of adding hash values and indices to
resist rogue key attacks, effectively prevents malicious participants from deceiv-
ing other participants by generating special public key pairs.
Compared with traditional decentralized mixing coin schemes, the scheme
proposed in this paper demonstrates significant efficiency advantages in process-
ing mixing coin transactions. The scheme reduces the communication overhead
and computation overhead during the transaction process through an optimized
transaction negotiation and confirmation process.
2 Related Works
2.1 Centralized Mixing Protocols
The earliest Bitcoin mixing schemes relied on centralized mixing services. These
services aggregate users’ Bitcoins into a single pool for mixing and then redis-
tribute them to the users, thus obscuring the origin of the funds. Coinlayering [9]
is a centralized coin mixing scheme aimed at protecting privacy by mixing trans-
actions of multiple users, allowing users to select multiple available “mixers” and
use them to perform randomized transactions to protect their identity privacy.
However, Coinlayering also has some drawbacks. Firstly, it relies on trust
in the mixer, where users need to trust that the mixer will not disclose their
privacy or participate in attacks. Secondly, due to the centralized design, there
is a risk of a single point of failure in this scheme, and once the coin mixer is
attacked or fails, the system will be unable to operate. Finally, Coinlayering has
poor resistance to censorship, and coin mixers may review or reject transactions
from certain users under regulatory pressure.
3 Preliminaries
3.1 Bilinear Pairing
Let G1 , G2 and GT be multiplicative cyclic groups of order q. Define a mapping
e : G1 × G2 → GT as a bilinear map if it satisfies the following three properties.
1. Bilinearity: For all u ∈ G1 , v ∈ G2 , and a, b ∈ Zq :
e(u, v) = 1GT
3.2 Multi-signatures
A multi-signature scheme [14] enables n parties to collectively sign a single sig-
nature σ on a message M . The scheme typically involves four algorithms: Setup,
Key generation, Sign and Verify, the details are described below:
1. Setup(1λ ) → param: The setup takes as input the security parameter λ and
output a system parameter param.
2. KeyGen(param) → (sk, pk): The KeyGen algorithm takes param as input
and outputs user’s private key sk and a public key pk.
46 M. Shen et al.
The system model of the mixing transaction process is shown in the Fig. 1. First,
each participant broadcasts or otherwise publishes a request to create or join a
mixing group, forming a temporary mixing group. The group members then
collectively generate a public address, to which all Bitcoins must be sent. Next,
each node splits its genuine Bitcoin mixing request into several sub-requests
and secretly sends them to other group members while receiving sub-requests
from different nodes. All nodes in the group gather the sub-requests and verify
the transaction, ensuring that all requests are fulfilled and that the amounts
match. Once the transaction is validated, all nodes sign the transaction, perform
a hashing operation, and generate the transaction hash. Finally, this finalized
transaction is broadcast to the Bitcoin network, where it awaits confirmation by
miners. After the miners confirm the transaction, it is packaged into a new block
and ultimately recorded publicly on the blockchain.
Based on the above system model and the security model, the design goals of
CMMS are as follows:
Decentralization: Implement a decentralized mixing mechanism that does not
rely on any centralized mixing service provider, thereby avoiding single points
of failure and reducing centralization risks.
Multi-signature: To leverage the benefits of multi-signature technology, which not
only reduces the size of signatures but also significantly improves the system’s
scalability.
Resistance to Rogue Key Attacks: The multi-signature mechanism used in our
scheme is designed to resist rogue key attacks, ensuring the validity and integrity
of the signatures even in the presence of malicious actors attempting to manip-
ulate the signing process.
A Decentralized Bitcoin Mixing Scheme Based on Multi-signature 47
5 Proposed Scheme
When participants choose to engage in a mixing operation, they can use Bitcoin’s
public broadcast channels to announce their intention to join the mix. Each
participant has the choice of either starting a new mixing group and inviting
others or joining an existing group created by other users. This collaboration
helps in obscuring the origins of transactions.
After a mixing group has been created, all of the bitcoins from members of the
group who want to participate in the mix must be sent to a public address that
they generate together. Our method for generating public addresses is described
in Algorithm 1.
N odei creates a new account as the output address and generates several different
negotiation requests. Each negotiation request includes all output addresses and
the amount of bitcoins involved. N odei randomly selects different nodes to send
each negotiation request while also receiving requests from other nodes. When
N odej receives a negotiation request from N odei , N odej determines whether he
has enough bitcoins to fulfill N odei ’s request. If he can fulfill the request, N odej
generates a message indicating that he has enough bitcoins to meet the N odei ’s
requirement and broadcasts this message to the mixing group.
48 M. Shen et al.
After completing the above negotiations, each N odei will verify the validity of
the information. First, N odei checks whether its output address has received the
correct amount of bitcoins. Second, by comparing the total bitcoins in all the
satisfied and partially satisfied messages with the total requested in all mixing
requests, it determines if any requests have been overlooked.
6 Security Analysis
6.1 Correctness
i=1 i=1
m m
m
ai ·ski
apk = pkiai = g2ai ·ski = g2 i=1
i=1 i=1
m
e Sign, g2−1
= e H0 (M ix − T rans) i=1 ai ·ski , g2−1
m
a ·sk
e (H0 (M ix − T rans), apk) = e H0 (M ix − T rans), g2 i=1 i i
50 M. Shen et al.
We can get
e Sign, g2−1 ) · e(H0 (M ix − T rans), apk
m m
a ·sk
= e H0 (M ix − T rans) i=1 ai ·ski , g2−1 · e H0 (M ix − T rans), g2 i=1 i i
m m
= e H0 (M ix − T rans) i=1 ai ·ski , g2−1 · e H0 (M ix − T rans) i=1 ai ·ski , g2
=1
The rogue key attack [14] is a common attack in cryptography. The attacker
maliciously creates a special key pair so that their public key has a certain
relationship with the public keys of other participants. This allows the attacker
to deceive participants into participating unknowingly in the signing process.
First, in this scheme,
m
m
H1 (pki ,{pk1 ,...,pkm })
apk = pkiai = pki
i=1 i=1
This public key aggregation approach uses the hash function H1 to bind the
public key of each participant. An attacker cannot make his public key some-
how maliciously related to the public keys of other participants by generating
special key pairs, because the public key aggregation formula ensures that the
contribution of each public key is independent and related to the set of all public
keys.
Secondly, the signing process uses the following formula:
si = H0 (M ix − T rans)ai ·ski
where ai is generated by the hash function H1 , which ensures that each signer’s
signature is associated with the public keys of all other signers. If the attacker
uses a malicious public key to participate in the signature, when the multiple
signatures are finally aggregated, the attack will fail because the hash of the
malicious public key cannot match the hash relationship of other signers.
Finally, the multi-signature verification process is verified by the following
equation:
?
e Sign, g2−1 · e (H0 (M ix − T rans), apk) = 1Gt .
During the verifying process, the aggregated signature Sign and the aggre-
gated public key apk are bound and are verified by a bilinear mapping e. Even
though attackers construct malicious key pairs [15], they cannot generate valid
signatures without corresponding legitimate private keys.
A Decentralized Bitcoin Mixing Scheme Based on Multi-signature 51
7 Performance Evaluation
7.1 Theoretical Evaluation
We analyze the performance of this scheme by comparison. Table 1 compares the
computational overhead of CMMS with CoinParty, Advanced scheme, and Coin-
Layering. CMMS has significant advantages in modular exponentiation opera-
tion and hash operation. And in modular multiplication operation, compared to
CoinParty, Advanced scheme, and CoinLayering three schemes also have lower
computational overhead. It should be noted that CMMS involves elliptic curve
pairing operation in the verification phase, which involves 2m pairing operations
since each node needs to compute two pairings for one verification. The other
schemes do not involve elliptic curve pairing operation, so our scheme will have
an overhead in verification compared to the other schemes. However, overall, our
computational overhead is still small.
size, fewer interaction rounds, and stronger robustness, demonstrating its supe-
riority in practical applications. This makes the scheme not only a significant
performance advantage but also more suitable for efficient and secure operations
in decentralized environments.
8 Conclusion
References
1. Liu, Y., et al.: A blockchain-based decentralized, fair and authenticated information
sharing scheme in zero trust internet-of-things. IEEE Trans. Comput. 72(2), 501–
512 (2022)
2. Guo, R., Li, K., Li, X., Zhang, Y., Li, X.: Compact multiple attribute-based sig-
natures with key aggregation and its application. IEEE Syst. J. 16(2), 3025–3035
(2022)
3. Wang, C., Zhou, T., Shen, J., Wang, W., Zhou, X.: Searchable and secure edge
pre-cache scheme for intelligent 6G wireless systems. Futur. Gener. Comput. Syst.
140, 129–137 (2023)
4. Wang, X., Liao, L., Cao, D.: mpXim: a decentralized mixing scheme based on
a multi-party discovering protocol. In: 2023 IEEE 23rd International Conference
on Software Quality, Reliability, and Security Companion (QRS-C), pp. 263–272
(2023)
5. Wang, C., Shen, J., Vijayakumar, P., Gupta, B.B.: Attribute-based secure data
aggregation for isolated IoT-enabled maritime transportation systems. IEEE Trans.
Intell. Transp. Syst. 24(2), 2608–2617 (2023)
6. Zhou, T., Shen, J., Vijayakumar, P., Bhuiyan, M.Z.A., Sivaraman, A.: Anonymous
authentication scheme for federated learning. In: IEEE INFOCOM 2023-IEEE
Conference on Computer Communications Workshops (INFOCOM WKSHPS), pp.
1–6. IEEE (2023)
7. Wang, X., Lin, C., Huang, X., He, D.: Anonymity-enhancing multi-hop locks for
monero-enabled payment channel networks. IEEE Trans. Inf. Forensics Secur. 19,
2438–2453 (2024)
8. Zhang, L., Zhu, T., Xiong, P., Zhou, W.: The price of unlearning: identifying
unlearning risk in edge computing. ACM Trans. Multimedia Comput. Commun.
Appl. (2024)
9. Lu, N., Chang, Y., Shi, W., Choo, K.: Coinlayering: an efficient coin mixing scheme
for large scale bitcoin transactions. IEEE Trans. Dependable Secure Comput.
19(3), 1974–1987 (2020)
10. Ziegeldorf, J.H., Grossmann, F., Henze, M., Inden, N., Wehrle, K.: Coinparty:
secure multi-party mixing of bitcoins. In: Proceedings of the 5th ACM Conference
on Data and Application Security and Privacy, pp. 75–86 (2015)
A Decentralized Bitcoin Mixing Scheme Based on Multi-signature 55
11. Xue, J., Shi, L., Liu, L., Zhang, X., Li, F.: Anonymity-enhancing decentralized
protocol for coin mixing based on ring signatures and key derivation. Peer-to-Peer
Netw. Appl. 16(6), 2761–2774 (2023)
12. Xiao, R., Ren, W., Zhu, T., Choo, K.: A mixing scheme using a decentralized signa-
ture protocol for privacy protection in bitcoin blockchain. IEEE Trans. Dependable
Secure Comput. 18(4), 1793–1803 (2019)
13. Harn, L., Ren, J.: Efficient identity-based RSA multisignatures. Comput. Secur.
27(1–2), 12–15 (2008)
14. Chen, X., Yang, A., Tong, Y., Weng, J., Weng, J., Li, T.: A multisignature-based
secure and OBU-friendly emergency reporting scheme in VANET. IEEE Internet
Things J. 9(22), 23130–23141 (2022)
15. Shukla, S., Patel, S.J.: A novel ECC-based provably secure and privacy-preserving
multi-factor authentication protocol for cloud computing. Computing 104(5),
1173–1202 (2022)
Decentralized Continuous Group Key
Agreement for UAV Ad-Hoc Network
1 Introduction
Unmanned Aerial Vehicles (UAVs), also known as drones, are aircraft that are
operated either by remote control or through embedded computer programs.
The development of UAV technology was initially driven by military needs [1],
and military drones still hold a significant position in the UAV market. Mili-
tary drones can be categorized based on their specific military applications and
operational missions, such as unmanned reconnaissance or surveillance aircraft,
combat drones [2], communication relay drones, electronic warfare drones, and
multi-purpose drones that integrate reconnaissance and strike capabilities. His-
torically, drones have been primarily used for military applications in hostile
territories, conducting long-range surveillance and armed attacks to reduce pilot
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 56–69, 2025.
https://doi.org/10.1007/978-981-96-4566-4_4
Decentralized Continuous Group Key Agreement for UAV Ad-Hoc Network 57
2 Related Work
2.1 Continuous Group Key Agreement
Continuous group key agreement allows a group of users to obtain a shared key in
an asynchronous environment [12]. For the standard TreeKEM, In 2020, Alwen
et al. showed that TreeKEM has problems with forward secrecy because its users
do not erase old keys fast enough [13]. This can lead to the previously updated
keys being fully revealed by an attacker and presenting an updatable public key
encryption scheme in the event of a state disclosure. In addition, there are many
variants of TreeKEM, Klein et al. proposed Tainted TreeKEM (TTKEM) [14],
which efficiently manages the security of group keys by introducing a “tainting”
mechanism that remains effective even in the presence of dynamic changes in
group membership. Furthermore, Gordon et al. proposed a taxonomy for secure
group messaging for asynchronous collaborative applications supporting end-
to-end encryption [15, 16]. The Causal TreeKEM protocol, a new protocol for
managing shared encryption keys in dynamic groups without a centralized server
was designed.
To summarize, continuous group key agreement is mostly used even for mes-
sage passing and has not been applied in wireless sensor networks so far. The
efficiency and security of Continuous Group Key Agreement is crucial for wireless
sensor networks because the devices are usually limited in resources and large
in number. The CGKA protocol ensures that the communication is secure and
the data is private even when the devices join and leave the group frequently.
when members join or leave, which can impose a significant burden on resource-
constrained devices like drones [20].
As a result, existing research solutions usually require large amounts of com-
putational and communication resources or face the danger of a single point
of failure, and it is also thought-provoking to achieve how UAV self-organizing
networks can self-heal from the insecurity of key leakage.
3 Preliminaries
3.1 Left Balanced Binary Tree
A left balanced binary tree is usually referred to as an AVL tree (Adelson-Velsky
and Landis tree) [21], which is a self-balancing binary search tree. In an AVL
tree, the heights of the left and right subtrees of each node differ by at most 1,
i.e., the balancing factor (the height of the left subtracts the height of the right)
of any node at any time must be −1, 0, or 1. This condition ensures that the
height of the tree is kept at the logarithmic level, which guarantees the efficiency
of the operation. Two paths are mentioned in the scheme of this paper:
1) The direct path dpath, the direct path from the leaf nodes to the root node,
for each leaf node v, the dpath(v) = (v0 = v, v1 , . . . , vl = vroot ), and
2) The copath copath, this is the sequence of node sibling nodes along the
direct path, copath(v) = (v0 , v1 , . . . , vl−1
).
4 Our Proposal
4.1 System Initialization
Prior to the commissioning of the UAVs, the GCS will be responsible for assigning
keys to UAV and constructing a shared set of parameters for FANET. These
parameters will be stored in advance in the UAV’s storage system. The specific
steps are summarized below:
Step1: The GCS first constructs a empty ratchet tree in which the root node is
set to the current group key, the inner nodes are labeled with key pairs for use
in an updatable public key encryption, and the leaf nodes are labeled similarly
to the inner nodes but their public keys are used to represent the identity of the
UAV.
Step2: GCS selects a random integer si for each U AVi , i ∈ (1, n) and allocates
a pair of public and private keys (pkt , skt ) ← U KG(si ), t ∈ (1, 2n − 1) to each
node except the root. As for the root node, the GSC places a secret value here
as the group key for the group. Each leaf node is associated with the drone, and
the public key on the leaf node can be used as the identity of the drone. Figure 1
shows the specific distribution.
where GK refers to the current group key, τ means the public keys of the ratchet
tree and U , R and J represent the update, revocation and join of UAV, respec-
tively.
Decentralized Continuous Group Key Agreement for UAV Ad-Hoc Network 61
The Htrans (0) = 0 initially, and for next rounds, given a verified:
Therefore, we can make sure the current legal UAVs have a consistent state,
which requires have a consistent view of the tree and to agree on the group keys,
members, and history.
1) The block structure. The block structure is shown in Fig. 2. The block
header consists of hash of the previous block H(Blockn−1 ), block generator,
timestamp, miner’s signature on the block, and the transcript hash representing
the history of the group Htrans (n − 1). The body consists of a series of trans-
actions, where each transaction includes the type of operation, the initiator of
the operation, the legal UAVs, the public keys of the ratchet tree and the round
hash.
copath of the node, and each drone confirms whether the transaction exists first.
After successful verification, it is decrypted based on the received message until
the shared key in the root node is calculated. Specific steps are as follows:
Step 1: U AVup makes an update request
which includes the operation type, the initiator of the operation and the tran-
script hash. U AVm uploads the transaction to the block. Every group members
can verify the transcript hash to ensure that their stored ratchet tree is up-to-
date and in a consistent state with the rest of the UAVs. Figure 3 shows the
ratchet tree in action:
5 Security Analysis
In this section, we conduct a detailed security analysis of the key agreement
protocol proposed in this paper.
Consistency: All group members obtain the same group key in each round,
regardless of key updates, member additions, or the revocation of malicious
drones. After each dynamic operation, a new transaction is generated by miners
and uploaded to the blockchain. This ensures that the current group key is
acknowledged by all drones, and once uploaded, the key becomes immutable
and traceable.
Forward Secrecy: It is defined that each round of group key updates is initiated
by PRF. The update process follows the (ski si+1 ) ← H(si ). Since the PRF
is forward-secure, even if an adversary obtains either ski or si+1 , they cannot
compute the previous key, si . Furthermore, even if the current group key is
compromised, it does not allow the adversary to decrypt past session data, as
each round’s key is independently generated by the PRF.
Post-compromise Security: The entire group key agreement is structured
based the ratchet tree, with the root node representing the shared session key of
the group. If the current session key is compromised, any group member can ini-
tiate a new update operation. After the update, the ratchet tree restores security,
and the leaked session key no longer affects the security of future communica-
tions.
Proof. A launches only one challenge request. Recall that the update operation
for a node of depth d involves the processing of a randomly chosen initial value
s0 to produce a sequence of values:
h h h
s0 −
→ (sk 0 s1 ) −
→ (sk 1 s2 ) · · · −
→ (sk d−1 sd )
where GK = sd is the dynamic group key. In addition, the update process utilizes
the UPKE algorithm with CPA security to encrypt each secret value s under the
key corresponding to the node on the copath:
UE UE UE
s1 −−→ c1 , s2 −−→ c2 , · · · , sd −−→ cd .
d−2
n
2i · εcpa = (2d−1 − 1)εcpa = ( − 1)εcpa ≤ (n − 2)εcpa . (1)
i=0
2
d−1
2i · εh = (2d − 2)εh = (n − 2)εh ≤ nεh . (2)
i=1
εd ≤ ε0 + (n − 2)εcpa + nεh
1
≤ + (n − 2)εcpa + nεh .
2
Therefore,
1 1
ε := Pr[A wins] − = εd − ≤ (n − 2)εcpa + nεh .
2 2
66 S. Hong et al.
6 Performance Evaluation
In this part, we will assess the efficiency of the UACGKA protocol. Addition-
ally, we will contrast UACGKA with the current methods CRA-DGK, TGKA,
TAAGKA, and BGKA to demonstrate the practicality of UACGKA.
All experiments were conducted on servers equipped with AMD Ryzen 7 6800H
with Radeon Graphics 3.20 GHz, 16 GB RAM, and Windows 11. We used the
Public Cryptography Library PBC to implement the underlying encryption algo-
rithm, where the asymmetric encryption uses secp256k1 elliptic curves to achieve
a security strength of 128 bits. Considering the efficiency as well as the ease of
implementation of aggregated signatures, we chose the secp256k1-based Schnorr
signature to perform the signing and verification process. In addition, the stan-
dard SHA256 is used as the hash function.
7 Conclusion
The UACGKA protocol proposed in this paper offers a secure and efficient
key management solution for self-organizing UAV networks. By integrating
68 S. Hong et al.
blockchain technology, we eliminate the risk of single points of failure and ensure
decentralized key management. The introduction of ratchet trees reduces com-
munication overhead, while the combination of pseudo-random generators and
updatable public-key encryption (UPKE) schemes enhances the forward security
of communications. Security proofs and performance evaluations further validate
the applicability and efficiency of UACGKA in dynamic environments. Future
work will focus on optimizing and practically deploying UACGKA across various
UAV application scenarios.
References
1. Wang, C., Zhou, T., Shen, J., Wang, W., Zhou, X.: Searchable and secure edge
pre-cache scheme for intelligent 6G wireless systems. Futur. Gener. Comput. Syst.
140, 129–137 (2023)
2. Tian, C., Jiang, Q., Li, T., Zhang, J., Xi, N., Ma, J.: Reliable PUF-based mutual
authentication protocol for UAVS towards multi-domain environment. Comput.
Netw. 218, 1389–1286 (2022)
3. Tan, Y., Liu, J., Kato, N.: Blockchain-based lightweight authentication for resilient
UAV communications: architecture, scheme, and future directions. IEEE Wirel.
Commun. 29(3), 24–31 (2022)
4. Zhang, C., et al.: UAV swarm-enabled collaborative secure relay communications
with time-domain colluding eavesdropper. IEEE Trans. Mob. Comput. 23(9),
8601–8619 (2024)
5. Panjwani, S.: Tackling adaptive corruptions in multicast encryption protocols. In:
Proceedings of the 4th Conference on Theory of Cryptography, pp. 21–40 (2007)
6. Semal, B., Markantonakis, K., Akram, R.N.: A certificateless group authenticated
key agreement protocol for secure communication in untrusted UAV networks. In:
Proceedings of the 37th Digital Avionics Systems Conference (DASC), pp. 1–8
(2018)
7. Alwen, J., Coretti, S., Jost, D., Mularczyk, M.: Continuous group key agreement
with active security. IACR Cryptol. ePrint Arch. (2020)
8. Wang, C., Shen, J., Vijayakumar, P., Gupta, B.B.: Attribute-based secure data
aggregation for isolated IoT-enabled maritime transportation systems. IEEE Trans.
Intell. Transp. Syst. 24(2), 2608–2617 (2023)
9. Tan, Y., Wang, J., Liu, J., Kato, N.: Blockchain-assisted distributed and
lightweight authentication service for industrial unmanned aerial vehicles. IEEE
Internet Things J. 9(18), 16928–16940 (2022)
10. Xu, Z., Liang, W., Li, K.C., Xu, J., Zomaya, A.Y., Zhang, J.: A time-sensitive
token-based anonymous authentication and dynamic group key agreement scheme
for industry 5.0. IEEE Trans. Industr. Inf. 18(10), 7118–7127 (2021)
Decentralized Continuous Group Key Agreement for UAV Ad-Hoc Network 69
11. Wang, L., Tian, Y., Zhang, D., Yanhua, L.: Constant-round authenticated and
dynamic group key agreement protocol for D2D group communications. Inf. Sci.
503, 61–71 (2019)
12. Kajita, K., Emura, K., Ogawa, K., Nojima, R., Ohtake, G.: Continuous group key
agreement with flexible authorization and its applications. In: Proceedings of the
9th ACM International Workshop on Security and Privacy Analytics, pp. 3–13
(2023)
13. Alwen, J., et al.: Cocoa: concurrent continuous group key agreement. In: Proceed-
ings of the 41st Annual International Conference on the Theory and Applications
of Cryptographic Techniques, pp. 815–844 (2022)
14. Klein, K., et al.: Keep the dirt: tainted treekem, adaptively and actively secure
continuous group key agreement. In: Proceedings of the 2021 IEEE Symposium on
Security and Privacy (SP), pp. 268–284 (2021)
15. Zhang, L., Zhu, T., Zhang, H., Xiong, P., Zhou, W.: Fedrecovery: differentially
private machine unlearning for federated learning frameworks. IEEE Trans. Inf.
Forensics Secur. 18, 4732–4746 (2023)
16. Chen, H., Zhu, T., Liu, C., Shui, Yu., Zhou, W.: High-frequency matters: attack
and defense for image-processing model watermarking. IEEE Trans. Serv. Comput.
17(4), 1565–1579 (2024)
17. Zhang, C., et al.: UAV swarm-enabled collaborative secure relay communications
with time-domain colluding eavesdropper. IEEE Trans. Veh. Technol. 23(9), 1536–
1233 (2024)
18. Tan, Y., Liu, J., Kato, N.: Blockchain-based key management for heterogeneous
flying ad hoc network. IEEE Trans. Industr. Inf. 17(11), 7629–7638 (2020)
19. Zhang, Z., et al.: TAGKA: threshold authenticated group key agreement protocol
against member disconnect for UANET. IEEE Trans. Veh. Technol. 72(11), 14987–
15001 (2023)
20. Semal, B., Markantonakis, K., Akram, R.N.: A certificateless group authenticated
key agreement protocol for secure communication in untrusted UAV networks. In:
Proceedings of the 37th Digital Avionics Systems Conference (DASC), pp. 1–8
(2018)
21. Feng, C., Liu, B., Guo, Z., Yu, K., Qin, Z., Choo, K.: Blockchain-based cross-
domain authentication for intelligent 5G-enabled internet of drones. IEEE Internet
Things J. 9(8), 6224–6238 (2022)
22. Cohn-Gordon, K., Cremers, C., Garratt, L., Millican, J., Milner, K.: On ends-to-
ends encryption: asynchronous group messaging with strong security guarantees.
In: Proceedings of the 2018 ACM SIGSAC Conference on Computer and Commu-
nications Security, pp. 1802–1819 (2018)
Efficient Homomorphic Approximation
of Max Pooling for Privacy-Preserving
Deep Learning
1 Introduction
Deep learning, known for its ability to automatically extract features from large
datasets and achieve state-of-the-art performance, has become a transformative
force in numerous fields, including image recognition [14] and medical prediction
[15]. However, the training and inference processes in deep learning often depend
on sensitive user data, which raises significant privacy concerns. To address these
challenges, researchers have developed a range of techniques aimed at balancing
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 70–80, 2025.
https://doi.org/10.1007/978-981-96-4566-4_5
Efficient Homomorphic Approximation of Max Pooling 71
data privacy with deep learning computations. Among these approaches are
federated learning [21], differential privacy [1], secure multi-party computation
[18], and fully homomorphic encryption [12].
Deep learning models often rely on the cloud’s extensive computational
resources to perform inference on user data. However, to avoid exposing sen-
sitive data to untrusted cloud environments, Privacy-Preserving Deep Learning
(PPDL) utilizing Fully Homomorphic Encryption (FHE) offers a more secure
and effective solution for such scenarios [19]. FHE enables computations to be
performed on ciphertext, with the decrypted output yielding identical results to
those obtained from computations on plaintext [7]. FHE schemes, such as BGV
[3] and CKKS [6], support both homomorphic addition and multiplication, mak-
ing them suitable for executing a variety of linear transformations commonly
used in deep learning algorithms. However, FHE faces challenges in efficiently
handling nonlinear computations [17], which are essential to deep learning frame-
works, particularly in implementing activation functions and max pooling layers.
To address this issue, polynomials are utilized to approximate nonlinear func-
tions homomorphicly. Cryptonets [8] replaces ReLU with the square activation
function. To reduce the approximation error, Chabanne et al. [4] and CryptoDL
[9] utilize high-order polynomials for approximating activation functions, which
inevitably leads to increased time cost. To strike a balance between approxi-
mation error and time efficiency, Wu et al. [22] propose trainable second-order
polynomials for this purpose. However, this additional training process is time-
consuming. Recently, Lee et al. [13] proposed to use a combination of low-order
minimax polynomials to efficiently and accurately approximate the activation
function.
As illustrated, previous research on homomorphic polynomial approximations
of nonlinear functions has primarily focused on activation functions, with less
attention given to the max pooling operation. Lee et al. [13] leverages homomor-
phic Sign to develop an approximation algorithm HM axP ool for the homomor-
phic max pooling function. Building on their work, our work focuses on devel-
oping a more efficient and high-precision approximation tailored specifically for
the max pooling function. Our main contributions are as follows:
2 Related Works
Deep learning on encrypted data is suitable for scenarios where users need to
delegate their data to untrusted parties for computation. It enables inference on
the user’s data while preserving their privacy. CryptoNets proposed by Gilad et
al. [8] is a pioneering work that applies deep learning models to encrypted data
for the first time. CryptoNets uses the fully homomorphic encryption scheme
of Bos et al. [2] and approximates ReLU with x2 , enabling neural networks to
evaluate ciphertext, thereby ensuring data privacy. This method not only proves
the feasibility of applying deep learning algorithms in ciphertext environments,
but also lays the foundation for subsequent research.
FHE is particularly good at processing linear calculations, such as addition
and multiplication, which makes it perform well when performing linear trans-
formations. However, nonlinear calculations, especially activation functions and
pooling operations common in deep learning, pose greater challenges to FHE.
CryptoNets [8] uses square activation function instead of ReLU and scaled aver-
age pooling instead of max pooling. Building on the foundation of CryptoNets,
Hesamifard et al. [9] employed a modified Chebyshev polynomial to approximate
the derivative of ReLU. Wu et al. [22] proposed a trainable second-order poly-
nomial approximation of the ReLU activation function, effectively reducing pre-
cision loss. Recently, Lee et al. [13] proposed to use a combination of low-order
minimax polynomials to efficiently and accurately approximate the activation
function. Phoenix [11] uses an approximate symbolic function called SgnHE
to homomorphically implement Argmax to replace softmax, and proposes a
homomorphic Argmax approximation algorithm called ArgmaxHE. Based on
ArgmaxHE, Zhang et al. [23] optimized the homomorphic Argmax approxima-
tion and proposed an improved homomorphic Argmax approximation algorithm,
which significantly reduces the inference delay.
However, most of these studies focused on the polynomial approximation of
the activation functions and Argmax, and paid less attention to the max pooling
function. Lee et al. [13] leverages homomorphic Sign to develop an approxima-
tion algorithm HM axP ool for the homomorphic max pooling function. Although
the method of Lee et al. [13] solves the basic problem of homomorphic max
pooling, it involves a large number of polynomial operations and has low com-
putational efficiency. Therefore, based on HM axP ool, we aim to propose a more
efficient homomorphic approximation of max pooling.
(a + b) + (a − b) · pα (a − b)
maxα
app (a, b) = (1)
2
To enhance computation efficiency in the homomorphic max pooling algo-
rithm, we note that each comparison involves the maxα
app (a, b) function. Reduc-
ing the number of scalar multiplications can lead to improved overall perfor-
mance. Therefore, we redefine the approximation of the maximum function as:
maxα α
app (a, b) = (a + b) + (a − b)p (a − b) (2)
It is obvious that the output of Eq. 2 will amplify the correct result by a
factor of two. In order to obtain correct result of max pooling, the proposed
homomorphic max pooling function HM axP ool+ (x1 , . . . , xn ) is defined as:
HM axExt(x1 , . . . , xn , n)
HM axP ool+ (x1 , . . . , xn ) = (3)
2log n
where n represents the number of elements involved in the comparison, and
x1 , . . . , xn are the ciphertexts corresponding to μ1 , . . . , μn .
The challenge is that when an odd number of items are involved in the com-
parison, pairwise comparisons will amplify the output, while the remaining item,
which does not participate in the comparison, will remain unchanged. To address
this potential issue, we design the sub-algorithm HM axExt(x1 , . . . , xn , n) in
Algorithm 1, which can be divided into four steps:
Assuming the number of input ciphertexts is n = 5 and the initial input data
is X = [x1 , x2 , x3 , x4 , x5 ], this algorithm is structured into three rounds.
In the first round, we invoke HM axExt(x1 , x2 , x3 , x4 , x5 , 5). After the
loop execution, we obtain X = [max1,2 = maxα app (x1 , x2 ), max3,4 =
maxα app (x3 , x4 ), x5 ], with odd = 1 and group = 3. Based on the value of group
and odd, we perform power expansion, update X = [max1,2 , max3,4 , x5 ∗ 2], and
call HM axExt again with the new X.
74 P. Zhang et al.
X[0]
HM axP ool+ (x1 , . . . , x5 ) = . (4)
2log 5
pooling function. The experiment was conducted 1,000 times to calculate the
average computational time and error.
The computational time for HM axP ool and HM axP ool+ is presented in
Table 2. It is evident that the HM axP ool+ algorithm enhances computational
efficiency to some extent. This improvement arises from a significant reduction
in the number of scalar multiplications. Specifically, when using a pooling kernel
size of 4 × 4, HM axP ool requires 15 homomorphic scalar multiplications, while
the newly proposed algorithm, HM axP ool+ , requires only 1 homomorphic scalar
multiplication, as shown in Table 1.
were grouped into 5 batches, with each batch containing 128 predictions. For
each batch, the success probability of the predictions was analyzed. The over-
all accuracy were determined by averaging the success probabilities of the 5
batches. As two different homomorphic max pooling algorithms HM axP ool and
HM axP ool+ were utilized, two experimental schemes are as follows.
1. PPNN+HMaxPool: Use the homomorphic SqueezeNet network model
with the max pooling algorithm HM axP ool proposed by Lee et al. [13].
2. PPNN+HMaxPool+ : Use the homomorphic SqueezeNet network model
with our proposed max pooling algorithm HM axP ool+ .
6 Conclusion
Aiming to reduce the number of homomorphic scalar multiplications required
during the homomorphic max pooling algorithm, we first redefine the approxi-
mation of the maximum function, and then design HM axExt algorithm to deal
with the potential output expansion issue, so that an efficient homomorphic max
pooling algorithm, denoted as HM axP ool+ , is proposed. Theoretical analysis
and experimental results both indicate that, compared to HM axP ool, our algo-
rithm HM axP ool+ can significantly reduce the number of homomorphic scalar
multiplications across different pooling kernel sizes, thereby reducing computa-
tional time, while also ensuring inference accuracy of homomorphic SqueezeNet
neural network.
Acknowledgments. This work was supported by the Science and Technology Pro-
gram Project of Shenzhen under Grant SZWD2021012.
References
1. Blanco-Justicia, A., Sánchez, D., Domingo-Ferrer, J., Muralidhar, K.: A critical
review on the use (and misuse) of differential privacy in machine learning. ACM
Comput. Surv. 55(8), 1–16 (2022)
2. Bos, J.W., Lauter, K., Loftus, J., Naehrig, M.: Improved security for a ring-based
fully homomorphic encryption scheme. In: Cryptography and Coding: 14th IMA
International Conference, IMACC 2013, Oxford, UK, 17–19 December 2013. Pro-
ceedings 14, pp. 45–64. Springer (2013)
3. Brakerski, Z., Gentry, C., Vaikuntanathan, V.: (leveled) fully homomorphic encryp-
tion without bootstrapping. ACM Trans. Comput. Theory (TOCT) 6(3), 1–36
(2014)
4. Chabanne, H., De Wargny, A., Milgram, J., Morel, C., Prouff, E.: Privacy-
preserving classification on deep neural network. Cryptology ePrint Archive (2017)
5. Cheon, J.H., Han, K., Kim, A., Kim, M., Song, Y.: A full RNS variant of approx-
imate homomorphic encryption. In: Selected Areas in Cryptography–SAC 2018:
25th International Conference, Calgary, AB, Canada, 15–17 August 2018, Revised
Selected Papers 25, pp. 347–368. Springer (2019)
6. Cheon, J.H., Kim, A., Kim, M., Song, Y.: Homomorphic encryption for arithmetic
of approximate numbers. In: Advances in Cryptology–ASIACRYPT 2017: 23rd
International Conference on the Theory and Applications of Cryptology and Infor-
mation Security, Hong Kong, China, 3–7 December 2017, Proceedings, Part I 23,
pp. 409–437. Springer (2017)
7. Gentry, C.: A fully homomorphic encryption scheme. Stanford university (2009)
80 P. Zhang et al.
8. Gilad-Bachrach, R., Dowlin, N., Laine, K., Lauter, K., Naehrig, M., Wernsing, J.:
Cryptonets: applying neural networks to encrypted data with high throughput and
accuracy. In: International Conference on Machine Learning, pp. 201–210. PMLR
(2016)
9. Hesamifard, E., Takabi, H., Ghasemi, M.: Cryptodl: deep neural networks over
encrypted data. arXiv preprint arXiv:1711.05189 (2017)
10. Iandola, F.N.: Squeezenet: Alexnet-level accuracy with 50x fewer parameters and
< 0.5 mb model size. arXiv preprint arXiv:1602.07360 (2016)
11. Jovanovic, N., Fischer, M., Steffen, S., Vechev, M.: Private and reliable neural net-
work inference. In: Proceedings of the 2022 ACM SIGSAC Conference on Computer
and Communications Security, pp. 1663–1677 (2022)
12. Lee, J.W., et al.: Privacy-preserving machine learning with fully homomorphic
encryption for deep neural network. IEEE Access 10, 30039–30054 (2022)
13. Lee, J., Lee, E., Lee, J.W., Kim, Y., Kim, Y.S., No, J.S.: Precise approximation of
convolutional neural networks for homomorphically encrypted data. IEEE Access
11, 62062–62076 (2023)
14. Li, Y.: Research and application of deep learning in image recognition. In: 2022
IEEE 2nd International Conference on Power, Electronics and Computer Applica-
tions (ICPECA), pp. 994–999. IEEE (2022)
15. Liu, T., Siegel, E., Shen, D.: Deep learning and medical image analysis for covid-19
diagnosis and prediction. Annu. Rev. Biomed. Eng. 24(1), 179–201 (2022)
16. Lou, Q., Jiang, L.: Hemet: a homomorphic-encryption-friendly privacy-preserving
mobile neural network architecture. In: International Conference on Machine
Learning, pp. 7102–7110. PMLR (2021)
17. Marcolla, C., Sucasas, V., Manzano, M., Bassoli, R., Fitzek, F.H., Aaraj, N.: Survey
on fully homomorphic encryption, theory, and applications. Proc. IEEE 110(10),
1572–1609 (2022)
18. Pillai, S.E.V.S., Polimetla, K.: Enhancing network privacy through secure multi-
party computation in cloud environments. In: 2024 International Conference on
Integrated Circuits and Communication Systems (ICICACS), pp. 1–6. IEEE (2024)
19. Ranbaduge, T., Vatsalan, D., Ding, M.: Privacy-preserving deep learning based
record linkage. IEEE Trans. Knowl. Data Eng. (2023)
20. Microsoft SEAL (release 4.1). Microsoft Research, Redmond, WA (2023). https://
github.com/Microsoft/SEAL
21. Wen, J., Zhang, Z., Lan, Y., Cui, Z., Cai, J., Zhang, W.: A survey on federated
learning: challenges and applications. Int. J. Mach. Learn. Cybern. 14(2), 513–535
(2023)
22. Wu, W., Liu, J., Wang, H., Tang, F., Xian, M.: Ppolynets: achieving high prediction
accuracy and efficiency with parametric polynomial activations. IEEE Access 6,
72814–72823 (2018)
23. Zhang, P., Duan, A., Lu, H.: An efficient homomorphic argmax approximation for
privacy-preserving neural networks. Cryptography 8(2), 18 (2024)
Blockchain-Aided Revocable Threshold Group
Signature Scheme for the Smart Grid
Abstract. In smart grid, the electronic consumption data, including the electric
vehicle charging payment, the residential electricity payment and so on, involves
the privacy of each individual. The leakage of private consumption data is a serious
issue in the grid system. Therefore, aiming at the secure and efficient collection,
this paper presents a threshold group signature scheme with revocation mecha-
nism suitable for the smart grid. Combining the blockchain technology and group
signature, it addresses the problems of the semi-trust among the authorities and the
anonymity of entities. In addition, the revocation mechanism enables the tracking
and punishment on the malicious users. Finally, the security analysis and sim-
ulations demonstrate the properties and advantages in terms of secure features,
computational and communication costs.
1 Introduction
The smart grid utilizes the digital technology to improve the reliability, efficiency, and
sustainability of electricity services. It integrates various technologies such as smart
meters, sensors, and automated control systems to enhance the management of electricity
supply and demand [1–3].
The smart grid integrates the renewable energy sources, such as solar and wind power.
As these sources can be variable and decentralized, the smart grid’s real-time data anal-
ysis helps balance supply and demand, ensuring a stable energy supply. Additionally, it
enhances grid resilience by quickly identifying and responding to outages or disruptions.
This not only reduces downtime, but also improves overall customer satisfaction.
One of the key features of the smart grid is its ability to enable two-way communi-
cation between utilities and consumers. This allows for real-time monitoring of energy
usage, empowering consumers to make informed decisions about their energy consump-
tion [4, 5]. Smart meters and charging stations, for instance, provide detailed usage data,
which can lead to cost savings and better energy management.
© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 81–89, 2025.
https://doi.org/10.1007/978-981-96-4566-4_6
82 X. Deng et al.
With the continuous development of modern industry, the security of smart grids is
directly related to the economy and security of the country. As smart grids rely heav-
ily on advanced communication technologies and interconnected devices, the potential
for cyber threats increases. The smart grid still faces some intractable problems on the
secure data transmission. First, the smart meters, charging stations and other electri-
cal equipment are all installed in the unattended public place, which is susceptible to
hacker attacks, such as tampering the establishment parameters, stealing the privacy of
consumers, and so on. Second, all the transactions are stored on the traditional servers,
making them a primary target for attacks and creating a single point of failure issue in
smart grid. At last, the malicious consumers are difficult for the smart grid system to
detect, and it is a challenge to hold them accountable and impose appropriate penalties
subsequently.
To address these challenges above, this paper proposes a threshold group signature
scheme with revocation mechanism suitable for the smart grid, aiming to achieve a secure
and efficient data management. The main contributions are as follows.
1) This work adopts a threshold group signature scheme to realize the real-time commu-
nication and record the consumer’s behaviors. The approach enhances the security
of consumer management and supports the batch message authentication, thereby
improving the real-time data interactions in smart grid. Moreover, the pseudonym is
used for preserving privacy of the consumers, reducing the risk of privacy breaches.
Pseudonymized consumers can still be analyzed and processed without exposing the
true identities.
2) Blockchain technology is employed to eliminate the issue of partial trust between
the authorities and consumers. Additionally, it takes the distributed storage model,
which avoids the issue of a single point of failure.
3) The blacklist mechanism is applied to revoke the identities of malicious users with
disruptive behavior and expired users. Real-time monitoring is conducted to detect
the malicious consumers, ensuring the long-term secure and stable operation of the
smart grid system.
2 Related Work
Smart grid is one of the typical applications of Internet of Things. Su et al. [6] proposed a
distributed attribute-based signature scheme for distributed electricity trading, which the
consumers can choose the trade objects freely without leaking their real identity. Taking
advantage of the immutability property of blockchain, it achieved signature verifiability
independent of dynamic changes in attributes. Zhu et al. [7] designed a lightweight ver-
ifiable certificate-based privacy-preserving data aggregation scheme without pairings
for smart grids. The protocol could collect the real-time energy consumption data to
provide services securely, while ensuring the privacy of individual users. Combining the
hash-based message authentication code and public key encryption with equality testing,
Zhang and Wei [8] presented a fine-grained privacy-preserving data aggregation scheme
supporting multifunctionality, such as data privacy, integrity, unlinkability, and fault tol-
erance. Tomar et al. [9] integrated fog computing and blockchain into the smart grid,
and provided a blockchain-based certificateless aggregate signature scheme that enjoyed
Blockchain-Aided Revocable Threshold Group Signature Scheme 83
the authentication, integrity, and privacy. Verma et al. [10] designed the first certificate-
based data aggregation scheme for smart grid for the purpose of eliminating complex
certificate maintenance and key escrow issue. Additionally, this scheme can resist the
attacks launched by a malicious certification authority. Shen et al. [11] employed con-
sortium blockchains and identity-based signatures to construct a blockchain-assisted
secure device authentication for the industrial Internet of Things (IIoT). In their pro-
tocol, the blockchain functions as an information-sharing platform, establishing trust
between different domains. Xue et al. [12] proposed a secure and efficient cross-domain
authentication scheme based on two cooperative blockchains for medical consortium
systems. Moreover, the chameleon hash was used to redact the state of the user, the
anonymity mechanism was utilized to enhance security, and to trace malicious users.
Power Plant (PP): The power plant is responsible for generating the power resource,
and then transmitting these resources to each power consumer for charging the vehicles,
lighting the houses, and so on.
Power Consumers (PC): These consumers have access to the electric resources in smart
grid. After being authenticated, they can pay the bill according to the used electricity.
Each consumer is equipped with an electricity meter to record the usage amount of the
electricity.
Payment Verifier (PV): The verifier collects the usage amount and payment voucher
of the electricity. After that, they verify the legality of these data and upload the trans-
actions on the blockchain. For the illegal payments, the payment verifier notify the grid
administrator to take measures on the illegal consumers.
Grid Administrator (GA): The grid administrator is responsible for managing power
consumers and maintaining the blockchain in smart grid, which generates the group
public key and holds the master private key. Moreover, it handles the identity registration
and revocation requests. In this context, the grid administrator is a complete trustable
entity.
Blockchain: The blockchain primarily maintains the identity registration information of
power consumers, which provides the immutability and addresses the trust issues among
the individuals. By updating the registration information of the consumers periodically,
the blockchain ensures the validity and legitimacy of the irrevocable users.
4.1 Setup
GA selects a cyclic group G with a generator P of prime order q. Then, it randomly
chooses a ∈ Zq∗ to calculate the public key PKGA = aP, and sets the private key
MSK = a. Define three secure hash functions.
with its coefficients aij ∈ Zq∗ , and j = 0, 1, · · · , t − 1. Thus, the secret value of each
member is SVidi = fi (0) = ai0 , and computes the public value PVidi = ai0 P. At last, each
group member idi generates a temporary pseudonym PIDi = H0 (idi , ai0 ) to preserve
the user’s identity privacy and be stored on blockchain.
4.4 Verification
PV receives these signature {σi }i∈[1,t] , it needs to verify the correctness of the signatures.
PV computes
t
t
σ = σi , V = H1 (mi )PVidi (5)
i=1 i=1
t
Y = PIDi PKidi . (6)
i=1
When the number of messages is greater than and equal to t, the validity of signatures
is performed by verifying
σP = Y + V. (7)
4.5 Revocation
Provided that the user with pseudonym PIDi∗ is a malicious participant, GA can directly
reveal the member’s
∗true∗identity
idi∗ by using of the
corresponding
∗ , and
secret value ai0
∗
add the tuples of idi , σi to the blacklist BL = idi , σi . ∗
5 Security Analysis
5.1 Immutability
Blockchain is built on the distributed networks, where each node holds a copy of the
chain. Adding a new block to the chain requires consensus from the majority of nodes
in the network. Once the data is encapsuled in the block, any attempt to tamper the data
would be detected, which is virtually impossible.
86 X. Deng et al.
5.2 Unlinkability
Suppose that an adversary can query the pseudonym of the target member idi . However,
the generation of these pseudonyms PIDi = H0 (idi , ai0 ) involves a secret value ai0
generated by the member. Therefore, the attacker cannot associate the PIDi with the
user’s real identity, which effectively prevents any direct link between the temporary
identities used in the system and the actual identities of member, thereby maintaining
the property of unlinkability.
5.3 Unforgeability
This scheme applies the group signature technology to finish the verification on the
electricity payment messages efficiently. Because of the private key SKidi = κi and
the secret value ai0 , it ensures that the attacker cannot forge a valid signature {σi } in
the process of signature generation. Hence, it completely eliminates the possibility of
payment data forgery, and guarantees the authenticity and immutability of it.
5.5 Revocability
Provided that there is a malicious member idi in forging the electricity payment messages
= {m
M 1 , m2 , . . . , mt }, its behaviors are recorded and uploaded to the blacklist BL =
idi , σi , and its corresponding access rights are revoked. Once idi is stored in the
blacklisted, it can no longer access any data and enjoy the electricity services in the
smart grid system.
6 Performance Evaluations
6.1 Computational Overhead
To evaluate the computational performance of this protocol, it makes some comparisons
with related schemes [11] and [12] in terms of key generation and signature stages.
The symbols of the time cost for various cryptographic operations in these schemes are
described in detail as follows.
Tsm : Execution time for scalar multiplication in G.
Th : Execution time for hash operations.
Tbp : Execution time for bilinear pairing operations.
Texp : Execution time for exponentiation operations.
bytes to server for signing and receiving 96 bytes signature, which the total length of
them is 384n bytes, where n is the number of entities to be authenticated.
In the Ref.
[12], the patient sends a message request, Sig(SKij ) Kij , IDk , Xi , Xi to the server, and
the server responds with {Ak , βk , δk , T1 }, resulting in a total communication overhead
of 368k bytes, where k is the number of participants.
7 Performance Evaluations
In this paper, we designed a revocable threshold group signature scheme for blockchain-
aided smart grid, addressing the shortcomings of existing electricity payment in terms of
privacy protection, the semi-trust among the authorities and the malicious entity tracking.
Additionally, this protocol possesses immutability and unforgeability properties, resists
the various attacks and significantly enhances the security and reliability. At last, it has
been proven to be more effective in a simulation environment.
Acknowledgments. This work is supported by China Southern Power Grid Technology Project
Funding Grant No. 030000KC23040090(GDKJXM20230408)).
References
1. Wang, X., Blaabjerg, F.: Harmonic stability in power electronic-based power systems: concept,
modeling, and analysis. IEEE Trans. Smart Grid 10(3), 2858–2870 (2019). https://doi.org/
10.1109/TSG.2018.2812712
2. Su, H., Liao, P., Zhou, F., Yi, S., Yang, Q.: Research on intelligent defect identification model
of power grid edge side equipment based on edge computing. Electric Power Inf. Commun.
Technol. 19(04), 31–37 (2021). (in Chinese), https://doi.org/10.16543/j.2095-641x.electric.
power.ict.2021.04.005
Blockchain-Aided Revocable Threshold Group Signature Scheme 89
3. Zhu, L., You, S., Yin, H., Zhao, Y., Li, F., Yao, W.: FNET/GridEye: a tool for situational aware-
ness of large power interconnection grids. In: IEEE PES Innovative Smart Grid Technologies
Europe, pp. 379–383. IEEE, The Hague, Netherlands (2020). https://doi.org/10.1109/ISGT-
Europe47291.2020.9248857
4. Al-Gburi, A., Al-Hasnawi, A., Lilien, L.: Differentiating security from privacy in Internet of
Thigs: A survey of selected threats and controls. In: Daimi, K. (ed.) Computer and Network
Security Essentials, pp. 153–172. Springer, Cham (2018). https://doi.org/10.1007/978-3-319-
58424-9_9
5. Chen, Y., Wang, W., Sun, L.: Smart grid terminal access authentication scheme based on
multiple encryption algorithms. Ind. Technol. Innov. 9(4), 93–101 (2022). (in Chinese), https://
doi.org/10.14103/j.issn.2095-8412.2022.08.012
6. Su, Q., Zhang, R., Xue, R., Sun, Y., Gao, S.: Distributed attribute-based signature with attribute
dynamic update for smart grid. IEEE Trans. Industr. Inf. 19(9), 9424–9435 (2023). https://
doi.org/10.1109/TII.2022.3228688
7. Zhu, F., Guo, D., Abuadbba, S., Yi, X., Luo, J., Kumari, S.: Lightweight verifiable privacy-
preserving data aggregation for smart grids. IEEE Internet Things J. 11(19), 31249–31259
(2024). https://doi.org/10.1109/JIOT.2024.3419161
8. Zhang, J., Wei, J.: PFDAM: Privacy-Preserving fine-grained data aggregation scheme sup-
porting multifunctionality in smart grid. IEEE Internet Things J. 11(15), 25520–25533 (2024).
https://doi.org/10.1109/JIOT.2024.3356593
9. Tomar, A., Tripathi, S., Arivarasan, K.: A blockchain-based certificateless aggregate signature
scheme for fog-enabled smart grid environment. IEEE Trans. Green Commun. Netw. 7(4),
1892–1905 (2023). https://doi.org/10.1109/TGCN.2023.3265608
10. Verma, G.K., Gope, P., Saxena, N., Kumar, N.: CB-DA: Lightweight and escrow-free
certificate-based data aggregation for smart grid. IEEE Trans. Depend. Secure Comput. 20(3),
2011–2024 (2023). https://doi.org/10.1109/TDSC.2022.3169952
11. Shen, M., Liu, H., Zhu, L., Xu, K., Yu, H., Du, X.: Blockchain-assisted secure device authen-
tication for cross-domain industrial IoT. IEEE J. Sel. Areas Commun. 38(5), 942–954 (2020).
https://doi.org/10.1109/JSAC.2020.2980916
12. Xue, L., Huang, H., Xiao, F., Wang, W.: A cross-domain authentication scheme based on
cooperative blockchains functioning with revocation for medical consortiums. IEEE Trans.
Netw. Serv. Manage. 19(3), 2409–2420 (2022). https://doi.org/10.1109/TNSM.2022.3146929
Privacy-Preserving Three-Factors
Authentication and Key Agreement
for Federated Learning
1 Introduction
With the traditional centralized machine learning bringing more and more
data security problems, federated learning (FL) [1] is becoming widely used
in distributed environments. Federated learning prevents data leakage by train-
ing models locally and transmitting only model updates. However, distributed
devices are located in different networks and environments with the risk of being
maliciously attacked. Therefore, it is necessary to verify the legitimacy of dis-
tributed devices in federated learning to prevent unauthorized devices from par-
ticipating in model training or tampering with data.
Authentication and key agreement (AKA) [2] is essential for verifying
the legitimacy of distributed devices. It not only ensures the authenticity of
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 90–103, 2025.
https://doi.org/10.1007/978-981-96-4566-4_7
Privacy-Preserving Three-Factors Authentication and Key Agreement 91
device identity, but also protects the privacy of communication between dis-
tributed devices. In federated learning, distributed devices authenticate each
other through AKA cryptographic primitive to ensure every participant is legit-
imate. Distributed devices encrypt communication by generating temporary ses-
sion key to prevent malicious devices from impersonating their identities and
stealing or tampering with model updates. It will bring a large communication
overhead because of the need for frequent communication between the client and
the server in federated learning to transmit the local trained model updates.
Meanwhile, frequent communication is easy to be attacked and tampered model
updates will lead to model performance degradation or data leakage.
Most existing authentication and key agreement protocols [3] are designed
for traditional client-server environments. However, the implementation of these
protocols may lead to excessive computational and communication overheads.
Therefore, authentication protocol for federated learning should not only ensures
secure data transmission and identity verification [4], but also avoid affecting
device performance due to excessive computational and communication over-
heads. Furthermore, the client devices may dropout or rejoin the system at any
time due to their unstable connection status. The AKA protocol also needs to be
flexible and adaptable to support dynamic authentication and key update mech-
anisms to ensure the system can maintain stable security and communication
efficiency in spite of constant changes in the participating clients.
2 Preliminaries
In this section, we give definitions of the background knowledge related to this
paper. The specific description of definitions is shown as follows.
2.1 Notations
This section provides a uniform description of the main symbols used in this
paper. The specific description of each notation is displayed in Table 1.
Table 1. Notations
Notation Description
Ui ith user of the edge device
EDi ith edge device
ES edge server
CSj jth cloud server
idi identifier of the ith user
pwi password of the ith user
bioi biometric of the ith user
H(·) one-way hash function
Ti timestamp
3 Specific Definitions
In this section, we give system the specific definitions of the system, including
system model, and adversary model. The detailed description of each model is
presented as follows.
– Edge device (EDi ). In federated learning, edge devices generate data and per-
form distributed computation. Edge devices train model locally and upload
encrypted model updates without transmitting raw data to protect data pri-
vacy.
– Edge server (ES). In federated learning, edge servers are important interme-
diaries connecting edge devices to cloud server. It is responsible for collecting
and aggregating model updates from multiple edge devices to reduce the bur-
den on cloud server and reduce communication overheads. Moreover, the edge
server also registers for cloud server and edge devices in this protocol.
– Cloud server (CSj ). Cloud server is responsible for collecting local model
updates or gradients from distributed edge devices or edge servers as a central
aggregation node. Cloud server generates the optimized global model through
the global aggregation algorithm and feeds it to each participating device for
the next round of iterative training.
94 G. Wang et al.
4 Proposed Protocol
There are three phases in the proposed protocol. The detailed construction of
each stage will be shown as follows (Fig. 2).
When the legitimate user Ui has logged into the edge device with their identity
credentials, authentication is completed between EDi , ES and CSj . Authenti-
cation is divided into 7 steps and the detailed construction of each step is shown
as follow.
Step 1: The user inserts the smartcard SCi through the reading device and
inputs idi , pwi & bioi . EDi computes σi = Rep(bioi , τi ), α∗ = M4 ⊕ H(idi σi ),
Rid∗i = H(idi σi α∗ ), Rpwi∗ = H(pwi σi ), n∗i = M1 ⊕ Rid∗i ⊕ Rpwi∗ ,
M2∗ = H(Rid∗i Rpwi∗ n∗i ), M5 = M3 ⊕ H(idi pwi ) and checks if M2∗ = M2 .
Step 2: If M2∗ = M2 is valid, EDi generates random number ri and the times-
tamp T1 . Then, EDi chooses the accessed cloud Cidi and computes M6 =
M5 ⊕ ri , M7 = H(Rid∗i ri M3 T1 ) and M8 = Cidj ⊕ H(M3 Rid∗i T1 ).
After the above computations have been completed, the authentication messages
{Ridi , M6 , M7 , M8 , T1 } is sent to edge server ES through the public channel.
Step 3: After receiving {Ridi , M6 , M7 , M8 , T1 } from EDi , ES checks if |T1 −
T1 | ≤ ΔT . If |T1 − T1 | ≤ ΔT is valid, ES extracts {ni , mi } corresponding to
Ridi . Then, ES computes M3∗ = H(mi Ridi ), ri∗ = M6 ⊕ M3∗ , and M7∗ =
H(Ridi r∗ M3∗ ). ES checks if M7∗ = M7 .
Step 4: when the above computations are completed and the equation is
valid, ES generates random number εi and timestamp T2 . Then, ES com-
putes Cid∗j = M8 ⊕ H(M3∗ Ridi T1 ) and extracts {γj , δj } corresponding
to Cid∗j . Finally, ES computes M9 = εi ⊕ H(δj Cid∗j ), M1 0 = H(γj εi ) ⊕ ri∗ ,
and M1 1 = H(Cid∗j εi ri∗ T2 ) and sends the authentication messages
{M9 , M10 , M11 , T2 } to cloud server.
Step 5: After receiving {M9 , M10 , M11 , T2 } from ES, CSj checks if |T2 − T2 | ≤
ΔT is valid. If the above equation is valid, CSj computes ε∗i = M9 ⊕ H(δj
Cid∗j ), ri∗ = M10 ⊕ H(γj εi ), M11
∗
= H(Cid∗j ε∗i ) and checks if M is valid.
Step 6: If the above equation is valid, CSj generates random ηj and timestamp
T3 . Then, CSj computes M1 2 = ηj ⊕ H(ri∗ Cid∗j ), sk = H(Cid∗j ri∗ ηj T3 )
and M13 = H(sk M12 ηj T3 ). After the above parameters are computed,
the authentication messages {M12 , M13 , T3 } are sent to EDi .
Step 7: After receiving the authentication messages from CSj , EDi checks if
|T3 − T3 | ≤ ΔT is valid. If the above equation is valid, EDi computes ηj∗ =
M12 ⊕ H(ri Cidj ) and sk ∗ = H(Cidj ri n∗j T3 ). Finally, EDj computes
H(sk ∗ M1 2 n∗j T3 ) and verifies if M13 = H(sk ∗ M12 n∗j T3 ) is valid.
If the above equation is valid, EDi stores sk ∗ in the local database.
5 Security Analysis
7 Performance Analysis
This section comprehensively evaluates the performance through theoretical
analysis and experimental validation.
The total time overhead is similarly compared with related works, and
the comparison result is shown in Fig. 6. First, ours is the lowest in terms
of total computational overhead compared to other related works. Compared
to other protocols, Ju2023 and ours have lower overhead than other solutions.
In Zhang2023 and Gao2024, there exist chebyshev polynomials and symmetric
encryption. Therefore, their computational overhead is higher. Overall, combin-
ing the results of the above analysis, the proposed protocol is more suitable for
federal learning communication.
102 G. Wang et al.
8 Conclude
References
1. Li, L., Fan, Y., Tse, M., Lin, K.Y.: A review of applications in federated learning.
Comput. Ind. Eng. 149, 106854 (2020)
2. Zhang, T., Shen, J., Yang, H., Pandi, V., Gupta, B.B., Arya, V.: Sustainable
authentication and key agreement protocol using chaotic maps for industry 5.0.
IEEE Trans. Consum. Electron. 70(1), 1580–1589 (2023)
3. Gao, Y., Zhou, T., Zheng, W., Yang, H., Zhang, T.: High-availability authentication
and key agreement for internet of things-based devices in industry 5.0. IEEE Trans.
Ind. Inf. 1–9 (2024)
4. Ahmad, A.: Fraud prevention in insurance: biometric identity verification and AI-
based risk assessment. In: 2024 International Conference on Knowledge Engineering
and Communication Systems, vol. 1, pp. 1–6 (2024)
5. Sun, Y., An, K., Luo, J., Zhu, Y., Zheng, G., Chatzinotas, S.: Intelligent reflecting
surface enhanced secure transmission against both jamming and eavesdropping
attacks. IEEE Trans. Veh. Technol. 70(10), 11017–11022 (2021)
6. AI-Shareeda, M.A., Manickam, S.: Man-in-the-middle attacks in mobile ad hoc
networks (MANETs): analysis and evaluation. Symmetry 14(8), 1543 (2022)
Privacy-Preserving Three-Factors Authentication and Key Agreement 103
7. Yogesh, P.R.: Formal verification of secure evidence collection protocol using BAN
logic and AVISPA. Procedia Comput. Sci. 167, 1334–1344 (2020)
8. Dodis, Y., Reyzin, L., Smith, A.: Fuzzy extractors: how to generate strong keys
from biometrics and other noisy data. In: Cachin, C., Camenisch, J.L. (eds.)
EUROCRYPT 2004. LNCS, vol. 3027, pp. 523–540. Springer, Heidelberg (2004).
https://doi.org/10.1007/978-3-540-24676-3_31
9. Cervesato, I.: The Dolev-Yao intruder is the most powerful attacker. In: 16th
Annual Symposium on Logic in Computer Science-LICS, vol. 1, pp. 1–2 (2001)
10. Li, B., Erdin, E., Gunes, M.H., Begis, G., Shipley, T.: An overview of anonymity
technology usage. Comput. Commun. 36(12), 1269–1283 (2013)
11. Sutrala, A.K., Obaidat, M.S., Saha, S., Das, A.K., Alazab, M., Park, Y.: Authen-
ticated key agreement scheme with user anonymity and untraceability for 5G-
enabled softwarized industrial cyber-physical systems. IEEE Trans. Intell. Transp.
Syst. 23(3), 2316–2330 (2022)
12. Ju, S., Park, Y.: Provably secure lightweight mutual authentication and key agree-
ment scheme for cloud-based IoT environments. Sensors 23(24), 9766 (2023)
Blockchain-Based Anonymous
Authentication Scheme with Traceable
Pseudonym Management in ITS
1 Introduction
With the rapid development of Intelligent Transportation Systems (ITS), Vehic-
ular Ad Hoc Networks (VANETs) [1], as a core component of ITS, have become
a key technology for improving traffic safety and transportation efficiency [2].
Particularly with the rise of autonomous driving technology, the widespread
adoption of autonomous vehicles and intelligent traffic management systems has
made VANETs crucial in enhancing both traffic safety and efficiency. However,
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 104–117, 2025.
https://doi.org/10.1007/978-981-96-4566-4_8
Anonymous Authentication Scheme with Traceable Pseudonym Management 105
With the rise of network threats, traditional authentication protocols face grow-
ing vulnerabilities. Blockchain’s decentralization and transparency offer a secure
solution for authentication and key management, reducing centralized risks. Our
contributions are as follows:
106 M. Wang et al.
2 Related Work
In VANETs, authentication and key agreement protocols are critical for ensuring
secure communication between vehicles and between vehicles and infrastructure.
Over recent years, researchers have proposed numerous Authentication and Key
Agreement schemes to enhance security in vehicular networks. Below are some
notable contributions in this field.
Lightweight authentication schemes play a vital role in securing communica-
tion within VANET environments. Vasudev et al. [14] introduced a lightweight
mutual authentication protocol for V2V communication. This protocol employs
encryption operations to facilitate efficient mutual authentication between vehi-
cles and establish a secure communication channel for key agreement between
devices and servers. Similarly, Li et al. [15] proposed a lightweight privacy-
preserving authentication protocol that utilizes hash functions and XOR oper-
ations to enable anonymous authentication. Their protocol operates in a net-
work model comprising TA, RSU, and vehicles, allowing vehicles to authenticate
anonymously with RSU. Furthermore, Lee et al. [16] developed a secure and effi-
cient honey-list-based authentication protocol. This protocol incorporates vehicle
location information for key agreement and supports V2V communication both
within the same region and across regions.
In recent years, blockchain technology, as a decentralized, secure, and trans-
parent distributed ledger, has been increasingly applied to authentication and
key agreement protocols in VANETs. Xie et al. [17] proposed a blockchain-based
V2I handover authentication and V2V broadcast protocol for VANETs. The pro-
tocol ensures the anonymity and traceability of broadcast information through
dynamic anonymity and pseudonym embedding strategies, enabling verification
even without transportation infrastructure or trusted authorities. Xue et al. [18]
presented a distributed authentication scheme utilizing blockchain and smart
contracts to facilitate decentralized authentication between users and access
points in mobile Internet of Vehicles roaming services. Lin et al. [19] devel-
oped an efficient blockchain-based conditional privacy-preserving authentication
scheme. This approach balances anonymity, traceability, and key management
Anonymous Authentication Scheme with Traceable Pseudonym Management 107
3 Proposed Scheme
In this section, we give a detailed description of the proposed scheme with the
following details:
2. TA selects a private key s∈Zq∗ and calculates the corresponding public key
pubT A = s · P . Then, RSU broadcasts the public key to nearby vehicles.
Kvr1 = skrsu ·θ, αj = Aj ⊕H1 (Kvr1
P IDj1 ), then computes Kvr2
= H2 (skrsu ·
1
pubi Kvr1 T1 ) and H3 (P IDj pubi Kvr1 αj T1 ). It checks whether Vh ⊕
Kvr2 = H3 (P IDj1 pubi Kvr1
αi T1 ) holds. If the equality is satisfied, it gen-
erates the timestamp T2 and calculates Vr1 = H3 (αj pubi Kvr1
Kvr2 T2 ),
1
where P IDj ⊕ αj = ID ⊕ σ. Next, using Eq. (1) to calculate Vr2 . Finally, it
sends the message M2 = {Vr2 , T2 } to the vehicle user.
If the equality holds, it calculates Vs = H5 (u T3 Sigpubi ) and further
verifies whether Vs = Vs holds. The RSU generates a timestamp T4 , com-
putes A = αj · P , Krt = pubT A · skrsu , Krvt = αj · pubT A , and Vrt1 =
H7 (pubi Krt Krvt SigT4 ). To allow the TA to verify the vehicle and prove
that the RSU has validated the vehicle’s signature, the RSU calculates
Vrta ⊕ μ = Vta and Vvta = Vta ⊕ H1 (pubi P IDj1 ). Finally, the RSU sends
M4 = {P IDj1 , Vrt1 , Sig, pubrsu , A, Vvta , T4 } to the TA.
5. After the TA receives M4 , it first verifies the timestamp |T4 − T4 | ≤ T .
If the timestamp is valid, it retrieves pubi by calling the smart contract
with P IDj1 , then calculates Krt
= pubrsu · skT A , Krvt
= skT A · A, and
Vrt1 = H7 (pubi Krt Krvt SigT4 ). It verifies whether Vrt1 = Vrt1 holds.If
this equality is satisfied, it computes H1 (pubi P IDj1 ) ⊕ Vvta = Vta , Kvt
=
skT A · pubi , and Vta = H6 (pubi P IDj1 Kvt
pubrsu ). If Vta = Vta holds, the
verification is successful (Fig 2).
4 Security Analysis
This section provides an informal security analysis of the proposed scheme.
1. Replay Attack: In this scheme, the messages M1 , M2 , M3 , and M4 carry
timestamps Ti during the authentication process. Upon receiving a mes-
sage, the receiver calculates the difference between the current time and the
received timestamp. It then checks whether the timestamp is valid. If the
timestamp passes the validation, the message is accepted; otherwise, it is
discarded.
2. Impersonation Attack: Assume that an attacker A intercepts the authen-
tication request message of vehicle Vi and attempts to impersonate Vi to
generate a valid request message, such as M1 = P IDj1 , Aj , Vh , T1 , θ. The
generation of this message depends on the long-term secret key ski , σ, and
112 M. Wang et al.
the short-term
random key αj . The calculation of Vh is given by: Vh =
Kvr2 ⊕ H P IDj1 pubi Kvr1 αj T1 , where Kvr2 is calculated as:Kvr2 =
H(ski ·pubrsu Kvr1 T1 ). Therefore, the attacker cannot impersonate the legit-
imately registered vehicle Vi . Clearly, the proposed scheme can resist vehicle
impersonation attacks.
3. Mutual Authentication: In the proposed scheme, the roadside unit (RSU)
authenticates the vehicle by verifying Vh in message M1 and the signature
Sig along with Vs in message M3 , while the vehicle authenticates the RSU
by validating Vr1 and Vr2 in message M2 . Additionally, the trusted authority
(TA) authenticates both the vehicle and the RSU by checking Vrt1 , Vta , and
the vehicle’s signature in message M4 . This process ensures the legitimacy
and integrity of all exchanged messages.
4. Conditional Anonymity: In the proposed scheme, the vehicle Vi commu-
nicates with the RSU and other vehicles using a pseudonym P IDj1 , thereby
providing anonymity for Vi . The identity IDi of the vehicle Vi is hidden within
P IDj1 = IDi ⊕ σ ⊕ αj , and the vehicle has multiple pseudonyms P IDj1 , which
change with the value of αj . Thus, an attacker cannot determine the actual
identity of the vehicle Vi . However, in the case of malicious behavior, the real
identity can be traced through the pseudonym.
5. Traceability and Revocation: The pseudonym P IDj1 of a vehicle Vi is
defined as P IDj1 = IDi ⊕σ⊕αj , where the recovery key σ is distributed among
Anonymous Authentication Scheme with Traceable Pseudonym Management 113
5 Performance Analysis
In this section, we analyze the performance of our proposed scheme in terms of
communication overhead and computational. For smart contracts, the Solidity
language is used to write smart contracts, Ganache serves as the blockchain
environment, and the Truffle framework is utilized for contract deployment and
testing.
Let Tecm , Tepa , Tmtp , Tbp , Texp , Th , and Tenc/dec represent elliptic curve mul-
tiplication, elliptic curve addition, scalar multiplication, bilinear pairing, mod-
ular exponentiation, hash function, and symmetric key encryption/decryption,
respectively. To better evaluate the computational cost, we use the pbc library to
simulate these operations. The time for each operation is as follows: Tecm = 0.826
ms, Tepa = 0.038 ms, Tmtp = 0.026 ms, Tbp = 2.952 ms, Texp = 1.507 ms,
Th = 0.037 ms, and Tenc/dec = 0.031 ms. Based on this, we analyzed the schemes
[20–23], and our proposed scheme. In our proposed scheme, the computational
cost at the vehicle side is 5Tecm + 9Th = 4.463 ms, the RSU side’s cost is
8Tecm + 10Th + 1Tepa = 7.016 ms, and the TA’s cost is 3Tecm + 3Th = 2.589 ms.
Thus, the total computational cost for our scheme is 14.068 ms. The comparison
results are shown in Table 3, and the computational overhead of different entities
is shown in Fig. 4.
6 Conclusion
In our research, a blockchain-based pseudonym management and anonymous
authentication protocol is proposed. By combining blockchain technology with
secret sharing scheme, it not only protects user privacy, but also has trace-
ability and enhances the decentralization ability of the system. The protocol
used the RSU to call the smart contract for vehicle authentication, and used
the non-tampering property of the blockchain to ensure the legitimacy of the
authenticated user, so as to ensure the accuracy and reliability of the informa-
tion in the VANETs. Security analysis shows that the protocol can effectively
resist common attacks and meet the security requirements of VANETs. In terms
of performance, our protocol has significant advantages over other schemes.
References
1. Zhu, F., Li, Z., Chen, S., Xiong, G.: Parallel transportation management and con-
trol system and its applications in building smart cities. IEEE Trans. Intell. Transp.
Syst. 17(6), 1576–1585 (2016)
2. Tan, H., Zheng, W., Vijayakumar, P.: Secure and efficient authenticated key man-
agement scheme for UAV-assisted infrastructure-less IoVs. IEEE Trans. Intell.
Transp. Syst. 24(6), 6389–6400 (2023)
116 M. Wang et al.
3. Zhou, Y., Wang, Z., Qiao, Z., Yang, B., Zhang, M.: An efficient and provably
secure identity authentication scheme for VANET. IEEE Internet Things J. 10(19),
17170–17183 (2023)
4. Tan, H., Zheng, W., Guan, Y., Lu, R.: A privacy-preserving attribute-based
authenticated key management scheme for accountable vehicular communications.
IEEE Trans. Veh. Technol. 72(3), 3622–3635 (2023)
5. Gupta, M., Benson, J., Patwa, F., Sandhu, R.: Secure V2V and V2I communication
in intelligent transportation using cloudlets. IEEE Trans. Serv. Comput. 15(4),
1912–1925 (2020)
6. Guan, F., Zhu, T., Zhou, W., Choo, K.: Graph neural networks: a survey on the
links between privacy and security. Artif. Intell. Rev. 57(2), 40 (2024)
7. Lv, S., Tan, H., Zheng, W., Zhang, T., Wang, M.: A dynamic conjunctive keywords
searchable symmetric encryption scheme for multiple users in cloud computing.
Comput. Commun. 209, 239–248 (2023)
8. Zhang, G., Liu, B., Zhu, T., Ding, M., Zhou, W.: PPFed: a privacy-preserving
and personalized federated learning framework. IEEE Internet Things J. 11(11),
19380–19393 (2024)
9. Eiza, M.H., Owens, T., Ni, Q.: Secure and robust multi-constrained QoS aware
routing algorithm for VANETs. IEEE Trans. Dependable Secure Comput. 13(1),
32–45 (2015)
10. Lin, C., Huang, X., He, D.: Ebcpa: efficient blockchain-based conditional privacy-
preserving authentication for VANETs. IEEE Trans. Dependable Secure Comput.
20(3), 1818–1832 (2022)
11. Tan, H., Zheng, W., Vijayakumar, P., Sakurai, K., Kumar, N.: An efficient vehicle-
assisted aggregate authentication scheme for infrastructure-less vehicular networks.
IEEE Trans. Intell. Transp. Syst. 24(12), 15590–15600 (2023)
12. Son, S., Lee, J., Park, Y., Park, Y., Das, A.K.: Design of blockchain-based
lightweight V2I handover authentication protocol for VANET. IEEE Trans. Net-
work Sci. Eng. 9(3), 1346–1358 (2022)
13. Ma, Q., Tan, H., Zhou, T.: Mutual authentication scheme for smart devices in
IoT-enabled smart home systems. Comput. Stand. Interfaces 86, 103743 (2023)
14. Vasudev, H., Deshpande, V., Das, D., Das, S.K.: A lightweight mutual authenti-
cation protocol for V2V communication in internet of vehicles. IEEE Trans. Veh.
Technol. 69(6), 6709–6717 (2020)
15. Li, X., Liu, T., Obaidat, M.S., Wu, F., Vijayakumar, P., Kumar, N.: A lightweight
privacy-preserving authentication protocol for VANETs. IEEE Syst. J. 14(3),
3547–3557 (2020)
16. Lee, J., Kim, G., Das, A.K., Park, Y.: Secure and efficient honey list-based authen-
tication protocol for vehicular ad hoc networks. IEEE Trans. Network Sci. Eng.
8(3), 2412–2425 (2021)
17. Xie, Q., Ding, Z., Tang, W., He, D., Tan, X.: Provable secure and lightweight
blockchain-based V2I handover authentication and V2V broadcast protocol for
VANETs. IEEE Trans. Veh. Technol. 72(12), 15200–15212 (2023)
18. Xue, K., Luo, X., Ma, Y., Li, J., Liu, J., Wei, D.: A distributed authentication
scheme based on smart contract for roaming service in mobile vehicular networks.
IEEE Trans. Veh. Technol. 71(5), 5284–5297 (2022)
19. Lin, C., Huang, X., He, D.: Ebcpa: efficient blockchain-based conditional privacy-
preserving authentication for VANETs. IEEE Trans. Dependable Secure Comput.
20(3), 1818–1832 (2023)
Anonymous Authentication Scheme with Traceable Pseudonym Management 117
20. Yang, A., Weng, J., Yang, K., Huang, C., Shen, X.: Delegating authentication
to edge: a decentralized authentication architecture for vehicular networks. IEEE
Trans. Intell. Transp. Syst. 23(2), 1284–1298 (2022)
21. Wang, Y., Ding, Y., Wu, Q., Wei, Y., Qin, B., Wang, H.: Privacy- preserving
cloud-based road condition monitoring with source authentication in VANETs.
IEEE Trans. Inf. Forensics Secur. 14(7), 1779–1790 (2019)
22. Zhang, J., Zhong, H., Cui, J., Tian, M., Xu, Y., Liu, L.: Edge computing-based
privacy-preserving authentication framework and protocol for 5G-enabled vehicu-
lar networks. IEEE Trans. Veh. Technol. 69(7), 7940–7954 (2020)
23. Zhou, Y., et al.: A novel cloud-assisted authentication key agreement protocol for
VANET. IEEE Trans. Veh. Technol. 73(9), 13526–13541 (2024)
Multi-keyword Searchable Data Auditing
for Cloud-Based Machine Learning
1 Introduction
In the information age, the rapid growth of data has driven the continuous rise in
the demand for large-scale data processing and analysis. Machine learning (ML)
has been extensively implemented and applied in various big data applications
across multiple fields, including image processing [1], information extraction [2],
data cleaning [3], and so on. ML efficiently extracts valuable knowledge from
large amounts of data. For resource-constrained clients, a common approach is
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 118–132, 2025.
https://doi.org/10.1007/978-981-96-4566-4_9
Multi-keyword Searchable Data Auditing 119
2 Related Work
The ML solutions have been driven by the explosion of data to effectively extract
and analyze useful insights from such data. Sabek et al. [5] introduced recent
120 H. Yu et al.
work on ML and big data. With the rapid development of ML, most of the rele-
vant research is concentrated on outsourced scenarios. There is growing concern
about the leakage of privacy information in training data under various models.
To achieve privacy-preserving machine learning (PPML), many scholars have
proposed various solutions. Liu et al. [6] proposed a scalable privacy-enhancing
aggregation scheme that effectively handles participant dropouts and improves
system efficiency. Mohassel et al. [7] introduced the SecureML scheme based on
Homomorphic Encryption and proposed an efficient privacy-preserving machine
learning protocol. Zhang et al. [8] proposed a privacy-preserving and personal-
ized federated learning framework (PPFed), which enhances personalized perfor-
mance while ensuring data privacy by adopting a model-layer partition strategy.
However, PPML typically requires significant computational resources and stor-
age space. To minimize data processing and storage costs, data can be outsourced
to CPSs. With the powerful computing capabilities and resource-sharing features
of the cloud, model training can also avoid the need for deploying expensive
hardware locally.
With the trend of outsourcing data for model training to cloud platforms, the
challenges of privacy preservation and data security have become progressively
more intricate. Chang et al. [9] introduced a defense approach based on gradi-
ent sparsification and pseudo-gradients, using cosine similarity to protect user
privacy during model training, thereby providing a certain level of privacy pro-
tection for cloud-based machine learning tasks. Although these methods reduce
the risk of privacy leakage, outsourcing data storage and processing in cloud
computing still presents many security vulnerabilities. Xiao et al. [10] pointed
out that attackers can efficiently crack encryption with minimal side-channel
data, further threatening user privacy and cloud data security. Furthermore, the
integrity of outsourced data during model training has also received considerable
attention. Numerous schemes for outsourced data auditing have been proposed.
Armknecht et al. [11] launched the concept of Outsourced Proofs of Retrievability
(OPOR) and proposed the first outsourced auditing scheme, called Fortress. In
recent years, many auditing schemes for outsourcing data have been presented.
Shen et al. [4] proposed an efficient public auditing protocol that achieves secure
verification of outsourced data. To mitigate the risk of privacy breaches, Rao et
al. [12] proposed a dynamic outsourced auditing scheme that can prevent any
dishonest entities and collisions.
Aiming at data training and analysis in ML, how to quickly retrieve out-
sourced data from cloud servers has become an important issue. Boneh et al. [13]
first proposed the use of a keyword search (PEKS) mechanism based on public-
key encryption, enabling cloud servers to provide a searchable encrypted data
set for user-selected keywords. Since then, several new PEKS-related schemes
have been proposed [14, 15]. However, the use of a single keyword search method
often returns a large number of irrelevant results, which undoubtedly signifi-
cantly affects the user’s search experience, reducing both the accuracy and effi-
ciency of the search. To deal with this challenge, the PEKS mechanism sup-
porting multi-keyword search is more feasible in the cloud. Liu et al. [16] pro-
Multi-keyword Searchable Data Auditing 121
3 Preliminaries
4 Scheme Model
The system model, security targets, and security model of the cloud-based
machine learning data auditing scheme will mainly be described in this section.
These elements will act as the foundation and operational basis for the design
scheme.
The system model mainly includes four entities: cloud, key generation center
(KGC), model trainer, and third-party auditor (TPA), as shown in Fig. 1. The
specific description is as follows:
122 H. Yu et al.
Cloud: The cloud serves to store datasets uploaded by model trainers, as well as
encrypted keyword sets and corresponding label sets. With powerful computing
capabilities and ample storage resources, the cloud provides an audit-proof to
verify that the files in the cloud are stored correctly and completely.
Key Generation Center: The KGC is used for creating the required keys. It
can generate keys based on the identity information sent by the model trainer
and return them through a secure channel.
Model Trainer: The model trainer is responsible for generating several key-
words for each data block and uploading the dataset, encrypted keywords, and
corresponding labels to the cloud.
Third Party Auditors: The TPA is responsible for verifying the accuracy
of the training data stored in the cloud. Once a request for auditing from the
model trainer is received, the TPA initiates an auditing challenge to the cloud.
Subsequently, the TPA verifies the validity of the proof provided by the cloud
and communicates the results to the model trainer.
Soundness: The proof returned by the cloud can be verified by the TPA, if the
cloud has not completely stored the dataset uploaded by the model trainer.
Multi-keyword Searchable Data Auditing 123
Privacy Preservation: During the TPA audit process, no private data infor-
mation of the model trainer can be obtained from either the trainer’s delegated
tasks or the audit proof returned from the cloud.
Honest but Curious TPA: The TPA will honestly and correctly execute the
system’s algorithms. However, the TPA may be curious about the data searched
by the model trainer and could potentially infer sensitive information about the
data through keywords.
Semi-honest Cloud: A semi-honest cloud can still correctly execute the algo-
rithms set by the system. However, to protect its financial interests and uphold
its reputation, CSPs are likely to lie to users to conceal issues of data damage
or loss.
The design details of data retrieval and auditing for outsourced model training
data in the cloud will be described in this section. It mainly includes seven stages.
Specifically, the details of each stage are outlined below.
Setup: The KGC executes the algorithm. By using the input security param-
eter 𝑘, the KGC generates the public parameter 𝑝 and the master key 𝑥. The
operations involved are outlined as follows.
Key Generation: The KGC executes the algorithm. The model trainer pro-
vides the identity ID, the system public parameters 𝑃, and the master key
𝑥. Using this information, the KGC generates the private key, which will be
returned to the model trainer through a secure channel. Specific operations are
as follows.
124 H. Yu et al.
1) The model trainer provides the identity ID to the KGC. The KGC randomly
selects 𝑟 ∈ Z∗𝑞 and computes 𝑅 = 𝑔𝑟 , 𝑘 = 𝑥 · ℎ ( ID 𝑅 ) mod 𝑞.
2) The KGC sends the private key 𝑘 ID = ( 𝑅, 𝑘 ) to the model trainer through a
secure channel.
3) For the received private key 𝑘 ID , the model trainer can verify its correctness
by checking whether the formula 𝑔 𝑘 = 𝑌 ℎ ( ID 𝑅) holds. If it holds, the private
key is correct; otherwise, the private key is incorrect.
Tag Generation: The algorithm is run on the model trainer’s side, where the
data set, corresponding encrypted keywords, and their labels are uploaded to
the cloud. The specific steps are as follows.
1) Assume there are 𝑛 data blocks in a certain model training. The model trainer
intends to upload these 𝑛 data blocks to cloud-based storage, where 𝐵 𝑥 rep-
resents the data block and 𝐼 𝑥 represents the corresponding index of the data
block (1 ≤ 𝑥 ≤ 𝑛). Each data block is set with 𝑚 keywords, which are repre-
sented as 𝑤 𝑥 𝑦 , where 1 ≤ 𝑦 ≤ 𝑚.
2) The model trainer selects a number 𝑑 ∈ Z∗𝑞 at random, and computes 𝐷 = 𝑢 𝑑 .
By encrypting all the keywords,
𝑑
𝑊 𝑥 𝑦 = ℎ1 𝑤 𝑥 𝑦
Challenge Generation: Based on the audit task sent by the model trainer,
the TPA issues a challenge to the cloud.
1) When the model trainer wants to search for related datasets in the cloud that
contain the keyword 𝑤 ∗ , a number 𝑓 ∈ Z∗𝑞 is chosen at random, and 𝐹 = 𝑢 𝑓
is computed. The target keyword is encrypted as follows:
𝑊 = ( ℎ1 ( 𝑤 ∗ ) ) 𝑓
The calculation result is considered the search target’s trapdoor 𝑇𝑊 , and the
model trainer sends {𝑇𝑊 , 𝑃, 𝑅, 𝐹 } to the TPA.
2) The TPA chooses a number 𝛾 ∈ Z∗𝑞 at random, and computes 𝐸 = 𝑔 𝛾 .
3) The TPA transmits the challenge {𝑇𝑊 , 𝑃, 𝑅, 𝐹, 𝐸 } to the cloud.
Multi-keyword Searchable Data Auditing 125
Target Information Search: The cloud executes the algorithm. Based on the
challenge issued by the TPA, the cloud retrieves the encrypted keyword’s label
and transmits to the model trainer the index of the data block that holds the
target keyword.
1) The cloud searches for the set of data blocks uploaded by the model based on
the parameter 𝑌 , defining this as the searchable range. The keyword labels
within this range are as follows:
{𝑉𝑥 𝑦 } = ℎ3 ( 𝑒 (𝑊 𝑥 𝑦 , 𝐹 ) 𝑌 𝐼 𝑥 ) ( 1 ≤ 𝑥 ≤ 𝑛, 1 ≤ 𝑦 ≤ 𝑚 )
2) Then, the cloud searches for the data keyword labels within this range:
?
{𝑉𝑥 𝑦 } = ℎ3 ( 𝑒 (𝑊, 𝐷 ) 𝑌 𝐼 𝑥 )
If this equality holds, there exists a set of data blocks containing the keyword
label, denoted as { 𝐼 𝑥 }, where 𝑥 ∈ 𝐿. Otherwise, if equality does not hold,
return 0 to the TPA.
3) The set of data indices { 𝐼 𝑥 }, where 𝑥 ∈ 𝐿, is sent to the model trainer.
Proof Generation: The cloud runs this algorithm. It generates a proof based
on the data index set { 𝐼 𝑥 }, 𝑥 ∈ 𝐿 and returns it to the TPA.
1) The cloud searches for information based on the data block set { 𝐼 𝑥 }, where
𝑥 ∈ 𝐿, and computes the data block keyword label set 𝛿 and the data block
label set 𝜇:
𝛿 = 𝐸 𝑥 ∈ 𝐿 ℎ2 ( 𝑅 | |𝑌 | | 𝑦 𝑊𝑥 𝑦 ) · 𝐸 𝑥 ∈ 𝐿 𝐼𝑥
𝜇= 𝜎𝑥
𝑥 ∈𝐿
Proof Verification: The TPA executes the algorithm to verify whether the
following equation is satisfied, based on the proof returned by the cloud:
𝑒 ( 𝛿, 𝜇) = 𝑒(𝑌 ℎ ( ID 𝑅) , 𝑔1 ) 𝛾
?
If this equation holds, the TPA returns 1 to the model trainer, indicating that
the required dataset is fully stored in the cloud; otherwise, the TPA returns 0,
indicating that the data has been tampered with or corrupted.
narrow the search scope and achieve more precise searching, we extend the above
scheme to support the retrieval and auditing of datasets based on multiple key-
words.
Assuming the model trainer provides 𝑐 keywords, which are denoted as a
keyword set Δ𝑤 ∗ = { 𝑤 1 , 𝑤 2 , . . . , 𝑤 𝑐 }. The key setup, key generation, and tag
generation algorithms remain consistent with the previous scheme. The remain-
ing stages are outlined in detail as follows:
The set of calculation results is considered the search target’s trapdoor set
Δ𝑇𝑊 , and the model trainer sends {Δ𝑇𝑊 , 𝑃, 𝑅, 𝐹 } to the TPA.
2) The TPA randomly selects a number 𝛾 ∈ Z∗𝑞 and calculates 𝐸 = 𝑔 𝛾 .
3) The TPA sends the challenge {Δ𝑇𝑊 , 𝐹, 𝑃, 𝑅, 𝐸 } to the cloud.
Target Information Search: The cloud executes the algorithm. The cloud
searches for data that satisfies all keywords.
1) The cloud searches for the set of data blocks uploaded by the model based on
the parameter 𝑌 , defining this as the searchable range. The keyword labels
within this range are as follows:
{𝑉𝑥 𝑦 } = ℎ3 ( 𝑒 (𝑊 𝑥 𝑦 , 𝐹 ) 𝑌 𝐼 𝑥 ) ( 1 ≤ 𝑥 ≤ 𝑛, 1 ≤ 𝑦 ≤ 𝑚 )
Then, the cloud performs a search for the first keyword label within the
specified range. It computes the value 𝑉:
?
𝑉 = ℎ3 ( 𝑒 (𝑊1 , 𝐷)| |𝑌 || 𝐼 𝑥 )
2) The cloud records the data block index set corresponding to the keyword
tags as { 𝐼 𝑥 }, where 𝑥 ∈ 𝐿 1 . This set is then treated as a new searchable range.
Within this new searchable range, the keyword tags are calculated as follows:
𝑉𝑥 𝑦 = ℎ3 𝑒 (𝑊 𝑥 𝑦 , 𝐹 ) |||𝑌 ||| 𝐼 𝑥 (𝑥 ∈ 𝐿1, 1 ≤ 𝑦 ≤ 𝑚)
Then, the cloud performs a search for the second keyword label within the
specified range. It computes the value ( V ) as follows:
3) Once the equation does not hold, return 0 to the TPA; if it holds, continue
until 𝑖 = 𝑐, and then obtain the set of data block locations as { 𝐼 𝑥 }, where
𝑥 ∈ 𝐿∗.
Multi-keyword Searchable Data Auditing 127
Proof Generation: The cloud runs this algorithm and generates a proof based
on the data index set { 𝐼 𝑥 }, 𝑥 ∈ 𝐿 ∗ which will be returned to the TPA.
1) The cloud searches for information based on the data block set { 𝐼 𝑥 }, 𝑥 ∈ 𝐿,
then computes the data block keyword label set 𝛿 and the data block label
set 𝜇:
𝛿 = 𝐸 𝑥 ∈ 𝐿∗ ℎ2 ( 𝑅 | |𝑌 | | 𝑦 𝑊𝑥 𝑦 ) · 𝐸 𝑥 ∈ 𝐿∗ 𝐼𝑥
𝜇= 𝜎𝑥
𝑥 ∈ 𝐿∗
Proof Verification: The TPA executes the algorithm to verify whether the
following equation is satisfied, based on the proof returned by the cloud:
𝑒 ( 𝛿, 𝜇) = 𝑒(𝑌 ℎ ( ID 𝑅) , 𝑔1 ) 𝛾
?
If this equation holds, the TPA returns 1 to the model trainer, indicating that
the required dataset is fully stored in the cloud; otherwise, the TPA returns 0,
indicating that the data has been tampered with or corrupted.
7 Security Analysis
7.1 Correctness
Theorem 1. If the model trainer, the TPA, and the cloud properly carry out
the protocol, and the data integrity proof formula holds, then the scheme is
considered to meet the requirements for audit correctness.
Based on the definition of bilinear pairing, the process of data integrity proof
is as follows:
ℎ2 ( 𝑅 | |𝑌 | | 𝑊𝑥 𝑦 ) 𝑔 𝐼𝑥
𝑒( 𝛿, 𝜇) = 𝑒 ( 𝐸 𝑥 ∈ 𝐿∗ 𝑦 ·𝐸 𝑥 ∈ 𝐿∗ , 𝜎𝑥 )
𝑥 ∈ 𝐿∗
𝑘
1
( ℎ2 ( 𝑅|𝑌 |𝑦 𝑊𝑥 𝑦 ) +𝑔 𝐼𝑥 )
ℎ2 ( 𝑅 | |𝑌 | | 𝑦 𝑊𝑥 𝑦 ) 𝑔 𝐼𝑥
= 𝑒(𝐸 𝑥 ∈ 𝐿∗ ·𝐸 𝑥 ∈ 𝐿∗ , 𝑔1 )
𝑥 ∈ 𝐿∗
1
𝑘·
𝛾· [ ( ℎ2 ( 𝑅 | |𝑌 | | 𝑊𝑥 𝑦 ) +𝑔 𝐼 𝑥 ] [ (
𝑥 ∈ 𝐿 ∗ ℎ2 𝑅|𝑌 | 𝑦 𝑊𝑥 𝑦 +𝑔
𝐼𝑥 ) ]
= 𝑒 (𝑔 𝑥 ∈ 𝐿∗ 𝑦 , 𝑔1 )
= 𝑒( 𝑔 , 𝑔1𝑘 )
𝛾
= 𝑒 ( 𝑔 𝑘 , 𝑔1 ) 𝛾
= 𝑒 (𝑌 ℎ ( ID 𝑅) , 𝑔1 ) 𝛾
128 H. Yu et al.
Theorem 2. If the verification formula for the encrypted keyword tag search is
valid, the drill considers the search verification to be correct.
When searching for a certain keyword 𝑤 ∗ ,the correctness of the search result
is proved as follows:
𝑉 = ℎ 3 ( 𝑒 (𝑊 ∗ , 𝐹 ) 𝑌 𝐼 𝑥 )
= ℎ3 𝑒 ( ℎ1 ( 𝑤 ∗ ) 𝑑 , 𝑢 𝑓 ) 𝑌 𝐼 𝑥
= ℎ3 𝑒 ( ℎ1 ( 𝑤 ∗ ) 𝑓 , 𝑢 𝑑 ) 𝑌 𝐼 𝑥
= ℎ3 ( 𝑒 (𝑇𝑤 , 𝐷 ) 𝑌 𝐼 𝑥 )
7.2 Soundness
Theorem 3. If the cloud cannot forge proof to pass the TPA verification, it is
considered that the scheme is unforgeable.
Due to the DL problem, the cloud cannot infer 𝑑 based on 𝐷 and the gener-
ator 𝑢, making it impossible to forge encrypted keyword labels;
Theorem 4. The TPA is an honest but curious entity. During the auditing
process, the TPA is unable to obtain the private information of the data through
encrypted keywords.
Proof: Suppose adversary A is synonymous with the TPA and is curious about
data for model training. It can be proven that the TPA cannot obtain specific
information.
1) Adversary A receives the public parameters sent by challenger C. Due to the
DL problem, A cannot crack the key x based on the public parameters 𝑌 = 𝑔 𝑥
and the generator 𝑔, meaning it cannot obtain the model trainer’s identity
information.
2) Adversary A receives the encrypted keyword 𝑊 = ℎ1 ( 𝑤 ∗ ) 𝑓 sent by the chal-
lenger C. Because the hash function is unidirectional, adversary A cannot
reverse deduce the keyword information 𝑤 ∗ of the user from the hash value.
In summary, a curious TPA cannot access any private details or real data dur-
ing the auditing process. Therefore, this guarantees data privacy during model
training.
8 Performance Analysis
In this section, We evaluated the performance of our proposed scheme, and
compared it with the experimental results of the schemes proposed by Wang X.
et al. [18] and Wang M. et al. [19].
Multi-keyword Searchable Data Auditing 129
The simulation for this experiment was run on the Ubuntu 22.04.5 system, con-
figured with an Intel Core i7-12650H CPU 2.68 GHz, and 3.8 GB of RAM.
The experiment utilized the GMP and PBC libraries and was implemented in C
language.
In the experiment, we assumed there was 1 file, with each data block contain-
ing 10 keywords. The experiment was conducted in four phases: Tag/signature
generation, Keyword verification, proof generation, and proof verification. The
specific experimental results were as follows.
7 0.5
Our scheme Our scheme
6 Wang X.et al.'s scheme Wang X.et al.'s scheme
Wang M.et al.'s scheme 0.4
5
Time costs (s)
3
0.2
2
0.1
1
0 0
100 200 300 400 500 600 700 800 10 20 30 40 50 60 70 80
The number of date blocks (n) The number of repetitions (r)
Proof Generation Phase: This phase’s cost is related to the count of data
blocks used for challenges. We configure the data block count for challenges
between 100 and 800 (thus, the total amount of blocks 𝑛 is configured to 1000),
increasing by 50 blocks at each step. When the data block count is small, our
scheme has a similar cost to the schemes proposed by Wang M. et al. and Wang
X. et al. as shown in Fig. 4. As the count of data blocks increases, our scheme
exhibits an evidently lower cost.
Proof Verification Phase: In this phase, both our scheme and the scheme
proposed by Wang X. et al. show that the cost and challenges are not affected
by the data block count. Compared to Wang M.’s scheme, both significantly
improve audit efficiency. Our scheme has lower time costs and higher efficiency
compared to the scheme proposed by Wang X. et al. as shown in Fig. 5.
2.5 3.5
Our scheme Our scheme
Wang X.et al.'s scheme 3 Wang X.et al.'s scheme
2 Wang M.et al.'s scheme Wang M.et al.'s scheme
2.5
Time costs (s)
Time costs(s)
1.5
2
1.5
1
1
0.5
0.5
0 0
100 200 300 400 500 600 700 800 100 200 300 400 500 600 700 800
The number of challenge date blocks(c) the number of challenge date blocks(c)
Fig. 4. Cost comparison of proof gener- Fig. 5. Cost comparison of proof verifi-
ation phase cation phase
Multi-keyword Searchable Data Auditing 131
9 Conclusion
References
1. Krizhevsky, A., Sutskever, I., Hinton, G.E.: ImageNet classification with deep con-
volutional neural networks. In: Bartlett, P.L., Pereira, F.C.N., Burges, C.J.C.,
Bottou, L., Weinberger, K.Q. (eds.) Advances in Neural Information Processing
Systems 25: 26th Annual Conference on Neural Information Processing Systems
2012. Proceedings of a Meeting Held 3–6 December 2012, Lake Tahoe, Nevada,
United States, pp. 1106–1114, (2012)
2. De Sa, C., Ratner, A., Ré, C., Shin, J., Wang, F., Sen, W., Zhang, C.: Incremen-
tal knowledge base construction using deepdive. VLDB J. 26(1), 81–105 (2017).
https://doi.org/10.1007/S00778-016-0437-2
3. Rekatsinas, T., Chu, X., Ilyas, I.F., Ré, C.: HoloClean: holistic data repairs with
probabilistic inference. Proc. VLDB Endow. 10(11), 1190–1201 (2017). https://
doi.org/10.14778/3137628.3137631
4. Shen, J., Shen, J., Chen, X., Huang, X., Susilo, W.: An efficient public auditing
protocol with novel dynamic structure for cloud data. IEEE Trans. Inf. Forensics
Secur. 12(10), 2402–2415 (2017). https://doi.org/10.1109/TIFS.2017.2705620
5. Sabek, I., Mokbel, M.F.: Machine learning meets big spatial data. In: 2020 IEEE
36th International Conference on Data Engineering (ICDE), pp. 1782–1785 (2020).
https://doi.org/10.1109/ICDE48307.2020.00169
132 H. Yu et al.
6. Liu, Z., Guo, J., Lam, K.-Y., Zhao, J.: Efficient dropout-resilient aggregation for
privacy-preserving machine learning. IEEE Trans. Inf. Forensics Secur. 18, 1839–
1854 (2023). https://doi.org/10.1109/TIFS.2022.3163592
7. Mohassel, P., Zhang, Y.: SecureML: a system for scalable privacy-preserving
machine learning. In: 2017 IEEE Symposium on Security and Privacy (SP), pp.
19–38 (2017). https://doi.org/10.1109/SP.2017.12
8. Zhang, G., Liu, B., Zhu, T., Ding, M., Zhou, W.: PPFed: a privacy-preserving
and personalized federated learning framework. IEEE Internet Things J. 11(11),
19380–19393 (2024). https://doi.org/10.1109/JIOT.2024.3360153
9. Chang, W., Zhu, T.: Gradient-based defense methods for data leakage in vertical
federated learning. Comput. Secur. 139, 103744 (2024)
10. Xiao, Z., Wang, C., Shen, J., Jonathan Wu, Q.M., He, D.: Less traces are all it
takes: Efficient side-channel analysis on AES. IEEE Trans. Comput.-Aided Design
Integr. Circuits Syst. (2024). https://doi.org/10.1109/TCAD.2024.3518414
11. Armknecht, F., Bohli, J.-M., Karame, G.O., Liu, Z., Reuter, C.A.: Outsourced
proofs of retrievability. In: Proceedings of the 2014 ACM SIGSAC Conference
on Computer and Communications Security, CCS 2014, pp. 831–843. Association
for Computing Machinery, New York (2014). https://doi.org/10.1145/2660267.
2660310
12. Rao, L., Zhang, H., Tengfei, T.: Dynamic outsourced auditing services for cloud
storage based on batch-leaves-authenticated merkle hash tree. IEEE Trans. Serv.
Comput. 13(3), 451–463 (2020). https://doi.org/10.1109/TSC.2017.2708116
13. Boneh, D., Di Crescenzo, G., Ostrovsky, R., Persiano, G.: Public key encryption
with keyword search. In: Cachin, C., Camenisch, J.L. (eds.) EUROCRYPT 2004.
LNCS, vol. 3027, pp. 506–522. Springer, Heidelberg (2004). https://doi.org/10.
1007/978-3-540-24676-3_30
14. Zhang, X., Chunxiang, X., Wang, H., Zhang, Y., Wang, S.: FS-PEKS: lattice-based
forward secure public-key encryption with keyword search for cloud-assisted indus-
trial internet of things. IEEE Trans. Dependable Secure Comput. 18(3), 1019–1032
(2021). https://doi.org/10.1109/TDSC.2019.2914117
15. Liu, Q., et al.: Authorized keyword search on mobile devices in secure data out-
sourcing. IEEE Trans. Mob. Comput. 23(5), 4181–4195 (2024). https://doi.org/
10.1109/TMC.2023.3288160
16. Liu, X., Yang, G., Susilo, W., Tonien, J., Liu, X., Shen, J.: Privacy-preserving
multi-keyword searchable encryption for distributed systems. IEEE Trans. Parallel
Distrib. Syst. 32(3), 561–574 (2021)
17. Zhang, X., Huang, C., Dawu, G., Zhang, J., Wang, H.: BIB-MKS: post-quantum
secure biometric identity-based multi-keyword search over encrypted data in cloud
storage systems. IEEE Trans. Serv. Comput. 16(1), 122–133 (2023). https://doi.
org/10.1109/TSC.2021.3112779
18. Wang, X., Zhang, X., Zhang, X., Miao, Y., Xue, J.: Enabling anonymous autho-
rized auditing over keyword-based searchable ciphertexts in cloud storage systems.
IEEE Trans. Serv. Comput. 16(6), 4220–4232 (2023). https://doi.org/10.1109/
TSC.2023.3315972
19. Wang, M., Jia, Yu., Shen, W., Hao, R.: Privacy-preserving time-based auditing
for secure cloud storage. IEEE Trans. Inf. Forensics Secur. 19, 7866–7878 (2024).
https://doi.org/10.1109/TIFS.2024.3449095
A Flexible Keyword-Based PIR Scheme
with Customizable Data Scales
for Multi-server Learning
Jingang Li, Zhaoyi Liu, Huijie Yang(B) , Tianqi Zhou, and Wenying Zheng
1 Introduction
Since the concept of PIR is proposed by Chor et al. [7], there has been a growing
focus on enhancing the efficiency and privacy of PIR systems for user queries.
Depending on the data storage method in the server’s database, PIR can be
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 133–147, 2025.
https://doi.org/10.1007/978-981-96-4566-4_10
134 J. Li et al.
divided into two main types: index-based PIR and keyword-based PIR. Index-
based PIR, the traditional technique, involves a user requesting specific entries
by their position to construct a query vector. Notable methods in index-based
PIR include those proposed by Melchor et al. [11], Angel et al. [3], and Ali et
al. [2].
However, as databases increasingly adopt key-value pair storage, recent
research has shifted towards keyword-based PIR protocols. In these protocols,
each database entry is associated with one or more keywords, allowing users to
retrieve information by specifying these keywords. Recent advancements in key-
word PIR have addressed several critical areas: reducing communication over-
head, supporting batch queries, optimizing performance for sparse databases,
and enabling multi-keyword searches. Mahdavi et al. [10] introduced constant-
weight equality operators that allowed for efficient single-round, single-server
keyword retrieval. Davidson et al. [5] proposed the ChalametPIR protocol which
used Binary Fuse filters to minimize bandwidth usage during keyword searches.
Patel et al. [13] presented SparsePIR to improve encoding time, offering better
cost-effectiveness for sparse databases. Meanwhile, Pandya et al. [12] developed
Bloom Swizzler, which enhanced query efficiency through efficient indirect layers
in sparse datasets. In terms of batch queries, Liu et al. [9] presented PIRANA sys-
tem to support frequent updates while integrating tagged private set intersection
techniques. Tovey et al. [15] presented DPIR protocol to facilitate task distribu-
tion and ensures metadata privacy in distributed environments. Yang et al. [19]
proposed FGF-mkPIR, a scheme that employed batch oblivious pseudo-random
functions for flexible multi-keyword queries, ensuring data integrity and correct-
ness in AIoT settings. Tong et al. [14] implemented verifiable multi-keyword
queries on encrypted data using advanced hashing techniques like Bloom fil-
ters and Cuckoo hashing. The practical applications of keyword PIR protocols
have also expanded into various fields. Xie et al. [18] designed an optimized
two-step keyword PIR protocol that enabled lightweight Bitcoin users to effi-
ciently retrieve transactions from full nodes without revealing query details.
Vadapalli et al. [16] created PIRSONA, a digital content delivery system that
protected user consumption patterns while providing personalized recommenda-
tions. Ahmad et al. [1] introduced Coeus, a document retrieval system that com-
bined secure matrix-vector multiplication with keyword PIR for oblivious rank-
ing and retrieval. Vora and Hegde [17] developed a sublinear-time-complexity
scheme for keyword-based private searching on cloud data, including mecha-
nisms for keyword association and dissociation.
Furthermore, privacy protection and data processing in distributed environ-
ments have received increasing attention. Zhang et al. [21] explored the iden-
tification of unlearning risks in edge computing environments, which held sig-
nificant implications for our research, especially in defining the boundaries of
training data and protecting user retrieval privacy. Chen et al. [6] highlighted the
importance of high-frequency information in watermark attacks and defensed for
image-processing models, providing valuable insights for our study, particularly
in preventing model stealing attacks. Zhang et al. [20] proposed the PPFed frame-
A Flexible Keyword PIR Scheme for Multi-server Learning 135
2 Preliminary
2.1 Brakerski-Fan-Vercauteren Homomorphic Encryption
Brakerski-Fan-Vercauteren Homomorphic Encryption (BFV) [4, 8], which is a
fully homomorphic encryption scheme based on the Ring Learning With Errors
136 J. Li et al.
(RLWE) problem. This allows the server to perform certain types of computa-
tions on the encrypted data without decrypting it, which is particularly useful
for protecting sensitive data.
BFV operates on the ring R = Z[x]/(xn + 1), where n is the degree of the
polynomial and xn + 1 is a reducible polynomial. A modulus q is chosen such
that q > 2n2 and q is an odd number.
During the generation of the public-private key pair, the private key s ∈ R
is selected first, followed by the construction of the public key pk = (a, b), where
a is a randomly chosen polynomial and b = as + e mod q, with e being a small
noise term.
The plaintext m is encoded as an element in the ring Rt = R/tR =
Zt [x]/(xn + 1), where t is a small prime number. Subsequently, the ciphertext
c = (c0 , c1 ) is computed, where c0 = u · b + v · m + e and c1 = u · a + e , with e
and e being small noise polynomials.
Finally, decryption involves computing m = (c0 + c1 · s)/t mod q to recover
the original plaintext polynomial m.
In the server, there are a total of n entries, each stores at a unique coordinate
within a hypercube denotes as D. The hypercube possesses d dimensions, d with
each dimension having length of l1 , l2 , . . . , ld , respectively, such that i=1 li ≥
n. When a user wishes to query the specific data, it constructs corresponding
A Flexible Keyword PIR Scheme for Multi-server Learning 137
encrypted query vectors based on coordinates in each dimension and can then
perform the query recursively.
For example, in a two-dimensional database, Fig. 1 illustrates the process of
hypercube recursion. The user separately constructs horizontal encrypted query
vector vrow and vertical encrypted query vector vcol , with each coordinate of
hypercube D storing unique data. In the first round of recursion, the intermediate
result D2 is obtained by the dot product of vcol and D. In the second round of
recursion, the final result is obtained by the dot product of vrow and D2 , which
is then returned to the user.
The advantage of using
d a hypercube lies in the fact that only an encrypted
query vector of length i=1 li needs to be transmitted, whereas without this
structure, an encrypted query vector of length n would have to be sent.
Clinet’s Query Privacy. Client query privacy is defined with respect to a bit
b ∈ {0, 1} and an adversary A. For any database D and for all stateful, efficient
adversaries A, the probability that the adversary successfully distinguishes the
bit b in the experiment is bounded in Eq. (1).
1
b,A (λ, D) = b] ≤
Pr[Expclient + negl(λ). (1)
2
Algorithm 1. HashPIR.Init
1: Input: Security parameter 1λ , number of entrys ni , dimension d, length l such
that ld ≥ ni .
2: Output: Homomorphic encryption public key pk, symmetric encryption key keys ,
re-encryption key keyP RE , Hypercube D1 , D2 , . . . , Dm .
3: Key Generation:
4: Generate homomorphic encryption keys (pk, sk) using security parameter 1λ .
5: Generate symmetric encryption key keys .
6: Send pk and keys to the proxy server.
7: proxy server generates re-encryption key keyP RE using pk.
8: Offline Setup in Each Server Si :
9: Each server Si transforms database data into plaintext polynomial form.
10: Initialize an empty hypercube storage D1 , D2 , . . . , Dm ← ∅.
11: for each entry (k, kvalue ) in Si do
12: Compute hash value h ← H(k).
13: Initialize d vectors {v1 , v2 , . . . , vd } each of length l with all zeros.
14: for each dimension dim ∈ {1, 2, . . . , d} do
15: Extract bits (dim − 1) ∗ 8 to (dim ∗ 8) − 1 from h as hdim .
16: Calculate index idxdim ← hdim mod l.
17: Set vdim [idxdim ] ← 1.
18: end for
19: Store value kvalue in hypercube Di at position (v1 , v2 , . . . , vd ).
20: end for
21: return pk, keys , keyP RE , D1 , D2 , . . . , Dm
dimension. Compute the index idxdim = hdim mod l using the extracted 8-bit
segment hdim , and set the idxdim -th position of vdim to 1. The position of 1 in
the vector represents the location of the data in that dimension.
In this way, positions are sequentially activated across each dimension, and
the value v from the entry (k, kvalue ) is stored at position (v1 , v2 , . . . , vd ) in the
hypercube D. For example, if d = 2, v1 = [0 1], and v2 = [0 1], it indicates
that the value v should be stored at the second position of the first dimension
and the second position of the second dimension. All values in each server are
independently mapped into a hypercube D based on their corresponding keys.
The user uses the Query function of SealPIR to generate the query vector q
from the d-dimensional vector (v1 , v2 , . . . , vd ), and then encrypts it using BFV.
Since the user can specify data from a total of m servers within this framework
to participate in training, they construct (q1 , q2 , . . . , qm ) and send them to the
proxy server.
In each step, the model processes one batch of the training set through three
stages: forward propagation, backpropagation, and model parameter updates.
Forward propagation is used to calculate the cross-entropy loss function, as
shown in Eq. (2).
C
L=− yj log(ŷj ) (2)
j=1
where L represents the cross-entropy loss function, yj is the probability that the
true label is class j (if the sample belongs to class j, then yj = 1; otherwise,
yj = 0), and ŷj is the predicted probability by the model that the label is class
j.
To ensure security during transmission, the trained model ω is encrypted
using a symmetric encryption key before being sent to the user. The user decrypts
the model using the same symmetric key. This method allows users without
computational resources to complete deep learning model training via the proxy
server.
The dataset used in this experiment is the classic Fashion-MNIST dataset, which
includes images of 10 categories such as t-shirt and trouser. Each image sample
consists of 784 pixels, with each pixel occupying 1 byte. Additionally, the sample
labels are quantized and occupy 1 byte, making a total of 785 bytes per sam-
ple. The training dataset contains 6000 samples per category, while the testing
dataset contains 1000 samples per category, making a total of 10 categories, with
samples from each category as shown in Fig. 4. For this experiment, the total
number of servers is set to N = 6, so the training set is divided into 6 equal
parts, each stored on one server.
144 J. Li et al.
Assuming that the keywords requested by the user are certainly present in the
servers, the user randomly selects 7000 entries from each server for training
purposes, aggregating a total of 42000 entries to be gathered at the proxy server
for training. Afterward, the trained model is tested on the testing dataset.
Images retrieved by PIR on the proxy server yields results as shown in Table 1
and Fig. 3. The loss functions for ResNet18, ResNet34, and ResNet50 all progres-
sively decreased to convergence on the validation set, achieving good detection
accuracy. Among these, ResNet50 exhibited the highest detection accuracy.
A Flexible Keyword PIR Scheme for Multi-server Learning 145
The experimental results in Table 2 provide the performance metrics of the Hash-
PIR system across different dimensions d and data sizes n, including storage
ratio, user CPU costs, server CPU costs, and network costs. Query times are
consistently low across all dimensions, typically around 2–8 ms, indicating effi-
cient query processing. Extraction times increase slightly with dimension but
remain manageable, reaching up to 19 ms for d = 4, which is still within accept-
able limits for most user. Setup times are stable and relatively low, ranging from
32.622 to 51.317 s. Answer times increase with higher dimensions but remain
feasible, with the highest being 53.715 s for d = 4 and n = 216 . This indicates
that the system can handle increased computational demands effectively. Net-
work costs for queries are stable and low, generally around 90–181 KB. Answer
network costs increase with higher dimensions but are still acceptable.
6 Conclusion
References
1. Ahmad, I., Sarker, L., Agrawal, D., El Abbadi, A., Gupta, T.: Coeus: a system for
oblivious document ranking and retrieval. In: Proceedings of the ACM SIGOPS
28th Symposium on Operating Systems Principles, pp. 672–690 (2021)
2. Ali, A., et al.: {Communication–Computation} trade-offs in {PIR}. In: 30th
USENIX Security Symposium (USENIX Security 2021), pp. 1811–1828 (2021)
3. Angel, S., Chen, H., Laine, K., Setty, S.: PIR with compressed queries and amor-
tized query processing. In: 2018 IEEE Symposium on Security and Privacy (SP),
pp. 962–979. IEEE (2018)
4. Brakerski, Z.: Fully homomorphic encryption without modulus switching from clas-
sical GapSVP. In: Annual Cryptology Conference, pp. 868–886. Springer (2012)
A Flexible Keyword PIR Scheme for Multi-server Learning 147
5. Celi, S., Davidson, A.: Call me by my name: simple, practical private informa-
tion retrieval for keyword queries. In: Proceedings of the 2024 on ACM SIGSAC
Conference on Computer and Communications Security, pp. 4107–4121 (2024)
6. Chen, H., Zhu, T., Liu, C., Yu, S., Zhou, W.: High-frequency matters: attack and
defense for image-processing model watermarking. IEEE Trans. Serv. Comput. 17,
1565–1579 (2024)
7. Chor, B., Kushilevitz, E., Goldreich, O., Sudan, M.: Private information retrieval.
J. ACM (JACM) 45(6), 965–981 (1998)
8. Fan, J., Vercauteren, F.: Somewhat practical fully homomorphic encryption. Cryp-
tology ePrint Archive (2012)
9. Liu, J., Li, J., Wu, D., Ren, K.: PIRANA: faster multi-query PIR via constant-
weight codes. In: 2024 IEEE Symposium on Security and Privacy (SP), pp. 4315–
4330. IEEE (2024)
10. Mahdavi, R.A., Kerschbaum, F.: Constant-weight {PIR}: single-round keyword
{PIR} via constant-weight equality operators. In: 31st USENIX Security Sympo-
sium (USENIX Security 22), pp. 1723–1740 (2022)
11. Melchor, C.A., Barrier, J., Fousse, L., Killijian, M.O.: XPIR: private information
retrieval for everyone. Proc. Priv. Enhanc. Technol. 2, 155–174 (2016)
12. Pandya, A.M.: Bloom swizzlers for efficient keyword-based private information
retrieval. Master’s thesis, University of Calgary, Calgary, Canada (2024). https://
prism.ucalgary.ca
13. Patel, S., Seo, J.Y., Yeo, K.: {Don’t} be dense: efficient keyword {PIR} for sparse
databases. In: 32nd USENIX Security Symposium (USENIX Security 2023), pp.
3853–3870 (2023)
14. Tong, Q., Li, X., Miao, Y., Wang, Y., Liu, X., Deng, R.H.: Beyond result verifi-
cation: efficient privacy-preserving spatial keyword query with suppressed leakage.
IEEE Trans. Inf. Forensics Secur. 19, 2746–2760 (2024)
15. Tovey, E., Weiss, J., Gilad, Y.: Distributed PIR: scaling private messaging via
the users’ machines. In: Proceedings of the 2024 on ACM SIGSAC Conference on
Computer and Communications Security, pp. 1967–1981 (2024)
16. Vadapalli, A., Bayatbabolghani, F., Henry, R.: You may also like... privacy: rec-
ommendation systems meet PIR. Proc. Priv. Enhanc. Technol. 2021(4), 30–53
(2021)
17. Vora, A.V., Hegde, S.: Keyword-based private searching on cloud data along with
keyword association and dissociation using cuckoo filter. Int. J. Inf. Secur. 18,
305–319 (2019)
18. Xie, Y., Zhang, C., Wei, L., Niu, Y., Wang, F.: Private transaction retrieval for
lightweight bitcoin client. In: 2019 IEEE International Conference on Blockchain
and Cryptocurrency (ICBC), pp. 440–446. IEEE (2019)
19. Yang, H., et al.: A flexible and verifiable keyword PIR scheme for cloud-edge-
terminal collaboration in AIoT. IEEE Internet Things J. 11, 18111–18122 (2024)
20. Zhang, G., Liu, B., Zhu, T., Ding, M., Zhou, W.: PPFed: a privacy-preserving and
personalized federated learning framework. IEEE Internet Things J. 11, 19380–
19393 (2024)
21. Zhang, L., Zhu, T., Xiong, P., Zhou, W.: The price of unlearning: identifying
unlearning risk in edge computing. ACM Trans. Multimed. Comput. Commun.
Appl. 1–23 (2024)
Automatic Software Vulnerability
Detection in Binary Code
Shigang Liu1,2(B) , Lin Li2 , Xinbo Ban2 , Chao Chen3 , Jun Zhang2 ,
Seyit Camtepe1 , and Yang Xiang2
1
CSIRO’s Data 61, Clayton, Australia
{shigang.liu,seyit.camtepe}@data61.csiro.au
2
Swinburne University of Technology, Melbourne, Australia
{shigangliu,linli,junzhang,yxiang}@swin.edu.au
3
Royal Melbourne Institute of Technology, Melbourne, Australia
[email protected]
1 Introduction
Nowadays, computer software components have become a crucial part of the
modern world, from simple applications on smart devices to complex enterprise
software and critical embedded systems. According to James et al. [1], approx-
imately 80% of cyber-attacks are motivated by financial gain, leading to losses
across numerous industries, including autonomous technologies, the Internet of
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 148–166, 2025.
https://doi.org/10.1007/978-981-96-4566-4_11
Automatic Software Vulnerability Detection in Binary Code 149
Things (IoT), artificial intelligence (AI), and more. For example, vulnerabili-
ties in widely-used browser plugins like Adobe Flash Player and Oracle Java
[2], as well as open-source components such as Heartbleed (Codenomicon 2014)
[3], and ShellShock (Symantec Security Response 2014) [4], have posed signifi-
cant security threats to thousands of companies and their customers worldwide.
The number and severity of software vulnerabilities continue to rise, making it
essential to identify and address them before deploying any software [5].
Binary code analysis is a common tool in computer security, especially when
the source code is unavailable. For example, vendors may want to verify the com-
pliance of third-party software, such as drivers and plugins. In the IoT domain,
most devices are built using public drivers, and many of these applications and
libraries lack source code. This creates a significant challenge for assessing the
security of IoT devices when deploying them on a large scale [6]. Binary code
analysis is crucial in specialized fields, including detecting software vulnera-
bilities, identifying malware, and recognizing code clones at the binary level.
Research shows that interest in binary analysis remains strong, with no signs of
decline [7]. Binary analysis is widely applicable and appealing, not only because
source code is often inconvenient to use, but also because it helps avoid mis-
matches between source code and binary code [8]. As a result, there continues
to be strong demand for software vulnerability analysis in binary code.
Recently, machine learning, especially Deep Learning (DL), has been widely
used to solve software security problems [9]. However, previous works such as [10]
and [11] are mainly focused on detecting vulnerabilities in source code. These
methods first use word2vec to convert the source code into feature vectors, then
apply Deep Learning to learn the high-level features, which are used to train the
Deep Learning models. Although some works perform vulnerability detection on
binary code, such as VDiscover [12], Instruction2vec [13], Maximal Divergence
Sequential Auto-Encoder [14], A-BiLSTM [15], and Deep Cost-sensitive Kernel
Machine [16], the experimental results from this study show that these techniques
often lead to a high false negative rate (over 65 Therefore, we are motivated to
develop an efficient automatic approach for vulnerability detection.
In this paper, we aim to use state-of-the-art techniques, specifically Code-
BERT [17] and attention-based Deep Learning [18], for vulnerability search in
binary code. At a high level, we first disassemble the binary software, then iden-
tify a set of vulnerable and non-vulnerable functions as ground truth. This
ground truth is fed into CodeBERT to generate effective embeddings, which
are then passed into attention-based neural networks to train the prediction
model. The model is used to detect vulnerabilities in real-world scenarios. We
conducted a series of comparative experiments to demonstrate the effectiveness
of our proposed BiVulD system, using real-world projects from various sources.
We will release our tool under an open-source license to enable other researchers
to evaluate our proposed strategy. Our main contributions are as follows:
150 S. Liu et al.
2 Related Work
Recently, machine learning techniques have been applied to detect software vul-
nerabilities, typically involving two main tasks: feature selection and model train-
ing. The approach to feature selection varies from study to study. For example,
DiscovRE [8] suggested using the control flow graph (CFG) to calculate func-
tion similarity. VDiscover [12] proposed to use machine learning for software
vulnerability detection based on the lightweight static and dynamic features.
Meanwhile, Deep Learning has been widely applied in software vulnerability
detection. Lee et al. [13] employs Deep Learning for software bugs detection in
assembly code. Chen et al. [19] proposed Atomic for vulnerability discovery in
cellular networks (e.g. LTE and 5G). Atomic first employed natural-language
processing technique to analyse LTE documents to recover a set of hazard indi-
cators, then it used Textual Entailment to predict whether there were potential
security flaws in cellular networks. Steenhoek et al. [20] assesses 9 state-of-the-
art DL models on widely used vulnerability datasets, exploring various research
questions to improve model understanding and performance, they highlight that
vulnerability detection tends to be more effective when tailored to a specific
type, outperforming models designed to cover all vulnerabilities. Yang et al.
[21] propose a novel Deep Learning-enhancement architecture, implemented in
Asteria-Pro, which significantly improves bug detection performance, with 96.9%
faster calculation time and enhanced precision. Jiang et al. [22] argued that
while third-party libraries enhance productivity in software development, they
1
https://nvd.nist.gov/.
2
https://cve.mitre.org/.
Automatic Software Vulnerability Detection in Binary Code 151
3 Proposed BiVulD
From a high-level perspective, BiVulD leverages known vulnerabilities (vulnera-
ble functions) to build ground-truth data, which is then used to train the predic-
tion model using these binary vulnerabilities. Figure 1 provides an overview of
152 S. Liu et al.
the proposed BiVulD system. BiVulD operates in four phases: database prepa-
ration, embeddings learning, prediction model training, and detection.
For database preparation, this work maps CVEs and NVD to vulnerability-
contributing commits in six real-world projects: Asterisk, FFmpeg, LibPNG,
LibTIFF, Pidgin, and VLC. The datasets for CWE119 and CWE399 [10] are
also considered. These projects are included not only because they are widely
used in everyday life, but also because they are significant in the research com-
munity [11,30].
For embedding learning, we use the state-of-the-art codeBERT to obtain
high-quality embeddings. These embeddings are then input into an attention-
based BiLSTM model to train the prediction model. Finally, the prediction
model will be used to detect potential vulnerabilities in any test cases. Dur-
ing the detection phase, given a test case, tools such as IDA PRO and objdump
(or other reverse engineering tools like Capstone) can be used to disassemble
binaries and obtain assembly-level instructions. Functions will then be extracted
without domain knowledge, as we have no vulnerability information about the
test case. Each function will be input into the model, which will predict whether
it is vulnerable or not.
Note that, some distinctive design principles are included in BiVulD, which
are:
almost all PNG features. VLC media player is an open source tool which supports
many audio and video file formats, including DVDs, Audio CDs, and so on.
For each of the projects, we manually map the vulnerability from CVE and
NVD which means we can tell where the vulnerability come from. Take CVE
ID: CVE-2018-12900 for example, we can see this is a buffer overflow by check-
ing the vulnerability categorization based on the information from NVD. The
analysis reveals that the issue is a heap-based buffer overflow in the cpSepa-
rateBufToContigBuf function, which is located in the tiffcp.c file. Therefore, we
treat the cpSeparateBufToContigBuf function as a buffer overflow vulnerability.
Moreover, the real-world projects include multiple types of vulnerabilities such
as buffer error (CWE119), resource management errors (CWE399), information
leak (CWE200), out-of-bound read (CWE125), input validation (CWE20), null
pointer dereference (CWE476), numeric errors (CWE189), out-of-bounds write
(CWE787), and insufficient information (NVD-CWE-noinfo). However, buffer
overflow takes the highest proportion of the vulnerabilities found. For example,
29 out of 71 vulnerabilities are buffer overflow in FFmpeg project (40.8%), and
9 out of 21 vulnerabilities are buffer overflow in LibTIFF project (42.9%). We
use 32-bit and 64-bit versions of GCC to compile the real-world projects under
Windows and Linux. Data information can be seen from Table 1.
In order to examine the effectiveness of our proposed scheme. CWE119 and
CWE399 datasets are considered in this work, CWE119 only contains buffer
overflow errors, and CWE399 only contains resource management errors. The
source code of CWE119 and CWE399 datasets were used in [10]. In the orig-
inal CWE119 dataset, there are 10440 vulnerable functions and 29313 non-
vulnerable functions. We compile them with MINGW32 and MINGW64 on the
Windows and Linux system. However, not every function is compilable because
some libraries are missing. As a result, we were only able to successfully com-
pile 20,590 samples, which included 10,712 vulnerable samples and 9,878 non-
vulnerable samples of CWE119. In the case of CWE399, we could only compile
1,530 vulnerable samples along with 2,164 non-vulnerable samples.
Automatic Software Vulnerability Detection in Binary Code 155
Since the BiVulD system is designed to detect binary vulnerabilities at the func-
tion level, we need to extract functions from assembly language instructions to
achieve this goal. We start by compiling the source code into binary code. After
compilation, we obtain 32-bit and 64-bit executable files for both Windows and
Linux systems (i.e., PE and ELF files) from six real-world projects, including
the CWE119 and CWE399 datasets. Next, we use objdump to display the binary
instructions in assembly language form. Objdump is a command-line program
used to extract information from object files on Unix-like systems. Finally, we
extract functions (i.e., sequences of assembly language instructions) from the
disassembled output.
For real-world projects, we compile each project under Windows and Linux
using 32-bit and 64-bit versions of GCC. The processed data samples are then
combined to form a new dataset. For example, LibTIFF contains 98 vulnerable
samples and 678 non-vulnerable samples from all the samples compiled under
Windows and Linux with 32-bit and 64-bit versions of the GCC compiler. It is
worth noting that we only present three real-world projects in the experimental
results section, as the results for the other projects are similar to those for
Asterisk, FFmpeg, and LibTIFF.
Fig. 2. Figure (a) and (c) display the two dimensional features of LibTIFFand
CWE119, and Figure (b) and (d) visualise the features of LibTIFF and CWE119
using t-Distributed Stochastic Neighbor Embedding (t-SNE). Note: label ‘0’ means
non-vulnerable class, label ‘1’ means vulnerable class.
At a high level, BiVulD takes as input the token sequences converted from
assembly functions. On the one hand, assembly instructions often differ despite
having similar semantics. On the other hand, similar data types with dissimilar
byte tokens should have similar representations, such as mov DWORD PTR [ebp-
0xc],0x0, and mov DWORD PTR [ebp-0x10],0x0. This phenomenon is similar to
constructing word representations in natural language processing. In light of this,
we represent a token embedding as a continuous vector e ∈ R128 . As a result,
we have |V| token embeddings, where V denotes the vocabulary.
The semantics of a function do not rely solely on the meaning of individual
tokens, but also on their context. Thus, we feed the embedding sequences into
a BiLSTM to generate more abstract representations by incorporating assembly
language instruction information. An LSTM is a recurrent neural network taking
the following form:
hi , ci = LSTM(xi , hi−1 , ci−1 ) (1)
where xi is the input at position i, hi ∈ R64 and ci ∈ R64 are the hidden
states and cells of the LSTM at position i, respectively. An LSTM will be fed to
BiLSTM in the forward and the backward direction. We concatenate the hidden
→
− ←
−
states generated in the two directions h i and h i as the output ĥi at position i.
To improve vulnerability detection, it is important to recognize that not all
tokens have the same contribution to the prediction. Therefore, we assess the
importance of each token feature at every position in order to identify the signifi-
cant tokens for vulnerability detection. Since the same token feature takes differ-
ent values at different token positions, they receive different importance scores. In
this way, each token at each position will get a vector of importance scores for its
features. We assume that important tokens should be represented by important
features. Thus we take the average of the feature importance scores for each
token as its final importance score. Since those scores are position and token
dependent, we refer to them as attentive position embeddings. The attentive
position embeddings are concatenated to the hidden representations generated
by the BiLSTM as the final representation of each token. All token representa-
tions are fed into a binary classifier, which is a feedforward neural network with
two hidden layers, to determine if the input function is vulnerable or not.
Automatic Software Vulnerability Detection in Binary Code 157
4 Experiments
4.1 Baselines
As far as we know, there are only a few research studies using Deep Learning for
vulnerability detection based on binary code. The closest works to ours are the
158 S. Liu et al.
4.4 Results
Table 3. Top-k precision for real-world data. # non vul/vul: the number of non-
vulnerable and vulnerable functions in test set.
that BiVulD’s attention mechanism can focus more on the instructions that
cause vulnerabilities. Since vulnerabilities are often related to specific parts of
the code rather than the entire function, the attention mechanism in the decoder
process can allocate varying amounts of attention to different parts of the code.
RQ2: Is BiVulD practical and effective in detecting real software vul-
nerabilities? We use real-world projects (Asterisk and LibTIFF) to demon-
strate that BiVulD is practical for addressing real-world software security prob-
lems.
The experimental results of BiVulD on LibTIFF, Asterisk, CWE119-
LibTIFF, and CWE119-Asterisk are presented in Table 3. The test set for
LibTIFF consists of 55 functions, 8 of which are vulnerable. The test set for
Asterisk consists of 1543 functions, including 8 vulnerable ones. For CWE119-
LibTIFF and CWE119-Asterisk, the training dataset is CWE119, while the
test sets are LibTIFF and Asterisk, respectively. For LibTIFF, Asterisk, and
CWE119-LibTIFF, we report the top-10 to top-50 precision, while for CWE119-
Asterisk we report the top-100 to top-500 precision due to the larger number of
non-vulnerable functions in the test cases.
First, from Table 3, we can see that top-k precision for CWE119-LibTIFF
exceeds 40%, with 8, 15, 18, 20, and 20 vulnerable functions identified, respec-
tively. This means that 50 functions need to be examined to find 20 vulnera-
ble functions. Secondly, with our proposed BiVulD, 500 functions need to be
checked to find 18 vulnerable functions in the CWE119-Asterisk scenario due
to false positives. We hypothesise that the high false positives in BiVulD are
due to two factors: 1) the training data (CWE119) and test data (Asterisk) are
from different sources, meaning the data distributions may differ. In this case,
BiVulD may not be able to capture the underlying distribution of the test cases;
2) the ratios of non-vulnerable to vulnerable functions in CWE119 (nearly 1:1)
and Asterisk (nearly 244:1) are quite different, which may weaken the prediction
model trained on CWE119 when targeting Asterisk.
From Table 3, we can see that 6 and 7 vulnerable functions can be found
by examining the top 10 functions over the LibTIFF and Asterisk datasets,
respectively. When examining the top 20 functions, all the vulnerable functions
162 S. Liu et al.
can be found in the Asterisk dataset, and 7 out of 8 vulnerable functions can
be identified in the LibTIFF dataset. Moreover, we can see that the accuracies
on LibTIFF and Asterisk are 97.7% and 99.7%, which are higher than those for
CWE119-LibTIFF (96.5%) and CWE119-Asterisk (99.4%). This indicates that
the prediction model built on the same dataset (or a dataset from the same
source) results in higher accuracy and can detect more vulnerable functions in
the top-k functions.
RQ3: How effective is BiVulD compared to other machine learning
tools used for vulnerable function search? To answer this question, we
use two types of datasets, including CWE119 and CWE399, as well as two real-
world software applications: Asterisk and LibTIFF. We mainly focus on detecting
vulnerable functions to demonstrate that BiVulD outperforms other machine
learning tools that can be used for software vulnerability search.
Table 4 shows the experimental results of BiVulD and SAFE based on
CWE119, CWE399, LibTIFF, and Asterisk. Because there is a class imbalance
in the LibTIFF and Asterisk datasets, we applied random oversampling (ROS)
to SAFE. The oversampling rate is 2. SAFE + ROS means we first use ROS
on the training data to balance the number of vulnerable and non-vulnerable
samples, and then we train the SAFE model on this balanced data. We used
the following parameters for SAFE: batch size 64, embedding size 256, and word
frequency 3.
From Table 4, we see that SAFE performs worst on the LibTIFF and Aster-
isk datasets. In particular, it cannot identify any vulnerabilities in the Aster-
isk dataset. However, SAFE + ROS shows some improvement. For example, in
LibTIFF, the Precision improved from 0.278 to 0.417, the FPR dropped from
0.200 to 0.108, and the F1-score increased from 0.385 to 0.500. In comparison,
BiVulD shows much better performance in terms of Precision, TPR, FPR, and
F1-score. For instance, BiVulD achieves the best F1-scores of 0.915 and 0.943
on the CWE119 and LibTIFF datasets, which are 22.2% and 55.8% higher than
SAFE.
For the CWE119 and CWE399 datasets, both SAFE and BiVulD performed
significantly better than on the Asterisk and LibTIFF datasets. However, BiVulD
outperforms SAFE in all performance measures. For example, in CWE119,
BiVulD has a Precision of 0.972 and a TPR of 0.864, which are about 33%
and 11% higher than SAFE, respectively. BiVulD also has an FPR of 0.028,
which is 40% lower than SAFE. As a result, BiVulD has a higher F1-score of
0.915, which is about 20% higher than SAFE.
Overall, we suspect BiVulD would outperform SAFE because SAFE is specif-
ically designed for function similarity matching. However, this paper mainly
focuses on binary-level vulnerability detection. Therefore, not every vulnerabil-
ity sample has a fixed version in the training data, which may affect the function
similarity matching for SAFE.
Automatic Software Vulnerability Detection in Binary Code 163
5 Limitations
The BiVulD system has some limitations that we plan to address in the future.
One of these limitations is that it focuses on detecting vulnerabilities in
binary code on x86 architecture. However, with an increasing number of IoT
vendors compiling and deploying third-party code bases across different archi-
tectures such as MIPS, ARM, and PowerPC, it is more crucial than ever to
search for known vulnerabilities in binary code across different architectures.
Therefore, we believe it would be interesting to further develop BiVulD with
transfer learning that can be adapted to detect vulnerabilities across different
CPU architectures.
BiVulD is currently limited to employing an attention model. We plan to
employ other interpretation methods [37,38] and a hierarchical neural network
or Convolutional Neural Network (CNN)-based approaches, where the binary
code is processed into ‘images’ as the input to the CNN.
The scenario of function inlining may pose a challenge for BiVulD as it may
alter the structure of the binary code. To address this issue, we will investigate
and study the effect of function inlining on BiVulD’s performance in the future.
We are planning to retrain a customized embedding model for binary-level
software vulnerability detection. Although this study employed codeBERT to
obtain good embeddings, codeBERT is trained on source code rather than binary
code. The experimental results in this study showed that codeBERT split the
assembly code to learn good feature representations. For instance, consider the
EBP register (i.e., extended base stack pointer), which is especially relevant
for stack buffer overflows in x86/64. CodeBERT splits the EBP token into two
tokens ‘eb’ and ‘p’. Although BiVulD achieved high classification performance,
it is unusual to split a register into two tokens. We aim to develop a customized
164 S. Liu et al.
embedding model that maintains the semantics of the code and improves classi-
fication performance in future.
6 Conclusions
References
1. Lerums, J.E., La’Reshia, D.P., Dietz, J.E.: Simulation modeling cyber threats,
risks, and prevention costs. In: 2018 IEEE International Conference on Elec-
tro/Information Technology (EIT), pp. 0096–0101. IEEE (2018)
2. Bassi, D., Singh, H.: A systematic literature review on software vulnerability pre-
diction models. IEEE Access (2023)
3. Amodei, A., Capriglione, D., Cerro, G., Ferrigno, L., Miele, G., Tomasso, G.: A
measurement approach for inline intrusion detection of heartbleed-like attacks in
IoT frameworks. IEEE Trans. Instrum. Meas. (2023)
4. Hammi, B., Zeadally, S., Nebhen, J.: Security threats, countermeasures, and chal-
lenges of digital supply chains. ACM Comput. Surv. 55 (2023)
5. Bhuiyan, M.H.M., Parthasarathy, A.S., Vasilakis, N., Pradel, M., Staicu, C.-A.:
SecBench. js: an executable security benchmark suite for server-side JavaScript.
In: International Conference on Software Engineering (ICSE) (2023)
6. Xu, G., et al.: SoProtector: safeguard privacy for native so files in evolving mobile
IoT applications. IEEE Internet Things J. 7(4), 2539–2552 (2019)
7. Alrabaee, S., Debbabi, M., Wang, L.: A survey of binary code fingerprinting
approaches: taxonomy, methodologies, and features. ACM Comput. Surv. (CSUR)
55(1), 1–41 (2022)
8. Eschweiler, S., Yakdan, K., Gerhards-Padilla, E.: discovRE: efficient cross-
architecture identification of bugs in binary code. In: NDSS (2016)
9. Yang, Y., Xia, X., Lo, D., Grundy, J.: A survey on deep learning for software
engineering. ACM Comput. Surv. (CSUR) 54(10s), 1–73 (2022)
10. Li, Z., et al.: VulDeePecker: a deep learning-based system for vulnerability detec-
tion. In: 25th Annual Network and Distributed System Security Symposium (NDSS
2018), San Diego, California, USA, 18–21 February 2018 (EI/CCF-A) (2018)
Automatic Software Vulnerability Detection in Binary Code 165
11. Liu, S., Lin, G., Han, Q.-L., Wen, S., Zhang, J., Xiang, Y.: DeepBalance: deep-
learning and fuzzy oversampling for vulnerability detection. IEEE Trans. Fuzzy
Syst. 28(7), 1329–1343 (2019)
12. Grieco, G., Grinblat, G.L., Uzal, L., Rawat, S., Feist, J., Mounier, L.: Toward
large-scale vulnerability discovery using machine learning. In: Proceedings of the
Sixth ACM Conference on Data and Application Security and Privacy, pp. 85–96.
ACM (2016)
13. Lee, Y.J., Choi, S.-H., Kim, C., Park, K.-W.: Learning binary code with deep
learning to detect software weakness. In: KSII The 9th International Conference
on Internet 2017 Symposium, pp. 245–249 (2017)
14. Le, T., Nguyen, T., Le, T., Phung, D., Montague, P., De Vel, O., Qu, L.: Maximal
divergence sequential autoencoder for binary software vulnerability detection. In:
International Conference on Learning Representations (2018)
15. Liu, S., Dibaei, M., Tai, Y., Chen, C., Zhang, J., Xiang, Y.: Cyber vulnerability
intelligence for internet of things binary. IEEE Trans. Industr. Inf. 16(3), 2154–
2163 (2019)
16. Nguyen, T., et al.: Deep cost-sensitive kernel machine for binary software vulner-
ability detection. In: Pacific-Asia Conference on Knowledge Discovery and Data
Mining, pp. 164–177. Springer (2020)
17. Feng, Z., et al.: CodeBERT: a pre-trained model for programming and natural
languages. arXiv preprint arXiv:2002.08155 (2020)
18. Hajra, S., Alam, M., Saha, S., Picek, S., Mukhopadhyay, D.: On the instability
of softmax attention-based deep learning models in side-channel analysis. IEEE
Trans. Inf. Forensics Secur. 19, 514–528 (2024)
19. Chen, Y., et al.: Bookworm game: Automatic discovery of LTE vulnerabilities
through documentation analysis. In: 2021 IEEE Symposium on Security and Pri-
vacy (SP). IEEE (2021)
20. Steenhoek, B., Rahman, M.M., Jiles, R., Le, W.: An empirical study of deep learn-
ing models for vulnerability detection. In: 2023 IEEE/ACM 45th International
Conference on Software Engineering (ICSE), pp. 2237–2248. IEEE (2023)
21. Yang, S., et al.: Asteria-pro: Enhancing deep-learning based binary code similarity
detection by incorporating domain knowledge. ACM Trans. Softw. Eng. Methodol.
(2023)
22. Jiang, L., et al.: BinaryAI: binary software composition analysis via intelligent
binary source code matching. In: Proceedings of the IEEE/ACM 46th International
Conference on Software Engineering, pp. 1–13 (2024)
23. Li, X., Qu, Y., Yin, H.: PalmTree: learning an assembly language model for instruc-
tion embedding. In: Proceedings of the 2021 ACM SIGSAC Conference on Com-
puter and Communications Security, pp. 3236–3251 (2021)
24. Zuo, F., Li, X., Young, P., Luo, L., Zeng, Q., Zhang, Z.: Neural machine translation
inspired binary code similarity comparison beyond function pairs (2019)
25. Li, Z., Wang, J., Sun, M., Lui, J.C.S.: MirChecker: detecting bugs in rust pro-
grams via static analysis. In: Proceedings of the 2021 ACM SIGSAC Conference
on Computer and Communications Security, pp. 2183–2196 (2021)
26. Batur Şahin, C., Abualigah, L.: A novel deep learning-based feature selection model
for improving the static analysis of vulnerability detection. Neural Comput. Appl.
33(20), 14049–14067 (2021). https://doi.org/10.1007/s00521-021-06047-x
27. Jiang, S., Fu, C., Qian, Y., He, S., Lv, J., Han, L.: IFAttn: binary code similarity
analysis based on interpretable features with attention. Comput. Secur. 102804
(2022)
166 S. Liu et al.
28. Wang, Y., Jia, P., Peng, X., Huang, C., Liu, J.: BinVulDet: detecting vulnerability
in binary program via decompiled pseudo code and biLSTM-attention. Comput.
Secur. 125, 103023 (2023)
29. Pei, K., Xuan, Z., Yang, J., Jana, S., Ray, B.: Trex: learning execution semantics
from micro-traces for binary similarity. arXiv preprint arXiv:2012.08680 (2020)
30. Fan, Y., Wan, C., Cai, F., Han, L., Hao, X.: VDoTR: detection based on tensor
representation of comprehensive code graphs. Comput. Sevulnerabil. Cur. 130,
103247 (2023)
31. Lin, G., Xiao, W., Zhang, L.Y., Gao, S., Tai, Y., Zhang, J.: Deep neural-based
vulnerability discovery demystified: data, model and performance. Neural Comput.
Appl. 33(20), 13287–13300 (2021)
32. Mashhadi, E., Hemmati, H.: Applying codeBERT for automated program repair
of java simple bugs. arXiv preprint arXiv:2103.11626 (2021)
33. Lin, G., et al.: Cross-project transfer representation learning for vulnerable function
discovery. IEEE Trans. Industr. Inf. 14(7), 3289–3297 (2018)
34. Cieslak, M.C., Castelfranco, A.M., Roncalli, V., Lenz, P.H., Hartline, D.K.: t-
distributed stochastic neighbor embedding (t-SNE): a tool for eco-physiological
transcriptomic analysis. Marine Genom. 51, 100723 (2020)
35. Xu, X., Liu, C., Feng, Q., Yin, H., Song, L., Song, D.: Neural network-based graph
embedding for cross-platform binary code similarity detection. In: Proceedings of
the 2017 ACM SIGSAC Conference on Computer and Communications Security,
pp. 363–376. ACM (2017)
36. Massarelli, L., Di Luna, G.A., Petroni, F., Baldoni, R., Querzoni, L.: SAFE: self-
attentive function embeddings for binary similarity. In: International Conference on
Detection of Intrusions and Malware, and Vulnerability Assessment, pp. 309–329.
Springer (2019)
37. Li, L., et al.: VulAnalyzer: explainable binary vulnerability detection with multi-
task learning and attentional graph convolution. ACM Trans. Priv. Secur. 26(3),
1–25 (2023)
38. Nauta, M., et al.: From anecdotal evidence to quantitative evaluation methods:
a systematic review on evaluating explainable AI. ACM Comput. Surv. 55(13s),
1–42 (2023)
Malicious Code Detection Based
on Generative Adversarial Model
1 Introduction
The swift increase in malware variants, along with their rapid spread and stealthy
nature, poses substantial challenges to conventional malware detection tech-
niques. Malware developers often employ code obfuscation techniques to gener-
ate new variants, hiding certain features to evade detection. Deep learning tech-
niques offer a promising direction for tackling these issues in malware detection.
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 167–183, 2025.
https://doi.org/10.1007/978-981-96-4566-4_12
168 J. Zhang et al.
However, training deep learning models requires large datasets, and insufficient
data or imbalanced classes can lead to increased time consumption. Certainly,
numerous experimental studies have integrated GAN-related models with mali-
cious code [1, 2].
To address the issues of class imbalance and insufficient data in malware
datasets, we consider using GAN to augment the dataset [3–5]. GAN provides a
powerful machine learning model capable of learning and generating the distribu-
tion of malware features. It can be used to generate sufficient samples to balance
our dataset. GAN can effectively augment malware datasets, saving training
costs and time while improving the accuracy of malicious code detection.
In this paper, we propose the use of GAN [6], WGAN [7], and WGAN-
GP [8] models to augment datasets related to malicious code. We selected the
Malimg dataset and the big2015 dataset for augmentation. Subsequently, we used
the original datasets and the augmented datasets (m1, m2, m3 augmented by
Malimg and b1, b2, b3 augmented by big2015) to train a ResNet50 network [9],
aiming to enhance its detection capabilities. In summary, we make the following
contributions:
2 Background
Pioronski [10] employed WGAN-GP to generate samples. This approach aimed to
address the challenge of data class imbalance and to enhance the classifier’s effi-
cacy in identifying malicious programs. Wan [11] utilized the Web Feature Sam-
ple Generation Adversarial Network (WFS-GAN). This network was designed
to create feature samples of malicious web pages. It effectively tackled the diffi-
culties associated with collecting samples and the extensive task of annotating a
large number of samples. Zhu [12] integrated image processing techniques with
the WGAN-GP model. They applied this combination to augment the malicious
code dataset. The researchers compared the recognition rates of various classi-
fier models, including decision trees. Their findings demonstrated that the aug-
mented dataset could boost the performance of classifiers. Moreover, it helped to
mitigate the issue of data imbalance across different categories within the original
dataset. Certainly, numerous experimental studies have integrated GAN-related
models with malicious code detection.
We uses generative adversarial networks to augment the malware dataset,
which consists of two components: a discriminator model for judging, and a gen-
erator model for producing data. GANs achieve model learning through adversar-
ial training. Within the GAN framework, two models are simultaneously trained:
one is the generator model G, which learns the distribution of data, and the other
is the discriminator model D, which distinguishes between real samples and fake
samples generated by G. The training process of generator G aims to maximize
the probability of discriminator D making mistakes, specifically by misclassifying
generated fake data as real samples. This means during training, the discrimina-
tor D is trained to distinguish real samples from generated ones as accurately as
possible, while the generator G is trained to produce increasingly realistic fake
samples to deceive D.
In this minimax game, there exists a unique equilibrium solution where the
generator G can generate a distribution similar to that of real samples, and
the discriminator D achieves a probability of 1/2 when making judgments. This
process can be expressed using formula (1):
min max V (D, G) = Ex∼pdata (x) [log D(x)] + Ez∼pat(x) [log(1 − D(G(z)))] (1)
G D
3 Model
To address class imbalance in malicious code datasets, this paper uses GAN,
WGAN, and WGAN-GP to augment the dataset. The augmented data trains a
ResNet50 model, which is then used for detection and testing. The design of the
GAN-based detection scheme is shown in Fig. 1.
170 J. Zhang et al.
Generator Implementation. The WGAN generator has five linear layers, each fol-
lowed by LeakyReLU and ending with a Tanh activation. The 100-dimensional
noise vector is transformed through dimensions of 128, 256, 512, and 1024, with
normalization and activations at each step. The final layer outputs 4096 dimen-
sions, using Tanh to map values between −1 and 1, helping mitigate the gradient
vanishing issue and ensuring fault tolerance during generation.
Loss Function and Gradient Clipping. The generator minimizes the WGAN loss,
while the discriminator maximizes it. The loss functions for the generator and
discriminator are shown in Eq. (2). RMSProp is used for optimization, and
gradient clipping is applied to the discriminator to prevent gradient explosion.
The clamp() method is typically used to constrain the gradients within a specified
range, ensuring training stability.
This paper uses the ResNet50 network model as the target detection algorithm
because the ResNet50 model’s input and output are both images and it can train
deeper neural networks. It employs a residual structure to preserve original fea-
tures, enhancing both the precision and generalization ability of the network
172 J. Zhang et al.
model. In this section, the augmented dataset obtained through adversarial net-
work model enhancement and the original dataset are both used to train the
ResNet50 model. After training, the corresponding model is used for detection
on the test set.
The input to the ResNet50 model undergoes 5 stages to obtain the final
output, passing through 49 convolutional layers and 1 fully connected layer. The
overall network structure of ResNet50 is shown in Fig. 3.
4 Experiment
First, we implemented GAN, WGAN, and WGAN-GP models to augment the
dataset, resulting in six datasets: m1, m2, m3 for Malimg, and b1, b2, b3 for
BIG2015. Each dataset contains approximately 1K samples. Then, these datasets
were all fed into ResNet50 for training. We demonstrate the effectiveness of
enhancing datasets with different GAN models. We utilize each trained ResNet
model for malicious code detection and evaluate metrics such as Accuracy, Recall,
Precision, and F1 score. We use these evaluation metrics to compare and assess
the performance improvement of ResNet in detecting malicious code after aug-
menting the dataset with various GAN models.
4.1 Dataset
We use the Malimg dataset and the Microsoft Malware Classification Challenge
(BIG2015) dataset for our experiments. Specifically:
Malicious Code Detection Based on Generative Adversarial Model 173
Read. We first need to read the hexadecimal text of the bytes file, ignoring
invalid characters and address information in each line. In this step, we remove
unnecessary information and keep only the malicious bytes from the file to turn
into a grayscale image.
Fill. In the second step, we determine the image width according to the byte
file size, and then fill the byte data according to the determined width to obtain
the grayscale image corresponding to the binary file.
Additionally, in our model, we used the Malimg dataset and BIG2015 dataset
to train GAN, WGAN, and WGAN-GP. And then used the trained model to
generate appropriate gray-scale images to supplement the dataset. This makes
the number of gray-scale images of malicious code in the test set similar to the
number of gray-scale images of non-malicious code, to better improve the test
performance of ResNet50.
The ratio of the number of gray-scale images between the malicious code
category and non-malicious code category of the enhanced dataset is 1:1, and
174 J. Zhang et al.
secondly, the training dataset and the test dataset both have about 1.5K gray-
scale images.
Training GANs begins by loading pre-trained discriminator and generator
models. Then, we proceed with the training loop. In each epoch, we go through
each batch from the data loader. For each batch, we first update the discrim-
inator: the generator creates samples using random noise, and these generated
samples along with real samples (malicious grayscale images) are fed into the
discriminator. We calculate the discriminator’s loss function and perform back-
propagation to update the parameters. After that, we train the generator: again,
the generator creates samples using random noise, and after calculating the gen-
erator’s loss function, we update the generator. We use the Adam optimizer to
update the parameters of both the discriminator and the generator, and binary
cross-entropy loss is used to calculate the relevant losses.
For WGAN, we also train the discriminator, but the difference is that we
clip the parameters to restrict the weight range when updating them. After
training the discriminator for n_critic times, we update the generator once. The
optimization algorithm used in this process is RMSProp.
The training process for WGAN-GP is similar to WGAN, but the difference
is that after calculating the scores for real and generated samples, WGAN-GP
needs to calculate a gradient penalty term and add it to the discriminator’s
loss function for backpropagation to update the discriminator. The optimization
algorithm used here is Adam.
In the training of GAN, WGAN, and WGAN-GP, the Malimg dataset and
the BIG2015 dataset are divided into a training set and a test set. The ratio of
the number of grayscale images in the training set and the data set is 1:1, and
the number is 1K in both sets.
The parameters for GAN, WGAN, and WGAN-GP are shown in Table 1.
In the training of ResNet model, the augmented the Malimg dataset and
BIG2015 dataset are divided into a training set and a test set. The ratio of the
number of grayscale images in the training set and the data set is 1:1, and the
number is 1.5K in both sets. The ratio of gray-scale images of malicious code to
gray-scale images of non-malicious code in the training set and the data set is
12:5. After each training epoch, the training set is used to test the values of the
loss function and accuracy. If the detected accuracy increases, the current model
file is saved.
Malicious Code Detection Based on Generative Adversarial Model 175
During the training process, after selecting suitable GAN, WGAN, and
WGAN-GP models, their respective generators were used to generate 2500
grayscale images of malicious code for each of the two datasets to augment
their quantities. Before augmentation, the ratio of grayscale images between
the datasets was 10:17. Following augmentation, the dataset was enriched with
200 generated malicious code images and 500 images from non-malicious code
classes. The post-augmentation ratio between grayscale images of malicious code
and non-malicious code classes became 12:5.
176 J. Zhang et al.
ResNet50 Performance. According to Figs. 12, 13, 14, 15, 16 and 17, select
the ResNet50 model with convergent loss functions. We assess ResNet50’s per-
formance by measuring its accuracy and training efficiency before and after data
augmentation. The relevant performance metrics can be found in Table 4. Table 4
demonstrates that training the ResNet50 model with a dataset augmented by
Generative Adversarial Networks leads to an improvement in classification accu-
racy compared to training with the original dataset. Furthermore, by comparing
the model training times in Table 4 before and after data augmentation, it is
observed that the efficiency of the model has also been improved.
In training the ResNet50 network model on the Malimg dataset, by observ-
ing the loss function value changes and related performance tables, the model
trained with GAN performs better, while the performance of models trained with
WGAN and WGAN-GP is relatively poor. Similarly, in the BIG2015 dataset,
by observing the loss function value changes and performance metrics tables, it
is also determined that the model trained with GAN performs better, while the
performance of models trained with WGAN and WGAN-GP is relatively poor.
Fig. 4. Loss change plot of GAN Fig. 5. Loss change plot of WGAN
trained on Malimg dataset. trained on Malimg dataset.
Fig. 6. Loss change plot of WGAN-GP Fig. 7. Loss change plot of GAN
trained on Malimg dataset. trained on BIG2015 dataset.
Fig. 8. Loss change plot of WGAN Fig. 9. Loss change plot of WGAN-GP
trained on BIG2015 dataset. trained on BIG2015 dataset.
Fig. 10. Loss change plot for training Fig. 11. Loss change plot of ResNet50
ResNet50 on the unaugmented Malimg trained on the unaugmented BIG2015
dataset. dataset.
178 J. Zhang et al.
Fig. 12. Loss change plot of ResNet50 Fig. 13. Loss change plot of ResNet50
trained on Malimg dataset augmented trained on Malimg dataset augmented
with GAN model. with WGAN model.
Fig. 14. Loss change plot of ResNet50 Fig. 15. Loss change plot of ResNet50
trained with Malimg dataset aug- trained on the BIG2015 dataset aug-
mented with WGAN-GP model. mented with WGAN-GP model.
Fig. 16. Loss change plot of ResNet50 Fig. 17. Loss change plot of ResNet50
trained on BIG2015 dataset augmented trained on BIG2015 dataset augmented
with WGAN model. with WGAN-GP model.
When the Malimg dataset without data augmentation is used to train the
ResNet50 network model, the accuracy of detecting related malware datasets is
82.35%. When the BIG2015 dataset without data augmentation is used to train
the ResNet50 network model, the accuracy of detecting related malware datasets
is 82.23%.
Malicious Code Detection Based on Generative Adversarial Model 179
After training the generative adversarial network models, GAN, WGAN, and
WGAN-GP models were used to generate a certain number of malware grayscale
images to augment the Malimg and BIG2015 datasets. As shown in Table 4, in
the Malimg dataset, augmenting the dataset with adversarial network models can
indeed improve the detection capability of the ResNet50 network model to some
extent. Among them, the GAN model has the best improvement effect, increasing
the detection accuracy of ResNet50 to 95.18%. The effects of WGAN-GP and
WGAN are slightly inferior to GAN, with WGAN-GP performing slightly better
than WGAN. In the BIG2015 dataset, the performance improvement of the
ResNet50 model is similar. The GAN model again has the best improvement
effect, increasing the detection accuracy of ResNet50 to 89.46%, followed by
WGAN-GP, with WGAN having the relatively poorer effect.
Fig. 18. ROC and PR charts for the ResNet50 trained on the original Malimg dataset.
Fig. 19. ROC and PR charts for the ResNet50 trained on the original BIG2015 dataset
Fig. 20. ROC and PR charts for the ResNet50 trained on the Malimg dataset aug-
mented with GAN
Malicious Code Detection Based on Generative Adversarial Model 181
Fig. 21. ROC and PR charts for the ResNet50 trained on the BIG2015 dataset aug-
mented with GAN
5 Conclusions
The paper implements three types of generative adversarial models. Subse-
quently, each model is used to enhance the Malimg dataset and the BIG2015
malicious dataset separately. The augmented datasets are then fed into the
ResNet50 model for malicious code detection, and the detection accuracy before
and after augmentation is compared. The experiments demonstrate that genera-
tive adversarial networks can indeed improve the performance of object detection
algorithms, with GAN showing the best enhancement effect among the three
models.
References
1. He, K., Zhang, X., Ren, S., et al.: Deep residual learning for image recognition. In:
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition,
pp. 770–778 (2016)
2. Yang, Z., Deng, F., Han, L.: Flexible android malware detection model based on
generative adversarial networks with code tensor. In: 2022 International Conference
on Cyber-Enabled Distributed Computing and Knowledge Discovery (CyberC), pp.
19–28. IEEE (2022)
3. Bac, T.N., Duy, P.T., Pham, V.: PWDGAN: generating adversarial malicious URL
examples for deceiving black-box phishing website detector using GANs. In: 2021
IEEE International Conference on Machine Learning and Applied Network Tech-
nologies (ICMLANT), pp. 1–4. IEEE (2021)
4. Wang, J., Liu, M., Yin, X., et al.: Semi-supervised malicious traffic detection with
improved wasserstein generative adversarial network with gradient penalty. In:
2022 IEEE 6th Advanced Information Technology, Electronic and Automation
Control Conference (IAEAC), pp. 1916–1922. IEEE (2022)
5. Hu, W., Cheng, J., Chong, X., et al.: A GAN-based anti-obfuscation detection
method for malicious code. In: 2022 3rd International Conference on Pattern Recog-
nition and Machine Learning (PRML), pp. 484–488. IEEE (2022)
Malicious Code Detection Based on Generative Adversarial Model 183
6. Goodfellow, I., Pouget-Abadie, J., Mirza, M., et al.: Generative adversarial nets.
Adv. Neural Inf. Process. Syst. 27 (2014)
7. Arjovsky, M., Chintala, S., Bottou, L.: Wasserstein generative adversarial networks.
In: International Conference on Machine Learning, pp. 214–223. PMLR (2017)
8. Gulrajani, I., Ahmed, F., Arjovsky, M., et al.: Improved training of Wasserstein
GANs. Adv. Neural Inf. Process. Syst. 30 (2017)
9. Wang, A., Ding, Y.: Network malicious traffic identification method based on
WGAN category balancing. In: 2021 IEEE International Conference on Signal
Processing, Communications and Computing (ICSPCC), pp. 1–6. IEEE (2021)
10. Pioroński, S., Górecki, T.: Using GAN to generate malicious samples suitable for
binary classifier training. In: 2022 IEEE International Conference on Big Data (Big
Data), pp. 6522–6527. IEEE (2022)
11. Wan, M., Yao, H., Yan, X.: Generation of malicious webpage samples based on
GAN. In: 2020 IEEE 19th International Conference on Trust, Security and Privacy
in Computing and Communications (TrustCom), pp. 864–869. IEEE (2020)
12. Zhu, X., Qian, L., Fu, W.: Method of enhancing malicious code based on generative
adversarial network. Comput. Eng. Design 42(11), 3032–3042 (2021)
13. Nataraj, L., Karthikeyan, S., Jacob, G., Manjunath, B.S.: Malware images: visu-
alization and automatic classification. In: Proceedings of the 8th International
Symposium on Visualization for Cyber Security, pp. 1–7 (2011)
14. Agarap, A.F.: Towards building an intelligent anti-malware system: a deep learning
approach using support vector machine (SVM) for malware classification. arXiv
preprint arXiv:1801.00318 (2017)
Construction of an AI Code Defect Detection
and Repair Dataset Based on Chain of Thought
Huimin Gong1 , Zongliang Shen1(B) , Hua Zhang1 , Lei Qiao2 , Huawei Wang2 ,
and Chi Zhang3
1 State Key Laboratory of Networking and Switching Technology, Beijing University of Posts
and Telecommunications, Beijing, China
{shenzongliang,zhanghua_288}@bupt.edu.cn
2 Beijing Institute of Control Engineering, Beijing, China
3 Information Center of China, North Industries Group Corporation, Beijing, China
Abstract. When detecting and repairing code defects, enhancing the generaliza-
tion ability and detection accuracy of models is a key challenge. This paper pro-
poses a data fine-tuning method based on Chain of Thought (CoT) fine-tuning to
improve the capabilities of models on defect detection in AI code. We constructed
a dataset that includes the CrossVul dataset and a manually created dataset of AI
code defects and repairs, improving data quality through techniques like context
free removal. In the experiments, we used the Codeshell-7B, Qwencoder2.5-7B
and Llama3.1-7B as the base and trained them using LoRA fine-tuning techniques.
We compared different datasets and training methods to verify the model’s effec-
tiveness in detecting and repairing AI code defects. The results show that the CoT
fine-tuning model outperforms models without CoT fine-tuning in all aspects of
handling code defect tasks. Additionally, the specialized dataset we created for AI
code defects and repairs significantly enhances the model’s accuracy and repair
rate in AI code detection. Our experiments highlight the importance of construct-
ing targeted datasets for AI code defects and employing CoT fine-tuning strategies
in improving code defect detection.
1 Introduction
Code defects are inevitable issues in software development, potentially leading to sys-
tem crashes, functional failures, or even security vulnerabilities. As modern software
scales and code complexity increases rapidly, traditional code defect detection methods
are becoming inadequate. Static analysis [1] and dynamic analysis [2] are two main
traditional approaches for detecting code security defects. Although static analysis can
quickly identify potential defects, it often has a high false positive rate and limited under-
standing of the actual execution behavior of the code. Dynamic analysis relies on runtime
information, which can capture actual defects but is inefficient and often provides poor
repair suggestions. Consequently, traditional methods struggle to meet expectations in
large-scale complex systems.
© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 184–196, 2025.
https://doi.org/10.1007/978-981-96-4566-4_13
Construction of an AI Code Defect Detection 185
Recently, generative AI, particularly large language models (LLMs), has achieved
significant success in natural language processing (NLP). Given their outstanding perfor-
mance in tasks such as text generation, translation, and question answering, researchers
have started applying these models to detect and repair code defects. LLMs, with their
powerful contextual understanding and generation capabilities, theoretically have the
potential to effectively capture complex logical defects in code and provide accurate
repair suggestions. However, the high specificity of AI code presents significant chal-
lenges. Firstly, AI code datasets are mostly close-source and are not availability through
public access, making relevant data very scarce. Additionally, existing large models
typically perform poorly on vulnerabilities related to AI code, failing to generalize
effectively. This exposes the limitations of current large models in code defect detection
and repair tasks. Moreover, quality issues of current dataset further constrain the perfor-
mance of models. In large code datasets, the presence of redundant context, noise, and
irrelevant code makes it difficult for models to focus on the actual defect areas. Mean-
while, existing prompt design methods often struggle to guide models in generating
effective repair suggestions.
This paper aims to explore the application of large language models in detecting and
repairing AI code defects, constructing and processing diverse datasets to optimize the
fine-tuning process and ensure the model is exposed to a wide range of code defect types.
This paper primarily adopts two key dataset processing methods: context free removal
and Chain of Thought (CoT) prompting. These dataset processing methods gradually
enhance the model’s perception of code defects during fine-tuning, enabling the model to
have strong defect identification and repair capabilities when handling unseen code snip-
pets. Ultimately, a series of experiments were conducted to evaluate the performance of
the large language model fine-tuned with different dataset processing methods, verifying
the effectiveness of the proposed methods and the model’s generalization capabilities.
2 Background
2.1 Code Defect Identification and Repair
Traditional code defect identification methods include static analysis and dynamic anal-
ysis. Static analysis checks the source code to identify potential security defects, and
due to its efficiency and low cost, it is widely used in large-scale systems. However,
it exhibits certain limitations [3] when dealing with complex problems, as it cannot
accurately determine the execution environment of code in different contexts. It is also
diffcult to effectively handle detection tasks involving complex structures like loops
and branches. In contrast, dynamic analysis monitors software behaviors during execu-
tion, allowing it to capture code defects under different execution environments more
accurately. However, since it requires software execution, it can consume significant
resources in large-scale systems and cannot comprehensively cover all paths, which
limits its applicability.
Machine learning models learn from large amounts of labeled code defect data,
which allows them to automatically extract features and identify known vulnerability
types. Unlike traditional detection methods, methods based on machine learning are not
limited to manually defined rules and can discover vulnerabilities by learning patterns
186 H. Gong et al.
in data. For example, Raychev et al. proposed a statistical learning-based code analysis
tool [4], which significantly improves the automation of code defect detection. However,
machine learning depends on large amounts of labeled data and struggles with complex
contextual relationships in code. These models often perform poorly when encountering
unseen code snippets or new types of vulnerabilities, as they lack sufficient generalization
ability.
Meanwhile,large language models(LLMs) show great potential in code defect detec-
tion and repair. Chen et al. studied the ability of LLMs to understand and generate code
[5], offering insights into using LLMs for code-related tasks. Ribeiro et al. proposed a
technique using GPT-3 to automatically repair type errors in programs, which achieves
good results on Ocaml programs [6]. Keller et al. proposed the Gemini model [7], which
automates the process from vulnerability discovery to repair code generation and testing.
With the deepening of these studies, more and more evidence suggests that LLMs have
enormous potential in code defect detection and repair.
strategy in the fine-tuning process. LoRA reduces the number of parameters that need
to be updated through low-rank matrix decomposition, thereby lowering computational
costs and resource requirements. This method is particularly suitable for researches with
limited resources, providing new ideas for efficient model fine-tuning [10].
The purpose of prompt engineering is to design and use of prompts to guide large models
in performing specific tasks. These prompts can be natural language descriptions, code
snippets, or specific instructions for the task, helping the model understand the task
requirements and generate appropriate outputs. Prompt engineering is widely used in
NLP and code generation tasks and has proven effective [11]. In code defect detection
and repair tasks, prompt engineering can generate clear prompts so that the model can
identify and fix code defects better. This improves task efficiency and accuracy [12].
However, traditional prompt engineering has limitations, especially when dealing with
multi-step reasoning tasks. Models often rely on pre-defined task formats. When the
problem changes slightly or multiple problems are combined, the performance of model
may not meet expectations. This rigid structure makes it difficult for models to flexibly
respond to complex contexts, leading to inadequate generalization. To address this issue,
the Chain of Thought (CoT) reasoning method has been widely applied in generative
AI in recent years. Wei et al. applied the Chain-of-Thought (CoT) reasoning method to
natural language reasoning tasks [13]. They demonstrated that this method significantly
improves the ability of large language models to handle reasoning tasks. Although the
focus of this paper is on natural language processing, the approach has also been applied
to the code generation domain. Yang et al. applied CoT to code completion [14], where
step-by-step reasoning helped generate high-quality and logically consistent code. This
method can be effectively transferred to code defect repair tasks, enhancing the accuracy
and logical consistency of repairs through the gradual analysis of errors in the code.
3 Methodology
This paper utilizes the CrossVul dataset [14] and a manually constructed AI code defect
and repair dataset to enhance the model’s performance in detecting and repair code
defects while addressing the model’s shortcomings in understanding and handling AI-
generated code. To further optimize the application of our dataset, this paper proposes
two key processing methods: irrelevant context removal and Chain of Thought prompting
to enhance the model’s ability to perceive and repair code defects.
In this paper, two key datasets were used for model training and fine-tuning: the CrossVul
dataset and a manually constructed AI code defect and repair dataset, details are as
follows.
188 H. Gong et al.
1. CrossVul Dataset
CrossVul is a publicly available dataset specifically designed for cross-language code
vulnerability detection. It collects code snippets from various projects in different
programming languages containing known security vulnerabilities from real-world
scenarios and extensive repositories. Each vulnerability sample is accompanied by
a detailed description, including the type of vulnerability, scope of impact, and cor-
responding secure code. As shown in Fig. 1, the dataset has a hierarchical structure
where the top-level directory is named after CWE-ID, differentiating various vulner-
ability types. The second-level directory is divided by programming language, and
the files are named to indicate whether the code snippet is vulnerable (bad) or secure
(good), organized into pairs by number. CrossVul contains over 40 programming
languages, 168 major types, and 13,738 vulnerability samples.
Context Free Removal. When processing the CrossVul dataset, we observed that each
vulnerability sample and its corresponding secure code often included context code unre-
lated to the defect. These irrelevant code not only interfered with the training process of
the model but also reduced the precision of defect detection. Therefore, this paper pro-
poses a context free removal method to enhance the dataset’s effectiveness and simplify
the model’s learning task. As shown in Fig. 2, the steps for context free removal are as
follows.
1. Function Extraction
Each vulnerability sample typically consists of multiple functions or code seg-
ments. The first step involves extracting each function or logical code block from the
sample using an Abstract Syntax Tree (AST) parser.
2. Function Name and Parameter Matching
The extracted functions are parsed syntactically using the AST parser to identify
those related to the defect or its fix. The CrossVul dataset contains pairs of vulnera-
bility and repair code samples, and in most cases, the defect and fix occur within the
same function. By matching function names and parameters, we can filter out code
blocks directly related to the defect.
3. Semantic Analysis
The relevant functions undergo semantic analysis to determine which code blocks
are directly linked to the root cause of the defect. Code segments not involving the
defects (e.g., helper functions, initialization code) are marked as “irrelevant context”,
which will be eliminated in step 5.
4. Difference Comparison
The differences between the defect code and the repair code are compared to
extract the specific lines of code involved in the fix. This step uses the code comparison
method from diff tools to extract the modified parts from the repair and defect files,
ensuring the correspondence between defect and repair codes. Irrelevant code can be
excluded after this comparison.
5. Irrelevant Code Elimination
After extracting the core code related to the defect, irrelevant context is removed.
This step primarily relies on analyzing differences, function dependencies, and vari-
able usage to ensure that the remaining code focuses on the parts directly related to
the defect.
6. Code Reorganization and Formatting
After step 5, the remaining code fragments are reorganized and formatted to
ensure that each sample retains syntactical completeness, facilitating normal input
and learning for the model.
Chain of Thought Prompting Design. Building upon the context free removal, this
paper further adopts a chain-of-thought prompting design to enhance the model’s rea-
soning ability. The core of Chain-of-thought prompting is to guide the model through a
step-by-step process, encouraging deep thinking during tasks like multi-step reasoning.
This method allows the model to simulate human reasoning when faced with complex
code defects, making it more effective at solving problems. As shown in Pseudocode
190 H. Gong et al.
Example of Error Analysis and Fixing Process, we added guiding information to the
dataset to prompt the model, including:
Problem Description: In order to help the model understand the nature of the error,
defects and fixs of each sample are accompanied by a brief and clear description. For
example, for a vulnerability involving incorrect data type input, the description might
state, “This code has a data type mismatch in the data processing API compatibility
error.” Such descriptions help the model focus on specific defect types during analysis.
Step-by-Step Reasoning Process: We designed step-by-step prompts to help the model
gradually derive the repair solution. This process aims to guide the model to identify the
type and location of the vulnerability, analyze possible solutions, and ultimately generate
repair code. For example, in code analysis, the model is first prompted to identify the
type of vulnerability, then to locate the vulnerability triggers in source code, and finally
to propose a feasible repair measure. This “chain-of-thought” prompting design allows
the model to understand and repair complex code defects in stages.
Code Examples and Repair Suggestions: In order to help the model better understand
the problem, we also provided relevant code snippets and repair suggestions after each
prompt. By integrating real code contexts, the model can more accurately infer the
repair process. The example code includes the original faulty code and its corresponding
corrected version, allowing the model to reason within a clear context and generate the
correct repair solution.
Construction of an AI Code Defect Detection 191
The specific steps for processing the dataset are as follows, using the CrossVul dataset
as an example.
1. CWE-ID Classification
The CrossVul dataset is classified according to CWE (Common Weakness Enumer-
ation) numbers, with each CWE-ID representing a specific type of code defect. This
classification method allows for systematic organization and management of different
192 H. Gong et al.
4 Experiments
The goal of this experiment is to evaluate the performance of different models in AI code
defect detection and repair tasks, with a focus on comparing the impact of the Chain-of-
Thought reasoning method. Additionally, the experiment explores the improvement in
the model’s ability to detect AI-specific defects through the constructed AI defect repair
dataset. For a comprehensive assessment, the experiment compares the performance of
the Codeshell-7B, Qwencoder2.5-7B, and Llama3.1-7B models.
Construction of an AI Code Defect Detection 193
AI-specific code defects and improve accuracy and logical consistency in code defect
repair tasks.
The fine-tuning process for each model was carried out using the official fine-tuning
script provided by the respective models (e.g., run_finetune.sh for Codeshell-7B), with
the following parameters:
Per Device Train Batch Size: 2 (moderately increases training speed).
Gradient Accumulation Steps: 4 (balances computational resources and training
efficacy).
Gradient Checkpointing: Enabled (reduces memory usage).
Learning Rate Scheduler Type: Cosine (gradually decreases learning rate).
Logging Steps: 50 (logs training progress every 50 steps).
Save Steps: 500 (saves the model every 500 steps to reduce storage frequency).
Learning Rate: 2e-5 (ensures training stability).
Number of Training Epochs: 5 (ensures sufficient training).
In evaluating the models on the test set, we used the following metrics to assess
performance in AI code defect detection and repair tasks:
Acc (Accuracy): Measures the proportion of correctly identified defects in the code. For
each test sample, if the model correctly identifies the defect location, it is considered
accurate.
Repair: Assesses the proportion of correct repair suggestions compared to the actual fixes.
For each sample where a defect is identified, the generated repair code is compared with
the actual fix, and a match is considered a successful repair.
BLEU: A text similarity measure based on n-grams, evaluating the degree of match
between generated code and reference code, with a score ranging from 0 to 1, where
1 indicates a perfect match. This metric compares the generated repair code with the
actual fix for each sample, calculating the n-gram match.
Exec. Correct (Execution Correctness): Measures whether the repaired code functions
correctly when executed in a real programming environment. This metric assesses
whether the generated repair code successfully resolves the defects in the original code.
Our experiments aimed to: (1) evaluate and compare the model’s performance under
different dataset compositions and processing strategies, and (2) analyze the effectiveness
of chain-of-thought fine-tuning in code defect detection tasks.
We formatted the test set data into model input prompts using pre-written scripts
and then fed them into both the baseline and fine-tuned models, including the traditional
static defect detection method. We then compared the outputs of each model. Table 2
summarizes the performance of each model on the test set, listing their performance on
the CrossVul dataset (A) and the manually constructed AI defect repair dataset (B). This
includes the performance of Codeshell-7B, Qwencoder2.5-7B, Llama3.1-7B, as well as
the Traditional Static Defect Detection method for comparison.
Construction of an AI Code Defect Detection 195
Table 2. Performance of Different Models on Different Test Sets. The CrossVul dataset extracts
10% of its data (147 samples) as Test Set A, while 10% of the data from the manually constructed
AI defect repair dataset (88 samples) is used as Test Set B.
After training on the datasets, all models generally outperformed the baseline. How-
ever, models fine-tuned directly with the CrossVul dataset occasionally underperformed,
especially when dealing with complex AI code defects. In contrast, applying chain-of-
thought to fine-tune the model with the CrossVul dataset significantly improved the
model’s ability to detect general code. It also enhanced its detection capability for AI
code to some extent. Adding the manually constructed AI code defect and repair dataset
significantly improved the model’s performance on AI code defect detection and repair
tasks. The model fine-tuned with the combination of CrossVul and AI datasets, along
with chain-of-thought prompts, exhibited the best performance.
5 Conclusion
In this paper, we explored the task of code defect detection and repair using large mod-
els, with a particular focus on the impact of chain-of-thought fine-tuning on model
performance. We constructed a dataset comprising the CrossVul dataset and a manually
created AI code defect and repair dataset, fine-tuning the model with various datasets and
processing methods to evaluate its performance across different contexts. Experimen-
tal results indicate that chain-of-thought fine-tuning significantly improves the model’s
accuracy, repair rate, BLEU score and execution correctness. Models fine-tuned with
the AI code defect and repair dataset demonstrated excellent performance in handling
complex AI code defect detection tasks. In contrast, models trained on the unprocessed
196 H. Gong et al.
CrossVul dataset and the original baseline model underperformed in some cases, high-
lighting the importance of targeted optimization and dataset construction for specific
scenarios.
During the experiments, we observed that directly fine-tuning with the unprocessed
CrossVul dataset, although yielding good results for some defect types, led to a noticeable
decline in detection performance when encountering new defect types, even falling short
of the baseline model. However, by introducing chain-of-thought prompts, the model
better understood the context and logical relationships in defect repair, enhancing its
generalization ability when faced with different defect types.
Acknowledgments. This work is supported by 3AD303F5 and NSFC (Grant No. 62472047).
References
1. Pistoia, M., Chandra, S., Fink, S.J., et al.: A survey of static analysis methods for identifying
security vulnerabilities in software systems. IBM Syst. J. 46(2), 265–288 (2007)
2. Wei, P.: The safety of software design loophole dynamic state examination technique analysis.
CD Technol. 2009(4), 16–17 (2009)
3. Zhao, Y. Z.: A static analysis-based system for detecting code security vulnerabilities.
University of Electronic Science and Technology of China (2012)
4. Raychev, V., Bielik, P., Vechev, M., et al.: Learning programs from noisy data. ACM Sigplan
Not. 51(1), 761–774 (2016)
5. Chen, M., Tworek, J., Jun, H., et al.: Evaluating large language models trained on code. arXiv
preprint arXiv:2107.03374 (2021)
6. Ribeiro, F., de Macedo, J.N.C., Tsushima, K., et al.: GPT-3-powered type error debugging:
Investigating the use of large language models for code repair. In: Proceedings of the 16th
ACM SIGPLAN International Conference on Software Language Engineering, pp. 111–124
(2023)
7. Keller, J., Nowakowski, J.: AI-powered patching: the future of automated vulnerability fixes.
Technical report (2024)
8. Amershi, S., Begel, A., Bird, C., et al.: Software engineering for machine learning: a case study.
In: IEEE/ACM 41st International Conference on Software Engineering: Software Engineering
in Practice (ICSE-SEIP), pp. 291–300 (2019)
9. Raffel, C., Shazeer, N., Roberts, A., et al.: Exploring the limits of transfer learning with a
unified text-to-text transformer. J. Mach. Learn. Res. 21(140), 1–67 (2020)
10. Hu, E.J., Shen, Y., Wallis, P., et al.: LORA: low-rank adaptation of large language models.
arXiv preprint arXiv:2106.09685 (2021)
11. Devlin, J., Chang, M.W., Lee, K., et al.: BERT: pre-training of deep bidirectional transformers
for language understanding. arXiv preprint arXiv:1810.04805 (2018)
12. Liu, X., Zheng, Y., Du, Z., et al.: GPT understands, too. AI Open (2023)
13. Wei, J., Wang, X., Schuurmans, D., et al.: Chain-of-thought prompting elicits reasoning in
large language models. Adv. Neural. Inf. Process. Syst. 35, 24824–24837 (2022)
14. Yang, G., Zhou, Y., Chen, X., et al.: Chain-of-thought in neural code generation: from and
for lightweight language models. IEEE Trans. Softw. Eng. (2024)
15. Nikitopoulos, G., Dritsa, K., Louridas, P., et al.: CrossVul: a cross-language vulnerability
dataset with commit data. In: Proceedings of the 29th ACM Joint Meeting on European Soft-
ware Engineering Conference and Symposium on the Foundations of Software Engineering,
pp. 1565–1569 (2021)
Backdoor Attack on Android Malware
Classifiers Based on Genetic Algorithms
Zhenghua Cai1 , Yongji Wang1(B) , Hua Zhang1 , Lei Qiao2 , Huawei Wang2 ,
and Chi Zhang3
1 State Key Laboratory of Networking and Switching Technology, Beijing University of Posts
and Telecommunications, Beijing, China
{wyj2022,zhanghua_288}@bupt.edu.cn
2 Beijing Institute of Control Engineering, Beijing, China
[email protected]
3 Information Center of China, North Industries Group Corporation, Beijing, China
Abstract. The rapid rise of malware challenges traditional detection methods due
to code obfuscation and polymorphism. While machine learning classifiers offer
quick detection and can identify complex malicious features, they are suscepti-
ble to backdoor attacks. We introduce GAT, a genetic algorithm-based approach
for generating effective and stealthy Android backdoors. Using the SHAP inter-
pretability tool, we first select efficient features as primary backdoors. A fitness
function then enables iterative optimization through a genetic algorithm. Addi-
tionally, we propose a method to integrate backdoor features into the source code,
maintaining functionality while facilitating attacks on Android classifiers in real
data outsourcing scenarios. Our evaluation of the Drebin and Mamadroid mal-
ware detectors in data outsourcing scenarios indicates that an attack success rate
exceeding 70% can be achieved with only 5% poisoned samples and a minimal
number of trigger features, while keeping the false positive rate below 10% and
the label flipping rate below 30%. Additionally, the performance degradation of
the classifiers remains within 5%. This work provides new insights into backdoor
attack methodologies in Android malware classifiers.
1 Introduction
The exponential growth of malware, along with techniques like code obfuscation and
polymorphism, has outpaced traditional detection methods. In response, security profes-
sionals are increasingly employing machine learning and deep learning to tackle large-
scale malware detection challenges. ML-based malware classifiers effectively counter
traditional evasion tactics and capture complex malicious features, often serving as the
first line of defense due to their rapid detection capabilities.
However, machine learning-based detection methods have attracted the attention of
attackers, leading to various AI-related security issues. Adversarial attacks can subtly
© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 197–214, 2025.
https://doi.org/10.1007/978-981-96-4566-4_14
198 Z. Cai et al.
modify sample features during inference to bypass detection, while poisoning attacks
target the training phase. Since ML-based malware detectors require regular retraining
with samples from third-party sources, attackers can inject malicious samples into the
training set, compromising model decisions. Backdoor attacks, a variant of poisoning,
involve constructing specific triggers that cause the model to misclassify inputs con-
taining the trigger while correctly classifying normal samples. These backdoors can be
categorized into two types: label-flipping backdoors, which alter the original labels, and
clean-label backdoors, where original labels remain unchanged while backdoor features
are embedded into the original code.
Sasaki et al. [1] first introduced a backdoor attack on malware classifiers using reverse
gradient optimization to poison a specific malware family, causing misclassification.
However, their research was limited to feature space experiments and did not address
real-world poisoning scenarios. Li et al. [2] proposed a genetic algorithm-based backdoor
attack that relies on label flipping to create poisoned samples. This approach assumes
that attackers can arbitrarily modify training set labels, which is impractical in real-world
contexts.
Severi et al. [3] were the first to investigate clean-label backdoor attacks on malware
classifiers, using interpretable machine learning to compute feature marginal contribu-
tions for constructing backdoor triggers. They injected these triggers into benign sam-
ples, creating a “channel” that led to misclassification of malware as benign. While their
method achieved a high success rate, it left noticeable traces in the samples, compro-
mising stealth, particularly when using transformed abstract features. Building on this
work, Yang et al. [4] explored the stealthiness of clean-label attacks and proposed an
optimized algorithm targeting specific malware families, achieving better effectiveness
and stealth. However, this backdoor was limited in applicability, only being effective
against certain malware families.
Li et al. [2] classified Android features into static and dynamic categories, mapping
original code files to feature matrices by inserting dummy code. However, this approach
is vulnerable to static analysis, which can prune the backdoor code, rendering it inef-
fective. Severi et al. [3] focused on ensuring the normal functionality of PE software
when inserting backdoor triggers but did not address the preservation of functionality in
Android applications. Yang et al. [4] established a mapping between custom functional
statements and features, enabling rapid mapping from backdoor features to original code.
Nonetheless, this process inevitably introduced redundant code, resulting in the mapping
of extraneous features.
Current backdoor attack algorithms in Android malware classifiers face limitations,
including inapplicability to real attack scenarios and a trade-off between effectiveness
and stealth. Existing backdoor insertion methods are prone to pruning by static analyz-
ers and may introduce redundant code. We propose a clean-label backdoor generation
algorithm based on genetic algorithms to enhance both effectiveness and stealth. Ini-
tially, SHAP machine learning interpretability tools are used to select a set of efficient
features as the initial backdoor. A fitness function is then designed to optimize backdoor
effectiveness and stealth iteratively via the genetic algorithm. Additionally, features are
categorized into XML and DEX types for mapping, with methods defined for feature
deletion and addition. To prevent static analysis pruning, code obfuscation and reflection
Backdoor Attack on Android Malware Classifiers Based on Genetic Algorithms 199
techniques are employed. The proposed algorithms were evaluated against two malware
detectors, leading to the following contributions:
(1) A genetic algorithm-based Android backdoor generation method that demonstrates
strong attack efficacy and stealth across three attack scenarios and two malware
classifiers.
(2) An Android backdoor insertion method that successfully integrates backdoor fea-
tures into source code without compromising software functionality, enabling
real-world attacks against Android classifiers.
2 Related Works
2.1 Malware Detection Systems
The earliest backdoor attack method, BadNet, was proposed by Gu et al. [5], followed by
various adaptations (Liu et al. [6]; Yao et al. [7]; Wenger et al. [8]). We focus on clean-
label backdoor attacks, which are covert methods targeting machine learning models.
In this approach, attackers introduce samples with specific backdoor functionality into
the training set, labeling them with correct class labels (i.e., clean labels). In image
recognition, incorrect labels (e.g., bird-cat) may raise suspicion during manual inspec-
tions, prompting Shafahi et al. [9] to propose clean-label poisoning attacks that ensure
input labels remain correct. Inspired by this work, subsequent researchers have explored
additional clean-label attack strategies (Barni et al. [10]; Liu et al. [11]; Saha et al. [12];
Turner et al. [13]). This method is particularly covert in practical applications, as it can
evade manual detection. In the context of malware detection, attackers must conduct
poisoning attacks through online platforms, where multiple detection engines label soft-
ware. Modifying labels without bypassing all engines is unrealistic, leading attackers to
rely on clean-label strategies that align with the labels assigned by the detection engines.
200 Z. Cai et al.
3 Attack Methods
algorithm undergoes N generations of selection, crossover, and mutation until the fitness
score converges, indicating no improvement over e iterations. The backdoor trigger with
the highest fitness score is then considered optimal.
Selection: The algorithm selects the two backdoor triggers with the highest fitness
scores for crossover and mutation. The fitness function T is defined as the weighted
average of the stealth score S and effectiveness score V (Eq. 5). The attacker seeks
to maintain the original classes of backdoor samples under the normal model while
changing them to the target class in the poisoned model, without compromising the
recognition of normal samples. Stealth Score: This is defined as the number of backdoor
samples Xb that retain their original labels in the normal model (Eq. 3). Effectiveness
Score: This consists of two components: the number of samples whose labels change to
the attacker’s target label under the poisoned model, and the count of normal samples X
that retain their original labels in the poisoned model (Eq. 4).
T = a × S + b × V (a + b = 1) (5)
Crossover: The algorithm randomly selects two backdoor triggers with higher fitness
scores for feature exchange. The number of exchanged features is 20% of the trigger
length, rounded up. If features are identical, only their values are swapped. If different,
both keys and values are exchanged, maintaining the overall length of the backdoor
trigger.
Mutation: The algorithm introduces random perturbations to the backdoor triggers
derived from selection and crossover. To ensure compliance with feature constraints and
maintain the link between backdoor features and benign labels, perturbation magnitudes
are calculated based on the distribution of benign samples in the attack sample set for
each feature.
Fig. 5. Code Obfuscation Using try-catch Statements to Add Invalid Function Calls
The second method involves defining empty functions. As shown in Fig. 6, an empty
function funA is defined under Lcom/google/Myclass, while function funB is defined
Backdoor Attack on Android Malware Classifiers Based on Genetic Algorithms 205
under Lcom/malware/Myclass to call funA. By invoking funB through the startup func-
tion, this method not only adds a function call feature from the malware package to the
Google package but also enhances the function presence feature.
Fig. 6. Defining Empty Functions and Invoking Them to Add Invalid Function Calls
For API reduction, we define it as Dec(x, t, d), indicating that function d calls function
t within Smali file x. We utilize regular expressions to locate the call within the Smali file,
and then employ reflection techniques for code obfuscation to achieve the API reduction.
As illustrated in Fig. 7, function funA under Lcom/google/Myclass calls function funB
under Lcom/malware/Myclass via reflection, making it undetectable by static analyzers
that funA has called funB.
4 Experiments
4.1 Dataset
The dataset used in this paper includes malicious samples from Drebin [14] and benign
and malicious applications collected from Android Zoo. A total of 27,445 benign appli-
cations and 23,839 malicious applications were collected. We randomly selected 5,489
206 Z. Cai et al.
benign applications and 4,768 malicious applications as the test dataset, while the
remaining applications were used for training.
Target Attack Attack Model Feature Set F1 Score FN Rate FP Rate Accuracy
Systems
Drebin NN ALL 98.48% 2.50% 0.50% 98.50%
SVM 97.98% 3% 1% 98%
RF 98.25% 2% 1.50% 98.25%
Mamadroid NN Family 92.60% 12.50% 1.50% 93%
SVM 92.06% 13% 2% 92.50%
RF 92.40% 12% 2.50% 92.75%
NN Package 96.60% 10.25% 1.50% 96.30%
SVM 95.06% 11% 2.4% 95.25%
RF 95.40% 13% 2.25% 95.74%
Acc(F, Xb ): The label flip rate, indicating the accuracy of the original model on mali-
cious samples with backdoors. A lower value signifies minimal disruption to detection
outcomes, which is desirable for maintaining stealth.
Table 3. Comparison of ASR Across Three Attack Scenarios with Trigger Size of 30 and Attack
Target as Drebin、Mamadroid Family and Package granularity.
Attack Scenario Model Outsourcing Ideal Data Outsourcing Real-World Data Outsourcing
Poisoned percentage 0.5 0.6 0.7 0.8 0.005 0.01 0.05 0.1 0.005 0.01 0.05 0.1
GS-Drebin 0.624 0.633 0.685 0.746 0.484 0.484 0.540 0.572 0.141 0.163 0.427 0.570
LC-Drebin 0.683 0.698 0.719 0.760 0.365 0.424 0.559 0.603 0.147 0.180 0.348 0.415
LM-Drebin 0.540 0.546 0.566 0.626 0.244 0.313 0.413 0.451 0.156 0.273 0.466 0.554
GAT-Drebin 0.697 0.708 0.747 0.797 0.526 0.526 0.552 0.622 0.211 0.232 0.462 0.592
GS-Mamadroid-F 0.883 0.898 0.919 0.96 0.865 0.875 0.889 0.904 0.747 0.78 0.849 0.865
LC-Mamadroid-F 0.834 0.853 0.905 0.936 0.704 0.744 0.81 0.872 0.644 0.664 0.678 0.71
LM-Mamadroid-F 0.84 0.846 0.906 0.926 0.744 0.763 0.813 0.891 0.656 0.673 0.746 0.794
GAT-Mamadroid-F 0.897 0.909 0.948 0.998 0.888 0.907 0.932 0.983 0.813 0.831 0.863 0.893
GS-Mamadroid-P 0.856 0.866 0.879 0.894 0.755 0.783 0.842 0.864 0.753 0.782 0.844 0.858
LC-Mamadroid-P 0.694 0.734 0.800 0.862 0.646 0.667 0.687 0.707 0.639 0.663 0.681 0.701
LM-Mamadroid-P 0.734 0.753 0.803 0.871 0.655 0.673 0.742 0.797 0.646 0.673 0.743 0.795
GAT-Mamadroid-P 0.878 0.897 0.922 0.973 0.814 0.833 0.864 0.893 0.809 0.815 0.855 0.876
success rates for Drebin and Mamadroid are 70% and 99%, respectively, highlighting
the GAT algorithm’s efficacy against neural networks.
Fig. 8. Attack Success Rates of the GAT Algorithm on Three Classifiers Under Data Outsourcing
Scenarios with Varying Trigger Sizes
Table 4. Comparison of ASR, FPb , Acc(Fb , X), and Acc(F, Xb ) between the GAT Algorithm and
the EXP Algorithm under the Drebin Classification System
optimize feature selection for backdoor attacks. While this study employs SHAP, we
explore LIME as an alternative to assess differences in attack speed and effectiveness, as
Backdoor Attack on Android Malware Classifiers Based on Genetic Algorithms 211
Table 4. (continued)
illustrated in Table 5. LIME, which uses Gaussian perturbations and a local linear model
to identify significant features, demonstrates a faster BST compared to SHAP. However,
SHAP achieves higher ASR due to its capability to capture feature interactions based
on Shapley values. When applying SHAP to the Mamadroid Package feature set, we
first reduce dimensionality from 197,136 to 500 using SVM, enabling efficient feature
extraction within a reasonable timeframe. This approach increases ASR by 16.74%
compared to LIME, while keeping the time overhead for SHAP within 20 s of LIME.
Overall, despite its longer processing time, SHAP’s accuracy in feature importance
calculation makes it preferable for initial backdoor generation.
Table 5. Comparison of ASR and BST of Different Model Interpretation Tools in Data
Outsourcing Scenarios
Fig. 9. Validation of the Effectiveness of Genetic Algorithm Optimization for Initial Backdoor
Generation
5 Conclusion
Acknowledgments. This work is supported by 3AD303F5 and NSFC (Grant No. 62472047).
References
1. Narisada, S., Sasaki, S., Hidano, S., et al.: Stronger targeted poisoning attacks against malware
detection In: Krenn, S., Shulman, H., Vaudenay, S. (eds.) Cryptology and Network Security.
CANS 2020. Lecture Notes in Computer Science, vol. 12579, pp. 65–84. Springer, Cham
(2020). https://doi.org/10.1007/978-3-030-65411-5_4
2. Li, C., Chen, X., Wang, D., et al.: Backdoor attack on machine learning based android malware
detectors. IEEE Trans. Depend. Secure Comput. 19(5), 3357–3370 (2021)
3. Severi, G., Meyer, J., Coull, S., et al.: Explanation-Guided backdoor poisoning attacks against
malware classifiers. In: 30th USENIX Security Symposium (USENIX Security 21), pp. 1487–
1504 (2021)
4. Yang, L., Chen, Z., Cortellazzi, J., et al.: Jigsaw puzzle: selective backdoor attack to subvert
malware classifiers. In: 2023 IEEE Symposium on Security and Privacy (SP). IEEE, pp. 719–
736 (2023)
214 Z. Cai et al.
5. Gu, T., Liu, K., Dolan-Gavitt, B., et al.: Badnets: Evaluating backdooring attacks on deep
neural networks. IEEE Access 7, 47230–47244 (2019)
6. Liu, Y., Ma, S., Aafer, Y., et al.: Trojaning attack on neural networks. In: 25th Annual Network
and Distributed System Security Symposium (NDSS 2018). Internet Soc (2018)
7. Yao, Y., Li, H., Zheng, H., et al.: Latent backdoor attacks on deep neural networks. In: Pro-
ceedings of the 2019 ACM SIGSAC Conference on Computer and Communications Security,
pp. 2041-2055 (2019)
8. Wenger, E., Passananti, J., Bhagoji, A.N., et al.: Backdoor attacks against deep learning
systems in the physical world. In: Proceedings of the IEEE/CVF Conference on Computer
Vision and Pattern Recognition, pp. 6206–6215 (2021)
9. Shafahi, A., et al.: Poison frogs! targeted clean-label poisoning attacks on neural networks.
CoRR (2018)
10. Barni, M., Kallas, K., Tondi, B.: A new backdoor attack in CNNS by training set corruption
without label poisoning. In: Proceedings of ICIP (2019)
11. Shapira, T., Berend, D., Rosenberg, I., Liu, Y., Shabtai, A., Elovici, Y.: Being single has
benefits. Instance poisoning to deceive malware classifiers. CoRR (2020)
12. Saha, A., Subramanya, A., Pirsiavash, H.: Hidden trigger backdoor attacks. In: Proceedings
of AAAI (2020)
13. Turner, A., Tsipras, D., Madry, A.: Label-consistent backdoor attacks. CoRR (2019)
14. Arp, D., Spreitzenbarth, M., Hubner, M., et al.: DREBIN: effective and explainable detection
of android malware in your pocket. NDSS 14, 23–26 (2014)
15. Mariconti, E., Onwuzurike, L., Andriotis, P., et al.: Mamadroid: detecting android malware
by building markov chains of behavioral models. arXiv preprint arXiv:1612.04433 (2016)
A Malicious Websites Classifier Based
on an Improved Relation Network
Qianshi Wang1 , Chongjun Xu2(B) , Huayu Yang2 , Xilin Zhai1 , and Hua Zhang1
1 State Key Laboratory of Networking and Switching Technology, Beijing University of Posts
and Telecommunications, Beijing 100876, China
{scrassy_1047,zhanghua_288}@bupt.edu.cn
2 Zhejiang Branch of National Computer Network Emergency Response Technical
Team/Coordination Center of China, Hangzhou, China
[email protected]
1 Introduction
As new energy vehicles continue to integrate various features, intelligent in-car sys-
tems are gradually becoming the core of vehicle control, information interaction and
in-car entertainment, this brings serious challenges of cyber security in IoV environ-
ment. Web platforms integrated in in-car systems are facing the threats of malicious
websites. Among these websites, pornography and gambling websites make up a sig-
nificant portion, and counterfeit websites usually use fake pages to deceive users, which
may bring potential substantial financial losses. Therefore, indentifying and classifying
malicious websites in in-car systems is of great importance to protect the security of
IoV.
Traditional models for detecting malicious website are mostly trained on dataset of
features like website traffic or content. Zhao et al. [1] classified multi-level features of
© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 215–230, 2025.
https://doi.org/10.1007/978-981-96-4566-4_15
216 Q. Wang et al.
malicious websites based on the URL features and content features of websites, they
designed a classification model based on Bert-ResNet, achieving higher accuracy in
detecting malicious websites. Afzal et al. [2] proposed a deep learning model, URLdeep-
Detect, based on semantic embeddings. URLdeepDetect recognizes malicious websites
by analyzing the semantic information in their urls, which has good performance in
detecting malicious websites.
The methods above need a large number of website samples to train models, however
it is hard to collect considerable malicious website images in IoV environment when
we research how to detect malicious websites using website images. Models may face
the overfitting issue without enough samples, as they are not able to learn a stable and
accurate decision boundary in few-shot scenarios. In order to solve the problem of insuf-
ficient image samples, existing researches have proposed some approaches. Koch et al.
[3] proposed Siamese network based on metric learning, which employs weight sharing
to computes the similarity of images and then outputs different image category accord-
ing to the similarity of different samples. Flood Sung et al. [4] developed a Relational
Network with an embedding module and a relation module to classify categories by cal-
culating the similarity between sample images, and their approach demonstrate good per-
formance in few-shot experiments. Although methods above shows great performance
in few-shot image classification tasks, their models are still single-network structure,
which may encounter challenges when extracting multi-level and fine-grained informa-
tion features from complex images, such as images of malicious website, resulting in
poor performance and even misjudgment.
In order to address the issues of insufficient samples and incomplete feature extrac-
tion, a few-shot classification model for malicious websites based on an improved
Relation Network, CarMaNet, is proposed in this paper. Build upon the ability of few-
shot models to complete classification tasks with few-shot training, we further import
the inception block [5] into the traditional Relation Network and achieve malicious
website classification based on attention mechanism in few-shot scenario. The specific
contributions are as follows.
(1) We propose a few-shot model called CarMaNet to classify malicious websites based
on an improved Relation Network. We bring attention mechanism into CarMaNet,
insert a parallel SE module into each residual block and combine the outputs of the
convolutional layer with the SE module. We call this process FCAA, which enables
CarMaNet to learn multi-level features of malicious websites better in few-shot
scenarios.
(2) We design and implement a multi-scale feature embedding module, which contains
an Inception structure. This makes it easier for the model to learn fine-grained fea-
tures of malicious website images, improving the model’s ability to capture features
in complex image samples.
(3) Experimental results show that CarMaNet demonstrates better performance in fea-
ture extraction and classification on few-shot datasets compared with traditional
models, significantly improving the accuracy of malicious website classification.
A Malicious Websites Classifier Based on an Improved Relation Network 217
2 Related Work
3 Our Model
In this paper, a classification model for malicious websites based on an improved Relation
Network, CarMaNet, is proposed. In order to better capture the multi-level features of
malicious websites and enhance the model’s ability to learn and represent fine-grained
image features, we add attention mechanism module into the feature embedding module
in the traditional Relation Network [14], enabling CarMaNet to relate different images
218 Q. Wang et al.
according to image similarity. The overall structure of CarMaNet mainly includes three
modules: embedding module, attention mechanism module and relation module. The
structure of CarMaNet is shown in Fig. 1.
CarMaNet first utilizes multi-scale embedding module to extract image features in the
support set and the query set. In embedding module, part of the convolutional blocks are
replaced by Inception block, enabling the embedding module to capture image features
in different scales and achieve deeper and broader abstraction of these features. Image
features are finally converted to embeddings via convolutional blocks and pooling layer.
At the same time, we insert an improved SENet [15] attention mechanism module
between the embedding module and the relation module to fully utilize the relations
between modules and enhance the represention ability of CarMaNet. This attention
mechanism module imports a parallel SE module into each residual block and combines
the output of the conventional convolutional module and SE block, providing more
accurate and various representation for subsequent metric computations.
The relation module contains 3 convolutional blocks and 2 fully connected layers, we
replace the original activation function and loss function to accommodate the embedding
module and SENet attention mechanism module. This approach further optimizes the
training process of CarMaNet. Finally, the relationship score is output using a scaling
function, the category which has the highest score is chosen as the final prediction result.
In CarMaNet, we use Inception block to enhance the function of embedding module. The
Inception block can capture multi-level image features, focusing not only on the overall
structure of the input sample, but also on fine-grained micro features. This characteristic
gives the Inception block significant advantages when processing malicious website
A Malicious Websites Classifier Based on an Improved Relation Network 219
Similar to the first branch, the second branch first applies a 1 * 1 convolution for
dimensionality reduction, then uses a 3 * 3 convolutional kernel to expand the range
when extracting spatial feature.
Building upon the second branch, we replaces the 3 * 3 convolutional kernel with a
5 * 5 kernel in the third branch, along with a 1 * 1 kernal to reduce the computational
cost of the 5 * 5 kernel, enabling the branch to capture broader spatial features at a lower
cost in large-scale images.
The fourth branch employs 3 * 3 max-pooling to extract salient regional features and
a 1 * 1 convolution to reduce the dimension, enhancing the model’s spatial invariance to
handle the transformation of the input image samples, preparing features for subsequent
layers.
Finally, the outputs of these four branches are merged. This four-branch structure
enables the embedding module to learn and utilize features from different dimensions
without increasing the computational cost, thereby improving the model’s representa-
tion ability and performance. The combined feature map contains information from all
branches, allowing the network to further process these multi-scale features in subsequent
modules.
The SE module extract global spatial information through global average pooling
(GAP), denoted as S = GAP(F(X)), then learn the weight relationships between channels
through fully connected layer (FC) and an activation function (such as Sigmoid). While
preforming this process, the weight of each channel can be recalibrated to enhance the
useful features while suppressing irrelevant ones. This process can be denoted as Eq. (1).
The original output of the residual block can be denoted as Eq. (3).
Y = F(X ) + X (3)
The output of the improved SENet can be expressed as Eq. (4) and Eq. (5).
Y = H (Conv 1×1 (F(X ) ⊕ SE(F(X )))) + X (4)
m
n
2
ϕϕ,φ ← argmin ri,j − I yi = yj (7)
ϕ,φ i=1 j=1
In Eq. 6, r i,j is the final relation score between samples in support set and samples
in the query set, while x i is the sample from the support set and x j is the sample from
the query set. i represents the numbers of sample categories, m is the branches of the
embedding networks, and f(x i )/f(x j ) represent the feature embeddings of samples from
the support /query set. C represents the process of feature merging, which itegrates
the embeddings to provide combined feature representation for calculating the relation
score. H w represents the weighted process based on attention mechanism, and gφ is the
mapping of the relation module.
In Eq. 7, I is an indicator, if the support set sample and the query set sample belong
to different categories, the value of I is 0, otherwise it is 1. Argmin is used to adjust
the parameters in the module based on MSE. Specific modifications to the activation
function and loss function are as follows.
A Malicious Websites Classifier Based on an Improved Relation Network 223
GSPD dataset: GSPD dataset contains three main categories of malicious web-
sites—pornography, gambling, and phishing, meanwhile spanning a total of 16 sub-
categories. Each subcategory includes 100 screenshots of malicious websites, with
images having a resolution of 600 × 600 pixels.
GWDT dataset: GWDT dataset contains 20 categories of malicious websites,
with approximately 130 samples per category. The image resolution is 256 × 256
pixels.
AGPD dataset: AGPD dataset includes 18 categories of malicious websites, each
with 200 samples, and the image resolution is 600 × 600 pixels.
(2) Baseline Models
In this section, we compare CarMaNet with four baseline models respectively
including Relation Net, MAML(Model-Agnostic Meta-Learning)[15], Prototypi-
cal Nets[17] and Matching Nets[18]. The introduction of baseline models are as
follows.
Relation Network consists of an embedding model and a relation module, it first
adds the embedding form of input classes to form a feature map of all classes, then
classify samples by calculating the relation scores according to the feature map.
MAML (Model-Agnostic Meta-Learning) is a model-agnostic meta-learning
algorithm. By training the model’s initial parameters on a set of tasks, MAML can
quickly learn from a small amount of new data. MAML achieves good generalization
performance with only a few gradient updates to the model parameters.
Prototypical Nets usually learns a mapping function to map the input samples to
prototypes in the vector space, then classify new samples accoding to their distances
to these prototypes.
Matching Nets maps input samples into a vector space using neural networks,
and calculate the similarity score between input samples and support set samples
using weighted calculation based on attention mechanism. This similarity score is
then used to classify the input samples.
(3) Hyperparameters
In our experiments, we set the initial learning rate to 0.001, utilizing Adam algorithm
as the gradient algorithm. Meanwhile, we set the gamma decay factor to 0.5, and the
number of epochs is set to 1000. Experiments are conducted in both 5-way 1-shot
and 5-way 5-shot scenarios.
Model 5-way
1-shot 5-shot
MAML 61.25 ± 1.33% 75.09 ± 1.74%
PrototypicalNets 64.74 ± 0.42% 72.83 ± 1.21%
MatchingNets 48.56 ± 1.32% 62.73 ± 0.68%
Relation Net 64.33 ± 0.87% 81.37 ± 1.26%
CarMaNet 67.95 ± 0.73% 89.69 ± 0.98%
Model 5-way
1-shot 5-shot
MAML 58.47 ± 1.29% 73.66 ± 0.62%
PrototypicalNets 62.36 ± 0.28% 80.73 ± 1.11%
MatchingNets 61.35 ± 1.26% 78.12 ± 0.33%
Relation Net 65.80 ± 1.59% 78.99 ± 0.56%
CarMaNet 72.89 ± 1.30% 92.94 ± 1.55%
Model 5-way
1-shot 5-shot
MAML 56.66 ± 1.30% 79.32 ± 1.66%
PrototypicalNets 59.81 ± 1.27% 74.75 ± 1.31%
MatchingNets 57.79 ± 0.79% 71.37 ± 0.75%
Relation Net 62.84 ± 0.92% 75.83 ± 0.70%
CarMaNet 68.91 ± 0.81% 88.37 ± 0.94%
From Fig. 5 we can conclude that under the condition of 5-way 1-shot, CarMaNet
and Relation Network have the similar accuracy in the early stage of training. After
reaching 230 epochs, CarMaNet gradually comes into a fitting state and turns to be more
stable, outperforming the traditional Relation Network model in accuracy and fitting
efficiency.
From Fig. 6 we can conclude that CarMaNet’s performance is similar to that in the
5-way 1-shot experimental environment. Under the 5-way 5-shot condition, CarMaNet
has similar performance with the original model around the 100th epoch and continues
to rise during the subsequent training process. Around the 230th epoch, CarMaNet enters
a fitting state and stabilizes, outperforming the traditional Relation Network model in
accuracy and fitting efficiency.
Experimental results shows that under both 5-way 1-shot and 5-way 5-shot con-
ditions, CarMaNet reaches higher classification accuracy while maintaining the same
A Malicious Websites Classifier Based on an Improved Relation Network 227
fitting speed as Relation Net. This indicates that CarMaNet has better performance in
few-shot tasks with limited training data.
Analyzing Table 4, we find that in 5-way 1-shot / 5-way 5-shot condition, the accuracy
of CarMaNet increases by 1.41%, 0.89%, 0.71% and 1.14%, 0.37%, 0.78% compared
with the tradition Relation Network when we only activate the multi-scale embedding
228 Q. Wang et al.
module(Imp_RN). When we further activate the FCAA module and set different com-
pression ratios r, CarMaNet outperformes Relation Network in accuracy, increasing by
3.62%, 7.09%, 6.07% and 8.32%, 13.95%, 12.54% in 5-way 1-shot and 5-way 5-shot
environments. This demonstrate the effectiveness of our multi-scale embedding module
and the FCAA model.
In order to further validate the superiority of FCAA, we continue to choose CBAM
[19], ECANet [20], SENet as the attention module of our model. Using the average value
of 5 experiments, experimental results are shown in Table 5.
Through Table 5 we find that model with FCAA designed in this paper achieves
higher accuracy on GSPD, GWDT and AGPD datasets compared with models using
traditional attention mechanisms like SENet, ECANet, or CBAM. In the 5-Way 1-Shot
scenario on the GSPD dataset, the model with FCAA achieves an accuracy of 67.95%,
which is 1.07% higher than the model with CBAM. In the 5-Way 5-Shot scenario,
the classification accuracy of the model with FCAA on the GSPD dataset is 89.69%,
surpassing CBAM by 4.06%. On the GWDT and AGPD datasets, the model with FCAA
outperforms CBAM by 15.61% and 10.64%, in two scenatios respectively. In 5-Way
1-Shot scenario on the GWDT dataset, the model with FCAA reaches an accuracy of
92.94%, while the accuracy of the model with SENet is 81.63%, making the model with
FCAA approximately 11.31% more accurate than the model with SENet. In summary, the
FCAA attention mechanism demonstrates superior performance in classifying malicious
website, and due to its few-shot learning capability, CarMaNet can better learn the
features of malicious website images in IoV environments with insufficient samples.
This enables CarMaNet to exhibit good transferability across other types of malicious
website datasets.
5 Conclusions
In this paper, we main foucus on the issues in in-car intelligent system such as lack of
samples and insufficient feature extraction. In order to address these issues, we propose
a few-shot classification model for malicious websites based on an improved Relation
Network called CarMaNet. We introduce the inception block into the traditional Relation
Network to increase its width, allowing it to capture multi-scale features better. Second,
A Malicious Websites Classifier Based on an Improved Relation Network 229
Fundings.. The paper is supported by the the key R&D programme of Zhejiang Province
No.2024C01012.
Disclosure of Interests.. The authors declare that they have no known competing financial inter-
ests or personal relationships that could have appeared to influence the work reported in this
paper.
References
1. Zhao, C.: Research on malicious website identification method integrating URL and page
information. Jiangsu Univ. Sci. Technol. (2022). https://doi.org/10.27171/d.cnki.ghdcc.2022.
000691
2. Afzal, S., Asim, M., Javed, A.R., Beg, M.O., Baker, T.: URLdeepDetect: a deep learning
approach for detecting Malicious URLs using semantic vector models. J. Netw. Syst. Manage.
29(3) (2021)
3. Koch, G., Zemel, R., Salakhutdinov, R.: Siamese neural networks for one-shot image
recognition. In: ICML Deep Learning Workshop, vol. 2, no. 1 (2015)
4. Sung, F., Yang, Y., Zhang, L., et al.: Learning to compare: relation network for few-shot
learning. In: IEEE Conference on Computer Vision and Pattern Recognition, pp. 1199–1208
(2018)
5. Hu, J., Shen, L., Sun, G.: Squeeze-and-excitation networks. In: IEEE Conference on Computer
Vision and Pattern Recognition, pp. 7132–7141 (2018)
6. VirusTotal: VirusTotal [EB/OL]. 01-06-2004. https://www.virustotal.com/gui/home/url.
Accessed 10 Nov 2022
7. Tencent URL Security Detection Center: Tencent URL Security Detection Center [EB/OL].
(2023–01–10). https://urlsec.qq.com/check.html. Accessed 1 Nov 2022
8. Norton Safe Web: Norton Safe Web [EB/OL]. 20-12-2023. https://safeweb.norton.com/.
Accessed 20 Nov 2023
230 Q. Wang et al.
9. Chen, Y., Zheng, R., Zhou, A., et al.: Automatic detection of pornographic and gambling
websites based on visual and textual content using a decision mechanism. Sensors 20(14),
3989 (2020)
10. Wang, C., Zhang, M., Shi, F., et al.: A hybrid multimodal data fusion-based method for
identifying gambling websites. Electronics 11(16), 2489 (2022)
11. Siddiq, M.A.A., Arifuzzaman, M., Islam, M.S.: Phishing website detection using deep
learning. In: 2nd International Conference on Computing Advancements, pp. 83–88 (2022)
12. Zhang, Y., Fu, X., Yang, R., et al.: DRSDetector: detecting gambling websites by multi-level
feature fusion. In: IEEE Symposium on Computers and Communications (ISCC), pp. 1441–
1447. IEEE (2023)
13. Naru, P., Chinthala, S.K.R., Sekhar, P.G., et al.: Detection of fake websites using machine
learning techniques. In: 3rd International Conference on Smart Data Intelligence (ICSMDI),
pp. 477–482. IEEE (2023)
14. Chiramdasu, R., Srivastava, G., Bhattacharya, S., et al.: Malicious URL detection using
logistic regression. In: IEEE International Conference on Omni-Layer Intelligent Systems
(COINS), pp. 1–6. IEEE (2021)
15. Finn, C., Abbeel, P., Levine, S.: Model-agnostic meta-learning for fast adaptation of deep
networks. In: International Conference on Machine Learning, PMLR, pp. 1126–1135 (2017)
16. Liu, C., Yu, S., Yu, M., et al.: Adaptive smooth L1 loss: a better way to regress scene texts
with extreme aspect ratios. In: IEEE Symposium on Computers and Communications (ISCC),
pp. 1–7. IEEE (2021)
17. Snell, J., Swersky, K., Zemel, R.: Prototypical networks for few-shot learning. In: Advances
in Neural Information Processing Systems, vol. 30 (2017)
18. Vinyals, O., Blundell, C., Lillicrap, T., et al.: Matching networks for one-shot learning. In:
Advances in Neural Information Processing Systems, vol. 29 (2016)
19. Woo, S., Park, J., Lee, J.Y., et al.: CBAM: convolutional block attention module. In: Ferrari,
V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) Computer Vision – ECCV 2018. ECCV
2018. Lecture Notes in Computer Science(), vol. 11211, pp. 3–19. Springer, Cham (2018).
https://doi.org/10.1007/978-3-030-01234-2_1
20. Wang, Q., Wu, B., Zhu, P., et al.: ECA-Net: efficient channel attention for deep convolutional
neural networks. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition
(CVPR), pp. 11534–11542 (2020)
Unknown Category Malicious Traffic Detection
Based on Contrastive Learning
1 Introduction
With the increasing number of electronic devices and the growing complexity of network
environments, cybersecurity issues have become increasingly prominent, resulting in
significant losses for the network economy. To effectively prevent cyberattacks, the con-
struction of Intrusion Detection Systems (IDS) has become one of the primary methods.
Existing IDS can be broadly categorized into host-based and network-based systems:
the former primarily detects intrusions by monitoring log information, while the lat-
ter analyzes network traffic to determine the presence of intrusions. Although machine
learning (ML)-based IDS are widely applied, they face limitations in extracting deep
features and struggle to cope with increasingly sophisticated attack methods. Deep learn-
ing (DL), with its powerful feature extraction capabilities, has gradually become integral
© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 231–245, 2025.
https://doi.org/10.1007/978-981-96-4566-4_16
232 L. Yan et al.
2 Related Work
2.1 Deep Learning-Based Intrusion Detection Methods
Reference [4] proposed a method based on feature selection and Conditional Wasser-
stein GAN (FCWGAN) combined with BiLSTM, which effectively improves the clas-
sification accuracy of network intrusion detection models and significantly reduces false
positives and false negatives. Reference [5] introduced an intrusion detection method
based on VAE-CWGAN and feature importance fusion, which significantly improves
model detection performance. Reference [6] proposed a Few-Shot Learning model based
on a Siamese Convolutional Neural Network (FSL-SCNN) that optimizes feature rep-
resentation and designs three specific loss functions to improve detection accuracy.
Reference [7] developed a multi-module integrated intrusion detection system named
GMM-WGAN, which combines a clustering algorithm based on the Gaussian Mix-
ture Model and Wasserstein GAN for handling imbalance, followed by a classification
module design based on CNN and LSTM. This system effectively enhances intrusion
detection performance on the NSL-KDD and UNSW-NB15 datasets.
However, the above methods have not fully addressed the issue of class imbalance
in the datasets, and deep learning models require large amounts of data for training
to achieve high accuracy. Consequently, some types of attacks are not easily learned
by deep learning models, resulting in lower detection performance. Therefore, it is
necessary to address the class imbalance problem in the datasets to further improve
detection performance.
Contrastive learning has been successfully applied in fields such as image recognition,
natural language processing, and video understanding [8], outperforming conventional
methods. Unlike traditional deep learning approaches, contrastive learning focuses on
learning feature representations that better reflect the semantic relationships between
samples. Typical studies like Hjelm et al. [9] and Bachman et al. [10] constructed con-
trastive prediction tasks based on global-local relations. Van den Oord et al. [11], Henaff
[12], took the context relationship as the prediction task and proposed the contrastive pre-
diction coding model. The algorithm used the autoregressive model to predict the future
time step as the prediction task. Based on that, they could learn a hidden representation
of the data.
Contrastive methods consist of two steps: contrastive task construction and con-
trastive learning. In the contrastive task construction, the semantic relationships between
samples are predefined. To better capture these semantic relationships, it is often neces-
sary to construct contrastive learning task pairs using various data augmentation samples
[13–15] and optimize the contrastive loss. The objective is to optimize the embedding
function based on the contrastive loss, bringing samples closer to others of the same class
while pushing them farther from samples of different classes. Consequently, contrastive
learning can reinforce the differences between traffic from different classes, thereby
enhancing the detection model’s sensitivity to abnormal traffic.
234 L. Yan et al.
3 Method
The paper proposes a contrastive learning-based method for detecting unknown mali-
cious traffic, aiming to address issues such as class imbalance in malicious traffic and
unclear boundaries between traffic categories. The method effectively detects unknown
types of malicious traffic and consists of two main components: contrastive learning
pre-training and fine-tuning the model classifier. The overall framework of the method
is illustrated in Fig. 1.
This paper introduces contrastive learning to enhance the discriminative ability between
different categories of traffic, enabling the model to better learn the differences between
various types of traffic and thereby improving the accuracy and robustness of the intrusion
detection model. The pre-training process of contrastive learning is composed of data
augmentation, an encoder, and the contrastive learning task, as illustrated in Fig. 2.
When constructing a contrastive learning task, samples are usually categorized based
on their labels, with samples of the same category forming positive pairs, while samples
of different categories form negative pairs. To address the issue of insufficient data for
certain malicious traffic categories, this paper adopts a data augmentation method by
adding random masks to generate augmented samples that retain the same semantics as
the original traffic samples. These augmented samples, along with the original samples
from the same category, form positive sample pairs, while samples from other categories
form negative pairs. This method effectively enlarges the dataset and provides a more
reliable foundation for subsequent contrastive learning tasks. The data augmentation
method with random masking is shown in Fig. 3.
Given a network traffic packet sequence flow = p1 , p2 ; · · · , pm , define a mask
containing elements with a value of 0 as follows:
mask = mask1 , mask2 , . . . , maskm
m
where maski = m − k (1)
i=1
The mask1 , mask2 , . . . , maskm contains a mask with k values of 0 at random
positions, and the operation of adding the mask is denoted as:
The theoretical support for constructing contrastive learning tasks through data
augmentation methods by adding random masks is as follows:
(1) Although the data undergoes random masking, it still retains some features and
information of the original data, making the analysis of the masked data still effective.
Therefore, the random masking method can ensure the validity of traffic data to a
certain extent.
(2) In complex and volatile network environments, network fluctuations may lead to
occasional packet loss in network traffic data. By adding random masks to generate
traffic data, it is possible to effectively simulate real network environments and reflect
actual packet loss phenomena.
236 L. Yan et al.
(3) In network intrusion detection tasks, the detection model should possess a certain
degree of robustness. Even when there are occasional missing packets in the traffic
sequence, a robust intrusion detection model should still be able to accurately identify
and detect potential intrusion behaviors.
After the data augmentation operation, the class label of the original sample is
assigned to the augmented sample. In this way, the augmented samples form positive
sample pairs with samples of the same class, and negative sample pairs with samples of
different classes, as illustrated in Fig. 4.
To enhance the model’s classification capability, we map the samples into the fea-
ture space using an encoder to perform contrastive learning tasks. The core idea of
contrastive learning is to optimize the encoder through backpropagation using the con-
trastive loss function, enabling the model to effectively distinguish between different
classes of samples. Specifically, the objective of the contrastive loss is to minimize the
distance between positive sample pairs while maximizing the distance between neg-
ative sample pairs in the feature space, thereby improving the model’s discriminative
performance. The contrastive loss function is defined as follows:
Lcontrastive = Lconstrastive
i
j z
exp zi · τp
−1 (3)
a∈(Ai ) exp(zi · τ )
za
= |p | log
(i) p∈p(i)
i
In this context, zi is used as the anchor sample, forming positive pairs with samples
from the same class and negative pairs with samples from different classes. Let the index
of the anchor sample be denoted as i, and let P(i) and A(i) represent the set of positive pairs
and the set of all sample pairs for the anchor sample, respectively. The denominator of the
contrastive loss function is computed over the entire sample pair set, while the numerator
Unknown Category Malicious Traffic Detection Based on Contrastive Learning 237
represents the similarity between the anchor sample and its positive pairs. To optimize
Eq. (3), it is necessary to maximize the dot product in the numerator while minimizing
the value of the denominator. The parameter τ ∈ R+ serves as a temperature coefficient,
aiming to guide the model to pay more attention to hard-to-distinguish negative samples.
The objective of the contrastive loss is to optimize the feature representation such that
the distance between positive pairs is minimized, while the distance between negative
pairs is increased, thereby enhancing the model’s classification performance.
This paper proposes a method for identifying unknown malicious traffic. First, a training
set is constructed by selecting several samples closest to the cluster centers through
clustering. Next, the encoder obtained from the contrastive learning task is fine-tuned
by training the model to perform classification tasks, achieving traffic classification and
detection of unknown malicious traffic. The fine-tuning model is illustrated in Fig. 5.
Let Sc denote the total number of samples in the training set, fφ represent the encoder,
be used to extract the feature vectors of the traffic data, and (xi , yi ) denote the traffic data
along with their corresponding labels.
For each sample xi in the test set, it is necessary to compute the Euclidean distance
between its feature vector fφ (xi ) and each class center vector pc . The distance calculation
238 L. Yan et al.
Subsequently, the test sample x is classified into the category corresponding to the
class center with the shortest distance. The classification formula is defined as follows:
y = arg min d fφ (x), pc (6)
c
If a sample has a distance greater than a predefined threshold from all class centers,
it is identified as a new class sample, and the class centers are updated accordingly.
This ensures that the model can adapt to new categories of malicious traffic, thereby
improving the model’s flexibility and adaptability.
During the classification phase, the model computes the probability that the sample
to be detected belongs to each class based on Euclidean distance. Specifically, the smaller
the distance between the sample and the class center, the higher the probability that the
sample belongs to that class. To achieve this, the softmax function can be used to convert
the distances into probabilities. The detailed calculation process is as follows:
exp −d fφ (x), pc
P(y = c|x) = (7)
c exp −d fφ (x), pc
For each sample x in the test set and its true class y, the cross-entropy loss for all samples
is calculated and averaged to obtain the loss function, The loss function is defined as
follows:
1 exp −d fφ (x), py
L=− log (8)
|Q| c exp −d fφ (x), pc
(x,y)∈Q
The loss function guides the update of model parameters during training, enabling
the model to perform classification tasks more accurately and effectively detect novel
attack traffic.
and testing sets for the performance evaluation of the proposed method. To address
the issue of class imbalance and enhance the detection performance and generalization
capability of the model, this study sets the sample size for each category in the training
set to 20. To simulate the imbalance present in real-world data, the sample sizes for each
category in the testing set are determined according to their distribution in the dataset.
For instance, in the CICIoT2023 dataset, the majority class (DDoS) is assigned 1,600
samples, while the Benign, DoS, and Mirai classes each have 800 samples. The Spoofing
and Recon classes are assigned 400 samples each, and the Web and BruteForce classes
are set to 100 samples. Furthermore, multiple testing sets containing different sample
configurations were constructed to validate the feasibility of the proposed method. A
similar strategy was employed for selecting the training and testing sets in the UNSW-
NB15 dataset. The distribution of classes in the datasets and the configurations of the
training and testing sets are detailed in Table 1.
TP
Recall = (9)
TP + FN
TP
Precision = (10)
TP + FP
2 ∗ (Precision ∗ Recall)
F1 − Score = (11)
Precision + Recall
In this context, TP represents the number of normal samples correctly identified as
normal, FN refers to the number of normal samples incorrectly classified as anomalous,
TN indicates the number of anomalous samples correctly identified as anomalous, and
FP denotes the number of anomalous samples incorrectly classified as normal.
In this section, we validate the feasibility of the proposed method based on three exper-
imental scenarios: (1) binary classification tasks to determine whether samples belong
to normal traffic or attack traffic; (2) multi-class classification tasks to identify whether
samples are normal traffic or a specific type of attack traffic; and (3) detection of unknown
category malicious traffic to assess the model’s capability to detect zero-day attacks.
Given that the CICIoT2023 dataset is relatively new and currently lacks related
research papers, we conducted experiments using four classic machine learning meth-
ods on this dataset, comparing their performance with the proposed method. The machine
learning methods employed include Logistic Regression, Random Forest, Adaptive
Boosting (AdaBoost), and Perceptron. Through comparative experiments, we analyze
the performance of different methods on this dataset to further validate the effectiveness
and advantages of the proposed approach.
Unknown Category Malicious Traffic Detection Based on Contrastive Learning 241
not only excels in binary classification tasks but also achieves outstanding performance
in multi-classification tasks.
Unknown Class Traffic Detection. This paper not only evaluates the effectiveness of
the proposed method in classifying known attack traffic categories but also further tests
the model’s performance in detecting unknown class malicious traffic. To achieve this,
the paper selects the detection rate as the core metric for assessing the effectiveness of
unknown class malicious traffic detection. The detection rate focuses on the proportion
of correctly identified attack samples by the model, making it particularly suitable for
evaluating the model’s performance against unknown class attacks. For unknown attack
traffic, the key is whether the model can effectively recognize these novel attacks. There-
fore, using the detection rate as an evaluation metric is more direct and appropriate. The
calculation formula for the detection rate is as follows:
TN
Detection rate = (12)
TN + FP
In the CICIoT2023 dataset, five attack categories are defined as known attack cate-
gories and combined with normal traffic data to form the training set for model training.
At the same time, one majority class attack category (DoS) and one minority class attack
category (Spoofing) are defined as unknown attack categories, which are added to the
test set to simulate zero-day attacks. In the UNSW-NB15 dataset, seven attack categories
are selected as known attack categories and combined with normal traffic data to create
the training set. Additionally, one majority class attack category (Backdoor) and one
minority class attack category (Fuzzers) are defined as unknown attack categories and
included in the test set.
In the training set, the number of samples for each known attack category is set to 20,
and it does not contain any unknown class network traffic. Unknown attack categories
Unknown Category Malicious Traffic Detection Based on Contrastive Learning 243
are included in the test set to validate the proposed method’s ability to detect zero-day
attacks. The configurations for the training and test sets are shown in Table 4.
The experimental results are shown in Table 5. The method presented in this paper
achieved detection rates of 89.73% and 86.74% for the unknown attack categories DoS
and Spoofing on the CICIoT2023 dataset, respectively, demonstrating good detection
performance. In the UNSW-NB15 dataset, the detection rates for the unknown attack
categories Backdoor and Fuzzers were 85.88% and 83.29%, respectively, also showing
satisfactory detection effects. These results indicate that the method proposed in this
paper can effectively detect new types of attacks when facing unknown malicious traffic.
244 L. Yan et al.
5 Conclusion
This paper proposes a novel unknown malicious traffic detection method based on con-
trastive learning. It not only effectively addresses the issue of class imbalance in traffic
datasets but also achieves precise detection of unknown malicious traffic. The method
was evaluated for binary and multi-class tasks using the CICIoT2023 and UNSW-NB15
datasets, and its advantages were verified through three key metrics: precision, recall,
and the F1 score. Compared with other methods, the proposed approach demonstrated
superior detection performance.
Future research will focus on enhancing the model’s applicability across various sce-
nario datasets to improve its generalization capabilities. Additionally, further increasing
the detection rate of unknown malicious traffic will be an important direction for future
study.
Acknowledgments. This work is supported, in part, by the National Natural Science Foundation
of China Grant No. 62172292 and 62472229.
Disclosure of Interests. The authors have no competing interests to declare that are relevant to
the content of this article.
Unknown Category Malicious Traffic Detection Based on Contrastive Learning 245
References
1. Geiger, A., Liu, D., Alnegheimish, S., et al.: TadGAN: time series anomaly detection using
generative adversarial networks. In: 2020 IEEE International Conference on Big Data (Big
Data). IEEE, pp. 33–43 (2020)
2. Duan, X., Fu, Y., Wang, K.: Network traffic anomaly detection method based on multi-scale
residual classifier. Comput. Commun. 198, 206–216 (2023)
3. Jiang, K., Wang, W., Wang, A., et al.: Network intrusion detection combined hybrid sampling
with deep hierarchical network. IEEE Access 8, 32464–32476 (2020)
4. Ma, Z., Li, J., Song, Y., et al.: Network intrusion detection method based on FCWGAN and
BiLSTM. Comput. Intell. Neurosci. 2022(1), 6591140 (2022)
5. Liu, T., Fu, Y., Kun, W., et al.: Network intrusion detection method based on VAE-CWGAN
and fusion of statistical importance of feature. J. Commun./Tongxin Xuebao 45(2) (2024)
6. Zhou, X., Liang, W., Shimizu, S., et al.: Siamese neural network based few-shot learning
for anomaly detection in industrial cyber-physical systems. IEEE Trans. Industr. Inf. 17(8),
5790–5798 (2020)
7. Cui, J., Zong, L., Xie, J., et al.: A novel multi-module integrated intrusion detection system
for high-dimensional imbalanced data. Appl. Intell. 53(1), 272–288 (2023)
8. Chen, T., Kornblith, S., Norouzi, M., et al.: A simple framework for contrastive learning of
visual representations[C]//International conference on machine learning. PMLR, pp. 1597–
1607 (2020)
9. Hjelm, Rd., Fedorov, A., Lavoie-Marchildon, S., et al.: Learning deep representations by
mutual information estimation and maximization. In: International Conference on Learning
Representations,International Conference on Learning Representations (2018)
10. Bachman, P., Hjelm, Rd., Buchwalter, W.: Learning representations by maximizing mutual
information across views. In: Neural Information Processing Systems,Neural Information
Processing Systems (2019)
11. Oord, A., Li, Y., Vinyals, O.: Representation learning with contrastive predictive coding.
Cornell University - arXiv (2018)
12. Hénaff Olivier, J., Srinivas, A., Fauw, J., et al.: Data-efficient image recognition with
contrastive predictive coding. arXiv: Computer Vision and Pattern Recognition (2019)
13. Liu, X., Zhang, F., Hou, Z., et al.: Self-supervised learning: Generative or contrastive. IEEE
Trans. Knowl. Data Eng. 35(1), 857–876 (2021)
14. Jaiswal, A., Babu, A.R., Zadeh, M.Z., et al.: A survey on contrastive self-supervised learning.
Technologies 9(1), 2 (2020)
15. Zhang, Y., Hooi, B., Hu, D., et al.: Unleashing the power of contrastive self-supervised visual
models via contrast-regularized fine-tuning. Adv. Neural. Inf. Process. Syst. 34, 29848–29860
(2021)
SoftPromptAttack: Advancing Backdoor
Attacks in Language Models Through
Prompt Learning Paradigms
Dixuan Chen, Hongyang Yan(B) , Jiatong Lin, Fan Chen, and Yu Cheng
1 Introduction
2 Related Work
2.1 Textual Backdoor Attack
designed to manipulate the model’s behavior when specific triggers are intro-
duced during inference, while maintaining the model’s performance on clean
data. For instance, Li et al. [2] introduced a model layer poisoning technique
known as PTM, which utilizes multiple character combinations as triggers. PTM
employs a shared linear layer to compute the poisoning loss output across each
layer of the model, making the model more sensitive to the poisoned data. In a
similar vein, Yang et al. [3] proposed using character-based triggers to poison the
training datasets, with particular focus on the model’s word embedding layer.
Their goal was to ensure that the trigger’s word embedding representation aligns
with the expected output when the trigger is activated. Additionally, Li et al. [4]
proposed two distinct attack strategies. The first method, known as character-
level perturbation, replaces certain characters in the text with visually similar
ones, a technique called homomorphic substitution, to serve as triggers. The
second method involves generating sentences using language models that closely
resemble the original text but contain subtle differences, which then function as
triggers. Zhang et al. [5] expanded on this concept by using a combination of
words and conjunctions to form more complex and varied triggers, enhancing the
stealthiness and effectiveness of the attack. Furthermore, Pan et al. [6] suggested
a more direct approach by either inserting or replacing specific words or phrases
directly in the text to create triggers, further diversifying the range of backdoor
attack techniques.
Perez et al. [12] focuses on studying whether malicious prompt will affect the
output of the model and harm the model’s ability to leak the original prompt.
This work proposes two attack methods: “target hijacking” and “prompt leakage”.
Considering the characteristic of constructing prompt templates as input for
prompt based learning, Zhao et al. [1] proposed a clever attack method. This
method uses the prompt template itself as a trigger to construct specific prompt
templates for the target class samples as toxic samples and another specific
prompt template for the remaining samples as clean samples. This paper is also
the first to explore clean label backdoor attacks on language models based on
the prompt learning paradigm. Xu et al. [13] used multiple toxic commands to
conduct backdoor attacks on the model without modifying the original text and
tags. As long as the model receives toxic instructions, any input content will
be ignored and output the target class label expected by the attacker. Wang
et al. [14] proposed using adversarial example samples as “demonstrations” to
mislead models into making incorrect predictions.
In summary, there are two main types of language model backdoor attacks
mentioned above. One of them is that most existing methods tamper with the
original samples or labels during the data preprocessing stage. Such backdoor
attacks not only destroy the semantics of the original samples, but are also
easily detected by defense algorithms, which means they are difficult to deploy
in practical applications. Secondly, even without damaging the samples during
the data processing stage, most language model backdoor attacks cannot achieve
both attack success rate and clean sample prediction accuracy while maintaining
high performance.
Therefore, to solve these two types of problems, this article proposes the
SoftPromptAttack method. Ensuring the efficiency of backdoor attacks while
improving the prediction accuracy of clean samples, achieving high attack success
rates and strong concealment simultaneously. The SoftPromptAttack method
represents a significant step forward in language model security, offering a more
efficient and stealthy way to implement backdoor attacks without compromising
the overall performance of the model. By maintaining high prediction accuracy
for clean samples while achieving a near-perfect attack success rate, this method
holds great promise for both research and practical applications in the field of
AI security.
3 Proposed Method
3.1 Prompt Engineering
The goal of Prompt Engineering (PE) is to find the most suitable template for
the original sample, which is an idea that allows downstream tasks to adapt to
pre trained language models. These artificially created prompt can effectively
guide the model to make appropriate and accurate predictions. For example,
‘The sentence of the following sentence is <mask>: I love this movie!’, The
statement before the colon is ‘prompt’, and the model will make a prediction
output at the <mask> position.
250 D. Chen et al.
Fig. 1. In our backdoor attack method, the soft prompt template itself serves as a
trigger, and the original sample labels are not tampered with. Green represents clean
soft prompt templates, while red represents toxic soft prompt templates. (Color figure
online)
4 Experimental Results
This section will introduce the experimental details of the proposed backdoor
attack method, including the datasets, implementation details, and evaluation
metrics. Then, the experimental results of this method will be compared with
other attack methods and analyzed in detail.
the third is a multi classification datasets. These text datasets contain rich
semantic knowledge, so they can comprehensively and effectively evaluate the
backdoor attack performance of the proposed method.
Our experiment was conducted on the BERT [15] model, which includes
two versions: Bert_base and Bert_1arge. All experiments in this section were
conducted on the NVIDIA RTX 3090 GPU platform and programmed using the
PyTorch framework and Python language.
Table 1. Comparison of our method with the results of existing baseline models.
does not tamper with the labels, this also increases the difficulty of backdoor
attacks. Therefore, achieving an ASR equivalent to Poisson labels is considered
acceptable. It is worth noting that our method achieved the highest percentage
ASR on the OLID datasets. For the CA metric, our method showed the best
results on all three datasets. This means that our method can improve CA to a
certain extent without sacrificing ASR, which will greatly benefit the detection
254 D. Chen et al.
of escape defense algorithms. When comparing clean labels of the same category,
we can also draw similar conclusions. By comprehensively comparing ASR and
CA, our method shows the best performance compared to the first two methods
of the same category.
Fig. 2. Comparison experiment with ProAttack, where the horizontal axis represents
the number of poisons administered and the vertical axis represents ASR.
As shown in the results of Fig. 2, with the increase of the poisoning sample
rate, it can be observed that our method’s ASR curve rises more rapidly and
steeply, and has a smaller variance compared to the baseline. This means that
our method can achieve the same performance ASR with fewer samples, and
can efficiently and stably inject backdoor into the model, effectively inducing
the model to make predictions.
The experimental results are shown in Table 2. Under SCPD defense, our
method showed a certain degree of decrease in ASR on SST-2 and AG’s News
datasets, but CA still maintained a good level. Under ONION’s defense, our
methods can achieve over 70% ASR and maintain over 80% CA. The analy-
sis results indicate that our method can effectively evade detection by defense
algorithms, maintain a high success rate of attacks, and also have decent con-
cealment. We believe that this is very advantageous for practical public settings.
5 Conclusion
In this article, we propose a novel language model backdoor attack method
based on soft prompt templates with clean labels. To inject a backdoor into
the language model, it directly uses soft prompts as triggers. Without explicitly
inserting triggers or altering clean labels, the method achieves an attack success
rate close to 100%, making it the highest-performing method in its category.
Acknowledgments. This work was supported by the National Natural Science Foun-
dation of China (No. 62372130) and the Guangzhou Basic and Applied Basic Research
Project (2023A04J1725, 2024A03J0397).
256 D. Chen et al.
References
1. Zhao, S., et al.: Prompt as triggers for backdoor attack: examining the vulnerability
in language models. arXiv preprint arXiv:2305.01219 (2023)
2. Li, L., et al.: Backdoor attacks on pre-trained models by layerwise weight poisoning.
arXiv preprint arXiv:2108.13888 (2021)
3. Yang, W., et al.: Be careful about poisoned word embeddings: exploring the vulner-
ability of the embedding layers in NLP models. arXiv preprint arXiv:2103.15543
(2021)
4. Li, S., et al.: Hidden backdoors in human-centric language models. In: Proceedings
of the 2021 ACM SIGSAC Conference on Computer and Communications Security
(2021)
5. Zhang, X., et al.: Trojaning language models for fun and profit. In: 2021 IEEE
European Symposium on Security and Privacy (EuroS&P). IEEE (2021)
6. Pan, X., et al.: Hidden trigger backdoor attack on NLP models via linguistic style
manipulation. In: 31st USENIX Security Symposium (USENIX Security 2022)
(2022)
7. Mann, B., et al.: Language models are few-shot learners. arXiv preprint
arXiv:2005.14165, vol. 1 (2020)
8. Radford, A.: Improving language understanding by generative pre-training (2018)
9. Peters, M.E., et al.: Deep contextualized word representations. arXiv:1802.05365
(2018)
10. Du, W., et al.: PPT: backdoor attacks on pre-trained models via poisoned prompt
tuning. IJCAI (2022)
11. Cai, X., et al.: BadPrompt: backdoor attacks on continuous prompt. In: Advances
in Neural Information Processing Systems, vol. 35 (2022)
12. Perez, F., Ribeiro, I.: Ignore previous prompt: attack techniques for language mod-
els. arXiv preprint arXiv:2211.09527 (2022)
13. Xu, J., et al.: Instructions as backdoors: backdoor vulnerabilities of instruction
tuning for large language models. arXiv preprint arXiv:2305.14710 (2023)
14. Wang, J., et al.: Adversarial demonstration attacks on large language models. arXiv
preprint arXiv:2305.14950 (2023)
15. Devlin, J., Chang, M.-W., Lee, K., Toutanova, K.: BERT: pre-training of deep
bidirectional transformers for language understanding. In: Proceedings of naacL-
HLT, vol. 1.0 (2019)
16. Gu, T., Dolan-Gavitt, B., Garg, S.: BadNets: identifying vulnerabilities in the
machine learning model supply chain. arXiv preprint arXiv:1708.06733 (2017)
17. Qi, F., et al.: Turn the combination lock: learnable textual backdoor attacks via
word substitution. arXiv preprint arXiv:2106.06361 (2021)
18. Qi, Fanchao, et al.: Hidden killer: invisible textual backdoor attacks with syntactic
trigger. arXiv preprint arXiv:2105.12400 (2021)
19. Kurita, K., Michel, P., Neubig, G.: Weight poisoning attacks on pre-trained models.
arXiv preprint arXiv:2004.06660 (2020)
20. Xu, L., et al.: Exploring the universal vulnerability of prompt-based learning
paradigm. arXiv preprint arXiv:2204.05239 (2022)
21. Gan, L., et al.: Triggerless backdoor attack for NLP tasks with clean labels. arXiv
preprint arXiv:2111.07970 (2021)
22. Qi, F., et al.: Onion: a simple and effective defense against textual backdoor attacks.
arXiv preprint arXiv:2011.10369 (2020)
Advancing Backdoor Attacks in LM Through Prompt Learning 257
23. Qi, F., et al.: Hidden killer: invisible textual backdoor attacks with syntactic trig-
ger. arXiv preprint arXiv:2105.12400 (2021)
24. Author, F., Author, S.: Title of a proceedings paper. In: Editor, F., Editor, S. (eds.)
CONFERENCE 2016. LNCS, vol. 9999, pp. 1–13. Springer, Heidelberg (2016).
https://doi.org/10.10007/1234567890
25. Author, F., Author, S., Author, T.: Book title, 2nd edn. Publisher, Location (1999)
26. Author, A.-B.: Contribution title. In: 9th International Proceedings on Proceed-
ings, pp. 1–2. Publisher, Location (2010)
27. LNCS Homepage. http://www.springer.com/lncs. Accessed 25 Oct 2023
Removing Regional Steering Vectors
to Achieve Knowledge Domain Forgetting
in Large Language Models
1 Introduction
LLMs are seeing increasingly broad applications in the field of artificial intel-
ligence [1, 2], not only advancing the frontiers of technology but also playing
a crucial role in numerous practical scenarios, including personalized services,
simulated dialogues, role-playing and specific compliance requirements. As the
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 258–270, 2025.
https://doi.org/10.1007/978-981-96-4566-4_18
Removing Regional Steering Vectors 259
2 Related Works
2.1 Hidden Layer
The hidden layer refers to the part of a neural network located between the input
layer and the output layer, used for feature extraction and nonlinear transfor-
mation. In the transformer architecture adopted by most LLMs [7], the hidden
layer contains millions or even billions of parameters, with each layer learning
and representing different features. Not all hidden layers are suitable for feature
extraction [8]. Lower layers are typically used for extracting simpler and more
concrete features, while higher layers are generally used for extracting more
abstract and complex features [9]. In the initial few layers of the hidden layer,
LLMs are more capable of handling simple tasks. It is only in the deeper lay-
ers that large models can tackle more complex tasks [10]. However, due to the
“black box” nature of LLMs [11], further research on hidden layers remains to
be explored.
Steering vectors refer to a method for controlling the behavior of LLMs by adding
vectors that encode the desired behavior to modify the internal activations of
the model [12]. Steering vectors can be used to define the latent guiding space
of sentences under a language model and for the precise generation of sentences
[13]. During training, steering vectors can be used to adjust the model weights,
making it more focused on key features.
Despite being more economical than fine-tuning models [14] and more robust
than prompt engineering [15], steering vectors as a method for controlling the
behavior of LLMs are very promising. However, their applications still lack suffi-
cient exploration [12]. After proving the existence and extraction of steering vec-
tors in LLMs [13], steering vectors have been widely used in adversarial retrieval
of unlearned information in LLMs [16] and in removing the refusal direction of
LLMs [17]. In the work of adversarial retrieval of unlearned information, the
required steering vectors were obtained by calculating the average difference
between anonymous prompts and original prompts and then adding the steering
vectors back into the model during inference to achieve adversarial attacks. In
the work of removing the refusal direction of LLMs, the required steering vec-
tors were similarly obtained by calculating the average difference between refusal
prompts and original prompts, and then removing the steering vectors from the
model weights during inference, thus removing the model’s safety guardrails.
Many modification methods based on steering vector technology have achieved
significant improvements over previous methods in various aspects. However,
there is still room for improvement in the general applicability of these methods.
Removing Regional Steering Vectors 261
We used the Dharmless harmless dataset compiled in [17]. In the first 80% of
this dataset, we randomly sampled 200 items as the general knowledge prompt
dataset. The remaining 20% was used as the test dataset.
For the first objective, we attempted to remove the knowledge domains where
LLMs tend to refuse to answer. When the model detects issues such as privacy
and security, it often responds with “Sorry, I can’t...” or “It’s against...” and so
on. We referred to these responses as refusal-prone responses. These responses
have obvious common characteristics, so this objective can use the method of
removing steering vectors as a way to forget the knowledge domains. We used the
Dharmf ul harmful dataset compiled in [17]. In the first 80% of this dataset, we
randomly sampled 200 items as the refusal-prone prompt dataset. The remaining
20% was used as the test dataset, as shown in Fig. 1.
Fig. 1. Preparation of the dataset for removing the tendency of LLMs to refuse to
answer prompt words.
Fig. 2. Data Processing Workflow for Specific Language Domains. We strive to ensure
that each group covers as broad a range of domains as possible within its language.
This way, each group can present more obvious residual stream data differences in
subsequent stages, leading to clearer steering vectors for the areas that need to be
forgotten.
To extract potential steering vectors from the model’s residual stream data, we
processed the general knowledge and the knowledge domains to be forgotten
datasets using the average difference method [18], which effectively isolates key
Removing Regional Steering Vectors 263
(l) 1 (l)
μi = xi (t), (1)
|Dremove |
t∈Dremove
(l) 1 (l)
νi = xi (t). (2)
|Dgeneral |
t∈Dgeneral
x ← x − rrT x (3)
4 Experiments
To demonstrate that our method can effectively enable the model to forget cer-
tain knowledge domains, we conducted extensive experiments using the pre-
viously prepared dataset. Furthermore, we expanded the experiments to more
steering vector extraction layers and open-source LLMs to prove the universality
of our method.
We conducted all our experiments on an Ubuntu 22.04 server with a vGPU-
32GB graphics card, using Python 3.12, PyTorch 2.3.0, and Cuda 12.1. In the
experiment, we used the 1.5B parameter sizes of the Instruct versions of the open-
source model Qwen 2 [20] and the 3.8B parameter size of the Instruct version
of the open-source model Phi-3-mini as test language models. These models
were chosen because of their good multilingual communication capabilities and
suitable model parameter sizes for testing. The technology of extracting steering
vectors itself is widely applicable in language models.
264 W. Wu et al.
Table 1. Test data for removing the knowledge domain of the tendency to refuse
direction
Fig. 3. In the test, we used the same system prompt and user prompt to compare the
original version and the interfered version of the same model. It can be seen that in the
response of the original model, the model refused to answer; whereas in the response
of the interfered model, the model successfully completed the answer.
We used 100 entries from the Dharmf ul dataset as a test dataset to test the
processed model. Since refusal-prone responses often include expressions such as
“I can’t provide”, “I can’t generate”, “I can’t create”, “I can’t write”, “I can’t help”,
“I can’t assist”, “I can’t for you”, “I can’t support”, “I can’t”, “I won’t”, “I strongly
oppose”, in this objective, we consider responses containing these expressions as
refusal-prone responses.
As shown in Table 1, the processed Qwen2 model achieved a significant
increase in the non-refusal response rate in several selected hidden layers com-
pared to the original model. Examples of responses are illustrated in Fig. 3.
Removing Regional Steering Vectors 265
Fig. 4. Remove Specific Language Result. By asking both the original model and the
processed model “Please speak to me in ...(a certain language)” and checking the pro-
portion of specified language text in the returned text after removing punctuation, we
tested the Qwen2 model and the Phi-3 model. In our tests, the Qwen2 model’s response
rate in Chinese was 87.82% in its original state, which dropped to a minimum of 4.24%
after processing at the 22nd layer. The Phi-3 model’s response rate in Japanese was
76.81% in its original state, which dropped to a minimum of 2.08% after processing at
the 29th layer.
4.4 Analysis
Based on the above experiments, the method of removing regional steering vec-
tors has achieved excellent results in making the model forget knowledge domains
that tend to refuse to answer and specific language knowledge domains. Although
parameters such as model layers may need to be adjusted for different spe-
cific objectives due to the unique characteristics of different hidden layers, we
believe that this method of making LLMs forget knowledge domains by removing
regional steering vectors has its universality.
Removing Regional Steering Vectors 267
Fig. 5. In the test, we conducted experiments on the interfered model. It can be seen
that the multilingual model Qwen2 [20], originally capable of English, Chinese and
Japanese, did not successfully output content in Chinese when asked to respond in
Chinese after the removal of the Dchinese Chinese knowledge domain. Specifically, sev-
eral responses marked with triangles in the figure. Additionally, the model mistakenly
believed it was responding in Chinese.
Table 3. Test data for removing the knowledge domain of the tendency to refuse
direction in different hidden layers
5 Conclusion
We investigates the method of achieving knowledge domain forgetting in LLMs
by removing regional steering vectors. Based on this discovery, we design a pro-
cess method involving dataset design, steering vector extraction, and steering
vector ablation, to more generally achieve knowledge forgetting in LLMs. The
proposed method not only reduces resource consumption compared to existing
mainstream methods but also has broader applicability compared to previous
research on steering vectors. Experimental results prove the effectiveness and
general applicability of our proposed method. In the future work, we will focus
on the applicability of this method in more knowledge forgetting scenarios and
LLMs.
Removing Regional Steering Vectors 269
References
1. Guan, F., Zhu, T., Sun, H., Zhou, W., Philip, S.Y.: Large language models for
link stealing attacks against graph neural networks. IEEE Trans. Big Data (2024).
https://doi.org/10.1109/TBDATA.2024.3489427
2. Kirchenbauer, J., Geiping, J., Wen, Y., Katz, J., Miers, I., Goldstein, T.: A water-
mark for large language models (2024). https://doi.org/10.48550/arXiv.2301.10226
3. Hoffmann, J., et al.: Training compute-optimal large language models. arXiv
preprint arXiv:2203.15556 (2022). https://doi.org/10.48550/arXiv.2203.15556
4. Hu, E.J., et al.: LoRA: low-rank adaptation of large language models (2021).
https://doi.org/10.48550/arXiv.2106.09685
5. Xu, H., Zhu, T., Zhang, L., Zhou, W., Philip, S.Y.: Update selective parameters:
federated machine unlearning based on model explanation. IEEE Trans. Big Data
(2024). https://doi.org/10.1109/TBDATA.2024.3409947
6. Lester, B., Al-Rfou, R., Constant, N.: The power of scale for parameter-efficient
prompt tuning (2021). https://doi.org/10.48550/arXiv.2104.08691
7. Vaswani, A., et al.: Attention is all you need (2023). https://doi.org/10.48550/
arXiv.1706.03762
8. Artzy, A.B., Schwartz, R.: Attend first, consolidate later: on the importance of
attention in different LLM layers. In: Belinkov, Y., Kim, N., Jumelet, J., Mohebbi,
H., Mueller, A., Chen, H. (eds.) Proceedings of the 7th BlackboxNLP Workshop:
Analyzing and Interpreting Neural Networks for NLP, Miami, Florida, USA, pp.
177–184. Association for Computational Linguistics (2024). https://doi.org/10.
18653/v1/2024.blackboxnlp-1.10
9. Liu, Z., Kong, C., Liu, Y., Sun, M.: Fantastic semantics and where to find them:
investigating which layers of generative LLMs reflect lexical semantics. In: Ku,
L.-W., Martins, A., Srikumar, V. (eds.) Findings of the Association for Compu-
tational Linguistics: ACL 2024, Bangkok, Thailand, pp. 14551–14558. Association
for Computational Linguistics (2024). https://doi.org/10.18653/v1/2024.findings-
acl.866
10. Jin, M., et al.: Exploring concept depth: how large language models acquire knowl-
edge at different layers? (2024). https://doi.org/10.48550/arXiv.2404.07066
11. Hawashin, H., Sadrzadeh, M.: Multimodal structure-aware quantum data process-
ing (2024). https://doi.org/10.48550/arXiv.2411.04242
12. Mayne, H., Yang, Y., Mahdi, A.: Can sparse autoencoders be used to decompose
and interpret steering vectors? (2024). https://doi.org/10.48550/arXiv.2411.08790
13. Subramani, N., Suresh, N., Peters, M.: Extracting latent steering vectors from
pretrained language models. In: Muresan, S., Nakov, P., Villavicencio, A. (eds.)
Findings of the Association for Computational Linguistics: ACL 2022, Dublin,
Ireland, pp. 566–581. Association for Computational Linguistics (2022). https://
doi.org/10.18653/v1/2022.findings-acl.48
270 W. Wu et al.
14. Wang, W., Yang, J., Peng, W.: Semantics-adaptive activation intervention for
LLMs via dynamic steering vectors (2024). https://doi.org/10.48550/arXiv.2410.
12299
15. Chalnev, S., Siu, M., Conmy, A.: Improving steering vectors by targeting sparse
autoencoder features (2024). https://doi.org/10.48550/arXiv.2411.02193
16. Seyitoğlu, A., Kuvshinov, A., Schwinn, L., Günnemann, S.: Extracting unlearned
information from LLMs with activation steering (2024). https://doi.org/10.48550/
arXiv.2411.02631
17. Arditi, A., et al.: Refusal in language models is mediated by a single direction
(2024). https://doi.org/10.48550/arXiv.2406.11717
18. Mallen, A., Brumley, M., Kharchenko, J., Belrose, N.: Eliciting latent knowledge
from quirky language models (2024). https://doi.org/10.48550/arXiv.2312.01037
19. Subramani, N., Suresh, N., Peters, M.E.: Extracting latent steering vectors from
pretrained language models (2022). https://doi.org/10.48550/arXiv.2205.05124
20. Qwen2 technical report (2024)
A Novel and Efficient Multi-scale
Spatio-Temporal Residual Network
for Multi-class Intrusion Detection
Nan Li1,2 , Zhaojian Gao3 , Jiabin Ye1(B) , Wei Tang4 , Xun Che5 ,
and Yadang Chen6
1
Nanjing Big Data Group Co., Ltd., Nanjing 210003, China
[email protected]
2
School of Computer Science and Engineering, Southeast University,
Nanjing 210096, China
3
Nanjing Zhixun Interconnection Technology Co., Nanjing 211100, China
4
Hangzhou SecLead Information Technology Co., Ltd., Hangzhou 310056, China
5
Jiangsu Ruizhihe Information Technology Co., Wuxi 214434, China
6
School of Computer Science and Cyberspace Security, Nanjing University of
Information Science and Technology, Nanjing 210044, China
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 271–283, 2025.
https://doi.org/10.1007/978-981-96-4566-4_19
272 N. Li et al.
1 Introduction
Currently, network intrusion detection models are mainly used for binary classi-
fication studies of normal and anomalous network data, and less often for mul-
ticlassification studies of anomalous data. However, multiclassification network
intrusion detection models can provide more detailed warning information and
are able to identify specific types of attacks. This allows the system to take direct
countermeasures against specific attack types. Failure to intercept these attacks
in a timely manner will cause irreparable damage to network devices and data.
The quality of the dataset and data distribution has a significant impact on
the performance of the model. Quality datasets and data features can lead to
better performance of trained models. Web traffic datasets suffer from extreme
imbalance, which is especially common in multiclassification problems, and this
unbalanced data distribution may adversely affect subsequent training results.
In addition, traditional feature selection methods also suffer from the problem of
performing better on most classes of data features while ignoring the features of a
few classes, which leads to poor training results of the model. Therefore, effective
methods need to be employed to address the problem of data imbalance and
ensure that the model is better able to classify and predict data across classes. To
address this shortcoming and to improve the recognition rate of minority classes,
we propose a simple and effective sampling method, specifically, we optimize
the sample distribution by including the samples that generate the imbalanced
classes in the dataset.
With the continuous development of intrusion detection models, the Inter-
national Internet Engineering Task Group (IIETG) develops intrusion detec-
tion system models and forms the IDWG (Intrusion Detection Working Group)
group. This group developed the Common Intrusion Detection Framework
(CIDF). The structure of the CIDF is divided into four logical parts, each of
which is described by events. In the field of machine learning, especially in the
field of intrusion detection, network traffic data can be regarded as events in
CIDF. Aiming at the characteristics of network traffic data and the shortcom-
ings of most current models that have a single structure and ignore temporal
features, we propose a new strategy that integrates the degree of importance of
traffic features and the temporal and spatial characteristics of intrusion detection
data.
The main contributions of the recommended model can be summarized as
follows:
– We propose an improved assisted classification sampling method based on.
Facing high-dimensional feature data, the previous method generates low
quality samples as well as unstable training. In order to solve the above prob-
lems, we improve the generator of the previous method, firstly, we add the
self-attention mechanism module to the generator part, so that the generator
can better deal with the local details and global structure; and we add the
BiGRU module, so as to improve the quality and continuity of the generated
samples; and we use the residual network structure to construct the generator
of the model, so as to accelerate the convergence speed of the model.
A Novel and Efficient Multi-scale Spatio-Temporal Residual Network 273
2 Related Work
2.1 Intrusion Detection Methods Based on Deep Learning
With the rapid development of hardware technology and neural networks, pre-
viously established signature databases are unable to detect all forms of attacks,
especially new types of attack variants. Deep learning based intrusion detection
methods [4, 13, 14] have received extensive attention and research. [7] proposed
to convert the original one-dimensional network intrusion data from the KDD-
CUP 99 dataset and the NSL-KDD dataset into two-dimensional data, and then
used an optimized convolutional neural network to learn the effective features,
and finally carried out multi-classification detection by a softmax classifier, but
the few types of network traffic in this dataset could not adapt to the new chal-
lenges. [19] proposed a traffic intrusion detection model BAT, which combines
Bidirectional Long Short-Term Memory (BiLSTM) and attention mechanism, in
addition, multiple convolutional layers are used to capture the local features of
the traffic data, which effectively improves the intrusion detection. In [20], Graph
Convolutional Network (GCN) is utilized to obtain the spatial correlation among
data flows, BiLSTM is utilized to obtain the temporal correlation among them,
and finally the attention mechanism is utilized to extract the key information
among data, respectively, but the accuracy of this model for binary classification
detection of the dataset is lower, which needs to be further improved. A multi-
scale convolutional neural network (M-CNN) intrusion detection model is con-
structed by [22] and an LSTM is introduced into it to improve the local feature
extraction capability of the model. In order to efficiently and effectively extract
spatial and temporal features from network traffic data, improved residual net-
work blocks and Bidirectional Gated Recurrent Unit (BiGRU) for detection and
classification, respectively, were proposed in [23], where the use of a pooling
layer was introduced into the residual network block structure to extract spatial
features from the data. It was shown that the fusion of spatio-temporal features
can improve the detection performance of the model in general. [10] proposed
a layered construction of CNN and Transformer for detecting spatio-temporal
features in network traffic data. [21] proposed to utilize the parallel processing
274 N. Li et al.
Fig. 1. The diagram of the improved ACGCN. The first layer employed is the Fully
Connected Layer, which serves to mix the input noise and labels and unfold them in a
high-dimensional space to provide a good starting point for the subsequent generation
process. Immediately following are multiple Deconvolution Layers, which are used in
conjunction with the Batch Normalization and ReLU activation functions to progres-
sively enlarge the feature maps while maintaining the stability of the training and the
quality of the generated samples.
the network’s adaptability to features of different scales and expanding the sen-
sory field, thus capturing richer feature information. This parallel learning strat-
egy for multi-scale features effectively reduces the number of parameters while
increasing the depth and width of the network, reduces the computational bur-
den, and significantly improves the expressive capability of the network model.
To further optimize the model, a BN layer is introduced, immediately after the
convolutional layer, to control the gradient explosion problem and accelerate the
training and convergence of the network. Subsequently, the data are processed
through the ReLU activation function to optimize the feature distribution and
enhance the model stability and efficiency. When processing sequence data, one-
dimensional convolution rather than high-dimensional convolution is used to
preserve the original structure and continuity of the data, avoid information
loss, and ensure the recognition accuracy of the model. By improved inception
block and juxtaposing different scales of convolutional kernels on the same neu-
ral network layer, this paper realizes multi-scale feature capture of input data,
thus allowing the network to learn both fine-grained and higher-level features.
In addition, the introduction of the self-attention mechanism behind the con-
volutional layers can highlight important features and improve the recognition
accuracy of the model. Combining the fusion output of multi-scale convolution
with the self-attention mechanism, the data is further passed into BiGRU, which
captures the dependent information from the forward and reverse directions of
the sequence at the same time, removes redundancy, and enriches the feature
information, which enhances the ability of the Inception module to extract the
global temporal features of network intrusion data, and provides an effective
means of identifying the key features in the network intrusion behaviors and
improving the detection accuracy. Finally, the element-wise addition function is
used for feature fusion. This function is based on the principle of item-wise addi-
tion, which can enhance the information content of individual elements while
maintaining the stability of feature dimensions. This feature fusion method can
effectively integrate features from different sources and improve the accuracy and
generalization ability of the model. This method effectively reduces the number
of parameters in the subsequent operations, while ensuring the information den-
sity and computational efficiency of the output features.
4 Experiments
4.1 Configuration
Ubuntu is adopted as the operating system and our method is implemented
through pytorch framework. In order to comprehensively evaluate the perfor-
mance of the proposed model, Accuracy, Precision, Recall, F1-Score, Confusion
Matrix and Training Loss Curve are used as model evaluation metrics.
and attack data in network traffic. Considering that the amount of normal and
attack traffic data in the binary classification task is comparable and there is
no need to rely excessively on data enhancement techniques, this experiment
decides not to use the improved ACGAN method for data enhancement. As
shown in Fig. 3, from which it can be seen that the detection rate for the attack
data is extremely high, where the left figure of the confusion matrix based on the
CIC-IDS-2017 dataset reveals the accuracy and the false alarm rate of the model
in identifying various types of network intrusion behaviors, and the right figure
demonstrates the model’s performance on the CIC-DDoS-2019 dataset. With
these confusion matrices, researchers can visually assess the model’s ability to
recognize different types of attacks, thus providing a comprehensive analysis of
the model’s accuracy and reliability.
Fig. 4. The top row shows the comparison of training loss curves for CIC-IDS-2017
dataset at learning rates of 0.01, 0.001 and 0.0001, and the bottom row shows the
comparison of training loss curves for CIC-DDoS-2019 dataset at learning rates of
0.01, 0.001 and 0.0001.
Fig. 5. The top row shows the comparison of training loss curves for CIC-IDS-2017
dataset at learning rates of 0.01, 0.001 and 0.0001, and the bottom row shows the
comparison of training loss curves for CIC-DDoS-2019 dataset at learning rates of
0.01, 0.001 and 0.0001.
5 Conclusion
To address the challenges of security in current network environments as well as
the problem of data distribution imbalance, we propose an improved ACGAN
model that realizes enhanced functionality for minority class data, and a novel
and efficient intrusion detection framework that enables binary and multi-
classification detection of cyber-attack data and maintains a high accuracy
rate. The problem of data imbalance with high-dimensional features is particu-
larly prominent, which largely affects the efficacy of intrusion detection models
because the models tend to favor the majority class, thus ignoring the less com-
mon but usually more important minority class samples. We propose an imbal-
282 N. Li et al.
References
1. Cao, B., Li, C., Song, Y., Qin, Y., Chen, C.: Network intrusion detection
model based on cnn and gru. Appl. Sci. 4184 (2022). https://doi.org/10.3390/
app12094184, http://dx.doi.org/10.3390/app12094184
2. Goodfellow, I., et al.: Generative adversarial nets. J. Jpn. Soc. Fuzzy Theory Intell.
Inform. 177 (2017)
3. Henderson, N., Min, W., Rowe, J., Lester, J.: Multimodal player affect modeling
with auxiliary classifier generative adversarial networks. In: Proceedings of the
AAAI Conference on Artificial Intelligence and Interactive Digital Entertainment,
vol. 16, no. 1, pp. 224–230 (October 2022). https://doi.org/10.1609/aiide.v16i1.
7434
4. Jothi, B., Pushpalatha, M.: Wils-trs — a novel optimized deep learning based
intrusion detection framework for iot networks. Pers. Ubiquitous Comput. 27(3),
1285–1301 (2023). https://doi.org/10.1007/s00779-021-01578-5
5. Lan, J., Liu, X., Li, B., Zhao, J.: A novel hierarchical attention-based triplet net-
work with unsupervised domain adaptation for network intrusion detection. Appl.
Intell. 11705–11726 (2023). https://doi.org/10.1007/s10489-022-04076-0, http://
dx.doi.org/10.1007/s10489-022-04076-0
6. Liao, J.: An intrusion detection model based on improved acgan in big data envi-
ronment. Secur. Commun. Netw. 2022, 1–9 (2022). https://doi.org/10.1155/2022/
6821174
7. Liu, G., Zhang, J.: Cnid: research of network intrusion detection based on convo-
lutional neural network. Discret. Dyn. Nat. Soc. 1–11 (2020). https://doi.org/10.
1155/2020/4705982, http://dx.doi.org/10.1155/2020/4705982
8. Liu, R., Ji, C., Niu, J., Guo, B.: Research on intrusion detection method based
on 1d-icnn-bigru. J. Phys.: Conf. Ser. 2347(1), 012001 (2022). https://doi.org/10.
1088/1742-6596/2347/1/012001
9. Lu, Y., Cheung, Y.M., Tang, Y.Y.: Hybrid Sampling with Bagging for Class Imbal-
ance Learning, pp. 14–26 (January 2016)
10. Luo, S., Zhao, Z., Hu, Q., Liu, Y.: A hierarchical cnn-transformer model for net-
work intrusion detection. In: 2nd International Conference on Applied Mathemat-
ics, Modelling, and Intelligent Computing (CAMMIC 2022), p. 278 (May 2022).
https://doi.org/10.1117/12.2639876
11. Lyu, J., Young Lee, H., Liu, H.: Color matching generation algorithm for ani-
mation characters based on convolutional neural network. Comput. Intell. Neu-
rosci. 1–13 (2022). https://doi.org/10.1155/2022/3146488, http://dx.doi.org/10.
1155/2022/3146488
A Novel and Efficient Multi-scale Spatio-Temporal Residual Network 283
12. Mozo, A., González-Prieto, A., Pastor, A., Gomez-Canaval, S., Talavera, E.: Syn-
thetic flow-based cryptomining attack generation through generative adversarial
networks. Sci. Rep. (2022). https://doi.org/10.1038/s41598-022-06057-2, http://
dx.doi.org/10.1038/s41598-022-06057-2
13. Peng, W., Kong, X., Peng, G., Li, X., Wang, Z.: Network intrusion detection based
on deep learning. In: 2019 International Conference on Communications, Infor-
mation System and Computer Engineering (July 2019). https://doi.org/10.1109/
cisce.2019.00102, http://dx.doi.org/10.1109/cisce.2019.00102
14. Saviour, M.P.A., Samiappan, D.: Ipfs based storage authentication and access con-
trol model with optimization enabled deep learning for intrusion detection. Adv.
Eng. Softw. 176, 103369 (2023). https://doi.org/10.1016/j.advengsoft.2022.103369
15. Shamsolmoali, P., Zareapoor, M., Shen, L., Sadka, A., Yang, J.: Imbalanced data
learning by minority class augmentation using capsule adversarial networks. Neu-
rocomputing (2020)
16. Shao, S., Wang, P., Yan, R.: Generative adversarial networks for data augmentation
in machine fault diagnosis. Comput. Ind. 85–93 (2019). https://doi.org/10.1016/j.
compind.2019.01.001, http://dx.doi.org/10.1016/j.compind.2019.01.001
17. Shen, W., et al.: Boundary sampling to boost mutation testing for deep learning
models. Inf. Softw. Technol. 106413 (2021). https://doi.org/10.1016/j.infsof.2020.
106413, http://dx.doi.org/10.1016/j.infsof.2020.106413
18. Shi, G., et al.: Knowledge-guided synthetic medical image adversarial augmenta-
tion for ultrasonography thyroid nodule classification. Comput. Methods Programs
Biomed. 105611 (2020). https://doi.org/10.1016/j.cmpb.2020.105611, http://dx.
doi.org/10.1016/j.cmpb.2020.105611
19. Su, T., Sun, H., Zhu, J., Wang, S., Li, Y.: Bat: deep learning methods on
network intrusion detection using nsl-kdd dataset. IEEE Access 29575–29585
(2020). https://doi.org/10.1109/access.2020.2972627, http://dx.doi.org/10.1109/
access.2020.2972627
20. Wang, X., Wang, Q.: An abnormal traffic detection method using gcn-bilstm-
attention in the internet of vehicles environment. EURASIP J. Wirel. Commun.
Netw. (2023). https://doi.org/10.1186/s13638-023-02274-z, http://dx.doi.org/10.
1186/s13638-023-02274-z
21. Wei, J., Chen, Y., Lai, Y., Wang, Y., Zhang, Z.: Domain adversarial neural network-
based intrusion detection system for in-vehicle network variant attacks. IEEE
Commun. Lett. 26(11), 2547–2551 (2022). https://doi.org/10.1109/lcomm.2022.
3195486, https://doi.org/10.1109/lcomm.2022.3195486
22. Yin, X., Chen, L.: Network intrusion detection method based on multi-scale cnn
in internet of things. Mob. Inf. Syst. 1–8 (2022). https://doi.org/10.1155/2022/
8124831, http://dx.doi.org/10.1155/2022/8124831
23. Yu, H., Kang, C., Xiao, Y., Yang, Y.: Network intrusion detection method based
on hybrid improved residual network blocks and bidirectional gated recurrent
units. IEEE Access 11, 68961–68971 (2023). https://doi.org/10.1109/access.2023.
3271866
Provable Data Auditing Scheme
from Trusted Execution Environment
1 Introduction
Over the past few years, the demand for storage has been increasing significantly
with the explosive growth of user data and the widespread application of artificial
intelligence in various fields. As a distributed data storage service based on
the Internet, cloud storage allows users to store data, files, applications, etc.
in remote data centers without relying on local hard disks or physical storage
devices. Therefore, it is becoming an increasingly popular choice for users to
share data. According to Huawei’s forecast, by 2030, the total amount of data
generated worldwide will reach 1YB per year, and mankind will usher in the YB
data era. The total amount of global general storage capacity will reach 37ZB,
of which AI-related storage capacity accounts for 63%.
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 284–298, 2025.
https://doi.org/10.1007/978-981-96-4566-4_20
Provable Data Auditing Scheme from Trusted Execution Environment 285
However, once users transfer their data to the cloud, they no longer have
control over it. The management practices employed by cloud service providers
are often opaque, making it difficult for users to place complete trust in them. To
reduce storage requirements and costs, these providers may overlook certain user
data or delete files that are infrequently accessed [1]. Furthermore, cloud servers
are vulnerable to attacks, hardware or software failures, and other unforeseen
issues, which could lead to the corruption or loss of data stored in the cloud
[2]. In order to protect their reputation, cloud service providers may opt not to
disclose such incidents to users.
To mitigate these potential risks, one solution is to retrieve the stored infor-
mation from the cloud for integrity verification. However, given the vast amount
of data, this approach is impractical due to the significant bandwidth and local
storage space it would require. The provable data possession (PDP) scheme [3]
was introduced to provide a probabilistic proof for the integrity of files stored
by third parties. This scheme enables auditors to efficiently assess the integrity
of the data without the need to download it, offering benefits such as reduced
bandwidth usage and enhanced data security.
Machine learning needs to be trained with large-scale data to construct math-
ematical models and statistical methods to discover patterns and laws in the data
in order to make predictions, inferences or classifications, etc. During the train-
ing process of machine learning, especially in scenarios involving sensitive data,
such as healthcare and finance, how to maintain data security and privacy is an
important issue. Meanwhile, in distributed machine learning training scenarios,
data are usually held by multiple nodes, how to ensure the integrity of these
specific data without transmitting them acts as a key factor for the authentic-
ity and effectiveness of machine learning training. The provable data possession
scheme, as a cryptographic technique, can provide data integrity verification to
ensure the integrity and authenticity of the data relied on for machine learning
training, key parameters generated by the training, and deployment to improve
the credibility of machine learning models.
In earlier PDP scheme frameworks, the client was responsible for generating
keys, computing verification metadata and tags, and then transmitting the tags
and files to the cloud server. During the challenge phase, the user would request
a verification of the data stored on the cloud server, which would then generate
proof of possession in response to the challenge. Finally, the client would confirm
the integrity of the data using the proof provided by the cloud server. In more
recent PDP frameworks, the role of a third-party auditor (TPA) was introduced,
with the TPA taking over the tasks of sending challenges to the cloud server and
validating the proof received from the server. However, there is a risk that the
TPA could misuse user data or collaborate with the cloud server to hide data
loss [4].
Therefore, we propose to replace TPA with the trusted execution environ-
ment (TEE) to prevent data leakage and attacks. Trusted execution environment
(TEE) is an isolation technology that combines software and hardware to ensure
that the execution process of data and programs is secure and reliable. TEE
286 Y. Zeng et al.
provides a protective environment for running sensitive code and processing sen-
sitive data, in which external attacks or operating systems cannot access without
authorization.
We chose Intel SGX as our experimental platform. As a typical application
of TEE, Intel SGX provides strong security protection in multiple fields. Intel
SGX builds a trusted execution environment called Enclave. When the Enclave
is started, hardware and software jointly verify whether the Enclave code has
not been tampered with. The memory within the Enclave is encrypted using the
Memory Encryption Engine (MEE) to ensure that only authorized code can read
the Enclave memory. The code within the Enclave can run in an environment
that is completely isolated from the operating system and other applications. The
operating system or hypervisor cannot access the memory or code in the Enclave.
SGX provides a remote attestation mechanism that allows users to verify that
the code and data of the Enclave are secure and have not been tampered with.
The main contributions of our paper include the following:
2 Related Work
aggregate signature technology [1]. In order to meet the requirement that cus-
tomers can check remote data ownership in a public cloud environment, Wang
et al. introduced an authorized proxy, so that the client can delegate the task of
integrity verification to a trusted proxy [6].
Amidst the growth of cloud computing, a large amount of stored data needs
to be shared securely and efficiently [7, 8]. Therefore, PDP schemes have gradu-
ally begun to focus on enhancing efficiency, security, and privacy. Shen et al. [9]
proposed a global and sampled verification and batch audit protocol that sup-
ports cloud data. The protocol adopts a novel dynamic structure consisting of
a bidirectionally linked information table and a position array. To prevent data
leakage to third-party auditors, Ni et al. proposed an identity-based privacy-
preserving provable data possession scheme [10]. The scheme leverages the RSA
assumption to allow third-party auditors to verify the integrity of outsourced
files through the validation of homomorphic authentication tags, while also sup-
porting aggregate verification across multiple users.
To address the security challenges of existing certificate-independent PDP
schemes in the random oracle model and eliminate the risk of TPAs retrieving
user data blocks, the certificateless provable data possession scheme was intro-
duced. The CL-PDP framework proposed by Deng et al. is specifically designed
for cloud storage and provides provable security within the standard model [11].
Meanwhile, Shen et al.’s CL-PDP scheme supports multi-replica and multi-server
setups, introducing a novel data structure called the mapping version mark table
(MVMT) to facilitate dynamic block-level operations and ensure data traceabil-
ity [12].
With the development of blockchain technology, researchers have found that
the integration of PDP schemes and blockchain technology is a very promising
field. The immutability and transparency of blockchain provide a decentralized
platform for data integrity verification. Wang et al. introduced a decentralized
cloud storage platform SStore, which uses the immutability of blockchain to
securely store file tags, thereby guaranteeing the safety and dependability of the
cloud storage data integrity verification process [2]. Miao et al. introduced a
certificateless multi-copy public audit protocol based on blockchain technology
to achieve decentralized multi-copy data possession verification and fault loca-
tion, while improving the dependability of fault detection and the precision of
decision-making through the use of smart contracts [13]. Wang et al. introduced
a distributed provable data possession scheme based on blockchain technology,
which ensures synchronization and enhances security in multi-cloud storage envi-
ronments. This scheme also facilitates fault detection by utilizing smart contracts
and binary search techniques [14].
3 Preliminaries
Symbols Descriptions
p, q two safe primes
N RSA modulo
Z∗N All elements that have a multiplicative inverse modulo N
QRN set of quadratic residues modulo N
ϕ (N ) Euler function
e Integers coprime with ϕ (N )
d The modular inverse element of e
g generator of N
v Secret random integer
pk public key
sk private key
f a pseudo-random function
π pseudo-random permutation
h () , H () cryptographic hash function
F the stored file
b Unique identifier for the user
n The total number of blocks uploaded to the cloud server
mi The i-th file block
ti The timestamp of the i-th file block
ID Unique identifier for the file
idi Unique identifier of the i-th file chunk
TID,i The label of the i-th file block of a file
c The number of challenge file blocks
chal challenge
k1 , k2 Two pseudo-random numbers
W Log files
in the computer’s memory but logically separated from the normal process
address space. This memory area is controlled by the SGX extended hard-
ware when the process is running and can only be accessed through special
instructions. The code and data in the SGX Enclave are protected, and any
other process, including the operating system kernel, cannot directly access
or tamper with the contents therein.
• Attestation. The goal of remote attestation is to create a proof that includes
measurements of the software state, which is then signed with a certificate key
stored in the hardware. The remote verifier validates this proof by verifying
both the signature and the measurements.
To ensure the integrity of the data stored at the end of the cloud server, and also
to prevent TPAs from leaking data or conspiring with cloud service providers, we
introduce a provable data possession scheme that leverages the trusted execution
environment. The system architecture of the scheme is illustrated in Fig. 1.
There are three main participants in the system: the user, the cloud server,
and the trusted execution environment. When a user needs to outsource a data
file to the cloud storage, he should first slice a large file into smaller chunks of
fixed size. The user computes the homomorphic verifiable tag of the file based on
290 Y. Zeng et al.
its public and private keys and uploads the file, public key and file tag together
to the cloud server. A secure transmission channel is builded between the user
and the TEE and transmits its own private key to the TEE for authentication.
The TEE as a trusted auditor periodically sends a challenge to the cloud server.
When the cloud server receives the challenge it generates a proof as a response
based on the data in its possession and returns it to the TEE. The TEE verifies
the proof and generates a verification log file. If the validation passes, the block
of files related to the challenge is complete, if the validation fails, the block of
files related to the challenge is corrupted or missing. the TEE digitally signs the
generated log file and provides regular feedback to the user on the integrity of
the data on the cloud.
The specific description of the PDP scheme we proposed above is as follows:
KeyGen: The client runs a probabilistic key generation algorithm, taking a secu-
rity parameter as input, to obtain a public-private key pair (pk, sk).
KeyTran: A secure transmission channel is established between the client and
TEE through identity authentication mechanism, key exchange protocol and
encryption technology to secretly transmit the client’s private key to TEE.
TagBlock : The client takes the file block and public and private keys as input to
generate verification metadata.
GenProof : The cloud server takes the public key, the stored file block, and the
challenge as input and generates a proof of possession.
Provable Data Auditing Scheme from Trusted Execution Environment 291
CheckProof : TEE runs to verify the possession certificate sent by the cloud
server, using the public and private keys, challenge, and possession certificate as
input to determine whether the verification is successful.
Sig: TEE digitally signs the log file and returns the log file and signature integer
pair to the client through an untrusted channel. The client verifies the signature
to verify the source and integrity of the information.
The specific algorithm and description of our proposed provable data possession
scheme based on a trusted execution environment are provided below.
We use b as a unique identifier of a client to identify the client information.
Assuming that the user wants to transmit a file F to the cloud storage, the user
needs to divide the file F into n file blocks first, and F = {mi }1in represents
the ordered set of n file blocks. When dividing file F into file blocks mi , a
timestamp ti (1 i n) and a unique identifier idi (1 i n) are generated
for each newly divided file block mi on the client, which is conducive to the
management of file blocks and data tracing. In large-volume data scenarios, it is
impossible to upload only one file at a time. In order to avoid conflicts, a unique
file identifier needs to be added to each file F , which we represent as ID.
In addition, we define h to be a secure hash function, and we also use a
pseudo-random function (PRF) f and a pseudo-random permutation (PRP) π.
They are described as follows:
• Setup: We let KeyGen (K) → (p, q, N, e, d). The gcd function represents the
greatest common divisor of two integers. Let a be a random integer belonging
to Z∗N , and a satisfies gcd (a ± 1, N ) = 1. The generator g is an element
in QRN , calculated by g = a2 . Generate a secret random number v, where
v ∈ ZN .
• KeyGen: The user generates a public key pk = (N, g) and a private key
sk = (e, d, v), and sends the public key pk = (N, g) to the server and TEE.
• KeyTran: A secure transmission channel is established between the user and
TEE, and the user sends his private key to TEE.
1) After the user and TEE authenticate each other through the identity
authentication mechanism, they use the key negotiation protocol to gen-
erate a shared key K. Calculate AKE (θ, ω) → K, where the AKE func-
tion represents the authentication key exchange protocol, θ represents
the identity information verification parameter, and ω represents the key
negotiation protocol parameter.
2) Let SKA be a symmetric key algorithm. The user calculates SKA (K, sk)
to encrypt the private key sk. After receiving it, TEE decrypts it to obtain
the user’s private key.
292 Y. Zeng et al.
mod N . c
3) Calculate ρ = j=1 αj xij .
4) Output proof of possession (T, ρ).
• CheckProof: TEE verifies the possession certificate returned by the cloud
server and generates a log file containing the verification process and results.
1) For 1 j c
Compute J = πk1 (j), α = fk2 (j), Cij = v||b||idij
e
let μ = c hT C αj mod N
j=1 ( ij )
Algorithm 1. ECDSA
Input: Log files W , Private Key d , Curve parameters: q , a , b , n , G
Output: (r , s )
Detailed experiment:
1: k ← $ {1, 2, · · · , n − 1}
2: P = (x , y ) = k G
3: r = x mod n , if r = 0 return to 1
4: t = k−1 mod n
5: e = H (W )
6: s = k−1 (e + d r ) mod n , if s = O return to 1
5 Experiments
In this section, we analyze the computation overhead of our proposed PDP
scheme. For our system, we focus on the five stages of TagBlock, GenProof,
CheckProof, TEESG, and SigVer. Because these stages are time-consuming
and frequently operated. Since hash operations, addition operations and mod-
ulo remainder have shorter execution times and minimal computation overhead
compared to other operations, we ignore their computational costs. To simplify
the expression, we employ TM Exp as a symbol for the computation overhead
of modular exponentiation operations in the system, TM M ul to represent the
computation overhead of modular multiplication operations, TM Rem to repre-
sent the computation overhead of modular remainder operations, and TSM ul to
represent the computation overhead of scalar multiplication in the elliptic curve
digital signature algorithm.
Before calculation, for a file F , we divide it into n equal-sized file blocks m
and use c to represent the number of challenged file blocks. The computation
overhead we describe below is only for one complete process that occurs in the
system. In the TagBlock stage, there are a total of 2n modular exponentiation
operations and n modular multiplication operations, with a computation over-
head of 2nTM Exp +nTM M ul . In the GenProof stage, there are a total of c modular
exponentiation operations and 2c − 1 modular multiplication operations, with
a computation overhead of cTM Exp + (2c − 1) TM M ul . In the CheckProof stage,
there are a total of c+2 modular exponentiation operations and c modular multi-
plication operations, with a computation overhead of (c + 2) TM Exp +cTM M ul . In
the TEESG stage, there are a total of 1 scalar multiplication operation, 1 modu-
lar inverse operation, and 2 modular multiplication operations, with a computa-
tion overhead of TSM ul + TM Inv + 2TM M ul . In the SigVer stage, there are a total
of 2 scalar multiplication operations, and 2 modular multiplication operations,
with a computation overhead of 2TSM ul + 2TM M ul (Fig. 3).
We use our personal host as the experimental platform to deploy the PDP
solution. The hardware and software configuration of the host includes Ubuntu
22.04 64-bit operating system, 32 GB physical memory, and Intel i5-12600KF
3.7 GHz CPU. At the same time, this experiment also uses the PBC library and
the GMP library to test the computation overhead of the PDP solution.
Provable Data Auditing Scheme from Trusted Execution Environment 295
Regardless of the size of the file, when 1% of the data is lost or damaged,
it can be detected with a probability of 99% and 95% respectively by randomly
sampling 460 and 300 file blocks [3]. Therefore, in this experiment, we only
consider the damaged or lost data accounting for only 1%. We selected a machine
learning training dataset, divided it into several 1 GB files, and used them as
files uploaded by users to the cloud server for storage.
Initially, we aim to determine the optimal size for the file blocks. The execu-
tion of TagBlock, GenProof, and CheckProof stages may be affected by the file
block size. To determine the optimal file block size, we challenge 460 file blocks
each time and calculate the sum of execution time of different file block sizes in
the three stages.
296 Y. Zeng et al.
Figure 2, drawn based on our experimental results, shows that when the
file block size is greater than or equal to 16KB, the sum of the computation
overhead of the TagBlock, GenProof, and CheckProof stages remains basically
unchanged. The curve’s deceleration rate slows down significantly, showing a
trend of stabilization. Therefore, we choose 16KB as the file block size for our
experiment.
According to the determined file block size, we experimentally obtain the
computation overhead of our PDP scheme when challenging 300 file blocks and
460 file blocks. Since TagBlock is not affected by the number of challenged file
blocks and its computation overhead is much larger than that of the other four
stages, the following figure does not show the experimental results of this stage.
The computation overhead of each stage in the experiment is shown in Table 2.
Compared with the experimental results of paper [3], it can be concluded that
the computation overhead of our PDP scheme is reasonable. At the same time,
the algorithm complexity of the challenge, CheckProof and TEESG stages is
low, the computation overhead is small, and they can be transplanted to TEE
for operation.
6 Conclusions
To eliminate the hidden danger of TPA stealing user data or conspiring with
cloud servers, we replace TPA with TEE and propose a data integrity auditing
scheme based on trusted execution environment. The scheme effectively improves
data security on the basis that data integrity can be verified. We use Intel SGX
technology to port the relevant cryptographic libraries to the Enclave and deploy
it in a real cloud server scenario for our PDP scheme experiments. The experi-
mental results show that the performance overhead incurred by TEE instead of
TPA is reasonable and can support applications in cloud storage environments.
References
1. Wang, Q., Wang, C., Ren, K., Lou, W., Li, J.: Enabling public auditability and data
dynamics for storage security in cloud computing. IEEE Trans. Parallel Distrib.
Syst. 22(5), 847–859 (2011)
2. Wang, L., Hu, M., Jia, Z., Guan, Z., Chen, Z.: Sstore: an efficient and secure
provable data auditing platform for cloud. IEEE Trans. Inf. Forensics Secur. 19,
4572–4584 (2024)
3. Ateniese, G., et al.: Provable data possession at untrusted stores. In: Proceedings of
the 14th ACM Conference on Computer and Communications Security, ser. CCS
’07. New York, NY, USA: Association for Computing Machinery, pp. 598–609,
2007. https://doi.org/10.1145/1315245.1315318
4. He, Y., Xu, Y., Jia, X., Zhang, S., Liu, P., Chang, S.: EnclavePDP: a general
framework to verify data integrity in cloud using intel SGX. In: 23rd International
Symposium on Research in Attacks, Intrusions and Defenses (RAID 2020). San
Sebastian: USENIX Association, pp. 195–208, October 2020. https://www.usenix.
org/conference/raid2020/presentation/he
5. Zhu, Y., Hu, H., Ahn, G.-J., Yu, M.: Cooperative provable data possession for
integrity verification in multicloud storage. IEEE Trans. Parallel Distrib. Syst.
23(12), 2231–2244 (2012)
6. Wang, H.: Proxy provable data possession in public clouds. IEEE Trans. Serv.
Comput. 6(4), 551–559 (2013)
7. Shen, J., Zhou, T., Chen, X., Li, J., Susilo, W.: Anonymous and traceable group
data sharing in cloud computing. IEEE Trans. Inf. Forensics Secur. 13(4), 912–925
(2018)
8. Shen, J., Yang, H., Vijayakumar, P., Kumar, N.: A privacy-preserving and untrace-
able group data sharing scheme in cloud computing. IEEE Trans. Dependable
Secure Comput. 19(4), 2198–2210 (2022)
9. Shen, J., Shen, J., Chen, X., Huang, X., Susilo, W.: An efficient public auditing
protocol with novel dynamic structure for cloud data. IEEE Trans. Inf. Forensics
Secur. 12(10), 2402–2415 (2017)
10. Ni, J., Zhang, K., Yu, Y., Yang, T.: Identity-based provable data possession from
rsa assumption for secure cloud storage. IEEE Trans. Dependable Secure Comput.
19(3), 1753–1769 (2022)
11. Deng, L., Wang, B., Wang, T., Feng, S., Li, S.: Certificateless provable data pos-
session scheme with provable security in the standard model suitable for cloud
storage. IEEE Trans. Serv. Comput. 16(6), 3986–3998 (2023)
298 Y. Zeng et al.
12. Shen, J., Zeng, P., Choo, K.-K.R., Li, C.: A certificateless provable data posses-
sion scheme for cloud-based ehrs. IEEE Trans. Inf. Forensics Secur. 18, 1156–1168
(2023)
13. Miao, Y., Huang, Q., Xiao, M., Susilo, W.: Blockchain assisted multi-copy provable
data possession with faults localization in multi-cloud storage. IEEE Trans. Inf.
Forensics Secur. 17, 3663–3676 (2022)
14. Wang, H., Wan, Z., He, D., Yu, J.: Synchronous blockchain-based distributed prov-
able data possession with forward-security. IEEE Trans. Serv. Comput. 17(3),
1227–1238 (2024)
15. Shen, J., Zhou, T., He, D., Zhang, Y., Sun, X., Xiang, Y.: Block design-based key
agreement for group data sharing in cloud computing. IEEE Trans. Dependable
Secure Comput. 16(6), 996–1010 (2019)
Enhanced PIR Scheme Combining
SimplePIR and Spiral: Achieving Higher
Throughput Without Client Hints
1 Introduction
Private Information Retrieval (PIR) [7] is a cryptographic protocol that allows
a user to retrieve an item from a database server without revealing which item
is being retrieved. The core idea is to protect the privacy of the user’s query
from the database owner. PIR schemes are particularly useful in scenarios where
users need to access data without disclosing their interests, thus preserving their
privacy. In the realm of privacy preserving machine learning, the table lookup
method serves as a possible approach for nonlinear computations, thereby falling
within the application scope of PIR.
Recently, researchers have proposed SimplePIR [12], a PIR protocol with
extremely high throughput on the server side, and provided an example of its
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 299–315, 2025.
https://doi.org/10.1007/978-981-96-4566-4_21
300 H. Xu et al.
1.2 Contributions
The contributions of our work align with those of previous studies in [11,14,19]
in that we aim to eliminate the client-side download step required in the design
of SimplePIR, achieving efficient PIR without user downloads. However, our
approach differs significantly from existing methods. Specifically, we build upon
the Spiral protocol and integrate the design principles of SimplePIR into Spiral,
enabling us to further enhance Spiral’s throughput and eliminate the need for
client-side hints downloads. In detail:
1. We modify the steps in Spiral that retrieves a row of entries from the
database, reducing the server’s computation time and consequently improving
the throughput of the PIR scheme.
2. We implement the modified scheme and compare its performance with Dou-
blePIR and the original Spiral on the same platform. Our results confirm
that the modified scheme outperforms Spiral in terms of throughput and,
for frequently updated databases, surpasses DoublePIR in terms of overall
communication overhead.
2 Preliminaries
R
For a probability distribution χ, we use x ← − χ to indicate that x is a random
R
sample from χ. For a finite set S, we use x ← − S to denote sampling x uniformly
at random from S. Let λ be the security parameter. Let p and q be the plaintext
and the ciphertext modulus, respectively. For a positive integer m, we use [m]
to represent the integer set {1, 2, . . . , m}. For integers a, b ∈ Z, we write [a, b]
to denote the set {a, a + 1, . . . , b}. For a positive integer q ∈ N, we write Zq to
denote the integers modulo q. We mainly use bold uppercase letters to denote
matrices (e.g., A, B) and bold lowercase letters to denote vectors (e.g., u, v).
Let R = Z[x]/(xd + 1) be the cyclotomic ring where d is a power of two. For
a positive integer q ∈ N, we write Rq = R/qR. Let · be the rounding to the
nearest integer, · the floor operation and · the ceiling operation. Let || · ||∞
denotes the infinite norm.
(n+1)×n
C ∈ Rq is a Regev ciphertext of M ∈ Rpn×n with error E ∈ Rqn×n and
(n+1)×n
secret key S ∈ Rq if ST C = Δ · M + E. To get M from Z = ST C, one
Z
could compute M = Δ . If ||E||∞ + (q mod p) < Δ/2, M = M .
– Cres ← Answer(D, qu, Chint ) On input the database D, a query qu, the hint
Chint , the answer algorithm outputs the rescaled ciphertext Cres as the server
response.
– d ← Extract(sk, Cres ): On input the secret key sk, the response Cres , the
extract algorithm outputs the database entry d ∈ D indexed by idx.
The PIR scheme should satisfy the following properties of correctness, secu-
rity, and efficiency.
where the probability is taken over all possible sources of randomness employed
by the algorithms.
Definition 5 (Security). For every λ, N ∈ N, and idx0 , idx1 ∈ [N], define the
distribution
The PIR scheme is computationally secure if for every efficient adversary A, the
quantity PIRadv[A](λ, N) is a negligible function of λ and N.
Definition 6 (Communication Efficiency). A PIR scheme is efficient if the
total communication cost of a query is smaller than the size of the database D.
The Spiral scheme relies on the matrix Regev encryption. We observe that matrix
Regev ciphertexts are simulatable. We define an algorithm CipherGen as follows.
If Δ > 2||E||∞ , then M = Z = M /Δ . In other words, given an input key
S, a randomly generated Regev ciphertext C allows the derivation of a message
M.
A simulated ciphertext should be indistinguishable from a ciphertext gen-
erated by a normal encryption algorithm. To that end, we provide the game
GameA,Ψ (1λ , 1n ):
If for all efficient adversary A there is a negligible function negl such that,
for all λ ∈ N, it holds, that
1
Pr[GameA,Ψ (1λ , 1n ) = 1] ≤ + negl(λ), (4)
2
where the probability is taken over the randomness used by A and the random-
ness used in the GameA,Ψ (1λ , 1n ).
Obviously, if the RLWE assumption holds, then the Regev ciphertext is indis-
tinguishable from a ciphertext generated randomly according to a uniform distri-
bution. Assuming that the random numbers output by the PRF function follow
a uniform distribution, then the ciphertext generated by the CipherGen algo-
rithm is indistinguishable from a ciphertext generated randomly according to
a uniform distribution. Therefore, it can be concluded that the Regev cipher-
text is indistinguishable from the random ciphertext generated by the CipherGen
algorithm.
306 H. Xu et al.
methods are applicable: (i, j1 , . . . , jv2 ), where i ∈ [0, 2v1 − 1] and each jk ∈ {0, 1}
for k ∈ {1, . . . , v2 }, or (i, j), where i ranges similarly, and j ∈ [0, 2v2 − 1]. Each
entry di in the database belongs to Rpn×n , ensuring that di ∞ ≤ p/2. The
public seed sd can be an arbitrary value, such as a concatenation of the server’s
name and a date, encoded as a bitstring. This seed can be periodically updated
according to a predefined schedule, allowing any user to compute it without
requiring direct communication with the server.
Algorithm 1: HintGen
Input: λ, the security parameter; D, the database; sd , the public seed;
Output: Chint , the result of processing the first (large) dimension of the
database.
1 for i = 0 to 2v1 − 1 do
2 Creg [i] ← CipherGen(1λ , sd ||i);
3 for j = 0 to 2v2 − 1 do
4 Chint [j] = ScalarMul(Creg [0], D[0][j]);
5 for i = 1 to 2v1 − 1 do
6 Chint [j] = Add(Chint [j], ScalarMul(Creg [i], D[i][j]));
7 return Chint ;
Algorithm 2: Query
Input: λ, the security parameter; sd , the public seed; idx = (i∗ , j1∗ , . . . , jv∗2 ),
the query index.
Output: The query qu = (Zreg , c, ck), a random secret key sk.
1 Sample a random secret key S ← KeyGen(1λ , 1n ) as sk for both matrix Regev
and GSW encryption;
2 for i = 0 to 2v1 − 1 do
3 Creg [i] ← CipherGen(1λ , sd ||i);
4 Zreg [i] ← HalfDecrypt(S, Creg [i]);
5 Zreg [i∗ ] ← (Zreg [i∗ ] − Δ/Δ · In ) mod q/Δ ;
6 Pack and encrypt the tuple (j1∗ , . . . , jv∗2 ) as in the Spiral to produce a scalar
Regev c;
7 Generate the conversion key ck as in the Spiral to translate the scalar Regev
with s to the GSW with S;
8 Set qu ← (Zreg , c, ck);
9 return qu and sk;
The first line of the algorithm generates a key shared for both matrix Regev
and GSW encryption. Lines 2–4 produce random ciphertexts and then decrypt
them to obtain a vector of half-plaintext. Line 5 embeds the row value into
the half-plaintext, where Δ = q/p, and Δ is a configurable parameter that
determines the noise bound introduced by the HalfDecrypt operation. Lines 6–7
correspond to the operations depicted in the lower half of Fig. 1, adhering to the
procedures outlined in Spiral, with the exception that the packing and encrypt-
ing operations solely manipulate column value, allowing for parameter optimiza-
tions. Notably, the computation of the conversion key must be performed during
the online phase.
Algorithm 3: Answer
Input: D,the database; qu = (Zreg , c, ck), the query; Chint , the hint.
Output: The rescaled response Cres .
1 Let Zreg ← Zreg · Δ ;
2 for j = 0 to 2v2 − 1 do
3 Zhint [j] ← Zreg [0] · D[0][j];
4 for i = 1 to 2v1 − 1 do
5 Zhint [j] ← Zhint [j] + Zreg [i] · D[i][j];
6 hj ← Chint [j] − [0|Zhint [j]T ]T ;
7 Using h, the column query c, and the conversion key ck, the Sprial procedures
Coefficient Expansion, RegevToGSW, Folding and ModulusSwitch are
sequentially applied, yielding a rescaled response denoted as Cres , as shown in
Fig. 1;
8 return Cres ;
Enhanced PIR Scheme Combining SimplePIR and Spiral 309
The first line multiplies Zreg by Δ to restore the elements in Zreg back to Rq .
Following that, lines 2 through 5 perform similar element-wise multiplication and
addition operations as in the offline phase, with the distinction that here, the
element-wise multiplication only requires computing the product of two square
matrices of dimension n. Line 6 calculates the difference between the server’s hint
Chint and the Zhint generated during the online phase, and to align their dimen-
sions, padding is applied with a zero vector. Line 7 briefly outlines the process
within Spiral, which is depicted in the bottom-right section of Fig. 1.
Lemma 1. From the viewpoint of a user with a private key S, the vector h in
line 7 of the Answer algorithm constitutes a matrix Regev encryption of the i∗
row from the database, with a noise bounded by O(2v1 dnpΔ ), where i∗ is the
input of the Query algorithm.
Proof. Let D = {d1 , . . . , dN } be the database where di ∈ Rpn×n and take any
index idx = (i∗ , j1∗ , . . . , jv∗2 ). Sample Chint ← HintGen(D, sd ) and let (qu, sk) ←
Query(1λ , sd , idx), Cres ← Answer(D, qu, Chint ). We note that qu = (Zreg , c, ck)
and sk = S. We analyze the vector h from the simulated matrix Regev.
(n+1)×n
By definition, Creg [i] ∈ Rq for any i ∈ {0, . . . , 2v1 } is a simulated
matrix Regev ciphertext. From the view point of a user with private key S =
[−s̃ | In ]T , the random ciphertext looks like
1×n
aT 0
Creg [i] = i + ∈ Rq(n+1)×n
s̃aTi + Ei Mi
Now consider hj for any j ∈ {0, . . . , 2v2 −1}, it is a difference between the sum
of the set {Creg [i]·D[i][j]|i ∈ {0, . . . , 2v1 −1}} computed in the HintGen algorithm
and the sum of the set {Δ · Zreg [i] · D[i][j]|i ∈ {0, . . . , 2v1 − 1}} computed in the
Answer algorithm. For any i = i∗ ,
Creg [i] · D[i][j] − Δ · [0|Zreg [i]T ]T · D[i][j] = (Creg [i] − Δ · [0|Zreg [i]T ]T ) · D[i][j].
– Game 0: This game follows the Query algorithm with inputs (λ, sd , idx0 ).
However, since CipherGen is modeled as a random oracle, random ciphertexts
must be obtained by querying the oracle. The oracle simple runs the CipherGen
algorithm to get a response.
Enhanced PIR Scheme Combining SimplePIR and Spiral 311
– Game 1: This game is identical to Game 0, except that the oracle operates
differently. The oracle gets a secret key S ← KeyGen(1λ , 1n ). Then for an
oracle query (sd ||i, q, n), it samples a random message Zreg,i ∈ Rq/Δ n×n
and
Fig. 2. Comparison of the amortized communication between our scheme and Dou-
blePIR, illustrating the relationship between the number of queries per update and the
corresponding amortized communication. The update occurs once for every specified
number of queries, and the graph shows the resulting communication overhead during
this period, for a 225 × 256B database.
312 H. Xu et al.
4 Performance Analysis
Based on the code of Spiral, we implemented the improved PIR and conducted
performance tests focusing on communication and throughput. The test platform
was equipped with an Intel Xeon(R) Gold 5218R CPU, 512 GB of RAM, and
Ubuntu 22.04 as the operating system. The PIR was primarily implemented
using C and Go languages, with Clang 12, Go 1.18.1, and GCC 11.4 serving as
the compiling tools. Both our scheme and the comparative schemes leveraged
SIMD instruction sets when possible. The performance data were obtained under
single-threaded conditions, and the results presented in the paper are averages
taken after running the tests five times or more.
For the comparison, we employed the Spiral scheme and DoublePIR. Both
of these schemes return to the client only the content that the client wishes to
query, excluding any additional value. This feature is vitable for PIR schemes
that prioritize server privacy. In terms of database configurations, we utilized two
databases resembling Spiral’s architecture: one is a 256 MB database with an
entry size of 256B, and the other is an 8 GB database with an entry size of 32 KB.
However, DoublePIR’s offline download volume proved excessively high when the
number of database entries is relatively small. Consequently, we introduced an
additional 8 GB database with an entry size of 256B, which effectively reduced
the offline download volume to less than the database’s capacity. We note that
schemes such as HintlessPIR and YPIR [14,19] are built on DoublePIR and
share the same limitation, thus we have not conducted a comparison of their
performance. Note that the YPIR claims to support large entry sizes based
on simplePIR, but this approach leaks more information to clients than they
originally queried.
The comparison results are presented in Table 1. For each scheme, we compare
them across several dimensions: the user’s offline download volume, the size of
online query and answer, server answer time, and server throughput. The size of
online query and answer represents the communication overhead of the schemes,
while the server throughput and answer time reflect the computational load on
the server side.
Table 1 clearly demonstrates the following facts:
Table 1. Comparison results of Spiral [18], DoublePIR [12], and our scheme.
Fig. 3. The behavior of DoublePIR and ours scheme when processing a database with
225 entries, and how the offline download size changes with the variation in entry size.
Lastly, we present some operational parameters for Spiral and the improved
scheme. The improved scheme employs identical parameters as the Spiral scheme,
with the additional introduction of Δ = 216 . For the three different database
configurations, the number of plaintext modulus bits log p = 8, the number
of ciphertext modulus bits log q = 56, the polynomial degree d = 2048 and
the matrix dimension of data is n = 2. For different database configurations,
while the first dimension v1 = 9, the second dimension v2 has values 6, 11,
9, respectively. The parameters of DoublePIR for the three different database
configurations have the same log q = 32 and n = 1024. For different database
configurations, the number of plaintext modulus bits are about 10, 9 and 9
bits, and the dimensions of database analog to (v1 , v2 ) are (11, 16), (16, 16) and
(16, 16).
5 Conclusion
References
1. Ali, A., et al.: Communication-computation trade-offs in pir. In: USENIX Security,
pp. 1811–1828 (2021)
2. An, Z., Tian, H., Chen, C., Zhang, F.: Deniable cryptosystems: simpler construc-
tions and achieving leakage resilience. In: ESORICS, pp. 24–44. Springer (2023)
3. Angel, S., Chen, H., Laine, K., Setty, S.: Pir with compressed queries and amortized
query processing. In: IEEE S&P, pp. 962–979 (2018)
4. Applebaum, B., Cash, D., Peikert, C., Sahai, A.: Fast cryptographic primitives
and circular-secure encryption based on hard learning problems. In: CRYPTO,
pp. 595–618. Springer (2009)
5. Cachin, C., Micali, S., Stadler, M.: Computationally private information retrieval
with polylogarithmic communication. In: EUROCRYPT, pp. 402–414. Springer
(1999)
6. Chang, Y.C.: Single database private information retrieval with logarithmic com-
munication. In: ACISP, pp. 50–61. Springer (2004)
7. Chor, B., Goldreich, O., Kushilevitz, E., Sudan, M.: Private information retrieval.
In: FOCS, pp. 41–45 (1995)
8. Damgård, I., Nielsen, J.B.: Improved non-committing encryption schemes based
on a general complexity assumption. In: CRYPTO, pp. 432–450. Springer (2000)
9. Davidson, A., Pestana, G., Celi, S.: Frodopir: simple, scalable, single-server private
information retrieval. Proc. Priv. Enhanc. Technol. (2023)
10. Gentry, C., Ramzan, Z.: Single-database private information retrieval with constant
communication rate. In: ICALP, pp. 803–815. Springer (2005)
11. Henzinger, A., Dauterman, E., Corrigan-Gibbs, H., Zeldovich, N.: Private web
search with tiptoe. In: SOSP, pp. 396–416 (2023)
12. Henzinger, A., Hong, M.M., Corrigan-Gibbs, H., Meiklejohn, S., Vaikuntanathan,
V.: One server for the price of two: simple and fast single-server private information
retrieval. In: USENIX Security, pp. 3889–3905 (2023)
13. Kushilevitz, E., Ostrovsky, R.: Replication is not needed: single database,
computationally-private information retrieval. In: FOCS, pp. 364–373 (1997)
14. Li, B., Micciancio, D., Raykova, M., Schultz-Wu, M.: Hintless single-server private
information retrieval. In: CRYPTO, pp. 183–217. Springer (2024)
15. Lipmaa, H.: An oblivious transfer protocol with log-squared communication. In:
ISC, pp. 314–328. Springer (2005)
16. Lyubashevsky, V., Peikert, C., Regev, O.: On ideal lattices and learning with errors
over rings. In: EUROCRYPT, pp. 1–23. Springer (2010)
17. Melchor, C.A., Barrier, J., Fousse, L., Killijian, M.O.: Xpir: private information
retrieval for everyone. Proc. Priv. Enhanc. Technol. 155–174 (2016)
18. Menon, S.J., Wu, D.J.: Spiral: fast, high-rate single-server pir via fhe composition.
In: IEEE S&P, pp. 930–947 (2022)
19. Menon, S.J., Wu, D.J.: Ypir: high-throughput single-server pir with silent prepro-
cessing. In: USENIX Security, pp. 5985–6002 (2024)
20. Mughees, M.H., Chen, H., Ren, L.: Onionpir: response efficient single-server pir.
In: ACM CCS, pp. 2292–2306 (2021)
21. Park, J., Tibouchi, M.: Shecs-pir: somewhat homomorphic encryption-based com-
pact and scalable private information retrieval. In: ESORICS, pp. 86–106. Springer
(2020)
A Two-Stage Image Blind Inpainting Algorithm
Based on Gated Residual Connection
Abstract. To solve the problem that image inpainting methods need to provide
damaged images and mask images, which limits the use of scenarios, a two-
stage image blind inpainting algorithm is proposed based on gated convolutional
residual blocks. The algorithm consists of two modules: mask prediction module
and image inpainting module. The mask prediction module predicts potential
visually inconsistent regions in a given image and outputs a predicted mask; The
image inpainting module repairs the damaged area of the image according to the
prediction mask and the context of the original image. In order to effectively utilize
the contextual information of the image, the algorithm adopts gated convolution in
the residual block and introduces a discriminator in the mask prediction module,
effectively improving the accuracy of predicting masks. Experimental results on
Places2 and CelebA-HQ datasets show that the image repair performance of the
proposed algorithm is better than that of the comparison method, and reliable
repair images are successfully generated.
1 Introduction
With the increasing popularity of multimedia data applications such as digital images,
video and audio, information processing technologies for multimedia data such as digital
images are also becoming increasingly widespread.
The purpose of image inpainting is to restore damaged or missing image informa-
tion, generate visually continuous content, or achieve local modification and removal of
interference through image segmentation. With the continuous development of computer
and image processing technology, digital image inpainting technology has become an
emerging field. A large number of digital image inpainting methods [1–4] have been
proposed by researchers, which can improve the image inpainting effect to a certain
extent and enhance the efficiency of image inpainting.
Traditional image inpainting methods rely on mathematical and physical theories
to ensure image content similarity and texture consistency. These methods accomplish
the inpainting of small areas of damaged images by establishing geometric models or
using texture synthesis. Traditional image inpainting methods are categorized into two
© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 316–331, 2025.
https://doi.org/10.1007/978-981-96-4566-4_22
A Two-Stage Image Blind Inpainting Algorithm 317
types: patch-based image inpainting methods [1, 2] and diffusion-based image inpainting
methods [3, 4]. Traditional image inpainting methods are effective for images with high
texture and small area damage. However, when dealing with complex and extensively
missing images, the effectiveness of these methods is not ideal.
In recent years, with the increasing depth of research on Convolutional Neural Net-
works (CNN) and Generative Adversarial Networks (GAN), image inpainting methods
based on deep learning have become a research hotspot [5]. CNN [6–9] and GAN [10]
have shown strong processing capabilities and good processing results in image feature
learning and expression, and has been widely used in large-scale image processing.
The image inpainting method of deep learning learns image semantics in large-
scale datasets, generates new pixels in damaged areas, and does not rely on the original
image. According to whether a mask is required for input, deep learning image inpainting
methods can be divided into two types: non-blind image inpainting methods and blind
image inpainting methods. In the process of repairing damaged images using non blind
restoration methods, it is necessary to accurately mask the damaged areas. Drawing
masks is a complex task, and the accuracy of masks will directly impact the image repair
effect, making it less practical.
Therefore, this article proposes a two-stage blind restoration algorithm to solve
the problem of limited usage scenarios caused by drawing masks in image restoration
methods. The algorithm consists of two stages: mask prediction module and image
restoration module. The mask prediction module contains seven residual blocks and one
convolutional block, while retaining bottleneck information in the network, which can
provide more effective information for the next module. In order to improve the accuracy
of predicting masks, a discriminator was introduced to evaluate the predicted masks. The
predicted mask and damaged image are simultaneously input into the image restoration
module. After 14 residual blocks, the repaired image was obtained.
The patch-based method is to search for the best matching patch in the visible area or
datasets of the image, and then copy this information to the missing area for filling.
Criminisi et al. [1] proposed a patch-based region filling and target removal repair algo-
rithm for eliminating unnecessary objects in images. A similar patch search algorithm
was proposed by Barnes et al. [2] that randomly matches patch blocks in an image and
continuously optimizes these matches to achieve structural image editing.
Diffusion-based method is also called variational partial differential equation (PDE)
based image inpainting algorithm. In this method, the image inpainting problem is trans-
formed into a functional solved variational problem by establishing the prior model and
data model of the image, and applying the knowledge of functional and total variation
in mathematics. Bertalmio et al. [3] proposed a PDE based diffusion repair method,
which considers the contextual information of missing parts in the image, predicts dam-
aged areas using pixels with known surrounding information, and performs image repair
based on the smoothness of local information. Chan et al. [4] proposed a total variation
regularization method that achieves image inpainting by minimizing the total variation
of the image.
318 Y. Pan and X. Zhang
In non blind image inpainting methods, Pathak et al. [11] proposed the Context
Encoder (CE), which can generate new pixels compared to traditional methods, but
can easily lead to excessive smoothing or edge loss. In order to solve the problem of
traditional convolution treating input pixels or features as equally effective, the concept of
gated convolution was proposed by Yu et al. [12] The core idea of gated convolution is to
allow masks to be automatically learned by the network, enabling it to clearly distinguish
between effective and invalid pixels. Even in deep networks, gated convolution can
independently learn in each channel, highlighting masked areas and sketch information,
and more effectively generate repair results.
The blind image inpainting method [13–26] only requires inputting damaged images,
omitting the step of mask drawing, and achieving automatic inpainting of visual content.
Wang et al. [13] proposed a Visual Consistency Network (VCN) that predicts semanti-
cally inconsistent regions in an image and generates corresponding damaged area masks.
By using spatial normalization methods to repair the predicted masks and damaged areas,
reducing the complex step of mask drawing. Zhao et al. [16] proposed a blind image
inpainting method using a single-stage Transformer-CNN Hybrid AutoEncoder. In addi-
tion, a Crosslayer Dissimilarity prompt (CDP) method has been designed to accelerate
the identification and remediation of contaminated areas.
1.2 Contributions
2 Preliminaries
CNN is a multi-layer neural network structure mainly used in computer vision tasks,
such as image classification, object detection, and image segmentation. The fundamental
components of a CNN encompass the data input layer, convolutional layer, activation
function layer, pooling layer, and fully connected layer.
The workflow of CNN is as follows: Firstly, the original image is transformed into a
data form with three dimensions: length, width, and number of channels through the data
input layer. Then, the image data undergoes multiple convolutions, activation functions,
and pooling operations to extract the feature information of the image. Finally, the
extracted feature maps are input into the fully connected layer for final classification or
other tasks.
A Two-Stage Image Blind Inpainting Algorithm 319
GAN, a deep learning model, revolves around the core concept of generating high-
quality samples through the interplay of two competing neural networks: a generator
and a discriminator. The training process improves the quality of samples generated
by the generator through competition between the generator and the discriminator. The
flowchart of GAN is shown in Fig. 1.
It is commonly held that enhancing the learning capacity of image inpainting models can
be achieved by augmenting the number of convolutional blocks. However, in practical
applications, performance improvement is not always achieved by increasing the number
of model layers, instead, some problems may be caused. This is because when the neural
network becomes very deep, the gradients in the backpropagation process may become
very small (gradient disappearance) or very large (gradient explosion), which can cause
weight updates to become very small or very large, making it difficult for the model
to learn effective representations. In addition, as the number of layers increases, the
parameters of the model will also increase significantly, leading to an increase in memory
utilization and making training more expensive.
The challenges encountered in deep neural networks are effectively mitigated by the
introduction of residual connections [27] within the framework of residual blocks [28].
These connections facilitate direct information transfer between layers, thereby enabling
information to skip certain layers and be directly conveyed from input to output. No
matter how many layers the model has, the layers with good learning performance are
retained, while the layers with poor performance are directly skipped through residual
connections. The effect of using residual blocks will not be worse than before, so it
is possible to steadily improve the model’s performance by increasing the number of
layers. The residual block structure is shown in Fig. 3 (b).
The formula for the residual block is as follows:
In Eq. (1) f (x) is the output of the residual network, g(x) is the output of two
convolutions in the residual block, and x is the sample datasets.
Fig. 3. Convolutional block structure. In the convolutional block structure, (a) represents a regular
convolutional block, (b) represents a residual block, and (c) represents a residual block with
different input and output channels. In cases where the number of input and output channels is
different, a 1 × 1 convolution kernel is introduced, which has the same number of convolution
kernels as the number of output channels to ensure consistency in the number of channels.
A Two-Stage Image Blind Inpainting Algorithm 321
Ldisc = Ex∼pdata [D(x)] − Ez∼p(z) [D(G(z))] + λgp Eẍ∼pẍ [(||∇ ẍD(ẍ)||2 − 1)2 ] (3)
In Eq. (2), Eq. (3) ẍ represents the linear interpolation between the generated image
and the real image, and λgp is the weight of the gradient penalty, with a value of 10. pẍ
represents the distribution of uniform sampling between generated and real samples.
322 Y. Pan and X. Zhang
The Adversarial loss comprises the generator loss and discriminator loss, with
the overall goal of establishing a game between the generator and discriminator. This
dynamic aims to empower the generator to generate more realistic samples, and the dis-
criminator can more accurately distinguish between generated samples and real samples,
namely:
In Eq. (4) Ladv represents adversarial loss, Lgen represents generator loss, and Ldisc
represents discriminator loss.
The main task of the image inpainting module is to apply the predicted mask in the
mask prediction module to the damaged area to achieve the goal of image inpainting.
However, directly extracting features from images may result in incomplete or invalid
information. In order to effectively extract image features, a gated residual connection
method was adopted in this stage.
By introducing gated convolution into the residual blocks [12], the image inpainting
module can output more effective multi-scale feature information, consequently enhanc-
ing the accuracy of image inpainting. In order to obtain broader contextual information
in the image, some gated residual blocks are added with dilation factors to form dilated
gated residual blocks.
The gated convolutional structure is shown in Fig. 5.
Traditional convolutions [11, 20] treat each pixel as valid for calculation, which
works well for classification and detection but fails in repair tasks, where distinguishing
between valid and invalid pixels is crucial. Gated convolution addresses this by using
learnable soft mask updates, allowing dynamic feature selection for each channel and
spatial position. This mechanism enables more flexible feature processing, improving
repair results. The formula for gated convolution is:
Gatingy,x = Wg · I (5)
A Two-Stage Image Blind Inpainting Algorithm 323
Featurey,x = Wf · I (6)
Among them, I is the feature map, σ is the sigmoid function, so the output gate
value is between 0 and 1. is the activation function, which can be ReLU, Tanh,
etc. Wg and Wf are two different convolutional filters, Oy,x represents performing two
different convolutions on I, and then multiplying the resulting feature map pixel by pixel.
During model training, five loss functions are used in this paper to make the generated
image closer to the real image: pixel by pixel reconstruction loss [21] Lrec , style loss
[22] Lstyle , perception loss [23] Lperc , ID-MRF loss [6] Lmrf , and adversarial loss [11]
Ladv . The specific explanations are as follows:
Pixel by pixel reconstruction loss is a loss function used to measure the difference
between the model generated image and the real image. Its goal is to minimize the
difference between the generated image and the real image, so that the generated image
is as close to the real data as possible.
The damage area loss [11] calculated the difference between the image generated by
the model in the damaged area and the real image. This loss is calculated by applying a
binary mask M in the damaged area.
Lhole = M (Iout − Igt )1 (8)
Among them, Lhole represents the loss of damaged areas, Lvalid represents the loss of
non-damaged areas, Iout represents the image generated by the model, and Igt represents
the real image of the target.
The non-damaged area loss calculates the difference between the image generated by
the model and the real image in the non-damaged area. This part of the loss is calculated
by applying a binary mask 1-M in the unbroken area.
Lvalid = (1 − M ) (Iout − Igt )1 (9)
The pixel by pixel reconstruction loss Lrec is the sum of these two losses:
The concept of style loss originates from the task of artistic style transfer, which
aims to make the generated image consistent in style with the real image.
1
Gij = φ(G)ik φ(G)jk (11)
C k
1
Rij = φ(R)ik φ(R)jk (12)
C k
324 Y. Pan and X. Zhang
where Gij represents the Gram matrix of the generated image, and Rij represents the
Gram matrix of the real image. i and j represent the rows and columns of the Gram
matrix, respectively, while C represents the number of channels in the feature map. φ(·)
represents the feature representation of the selected intermediate convolutional layer, k
represents the number of channels in the feature map of the convolutional layer, and N
represents the total number of pixels on the feature map of the convolutional layer. The
formula for style loss is:
1
Lstyle = || Gij − Rij ||1 (13)
N2 i,j
The goal of perceptual loss is to measure the perceptual quality of generated images
by comparing the differences in feature representations between generated and real
images in some intermediate layers. The perceptual loss formula is:
1
Lperc = || φ(G)i − φ(R)i ||1 (14)
Ni i
Among them, φ(G)i and φ(R)i represent the feature maps of the VGG19 network’s
relu3–2 and relu4–2 layers for generating and real images, respectively, and Ni represents
the total number of pixels on the i-th layer feature map.
The purpose of ID-MRF loss is to enhance the generation model’s ability to main-
tain structure during image generation, making the generated images more natural and
realistic.
4
Lmrf = LM (conv4_2) + LM (convt_2) (15)
t−3
used for training. To ensure consistent input image resolution, we downsampled the
CelebA-HQ datasets, resulting in an image resolution of 256 × 256 after sampling. To
ensure the fairness of the experimental results, all experiments were conducted using
the same training and testing sets. During the experiment, we used an Adam optimizer
with a learning rate of 1e-4 and a training period of T = 800000, λr1 = λp = 0.5, λa =
λm = 1e-3, λr2 = 1.4, λs = 1e-4.
4.2 Results
To assess the efficacy of the proposed algorithm, CE [11], LBAM [30], Vcnet [13]
and TransCNN-HAEwCDP [16] were selected for comparison with the method pro-
posed in this paper. To ensure the fairness of the experiment, both this experiment and
the comparative experiment used the same training and testing sets. Notably, Vcnet,
TransCNN-HAEwCDP and the proposed method in this paper all belong to blind repair
algorithms, while CE and LBAM belong to non-blind repair algorithms. CE addresses
damaged areas based on predefined rules without the requirement for masks. To ensure
the fairness of the experiment, the mask predicted in the mask prediction module of this
article is used as the input mask for LBAM. The repair results of the four methods are
shown in Figs. 6.
The performance of image inpainting algorithms is directly related to the quality after
inpainting. In this article, Three evaluation metrics, Mean Absolute Error (MAE) [13],
Peak Signal to Noise Ratio (PSNR) [30] and Structural Similarity Image Measurement
(SSIM) [6] were employed to evaluate the quality of the repaired image. The test results
are presented in Table 1 and 2.
From Fig. 6, it can be seen that for images with relatively simple structures, all five
comparison methods can generate reasonable content in the damaged area. However, for
images with complex textures, although there are still some shortcomings in the repair
effect of damaged areas, compared with the other four methods, the repair algorithm pro-
posed in this paper has made significant improvements in image texture and details. This
article adopts a method of predicting first and then repairing, which can more effectively
326 Y. Pan and X. Zhang
reconstruct lost information and make the repair results more natural and coherent. The
CE method uses reconstruction loss and adversarial loss to repair regular images, but the
capture of contextual semantic information is poor, making it impossible to reconstruct
appropriate semantic information and thus unable to perform effective repairs. LBAM
adopts a learnable attention mapping module to learn feature re normalization and mask
updating in an end-to-end manner. Compared with the CE method, it can effectively
repair irregular damaged areas. However, there are still significant edge responses and
color differences at the edges of damaged pixels. The Vcnet method can achieve good
repair results through the proposed blind repair network, but there are still some details
missing and visual artifacts. The TransCNN HAEwCDP method adopts a single-stage
Transformer CNN hybrid autoencoder and cross layer prompting method, which utilizes
local context to generate reasonable and realistic content for damaged areas without the
need for GAN. However, at the edges of complex areas, phenomena such as blurring and
artifacts may occur. This indicates that the method proposed in this article has certain
advantages in repairing images with complex textures.
Table 1. Comparison Results of the Average Values on the CelebA-HQ Test Set
Table 2. Comparison Results of the Average Values on the Places2 Test Set
A smaller MAE value indicates a smaller difference between the predicted and true
values of the model, suggesting a better fit of the model. On the other hand, a larger
PSNR value suggests less distortion between the repaired image and the real image,
indicating good image quality. Moreover, a larger SSIM value implies a higher degree
of similarity between the repaired image and the real image. Based on the findings
depicted in Tables 1 and 2, our method demonstrates superior performance in terms of
MAE, PSNR and SSIM compared to the other three methods in the CelebA-HQ dataset
and Places2 dataset.
A Two-Stage Image Blind Inpainting Algorithm 327
Fig. 7. Mask prediction results. (a) represents real masks; (b) represents the predicted result of
adding a discriminator mask; (c) represents the predicted result without discriminator mask.
From Fig. 7, it can be seen that the accuracy of mask prediction is significantly
improved by the introduction of WGAN-GP discriminator in the mask prediction mod-
ule. Compared with networks without discriminators, the mask contours predicted by
introducing discriminators are clearer.
2) Whether gated convolution is introduced in the residual blocks
Gated convolution, as an improvement on traditional convolution, adopts learnable
soft-mask update rules that allow dynamic feature selection mechanisms to be learned
for each channel and spatial position, making the network more flexible in handling
irregular image inpainting tasks (Fig. 8).
Fig. 8. Plot of the results of gated convolution inpainting with or without. (a) represents the
damaged images; (b) represents the repaired image without gated convolution; (c) represents the
repaired image with gated convolution。
328 Y. Pan and X. Zhang
4.4 Discussion
The data in Fig. 10 are from the Places2 datasets and the CelebA-HQ datasets. The
simulation test results in Fig. 10 show that in some complex areas, the network without
introducing gated convolution has problems such as limited feature extraction, edge arti-
facts, and blurred repair results when dealing with irregular areas. In contrast, networks
that introduce gated convolution can more effectively capture the features of irregu-
lar regions, resulting in more realistic repaired images and significantly improving the
quality of the repaired results.
From the results of the two comparative experiments mentioned above, it can be
concluded that the algorithm proposed in this paper can generate images with more
realistic content and better visual effects. This indicates that in the process of image
generation, the algorithm can effectively capture subtle details and features, making the
repaired results more realistic.
Among them, H and W represent the height and width of the input feature map, k
represents the size of the convolution kernel, Cin is the number of input channels, and
Cout is the number of output channels.
A residual block contains two convolutions Oconv , one activation function OELU , one
batch normalization OBatchNorm , and one skip connection OPr ojection . The mask prediction
module contains 7 residual blocks and one convolution block. The image inpainting
module contains 14 residual blocks, and the discriminator contains 5 convolution blocks.
The total spatial complexity is:
The spatial complexity is related to the number of parameters in the model and the
storage requirements for intermediate feature maps during runtime.
Among them, Pconv represents the number of parameters for conv, PPr ojection repre-
sents the number of parameters for skip connections, and PResBlock represents the number
of parameters for residual block.
The parameter quantity of the module is:
The total storage requirements include input, output, and intermediate feature maps.
The storage space for the total feature map of the residual block is:
The total feature map storage space of the convolutional block is:
5 Conclusions
This paper proposes a two-stage blind repair algorithm using gated residual connec-
tions, consisting of a mask prediction module and an image inpainting module. The
damaged image first enters the mask prediction module, where a discriminator helps
improve mask accuracy. The predicted mask and damaged image are then processed by
the inpainting module, which replaces traditional convolutions with gated convolutions
to better utilize contextual information and generate realistic repairs. Experiments show
that the method predicts realistic masks without prior knowledge and produces visu-
ally plausible repaired images. The future work will explore techniques for restoring
high-resolution images. Additionally, considering the demands of practical application
scenarios, attention should also be paid to the computational efficiency and real-time
performance of the model.
Acknowledgments. This work was supported by Natural Science Basic Research Program of
Shaanxi (Program No. 2021JQ-722).
Disclosure of Interests. The authors declare that they have no known competing financial inter-
ests or personal relationships that could have appeared to influence the work reported in the
paper.
330 Y. Pan and X. Zhang
References
1. Criminisi, A., Pérez, P., Toyama, K.: Region filling and object removal by exemplar-based
image inpainting. IEEE Trans. Image Process. 13(9), 1200–1212 (2004). https://doi.org/10.
1109/TIP.2004.833105
2. Barnes, C., Shechtman, E., Finkelstein, A., et al.: PatchMatch: a randomized correspondence
algorithm for structural image editing. ACM Trans. Graph. 28(3), 24 (2009). https://doi.org/
10.1145/3596711.3596777
3. Bertalmio, M., Sapiro, G., Caselles, V., et al.: Image inpainting. In: Proceedings of the 27th
Annual Conference on Computer Graphics and Interactive Techniques, pp. 417–424 (2000).
https://doi.org/10.1145/344779.344972
4. Chan, T.F., Shen, J.: Nontexture inpainting by curvature-driven diffusions. J. Vis. Commun.
Image Represent. 12(4), 436–449 (2001). https://doi.org/10.1006/jvci.2001.0487
5. Hui, Z., Li, J., Wang, X., et al.: Image fine-grained inpainting, 2020. arXiv preprint arXiv:
2002.02609, https://doi.org/10.48550/arXiv.2002.02609
6. Wang, Y., Tao, X., Qi, X., et al.: Image inpainting via generative multi-column convolutional
neural networks. Adv. Neural Inf. Process. Syst. 31 (2018). https://doi.org/10.48550/arXiv.
1810.08771
7. Yildirim, A.B., Pehlivan, H., Bilecen, B.B., et al.: Diverse inpainting and editing with gan
inversion. In: Proceedings of the IEEE/CVF International Conference on Computer Vision,
pp. 23120–23130 (2023). https://doi.org/10.48550/arXiv.2307.15033
8. Xiao, Q., Li, G., Chen, Q.: Deep inception generative network for cognitive image inpainting
(2018). arXiv preprint arXiv:1812.01458, 1812.HUI Z, LI J, WANG X, et al. Image fine-
grained inpainting[J]. arXiv: 2002. 02609, 2020. https://doi.org/10.48550/arXiv.1812.01458
9. Zhao, L.L., Shen, L., Hong, R.C.: A review of research progress in image inpainting. Comput.
Sci. 48(03), 14–26 (2021). (in Chinese). https://doi.org/10.11896/jsjkx.210100048
10. Goodfellow, I., et al.: Generative adversarial nets. In: Advances in Neural Information Pro-
cessing Systems (NeurIPS), pp. 2672–2680 (2014). https://doi.org/10.3156/JSOFT.29.5_1
77_2
11. Pathak, D., Krahenbuhl, P., Donahue, J., et al.: Context encoders: feature learning by inpaint-
ing. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition,
pp. 2536–2544 (2016). https://doi.org/10.48550/arXiv.1604.07379
12. Yu, J., Lin, Z., Yang, J., et al.: Free-form image inpainting with gated convolution. In: Pro-
ceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4471–4480
(2019). https://doi.org/10.48550/arXiv.1806.03589
13. Wang, Y., Chen, Y.C., Tao, X., Jia, J.: VCNet: a robust approach to blind image inpainting. In:
Vedaldi, A., Bischof, H., Brox, T., Frahm, J.M. (eds.) Computer Vision – ECCV 2020. ECCV
2020. LNCS, vol. 12370, pp. 752–768. Springer, Cham (2020).https://doi.org/10.1007/978-
3-030-58595-2_45
14. Li, W.J., Zhou, Z.P., et al.: Research on blind image inpainting method based on context
semantic hierarchical reasoning. J. Nanchang Hangkong Univ. (Natl. Sci.) 36(02), 36–43
(2022). (in Chinese). https://doi.org/10.3969/j.issn.2096-8566.2022.02.006
15. Liu, D.B., Wang, H.Q., Wang, K., et al.: Blind inpainting method for incomplete sparse text
images based on content style transfer. Laser Optoelectron. Progress 59(24), 106–117 (2022).
(in Chinese). https://doi.org/10.3788/LOP202259.2411001
16. Zhao, H., Gu, Z., Zheng, B., et al.: Transcnn-hae: transformer-cnn hybrid autoencoder for blind
image inpainting. In: Proceedings of the 30th ACM International Conference on Multimedia,
pp. 6813–6821 (2022). https://doi.org/10.1145/3503161.3547848
17. Li, X., Wang, Z., Chen, C., et al.: SemID: blind image inpainting with semantic inconsistency
detection. Tsinghua Sci. Technol. 29(4), 1053–1068 (2024). https://doi.org/10.26599/TST.
2023.9010079
A Two-Stage Image Blind Inpainting Algorithm 331
18. Wang, J., Yuan, C., Li, B., et al.: Self-prior guided pixel adversarial networks for blind image
inpainting. IEEE Trans. Pattern Anal. Mach. Intell. (2023). https://doi.org/10.1109/TPAMI.
2023.3284431
19. Bai, Y., He, R., Tan, W., et al.: Fine-grained blind face inpainting with 3D face compo-
nent disentanglement. In: ICASSP 2023–2023 IEEE International Conference on Acoustics,
Speech and Signal Processing (ICASSP). IEEE, pp. 1–5 (2023). https://doi.org/10.1109/ICA
SSP49357.2023.10097082
20. Li, C.Y., Lin, Y.Y., Chiu, W.C.: Decontamination transformer for blind image inpainting.
In: ICASSP 2023–2023 IEEE International Conference on Acoustics, Speech and Signal
Processing (ICASSP). IEEE, pp. 1–5 (2023). https://doi.org/10.1109/ICASSP49357.2023.
10094950
21. Schmalfuss, J., Scheurer, E., Zhao, H., et al.: Blind image inpainting with sparse directional
filter dictionaries for lightweight CNNs. J. Math. Imaging Vis. 65(2), 323–339 (2023). https://
doi.org/10.1007/s10851-022-01119-6
22. Phutke, S.S., Kulkarni, A., Vipparthi, S.K., et al.: Blind image inpainting via omni-
dimensional gated attention and wavelet queries. In: Proceedings of the IEEE/CVF Con-
ference on Computer Vision and Pattern Recognition, pp. 1251–1260 (2023). https://doi.org/
10.1109/CVPRW59228.2023.00132
23. Gao, M., Kang, B.: Blind image inpainting using low-dimensional manifold regularization.
J. Circuits Syst. Comput. 31(12), 2250211 (2022). https://doi.org/10.1142/S02181266225
02115
24. Kim, H., Kim, C.I., Kim, H., et al.: Panoptic blind image inpainting. ISA Trans. 132, 208–221
(2023). https://doi.org/10.1016/j.isatra.2022.10.030
25. Gao, M., Kang, B., Feng, X., et al.: Low dimensional manifold regularization based blind
image inpainting and non-uniform impulse noise recovery. IEEE Access 8, 200551–200560
(2020). https://doi.org/10.1109/ACCESS.2020.3035532
26. He, S., Peng, X., Yuan, Z., et al.: Contour-context joint blind image inpainting network
for molecular sieve particle size measurement of SEM images. IEEE Trans. Instrum. Meas.
(2023). https://doi.org/10.1109/tim.2023.3279451
27. Russakovsky, O., Deng, J., Su, H., et al.: Imagenet large scale visual recognition challenge.
Int. J. Comput. Vis. 115, 211–252 (2015). https://doi.org/10.1007/s11263-015-0816-y
28. Pehlivan, H., Dalva, Y., Dundar, A.: Styleres: transforming the residuals for real image editing
with stylegan. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern
Recognition, pp. 1828–1837 (2023). https://doi.org/10.48550/arXiv.2212.14359
29. Gulrajani, I., Ahmed, F., Arjovsky, M., et al.: Improved training of wasserstein gans. Adv.
Neural Inf. Process. Syst. 30 (2017). https://doi.org/10.48550/arXiv.1704.00028
30. Xie, C., Liu, S., Li, C., et al.: Image inpainting with learnable bidirectional attention maps. In:
Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 8858–8867
(2019). https://doi.org/10.48550/arXiv.2104.12087
GAN-Based Adaptive Trigger Generation
and Target Gradient Alignment in Vertical
Federated Learning Backdoor Attacks
1 Introduction
With the increasing demand for data privacy protection and multi-party col-
laboration, Federated Learning (FL) has become a research hotspot in the field
of machine learning [1]. Among them, Vertical Federated Learning (VFL), as
an important federated learning paradigm, allows multiple participants with the
same users but different features to collaboratively train a global model without
sharing raw data [2–7]. VFL has broad application potential in data-sensitive
fields such as financial risk control[8, 9]and medical diagnosis [10, 11].
Horizontal Federated Learning (HFL) and Vertical Federated Learning
(VFL) differ in data distribution: HFL participants share the same feature space
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 332–346, 2025.
https://doi.org/10.1007/978-981-96-4566-4_23
GAN-Based Adaptive Trigger Generation 333
but hold different data samples, while VFL participants share the same data
samples but hold different feature subsets. Due to this structure, VFL is more
vulnerable to backdoor attacks. Attackers can exploit feature interactions with-
out label information, enabling subtle malicious behaviors that influence the
global model. Hence, traditional HFL backdoor methods do not directly apply
to VFL, demanding specialized strategies.
Significant progress has been made in HFL backdoor attacks [12–14], but
research on VFL backdoors, especially those launched by passive parties with-
out label access, remains limited [15–18]. In VFL, passive parties only hold par-
tial features, cannot modify labels, and have limited access to others’ data and
models, making attacks more challenging. Figure 1 illustrates a VFL training
architecture. Exploring how passive parties can implement efficient backdoor
attacks under these constraints is crucial.
2 Related Work
As Federated Learning (FL) continues to evolve, its security has garnered increas-
ing attention. Backdoor attacks, in particular, have emerged as a critical threat
and drawn widespread interest. However, in the realm of Vertical Federated
Learning (VFL), research on backdoor attacks is still in the early stages. Exist-
ing approaches often depend heavily on label inference or substantial auxiliary
information, which imposes significant limitations.
In training-phase attack methods, Liu et al. [15] proposed the Label Replace-
ment Backdoor (GR) method, requiring a small number of target samples from
the training set. During VFL training, the attacker swaps local embedding fea-
tures and returns gradients of these target samples with those of poisoned data.
However, this method demands label information for each batch, making it less
practical as the dataset grows. Naseri et al. [16] introduced BadVFL, assuming
the attacker has limited labels from each class. The attacker performs semi-
supervised label inference and then implants subtle triggers. Yet, this approach
heavily relies on label inference and prior class knowledge, reducing its prac-
ticality. Xuan et al. [20] proposed a gradient-based label inference attack. The
attacker infers class information from gradient data, then replaces local inputs
and adds triggers. However, this depends on accurate gradient-based inference,
limiting its applicability in complex multi-class settings. Peng et al. [18] presented
TECB, but its trigger update relies on a fixed optimization strategy and uses
random noise replacement, which can destabilize training. He et al. [21] added
triggers directly to embedding features, requiring many known target labels and
thus diminishing stealthiness.
In inference-phase attack strategies, LR-BA [17] infers pseudo-labels for
training data and then refines the model with backdoor representations. Yet,
it requires labeled samples from each class, which is impractical for large-scale
VFL.TPGD [22] generates adversarial perturbations during the inference phase
by modifying the attacker’s local embedding features. However, it assumes the
attacker can alter labels in the active party, which is unrealistic in real-world
VFL environments.
control partial features and lack influence over the global training process. Con-
sequently, conventional dynamic backdoor strategies face significant challenges
in VFL scenarios.
backdoor ensures that when the model is used for inference, any input containing
a specific trigger will be misclassified as a predetermined target class τ . At the
same time, the model’s accuracy on normal, trigger-free data remains unaffected.
Attacker’s Knowledge: In this type of attack, the attacker can carry out an
effective backdoor attack with access to only a minimal set of samples from
the target class. In our experiments, we utilized a limited number of target
class samples (e.g., only 50 target samples) and compromised the model by
dynamically creating triggers using the Adaptive Trigger Generation Network
(ATGN). Although this requirement for target samples differs from traditional
VFL scenarios, in practical applications, attackers can obtain a small number
of target samples through various means (e.g., data purchasing or accessing
public datasets). Additionally, the attacker has no further knowledge about other
participants’ data and models and can only rely on gradient information obtained
from the active party.
4 Methodology
This section details the dynamic trigger generation network method proposed for
backdoor attacks in Vertical Federated Learning (VFL). Figure 2 illustrates the
attack framework of the backdoor attack proposed in this paper. The method
comprises two core steps: Adaptive Trigger Generation Network (ATGN) and
Gradient-Feature Correlation Attack (GFCA). These two stages work collab-
oratively by generating dynamic triggers and performing feature and gradient
replacement to enhance the success rate of backdoor attacks.
where T (z; θTGN ) is the trigger generation network with parameters θTGN . The
generated trigger δ is injected into the attacker’s local target data, and the local
model extracts embedding features to be sent to the active party for training.
During the VFL training process, the active party calculates the global loss
L using the top model G(H1 , H2 , . . . , HK ; θtop ) based on the embedding features
H1 , H2 , . . . , HK transmitted by all participants. Subsequently, the active party
computes the gradient of the loss with respect to each participant’s embedding
∂L
features ∂H A and returns the corresponding gradient information to the passive
i
∂L
parties. For the attacker, the received gradient ∂HiA
contains the effect of the
trigger δ on the local feature embedding HiA .
To extract the gradient related to the trigger, the attacker uses the chain rule
to decompose the gradient of the global loss L with respect to the trigger δ into
two parts: first, the gradient of the global loss with respect to the embedding
features HiA , and then the gradient of the embedding features with respect to
the trigger δ. This process can be expressed as:
M
∂L 1 ∂L ∂HiA
= · (2)
∂δ M i=1 ∂HiA ∂δ
∂H A
where M represents the batch size, and ∂δi represents the gradient of the
∂L
embedding features ∂H A with respect to the trigger δ, which can be calculated
i
through backpropagation of the attacker’s local model fA (xA i + δ; θA ). Through
this step, the attacker can extract the effect of the trigger δ on the global loss L.
Once the trigger-related gradient ∂L∂δ is obtained, the attacker can use this
information to update the parameters θTGN of the trigger generation network
(ATGN). The specific parameter update process is performed through backprop-
agation, and its formula is:
(t+1) t
θTGN = θTGN − α · ∇θTGN L (3)
where α is the learning rate, and the gradient of the loss function L is derived
from the previously computed ∂L ∂L ∂δ
∂δ , i.e., ∇θTGN L = ∂δ · ∂θTGN . Through each
GAN-Based Adaptive Trigger Generation 339
training round, ATGN continuously optimizes its parameters so that the gener-
ated trigger δ can more effectively deceive the VFL model, causing samples with
the trigger to be misclassified into the target class during the inference phase.
Throughout the training process, ATGN continuously receives gradient infor-
mation from the active party and combines this information for trigger genera-
tion and parameter optimisation. As training progresses, the generated trigger δ
gradually adapts to changes in the model, enhancing the stealthiness and effec-
tiveness of the attack.
This dynamic trigger generation method has stronger adaptability compared
to traditional static optimisation, allowing flexible adjustment of the trigger
during training, thereby improving the stealthiness and effectiveness of backdoor
attacks.
Here, Θ represents the parameters of the global model, and the goal is to min-
imize the loss function L by adjusting these parameters. The loss function L
340 K. Li et al.
measures the difference between the model’s output and the target class τ . The
poisoned dataset Dp contains samples from non-target classes. The function T
refers to a trigger generation network trained during the ATGN phase, and it
generates triggers by applying random noise z. The term T (z) represents the
generated trigger from the noise z, which is added to the input data, producing
the modified input x + T (z). The expression F (x + T (z); Θ) denotes the output
of the global model when processing the poisoned data, and τ is the class the
attacker intends for the poisoned samples to be misclassified into. By optimizing
Θ, the attacker ensures that the poisoned data with the generated trigger is
classified into the target class τ , successfully executing a backdoor attack during
inference.
GFCA achieves its goal through two primary steps. The first step involves
feature substitution, where the attacker systematically replaces the embedding
features of poisoned data with those of the target class samples in a one-to-one
correspondence. This manipulation causes the poisoned data to resemble the
target class samples during training, thereby misleading the model’s classification
behavior. The second step is gradient substitution, where the attacker replaces
the gradients of the poisoned samples with the gradients from the target class
data and amplifies them by a factor of λ. This amplification accelerates the
alignment of the poisoned data with the target class in gradient space. This dual
strategy of aligning both features and gradients not only strengthens the attack’s
effectiveness but also significantly increases the likelihood that the poisoned data
will be misclassified into the target class during the inference phase.
This strategy performs particularly well in complex multi-classification tasks
because it allows the poisoned data to gradually align with the target class
samples in both feature and gradient spaces, thereby significantly improving the
success rate of backdoor attacks.
5 Experiments
To validate the effectiveness of our proposed Adaptive Federated Backdoor Frame-
work (AFBF) in Vertical Federated Learning (VFL), we conducted extensive
experiments on multiple public datasets. The selection criteria for these datasets
included diversity in class numbers, image complexity, and widespread use in the
research community to ensure the generalizability of our results. Additionally,
all datasets underwent standard preprocessing steps, including normalization of
pixel values to the [0,1] range, resizing images to 32×32 pixels if necessary, and
data augmentation techniques such as random cropping and horizontal flipping
to enhance model robustness. We compared AFBF with existing major backdoor
attack methods and performed an in-depth analysis of the impact of each compo-
nent and hyperparameters in the framework on the overall performance.
and 10,000 test images across 10 classes, with images evenly distributed among
classes and sized at 32×32 pixels. CIFAR-100 is similar to CIFAR-10 but con-
tains 100 classes, each with 600 images, totalling 60,000 images, with 50,000 for
training and 10,000 for testing. CINIC-10 has a total of 270,000 images, with
180,000 for training and 90,000 for testing, also divided into 10 classes. This
dataset is an enhanced variant of CIFAR-10, designed to assess the performance
of algorithms on more extensive and intricate classification challenges.
In our experiments, we implemented a two-party VFL configuration, consist-
ing of one passive party (serving as the attacker) and one active party. Following
the approaches outlined in [15, 18], each image was divided along its vertical
midline, with each participant retaining half of the features. The passive party
was assigned the left half of the image, while the active party received the right
half along with the associated label information. Both participants employed
ResNet-18 [27] as the base model to extract local features. The top model used
by the active party is a four-layer fully connected neural network, with each
layer followed by batch normalization and ReLU activation, which integrates
the embedding features from both parties to perform the classification task.
We designated the backdoor target classes as “car,” “Apple,” and “car” for the
CIFAR-10, CIFAR-100, and CINIC-10 datasets, respectively. Within the AFBF
framework, the attacker requires only a small subset of target class samples.
Specifically, we randomly selected 50 samples from the target class to form the
target sample set Dt . Additionally, 50 samples were randomly chosen from the
remainder of the training set to create the poisoned sample set Dp . The datasets
were split into training and testing sets with an 80–20 ratio to maintain consis-
tency across experiments. Furthermore, we ensured that the selected poisoned
samples did not overlap with the validation set to prevent data leakage. To
assess the attack effectiveness of AFBF, we applied the generated triggers to
all non-target class samples in the test set, constructing a backdoor test set for
evaluation.
To measure the model’s performance, we used two commonly used metrics:
Attack Success Rate (ASR) and Main Task Accuracy (MTA) [28]. ASR repre-
sents the proportion of samples in the backdoor test set that are misclassified as
the target class by the model; MTA measures the classification accuracy of the
poisoned model on the clean test set. It should be noted that the CIFAR-10 and
CINIC-10 datasets are evaluated using Top-1 accuracy, while due to the large
number of classes in CIFAR-100, we use Top-5 accuracy [29].
The experimental setup included a total of 100 training rounds to ensure suf-
ficient convergence of both the VFL model and the backdoor mechanisms. We
implemented early stopping based on validation performance to prevent overfit-
ting. The learning rate was set to 0.01 with a decay factor of 0.1 applied every
30 epochs. Batch size was maintained at 128 for all experiments.
In the comparison, we selected several typical VFL backdoor attack methods
as baselines, including both training-phase and inference-phase attack meth-
ods. The training-phase attack methods include GR [15] and BadVFL [16]; the
inference-phase attack method is LR-BA [17]. Additionally, we also compared it
342 K. Li et al.
with the standard VFL model without attack as the baseline. For fairness, we
adjusted the parameters of each method according to the settings in [17] and
[15]. For example, for GR and LR-BA, we set the number of poisoned samples
to be consistent with AFBF; for BadVFL, we selected the best combination of
original and target classes.
Each experiment was repeated five times independently to account for vari-
ability, and the results were averaged to ensure statistical significance. We also
conducted ablation studies to isolate the effects of ATGN and GFCA components
within the AFBF framework.
Table 1. Comparison of MTA and ASR for Different Datasets and Attacks
main task accuracy. This indicates that existing defence methods have difficulty
effectively resisting AFBF attacks without sacrificing model performance. AFBF
still maintains high ASR and MTA in the face of these defence measures, demon-
strating its robustness against defences.
Table 2. Comparison of MTA and ASR across different datasets and methods
6 Conclusions
Acknowledgments. This work was supported by the National Natural Science Foun-
dation of China (No. 62372130, No. 62261160651) and the Guangzhou Basic and
Applied Basic Research Project (No. 2023A04J1725, No. 2024A03J0397).
References
1. McMahan, B., Moore, E., Ramage, D., Hampson, S., y Arcas, B.A.:
Communication-efficient learning of deep networks from decentralized data. In:
Artificial Intelligence and Statistics, pp. 1273–1282. PMLR (2017)
2. Cheng, K., et al.: Secureboost: a lossless federated learning framework. IEEE Intell.
Syst. 36(6), 87–98 (2021)
3. He, C., Annavaram, M., Avestimehr, S.: Group knowledge transfer: federated learn-
ing of large cnns at the edge. Adv. Neural. Inf. Process. Syst. 33, 14068–14080
(2020)
4. Hu, Y., Niu, D., Yang, J., Zhou, S.: Fdml: a collaborative machine learning frame-
work for distributed features. In: Proceedings of the 25th ACM SIGKDD Interna-
tional Conference on Knowledge Discovery & Data Mining, pp. 2232–2240 (2019)
5. Yang, L., Tianjian, C., Qiang, Y.: Secure federated transfer learning. In: Proceed-
ings of the 24th ACM SIGKDD International Conference on Knowledge Discovery
and Data Mining (2018)
6. Liu, Y., et al.: A communication efficient collaborative learning framework for
distributed features. arXiv preprint arXiv:1912.11187 (2019)
7. Yang, Q., Liu, Y., Chen, T., Tong, Y.: Federated machine learning: concept and
applications. ACM Trans. Intell. Syst. Technol. (TIST) 10(2), 1–19 (2019)
8. Kairouz, P., et al.: Advances and open problems in federated learning. Found.
Trends R Mach. Learn. 14(1–2), 1–210 (2021)
9. Zhang, Y., Chen, X., Liu, X., Zhu, N.: Exploring trust transfer between internet
enterprises and their affiliated internet-only banks: an adoption study of internet-
only banks in china. Chin. Manag. Stud. 12(1), 56–78 (2018)
10. Matschinske, J., et al.: The featurecloud platform for federated learning in
biomedicine: unified approach. J. Med. Internet Res. 25, e42621 (2023)
11. Kaissis, G.A., Makowski, M.R., Rückert, D., Braren, R.F.: Secure, privacy-
preserving and federated machine learning in medical imaging. Nat. Mach. Intell.
2(6), 305–311 (2020)
12. Bagdasaryan, E., Veit, A., Hua, Y., Estrin, D., Shmatikov, V.: How to backdoor fed-
erated learning. In: International Conference on Artificial Intelligence and Statis-
tics, pp. 2938–2948. PMLR (2020)
13. Baruch, G., Baruch, M., Goldberg, Y.: A little is enough: circumventing defenses
for distributed learning. Adv. Neural Inf. Process. Syst. 32 (2019)
14. Shejwalkar, V., Houmansadr, A.: Manipulating the byzantine: optimizing model
poisoning attacks and defenses for federated learning. In: NDSS (2021)
15. Zou, T., et al.: Defending batch-level label inference and replacement attacks in
vertical federated learning. IEEE Trans. Big Data (2022)
346 K. Li et al.
16. Naseri, M., Han, Y., De Cristofaro, E.: Badvfl: backdoor attacks in vertical fed-
erated learning. In: 2024 IEEE Symposium on Security and Privacy (SP), pp.
2013–2028. IEEE (2024)
17. Gu, Y., Bai, Y.: LR-BA: backdoor attack against vertical federated learning using
local latent representations. Comput. Secur. 129, 103193 (2023)
18. Chen, P., Yang, J., Lin, J., Lu, Z., Duan, Q., Chai, H.: A practical clean-label
backdoor attack with limited information in vertical federated learning. In: 2023
IEEE International Conference on Data Mining (ICDM), pp. 41–50. IEEE (2023)
19. Nguyen, T.A., Tran, A.: Input-aware dynamic backdoor attack. Adv. Neural. Inf.
Process. Syst. 33, 3454–3464 (2020)
20. Xuan, Y., Chen, X., Zhao, Z., Tang, B., Dong, Y.: Practical and general back-
door attacks against vertical federated learning. In: Joint European Conference on
Machine Learning and Knowledge Discovery in Databases, pp. 402–417. Springer
(2023)
21. He, Y., et al.: Backdoor attack against split neural network-based vertical federated
learning. IEEE Trans. Inf. Forensics Secur. (2023)
22. Liu, J., Xie, C., Koyejo, S., Li, B.: Copur: certifiably robust collaborative inference
via feature purification. Adv. Neural. Inf. Process. Syst. 35, 26645–26657 (2022)
23. Salem, A., Wen, R., Backes, M., Ma, S., Zhang, Y.: Dynamic backdoor attacks
against machine learning models. In: 2022 IEEE 7th European Symposium on
Security and Privacy (EuroS&P), pp. 703–718. IEEE (2022)
24. Li, Y., Li, Y., Wu, B., Li, L., He, R., Lyu, S.: Invisible backdoor attack with sample-
specific triggers. In: Proceedings of the IEEE/CVF International Conference on
Computer Vision, pp. 16463–16472 (2021)
25. Krizhevsky, A., Hinton, G., et al.: Learning multiple layers of features from tiny
images (2009)
26. Darlow, L.N., Crowley, E.J., Antoniou, A., Storkey, A.J.: Cinic-10 is not imagenet
or cifar-10. arXiv preprint arXiv:1810.03505 (2018)
27. Fu, C., et al.: Label inference attacks against vertical federated learning. In: 31st
USENIX Security Symposium (USENIX Security 22), pp. 1397–1414 (2022)
28. Li, Y., Jiang, Y., Li, Z., Xia, S.T.: Backdoor learning: a survey. IEEE Trans. Neural
Netw. Learn. Syst. 35(1), 5–22 (2022)
29. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In:
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition,
pp. 770–778 (2016)
Weakly Supervised Waste Classification
with Adaptive Loss and Enhanced Class
Activation Maps
1 Introduction
different types of waste can still be seen almost everywhere in cities. Sustain-
able waste management reduces the use of natural resources while searching for
solutions that do not negatively impact the environment or public health. A few
crucial phases are included in the sustainable waste management process: gather-
ing of waste, processing and sorting, material or energy recovery from waste, and
reusing to create new goods. To promote waste recycling and environmental sus-
tainability, many artificial intelligence-based approaches for waste classification
[14] have been developed. Some methods integrate additional feature extractor
modules for waste classification. These modules, such as optimized convolutions
[14] and feature pyramids [21], are used to capture waste at different scales in
images. Some methods design fine-tuning methods while retaining the original
model structure [6]. These methods can enable the model to achieve faster con-
vergence when the dataset images are insufficient. However, these methods are
all single-label classification methods. In real-life scenarios, there is often more
than just one type of waste.
In contrast, multi-label classification allows for the simultaneous classification
of multiple objects within a single image [8]. However, labeling a large-scale and
accurate multi-label dataset for training is time-consuming [24]. To address this
issue, Weakly Supervised Multi-label Classification (WSML) has emerged [16].
While WSML reduces the labor costs and time associated with labeling images,
it also faces challenges [4]. For example, these unannotated objects and their
unobserved labels often lead to large loss values, resulting in poor classification
performance [5].
To alleviate the problem, Some methods focus on exploiting unobserved
labels with techniques like Assume Negative (AN) [5]. In most multi-label clas-
sification scenarios, the number of objects present in each image is smaller than
the total number of classes in the datasets [18]. Consequently, many true neg-
ative labels are incorporated into training. However, some false negative labels
can interfere with the model’s ability to learn the characteristics of objects. To
address the problem, an online estimation module is utilized to correct these
unobserved labels during training [5]. While this method can correct some false
negative labels, it requires prior knowledge of the mathematical expectation for
the number of positive labels in each image. Additionally, the approach neglects
that the model first fits true labels (early learning phase) and then fits false
labels (memorization phase) [11]. To utilize this finding, LL-R, LL-Ct, and LL-
Cp are proposed to reject and correct partial false negative labels [11]. However,
they struggle to handle false negative labels separately.
In this study, we propose a new weakly supervised multi-label waste clas-
sification framework, called the Adaptive Weakly-supervised Waste Classifica-
tion Framework (AWWCF). The AWWCF comprises three modules: the Target
Preprocessing Module (TPM), the Prediction Module (PM), and the Comput-
ing Module (CM). TPM is responsible for assigning unobserved labels as neg-
ative. To improve the effectiveness of AWWCF in utilizing unobserved labels
and reduce the impact of false negative labels on model learning features, we
integrate a novel method with the PM and CM of AWWCF, called Adaptive
Weakly Supervised Waste Classification with ALEC 349
Loss and Enhanced CAMs (ALEC). In the PM, ALEC calculates the areas of
CAMs, whose attribution scores are the largest part. These areas are associ-
ated with positive labels, but their attribution scores are diminished by other
false negative labels. ALEC enhances these areas. By increasing the attribution
scores in these areas, the probability of false negative labels being predicted as
positive is reduced, thereby improving classification performance. In the CM,
ALEC makes full use of unobserved labels by assigning them as negative. At
different stages of training, ALEC uses different methods to deal with large loss
values depending on the proportion of true negative and false negative labels. In
the early stages of training, ALEC gradually rejects false negative labels. When
classification performance stabilizes or mean average precision oscillates within
a range for a prolonged period, ALEC adopts more aggressive ways to correct
these false labels. It permanently modifies some of the false negative labels with
the largest loss values. The contributions of this paper are:
• We propose a multi-label waste classification framework called Adaptive
Weakly-supervised Waste Classification Framework (AWWCF). The frame-
work can classify multiple classes of waste in complex scenarios with partial
labels.
• We design a new method called Adaptive Loss and Enhanced CAMs (ALEC)
for weakly supervised multi-label classification. ALEC dynamically enhances
CAMs with high attribution scores to reduce false negative labels’ impact
on predictions for observed labels. Moreover, ALEC dynamically rejects and
corrects unobserved labels with large loss values.
• We conduct extensive experiments on three multi-label waste datasets. Exper-
imental results show that the framework incorporating ALEC has better clas-
sification performance compared with existing waste weakly supervised multi-
label classification framework. The proposed method ALEC can combine dif-
ferent backbones and improve their weakly supervised waste classification
performance.
The remaining sections of this paper are structured as follows: Sect. 2 intro-
duces related work, Sect. 3 describes the AWWCF modules in detail, Sect. 4
provides detailed comparative experiments and results, and Sect. 5 summarizes
all the work.
2 Related Work
2.1 Weakly Supervised Multi-label Classification
In weakly supervised multi-label classification tasks, there are various methods
used to train the model. Some studies fuse image features at different scales
[23]. Some address the issue of label distribution in datasets [7], which design a
network with two branches to improve classification performance on both head
classes and tail classes. Another proposal [5] suggests different ways to handle
extreme cases where there is only one single positive label in each image. The first
approach is to introduce a hyper-parameter for the ordinary BCELoss function.
350 W. Dai and L. Sun
super category and does not consider other types of waste. Graph neural network
is also used to solve multi-label waste classification tasks [25]. They attach impor-
tance to correlations between labels instead of treating them independently.
By stacking multiple layers, their model can model the complex relationships
between labels. Recently, a program called “TACO” has been launched [17]. It
contains thousands of photos of waste in diverse environments and encompasses
sixty different waste categories. Therefore, we utilize the images and annotations
from this program to create a new waste dataset.
objects in the input images and predict the existence probability of each class.
The CM mainly calculates the cross-entropy loss function between the prediction
values output by PM and the target values output by the TPM. It uses ALEC
to reject and correct large loss values to minimize their impact on the model
learning process (lines 7–8).
tasks, only a small proportion of the labels are known, and hence |S p |+|S n | < N .
The expression for Y AN can be given by Eq. 1:
AN 1, i ∈ S p
yi = (1)
0, i ∈ S n ∪ S u
After obtaining the target value Y AN from TPM and the prediction value P
from PM, the cross-entropy loss can be achieved. Finally,
the model
is trained
by minimizing the loss function L on the dataset D = X , Y AN as shown in
Eq. 2:
N
1 1
L= BCELoss Pi , yiAN (2)
|D | AN
N i=1
(x,y )∈D
Weakly Supervised Waste Classification with ALEC 355
Due to the modification of target values for some labels in PM, there may be
a subset of labels with excessively large cross-entropy loss values, typically due
to the Assume Negative assumption, and these labels are usually false negative
labels. To prevent these labels affecting model learning, the ALEC introduces a
weight λi for each loss value:
N
1 1
L= li × λ i . (3)
|D | N i=1
(x,y AN )∈D
The definition of li is BCELoss Pi , yiAN , where the parameters Pi and yiAN
are omitted for convenience. The λi is defined as a function λi = λ Pi , yiAN ,
where the parameters are also omitted for convenience. λi is a weight used to
measure the proportion of li in the loss functionL in Eq. 3. The setting of λi is
shown in Eq. 4:
0, i ∈ S u and γ ≤ pt and li > R(t)
λi = (4)
1, otherwise ,
Here, γ refers to the difference between the current epoch and the epoch of the
last mAP maximum when using rejection methods, pt means patience, denotes
after how many epochs of oscillating, the correcting method starts to step in.
R(t) represents the largest loss values in the loss set {li |(x, y AN ) ∈ D , i ∈ S u } for
the first [(t−1)Δrel ]% of the samples. Δrel is a hyper-parameter that determines
the rate at which the rejection rate of large loss value samples increases. In the
early stage of training, the loss values of samples with excessively high loss are
set to zero in each batch, which can be understood as temporarily ignoring these
samples with excessively high loss. The zeroing ratio is related to the number of
classification categories in the dataset, the hyper-parameter batch size, and the
rejection rate. The more classes in the dataset, the larger the hyper-parameter
batch size, and the higher the rejection rate, the more noise labels will be ignored.
As the number of training iterations increases, the model gradually learns the
noise labels, so the zeroing ratio will also increase. In the later stages of training,
the target with large loss values is permanently corrected and the loss values are
recalculated. However, due to the correction method being too aggressive, the
R(t) is a fixed value representing the top Δrel % of the loss value set, rather than
[(t − 1)Δrel ]%. The targets of other normal loss value samples are not changed,
and the loss values are not recalculated.
Correspondingly, due to the label modification method is used in the later
stage of training, the target value yiAN also needs to be modified. The specific
setting is shown in Eq. 5:
AN 1, i ∈ S u and t > p and li > R(t)
yi = (5)
remain, otherwise,
356 W. Dai and L. Sun
After the label yiAN that satisfies the first condition of Eq. 5 is modified, sets
S u and S p will also be modified based on Eq. 6
S u ← S u − {i}
(6)
S p ← S p ∪ {i}
4 Experiment
4.1 Experimental Setting
We evaluate the ALEC on three multi-label datasets: MW, MWS, and TACOS.
The MW and MWS datasets consist of trash images collected from the internet,
while the TACOS dataset comprises images from the TACO project. Due to
the severe imbalance in the number of labels in the MH and TACOS datasets,
we excluded object categories with fewer images to ensure sufficient images per
category for training and testing. Detailed information about the three datasets
is presented in the table. The table lists the number of images in the training,
validation, and test sets, the number of annotations, and the number of object
categories. Generally, we use 70% of the images and annotations to train the
model, 20% to validate its performance, and the remaining 10% to test the
model’s classification performance. We use mean Average Precision (mAP) as
the classification evaluation metric (Table 3).
Datasets train imgs train annos val imgs val annos test annos test annos classes
MW 3283 5474 938 1553 469 751 46
MWS 3262 5167 932 1459 466 721 30
TACOS 892 2640 255 758 126 290 20
• Full labels: Each image in the datasets has true labels for all classes, repre-
senting the best performance of multi-label classification on the datasets.
• Naive Assume Negative (Naive AN) [11]: Treat all unobserved objects as
non-existent.
• Regularized online label estimation (ROLE) [5]: Combine the EPR and online
estimates of the unobserved labels throughout training.
• Large Loss Rejection (LL-R) [11]: Gradually reject large loss value samples
according to the training stages.
• Large Loss Correction (temporary) (LL-Ct) [11]: Utilize the loss calculated
from the modified labels but do not correct the actual labels.
Weakly Supervised Waste Classification with ALEC 357
• Large Loss Correction (permanent) (LL-Cp) [11]: Utilize the loss calculated
from the modified labels and correct the actual labels permanently.
• LL-R+BoostLU [12]: Employing the LL-R method and enhancing all areas
where CAMs are greater than 0.
The results are presented in Table 3 mAP is used to evaluate the classifi-
cation performance. ALEC outperforms all baseline methods on four datasets.
Naive AN assumes all unobserved labels correspond to objects absent in the
images, introducing substantial noise into model training and resulting in the
worst classification performance among baseline methods. WAN assumes that a
large proportion of unobserved labels are indeed absent in the images, meaning
only a small amount are false negative labels. Consequently, a fixed weight is
added to all assumed negative labels. However, WAN fails to consider the spe-
cific situation of each sample and only corrects for an overall perspective. ROLE
combines EPR with online estimates of unobserved labels, leveraging the find-
ing that the model can fit correct labels more quickly. However, it requires a
complex optimization process and is computationally expensive. LL-Cp is more
aggressive and permanently changes the labels of samples with large loss values
in subsequent training iterations. However, LL-Cp performs poorly on datasets
with lower prediction performance, such as MH and TACOS. On such challeng-
ing datasets, it is easy to correct true labels to false labels simply based on
loss values. Similarly, BoostLU enhances all parts of the CAMs greater than 0.
However, when the classification accuracy of the model is relatively low, this
simple method may blindly enhance irrelevant regions. In comparison, ALEC
corrects only in the later stages of training and corrects false labels when the
model exhibits relatively high classification performance.
358 W. Dai and L. Sun
4.4 Visualization
In this section, we use CAM to visualize the regions of interest in images for
models trained under different conditions. In CAMs, the red areas represent the
parts that the model focuses on, and the brighter the red, the greater the impact
of that part on the model’s predictions. Under the condition of full labeling, the
model focuses on areas that are more concentrated and have higher attention
than the other two methods. When using LL-R, the model occasionally treats
the background noise in the image as a feature of the object. For example, when
classifying the two images below, the background of the entire image is covered
in red, but the main body of the object is not noticed by the model. ALEC
Weakly Supervised Waste Classification with ALEC 359
almost never experiences this situation. Compared to using LL-R, the model
using ALEC can focus on a more comprehensive appearance of objects, such as
when classifying cigarettes (Fig. 4).
5 Conclusion
In this study, we design a new method to utilize unobserved labels in weakly-
supervised multi-label waste classification, called Adaptive Loss and Enhanced
CAMs (ALEC). Moreover, we integrate ALEC with deep learning classifica-
tion model to form a new framework called the Adaptive Weakly-supervised
Waste Classification Framework (AWWCF). Our experiments on three multi-
label waste datasets show that ALEC outperforms other weakly supervised clas-
sification methods. Although ALEC demonstrates relatively better classification
performance, there is still a noticeable gap when compared to the results under
the full label scenario. In the future, we will explore other potential factors
that may affect the model’s performance in weakly supervised multi-label waste
classification, including the impact of label distribution and object scale on the
model’s classification effectiveness.
References
1. Adedeji, O., Wang, Z.: Intelligent waste classification system using deep learning
convolutional neural network. Procedia Manuf. 35, 607–612 (2019). https://doi.
org/10.1016/j.promfg.2019.05.086
2. Ahmed, M.I.B., et al.: Deep learning approach to recyclable products classification:
towards sustainable waste management. Sustainability 15(14) (2023). https://doi.
org/10.3390/su151411138
3. Alrayes, F.S., et al.: Waste classification using vision transformer based on multi-
layer hybrid convolution neural network. Urban Clim. 49, 101483 (2023). https://
doi.org/10.1016/j.uclim.2023.101483
4. Arachie, C., Huang, B.: Constrained labeling for weakly supervised learning. In:
de Campos, C., Maathuis, M.H. (eds.) Proceedings of the Thirty-Seventh Con-
ference on Uncertainty in Artificial Intelligence. Proceedings of Machine Learning
Research, vol. 161, pp. 236–246. PMLR (2021)
360 W. Dai and L. Sun
5. Cole, E., Mac Aodha, O., Lorieul, T., Perona, P., Morris, D., Jojic, N.: Multi-label
learning from single positive labels. In: Proceedings of the IEEE/CVF Conference
on Computer Vision and Pattern Recognition (CVPR), pp. 933–942 (2021)
6. Cárdenas-León, I., Koeva, M., Nourian, P., Davey, C.: Urban digital twin-based
solution using geospatial information for solid waste management. Sustain. Urban
Areas 115, 105798 (2024). https://doi.org/10.1016/j.scs.2024.105798
7. Guo, H., Wang, S.: Long-tailed multi-label visual recognition by collaborative train-
ing on uniform and re-balanced samplings. In: Proceedings of the IEEE/CVF Con-
ference on Computer Vision and Pattern Recognition (CVPR), pp. 15089–15098
(2021)
8. Han, M., Wu, H., Chen, Z., Li, M., Zhang, X.: A survey of multi-label classification
based on supervised and semi-supervised learning. Int. J. Mach. Learn. Cybern.
14(3), 697–724 (2023). https://doi.org/10.1007/s13042-022-01658-9
9. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In:
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition
(CVPR) (2016)
10. Huang, L., Li, M., Xu, T., Dong, S.Q.: A waste classification method based on a
capsule network. Environ. Sci. Pollut. Res. 30(36), 86454–86462 (2023). https://
doi.org/10.1007/s11356-023-27970-7
11. Kim, Y., Kim, J.M., Akata, Z., Lee, J.: Large loss matters in weakly supervised
multi-label classification. In: Proceedings of the IEEE/CVF Conference on Com-
puter Vision and Pattern Recognition (CVPR), pp. 14156–14165 (2022)
12. Kim, Y., Kim, J.M., Jeong, J., Schmid, C., Akata, Z., Lee, J.: Bridging the gap
between model explanations in partially annotated multi-label classification. In:
Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recog-
nition (CVPR), pp. 3408–3417 (2023)
13. Ma, N., Zhang, X., Zheng, H.T., Sun, J.: ShuffleNet V2: practical guidelines for
efficient CNN architecture design. In: Proceedings of the European Conference on
Computer Vision (ECCV) (2018)
14. Mao, W.L., Chen, W.C., Wang, C.T., Lin, Y.H.: Recycling waste classification
using optimized convolutional neural network. Resour. Conserv. Recycl. 164,
105132 (2021). https://doi.org/10.1016/j.resconrec.2020.105132
15. Masand, A., Chauhan, S., Jangid, M., Kumar, R., Roy, S.: ScrapNet: an efficient
approach to trash classification. IEEE Access 9, 130947–130958 (2021). https://
doi.org/10.1109/ACCESS.2021.3111230
16. Pati, P., et al.: Weakly supervised joint whole-slide segmentation and classification
in prostate cancer. Med. Image Anal. 89, 102915 (2023). https://doi.org/10.1016/
j.media.2023.102915
17. Proença, P.F., Simões, P.: TACO: trash annotations in context for litter detection.
CoRR abs/2003.06975 (2020). arXiv: 2003.06975
18. Ridnik, T., et al.: Asymmetric loss for multi-label classification. In: Proceedings of
the IEEE/CVF International Conference on Computer Vision (ICCV), pp. 82–91
(2021)
19. Shennib, F., Schmitt, K.: Data-driven technologies and artificial intelligence in
circular economy and waste management systems: a review. In: 2021 IEEE Inter-
national Symposium on Technology and Society (ISTAS), pp. 1–5 (2021). https://
doi.org/10.1109/ISTAS52410.2021.9629183
20. Soni, U., Roy, A., Verma, A., Jain, V.: Forecasting municipal solid waste generation
using artificial intelligence models-a case study in India. SN Appl. Sci. 1(2), 162
(2019). https://doi.org/10.1007/s42452-018-0157-x
Weakly Supervised Waste Classification with ALEC 361
21. Sun, L., Dai, W., Muhammad, G.: Multi-level graph memory network cluster con-
volutional recurrent network for traffic forecasting. Inf. Fusion 105, 102214 (2024).
https://doi.org/10.1016/j.inffus.2023.102214
22. Sun, L., Li, C., Liu, B., Zhang, Y.: Class-driven graph attention network for multi-
label time series classification in mobile health digital twins. IEEE J. Sel. Areas
Commun. 41(10), 3267–3278 (2023). https://doi.org/10.1109/JSAC.2023.3310064
23. Sun, L., Li, Y., Zheng, M., Zhong, Z., Zhang, Y.: MCnet: multiscale visible image
and infrared image fusion network. Signal Process. 208, 108996 (2023). https://
doi.org/10.1016/j.sigpro.2023.108996
24. Verelst, T., Rubenstein, P.K., Eichner, M., Tuytelaars, T., Berman, M.: Spatial
consistency loss for training multi-label classifiers from single-label annotations. In:
Proceedings of the IEEE/CVF Winter Conference on Applications of Computer
Vision (WACV), pp. 3879–3889 (2023)
25. Xiao, J., Xu, J., Tian, C., Han, P., You, L., Zhang, S.: A serial attention frame for
multi-label waste bottle classification. Appl. Sci. 12(3) (2022). https://doi.org/10.
3390/app12031742
26. Xu, J., Pan, Y., Pan, X., Hoi, S., Yi, Z., Xu, Z.: RegNet: self-regulated network
for image classification. IEEE Trans. Neural Netw. Learn. Syst. 34(11), 9562–9567
(2023). https://doi.org/10.1109/TNNLS.2022.3158966
27. Yun, S., Oh, S.J., Heo, B., Han, D., Choe, J., Chun, S.: Re-labeling ImageNet:
from single to multi-labels, from global to localized labels. In: 2021 IEEE/CVF
Conference on Computer Vision and Pattern Recognition (CVPR), pp. 2340–2350
(2021). https://doi.org/10.1109/CVPR46437.2021.00237
A Vehicle Asynchronous Communication
Scheme Based on Federated Deep
Reinforcement Learning
1 Introduction
The widespread application of the Internet of Vehicles (IoV) enables data shar-
ing between vehicles and edge networks, bringing technologies like autonomous
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 362–376, 2025.
https://doi.org/10.1007/978-981-96-4566-4_25
A Vehicle Asynchronous Communication Scheme Based on FDRL 363
2 Related Work
Federated learning (FL) enables multiple devices or data sources to collabora-
tively train models without sharing raw data, ensuring privacy [1]. Recently, FL
has gained significant attention, with Li et al. [2] introducing its core compo-
nents, such as data configuration, models, privacy measures, and communication
frameworks. Nknam et al. [3] focused on improving training speed by leveraging
edge computing, data caching, and spectrum management in wireless networks,
highlighting the reduction of traffic burdens through edge caching.
A key challenge in FL is communication overhead, particularly in large-scale
networks [4]. To address this, research has shifted toward data compression tech-
niques that reduce energy consumption and communication costs while main-
taining training quality. Federated learning depends on iterative communication
between IoT devices and edge servers to train global models, where increased
interaction can speed up convergence but also raise communication costs. Solu-
tions to mitigate these costs include distributed training, model simplification,
364 J. Huang et al.
3 System Model
In this section, the system model of this paper is presented. As shown in Fig. 1,
the architecture comprises RSUs and vehicle users. Within the RSU’s cover-
age area, vehicle users first train their local models using their respective data,
A Vehicle Asynchronous Communication Scheme Based on FDRL 365
then upload the updated model parameters to the RSU for aggregation. The
RSU aggregates these model parameters to create a global model, which is then
distributed to the vehicle users involved in the federated learning process. This
iterative process continues until the model reaches optimal training performance.
In the system model, model gradient exchanges between RSUs and vehicles
occur through wireless communication, which unavoidably causes delays. Com-
pared to the upload time, the download time is negligible because RSUs have
significantly higher downlink power than vehicles’ uplink power and allocate
greater downlink bandwidth for data distribution. Furthermore, since RSUs pos-
sess high computational capabilities, the delay caused by the aggregation phase
can also be ignored. The delay model is divided into two parts: upload delay and
model training delay.
(1) Upload Delay: Assuming that vehicles are covered by a base station, Orthog-
onal Frequency Division Multiple Access (OFDMA) is adopted as the
method for uploading local gradients. Considering limited communication
resources, it is assumed that each RSU distributes N subcarriers among W
vehicles (N < W ) for local parameter transmission, ensuring each vehicle
receives at most one subcarrier. Thus, the uplink transmission rate of vehicle
i is:
up qi,k ui,k b−γ
i,k
νi.k = si,k B log2 l + , (1)
σ02
where B refers to the uplink channel bandwidth, qi,k denotes the transmis-
sion power of vehicle i in the k-th round, ui,k indicates the channel gain
between vehicle i and the RSU, σ02 is the noise level, bi,k represents the
366 J. Huang et al.
N
distance to the RSU, and γ is the path loss coefficient. sk,t = n=1 ai,n ,
where ai,n is the number of subcarriers allocated. If subcarrier n is assigned
to vehicle i, then ai,n = 1; otherwise, ai,n = 0.
Assume the model size is M . Then, the communication delay for vehicle i
at the k-th round is:
M
tup
i,k = up . (2)
νi,k
(2) Model training delay: Vehicle i in the global iteration, using a dataset of size
Di for local training, the computation delay is defined as:
EDi c
tcom
i.k = , (3)
fi,k
where E denotes the total number of local iteration rounds, c represents the
CPU cycles required to process one bit of data, and fi,k signifies the CPU
frequency of vehicle i during the k-th training round.
(3) The total delay calculation: The delay is designed as the sum of the com-
putation time of the vehicle client i and the model upload time. Thus, the
total delay in the k th round is:
up
tsum com
i,k = ti.k + ti.k . (4)
The federated learning process occurs between the RSU and the vehicles within
its coverage. Each vehicle in the set i = {1, 2, ..., n} has its local dataset Di =
{D1 , D2 , ..., Dn } and corresponding local model parameters θ = {θ1 , θ2 , ..., θn }.
For each vehicle i, the dataset Di = {xi , yi } consists of input data xi and the
associated output yi .
The federated learning process consists of several stages. In each round, there
are three key steps: First, vehicles train their local models and upload their model
parameters to the RSU. Next, the RSU aggregates the local models to generate
the global model parameters, which are then sent back to the vehicles. Finally,
based on the updated global model parameters, the vehicles initialize their local
training parameters for the next round.
During local training, the local loss function for the dataset Di of vehicle i
can be expressed as the loss calculated over the data sample set at vehicle i:
1
Fi (θ) = f (θi ; xi , yi ) . (9)
|Di |
(xi ,yi )∈Di
Based on the above loss function, the local model parameters can be updated
and adjusted using standard gradient descent methods. After completing all the
local training, the local model parameters for the rth round are obtained as
{θl (r), θ2 (r), ..., θn (r)}. The aggregated global model parameters can be defined
as:
n
1
θ∗ (r) = n |Di | θi (r). (10)
i=1 |Di | i=1
Di
where εi = D .
4 Optimization Scheme
Building on the above system model, the objective is to minimize the global loss
function of federated learning while maintaining constraints on delay and energy
consumption within the vehicular network, as follows:
P1:min F (θ)
up
s.t. Cl: max tcom
i,k + ti,k ≤ tmax ,
i (12)
n
C2: ecom
i,k + eup
i,k ≤ κ.
i=l
368 J. Huang et al.
Constraint C1 defines the delay limitation for the vehicle client, while con-
straint C2 sets the energy consumption limit, with K representing the total
energy budget. Given the mobility of vehicles and the dynamic characteristics of
wireless channels, uncertainties may arise in the performance, delay, and energy
consumption of federated learning. For instance, if a vehicle moves out of the
RSU’s coverage area or encounters poor wireless conditions, it may fail to com-
plete its assigned task within the stipulated timeframe. Additionally, since the
loss function does not have a closed-form expression, problem P1 cannot be
directly solved. To address this, a client selection algorithm based on DRL is
proposed, utilizing its capability to determine optimal strategies in complex
environments. Furthermore, an asynchronous transmission method is employed
to upload local model updates to the RSU, effectively reducing client waiting
time and minimizing transmission delay.
the difference in loss values between the t−1 and tth iterations, ΔHt . There-
fore, the state function at time t can be described as st = (Lt , Et , Wt , ΔHt ).
The state space can be represented as S = {si |i = 1, 2, ...}, where si is the
potential state at the ith time.
(2) Action
The action at =jt ∈ {1, 2, ..., n} is a positive integer representing the number
of clients selected. At the tth iteration, the DRL agent picks an action based
on the current state to carry out the global model update.
(3) Policy
The policy π maps the state space to the action space, i.e., at = π(st ).
To represent the probability distribution of actions, a neural network is
employed to model π(a|s).
(4) Reward
Following the execution of action at , the agent receives a reward rt , which
depends on both the loss function value and the resource consumption at
time t, and is given by:
Et −E t−l
ΔHt
rt = −ξ E t−l
+ , t ∈ {1, 2, . . . , T } (13)
F
where ξ is a constant that ensures rt increases exponentially as the energy
consumption Et decreases. The energy consumption Et−1 is calculated as
the moving average, satisfying Et = τ Et + (1 − τ )Et−1 , where τ ∈ (0, 1). At
time t, the agent is rewarded more when energy consumption is lower and
the loss function value is closer to convergence.
must be considered. Vehicle clients with higher computational capacity and lower
energy consumption should be prioritized for frequent participation in global
updates. Furthermore, the inherent instability in vehicular networks, such as
road congestion and device disconnections, necessitates dynamic adjustments to
the RSU’s global update strategy based on real-time network conditions.
To illustrate the asynchronous transmission process more clearly, Fig. 3 com-
pares dynamic asynchronous federated learning (left subfigure), synchronous fed-
erated learning (middle subfigure), and standard asynchronous federated learn-
ing (right subfigure).
In the IoV scenario with four vehicle clients, federated learning strategies
vary in their approach to global updates. In synchronous FL, the RSU waits
for all four vehicle clients to submit their updated models before initiating a
global update, ensuring equal contributions but potentially causing delays. On
the other hand, asynchronous FL enables the RSU to perform a global update as
soon as it receives updated parameters from any client that completes its local
training, minimizing waiting time. Dynamic asynchronous FL further enhances
adaptability by allowing the RSU to adjust the number of participating clients
dynamically for each iteration round; at time t, the RSU performs a global
update when it receives updated models from jt clients, where jt changes with
the iteration to accommodate the varying conditions of vehicular networks. As
illustrated in Fig. 3, at time t1 , the RSU conducts a global update with con-
tributions from two clients, while at time t2 , the update involves three clients,
demonstrating the flexibility of this approach in addressing the dynamic nature
of the IoV environment.
Firstly, each vehicle client i updates its model parameters locally during the
training phase r ∈ {1, 2, ...}:
where mt,i represents the local update index of vehicle client i at time t, and zt,i ∈
{0, 1} is a binary variable that indicates whether vehicle client i participates in
the global update at time t, subject to the constraint i zt,i ≥ nt , ∀i ∈ N .
The RSU then broadcasts the updated model to all vehicle clients, and these
clients update their local models following the equation θi (mt,i + 1) = θ (t), for
i ∈ N . The global loss function value at time T , noted as F (θ (t)), is given by:
zt,i |Di |Fi (θ (t))
F (θ (t)) = i∈N
(15)
i∈N zt,i |Di |
Parameter Value
Vehicle User 30
RSU 1
Noise Power −100 dm
Uplink Link Bandwidth 20 MHz
Subcarriers to be Allocated 20
Maximum Uplink Link Power of Vehicle 17 dBm
Local Computing Capability 3 GHz
Effective Switching Capacitance of Local Computation 10−29
(2) FashionMNIST Dataset × LeNet-5 Model: The LeNet-5 model, which con-
sists of two convolutional layers and three fully connected layers, is applied
to train the FashionMNIST dataset. FashionMNIST, provided by Zalando,
contains 10 classes, 60,000 training samples, and 10,000 test samples of fash-
ion article images.
To evaluate the superiority of the proposed scheme, comparisons are made
with the following three baselines and the proposed algorithm:
– FedAvg [21]: A synchronous federated learning method that randomly selects
a fixed number of clients to participate in training.
– FedCS [22]: A synchronous federated learning method employing a greedy
client scheduling strategy to select as many clients as possible for each round
of federated learning training.
– FedAsync [23]: An asynchronous federated learning method allowing clients
to perform local model training and send updates to the server at different
times without waiting for other clients to complete their training.
In this method, a DRL agent is trained across multiple datasets with 100 devices.
The DDQN model features two two-layer MLP networks, each with 512 hidden
units. The input size is 10100, which includes 101 model weights, derived from
reducing the 100 device weights (global w and local θ(k)) into a 100-dimensional
representation. The second layer outputs 100 values, processed by a softmax
layer to determine the probability of selecting each device. The DDQN model is
fast, requiring only a few seconds per training iteration.
Figure 4 illustrates the training process of the DRL agent across three differ-
ent datasets. For each learning task, it can be observed that the reward varies
significantly during the initial 100 rounds. As training progresses, the rewards
converge to a stable and high value. For instance, MNIST takes approximately
100 rounds to converge to the maximum reward, while FMNIST takes about 80
rounds. Before training, a target accuracy is preset-98% for the MNIST dataset
and 85% for the FashionMNIST dataset.
6 Conclusion
This paper proposes a DRL-based dynamic node communication scheme to
address high communication costs due to frequent vehicle communications. It
first analyzes the problem by considering real-time network conditions, vehi-
cle capabilities, and task requirements, establishing a joint optimization objec-
tive for federated learning accuracy, resource consumption, and communication
latency. A DRL-based node selection algorithm is then designed, modeling the
selection process as a Markov decision process to identify the optimal node com-
bination for training. Finally, a dynamic asynchronous aggregation strategy is
introduced, enabling model updates from different vehicles to be received at
different times. Experimental results on standard datasets show that the pro-
posed method significantly improves training accuracy and reduces communica-
tion latency compared to baseline algorithms.
Acknowledgments. This work was sponsored by the National Natural Science Foun-
dation of China (Grant No. 62271264).
References
1. Elbir, A.M., Soner, B., Çöleri, S., Gündüz, D., Bennis, M.: Federated learning
in vehicular networks. In: 2022 IEEE International Mediterranean Conference on
Communications and Networking (MeditCom), pp. 72–77. IEEE (2022)
2. Li, Q., et al.: A survey on federated learning systems: vision, hype and reality
for data privacy and protection. IEEE Trans. Knowl. Data Eng. 35(4), 3347–3366
(2021)
3. Niknam, S., Dhillon, H.S., Reed, J.H.: Federated learning for wireless communi-
cations: motivation, opportunities, and challenges. IEEE Commun. Mag. 58(6),
46–51 (2020)
4. Liu, S., Yu, J., Deng, X., Wan, S.: FedCPF: an efficient-communication federated
learning approach for vehicular edge computing in 6G communication networks.
IEEE Trans. Intell. Transp. Syst. 23(2), 1616–1629 (2021)
5. Haddadpour, F., Kamani, M.M., Mahdavi, M., Cadambe, V.: Trading redundancy
for communication: speeding up distributed SGD for non-convex optimization. In:
International Conference on Machine Learning. PMLR, pp. 2545–2554 (2019)
6. Liu, L., Zhang, J., Song, S., Letaief, K.B.: Edge-assisted hierarchical federated
learning with non-IID data. arXiv preprint arXiv:1905.06641 (2019)
7. Yao, X., Huang, T., Wu, C., Zhang, R., Sun, L.: Towards faster and better federated
learning: a feature fusion approach. In: 2019 IEEE International Conference on
Image Processing (ICIP), pp. 175–179. IEEE (2019)
8. Wei, S., Tong, Y., Zhou, Z., Song, T.: Efficient and fair data valuation for horizontal
federated learning. In: Federated Learning: Privacy and Incentive, pp. 139–152
(2020)
9. Song, T., Tong, Y., Wei, S.: Profit allocation for federated learning. In: 2019 IEEE
International Conference on Big Data (Big Data), pp. 2577–2586. IEEE (2019)
10. Chai, Z., et al.: TiFL: a tier-based federated learning system. In: Proceedings of
the 29th International Symposium on High-Performance Parallel and Distributed
Computing, pp. 125–136 (2020)
376 J. Huang et al.
11. Lai, F., Zhu, X., Madhyastha, H.V., Chowdhury, M.: Oort: efficient federated learn-
ing via guided participant selection. In: 15th USENIX Symposium on Operating
Systems Design and Implementation (OSDI 21), pp. 19–35 (2021)
12. Ye, D., Yu, R., Pan, M., Han, Z.: Federated learning in vehicular edge computing:
a selective model aggregation approach. IEEE Access 8, 23920–23935 (2020)
13. Chai, H., Leng, S., Chen, Y., Zhang, K.: A hierarchical blockchain-enabled feder-
ated learning algorithm for knowledge sharing in internet of vehicles. IEEE Trans.
Intell. Transp. Syst. 22(7), 3975–3986 (2020)
14. Lu, Y., Huang, X., Dai, Y., Maharjan, S., Zhang, Y.: Federated learning for data
privacy preservation in vehicular cyber-physical systems. IEEE Network 34(3),
50–56 (2020)
15. Pokhrel, S.R., Choi, J.: Federated learning with blockchain for autonomous vehi-
cles: analysis and design challenges. IEEE Trans. Commun. 68(8), 4734–4746
(2020)
16. Sun, K., et al.: Joint top-K sparsification and shuffle model for communication-
privacy-accuracy tradeoffs in federated learning-based IoV. IEEE Internet Things
J. (2024)
17. Wu, J., et al.: FedAPT: joint adaptive parameter freezing and resource allocation
for communication-efficient federated vehicular networks. IEEE Internet Things J.
(2024)
18. Samarakoon, S., Bennis, M., Saad, W., Debbah, M.: Federated learning for ultra-
reliable low-latency V2V communications. In: 2018 IEEE Global Communications
Conference (GLOBECOM), pp. 1–7. IEEE (2018)
19. Cao, J., Zhang, K., Wu, F., Leng, S.: Learning cooperation schemes for mobile
edge computing empowered internet of vehicles. In: 2020 IEEE Wireless Commu-
nications and Networking Conference (WCNC), pp. 1–6. IEEE (2020)
20. Parekh, R., et al.: GeFL: gradient encryption-aided privacy preserved federated
learning for autonomous vehicles. IEEE Access 11, 1825–1839 (2023)
21. McMahan, B., Moore, E., Ramage, D., Hampson, S., y Arcas, B.A.:
Communication-efficient learning of deep networks from decentralized data. In:
Artificial Intelligence and Statistics, pp. 1273–1282. PMLR (2017)
22. Liu, Y., Chang, S., Liu, Y.: FedCS: communication-efficient federated learning with
compressive sensing. In: 2022 IEEE 28th International Conference on Parallel and
Distributed Systems (ICPADS), pp. 17–24. IEEE (2023)
23. Xie, C., Koyejo, S., Gupta, I.: Asynchronous federated optimization. arXiv preprint
arXiv:1903.03934 (2019)
A Vehicles Scheduling Algorithm Based
on Clustering Based Federated Learning
1 Introduction
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 377–391, 2025.
https://doi.org/10.1007/978-981-96-4566-4_26
378 X. Zhang et al.
high communication overhead and privacy concerns [2]. Federated learning, a dis-
tributed machine learning algorithm, helps address these challenges by enabling
decentralized model training [3].
Model training depends on numerous parameters, and limited network
resources become a bottleneck for data transmission and real-time information
exchange [4]. This impacts federated learning’s efficiency, particularly in large
model transmission [5]. User scheduling within vehicular networks optimizes
resource allocation and reduces congestion [6]. Effective scheduling strategies
[7] can reduce transmission latency, ensuring timely information. To improve
model convergence and minimize update delays, this paper proposes a federated
learning-based vehicular user scheduling scheme. Focusing on minimizing com-
munication time, it integrates four key factors: data importance, communication
rounds, channel state, and communication delay per round, to determine the
optimal vehicle selection probability. The specific contributions of this paper are
as follows:
2 Related Work
2.1 Federated Learning in Vehicular Networks
3 System Model
Each vehicle has computational resources to convert raw data into parameter
models. RSUs cluster vehicles based on data similarity, with vehicles in the same
cluster sharing a global model. Vehicles with more computational resources and
stable movement are prioritized as cluster heads to reduce cluster instability and
communication overhead caused by frequent head changes.
To learn model features from distributed vehicular data, the RSU solves the
following distributed optimization problem:
K
min F (w) = qk Fk (w), (1)
w
k=1
A Vehicles Scheduling Algorithm Based on Clustering Based FL 381
where w represents the learning model parameters, qk is the weight of the k-th
K
closest device, and qk > 0 with the constraint that k=1 qk = 1. Fk (w) denotes
the local loss function at device k, which is expressed as:
1
Fk (w) = f (w, x), (2)
nk
x∈Dk
E(Ck , Mk ) = δ · Ck + (1 − δ) · Mk , (3)
where T (t) represents the communication time required after t rounds of updates,
(i) (i)
Nt is the number of communication rounds still needed, TB and TU represent
382 X. Zhang et al.
the time to send the global model from the RSU to each cluster head and the
upload time of the local model in round i, respectively. The sum of these two is
the communication time for round i. The communication time is not considered
(i) (i)
in the scheduling strategy and can be neglected compared to TB and TU . The
minimization problem can be transformed into:
(t)
P1 min TU + Nt+1
R
TUR ,
(t)
p(t) ,...,pk
(5)
K
(t)
s.t. C1 : pk = 1,
k=1
R
where Nt+1 represents the number of communication rounds still needed after
t rounds of training iterations, and TUR represents the communication time for
each round in future rounds.
Transforming P 1 into a convex optimization problem, it can be solved using
the KKT (Karush-Kuhn-Tucker) conditions. Assuming:
(t) (t) (t)
TU = f (Pk , qd, BRk ), (6)
and taking the derivative:
(t)
∂TU λ qd
(t)
=− (t) (t)
+ λ∗ , (7)
∂pk ϕt nnk wk BRk
Setting the above equation to 0 gives:
λ qd
(t) (t)
+ λ∗ = 0. (8)
ϕt nnk wk BRk
(t)
Combining the constraint conditions and the expression for pk gives:
K (t)
ϕt nk wk
n = 1. (9)
∗
(t) + λ
qd
k=1
BRk
Solving this equation can yield the Lagrange multiplier λ∗ . Through the above
KKT conditions, the optimal solution can be obtained:
(t)
(t)∗ ϕt nk wk
pk = n , (10)
∗
(t) + λ
qd
BRk
where ϕ1 is a scalar factor that adjusts the proportion between parameter impor-
(t) 1
tance nnk wk and transmission rate (t)
. These three parameters
qd /BRk +λ∗
(t)∗ (t)∗
determine the optimal probability pk .
As t increases, pk decreases contin-
uously. As training progresses, the number of communication rounds t gradu-
ally increases, and the scheduling probability distribution will be more inclined
towards the transmission rate of the vehicle devices.
A Vehicles Scheduling Algorithm Based on Clustering Based FL 383
hk (r)gk−α hk (r)gk−α
SINRk (r) = = , (14)
I(r) + σ 2 x∈ΦI hx (r)x
−α + σ 2
where σ 2 is the normalizednoise power (noise power divided by the device trans-
mission power), I(r) = x∈ΦI hx (r)x−α is the interference during the first
attempt, ΦI represents the set of interferers, and hx (r) represents the small-
scale fading gain between the vehicle at position x and the RSU at the origin.
Assuming Rayleigh fading, hx ∼ exp(1). α represents the path loss exponent,
and since the set of interferers is the same across different transmission attempts
in an aggregation step, the success rate during the aggregation period is time-
dependent.
384 X. Zhang et al.
4 Algorithm Steps
Combining the clustering scheme based on inference similarity and the user
scheduling scheme that minimizes communication delay, a federated learning-
based vehicle user scheduling research scheme is proposed. Algorithm 1 provides
the pseudocode for the scheme.
In the first round, the RSU samples the vehicles and broadcasts the global
initial model parameters. Vehicles use the received model to train with their
own data and send it back to the RSU, which then clusters the vehicles. A
neighborhood matrix is constructed from the models obtained from the vehicle
clients:
Bit Bjt F
Ati,j = , i, j = 1, . . . , n, (15)
Bit F Bjt F
where Ati,j represents the similarity between vehicles i and j based on their model
parameters. A threshold operator Γ is defined and applied as:
The RSU places each row’s values of à into the same cluster to form clusters
Tt+1
{V Cjt+1 }jt+1 , jt = 1, . . . , Tt of vehicles with similar data. Concurrently, the RSU
selects a cluster head based on the computational resources and mobility stability
of the vehicles within the cluster and sends these results back to the vehicles. The
(t)
scheduler calculates the optimal probability pk for each vehicle to be selected,
and vehicles are chosen for model training based on this probability distribution.
In the t-th iteration, the RSU samples data from the cluster heads and
updates the global model as follows:
N
1
w(t+E) = wt + (k ∈ St , SINRk > θ) (v(t+E)
k
− wt ), (17)
Uk
k=1
k
v(t+i+1) = w(t+i)
k
− η(t+i) ∇Fk (w(t+1)
k
; ξ(t+i)
k
), i = 0, . . . , E − 1, (18)
k
where η(i+1) is the learning rate, and ξi+1 is a uniformly selected sample from
vehicle k’s local dataset. After updating the model parameters, vehicles send
them back to the cluster heads, which process them and send them to the RSU.
The RSU uses the updated parameters to form dynamic vehicle clusters with
similar data and calculates the new global model parameters for each cluster.
Parameter Value
Transmission Bandwidth 5 MHz
Vehicle Computing Resources 0.5–1.5 GHz
Vehicle Transmission Power 1.3 W
Batch Size 64
Path Loss 0.0001
Number of Vehicles 100
Path Loss Exponent 4
SINR Threshold −15 dB
1. FedAvg [32]: The central server performs a weighted average of the locally
trained model parameters to achieve an update of the global model.
2. Fed-Cluster: Incorporating an inference similarity-based clustering method
on top of FedAvg.
3. Importance Aware (IA): The probability of vehicle selection is solely
related to the importance of its data.
4. Importance and Channel Aware (ICA): Under the federated learning
framework, the server selects devices for scheduling considering both data
importance and communication channel status.
scheduling. Compared to IA and ICA, the method proposed in this paper excels
in convergence speed, showing signs of convergence around 60 s, as it takes into
account both communication time and training rounds. This indicates that, com-
pared to other algorithms, the scheduling strategy of this paper can more rapidly
aggregate the local models of distributed vehicles into a global model under lim-
ited communication resources, thus accelerating the overall federated learning
process.
6 Conclusion
This paper addresses the issue of limited communication resources by designing
a vehicle user scheduling scheme based on the federated learning framework. To
tackle the accuracy decline caused by heterogeneous data in vehicular networks,
an inference similarity-based clustering method is proposed, selecting appro-
priate cluster heads based on vehicles’ computational resources and mobility
stability. Before each round of parameter transmission, the scheme links data
importance with communication rounds, channel quality, and single-round delay
to solve for the optimal selection probability of vehicles within clusters, conserv-
ing communication resources. Considering the variability in device transmission
success rates that could lead to shifts in global model characteristics, the success
probability is incorporated at the RSU aggregation stage to enhance the effec-
tiveness of federated learning. Finally, experiments on real datasets verify the
superior performance of this algorithm in terms of accuracy, convergence speed,
and delay.
A Vehicles Scheduling Algorithm Based on Clustering Based FL 389
Acknowledgments. This work was sponsored by the National Natural Science Foun-
dation of China (Grant No. 62271264).
References
1. Han, S., Xiao, F., Cheng, W.: A review on the application of deep reinforcement
learning in automatic driving systems. J. Xihua Univ. (Nat. Sci. Ed.) 42(4), 25–31
(2023). https://doi.org/10.12198/j.issn.1673-159X.4740
2. Zijia, M., Zhipeng, G., Yang, Y., et al.: An efficient distributed model sharing strat-
egy for data privacy protection in Telematics. J. Commun. 43(4), 83–94 (2022).
https://doi.org/10.11959/j.issn.1000-436x.2022074
3. Guanglei, G., Bo, G., Ke, X., et al.: An overview of federated learning-enabled 6G
networks. J. Internet Things 7(2), 50–66 (2023). https://doi.org/10.11959/j.issn.
2096-3750.2023.00323
4. Sun, B., Liu, Y., Wang, T., et al.: A review of federated learning efficiency opti-
mization in mobile edge networks. Comput. Res. Dev. 65(7), 1439–1469 (2022).
https://doi.org/10.7544/issn1000-1239.20210119
5. Jiarui, W., Guoping, T., Siyuan, Z.: Split-cluster wireless federated learning algo-
rithm for high-speed vehicular networking scenarios. Comput. Appl. 41(6), 1546
(2021). https://doi.org/10.11772/j.issn.1001-9081.2020121912
6. Lin, F., Li, H., Luo, C., et al.: A joint computational offloading and resource
allocation algorithm for dependent tasks in vehicular networking. J. Chongqing
Univ. Posts Telecommun. (Nat. Sci. Ed.) 35(5) (2023). https://doi.org/10.3979/j.
issn.1673-825X.202210270296
7. Chengyu, Z., Yiting, Y., Hongbin, L., et al.: A review of optimal resource allocation
schemes for 5G Telematics. Telecommun. Sci. 39(7), 124–138 (2023). https://doi.
org/10.11959/j.issn.1000-0801.2023139
8. Kim, E.J., Lee, E.K.: Performance impact of differential privacy on federated
learning in vehicular networks. In: NOMS 2022-2022 IEEE/IFIP Network Oper-
ations and Management Symposium, pp. 1–5. IEEE (2022). https://doi.org/10.
1109/NOMS54207.2022.9789814
9. Huang, J., Xu, C., Ji, Z., et al.: AFLPC: an asynchronous federated learning
privacy-preserving computing model applied to 5G-V2X. Secur. Commun. Netw.
(2022). https://doi.org/10.1155/2022/9334943
10. Su, X., Huo, Y., Wang, X., et al.: An enhancing semi-supervised federated learning
framework for internet of vehicles. In: 2023 IEEE 98th Vehicular Technology Con-
ference (VTC2023-Fall), pp. 1–5. IEEE (2023). https://doi.org/10.1109/VTC2023-
fall60731.2023.10333466
11. Li, B., Jiang, Y., Pei, Q., et al.: FEEL: federated end-to-end learning with non-
IID data for vehicular ad hoc networks. IEEE Trans. Intell. Transp. Syst. 23(9),
16728–16740 (2022). https://doi.org/10.1109/TITS.2022.3190294
12. Zhou, X., Liang, W., She, J., et al.: Two-layer federated learning with heterogeneous
model aggregation for 6G supported internet of vehicles. IEEE Trans. Veh. Technol.
70(6), 5308–5317 (2021). https://doi.org/10.1109/TVT.2021.3077893
13. Bute, M.S., Fan, P., Luo, Q.: Incentive based federated learning data dissemination
for vehicular edge computing networks. In: 2023 IEEE 98th Vehicular Technol-
ogy Conference (VTC2023-Fall), pp. 1–5. IEEE (2023). https://doi.org/10.1109/
VTC2023-Fall60731.2023.10333479
390 X. Zhang et al.
14. Bao, W., Wu, C., Guleng, S., et al.: Edge computing-based joint client selection
and networking scheme for federated learning in vehicular IoT. China Commun.
18(6), 39–52 (2021). https://doi.org/10.23919/JCC.2021.06.004
15. Lim, W., Luong, N.C., Hoang, D.T., et al.: Federated learning in mobile edge
networks: a comprehensive survey. IEEE Commun. Surv. Tutor. 22(3), 2031–2063
(2020). https://doi.org/10.1109/COMST.2020.2986024
16. Li, Q., Wen, Z., Wu, Z., et al.: A survey on federated learning systems: vision,
hype and reality for data privacy and protection. IEEE Trans. Knowl. Data Eng.
(2021). https://doi.org/10.1109/TKDE.2021.3124599
17. Sai, Z., Tianrui, L., Wei, H.: A federated learning algorithm for communication
cost optimization. Comput. Appl. 43(1), 1 (2023). https://doi.org/10.11772/j.issn.
1001-9081.2021122054
18. Zhang, H., Zhang, Y., Cao, C.: A federated learning approach to solve the data het-
erogeneity problem. Appl. Res. Comput./Jisuanji Yingyong Yanjiu 41(3) (2024).
https://doi.org/10.19734/j.issn.1001-3695.2023.07.0296
19. Liu, J., Xu, H., Xu, Y., et al.: Communication-efficient asynchronous federated
learning in resource-constrained edge computing. Comput. Netw. 199, 108429
(2021). https://doi.org/10.1016/j.comnet.2021.108429
20. Nishio, T., Yonetani, R.: Client selection for federated learning with heterogeneous
resources in mobile edge. In: 2019 IEEE International Conference on Communica-
tions (ICC), pp. 1–7. IEEE (2019). https://doi.org/10.1109/ICC.2019.8761315
21. Wang, S., Tuor, T., Salonidis, T., et al.: Adaptive federated learning in resource
constrained edge computing systems. IEEE J. Sel. Areas Commun. 37(6), 1205–
1221 (2019). https://doi.org/10.1109/JSAC.2019.2904348
22. Dinh, C.T., Tran, N.H., Nguyen, M., et al.: Federated learning over wireless net-
works: convergence analysis and resource allocation. IEEE/ACM Trans. Netw.
29(1), 398–409 (2020). https://doi.org/10.1109/TNET.2020.3035770
23. Yang, H.H., Liu, Z., Quek, T., et al.: Scheduling policies for federated learning in
wireless networks. IEEE Trans. Commun. 68(1), 317–333 (2019). https://doi.org/
10.1109/TCOMM.2019.2944169
24. Liu, L., Zhang, J., Song, S.H., et al.: Client-edge-cloud hierarchical federated learn-
ing. In: 2020 IEEE International Conference on Communications (ICC), pp. 1–6.
IEEE (2020). https://doi.org/10.1109/ICC40277.2020.9148862
25. Reisizadeh, A., Mokhtari, A., Hassani, H., et al.: FedPAQ: a communication-
efficient federated learning method with periodic averaging and quantization. In:
International Conference on Artificial Intelligence and Statistics, PMLR, pp. 2021–
2031 (2020). https://doi.org/10.48550/arXiv.1909.13014
26. Yang, K., Jiang, T., Shi, Y., et al.: Federated learning via over-the-air computation.
IEEE Trans. Wireless Commun. 19(3), 2022–2035 (2020). https://doi.org/10.1109/
TWC.2019.2961673
27. Diao, E., Ding, J., Tarokh, V.: HeteroFL: computation and communication effi-
cient federated learning for heterogeneous clients. arXiv preprint arXiv:2010.01264
(2020). https://doi.org/10.48550/arXiv.2010.01264
28. Posner, J., Tseng, L., Aloqaily, M., et al.: Federated learning in vehicular networks:
opportunities and solutions. IEEE Netw. 35(2), 152–159 (2021). https://doi.org/
10.1109/MNET.011.2000430
29. LeCun, Y., Bottou, L., Bengio, Y., et al.: Gradient-based learning applied to docu-
ment recognition. Proc. IEEE 86(11), 2278–2324 (1998). https://doi.org/10.1109/
5.726791
A Vehicles Scheduling Algorithm Based on Clustering Based FL 391
30. Li, J., Shao, Y., Wei, K., et al.: Blockchain assisted decentralized federated learning
(BLADE-FL), performance analysis and resource allocation. IEEE Trans. Paral-
lel Distrib. Syst. 33(10), 2401–2415 (2021). https://doi.org/10.1109/TPDS.2021.
3138848
31. Stallkamp, J., Schlipsing, M., Salmen, J., et al.: The German traffic sign recognition
benchmark: a multi-class classification competition. In: The 2011 International
Joint Conference on Neural Networks, pp. 1453–1460. IEEE (2011). https://doi.
org/10.1109/IJCNN.2011.6033395
32. Konečný, J., Mcmahan, H.B., Yu, F.X., et al.: Federated learning: strategies
for improving communication efficiency. arXiv preprint arXiv:1610.05492 (2016).
https://doi.org/10.48550/arXiv.1610.0549
33. Chen, M., Poor, H.V., Saad, W., et al.: Convergence time optimization for federated
learning over wireless networks. IEEE Trans. Wireless Commun. 20(4), 2457–2471
(2020). https://doi.org/10.1109/TWC.2020.3042530
34. Ren, J., He, Y., Wen, D., et al.: Scheduling for cellular federated edge learning
with importance and channel awareness. IEEE Trans. Wireless Commun. 19(11),
7690–7703 (2020). https://doi.org/10.1109/TWC.2020.3015671
A Cooperative Caching Strategy Based
on Deep Q-Network for Mobile Edge
Networks
1 Introduction
The surge in network traffic from mobile devices and bandwidth-intensive appli-
cations is overloading traditional cloud architectures, characterized by central-
ized processing and storage. This strain results in increased latency and dimin-
ished Quality of Experience (QoE), particularly affecting real-time applications.
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 392–407, 2025.
https://doi.org/10.1007/978-981-96-4566-4_27
A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 393
2 Related Work
Edge caching, a key research area in MEC [3, 4], alleviates core network band-
width pressure by caching popular content on edge servers. This reduces latency
and improves user QoS, crucial for low-latency, high-speed data transmission.
The core aims of edge caching strategies are to enhance cache space utilization
and reduce backhaul load, providing a superior experience for latency-sensitive
devices.
Latency and energy consumption are critical to user QoS. Pre-storing content
reduces both, improving QoS and extending device life. Thus, caching strategies
balancing latency and energy are vital. Li et al. [5] proposed a delay-constrained
sleep algorithm for energy conservation. Addressing service disruptions from
mobility and limited server coverage, Li et al. [6] presented an energy-delay bal-
anced cooperative caching strategy using a branch-and-bound algorithm. Kong
et al. [7] used a DDPG algorithm for a joint computation and caching framework,
minimizing energy. Vallero et al. [8] analyzed caching performance under vari-
ous traffic, capacity, and distribution conditions, achieving energy reductions.
Kang et al. [9] proposed a mobility-aware task scheduling strategy with data
caching, solved using an improved differential evolution algorithm. Zheng et al.
[10] transformed time-varying popularity into a Markov chain, implementing a
hybrid strategy for dynamic cache replacement and deterministic offloading. Ref-
erence [11] proposed a Multidimensional Cooperative Caching (MDCC) scheme
to minimize transmission latency.
Cache hit ratio is a key factor in caching decisions, representing the percent-
age of requests fulfilled by the cache. Higher hit ratios lead to shorter response
times, higher throughput, and better performance. Yuan et al. [12] established
a caching scheme maximizing the hit ratio under cooperative cost constraints,
using an ADMM-based distributed algorithm. Chen et al. [13] calculated video
request probabilities based on popularity, preferences, and video characteris-
tics, maximizing hit ratio under storage constraints. Algorithms in [14, 15] also
demonstrate good performance in terms of cache hit ratio.
Edge server storage is limited, making it difficult to cache all content. Pri-
oritizing content users prefer is necessary to improve resource utilization. User
satisfaction impacts CP revenue, making maximizing caching benefits from the
CP perspective crucial. Wang et al. [16] studied how CPs can maximize utility,
establishing a Stackelberg game with a reinforcement learning-based algorithm
for pricing strategies. He et al. [17] proposed CIPS to motivate content sharing,
selecting the most appropriate CP.
The aforementioned studies often lack consideration of the impact of time-
varying characteristics of mobile devices and traffic load on caching decisions.
A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 395
Due to limited storage and computing resources in mobile edge networks, effec-
tive allocation and utilization are vital to handle the increasing demands of
services.
3 System Model
3.1 System Architecture
where ϕen ∈ S represents the set of user devices covered by the EN , and Es is
the upper limit of the number of users an EN can serve.
396 C. Yang et al.
Similar to most existing studies, this chapter assumes that the content pop-
ularity follows a Zipf distribution. Therefore, the content access probability qf
is given as:
1
fγ
qf = 1 , (4)
f ∈AF f γ
When the requested content is cached at the local edge node, the delay for a
user to download content f directly from the edge node is given by:
f sf
dlocal
en = , (5)
rdown
where f sf ∈ F S represents the size of content f , and rdown = Ben,s log2 (1 +
p·hen,s
B·Nσ ) is the downlink transmission rate, Ben,s denotes the bandwidth allocated
between the edge node en and user device s, p is the transmission power of the
edge node, hen,s represents the channel gain, and Nσ denotes the high Gaussian
noise power.
When the requested content is cached at other neighboring nodes within
the cooperative domain of the local edge node, the transmission of the content
between edge nodes incurs a delay, which is expressed as:
f sf
dr = , (6)
rEN
where rEN denote the delay for transmitting content f between edge nodes
within the cooperative domain, under the assumption that the queuing delay for
requests is negligible. The delay can be expressed as:
en = dr + den .
dCD local
(7)
When the requested content is not cached at the edge nodes, it must be
transferred from the cloud server to the local edge node. The transmission delay
between the cloud server and edge nodes is expressed as:
f sf
dc = , (8)
rc
where rc represents the transmission rate per unit data between the edge node
and the cloud server, with rc < rEN . Consequently, the delay to download
content f from the cloud server can be expressed as:
dcloud = dc + dlocal
en . (9)
A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 397
In summary, the system’s total request delay cost can be expressed as:
D= Df · od , (11)
f ∈AF s∈U S
4 Optimization Scheme
4.1 The Edge Node Clustering Scheme Based on K-Means++
In an edge collaboration environment, the geographical distribution of edge nodes
is random, leading to varying distances between nodes. Consequently, the data
transmission delay between edge nodes can differ accordingly. In addition, due
to the differing numbers of users served by each node, there is an unequal distri-
bution of traffic load, which can result in resource shortages (both storage and
computation) at hotspot nodes, significantly increasing the delay of user requests.
Therefore, we classify the nodes into hotspot nodes and regular nodes based on
the network traffic data of each edge node, denoted as N etT = {n1 , n2 , . . . , nE },
where ni represents the traffic data of edge node i. Hotspot nodes are defined
as:
ni
Nhot = nn ≥ ni ∈N etT , n ∈ N etT . (14)
E
Then, an ordinary node is denoted as:
ni
Nnormal = nn < ni ∈N etT , n ∈ N etT . (15)
E
(1) States
The state space reflects the environmental state of the content delivery net-
work. Similar to most existing DRL-based caching algorithms, we represent
the state space S using the caching status S = {E1 , E2 , . . . , EN } of the edge
nodes, where Ei denotes the caching state of edge node i.
(2) Actions
The action space is represented by the caching decision matrix XEN ×AF .
Each edge node, as an agent, makes joint caching decisions based on its own
caching space.
(3) Rewards
The goal of reinforcement learning is to obtain the maximum reward, while
the objective function of this paper is to minimize the sum of content request
delay cost and caching cost. Therefore, we define the reward as:
400 C. Yang et al.
1
R= . (16)
D+C
After determining the state space, action space, and reward, the edge col-
laboration network caching process based on DQN is described as follows: The
edge nodes act as agents in the DQN algorithm, interacting with the environ-
ment through states, actions, and rewards. When the agent receives a content
request from a user device, it seeks the optimal content caching placement strat-
egy based on the current caching status of each edge node. This is achieved by
training to obtain the Q-values for all actions, allowing the agent to select the
optimal action. After the action is executed, the environment changes, and the
agent receives the current environment and its corresponding reward, which are
stored in the experience replay buffer. Then, a mini-batch of data, represented as
the four-tuple (s, a, r, s ) is randomly sampled from the experience replay buffer
to update the network parameters. The DQN network parameters are updated
using the Temporal Difference (TD) algorithm, which requires calculating the
TD target value, expressed as:
ŷ = r + γ max
Q(s , a ; θ ). (17)
a
In the DQN iteration process, the Q-values are trained by minimizing the
loss function to approximate the target values. The loss function is expressed as:
Through the above calculation, the update of the Q-value estimation net-
work parameters is achieved. Then, we periodically update the Q target network
parameters to ensure smooth changes in the Q-value estimation network param-
eters, enhancing the algorithm’s practicality. The DQN architecture is shown in
Fig. 2.
and multiple edge nodes, each equipped with an edge server, use their limited
cache space to store content of interest to users. Users with content requests
are distributed within the coverage area of the edge nodes. They can obtain
requested content either from the edge nodes or the cloud server. The specific
simulation environment parameters are shown in Table 1.
Parameter Value
System Bandwidth 10 MHz
Number of Edge Nodes 10
Edge Node Cache Space 100 MB
Number of Contents 500
Content Size 5 MB
Zipf Factor 0.6
Number of Centroids 3
Number of Neural Network Layers 2
Maximum Training Epochs 1000
Batch Sampling Size 64
Experience Replay Pool Size 5000
Learning Rate 0.001
where Nen,f represents the number of times content f is fetched from the
edge node, and Nf represents the total number of requests for content f .
(2) The content download latency (CDL) can be modeled as:
1
CDL = Df , (20)
S
f ∈AF s∈U S
where CDL represents the average delay of user content download requests.
A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 403
Figure 3 illustrates the impact of different Zipf parameters γ on cache hit ratio
for the proposed caching scheme. As the Zipf parameter increases, the cache hit
ratio for all algorithms improves. The larger the value of parameter γ, the higher
the popularity of the corresponding content items, meaning the probability of
users requesting that content increases, thus improving the cache hit ratio. More-
over, it can be clearly observed that the proposed scheme consistently achieves
a higher cache hit ratio than the four baseline methods. For example, with Zipf
parameter γ = 0.8, the cache hit ratio of the proposed algorithm is 19%, 22%,
and 28% higher than LFU, LRU, and FIFO, respectively, demonstrating the
decision-making capability of the reinforcement learning algorithm. Compared
to the “Without cluster” scheme, it improves by 7%, as edge node clustering
allows better utilization of idle cache space, resulting in more user-interesting
content being cached at the edge, thus increasing the cache hit ratio.
Figure 4 shows the impact of different edge node cache capacities on CHR
and CDL. From Fig. 4(a), it can be observed that as the edge node cache capacity
increases, the overall content cache hit ratio of the algorithm also increases, with
404 C. Yang et al.
a significant improvement. When the node cache capacity is 200MB, the cache
hit ratios for LFU, LRU, FIFO, and Without cluster are 54.4%, 53%, 50.9%,
and 62.4%, respectively. The proposed algorithm improves the cache hit ratio by
23%, 27%, 32%, and 8% compared to the baseline algorithms. The smaller the
cache capacity, the more pronounced the performance advantage of the proposed
algorithm. This is because, when the cache capacity is larger, more content can
be cached within the edge node collaboration domain, making it easier to satisfy
user requests, resulting in a smaller performance gap. The algorithm without
clustering also performs significantly better than LFU, LRU, and FIFO, which
indirectly indicates the advantages of reinforcement learning in cache decision-
making.
Figure 4(b) shows the impact of edge node cache capacity on content down-
load latency. As the cache capacity of edge nodes increases, the overall content
download latency decreases. When the edge node cache capacity increases, more
content can be stored, reducing the probability of downloading content from the
cloud server, thus lowering the download latency. Additionally, when the edge
node capacity is smaller, the proposed algorithm performs better. When the
cache node capacity is limited, edge nodes can cluster based on network traffic
data, making full use of the idle cache space of ordinary nodes within the coop-
eration domain. This allows for the storage of richer popular content within the
collaboration domain, and the data transmission latency between edge nodes is
much smaller than that between edge nodes and the cloud. Therefore, under
limited resources, the proposed scheme is more practical.
A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 405
Figure 5 shows how content quantity impacts Cache Hit Ratio (CHR) and
Content Download Latency (CDL), comparing the proposed scheme with the
“Without cluster” method. As seen in Fig. 5(a), increasing content quantity
decreases the CHR. With fixed cache space, more content means less can be
stored, reducing the CHR. However, the proposed scheme’s CHR declines slower
than “Without cluster”. With 500 and 1000 content items, the proposed scheme
shows 4.7% and 11.8% higher CHR, respectively, indicating better performance
as content quantity increases. The K-means++ clustering ensures better cache
space utilization within the domain, enabling more effective content caching and
improved performance.
In Fig. 5(b), it is evident that content download latency increases with the
content quantity. This is due to the limited cache capacity of edge nodes, which
results in many contents being cached only on the cloud server, leading to higher
download latency. Similarly, the proposed scheme performs better when the con-
tent quantity is larger. Compared to the “Without cluster” approach, the pro-
posed algorithm can make full use of the cache space of ordinary nodes, storing
more content and thus reducing content download latency.
6 Conclusion
This paper studies the caching strategy from the perspective of uneven traffic
load across edge nodes and proposes a DQN-based edge collaborative caching
algorithm. Firstly, based on the network traffic data of edge nodes, they are
classified into hotspot nodes and ordinary nodes. Using the K-means++ algo-
rithm, edge nodes are clustered based on their traffic data and distance. Then,
the delay and caching costs associated with caching content on edge nodes and
within the collaborative caching domain are analyzed to establish the content
caching problem. To minimize system costs, a DQN-based approach is proposed
to optimize caching decisions. Simulation results show that the DQN-based edge
406 C. Yang et al.
Acknowledgments. This work was sponsored by the National Natural Science Foun-
dation of China (Grant No. 62271264).
References
1. Abbas, N., Zhang, Y., Taherkordi, A., Skeie, T.: Mobile edge computing: a survey.
IEEE Internet Things J. 5(1), 450–465 (2017)
2. Wang, J., Zhao, L., Liu, J., Kato, N.: Smart resource allocation for mobile edge
computing: a deep reinforcement learning approach. IEEE Trans. Emerg. Top.
Comput. 9(3), 1529–1541 (2019)
3. Tran, T.X., Hajisami, A., Pandey, P., Pompili, D.: Collaborative mobile edge com-
puting in 5G networks: new paradigms, scenarios, and challenges. IEEE Commun.
Mag. 55(4), 54–61 (2017)
4. Peng, M., Sun, Y., Li, X., Mao, Z., Wang, C.: Recent advances in cloud radio access
networks: system architectures, key techniques, and open issues. IEEE Commun.
Surv. Tutor. 18(3), 2282–2308 (2016)
5. Li, P., Gong, S., Gao, S., Hu, Y., Pan, Z., You, X.: Delay-constrained sleeping
mechanism for energy saving in cache-aided ultra-dense network. Sci. China Inf.
Sci. 62, 1–14 (2019)
6. Li, C., Zhang, Y., Gao, X., Luo, Y.: Energy-latency tradeoffs for edge caching and
dynamic service migration based on DQN in mobile edge computing. J. Parallel
Distrib. Comput. 166, 15–31 (2022)
7. Kong, X., et al.: Deep reinforcement learning-based energy-efficient edge computing
for internet of vehicles. IEEE Trans. Industr. Inf. 18(9), 6308–6316 (2022)
8. Vallero, G., Deruyck, M., Joseph, W., Meo, M.: Caching at the edge in high energy-
efficient wireless access networks. In: ICC 2020-2020 IEEE International Conference
on Communications (ICC), pp. 1–7. IEEE (2020)
9. Kang, L., Tang, B., Zhang, L., Tang, L.: Mobility-aware and data caching-based
task scheduling strategy in mobile edge computing. In: 2019 IEEE International
Conference on Parallel & Distributed Processing with Applications, Big Data &
Cloud Computing, Sustainable Computing & Communications, Social Computing
& Networking (ISPA/BDCloud/SocialCom/SustainCom), pp. 1071–1077. IEEE
(2019)
10. Zheng, C., Liu, S., Huang, Y., Yang, L.: Hybrid policy learning for energy-latency
tradeoff in MEC-assisted VR video service. IEEE Trans. Veh. Technol. 70(9), 9006–
9021 (2021)
11. Lin, P., Song, Q., Jamalipour, A.: Multidimensional cooperative caching in CoMP-
integrated ultra-dense cellular networks. IEEE Trans. Wireless Commun. 19(3),
1977–1989 (2019)
12. Yuan, P., Shao, S., Geng, L., Zhao, X.: Caching hit ratio maximization in mobile
edge computing with node cooperation. Comput. Netw. 200, 108507 (2021)
13. Chen, X., He, L., Xu, S., Hu, S., Li, Q., Liu, G.: Hit ratio driven mobile edge caching
scheme for video on demand services. In: 2019 IEEE International Conference on
Multimedia and Expo (ICME), pp. 1702–1707. IEEE (2019)
A Cooperative Caching Strategy Based on DQN for Mobile Edge Networks 407
14. Wei, X., Liu, J., Wang, Y., Tang, C., Hu, Y.: Wireless edge caching based on content
similarity in dynamic environments. J. Syst. Architect. 115, 102000 (2021)
15. Li, Z., Gao, X., Li, Q., Guo, J., Yang, B.: Edge caching enhancement for industrial
internet: a recommendation-aided approach. IEEE Internet Things J. 9(18), 16941–
16952 (2022)
16. Wang, Q., Guo, S., Liu, J., Pan, C., Yang, L.: MotiShare: incentive mechanisms
for content providers in heterogeneous time-varying edge content market. IEEE
Trans. Serv. Comput. 16(1), 452–465 (2021)
17. He, J., Wang, H., Chu, X., Zhang, T.: Incentive mechanism and content provider
selection for device-to-device-based content sharing. IEEE Trans. Veh. Technol.
68(3), 2946–2957 (2019)
18. Li, C., Zhang, Y., Song, M., Yan, X., Luo, Y.: An optimized content caching
strategy for video stream in edge-cloud environment. J. Netw. Comput. Appl. 191,
103158 (2021)
YOLO-LiteMax: An Improved Model
for UAV Small Object Detection
1 Introduction
With the increasing popularity of drones, their applications are expanding across
diverse fields such as agriculture [18], logistics [13], surveillance [16], environmen-
tal monitoring [2], and disaster response [6]. These applications take full advan-
tage of drones’ unique capabilities, such as high mobility, flexibility, and access
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 408–422, 2025.
https://doi.org/10.1007/978-981-96-4566-4_28
YOLO-LiteMax: An Improved Model for UAV Small Object Detection 409
to hard-to-reach areas [4]. Accurate and efficient object detection is essential for
the effective use of drones in these fields. Object detection is one of the core
tasks in the field of computer vision, providing essential support for autonomous
navigation, precise operations, and situational awareness for drones by automat-
ically identifying and locating specific objects in images or videos. For example,
in agriculture, object detection technology can be used to assess crop health,
monitor pests, and evaluate harvest conditions, thereby improving the efficiency
and accuracy of agricultural management. In logistics and parcel delivery, object
detection facilitates automatic identification and tracking of packages, enabling
automated sorting and delivery. In disaster response and environmental moni-
toring, object detection helps identify critical objects in emergency situations,
such as distressed individuals, fire hotspots, or flooded areas, providing timely
and reliable information for rescue operations. Therefore, the performance of
object detection directly impacts the effectiveness of drone applications in these
scenarios. Advances in deep learning-based object detection algorithms have sig-
nificantly driven progress in this field, greatly improving the accuracy and effi-
ciency of visual recognition tasks [14, 20, 21]. Despite these advancements, drone-
based object detection still faces several challenges. The aerial perspective often
results in small object sizes and low resolution, which makes precise detection
difficult due to limited pixel information [24]. Additionally, drones’ constrained
computational resources restrict the deployment of complex and computation-
intensive models [7]. Moreover, the continuously changing environmental con-
ditions, dynamic backgrounds, and frequent changes in camera angles further
increase the complexity of reliable object detection from drones [24]. Although
deep learning techniques have led to significant improvements in object detec-
tion, these challenges still hinder real-time performance and accuracy in drone
applications. To address these limitations, this study aims to develop a more
efficient and accurate object detection model, specifically optimized for drone
imagery, to enhance its robustness in complex environments while maintaining
computational efficiency.
In deep learning-based object detection for drones, two main approaches
exist: two-stage detectors and one-stage detectors. Two-stage detectors, such
as Faster R-CNN [20], first generate candidate regions of interest in the image
through a region proposal network, which are then further analyzed for precise
object classification and bounding box refinement. This approach is well-suited
for complex drone imagery involving challenging backgrounds and densely dis-
tributed targets. However, the two-stage process demands substantial computa-
tional resources, limiting its practicality for real-time applications on resource-
constrained drone devices. Conversely, one-stage detectors, like You Only Look
Once (YOLO) [19] and Single Shot MultiBox Detector (SSD) [15], process the
entire image in a single pass, directly predicting class probabilities and bounding
boxes from a grid of predefined anchor boxes. By combining classification and
localization in one step, one-stage detectors achieve faster processing speeds.
Nonetheless, they often show reduced accuracy for small or densely clustered
objects, particularly in complex backgrounds, leading to a higher likelihood of
false positives and missed detections. In summary, although existing two-stage
410 J. Su et al.
(1) We redesign the CSP module by introducing the FasterNet Block, effec-
tively reducing redundant feature information and lowering model complex-
ity without compromising detection accuracy. This improvement addresses
the issue of excess computational load in previous models.
(2) We redesign the neck network by proposing the STSSF pyramid structure
to solve the problem that simple summation and concatenation in the orig-
inal feature pyramid cannot take full advantage of the correlation between
feature maps This new structure incorporates the P2 detection scale and
optimizes feature fusion, significantly enhancing small object detection and
improving model robustness in complex scenes.
(3) To address the limitations of the original YOLOv8 detection head in han-
dling small batch sizes, effective feature fusion, and capturing both local
and global context, we propose the SCPD head. By incorporating shared
convolution and group normalization, SCPD head achieves higher detection
accuracy with fewer parameters by improving the stability and consistency
of feature extraction.
(4) To validate the effectiveness of the proposed methods, we conduct exten-
sive ablation studies and performance comparisons on the VisDrone2019
dataset. The dataset features various challenges, such as small object sizes,
low resolution, and dynamic backgrounds, which align well with the issues
discussed in this study. The results show that our approach offers significant
improvements over existing methods, especially in detecting small objects,
while maintaining computational efficiency, making it ideal for drone-based
applications.
2 Related Work
In drone-based object detection, several studies have proposed methods to
enhance detection accuracy and adaptability. Sahin et al. [22] introduced YOLO-
Drone, which improved the detection of multi-scale objects, particularly small
ones, by adding additional convolutional and detection layers. However, these
modifications increased computational requirements, limiting the model’s effec-
tiveness in handling complex occlusions. Albaba et al. [1] proposed SyNet, an
YOLO-LiteMax: An Improved Model for UAV Small Object Detection 411
Fig. 1. The overall structure of the proposed improved model. There are two C3DSF
modules in the figure, the first one is an orange route and the second one is a blue
route. The red dashed box represents a novel enhancement method for P2 feature maps.
(Color figure online)
412 J. Su et al.
In the SCB module, the FasterNet Block from FasterNet [5] replaces the orig-
inal Bottleneck structure in the CSP module. The core concept of the FasterNet
Block is Partial Convolution. In this approach, convolution is applied only to a
subset of the input channels, while the remaining channels are processed using
identity mapping. This reduces computational overhead by limiting the num-
ber of channels involved in convolution, while the non-convolved channels are
retained and passed through subsequent layers. As a result, FasterNet Block
minimizes computational complexity without sacrificing the flow of information
and feature propagation through the network. The SCB module operates as fol-
lows: First, the feature map is processed by a CBS module, which consists of
conventional convolution, batch normalization, and SiLU activation. The output
is then passed into a Split layer, which divides the feature map into two parts,
each containing half of the total channels, and stores the results in a list. The
output from each FasterNet Block is then fed into the next block and added
to the list. After passing through several FasterNet Blocks, all feature maps in
the list are concatenated along the channel dimension and processed by another
CBS module to produce the final output (Fig. 3).
Compared to the original YOLOv8, the SCB module introduces the Faster-
Net Block, which addresses the computational inefficiency of YOLOv8. In
YOLOv8, the Bottleneck structure convolves all input channels, leading to high
computational overhead. In contrast, the SCB module reduces this overhead by
applying convolution only to a subset of the channels, while ensuring that the
remaining channels are still retained and processed later. This selective convolu-
tion significantly reduces the computational complexity and number of parame-
ters while maintaining detection accuracy.
YOLO-LiteMax: An Improved Model for UAV Small Object Detection 413
while also preventing small-object features from being lost during upsampling.
Finally, the large, medium, and small-size feature maps are concatenated along
the channel dimension to form a complete feature map. Unlike traditional sim-
ple concatenation, the SAC module not only adds large-size feature layers but
also processes feature maps at different sizes differently. This design effectively
enhances the model’s ability to capture small-object and detailed information,
thereby improving overall detection performance.
Since the targets in the analyzed images are often dense and small, we propose
a novel enhancement method for P2 feature maps in this section, as shown in the
red dashed box in Fig. 1. Different from the traditional way of adding small-scale
feature maps, we delay the upsampling step because the first C3DSF module has
already performed the preliminary fusion of the features in the P3, P4, and P5
layers, which completes the integration of the multi-scale features in advance
and reduces the unnecessary downsampling operation, which in turn reduces the
computational amount. Considering that the addition of the SAC module will
utilize the P1 layer feature maps, which will significantly increase the number
of parameter and computations, we choose to fuse the P2 layer through upsam-
pling and concatenate operation. Additionally, we introduce a second C3DSF
module to fuse the P2 layer with the processed P3 and P4 layer features. With
this design, the features from the P3 and P4 layers are subjected to deeper pro-
cessing after integrating more contextual information, and then merge with the
detailed information from the P2 layer. This enhances the model’s ability to
detect small objects. Compared to the SSF pyramid structure of ASF-YOLO,
this improvement not only retains high-resolution detail information when han-
dling small objects but also integrates deep semantic information, resulting in a
model that is more robust in complex scenarios.
Fig. 5. The structure of SCPD. The modules in the red dashed line represent shared
convolution. (Color figure online)
The VisDrone2019 dataset [9] is a widely used collection of drone aerial imagery,
compiled, annotated, and organized by the AISKYEYE data mining team at
Tianjin University. This dataset spans multiple cities in China and features
416 J. Su et al.
In the experiment, we use Ubuntu 20.04 as the operating system, with PyTorch
2.2.2 and Cuda 12.1 as the computational software environment. The hardware
configuration included an i9-13900KF CPU and an NVIDIA RTX 4090 GPU
with 24 GB of video memory. The experiments in this study were conducted
using hardware primarily equipped with an NVIDIA RTX 3080 GPU and an i5-
12600KF CPU. The input image size was 640 × 640, with 400 epochs. The batch
size was 8, the initial learning rate was 0.01. The evaluation metrics include Pre-
cision, Recall, mAP, and Parameter. Precision is a key metric for evaluating the
performance of classification models, primarily used to measure how accurately
the model predicts positive samples. The formula is as follows.
TP
P recision = × 100% (1)
TP + FP
where T P is true positives, indicating the number of samples correctly predicted
as positive; F P is a false positive, indicating the number of samples that are
incorrectly predicted to be positive.
Recall measures a model’s ability to correctly identify positive samples. It
ranges from 0 to 1, with a higher value indicating that the model is better at
identifying positive samples. The formula is as follows.
TP
Recall = × 100% (2)
TP + FN
where F N represents the number of samples that were actually positive but
incorrectly predicted to be negative.
Mean Average Precision (mAP) is a key metric for evaluating model perfor-
mance in object detection and information retrieval tasks. mAP considers both
accuracy and recall across multiple categories, providing a more comprehensive
assessment of the model’s performance. A higher mAP value indicates better
detection performance. The formula is as follows.
N
AP = (Rn − Rn−1 )Pn (3)
n=1
where Pn and Rn are the precision and recall rates corresponding to the N th
recall rate, respectively.
YOLO-LiteMax: An Improved Model for UAV Small Object Detection 417
C
1
mAP = APi (4)
C i=1
where C is the total number of classes and APi is the average precision of class
i.
Parameter is an important metric for evaluating the complexity of a deep
learning model, representing the total count of learnable parameters. Typically,
the number of parameters reflects the model’s capacity, influencing its expres-
siveness and the computational resources required for training.
In this study, several enhancements are made to the original YOLOv8 model
to improve both the precision and efficiency of object detection. To verify the
effectiveness of these improvements, we design two sets of comparative exper-
iments. The first set compares the enhanced model with several YOLO series
algorithms, while the second set compares it with other high-performance mod-
els. The evaluation metrics include accuracy, recall, mAP, parameter, and model
size.
Compared to YOLOv5s, YOLOv8s, and YOLOv10s [25], YOLO-LiteMax
improves mAP50 by 12%, 5.9%, and 4.1%, respectively, and enhances mAP50-95
by 9.3%, 3.8%, and 2.6%, while maintaining a smaller parameter count and model
size. Additionally, compared to the latest YOLOv11s algorithm, the improved
model achieves an increase of 4.9% in mAP50 and 3.1% in mAP50-95. These
results demonstrate that YOLO-LiteMax delivers the best overall performance
for detecting densely packed small objects. Despite a slight increase in inference
time, the improved model still achieves real-time detection capabilities (Table 1).
Table 1. Detection results of some YOLO series models and the proposed model. (The
bold data in the table indicate the best results.)
Table 2. Detection results of the classical model and the proposed model. (The bold
data in the table indicate the best results.
Baseline SCB STSSF SCPD Precision/% Recall/% mAP50/% mAP50-95/% Params/M Model Size/MB
YOLOv8s 50.9 38.2 39.3 23.5 11.1 21.5
50.5 38.3 39.5 23.4 8.3 16.1
55.8 41.7 44.2 26.5 6.9 13.6
56.3 42.7 45.2 27.3 6.1 12.2
Figure 6, the white box at the top of the image contains several motorcycles.
None of the YOLOv8s, YOLOv10s, or YOLOv11s models detected the motor-
cycles, whereas YOLO-LiteMax is able to identify most of the targets. Addi-
tionally, on the left side of the image, there is an adult and a child; YOLOv8s,
YOLOv10s, and YOLOv11s only recognize the adult, whereas YOLO-LiteMax
accurately identifies both the adult and the child.
Fig. 6. Public facilities. (a) results of yolov8s; (b) results of yolov10s; (c) results of
yolov11s; (d) results of YOLO-LiteMax.
Figure 7 shows a traffic junction where four cyclists are marked with a white
box on the crosswalk. YOLOv8s and YOLOv11s detect only one individual,
YOLOv10s fails to detect any targets, while YOLO-LiteMax successfully detects
all four cyclists. Additionally, the traffic sign highlighted in the white box in
the center of the image is mistakenly identified as a car by YOLOv10s and
YOLOv11s, while both YOLOv8s and YOLO-LiteMax correctly avoid this mis-
classification.
YOLO-LiteMax: An Improved Model for UAV Small Object Detection 421
Fig. 7. Traffic intersection. (a) results of yolov8s; (b) results of yolov10s; (c) results of
yolov11s; (d) results of YOLO-LiteMax.
References
1. Albaba, B.M., Ozer, S.: SyNet: an ensemble network for object detection in UAV
images. In: 2020 25th International Conference on Pattern Recognition (ICPR),
pp. 10227–10234. IEEE (2021)
2. Burgués, J., Marco, S.: Environmental chemical sensing using small drones: a
review. Sci. Total Environ. 748, 141172 (2020)
3. Cai, Z., Vasconcelos, N.: Cascade R-CNN: delving into high quality object detec-
tion. In: Proceedings of the IEEE Conference on Computer Vision and Pattern
Recognition, pp. 6154–6162 (2018)
4. Chang, Y.C., Chen, H.T., Chuang, J.H., Liao, I.C.: Pedestrian detection in aerial
images using vanishing point transformation and deep learning. In: 2018 25th IEEE
International Conference on Image Processing (ICIP), pp. 1917–1921. IEEE (2018)
5. Chen, J., et al.: Run, don’t walk: chasing higher flops for faster neural networks.
In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern
Recognition, pp. 12021–12031 (2023)
6. Daud, S., et al.: Applications of drone in disaster management: a scoping review.
Sci. Justice 62(1), 30–42 (2022)
7. Deng, J., Shi, Z., Zhuo, C.: Energy-efficient real-time UAV object detection on
embedded platforms. IEEE Trans. Comput. Aided Des. Integr. Circuits Syst.
39(10), 3123–3127 (2019)
422 J. Su et al.
8. Detector, A.F.O.: FCOS: a simple and strong anchor-free object detector. IEEE
Trans. Pattern Anal. Mach. Intell. 44(4) (2022)
9. Du, D., et al.: VisDrone-DET2019: the vision meets drone object detection in image
challenge results. In: Proceedings of the IEEE/CVF International Conference on
Computer Vision Workshops, pp. 0–0 (2019)
10. Duan, K., Bai, S., Xie, L., Qi, H., Huang, Q., Tian, Q.: CenterNet: keypoint triplets
for object detection. In: Proceedings of the IEEE/CVF International Conference
on Computer Vision, pp. 6569–6578 (2019)
11. Kang, M., Ting, C.M., Ting, F.F., Phan, R.: ASF-YOLO: a novel YOLO model
with attentional scale sequence fusion for cell instance segmentation. Image Vis.
Comput. 147, 105057 (2024)
12. Li, Y., Chen, Y., Wang, N., Zhang, Z.: Scale-aware trident networks for object
detection. In: Proceedings of the IEEE/CVF International Conference on Com-
puter Vision, pp. 6054–6063 (2019)
13. Li, Y., Liu, M., Jiang, D.: Application of unmanned aerial vehicles in logistics: a
literature review. Sustainability 14(21), 14473 (2022)
14. Lin, T.Y., Dollár, P., Girshick, R., He, K., Hariharan, B., Belongie, S.: Feature
pyramid networks for object detection. In: Proceedings of the IEEE Conference on
Computer Vision and Pattern Recognition, pp. 2117–2125 (2017)
15. Liu, W., et al.: SSD: single shot MultiBox detector. In: Leibe, B., Matas, J., Sebe,
N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9905, pp. 21–37. Springer, Cham
(2016). https://doi.org/10.1007/978-3-319-46448-0_2
16. Lo, L.Y., Yiu, C.H., Tang, Y., Yang, A.S., Li, B., Wen, C.Y.: Dynamic object
tracking on autonomous UAV system for surveillance applications. Sensors 21(23),
7888 (2021)
17. Lou, H., et al.: DC-YOLOv8: small-size object detection algorithm based on camera
sensor. Electronics 12(10), 2323 (2023)
18. Maes, W.H., Steppe, K.: Perspectives for remote sensing with unmanned aerial
vehicles in precision agriculture. Trends Plant Sci. 24(2), 152–164 (2019)
19. Redmon, J.: You only look once: unified, real-time object detection. In: Proceedings
of the IEEE Conference on Computer Vision and Pattern Recognition (2016)
20. Ren, S., He, K., Girshick, R., Sun, J.: Faster R-CNN: towards real-time object
detection with region proposal networks. IEEE Trans. Pattern Anal. Mach. Intell.
39(6), 1137–1149 (2016)
21. Ross, T.Y., Dollár, G.: Focal loss for dense object detection. In: proceedings of the
IEEE Conference on Computer Vision and Pattern Recognition, pp. 2980–2988
(2017)
22. Sahin, O., Ozer, S.: Yolodrone: Improved yolo architecture for object detection in
drone images. In: 2021 44th International Conference on Telecommunications and
Signal Processing (TSP), pp. 361–365. IEEE (2021)
23. Sun, W., Dai, L., Zhang, X., Chang, P., He, X.: RSOD: real-time small object
detection algorithm in UAV-based traffic monitoring. Appl. Intell., 1–16 (2022)
24. Tang, G., Ni, J., Zhao, Y., Gu, Y., Cao, W.: A survey of object detection for UAVs
based on deep learning. Remote Sens. 16(1), 149 (2023)
25. Wang, A., et al.: YOLOv10: real-time end-to-end object detection. arXiv preprint:
arXiv:2405.14458 (2024)
26. Wu, Y., He, K.: Group normalization. In: Proceedings of the European Conference
on Computer Vision (ECCV), pp. 3–19 (2018)
27. Zhao, L., Zhu, M.: MS-YOLOv7: YOLOv7 based on multi-scale for object detection
on UAV aerial photography. Drones 7(3), 188 (2023)
LMCF-FS: A Novel Lightweight Malware
Classification Framework Driven
by Feature Selection
Cui Yun1 , Lei Zhou2,3 , Shuangshuang Xing2,3 , Ning Yang2,3 , Pan Zhao2,3 ,
and Zhiguo Chen2,3(B)
1
School of Computer, Jiangsu University of Science and Technology,
Zhenjiang 212100, China
[email protected]
2
Engineering Research Center of Digital Forensics, Ministry of Education, Nanjing
University of Information Science and Technology, Nanjing 210044, China
{zhoulei,shuangshuangxing,yangning,zhaopan}@nuist.edu.cn
3
School of Computer Science and School of Cyber Science and Engineering, Nanjing
University of Information Science and Technology, Nanjing 210044, China
[email protected]
Abstract. The rapid increase in the number of malware and its variants
poses a significant threat to internet security. While existing machine
learning-based malware classification methods can improve accuracy,
they often require complex feature engineering. In contrast, converting
executable files into images and classifying them using deep learning
models can reduce the reliance on prior knowledge of malware features
and simplify the feature processing. However, most studies have not ade-
quately considered the impact of image pixel features and size on classi-
fication results. In addition, the model structure of deep learning classi-
fication methods based on pre-trained parameters and transfer learning
is usually more complex, resulting in the classification model requiring
more memory and time overhead. To solve the above problems, this paper
proposes a novel lightweight malware classification framework LMCF-FS
driven by feature selection. This framework uses a proposed new fea-
ture selection method based on pixel pair features to improve feature
sparsity and optimize image texture features, thereby improving feature
expression capabilities. At the same time, an interpolation algorithm is
used to balance the image size. In addition, a new lightweight malware
classification model ConvInceptionNet is built based on the Inception-C
module to balance accuracy, computational cost, and number of param-
eters, thereby improving malware classification efficiency. Our proposed
LMCF-FS framework achieves a classification accuracy of 99.12% on the
Microsoft Malware Classification Challenge Dataset (BIG2015 Dataset),
and it only takes 0.92 ms to predict a 224 × 224 malware image, verifying
LMCF-FS effectiveness in dealing with malware classification.
1 Introduction
With the widespread application of the 5G network, Internet of Things technol-
ogy, and cloud computing, people’s lives have been brought great convenience.
However, this convenience is also accompanied by the increasing threat of mal-
ware, which brings security risks such as data leakage, malicious extortion, and
privacy theft to user terminals. According to statistics, the number of mal-
ware incidents has risen annually over the past decade, with over 1.35 billion
pieces detected by 2024 [1]. This phenomenon shows that cybercriminals are
constantly exploring new means and techniques to evade detection by security
measures. Faced with this challenge, security researchers are working to research
and develop more effective defenses against malware attacks.
Traditional malware static analysis technology uses machine learning meth-
ods to automatically learn malware characteristics and behavior patterns
through training models to achieve automated classification [7], thereby reducing
classification time and improving classification performance. Although machine
learning methods have brought certain improvements to malware classification,
traditional machine learning algorithms such as logistic regression, support vec-
tor machines, and decision trees [26] still require professional knowledge related
to the field of malware classification to carry out relevant feature extraction and
selection. This means that once malware developers learn the signatures used by
systems, they can evade systems relatively easily.
Compared with machine learning algorithms, deep learning technology, with
its multi-layered neural network structure, can automatically learn deep-level
distinguishing features in original binary files or code data without the need
for manual design by security researchers [4,28]. Therefore, deep learning tech-
nology is more suitable to deal with the challenges of new malware and its
variants, and can effectively improve classification accuracy and generalization
capabilities. Nevertheless, for malware that uses obfuscation technology, whether
its unpacking or decryption is successful will directly affect the performance of
the classification model. In order to avoid the impact of obfuscation technol-
ogy, many scholars explore visualizing executable programs as images and apply
deep learning models for malware classification [8,16,27]. However, there are
some issues with this method of converting malware into images. For example,
the impact of image pixel characteristics and size on classification results is not
fully considered during conversion, which may result in reduced classification
accuracy. In addition, some studies use methods such as pre-training parameters
and transfer learning to improve classification accuracy and model performance,
but these methods often cause the model to become large, complex, and lack
flexibility.
To mitigate the impact of obfuscation techniques on static malware clas-
sification research, while reducing model complexity and enhancing the accu-
racy and performance of malware classification tasks, we propose a lightweight
malware classification framework driven by feature selection, called LMCF-FS.
This framework introduces a novel feature selection method to optimize and
enhance image texture features, thereby improving the model’s classification per-
LMCF-FS 425
2 Related Work
2.1 Malware Classification Based on Feature Selection
deep learning can address the complex feature relationships posed by different
types of malware and their variants, exhibiting strong generalization capabil-
ities. For instance, Gibert et al. [6] achieved classification results comparable
to gradient boosting methods across multiple malware datasets by designing
a multimodal network model that combines various feature types, including
images and text. Vasan et al. [24] attained accuracies of 98.82% and 97.35%
on multiple datasets by converting raw binary files into color images and uti-
lizing a fine-tuned CNN architecture for classification. Additionally, Awan et
al. [3] effectively enhanced the model’s ability to extract malware image fea-
tures by combining a spatial attention mechanism with a convolutional neural
network (SACNN), achieving a classification accuracy of 97.42%. Kumar et al.
[13] applied deep learning to Internet of Things (IoT) malware classification
by integrating traditional convolutional neural networks with transfer learning,
attaining classification accuracies of 99.18% and 98.63% across multiple datasets
through fine-tuning the CNN. Additionally, Mallik et al. [15] combined convo-
lutional networks with recurrent networks to capture the temporal features of
malware, achieving a classification accuracy of 98.36% on the MalImg dataset,
showcasing the powerful capability of deep learning in handling malware behav-
ior patterns.
These studies demonstrate that deep learning models can autonomously learn
and construct hierarchical feature representations from raw data, thereby cap-
turing subtle differences in malware behavior for precise classification. However,
these models rely on pre-trained parameters and transfer learning, which signif-
icantly increase computational resource demands and complicate the architec-
ture, thus limiting their applicability in resource-constrained environments.
3 Methodology
3.1 Overview
The malware classification framework LMCF-FS proposed in this paper is shown
in Fig. 1 and is divided into four parts: image visualization, feature selection,
image enhancement, and model construction. The image visualization module
converts the malware’s Byte file into an RGB image. The feature selection mod-
ule sorts the pixel pairs of the RGB images by probability to extract key pixel
pair features and optimize the texture features of the malware images. The
image enhancement module uses the bicubic interpolation algorithm to equalize
the image size. The model-building module combines the convolutional neural
network and the Inception-C module Combined, a lightweight malware classi-
fication model ConvInceptionNet is built to improve the efficiency of malware
family classification.
256 × 256
N= (1)
2n
Among them, n = {0, 1, 2, . . . , 7}.
Through the above feature selection process, the pixel values of the selected
pixel pairs remain unchanged, while the pixel values of the unselected pixel
pairs are replaced with 0 respectively. The image after feature selection is shown
in Fig. 3. This process is designed to optimize image texture features, improve
feature sparsity, and enhance the discrimination of malware classification during
the feature selection process to improve the effectiveness of malware classification
algorithms.
Due to differences in the content length of malware files, the size of the
generated RGB images is not uniform. To solve this problem, we use the bicubic
interpolation algorithm for image enhancement and adjust all images to the
standard size of 224 × 224 to improve the efficiency of malware classification.
Bicubic interpolation can preserve the details of the original image as much as
possible and produce smoother edges [11]. As shown in Fig. 4, this algorithm
uses the gray value of 16 points around the sampling point to perform cubic
interpolation, while taking into account the influence of the gray value of the
adjacent 4 points and the influence of the change rate of the gray value. The
formula for calculating the pixel value is as follows:
LMCF-FS 429
⎧
⎪
⎨(a + 2)|x| − (a + 3)|x| + 1
3 2
for |x| ≤ 1
W (x) = a|x|3 − 5a|x|2 + 8a|x| − 4a for 1 < |x| < 2 (2)
⎪
⎩
0 otherwise
3
3
f (x, y) = f (xi , yj )W (x − xi )W (y − yi ) (3)
i=0 j=0
Among them, is the distance from the pixel point (x, y) to the nearest 16
sample points. a generally takes 1 or −0.5. For the interpolated pixel point (x,
y) (x, y can be a floating point number), select points near 4 × 4 for weighted
summation, and calculate according to Eq. 3.
3.3 ConvInceptionNet
The Inception model stands out in the field of deep learning for its powerful
feature extraction capabilities and efficient parallel computing capabilities, espe-
cially in image recognition tasks [21]. This model processes input data in parallel
through multiple convolution kernels of different scales and can capture differ-
ent levels of feature information, thus improving the accuracy of classification.
Therefore, we draw on its efficient feature extraction mechanism and combine
it with a lightweight network structure to design a lightweight network model
ConvInceptionNet to achieve high efficiency and accuracy in malware classifica-
tion and adapt to device resource-limited environments and mobile applications.
The model architecture is shown in Fig. 5.
Firstly, image features are input to three convolutional layers: the first con-
volutional layer has 64 filters of size 5 × 5 with a stride of 2 × 2; the second
convolutional layer has 192 filters of size 1 × 1 filter with a stride of 2 × 2;
430 C. Yun et al.
Fig. 3. The malware images after feature selection, where N is the number of features
after the activation function and a 3 × 3 max pooling layer, the stride is 2 × 2;
the third convolutional layer has 256 filters of size 3 × 3, with a stride of The
width is 2 × 2. Secondly, two Inception-C modules, pooling layers, and specific
feature selection methods are used to improve the efficiency and performance
of the model. Each Inception-C module consists of multiple branches, includ-
ing 1 × 1 convolution, 1 × 3 convolution, 3 × 1 convolution, and average pooling.
Each branch produces 256 output channels (filters). Finally, the FC layer uses
a drop rate of 0.1. Our proposed model can learn features at different scales by
using multiple convolution kernels of different sizes and structures to process
input feature maps in parallel, thereby obtaining richer feature expressions and
achieving accurate classification of malware images.
LMCF-FS 431
4 Experimental Evaluation
4.1 Data Set
To reduce redundant features of RGB images, and optimize image texture fea-
tures, we use the method described in Sect. 3.2 to perform feature selection on the
RGB images of the BIG2015 dataset. From the experimental results in Table 1, it
can be concluded that when feature selection is performed on the original image,
the accuracy increases as the number of features decreases. In particular, when
the number of pixel pair features is reduced from 65536 to 1024, the malware
classification accuracy increases from 98.43% to 98.89%, the precision increases
from 98.46% to 98.75%, and the recall rate increases from 98.34% to 98.76%, and
F1 The score increased from 98.40% to 98.76%. However, when the number of
features is less than 512, the classification performance of the malware classifica-
tion model decreases significantly because the key features of the image are lost.
It can be concluded that appropriately reducing the redundant pixel features of
RGB images can improve the efficiency of malware classification, but there is a
trade-off between the number of features and classification performance.
In addition, in order to reduce the impact of malware image size imbalance
on the classification effect, we use the bicubic interpolation algorithm introduced
in Sect. 3.2 to perform image enhancement on the RGB images of the BIG2015
dataset, providing more consistent data for subsequent malware classification.
The experimental results are shown in Table 1. By applying feature selection
and standardizing image sizes, the accuracy, recall, and F1 score of malware
classification increased by 0.19%, 0.23%, and 0.11%, respectively. In particular,
when the number of pixel pair features is reduced to 1024, the malware classifi-
cation accuracy reaches 99.12%, the precision reaches 99.24%, the recall reaches
99.01%, and the F1 score reaches 99.12%, achieving the best classification perfor-
mance. The results indicate that the combination of feature selection and image
enhancement can significantly improve the efficiency and accuracy of malware
classification systems.
LMCF-FS 433
Model Accuracy Training time per epoch Prediction Time Parameters GPU Memory (MiB)
SqueezeNet 97.28% 9s 1.38 ms 740,041 3280
ShuffleNetV2 98.16% 10 s 1.82 ms 2,720,383 2226
MobileNetV3 98.25% 10 s 1.84 ms 1,521,463 2476
EfficientNetB0 98.25% 11 s 1.87 ms 6,227,797 3840
ConvInceptionNet 98.43% 7 s 0.92 ms 1,268,312 2604
accuracy as the MobileNetV3 model, its number of parameters is 4.8 times that
of ConvInceptionNet. This shows that EfficientNetB0 sacrifices the lightweight
advantage of the model while pursuing performance improvement. The ConvIn-
ceptionNet model achieves a combination of high performance and lightweight
by optimizing the network structure.
Therefore, through comparative experiments, the ConvInceptionNet model
has demonstrated comprehensive performance superior to four typical light-
weight network models in key indicators such as accuracy, training/prediction
time, parameter volume, and memory usage.
Mallik et al. [15] used a feature extractor based on VGG16 to extract visual
outliers, and merged the features with the output of the VGG16 layer through
two BiLSTM layers, achieving an accuracy of 98.36% on the BIG2015 dataset.
Acharya et al. [2] proposed a lightweight model based on EfficientNetB1, which
LMCF-FS 435
achieved an accuracy of 98.57% on the BIG2015 dataset. Zou et al. [29] pro-
posed a lightweight IMCLNet model that integrates coordinate attention, depth-
separable convolution, and global context embedding, achieving an accuracy of
99.11% on the BIG2015 dataset.
Unlike the above research methods, this paper performs feature selection and
image enhancement on the RGB images of malware. Then we input them into
the lightweight model ConvInceptionNet for classification. This method achieves
comprehensive improvements in precision, recall, F1 score, and accuracy while
maintaining high efficiency. Among them, it is worth noting that the IMCLNet
model of Zou et al. [30] predicts a 32 × 32 size malware image in 0.84 ms, while
the LMCF-FS framework proposed in this paper predicts a 224 × 224 size mal-
ware image only needs 0.92 ms. Therefore, the malware images used in this paper
are larger and more informative, but the prediction time does not increase sig-
nificantly. This proves that LMCF-FS can maintain high classification accuracy
while effectively reducing time consumption when processing large-size images,
and is more suitable for resource-constrained scenes and mobile devices.
5 Conclusion
In order to further improve the efficiency of malware classification, this paper
proposes a lightweight malware classification framework driven by feature
selection(LMCF-FS). This framework integrates visualization technology, fea-
ture selection, image enhancement, and lightweight neural network models. We
convert the malware’s byte files into RGB images. Then we propose a new fea-
ture selection method based on pixel pair features. This method optimizes image
texture features by increasing the sparsity of features, which can significantly
improve the expressive ability of features. At the same time, we use the bicu-
bic interpolation algorithm to enhance the malware images, which effectively
solves the problem of the unbalanced size of the image dataset. In addition,
a lightweight model ConvInceptionNet based on the Inception-C module was
designed, which achieves a smaller parameter size and faster inference speed
while maintaining high performance. The LMCF-FS achieved 99.12%, 99.24%,
99.01%, and 99.12% accuracy, precision, recall, and F1 score respectively on the
BIG2015 data set, and the time to predict a 224 × 224 image is only 0.92 ms.
References
1. Av-test. (2020). https://www.av-test.org/en/statistics/malware/
2. Acharya, V., Ravi, V., Mohammad, N.: EfficientNet-based convolutional neural
networks for malware classification. In: 2021 12th International Conference on
Computing Communication and Networking Technologies (ICCCNT), pp. 1–6.
IEEE (2021)
436 C. Yun et al.
3. Awan, M.J., et al.: Image-based malware classification using VGG19 network and
spatial convolutional attention. Electronics 10(19), 2444 (2021)
4. Basha, S.S., Dubey, S.R., Pulabaigari, V., Mukherjee, S.: Impact of fully connected
layers on performance of convolutional neural networks for image classification.
Neurocomputing 378, 112–119 (2020)
5. Darem, A., Abawajy, J., Makkar, A., Alhashmi, A., Alanazi, S.: Visualization and
deep-learning-based malware variant detection using opcode-level features. Futur.
Gener. Comput. Syst. 125, 314–323 (2021)
6. Gibert, D., Mateu, C., Planes, J.: Hydra: a multimodal deep learning framework
for malware classification. Compute. Secur. 95, 101873 (2020)
7. Gibert, D., Mateu, C., Planes, J.: The rise of machine learning for detection and
classification of malware: research developments, trends and challenges. J. Netw.
Comput. Appl. 153, 102526 (2020)
8. Hemalatha, J., Roseline, S.A., Geetha, S., Kadry, S., Damaševičius, R.: An efficient
DenseNet-based deep learning model for malware detection. Entropy 23(3), 344
(2021)
9. Howard, A., et al.: Searching for MobileNetV3. In: Proceedings of the IEEE/CVF
International Conference on Computer Vision, pp. 1314–1324 (2019)
10. Iandola, F.N., Han, S., Moskewicz, M.W., Ashraf, K., Dally, W.J., Keutzer, K.:
SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and¡ 0.5 mb model
size. arXiv preprint: arXiv:1602.07360 (2016)
11. Keys, R.: Cubic convolution interpolation for digital image processing. IEEE Trans.
Acoust. Speech Signal Process. 29(6), 1153–1160 (1981)
12. Kong, Z., Xue, J., Wang, Y., Zhang, Q., Han, W., Zhu, Y.: MalFSM: feature subset
selection method for malware family classification. Chin. J. Electron. 32(1), 26–38
(2023)
13. Kumar, S., et al.: MCFT-CNN: malware classification with fine-tune convolution
neural networks using traditional and transfer learning in internet of things. Futur.
Gener. Comput. Syst. 125, 334–351 (2021)
14. Ma, N., Zhang, X., Zheng, H.T., Sun, J.: ShuffleNet V2: practical guidelines for
efficient CNN architecture design. In: Proceedings of the European Conference on
Computer Vision (ECCV), pp. 116–131 (2018)
15. Mallik, A., Khetarpal, A., Kumar, S.: ConRec: malware classification using convo-
lutional recurrence. J. Comput. Virol. Hack. Tech. 18(4), 297–313 (2022)
16. Nataraj, L., Karthikeyan, S., Jacob, G., Manjunath, B.S.: Malware images: visu-
alization and automatic classification. In: Proceedings of the 8th International
Symposium on Visualization for Cyber Security, pp. 1–7 (2011)
17. Ni, S., Qian, Q., Zhang, R.: Malware identification using visualization images and
deep learning. Comput. Secur. 77, 871–885 (2018)
18. Pinhero, A., et al.: Malware detection employed by visualization and deep neural
network. Comput. Secur. 105, 102247 (2021)
19. Ronen, R., Radu, M., Feuerstein, C., Yom-Tov, E., Ahmadi, M.: Microsoft malware
classification challenge. arXiv preprint: arXiv:1802.10135 (2018)
20. Su, J., Vasconcellos, D.V., Prasad, S., Sgandurra, D., Feng, Y., Sakurai, K.:
Lightweight classification of IoT malware based on image recognition. In: 2018
IEEE 42nd Annual Computer Software and Applications Conference (COMPSAC),
vol. 2, pp. 664–669. IEEE (2018)
21. Szegedy, C., Ioffe, S., Vanhoucke, V., Alemi, A.: Inception-v4, inception-ResNet
and the impact of residual connections on learning. In: Proceedings of the AAAI
Conference on Artificial Intelligence, vol. 31 (2017)
LMCF-FS 437
22. Tan, M., Le, Q.: EfficientNet: rethinking model scaling for convolutional neural net-
works. In: International Conference on Machine Learning, pp. 6105–6114. PMLR
(2019)
23. Tekerek, A., Yapici, M.M.: A novel malware classification and augmentation model
based on convolutional neural network. Comput. Secur. 112, 102515 (2022)
24. Vasan, D., Alazab, M., Wassan, S., Naeem, H., Safaei, B., Zheng, Q.: IMCFN:
image-based malware classification using fine-tuned convolutional neural network
architecture. Comput. Netw. 171, 107138 (2020)
25. Vasan, D., Hammoudeh, M., Alazab, M.: Broad learning: a GPU-free image-based
malware classification. Appl. Soft Comput. 154, 111401 (2024)
26. Wadkar, M., Di Troia, F., Stamp, M.: Detecting malware evolution using support
vector machines. Expert Syst. Appl. 143, 113022 (2020)
27. Yuan, B., Wang, J., Liu, D., Guo, W., Wu, P., Bao, X.: Byte-level malware clas-
sification based on Markov images and deep learning. Comput. Secur. 92, 101740
(2020)
28. Zhang, Z., Qi, P., Wang, W.: Dynamic malware analysis with feature engineer-
ing and feature learning. In: Proceedings of the AAAI Conference on Artificial
Intelligence, vol. 34, pp. 1210–1217 (2020)
29. Zou, B., Cao, C., Tao, F., Wang, L.: IMCLNet: a lightweight deep neural network
for image-based malware classification. J. Inf. Secur. Appl. 70, 103313 (2022)
30. Zou, B., Cao, C., Wang, L., Fu, S., Qiao, T., Sun, J.: FACILE: a capsule net-
work with fewer capsules and richer hierarchical information for malware image
classification. Comput. Secur. 137, 103606 (2024)
Rule Learning-Based Target Prediction
for Efficient and Flexible Private
Information Retrieval
Abstract. In recent years, machine learning has been widely used in all
aspects of social production and life, including transportation, finance,
etc., and predicting unknown situations based on existing information is
also a common application method of machine learning. Therefore, the
topic of prediction problems and machine learning models has always
been a common research hotspot and direction of machine learning. How-
ever, there are still many problems in the current common prediction
models. The prediction method based on the frequency and possibility
of data retrieval has insufficient prediction accuracy and high error rate.
The specific retrieval method obtained by training in the common model
is relatively fixed, which can not change in real time with the update
and change of the original data, and the dynamic is poor. At the same
time, there is no certain security protection in the prediction, which is
easy to cause certain privacy leakage in the interaction process, that is,
poor security. In order to solve the above problems, we propose a rule
learning prediction model based on private information retrieval (PIR-
RL). The model use rule learning to realize the prediction function, and
the rule learning can help to extract rule features to achieve the accu-
racy of prediction. At the same time, inspired by SealPIR, this paper
proposes a Target Prediction Private Information Retrieval (TP-PIR)
to achieve privacy protection. Among them, the dynamic nature of rule
learning and the low computational cost have certain advantages in the
face of practical problems. The lightweight TP-PIR also achieves privacy
protection while ensuring less computational communication overhead.
From the theoretical analysis and experimental results, this model has
good stability and practical value, and can realize the coordination of
prediction service in privacy protection.
c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 438–448, 2025.
https://doi.org/10.1007/978-981-96-4566-4_30
A Efficient and Flexible Prediction Model of PIR-RL 439
1 Introduction
With the rapid development of information technology, data has become one
of the most valuable resources in modern society. As a data-driven computing
method, machine learning has been continuously developed and popularized.
More and more machine learning methods are gradually applied to the relevant
aspects of social life. Using accurate machine learning methods and a wide range
of data, analysis and prediction functions can be achieved through training and
learning. In order to make the majority of users realize the convenience brought
by machine learning, Machine Learning as a Service (MLaaS) is proposed. It is
a new service model based on cloud service, which can provide users with lower
cost machine learning environment and conditions, so that users can make online
prediction services at any time. So as to make machine learning more popular.
However, at the same time, because the machine learning method needs to use
a large amount of data for learning analysis, there will be some data that is easy
to leak privacy in this massive data. These privacy data may cause the exposure
of some key information and sensitive information, which is extremely harmful
to the security of user privacy. For example, in recent years, security incidents
such as data leakage and network attacks have occurred frequently, which has
become a major hidden danger worldwide. On June 3rd, the information of
students and faculty of the Chinese University of Hong Kong was stolen in
large quantities. On July 12 th, AT&T suffered a serious data leakage incident,
resulting in the theft of nearly 110 million user calls and SMS records. The
disclosure of such information may cause serious impact and threat to individuals
and society. Similar incidents such as large-scale hacker attacks, data theft, and
user privacy leaks are still common, which seriously threaten personal privacy,
corporate interests, and national security. This also makes the importance of
information security and data protection increasingly prominent, reflecting the
importance of privacy computing in machine learning. Therefore, we need to
propose new methods to deal with privacy leakage to strengthen the protection of
data privacy, especially for data privacy security protection for machine learning.
1.2 Contribution
With the increasing attention to data privacy and security issues, especially in
sensitive areas such as medical care, finance, and government, the use of decision
trees for data prediction is facing certain challenges. Although the decision tree
model has been widely used in classification and regression problems because of
its easy understanding, simple model construction and good performance, it still
has some shortcomings in data privacy protection and computational efficiency.
Specifically, the main problems of the decision tree model include: First, since
the decision tree model needs to train and construct the entire data set, once
the amount of data changes, it needs to be retrained, which on the one hand
increases the computational cost, on the other hand also leads to insufficient
dynamics. Secondly, in the traditional decision tree training process, the risk of
sensitive data leakage is high. Especially in an open environment, the training
data and prediction data of the model may expose the privacy information of
individuals or organizations, which poses a potential threat to data security.
In order to deal with the above problems, we design and propose an innovative
PIR-based rule learning prediction model. Based on the existing rule learning,
the model not only strengthens the protection of data privacy, but also improves
A Efficient and Flexible Prediction Model of PIR-RL 441
encoding of its query attributes and index its query, and returns its corresponding
prediction value after the client sends the query information.
For the Client. It needs the machine learning prediction service first generates
certain attribute values according to the query target provided, and combines
these attribute values into a query attribute vector. Then, the client uses the
security comparison protocol to compare the query attribute vector with the set
of attribute thresholds preset on the server side to ensure the privacy and security
of the comparison process. Based on the results of security comparison, the
client encodes the query attribute vector and generates the query index. Then,
the client processes the query index according to the PIR scheme, constructs
the PIR query request, and sends the request to the server. Finally, the client
receives the PIR query results returned by the server, and obtains the structure
corresponding to the query attribute vector by decryption, that is, the required
predictive value. In short, the client encodes the query index into a query index
through vector coding, and generates a PIR query request based on the index.
After the server returns the query result, the client decrypts to obtain the final
prediction value. In the whole process, the user ’s query characteristics and
prediction results are unknown to the server. The black box-like environment
constructed by PIR ensures the privacy of data, while ensuring its computability
and comparability, so as to strengthen the protection of privacy and security
while realizing the basic functions.
We assume that the server is strictly in accordance with the provided scheme
and security protocol, but there is a curious service provider who wants to know
the user’s specific application content. It is hoped that the security comparison
interface can be used to analyze the user’s input features, so as to indirectly
obtain the user ’s model information and prediction data. At the same time,
it is assumed that there is a malicious data stealer outside the overall service.
A Efficient and Flexible Prediction Model of PIR-RL 443
He does not want to obtain the prediction results he wants through the normal
query steps, but wants to obtain the content he wants directly by retrieving the
original data, in an attempt to bypass the authentication identification of the
security protocol to directly access the central dataset. Moreover, the efficiency
of prediction is also worthy of attention. The model needs to ensure that the
computational overhead and communication overhead are as small as possible
to ensure that it is efficient. The computational overhead comes from the way
and mode of prediction, and the communication efficiency depends on the traffic
between the server and the client and the number of communication rounds.
In view of the above privacy risks and security risks, the proposed scheme
should have the following requirements and capabilities, including server-side
rule learning machine learning protection and user-side data security privacy
protection, as well as high efficiency in the case of achieving prediction results.
The goal of this paper is to design a safe, reliable and accurate prediction
scheme for MLaaS. In general, the specific objectives that our program needs to
achieve are:
Ensure the Security of Sensitive Information Data. Rule learning is
obtained by training and integrating the data in the server. The final rule comes
from an important part of the original data and is the core information belonging
to the server. Therefore, it is necessary to ensure that the rule learning model
is confidential to ordinary users, and the relevant information of specific rules
cannot be exposed when providing security protocol comparison interface and
privacy retrieval service interface. That is to ensure that the user cannot obtain
the specific content inside the server from the server interface, and the server
cannot spy out some of the user ’s query-related information from the interface
and security protocols to ensure that both parties are in a confidential state.
Ensure the Prediction Accuracy and Efficiency. At the same time, we also
need to achieve the accuracy and efficiency of prediction. As the fundamental
purpose of the whole system model, prediction must first ensure the basic real-
ization of the prediction function, that is, a certain demand for accuracy. At the
same time, in order to ensure the good practicability of the whole model, we also
need to ensure that the computational overhead and communication overhead
are as small as possible. The computational overhead comes from the way and
mode of prediction, and the communication efficiency depends on the amount
of communication and the number of communication rounds between the server
and the client.
So, we need to solve the above problems and requirements by designing spe-
cific method models. Focusing on improving the accuracy of prediction, the effi-
ciency and dynamics of the model, privacy and confidentiality, etc., a new pre-
diction method model with more perfect functions and outstanding advantages
is realized.
444 W. Tu et al.
The flow of the PIR-RL scheme designed in this paper is shown in Fig. 2. The
server S is mainly responsible for training data to generate rule sets and providing
private information retrieval services. Client C converts the original query into
a PIR query, so that the ciphertext result is obtained by using the PIR service
query, and the corresponding plaintext is obtained after decryption. The key
advantage of this scheme is that it can ensure the privacy in the query process,
and still use the ability of rule-based learning to achieve real-time update of
A Efficient and Flexible Prediction Model of PIR-RL 445
data and rules and accurate prediction. The following will introduce the specific
process steps in detail:
1. The server obtains rules through training based on its own data, and then
generates corresponding feature attributes and rule sets according to the rules.
2. The server uses the rule set to build a PIR database, in which the rules
Attribute are associated with each other. After the construction is completed,
the query service can be provided to the outside.
3. The query vector is obtained according to the query attribute given by the
client, and then the security comparison result is obtained through the com-
parison and analysis of the security protocol, so as to generate the query
index.
4. Firstly, the query index is encrypted as the query target for the server to
retrieve, and then passed to the server to initiate the PIR query request.
5. The server obtains the query application, and searches and searches in the
database DB according to the specific rules of SealPIR.
6. Retrieve the desired target results for the query request and complete a PIR
retrieval task. The server-side S sends the retrieval result ( ciphertext ) back
to the client C.
7. The client receives the query result of the server, and then decrypts it with
the private key, and the final plaintext is the expected result of the user.
The role of the preparation stage is to prepare for the subsequent query and
response, mainly for the pre-processing of the original data already stored in
the server. On the server side, the original data should have certain association
conditions. According to these associations, rule learning is used for training,
and a series of forms such as IF-THEN or ATTRIBUTE-AND can be obtained.
According to the type of prediction required, the key value correspondence or
interrelated relationship is selected. The server will integrate the rules of these
columns as a data set, or can be understood as a rule database, and then perform
PIR related processing on the data set to become a PIR database. It ensures the
invisibility of the original rule data to ensure privacy. In the subsequent access
and comparison, only the interface of the PIR database is called and judged. On
the client side, according to the key K of the security protocol, the query vector
value [X] is encrypted, and a layer of protection is made for subsequent queries.
At the same time, when constructing the PIR database, the public key pk and
the private key sk used to perform the PIR service are prepared. The public key
pk is public, and the private key sk is saved by the user itself, which will be used
for the decryption of the later information acquisition.
446 W. Tu et al.
In the query phase, the client will initiate a prediction request to the server
based on the eigenvalues of the query. The query attribute is processed into a
query vector, and then the query index Q is generated by encoding. After we
obtain the query index, the query generation interface of PIR will be called
to generate the corresponding PIR query request. Since the data storage mode
used by PIR is polynomial storage, we first give a polynomial coefficient N as
the length, and then expand the query vector q to a PIR query vector p with
a dimension of l according to the data volume, that is, p = P IR.expand(q).
Where p = {p0, p1, ..., pl − 1}, each element pi of the vector p is encrypted and
corresponds one-to-one with the elements at the corresponding position in the
PIR database DB. After the extension is completed, the server performs the PIR
retrieval process.Specifically, the server performs homomorphic multiplication on
each ciphertext element pi in the query vector p and the i th data item in the
corresponding dimension of the database, and performs homomorphic addition
and splicing on the results of all homomorphic multiplications, and finally obtains
the ciphertext search result a, that is, a = P IR.answer(p, DB). Finally, the
ciphertext a is returned to the client to complete the query process
In the response phase, the server processes PIR queries through the SealPIR
scheme and returns the query results to the client. After receiving the response,
the client decrypts it and finally obtains the prediction result. However, at this
time, the search result a is still in ciphertext form. The server cannot know
its plaintext content, and will only return the ciphertext result a to the client
according to the pre-execution. After receiving the ciphertext a, the client will
use the private key sk generated and saved before to decrypt the ciphertext
a, so as to obtain the final plaintext prediction value. The predicted value A
corresponds to the attribute vector of the client query and represents the result
of online prediction through the rule learning model. In the whole process, the
server only processes encrypted query vectors and ciphertext data, so it cannot
know the user ’s query content or prediction results. At the same time, because
all operations are performed at the ciphertext level, the server cannot obtain
any plaintext information about the client query. Moreover, the client can only
obtain the predicted value of its query, and cannot obtain additional server data
to cause privacy leakage of the server.
We carried out the experiment according to the designed scheme, gave the
training data set with correlation, and tried the SealPIR process. Finally, the
prediction can be successfully completed under ideal conditions, and has high
accuracy. At the same time, the PIR process is also normal. Here are some
parameters of PIR (see the Table 1) :
The results show that the PIR is successfully executed and the parameter
values are normal, which verifies the correctness of our model.
A Efficient and Flexible Prediction Model of PIR-RL 447
4 Conclusion
In this paper, a new PIR-RL model is proposed for common prediction problems.
According to the confidentiality problems in conventional prediction models and
the deficiencies in common prediction schemes such as decision tree encryption,
the rule learning prediction model PIR-RL based on private information retrieval
designed in this paper first introduces the flexibility and dynamics of rule learning
to optimize the defects and deficiencies of common decision tree prediction. At
the same time, the combination of SealPIR also ensures the security of the
overall model, which can effectively realize the function of prediction and privacy
protection, so as to have strong practicability.
References
1. Ouadrhiri, A.E., Abdelhadi, A.: Differential privacy for deep and federated learn-
ing: a survey. IEEE Access 10, 22359–22380 (2022). https://doi.org/10.1109/
ACCESS.2022.3151670
2. Jia, B., Zhang, X., Liu, J., Zhang, Y., Huang, K., Liang, Y.: Blockchain-enabled
federated learning data protection aggregation scheme with differential privacy and
homomorphic encryption in IIoT. In: IEEE Transactions on Industrial Informatics,
vol. 18, no. 6, pp. 4049-4058 (2022). https://doi.org/10.1109/TII.2021.3085960
3. Shuai, L.I., Chang, J., Li-lv, M., Cai, K.: A Stacking ensemble clustering algo-
rithm based on differential privacy protection. Comput. Eng. Sci. 44(08), 1402–
1408 (2022)
448 W. Tu et al.
4. Guan, G., Zhi, T., Tao, Z., Cao, Y.: Federal learning privacy protection method
for multi-key homomorphic encryption in the Internet of Things. Inf. Secur. Res.
10(10), 958-966 (2024)
5. Ma, J., Naas, S.-A., Sigg, S., Lyu, X.: Privacy-preserving federated learning based
on multi-key homomorphic encryption. Int. J. Intell. Syst. 37, 5880–5901 (2022).
https://doi.org/10.1002/int.22818
6. Tang, S., Yuan, Y.: Privacy-preserving graph query based on secure multi-party
computation. Front. Data Comput. 5(5), 98–106 (2023). https://cstr.cn/32002.14.
jfdc.CN10-1649/TP.2023.05.008
7. Chuanxin, Z., Yi, S., Degang, W., Huawei, G.E.: Survey of federated learning
research. Chin. J. Netw. Inf. Secur. 7(5), 77–92 (2021)
8. Mi, Q., et al.: Application of deep learning method to drought prediction. J. Appl.
Meteor. Sci. 33(1), 104–114 (2022). https://doi.org/10.11898/1001-7313.20220109
9. ZHANG Feng;SUN Xue-dong;CHANG Hui-you;ZHAO Gan-sen: Research on
privacy-preserving two-party collaborative filtering recommendation. Acta Elec-
tron. Sin. 37(1), 84–89 (2009)
10. Liu, H., Lu, J., Peng, J., Qiao, D., Zhao, Y.: Research on redundant control of
AMT system gear shifting process based on decision tree algorithm. Trans. Beijing
Inst. Technol. 42(1), 63–73 (2022). https://doi.org/10.12172/202211150002
11. Gao, F.W., et al.: Forest fire prediction in Algeria based on decision tree algorithm
in spark MLlib. J. Sichuan For. Sci. Technol. 44(5), 24–31 (2023). https://doi.org/
10.12172/202211150002
12. Gong, Z., et al.: Comparative study of mutation techniques based on rules and
learning. J. Softw. 35(7) (2024)
13. Ye, A., et al.: Trajectory differential privacy protection mechanism based on pre-
diction and sliding window. J. Commun./Tongxin Xuebao 41(4) (2020)
14. Xu, J., Lin, J., Li, Y., Xiong, Z.: Distributed user privacy protection adjustable
personalized QoS prediction model for cloud services. J. Netw. Inf. Secur. 9(2),
70–80 (2023)
15. Yang, F.: Research on online prediction scheme of secure decision tree based on pri-
vate information retrieval. Xi ’an University of Electronic Science and Technology
(2022). https://doi.org/10.27389/d.cnki.gxadu.2022.001965.
Author Index
B L
Ban, Xinbo 148 Li, Bo 16
Li, Jingang 133
C Li, Kun 332
Cai, Zhenghua 197 Li, Lin 148
Camtepe, Seyit 148 Li, Nan 271
Che, Xun 271 Li, Yitong 299
Chen, Chao 148 Li, Yuxian 1
Chen, Dixuan 246 Liang, Dongyang 332
Chen, Fan 246, 332 Lin, Jiatong 246, 332
Chen, Xianyi 231 Lin, Zhiqiang 16
Chen, Yadang 271 Liu, Hongwei 70
Chen, Zhiguo 423 Liu, Shigang 148
Cheng, Yu 246, 332 Liu, Shuyan 377
Cui, Yuxin 118 Liu, Zhaoyi 133
D
M
Dai, Wenzhang 347
Ma, Liang 392
Deng, Xiaozhi 81
Ma, Qingru 118
Dong, Zheng 438
Duan, Ao 70
N
G Niyato, Dusit 1
Gao, Qiang 1
Gao, Shengyang 362
Gao, Zhaojian 271 P
Gong, Huimin 184 Pan, Yiting 316
H
Q
Hong, Shijia 43, 56
Qiao, Lei 184, 197
Huang, Jinming 362
Qiu, Dongyan 70
J
Ji, Yinglin 90 S
Jiang, Guixin 90 Shen, Mingdi 43, 56
Jin, Zilong 362, 377, 392 Shen, Zongliang 184
Sheng, Xinlei 284
K Su, Jian 408
Ke, Lishan 16 Sun, Jianfei 1
Kong, Wei 258 Sun, Le 347
© The Editor(s) (if applicable) and The Author(s), under exclusive license
to Springer Nature Singapore Pte Ltd. 2025
Y. Xiang and J. Shen (Eds.): ML4CS 2024, LNCS 15566, pp. 449–450, 2025.
https://doi.org/10.1007/978-981-96-4566-4
450 Author Index