Skip to content

Improved quantize script #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 19, 2023
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ python3 -m pip install torch numpy sentencepiece
python3 convert-pth-to-ggml.py models/7B/ 1

# quantize the model to 4-bits
./quantize.sh 7B
python3 quantize.py 7B

# run the inference
./main -m ./models/7B/ggml-model-q4_0.bin -n 128
Expand Down
126 changes: 126 additions & 0 deletions quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#!/usr/bin/env python3

"""Script to execute the "quantize" script on a given set of models."""

import subprocess
import argparse
import glob
import sys
import os


def main():
"""Update the quantize binary name depending on the platform and parse
the command line arguments and execute the script.
"""

if "linux" in sys.platform or "darwin" in sys.platform:
quantize_script_binary = "quantize"

elif "win32" in sys.platform or "cygwin" in sys.platform:
quantize_script_binary = "quantize.exe"

else:
print("WARNING: Unknown platform. Assuming a UNIX-like OS.\n")
quantize_script_binary = "quantize"

parser = argparse.ArgumentParser(
prog='python3 quantize.py',
description='This script quantizes the given models by applying the '
f'"{quantize_script_binary}" script on them.'
)
parser.add_argument(
'models', nargs='+', choices=('7B', '13B', '30B', '65B'),
help='The models to quantize.'
)
parser.add_argument(
'-r', '--remove-16', action='store_true', dest='remove_f16',
help='Remove the f16 model after quantizing it.'
)
parser.add_argument(
'-m', '--models-path', dest='models_path',
default=os.path.join(os.getcwd(), "models"),
help='Specify the directory where the models are located.'
)
parser.add_argument(
'-q', '--quantize-script-path', dest='quantize_script_path',
default=os.path.join(os.getcwd(), quantize_script_binary),
help='Specify the path to the "quantize" script.'
)

# TODO: Revise this code
# parser.add_argument(
# '-t', '--threads', dest='threads', type='int',
# default=os.cpu_count(),
# help='Specify the number of threads to use to quantize many models at '
# 'once. Defaults to os.cpu_count().'
# )

args = parser.parse_args()

if not os.path.isfile(args.quantize_script_path):
print(
f'The "{quantize_script_binary}" script was not found in the '
"current location.\nIf you want to use it from another location, "
"set the --quantize-script-path argument from the command line."
)
sys.exit(1)

for model in args.models:
# The model is separated in various parts
# (ggml-model-f16.bin, ggml-model-f16.bin.0, ggml-model-f16.bin.1...)
f16_model_path_base = os.path.join(
args.models_path, model, "ggml-model-f16.bin"
)

f16_model_parts_paths = map(
lambda filename: os.path.join(f16_model_path_base, filename),
glob.glob(f"{f16_model_path_base}*")
)

for f16_model_part_path in f16_model_parts_paths:
if not os.path.isfile(f16_model_part_path):
print(
f"The f16 model {os.path.basename(f16_model_part_path)} "
f"was not found in {args.models_path}{os.path.sep}{model}"
". If you want to use it from another location, set the "
"--models-path argument from the command line."
)
sys.exit(1)

__run_quantize_script(
args.quantize_script_path, f16_model_part_path
)

if args.remove_f16:
os.remove(f16_model_part_path)


# This was extracted to a top-level function for parallelization, if
# implemented. See https://github.com/ggerganov/llama.cpp/pull/222/commits/f8db3d6cd91bf1a1342db9d29e3092bc12dd783c#r1140496406

def __run_quantize_script(script_path, f16_model_part_path):
"""Run the quantize script specifying the path to it and the path to the
f16 model to quantize.
"""

new_quantized_model_path = f16_model_part_path.replace("f16", "q4_0")
subprocess.run(
[script_path, f16_model_part_path, new_quantized_model_path, "2"],
check=True
)


if __name__ == "__main__":
try:
main()

except subprocess.CalledProcessError:
print("\nAn error ocurred while trying to quantize the models.")
sys.exit(1)

except KeyboardInterrupt:
sys.exit(0)

else:
print("\nSuccesfully quantized all models.")
15 changes: 0 additions & 15 deletions quantize.sh

This file was deleted.