Skip to content

Commit 8ed4a7b

Browse files
Allow specifying the BERT text classifier's model output tensor name via options.
PiperOrigin-RevId: 638652416
1 parent 3247254 commit 8ed4a7b

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

‎tensorflow_lite_support/cc/task/text/bert_nl_classifier.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ StatusOr<std::vector<core::Category>> BertNLClassifier::Postprocess(
7474
}
7575
const TfLiteTensor* scores = FindTensorByName(
7676
output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(),
77-
kScoreTensorName);
77+
options_->has_output_tensor_name() ? options_->output_tensor_name()
78+
: kScoreTensorName);
7879

7980
// optional labels extracted from metadata
8081
return BuildResults(scores, /*labels=*/nullptr);

‎tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options.proto

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package tflite.task.text;
2020
import "tensorflow_lite_support/cc/task/core/proto/base_options.proto";
2121

2222
// Options for setting up a BertNLClassifier.
23-
// Next Id: 3
23+
// Next Id: 4
2424
message BertNLClassifierOptions {
2525
// Base options for configuring BertNLClassifier, such as specifying the
2626
// TfLite model file with metadata, accelerator options, etc.
@@ -31,4 +31,9 @@ message BertNLClassifierOptions {
3131
// Deprecated: max_seq_len is now read from the model (i.e. input tensor size)
3232
// automatically.
3333
optional int32 max_seq_len = 2 [default = 128];
34+
35+
// The name of the output tensor.
36+
//
37+
// If not provided, defaults to "probability".
38+
optional string output_tensor_name = 3;
3439
}

0 commit comments

Comments
 (0)