@@ -19,9 +19,11 @@ package org.apache.spark.deploy.yarn
19
19
20
20
import scala .collection .mutable .{HashMap , HashSet , Set }
21
21
22
- import org .apache .hadoop .conf .Configuration
22
+ import org .apache .hadoop .fs .CommonConfigurationKeysPublic
23
+ import org .apache .hadoop .net .DNSToSwitchMapping
23
24
import org .apache .hadoop .yarn .api .records ._
24
25
import org .apache .hadoop .yarn .client .api .AMRMClient .ContainerRequest
26
+ import org .apache .hadoop .yarn .conf .YarnConfiguration
25
27
import org .mockito .Mockito ._
26
28
27
29
import org .apache .spark .{SparkConf , SparkFunSuite }
@@ -49,18 +51,22 @@ class LocalityPlacementStrategySuite extends SparkFunSuite {
49
51
}
50
52
51
53
private def runTest (): Unit = {
54
+ val yarnConf = new YarnConfiguration ()
55
+ yarnConf.setClass(
56
+ CommonConfigurationKeysPublic .NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY ,
57
+ classOf [MockResolver ], classOf [DNSToSwitchMapping ])
58
+
52
59
val resource = Resource .newInstance(8 * 1024 , 4 )
53
60
val strategy = new LocalityPreferredContainerPlacementStrategy (new SparkConf (),
54
- new Configuration () , resource)
61
+ yarnConf , resource)
55
62
56
63
val totalTasks = 32 * 1024
57
64
val totalContainers = totalTasks / 16
58
65
val totalHosts = totalContainers / 16
59
66
67
+ val mockId = mock(classOf [ContainerId ])
60
68
val hosts = (1 to totalHosts).map { i => (s " host_ $i" , totalTasks % i) }.toMap
61
- val containers = (1 to totalContainers).map { i =>
62
- ContainerId .fromString(s " container_12345678_0001_01_ $i" )
63
- }
69
+ val containers = (1 to totalContainers).map { i => mockId }
64
70
val count = containers.size / hosts.size / 2
65
71
66
72
val hostToContainerMap = new HashMap [String , Set [ContainerId ]]()
0 commit comments