Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: HuggingFace integration #286

Open
thomaspinder opened this issue Jun 1, 2023 · 3 comments
Open

feat: HuggingFace integration #286

thomaspinder opened this issue Jun 1, 2023 · 3 comments
Labels
documentation Improvements or additions to documentation enhancement New feature or request no-stale
Milestone

Comments

@thomaspinder
Copy link
Collaborator

Feature Request

Demonstrate hoiw a GPJax dataset may be used to hold a HuggingFace dataset.

Describe Preferred Solution

A generic method that can coerce a HuggingFace dataset to a GPJax dataset would be ideally. If this is unfeasible, a notebook would also be most welcome.

Tagging @ingmarschuster who has developed a prototype of this.

@thomaspinder thomaspinder added documentation Improvements or additions to documentation enhancement New feature or request labels Jun 1, 2023
@thomaspinder thomaspinder added this to the v1.0.0 milestone Jun 1, 2023
@ingmarschuster
Copy link
Contributor

ingmarschuster commented Jul 10, 2023

If we could get Huggingface datasets directly to be Pytrees, that would be the ideal I guess. I haven't found out how to do this (maybe because I am not an expert on Pytrees). However, one can easily "view" huggingface datasets as dictionaries:

import datasets as ds

# dummy for demo
dat = ds.Dataset.from_dict({"X": jnp.arange(5), "y": jnp.arange(5)+1})

dat.set_transform(lambda x: x)
dat.set_format("jax")

Now dat[0:3] and any (automatic) batching by huggingface lib returns dicts like {'X': Array([0, 1, 2], dtype=int32), 'y': Array([1, 2, 3], dtype=int32)} and one can take advantage of the underlying super-fast apache parqet/Apache arrow.

An alternative would be to implement some function gpx.hfdata_to_dataset_fn that takes a huggingface dataset and turns it into the corresponding GPJax Dataset object, then use

dat.set_transform(gpx.hfdata_to_dataset_fn)

@ingmarschuster
Copy link
Contributor

When making this change, we could introduce the possibility use dicts for Dataset.X (and potentially Dataset.y), which blends nicely with semantic naming of input (output) dimension as is done in Huggingface datasets/pandas dataframes.
To do this elegantly, we would implement the possibility for kernels to store which key they are responsible for. For example the kernel
RBF(features=[“LongLat”]) + PoweredExponential(features=[“SecondsSince1970”])
would use an RBF on data.X['LongLat'], a PoweredExponential on data.X['SecondsSince1970']. This is a more natural alternative to the active_dims parameter in kernel constructors.
We can automatically check if all the features needed by a model are there by just making a list of all features required by the combined kernel.

Copy link

github-actions bot commented Sep 1, 2024

There has been no recent activity on this issue. To keep our issues log clean, we remove old and inactive issues.
Please update to the latest version of GPJax and check if that resolves the issue. Let us know if that works for you by leaving a comment.
This issue is now marked as stale and will be closed if no further activity occurs. If you believe that this is incorrect, please comment. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request no-stale
Projects
None yet
Development

No branches or pull requests

2 participants