Skip to content

Commit abdb958

Browse files
committed
Update LUT test, adding new simple unit test logic.
1 parent b0c32d6 commit abdb958

File tree

1 file changed

+58
-2
lines changed

1 file changed

+58
-2
lines changed

‎torchao/experimental/kernels/cpu/aarch64/tests/test_utils_lut.h

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
namespace torchao::test_utils::lut {
1313

1414
enum class QuantizationGranularity { PER_TENSOR, PER_CHANNEL, PER_GROUP };
15-
enum class GroundTruthStrategy { IDEAL_DEQUANT, LUT_DEQUANT };
15+
enum class GroundTruthStrategy { IDEAL_DEQUANT, LUT_DEQUANT, RANDOM_LUT_DIRECT };
1616

1717
struct TestCaseConfig {
1818
int rows;
@@ -27,7 +27,7 @@ struct TestCaseConfig {
2727

2828

2929
std::pair<std::vector<int8_t>, std::vector<int8_t>>
30-
generate_simple_u_to_s_lut_and_indices(
30+
generate_random_int8_lut_and_indices(
3131
int weight_nbit,
3232
const std::vector<int8_t>& weight_qvals) {
3333
// 1. Define the offset used to map between signed and unsigned representations.
@@ -168,6 +168,11 @@ class CpuTestCaseGenerator {
168168
static lut_quantization_test_case<T_in, T_zp, T_lut> generate(const TestCaseConfig& config) {
169169
std::mt19937 gen(config.random_seed);
170170

171+
// Check if the user requested the simple, direct LUT generation strategy (Unit test for LUT).
172+
if (config.strategy == GroundTruthStrategy::RANDOM_LUT_DIRECT) {
173+
return generate_random_lut_direct(config, gen);
174+
}
175+
171176
// 1. Generate random floating-point data
172177
auto input_float = generate_random_data(config.rows * config.cols, gen);
173178

@@ -194,6 +199,57 @@ class CpuTestCaseGenerator {
194199
}
195200

196201
private:
202+
203+
204+
/**
205+
* @brief Helper for the simple RANDOM_LUT_DIRECT strategy.
206+
*
207+
* Bypasses quantization simulation. It directly creates a random LUT and
208+
* random indices, with the ground truth being the result of the lookup.
209+
* This path always generates a PER_TENSOR style test case.
210+
*/
211+
static lut_quantization_test_case<T_in, T_zp, T_lut>
212+
generate_random_lut_direct(const TestCaseConfig& config, std::mt19937& gen) {
213+
const int num_elements = config.rows * config.cols;
214+
const T_in q_min = -(1 << (config.nbit - 1));
215+
const T_in q_max = (1 << (config.nbit - 1)) - 1;
216+
const size_t lut_size = (1 << config.nbit);
217+
const int lut_idx_offset = q_min;
218+
219+
// 1. Generate a completely random LUT.
220+
std::vector<T_lut> lut(lut_size);
221+
std::uniform_real_distribution<float> lut_val_dist(-5.0f, 5.0f);
222+
for (size_t i = 0; i < lut_size; ++i) {
223+
lut[i] = static_cast<T_lut>(lut_val_dist(gen));
224+
}
225+
226+
// 2. Generate random quantized values to serve as our input data.
227+
std::vector<T_in> input_qvals(num_elements);
228+
std::uniform_int_distribution<int> q_val_dist(q_min, q_max);
229+
for (int i = 0; i < num_elements; ++i) {
230+
input_qvals[i] = static_cast<T_in>(q_val_dist(gen));
231+
}
232+
233+
// 3. The ground truth is simply the result of looking up the q_vals in the LUT.
234+
std::vector<float> expected_output(num_elements);
235+
for (int i = 0; i < num_elements; ++i) {
236+
T_in q_val = input_qvals[i];
237+
size_t lut_idx = q_val - lut_idx_offset; // Map q_val (e.g., -8..7) to index (e.g., 0..15)
238+
assert(lut_idx < lut.size());
239+
expected_output[i] = static_cast<float>(lut[lut_idx]);
240+
}
241+
242+
// 4. Create dummy scale/zero to satisfy the test case struct assertions.
243+
// This strategy is inherently per-tensor.
244+
std::vector<float> scales = {1.0f};
245+
std::vector<T_zp> zeros = {static_cast<T_zp>(0)};
246+
247+
return lut_quantization_test_case<T_in, T_zp, T_lut>(
248+
config.rows, config.cols, QuantizationGranularity::PER_TENSOR, /*group_size=*/-1, config.nbit,
249+
std::move(input_qvals), std::move(scales), std::move(zeros),
250+
std::move(lut), std::move(expected_output));
251+
}
252+
197253
// Helper to generate random float data
198254
static std::vector<float> generate_random_data(int size, std::mt19937& gen) {
199255
std::vector<float> data(size);

0 commit comments

Comments
 (0)