借助 LLM Inference API,您可以完全在设备端为 Android 应用运行大语言模型 (LLM),并使用这些模型执行各种任务,例如生成文本、以自然语言形式检索信息以及总结文档。该任务内置对多个文本到文本大语言模型的支持,因此您可以将最新的设备端生成式 AI 模型应用于 Android 应用。
如需快速将 LLM Inference API 添加到 Android 应用,请按照快速入门中的说明操作。如需查看运行 LLM Inference API 的 Android 应用的基本示例,请参阅示例应用。如需更深入地了解 LLM Inference API 的运作方式,请参阅配置选项、模型转换和 LoRA 调优部分。
您可以通过 MediaPipe Studio 演示查看此任务的实际运作方式。如需详细了解此任务的功能、模型和配置选项,请参阅概览。
快速入门
请按以下步骤将 LLM Inference API 添加到您的 Android 应用。 LLM Inference API 针对高端 Android 设备(例如 Pixel 8 和 Samsung S23 或更高版本)进行了优化,并且无法可靠地支持设备模拟器。
添加依赖项
LLM Inference API 使用 com.google.mediapipe:tasks-genai
库。将以下依赖项添加到 Android 应用的 build.gradle
文件中:
dependencies {
implementation 'com.google.mediapipe:tasks-genai:0.10.24'
}
下载模型
从 Hugging Face 下载采用 4 位量化格式的 Gemma-3 1B。如需详细了解可用模型,请参阅“模型”文档。
将 output_path 文件夹的内容推送到 Android 设备。
$ adb shell rm -r /data/local/tmp/llm/ # Remove any previously loaded models
$ adb shell mkdir -p /data/local/tmp/llm/
$ adb push output_path /data/local/tmp/llm/model_version.task
初始化任务
使用基本配置选项初始化任务:
// Set the configuration options for the LLM Inference task
val taskOptions = LlmInferenceOptions.builder()
.setModelPath('/data/local/tmp/llm/model_version.task')
.setMaxTopK(64)
.build()
// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, taskOptions)
运行任务
使用 generateResponse()
方法生成文本回答。这会生成单个生成的回答。
val result = llmInference.generateResponse(inputPrompt)
logger.atInfo().log("result: $result")
如需流式传输响应,请使用 generateResponseAsync()
方法。
val options = LlmInference.LlmInferenceOptions.builder()
...
.setResultListener { partialResult, done ->
logger.atInfo().log("partial result: $partialResult")
}
.build()
llmInference.generateResponseAsync(inputPrompt)
示例应用
如需查看 LLM 推理 API 的实际运作情况,并探索一系列全面的设备端生成式 AI 功能,请查看 Google AI Edge Gallery 应用。
Google AI Edge Gallery 是一款开源 Android 应用,可供开发者进行互动式开发。其中展示了:
- 将 LLM Inference API 用于各种任务的实用示例,包括:
- 询问图片:上传图片并询问与其相关的问题。获取说明、解决问题或识别对象。
- 提示实验室:总结、重写、生成代码或使用自由形式提示来探索单轮 LLM 用例。
- AI 聊天:进行多轮对话。
- 能够从 Hugging Face LiteRT 社区和官方 Google 版本(例如 Gemma 3N)中发现、下载和实验各种经过 LiteRT 优化的模型。
- 不同模型的实时设备端性能基准(首次出现令牌的时间、解码速度等)。
- 如何导入和测试您自己的自定义
.task
模型。
此应用可帮助您了解 LLM 推理 API 的实际实现以及设备端生成式 AI 的潜力。从 Google AI Edge Gallery GitHub 代码库浏览源代码并下载应用。
配置选项
使用以下配置选项设置 Android 应用:
选项名称 | 说明 | 值范围 | 默认值 |
---|---|---|---|
modelPath |
模型在项目目录中的存储路径。 | 路径 | 不适用 |
maxTokens |
模型处理的词元(输入词元 + 输出词元)数量上限。 | 整数 | 512 |
topK |
模型在生成过程中每个步骤考虑的令牌数。 将预测限制为前 k 个概率最高的 token。 | 整数 | 40 |
temperature |
生成过程中引入的随机性程度。温度越高,生成的文本就越具创造性;温度越低,生成的文本就越具可预测性。 | 浮点数 | 0.8 |
randomSeed |
文本生成期间使用的随机种子。 | 整数 | 0 |
loraPath |
设备本地 LoRA 模型的绝对路径。注意:此功能仅适用于 GPU 型号。 | 路径 | 不适用 |
resultListener |
设置结果监听器以异步接收结果。 仅在使用异步生成方法时适用。 | 不适用 | 不适用 |
errorListener |
设置可选的错误监听器。 | 不适用 | 不适用 |
多模态提示
LLM Inference API Android API 支持多模态提示,其中的模型接受文本和图片输入。��用多模态功能后,用户可以在提示中同时包含图片和文字,LLM 会提供文本回答。
首先,使用与 MediaPipe 兼容的 Gemma 3n 变体:
- Gemma-3n E2B:Gemma-3n 系列的 2B 模型。
- Gemma-3n E4B:Gemma-3n 系列的 4B 模型。
如需了解详情,请参阅 Gemma-3n 文档。
如需在问题中提供图片,请先将输入图片或帧转换为 com.google.mediapipe.framework.image.MPImage
对象,然后再将其传递给 LLM 推理 API:
import com.google.mediapipe.framework.image.BitmapImageBuilder
import com.google.mediapipe.framework.image.MPImage
// Convert the input Bitmap object to an MPImage object to run inference
val mpImage = BitmapImageBuilder(image).build()
如需为 LLM Inference API 启用视觉支持,请在 Graph 选项中将 EnableVisionModality
配置选项设置为 true
:
LlmInferenceSession.LlmInferenceSessionOptions sessionOptions =
LlmInferenceSession.LlmInferenceSessionOptions.builder()
...
.setGraphOptions(GraphOptions.builder().setEnableVisionModality(true).build())
.build();
Gemma-3n 每次会话最多接受一张图片,因此请将 MaxNumImages
设置为 1。
LlmInferenceOptions options = LlmInferenceOptions.builder()
...
.setMaxNumImages(1)
.build();
以下是设置为处理视觉和文本输入的 LLM Inference API 的实现示例:
MPImage image = getImageFromAsset(BURGER_IMAGE);
LlmInferenceSession.LlmInferenceSessionOptions sessionOptions =
LlmInferenceSession.LlmInferenceSessionOptions.builder()
.setTopK(10)
.setTemperature(0.4f)
.setGraphOptions(GraphOptions.builder().setEnableVisionModality(true).build())
.build();
try (LlmInference llmInference =
LlmInference.createFromOptions(ApplicationProvider.getApplicationContext(), options);
LlmInferenceSession session =
LlmInferenceSession.createFromOptions(llmInference, sessionOptions)) {
session.addQueryChunk("Describe the objects in the image.");
session.addImage(image);
String result = session.generateResponse();
}
LoRA 自定义
LLM Inference API 支持使用 PEFT(参数高效微调)库进行 LoRA(低秩自适应)调优。LoRA 调优通过经济高效的训练流程自定义 LLM 的行为,根据新训练数据创建一小组可训练权重,而不是重新训练整个模型。
LLM Inference API 支持向 Gemma-2 2B、Gemma 2B 和 Phi-2 模型的注意力层添加 LoRA 权重。下载 safetensors
格式的模型。
基本模型必须采用 safetensors
格式,才能创建 LoRA 权重。完成 LoRA 训练后,您可以将模型转换为 FlatBuffers 格式,以便在 MediaPipe 上运行。
准备 LoRA 权重
使用 PEFT 中的 LoRA 方法指南,基于您自己的数据集训练经过微调的 LoRA 模型。
LLM Inference API 仅支持在注意力层上使用 LoRA,因此请仅在 LoraConfig
中指定注意力层:
# For Gemma
from peft import LoraConfig
config = LoraConfig(
r=LORA_RANK,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)
# For Phi-2
config = LoraConfig(
r=LORA_RANK,
target_modules=["q_proj", "v_proj", "k_proj", "dense"],
)
使用准备好的数据集进行训练并保存模型后,adapter_model.safetensors
中会提供经过微调的 LoRA 模型权重。safetensors
文件是模型转换期间使用的 LoRA 检查点。
模型转换
使用 MediaPipe Python 软件包将模型权重转换为 Flatbuffer 格式。ConversionConfig
指定了基本模型选项以及其他 LoRA 选项。
import mediapipe as mp
from mediapipe.tasks.python.genai import converter
config = converter.ConversionConfig(
# Other params related to base model
...
# Must use gpu backend for LoRA conversion
backend='gpu',
# LoRA related params
lora_ckpt=LORA_CKPT,
lora_rank=LORA_RANK,
lora_output_tflite_file=LORA_OUTPUT_FILE,
)
converter.convert_checkpoint(config)
转换器将生成两个 Flatbuffer 文件,一个用于基准模型,另一个用于 LoRA 模型。
LoRA 模型推理
Android 在初始化期间支持静态 LoRA。如需加载 LoRA 模型,请指定 LoRA 模型路径以及基础 LLM。
// Set the configuration options for the LLM Inference task
val options = LlmInferenceOptions.builder()
.setModelPath(BASE_MODEL_PATH)
.setMaxTokens(1000)
.setTopK(40)
.setTemperature(0.8)
.setRandomSeed(101)
.setLoraPath(LORA_MODEL_PATH)
.build()
// Create an instance of the LLM Inference task
llmInference = LlmInference.createFromOptions(context, options)
如需使用 LoRA 运行 LLM 推理,请使用与基准模型相同的 generateResponse()
或 generateResponseAsync()
方法。