Skip to content

Commit

Permalink
add some features
Browse files Browse the repository at this point in the history
  • Loading branch information
GCS-ZHN committed Oct 7, 2022
1 parent 9ceede9 commit 5aa0462
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 32 deletions.
18 changes: 14 additions & 4 deletions src/socube/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def main(*args: str):
type=str,
default="balance",
help=help["basic_args"]["generate_mode"])
basic_args.add_argument("--generate-ratio",
type=float,
default=1.0,
help=help["basic_args"]["generate_ratio"])
basic_args.add_argument("--only-embedding",
action="store_true",
default=False,
Expand Down Expand Up @@ -127,6 +131,10 @@ def main(*args: str):
"-mp",
action="store_true",
help=help["model_args"]["enable_multiprocess"])
model_args.add_argument("--enable-ensemble",
"-ee",
action="store_true",
help=help["model_args"]["enable_ensemble"])

notice_args = parser.add_argument_group(help["notice_args"]["title"])
notice_args.add_argument("--mail",
Expand Down Expand Up @@ -273,7 +281,8 @@ def main(*args: str):
output_path=embedding_path,
adj=args.adj_factor,
seed=args.seed,
mode=args.generate_mode)
mode=args.generate_mode,
ratio=args.generate_ratio)

samples = samples.T
writeHdf(
Expand Down Expand Up @@ -343,11 +352,11 @@ def main(*args: str):
label_file="TrainLabel.csv",
threshold=args.threshold,
k=args.k,
once=False,
use_index=False,
step=5,
max_acc_limit=1,
multi_process=args.enable_multiprocess)
multi_process=args.enable_multiprocess,
once=(not args.enable_ensemble))

log("Inference", "Begin doublet detection output")
infer(data_dir=train_path,
Expand All @@ -361,6 +370,7 @@ def main(*args: str):
gpu_ids=gpu_ids,
with_eval=args.enable_validation,
seed=args.seed,
multi_process=args.enable_multiprocess)
multi_process=args.enable_multiprocess,
once=(not args.enable_ensemble))

em.setNormalInfo("Doublet detection finished")
6 changes: 4 additions & 2 deletions src/socube/help.en_US.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
"adj_factor": "The adjustment factor for the doublet expression level. By default it is assumed that the doublet expression level is twice the sinlget, but there are fluctuations in the real situation and the expression level can be changed by adjusting this factor. Default 1.0.",
"dim": "The target dimension for gene degradation is also the number of channels to train the model. Default 10.",
"cube_id": "If you want to reuse the socube embedding features obtained earlier, just specify the embedding ID, which is a string like \"yyyymmdd-HHMMSS-xxx\", along with the original output path.",
"generate_ratio": "The ratio of the number of generated doublets to the number of singlets. Default 1.0.",
"generate_mode": "The generate mode of in-silico doublets, \"balance\", \"heterotypic\" or \"homotypic\". Default \"balance\".",
"only_embedding": "This option is provided for users who only want to use socube embedding but do not require doublet detection"
"only_embedding": "This option is provided for users who only want to use socube embedding but do not require doublet detection. If this option is specified, the model training and doublet detection will not be performed."
},
"model_args": {
"title": "model training configuration",
Expand All @@ -24,7 +25,8 @@
"infer_batch_size": "Batch size of model inferring. Default 400.",
"threshold": "The classification threshold for doublet detection. The model outputs the probability score of doublet, which is greater than the threshold considered as doublet and vice versa for singlet. user can customize the threshold value. Default 0.5.",
"enable_validation": "This optional is provided for performance evaluation. You should input h5ad format data, and store label in its `obs` property. `obs` property is a `DataFrame` object and its label column named \"type\" and value is \"doublet\" and \"singlet\".",
"enable_multiprocess": "Enable multi process to make use of CPU's multiple cores."
"enable_multiprocess": "Enable multi process to make use of CPU's multiple cores.",
"enable_ensemble": "Enable ensemble model, which is the average of k models trained by k-fold cross-validation. Default False."
},
"notice_args": {
"title": "notice configuration",
Expand Down
6 changes: 4 additions & 2 deletions src/socube/help.zh_CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
"adj_factor": "二聚体表达水平的调整系数。默认情况下,假定二聚体的表达水平是单体的两倍,但实际情况存在波动,可以通过调整这个系数改变表达水平。默认为1.0。",
"dim": "基因特征降维的目标维度也是训练模型的通道数量。默认为10。",
"cube_id": "如果你想重新使用先前获得的socube嵌入特征,只需指定embedding ID,这是一个类似于 \"yyymmdd-HHMMSS-xxx\"的字符串,位于embedding子目录下。",
"generate_ratio": "生成训练集中二聚体与单体的数量比,默认为1.0。",
"generate_mode": "生成模拟二聚体的模式,可选值为\"balance\"\"heterotypic\"\"homotypic\"。默认为\"balance\"",
"only_embedding": "这个选项提供给那些只想使用socube的特征嵌入功能的用户,使用后不会进行二聚体检测"
"only_embedding": "这个选项提供给那些只想使用socube的特征嵌入功能的用户,使用后不会进行二聚体检测。默认为False。"
},
"model_args": {
"title": "模型训练配置",
Expand All @@ -24,7 +25,8 @@
"infer_batch_size": "模型推理的批量大小。默认为400。",
"threshold": "二聚体检测的分类阈值。该模型输出二聚体的概率分数,大于阈值的被认为是二聚体,反之为单体。用户可以自定义阈值。默认为0.5。",
"enable_validation": "这个选项是为性能评估提供的。你应该输入h5ad格式的数据,并在其`obs`属性中存储标签。`obs'属性是一个`DataFrame'对象,它的标签列名为 \"type\",值为 \"doublet \"\"singlet\"",
"enable_multiprocess": "启用多进程以利用CPU的多个核心。"
"enable_multiprocess": "启用多进程以利用CPU的多个核心。",
"enable_ensemble": "启用集成学习,将k-折交叉验证的k个模型集成成一个模型。"
},
"notice_args": {
"title": "消息通知配置",
Expand Down
2 changes: 1 addition & 1 deletion src/socube/task/doublet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def generateDoublet(samples: pd.DataFrame,
the second is the negative (singlet) samples.
"""
assert mode in ["heterotypic", "homotypic", "balance"], "mode must be one of 'balance', 'heterotypic', 'homotypic'"
log("Preprocess", "Generating doublet with mode: {}".format(mode))
log("Generate", "Generating doublet with mode: {}".format(mode))
values = samples.values
droplet_num = samples.shape[0]
if size is None or size <= 0:
Expand Down
90 changes: 67 additions & 23 deletions src/socube/task/doublet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,23 +186,38 @@ def fit(home_dir: str,
# If free memory of cuda:0 is less than required, a Runtime Error will occurred
# Otherwise, you will find TWO gpu device are used by your process in `nvidia-smi` report.

with ParallelManager(paral_type="process", max_workers=k if multi_process else 1, verbose=True) as pm:
results: List[Future] = []
if multi_process:
log("Train", "Use multi-process to train the model")
with ParallelManager(paral_type="process", max_workers=k, verbose=False) as pm:
results: List[Future] = []
for fold, (train_set, valid_set) in enumerate(dataset.kFold, 1):
device = torch.device("cpu")
if gpu_ids is not None and len(gpu_ids) > 0 and torch.cuda.is_available():
device = torch.device(gpu_ids[(fold-1)%len(gpu_ids)])
results.append(pm.submit(_fit, fold = fold, device=device, train_set=train_set, valid_set = valid_set, **train_kwargs))
# quit k-fold
if once:
break

for result in results:
rep, train_rep = result.result()
for head in rep:
report[head] = rep[head]
for head in train_rep:
train_record[head] = train_rep[head]
else:
for fold, (train_set, valid_set) in enumerate(dataset.kFold, 1):
device = torch.device("cpu")
if gpu_ids is not None and len(gpu_ids) > 0 and torch.cuda.is_available():
device = torch.device(gpu_ids[(fold-1)%len(gpu_ids)])
results.append(pm.submit(_fit, fold = fold, device=device, train_set=train_set, valid_set = valid_set, **train_kwargs))
# quit k-fold
if once:
break

for result in results:
rep, train_rep = result.result()
rep, train_rep = _fit(fold=fold, device=device, train_set=train_set, valid_set=valid_set, **train_kwargs)
for head in rep:
report[head] = rep[head]
for head in train_rep:
train_record[head] = train_rep[head]
# quit k-fold
if once:
break

report["average"] = report.mean(axis=1)
report["sample_stdev"] = report.std(axis=1)
Expand Down Expand Up @@ -437,7 +452,8 @@ def infer(data_dir: str,
gpu_ids: List[str] = None,
with_eval: bool = False,
seed: Optional[int] = None,
multi_process: bool = False):
multi_process: bool = False,
once: bool = False):
"""
Model inference
Expand Down Expand Up @@ -467,6 +483,8 @@ def infer(data_dir: str,
the seed for random
multi_process: bool
whether to use multi-process for inference
once: bool
whether use emsemble for inference
"""
dataset = ConvClassifyDataset(data_dir=data_dir,
labels=label_file,
Expand All @@ -485,13 +503,35 @@ def infer(data_dir: str,
mkDirs(plot_dir)
mkDirs(output_dir)
ensemble_score_list = []
with ParallelManager(max_workers=k if multi_process else 1, paral_type="process", verbose=True) as pm:
results: List[Future] = []
if multi_process:
with ParallelManager(max_workers=k, paral_type="process", verbose=True) as pm:
results: List[Future] = []
for fold in range(1, k + 1):
device = torch.device("cpu")
if gpu_ids is not None and len(gpu_ids) > 0 and torch.cuda.is_available():
device = torch.device(gpu_ids[(fold-1)%len(gpu_ids)])
results.append(pm.submit(_infer,
in_channels,
model_dir,
dataset,
batch_size,
device,
with_eval,
plot_dir,
threshold,
fold,
output_dir))

if once:
break
for result in results:
ensemble_score_list.append(result.result())
else:
for fold in range(1, k + 1):
device = torch.device("cpu")
if gpu_ids is not None and len(gpu_ids) > 0 and torch.cuda.is_available():
device = torch.device(gpu_ids[(fold-1)%len(gpu_ids)])
results.append(pm.submit(_infer,
ensemble_score_list.append(_infer(
in_channels,
model_dir,
dataset,
Expand All @@ -502,18 +542,22 @@ def infer(data_dir: str,
threshold,
fold,
output_dir))
for result in results:
ensemble_score_list.append(result.result())
log("Inference", "Model ensembling")
ensemble_score = sum(ensemble_score_list) / k
if once:
break

if once:
ensemble_score = ensemble_score_list[0]
else:
log("Inference", "Model ensembling")
ensemble_score = np.mean(ensemble_score_list, axis=0)

if with_eval:
if with_eval and not once:
writeCsv(
evaluateReport(
label=dataset._labels.iloc[:, 0].values,
score=ensemble_score,
roc_plot_file=os.path.join(plot_dir, f"inference_roc_{threshold}.png"),
prc_plot_file=os.path.join(plot_dir, f"inference_prc_{threshold}.png"),
roc_plot_file=os.path.join(plot_dir, f"inference_roc_{threshold}.pdf"),
prc_plot_file=os.path.join(plot_dir, f"inference_prc_{threshold}.pdf"),
threshold=threshold),
os.path.join(output_dir, f"inference_report_{threshold}.csv"))

Expand All @@ -534,7 +578,7 @@ def _infer(
plot_dir,
threshold,
fold,
output_dir):
output_dir) -> np.ndarray:
"""
Internal function for inference
"""
Expand All @@ -551,9 +595,9 @@ def _infer(
label=label,
score=score,
roc_plot_file=os.path.join(plot_dir,
f"inference_roc_{fold}_{threshold}.png"),
f"inference_roc_{fold}_{threshold}.pdf"),
prc_plot_file=os.path.join(plot_dir,
f"inference_prc_{fold}_{threshold}.png"),
f"inference_prc_{fold}_{threshold}.pdf"),
threshold=threshold),
os.path.join(output_dir, f"inference_report_{fold}_{threshold}.csv"))

Expand Down

0 comments on commit 5aa0462

Please sign in to comment.