Skip to content

let BigQueryGetData operator take a query string and as_dict flag #24460

@MazrimT

Description

@MazrimT

Description

Today the BigQueryGetData airflow.providers.google.cloud.operators.bigquery.BigQueryGetDataOperator only allows you to point to a specific dataset and table and how many rows you want.

It already sets up a BigQueryHook so it very easy to implement custom query from a string as well.
It would also be very efficient to have a as_dict flag to return the result as a list of dicts.
I am not an expert in python but here is my atempt at a modification of the current code (from 8.0.0rc2)

class BigQueryGetDataOperatorX(BaseOperator):
    """
    Fetches the data from a BigQuery table (alternatively fetch data for selected columns)
    and returns data in a python list. The number of elements in the returned list will
    be equal to the number of rows fetched. Each element in the list will again be a list
    where element would represent the columns values for that row.

    **Example Result**: ``[['Tony', '10'], ['Mike', '20'], ['Steve', '15']]``

    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:BigQueryGetDataOperator`

    .. note::
        If you pass fields to ``selected_fields`` which are in different order than the
        order of columns already in
        BQ table, the data will still be in the order of BQ table.
        For example if the BQ table has 3 columns as
        ``[A,B,C]`` and you pass 'B,A' in the ``selected_fields``
        the data would still be of the form ``'A,B'``.

    **Example**: ::

        get_data = BigQueryGetDataOperator(
            task_id='get_data_from_bq',
            dataset_id='test_dataset',
            table_id='Transaction_partitions',
            max_results=100,
            selected_fields='DATE',
            gcp_conn_id='airflow-conn-id'
        )

    :param dataset_id: The dataset ID of the requested table. (templated)
    :param table_id: The table ID of the requested table. (templated)
    :param max_results: The maximum number of records (rows) to be fetched
        from the table. (templated)
    :param selected_fields: List of fields to return (comma-separated). If
        unspecified, all fields are returned.
    :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
    :param delegate_to: The account to impersonate using domain-wide delegation of authority,
        if any. For this to work, the service account making the request must have
        domain-wide delegation enabled.
    :param location: The location used for the operation.
    :param impersonation_chain: Optional service account to impersonate using short-term
        credentials, or chained list of accounts required to get the access_token
        of the last account in the list, which will be impersonated in the request.
        If set as a string, the account must grant the originating account
        the Service Account Token Creator IAM role.
        If set as a sequence, the identities from the list must grant
        Service Account Token Creator IAM role to the directly preceding identity, with first
        account from the list granting this role to the originating account (templated).
    :param query: (Optional) A sql query to execute instead
    :param as_dict: if True returns the result as a list of dictionaries. default to False
    """

    template_fields: Sequence[str] = (
        'dataset_id',
        'table_id',
        'max_results',
        'selected_fields',
        'impersonation_chain',
    )
    ui_color = BigQueryUIColors.QUERY.value

    def __init__(
        self,
        *,
        dataset_id: Optional[str] = None,
        table_id: Optional[str] = None,
        max_results: Optional[int] = 100,
        selected_fields: Optional[str] = None,
        gcp_conn_id: str = 'google_cloud_default',
        delegate_to: Optional[str] = None,
        location: Optional[str] = None,
        impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
        query: Optional[str] = None,
        as_dict: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        self.dataset_id = dataset_id
        self.table_id = table_id
        self.max_results = int(max_results)
        self.selected_fields = selected_fields
        self.gcp_conn_id = gcp_conn_id
        self.delegate_to = delegate_to
        self.location = location
        self.impersonation_chain = impersonation_chain
        self.query = query
        self.as_dict = as_dict

        if not query and not table_id:
            self.log.error('Table_id or query not set. Please provide either a dataset_id + table_id or a query string')

    def execute(self, context: 'Context') -> list:
        self.log.info(
            'Fetching Data from %s.%s max results: %s', self.dataset_id, self.table_id, self.max_results
        )

        hook = BigQueryHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
            location=self.location,
        )

        if not self.query:
            if not self.selected_fields:
                schema: Dict[str, list] = hook.get_schema(
                    dataset_id=self.dataset_id,
                    table_id=self.table_id,
                )
                if "fields" in schema:
                    self.selected_fields = ','.join([field["name"] for field in schema["fields"]])

            with hook.list_rows(
                dataset_id=self.dataset_id,
                table_id=self.table_id,
                max_results=self.max_results,
                selected_fields=self.selected_fields
            ) as rows:

                if self.as_dict:
                    table_data = [json.dumps(dict(zip(self.selected_fields, row))).encode('utf-8') for row in rows]
                else:
                    table_data = [row.values() for row in rows]    
        
        else:
            with hook.get_conn().cursor().execute(self.query) as cursor:
                if self.as_dict:
                    table_data = [json.dumps(dict(zip(self.keys,row))).encode('utf-8') for row in cursor.fetchmany(self.max_results)]
                else:
                    table_data = [row for row in cursor.fetchmany(self.max_results)]

        self.log.info('Total extracted rows: %s', len(table_data))

        return table_data

Use case/motivation

This would simplify getting data from BigQuery into airflow instead of having to first store the data in a separat table with BigQueryInsertJob and then fetch that.
Also simplifies handling the data with as_dict in the same way that many other database connectors in python does.

Related issues

No response

Are you willing to submit a PR?

  • Yes I am willing to submit a PR!

Code of Conduct

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions