Skip to content

Commit 2037303

Browse files
authored
Adds support for Connection/Hook discovery from providers (#12466)
* Adds support for Hook discovery from providers This PR extends providers discovery with the mechanism of retrieving mapping of connections from type to hook. Fixes #12456 * fixup! Adds support for Hook discovery from providers * fixup! fixup! Adds support for Hook discovery from providers
1 parent c9d1ea5 commit 2037303

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+462
-197
lines changed

β€Ž.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ repos:
457457
entry: ./scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py
458458
language: python
459459
require_serial: true
460-
files: provider.yaml$
460+
files: provider.yaml$|scripts/ci/pre_commit/pre_commit_check_provider_yaml_files.py$
461461
additional_dependencies: ['PyYAML==5.3.1', 'jsonschema==3.2.0', 'tabulate==0.8.7']
462462
- id: mermaid
463463
name: Generate mermaid images

β€Žairflow/cli/cli_parser.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,12 @@ class GroupCommand(NamedTuple):
11531153
),
11541154
)
11551155
PROVIDERS_COMMANDS = (
1156+
ActionCommand(
1157+
name='hooks',
1158+
help='List registered provider hooks',
1159+
func=lazy_load_command('airflow.cli.commands.provider_command.hooks_list'),
1160+
args=(ARG_OUTPUT,),
1161+
),
11561162
ActionCommand(
11571163
name='list',
11581164
help='List installed providers',

β€Žairflow/cli/commands/provider_command.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Providers sub-commands"""
18-
from typing import Dict, List
18+
from typing import Dict, List, Tuple
1919

2020
import pygments
2121
import yaml
@@ -44,8 +44,7 @@ def provider_get(args):
4444
"""Get a provider info."""
4545
providers = ProvidersManager().providers
4646
if args.provider_name in providers:
47-
provider_version = providers[args.provider_name][0]
48-
provider_info = providers[args.provider_name][1]
47+
provider_version, provider_info = providers[args.provider_name]
4948
print("#")
5049
print(f"# Provider: {args.provider_name}")
5150
print(f"# Version: {provider_version}")
@@ -64,3 +63,23 @@ def provider_get(args):
6463
def providers_list(args):
6564
"""Lists all providers at the command line"""
6665
print(_tabulate_providers(ProvidersManager().providers.values(), args.output))
66+
67+
68+
def _tabulate_hooks(hook_items: Tuple[str, Tuple[str, str]], tablefmt: str):
69+
tabulate_data = [
70+
{
71+
'Connection type': hook_item[0],
72+
'Hook class': hook_item[1][0],
73+
'Hook connection attribute name': hook_item[1][1],
74+
}
75+
for hook_item in hook_items
76+
]
77+
78+
msg = tabulate(tabulate_data, tablefmt=tablefmt, headers='keys')
79+
return msg
80+
81+
82+
def hooks_list(args):
83+
"""Lists all hooks at the command line"""
84+
msg = _tabulate_hooks(ProvidersManager().hooks.items(), args.output)
85+
print(msg)

β€Žairflow/models/connection.py

Lines changed: 3 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -30,70 +30,10 @@
3030
from airflow.exceptions import AirflowException, AirflowNotFoundException
3131
from airflow.models.base import ID_LEN, Base
3232
from airflow.models.crypto import get_fernet
33+
from airflow.providers_manager import ProvidersManager
3334
from airflow.utils.log.logging_mixin import LoggingMixin
3435
from airflow.utils.module_loading import import_string
3536

36-
# A map that assigns a connection type to a tuple that contains
37-
# the path of the class and the name of the conn_id key parameter.
38-
# PLEASE KEEP BELOW LIST IN ALPHABETICAL ORDER.
39-
CONN_TYPE_TO_HOOK = {
40-
"azure_batch": (
41-
"airflow.providers.microsoft.azure.hooks.azure_batch.AzureBatchHook",
42-
"azure_batch_conn_id",
43-
),
44-
"azure_cosmos": (
45-
"airflow.providers.microsoft.azure.hooks.azure_cosmos.AzureCosmosDBHook",
46-
"azure_cosmos_conn_id",
47-
),
48-
"azure_data_lake": (
49-
"airflow.providers.microsoft.azure.hooks.azure_data_lake.AzureDataLakeHook",
50-
"azure_data_lake_conn_id",
51-
),
52-
"cassandra": ("airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook", "cassandra_conn_id"),
53-
"cloudant": ("airflow.providers.cloudant.hooks.cloudant.CloudantHook", "cloudant_conn_id"),
54-
"dataprep": ("airflow.providers.google.cloud.hooks.dataprep.GoogleDataprepHook", "dataprep_default"),
55-
"docker": ("airflow.providers.docker.hooks.docker.DockerHook", "docker_conn_id"),
56-
"elasticsearch": (
57-
"airflow.providers.elasticsearch.hooks.elasticsearch.ElasticsearchHook",
58-
"elasticsearch_conn_id",
59-
),
60-
"exasol": ("airflow.providers.exasol.hooks.exasol.ExasolHook", "exasol_conn_id"),
61-
"gcpcloudsql": (
62-
"airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook",
63-
"gcp_cloudsql_conn_id",
64-
),
65-
"gcpssh": (
66-
"airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook",
67-
"gcp_conn_id",
68-
),
69-
"google_cloud_platform": (
70-
"airflow.providers.google.cloud.hooks.bigquery.BigQueryHook",
71-
"bigquery_conn_id",
72-
),
73-
"grpc": ("airflow.providers.grpc.hooks.grpc.GrpcHook", "grpc_conn_id"),
74-
"hive_cli": ("airflow.providers.apache.hive.hooks.hive.HiveCliHook", "hive_cli_conn_id"),
75-
"hiveserver2": ("airflow.providers.apache.hive.hooks.hive.HiveServer2Hook", "hiveserver2_conn_id"),
76-
"imap": ("airflow.providers.imap.hooks.imap.ImapHook", "imap_conn_id"),
77-
"jdbc": ("airflow.providers.jdbc.hooks.jdbc.JdbcHook", "jdbc_conn_id"),
78-
"jira": ("airflow.providers.jira.hooks.jira.JiraHook", "jira_conn_id"),
79-
"kubernetes": ("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook", "kubernetes_conn_id"),
80-
"mongo": ("airflow.providers.mongo.hooks.mongo.MongoHook", "conn_id"),
81-
"mssql": ("airflow.providers.odbc.hooks.odbc.OdbcHook", "odbc_conn_id"),
82-
"mysql": ("airflow.providers.mysql.hooks.mysql.MySqlHook", "mysql_conn_id"),
83-
"odbc": ("airflow.providers.odbc.hooks.odbc.OdbcHook", "odbc_conn_id"),
84-
"oracle": ("airflow.providers.oracle.hooks.oracle.OracleHook", "oracle_conn_id"),
85-
"pig_cli": ("airflow.providers.apache.pig.hooks.pig.PigCliHook", "pig_cli_conn_id"),
86-
"postgres": ("airflow.providers.postgres.hooks.postgres.PostgresHook", "postgres_conn_id"),
87-
"presto": ("airflow.providers.presto.hooks.presto.PrestoHook", "presto_conn_id"),
88-
"redis": ("airflow.providers.redis.hooks.redis.RedisHook", "redis_conn_id"),
89-
"snowflake": ("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook", "snowflake_conn_id"),
90-
"sqlite": ("airflow.providers.sqlite.hooks.sqlite.SqliteHook", "sqlite_conn_id"),
91-
"tableau": ("airflow.providers.salesforce.hooks.tableau.TableauHook", "tableau_conn_id"),
92-
"vertica": ("airflow.providers.vertica.hooks.vertica.VerticaHook", "vertica_conn_id"),
93-
"wasb": ("airflow.providers.microsoft.azure.hooks.wasb.WasbHook", "wasb_conn_id"),
94-
}
95-
# PLEASE KEEP ABOVE LIST IN ALPHABETICAL ORDER.
96-
9737

9838
def parse_netloc_to_hostname(*args, **kwargs):
9939
"""This method is deprecated."""
@@ -326,7 +266,8 @@ def rotate_fernet_key(self):
326266

327267
def get_hook(self):
328268
"""Return hook based on conn_type."""
329-
hook_class_name, conn_id_param = CONN_TYPE_TO_HOOK.get(self.conn_type, (None, None))
269+
hook_class_name, conn_id_param = ProvidersManager().hooks.get(self.conn_type, (None, None))
270+
330271
if not hook_class_name:
331272
raise AirflowException(f'Unknown hook type "{self.conn_type}"')
332273
hook_class = import_string(hook_class_name)

β€Žairflow/plugins_manager.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import importlib_metadata
3030

3131
from airflow import settings
32+
from airflow.utils.entry_points import entry_points_with_dist
3233
from airflow.utils.file import find_path_from_directory
3334

3435
if TYPE_CHECKING:
@@ -169,23 +170,6 @@ def is_valid_plugin(plugin_obj):
169170
return False
170171

171172

172-
def entry_points_with_dist(group: str):
173-
"""
174-
Return EntryPoint objects of the given group, along with the distribution information.
175-
176-
This is like the ``entry_points()`` function from importlib.metadata,
177-
except it also returns the distribution the entry_point was loaded from.
178-
179-
:param group: FIlter results to only this entrypoint group
180-
:return: Generator of (EntryPoint, Distribution) objects for the specified groups
181-
"""
182-
for dist in importlib_metadata.distributions():
183-
for e in dist.entry_points:
184-
if e.group != group:
185-
continue
186-
yield (e, dist)
187-
188-
189173
def load_entrypoint_plugins():
190174
"""
191175
Load and register plugins AirflowPlugin subclasses from the entrypoints.

β€Žairflow/provider.yaml.schema.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,13 @@
173173
"python-module"
174174
]
175175
}
176+
},
177+
"hook-class-names": {
178+
"type": "array",
179+
"description": "Hook class names that provide connection types to core",
180+
"items": {
181+
"type": "string"
182+
}
176183
}
177184
},
178185
"additionalProperties": false,

β€Žairflow/providers/apache/cassandra/hooks/cassandra.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ class CassandraHook(BaseHook, LoggingMixin):
8383
For details of the Cluster config, see cassandra.cluster.
8484
"""
8585

86-
def __init__(self, cassandra_conn_id: str = 'cassandra_default'):
86+
conn_name_attr = 'cassandra_conn_id'
87+
default_conn_name = 'cassandra_default'
88+
conn_type = 'cassandra'
89+
90+
def __init__(self, cassandra_conn_id: str = default_conn_name):
8791
super().__init__()
8892
conn = self.get_connection(cassandra_conn_id)
8993

β€Žairflow/providers/apache/cassandra/provider.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,6 @@ hooks:
4141
- integration-name: Apache Cassandra
4242
python-modules:
4343
- airflow.providers.apache.cassandra.hooks.cassandra
44+
45+
hook-class-names:
46+
- airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook

β€Žairflow/providers/apache/hive/hooks/hive.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,13 @@ class HiveCliHook(BaseHook):
7878
:type mapred_job_name: str
7979
"""
8080

81+
conn_name_attr = 'hive_cli_conn_id'
82+
default_conn_name = 'hive_cli_default'
83+
conn_type = 'hive_cli'
84+
8185
def __init__(
8286
self,
83-
hive_cli_conn_id: str = "hive_cli_default",
87+
hive_cli_conn_id: str = default_conn_name,
8488
run_as: Optional[str] = None,
8589
mapred_queue: Optional[str] = None,
8690
mapred_queue_priority: Optional[str] = None,
@@ -809,6 +813,7 @@ class HiveServer2Hook(DbApiHook):
809813

810814
conn_name_attr = 'hiveserver2_conn_id'
811815
default_conn_name = 'hiveserver2_default'
816+
conn_type = 'hiveserver2'
812817
supports_autocommit = False
813818

814819
def get_conn(self, schema: Optional[str] = None) -> Any:

β€Žairflow/providers/apache/hive/provider.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,7 @@ transfers:
6666
- source-integration-name: Microsoft SQL Server (MSSQL)
6767
target-integration-name: Apache Hive
6868
python-module: airflow.providers.apache.hive.transfers.mssql_to_hive
69+
70+
hook-class-names:
71+
- airflow.providers.apache.hive.hooks.hive.HiveCliHook
72+
- airflow.providers.apache.hive.hooks.hive.HiveServer2Hook

0 commit comments

Comments
 (0)