diff --git a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala index 197c2f13d807b..844858b09e3c0 100644 --- a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.TestUtils._ import org.apache.spark.api.plugin._ import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.memory.MemoryMode import org.apache.spark.resource.ResourceInformation import org.apache.spark.resource.ResourceUtils.GPU import org.apache.spark.resource.TestResourceIDs.{DRIVER_GPU_ID, EXECUTOR_GPU_ID, WORKER_GPU_ID} @@ -228,6 +229,53 @@ class PluginContainerSuite extends SparkFunSuite with LocalSparkContext { assert(driverResources.get(GPU).name === GPU) } } + + test("memory override in plugin") { + val conf = new SparkConf() + .setAppName(getClass().getName()) + .set(SparkLauncher.SPARK_MASTER, "local-cluster[2,1,1024]") + .set(PLUGINS, Seq(classOf[MemoryOverridePlugin].getName())) + + val sc = new SparkContext(conf) + val memoryManager = sc.env.memoryManager + + assert(memoryManager.tungstenMemoryMode == MemoryMode.OFF_HEAP) + assert(memoryManager.maxOffHeapStorageMemory == MemoryOverridePlugin.offHeapMemory) + + // Ensure all executors has started + TestUtils.waitUntilExecutorsUp(sc, 1, 60000) + + // Check executor memory is also updated + val execInfo = sc.statusTracker.getExecutorInfos.head + assert(execInfo.totalOffHeapStorageMemory() == MemoryOverridePlugin.offHeapMemory) + } +} + +class MemoryOverridePlugin extends SparkPlugin { + override def driverPlugin(): DriverPlugin = { + new DriverPlugin { + override def init(sc: SparkContext, pluginContext: PluginContext): JMap[String, String] = { + // Take the original executor memory, and set `spark.memory.offHeap.size` to be the + // same value. Also set `spark.memory.offHeap.enabled` to true. + val originalExecutorMemBytes = { + sc.conf.getSizeAsMb(EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) + } + sc.conf.set(MEMORY_OFFHEAP_ENABLED.key, "true") + sc.conf.set(MEMORY_OFFHEAP_SIZE.key, s"${originalExecutorMemBytes}M") + MemoryOverridePlugin.offHeapMemory = sc.conf.getSizeAsBytes(MEMORY_OFFHEAP_SIZE.key) + Map.empty[String, String].asJava + } + } + } + + override def executorPlugin(): ExecutorPlugin = { + new ExecutorPlugin {} + } +} + +object MemoryOverridePlugin { + var offHeapMemory: Long = _ + var totalExecutorMemory: Long = _ } class NonLocalModeSparkPlugin extends SparkPlugin {