Connecting to SQL Databases using JDBC

You can use Databricks to query many SQL databases using JDBC drivers.

Databricks Runtime contains the following drivers for MySQL:

  • Databricks Runtime 3.4 and above include org.mariadb.jdbc.
  • Databricks Runtime 3.3 and below include com.mysql.jdbc.

Databricks Runtime 3.4 and above contain drivers for Microsoft SQL Server and Azure SQL Database. See the Databricks Runtime Release Notes for the complete list of JDBC libraries included in Databricks Runtime.

You can use other SQL databases as well, including (but not limited to) PostgreSQL and Oracle. See Libraries to learn how to install a JDBC library JAR for databases whose drivers are not available in Databricks.

This topic covers how to use the DataFrame API to connect to SQL databases using JDBC and how to control the parallelism of reads through the JDBC interface. This topic provides detailed examples using the Scala API, with abbreviated Python and Spark SQL examples at the end. For all of the supported arguments for connecting to SQL databases using JDBC, see the JDBC section of the Spark SQL programming guide.

Important

The examples in this topic do not include usernames and passwords in JDBC URLs. Instead it expects that you follow the Secrets user guide to store your database credentials as secrets, and then leverage them in a notebook to populate your credentials in a java.util.Properties object. For example:

val jdbcUsername = dbutils.secrets.get(scope = "jdbc", key = "username")
val jdbcPassword = dbutils.secrets.get(scope = "jdbc", key = "password")

For a full example of secret management, see Secret Workflow Example.

Establish cloud connectivity

Databricks VPCs are configured to allow only Spark clusters. When connecting to another infrastructure, the best practice is to use VPC Peering. Once VPC peering is established, you can check with the netcat utility on the cluster.

%sh nc -vz <jdbcHostname> <jdbcPort>

Establish connectivity to MySQL

This example queries MySQL using its JDBC driver.

Step 1: Check that the JDBC driver is available

This statement checks that the driver class exists in your classpath. You can use the %scala magic command to test this in other notebook types, such as Python.

Class.forName("org.mariadb.jdbc.Driver") // |DBR| 3.4 and above

Class.forName("com.mysql.jdbc.Driver") // |DBR| 3.3 and below

Step 2: Create the JDBC URL

val jdbcHostname = "hostname"
val jdbcPort = 3306
val jdbcDatabase = "<database>"

// Create the JDBC URL without passing in the user and password parameters.
val jdbcUrl = s"jdbc:mysql://${jdbcHostname}:${jdbcPort}/${jdbcDatabase}"

// Create a Properties() object to hold the parameters.
import java.util.Properties
val connectionProperties = new Properties()

connectionProperties.put("user", s"${jdbcUsername}")
connectionProperties.put("password", s"${jdbcPassword}")

Step 3: Check connectivity to the MySQL database

import java.sql.DriverManager
val connection = DriverManager.getConnection(jdbcUrl, jdbcUsername, jdbcPassword)
connection.isClosed()

Establish connectivity to SQL Server

This example queries SQL Server using its JDBC driver.

Step 1: Check that the JDBC driver is available

Class.forName("com.microsoft.sqlserver.jdbc.SQLServerDriver")

Step 2: Create the JDBC URL

val jdbcHostname = "<hostname>"
val jdbcPort = 1433
val jdbcDatabase = "<database>"

// Create the JDBC URL without passing in the user and password parameters.
val jdbcUrl = s"jdbc:sqlserver://${jdbcHostname}:${jdbcPort};database=${jdbcDatabase}"

// Create a Properties() object to hold the parameters.
import java.util.Properties
val connectionProperties = new Properties()

connectionProperties.put("user", s"${jdbcUsername}")
connectionProperties.put("password", s"${jdbcPassword}")

Step 3: Check connectivity to the SQLServer database

val driverClass = "com.microsoft.sqlserver.jdbc.SQLServerDriver"
connectionProperties.setProperty("Driver", driverClass)

Read data from JDBC

This section loads data from a database table. This uses a single JDBC connection to pull the table into the Spark environment. For parallel reads, see Manage parallelism.

val employees_table = spark.read.jdbc(jdbcUrl, "employees", connectionProperties)

Spark automatically reads the schema from the database table and maps its types back to Spark SQL types.

employees_table.printSchema

You can run queries against this JDBC table:

display(employees_table.select("age", "salary").groupBy("age").avg("salary"))

Write data to JDBC

This section shows how to write data to a database from an existing Spark SQL table named diamonds.

%sql -- quick test that this test table exists
select * from diamonds limit 5

The following code saves the data into a database table named diamonds. Using column names that are reserved keywords can trigger an exception. The example table has column named table, so you can rename it with withColumnRenamed() prior to pushing it to the JDBC API.

spark.table("diamonds").withColumnRenamed("table", "table_number")
     .write
     .jdbc(jdbcUrl, "diamonds", connectionProperties)

Spark automatically creates a database table with the appropriate schema determined from the DataFrame schema.

The default behavior is to create a new table and to throw an error message if a table with the same name already exists. You can use the Spark SQL SaveMode feature to change this behavior. For example, here’s how to append more rows to the table:

import org.apache.spark.sql.SaveMode

spark.sql("select * from diamonds limit 10").withColumnRenamed("table", "table_number")
     .write
     .mode(SaveMode.Append) // <--- Append to the existing table
     .jdbc(jdbcUrl, "diamonds", connectionProperties)

You can also overwrite an existing table:

spark.table("diamonds").withColumnRenamed("table", "table_number")
     .write
     .mode(SaveMode.Overwrite) // <--- Overwrite the existing table
     .jdbc(jdbcUrl, "diamonds", connectionProperties)

Push down a query to the database engine

You can push down an entire query to the database and return just the result. The table parameter identifies the JDBC table to read. You can use anything that is valid in a SQL query FROM clause.

// Note: The parentheses are required.
val pushdown_query = "(select * from employees where emp_no < 10008) emp_alias"
val df = spark.read.jdbc(url=jdbcUrl, table=pushdown_query, properties=connectionProperties)
display(df)

Push down optimization

In addition to ingesting an entire table, you can push down a query to the database to leverage it for processing, and return only the results.

// Explain plan with no column selection returns all columns
spark.read.jdbc(jdbcUrl, "diamonds", connectionProperties).explain(true)

You can prune columns and pushdown query predicates to the database with DataFrame methods.

// Explain plan with column selection will prune columns and just return the ones specified
// Notice that only the 3 specified columns are in the explain plan
spark.read.jdbc(jdbcUrl, "diamonds", connectionProperties).select("carat", "cut", "price").explain(true)
// You can push query predicates down too
// Notice the filter at the top of the physical plan
spark.read.jdbc(jdbcUrl, "diamonds", connectionProperties).select("carat", "cut", "price").where("cut = 'Good'").explain(true)

Manage parallelism

In the Spark UI, you can see that the numPartitions dictate the number of tasks that are launched. Each task is spread across the executors, which can increase the parallelism of the reads and writes through the JDBC interface. See the Spark SQL programming guide for other parameters, such as fetchsize, that can help with performance.

JDBC reads

You can provide split boundaries based on the dataset’s column values.

These options specify the parallelism on read. These options must all be specified if any of them is specified. lowerBound and upperBound decide the partition stride, but do not filter the rows in table. Therefore, Spark partitions and returns all rows in the table.

The following example splits the table read across executors on the emp_no column using the columnName, lowerBound, upperBound, and numPartitions parameters.

val df = (spark.read.jdbc(url=jdbcUrl,
    table="employees",
    columnName="emp_no",
    lowerBound=1L,
    upperBound=100000L,
    numPartitions=100,
    connectionProperties=connectionProperties))
display(df)

JDBC writes

Spark’s partitions dictate the number of connections used to push data through the JDBC API. You can control the parallelism by calling coalesce(<N>) or repartition(<N>) depending on the existing number of partitions. Call coalesce when reducing the number of partitions, and repartition when increasing the number of partitions.

import org.apache.spark.sql.SaveMode

val df = spark.table("diamonds")
println(df.rdd.partitions.length)

// Given the number of partitions above, you can reduce the partition value by calling coalesce() or increase it by calling repartition() to manage the number of connections.
df.repartition(10).write.mode(SaveMode.Append).jdbc(jdbcUrl, "diamonds", connectionProperties)

Python example

The following Python examples cover some of the same tasks as those provided for Scala.

Create the JDBC URL

jdbcHostname = "<hostname>"
jdbcDatabase = "employees"
jdbcPort = 3306
jdbcUrl = "jdbc:mysql://{0}:{1}/{2}?user={3}&password={4}".format(jdbcHostname, jdbcPort, jdbcDatabase, username, password)

You can pass in a dictionary that contains the credentials and driver class similar to the Scala example above.

jdbcUrl = "jdbc:mysql://{0}:{1}/{2}".format(jdbcHostname, jdbcPort, jdbcDatabase)
connectionProperties = {
  "user" : jdbcUsername,
  "password" : jdbcPassword,
  "driver" : "com.mysql.jdbc.Driver"
}

Push down a query to the database engine

pushdown_query = "(select * from employees where emp_no < 10008) emp_alias"
df = spark.read.jdbc(url=jdbcUrl, table=pushdown_query, properties=connectionProperties)
display(df)

Read from JDBC connections across multiple workers

