Skip to content

[Taskflow] Fix the recognition bug of json format with both PIR suffix and id2label #10487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from

Conversation

hanlintang
Copy link
Contributor

PR types

Bug fixes

PR changes

APIs

Description

修复了在调用taskflow读取静态模型时PIR模型文件与id2label词典后缀都为.json时,taskflow错误将id2label.json读取为模型文件的错误。

Bug触发条件:
在完成multi_class文本分类训练后,根据文档执行模型预测,调用taskflow。

aistudio@jupyter-227232-8957468:~/PaddleNLP/slm/applications/text_classification/multi_class$ python
Python 3.8.10 (default, May 26 2023, 14:05:08) 
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from paddlenlp import Taskflow
/home/aistudio/.local/lib/python3.8/site-packages/_distutils_hack/__init__.py:26: UserWarning: Setuptools is replacing distutils.
  warnings.warn("Setuptools is replacing distutils.")
>>> cls = Taskflow("text_classification", task_path='checkpoint/export', is_static_model=True)
[2025-04-24 06:17:29,876] [ WARNING] - checkpoint/export includes more than one '.json' file.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/aistudio/.local/lib/python3.8/site-packages/paddlenlp/taskflow/taskflow.py", line 869, in __init__
    self.task_instance = task_class(
  File "/home/aistudio/.local/lib/python3.8/site-packages/paddlenlp/taskflow/text_classification.py", line 141, in __init__
    self._get_inference_model()
  File "/home/aistudio/.local/lib/python3.8/site-packages/paddlenlp/taskflow/task.py", line 352, in _get_inference_model
    raise IOError(
OSError: checkpoint/export should include id2label.json and id2label.pdiparams while is_static_model is True

如果直接将模型json更改为id2label也会报错:

>>> cls = Taskflow("text_classification", task_path='./checkpoint/export/static_model', is_static_model=True)
[2025-04-24 06:24:38,298] [ WARNING] - ./checkpoint/export/static_model includes more than one '.json' file.
W0424 06:24:38.350104 12434 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 12.0, Runtime API Version: 12.6
W0424 06:24:38.354985 12434 gpu_resources.cc:164] device: 0, cuDNN Version: 9.5.
W0424 06:24:38.355019 12434 gpu_resources.cc:196] WARNING: device: 0. The installed Paddle is compiled with CUDA 12.6, but CUDA runtime version in your machine is 12.0, which may cause serious incompatible bug. Please recompile or reinstall Paddle with compatible CUDA version.
[2025-04-24 06:24:42,312] [    INFO] - Load id2label from ./checkpoint/export/static_model/id2label.json.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/aistudio/.local/lib/python3.8/site-packages/paddlenlp/taskflow/taskflow.py", line 869, in __init__
    self.task_instance = task_class(
  File "/home/aistudio/.local/lib/python3.8/site-packages/paddlenlp/taskflow/text_classification.py", line 142, in __init__
    self._construct_id2label()
  File "/home/aistudio/.local/lib/python3.8/site-packages/paddlenlp/taskflow/text_classification.py", line 249, in _construct_id2label
    self.id2label[int(i)] = id2label[i]
ValueError: invalid literal for int() with base 10: 'base_code'

修复后可以正常执行:

Python 3.8.10 (default, May 26 2023, 14:05:08) 
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from paddlenlp import Taskflow
/home/aistudio/.local/lib/python3.8/site-packages/_distutils_hack/__init__.py:26: UserWarning: Setuptools is replacing distutils.
  warnings.warn("Setuptools is replacing distutils.")
>>> cls = Taskflow("text_classification", task_path='checkpoint/export', is_static_model=True)
W0424 07:00:22.198881 15599 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 12.0, Runtime API Version: 12.6
W0424 07:00:22.203166 15599 gpu_resources.cc:164] device: 0, cuDNN Version: 9.5.
W0424 07:00:22.203220 15599 gpu_resources.cc:196] WARNING: device: 0. The installed Paddle is compiled with CUDA 12.6, but CUDA runtime version in your machine is 12.0, which may cause serious incompatible bug. Please recompile or reinstall Paddle with compatible CUDA version.
[2025-04-24 07:00:26,160] [    INFO] - Load id2label from checkpoint/export/id2label.json.
>>> cls(["黑苦荞茶的功效与作用及食用方法","幼儿挑食的生理原因是"])
[{'predictions': [{'label': '功效作用', 'score': 0.985942790111639}], 'text': '黑苦荞茶的功效与作用及食用方法'}, {'predictions': [{'label': '病因分析', 'score': 0.6330634642946378}], 'text': '幼儿挑食的生理原因是'}]

@DrownFish19

Copy link

paddle-bot bot commented Apr 24, 2025

Thanks for your contribution!

Copy link

codecov bot commented Apr 24, 2025

Codecov Report

Attention: Patch coverage is 0% with 14 lines in your changes missing coverage. Please review.

Project coverage is 48.67%. Comparing base (e8a19d3) to head (80b33f8).
Report is 2 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/taskflow/task.py 0.00% 14 Missing ⚠️

❌ Your patch check has failed because the patch coverage (0.00%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (48.67%) is below the target coverage (58.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop   #10487      +/-   ##
===========================================
- Coverage    48.67%   48.67%   -0.01%     
===========================================
  Files          768      768              
  Lines       126915   126921       +6     
===========================================
  Hits         61777    61777              
- Misses       65138    65144       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label Apr 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants