Skip to content
65 changes: 65 additions & 0 deletions decorator_switch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
def simple_decorator(func):
def wrapper(*args, **kwargs):
if __name__ == '__main__':
print("simple_decorator")
return func(*args, **kwargs)
return wrapper


def decorator_with_parameters(print_func=False):
def decorator(func):
def wrapper(*args, **kwargs):
if __name__ == '__main__':
print("decorator_with_parameters")
if print_func:
print(f'{func = }')
return func(*args, **kwargs)
return wrapper
return decorator


def second_order_decorator(sub_dec, *args, enabled=False, **kwargs):
def wrapper(func):
if enabled:
if __name__ == '__main__':
print("switchable_decorator")
if args or kwargs:
return sub_dec(*args, **kwargs)(func)
else:
return sub_dec(func)
else:
return func
return wrapper


@second_order_decorator(simple_decorator, enabled=True)
def my_function0():
print("my function 0\n")


@second_order_decorator(decorator_with_parameters, enabled=True, print_func=True)
def my_function1():
print("my function 1\n")


@second_order_decorator(decorator_with_parameters, enabled=False, print_func=True)
def my_function2():
print("my function 2\n")


@simple_decorator
def my_function3():
print("my function 3\n")


@decorator_with_parameters(print_func=True)
def my_function4():
print("my function 4\n")


if __name__ == '__main__':
my_function0()
my_function1()
my_function2()
my_function3()
my_function4()
57 changes: 37 additions & 20 deletions model_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,58 @@
from qlib.contrib.data.handler import Alpha158
from qlib.data.dataset.handler import DataHandlerLP
from qlib.contrib.report.analysis_position.report import _calculate_report_data
from decorator_switch import second_order_decorator as sod


# 定义任务
@task(name="load_config")
def load_config():
with open("workflow_config_lightgbm_Alpha158.yaml", "r") as f:
config = yaml.safe_load(f)
return config

CFG = load_config()
USE_PREFECT = CFG["use_prefect"]

@task(name="model_data_init")

@sod(task, enabled=USE_PREFECT, name="model_data_init")
def model_data_init(config):
# Initialize QLib
provider_uri = config["qlib_init"]["provider_uri"]
region = config["qlib_init"]["region"]
qlib.init_qlib(provider_uri=provider_uri, region=region)
logger = get_run_logger()
logger.info("QLib initialized successfully")
use_prefect = config["use_prefect"]
qlib.init(provider_uri=provider_uri, region=region)
if use_prefect:
logger = get_run_logger()
logger.info("QLib initialized successfully")

# Initialize model
model_config = config["task"]["model"]
model = init_instance_by_config(model_config)
logger.info(f"Model initialized: {model}")
if use_prefect:
logger.info(f"Model initialized: {model}")

# Initialize data
data_config = config["task"]["dataset"]
dataset = init_instance_by_config(data_config)
logger.info(f"Dataset initialized: {dataset}")
# data_handler_config = config["task"]["dataset"]["kwargs"]["handler"]
data_handler_config = config["data_handler_config"]
hd = Alpha158(**data_handler_config)
dataset_conf = config["task"]["dataset"]
dataset_conf["kwargs"]["handler"] = hd

dataset = init_instance_by_config(dataset_conf)
if use_prefect:
logger.info(f"Dataset initialized: {dataset}")

# Reweighter = task_config.get("reweighter", None)
history = hd.fetch()
history = history.reset_index()
history.head()
execute_sql("history.db", "DROP TABLE IF EXISTS history_db")
execute_sql("history.db", "CREATE TABLE history_db AS SELECT * FROM history")
execute_sql("history.db", "SELECT * FROM history_db")

return model, dataset


@task(name="train_and_predict")

@sod(task, enabled=USE_PREFECT, name="train_and_predict")
def train_and_predict(model, dataset):
model.fit(dataset)
pred = model.predict(dataset)
Expand All @@ -71,7 +88,7 @@ def execute_sql(db_name, sql):
con.sql(sql)


@task(name="strategy")
@sod(task, enabled=USE_PREFECT, name="strategy")
def strategy_simulator(config, pred):
STRATEGY_CONFIG = config["port_analysis_config"]["strategy"]["kwargs"]
STRATEGY_CONFIG["signal"] = pred
Expand All @@ -83,14 +100,14 @@ def strategy_simulator(config, pred):
return strategy_obj, executor_obj


@task(name="backtest_record")
@sod(task, enabled=USE_PREFECT, name="backtest_record")
def backtest_record(config, strategy_obj, executor_obj):
backtest_config = config["port_analysis_config"]["backtest"]
portfolio_metric_dict, indicator_dict = backtest(
executor=executor_obj, strategy=strategy_obj, **backtest_config
)

FREQ = "day"
FREQ = backtest_config['freq']
analysis_freq = "{0}{1}".format(*qlib.utils.time.Freq.parse(FREQ))
portfolio_metrics = portfolio_metric_dict.get(analysis_freq)
report_normal = portfolio_metrics[0]
Expand Down Expand Up @@ -123,8 +140,8 @@ def backtest_record(config, strategy_obj, executor_obj):
return report_df, indicators_normal


@task(name="risk_analysis")
def riskanalysis(report_normal):
@sod(task, enabled=USE_PREFECT, name="risk_analysis")
def risk_analysis_(report_normal):
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"]
Expand All @@ -141,13 +158,13 @@ def riskanalysis(report_normal):


# 定义流程
@flow(name="qlib_workflow", description="Demo Prefect")
def run_workflow(name="qlib_workflow"):
@sod(flow, enabled=USE_PREFECT, name="qlib_workflow", description="Demo Prefect")
def run_workflow(config=CFG, name="qlib_workflow"):
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("zhanyuan")
with mlflow.start_run() as run:
mlflow.lightgbm.autolog()
config = load_config()

model, dataset = model_data_init(config)
pred, label = train_and_predict(model, dataset)
strategy_obj, executor_obj = strategy_simulator(config, pred)
Expand Down
5 changes: 4 additions & 1 deletion workflow_config_lightgbm_Alpha158.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ qlib_init:
region: cn
market: &market csi300
benchmark: &benchmark SH000300
freq: &freq day
use_prefect: True
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
Expand Down Expand Up @@ -36,6 +38,7 @@ port_analysis_config: &port_analysis_config
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
freq: *freq
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
Expand All @@ -46,7 +49,7 @@ port_analysis_config: &port_analysis_config
class: SimulatorExecutor
module_path: qlib.backtest.executor
kwargs:
time_per_step: day
time_per_step: *freq
generate_portfolio_metrics: True

task:
Expand Down