df = spark.read.jdbc(url=jdbcUrl, table="employees", column="emp_no", lowerBound=1, upperBound=100000, numPartitions=100)
display(df)

Spark SQL example

You can define a Spark SQL table or view that uses a JDBC connection. For details, see Create Table and Create View.

%sql
CREATE TABLE jdbcTable
USING org.apache.spark.sql.jdbc
OPTIONS (
  url "jdbc:<databaseServerType>://<jdbcHostname>:<jdbcPort>",
  table "<jdbcDatabase>.atable",
  user "<jdbcUsername>",
  password "<jdbcPassword>"
)

Append data into the database table using Spark SQL:

%sql
INSERT INTO diamonds
SELECT * FROM diamonds LIMIT 10 -- append 10 records to the table
%sql
SELECT count(*) record_count FROM diamonds --count increased by 10

Overwrite data in the database table using Spark SQL. This causes the database to drop and create the diamonds table:

%sql
INSERT OVERWRITE TABLE diamonds
SELECT carat, cut, color, clarity, depth, TABLE AS table_number, price, x, y, z FROM diamonds
%sql
SELECT count(*) record_count FROM diamonds --count returned to original value (10 less)

Optimize performance when reading data

If you’re attempting to read data from an external JDBC database and it’s slow here are some suggestions to improve performance:

Determine whether the JDBC unload is occurring in parallel

In order to load data in parallel, the Spark JDBC data source must be configured with appropriate partitioning information so that it can issue multiple concurrent queries to the external database. If you neglect to configure partitioning then all data will be fetched on the driver using a single JDBC query which runs the risk of causing the driver to throw an OOM exception.

There are two APIs for specifying partitioning, high level and low level.

The high level API takes the name of a numeric column (columnName), two range endpoints (lowerBound, upperBound) and a target numPartitions and generates Spark tasks by evenly splitting the specified range into numPartitions tasks. This work well if your database table has an indexed numeric column with fairly evenly-distributed values, such as an auto-incrementing primary key; it works somewhat less well if the numeric column is extremely skewed, leading to imbalanced tasks.

The low level API, accessible in Scala, accepts an array of WHERE conditions that can be used to define custom partitions: this is useful for partitioning on non-numeric columns or for dealing with skew. When defining custom partitions, do not forget to consider NULL when the partition columns are Nullable. We do not suggest that you manually define partitions using more than two columns since writing the boundary predicates require much more complex logic.

Tune the JDBC fetchSize parameter

JDBC drivers have a fetchSize parameter that controls the number of rows fetched at a time from the remote JDBC database. If this value is set too low then your workload may become latency-bound due to a high number of roundtrip requests between Spark and the external database in order to fetch the full result set. If this value is too high then you risk OOMs. The optimal value will be workload dependent (since it depends on the result schema, sizes of strings in results, and so on), but increasing it even slightly from the default can result in huge performance gains.

Oracle’s default fetchSize is 10. Increasing it even slightly, to 100, gives massive performance gains, and going up to a higher value, like 2000, gives an additional improvement. For example:

PreparedStatement stmt = null;
ResultSet rs = null;

try {
  stmt = conn. prepareStatement("select a, b, c from table");
  stmt.setFetchSize(100);

  rs = stmt.executeQuery();
  while (rs.next()) {
    ...
  }
}

See Make your java run faster for a more general discussion of this tuning parameter for Oracle JDBC drivers.

Consider the impact of indexes

If you are reading in parallel (using one of the partitioning techniques) Spark issues concurrent queries to the JDBC database. If these queries end up requiring full table scans this could end up bottlenecking in the remote database and become extremely slow. Thus you should consider the impact of indexes when choosing a partitioning column and pick a column such that the individual partitions’ queries can be executed reasonably efficiently in parallel.

Important

Make sure that the database has an index on the partitioning column.

When a single-column index is not defined on the source table, you still can choose the leading(leftmost) column in a composite index as the partitioning column. When only composite indexes are available, most databases can use a concatenated index when searching with the leading (leftmost) columns. Thus, the leading column in a multi-column index can also be used as a partitioning column.

Consider whether the number of partitions is appropriate
Using too many partitions when reading from the external database risks overloading that database with too many queries. Most DBMS systems have limits on the concurrent connections. As a starting point, aim to have the number of partitions be close to the number of cores / task slots in your Spark cluster in order to maximize parallelism but keep the total number of queries capped at a reasonable limit. If you need lots of parallelism after fetching the JDBC rows (because you’re doing something CPU-bound in Spark) but don’t want to issue too many concurrent queries to your database then consider using a lower numPartitions for the JDBC read and then doing an explicit repartition() in Spark.
Consider database-specific tuning techniques
The database vendor may have a guide on tuning performance for ETL / bulk access workloads.