Skip to content

[WIP] Layer-wise unit test without PyTorch #1793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 315 additions & 0 deletions layer_unit_test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "51ab58f1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"jax 0.015534584410488605 pt 0.015534583479166031\n"
]
}
],
"source": [
"\"\"\"\n",
"Copyright 2025 Google LLC\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License.\n",
"\"\"\"\n",
"\n",
"\"\"\" Tests for Llama4 Vision RoPE \"\"\"\n",
"from typing import Callable, NamedTuple, Optional, Tuple\n",
"import os.path\n",
"import sys\n",
"import math\n",
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"import jax\n",
"import unittest\n",
"import jax.numpy as jnp\n",
"from jax.sharding import Mesh\n",
"from MaxText.globals import PKG_DIR\n",
"from MaxText import pyconfig\n",
"from MaxText import maxtext_utils\n",
"from MaxText.layers import attentions, embeddings, llama4\n",
"import numpy as np\n",
"\n",
"Attention = attentions.Attention\n",
"\n",
"# pylint: disable=line-too-long, missing-function-docstring\n",
"\n",
"\"\"\" \n",
"Llama4 Vision RoPE \n",
"Details https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py\n",
"\"\"\"\n",
"\n",
"\n",
"def to_jax(pt_tensor: torch.Tensor) -> jax.Array:\n",
" return jnp.asarray(pt_tensor.detach().numpy())\n",
"\n",
"\n",
"### original Pytorch Reference implementation\n",
"def reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):\n",
" \"\"\"Reshape the frequency tensor for broadcasting.\"\"\"\n",
" ndim = query.ndim\n",
" shape = [d if i in (1, ndim - 1) else 1 for i, d in enumerate(query.shape)]\n",
" return freqs_ci.view(*shape)\n",
"\n",
"\n",
"class Llama4UnfoldConvolutionTest(unittest.TestCase):\n",
" \"\"\"Test for the Llama4 Unfold Convolution implementation.\"\"\"\n",
"\n",
" def __copy_weights(self, pt_model, params):\n",
" \"\"\"Copy weights from PyTorch model to JAX model.\n",
"\n",
" Args:\n",
" pt_model: PyTorch Llama4UnfoldConvolution model\n",
" params: JAX model parameters\n",
" \"\"\"\n",
" # Create new params with copied weights\n",
" updated_params = jax.tree_util.tree_map(lambda x: x, params)\n",
" updated_params[\"params\"][\"vit_unfold_linear\"][\"kernel\"] = to_jax(pt_model.linear.weight).T\n",
" return updated_params\n",
"\n",
" def test_unfold_convolution(self):\n",
" \"\"\"Test for the Llama4 Unfold Convolution implementation.\"\"\"\n",
" # Test parameters\n",
" # following the llama4 config\n",
" # https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/config.json\n",
" batch_size = 10\n",
" num_channels = 3\n",
" image_size = 336\n",
" patch_size = 14\n",
" hidden_size = 1408\n",
"\n",
" # Create random input tensor\n",
" inputs_pt = torch.randn(batch_size, num_channels, image_size, image_size)\n",
"\n",
" # PyTorch implementation\n",
" # following llama4 implementation in\n",
" # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py#L1279\n",
" class Llama4UnfoldConvolution(nn.Module):\n",
" \"\"\"Llama4 Unfold Convolution implementation.\"\"\"\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.patch_size = patch_size\n",
" self.hidden_size = hidden_size\n",
" kernel_size = patch_size\n",
" if isinstance(kernel_size, int):\n",
" kernel_size = (kernel_size, kernel_size)\n",
" self.unfold = nn.Unfold(kernel_size=kernel_size, stride=patch_size)\n",
" self.linear = nn.Linear(num_channels * kernel_size[0] * kernel_size[1], hidden_size, bias=False)\n",
"\n",
" def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n",
" # num_patches = (self.image_size // self.patch_size) ** 2\n",
" # hidden_states shape: torch.Size([batch_size, num_channels, img, img])\n",
" hidden_states = self.unfold(hidden_states)\n",
" # hidden_states shape: torch.Size([batch_size, num_channels * patch_size * patch_size, num_patches])\n",
" hidden_states = hidden_states.permute(0, 2, 1)\n",
" # hidden_states shape: torch.Size([batch_size, num_patches, num_channels * patch_size * patch_size])\n",
" hidden_states = self.linear(hidden_states)\n",
" # hidden_states shape: torch.Size([batch_size, num_patches, hidden_size])\n",
" return hidden_states\n",
"\n",
" # Initialize PyTorch model\n",
" pt_model = Llama4UnfoldConvolution()\n",
" pt_model.eval()\n",
" pt_output = pt_model(inputs_pt)\n",
"\n",
" # JAX implementation\n",
" class JaxConfig:\n",
"\n",
" def __init__(self):\n",
" self.patch_size_for_vit = patch_size\n",
" self.hidden_size_for_vit = hidden_size\n",
" self.dtype_mm = jnp.float32\n",
"\n",
" # Initialize JAX model\n",
" jax_model = llama4.Llama4UnfoldConvolution(JaxConfig())\n",
" params = jax_model.init(jax.random.PRNGKey(0), to_jax(inputs_pt))\n",
"\n",
" # Copy weights from PyTorch to JAX\n",
" pt_params = self.__copy_weights(pt_model, params)\n",
"\n",
" # Run JAX forward pass with updated params\n",
" jax_output = jax_model.apply(pt_params, to_jax(inputs_pt))\n",
"\n",
" # Compare shapes\n",
" self.assertEqual(pt_output.shape, jax_output.shape)\n",
"\n",
" # Compare outputs with reasonable tolerances\n",
" np.testing.assert_allclose(to_jax(pt_output), jax_output, rtol=1e-3, atol=0.05)\n",
"\n",
"\n",
"class Llama4VisionPixelShuffleMLPTest(unittest.TestCase):\n",
" \"\"\"Test for the Llama4 Vision Pixel Shuffle MLP implementation.\"\"\"\n",
"\n",
" def __copy_weights(self, pt_model, params):\n",
" \"\"\"Copy weights from PyTorch model to JAX model.\n",
"\n",
" Args:\n",
" pt_model: PyTorch Llama4VisionPixelShuffleMLP model\n",
" params: JAX model parameters\n",
" \"\"\"\n",
" # Create new params with copied weights\n",
" updated_params = jax.tree_util.tree_map(lambda x: x, params)\n",
" # Copy weights for both MLP layers\n",
" updated_params[\"params\"][\"pixel_shuffle_mlp\"][\"vit_pixel_shuffle_mlp_fc1\"][\"kernel\"] = to_jax(pt_model.mlp.fc1.weight).T\n",
" updated_params[\"params\"][\"pixel_shuffle_mlp\"][\"vit_pixel_shuffle_mlp_fc2\"][\"kernel\"] = to_jax(pt_model.mlp.fc2.weight).T\n",
" return updated_params\n",
"\n",
" def test_pixel_shuffle_mlp(self):\n",
" \"\"\"Test for the Llama4 Vision Pixel Shuffle MLP implementation.\"\"\"\n",
" # Test parameters\n",
" # following config https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/config.json\n",
" batch_size = 10\n",
" num_patches = 24 * 24 # 336/14 = 24 patches per side\n",
" hidden_size = 1408\n",
" intermediate_size = 5632\n",
" projector_input_dim = 4096\n",
" projector_output_dim = 4096\n",
" pixel_shuffle_ratio = 0.5\n",
" projector_dropout = 0.0\n",
"\n",
" def pixel_shuffle(input_tensor, shuffle_ratio):\n",
" # input_tensor: [batch_size, num_patches, channels]\n",
" batch_size, num_patches, channels = input_tensor.shape\n",
" patch_size = int(math.sqrt(num_patches))\n",
"\n",
" input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)\n",
" batch_size, height, width, channels = input_tensor.size()\n",
"\n",
" reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio))\n",
" reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()\n",
"\n",
" reshaped_tensor = reshaped_tensor.view(\n",
" batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2))\n",
" )\n",
" reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()\n",
"\n",
" output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])\n",
" return output_tensor\n",
"\n",
" # PyTorch implementation\n",
" class Llama4VisionMLP2(nn.Module):\n",
" \"\"\"Llama4 Vision MLP2 implementation.\"\"\"\n",
"\n",
" def __init__(self, config):\n",
" super().__init__()\n",
" self.hidden_size = config.hidden_size\n",
" self.intermediate_size = config.intermediate_size\n",
" self.fc1 = nn.Linear(self.intermediate_size, config.projector_input_dim, bias=False)\n",
" self.fc2 = nn.Linear(config.projector_output_dim, config.projector_output_dim, bias=False)\n",
" self.activation_fn = nn.GELU()\n",
" self.dropout = config.projector_dropout\n",
"\n",
" def forward(self, hidden_states):\n",
" hidden_states = self.fc1(hidden_states)\n",
" hidden_states = self.activation_fn(hidden_states)\n",
" hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)\n",
" return self.activation_fn(self.fc2(hidden_states))\n",
"\n",
" class Llama4VisionPixelShuffleMLP(nn.Module):\n",
" \"\"\"Llama4 Vision Pixel Shuffle MLP implementation.\"\"\"\n",
"\n",
" def __init__(self, config):\n",
" super().__init__()\n",
" self.pixel_shuffle_ratio = config.pixel_shuffle_ratio\n",
" self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2))\n",
" self.output_dim = config.projector_output_dim\n",
" self.mlp = Llama4VisionMLP2(config)\n",
"\n",
" def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:\n",
" # encoded_patches shape: torch.Size([batch_size, num_patches, hidden_size])\n",
" encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)\n",
" return self.mlp(encoded_patches)\n",
" # result shape: torch.Size([batch_size, num_patches * (pixel_shuffle_rate**2), projector_output_dim])\n",
"\n",
" # Initialize PyTorch model\n",
" class Config:\n",
"\n",
" def __init__(self):\n",
" self.hidden_size = hidden_size\n",
" self.intermediate_size = intermediate_size\n",
" self.projector_input_dim = projector_input_dim\n",
" self.projector_output_dim = projector_output_dim\n",
" self.pixel_shuffle_ratio = pixel_shuffle_ratio\n",
" self.projector_dropout = projector_dropout\n",
"\n",
" # Create random input tensor\n",
" inputs_pt = torch.randn(batch_size, num_patches, hidden_size)\n",
"\n",
" pt_model = Llama4VisionPixelShuffleMLP(Config())\n",
" pt_model.eval()\n",
" pt_output = pt_model(inputs_pt)\n",
"\n",
" # JAX implementation\n",
" class JaxConfig:\n",
"\n",
" def __init__(self):\n",
" self.pixel_shuffle_ratio_for_vit = pixel_shuffle_ratio\n",
" self.projector_input_dim_for_vit = projector_input_dim\n",
" self.projector_output_dim_for_vit = projector_output_dim\n",
" self.dtype_mm = jnp.float32\n",
" self.projector_dropout_for_vit = projector_dropout\n",
"\n",
" # Initialize JAX model\n",
" jax_model = llama4.Llama4VisionPixelShuffleMLP(JaxConfig())\n",
" params = jax_model.init(jax.random.PRNGKey(0), to_jax(inputs_pt))\n",
"\n",
" # Copy weights from PyTorch to JAX\n",
" pt_params = self.__copy_weights(pt_model, params)\n",
"\n",
" # Run JAX forward pass with updated params\n",
" jax_output = jax_model.apply(pt_params, to_jax(inputs_pt), deterministic=True)\n",
"\n",
" # Compare shapes\n",
" self.assertEqual(pt_output.shape, jax_output.shape)\n",
"\n",
" # Compare outputs with reasonable tolerances\n",
" np.testing.assert_allclose(to_jax(pt_output), jax_output, rtol=1e-3, atol=0.05)\n",
" print(f\"jax {to_jax(pt_output).mean()} pt {pt_output.mean()}\")\n",
"\n",
"testclass = Llama4VisionPixelShuffleMLPTest()\n",
"testclass.test_pixel_shuffle_mlp()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading