Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import PackageDescription

let package = Package(
name: "swift-numerics-differentiable",
platforms: [
.macOS(.v13),
],
products: [
.library(
name: "NumericsDifferentiable",
Expand All @@ -22,6 +25,12 @@ let package = Package(
.package(url: "https://github.com/apple/swift-numerics", from: "1.0.2"),
],
targets: [
.executableTarget(name: "CodeGeneratorExecutable"),
.plugin(
name: "CodeGeneratorPlugin",
capability: .buildTool,
dependencies: ["CodeGeneratorExecutable"]
),
.target(
name: "NumericsDifferentiable",
dependencies: [
Expand All @@ -34,6 +43,9 @@ let package = Package(
name: "RealModuleDifferentiable",
dependencies: [
.product(name: "RealModule", package: "swift-numerics"),
],
plugins: [
"CodeGeneratorPlugin",
]
),
.target(
Expand Down
36 changes: 36 additions & 0 deletions Plugins/CodeGeneratorPlugin.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import Foundation
import PackagePlugin

@main
struct CodeGeneratorPlugin: BuildToolPlugin {
func createBuildCommands(context: PackagePlugin.PluginContext, target _: PackagePlugin.Target) async throws -> [PackagePlugin.Command] {
let output = context.pluginWorkDirectoryURL

let floatingPointTypes: [String] = ["Float", "Double"]
let simdWidths = [2, 4, 8, 16, 32, 64]

let outputFiles = floatingPointTypes.flatMap { floatingPointType in
simdWidths.flatMap { simdWidth in
[
output.appending(component: "SIMD\(simdWidth)+\(floatingPointType)+RealFunctions.swift"),
output.appending(component: "SIMD\(simdWidth)+\(floatingPointType)+RealFunctions+Derivatives.swift"),
]
} + [
output.appending(component: "\(floatingPointType)+RealFunctions+Derivatives.swift"),
]
} + [
output.appending(component: "SIMD+RealFunctions.swift"),
]

return [
.buildCommand(
displayName: "Generate Code",
executable: try context.tool(named: "CodeGeneratorExecutable").url,
arguments: [output.relativePath],
environment: [:],
inputFiles: [],
outputFiles: outputFiles
),
]
}
}
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# swift-numerics-differentiable

This package attempts to add more Differentiable capabilities to the existing [swift-numerics](https://github.com/apple/swift-numerics) package. Every target in swift-numerics has a Differentiable counterpart that `@_exported import`s the original module such that when you import `NumericsDifferentiable` you will also get all the contents of the `Numerics` module from swift-numerics.

## RealModule Differentiable
- Registers derivatives to the `Float` and `Double` conformances to `ElementaryFunctions` and `RealFunctions` from swift-numerics.
- Conforms all `SIMD{n}` types to `ElementaryFunctions` and adds most of the protocol requirements from `RealFunctions` as well (`signGamma` is not implementable)
- Registers derivatives for all the provided `ElementaryFunctions` and `RealFunctions` implementations on SIMD{n}
- Tries to leverage Apple's `simd` framework to accelerate these operations where possible on Apple platforms.

## Contributing
### Code Formatting
This package makes use of [SwiftFormat](https://github.com/nicklockwood/SwiftFormat?tab=readme-ov-file#command-line-tool), which you can install
Expand Down
91 changes: 91 additions & 0 deletions Sources/CodeGeneratorExecutable/CodeGenerator.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import Foundation

@main
struct CodeGenerator {
static func main() throws {
guard CommandLine.arguments.count == 2 else {
throw CodeGeneratorError.invalidArguments
}
// arguments[0] is the path to this command line tool
let output = URL(filePath: CommandLine.arguments[1])

// generate default implementations of RealFunctions for SIMD protocol
let realFunctionSIMDFileURL = output.appending(component: "SIMD+RealFunctions.swift")
let realFunctionsSIMDExtension = RealFunctionsGenerator.realFunctionsExtension(
objectType: "SIMD",
type: "Self",
whereClause: true,
simdAccelerated: false
)
try realFunctionsSIMDExtension.write(to: realFunctionSIMDFileURL, atomically: true, encoding: .utf8)

let floatingPointTypes: [String] = ["Float", "Double"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: there are aliases Float16, Float32, Float64, so you could make this an array of bit widths rather than strings, which might clean up some other parts of the code?

let simdWidths: [Int] = [2, 4, 8, 16, 32, 64]

for floatingPointType in floatingPointTypes {
// Generator Derivatives for RealFunctions for floating point types
let realFunctionDerivativesFileURL = output.appending(
component: "\(floatingPointType)+RealFunctions+Derivatives.swift",
directoryHint: .notDirectory
)
let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension(
type: floatingPointType,
floatingPointType: floatingPointType
)
try realFunctionsDerivativesExtensionCode.write(to: realFunctionDerivativesFileURL, atomically: true, encoding: .utf8)

for simdWidth in simdWidths {
let realFunctionFileURL = output.appending(
component: "SIMD\(simdWidth)+\(floatingPointType)+RealFunctions.swift",
directoryHint: .notDirectory
)
let simdType = "SIMD\(simdWidth)<\(floatingPointType)>"

// no simd methods exist for simd size >= 16 and scalar > Float so we don't add acceleration to those.
let simdAccelerated = simdWidth < 16 || (simdWidth == 16 && floatingPointType == "Float")

// Generate RealFunctions implementations on concrete SIMD types to attach derivatives to
let realFunctionsExtensionCode = RealFunctionsGenerator.realFunctionsExtension(
objectType: simdType,
type: simdType,
whereClause: false,
simdAccelerated: simdAccelerated
)
try realFunctionsExtensionCode.write(to: realFunctionFileURL, atomically: true, encoding: .utf8)

// Generate RealFunctions derivatives for concrete SIMD types
let realFunctionDerivativesFileURL = output
.appending(component: "SIMD\(simdWidth)+\(floatingPointType)+RealFunctions+Derivatives.swift")
let type = "SIMD\(simdWidth)<\(floatingPointType)>"
let realFunctionsDerivativesExtensionCode = RealFunctionsDerivativesGenerator.realFunctionsDerivativesExtension(
type: type,
floatingPointType: floatingPointType
)
try realFunctionsDerivativesExtensionCode.write(to: realFunctionDerivativesFileURL, atomically: true, encoding: .utf8)
}
}
}
}

struct RealFunction {
var name: String
var simdName: String?
var arguments: [Argument]

struct Argument {
var name: String
var label: String?
var type: String?
}

init(name: String, simdName: String? = nil, arguments: [Argument] = [.init(name: "x", label: "_")]) {
self.name = name
self.simdName = simdName
self.arguments = arguments
}
}

enum CodeGeneratorError: Error {
case invalidArguments
case invalidData
}
Loading