diff --git a/mpcontribs-client/mpcontribs/client/__init__.py b/mpcontribs-client/mpcontribs/client/__init__.py index 3a1daefc4..9e42291d8 100644 --- a/mpcontribs-client/mpcontribs/client/__init__.py +++ b/mpcontribs-client/mpcontribs/client/__init__.py @@ -61,6 +61,7 @@ } SUPPORTED_FILETYPES = (Gz, Jpeg, Png, Gif, Tiff) SUPPORTED_MIMES = [t().mime for t in SUPPORTED_FILETYPES] +DEFAULT_DOWNLOAD_DIR = Path.home() / "mpcontribs-downloads" j2h = Json2Html() pd.options.plotting.backend = "plotly" @@ -428,6 +429,11 @@ def _is_valid_payload(self, model: str, data: dict): return True + def get_project_names() -> list: + """Retrieve list of project names.""" + resp = self.projects.get_entries(_fields=["name"]).result() + return [p["name"] for p in resp["data"]] + def get_project(self, name: str) -> Type[Dict]: """Retrieve full project entry @@ -610,8 +616,7 @@ def init_columns(self, name: str, columns: dict) -> dict: existing_columns.add(k) - resp = self.projects.get_entries(_fields=["name"]).result() - valid_projects = {p["name"] for p in resp["data"]} + valid_projects = self.get_project_names() if name not in valid_projects: return {"error": f"{name} doesn't exist or you don't have access!"} @@ -691,7 +696,8 @@ def delete_contributions( max_workers = MAX_WORKERS print(f"max_workers reset to max {MAX_WORKERS}") - cids = self.get_contributions(name)["ids"] + query = dict(project=name) + cids = self.get_all_ids(query=query).get(name, {}).get("ids", []) total = len(cids) # reset columns to be save (sometimes not all are reset BUGFIX?) self.projects.update_entry(pk=name, project={"columns": []}).result() @@ -713,7 +719,7 @@ def delete_contributions( ] self._run_futures(futures, total=len(cids)) - cids = self.get_contributions(name)["ids"] + cids = self.get_all_ids(query=query).get(name, {}).get("ids", []) if not retry: break @@ -732,57 +738,68 @@ def delete_contributions( @sleep_and_retry @limits(calls=175, period=60) - def get_unique_identifiers_flag(self, name: str) -> bool: - """Retrieve value of `unique_identifiers` flag for a project + def get_totals(self, query: dict = None, per_page: int = 1) -> tuple: + """Retrieve total count and pages for contributions matching query - Args: - name (str): name of the project - """ - return self.projects.get_entry( - pk=name, _fields=["unique_identifiers"] - ).result()["unique_identifiers"] - - @sleep_and_retry - @limits(calls=175, period=60) - def get_total_pages(self, name: str, per_page: int) -> int: - """Retrieve total number of pages for contributions in a project + See `client.contributions.get_entries()` for query. Args: - name (str): name of the project + query (dict): query to select contributions per_page (int): number of contributions per page """ - return self.contributions.get_entries( - project=name, per_page=per_page, _fields=["id"], - ).result()["total_pages"] + query = query or {} + [query.pop(k, None) for k in ["per_page", "_fields"]] + result = self.contributions.get_entries( + per_page=per_page, _fields=["id"], **query + ).result() + return result["total_count"], result["total_pages"] - def get_contributions(self, name: str) -> dict: - """Retrieve a list of existing contributions and their components for a project + def get_all_ids(self, query: dict, include: list = None) -> dict: + """Retrieve a list of existing contribution and component ObjectIds Args: - name (str): name of the project + query (dict): query to select contributions + include (list): components to include in response + + Returns: + {"": { + "ids": {}, + "identifiers": {}, + "structures": {}, + "tables": {}, + "attachments": {}, + }, ...} """ - ret = defaultdict(set) - ret["unique_identifiers"] = self.get_unique_identifiers_flag(name) - pages = self.get_total_pages(name, 250) - id_fields = ["id", "identifier"] - - @sleep_and_retry - @limits(calls=175, period=60) - def get_future(page): - future = session.get( - f"{self.url}/contributions/", - headers=self.headers, - params={ - "project": name, - "page": page, - "per_page": 250, - "_fields": ",".join(id_fields + COMPONENTS) - }, - ) - setattr(future, "track_id", page) - return future + include = include or [] + components = set(x for x in include if x in COMPONENTS) + if include and not components: + print(f"`include` must be subset of {COMPONENTS}!") + return + + ret, per_page = {}, 250 + query = query or {} + [query.pop(k, None) for k in ["page", "per_page", "_fields"]] + _, pages = self.get_totals(query=query, per_page=per_page) + id_fields = {"project", "id", "identifier"} + fields = ",".join(id_fields | components) + url = f"{self.url}/contributions/" + + # convert lists in query to comma-separated + for k, v in query.items(): + if isinstance(v, list): + query[k] = ",".join(v) with FuturesSession(max_workers=MAX_WORKERS) as session: + + @sleep_and_retry + @limits(calls=175, period=60) + def get_future(page): + params = {"page": page, "per_page": per_page, "_fields": fields} + params.update(query) + future = session.get(url, headers=self.headers, params=params) + setattr(future, "track_id", page) + return future + # bravado future doesn't work with concurrent.futures futures = [get_future(page + 1) for page in range(pages)] @@ -791,12 +808,20 @@ def get_future(page): for resp in responses.values(): for contrib in resp["data"]: - ret["ids"].add(contrib["id"]) - ret["identifiers"].add(contrib["identifier"]) + project = contrib["project"] + if project not in ret: + ret[project] = {k: set() for k in ["ids", "identifiers"]} + + ret[project]["ids"].add(contrib["id"]) + ret[project]["identifiers"].add(contrib["identifier"]) + + for component in components: + if component in contrib: + if component not in ret[project]: + ret[project][component] = set() - for component in COMPONENTS: - md5s = set(d["md5"] for d in contrib[component]) - ret[component] |= md5s + md5s = set(d["md5"] for d in contrib[component]) + ret[project][component] |= md5s futures = [ future @@ -837,15 +862,6 @@ def update_contributions(self, name: str, data: dict, query: dict = None) -> dic print(f"Updated {updated} contributions.") - def get_number_contributions(self, **query) -> int: - """Retrieve total number of contributions for query - - See `client.contributions.get_entries()` for keyword arguments used in query. - """ - return self.contributions.get_entries( - _fields=["id"], _limit=1, **query - ).result()["total_count"] - def publish(self, name: str, recursive: bool = False) -> dict: """Publish a project and optionally its contributions @@ -872,7 +888,6 @@ def publish(self, name: str, recursive: bool = False) -> dict: def submit_contributions( self, contributions: list, - skip_dupe_check: bool = False, ignore_dupes: bool = False, retry: bool = False, per_page: int = 100, @@ -902,7 +917,6 @@ def submit_contributions( Args: contributions (list): list of contribution dicts to submit - skip_dupe_check (bool): skip check for duplicates of identifiers and components ignore_dupes (bool): force duplicate components to be submitted retry (bool): keep trying until all contributions successfully submitted per_page (int): number of contributions to submit in each chunk/request @@ -920,7 +934,6 @@ def submit_contributions( # get existing contributions existing = defaultdict(set) - existing["unique_identifiers"] = True project_names = set() collect_ids = [] require_one_of = {"data"} | set(COMPONENTS) @@ -944,20 +957,23 @@ def submit_contributions( id2project = {} if collect_ids: - resp = self.contributions.get_entries( - id__in=collect_ids, _fields=["id", "project"] - ).result() - id2project = {c["id"]: c["project"] for c in resp["data"]} - project_names |= set(id2project.values()) + resp = self.get_all_ids(query=dict(id__in=collect_ids)) + project_names |= set(resp.keys()) + + for project_name, values in resp.items(): + for cid in values["ids"]: + id2project[cid] = project_name + print("get existing contributions ...") project_names = list(project_names) + existing = self.get_all_ids(query=dict(project__in=project_names), include=COMPONENTS) + resp = self.projects.get_entries( + name__in=project_names, _fields=["name", "unique_identifiers"] + ).result() - if not skip_dupe_check: - print("get existing contributions ...") - existing = { - project_name: self.get_contributions(project_name) - for project_name in project_names - } + for project in resp["data"]: + project_name = project["name"] + existing[project_name]["unique_identifiers"] = project["unique_identifiers"] # prepare contributions print("prepare contributions ...") @@ -1140,7 +1156,12 @@ def put_future(cdct): self._run_futures(futures, total=ncontribs) if existing[project_name]["unique_identifiers"] and retry: - existing[project_name] = self.get_contributions(project_name) + existing[project_name] = self.get_all_ids( + query=dict(project=project_name), include=COMPONENTS + ).get(project_name, {"identifiers": set()}) + existing[project_name]["unique_identifiers"] = self.projects.get_entry( + pk=project_name, _fields=["unique_identifiers"] + ).result()["unique_identifiers"] contribs[project_name] = [ c for c in contribs[project_name] if c["identifier"] not in existing[project_name]["identifiers"] @@ -1151,9 +1172,9 @@ def put_future(cdct): print("Please resubmit failed contributions manually.") end = datetime.utcnow() - updated_total = self.get_number_contributions( + updated_total, _ = self.get_totals(query=dict( project__in=project_names, last_modified__gt=start, last_modified__lt=end - ) + )) toc = time.perf_counter() dt = (toc - tic) / 60 self._load() @@ -1161,68 +1182,143 @@ def put_future(cdct): else: print("Nothing to submit.") - def download_contribution( + def download_contributions( self, - cid: str, - outdir: Union[str, Path] = "mpcontribs-downloads", + ids: list, + outdir: Union[str, Path] = DEFAULT_DOWNLOAD_DIR, overwrite: bool = False, - include: list = [] + include: list = None ) -> Path: - """Download a single contribution as .json.gz file(s) + """Download a list of contributions as .json.gz file(s) Args: - cid: contribution ObjectId - outdir: optional existing output directory - overwrite: force re-download of existing contribution/components - include: components to include in download + ids: list of contribution ObjectIds + outdir: optional output directory + overwrite: force re-download + include: components to include in downloads """ + include = include or [] outdir = Path(outdir) or Path(".") outdir.mkdir(parents=True, exist_ok=True) - if any(x for x in include if x not in COMPONENTS): + components = set(x for x in include if x in COMPONENTS) + if include and not components: print(f"`include` must be subset of {COMPONENTS}!") return - include = set(x for x in include if x in COMPONENTS) - fields = ["project"] + list(include) - contrib = self.contributions.get_entry(pk=cid, _fields=fields).result() - components = [] + all_ids = self.get_all_ids(query=dict(id__in=ids), include=components) - if include: - proj = self.projects.get_entry(pk=contrib["project"], _fields=["columns"]).result() - for column in proj["columns"]: - path = column["path"] - if path.split(".", 1)[0] in include: - components.append(path) + for name, values in all_ids.items(): + cids = list(values["ids"]) + print(name, self._download_resource( + resource="contributions", ids=cids, outdir=outdir, overwrite=overwrite + )) - model = self.get_model("ContributionsSchema") - fields = list(k for k in model._properties.keys() if k not in COMPONENTS) + components + for component in components: + ids = list(values[component]) + print(name, self._download_resource( + resource=component, ids=ids, outdir=outdir, overwrite=overwrite + )) - # download contribution - subdir = outdir / cid - subdir.mkdir(exist_ok=True) - path = subdir / "contribution.json.gz" + def download_structures( + self, + ids: list, + outdir: Union[str, Path] = DEFAULT_DOWNLOAD_DIR, + overwrite: bool = False + ) -> Path: + """Download a list of structures as a .json.gz file + + Args: + ids: list of structure ObjectIds + outdir: optional output directory + overwrite: force re-download + + Returns: + path of output file + """ + return self._download_resource( + resource="structures", ids=ids, outdir=outdir, overwrite=overwrite + ) + + def download_tables( + self, + ids: list, + outdir: Union[str, Path] = DEFAULT_DOWNLOAD_DIR, + overwrite: bool = False + ) -> Path: + """Download a list of tables as a .json.gz file + + Args: + ids: list of table ObjectIds + outdir: optional output directory + overwrite: force re-download + + Returns: + path of output file + """ + return self._download_resource( + resource="tables", ids=ids, outdir=outdir, overwrite=overwrite + ) + + def download_attachments( + self, + ids: list, + outdir: Union[str, Path] = DEFAULT_DOWNLOAD_DIR, + overwrite: bool = False + ) -> Path: + """Download a list of attachments as a .json.gz file + + Args: + ids: list of attachment ObjectIds + outdir: optional output directory + overwrite: force re-download + + Returns: + path of output file + """ + return self._download_resource( + resource="attachments", ids=ids, outdir=outdir, overwrite=overwrite + ) + + def _download_resource( + self, + resource: str, + ids: list, + outdir: Union[str, Path] = DEFAULT_DOWNLOAD_DIR, + overwrite: bool = False + ) -> Path: + """Helper to download a list of resources as .json.gz file + + Args: + resource: type of resource + ids: list of resource ObjectIds + outdir: optional output directory + overwrite: force re-download + + Returns: + path to output file + """ + # TODO chunk ids and paginate + resources = ["contributions"] + COMPONENTS + if resource not in resources: + print(f"`resource` must be one of {resources}!") + return + + outdir = Path(outdir) or Path(".") + subdir = outdir / resource + subdir.mkdir(parents=True, exist_ok=True) + digest = get_md5({"ids": ids}) + path = subdir / f"{digest}.json.gz" if not path.exists() or overwrite: - content = self.contributions.download_entries( - id=cid, short_mime="gz", format="json", _fields=fields + model = self.get_model(f"{resource.capitalize()}Schema") + fields = list(model._properties.keys()) + content = getattr(self, resource).download_entries( + id__in=ids, short_mime="gz", format="json", _fields=fields ).result() path.write_bytes(content) - print(path) - - # download components - for component in components: - path = subdir / f"{component}.json.gz" - if not path.exists() or overwrite: - ids = [x["id"] for x in contrib[component]] - model = self.get_model(f"{component.capitalize()}Schema") - fields = list(model._properties.keys()) - content = getattr(self, component).download_entries( - id__in=ids, short_mime="gz", format="json", _fields=fields - ).result() - path.write_bytes(content) - print(path) - - return subdir + + return path + def _run_futures(self, futures, total=None): """helper to run futures/requests"""