Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 204 additions & 45 deletions Stitch/Graph/StitchAI/Mapping/SyntaxToActions/deriveActions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,26 @@ extension SwiftUIViewParserResult {
viewStatePatchConnections: patchResult.stateVarConnections),
caughtErrors: self.caughtErrors + layerResults.caughtErrors)// + patchResults.caughtErrors)
}

@MainActor
func deriveStitchActionsSync(bindingDeclarations: [(String, SwiftParserInitializerType)]) throws -> SwiftSyntaxActionsResult {
// Extract layer data
let layerResults = self.viewStack.deriveStitchActions(bindingDeclarations: bindingDeclarations)

let interactionsPatchActionResult = layerResults.actions.getPatchResultsFromViewEvents()

let patchCodeStatements = try SwiftPatchClosureType.swiftPatchLogic(self.bindingDeclarations.getSwiftPatchCodeTypes())

// Prepend view event data for code from `updateLayerInputs`
let allPatchCode: [SwiftPatchClosureType] = interactionsPatchActionResult.map { .viewEvent($0) } + [patchCodeStatements]

let patchResult = allPatchCode.derivePatchNodesSync()

return .init(graphData: .init(layer_data_list: layerResults.actions,
patchNodes: patchResult.nodes,
viewStatePatchConnections: patchResult.stateVarConnections),
caughtErrors: self.caughtErrors + layerResults.caughtErrors)// + patchResults.caughtErrors)
}
}

extension Array where Element == SyntaxView {
Expand Down Expand Up @@ -529,6 +549,17 @@ extension SwiftPatchCodeType {
return nil
}
}

var containsJsRef: Bool {
switch self {
case .expression(.jsRef):
return true
case .subscriptType(let codeType, _):
return codeType.containsJsRef
default:
return false
}
}
}

