Skip to content

Fix slicing and get_plain() in GemLite #2288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 5, 2025
Merged

Conversation

mobicham
Copy link
Collaborator

@mobicham mobicham commented Jun 2, 2025

Contributions

  • GemLite 4.7.0 uses FMA mode by default to improve dequantization performance (Wq * s + z instead of (W_q - z ) * s), so we need to update get_plain() to make it compatible with both formats.
  • Updated slicing to work directly on the packed data, since the older version that was using get_plain() was causing vLLM issues.

Notes

  • gemlite.set_kernel_caching(True) gives wrong output with torchao but not when using gemlite as a module, not sure why, but that would impact perf for batch-size=1 by up to 10 tokens/sec.

Tests

End-2-End test

https://gist.github.com/mobicham/54fed6f18bee590f615f18391b45b71e

Slicing Test

import torch, gemlite
from torchao.quantization import GemliteUIntXWeightOnlyConfig, quantize_
device = 'cuda:0'
dtype = torch.float16
gemlite.set_autotune("default")

torch.manual_seed(0)
in_features, out_features, group_size = 256, 512, 64

orig_shape = [out_features, in_features]
layer = torch.nn.Linear(in_features, out_features, bias=False, dtype=dtype, device=device)
layer.weight.data /= 10.
weight = layer.weight.data.clone()

quantize_(layer, GemliteUIntXWeightOnlyConfig(bit_width=4, group_size=group_size))

meta_args =  layer.weight.tensor_impl.gemlite_kwargs['meta_args']
W_group_mode = meta_args[10]

#Test matmul
####################################################################################
torch.manual_seed(0)
x = torch.randn((1, layer.in_features), device=device, dtype=dtype) / 10.
y_ref = x @ weight.T
y_gem  = layer(x)
err = (y_ref - y_gem).abs().mean()
assert err < 5e-3, "Dot product mismatch. " + str(err)

#Test slicing 
####################################################################################
def dequant(input_layer, in_features, orig_shape):
    int_data = input_layer.tensor_impl.packed_weight
    scale = input_layer.tensor_impl.scale
    zero_point = input_layer.tensor_impl.zero_point

    W_q = (
        gemlite.bitpack.unpack_over_rows(
            int_data, W_nbits=4, num_output_rows=in_features, dtype=torch.uint8
        )
        .T.contiguous()
        .view([-1, group_size])
    )

    s = scale.t().contiguous().view(-1, 1)
    z = zero_point.t().contiguous().view(-1, 1)

    if W_group_mode == 4:  # FMA
        W_deq = (W_q * s + z).view(orig_shape)
    else:
        W_deq = ((W_q - z) * s).view(orig_shape)

    return W_deq

W_r = dequant(layer.weight, layer.in_features, orig_shape) #~weight


#Slicing in half
for slice_axis, start, end in [(0, 0, 256), (0, 256, 256), (1, 0, 128), (1, 128, 128)]:
    layer_sliced = layer.weight.narrow(slice_axis, start, end)

    if slice_axis == 0:
        num_rows, out_shape = layer.in_features, (orig_shape[0]//2, orig_shape[1]) 
    else:
        num_rows, out_shape = layer.in_features // 2, (orig_shape[0], orig_shape[1]//2)

    W_slice = dequant(layer_sliced, num_rows, out_shape)

    W_slice_ref = W_r[start:start+end, :] if slice_axis == 0 else W_r[:, start:start+end]
    assert (W_slice_ref - W_slice).abs().mean() == 0, f"Slicing {start}:{end} along axis={slice_axis} is incorrect"
    ```
Copy link

pytorch-bot bot commented Jun 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2288

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit b2892ce with merge base 35ffb26 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 2, 2025
@mobicham mobicham marked this pull request as draft June 2, 2025 15:07
@jerryzh168
Copy link
Contributor

could you incorporate test into

def test_slice_gemlite(self, device, dtype):
as well

@mobicham mobicham marked this pull request as ready for review June 2, 2025 17:44
@mobicham
Copy link
Collaborator Author

mobicham commented Jun 3, 2025

Updated the test and successfully tested on vLLM.

@mobicham
Copy link
Collaborator Author

mobicham commented Jun 3, 2025

In vLLM, I get _same_metadata() with models that have lm_head quantized, it seems to me that ao';s vLLM implementation doesn't support that?

@mobicham mobicham requested a review from jerryzh168 June 3, 2025 18:22
@jerryzh168 jerryzh168 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Jun 5, 2025
@jerryzh168 jerryzh168 merged commit 0640474 into pytorch:main Jun 5, 2025
19 of 20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
3 participants