Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions sdks/python/apache_beam/io/gcp/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1926,7 +1926,8 @@ def __init__(
ignore_insert_ids=False,
# TODO(BEAM-11857): Switch the default when the feature is mature.
with_auto_sharding=False,
ignore_unknown_columns=False):
ignore_unknown_columns=False,
load_job_project_id=None):
"""Initialize a WriteToBigQuery transform.

Args:
Expand Down Expand Up @@ -2058,6 +2059,9 @@ def __init__(
which treats unknown values as errors. This option is only valid for
method=STREAMING_INSERTS. See reference:
https://cloud.google.com/bigquery/docs/reference/rest/v2/tabledata/insertAll
load_job_project_id: Specifies an alternate GCP project id to use for
billingBatch File Loads. By default, the project id of the table is
used.
"""
self._table = table
self._dataset = dataset
Expand Down Expand Up @@ -2092,6 +2096,7 @@ def __init__(
self.schema_side_inputs = schema_side_inputs or ()
self._ignore_insert_ids = ignore_insert_ids
self._ignore_unknown_columns = ignore_unknown_columns
self.load_job_project_id = load_job_project_id

# Dict/schema methods were moved to bigquery_tools, but keep references
# here for backward compatibility.
Expand Down Expand Up @@ -2185,7 +2190,8 @@ def expand(self, pcoll):
schema_side_inputs=self.schema_side_inputs,
additional_bq_parameters=self.additional_bq_parameters,
validate=self._validate,
is_streaming_pipeline=is_streaming_pipeline)
is_streaming_pipeline=is_streaming_pipeline,
load_job_project_id=self.load_job_project_id)

def display_data(self):
res = {}
Expand Down
39 changes: 28 additions & 11 deletions sdks/python/apache_beam/io/gcp/bigquery_file_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,14 @@ def __init__(
test_client=None,
additional_bq_parameters=None,
step_name=None,
source_format=None):
source_format=None,
load_job_project_id=None):
self._test_client = test_client
self._write_disposition = write_disposition
self._additional_bq_parameters = additional_bq_parameters or {}
self._step_name = step_name
self._source_format = source_format
self._load_job_project_id = load_job_project_id

def setup(self):
self._bq_wrapper = bigquery_tools.BigQueryWrapper(client=self._test_client)
Expand Down Expand Up @@ -439,7 +441,8 @@ def process(self, element, schema_mod_job_name_prefix):
create_disposition='CREATE_NEVER',
additional_load_parameters=additional_parameters,
job_labels=self._bq_io_metadata.add_additional_bq_job_labels(),
source_format=self._source_format)
source_format=self._source_format,
load_job_project_id=self._load_job_project_id)
yield (destination, schema_update_job_reference)


Expand All @@ -462,13 +465,15 @@ def __init__(
create_disposition=None,
write_disposition=None,
test_client=None,
step_name=None):
step_name=None,
load_job_project_id=None):
self.create_disposition = create_disposition
self.write_disposition = write_disposition
self.test_client = test_client
self._observed_tables = set()
self.bq_io_metadata = None
self._step_name = step_name
self.load_job_project_id = load_job_project_id

def display_data(self):
return {
Expand Down Expand Up @@ -527,8 +532,12 @@ def process(self, element, job_name_prefix=None, unused_schema_mod_jobs=None):

if not self.bq_io_metadata:
self.bq_io_metadata = create_bigquery_io_metadata(self._step_name)

project_id = (
copy_to_reference.projectId
if self.load_job_project_id is None else self.load_job_project_id)
job_reference = self.bq_wrapper._insert_copy_job(
copy_to_reference.projectId,
project_id,
copy_job_name,
copy_from_reference,
copy_to_reference,
Expand Down Expand Up @@ -559,14 +568,16 @@ def __init__(
temporary_tables=False,
additional_bq_parameters=None,
source_format=None,
step_name=None):
step_name=None,
load_job_project_id=None):
self.schema = schema
self.test_client = test_client
self.temporary_tables = temporary_tables
self.additional_bq_parameters = additional_bq_parameters or {}
self.source_format = source_format
self.bq_io_metadata = None
self._step_name = step_name
self.load_job_project_id = load_job_project_id
if self.temporary_tables:
# If we are loading into temporary tables, we rely on the default create
# and write dispositions, which mean that a new table will be created.
Expand Down Expand Up @@ -663,7 +674,8 @@ def process(self, element, load_job_name_prefix, *schema_side_inputs):
create_disposition=create_disposition,
additional_load_parameters=additional_parameters,
source_format=self.source_format,
job_labels=self.bq_io_metadata.add_additional_bq_job_labels())
job_labels=self.bq_io_metadata.add_additional_bq_job_labels(),
load_job_project_id=self.load_job_project_id)
yield (destination, job_reference)


Expand Down Expand Up @@ -789,7 +801,8 @@ def __init__(
schema_side_inputs=None,
test_client=None,
validate=True,
is_streaming_pipeline=False):
is_streaming_pipeline=False,
load_job_project_id=None):
self.destination = destination
self.create_disposition = create_disposition
self.write_disposition = write_disposition
Expand Down Expand Up @@ -823,6 +836,7 @@ def __init__(
self.schema_side_inputs = schema_side_inputs or ()

self.is_streaming_pipeline = is_streaming_pipeline
self.load_job_project_id = load_job_project_id
self._validate = validate
if self._validate:
self.verify()
Expand Down Expand Up @@ -1005,7 +1019,8 @@ def _load_data(
temporary_tables=True,
additional_bq_parameters=self.additional_bq_parameters,
source_format=self._temp_file_format,
step_name=step_name),
step_name=step_name,
load_job_project_id=self.load_job_project_id),
load_job_name_pcv,
*self.schema_side_inputs).with_outputs(
TriggerLoadJobs.TEMP_TABLES, main='main'))
Expand All @@ -1029,7 +1044,7 @@ def _load_data(
additional_bq_parameters=self.additional_bq_parameters,
step_name=step_name,
source_format=self._temp_file_format,
),
load_job_project_id=self.load_job_project_id),
schema_mod_job_name_pcv))

finished_schema_mod_jobs_pc = (
Expand All @@ -1046,7 +1061,8 @@ def _load_data(
create_disposition=self.create_disposition,
write_disposition=self.write_disposition,
test_client=self.test_client,
step_name=step_name),
step_name=step_name,
load_job_project_id=self.load_job_project_id),
copy_job_name_pcv,
pvalue.AsIter(finished_schema_mod_jobs_pc)))

Expand Down Expand Up @@ -1084,7 +1100,8 @@ def _load_data(
temporary_tables=False,
additional_bq_parameters=self.additional_bq_parameters,
source_format=self._temp_file_format,
step_name=step_name),
step_name=step_name,
load_job_project_id=self.load_job_project_id),
load_job_name_pcv,
*self.schema_side_inputs))

Expand Down
84 changes: 84 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,90 @@ def test_records_traverse_transform_with_mocks(self):

assert_that(jobs, equal_to([job_reference]), label='CheckJobs')

def test_load_job_id_used(self):
job_reference = bigquery_api.JobReference()
job_reference.projectId = 'loadJobProject'
job_reference.jobId = 'job_name1'

result_job = bigquery_api.Job()
result_job.jobReference = job_reference

mock_job = mock.Mock()
mock_job.status.state = 'DONE'
mock_job.status.errorResult = None
mock_job.jobReference = job_reference

bq_client = mock.Mock()
bq_client.jobs.Get.return_value = mock_job

bq_client.jobs.Insert.return_value = result_job

transform = bqfl.BigQueryBatchFileLoads(
'project1:dataset1.table1',
custom_gcs_temp_location=self._new_tempdir(),
test_client=bq_client,
validate=False,
load_job_project_id='loadJobProject')

with TestPipeline('DirectRunner') as p:
outputs = p | beam.Create(_ELEMENTS) | transform
jobs = outputs[bqfl.BigQueryBatchFileLoads.DESTINATION_JOBID_PAIRS] \
| "GetJobs" >> beam.Map(lambda x: x[1])

assert_that(jobs, equal_to([job_reference]), label='CheckJobProjectIds')

def test_load_job_id_use_for_copy_job(self):
destination = 'project1:dataset1.table1'

job_reference = bigquery_api.JobReference()
job_reference.projectId = 'loadJobProject'
job_reference.jobId = 'job_name1'
result_job = mock.Mock()
result_job.jobReference = job_reference

mock_job = mock.Mock()
mock_job.status.state = 'DONE'
mock_job.status.errorResult = None
mock_job.jobReference = job_reference

bq_client = mock.Mock()
bq_client.jobs.Get.return_value = mock_job

bq_client.jobs.Insert.return_value = result_job
bq_client.tables.Delete.return_value = None

with TestPipeline('DirectRunner') as p:
outputs = (
p
| beam.Create(_ELEMENTS, reshuffle=False)
| bqfl.BigQueryBatchFileLoads(
destination,
custom_gcs_temp_location=self._new_tempdir(),
test_client=bq_client,
validate=False,
temp_file_format=bigquery_tools.FileFormat.JSON,
max_file_size=45,
max_partition_size=80,
max_files_per_partition=2,
load_job_project_id='loadJobProject'))

dest_copy_jobs = outputs[
bqfl.BigQueryBatchFileLoads.DESTINATION_COPY_JOBID_PAIRS]

copy_jobs = dest_copy_jobs | "GetCopyJobs" >> beam.Map(lambda x: x[1])

assert_that(
copy_jobs,
equal_to([
job_reference,
job_reference,
job_reference,
job_reference,
job_reference,
job_reference
]),
label='CheckCopyJobProjectIds')

@mock.patch('time.sleep')
def test_wait_for_job_completion(self, sleep_mock):
job_references = [bigquery_api.JobReference(), bigquery_api.JobReference()]
Expand Down
9 changes: 7 additions & 2 deletions sdks/python/apache_beam/io/gcp/bigquery_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,8 @@ def perform_load_job(
create_disposition=None,
additional_load_parameters=None,
source_format=None,
job_labels=None):
job_labels=None,
load_job_project_id=None):
"""Starts a job to load data into BigQuery.

Returns:
Expand All @@ -1005,8 +1006,12 @@ def perform_load_job(
'Only one of source_uris and source_stream may be specified. '
'Got both.')

project_id = (
destination.projectId
if load_job_project_id is None else load_job_project_id)

return self._insert_load_job(
destination.projectId,
project_id,
job_id,
destination,
source_uris=source_uris,
Expand Down
12 changes: 12 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,18 @@ def test_perform_load_job_with_source_stream(self):
upload = client.jobs.Insert.call_args[1]["upload"]
self.assertEqual(b'some,data', upload.stream.read())

def test_perform_load_job_with_load_job_id(self):
client = mock.Mock()
wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

wrapper.perform_load_job(
destination=parse_table_reference('project:dataset.table'),
job_id='job_id',
source_uris=['gs://example.com/*'],
load_job_project_id='loadId')
call_args = client.jobs.Insert.call_args
self.assertEqual('loadId', call_args[0][0].projectId)

def verify_write_call_metric(
self, project_id, dataset_id, table_id, status, count):
"""Check if an metric was recorded for the BQ IO write API call."""
Expand Down