|
34 | 34 | from airflow.models.xcom import XCom
|
35 | 35 | from airflow.providers.common.sql.operators.sql import (
|
36 | 36 | SQLCheckOperator,
|
| 37 | + SQLColumnCheckOperator, |
37 | 38 | SQLIntervalCheckOperator,
|
| 39 | + SQLTableCheckOperator, |
38 | 40 | SQLValueCheckOperator,
|
| 41 | + _get_failed_checks, |
| 42 | + parse_boolean, |
39 | 43 | )
|
40 | 44 | from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
|
41 | 45 | from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
|
@@ -520,6 +524,241 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
|
520 | 524 | )
|
521 | 525 |
|
522 | 526 |
|
| 527 | +class BigQueryColumnCheckOperator(_BigQueryDbHookMixin, SQLColumnCheckOperator): |
| 528 | + """ |
| 529 | + BigQueryColumnCheckOperator subclasses the SQLColumnCheckOperator |
| 530 | + in order to provide a job id for OpenLineage to parse. See base class |
| 531 | + docstring for usage. |
| 532 | +
|
| 533 | + .. seealso:: |
| 534 | + For more information on how to use this operator, take a look at the guide: |
| 535 | + :ref:`howto/operator:BigQueryColumnCheckOperator` |
| 536 | +
|
| 537 | + :param table: the table name |
| 538 | + :param column_mapping: a dictionary relating columns to their checks |
| 539 | + :param partition_clause: a string SQL statement added to a WHERE clause |
| 540 | + to partition data |
| 541 | + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. |
| 542 | + :param use_legacy_sql: Whether to use legacy SQL (true) |
| 543 | + or standard SQL (false). |
| 544 | + :param location: The geographic location of the job. See details at: |
| 545 | + https://cloud.google.com/bigquery/docs/locations#specifying_your_location |
| 546 | + :param impersonation_chain: Optional service account to impersonate using short-term |
| 547 | + credentials, or chained list of accounts required to get the access_token |
| 548 | + of the last account in the list, which will be impersonated in the request. |
| 549 | + If set as a string, the account must grant the originating account |
| 550 | + the Service Account Token Creator IAM role. |
| 551 | + If set as a sequence, the identities from the list must grant |
| 552 | + Service Account Token Creator IAM role to the directly preceding identity, with first |
| 553 | + account from the list granting this role to the originating account (templated). |
| 554 | + :param labels: a dictionary containing labels for the table, passed to BigQuery |
| 555 | + """ |
| 556 | + |
| 557 | + def __init__( |
| 558 | + self, |
| 559 | + *, |
| 560 | + table: str, |
| 561 | + column_mapping: dict, |
| 562 | + partition_clause: str | None = None, |
| 563 | + gcp_conn_id: str = "google_cloud_default", |
| 564 | + use_legacy_sql: bool = True, |
| 565 | + location: str | None = None, |
| 566 | + impersonation_chain: str | Sequence[str] | None = None, |
| 567 | + labels: dict | None = None, |
| 568 | + **kwargs, |
| 569 | + ) -> None: |
| 570 | + super().__init__( |
| 571 | + table=table, column_mapping=column_mapping, partition_clause=partition_clause, **kwargs |
| 572 | + ) |
| 573 | + self.table = table |
| 574 | + self.column_mapping = column_mapping |
| 575 | + self.partition_clause = partition_clause |
| 576 | + self.gcp_conn_id = gcp_conn_id |
| 577 | + self.use_legacy_sql = use_legacy_sql |
| 578 | + self.location = location |
| 579 | + self.impersonation_chain = impersonation_chain |
| 580 | + self.labels = labels |
| 581 | + # OpenLineage needs a valid SQL query with the input/output table(s) to parse |
| 582 | + self.sql = "" |
| 583 | + |
| 584 | + def _submit_job( |
| 585 | + self, |
| 586 | + hook: BigQueryHook, |
| 587 | + job_id: str, |
| 588 | + ) -> BigQueryJob: |
| 589 | + """Submit a new job and get the job id for polling the status using Trigger.""" |
| 590 | + configuration = {"query": {"query": self.sql}} |
| 591 | + |
| 592 | + return hook.insert_job( |
| 593 | + configuration=configuration, |
| 594 | + project_id=hook.project_id, |
| 595 | + location=self.location, |
| 596 | + job_id=job_id, |
| 597 | + nowait=False, |
| 598 | + ) |
| 599 | + |
| 600 | + def execute(self, context=None): |
| 601 | + """Perform checks on the given columns.""" |
| 602 | + hook = self.get_db_hook() |
| 603 | + failed_tests = [] |
| 604 | + for column in self.column_mapping: |
| 605 | + checks = [*self.column_mapping[column]] |
| 606 | + checks_sql = ",".join([self.column_checks[check].replace("column", column) for check in checks]) |
| 607 | + partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else "" |
| 608 | + self.sql = f"SELECT {checks_sql} FROM {self.table} {partition_clause_statement};" |
| 609 | + |
| 610 | + job_id = hook.generate_job_id( |
| 611 | + dag_id=self.dag_id, |
| 612 | + task_id=self.task_id, |
| 613 | + logical_date=context["logical_date"], |
| 614 | + configuration=self.configuration, |
| 615 | + ) |
| 616 | + job = self._submit_job(hook, job_id=job_id) |
| 617 | + context["ti"].xcom_push(key="job_id", value=job.job_id) |
| 618 | + records = list(job.result().to_dataframe().values.flatten()) |
| 619 | + |
| 620 | + if not records: |
| 621 | + raise AirflowException(f"The following query returned zero rows: {self.sql}") |
| 622 | + |
| 623 | + self.log.info("Record: %s", records) |
| 624 | + |
| 625 | + for idx, result in enumerate(records): |
| 626 | + tolerance = self.column_mapping[column][checks[idx]].get("tolerance") |
| 627 | + |
| 628 | + self.column_mapping[column][checks[idx]]["result"] = result |
| 629 | + self.column_mapping[column][checks[idx]]["success"] = self._get_match( |
| 630 | + self.column_mapping[column][checks[idx]], result, tolerance |
| 631 | + ) |
| 632 | + |
| 633 | + failed_tests.extend(_get_failed_checks(self.column_mapping[column], column)) |
| 634 | + if failed_tests: |
| 635 | + raise AirflowException( |
| 636 | + f"Test failed.\nResults:\n{records!s}\n" |
| 637 | + "The following tests have failed:" |
| 638 | + f"\n{''.join(failed_tests)}" |
| 639 | + ) |
| 640 | + |
| 641 | + self.log.info("All tests have passed") |
| 642 | + |
| 643 | + |
| 644 | +class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator): |
| 645 | + """ |
| 646 | + BigQueryTableCheckOperator subclasses the SQLTableCheckOperator |
| 647 | + in order to provide a job id for OpenLineage to parse. See base class |
| 648 | + for usage. |
| 649 | +
|
| 650 | + .. seealso:: |
| 651 | + For more information on how to use this operator, take a look at the guide: |
| 652 | + :ref:`howto/operator:BigQueryTableCheckOperator` |
| 653 | +
|
| 654 | + :param table: the table name |
| 655 | + :param checks: a dictionary of check names and boolean SQL statements |
| 656 | + :param partition_clause: a string SQL statement added to a WHERE clause |
| 657 | + to partition data |
| 658 | + :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. |
| 659 | + :param use_legacy_sql: Whether to use legacy SQL (true) |
| 660 | + or standard SQL (false). |
| 661 | + :param location: The geographic location of the job. See details at: |
| 662 | + https://cloud.google.com/bigquery/docs/locations#specifying_your_location |
| 663 | + :param impersonation_chain: Optional service account to impersonate using short-term |
| 664 | + credentials, or chained list of accounts required to get the access_token |
| 665 | + of the last account in the list, which will be impersonated in the request. |
| 666 | + If set as a string, the account must grant the originating account |
| 667 | + the Service Account Token Creator IAM role. |
| 668 | + If set as a sequence, the identities from the list must grant |
| 669 | + Service Account Token Creator IAM role to the directly preceding identity, with first |
| 670 | + account from the list granting this role to the originating account (templated). |
| 671 | + :param labels: a dictionary containing labels for the table, passed to BigQuery |
| 672 | + """ |
| 673 | + |
| 674 | + def __init__( |
| 675 | + self, |
| 676 | + *, |
| 677 | + table: str, |
| 678 | + checks: dict, |
| 679 | + partition_clause: str | None = None, |
| 680 | + gcp_conn_id: str = "google_cloud_default", |
| 681 | + use_legacy_sql: bool = True, |
| 682 | + location: str | None = None, |
| 683 | + impersonation_chain: str | Sequence[str] | None = None, |
| 684 | + labels: dict | None = None, |
| 685 | + **kwargs, |
| 686 | + ) -> None: |
| 687 | + super().__init__(table=table, checks=checks, partition_clause=partition_clause, **kwargs) |
| 688 | + self.table = table |
| 689 | + self.checks = checks |
| 690 | + self.partition_clause = partition_clause |
| 691 | + self.gcp_conn_id = gcp_conn_id |
| 692 | + self.use_legacy_sql = use_legacy_sql |
| 693 | + self.location = location |
| 694 | + self.impersonation_chain = impersonation_chain |
| 695 | + self.labels = labels |
| 696 | + # OpenLineage needs a valid SQL query with the input/output table(s) to parse |
| 697 | + self.sql = "" |
| 698 | + |
| 699 | + def _submit_job( |
| 700 | + self, |
| 701 | + hook: BigQueryHook, |
| 702 | + job_id: str, |
| 703 | + ) -> BigQueryJob: |
| 704 | + """Submit a new job and get the job id for polling the status using Trigger.""" |
| 705 | + configuration = {"query": {"query": self.sql}} |
| 706 | + |
| 707 | + return hook.insert_job( |
| 708 | + configuration=configuration, |
| 709 | + project_id=hook.project_id, |
| 710 | + location=self.location, |
| 711 | + job_id=job_id, |
| 712 | + nowait=False, |
| 713 | + ) |
| 714 | + |
| 715 | + def execute(self, context=None): |
| 716 | + """Execute the given checks on the table.""" |
| 717 | + hook = self.get_db_hook() |
| 718 | + checks_sql = " UNION ALL ".join( |
| 719 | + [ |
| 720 | + self.sql_check_template.replace("check_statement", value["check_statement"]) |
| 721 | + .replace("_check_name", check_name) |
| 722 | + .replace("table", self.table) |
| 723 | + for check_name, value in self.checks.items() |
| 724 | + ] |
| 725 | + ) |
| 726 | + partition_clause_statement = f"WHERE {self.partition_clause}" if self.partition_clause else "" |
| 727 | + self.sql = f"SELECT check_name, check_result FROM ({checks_sql}) " |
| 728 | + f"AS check_table {partition_clause_statement};" |
| 729 | + |
| 730 | + job_id = hook.generate_job_id( |
| 731 | + dag_id=self.dag_id, |
| 732 | + task_id=self.task_id, |
| 733 | + logical_date=context["logical_date"], |
| 734 | + configuration=self.configuration, |
| 735 | + ) |
| 736 | + job = self._submit_job(hook, job_id=job_id) |
| 737 | + context["ti"].xcom_push(key="job_id", value=job.job_id) |
| 738 | + records = job.result().to_dataframe() |
| 739 | + |
| 740 | + if records.empty: |
| 741 | + raise AirflowException(f"The following query returned zero rows: {self.sql}") |
| 742 | + |
| 743 | + records.columns = records.columns.str.lower() |
| 744 | + self.log.info("Record:\n%s", records) |
| 745 | + |
| 746 | + for row in records.iterrows(): |
| 747 | + check = row[1].get("check_name") |
| 748 | + result = row[1].get("check_result") |
| 749 | + self.checks[check]["success"] = parse_boolean(str(result)) |
| 750 | + |
| 751 | + failed_tests = _get_failed_checks(self.checks) |
| 752 | + if failed_tests: |
| 753 | + raise AirflowException( |
| 754 | + f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n" |
| 755 | + "The following tests have failed:" |
| 756 | + f"\n{', '.join(failed_tests)}" |
| 757 | + ) |
| 758 | + |
| 759 | + self.log.info("All tests have passed") |
| 760 | + |
| 761 | + |
523 | 762 | class BigQueryGetDataOperator(BaseOperator):
|
524 | 763 | """
|
525 | 764 | Fetches the data from a BigQuery table (alternatively fetch data for selected columns)
|
|
0 commit comments