diff --git a/Cargo.lock b/Cargo.lock index 962be65..8b5b209 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -64,17 +64,6 @@ dependencies = [ "syn", ] -[[package]] -name = "async-trait" -version = "0.1.83" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "atomic-waker" version = "1.1.2" @@ -143,25 +132,17 @@ checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "cbadv" -version = "2.0.1" +version = "2.0.2" dependencies = [ "assert-json-diff", - "async-trait", "base64", - "chrono", "futures", - "futures-util", - "hex", - "hmac", - "num-traits", "openssl", - "rand", "reqwest", "ring", "serde", "serde_json", "serde_with", - "sha2", "tokio", "tokio-test", "tokio-tungstenite", @@ -171,9 +152,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.2" +version = "1.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f34d93e62b03caf570cccc334cbc6c2fceca82f39211051345108adcba3eebdc" +checksum = "9157bbaa6b165880c27a4293a474c91cdcf265cc68cc829bf10be0964a391caf" dependencies = [ "shlex", ] @@ -186,16 +167,14 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" dependencies = [ "android-tzdata", "iana-time-zone", - "js-sys", "num-traits", "serde", - "wasm-bindgen", "windows-targets", ] @@ -293,7 +272,6 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", - "subtle", ] [[package]] @@ -334,9 +312,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "fnv" @@ -521,20 +499,11 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" -[[package]] -name = "hmac" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" -dependencies = [ - "digest", -] - [[package]] name = "http" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" dependencies = [ "bytes", "fnv", @@ -846,9 +815,9 @@ checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "js-sys" -version = "0.3.74" +version = "0.3.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a865e038f7f6ed956f788f0d7d60c541fff74c7bd74272c5d4cf15c63743e705" +checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7" dependencies = [ "once_cell", "wasm-bindgen", @@ -856,9 +825,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.167" +version = "0.2.168" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" +checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d" [[package]] name = "linux-raw-sys" @@ -872,16 +841,6 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" -[[package]] -name = "lock_api" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" -dependencies = [ - "autocfg", - "scopeguard", -] - [[package]] name = "log" version = "0.4.22" @@ -1011,29 +970,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "parking_lot" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets", -] - [[package]] name = "percent-encoding" version = "2.3.1" @@ -1121,15 +1057,6 @@ dependencies = [ "getrandom", ] -[[package]] -name = "redox_syscall" -version = "0.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" -dependencies = [ - "bitflags", -] - [[package]] name = "reqwest" version = "0.12.9" @@ -1196,22 +1123,22 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustix" -version = "0.38.41" +version = "0.38.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" +checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85" dependencies = [ "bitflags", "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "rustls" -version = "0.23.19" +version = "0.23.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "934b404430bb06b3fae2cba809eb45a1ab1aecd64491213d7c3301b88393f8d1" +checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b" dependencies = [ "once_cell", "rustls-pki-types", @@ -1231,9 +1158,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" +checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" [[package]] name = "rustls-webpki" @@ -1261,12 +1188,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - [[package]] name = "security-framework" version = "2.11.1" @@ -1292,18 +1213,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.215" +version = "1.0.216" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" +checksum = "0b9781016e935a97e8beecf0c933758c97a5520d32930e460142b4cd80c6338e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.215" +version = "1.0.216" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" +checksum = "46f859dbbf73865c6627ed570e78961cd3ac92407a2d117204c49232485da55e" dependencies = [ "proc-macro2", "quote", @@ -1384,32 +1305,12 @@ dependencies = [ "digest", ] -[[package]] -name = "sha2" -version = "0.10.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - [[package]] name = "shlex" version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" -[[package]] -name = "signal-hook-registry" -version = "1.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" -dependencies = [ - "libc", -] - [[package]] name = "slab" version = "0.4.9" @@ -1546,9 +1447,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.36" +version = "0.3.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" dependencies = [ "deranged", "itoa", @@ -1567,9 +1468,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.18" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" dependencies = [ "num-conv", "time-core", @@ -1587,17 +1488,15 @@ dependencies = [ [[package]] name = "tokio" -version = "1.41.1" +version = "1.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" +checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551" dependencies = [ "backtrace", "bytes", "libc", "mio", - "parking_lot", "pin-project-lite", - "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.52.0", @@ -1626,20 +1525,19 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.0" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" dependencies = [ "rustls", - "rustls-pki-types", "tokio", ] [[package]] name = "tokio-stream" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" dependencies = [ "futures-core", "pin-project-lite", @@ -1675,9 +1573,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.12" +version = "0.7.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" dependencies = [ "bytes", "futures-core", @@ -1825,18 +1723,6 @@ checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "getrandom", "rand", - "uuid-macro-internal", -] - -[[package]] -name = "uuid-macro-internal" -version = "1.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b91f57fe13a38d0ce9e28a03463d8d3c2468ed03d75375110ec71d93b449a08" -dependencies = [ - "proc-macro2", - "quote", - "syn", ] [[package]] @@ -1868,9 +1754,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d15e63b4482863c109d70a7b8706c1e364eb6ea449b201a76c5b89cedcec2d5c" +checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" dependencies = [ "cfg-if", "once_cell", @@ -1879,13 +1765,12 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d36ef12e3aaca16ddd3f67922bc63e48e953f126de60bd33ccc0101ef9998cd" +checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" dependencies = [ "bumpalo", "log", - "once_cell", "proc-macro2", "quote", "syn", @@ -1894,9 +1779,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.47" +version = "0.4.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dfaf8f50e5f293737ee323940c7d8b08a66a95a419223d9f41610ca08b0833d" +checksum = "38176d9b44ea84e9184eff0bc34cc167ed044f816accfe5922e54d84cf48eca2" dependencies = [ "cfg-if", "js-sys", @@ -1907,9 +1792,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "705440e08b42d3e4b36de7d66c944be628d579796b8090bfa3471478a2260051" +checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1917,9 +1802,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98c9ae5a76e46f4deecd0f0255cc223cfa18dc9b261213b8aa0c7b36f61b3f1d" +checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" dependencies = [ "proc-macro2", "quote", @@ -1930,15 +1815,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ee99da9c5ba11bd675621338ef6fa52296b76b83305e9b6e5c77d4c286d6d49" +checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" [[package]] name = "web-sys" -version = "0.3.74" +version = "0.3.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a98bc3c33f0fe7e59ad7cd041b89034fa82a7c2d4365ca538dda6cdaf513863c" +checksum = "04dd7223427d52553d3702c004d3b2fe07c148165faa56313cb00211e31c12bc" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index 558dd99..321d2a7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "cbadv" -version = "2.0.1" +version = "2.0.2" edition = "2021" description = "Asynchronous Coinbase Advanced REST and WebSocket API" license = "MIT" @@ -12,43 +12,22 @@ categories = ["api-bindings", "cryptography::cryptocurrencies"] include = ["src/**", "Cargo.toml", "README.md", "LICENSE", "examples/**"] [features] -default = ["config"] +default = [] full = ["config"] config = ["dep:toml"] [dependencies] -# Core dependencies -reqwest = { version = "0.12.9", features = ["json"] } futures = "0.3.31" -tokio = { version = "1.41.1", features = ["full"] } - -# Cryptography and signing -hmac = "0.12.1" -sha2 = "0.10.8" -hex = "0.4.3" - -# Serialization and configuration +reqwest = { version = "0.12.9", features = ["json"] } +tokio-tungstenite = { version = "0.24.0", features = ["native-tls"] } +tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread"], default-features = false } serde = { version = "1.0.215", features = ["derive"] } serde_json = "1.0.133" serde_with = "3.11.0" toml = { version = "0.8.19", optional = true } - -# WebSocket support -tokio-tungstenite = { version = "0.24.0", features = ["native-tls"] } -futures-util = "0.3.31" -async-trait = "0.1.83" - -# Utilities -uuid = { version = "1.11.0", features = [ - "v4", - "fast-rng", - "macro-diagnostics", -] } -chrono = "0.4.38" -num-traits = "0.2.19" +uuid = { version = "1.11.0", features = ["v4", "fast-rng"] } base64 = "0.22.1" ring = "0.17.8" -rand = "0.8.5" openssl = "0.10.68" [[example]] @@ -103,17 +82,12 @@ required-features = ["config"] [[example]] name = "websocket" path = "examples/websocket.rs" -required-features = ["config"] [[example]] name = "websocket_user" path = "examples/websocket_user.rs" required-features = ["config"] -[[example]] -name = "watch_candles" -path = "examples/watch_candles.rs" - [[example]] name = "custom_config" path = "examples/custom_config.rs" diff --git a/README.md b/README.md index 5684d05..9061ef7 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +

GitHub repo size

+ --- @@ -39,16 +41,18 @@ cbadv = { git = "https://github.com/ohkthx/cbadv-rs", branch = "main" } ## Table of Contents -- [Features](#features) -- [Documentation](#documentation) -- [Configuration](#configuration) -- [Examples](#examples) -- [API Coverage](#api-coverage) - - [WebSocket API](#websocket-api) - - [REST API](#rest-api) -- [TODO](#todo) -- [Contributing](#contributing) -- [Tips Appreciated!](#tips-appreciated) +- [Asynchronous CoinBase Advanced API](#asynchronous-coinbase-advanced-api) + - [Table of Contents](#table-of-contents) + - [Features](#features) + - [Documentation](#documentation) + - [API Coverage](#api-coverage) + - [WebSocket API](#websocket-api) + - [REST API](#rest-api) + - [Configuration](#configuration) + - [Examples](#examples) + - [TODO](#todo) + - [Contributing](#contributing) + - [Tips Appreciated](#tips-appreciated) --- @@ -58,7 +62,6 @@ cbadv = { git = "https://github.com/ohkthx/cbadv-rs", branch = "main" } - Authenticated and Public REST Endpoints. - Builders to create REST and WebSocket Clients. - Convenient configuration file support for API keys (`features = ["config"]`). -- Comprehensive coverage of all accessible REST and WebSocket endpoints (as of **20231206**). - Numerous examples for seamless integration and testing. --- diff --git a/examples/README.md b/examples/README.md index b8907ab..0a55876 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,3 +1,4 @@ +

GitHub repo size

+ # cbadv-rs: Coinbase Advanced Trading API Wrapper @@ -16,23 +18,24 @@ Welcome to **cbadv-rs**, a Rust crate for interacting with the Coinbase Advanced ## Table of Contents -- [Examples](#examples) - - [Account API](#account-api) - - [Product API](#product-api) - - [Fee API](#fee-api) - - [Order API](#order-api) - - [Portfolio API](#portfolio-api) - - [Payment API](#payment-api) - - [Convert API](#convert-api) - - [Data API](#data-api) - - [Public API](#public-api) - - [Sandbox API](#sandbox-api) - - [WebSocket API](#websocket-api) - - [User Orders (WebSocket API)](#user-orders-websocket-api) - - [Watch Candles (WebSocket API)](#watch-candles-websocket-api) - - [Custom Configurations](#custom-configurations) -- [Contributing](#contributing) -- [License](#license) +- [cbadv-rs: Coinbase Advanced Trading API Wrapper](#cbadv-rs-coinbase-advanced-trading-api-wrapper) + - [Table of Contents](#table-of-contents) + - [Examples](#examples) + - [Account API](#account-api) + - [Product API](#product-api) + - [Fee API](#fee-api) + - [Order API](#order-api) + - [Portfolio API](#portfolio-api) + - [Convert API](#convert-api) + - [Payment API](#payment-api) + - [Data API](#data-api) + - [Public API](#public-api) + - [Sandbox API](#sandbox-api) + - [WebSocket API](#websocket-api) + - [User Orders (WebSocket API)](#user-orders-websocket-api) + - [Custom Configurations](#custom-configurations) + - [Contributing](#contributing) + - [License](#license) --- @@ -182,18 +185,6 @@ cargo run --example websocket_user --features="config" --- -#### Watch Candles (WebSocket API) - -Learn how to watch candlestick data via the WebSocket API. Currently, only 5-minute granularity is supported (as of 2023-10-19). Example source: [watch_candles.rs](https://github.com/Ohkthx/cbadv-rs/tree/main/examples/watch_candles.rs) - -**Run the example**: - -```bash -cargo run --example watch_candles --features="config" -``` - ---- - ### Custom Configurations Learn how to create custom configuration files tailored to your integration needs. Example source: [custom_config.rs](https://github.com/Ohkthx/cbadv-rs/tree/main/examples/custom_config.rs) diff --git a/examples/account_api.rs b/examples/account_api.rs index a0c269a..de685c5 100644 --- a/examples/account_api.rs +++ b/examples/account_api.rs @@ -63,7 +63,7 @@ async fn main() { match accounts.iter().position(|r| r.currency == product_name) { Some(index) => { let account = accounts.get(index).unwrap(); - account_uuid = account.uuid.clone(); + account_uuid.clone_from(&account.uuid); } None => println!("Out of bounds, could not find account."), } diff --git a/examples/order_api.rs b/examples/order_api.rs index 6b3da9c..cd48110 100644 --- a/examples/order_api.rs +++ b/examples/order_api.rs @@ -10,38 +10,16 @@ use std::process::exit; use std::thread; +use std::time::Duration; use cbadv::config::{self, BaseConfig}; use cbadv::models::order::{ - OrderCancelRequest, OrderCreateBuilder, OrderEditRequest, OrderListQuery, OrderSide, - OrderStatus, OrderType, TimeInForce, + OrderCancelRequest, OrderCreateBuilder, OrderCreateRequest, OrderEditRequest, OrderListQuery, + OrderSide, OrderStatus, OrderType, TimeInForce, }; -use cbadv::RestClientBuilder; -use chrono::Duration; - -#[tokio::main] -async fn main() { - let create_new: bool = false; - let edit_created: bool = true; - let cancel_created: bool = true; - let cancel_all: bool = false; - let product_id: &str = "ETH-USDC"; - let mut created_order_id: Option = None; - let new_order = match OrderCreateBuilder::new(product_id, OrderSide::Buy) - .base_size(0.005) - .limit_price(100.0) - .post_only(true) - .order_type(OrderType::Limit) - .time_in_force(TimeInForce::GoodUntilCancelled) - .build() - { - Ok(order) => order, - Err(error) => { - println!("Unable to build order: {error}"); - exit(1); - } - }; +use cbadv::{RestClient, RestClientBuilder}; +fn init_client() -> RestClient { // Load the configuration file. let config: BaseConfig = match config::load("config.toml") { Ok(c) => c, @@ -60,53 +38,101 @@ async fn main() { }; // Create a client to interact with the API. - let mut client = match RestClientBuilder::new().with_config(&config).build() { + match RestClientBuilder::new().with_config(&config).build() { Ok(c) => c, Err(why) => { eprintln!("!ERROR! {why}"); exit(1) } - }; + } +} - if create_new { - println!( - "Creating Order with Client ID: {}", - new_order.client_order_id - ); - match client.order.create(&new_order).await { - Ok(summary) => { - if let Some(success) = &summary.success_response { - created_order_id = Some(success.order_id.clone()); - } - println!("Order creation result: {summary:#?}"); +async fn create_new_order( + client: &mut RestClient, + new_order: &OrderCreateRequest, +) -> Option { + let mut created_order_id: Option = None; + println!( + "Creating Order with Client ID: {}", + new_order.client_order_id + ); + + match client.order.create(new_order).await { + Ok(summary) => { + if let Some(success) = &summary.success_response { + created_order_id = Some(success.order_id.clone()); } - Err(error) => println!("Unable to create order: {error}"), + println!("Order creation result: {summary:#?}"); + } + Err(error) => println!("Unable to create order: {error}"), + } + + created_order_id +} + +async fn edit_created_order(client: &mut RestClient, order_id: &str) { + let edit_order = OrderEditRequest::new(order_id, 50.0, 0.006); + println!("\n\nEditing order for {order_id}."); + match client.order.edit(&edit_order).await { + Ok(result) => println!("{result:#?}"), + Err(error) => println!("Unable to edit order: {error}"), + } +} + +async fn cancel_created_order(client: &mut RestClient, order_id: &str) { + println!("\n\nCancelling Order with ID: {order_id}"); + match client + .order + .cancel(&OrderCancelRequest::new(&[order_id.to_string()])) + .await + { + Ok(summary) => println!("Order cancel result: {summary:#?}"), + Err(error) => println!("Unable to cancel order: {error}"), + } +} + +#[tokio::main] +async fn main() { + let create_new: bool = false; + let edit_created: bool = true; + let cancel_created: bool = true; + let cancel_all: bool = false; + let product_id: &str = "ETH-USDC"; + let mut created_order_id: Option = None; + let new_order = match OrderCreateBuilder::new(product_id, OrderSide::Buy) + .base_size(0.005) + .limit_price(100.0) + .post_only(true) + .order_type(OrderType::Limit) + .time_in_force(TimeInForce::GoodUntilCancelled) + .build() + { + Ok(order) => order, + Err(error) => { + println!("Unable to build order: {error}"); + exit(1); } + }; + + let mut client = init_client(); + + // Creates a new order from scratch, the resulting order id will be used for other operations. + if create_new { + created_order_id = create_new_order(&mut client, &new_order).await; } + // Edits the created order. if let Some(order_id) = &created_order_id { if create_new && edit_created { - thread::sleep(Duration::seconds(1).to_std().unwrap()); - let edit_order = OrderEditRequest::new(order_id, 50.0, 0.006); - println!("\n\nEditing order for {order_id}."); - match client.order.edit(&edit_order).await { - Ok(result) => println!("{result:#?}"), - Err(error) => println!("Unable to edit order: {error}"), - } + thread::sleep(Duration::from_secs(1)); + edit_created_order(&mut client, order_id).await; } } + // Cancels the created order. if let Some(order_id) = &created_order_id { if create_new && cancel_created { - println!("\n\nCancelling Order with ID: {order_id}"); - match client - .order - .cancel(&OrderCancelRequest::new(&[order_id.clone()])) - .await - { - Ok(summary) => println!("Order cancel result: {summary:#?}"), - Err(error) => println!("Unable to cancel order: {error}"), - } + cancel_created_order(&mut client, order_id).await; } } @@ -143,7 +169,7 @@ async fn main() { println!("Orders obtained: {:#?}", orders.orders.len()); match orders.orders.first() { Some(order) => { - order_id = order.order_id.clone(); + order_id.clone_from(&order.order_id); println!("{order:#?}"); } None => println!("Out of bounds, no orders exist."), diff --git a/examples/watch_candles.rs b/examples/watch_candles.rs deleted file mode 100644 index 5345205..0000000 --- a/examples/watch_candles.rs +++ /dev/null @@ -1,114 +0,0 @@ -//! # Watch Candle Example -//! -//! Shows how to: -//! - Create a user-defined struct that receives updates. -//! - Implement required trait `CandleCallback` for the user-defined struct. -//! - Initialize and watch candles via WebSocket. -//! - Process candles coming from API. - -use std::process::exit; - -use cbadv::models::product::{Candle, ProductListQuery}; -use cbadv::traits::CandleCallback; -use cbadv::{async_trait, RestClient, RestClientBuilder, WebSocketClientBuilder}; - -/// Example of user-defined struct to pass to the candle watcher. -pub struct UserStruct { - /// Total amount of candles seen. - processed: usize, -} - -#[async_trait] -impl CandleCallback for UserStruct { - async fn candle_callback(&mut self, current_start: u64, product_id: String, candle: Candle) { - self.processed += 1; - - let mut is_same = ""; - if current_start == candle.start { - is_same = "[MATCHES CURRENT START]"; - } - - // Processed | Product_Id | Candle Start | Current - println!( - "{:<5} {:>14} ({}): finished candle {}", - self.processed, product_id, candle.start, is_same - ); - } -} - -/// Obtain product names of candles to be obtained. -async fn get_products(client: &mut RestClient) -> Vec { - println!("Getting '*-USDC' products."); - - // Holds all of the product names. - let mut product_names: Vec = vec![]; - let query = ProductListQuery::new(); - - // Pull multiple products from the Product API. - match client.public.products(&query).await { - Ok(products) => { - product_names = products - .iter() - // Filter products to only containing *-USDC pairs. - .filter_map(|p| match p.quote_currency_id.as_str() { - "USDC" => Some(p.product_id.clone()), - _ => None, - }) - .collect(); - } - Err(error) => println!("Unable to get products: {error}"), - } - - product_names -} - -#[tokio::main] -async fn main() -> Result<(), Box> { - // Create a client to interact with the API. - let mut rclient = match RestClientBuilder::new().build() { - Ok(c) => c, - Err(why) => { - eprintln!("!ERROR! {why}"); - exit(1) - } - }; - - // Create a client to interact with the API. - let wsclient = match WebSocketClientBuilder::new() - .auto_reconnect(true) - .max_retries(20) - .build() - { - Ok(c) => c, - Err(why) => { - eprintln!("!ERROR! {why}"); - exit(1) - } - }; - - // Products of interest. - let products = get_products(&mut rclient).await; - println!("Obtained {} products.\n", products.len()); - - // User struct to be passed to the watcher. - let mystruct: UserStruct = UserStruct { processed: 0 }; - - // Start watching candles. - println!("Starting candle watcher for {} products.", products.len()); - // let task = match websocket::watch_candles(&mut wsclient, &products, mystruct).await { - let task = match wsclient.watch_candles(&products, mystruct).await { - Ok(value) => value, - Err(err) => { - println!("Could not watch candles: {err}"); - exit(1); - } - }; - - // Wait to join the task. - match task.await { - Ok(()) => println!("Task is complete."), - Err(err) => println!("Task ended in error: {err}"), - }; - - Ok(()) -} diff --git a/examples/websocket.rs b/examples/websocket.rs index 5f9d1c6..ba6a1a9 100644 --- a/examples/websocket.rs +++ b/examples/websocket.rs @@ -7,21 +7,36 @@ //! - Unsubscribe to channels. use std::process::exit; +use std::time::{Duration, Instant}; -use cbadv::models::websocket::{Channel, EndpointType, Message}; +use cbadv::models::websocket::{Channel, EndpointStream, Events, Message}; use cbadv::types::CbResult; -use cbadv::{FunctionCallback, WebSocketClientBuilder}; +use cbadv::WebSocketClientBuilder; /// This is used to parse messages. It is passed to the `listen` function to pull Messages out of /// the stream. -fn message_callback(msg: CbResult) { +fn message_action(msg: CbResult) -> Result<(), String> { let rcvd = match msg { - Ok(message) => format!("{message:?}"), // Leverage Debug for all Message variants + Ok(Message { + events: Events::Candles(candles_events), + channel, + .. + }) => { + for ticker in candles_events { + println!("{ticker:?}"); + } + format!("this is a {channel:?} message") + } + Ok(message) => format!( + "this is not a candles message it is a {:?} message", + message.channel + ), // Leverage Debug for all Message variants Err(error) => format!("Error: {error}"), // Handle WebSocket errors }; // Update the callback object's properties and log the message. println!("{rcvd}\n"); + Ok(()) } #[tokio::main] @@ -37,28 +52,15 @@ async fn main() { }) .unwrap(); - // Assign the callback function to an object. - let callback = FunctionCallback::from_sync(message_callback); - // Connect to the websocket, a subscription needs to be sent within 5 seconds. // If a subscription is not sent, Coinbase will close the connection. - let mut readers = client + let readers = client .connect() .await .expect("Could not connect to WebSocket"); - let public = readers - .take_endpoint(&EndpointType::Public) - .expect("Could not get public reader"); - - let listened_client = client.clone(); - let listener = tokio::spawn(async move { - let mut listened_client = listened_client; - listened_client.listen(public, callback).await; - }); - // Products of interest. - let products = vec!["BTC-USD".to_string(), "ETH-USD".to_string()]; + let products = vec!["BTC-USDC".to_string(), "ETH-USDC".to_string()]; // Heartbeats is a great way to keep a connection alive and not timeout. client.subscribe(&Channel::Heartbeats, &[]).await.unwrap(); @@ -69,12 +71,40 @@ async fn main() { .await .unwrap(); + // Get updates (subscribe) on products and currencies. + client.subscribe(&Channel::Level2, &products).await.unwrap(); + // Stop obtaining (unsubscribe) updates on products and currencies. client .unsubscribe(&Channel::Status, &products) .await .unwrap(); - // Passes the parser callback and listens for messages. - listener.await.unwrap(); + let mut count = 0; + const TICK_RATE: u64 = 1000 / 60; + let mut last_tick = Instant::now(); + let mut stream: EndpointStream = readers.into(); + + loop { + // Fetch messages from the WebSocket stream. + let _ = client.fetch_sync(&mut stream, 100, |msg| { + count += 1; + print!("{count}: "); + message_action(msg) + }); + + // Calculate the time since the last tick and sleep for the remaining time to hit the tick rate. + let last_tick_ms = last_tick.elapsed().as_millis(); + let timeout = match u64::try_from(last_tick_ms) { + Ok(ms) => TICK_RATE.saturating_sub(ms), + Err(why) => { + eprintln!("Conversion error: {why}"); + TICK_RATE + } + }; + + // Sleep for the remaining time to hit the tick rate. Prevent busy loop. + tokio::time::sleep(Duration::from_millis(timeout)).await; + last_tick = Instant::now(); + } } diff --git a/examples/websocket_user.rs b/examples/websocket_user.rs index 04f1dfd..3001e68 100644 --- a/examples/websocket_user.rs +++ b/examples/websocket_user.rs @@ -8,11 +8,12 @@ use std::process::exit; +use tokio::sync::mpsc; + use cbadv::config::{self, BaseConfig}; -use cbadv::models::websocket::{Channel, EndpointType, Message}; -use cbadv::traits::MessageCallback; +use cbadv::models::websocket::{Channel, Message}; use cbadv::types::CbResult; -use cbadv::{async_trait, WebSocketClientBuilder}; +use cbadv::WebSocketClientBuilder; /// Example of an object with an attached callback function for messages. struct CallbackObject { @@ -20,11 +21,10 @@ struct CallbackObject { total_processed: usize, } -#[async_trait] -impl MessageCallback for CallbackObject { +impl CallbackObject { /// This is used to parse messages. It is passed to the `listen` function to pull Messages out of /// the stream. - async fn message_callback(&mut self, msg: CbResult) { + async fn message_action(&mut self, msg: CbResult) { let rcvd = match msg { Ok(message) => format!("{message:?}"), // Leverage Debug for all Message variants Err(error) => format!("Error: {error}"), // Handle WebSocket errors @@ -67,31 +67,40 @@ async fn main() { .unwrap(); // Callback Object. - let callback = CallbackObject { total_processed: 0 }; + let mut callback = CallbackObject { total_processed: 0 }; + + // Create an mpsc channel for communication. + let (tx, mut rx) = mpsc::channel::>(100); // Connect to the websocket, a subscription needs to be sent within 5 seconds. // If a subscription is not sent, Coinbase will close the connection. - let mut readers = client + let readers = client .connect() .await .expect("Could not connect to WebSocket."); - let user = readers - .take_endpoint(&EndpointType::User) - .expect("Could not get secure user reader."); + // Basic subscriptions. + client.subscribe(&Channel::Heartbeats, &[]).await.unwrap(); + client.subscribe(&Channel::User, &[]).await.unwrap(); - let listened_client = client.clone(); + // Spawn the listener task. let listener = tokio::spawn(async move { - let mut listened_client = listened_client; - listened_client.listen(user, callback).await; + client + .listen(readers, move |msg| { + let tx = tx.clone(); + async move { + if tx.send(msg).await.is_err() { + eprintln!("Receiver dropped. Exiting listener..."); + } + } + }) + .await; }); - // Heartbeats is a great way to keep a connection alive and not timeout. - client.subscribe(&Channel::Heartbeats, &[]).await.unwrap(); - - // Subscribe to user orders. - client.subscribe(&Channel::User, &[]).await.unwrap(); + // Process messages in the main task. + while let Some(msg) = rx.recv().await { + callback.message_action(msg).await; + } - // Passes the parser callback and listens for messages. listener.await.unwrap(); } diff --git a/src/candle_watcher.rs b/src/candle_watcher.rs deleted file mode 100644 index b73c52a..0000000 --- a/src/candle_watcher.rs +++ /dev/null @@ -1,164 +0,0 @@ -//! Candle Watcher is the underlying object used to track candle updates. - -use std::collections::HashMap; - -use async_trait::async_trait; -use chrono::Utc; - -use crate::constants::websocket::GRANULARITY; -use crate::models::product::Candle; -use crate::models::websocket::{CandleUpdate, Channel, Endpoint, Event, Message}; -use crate::traits::{CandleCallback, MessageCallback}; -use crate::types::CbResult; -use crate::WebSocketClient; - -/// Tracks the candle watcher task. -pub(crate) struct CandleWatcher -where - T: CandleCallback, -{ - /// Holds the most recent candle processed for each product. [key: Product Id, value: Candle] - candles: HashMap, - /// User-defined object that implements `CandleCallback`, triggered on completed candles. - user_watcher: T, -} - -impl CandleWatcher -where - T: CandleCallback, -{ - /// Starts the task that tracks candles for completion. - /// - /// # Arguments - /// - /// * `reader` - WebSocket reader to receive updates. - /// * `user_obj` - User object that implements `CandleCallback` to receive completed candles. - pub(crate) async fn start(mut client: WebSocketClient, endpoint: Endpoint, user_obj: T) - where - T: CandleCallback + Send + Sync + 'static, - { - let tracker = Self { - candles: HashMap::new(), - user_watcher: user_obj, - }; - - // Start the listener. - client.listen(endpoint, tracker).await; - } - - /// Returns a completed candle if a newer candle is received. - /// - /// # Arguments - /// - /// * `product_id` - The ID of the product this candle belongs to. - /// * `new_candle` - The new candle update received from the WebSocket. - fn check_candle(&mut self, product_id: &str, new_candle: Candle) -> Option { - // Retrieve the current candle for the product. - if let Some(existing_candle) = self.candles.get(product_id) { - if existing_candle.start < new_candle.start { - // A newer candle has been received; replace the existing candle. - let completed_candle = self.candles.remove(product_id).unwrap(); - self.candles.insert(product_id.to_string(), new_candle); - Some(completed_candle) // Return the completed candle. - } else { - // Update the existing candle without considering it complete. - self.candles.insert(product_id.to_string(), new_candle); - None - } - } else { - // No existing candle; add the new candle as the initial one. - self.candles.insert(product_id.to_string(), new_candle); - None - } - } - - /// Extracts candle updates from a WebSocket message. - /// - /// # Arguments - /// - /// * `message` - The WebSocket message to extract updates from. - /// - /// # Returns - /// - /// A vector of `CandleUpdate` sorted by timestamp (newest first). - fn extract_candle_updates(message: &Message) -> Vec { - let mut updates: Vec = message - .events - .iter() - .filter_map(|event| { - if let Event::Candles(candles_event) = event { - Some(candles_event.candles.clone()) - } else { - None - } - }) - .flatten() - .collect(); - - // Sort updates by timestamp (newest first). - updates.sort_by(|a, b| b.data.start.cmp(&a.data.start)); - updates - } - - /// Processes a vector of candle updates. - /// - /// # Arguments - /// - /// * `updates` - The sorted vector of `CandleUpdate` to process. - async fn process_candle_updates(&mut self, mut updates: Vec) { - if let Some(update) = updates.pop() { - let product_id = update.product_id.clone(); - let new_candle = update.data; - - if let Some(completed_candle) = self.check_candle(&product_id, new_candle) { - self.trigger_user_callback(product_id, completed_candle) - .await; - } - } - } - - /// Triggers the user's callback with a completed candle. - /// - /// # Arguments - /// - /// * `product_id` - The ID of the product associated with the candle. - /// * `completed_candle` - The completed candle to send to the callback. - async fn trigger_user_callback(&mut self, product_id: String, completed_candle: Candle) { - #![allow(clippy::cast_sign_loss)] - let now = Utc::now().timestamp() as u64; - let start_time = now - (now % (GRANULARITY * 2)); - - self.user_watcher - .candle_callback(start_time, product_id, completed_candle) - .await; - } -} - -#[async_trait] -impl MessageCallback for CandleWatcher -where - T: CandleCallback + Send + Sync, -{ - /// Handles incoming messages and processes candle updates. - async fn message_callback(&mut self, msg: CbResult) { - match msg { - Ok(message) => { - if message.channel != Channel::Candles { - return; // Ignore non-candle messages. - } - - // Extract candle updates and process them. - let updates = CandleWatcher::::extract_candle_updates(&message); - if updates.is_empty() { - return; // No updates to process. - } - - // Process the most recent update and handle completed candles. - self.process_candle_updates(updates).await; - } - Err(err) => { - eprintln!("!WEBSOCKET ERROR! {err}"); - } - } - } -} diff --git a/src/constants.rs b/src/constants.rs index 41b47cf..26dfc4f 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -69,10 +69,6 @@ pub(crate) mod public { pub(crate) mod websocket { pub(crate) const PUBLIC_ENDPOINT: &str = "wss://advanced-trade-ws.coinbase.com"; pub(crate) const SECURE_ENDPOINT: &str = "wss://advanced-trade-ws-user.coinbase.com"; - - /// Granularity of Candles from the WebSocket Candle subscription. - /// NOTE: This is a restriction by `CoinBase` and cannot be currently changed (20240125) - pub(crate) const GRANULARITY: u64 = 300; } /// Amount of tokens per second refilled. diff --git a/src/jwt.rs b/src/jwt.rs index 7e8beec..24eb3ca 100644 --- a/src/jwt.rs +++ b/src/jwt.rs @@ -250,10 +250,9 @@ impl Jwt { } // Implement serialization for Header to handle base64 encoding -impl<'a> Header<'a> { +impl Header<'_> { fn serialize_base64(&self) -> CbResult { let raw = serde_json::to_vec(self).map_err(|why| CbError::BadSignature(why.to_string()))?; Ok(URL_SAFE_NO_PAD.encode(&raw)) } } - diff --git a/src/lib.rs b/src/lib.rs index 29c33bb..02ca330 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,7 +23,9 @@ pub mod config; #[macro_use] pub(crate) mod macros; -mod candle_watcher; +/// Re-export tokio for use in the library. +pub use tokio::{self, main as tokio_main}; + pub(crate) mod http_agent; pub(crate) mod jwt; mod token_bucket; @@ -34,7 +36,6 @@ pub mod time; pub mod traits; pub mod types; pub(crate) mod utils; -pub use utils::FunctionCallback; pub mod apis; pub mod models; @@ -43,6 +44,3 @@ mod rest; mod websocket; pub use rest::{RestClient, RestClientBuilder}; pub use websocket::{WebSocketClient, WebSocketClientBuilder}; - -// Re-export async_trait for the end-user. -pub use async_trait::async_trait; diff --git a/src/models/account.rs b/src/models/account.rs index 0c45c30..0ad07dc 100644 --- a/src/models/account.rs +++ b/src/models/account.rs @@ -108,7 +108,7 @@ impl Query for AccountListQuery { fn to_query(&self) -> String { QueryBuilder::new() .push("limit", self.limit) - .push_optional("cursor", &self.cursor) + .push_optional("cursor", self.cursor.as_ref()) .build() } } diff --git a/src/models/fee.rs b/src/models/fee.rs index 347e2f3..0c87dfb 100644 --- a/src/models/fee.rs +++ b/src/models/fee.rs @@ -97,7 +97,7 @@ impl Query for FeeTransactionSummaryQuery { fn to_query(&self) -> String { QueryBuilder::new() - .push_optional("product_type", &self.product_type) + .push_optional("product_type", self.product_type.as_ref()) .build() } } diff --git a/src/models/order/builders.rs b/src/models/order/builders.rs index be65c83..d61a1e0 100644 --- a/src/models/order/builders.rs +++ b/src/models/order/builders.rs @@ -429,7 +429,7 @@ impl OrderCreateBuilder { fn build_limit_gtd(&self) -> Result { let base_size = require_field(self.base_size, "base_size")?; let limit_price = require_field(self.limit_price, "limit_price")?; - let end_time = require_field_ref(&self.end_time, "end_time")?; + let end_time = require_field_ref(self.end_time.as_ref(), "end_time")?; Ok(OrderConfiguration::LimitGtd(LimitGtd { base_size, @@ -460,7 +460,7 @@ impl OrderCreateBuilder { let limit_price = require_field(self.limit_price, "limit_price")?; let stop_price = require_field(self.stop_price, "stop_price")?; let stop_direction = require_field(self.stop_direction, "stop_direction")?; - let end_time = require_field_ref(&self.end_time, "end_time")?; + let end_time = require_field_ref(self.end_time.as_ref(), "end_time")?; Ok(OrderConfiguration::StopLimitGtd(StopLimitGtd { base_size, @@ -478,8 +478,6 @@ fn require_field(field: Option, field_name: &str) -> Result { } /// Validates that a required field reference is present and returns it, or an error if it is missing. -fn require_field_ref<'a, T>(field: &'a Option, field_name: &str) -> Result<&'a T, CbError> { - field - .as_ref() - .ok_or_else(move || CbError::BadParse(format!("{field_name} is required."))) +fn require_field_ref<'a, T>(field: Option<&'a T>, field_name: &str) -> Result<&'a T, CbError> { + field.ok_or_else(|| CbError::BadParse(format!("{field_name} is required."))) } diff --git a/src/models/order/enums.rs b/src/models/order/enums.rs index 12be5d3..a474e7a 100644 --- a/src/models/order/enums.rs +++ b/src/models/order/enums.rs @@ -204,8 +204,8 @@ impl AsRef for TimeInForce { fn as_ref(&self) -> &str { match self { TimeInForce::Unknown => "UNKNOWN_TIME_IN_FORCE", - TimeInForce::GoodUntilCancelled => "GOOD_TIL_CANCELLED", - TimeInForce::GoodUntilDate => "GOOD_TIL_DATE_TIME", + TimeInForce::GoodUntilCancelled => "GOOD_UNTIL_CANCELLED", + TimeInForce::GoodUntilDate => "GOOD_UNTIL_DATE_TIME", TimeInForce::ImmediateOrCancel => "IMMEDIATE_OR_CANCEL", TimeInForce::FillOrKill => "FILL_OR_KILL", } diff --git a/src/models/order/queries.rs b/src/models/order/queries.rs index fccb2a1..4fcaedd 100644 --- a/src/models/order/queries.rs +++ b/src/models/order/queries.rs @@ -73,19 +73,19 @@ impl Query for OrderListQuery { /// Converts the object into HTTP request parameters. fn to_query(&self) -> String { QueryBuilder::new() - .push_optional_vec("order_ids", &self.order_ids) - .push_optional_vec("product_ids", &self.product_ids) - .push_optional("product_type", &self.product_type) - .push_optional_vec("order_status", &self.order_status) - .push_optional_vec("time_in_forces", &self.time_in_forces) - .push_optional_vec("order_types", &self.order_types) - .push_optional("order_side", &self.order_side) - .push_optional("start_date", &self.start_date) - .push_optional("end_date", &self.end_date) - .push_optional_vec("asset_filters", &self.asset_filters) - .push_optional("limit", &self.limit) - .push_optional("cursor", &self.cursor) - .push_optional("sort_by", &self.sort_by) + .push_optional_vec("order_ids", self.order_ids.as_ref()) + .push_optional_vec("product_ids", self.product_ids.as_ref()) + .push_optional("product_type", self.product_type.as_ref()) + .push_optional_vec("order_status", self.order_status.as_ref()) + .push_optional_vec("time_in_forces", self.time_in_forces.as_ref()) + .push_optional_vec("order_types", self.order_types.as_ref()) + .push_optional("order_side", self.order_side.as_ref()) + .push_optional("start_date", self.start_date.as_ref()) + .push_optional("end_date", self.end_date.as_ref()) + .push_optional_vec("asset_filters", self.asset_filters.as_ref()) + .push_optional("limit", self.limit.as_ref()) + .push_optional("cursor", self.cursor.as_ref()) + .push_optional("sort_by", self.sort_by.as_ref()) .build() } } @@ -225,14 +225,20 @@ impl Query for OrderListFillsQuery { /// Converts the object into HTTP request parameters. fn to_query(&self) -> String { QueryBuilder::new() - .push_optional_vec("order_ids", &self.order_ids) - .push_optional_vec("trade_ids", &self.trade_ids) - .push_optional_vec("product_ids", &self.product_ids) - .push_optional("start_sequence_timestamp", &self.start_sequence_timestamp) - .push_optional("end_sequence_timestamp", &self.end_sequence_timestamp) + .push_optional_vec("order_ids", self.order_ids.as_ref()) + .push_optional_vec("trade_ids", self.trade_ids.as_ref()) + .push_optional_vec("product_ids", self.product_ids.as_ref()) + .push_optional( + "start_sequence_timestamp", + self.start_sequence_timestamp.as_ref(), + ) + .push_optional( + "end_sequence_timestamp", + self.end_sequence_timestamp.as_ref(), + ) .push("limit", self.limit) - .push_optional("cursor", &self.cursor) - .push_optional("sort_by", &self.sort_by) + .push_optional("cursor", self.cursor.as_ref()) + .push_optional("sort_by", self.sort_by.as_ref()) .build() } } diff --git a/src/models/order/serde_utils.rs b/src/models/order/serde_utils.rs index 6dc8067..1955313 100644 --- a/src/models/order/serde_utils.rs +++ b/src/models/order/serde_utils.rs @@ -19,7 +19,7 @@ impl<'de> DeDeserialize<'de> for OrderType { struct OrderTypeVisitor; -impl<'de> Visitor<'de> for OrderTypeVisitor { +impl Visitor<'_> for OrderTypeVisitor { type Value = OrderType; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { diff --git a/src/models/portfolio.rs b/src/models/portfolio.rs index e6f235d..afa2eb4 100644 --- a/src/models/portfolio.rs +++ b/src/models/portfolio.rs @@ -328,7 +328,7 @@ impl Query for PortfolioListQuery { fn to_query(&self) -> String { QueryBuilder::new() - .push_optional("portfolio_type", &self.portfolio_type) + .push_optional("portfolio_type", self.portfolio_type.as_ref()) .build() } } @@ -360,7 +360,7 @@ impl Query for PortfolioBreakdownQuery { fn to_query(&self) -> String { QueryBuilder::new() - .push_optional("currency", &self.currency) + .push_optional("currency", self.currency.as_ref()) .build() } } diff --git a/src/models/product.rs b/src/models/product.rs index 4507e2f..f082a4f 100644 --- a/src/models/product.rs +++ b/src/models/product.rs @@ -160,16 +160,20 @@ pub struct Product { /// The trading pair. pub product_id: String, /// The current price for the product, in quote currency. - #[serde_as(as = "DisplayFromStr")] + #[serde_as(as = "DefaultOnError")] + #[serde(default)] pub price: f64, /// The amount the price of the product has changed, in percent, in the last 24 hours. - #[serde_as(as = "DisplayFromStr")] + #[serde_as(as = "DefaultOnError")] + #[serde(default)] pub price_percentage_change_24h: f64, /// The trading volume for the product in the last 24 hours. - #[serde_as(as = "DisplayFromStr")] + #[serde_as(as = "DefaultOnError")] + #[serde(default)] pub volume_24h: f64, /// The percentage amount the volume of the product has changed in the last 24 hours. - #[serde_as(as = "DisplayFromStr")] + #[serde_as(as = "DefaultOnError")] + #[serde(default)] pub volume_percentage_change_24h: f64, /// Minimum amount base value can be increased or decreased at once. #[serde_as(as = "DisplayFromStr")] @@ -388,12 +392,15 @@ impl Query for ProductListQuery { fn to_query(&self) -> String { QueryBuilder::new() - .push_optional("limit", &self.limit) - .push_optional("offset", &self.offset) - .push_optional("product_type", &self.product_type) - .push_optional_vec("product_ids", &self.product_ids) - .push_optional("get_all_products", &self.get_all_products) - .push_optional("get_tradability_status", &self.get_tradability_status) + .push_optional("limit", self.limit.as_ref()) + .push_optional("offset", self.offset.as_ref()) + .push_optional("product_type", self.product_type.as_ref()) + .push_optional_vec("product_ids", self.product_ids.as_ref()) + .push_optional("get_all_products", self.get_all_products.as_ref()) + .push_optional( + "get_tradability_status", + self.get_tradability_status.as_ref(), + ) .build() } } @@ -470,8 +477,8 @@ impl Query for ProductTickerQuery { fn to_query(&self) -> String { QueryBuilder::new() .push("limit", self.limit) - .push_optional("start", &self.start) - .push_optional("end", &self.end) + .push_optional("start", self.start.as_ref()) + .push_optional("end", self.end.as_ref()) .build() } } @@ -532,7 +539,7 @@ impl Query for ProductBidAskQuery { fn to_query(&self) -> String { QueryBuilder::new() - .push_optional_vec("product_ids", &Some(self.product_ids.clone())) + .push_optional_vec("product_ids", Some(self.product_ids.as_ref())) .build() } } @@ -584,10 +591,10 @@ impl Query for ProductBookQuery { fn to_query(&self) -> String { QueryBuilder::new() .push("product_id", &self.product_id) - .push_optional("limit", &self.limit) + .push_optional("limit", self.limit.as_ref()) .push_optional( "aggregation_price_increment", - &self.aggregation_price_increment, + self.aggregation_price_increment.as_ref(), ) .build() } diff --git a/src/models/websocket/enums.rs b/src/models/websocket/enums.rs index 8706281..af7f676 100644 --- a/src/models/websocket/enums.rs +++ b/src/models/websocket/enums.rs @@ -16,7 +16,8 @@ pub enum Channel { Ticker, /// Real-time price updates every 5000 milli-seconds. TickerBatch, - /// All updates and easiest way to keep order book snapshot + /// All updates and easiest way to keep order book snapshot. + #[serde(alias = "l2_data")] Level2, /// Real-time updates every time a market trade happens. MarketTrades, @@ -40,7 +41,11 @@ pub enum EventType { #[derive(Serialize, SerdeDeserialize, PartialEq, Debug)] #[serde(rename_all = "snake_case")] pub enum Level2Side { + /// Bids / Buy side. Bid, + /// Asks / Offer side. + /// NOTE: As of 20241209, the API has a typo and uses "offer" instead of "ask". + #[serde(alias = "offer")] Ask, } diff --git a/src/models/websocket/events.rs b/src/models/websocket/events.rs index 4665335..49f06d4 100644 --- a/src/models/websocket/events.rs +++ b/src/models/websocket/events.rs @@ -1,4 +1,4 @@ -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use super::{ CandleUpdate, EventType, FuturesBalanceSummaryUpdate, Level2Update, MarketTradesUpdate, @@ -6,43 +6,44 @@ use super::{ }; /// Events that could be received in a message. -#[derive(Debug)] -pub enum Event { - Status(StatusEvent), - Candles(CandlesEvent), - Ticker(TickerEvent), - TickerBatch(TickerEvent), - Level2(Level2Event), - User(UserEvent), - MarketTrades(MarketTradesEvent), - Heartbeats(HeartbeatsEvent), - Subscribe(SubscribeEvent), - FuturesBalanceSummary(FuturesSummaryBalanceEvent), +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Events { + Status(Vec), + Candles(Vec), + Ticker(Vec), + TickerBatch(Vec), + Level2(Vec), + User(Vec), + MarketTrades(Vec), + Heartbeats(Vec), + Subscribe(Vec), + FuturesBalanceSummary(Vec), } /// The status event containing updates to products. -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct StatusEvent { pub r#type: EventType, pub products: Vec, } /// The candles event containing updates to candles. -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct CandlesEvent { pub r#type: EventType, pub candles: Vec, } /// The ticker event containing updates to tickers. -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct TickerEvent { pub r#type: EventType, pub tickers: Vec, } /// The level2 event containing updates to the order book. -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct Level2Event { pub r#type: EventType, pub product_id: String, @@ -50,34 +51,34 @@ pub struct Level2Event { } /// The user event containing updates to orders. -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct UserEvent { pub r#type: EventType, pub orders: Vec, } /// The market trades event containing updates to trades. -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct MarketTradesEvent { pub r#type: EventType, pub trades: Vec, } /// The heartbeats event containing the current time and heartbeat counter. -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct HeartbeatsEvent { pub current_time: String, pub heartbeat_counter: u64, } /// The subscribe event containing the current subscriptions. -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct SubscribeEvent { pub subscriptions: SubscribeUpdate, } /// The futures summary balance event containing the current futures account balance. -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct FuturesSummaryBalanceEvent { pub r#type: EventType, pub fcm_balance_summary: FuturesBalanceSummaryUpdate, diff --git a/src/models/websocket/message.rs b/src/models/websocket/message.rs index 62b7beb..58e67a8 100644 --- a/src/models/websocket/message.rs +++ b/src/models/websocket/message.rs @@ -1,15 +1,9 @@ -use std::fmt; +use serde::{Deserialize, Serialize}; -use serde::de::{self, Deserialize, Deserializer, MapAccess, Visitor}; -use serde_json::Value; - -use super::{ - CandlesEvent, Channel, Event, FuturesSummaryBalanceEvent, HeartbeatsEvent, Level2Event, - MarketTradesEvent, StatusEvent, SubscribeEvent, TickerEvent, UserEvent, -}; +use super::{Channel, Events}; /// Message from the WebSocket containing event updates. -#[derive(Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct Message { /// The channel the message is from. pub channel: Channel, @@ -20,147 +14,32 @@ pub struct Message { /// The sequence number for the message pub sequence_num: u64, /// The events in the message. - pub events: Vec, + pub events: Events, } -/// Custom deserialization for Message. -impl<'de> Deserialize<'de> for Message { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - deserializer.deserialize_map(MessageVisitor) - } -} - -/// Visitor struct for custom deserialization for Message. -struct MessageVisitor; - -impl<'de> Visitor<'de> for MessageVisitor { - type Value = Message; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a WebSocket message") - } - - fn visit_map(self, mut map: M) -> Result - where - M: MapAccess<'de>, - { - let mut channel: Option = None; - let mut client_id: Option = None; - let mut timestamp: Option = None; - let mut sequence_num: Option = None; - let mut events_value: Option = None; - - // Extract common fields and store the raw events for later deserialization. - while let Some(key) = map.next_key::<&str>()? { - match key { - "channel" => { - if channel.is_some() { - return Err(de::Error::duplicate_field("channel")); +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn heartbeat_works() { + let data = r#" + { + "channel":"heartbeats", + "client_id":"", + "timestamp":"2025-01-14T22:11:18.791273556Z", + "sequence_num":17, + "events": + [ + { + "current_time":"2025-01-14 22:11:18.787177997 +0000 UTC m=+25541.571430466", + "heartbeat_counter":25539 } - channel = Some(map.next_value()?); - } - "client_id" => { - if client_id.is_some() { - return Err(de::Error::duplicate_field("client_id")); - } - client_id = Some(map.next_value()?); - } - "timestamp" => { - if timestamp.is_some() { - return Err(de::Error::duplicate_field("timestamp")); - } - timestamp = Some(map.next_value()?); - } - "sequence_num" => { - if sequence_num.is_some() { - return Err(de::Error::duplicate_field("sequence_num")); - } - sequence_num = Some(map.next_value()?); - } - "events" => { - if events_value.is_some() { - return Err(de::Error::duplicate_field("events")); - } - // Temporarily store events as serde_json::Value - events_value = Some(map.next_value()?); - } - _ => { - // Skip unknown fields or handle as needed. - let _ = map.next_value::()?; - } + ] } - } - - let channel = channel.ok_or_else(|| de::Error::missing_field("channel"))?; - let client_id = client_id.ok_or_else(|| de::Error::missing_field("client_id"))?; - let timestamp = timestamp.ok_or_else(|| de::Error::missing_field("timestamp"))?; - let sequence_num = sequence_num.ok_or_else(|| de::Error::missing_field("sequence_num"))?; - let events_value = events_value.ok_or_else(|| de::Error::missing_field("events"))?; - - // Deserialize events based on the channel. - let events = deserialize_events(&channel, events_value).map_err(de::Error::custom)?; - - Ok(Message { - channel, - client_id, - timestamp, - sequence_num, - events, - }) - } -} + "#; -/// Helper function to deserialize events based on the channel. -fn deserialize_events( - channel: &Channel, - events_value: Value, -) -> Result, Box> { - match channel { - Channel::Status => { - let events: Vec = serde_json::from_value(events_value)?; - Ok(events.into_iter().map(Event::Status).collect()) - } - Channel::Candles => { - let events: Vec = serde_json::from_value(events_value)?; - Ok(events.into_iter().map(Event::Candles).collect()) - } - Channel::Ticker => { - let events: Vec = serde_json::from_value(events_value)?; - Ok(events.into_iter().map(Event::Ticker).collect()) - } - Channel::TickerBatch => { - let events: Vec = serde_json::from_value(events_value)?; - Ok(events.into_iter().map(Event::TickerBatch).collect()) - } - Channel::Level2 => { - let events: Vec = serde_json::from_value(events_value)?; - Ok(events.into_iter().map(Event::Level2).collect()) - } - Channel::User => { - let events: Vec = serde_json::from_value(events_value)?; - Ok(events.into_iter().map(Event::User).collect()) - } - Channel::MarketTrades => { - let events: Vec = serde_json::from_value(events_value)?; - Ok(events.into_iter().map(Event::MarketTrades).collect()) - } - Channel::Heartbeats => { - let events: Vec = serde_json::from_value(events_value)?; - Ok(events.into_iter().map(Event::Heartbeats).collect()) - } - Channel::Subscriptions => { - let events: Vec = serde_json::from_value(events_value)?; - Ok(events.into_iter().map(Event::Subscribe).collect()) - } - Channel::FuturesBalanceSummary => { - let events: Vec = serde_json::from_value(events_value)?; - Ok(events - .into_iter() - .map(Event::FuturesBalanceSummary) - .collect()) - } + let res: Result = serde_json::from_str(data); + assert!(res.is_ok()); } } diff --git a/src/models/websocket/responses.rs b/src/models/websocket/responses.rs index dbf4335..e837d5f 100644 --- a/src/models/websocket/responses.rs +++ b/src/models/websocket/responses.rs @@ -2,12 +2,12 @@ use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DefaultOnError, DisplayFromStr}; use crate::models::order::{OrderSide, OrderStatus, OrderType, TimeInForce, TriggerStatus}; -use crate::models::product::{Candle, ProductType}; +use crate::models::product::{Candle, Product, ProductType}; use super::Level2Side; #[serde_as] -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct Level2Update { pub side: Level2Side, pub event_time: String, @@ -17,7 +17,7 @@ pub struct Level2Update { pub new_quantity: f64, } -#[derive(Deserialize, Debug, Default)] +#[derive(Serialize, Deserialize, Debug, Default)] pub struct SubscribeUpdate { #[serde(default)] pub status: Vec, @@ -64,6 +64,23 @@ pub struct ProductUpdate { pub min_market_funds: f64, } +impl From for ProductUpdate { + fn from(product: Product) -> Self { + ProductUpdate { + product_type: product.product_type, + id: product.product_id, + base_currency: product.base_currency_id, + quote_currency: product.quote_currency_id, + base_increment: product.base_increment, + quote_increment: product.quote_increment, + display_name: product.display_name, + status: product.status, + status_message: String::new(), + min_market_funds: product.quote_min_size, + } + } +} + /// Represents a Market Trade received from the Websocket API. #[serde_as] #[derive(Serialize, Deserialize, Debug, Clone)] @@ -208,7 +225,7 @@ pub struct FuturesBalanceSummaryUpdate { #[serde_as] #[derive(Debug, Deserialize, Serialize, Clone)] -struct MarginWindowMeasure { +pub struct MarginWindowMeasure { margin_window_type: String, margin_level: String, #[serde_as(as = "DisplayFromStr")] diff --git a/src/models/websocket/types.rs b/src/models/websocket/types.rs index b05c0d8..ba8c685 100644 --- a/src/models/websocket/types.rs +++ b/src/models/websocket/types.rs @@ -2,8 +2,8 @@ use std::collections::{HashMap, HashSet}; use std::pin::Pin; use std::sync::Arc; +use futures::stream::{self, SelectAll}; use futures::Stream; -use futures_util::stream::{self, SelectAll}; use serde::Serialize; use tokio::sync::Mutex; use tokio_tungstenite::tungstenite::{Error as WsError, Message as WsMessage}; @@ -110,6 +110,11 @@ impl WebSocketEndpoints { endpoints } + + /// Check if the `WebSocketEndpoints` is empty. + pub(crate) fn is_empty(&self) -> bool { + self.endpoints.is_empty() + } } /// Stores the current subscriptions for each channel for each endpoint. @@ -133,7 +138,6 @@ impl WebSocketSubscriptions { } /// Add subscriptions to the specified channel. - pub(crate) async fn add( &mut self, channel: &Channel, @@ -232,6 +236,12 @@ impl From> for EndpointStream { } } +impl From for EndpointStream { + fn from(mut endpoints: WebSocketEndpoints) -> Self { + endpoints.extract_to_vec().into() + } +} + impl Stream for EndpointStream { type Item = Result; diff --git a/src/traits.rs b/src/traits.rs index 10d4eab..9159b1f 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,36 +1,10 @@ //! Traits used to allow interfacing with advanced functionality for end-users. -use async_trait::async_trait; use reqwest::Response; use serde::Serialize; -use crate::models::{product::Candle, websocket::Message}; use crate::types::CbResult; -/// Used to pass to a callback to the candle watcher on a successful ejection. -#[async_trait] -pub trait CandleCallback { - /// Called when a candle is succesfully ejected. - /// - /// # Arguments - /// - /// * `current_start` - Current UTC timestamp for a start. - /// * `product_id` - Product the candle belongs to. - /// * `candle` - Candle that was recently completed. - async fn candle_callback(&mut self, current_start: u64, product_id: String, candle: Candle); -} - -/// Used to pass objects to the listener for greater control over message processing. -#[async_trait] -pub trait MessageCallback { - /// This is called when processing a message from the WebSocket. - /// - /// # Arguments - /// - /// * `msg` - Message or Error received from the WebSocket. - async fn message_callback(&mut self, msg: CbResult); -} - /// Used to pass query/paramters for a URL. pub(crate) trait Query { /// Checks that the query is valid and the required fields are present. diff --git a/src/types.rs b/src/types.rs index 67382e7..5604337 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,6 +1,6 @@ //! Contains custom / shorthand types to simplify end-user code. -use futures_util::stream::SplitStream; +use futures::stream::SplitStream; use tokio::net::TcpStream; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; diff --git a/src/utils.rs b/src/utils.rs index 7c5a4fd..38e0dca 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -3,15 +3,6 @@ //! `utils` is a collection of helpful tools that may be required throughout the rest of the API. use std::fmt::{Display, Write}; -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; - -use async_trait::async_trait; - -use crate::models::websocket::Message; -use crate::traits::MessageCallback; -use crate::types::CbResult; /// Builds the URL Query to be sent to the API. pub(crate) struct QueryBuilder { @@ -43,7 +34,7 @@ impl QueryBuilder { } /// Adds a key-value pair to the query string if the value is present. - pub(crate) fn push_optional(self, key: &str, value: &Option) -> Self { + pub(crate) fn push_optional(self, key: &str, value: Option<&T>) -> Self { if let Some(v) = value { self.push(key, v) } else { @@ -55,7 +46,7 @@ impl QueryBuilder { pub(crate) fn push_optional_vec( mut self, key: &str, - values: &Option>, + values: Option<&Vec>, ) -> Self { if let Some(values) = values { for value in values { @@ -70,67 +61,3 @@ impl QueryBuilder { self.query } } - -type BoxCallback = - Box) -> Pin + Send>> + Send + Sync>; - -/// Used to wrap callback functions for the WebSocket Client's `listen()` function.. -pub struct FunctionCallback { - callback: Arc, -} - -impl FunctionCallback { - /// Creates a new `FunctionCallback` from an asynchronous function. - /// - /// # Arguments - /// - /// * `async_fn` - The asynchronous function to be called when a message is received. - pub fn from_async(async_fn: F) -> Self - where - F: Fn(CbResult) -> Fut + Send + Sync + 'static, - Fut: Future + Send + 'static, - { - let callback = move |msg: CbResult| -> Pin + Send>> { - let fut = async_fn(msg); - Box::pin(fut) - }; - - Self { - callback: Arc::new(Box::new(callback)), - } - } - - /// Creates a new `FunctionCallback` from a synchronous function. - /// - /// # Arguments - /// - /// * `sync_fn` - The synchronous function to be called when a message is received. - pub fn from_sync(sync_fn: F) -> Self - where - F: Fn(CbResult) + Send + Sync + 'static, - { - let sync_fn = Arc::new(sync_fn); - - let callback = { - let sync_fn = Arc::clone(&sync_fn); - move |msg: CbResult| -> Pin + Send>> { - let sync_fn = Arc::clone(&sync_fn); - Box::pin(async move { - (sync_fn)(msg); - }) - } - }; - - Self { - callback: Arc::new(Box::new(callback)), - } - } -} - -#[async_trait] -impl MessageCallback for FunctionCallback { - async fn message_callback(&mut self, msg: CbResult) { - let callback = Arc::clone(&self.callback); - (callback)(msg).await; - } -} diff --git a/src/websocket.rs b/src/websocket.rs index 451971d..7938058 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -4,17 +4,18 @@ //! Many parts of the REST API suggest using websockets instead due to ratelimits and being quicker //! for large amount of constantly changing data. +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; -use futures_util::stream::{self, SplitSink}; -use futures_util::{SinkExt, StreamExt}; +use futures::stream::{SplitSink, Stream}; +use futures::task::{noop_waker_ref, Context, Poll}; +use futures::{SinkExt, StreamExt}; use tokio::net::TcpStream; use tokio::sync::Mutex; -use tokio::task::JoinHandle; use tokio_tungstenite::tungstenite::{Error as WsError, Message as WsMessage}; use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; -use crate::candle_watcher::CandleWatcher; use crate::constants::websocket::{PUBLIC_ENDPOINT, SECURE_ENDPOINT}; use crate::errors::CbError; use crate::jwt::Jwt; @@ -24,7 +25,6 @@ use crate::models::websocket::{ }; use crate::time; use crate::token_bucket::{RateLimits, TokenBucket}; -use crate::traits::{CandleCallback, MessageCallback}; use crate::types::CbResult; #[cfg(feature = "config")] @@ -51,8 +51,8 @@ fn get_channel_endpoint(channel: &Channel) -> EndpointType { pub struct WebSocketClientBuilder { api_key: Option, api_secret: Option, - enable_public: bool, - enable_user: bool, + use_public: bool, + use_user: bool, max_retries: u32, public_bucket: Arc>, secure_bucket: Arc>, @@ -63,9 +63,9 @@ impl Default for WebSocketClientBuilder { Self { api_key: None, api_secret: None, - enable_public: true, // By default, enable public connection. - enable_user: false, // By default, do not enable secure connection. - max_retries: 0, // By default, do not auto-reconnect. + use_public: true, // By default, enable public connection. + use_user: false, // By default, disable user connection. + max_retries: 0, // By default, do not auto-reconnect. public_bucket: Arc::new(Mutex::new(TokenBucket::new( RateLimits::max_tokens(false, true), RateLimits::refresh_rate(false, true), @@ -96,7 +96,7 @@ impl WebSocketClientBuilder { { self.api_key = Some(config.coinbase().api_key.to_string()); self.api_secret = Some(config.coinbase().api_secret.to_string()); - self.enable_user = true; + self.use_user = true; self } @@ -109,7 +109,7 @@ impl WebSocketClientBuilder { pub fn with_authentication(mut self, key: &str, secret: &str) -> Self { self.api_key = Some(key.to_string()); self.api_secret = Some(secret.to_string()); - self.enable_user = true; + self.use_user = true; self } @@ -118,18 +118,8 @@ impl WebSocketClientBuilder { /// # Arguments /// /// * `enable` - Enable or disable the public connection. - pub fn enable_public(mut self, enable: bool) -> Self { - self.enable_public = enable; - self - } - - /// Enables or disables the secure user connection. - /// - /// # Arguments - /// - /// * `enable` - Enable or disable the secure user connection. - pub fn enable_user(mut self, enable: bool) -> Self { - self.enable_user = enable; + pub fn use_public(mut self, enable: bool) -> Self { + self.use_public = enable; self } @@ -164,14 +154,14 @@ impl WebSocketClientBuilder { /// Returns a `CbError` if the API key or secret are missing or if both public and secure connections are disabled. pub fn build(self) -> CbResult { // Ensure at least one connection is enabled. - if !self.enable_public && !self.enable_user { + if !self.use_public && !self.use_user { return Err(CbError::BadConnection( "At least one of public or secure connections must be enabled.".to_string(), )); } // Create JWT if user connection is enabled. - let jwt = if self.enable_user { + let jwt = if self.use_user { let key = self.api_key.ok_or_else(|| { CbError::BadPrivateKey("API key is required for authentication.".to_string()) })?; @@ -189,8 +179,8 @@ impl WebSocketClientBuilder { secure_bucket: self.secure_bucket, public_tx: Arc::new(Mutex::new(None)), secure_tx: Arc::new(Mutex::new(None)), - enable_public: self.enable_public, - enable_user: self.enable_user, + enable_public: self.use_public, + enable_user: self.use_user, max_retries: self.max_retries, subscriptions: Arc::new(Mutex::new(WebSocketSubscriptions::new())), }) @@ -241,7 +231,7 @@ impl WebSocketClient { /// # Errors /// /// Returns a `CbError` if the WebSocket connection fails. - pub async fn connect(&mut self) -> CbResult { + pub async fn connect(&self) -> CbResult { let mut endpoints = WebSocketEndpoints::default(); if self.enable_public { @@ -258,7 +248,7 @@ impl WebSocketClient { } /// Connects to the WebSocket endpoint. - async fn connect_endpoint(&mut self, endpoint_type: &EndpointType) -> CbResult { + async fn connect_endpoint(&self, endpoint_type: &EndpointType) -> CbResult { match endpoint_type { EndpointType::Public => { let (public_socket, _) = connect_async(PUBLIC_ENDPOINT).await.map_err(|why| { @@ -294,7 +284,7 @@ impl WebSocketClient { /// # Errors /// /// Returns a `CbError` if the WebSocket connection fails. - async fn reconnect(&mut self, endpoint_type: &EndpointType) -> CbResult { + async fn handle_reconnect(&mut self, endpoint_type: &EndpointType) -> CbResult { let endpoint = self.connect_endpoint(endpoint_type).await?; // Re-subscribe to previous channels for this endpoint. @@ -328,7 +318,7 @@ impl WebSocketClient { // Rety until max retries hit. while retries < self.max_retries { - match self.reconnect(endpoint_type).await { + match self.handle_reconnect(endpoint_type).await { Ok(endpoint) => return Ok(endpoint), Err(why) => { eprintln!( @@ -346,12 +336,32 @@ impl WebSocketClient { ))) } - /// Handles reconnection logic for endpoints. - async fn handle_reconnection(&mut self, stream: EndpointStream) -> Option { - match stream { + /// Reconnects to the WebSocket endpoint. Returns a new `EndpointStream`. + /// This is used when the WebSocket connection is lost. + /// + /// # Arguments + /// + /// * `stream` - The current `EndpointStream` that was being listened to. + /// + /// # Errors + /// + /// Returns a `CbError` if the WebSocket connection fails. + pub async fn reconnect(&mut self, stream: E) -> CbResult + where + E: Into, + { + let mut new_endpoints = WebSocketEndpoints::default(); + + match stream.into() { EndpointStream::Single(route, _) => { // Reconnect and return a new Single EndpointStream. - self.wait_on_reconnect(&route).await.ok().map(Into::into) + match self.wait_on_reconnect(&route).await { + Ok(endpoint) => { + new_endpoints.add(route, endpoint); + Ok(new_endpoints) + } + Err(why) => Err(why), + } } EndpointStream::Multiple(_) => { // Obtain all the endpoints that need to be reconnected. @@ -361,32 +371,24 @@ impl WebSocketClient { }; // Iterate over each endpoint and attempt to reconnect. - let mut new_endpoints = WebSocketEndpoints::default(); for endpoint_type in keys { - if let Ok(new_endpoint) = self.wait_on_reconnect(&endpoint_type).await { - new_endpoints.add(endpoint_type.clone(), new_endpoint); - } else { - eprintln!("Failed to reconnect: {endpoint_type:?}"); - return None; + match self.wait_on_reconnect(&endpoint_type).await { + Ok(new_endpoint) => { + new_endpoints.add(endpoint_type.clone(), new_endpoint); + } + Err(why) => { + return Err(why); + } } } - // Extract the readers (streams) from the new endpoints. - let streams = new_endpoints - .extract_to_vec() - .into_iter() - .map(|endpoint| match endpoint { - Endpoint::Public((_, reader)) | Endpoint::User((_, reader)) => reader, - }) - .collect::>(); - - // Create a new Multiple EndpointStream. - let mut select_all = stream::SelectAll::new(); - for stream in streams { - select_all.push(stream); + if new_endpoints.is_empty() { + return Err(CbError::BadConnection( + "Failed to reconnect to any endpoints.".to_string(), + )); } - Some(EndpointStream::Multiple(select_all)) + Ok(new_endpoints) } } } @@ -396,11 +398,12 @@ impl WebSocketClient { /// # Arguments /// /// * `endpoints` - A single `Endpoint` or multiple `WebSocketEndpoints`. - /// * `callback` - A callback object that implements the `MessageCallback` trait. - pub async fn listen(&mut self, endpoints: E, mut callback: T) + /// * `callback` - The asynchronous closure to invoke on each message. + pub async fn listen(&mut self, endpoints: E, mut callback: F) where - T: MessageCallback + Send + 'static, E: Into, + F: FnMut(CbResult) -> Fut + Send + 'static, + Fut: Future + Send, { let mut stream = endpoints.into(); @@ -409,20 +412,114 @@ impl WebSocketClient { if let Some(result) = Self::process_message(message) { if let Err(CbError::BadConnection(_)) = &result { // Handle reconnection logic. - if let Some(new_stream) = self.handle_reconnection(stream).await { - // Restart the loop with the new streams. - stream = new_stream; - break; + match self.reconnect(stream).await { + Ok(new_stream) => { + stream = new_stream.into(); + break; // Exit inner loop to reconnect. + } + Err(why) => { + eprintln!("Failed to reconnect: {why}"); + return; // Exit function if reconnection fails + } } + } + + // Invoke the asynchronous closure with the result. + callback(result).await; + } + } + } + } - // Reconnection failed, exit. - return; + /// Fetches messages from the WebSocket stream with a limit on the number of messages to fetch. + /// + /// NOTE: Adequate pauses / sleeps between calls should be added to prevent busy-looping. + /// + /// # Arguments + /// + /// * `stream` - The WebSocket stream to get messages from. + /// * `limit` - The maximum number of messages to fetch. Use `usize::MAX` to fetch all messages. + /// * `action` - The action to take on each message. + /// + /// # Errors + /// + /// Returns a `String` if the user returns an error within the action. + pub fn fetch_sync( + &self, + stream: &mut EndpointStream, + limit: usize, + mut action: F, + ) -> Result<(), String> + where + F: FnMut(CbResult) -> Result<(), String>, + { + let mut count = 0; + + while count <= limit || limit == usize::MAX { + // Use poll_next to check for available messages without waiting. + match Pin::new(&mut *stream).poll_next(&mut Context::from_waker(noop_waker_ref())) { + Poll::Ready(Some(message)) => { + // Process and add the message to the result vector if valid. + if let Some(result) = Self::process_message(message) { + action(result)?; } - callback.message_callback(result).await; + count += 1; + } + Poll::Ready(None) | Poll::Pending => { + // No more messages available or stream is pending; exit the loop. + break; } } } + + Ok(()) + } + + /// Asynchronously fetches messages from the WebSocket stream with a limit on the number of messages to fetch. + /// + /// NOTE: Adequate pauses / sleeps between calls should be added to prevent busy-looping. + /// + /// # Arguments + /// + /// * `stream` - The WebSocket stream to get messages from. + /// * `limit` - The maximum number of messages to fetch. Use `usize::MAX` to fetch all messages. + /// * `action` - The action to take on each message. + /// + /// # Errors + /// + /// Returns a `String` if the user returns an error within the action. + pub async fn fetch_async( + &self, + stream: &mut EndpointStream, + limit: usize, + mut action: F, + ) -> Result<(), String> + where + F: FnMut(CbResult) -> Fut, + Fut: Future>, + { + let mut count = 0; + + while count <= limit || limit == usize::MAX { + // Use poll_next to check for available messages without waiting. + match Pin::new(&mut *stream).poll_next(&mut Context::from_waker(noop_waker_ref())) { + Poll::Ready(Some(message)) => { + // Process and add the message to the result vector if valid. + if let Some(result) = Self::process_message(message) { + action(result).await?; + } + + count += 1; + } + Poll::Ready(None) | Poll::Pending => { + // No more messages available or stream is pending; exit the loop. + break; + } + } + } + + Ok(()) } /// Waits for a token to be consumable for the correct bucket. @@ -439,13 +536,12 @@ impl WebSocketClient { } } - /// Processes WebSocket messages and applies a callback. Created to ignore alternative message types. + /// Processes the WebSocket message and returns a `Message` if successful. /// /// # Arguments /// - /// * `message` - A WebSocket message to process. - /// * `callback` - A closure or function that processes parsed messages or errors. - fn process_message(message: Result) -> Option> { + /// * `message` - The WebSocket message to process. + pub fn process_message(message: Result) -> Option> { match message { Ok(msg) => match msg { WsMessage::Text(data) => { @@ -478,7 +574,7 @@ impl WebSocketClient { /// * `product_ids` - A vector of product IDs that are being changed. /// * `action` - The action being taken (either "subscribe" or "unsubscribe"). /// * `endpoint` - The endpoint type (either public or user). - pub(crate) async fn update( + async fn update( &mut self, channel: &Channel, product_ids: &[String], @@ -588,6 +684,7 @@ impl WebSocketClient { let mut subs = self.subscriptions.lock().await; subs.add(channel, product_ids, route).await; } + Ok(()) } @@ -629,45 +726,4 @@ impl WebSocketClient { } Ok(()) } - - /// Watches candles for a set of products, producing candles once they are considered complete. - /// - /// # Argument - /// - /// * `products` - Products to watch for candles for. - /// * `watcher` - User-defined struct that implements `CandleCallback` to send completed candles to. - /// - /// # Errors - /// - /// Returns a `CbError` if the public connection is not enabled. - pub async fn watch_candles( - mut self, - products: &[String], - watcher: T, - ) -> CbResult> - where - T: CandleCallback + Send + Sync + 'static, - { - if !self.enable_public { - return Err(CbError::BadConnection( - "Public connection is not enabled.".to_string(), - )); - } - - // Connect and spawn a task. - match self.connect().await?.take_endpoint(&EndpointType::Public) { - Some(public) => { - // Keep the connection open by subscribing to heartbeats and sub to candles. - self.subscribe(&Channel::Heartbeats, &[]).await?; - self.subscribe(&Channel::Candles, products).await?; - - // Start task to watch candles using user's watcher. - let listener = tokio::spawn(CandleWatcher::start(self, public, watcher)); - Ok(listener) - } - None => Err(CbError::BadConnection( - "Public connection is not connected.".to_string(), - )), - } - } }