Skip to content

Commit 40838f0

Browse files
committed
refactor(core): make the endpoint responsible for closing the frame
1 parent 711efc6 commit 40838f0

File tree

16 files changed

+244
-177
lines changed

16 files changed

+244
-177
lines changed

core/src/androidTest/java/io/github/thibaultbee/streampack/core/elements/endpoints/DummyEndpoint.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,14 @@ class DummyEndpoint : IEndpointInternal {
6363
_isOpenFlow.emit(false)
6464
}
6565

66-
override suspend fun write(frame: Frame, streamPid: Int) {
66+
override suspend fun write(frame: Frame, streamPid: Int, onFrameProcessed: (() -> Unit)) {
6767
Log.i(TAG, "write: $frame")
6868
_frameFlow.emit(frame)
6969
when {
7070
frame.isAudio -> numOfAudioFramesWritten++
7171
frame.isVideo -> numOfVideoFramesWritten++
7272
}
73+
onFrameProcessed()
7374
}
7475

7576
override fun addStreams(streamConfigs: List<CodecConfig>): Map<CodecConfig, Int> {

core/src/main/java/io/github/thibaultbee/streampack/core/elements/endpoints/CombineEndpoint.kt

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,24 @@ import io.github.thibaultbee.streampack.core.elements.utils.combineStates
2323
import io.github.thibaultbee.streampack.core.elements.utils.extensions.intersect
2424
import io.github.thibaultbee.streampack.core.logger.Logger
2525
import io.github.thibaultbee.streampack.core.pipelines.outputs.encoding.EncodingPipelineOutputDispatcherProvider
26+
import io.github.thibaultbee.streampack.core.pipelines.utils.MultiThrowable
27+
import kotlinx.coroutines.CompletableDeferred
28+
import kotlinx.coroutines.CoroutineDispatcher
29+
import kotlinx.coroutines.CoroutineScope
30+
import kotlinx.coroutines.Deferred
31+
import kotlinx.coroutines.SupervisorJob
2632
import kotlinx.coroutines.flow.StateFlow
33+
import kotlinx.coroutines.launch
2734

2835

2936
/**
3037
* Combines multiple endpoints into one.
3138
*
3239
* @param endpoints Endpoints to combine
40+
* @param coroutineDispatcher Coroutine dispatcher to use for frame writing
3341
*/
34-
fun CombineEndpoint(vararg endpoints: IEndpointInternal) = CombineEndpoint(endpoints.toList())
42+
fun CombineEndpoint(vararg endpoints: IEndpointInternal, coroutineDispatcher: CoroutineDispatcher) =
43+
CombineEndpoint(endpoints.toList(), coroutineDispatcher)
3544

3645
/**
3746
* Combines multiple endpoints into one.
@@ -42,9 +51,17 @@ fun CombineEndpoint(vararg endpoints: IEndpointInternal) = CombineEndpoint(endpo
4251
*
4352
* For specific behavior like reconnecting your remote endpoint, you can create a custom endpoint that
4453
* inherits from [CombineEndpoint] and override [open], [close], [startStream], [stopStream].
54+
*
55+
* @param endpointInternals List of endpoints to combine
56+
* @param coroutineDispatcher Coroutine dispatcher to use for frame writing
4557
*/
46-
open class CombineEndpoint(protected val endpointInternals: List<IEndpointInternal>) :
58+
open class CombineEndpoint(
59+
protected val endpointInternals: List<IEndpointInternal>,
60+
coroutineDispatcher: CoroutineDispatcher
61+
) :
4762
IEndpointInternal {
63+
private val coroutineScope = CoroutineScope(SupervisorJob() + coroutineDispatcher)
64+
4865
/**
4966
* Internal map of endpoint streamId to real streamIds
5067
*/
@@ -185,31 +202,36 @@ open class CombineEndpoint(protected val endpointInternals: List<IEndpointIntern
185202
*
186203
* If all endpoints write fails, it throws the exception of the first endpoint that failed.
187204
*/
188-
override suspend fun write(frame: Frame, streamPid: Int) {
189-
val currentBufferPos = frame.buffer.position()
190-
var numOfThrowable = 0
191-
var throwable: Throwable? = null
205+
override suspend fun write(frame: Frame, streamPid: Int, onFrameProcessed: (() -> Unit)) {
206+
val throwables = mutableListOf<Throwable>()
192207

193-
endpointInternals.forEach { endpoint ->
208+
/**
209+
* Track the number of frames written and processed to call onFrameProcessed only once
210+
* when all endpoints have processed the frame.
211+
*/
212+
val deferreds = mutableListOf<Deferred<*>>()
213+
214+
endpointInternals.filter { it.isOpenFlow.value }.forEach { endpoint ->
194215
try {
195-
if (endpoint.isOpenFlow.value) {
196-
val endpointStreamId = endpointsToStreamIdsMap[Pair(endpoint, streamPid)]!!
197-
endpoint.write(frame, endpointStreamId)
216+
val deferred = CompletableDeferred<Unit>()
217+
val duplicatedFrame = frame.copy(rawBuffer = frame.rawBuffer.duplicate())
218+
val endpointStreamId = endpointsToStreamIdsMap[Pair(endpoint, streamPid)]!!
198219

199-
// Reset buffer position to write frame to next endpoint
200-
frame.buffer.position(currentBufferPos)
201-
}
220+
deferreds += deferred
221+
endpoint.write(duplicatedFrame, endpointStreamId, { deferred.complete(Unit) })
202222
} catch (t: Throwable) {
203-
Logger.e(TAG, "Failed to write frame to endpoint $endpoint", t)
204-
if (throwable == null) {
205-
throwable = t
206-
}
207-
numOfThrowable++
223+
Logger.e(TAG, "Failed to get stream id for endpoint $endpoint", t)
224+
throwables += t
208225
}
209226
}
210227

211-
if (numOfThrowable == endpointInternals.size) {
212-
throw throwable!!
228+
coroutineScope.launch {
229+
deferreds.forEach { it.await() }
230+
onFrameProcessed()
231+
}
232+
233+
if (throwables.isNotEmpty()) {
234+
throw MultiThrowable(throwables)
213235
}
214236
}
215237

@@ -245,7 +267,8 @@ class CombineEndpointFactory(private val endpointFactory: List<IEndpointInternal
245267
dispatcherProvider: EncodingPipelineOutputDispatcherProvider
246268
): IEndpointInternal {
247269
return CombineEndpoint(
248-
endpointFactory.map { it.create(context, dispatcherProvider) }
270+
endpointFactory.map { it.create(context, dispatcherProvider) },
271+
dispatcherProvider.defaultDispatcher
249272
)
250273
}
251274
}

core/src/main/java/io/github/thibaultbee/streampack/core/elements/endpoints/DualEndpoint.kt

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package io.github.thibaultbee.streampack.core.elements.endpoints
1818
import android.content.Context
1919
import io.github.thibaultbee.streampack.core.configuration.mediadescriptor.MediaDescriptor
2020
import io.github.thibaultbee.streampack.core.pipelines.outputs.encoding.EncodingPipelineOutputDispatcherProvider
21+
import kotlinx.coroutines.CoroutineDispatcher
2122

2223
/**
2324
* An implementation of [CombineEndpoint] that combines two endpoints.
@@ -29,11 +30,15 @@ import io.github.thibaultbee.streampack.core.pipelines.outputs.encoding.Encoding
2930
*
3031
* @param mainEndpoint the main endpoint
3132
* @param secondEndpoint the second endpoint
33+
* @param coroutineDispatcher the coroutine dispatcher
3234
*/
3335
open class DualEndpoint(
34-
private val mainEndpoint: IEndpointInternal, private val secondEndpoint: IEndpointInternal
36+
private val mainEndpoint: IEndpointInternal,
37+
private val secondEndpoint: IEndpointInternal,
38+
coroutineDispatcher: CoroutineDispatcher
3539
) : CombineEndpoint(
36-
listOf(secondEndpoint, mainEndpoint)
40+
listOf(secondEndpoint, mainEndpoint),
41+
coroutineDispatcher
3742
) {
3843
/**
3944
* Opens the [mainEndpoint].
@@ -91,7 +96,8 @@ class DualEndpointFactory(
9196
): IEndpointInternal {
9297
return DualEndpoint(
9398
mainEndpointFactory.create(context, dispatcherProvider),
94-
secondEndpointFactory.create(context, dispatcherProvider)
99+
secondEndpointFactory.create(context, dispatcherProvider),
100+
dispatcherProvider.defaultDispatcher
95101
)
96102
}
97103
}

core/src/main/java/io/github/thibaultbee/streampack/core/elements/endpoints/DynamicEndpoint.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ open class DynamicEndpoint(
9898

9999
override fun addStream(streamConfig: CodecConfig) = endpoint.addStream(streamConfig)
100100

101-
override suspend fun write(frame: Frame, streamPid: Int) = endpoint.write(frame, streamPid)
101+
override suspend fun write(frame: Frame, streamPid: Int, onFrameProcessed: () -> Unit) =
102+
endpoint.write(frame, streamPid, onFrameProcessed)
102103

103104
override suspend fun startStream() = endpoint.startStream()
104105

core/src/main/java/io/github/thibaultbee/streampack/core/elements/endpoints/IEndpoint.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,18 @@ interface IEndpointInternal : IEndpoint, SuspendStreamable,
3939
/**
4040
* Writes a [Frame] to the [IEndpointInternal].
4141
*
42+
* The [onFrameProcessed] callback must be called when the frame has been processed and the [Frame.rawBuffer] is not used anymore.
43+
* The [IEndpointInternal] must called [onFrameProcessed] even if the frame is dropped or it somehow crashes.
44+
* Also, once [onFrameProcessed] is called, the [Frame.rawBuffer] must not be used anymore by the [IEndpointInternal].
45+
*
4246
* @param frame the [Frame] to write
4347
* @param streamPid the stream id the [Frame] belongs to
48+
* @param onFrameProcessed a callback called when the [Frame.rawBuffer] is not used anymore
4449
*/
4550
suspend fun write(
4651
frame: Frame,
47-
streamPid: Int
52+
streamPid: Int,
53+
onFrameProcessed: (() -> Unit)
4854
)
4955

5056
/**

core/src/main/java/io/github/thibaultbee/streampack/core/elements/endpoints/MediaMuxerEndpoint.kt

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -140,40 +140,47 @@ class MediaMuxerEndpoint(
140140
}
141141

142142
override suspend fun write(
143-
frame: Frame, streamPid: Int
143+
frame: Frame, streamPid: Int, onFrameProcessed: () -> Unit
144144
) = withContext(ioDispatcher) {
145145
mutex.withLock {
146-
if (state != State.STARTED && state != State.PENDING_START) {
147-
Logger.w(TAG, "Trying to write while not started. Current state: $state")
148-
return@withContext
149-
}
146+
try {
147+
if (state != State.STARTED && state != State.PENDING_START) {
148+
Logger.w(TAG, "Trying to write while not started. Current state: $state")
149+
return@withContext
150+
}
150151

151-
val mediaMuxer = requireNotNull(mediaMuxer) { "MediaMuxer is not initialized" }
152+
val mediaMuxer = requireNotNull(mediaMuxer) { "MediaMuxer is not initialized" }
152153

153-
if ((state == State.PENDING_START) && (streamIdToTrackId.size < numOfStreams)) {
154-
addTrack(mediaMuxer, streamPid, frame.format)
155-
if (streamIdToTrackId.size == numOfStreams) {
156-
mediaMuxer.start()
157-
setState(State.STARTED)
154+
if ((state == State.PENDING_START) && (streamIdToTrackId.size < numOfStreams)) {
155+
addTrack(mediaMuxer, streamPid, frame.format)
156+
if (streamIdToTrackId.size == numOfStreams) {
157+
mediaMuxer.start()
158+
setState(State.STARTED)
159+
}
158160
}
159-
}
160161

161-
if (state == State.STARTED) {
162-
val trackId = streamIdToTrackId[streamPid]
163-
?: throw IllegalStateException("Could not find trackId for streamPid $streamPid: ${frame.format}")
164-
val info = BufferInfo().apply {
165-
set(
166-
0,
167-
frame.buffer.remaining(),
168-
frame.ptsInUs,
169-
if (frame.isKeyFrame) BUFFER_FLAG_KEY_FRAME else 0
170-
)
171-
}
172-
try {
173-
mediaMuxer.writeSampleData(trackId, frame.buffer, info)
174-
} catch (e: IllegalStateException) {
175-
Logger.w(TAG, "MediaMuxer is in an illegal state. ${e.message}")
162+
if (state == State.STARTED) {
163+
val trackId = streamIdToTrackId[streamPid]
164+
?: throw IllegalStateException("Could not find trackId for streamPid $streamPid: ${frame.format}")
165+
val info = BufferInfo().apply {
166+
set(
167+
0,
168+
frame.buffer.remaining(),
169+
frame.ptsInUs,
170+
if (frame.isKeyFrame) BUFFER_FLAG_KEY_FRAME else 0
171+
)
172+
}
173+
try {
174+
mediaMuxer.writeSampleData(trackId, frame.buffer, info)
175+
} catch (e: IllegalStateException) {
176+
Logger.w(TAG, "MediaMuxer is in an illegal state. ${e.message}")
177+
}
176178
}
179+
} catch (t: Throwable) {
180+
Logger.e(TAG, "Error while writing frame: ${t.message}")
181+
throw t
182+
} finally {
183+
onFrameProcessed()
177184
}
178185
}
179186
}

core/src/main/java/io/github/thibaultbee/streampack/core/elements/endpoints/composites/CompositeEndpoint.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,9 @@ class CompositeEndpoint(
7676

7777
override suspend fun write(
7878
frame: Frame,
79-
streamPid: Int
80-
) = muxer.write(frame, streamPid)
79+
streamPid: Int,
80+
onFrameProcessed: (() -> Unit)
81+
) = muxer.write(frame, streamPid, onFrameProcessed)
8182

8283
override fun addStreams(streamConfigs: List<CodecConfig>): Map<CodecConfig, Int> {
8384
mutex.tryLock()

core/src/main/java/io/github/thibaultbee/streampack/core/elements/endpoints/composites/muxers/IMuxerInternal.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ interface IMuxerInternal :
3232
fun onOutputFrame(packet: Packet)
3333
}
3434

35-
fun write(frame: Frame, streamPid: Int)
35+
fun write(frame: Frame, streamPid: Int, onFrameProcessed: () -> Unit)
3636

3737
fun addStreams(streamsConfig: List<CodecConfig>): Map<CodecConfig, Int>
3838

core/src/main/java/io/github/thibaultbee/streampack/core/elements/endpoints/composites/muxers/flv/FlvMuxer.kt

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,42 +41,46 @@ class FlvMuxer(
4141
override val streamConfigs: List<CodecConfig>
4242
get() = streams.map { it.config }
4343

44-
override fun write(frame: Frame, streamPid: Int) {
44+
override fun write(frame: Frame, streamPid: Int, onFrameProcessed: () -> Unit) {
4545
synchronized(this) {
46-
if (!hasFirstFrame) {
47-
/**
48-
* Wait for first video frame to start (only if video is present)
49-
*/
50-
if (hasVideo) {
51-
// Expected first video key frame
52-
if (frame.isVideo && frame.isKeyFrame) {
46+
try {
47+
if (!hasFirstFrame) {
48+
/**
49+
* Wait for first video frame to start (only if video is present)
50+
*/
51+
if (hasVideo) {
52+
// Expected first video key frame
53+
if (frame.isVideo && frame.isKeyFrame) {
54+
startUpTime = frame.ptsInUs
55+
hasFirstFrame = true
56+
} else {
57+
// Drop
58+
return
59+
}
60+
} else {
61+
// Audio only
5362
startUpTime = frame.ptsInUs
5463
hasFirstFrame = true
55-
} else {
56-
// Drop
57-
return
5864
}
59-
} else {
60-
// Audio only
61-
startUpTime = frame.ptsInUs
62-
hasFirstFrame = true
6365
}
64-
}
65-
}
6666

67-
if (frame.ptsInUs < startUpTime!!) {
68-
return
69-
}
67+
if (frame.ptsInUs < startUpTime!!) {
68+
return
69+
}
7070

71-
frame.ptsInUs -= startUpTime!!
72-
val stream = streams[streamPid]
73-
val sendHeader = stream.sendHeader
74-
stream.sendHeader = false
75-
val flvTags = AVTagsFactory(frame, stream.config, sendHeader).build()
76-
flvTags.forEach {
77-
listener?.onOutputFrame(
78-
Packet(it.write(), frame.ptsInUs)
79-
)
71+
frame.ptsInUs -= startUpTime!!
72+
val stream = streams[streamPid]
73+
val sendHeader = stream.sendHeader
74+
stream.sendHeader = false
75+
val flvTags = AVTagsFactory(frame, stream.config, sendHeader).build()
76+
flvTags.forEach {
77+
listener?.onOutputFrame(
78+
Packet(it.write(), frame.ptsInUs)
79+
)
80+
}
81+
} finally {
82+
onFrameProcessed()
83+
}
8084
}
8185
}
8286

core/src/main/java/io/github/thibaultbee/streampack/core/elements/endpoints/composites/muxers/mp4/Mp4Muxer.kt

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,17 @@ class Mp4Muxer(
5656

5757
override val streamConfigs: List<CodecConfig>
5858
get() = tracks.map { it.config }
59-
60-
override fun write(frame: Frame, streamPid: Int) {
59+
60+
override fun write(frame: Frame, streamPid: Int, onFrameProcessed: () -> Unit) {
6161
synchronized(this) {
62-
if (segmenter!!.mustWriteSegment(frame)) {
63-
writeSegment()
62+
try {
63+
if (segmenter!!.mustWriteSegment(frame)) {
64+
writeSegment()
65+
}
66+
currentSegment!!.add(frame, streamPid)
67+
} finally {
68+
onFrameProcessed()
6469
}
65-
currentSegment!!.add(frame, streamPid)
6670
}
6771
}
6872

0 commit comments

Comments
 (0)