|
20 | 20 | package org.apache.spark.shuffle.comet; |
21 | 21 |
|
22 | 22 | import java.io.IOException; |
23 | | -import java.util.BitSet; |
24 | 23 |
|
25 | 24 | import org.apache.spark.SparkConf; |
26 | 25 | import org.apache.spark.memory.MemoryConsumer; |
27 | 26 | import org.apache.spark.memory.MemoryMode; |
28 | | -import org.apache.spark.memory.SparkOutOfMemoryError; |
29 | 27 | import org.apache.spark.memory.TaskMemoryManager; |
30 | | -import org.apache.spark.sql.internal.SQLConf; |
31 | | -import org.apache.spark.unsafe.array.LongArray; |
32 | 28 | import org.apache.spark.unsafe.memory.MemoryBlock; |
33 | | -import org.apache.spark.unsafe.memory.UnsafeMemoryAllocator; |
| 29 | +import org.apache.spark.util.Utils; |
34 | 30 |
|
35 | | -import org.apache.comet.CometSparkSessionExtensions$; |
| 31 | +import org.apache.comet.CometConf$; |
36 | 32 |
|
37 | 33 | /** |
38 | 34 | * A simple memory allocator used by `CometShuffleExternalSorter` to allocate memory blocks which |
39 | | - * store serialized rows. We don't rely on Spark memory allocator because we need to allocate |
40 | | - * off-heap memory no matter memory mode is on-heap or off-heap. This allocator is configured with |
41 | | - * fixed size of memory, and it will throw `SparkOutOfMemoryError` if the memory is not enough. |
42 | | - * |
43 | | - * <p>Some methods are copied from `org.apache.spark.unsafe.memory.TaskMemoryManager` with |
44 | | - * modifications. Most modifications are to remove the dependency on the configured memory mode. |
| 35 | + * store serialized rows. This class is simply an implementation of `MemoryConsumer` that delegates |
| 36 | + * memory allocation to the `TaskMemoryManager`. This requires that the `TaskMemoryManager` is |
| 37 | + * configured with `MemoryMode.OFF_HEAP`, i.e. it is using off-heap memory. |
45 | 38 | */ |
46 | | -public final class CometShuffleMemoryAllocator extends MemoryConsumer { |
47 | | - private final UnsafeMemoryAllocator allocator = new UnsafeMemoryAllocator(); |
48 | | - |
49 | | - private final long pageSize; |
50 | | - private final long totalMemory; |
51 | | - private long allocatedMemory = 0L; |
52 | | - |
53 | | - /** The number of bits used to address the page table. */ |
54 | | - private static final int PAGE_NUMBER_BITS = 13; |
55 | | - |
56 | | - /** The number of entries in the page table. */ |
57 | | - private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; |
58 | | - |
59 | | - private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE]; |
60 | | - private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); |
| 39 | +public final class CometShuffleMemoryAllocator extends CometShuffleMemoryAllocatorTrait { |
| 40 | + private static CometShuffleMemoryAllocatorTrait INSTANCE; |
61 | 41 |
|
62 | | - private static final int OFFSET_BITS = 51; |
63 | | - private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; |
64 | | - |
65 | | - private static CometShuffleMemoryAllocator INSTANCE; |
66 | | - |
67 | | - public static synchronized CometShuffleMemoryAllocator getInstance( |
| 42 | + /** |
| 43 | + * Returns the singleton instance of `CometShuffleMemoryAllocator`. This method should be used |
| 44 | + * instead of the constructor to ensure that only one instance of `CometShuffleMemoryAllocator` is |
| 45 | + * created. For Spark tests, this returns `CometTestShuffleMemoryAllocator` which is a test-only |
| 46 | + * allocator that should not be used in production. |
| 47 | + */ |
| 48 | + public static CometShuffleMemoryAllocatorTrait getInstance( |
68 | 49 | SparkConf conf, TaskMemoryManager taskMemoryManager, long pageSize) { |
69 | | - if (INSTANCE == null) { |
70 | | - INSTANCE = new CometShuffleMemoryAllocator(conf, taskMemoryManager, pageSize); |
| 50 | + boolean isSparkTesting = Utils.isTesting(); |
| 51 | + boolean useUnifiedMemAllocator = |
| 52 | + (boolean) |
| 53 | + CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_UNIFIED_MEMORY_ALLOCATOR_IN_TEST().get(); |
| 54 | + |
| 55 | + if (isSparkTesting && !useUnifiedMemAllocator) { |
| 56 | + synchronized (CometShuffleMemoryAllocator.class) { |
| 57 | + if (INSTANCE == null) { |
| 58 | + // CometTestShuffleMemoryAllocator handles pages by itself so it can be a singleton. |
| 59 | + INSTANCE = new CometTestShuffleMemoryAllocator(conf, taskMemoryManager, pageSize); |
| 60 | + } |
| 61 | + } |
| 62 | + return INSTANCE; |
| 63 | + } else { |
| 64 | + if (taskMemoryManager.getTungstenMemoryMode() != MemoryMode.OFF_HEAP) { |
| 65 | + throw new IllegalArgumentException( |
| 66 | + "CometShuffleMemoryAllocator should be used with off-heap " |
| 67 | + + "memory mode, but got " |
| 68 | + + taskMemoryManager.getTungstenMemoryMode()); |
| 69 | + } |
| 70 | + |
| 71 | + // CometShuffleMemoryAllocator stores pages in TaskMemoryManager which is not singleton, |
| 72 | + // but one instance per task. So we need to create a new instance for each task. |
| 73 | + return new CometShuffleMemoryAllocator(taskMemoryManager, pageSize); |
71 | 74 | } |
72 | | - |
73 | | - return INSTANCE; |
74 | 75 | } |
75 | 76 |
|
76 | | - CometShuffleMemoryAllocator(SparkConf conf, TaskMemoryManager taskMemoryManager, long pageSize) { |
| 77 | + CometShuffleMemoryAllocator(TaskMemoryManager taskMemoryManager, long pageSize) { |
77 | 78 | super(taskMemoryManager, pageSize, MemoryMode.OFF_HEAP); |
78 | | - this.pageSize = pageSize; |
79 | | - this.totalMemory = |
80 | | - CometSparkSessionExtensions$.MODULE$.getCometShuffleMemorySize(conf, SQLConf.get()); |
81 | | - } |
82 | | - |
83 | | - public synchronized long acquireMemory(long size) { |
84 | | - if (allocatedMemory >= totalMemory) { |
85 | | - throw new SparkOutOfMemoryError( |
86 | | - "Unable to acquire " |
87 | | - + size |
88 | | - + " bytes of memory, current usage " |
89 | | - + "is " |
90 | | - + allocatedMemory |
91 | | - + " bytes and max memory is " |
92 | | - + totalMemory |
93 | | - + " bytes"); |
94 | | - } |
95 | | - long allocationSize = Math.min(size, totalMemory - allocatedMemory); |
96 | | - allocatedMemory += allocationSize; |
97 | | - return allocationSize; |
98 | 79 | } |
99 | 80 |
|
100 | 81 | public long spill(long l, MemoryConsumer memoryConsumer) throws IOException { |
| 82 | + // JVM shuffle writer does not support spilling for other memory consumers |
101 | 83 | return 0; |
102 | 84 | } |
103 | 85 |
|
104 | | - public synchronized LongArray allocateArray(long size) { |
105 | | - long required = size * 8L; |
106 | | - MemoryBlock page = allocate(required); |
107 | | - return new LongArray(page); |
108 | | - } |
109 | | - |
110 | | - public synchronized void freeArray(LongArray array) { |
111 | | - if (array == null) { |
112 | | - return; |
113 | | - } |
114 | | - free(array.memoryBlock()); |
115 | | - } |
116 | | - |
117 | | - public synchronized MemoryBlock allocatePage(long required) { |
118 | | - long size = Math.max(pageSize, required); |
119 | | - return allocate(size); |
120 | | - } |
121 | | - |
122 | | - private synchronized MemoryBlock allocate(long required) { |
123 | | - if (required > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) { |
124 | | - throw new TooLargePageException(required); |
125 | | - } |
126 | | - |
127 | | - long got = acquireMemory(required); |
128 | | - |
129 | | - if (got < required) { |
130 | | - allocatedMemory -= got; |
131 | | - |
132 | | - throw new SparkOutOfMemoryError( |
133 | | - "Unable to acquire " |
134 | | - + required |
135 | | - + " bytes of memory, got " |
136 | | - + got |
137 | | - + " bytes. Available: " |
138 | | - + (totalMemory - allocatedMemory)); |
139 | | - } |
140 | | - |
141 | | - int pageNumber = allocatedPages.nextClearBit(0); |
142 | | - if (pageNumber >= PAGE_TABLE_SIZE) { |
143 | | - allocatedMemory -= got; |
144 | | - |
145 | | - throw new IllegalStateException( |
146 | | - "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); |
147 | | - } |
148 | | - |
149 | | - MemoryBlock block = allocator.allocate(got); |
150 | | - |
151 | | - block.pageNumber = pageNumber; |
152 | | - pageTable[pageNumber] = block; |
153 | | - allocatedPages.set(pageNumber); |
154 | | - |
155 | | - return block; |
| 86 | + public synchronized MemoryBlock allocate(long required) { |
| 87 | + return this.allocatePage(required); |
156 | 88 | } |
157 | 89 |
|
158 | 90 | public synchronized void free(MemoryBlock block) { |
159 | | - if (block.pageNumber == MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) { |
160 | | - // Already freed block |
161 | | - return; |
162 | | - } |
163 | | - allocatedMemory -= block.size(); |
164 | | - |
165 | | - pageTable[block.pageNumber] = null; |
166 | | - allocatedPages.clear(block.pageNumber); |
167 | | - block.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; |
168 | | - |
169 | | - allocator.free(block); |
170 | | - } |
171 | | - |
172 | | - public synchronized long getAvailableMemory() { |
173 | | - return totalMemory - allocatedMemory; |
| 91 | + this.freePage(block); |
174 | 92 | } |
175 | 93 |
|
176 | 94 | /** |
177 | 95 | * Returns the offset in the page for the given page plus base offset address. Note that this |
178 | 96 | * method assumes that the page number is valid. |
179 | 97 | */ |
180 | 98 | public long getOffsetInPage(long pagePlusOffsetAddress) { |
181 | | - long offsetInPage = decodeOffset(pagePlusOffsetAddress); |
182 | | - int pageNumber = TaskMemoryManager.decodePageNumber(pagePlusOffsetAddress); |
183 | | - assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); |
184 | | - MemoryBlock page = pageTable[pageNumber]; |
185 | | - assert (page != null); |
186 | | - return page.getBaseOffset() + offsetInPage; |
187 | | - } |
188 | | - |
189 | | - public long decodeOffset(long pagePlusOffsetAddress) { |
190 | | - return pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS; |
| 99 | + return taskMemoryManager.getOffsetInPage(pagePlusOffsetAddress); |
191 | 100 | } |
192 | 101 |
|
193 | 102 | public long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { |
194 | | - assert (pageNumber >= 0); |
195 | | - return ((long) pageNumber) << OFFSET_BITS | offsetInPage & MASK_LONG_LOWER_51_BITS; |
| 103 | + return TaskMemoryManager.encodePageNumberAndOffset(pageNumber, offsetInPage); |
196 | 104 | } |
197 | 105 |
|
198 | 106 | public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { |
|
0 commit comments