diff --git a/core/src.py b/core/src.py index 1773dc9..db62725 100644 --- a/core/src.py +++ b/core/src.py @@ -18,7 +18,7 @@ def __init__(self, proxy_port: str ): self._project_path: str = '' - self._base_path: str = '' + self._base_path: str = '' # 注意批量的时候要判断下host,不能直接用base_path self.all_paths:Union[Dict] = {} self._all_projects: Union[List, None] = None self._proxy_ip: str = proxy_ip @@ -220,8 +220,6 @@ def run(args) -> None: elif args.url_file: url_list=read_urls(args.url_file) - auto_exploit_swagger = AutoExploitSwagger(args.proxy_ip, args.proxy_port) - all_urls = {} for url in url_list: - all_urls.update(auto_exploit_swagger.get_all_urls(url)) - exploit_threads(auto_exploit_swagger, all_urls, args.exploit_threads) + auto_exploit_swagger = AutoExploitSwagger(args.proxy_ip, args.proxy_port) + exploit_threads(auto_exploit_swagger, auto_exploit_swagger.get_all_urls(url), args.exploit_threads) diff --git a/lib/common.py b/lib/common.py index 9c60a1c..71624c4 100644 --- a/lib/common.py +++ b/lib/common.py @@ -28,10 +28,16 @@ def banner(): def get_base_path(target_url: str) -> str: domain = urlparse(target_url) - domain = domain.scheme + "://" + domain.netloc + scheme = domain.scheme + domain = scheme + "://" + domain.netloc base_path = '' + host = '' try: res = json.loads(requests.get(url=target_url, timeout=5, verify=False).text) + if "host" in res.keys(): + host = scheme + "://" + res['host'] + else: + host = domain if "basePath" in res.keys(): base_path = res['basePath'] elif "servers" in res.keys(): @@ -41,4 +47,4 @@ def get_base_path(target_url: str) -> str: except Exception as e: logger.error("target_url timeout...") logger.error(e) - return (domain + base_path).rstrip('/') + return (host + base_path).rstrip('/') diff --git a/start.py b/start.py index 65d62f2..fb2a313 100644 --- a/start.py +++ b/start.py @@ -7,7 +7,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description='Swagger API 自动化扫描工具') - parser.add_argument("-u", "--url", dest='target_url', required=True, help="swagger api地址") + parser.add_argument("-u", "--url", dest='target_url', help="swagger api地址") parser.add_argument("-i", "--ip", dest='proxy_ip', default='127.0.0.1', help="proxy ip") parser.add_argument("-p", "--port", dest='proxy_port', default='7777', help="proxy port") parser.add_argument("-t", "--threads", dest='exploit_threads', default=10, help="线程数目,默认是10")