Skip to content

Commit

Permalink
[CELEBORN-1847] Introduce local and DFS tier writer
Browse files Browse the repository at this point in the history
  • Loading branch information
FMX committed Jan 25, 2025
1 parent a77a64b commit e19bb65
Show file tree
Hide file tree
Showing 3 changed files with 490 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.service.deploy.worker.storage;

import org.apache.celeborn.reflect.DynConstructors;
import org.apache.celeborn.server.common.service.mpu.MultipartUploadHandler;

public class TierWriterHelper {
public static MultipartUploadHandler getS3MultipartUploadHandler(
String bucketName,
String s3AccessKey,
String s3SecretKey,
String s3EndpointRegion,
String key,
int maxRetryies) {
return (MultipartUploadHandler)
DynConstructors.builder()
.impl(
"org.apache.celeborn.S3MultipartUploadHandler",
String.class,
String.class,
String.class,
String.class,
String.class,
Integer.class)
.build()
.newInstance(bucketName, s3AccessKey, s3SecretKey, s3EndpointRegion, key, maxRetryies);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,25 @@ package org.apache.celeborn.service.deploy.worker.storage

import java.io.IOException
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.JavaConverters.asScalaBufferConverter

import io.netty.buffer.{ByteBuf, CompositeByteBuf}

import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.AlreadyClosedException
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{FileInfo, MemoryFileInfo}
import org.apache.celeborn.common.meta.{DiskFileInfo, FileInfo, MemoryFileInfo}
import org.apache.celeborn.common.metrics.source.AbstractSource
import org.apache.celeborn.common.protocol.StorageInfo
import org.apache.celeborn.common.unsafe.Platform
import org.apache.celeborn.common.util.FileChannelUtils
import org.apache.celeborn.server.common.service.mpu.MultipartUploadHandler
import org.apache.celeborn.service.deploy.worker.WorkerSource
import org.apache.celeborn.service.deploy.worker.congestcontrol.{CongestionController, UserCongestionControlContext}
import org.apache.celeborn.service.deploy.worker.memory.MemoryManager

abstract class TierWriterBase(
Expand Down Expand Up @@ -330,3 +336,275 @@ class MemoryTierWriter(
// memory tier write does not need flush tasks
}
}

class LocalTierWriter(
conf: CelebornConf,
metaHandler: PartitionMetaHandler,
numPendingWrites: AtomicInteger,
notifier: FlushNotifier,
flusher: Flusher,
source: AbstractSource,
diskFileInfo: DiskFileInfo,
storageType: StorageInfo.Type,
partitionDataWriterContext: PartitionDataWriterContext,
storageManager: StorageManager)
extends TierWriterBase(
conf,
metaHandler,
numPendingWrites,
notifier,
diskFileInfo,
source,
storageType,
partitionDataWriterContext.getPartitionLocation.getFileName,
partitionDataWriterContext.getShuffleKey,
storageManager) {
flusherBufferSize = conf.workerFlusherBufferSize
val flushWorkerIndex = flusher.getWorkerIndex
val userCongestionControlContext: UserCongestionControlContext =
if (CongestionController.instance != null)
CongestionController.instance.getUserCongestionContext(
partitionDataWriterContext.getUserIdentifier)
else
null

private val channel: FileChannel =
FileChannelUtils.createWritableFileChannel(diskFileInfo.getFilePath);

override def needEvict: Boolean = {
false
}

override def genFlushTask(finalFlush: Boolean, keepBuffer: Boolean): FlushTask = {
notifier.numPendingFlushes.incrementAndGet()
new LocalFlushTask(flushBuffer, channel, notifier, true)
}

override def writerInternal(buf: ByteBuf): Unit = {
val numBytes = buf.readableBytes()
val flushBufferReadableBytes = flushBuffer.readableBytes
if (flushBufferReadableBytes != 0 && flushBufferReadableBytes + numBytes >= flusherBufferSize) {
flush(false)
}
buf.retain()
try {
flushBuffer.addComponent(true, buf)
MemoryManager.instance.incrementDiskBuffer(numBytes)
if (userCongestionControlContext != null)
userCongestionControlContext.updateProduceBytes(numBytes)
} catch {
case oom: OutOfMemoryError =>
buf.release()
MemoryManager.instance().releaseDiskBuffer(numBytes)
throw oom;
}
}

override def evict(file: TierWriterBase): Unit = ???

override def finalFlush(): Unit = {
if (flushBuffer != null && flushBuffer.readableBytes() > 0) {
flush(true)
}
}

override def closeStreams(): Unit = {
// local disk file won't need to close streams
}

override def notifyFileCommitted(): Unit =
storageManager.notifyFileInfoCommitted(shuffleKey, filename, diskFileInfo)

override def closeResource(): Unit = {
try if (channel != null) channel.close()
catch {
case e: IOException =>
logWarning(
s"Close channel failed for file ${diskFileInfo.getFilePath} caused by ${e.getMessage}.")
}
}

override def cleanLocalOrDfsFiles(): Unit = {
diskFileInfo.deleteAllFiles(null)
}

override def takeBufferInternal(): CompositeByteBuf = {
flusher.takeBuffer()
}

override def returnBufferInternal(destroy: Boolean): Unit = {
if (flushBuffer != null) {
flusher.returnBuffer(flushBuffer, true)
flushBuffer = null
}
}

override def addFlushTask(task: FlushTask): Unit = {
if (!flusher.addTask(task, writerCloseTimeoutMs, flushWorkerIndex)) {
val e = new IOException("Add flush task timeout.")
notifier.setException(e)
throw e
}
}
}

class DfsTierWriter(
conf: CelebornConf,
metaHandler: PartitionMetaHandler,
numPendingWrites: AtomicInteger,
notifier: FlushNotifier,
flusher: Flusher,
source: AbstractSource,
hdfsFileInfo: DiskFileInfo,
storageType: StorageInfo.Type,
partitionDataWriterContext: PartitionDataWriterContext,
storageManager: StorageManager)
extends TierWriterBase(
conf,
metaHandler,
numPendingWrites,
notifier,
hdfsFileInfo,
source,
storageType,
partitionDataWriterContext.getPartitionLocation.getFileName,
partitionDataWriterContext.getShuffleKey,
storageManager) {
flusherBufferSize = conf.workerHdfsFlusherBufferSize
val flushWorkerIndex = flusher.getWorkerIndex
val hadoopFs = StorageManager.hadoopFs.get(storageType)
var deleted = false
var s3MultipartUploadHandler: MultipartUploadHandler = null
var partNumber: Int = 1

this.flusherBufferSize =
if (hdfsFileInfo.isS3()) {
conf.workerS3FlusherBufferSize
} else {
conf.workerHdfsFlusherBufferSize
}

try {
hadoopFs.create(hdfsFileInfo.getDfsPath, true).close()
if (hdfsFileInfo.isS3) {
val configuration = hadoopFs.getConf
val s3AccessKey = configuration.get("fs.s3a.access.key")
val s3SecretKey = configuration.get("fs.s3a.secret.key")
val s3EndpointRegion = configuration.get("fs.s3a.endpoint.region")

val uri = hadoopFs.getUri
val bucketName = uri.getHost
val index = hdfsFileInfo.getFilePath.indexOf(bucketName)
val key = hdfsFileInfo.getFilePath.substring(index + bucketName.length + 1)

this.s3MultipartUploadHandler = TierWriterHelper.getS3MultipartUploadHandler(
bucketName,
s3AccessKey,
s3SecretKey,
s3EndpointRegion,
key,
conf.s3MultiplePartUploadMaxRetries)
s3MultipartUploadHandler.startUpload()
}
} catch {
case _: IOException =>
try
// If create file failed, wait 10 ms and retry
Thread.sleep(10)
catch {
case ex: InterruptedException =>
throw new RuntimeException(ex)
}
hadoopFs.create(hdfsFileInfo.getDfsPath, true).close()
}

override def needEvict: Boolean = {
false
}

override def genFlushTask(finalFlush: Boolean, keepBuffer: Boolean): FlushTask = {
notifier.numPendingFlushes.incrementAndGet()
if (hdfsFileInfo.isHdfs) {
new HdfsFlushTask(flushBuffer, hdfsFileInfo.getDfsPath(), notifier, true)
} else {
val flushTask = new S3FlushTask(
flushBuffer,
notifier,
false,
s3MultipartUploadHandler,
partNumber,
finalFlush)
partNumber = partNumber + 1
flushTask
}
}

override def writerInternal(buf: ByteBuf): Unit = {
val numBytes = buf.readableBytes()
val flushBufferReadableBytes = flushBuffer.readableBytes
if (flushBufferReadableBytes != 0 && flushBufferReadableBytes + numBytes >= flusherBufferSize) {
flush(false)
}
buf.retain()
try {
flushBuffer.addComponent(true, buf)
MemoryManager.instance.incrementDiskBuffer(numBytes)
} catch {
case oom: OutOfMemoryError =>
buf.release()
MemoryManager.instance().releaseDiskBuffer(numBytes)
throw oom;
}
}

override def evict(file: TierWriterBase): Unit = ???

override def finalFlush(): Unit = {
if (flushBuffer != null && flushBuffer.readableBytes() > 0) {
flush(true)
}
}

override def closeStreams(): Unit = {
if (hadoopFs.exists(hdfsFileInfo.getDfsPeerWriterSuccessPath)) {
hadoopFs.delete(hdfsFileInfo.getDfsPath, false)
deleted = true
} else {
hadoopFs.create(hdfsFileInfo.getDfsWriterSuccessPath).close()
val indexOutputStream = hadoopFs.create(hdfsFileInfo.getDfsIndexPath)
indexOutputStream.writeInt(hdfsFileInfo.getReduceFileMeta.getChunkOffsets.size)
for (offset <- hdfsFileInfo.getReduceFileMeta.getChunkOffsets.asScala) {
indexOutputStream.writeLong(offset)
}
indexOutputStream.close()
}
}

override def notifyFileCommitted(): Unit =
storageManager.notifyFileInfoCommitted(shuffleKey, filename, hdfsFileInfo)

override def closeResource(): Unit = {}

override def cleanLocalOrDfsFiles(): Unit = {
hdfsFileInfo.deleteAllFiles(hadoopFs)
}

override def takeBufferInternal(): CompositeByteBuf = {
flusher.takeBuffer()
}

override def returnBufferInternal(destroy: Boolean): Unit = {
if (flushBuffer != null) {
flusher.returnBuffer(flushBuffer, true)
flushBuffer = null
}
}

override def addFlushTask(task: FlushTask): Unit = {
if (!flusher.addTask(task, writerCloseTimeoutMs, flushWorkerIndex)) {
val e = new IOException("Add flush task timeout.")
notifier.setException(e)
throw e
}
}
}
Loading

0 comments on commit e19bb65

Please sign in to comment.