diff --git a/main.py b/main.py index dd3ac95..a62f918 100644 --- a/main.py +++ b/main.py @@ -52,7 +52,7 @@ def get_face_url(code): class AutoCrawler: def __init__(self, skip_already_exist=True, n_threads=4, do_google=True, do_naver=True, download_path='download', - full_resolution=False, face=False, no_gui=False): + full_resolution=False, face=False, no_gui=False, limit=0): """ :param skip_already_exist: Skips keyword already downloaded before. This is needed when re-downloading. :param n_threads: Number of threads to download. @@ -62,6 +62,7 @@ def __init__(self, skip_already_exist=True, n_threads=4, do_google=True, do_nave :param full_resolution: Download full resolution image instead of thumbnails (slow) :param face: Face search mode :param no_gui: No GUI mode. Acceleration for full_resolution mode. + :param limit: Maximum count of images to download. (0: infinite) """ self.skip = skip_already_exist @@ -72,6 +73,7 @@ def __init__(self, skip_already_exist=True, n_threads=4, do_google=True, do_nave self.full_resolution = full_resolution self.face = face self.no_gui = no_gui + self.limit = limit os.makedirs('./{}'.format(self.download_path), exist_ok=True) @@ -158,13 +160,20 @@ def base64_to_object(src): data = base64.decodebytes(bytes(encoded, encoding='utf-8')) return data - def download_images(self, keyword, links, site_name): + def download_images(self, keyword, links, site_name, max_count=0): self.make_dir('{}/{}'.format(self.download_path, keyword.replace('"', ''))) total = len(links) + success_count = 0 + + if max_count == 0: + max_count = total for index, link in enumerate(links): + if success_count >= max_count: + break + try: - print('Downloading {} from {}: {} / {}'.format(keyword, site_name, index + 1, total)) + print('Downloading {} from {}: {} / {}'.format(keyword, site_name, success_count + 1, max_count)) if str(link).startswith('data:image/jpeg;base64'): response = self.base64_to_object(link) @@ -183,12 +192,14 @@ def download_images(self, keyword, links, site_name): path = no_ext_path + '.' + ext self.save_object_to_file(response, path, is_base64=is_base64) + success_count += 1 del response ext2 = self.validate_image(path) if ext2 is None: print('Unreadable file - {}'.format(link)) os.remove(path) + success_count -= 1 else: if ext != ext2: path2 = no_ext_path + '.' + ext2 @@ -229,7 +240,7 @@ def download_from_site(self, keyword, site_code): links = [] print('Downloading images from collected links... {} from {}'.format(keyword, site_name)) - self.download_images(keyword, links, site_name) + self.download_images(keyword, links, site_name, max_count=self.limit) print('Done {} : {}'.format(site_name, keyword)) @@ -328,6 +339,7 @@ def imbalance_check(self): parser.add_argument('--no_gui', type=str, default='auto', help='No GUI mode. Acceleration for full_resolution mode. ' 'But unstable on thumbnail mode. ' 'Default: "auto" - false if full=false, true if full=true') + parser.add_argument('--limit', type=int, default=0, help='Maximum count of images to download per site. (0: infinite)') args = parser.parse_args() _skip = False if str(args.skip).lower() == 'false' else True @@ -336,6 +348,7 @@ def imbalance_check(self): _naver = False if str(args.naver).lower() == 'false' else True _full = False if str(args.full).lower() == 'false' else True _face = False if str(args.face).lower() == 'false' else True + _limit = int(args.limit) no_gui_input = str(args.no_gui).lower() if no_gui_input == 'auto': @@ -345,10 +358,10 @@ def imbalance_check(self): else: _no_gui = False - print('Options - skip:{}, threads:{}, google:{}, naver:{}, full_resolution:{}, face:{}, no_gui:{}' - .format(_skip, _threads, _google, _naver, _full, _face, _no_gui)) + print('Options - skip:{}, threads:{}, google:{}, naver:{}, full_resolution:{}, face:{}, no_gui:{}, limit:{}' + .format(_skip, _threads, _google, _naver, _full, _face, _no_gui, _limit)) crawler = AutoCrawler(skip_already_exist=_skip, n_threads=_threads, do_google=_google, do_naver=_naver, full_resolution=_full, - face=_face, no_gui=_no_gui) + face=_face, no_gui=_no_gui, limit=_limit) crawler.do_crawling()