Skip to content
24 changes: 19 additions & 5 deletions sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from sys import version_info

import pydantic
from pydantic import Field, computed_field
from packaging import version
from pydantic import Field
from pydantic_core import from_json
from sqlglot import exp
from sqlglot.errors import ParseError
Expand Down Expand Up @@ -110,7 +110,14 @@ class ConnectionConfig(abc.ABC, BaseConfig):
catalog_type_overrides: t.Optional[t.Dict[str, str]] = None

# Whether to share a single connection across threads or create a new connection per thread.
shared_connection: t.ClassVar[bool] = False
#
# MyPy throws a "Decorators on top of @property are not supported" error despite this being a
# valid decoration, and Pydantic recommend disabling the MyPy hint for this reason - see:
# https://pydantic.dev/docs/validation/2.0/usage/computed_fields/
@computed_field # type: ignore[prop-decorator]
@property
def shared_connection(self) -> bool:
return False

@property
@abc.abstractmethod
Expand Down Expand Up @@ -311,7 +318,10 @@ class BaseDuckDBConnectionConfig(ConnectionConfig):

token: t.Optional[str] = None

shared_connection: t.ClassVar[bool] = True
@computed_field # type: ignore[prop-decorator]
@property
def shared_connection(self) -> bool:
return True

_data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {}

Expand Down Expand Up @@ -820,11 +830,15 @@ class DatabricksConnectionConfig(ConnectionConfig):
DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks"
DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3

shared_connection: t.ClassVar[bool] = True

_concurrent_tasks_validator = concurrent_tasks_validator
_http_headers_validator = http_headers_validator

@computed_field # type: ignore[prop-decorator]
@property
def shared_connection(self) -> bool:
"""The connection should only be shared if U2M OAuth is being used"""
return self.auth_type is not None and self.oauth_client_id is None

@model_validator(mode="before")
def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
# SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block.
Expand Down
Loading
Loading