Skip to content

docs: add snippet for evaluating a boosted tree model #1154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 20, 2024
33 changes: 28 additions & 5 deletions samples/snippets/classification_boosted_tree_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,42 @@ def test_boosted_tree_model(random_model_id: str) -> None:
y = training_data["income_bracket"]

# create and train the model
census_model = ensemble.XGBClassifier(
tree_model = ensemble.XGBClassifier(
n_estimators=1,
booster="gbtree",
tree_method="hist",
max_iterations=1, # For a more accurate model, try 50 iterations.
subsample=0.85,
)
census_model.fit(X, y)
tree_model.fit(X, y)

census_model.to_gbq(
your_model_id, # For example: "your-project.census.census_model"
tree_model.to_gbq(
your_model_id, # For example: "your-project.bqml_tutorial.tree_model"
replace=True,
)
# [END bigquery_dataframes_bqml_boosted_tree_create]
# [START bigquery_dataframes_bqml_boosted_tree_explain]
# Select model you'll use for predictions. `read_gbq_model` loads model
# data from BigQuery, but you could also use the `tree_model` object
# from the previous step.
tree_model = bpd.read_gbq_model(
your_model_id, # For example: "your-project.bqml_tutorial.tree_model"
)

# input_data is defined in an earlier step.
evaluation_data = input_data[input_data["dataframe"] == "evaluation"]
X = evaluation_data.drop(columns=["income_bracket", "dataframe"])
y = evaluation_data["income_bracket"]

# The score() method evaluates how the model performs compared to the
# actual data. Output DataFrame matches that of ML.EVALUATE().
score = tree_model.score(X, y)
score.peek()
# Output:
# precision recall accuracy f1_score log_loss roc_auc
# 0 0.671924 0.578804 0.839429 0.621897 0.344054 0.887335
# [END bigquery_dataframes_bqml_boosted_tree_explain]
assert tree_model is not None
assert evaluation_data is not None
assert score is not None
assert input_data is not None
assert census_model is not None