@@ -426,11 +426,48 @@ def test_cloud_diff(self, mock_request, mock_os_environ, mock_print):
426
426
427
427
mock_request .assert_called_once ()
428
428
mock_print .assert_called_once ()
429
+ request_endpoint = mock_request .call_args [0 ][1 ]
429
430
request_data_dict = mock_request .call_args [1 ]["json" ]
430
431
self .assertEqual (
431
432
mock_request .call_args [1 ]["headers" ]["Authorization" ],
432
433
"Key " + expected_api_key ,
433
434
)
435
+ self .assertEqual (request_endpoint , f'https://app.datafold.com/api/v1/datadiffs' )
436
+ self .assertEqual (request_data_dict ["data_source1_id" ], expected_datasource_id )
437
+ self .assertEqual (request_data_dict ["data_source2_id" ], expected_datasource_id )
438
+ self .assertEqual (request_data_dict ["table1" ], prod_qualified_list )
439
+ self .assertEqual (request_data_dict ["table2" ], dev_qualified_list )
440
+ self .assertEqual (request_data_dict ["pk_columns" ], expected_primary_keys )
441
+
442
+ @patch ("data_diff.dbt.rich.print" )
443
+ @patch ("data_diff.dbt.os.environ" )
444
+ @patch ("data_diff.dbt.requests.request" )
445
+ def test_cloud_diff_host_name_override (self , mock_request , mock_os_environ , mock_print ):
446
+ expected_api_key = "an_api_key"
447
+ mock_response = Mock ()
448
+ mock_response .json .return_value = {"id" : 123 }
449
+ mock_request .return_value = mock_response
450
+ mock_os_environ .get .return_value = expected_api_key
451
+ dev_qualified_list = ["dev_db" , "dev_schema" , "dev_table" ]
452
+ prod_qualified_list = ["prod_db" , "prod_schema" , "prod_table" ]
453
+ expected_datasource_id = 1
454
+ expected_primary_keys = ["primary_key_column" ]
455
+ diff_vars = DiffVars (
456
+ dev_qualified_list , prod_qualified_list , expected_primary_keys , expected_datasource_id , None , None
457
+ )
458
+ host_name = "a_host_name"
459
+ _cloud_diff (diff_vars , host_name )
460
+
461
+ mock_request .assert_called_once ()
462
+ mock_print .assert_called_once ()
463
+
464
+ request_endpoint = mock_request .call_args [0 ][1 ]
465
+ request_data_dict = mock_request .call_args [1 ]["json" ]
466
+ self .assertEqual (
467
+ mock_request .call_args [1 ]["headers" ]["Authorization" ],
468
+ "Key " + expected_api_key ,
469
+ )
470
+ self .assertEqual (request_endpoint , f'https://{ host_name } /api/v1/datadiffs' )
434
471
self .assertEqual (request_data_dict ["data_source1_id" ], expected_datasource_id )
435
472
self .assertEqual (request_data_dict ["data_source2_id" ], expected_datasource_id )
436
473
self .assertEqual (request_data_dict ["table1" ], prod_qualified_list )
@@ -500,7 +537,35 @@ def test_diff_is_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, mock_
500
537
mock_dbt_parser_inst .get_models .assert_called_once ()
501
538
mock_dbt_parser_inst .set_connection .assert_not_called ()
502
539
503
- mock_cloud_diff .assert_called_once_with (expected_diff_vars )
540
+ mock_cloud_diff .assert_called_once_with (expected_diff_vars , None )
541
+ mock_local_diff .assert_not_called ()
542
+ mock_print .assert_called_once ()
543
+
544
+ @patch ("data_diff.dbt._get_diff_vars" )
545
+ @patch ("data_diff.dbt._local_diff" )
546
+ @patch ("data_diff.dbt._cloud_diff" )
547
+ @patch ("data_diff.dbt.DbtParser.__new__" )
548
+ @patch ("data_diff.dbt.rich.print" )
549
+ def test_diff_is_cloud (self , mock_print , mock_dbt_parser , mock_cloud_diff , mock_local_diff , mock_get_diff_vars ):
550
+ mock_dbt_parser_inst = Mock ()
551
+ mock_model = Mock ()
552
+ expected_dbt_vars_dict = {
553
+ "prod_database" : "prod_db" ,
554
+ "prod_schema" : "prod_schema" ,
555
+ "datasource_id" : 1 ,
556
+ }
557
+ host_name = 'a_host_name'
558
+
559
+ mock_dbt_parser .return_value = mock_dbt_parser_inst
560
+ mock_dbt_parser_inst .get_models .return_value = [mock_model ]
561
+ mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
562
+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], 123 , None , None )
563
+ mock_get_diff_vars .return_value = expected_diff_vars
564
+ dbt_diff (is_cloud = True , cloud_host_name = host_name )
565
+ mock_dbt_parser_inst .get_models .assert_called_once ()
566
+ mock_dbt_parser_inst .set_connection .assert_not_called ()
567
+
568
+ mock_cloud_diff .assert_called_once_with (expected_diff_vars , host_name )
504
569
mock_local_diff .assert_not_called ()
505
570
mock_print .assert_called_once ()
506
571
0 commit comments