Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.io.IOException;
import java.io.Serializable;
import java.net.URLClassLoader;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
Expand Down Expand Up @@ -451,6 +452,9 @@ public abstract static class DataSourceConfiguration implements Serializable {
@Pure
abstract @Nullable ClassLoader getDriverClassLoader();

@Pure
abstract @Nullable ValueProvider<String> getDriverJars();

@Pure
abstract @Nullable DataSource getDataSource();

Expand All @@ -476,6 +480,8 @@ abstract Builder setConnectionInitSqls(

abstract Builder setDriverClassLoader(ClassLoader driverClassLoader);

abstract Builder setDriverJars(ValueProvider<String> driverJars);

abstract Builder setDataSource(@Nullable DataSource dataSource);

abstract DataSourceConfiguration build();
Expand Down Expand Up @@ -583,17 +589,36 @@ public DataSourceConfiguration withDriverClassLoader(ClassLoader driverClassLoad
return builder().setDriverClassLoader(driverClassLoader).build();
}

/**
* Comma separated paths for JDBC drivers. This method is filesystem agnostic and can be used
* for all FileSystems supported by Beam If not specified, the default classloader is used to
* load the jars.
*
* <p>For example, gs://your-bucket/driver_jar1.jar,gs://your-bucket/driver_jar2.jar.
*/
public DataSourceConfiguration withDriverJars(String driverJars) {
checkArgument(driverJars != null, "driverJars can not be null");
return withDriverJars(ValueProvider.StaticValueProvider.of(driverJars));
}

/** Same as {@link #withDriverJars(String)} but accepting a ValueProvider. */
public DataSourceConfiguration withDriverJars(ValueProvider<String> driverJars) {
checkArgument(driverJars != null, "driverJars can not be null");
return builder().setDriverJars(driverJars).build();
}

void populateDisplayData(DisplayData.Builder builder) {
if (getDataSource() != null) {
builder.addIfNotNull(DisplayData.item("dataSource", getDataSource().getClass().getName()));
} else {
builder.addIfNotNull(DisplayData.item("jdbcDriverClassName", getDriverClassName()));
builder.addIfNotNull(DisplayData.item("jdbcUrl", getUrl()));
builder.addIfNotNull(DisplayData.item("username", getUsername()));
builder.addIfNotNull(DisplayData.item("driverJars", getDriverJars()));
}
}

DataSource buildDatasource() {
public DataSource buildDatasource() {
if (getDataSource() == null) {
BasicDataSource basicDataSource = new BasicDataSource();
if (getDriverClassName() != null) {
Expand Down Expand Up @@ -630,6 +655,11 @@ && getConnectionInitSqls().get() != null
if (getDriverClassLoader() != null) {
basicDataSource.setDriverClassLoader(getDriverClassLoader());
}
if (getDriverJars() != null) {
URLClassLoader classLoader =
URLClassLoader.newInstance(JdbcUtil.saveFilesLocally(getDriverJars().get()));
basicDataSource.setDriverClassLoader(classLoader);
}

return basicDataSource;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ public Schema configurationSchema() {
// readQuery
.addNullableField("partitionColumn", FieldType.STRING)
.addNullableField("partitions", FieldType.INT16)
.addNullableField("maxConnections", FieldType.INT16)
.addNullableField("driverJars", FieldType.STRING)
.build();
}

Expand Down Expand Up @@ -202,11 +204,14 @@ protected JdbcIO.DataSourceConfiguration getDataSourceConfiguration() {
dataSourceConfiguration = dataSourceConfiguration.withConnectionInitSqls(initSqls);
}

if (config.getSchema().hasField("maxConnections")) {
@Nullable Integer maxConnections = config.getInt32("maxConnections");
if (maxConnections != null) {
dataSourceConfiguration = dataSourceConfiguration.withMaxConnections(maxConnections);
}
@Nullable Integer maxConnections = config.getInt32("maxConnections");
if (maxConnections != null) {
dataSourceConfiguration = dataSourceConfiguration.withMaxConnections(maxConnections);
}

@Nullable String driverJars = config.getString("driverJars");
if (driverJars != null) {
dataSourceConfiguration = dataSourceConfiguration.withDriverJars(driverJars);
}

return dataSourceConfiguration;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;

import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.nio.file.Paths;
import java.sql.Date;
import java.sql.JDBCType;
import java.sql.PreparedStatement;
Expand All @@ -38,19 +44,25 @@
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.io.jdbc.JdbcIO.PreparedStatementSetter;
import org.apache.beam.sdk.io.jdbc.JdbcIO.ReadWithPartitions;
import org.apache.beam.sdk.io.jdbc.JdbcIO.RowMapper;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.logicaltypes.FixedPrecisionNumeric;
import org.apache.beam.sdk.schemas.logicaltypes.MicrosInstant;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.util.MimeTypes;
import org.apache.beam.sdk.util.Preconditions;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Files;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.DateTime;
import org.joda.time.Duration;
Expand All @@ -61,6 +73,42 @@
/** Provides utility functions for working with {@link JdbcIO}. */
class JdbcUtil {

private static final Logger LOG = LoggerFactory.getLogger(JdbcUtil.class);

/** Utility method to save jar files locally in the worker. */
static URL[] saveFilesLocally(String driverJars) {
List<String> listOfJarPaths = Splitter.on(',').trimResults().splitToList(driverJars);

final String destRoot = Files.createTempDir().getAbsolutePath();
List<URL> driverJarUrls = new ArrayList<>();
listOfJarPaths.stream()
.forEach(
jarPath -> {
try {
ResourceId sourceResourceId = FileSystems.matchNewResource(jarPath, false);
@SuppressWarnings("nullness")
File destFile = Paths.get(destRoot, sourceResourceId.getFilename()).toFile();
ResourceId destResourceId =
FileSystems.matchNewResource(destFile.getAbsolutePath(), false);
copy(sourceResourceId, destResourceId);
LOG.info("Localized jar: " + sourceResourceId + " to: " + destResourceId);
driverJarUrls.add(destFile.toURI().toURL());
} catch (IOException e) {
LOG.warn("Unable to copy " + jarPath, e);
}
});
return driverJarUrls.stream().toArray(URL[]::new);
}

/** utility method to copy binary (jar file) data from source to dest. */
private static void copy(ResourceId source, ResourceId dest) throws IOException {
try (ReadableByteChannel rbc = FileSystems.open(source)) {
try (WritableByteChannel wbc = FileSystems.create(dest, MimeTypes.BINARY)) {
ByteStreams.copy(rbc, wbc);
}
}
}

/** Generates an insert statement based on {@link Schema.Field}. * */
static String generateStatement(String tableName, List<Schema.Field> fields) {
String fieldNames =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

/** Provides utility functions for working with Beam {@link Schema} types. */
@Experimental(Kind.SCHEMAS)
class SchemaUtil {
public class SchemaUtil {
/**
* Interface implemented by functions that extract values of different types from a JDBC
* ResultSet.
Expand Down Expand Up @@ -178,7 +178,7 @@ private static BeamFieldConverter jdbcTypeToBeamFieldConverter(
}

/** Infers the Beam {@link Schema} from {@link ResultSetMetaData}. */
static Schema toBeamSchema(ResultSetMetaData md) throws SQLException {
public static Schema toBeamSchema(ResultSetMetaData md) throws SQLException {
Schema.Builder schemaBuilder = Schema.builder();

for (int i = 1; i <= md.getColumnCount(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;

import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
Expand All @@ -36,14 +42,18 @@
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.joda.time.DateTime;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Test JdbcUtil. */
@RunWith(JUnit4.class)
public class JdbcUtilTest {

@Rule public TemporaryFolder temporaryFolder = new TemporaryFolder();

// TODO(BEAM-13846): Support string-based partitioning once the transform supports modifying
// range properties (inclusive/exclusive).
static final JdbcReadWithPartitionsHelper<String> PROTOTYPE_STRING_PARTITIONER =
Expand Down Expand Up @@ -236,4 +246,25 @@ public void testLongPartitioningNotEnoughRanges() {
assertEquals(4, ranges.size());
assertArrayEquals(expectedRanges.toArray(), ranges.toArray());
}

@Test
public void testSavesFilesAsExpected() throws IOException {
File tempFile1 = temporaryFolder.newFile();
File tempFile2 = temporaryFolder.newFile();
String expectedContent1 = "hello world";
String expectedContent2 = "hello world 2";
Files.write(tempFile1.toPath(), expectedContent1.getBytes(StandardCharsets.UTF_8));
Files.write(tempFile2.toPath(), expectedContent2.getBytes(StandardCharsets.UTF_8));

URL[] urls =
JdbcUtil.saveFilesLocally(tempFile1.getAbsolutePath() + "," + tempFile2.getAbsolutePath());

assertEquals(2, urls.length);
assertEquals(
expectedContent1,
new String(Files.readAllBytes(Paths.get(urls[0].getFile())), StandardCharsets.UTF_8));
assertEquals(
expectedContent2,
new String(Files.readAllBytes(Paths.get(urls[1].getFile())), StandardCharsets.UTF_8));
}
}
24 changes: 22 additions & 2 deletions sdks/python/apache_beam/io/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,13 @@ def default_io_expansion_service(classpath=None):
('connection_init_sqls', typing.Optional[typing.List[str]]),
('read_query', typing.Optional[str]),
('write_statement', typing.Optional[str]),
('fetch_size', typing.Optional[int]),
('fetch_size', typing.Optional[np.int16]),
('output_parallelization', typing.Optional[bool]),
('autosharding', typing.Optional[bool]),
('partition_column', typing.Optional[str]),
('partitions', typing.Optional[np.int16])],
('partitions', typing.Optional[np.int16]),
('max_connections', typing.Optional[np.int16]),
('driver_jars', typing.Optional[str])],
)

DEFAULT_JDBC_CLASSPATH = ['org.postgresql:postgresql:42.2.16']
Expand Down Expand Up @@ -176,6 +178,8 @@ def __init__(
connection_properties=None,
connection_init_sqls=None,
autosharding=False,
max_connections=None,
driver_jars=None,
expansion_service=None,
classpath=None,
):
Expand All @@ -194,6 +198,11 @@ def __init__(
passed as list of strings
:param autosharding: enable automatic re-sharding of bundles to scale the
number of shards with the number of workers.
:param max_connections: sets the maximum total number of connections.
use a negative value for no limit.
:param driver_jars: comma separated paths for JDBC drivers. if not
specified, the default classloader is used to load the
driver jars.
:param expansion_service: The address (host:port) of the ExpansionService.
:param classpath: A list of JARs or Java packages to include in the
classpath for the expansion service. This option is
Expand Down Expand Up @@ -225,6 +234,8 @@ def __init__(
fetch_size=None,
output_parallelization=None,
autosharding=autosharding,
max_connections=max_connections,
driver_jars=driver_jars,
partitions=None,
partition_column=None))),
),
Expand Down Expand Up @@ -277,6 +288,8 @@ def __init__(
partitions=None,
connection_properties=None,
connection_init_sqls=None,
max_connections=None,
driver_jars=None,
expansion_service=None,
classpath=None,
):
Expand All @@ -299,6 +312,11 @@ def __init__(
[propertyName=property;]*
:param connection_init_sqls: required only for MySql and MariaDB.
passed as list of strings
:param max_connections: sets the maximum total number of connections.
use a negative value for no limit.
:param driver_jars: comma separated paths for JDBC drivers. if not
specified, the default classloader is used to load the
driver jars.
:param expansion_service: The address (host:port) of the ExpansionService.
:param classpath: A list of JARs or Java packages to include in the
classpath for the expansion service. This option is
Expand Down Expand Up @@ -330,6 +348,8 @@ def __init__(
fetch_size=fetch_size,
output_parallelization=output_parallelization,
autosharding=None,
max_connections=max_connections,
driver_jars=driver_jars,
partition_column=partition_column,
partitions=partitions))),
),
Expand Down