-
Notifications
You must be signed in to change notification settings - Fork 2
Explicitly specify an empty HF cache during testing of offline load #106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Which Dependabot PR do you mean here? I didnt merge this one in as the test failed #94 and it generally felt like a bit of toil I'd want us to avoid getting into |
Yes, that's what I meant. But it also failed on the The weird thing is that it works fine locally - both on 3.10 and 3.11 (haven't tried 3.12). Though I (somewhat obviously) don't run these things in parallel. |
Are we able somehow to narrow down exactly which network calls got made when they shouldnt? For example, in the service I noticed Gradio phones home as they're nice enough to log that (I will disable that...). My guess is some other libs could do the same a bit randomly? If we are sure it's transformers, it would be cool to get that assertion in the test anyway |
I guess you're saying that if we had a network call, we woudln't necessarily know that it was HF that did the call. And I suppose that's true. But the way it currently fails is there is no network call attempted when I would have expected one to be attempted. The test currently tests the following things:
And the current issue is with the 3rd item from the above. There are no logged attempts for a network call. |
Right yep its saying nothing got called! I misread the assert message :D |
c4f1061
to
677cfd4
Compare
# in such a situation | ||
@contextmanager | ||
def force_hf_download(): | ||
with tempfile.TemporaryDirectory() as temp_dir: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I'm officially out of my depth here it feels like
-
Forces the from_pretrained method to use force_download=True during the test
Yep I get this one -
Asserts that a network call is attempted (but refused)
Yep I see the assert for sure
But when do you actually expect it to call "transformers.BertModel.from_pretrained " inside the code next?
I just see serialize to dill then deserialize, but the deserialize doesnt call from_pretrained
, at least not in itself...
My gut is feeling like it should do something like "mc = deserialize(...); now do something with mc which calls from_pretrained", if this is anything like java/c# serialization anyway - like the saved file is just the object state, loading it back wont trigger a constructor or anything
But anyway - I'm sure there's a magic line that's likely to be found by people that know what they're doing :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot to press send on the above...
Looking at the new lines around the wait for async - I'd hope there's some function like "do something on mc, that waits until its ready internally". Feels like there should be some way to use the deserialise funciton and rely on the asyc calls having finished, else anyone using this (not in a test) would be equally stuck.
Would always want to avoid the waits as it implies an underlying design issue, esp if we can fix it inside the library itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Within deserialisation it'll get to MetaCATAddon.load_existing
. Which in turn calls MetaCATAddon.load
method. And that deserialises the underlying object. Which then should call the MetaCAT.__init__
. And that calls MetaCAT.get_model
, which (in case of Bert-based MetaCAT like here) inits BertForMetaAnnotation
. And that finally calls BertModel.from_pretrained
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is, this is built on a custom serialisation that is designed to tirgger the init. The trivial implementation of pickle
'ing stuff would indeed avoid the calls to__init__
. But we don't really want that since:
- some things break when you do that
- we don't want to save everything to disk in the same format that it is in the memory (i.e some things know how to save their bits better than I do)
- this allows us to be more backwards compatible in terms of loading older models (otherwise
pickle
ing would preserve the state of the class as well - not just its attributes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the new lines around the wait for async - I'd hope there's some function like "do something on mc, that waits until its ready internally". Feels like there should be some way to use the deserialise funciton and rely on the asyc calls having finished, else anyone using this (not in a test) would be equally stuck.
Would always want to avoid the waits as it implies an underlying design issue, esp if we can fix it inside the library itself.
I'm pretty sure this is an async issue because (when testing locally) the network call is done from the same process, but on a different thread.
Now, this isn't anything we've designed, it's something on transformers
side. I don't know this for certain, but my best guess is that they will wait for completion if/when the bit that's being downloaded is needed. Because - like you said - otherwise people would fail to use a model they've initialised / loaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we then call something on mc? Basically assert it works, not just that is the right class
mc = deserialise(self.mc_save_path)
cat = CAT(...)
cat.add_addon(mc)
result = cat.get_entities(..)
Assert result just to confirm its worked
# (Noting the real assertion is that this also hasnt made any network calls)
Feels like:
- Either the above works and something somewhere waits for the async calls
- OR no user can technically use this without risking it not being ready when they get_entities
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's certainly a good idea!
I can have it run through a document and that should work if my assumption about the lazy loading from above is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added something to workflow that runs through a document and an entity.
But it still fails.
So clearly it's actually getting the model from somewhere. And it does so without doing a network call (at least not in the way I'm guarding against).
And now it bloody passes.... EDIT: |
Alright - it looks like I finally figure out a way to mitigate this. By setting a different model varient before saving the model, I'm able to force the subsequent network call upon model loading off disk. My best guess is that the reason the test sometimes does and sometimes doesn't work otherwise is because of some sort of HF in-memory cache. Because other parts of the test suite will inevitably have used the same model, so it may be cached somewhere. But in some cases the cache is already invalidated by the time it comes to this test (which is why it intermittedly does succeed). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I knew it would be one magic line !
Makes a lot of sense that there's some extra caching behaviour we dont know about
Thanks for explaining a lot of this one to me, it's been good to learn about the internals here
Seems like the test introduced by the offline BERT MetaCAT load (#85) seems to be failing on (some) subsequent PRs.
Examples of failure:
https://github.com/CogStack/cogstack-nlp/actions/runs/17092576831/job/48478029082
The workflow seems to have run fine on the PR and subsequently upon merge. So perhaps the change is due to the changes in the dependabot PR that includes updating the actions? Perhaps the (new versions of the) actions use a more comprehensive caching system accross multiple runners? I.e something that downloads on one runner and uses the cached version on the other, regardless of me forcing the download through the API?
In any case, I wanted the test to make sure that the offline loading does in fact work. And in order to do that I need to make sure the cache isn't used. So this PR tried to force that behaviour.
EDIT:
Found another failure:
https://github.com/CogStack/cogstack-nlp/actions/runs/17092444219/job/48469017414
So clearly not (only) to do with the dependabot PR.