Skip to content

Commit

Permalink
[Bug] [flink-connector-jdbc]change jdbc Source connector to get fiel…
Browse files Browse the repository at this point in the history
…ds from jdbc meta data and support oracle database (apache#1781)

* get fields from jdbc meta data
* remove regex pattern
* use StringUtils
  • Loading branch information
gleiyu authored May 6, 2022
1 parent ee6a987 commit 1a368dc
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 95 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.seatunnel.flink.jdbc.input;

import static org.apache.flink.api.common.typeinfo.BasicTypeInfo.BIG_DEC_TYPE_INFO;
import static org.apache.flink.api.common.typeinfo.BasicTypeInfo.DOUBLE_TYPE_INFO;
import static org.apache.flink.api.common.typeinfo.BasicTypeInfo.FLOAT_TYPE_INFO;
import static org.apache.flink.api.common.typeinfo.BasicTypeInfo.STRING_TYPE_INFO;
import static org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO;

import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;

import java.util.HashMap;
import java.util.Map;

public class OracleTypeInformationMap implements TypeInformationMap {

private static final Map<String, TypeInformation<?>> INFORMATION_MAP = new HashMap<>();

static {
INFORMATION_MAP.put("NVARCHAR2", STRING_TYPE_INFO);
INFORMATION_MAP.put("VARCHAR2", STRING_TYPE_INFO);
INFORMATION_MAP.put("FLOAT", DOUBLE_TYPE_INFO);
INFORMATION_MAP.put("NUMBER", BIG_DEC_TYPE_INFO);
INFORMATION_MAP.put("LONG", STRING_TYPE_INFO);
INFORMATION_MAP.put("DATE", SqlTimeTypeInfo.TIMESTAMP);
INFORMATION_MAP.put("RAW", BYTE_PRIMITIVE_ARRAY_TYPE_INFO);
INFORMATION_MAP.put("LONG RAW", BYTE_PRIMITIVE_ARRAY_TYPE_INFO);
INFORMATION_MAP.put("NCHAR", STRING_TYPE_INFO);
INFORMATION_MAP.put("CHAR", STRING_TYPE_INFO);
INFORMATION_MAP.put("BINARY_FLOAT", FLOAT_TYPE_INFO);
INFORMATION_MAP.put("BINARY_DOUBLE", DOUBLE_TYPE_INFO);
INFORMATION_MAP.put("ROWID", STRING_TYPE_INFO);
INFORMATION_MAP.put("NCLOB", STRING_TYPE_INFO);
INFORMATION_MAP.put("CLOB", STRING_TYPE_INFO);
INFORMATION_MAP.put("BLOB", BYTE_PRIMITIVE_ARRAY_TYPE_INFO);
INFORMATION_MAP.put("BFILE", BYTE_PRIMITIVE_ARRAY_TYPE_INFO);
INFORMATION_MAP.put("TIMESTAMP", SqlTimeTypeInfo.TIMESTAMP);
INFORMATION_MAP.put("TIMESTAMP WITH TIME ZONE", SqlTimeTypeInfo.TIMESTAMP);
INFORMATION_MAP.put("TIMESTAMP WITH LOCAL TIME ZONE", SqlTimeTypeInfo.TIMESTAMP);
}

@Override
public TypeInformation<?> getInformation(String datatype) {
return INFORMATION_MAP.get(datatype);
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
Expand Down Expand Up @@ -39,15 +40,16 @@
import org.apache.seatunnel.flink.jdbc.input.DefaultTypeInformationMap;
import org.apache.seatunnel.flink.jdbc.input.JdbcInputFormat;
import org.apache.seatunnel.flink.jdbc.input.MysqlTypeInformationMap;
import org.apache.seatunnel.flink.jdbc.input.OracleTypeInformationMap;
import org.apache.seatunnel.flink.jdbc.input.PostgresTypeInformationMap;
import org.apache.seatunnel.flink.jdbc.input.TypeInformationMap;

import org.apache.seatunnel.shade.com.typesafe.config.Config;

import org.apache.commons.lang3.StringUtils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.connector.jdbc.split.JdbcNumericBetweenParametersProvider;
import org.apache.flink.connector.jdbc.split.JdbcParameterValuesProvider;
Expand All @@ -56,16 +58,14 @@
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class JdbcSource implements FlinkBatchSource {

Expand All @@ -74,19 +74,11 @@ public class JdbcSource implements FlinkBatchSource {
private static final int DEFAULT_FETCH_SIZE = 10000;

private Config config;
private String tableName;
private String driverName;
private String dbUrl;
private String username;
private String password;
private int fetchSize = DEFAULT_FETCH_SIZE;
private int parallelism = -1;
private Set<String> fields;
private Map<String, TypeInformation<?>> tableFieldInfo;

private static final Pattern COMPILE = Pattern.compile("[\\s]*select[\\s]*(.*)from[\\s]*([\\S]+)(.*)",
Pattern.CASE_INSENSITIVE);

private JdbcInputFormat jdbcInputFormat;

@Override
Expand Down Expand Up @@ -115,13 +107,11 @@ public CheckResult checkConfig() {

@Override
public void prepare(FlinkEnvironment env) {
driverName = config.getString(DRIVER);
dbUrl = config.getString(URL);
username = config.getString(USERNAME);
String driverName = config.getString(DRIVER);
String dbUrl = config.getString(URL);
String username = config.getString(USERNAME);
String query = config.getString(QUERY);
Tuple2<String, Set<String>> tableNameAndFields = getTableNameAndFields(COMPILE, query);
tableName = tableNameAndFields.f0;
fields = tableNameAndFields.f1;

if (config.hasPath(PASSWORD)) {
password = config.getString(PASSWORD);
}
Expand All @@ -140,26 +130,26 @@ public void prepare(FlinkEnvironment env) {
}

try (Connection connection = DriverManager.getConnection(dbUrl, username, password)) {
tableFieldInfo = initTableField(connection);
tableFieldInfo = initTableField(connection, query);
RowTypeInfo rowTypeInfo = getRowTypeInfo();
JdbcInputFormat.JdbcInputFormatBuilder builder = JdbcInputFormat.buildFlinkJdbcInputFormat();
if (config.hasPath(PARTITION_COLUMN)) {
if (!tableFieldInfo.containsKey(config.getString(PARTITION_COLUMN))) {
throw new IllegalArgumentException(String.format("field %s not contain in table %s",
config.getString(PARTITION_COLUMN), tableName));
String partitionColumn = config.getString(PARTITION_COLUMN);
if (!tableFieldInfo.containsKey(partitionColumn)) {
throw new IllegalArgumentException(String.format("field %s not contain in query sql %s",
partitionColumn, query));
}
if (!isNumericType(rowTypeInfo.getTypeAt(config.getString(PARTITION_COLUMN)))) {
throw new IllegalArgumentException(String.format("%s is not numeric type", PARTITION_COLUMN));
if (!isNumericType(rowTypeInfo.getTypeAt(partitionColumn))) {
throw new IllegalArgumentException(String.format("%s is not numeric type", partitionColumn));
}
JdbcParameterValuesProvider jdbcParameterValuesProvider =
initPartition(config.getString(PARTITION_COLUMN), connection);
initPartition(partitionColumn, connection, query);
builder.setParametersProvider(jdbcParameterValuesProvider);
query = extendPartitionQuerySql(query, config.getString(PARTITION_COLUMN));
query = String.format("SELECT * FROM (%s) tt where " + partitionColumn + " >= ? AND " + partitionColumn + " < ?", query);
}
builder.setDrivername(driverName).setDBUrl(dbUrl).setUsername(username)
.setPassword(password).setQuery(query).setFetchSize(fetchSize)
.setRowTypeInfo(rowTypeInfo);

jdbcInputFormat = builder.finish();
} catch (SQLException e) {
throw new RuntimeException("jdbc connection init failed.", e);
Expand All @@ -171,23 +161,7 @@ public String getPluginName() {
return "JdbcSource";
}

private String extendPartitionQuerySql(String query, String column) {
Matcher matcher = COMPILE.matcher(query);
if (matcher.find()) {
String where = matcher.group(Integer.parseInt("3"));
if (where != null && where.trim().toLowerCase().startsWith("where")) {
// contain where
return query + " AND \"" + column + "\" BETWEEN ? AND ?";
} else {
// not contain where
return query + " WHERE \"" + column + "\" BETWEEN ? AND ?";
}
} else {
throw new IllegalArgumentException("sql statement format is incorrect :" + query);
}
}

private JdbcParameterValuesProvider initPartition(String columnName, Connection connection) throws SQLException {
private JdbcParameterValuesProvider initPartition(String columnName, Connection connection, String query) throws SQLException {
long max = Long.MAX_VALUE;
long min = Long.MIN_VALUE;
if (config.hasPath(PARTITION_UPPER_BOUND) && config.hasPath(PARTITION_LOWER_BOUND)) {
Expand All @@ -196,7 +170,7 @@ private JdbcParameterValuesProvider initPartition(String columnName, Connection
return new JdbcNumericBetweenParametersProvider(min, max).ofBatchNum(parallelism * 2);
}
try (ResultSet rs = connection.createStatement().executeQuery(String.format("SELECT MAX(%s),MIN(%s) " +
"FROM %s", columnName, columnName, tableName))) {
"FROM (%s) tt", columnName, columnName, query))) {
if (rs.next()) {
max = config.hasPath(PARTITION_UPPER_BOUND) ? config.getLong(PARTITION_UPPER_BOUND) :
Long.parseLong(rs.getString(1));
Expand All @@ -212,59 +186,47 @@ private boolean isNumericType(TypeInformation<?> type) {
|| type.equals(LONG_TYPE_INFO) || type.equals(BIG_INT_TYPE_INFO);
}

private Map<String, TypeInformation<?>> initTableField(Connection connection) {
Map<String, TypeInformation<?>> map = new LinkedHashMap<>();

private Map<String, TypeInformation<?>> initTableField(Connection connection, String selectSql) {
try {
TypeInformationMap informationMapping = getTypeInformationMap(driverName);
DatabaseMetaData metaData = connection.getMetaData();
ResultSet columns = metaData.getColumns(connection.getCatalog(), connection.getSchema(), tableName, "%");
while (columns.next()) {
String columnName = columns.getString("COLUMN_NAME");
String dataTypeName = columns.getString("TYPE_NAME");
if (fields == null || fields.contains(columnName)) {
map.put(columnName, informationMapping.getInformation(dataTypeName));
}
String databaseDialect = connection.getMetaData().getDatabaseProductName();
PreparedStatement preparedStatement = connection.prepareStatement(selectSql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
preparedStatement.setMaxRows(1);
ResultSetMetaData rsMeta = preparedStatement.getMetaData();
try {
return getRowInfo(rsMeta, databaseDialect);
} catch (SQLException e) {
ResultSet rs = preparedStatement.executeQuery();
return getRowInfo(rs.getMetaData(), databaseDialect);
}
} catch (Exception e) {
} catch (SQLException e) {
LOGGER.warn("get row type info exception", e);
}
return map;
return new LinkedHashMap<>();
}

private Tuple2<String, Set<String>> getTableNameAndFields(Pattern regex, String selectSql) {
Matcher matcher = regex.matcher(selectSql);
String tableName;
Set<String> fields = null;
if (matcher.find()) {
String var = matcher.group(1);
tableName = matcher.group(2);
if (!"*".equals(var.trim())) {
LinkedHashSet<String> vars = new LinkedHashSet<>();
String[] split = var.split(",");
for (String s : split) {
vars.add(s.trim());
}
fields = vars;
private Map<String, TypeInformation<?>> getRowInfo(ResultSetMetaData rsMeta, String databaseDialect) throws SQLException {
Map<String, TypeInformation<?>> map = new LinkedHashMap<>();
if (rsMeta == null) {
throw new SQLException("No result set metadata available to resolver row info!");
}
TypeInformationMap informationMapping = getTypeInformationMap(databaseDialect);
for (int i = 1; i <= rsMeta.getColumnCount(); i++) {
String columnName = rsMeta.getColumnLabel(i);
String columnTypeName = rsMeta.getColumnTypeName(i);
if (columnTypeName == null) {
throw new SQLException("Unsupported to get type info from result set metadata!");
}
return new Tuple2<>(tableName, fields);
} else {
throw new IllegalArgumentException("can't find tableName and fields in sql :" + selectSql);
map.put(columnName, informationMapping.getInformation(columnTypeName));
}
return map;
}

private RowTypeInfo getRowTypeInfo() {
int size = tableFieldInfo.size();
if (fields != null && fields.size() > 0) {
size = fields.size();
} else {
fields = tableFieldInfo.keySet();
}

Set<String> fields = tableFieldInfo.keySet();
TypeInformation<?>[] typeInformation = new TypeInformation<?>[size];
String[] names = new String[size];
int i = 0;

for (String field : fields) {
typeInformation[i] = tableFieldInfo.get(field);
names[i] = field;
Expand All @@ -273,12 +235,13 @@ private RowTypeInfo getRowTypeInfo() {
return new RowTypeInfo(typeInformation, names);
}

private TypeInformationMap getTypeInformationMap(String driverName) {
driverName = driverName.toLowerCase();
if (driverName.contains("mysql")) {
private TypeInformationMap getTypeInformationMap(String databaseDialect) {
if (StringUtils.containsIgnoreCase(databaseDialect, "mysql")) {
return new MysqlTypeInformationMap();
} else if (driverName.contains("postgresql")) {
} else if (StringUtils.containsIgnoreCase(databaseDialect, "postgresql")) {
return new PostgresTypeInformationMap();
} else if (StringUtils.containsIgnoreCase(databaseDialect, "oracle")) {
return new OracleTypeInformationMap();
} else {
return new DefaultTypeInformationMap();
}
Expand Down

0 comments on commit 1a368dc

Please sign in to comment.