Skip to content

Add dynamic quantization support to gemlite layout #2327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions torchao/dtypes/uintx/gemlite_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def get_gemlite_aqt_kwargs(
group_size=64,
bit_width=4,
packing_bitwidth=None,
mode="weight_only",
use_hqq=True,
):
if gemlite is None:
Expand All @@ -108,6 +109,10 @@ def get_gemlite_aqt_kwargs(
f"Invalid packing bitwidth, got {packing_bitwidth}"
)

assert mode in ["weight_only", "dynamic"], (
f"Invalid mode: should be either weight_only or dynamic, got {mode}"
)

out_features, in_features = weight.shape
group_size = in_features if group_size is None else group_size

Expand All @@ -116,6 +121,7 @@ def get_gemlite_aqt_kwargs(
group_size=group_size,
bit_width=bit_width,
packing_bitwidth=packing_bitwidth,
mode=mode,
)
aqt_kwargs["use_hqq"] = use_hqq
return aqt_kwargs
Expand All @@ -126,6 +132,7 @@ class GemlitePackedLayout(Layout):
group_size: Optional[int] = 128
bit_width: int = 4
packing_bitwidth: Optional[int] = None
mode: Optional[str] = "weight_only"


@register_layout(GemlitePackedLayout)
Expand Down Expand Up @@ -202,13 +209,24 @@ def from_plain(
group_size, bit_width = _layout.group_size, _layout.bit_width
out_features, in_features = int_data.shape
packing_bitwidth = _layout.packing_bitwidth
mode = _layout.mode

if bit_width == 8 and group_size == in_features:
gemlite_linear = gemlite.helper.A16W8(device=int_data.device).from_weights(
processor = (
gemlite.helper.A8W8_int8_dynamic
if mode == "dynamic"
else gemlite.helper.A16W8
)
gemlite_linear = processor(device=int_data.device).from_weights(
int_data, scales=scale, bias=None
)
else:
gemlite_linear = gemlite.helper.A16Wn(
processor = (
gemlite.helper.A8Wn_dynamic
if mode == "dynamic"
else gemlite.helper.A16Wn
)
gemlite_linear = processor(
device=int_data.device, packing_bitwidth=packing_bitwidth
).from_weights(
int_data, scale, zero_point, bit_width, group_size, bias=None
Expand Down
8 changes: 7 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,10 +742,16 @@ def from_float(cls, weight):

bit_width = 4
packing_bitwidth = None
mode = "weight_only"
use_hqq = True

aqt_kwargs = get_gemlite_aqt_kwargs(
weight, cls.group_size, bit_width, packing_bitwidth, use_hqq
weight,
group_size=cls.group_size,
bit_width=bit_width,
packing_bitwidth=packing_bitwidth,
mode=mode,
use_hqq=use_hqq,
)
weight = to_affine_quantized_intx(weight, **aqt_kwargs)
input_quant_func = _to_float16
Expand Down
11 changes: 9 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,13 +986,14 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
size is more fine grained
`bit_width`: bit width of the quantized weight.
`packing_bitwidth`: bit width of the packed weight, should be 8 or 32. Can have performance impacts depending on hardware.
`contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice.
`mode`: if set to "dynamic", activations are quantized at runtime; default is "weight_only" (weight-only quantization).
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values.
"""

group_size: Optional[int] = 128
bit_width: int = 4
packing_bitwidth: Optional[int] = None
mode: Optional[str] = "weight_only"
set_inductor_config: bool = True


Expand All @@ -1007,6 +1008,7 @@ def _gemlite_uintx_weight_only_transform(
group_size = config.group_size
bit_width = config.bit_width
packing_bitwidth = config.packing_bitwidth
mode = config.mode
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

Expand All @@ -1018,7 +1020,12 @@ def _gemlite_uintx_weight_only_transform(
new_weight = to_affine_quantized_intx(
weight,
**get_gemlite_aqt_kwargs(
weight, group_size, bit_width, packing_bitwidth, use_hqq
weight,
group_size=group_size,
bit_width=bit_width,
packing_bitwidth=packing_bitwidth,
mode=mode,
use_hqq=use_hqq,
),
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
Expand Down
Loading