diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 96882c62c2..50dc4b4b08 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -138,6 +138,7 @@ class SparkSession private(
           SparkSession.sessionStateClassName(sparkContext.conf),
           self)
         initialSessionOptions.foreach { case (k, v) => state.conf.setConfString(k, v) }
+        extensions.registerFunctions(state)
         state
       }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index f99c108161..057dec1b4e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -20,9 +20,11 @@ package org.apache.spark.sql
 import scala.collection.mutable
 
 import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
+import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.parser.ParserInterface
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.SessionState
 
 /**
  * :: Experimental ::
@@ -65,6 +67,7 @@ class SparkSessionExtensions {
   type CheckRuleBuilder = SparkSession => LogicalPlan => Unit
   type StrategyBuilder = SparkSession => Strategy
   type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
+  type FunctionDescription = (String, Seq[Expression] => Expression)
 
   private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
 
@@ -168,4 +171,21 @@ class SparkSessionExtensions {
   def injectParser(builder: ParserBuilder): Unit = {
     parserBuilders += builder
   }
+
+  private[this] val injectedFunctions =
+    mutable.Buffer.empty[FunctionDescription]
+
+  private[sql] def registerFunctions(sessionState: SessionState) = {
+    for ((name, function) <- injectedFunctions) {
+      sessionState.functionRegistry.registerFunction(name, function)
+    }
+  }
+
+  /**
+   * Injects a custom function into the [[org.apache.spark.sql.catalyst.analysis.FunctionRegistry]]
+   * at runtime for all sessions.
+   */
+  def injectFunction(functionDescription: FunctionDescription): Unit = {
+    injectedFunctions += functionDescription
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 43db796633..af8782c9fb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -18,12 +18,12 @@ package org.apache.spark.sql
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
 
 /**
  * Test cases for the [[SparkSessionExtensions]].
@@ -32,6 +32,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
   type ExtensionsBuilder = SparkSessionExtensions => Unit
   private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder
 
+
   private def stop(spark: SparkSession): Unit = {
     spark.stop()
     SparkSession.clearActiveSession()
@@ -90,6 +91,16 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
     }
   }
 
+  test("inject function") {
+    val extensions = create { extensions =>
+      extensions.injectFunction(MyExtensions.myFunction)
+    }
+    withSession(extensions) { session =>
+      assert(session.sessionState.functionRegistry
+        .lookupFunction(MyExtensions.myFunction._1).isDefined)
+    }
+  }
+
   test("use custom class for extensions") {
     val session = SparkSession.builder()
       .master("local[1]")
@@ -98,6 +109,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
     try {
       assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
       assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
+      assert(session.sessionState.functionRegistry
+        .lookupFunction(MyExtensions.myFunction._1).isDefined)
     } finally {
       stop(session)
     }
@@ -136,9 +149,16 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars
     delegate.parseDataType(sqlText)
 }
 
+object MyExtensions {
+
+  val myFunction = ("myFunction", (myArgs: Seq[Expression]) => Literal(5, IntegerType))
+}
+
 class MyExtensions extends (SparkSessionExtensions => Unit) {
+
   def apply(e: SparkSessionExtensions): Unit = {
     e.injectPlannerStrategy(MySparkStrategy)
     e.injectResolutionRule(MyRule)
+    e.injectFunction(MyExtensions.myFunction)
   }
 }
