Skip to content

fix: fix concurrency problems #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
11a4792
fix: fix concurrency problems
bioinformatist Sep 23, 2023
7687747
fix: uses `Option<i64>` instead of a regular `i64` for `model_update_…
bioinformatist Sep 25, 2023
9c2834d
feat: remove `data_path` field
bioinformatist Sep 25, 2023
a12a982
fix: remove useless fields in metadata
bioinformatist Sep 25, 2023
59cf0a2
fix: use arc instead of clone
bioinformatist Sep 27, 2023
de6826f
test: add tests for Task and Inference (#12)
lazyky Sep 28, 2023
fc318c1
fix: change tokio version to 1.x
bioinformatist Oct 7, 2023
3eebcec
fix: change `build_sync` to pub
bioinformatist Oct 7, 2023
0ffd29b
fix: - bring task initial test back | - chang`&'a [&'a str]` to `&'a …
bioinformatist Oct 7, 2023
c47a513
feat: the lock on batch is optional now
bioinformatist Oct 7, 2023
aa772e9
fix: partially migrate `async_fn_in_trait` to `return_position_impl_t…
bioinformatist Oct 7, 2023
0bc963b
refactor: delete unnecessary code (#13)
lazyky Oct 7, 2023
4a63ec6
feat: Optimize the timing of task run() (#15)
lazyky Oct 14, 2023
9bb1631
feat: Merge `TaskConfig` into TDengine; Add `current_ts` (#16)
lazyky Dec 24, 2023
e3b0b1d
build: use nighly version in project
bioinformatist Jan 10, 2024
aac9678
build: switch `taos` back to version `*`
bioinformatist Jan 10, 2024
89c783f
refactor: rename `TDengine::from_dsn` to `TDengine::new`
bioinformatist Jan 10, 2024
0145a1d
refactor(core,-tdengine): replace `Task::run()` with `Task::prepare()`
bioinformatist Jan 17, 2024
787afca
style(core,tdengine): remove unnecessary trait bound
bioinformatist Jan 19, 2024
4dc0876
fix(cml-tdengine/src/core/task.rs): Remove the batch with working sta…
lazyky Jan 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ resolver = "2"
[workspace.dependencies]
anyhow = "1"
rand = "0.8.5"
dashmap = { git = "https://github.com/xacrimon/dashmap.git", branch = "master", features = ["inline"] }
derive_builder = "0.12.0"
deadpool = { version = "0.9.5", default-features = false, features = ["managed"] }
dashmap = { version = "5.5.0", features = ["inline"] }
chrono = "0.4.26"
burn = { git = "https://github.com/burn-rs/burn.git", branch = "main" }
burn = "0.9.0"
serde = { version = "1.0.163", features = ["derive"] }
typed-builder = "0.16.2"
derive-getters = "0.2.1"
6 changes: 3 additions & 3 deletions cml-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ edition = "2021"
[dependencies]
anyhow = { workspace = true }
dashmap = { workspace = true }
derive-getters = "0.2.1"
derive_builder = { workspace = true }
deadpool = { workspace = true }
derive-getters = { workspace = true }
typed-builder = { workspace = true }
deadpool = { version = "0.10.0", default-features = false, features = ["managed", "rt_tokio_1"] }
serde = { workspace = true }
chrono = { workspace = true }
1 change: 1 addition & 0 deletions cml-core/src/core.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod inference;
pub mod register;
pub mod task;
pub(crate) mod utils;
33 changes: 17 additions & 16 deletions cml-core/src/core/inference.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
use crate::metadata::MetaData;
use crate::{metadata::Metadata, SharedBatchState};
use anyhow::Result;
use deadpool::managed::{Manager, Pool};
use derive_getters::Getters;
use std::path::PathBuf;
use std::future::Future;
use typed_builder::TypedBuilder;

#[derive(Builder, Getters)]
#[derive(TypedBuilder, Getters)]
pub struct NewSample<F> {
data_path: PathBuf,
#[builder(default = "None")]
output: Option<F>,
#[builder(default = "None")]
#[builder(default, setter(strip_option))]
pub output: Option<F>,
#[builder(default, setter(strip_option))]
optional_fields: Option<Vec<F>>,
#[builder(default = "None")]
#[builder(default, setter(strip_option))]
optional_tags: Option<Vec<F>>,
}

pub trait Inference<M, F, T, C: Manager> {
async fn init_inference(
fn init_inference(
&self,
target_type: M,
optional_fields: Option<Vec<M>>,
optional_tags: Option<Vec<M>>,
) -> Result<()>;
) -> impl Future<Output = Result<()>> + Send;

async fn inference<FN>(
fn inference<FN>(
&self,
metadata: MetaData<F>,
available_status: &[&str],
data: &mut Vec<NewSample<F>>,
metadata: &Metadata<F>,
data: Vec<NewSample<F>>,
current_ts: Option<T>,
batch_state: Option<&SharedBatchState>,
pool: &Pool<C>,
inference_fn: FN,
) -> Result<()>
) -> impl Future<Output = Result<()>>
where
FN: FnOnce(&mut Vec<NewSample<F>>, &str, T) -> Vec<NewSample<F>>;
FN: FnOnce(&mut [NewSample<F>], &str, T);
}
43 changes: 13 additions & 30 deletions cml-core/src/core/register.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,31 @@
use crate::metadata::MetaData;
use crate::{metadata::Metadata, SharedBatchState};
use anyhow::Result;
use dashmap::DashMap;
use deadpool::managed::{Manager, Pool};
use derive_getters::Getters;
use std::{
path::PathBuf,
sync::{Arc, Mutex},
};
use std::future::Future;
use typed_builder::TypedBuilder;

#[derive(Builder, Getters)]
#[derive(TypedBuilder, Getters, Clone)]
pub struct TrainData<F> {
data_path: PathBuf,
gt: F,
#[builder(default = "None")]
#[builder(default, setter(strip_option))]
optional_fields: Option<Vec<F>>,
}

pub struct BatchState {
pub map: DashMap<String, bool>,
}

pub type SharedBatchState = Arc<Mutex<BatchState>>;

impl BatchState {
pub fn create(num_shards: usize) -> SharedBatchState {
Arc::new(Mutex::new(BatchState {
map: DashMap::with_shard_amount(num_shards),
}))
}
}

pub trait Register<M, F, C: Manager> {
async fn init_register(
pub trait Register<M, F, T, C: Manager> {
fn init_register(
&self,
gt_type: M,
optional_fields: Option<Vec<M>>,
optional_tags: Option<Vec<M>>,
) -> Result<()>;
) -> impl Future<Output = Result<()>> + Send;

async fn register(
fn register(
&self,
metadata: MetaData<F>,
metadata: &Metadata<F>,
train_data: Vec<TrainData<F>>,
batch_state: Arc<Mutex<BatchState>>,
current_ts: Option<T>,
batch_state: Option<&SharedBatchState>,
pool: &Pool<C>,
) -> Result<()>;
) -> impl Future<Output = Result<()>> + Send;
}
25 changes: 4 additions & 21 deletions cml-core/src/core/task.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,11 @@
use anyhow::Result;
use chrono::Duration;
use derive_getters::Getters;

#[derive(Builder, Getters, Clone)]
pub struct TaskConfig<'a> {
min_start_count: usize,
min_update_count: usize,
working_status: Vec<&'a str>,
limit_time: Duration,
}

use std::future::Future;
pub trait Task<M> {
async fn init_task(
fn init_task(
&self,
optional_fields: Option<Vec<M>>,
optional_tags: Option<Vec<M>>,
) -> Result<()>;
) -> impl Future<Output = Result<()>> + Send;

fn run<FN>(
&self,
task_config: TaskConfig,
build_from_scratch_fn: FN,
fining_build_fn: FN,
) -> Result<()>
where
FN: Fn(&str) -> Result<()> + Send + Sync;
fn prepare(&self) -> Result<(Vec<String>, Vec<String>)>;
}
3 changes: 3 additions & 0 deletions cml-core/src/core/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub fn get_placeholders<T>(fields: &[T]) -> String {
vec!["?"; fields.len()].join(", ")
}
7 changes: 6 additions & 1 deletion cml-core/src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use anyhow::Result;
use std::future::Future;

pub trait Handler {
type Database;
async fn init(self, client: &Self::Database, db: Option<&str>) -> Result<()>;
fn init(
self,
client: &Self::Database,
db: Option<&str>,
) -> impl Future<Output = Result<()>> + Send;
}
19 changes: 12 additions & 7 deletions cml-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
#![allow(incomplete_features)]
#![feature(async_fn_in_trait)]

#[macro_use]
extern crate derive_builder;
use std::{
collections::HashSet,
sync::{Arc, Condvar, Mutex},
};

pub mod core;
pub mod handler;
pub mod metadata;
mod handler;
mod metadata;

pub type SharedBatchState = Arc<(Mutex<HashSet<String>>, Condvar)>;

pub use core::{register::Register, utils::get_placeholders};
pub use handler::Handler;
pub use metadata::Metadata;
23 changes: 5 additions & 18 deletions cml-core/src/metadata.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
use derive_getters::Getters;
use typed_builder::TypedBuilder;

#[derive(Builder, Getters)]
pub struct MetaData<F> {
model_update_time: i64,
batch: String,
inherent_field_num: usize,
inherent_tag_num: usize,
optional_field_num: usize,
#[builder(default = "None")]
#[derive(TypedBuilder, Getters, Clone)]
pub struct Metadata<F> {
pub batch: String,
#[builder(default, setter(strip_option))]
optional_tags: Option<Vec<F>>,
}

impl<F> MetaData<F> {
pub fn get_placeholders(&self) -> (String, String) {
(
vec!["?"; self.optional_tags.as_ref().map_or(0, |v| v.len()) + self.inherent_tag_num]
.join(", "),
vec!["?"; self.optional_field_num + self.inherent_field_num].join(", "),
)
}
}
11 changes: 6 additions & 5 deletions cml-tdengine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ taos-query = { version = "*" }
anyhow = { workspace = true }
rand = { workspace = true }
dashmap = { workspace = true }
derive_builder = { workspace = true }
deadpool = { workspace = true }
typed-builder = { workspace = true }
serde = { workspace = true }
chrono = { workspace = true }
rayon = "1.7.0"
tokio = { version = "1.29.1", features = ["rt", "macros"] }
tokio = { version = "1", features = ["rt", "macros"] }
derive-getters = { workspace = true }

[dev-dependencies]
burn = { workspace = true }
burn-autodiff = { git = "https://github.com/burn-rs/burn.git", branch = "main" }
burn-ndarray = { git = "https://github.com/burn-rs/burn.git", branch = "main" }
burn-autodiff = "0.9.0"
burn-ndarray = "0.9.0"
tempfile = "3.8.0"
Loading