![]() |
⚡️ A fast, plug-and-play, model-agnostic quantization technique delivering state-of-the-art performance for Large Language Models without sacrificing accuracy.
💡 Want to run a large model on your GPU but don’t have enough memory? With SINQ, you can deploy models that would otherwise be too big drastically reducing memory usage while preserving LLM quality.
⏱️ SINQ quantizes Qwen3-14B in just ~21 sec and DeepSeekV2.5-236B in ~5 min
🆕 [14/11/2025] PR to integrate SINQ into HF transformers! 🤗
**Issue + PR have been opened to integrate SINQ into HF transformers. You can follow the discussion here!
Note: We’re also actively working to add support for popular frameworks such as
vLLM,SGLang, andllama.cppto enable fast SINQ-ference.
In the meantime, you can ⭐️ star and watch the repo to stay updated!
SINQ (Sinkhorn-Normalized Quantization) is a novel, fast and high-quality quantization method designed to make any Large Language Models smaller while keeping their accuracy almost intact.
- 1. How does SINQ work?
- 2. Why should I use SINQ?
- 3. Quantize (and save) any LLM with SINQ
- 4. Run pre-quantized SINQ models from Hugging Face
- 5. How to reproduce paper results
- 6. Ongoing updates on new features and integrations
- 7. How to Cite This Work
- 8. Related Repositories
| Feature | SINQ | HQQ | A-SINQ | AWQ |
|---|---|---|---|---|
| 🎯 Calibration | Calibration-free | Calibration-free | Calibrated | Calibrated |
| 🧮 Quantization Type | Symmetric & Asymmetric | Asymmetric only | Symmetric & Asymmetric | Symmetric & Asymmetric |
| 📦 NF4 Support | Yes | No | Yes | No |
| ⚡ Quantization Speed | ~2× Faster than HQQ | Slower | ~4× Faster than AWQ | Slower |
| 📈 Model Quality | Higher | Lower | Higher | Lower |
📄 Want to know more? Read our paper on arXiv!
Click to expand a quick explanation of SINQ’s core idea
Conventional quantization uses one scale per weight dimension, which makes models vulnerable to outliers: large weights that distort scaling and cause significant errors.
SINQ solves this by introducing dual scaling: separate scale factors for rows and columns. This flexibility redistributes outlier influence and keeps quantization errors smaller and more balanced.
With standard single-scale quantization, errors tend to cluster around outliers.
With SINQ, they become spread out and less severe, preserving model accuracy even at 3 bit precision. This improvement is driven by SINQ’s Sinkhorn-normalized optimization, which iteratively rescales rows and columns to balance their variance - a process inspired by Sinkhorn matrix normalization. By reducing the overall matrix imbalance (refer to the paper for more info), weights become inherently easier to quantize, leading to more stable behavior across layers and consistently higher accuracy even at very low bit-widths.
Click to expand a quick explanation on why you should use SINQ to quantize your LLM
- Higher LLM quality and ~2× faster quantization than HQQ
- >31× faster quantization process and comparable or better LLM quality compared to AWQ / GPTQ
- Model-agnostic: works without knowing the specific LLM architecture, unlike QuaRot
- Training-free: it does not require end-to-end training, unlike SpinQuant or KurTail
- Additionally, A-SINQ (calibrated) further beats AWQ, GPTQ, and Hadamard+GPTQ on quality while achieving >4× faster quantization time.
Example
- ⏱️ SINQ quantizes Qwen3-14B in just ~21 sec and DeepSeekV2.5-236B in ~5 min on a single GPU
- 💾 Enables you to run DeepSeekV2.5-236B on a single GPU with ~110 GB of memory (vs ~472 GB) while losing < 1 ppl on WikiText2 and C4
First, install the dependencies and set up the package:
# 1. Clone the repository
git clone https://github.com/huawei-csl/SINQ.git
cd SINQ
# 2. Install dependencies
pip install -r req.txt
# 3. Install SINQ
pip install .Quantizing any 🤗 Hugging Face model with SINQ is simple and takes only a few lines of code:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sinq.patch_model import AutoSINQHFModel
from sinq.sinqlinear import BaseQuantizeConfig
model_name = "Qwen/Qwen3-1.7B"
device = "cuda:0"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
quant_cfg = BaseQuantizeConfig(
nbits=4, # quantization bit-width
group_size=64, # group size
tiling_mode="1D", # tiling strategy
method="sinq" # quantization method ("asinq" for the calibrated version)
)
qmodel = AutoSINQHFModel.quantize_model(
model,
tokenizer=tokenizer,
quant_config=quant_cfg,
compute_dtype=torch.bfloat16,
device=device
)✅ That’s it. Your model is now quantized with SINQ and ready for inference or saving.
You can further customize the quantization process to balance accuracy and memory for your needs.
Here’s a summary of the main arguments you can tune:
| Flag | Description | Options | Default |
|---|---|---|---|
--nbits |
Bit-width for weight quantization | 2, 3, 4, 5, 6, 8 | 4 |
--tiling_mode |
Weight matrix tiling strategy | 1D, 2D | 1D |
--group_size |
Weights per quantization group | 64, 128 | 64 |
--method |
Quantization method | sinq, asinq | sinq |
💡 Tip: For most cases, the defaults (--nbits 4 --tiling_mode 1D --group_size 64 --method sinq) provide an excellent trade-off between compression and accuracy.
If you want to reuse a quantized model later, save it to disk in HF-style sharded safetensors and reload without needing base FP weights.
Requires:
pip install safetensorsandpip install gemlite==0.5.1.post1
# --- Save to a folder (sharded safetensors) ---
from sinq.patch_model import AutoSINQHFModel
import torch
save_dir = "qwen3-1.7b-sinq-4bit" # any path
# 'model' must already be SINQ-quantized (e.g., via AutoSINQHFModel.quantize_model)
AutoSINQHFModel.save_quantized_safetensors(
qmodel,
tokenizer,
save_dir,
verbose=True,
max_shard_size="4GB", # typical HF shard size (use "8GB" if you prefer)
)# --- Reload later--
from sinq.patch_model import AutoSINQHFModel
import torch
tokenizer = AutoTokenizer.from_pretrained(save_dir)
device = "cuda:0"
qmodel = AutoSINQHFModel.from_quantized_safetensors(
save_dir,
device=device,
compute_dtype=torch.bfloat16,
)✅ Your model is now loaded and ready for inference!
You can optionally compile the model’s forward pass using torch.compile, which can provide a significant speed boost (especially after the first run):
# Warm up to initialize CUDA graphs
_ = qmodel.forward(torch.tensor([[0]], device=device))
# Compile for faster inference
qmodel.forward = torch.compile(
qmodel.forward,
dynamic=True,
fullgraph=False,
backend="inductor",
mode="reduce-overhead",
)⏱️ The first run will take longer because PyTorch compiles optimized kernels, but subsequent runs will be much faster.
Alternative: save & reload as a single .pt file
# --- Save to a folder (.pt) ---
from sinq.patch_model import AutoSINQHFModel
save_dir = "qwen3-1.7b-sinq-4bit" # any path
AutoSINQHFModel.save_quantized(qmodel, tokenizer, save_dir, verbose=True) # creates qmodel.pt# --- Reload later from .pt ---
from sinq.patch_model import AutoSINQHFModel
import torch
tokenizer = AutoTokenizer.from_pretrained(save_dir)
qmodel = AutoSINQHFModel.from_quantized(
save_dir,
device=device,
compute_dtype=torch.bfloat16,
)Compatible with lm-eval evaluation framework
Below is a minimal example showing how to evaluate a SINQ-quantized model on a benchmark dataset:
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
# Wrap the already quantized model and tokenizer with HFLM
lm = HFLM(pretrained=qmodel, tokenizer=tokenizer, device=device)
# Evaluate (many tasks available on lm-eval such as MMLU and HellaSwag)
results = evaluator.simple_evaluate(
model=lm,
tasks=["lambada_openai"], # small and fast benchmark
device=device
)We’re publishing a growing collection of pre-quantized SINQ models on 🤗 Hugging Face: huawei-csl / SINQ collection !
import torch
from transformers import AutoTokenizer
from sinq.patch_model import AutoSINQHFModel
model_name = "huawei-csl/<model_name>" # pick a model from the collection
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = "cuda:0"
qmodel = AutoSINQHFModel.from_quantized_safetensors(
model_name,
device=device,
compute_dtype=torch.bfloat16,
)For additional speed (in addition to the one given by gemlite), do a quick warm-up and JIT-compile the forward:
# Warm-up to build shapes
_ = qmodel.forward(torch.tensor([[0]], device=device))
# Compile the forward pass
qmodel.forward = torch.compile(
qmodel.forward,
dynamic=True,
fullgraph=False,
backend="inductor",
mode="reduce-overhead",
)prompt = "Explain neural network quantization in one sentence."
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.inference_mode():
out_ids = qmodel.generate(**inputs, max_new_tokens=32, do_sample=False)
print(tokenizer.decode(out_ids[0], skip_special_tokens=True))⏱️ The first run will be slower due to kernel/graph compilation. Subsequent runs are much faster!
Click to expand the commands to reproduce the paper results
First, install the dependencies and set up the package:
# 1. Clone the repository
git clone https://github.com/huawei-csl/SINQ.git
cd SINQ
# 2. Install dependencies
pip install -r req.txt
# 3. Install SINQ
pip install .Then run the following command to quantize Qwen3-1.7B out of the box:
cd tests
python quant_model_eval.pyBy default, this will run SINQ with the following settings:
- ✅ 4-bit weight quantization
- ✅ Dual-scale + shift parameterization
- ✅ 1D tiling
- ✅ Group size = 64
Reproduce the core SINQ results (as shown in Table 1 of the paper):
python quant_model_eval.py --model_name Qwen/Qwen3-1.7BThis uses INT4 uniform quantization without calibration - the main benchmark setting of the paper.
Try SINQ with non-uniform quantization (e.g., NF4):
python quant_model_eval.py --method sinq_nf4 --model_name Qwen/Qwen3-1.7BCombine SINQ with activation-aware calibration (AWQ) for higher accuracy:
python quant_model_eval.py --method asinq --model_name Qwen/Qwen3-1.7BCustomize experiments with the following command-line arguments:
| Flag | Description | Options | Default |
|---|---|---|---|
--nbits |
Number of bits used to quantize model weights | 2, 3, 4, 8 | 4 |
--tiling_mode |
Strategy for tiling weight matrices during quantization | 1D, 2D | 1D |
--group_size |
Number of weights processed together as a quantization group | 64, 128 | 64 |
📝 Note: All results reported in the paper were obtained using the evaluation framework from Efficient-ML/Qwen3-Quantization rather than
lm-eval.
We are actively expanding SINQ with new features and integrations. Stay tuned here for the latest updates:
- [26/09/2025] - SINQ paper released on arXiv
- [30/09/2025] - SINQ GitHub repository made public
- [02/10/2025] - SINQ paper featured on 🤗 Hugging Face Papers
- [17/10/2025] - First pre-quantized SINQ models available on 🤗Hugging Face Hub! 🆕
- [23/10/2025] - Faster inference with gemlite backend (4-bit 1D tiling)
- 🔜 Coming soon - 🤗 Integration with Hugging Face Transformers
- 🔜 Coming soon - Support for Conv2D layers and timm models for computer vision tasks
- 🔜 Work in progress - Support for mixed-precision quantization (combine multiple bitwidths for optimal accuracy-efficiency balance)
- 🔜 Work in progress - We’re actively working to provide support for popular frameworks such as
vLLM,SGLang, andllama.cpp.
If you find SINQ useful in your research or applications, please cite our paper:
@misc{muller2025sinq,
title={SINQ: Sinkhorn-Normalized Quantization for Calibration-Free Low-Precision LLM Weights},
author={Lorenz K. Muller and Philippe Bich and Jiawei Zhuang and Ahmet Celik and Luca Benfenati and Lukas Cavigelli},
year={2025},
eprint={2509.22944},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={http://arxiv.org/abs/2509.22944}
}This project builds upon and extends the excellent work from the following open-source projects:
- Qwen3-Quantization - Base implementation and evaluation scripts for Qwen3 quantization.
- HQQ - High-quality calibration-free quantization baseline.
📜 You can find their original licenses in the corresponding LICENSE files in these repositories.


