Skip to content

Moderations #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add moderations endpoint
  • Loading branch information
CJCrafter committed Feb 29, 2024
commit c41eeeb5351a21a336f69a6d3b33a878bf520eb5
14 changes: 14 additions & 0 deletions src/main/kotlin/com/cjcrafter/openai/OpenAI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.cjcrafter.openai.completions.CompletionResponseChunk
import com.cjcrafter.openai.embeddings.EmbeddingsRequest
import com.cjcrafter.openai.embeddings.EmbeddingsResponse
import com.cjcrafter.openai.files.*
import com.cjcrafter.openai.moderations.ModerationHandler
import com.cjcrafter.openai.threads.ThreadHandler
import com.cjcrafter.openai.threads.message.TextAnnotation
import com.cjcrafter.openai.util.OpenAIDslMarker
Expand Down Expand Up @@ -135,6 +136,19 @@ interface OpenAI {
@Contract(pure = true)
fun files(): FileHandler = files

/**
* Returns the handler for the moderations endpoint. This handler can be used
* to create moderations.
*/
val moderations: ModerationHandler

/**
* Returns the handler for the moderations endpoint. This method is purely
* syntactic sugar for Java users.
*/
@Contract(pure = true)
fun moderations(): ModerationHandler = moderations

/**
* Returns the handler for the assistants endpoint. This handler can be used
* to create, retrieve, and delete assistants.
Expand Down
25 changes: 16 additions & 9 deletions src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import com.cjcrafter.openai.completions.CompletionResponseChunk
import com.cjcrafter.openai.embeddings.EmbeddingsRequest
import com.cjcrafter.openai.embeddings.EmbeddingsResponse
import com.cjcrafter.openai.files.*
import com.cjcrafter.openai.moderations.ModerationHandler
import com.cjcrafter.openai.moderations.ModerationHandlerImpl
import com.cjcrafter.openai.threads.ThreadHandler
import com.cjcrafter.openai.threads.ThreadHandlerImpl
import com.fasterxml.jackson.databind.JavaType
Expand Down Expand Up @@ -127,23 +129,28 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
return requestHelper.executeRequest(httpRequest, EmbeddingsResponse::class.java)
}

private var files0: FileHandlerImpl? = null
override val files: FileHandler
get() = files0 ?: FileHandlerImpl(requestHelper, FILES_ENDPOINT).also { files0 = it }
override val files: FileHandler by lazy {
FileHandlerImpl(requestHelper, FILES_ENDPOINT)
}

private var assistants0: AssistantHandlerImpl? = null
override val assistants: AssistantHandler
get() = assistants0 ?: AssistantHandlerImpl(requestHelper, ASSISTANTS_ENDPOINT).also { assistants0 = it }
override val moderations: ModerationHandler by lazy {
ModerationHandlerImpl(requestHelper, MODERATIONS_ENDPOINT)
}

private var threads0: ThreadHandlerImpl? = null
override val threads: ThreadHandler
get() = threads0 ?: ThreadHandlerImpl(requestHelper, THREADS_ENDPOINT).also { threads0 = it }
override val assistants: AssistantHandler by lazy {
AssistantHandlerImpl(requestHelper, ASSISTANTS_ENDPOINT)
}

override val threads: ThreadHandler by lazy {
ThreadHandlerImpl(requestHelper, THREADS_ENDPOINT)
}

companion object {
const val COMPLETIONS_ENDPOINT = "v1/completions"
const val CHAT_ENDPOINT = "v1/chat/completions"
const val EMBEDDINGS_ENDPOINT = "v1/embeddings"
const val FILES_ENDPOINT = "v1/files"
const val MODERATIONS_ENDPOINT = "v1/moderations"
const val ASSISTANTS_ENDPOINT = "v1/assistants"
const val THREADS_ENDPOINT = "v1/threads"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.cjcrafter.openai.moderations

import com.cjcrafter.openai.util.OpenAIDslMarker

/**
* Represents a request to create a new moderation request.
*
* @property input The input to moderate
* @property model The model to use for moderation
*/
data class CreateModerationRequest internal constructor(
var input: Any,
var model: String? = null
) {

@OpenAIDslMarker
class Builder internal constructor() {
private var input: Any? = null
private var model: String? = null

/**
* Sets the input to moderate.
*
* @param input The input to moderate
*/
fun input(input: String) = apply { this.input = input }

/**
* Sets the input to moderate.
*
* @param input The input to moderate
*/
fun input(input: List<String>) = apply { this.input = input }

/**
* Sets the model to use for moderation.
*
* @param model The model to use for moderation
*/
fun model(model: String) = apply { this.model = model }

/**
* Builds the [CreateModerationRequest] instance.
*/
fun build(): CreateModerationRequest {
return CreateModerationRequest(
input = input ?: throw IllegalStateException("input must be defined to use CreateModerationRequest"),
model = model
)
}
}

companion object {
/**
* Returns a builder to construct a [CreateModerationRequest] instance.
*/
@JvmStatic
fun builder() = Builder()
}
}
31 changes: 31 additions & 0 deletions src/main/kotlin/com/cjcrafter/openai/moderations/Moderation.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.cjcrafter.openai.moderations

import com.fasterxml.jackson.annotation.JsonProperty

/**
* A moderation object returned by the moderations api.
*
* @property id The id of the moderation request. Always starts with "modr-".
* @property model The model which was used to moderate the content.
* @property results The results of the moderation request.
* @constructor Create empty Moderation
*/
data class Moderation(
@JsonProperty(required = true) val id: String,
@JsonProperty(required = true) val model: String,
@JsonProperty(required = true) val results: Results,
) {
/**
* The results of the moderation request.
*
* @property flagged If any categories were flagged.
* @property categories The categories that were flagged.
* @property categoryScores The scores of each category.
* @constructor Create empty Results
*/
data class Results(
@JsonProperty(required = true) val flagged: Boolean,
@JsonProperty(required = true) val categories: Map<String, Boolean>,
@JsonProperty("category_scores", required = true) val categoryScores: Map<String, Double>,
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.cjcrafter.openai.moderations

/**
* Handler used to interact with [Moderation] objects.
*/
interface ModerationHandler {

/**
* Creates a new moderation request with the given options.
*
* @param request The values of the moderation to create
* @return The created moderation
*/
fun create(request: CreateModerationRequest): Moderation
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.cjcrafter.openai.moderations

import com.cjcrafter.openai.RequestHelper

class ModerationHandlerImpl(
private val requestHelper: RequestHelper,
private val endpoint: String,
): ModerationHandler {
override fun create(request: CreateModerationRequest): Moderation {
val httpRequest = requestHelper.buildRequest(request, endpoint).build()
return requestHelper.executeRequest(httpRequest, Moderation::class.java)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package com.cjcrafter.openai.chat

import com.cjcrafter.openai.MockedTest
import com.cjcrafter.openai.chat.ChatMessage.Companion.toSystemMessage
import com.cjcrafter.openai.chat.tool.FunctionToolCall
import com.cjcrafter.openai.chat.tool.Tool
import com.cjcrafter.openai.chat.tool.ToolCall
import okhttp3.mockwebserver.MockResponse
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
Expand Down Expand Up @@ -46,9 +49,9 @@ class MockedChatStreamTest : MockedTest() {

// Assertions
assertEquals(ChatUser.ASSISTANT, toolMessage.role, "Tool call should be from the assistant")
assertEquals(ToolType.FUNCTION, toolMessage.toolCalls?.get(0)?.type, "Tool call should be a function")
assertEquals("solve_math_problem", toolMessage.toolCalls?.get(0)?.function?.name)
assertEquals("3/2", toolMessage.toolCalls?.get(0)?.function?.tryParseArguments()?.get("equation")?.asText())
assertEquals(Tool.Type.FUNCTION, toolMessage.toolCalls?.get(0)?.type, "Tool call should be a function")
assertEquals("solve_math_problem", (toolMessage.toolCalls?.get(0) as? FunctionToolCall)?.function?.name)
assertEquals("3/2", (toolMessage.toolCalls?.get(0) as? FunctionToolCall)?.function?.tryParseArguments()?.get("equation")?.asText())

assertEquals(ChatUser.ASSISTANT, message.role, "Message should be from the assistant")
assertEquals("The result of 3 divided by 2 is 1.5.", message.content)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class FunctionCallTest {
name("enum_checker")
description("This function is used to test the enum parameter")
addEnumParameter("enum", mutableListOf("a", "b", "c"))
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"enum_checker\", \"arguments\": \"{\\\"enum\\\": \\\"d\\\"}\"}" // d is not a valid enum
Expand All @@ -37,7 +37,7 @@ class FunctionCallTest {
name("enum_checker")
description("This function is used to test the enum parameter")
addEnumParameter("enum", mutableListOf("a", "b", "c"))
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"enum_checker\", \"arguments\": \"{\\\"enum\\\": \\\"a\\\"}\"}" // a is a valid enum
Expand All @@ -55,7 +55,7 @@ class FunctionCallTest {
name("integer_checker")
description("This function is used to test the integer parameter")
addIntegerParameter("integer", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"integer_checker\", \"arguments\": \"{\\\"integer\\\": \\\"not an integer\\\"}\"}" // not an integer
Expand All @@ -73,7 +73,7 @@ class FunctionCallTest {
name("integer_checker")
description("This function is used to test the integer parameter")
addIntegerParameter("integer", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"integer_checker\", \"arguments\": \"{\\\"integer\\\": 1}\"}" // 1 is an integer
Expand All @@ -91,7 +91,7 @@ class FunctionCallTest {
name("boolean_checker")
description("This function is used to test the boolean parameter")
addBooleanParameter("is_true", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"boolean_checker\", \"arguments\": \"{\\\"boolean\\\": \\\"not a boolean\\\"}\"}" // not a boolean
Expand All @@ -109,7 +109,7 @@ class FunctionCallTest {
name("boolean_checker")
description("This function is used to test the boolean parameter")
addBooleanParameter("is_true", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"boolean_checker\", \"arguments\": \"{\\\"is_true\\\": true}\"}" // true is a boolean
Expand All @@ -128,7 +128,7 @@ class FunctionCallTest {
description("This function is used to test the required parameter")
addIntegerParameter("required", "test parameter", required = true)
addBooleanParameter("not_required", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"required_parameter_function\", \"arguments\": \"{\\\"not_required\\\": true}\"}" // missing required parameter
Expand All @@ -147,7 +147,7 @@ class FunctionCallTest {
description("This function is used to test the required parameter")
addIntegerParameter("required", "test parameter", required = true)
addBooleanParameter("not_required", "test parameter")
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"required_parameter_function\", \"arguments\": \"{\\\"required\\\": 1, \\\"not_required\\\": true}\"}" // has required parameter
Expand All @@ -165,7 +165,7 @@ class FunctionCallTest {
name("function_name_checker")
description("This function is used to test the function name")
noParameters()
}.toTool()
}
)
@Language("json")
val json = "{\"name\": \"invalid_function_name\", \"arguments\": \"{}\"}" // invalid function name
Expand Down