Whitelist IP Addresses

This guide walks you through how to programmatically modify AWS security groups so that additional IP addresses can access Databricks. You can use this as a guide to help you:

  • Open access to additional IP addresses to the Databricks REST API.
  • Modify access to security groups for your DataSources, such as for your RDS, Redshift, and more.

Set up the keys and region

You’ll need a set of AWS keys with permission to describe the security groups in your AWS account and add new rules to it. We suggest the permissions for your keys.

{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Sid": "Stmt1403287045000",
            "Effect": "Allow",
            "Action": [
                "ec2:DescribeSecurityGroups",
                "ec2:AuthorizeSecurityGroupIngress",
                "ec2:RevokeSecurityGroupIngress"
            ],
            "Resource": [
                "*"
            ]
        }
    ]
}
ACCESS_KEY = ""
SECRET_KEY = "" # shouldn't have to encode or escape
REGION = ""

# Add a list of the IPAddresses to add
IP_ADDRESSES_TO_ADD = ["1.2.3.4/32"]
PORTS_TO_ADD = ["80","443"] # only include individual ports, not ranges
PROTOCOLS_TO_ADD = ["tcp"]

Find the Databricks external services security group

import boto.ec2

# this will only find the first matching security group
def get_databricks_security_group():
  conn = boto.ec2.connect_to_region(REGION,
                                    aws_access_key_id=ACCESS_KEY,
                                    aws_secret_access_key=SECRET_KEY)
  rs = conn.get_all_security_groups()
  for r in rs:
    if (r.name.find('dbc') == 0 or r.name.find('databricks') == 0) and (r.name.find('-ExternalServices') > 0 or r.name.find('-worker-unmanaged') > 0):
      return r

databricks_security_group = get_databricks_security_group()
print databricks_security_group

Make sure the security group printed above seems reasonable.

Authorize the new IP addresses

from collections import defaultdict

def sg_rule_dict(security_group):
  rule_dict = defaultdict(list)
  rules = map(lambda x: ("%s:%s-%s" %(x.ip_protocol, x.from_port, x.to_port), x.grants), security_group.rules)
  for rule in rules:
    for grant in rule[1]:
      rule_dict[rule[0]].append(grant.cidr_ip)
  return rule_dict

existing_rules = sg_rule_dict(databricks_security_group)

for ip_address in IP_ADDRESSES_TO_ADD:
  for port in PORTS_TO_ADD:
     for protocol in PROTOCOLS_TO_ADD:
       key = "%s:%s-%s" % (protocol, port, port)
       if ip_address not in existing_rules[key]:
         databricks_security_group.authorize(ip_protocol=protocol,
                                             from_port=port,
                                             to_port=port,
                                             cidr_ip=ip_address)