Skip to content

Commit

Permalink
chore: add support for dimensions parameter to OpenAIEmbedding (#2215)
Browse files Browse the repository at this point in the history
* adding support for dimensions parameter to OpenAIEmbedding

* fixing scalastyle error

* fix service to use for testing embedding dimensions

---------

Co-authored-by: Scott Graham <scgraham@microsoft.com>
Co-authored-by: Mark Hamilton <mhamilton723@gmail.com>
  • Loading branch information
3 people authored May 1, 2024
1 parent edb8c54 commit d0a2161
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,24 @@ trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion {

}

trait HasOpenAIEmbeddingParams extends HasOpenAISharedParams with HasAPIVersion {

val dimensions: ServiceParam[Int] = new ServiceParam[Int](
this, "dimensions", "Number of dimensions for output embeddings.", isRequired = false)

def getDimensions: Int = getScalarParam(dimensions)

def setDimensions(value: Int): this.type = setScalarParam(dimensions, value)

private[ml] def getOptionalParams(r: Row): Map[String, Any] = {
Seq(
dimensions
).flatMap(param =>
getValueOpt(r, param).map(v => (GenerationUtils.camelToSnake(param.name), v))
).toMap
}
}

trait HasOpenAITextParams extends HasOpenAISharedParams {

val maxTokens: ServiceParam[Int] = new ServiceParam[Int](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
import com.microsoft.azure.synapse.ml.io.http.JSONOutputParser
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.ServiceParam
Expand All @@ -21,7 +22,7 @@ import scala.language.existentials
object OpenAIEmbedding extends ComplexParamsReadable[OpenAIEmbedding]

class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAISharedParams with HasOpenAICognitiveServiceInput with SynapseMLLogging {
with HasOpenAIEmbeddingParams with HasOpenAICognitiveServiceInput with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

def this() = this(Identifiable.randomUID("OpenAIEmbedding"))
Expand Down Expand Up @@ -61,10 +62,16 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
s"${getUrl}openai/deployments/${getValue(row, deploymentName)}/embeddings"
}

private[this] def getStringEntity[A](text: A, optionalParams: Map[String, Any]): StringEntity = {
val fullPayload = optionalParams.updated("input", text)
new StringEntity(fullPayload.toJson.compactPrint, ContentType.APPLICATION_JSON)
}

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
r =>
lazy val optionalParams: Map[String, Any] = getOptionalParams(r)
getValueOpt(r, text)
.map(text => new StringEntity(Map("input" -> text).toJson.compactPrint, ContentType.APPLICATION_JSON))
.map(text => getStringEntity(text, optionalParams))
.orElse(throw new IllegalArgumentException("Please set textCol."))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ import org.scalactic.Equality

trait OpenAIAPIKey {
lazy val openAIAPIKey: String = sys.env.getOrElse("OPENAI_API_KEY", Secrets.OpenAIApiKey)
lazy val openAIServiceName: String = "synapseml-openai"
lazy val openAIServiceName: String = sys.env.getOrElse("OPENAI_SERVICE_NAME", "synapseml-openai")
lazy val deploymentName: String = "gpt-35-turbo"
lazy val modelName: String = "gpt-35-turbo"
lazy val openAIAPIKeyGpt4: String = sys.env.getOrElse("OPENAI_API_KEY_2", Secrets.OpenAIApiKeyGpt4)
lazy val openAIServiceNameGpt4: String = "synapseml-openai-2"
lazy val openAIServiceNameGpt4: String = sys.env.getOrElse("OPENAI_SERVICE_NAME_2", "synapseml-openai-2")
lazy val deploymentNameGpt4: String = "gpt-4"
lazy val modelNameGpt4: String = "gpt-4"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@ class OpenAIEmbeddingsSuite extends TransformerFuzzing[OpenAIEmbedding] with Ope
})
}

lazy val embeddingExtra: OpenAIEmbedding = new OpenAIEmbedding()
.setSubscriptionKey(openAIAPIKeyGpt4)
.setDeploymentName("text-embedding-3-small")
.setApiVersion("2024-03-01-preview")
.setDimensions(100)
.setUser("testUser")
.setCustomServiceName(openAIServiceNameGpt4)
.setTextCol("text")
.setOutputCol("out")

test("Extra Params Usage") {
embeddingExtra.transform(df).collect().foreach(r => {
val v = r.getAs[Vector]("out")
assert(v.size == 100)
})
}


override def testObjects(): Seq[TestObject[OpenAIEmbedding]] =
Seq(new TestObject(embedding, df))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ from synapse.ml.core.platform import find_secret
service_name = "synapseml-openai"
deployment_name = "gpt-35-turbo"
deployment_name_embeddings = "text-embedding-ada-002"
deployment_name_embeddings_3 = "text-embedding-3-small"

key = find_secret(
secret_name="openai-api-key", keyvault="mmlspark-build-keys"
Expand Down Expand Up @@ -132,6 +133,29 @@ embedding = (
display(embedding.transform(df))
```

### Generating Text Embeddings with Reduced Dimensions

Text-Embedding-3 models developed by OpenAI are trained using a Matryoshka Representation Learning technique
which supports reducing the dimension of the embedding by trading-off some performance.

```python
from synapse.ml.services.openai import OpenAIEmbedding

embedding = (
OpenAIEmbedding()
.setSubscriptionKey(key)
.setDeploymentName(deployment_name_embeddings_3)
.setCustomServiceName(service_name)
.setApiVersion("2024-03-01-preview")
.setDimensions(256)
.setTextCol("prompt")
.setErrorCol("error")
.setOutputCol("embeddings")
)

display(embedding.transform(df))
```

### Chat Completion

Models such as ChatGPT and GPT-4 are capable of understanding chats instead of single prompts. The `OpenAIChatCompletion` transformer exposes this functionality at scale.
Expand Down

0 comments on commit d0a2161

Please sign in to comment.