extension SwiftPatchClosureType {
Expand Down Expand Up @@ -857,17 +888,11 @@ extension NodeEntity {

extension SwiftPatchCodeType {
@MainActor
func derivePatchData(document: StitchDocumentViewModel,
varName: String?,
varNameToCode: [String: SwiftPatchCodeType],
viewEvent: SyntaxViewEvent?,
existingStateVarConnections: [String: [NodeIOCoordinate]],
nodesDict: [UUID: NodeEntity]) async throws -> [PatchSyntaxResultType] {
guard let aiManager = document.aiManager else {
fatalErrorIfDebug()
return []
}

func derivePatchDataSync(varName: String?,
varNameToCode: [String: SwiftPatchCodeType],
viewEvent: SyntaxViewEvent?,
existingStateVarConnections: [String: [NodeIOCoordinate]],
nodesDict: [UUID: NodeEntity]) throws -> [PatchSyntaxResultType] {
switch self {
case .expression(let codeType):
switch codeType {
Expand Down Expand Up @@ -895,40 +920,16 @@ extension SwiftPatchCodeType {
}

// recursion
return try await refCode.derivePatchData(
document: document,
return try refCode.derivePatchDataSync(
varName: varName,
varNameToCode: varNameToCode,
viewEvent: viewEvent,
existingStateVarConnections: existingStateVarConnections,
nodesDict: nodesDict)

case .jsRef(let jsData):
guard let sourceCode = varNameToCode.get(jsData.fnName)?
.jsScript,
let varName = varName else {
fatalErrorIfDebug()
return []
}

// Get AI info
let jsNodeRequest = AIJSNodeSettingsFromScritptRequest(existingScript: sourceCode)

let jsSettings = try await jsNodeRequest
.request(document: document,
aiManager: aiManager)

// TODO: double check empty list below

return try SwiftPatchNodeCode(patch: .javascript,
ports: [])
.defaultNodeEntityData(varName: varName,
varNameToCode: varNameToCode,
groupNodeId: nil,
existingStateVarConnections: existingStateVarConnections,
nodesDict: nodesDict,
viewEvent: viewEvent,
jsSettings: jsSettings)
case .jsRef:
// Return empty for JS references in sync mode - async version will handle this
return []

case .portValuesInit(let args):
// Check for PortValueDescription
Expand Down Expand Up @@ -982,6 +983,56 @@ extension SwiftPatchCodeType {
return []
}
}

@MainActor
func derivePatchData(document: StitchDocumentViewModel,
varName: String?,
varNameToCode: [String: SwiftPatchCodeType],
viewEvent: SyntaxViewEvent?,
existingStateVarConnections: [String: [NodeIOCoordinate]],
nodesDict: [UUID: NodeEntity]) async throws -> [PatchSyntaxResultType] {
// Handle the async jsRef case
if case .expression(.jsRef(let jsData)) = self {
guard let aiManager = document.aiManager else {
fatalErrorIfDebug()
return []
}

guard let sourceCode = varNameToCode.get(jsData.fnName)?
.jsScript,
let varName = varName else {
fatalErrorIfDebug()
return []
}

// Get AI info
let jsNodeRequest = AIJSNodeSettingsFromScritptRequest(existingScript: sourceCode)

let jsSettings = try await jsNodeRequest
.request(document: document,
aiManager: aiManager)

// TODO: double check empty list below

return try SwiftPatchNodeCode(patch: .javascript,
ports: [])
.defaultNodeEntityData(varName: varName,
varNameToCode: varNameToCode,
groupNodeId: nil,
existingStateVarConnections: existingStateVarConnections,
nodesDict: nodesDict,
viewEvent: viewEvent,
jsSettings: jsSettings)
}

// For all other cases, delegate to the synchronous version
return try derivePatchDataSync(
varName: varName,
varNameToCode: varNameToCode,
viewEvent: viewEvent,
existingStateVarConnections: existingStateVarConnections,
nodesDict: nodesDict)
}
}

extension Sequence {
Expand All @@ -995,8 +1046,56 @@ extension Sequence {
}

extension Array where Element == SwiftPatchClosureType {
@MainActor
func derivePatchNodesSync() -> SwiftSyntaxPatchActionsResult {
var result = SwiftSyntaxPatchActionsResult(nodes: [],
stateVarConnections: [:],
caughtErrors: [])

for closureType in self {
let existingNodesDict = result.nodes.reduce(into: [:]) { result, node in
result.updateValue(node, forKey: node.id)
}

switch closureType {
case .swiftPatchLogic(let codeStatements):
let patchResult = codeStatements
.derivePatchNodesSync(existingStateVarConnections: result.stateVarConnections,
existingNodesDict: existingNodesDict,
viewEvent: nil)
result += patchResult

case .viewEvent(let swiftPatchViewEvent):
let viewEventData = swiftPatchViewEvent.viewEvent
let closureActionsResult = swiftPatchViewEvent
.codeStatements
.derivePatchNodesSync(existingStateVarConnections: result.stateVarConnections,
existingNodesDict: existingNodesDict,
viewEvent: viewEventData)
result += closureActionsResult
}
}

return result
}

@MainActor
func derivePatchNodes(document: StitchDocumentViewModel) async -> SwiftSyntaxPatchActionsResult {
// Check if any closure contains async operations
let hasAsyncOperations = self.contains { closureType in
switch closureType {
case .swiftPatchLogic(let codeStatements):
return codeStatements.contains { (_, code) in code.containsJsRef }
case .viewEvent(let swiftPatchViewEvent):
return swiftPatchViewEvent.codeStatements.contains { (_, code) in code.containsJsRef }
}
}

if !hasAsyncOperations {
// Use sync version if no async operations needed
return derivePatchNodesSync()
}

var result = SwiftSyntaxPatchActionsResult(nodes: [],
stateVarConnections: [:],
caughtErrors: [])
Expand All @@ -1008,7 +1107,6 @@ extension Array where Element == SwiftPatchClosureType {

switch closureType {
case .swiftPatchLogic(let codeStatements):

let patchResult = await codeStatements
.derivePatchNodes(document: document,
existingStateVarConnections: result.stateVarConnections,
Expand All @@ -1017,17 +1115,13 @@ extension Array where Element == SwiftPatchClosureType {
result += patchResult

case .viewEvent(let swiftPatchViewEvent):
// Create node for view event
let viewEventData = swiftPatchViewEvent.viewEvent

// Get data from closure actions
let closureActionsResult = await swiftPatchViewEvent
.codeStatements
.derivePatchNodes(document: document,
existingStateVarConnections: result.stateVarConnections,
existingNodesDict: existingNodesDict,
viewEvent: viewEventData)

result += closureActionsResult
}
}
Expand Down Expand Up @@ -1221,11 +1315,76 @@ extension Dictionary where Key == UUID, Value == NodeEntity {
}

extension Array where Element == (String, SwiftPatchCodeType) {
@MainActor
func derivePatchNodesSync(existingStateVarConnections: [String: [NodeIOCoordinate]],
existingNodesDict: [UUID: NodeEntity],
viewEvent: SyntaxViewEvent?) -> SwiftSyntaxPatchActionsResult {
// Create dictionary of self
let varNameToCode = self.reduce(into: [String: SwiftPatchCodeType]()) { result, data in
result.updateValue(data.1, forKey: data.0)
}

// Instantiate dictionary of nodes to return as array later
var nodesDict = [UUID: NodeEntity]()

// Tracks connections to state variables, used as layer inputs later
var stateVarConnections = [String: [NodeIOCoordinate]]()

var caughtErrors = [SwiftUISyntaxError]()

// Create patch nodes and input values
for (varName, code) in self {
do {
let mergedStateVarConnections = existingStateVarConnections
.merging(stateVarConnections) { $1 }
let mergedNodesDict = existingNodesDict
.merging(nodesDict) { $1 }

let events = try code.derivePatchDataSync(
varName: varName,
varNameToCode: varNameToCode,
viewEvent: viewEvent,
existingStateVarConnections: mergedStateVarConnections,
nodesDict: mergedNodesDict)

for event in events {
nodesDict.updateWithEventData(event,
layerInputCoordinate: nil,
varName: varName,
stateVarConnections: &stateVarConnections)
}

} catch let error as SwiftUISyntaxError {
caughtErrors.append(error)
} catch {
fatalErrorIfDebug(error.localizedDescription)
log("deriveStitchActions: error.localizedDescription: \(error.localizedDescription)")
continue
}
}

return .init(nodes: [NodeEntity](nodesDict.values),
stateVarConnections: stateVarConnections,
caughtErrors: caughtErrors)
}

@MainActor
func derivePatchNodes(document: StitchDocumentViewModel,
existingStateVarConnections: [String: [NodeIOCoordinate]],
existingNodesDict: [UUID: NodeEntity],
viewEvent: SyntaxViewEvent?) async -> SwiftSyntaxPatchActionsResult {
// Check if any element needs async processing (contains jsRef)
let hasAsyncOperations = self.contains { (_, code) in
code.containsJsRef
}

if !hasAsyncOperations {
// Use sync version if no async operations needed
return derivePatchNodesSync(existingStateVarConnections: existingStateVarConnections,
existingNodesDict: existingNodesDict,
viewEvent: viewEvent)
}

// Create dictionary of self
let varNameToCode = self.reduce(into: [String: SwiftPatchCodeType]()) { result, data in
result.updateValue(data.1, forKey: data.0)
Expand All @@ -1239,7 +1398,7 @@ extension Array where Element == (String, SwiftPatchCodeType) {

var caughtErrors = [SwiftUISyntaxError]()

// Create patch nodes and input values
// Create patch nodes and input values with async support
for (varName, code) in self {
do {
let mergedStateVarConnections = existingStateVarConnections
Expand Down