Skip to content

Commit

Permalink
add db support
Browse files Browse the repository at this point in the history
  • Loading branch information
jianzfb committed Dec 26, 2024
1 parent eec7c99 commit 431919e
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 5 deletions.
Empty file.
45 changes: 45 additions & 0 deletions antgo/pipeline/application/command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# -*- coding: UTF-8 -*-
# @Time : 2024/11/27 22:42
# @File : command.py
# @Author : jian<jian@mltalker.com>
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
from antvis.client.httprpc import *
from antgo import config
import os
import cv2
import base64
import numpy as np

# 启动,进度
class BackgroundTask(object):
def __init__(self):
self.task_message = None # 记录到数据库 task_user, task_id, task_message
pass

def _info(self):
return ['db', 'user']

def _config(self):
# 自定义db设计
{
'table': 'task', # 构建新表名字
'link': 'user', # 关联表
'fields': {
'task_name': 'str',
'task_create_time': 'date'
}
}

def __call__(self, *args, db=None, user=None):
# input
# db, user, params

# 执行
# 独立线程,执行后台func(task_id, task_message, **params)


# output
# task_id, create_time
pass
21 changes: 21 additions & 0 deletions antgo/pipeline/application/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: UTF-8 -*-
# @Time : 2020-05-28 22:36
# @File : db.py
# @Author : jian<jian@mltalker.com>
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import scoped_session
from antgo.pipeline.application import orm


def create_db(db_url, **db_kwargs):
session_factory = orm.new_session_factory(
db_url,
reset=False,
echo=False,
**db_kwargs
)
_scoped_session = scoped_session(session_factory)
return _scoped_session()
224 changes: 224 additions & 0 deletions antgo/pipeline/application/orm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
"""sqlalchemy ORM tools for the state of the constellation of processes"""
import datetime
import json
from sqlalchemy.types import TypeDecorator
from sqlalchemy.dialects.mysql import LONGTEXT,TEXT
from sqlalchemy import (
inspect,
Column, Integer, ForeignKey, Unicode, Boolean,
DateTime
)
from sqlalchemy.sql import func
from sqlalchemy.dialects.mysql import FLOAT
from sqlalchemy import JSON
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.pool import StaticPool
from sqlalchemy.pool import NullPool
from sqlalchemy.pool import QueuePool
from sqlalchemy.sql.expression import bindparam
from sqlalchemy import create_engine, Table
from sqlalchemy import LargeBinary
from sqlalchemy.types import JSON
from sqlalchemy.types import TypeDecorator, VARCHAR

from sqlalchemy import and_, or_
from sqlalchemy.orm import backref
from sqlalchemy.ext.mutable import Mutable
from sqlalchemy.pool import SingletonThreadPool
import time

Base = declarative_base()
Base.log = app_log


class APIToken(Base):
"""An API token"""
__tablename__ = 'api_tokens'

@declared_attr
def user_id(cls):
return Column(Integer, ForeignKey('users.id', ondelete="CASCADE"), nullable=True)
user = relationship('User', back_populates="api_token")

id = Column(Integer, primary_key=True)
hashed = Column(Unicode(1024))
prefix = Column(Unicode(1024))
prefix_length = 4
algorithm = "sha512"
rounds = 16384
salt_bytes = 8
create_time = Column(DateTime, default=datetime.datetime.now)
expire_time = Column(DateTime, default=datetime.datetime.now)

@property
def token(self):
raise AttributeError("token is write-only")

@token.setter
def token(self, token):
"""Store the hashed value and prefix for a token"""
self.prefix = token[:self.prefix_length]
self.hashed = hash_token(token, rounds=self.rounds, salt=self.salt_bytes, algorithm=self.algorithm)

def __repr__(self):
kind = ''
name = ''
if self.user is not None:
kind = 'user'
name = self.user.name
else:
# this shouldn't happen
kind = 'owner'
name = 'unknown'

return "<{cls}('{pre}...', {kind}='{name}')>".format(
cls=self.__class__.__name__,
pre=self.prefix,
kind=kind,
name=name,
)

@classmethod
def find(cls, db, token, *, kind=None):
"""Find a token object by value.
Returns None if not found.
`kind='user'` only returns API tokens for users
`kind='service'` only returns API tokens for services
"""
prefix = token[:cls.prefix_length]
# since we can't filter on hashed values, filter on prefix
# so we aren't comparing with all tokens
prefix_match = db.query(cls).filter(bindparam('prefix', prefix).startswith(cls.prefix))

if kind == 'user':
prefix_match = prefix_match.filter(cls.user_id != None)
elif kind is not None:
raise ValueError("kind must be 'user', 'service', or None, not %r" % kind)

for orm_token in prefix_match:
if orm_token.match(token):
return orm_token

def match(self, token):
"""Is this my token?"""
return compare_token(self.hashed, token)

@classmethod
def new(cls, token=None, user=None):
"""Generate a new API token for a user or service"""
# assert user or service or taskapp
# assert not (user and service and service)
db = None
db = inspect(user).session

if token is None:
token = new_token()
else:
if len(token) < 8:
raise ValueError("Tokens must be at least 8 characters, got %r" % token)
found = APIToken.find(db, token)
if found:
raise ValueError("Collision on token: %s..." % token[:4])

orm_token = APIToken(token=token)
assert user.id is not None
orm_token.user_id = user.id

