Skip to content

Commit

Permalink
KAFKA-1683; persisting session information in Requests
Browse files Browse the repository at this point in the history
Author: Gwen Shapira <cshapi@gmail.com>

Reviewers: Sriharsha Chintalapa, Ismael Juma, Edward Ribeiro, Parth Brahmbhatt, Jun Rao

Closes apache#155 from gwenshap/KAFKA-1683
  • Loading branch information
gwenshap committed Aug 26, 2015
1 parent 436b7dd commit 8b538d6
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class PlaintextTransportLayer implements TransportLayer {
private static final Logger log = LoggerFactory.getLogger(PlaintextTransportLayer.class);
private final SelectionKey key;
private final SocketChannel socketChannel;
private final Principal principal = new KafkaPrincipal("ANONYMOUS");
private final Principal principal = KafkaPrincipal.ANONYMOUS;

public PlaintextTransportLayer(SelectionKey key) throws IOException {
this.key = key;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLPeerUnverifiedException;

import org.apache.kafka.common.security.auth.KafkaPrincipal;
import org.apache.kafka.common.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -600,7 +601,8 @@ public Principal peerPrincipal() throws IOException {
try {
return sslEngine.getSession().getPeerPrincipal();
} catch (SSLPeerUnverifiedException se) {
throw new IOException(String.format("Unable to retrieve getPeerPrincipal due to %s", se));
log.warn("SSL peer is not authenticated, returning ANONYMOUS instead");
return KafkaPrincipal.ANONYMOUS;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ private void close(KafkaChannel channel) {
this.sensors.connectionClosed.record();
}


/**
* check if channel is ready
*/
Expand All @@ -475,9 +476,11 @@ public boolean isChannelReady(String id) {
}

/**
* Get the channel associated with this numeric id
* Get the channel associated with this connection
* Exposing this to allow SocketServer get the Principal from the channel when creating a request
* without making Selector know about Principals
*/
private KafkaChannel channelForId(String id) {
public KafkaChannel channelForId(String id) {
KafkaChannel channel = this.channels.get(id);
if (channel == null)
throw new IllegalStateException("Attempt to write to socket for which there is no open connection. Connection id " + id + " existing connections " + channels.keySet().toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.security.Principal;

public class KafkaPrincipal implements Principal {
public final static KafkaPrincipal ANONYMOUS = new KafkaPrincipal("ANONYMOUS");
private final String name;

public KafkaPrincipal(String name) {
Expand Down
16 changes: 10 additions & 6 deletions core/src/main/scala/kafka/network/RequestChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package kafka.network

import java.nio.ByteBuffer
import java.security.Principal
import java.util.concurrent._

import com.yammer.metrics.core.Gauge
Expand All @@ -29,11 +30,12 @@ import kafka.utils.{Logging, SystemTime}
import org.apache.kafka.common.network.Send
import org.apache.kafka.common.protocol.{ApiKeys, SecurityProtocol}
import org.apache.kafka.common.requests.{AbstractRequest, RequestHeader}
import org.apache.kafka.common.security.auth.KafkaPrincipal
import org.apache.log4j.Logger


object RequestChannel extends Logging {
val AllDone = new Request(processor = 1, connectionId = "2", buffer = getShutdownReceive(), startTimeMs = 0, securityProtocol = SecurityProtocol.PLAINTEXT)
val AllDone = new Request(processor = 1, connectionId = "2", new Session(KafkaPrincipal.ANONYMOUS, ""), buffer = getShutdownReceive(), startTimeMs = 0, securityProtocol = SecurityProtocol.PLAINTEXT)

def getShutdownReceive() = {
val emptyProducerRequest = new ProducerRequest(0, 0, "", 0, 0, collection.mutable.Map[TopicAndPartition, ByteBufferMessageSet]())
Expand All @@ -44,7 +46,9 @@ object RequestChannel extends Logging {
byteBuffer
}

case class Request(processor: Int, connectionId: String, private var buffer: ByteBuffer, startTimeMs: Long, securityProtocol: SecurityProtocol) {
case class Session(principal: Principal, host: String)

case class Request(processor: Int, connectionId: String, session: Session, private var buffer: ByteBuffer, startTimeMs: Long, securityProtocol: SecurityProtocol) {
@volatile var requestDequeueTimeMs = -1L
@volatile var apiLocalCompleteTimeMs = -1L
@volatile var responseCompleteTimeMs = -1L
Expand Down Expand Up @@ -113,11 +117,11 @@ object RequestChannel extends Logging {
}

if(requestLogger.isTraceEnabled)
requestLogger.trace("Completed request:%s from connection %s;totalTime:%d,requestQueueTime:%d,localTime:%d,remoteTime:%d,responseQueueTime:%d,sendTime:%d"
.format(requestDesc, connectionId, totalTime, requestQueueTime, apiLocalTime, apiRemoteTime, responseQueueTime, responseSendTime))
requestLogger.trace("Completed request:%s from connection %s;totalTime:%d,requestQueueTime:%d,localTime:%d,remoteTime:%d,responseQueueTime:%d,sendTime:%d,securityProtocol:%s,principal:%s"
.format(requestDesc, connectionId, totalTime, requestQueueTime, apiLocalTime, apiRemoteTime, responseQueueTime, responseSendTime, securityProtocol, session.principal))
else if(requestLogger.isDebugEnabled)
requestLogger.debug("Completed request:%s from connection %s;totalTime:%d,requestQueueTime:%d,localTime:%d,remoteTime:%d,responseQueueTime:%d,sendTime:%d"
.format(requestDesc, connectionId, totalTime, requestQueueTime, apiLocalTime, apiRemoteTime, responseQueueTime, responseSendTime))
requestLogger.debug("Completed request:%s from connection %s;totalTime:%d,requestQueueTime:%d,localTime:%d,remoteTime:%d,responseQueueTime:%d,sendTime:%d,securityProtocol:%s,principal:%s"
.format(requestDesc, connectionId, totalTime, requestQueueTime, apiLocalTime, apiRemoteTime, responseQueueTime, responseSendTime, securityProtocol, session.principal))
}
}

Expand Down
5 changes: 4 additions & 1 deletion core/src/main/scala/kafka/network/SocketServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,10 @@ private[kafka] class Processor(val id: Int,
}
collection.JavaConversions.collectionAsScalaIterable(selector.completedReceives).foreach(receive => {
try {
val req = RequestChannel.Request(processor = id, connectionId = receive.source, buffer = receive.payload, startTimeMs = time.milliseconds, securityProtocol = protocol)

val channel = selector.channelForId(receive.source);
val session = RequestChannel.Session(channel.principal, channel.socketDescription)
val req = RequestChannel.Request(processor = id, connectionId = receive.source, session = session, buffer = receive.payload, startTimeMs = time.milliseconds, securityProtocol = protocol)
requestChannel.sendRequest(req)
} catch {
case e @ (_: InvalidRequestException | _: SchemaException) => {
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/kafka/server/KafkaApis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class KafkaApis(val requestChannel: RequestChannel,
*/
def handle(request: RequestChannel.Request) {
try{
trace("Handling request: " + request.requestObj + " from connection: " + request.connectionId)
trace("Handling request:%s from connection %s;securityProtocol:%s,principal:%s".
format(request.requestObj, request.connectionId, request.securityProtocol, request.session.principal))
request.requestId match {
case RequestKeys.ProduceKey => handleProducerRequest(request)
case RequestKeys.FetchKey => handleFetchRequest(request)
Expand Down
19 changes: 10 additions & 9 deletions core/src/test/scala/unit/kafka/network/SocketServerTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,15 @@

package kafka.network;

import java.io._

import java.net._
import javax.net.ssl._
import java.io._
import java.nio.ByteBuffer
import java.util.Random

import kafka.api.ProducerRequest
import kafka.cluster.EndPoint
import kafka.common.TopicAndPartition
import kafka.message.ByteBufferMessageSet
import kafka.producer.SyncProducerConfig
import org.apache.kafka.common.metrics.Metrics
import org.apache.kafka.common.network.NetworkSend
import org.apache.kafka.common.protocol.SecurityProtocol
import org.apache.kafka.common.security.auth.KafkaPrincipal
import org.apache.kafka.common.utils.SystemTime
import org.junit.Assert._
import org.junit._
Expand All @@ -43,7 +37,6 @@ import java.nio.ByteBuffer
import kafka.common.TopicAndPartition
import kafka.message.ByteBufferMessageSet
import kafka.server.KafkaConfig
import java.nio.channels.SelectionKey
import kafka.utils.TestUtils

import scala.collection.Map
Expand Down Expand Up @@ -230,4 +223,12 @@ class SocketServerTest extends JUnitSuite {
assertEquals(serializedBytes.toSeq, receiveResponse(sslSocket).toSeq)
overrideServer.shutdown()
}

@Test
def testSessionPrincipal(): Unit = {
val socket = connect()
val bytes = new Array[Byte](40)
sendRequest(socket, 0, bytes)
assertEquals(KafkaPrincipal.ANONYMOUS, server.requestChannel.receiveRequest().session.principal)
}
}

0 comments on commit 8b538d6

Please sign in to comment.