orm_token.create_time = datetime.datetime.now()
db.add(orm_token)
db.commit()
return token


class Task(Base):
__tablename__ = 'task'
id = Column(Integer, primary_key=True)
task_name = Column(Unicode(50), default="", unique=True)

task_progress = Column(Unicode(2048), default="")
task_create_time = Column(DateTime, default=datetime.datetime.now)
task_stop_time = Column(DataTime, default=datetime.datetime.now)
task_is_finish = Column(Boolean, default=False)

@declared_attr
def user_id(cls):
return Column(Integer, ForeignKey('users.id', ondelete='CASCADE'), nullable=True)
user = relationship('User', back_populates="task")

def __repr__(self):
return "<{cls}('{name}')>".format(
cls=self.__class__.__name__,
name=self.task_name
)


class User(Base):
__tablename__ = 'users'
id = Column(Integer, primary_key=True)
name = Column(Unicode(50), unique=True) # 用户名
password = Column(Unicode(4096), default="") # 用户密码

task = relationship('Task',
back_populates='user',
cascade="all,delete, delete-orphan") # 用户创建的项目

login_time = Column(DateTime, default=datetime.datetime.now)
last_login_time = Column(DateTime, default=datetime.datetime.now)

api_token = relationship("APIToken", back_populates="user")
api_token_str = Column(Unicode(4096), default="")
cookie_id = Column(Unicode(1023), default="")

@classmethod
def find(cls, db, name):
"""Find a user by name.
Returns None if not found.
"""
return db.query(cls).filter(cls.name == name).first()

def new_api_token(self, token=None, reset=True):
"""Create a new API token
If `token` is given, load that token.
"""
token_str = APIToken.new(token=token, user=self)
if reset:
self.api_token_str = token_str
db = inspect(self).session
db.commit()
return token_str

def __repr__(self):
return "<{cls}('{name}')>".format(
cls=self.__class__.__name__,
name=self.name
)


def new_session_factory(url="sqlite:///:memory:", reset=False, **kwargs):
"""Create a new session at url"""
if url.startswith('sqlite'):
kwargs.setdefault('connect_args', {'check_same_thread': False})
kwargs.setdefault('poolclass', NullPool)

if url.startswith('mysql'):
kwargs.setdefault('pool_recycle', 3600)
kwargs.setdefault('pool_size', 10)

if url.endswith(':memory:'):
# If we're using an in-memory database, ensure that only one connection
# is ever created.
kwargs.setdefault('poolclass', NullPool)

engine = create_engine(url, **kwargs)

if reset:
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)

session_factory = sessionmaker(bind=engine)
return session_factory
28 changes: 28 additions & 0 deletions antgo/pipeline/application/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# -*- coding: UTF-8 -*-
# @Time : 2024/11/27 22:42
# @File : user.py
# @Author : jian<jian@mltalker.com>
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
from antvis.client.httprpc import *
from antgo import config
import os
import cv2
import base64
import numpy as np


# 单点认证,认证通过执行,不通过暂停管线执行退出
class CasAuth(object):
def __init__(self):
pass

def __call__(self, *args):
# input
# ticket, token
#

# output
# user
pass
10 changes: 10 additions & 0 deletions antgo/pipeline/engine/execution/base_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def __apply__(self, *arg, **kws):
self._op._index = self._index
if isinstance(self._op, RemoteApiOp):
self._op._index = self._index

if hasattr(self._op, '_info'):
info = self._op.info()
for key in self._op._info():
kws.update({key: getattr(arg[0], key, None)})
return self._op(*args, **kws)

def __call__(self, *arg, **kws):
Expand Down Expand Up @@ -68,6 +73,11 @@ def __call__(self, *arg, **kws):
self._op._index = self._index
if isinstance(self._op, RemoteApiOp):
self._op._index = self._index

if hasattr(self._op, '_info'):
info = self._op.info()
for key in self._op._info():
kws.update({key: getattr(arg[0], key, None)})
res = self._op(*arg, **kws)
return res
except Exception:
Expand Down
5 changes: 3 additions & 2 deletions antgo/pipeline/functional/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

__graph_info = {}

def add_op_info(op_name, op_index, op_args, op_kwargs):
def add_op_info(op_name, op_index, op_args, op_kwargs, op_config=None):
global __graph_info
if 'op' not in __graph_info:
__graph_info['op'] = []
Expand All @@ -18,7 +18,8 @@ def add_op_info(op_name, op_index, op_args, op_kwargs):
'op_name': op_name, # 算子名称
'op_index': op_index, # 上下游数据流转
'op_args': op_args, # 算子参数
'op_kwargs': op_kwargs # 算子参数
'op_kwargs': op_kwargs, # 算子参数
'op_config': op_config
})


Expand Down
6 changes: 3 additions & 3 deletions antgo/pipeline/functional/data_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ def wrapper(*arg, **kws):
# pylint: disable=protected-access
path = hp._name
index = hp._index

# 添加算子节点信息到全局配置
add_op_info(path, index, arg, kws)
# 解析算子对象
op = self.resolve(path, index, *arg, **kws)
# 添加算子节点信息到全局配置
add_op_info(path, index, arg, kws, op._config() if getattr(op, '_config', None) else None)
return self.map(op)

return getattr(wrapper, name)
Expand Down
Loading

0 comments on commit 431919e

Please sign in to comment.