Skip to content

Commit c4b5b68

Browse files
committed
support for openai o1 models
fixes #130
1 parent 7ee4c5c commit c4b5b68

File tree

1 file changed

+118
-13
lines changed

1 file changed

+118
-13
lines changed

org-ai-openai.el

Lines changed: 118 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,12 @@ For chat completion responses.")
237237
(defvar org-ai--current-request-callback nil
238238
"Internal var that stores the current request callback.")
239239

240+
(defvar org-ai--current-request-is-streamed nil
241+
"Whether we expect a streamed response or a single completion payload.")
242+
243+
(defvar org-ai--current-progress-reporter nil
244+
"progress-reporter for non-streamed responses to make them less boring.")
245+
240246
(defvar org-ai-after-chat-insertion-hook nil
241247
"Hook that is called when a chat response is inserted.
242248
Note this is called for every stream response so it will typically
@@ -394,6 +400,7 @@ from the OpenAI API."
394400
(list (make-org-ai--response :type 'stop :payload stop-reason))))
395401
((string= response-type "message_stop") nil)
396402

403+
397404
;; try perplexity.ai
398405
((and (plist-get response 'model) (string-prefix-p "llama-" (plist-get response 'model)))
399406
(let ((choices (plist-get response 'choices)))
@@ -412,7 +419,20 @@ from the OpenAI API."
412419
(when finish-reason
413420
(list (make-org-ai--response :type 'stop :payload finish-reason))))))))
414421

415-
;; fallback to openai
422+
;; single message e.g. from non-streamed completion
423+
((let ((choices (plist-get response 'choices)))
424+
(and (= 1 (length choices))
425+
(plist-get (aref choices 0) 'message)))
426+
(let* ((choices (plist-get response 'choices))
427+
(choice (aref choices 0))
428+
(text (plist-get (plist-get choice 'message) 'content))
429+
(role (plist-get (plist-get choice 'message) 'role))
430+
(finish-reason (or (plist-get choice 'finish_reason) 'stop)))
431+
(list (make-org-ai--response :type 'role :payload role)
432+
(make-org-ai--response :type 'text :payload text)
433+
(make-org-ai--response :type 'stop :payload finish-reason))))
434+
435+
;; try openai streamed
416436
(t (let ((choices (plist-get response 'choices)))
417437
(cl-loop for choice across choices
418438
append (or (when-let ((role (plist-get (plist-get choice 'delta) 'role)))
@@ -519,6 +539,7 @@ penalty. `PRESENCE-PENALTY' is the presence penalty."
519539
(setq org-ai--currently-inside-code-markers nil)
520540
(setq service (or (if (stringp service) (org-ai--read-service-name service) service)
521541
org-ai-service))
542+
(setq stream (org-ai--stream-supported service model))
522543

523544
(let* ((url-request-extra-headers (org-ai--get-headers service))
524545
(url-request-method "POST")
@@ -532,25 +553,32 @@ penalty. `PRESENCE-PENALTY' is the presence penalty."
532553
:frequency-penalty frequency-penalty
533554
:presence-penalty presence-penalty
534555
:service service
535-
:stream t)))
556+
:stream stream)))
536557
(org-ai--check-model model endpoint)
537558

538559
;; (message "REQUEST %s %s" endpoint url-request-data)
539560

561+
(setq org-ai--current-request-is-streamed stream)
540562
(setq org-ai--current-request-callback callback)
563+
(when (not stream) (org-ai--progress-reporter-until-request-done))
541564

542565
(setq org-ai--current-request-buffer-for-stream
543566
(url-retrieve
544567
endpoint
545568
(lambda (_events)
569+
(with-current-buffer org-ai--current-request-buffer-for-stream
570+
(org-ai--url-request-on-change-function nil nil nil))
546571
(org-ai--maybe-show-openai-request-error org-ai--current-request-buffer-for-stream)
547572
(org-ai-reset-stream-state))))
548573

549574
;; (display-buffer-use-some-window org-ai--current-request-buffer-for-stream nil)
550575

551-
(unless (member 'org-ai--url-request-on-change-function after-change-functions)
576+
(if stream
577+
(unless (member 'org-ai--url-request-on-change-function after-change-functions)
578+
(with-current-buffer org-ai--current-request-buffer-for-stream
579+
(add-hook 'after-change-functions #'org-ai--url-request-on-change-function nil t)))
552580
(with-current-buffer org-ai--current-request-buffer-for-stream
553-
(add-hook 'after-change-functions #'org-ai--url-request-on-change-function nil t)))
581+
(remove-hook 'after-change-functions #'org-ai--url-request-on-change-function t)))
554582

555583
org-ai--current-request-buffer-for-stream))
556584

@@ -600,25 +628,36 @@ temperature of the distribution. `TOP-P' is the top-p value.
600628
`FREQUENCY-PENALTY' is the frequency penalty. `PRESENCE-PENALTY'
601629
is the presence penalty.
602630
`STREAM' is a boolean indicating whether to stream the response."
603-
(let ((extra-system-prompt))
631+
(let ((extra-system-prompt)
632+
(max-completion-tokens))
604633

605634
(when (eq service 'anthropic)
606635
(when (string-equal (plist-get (aref messages 0) :role) "system")
607636
(setq extra-system-prompt (plist-get (aref messages 0) :content))
608637
(cl-shiftf messages (cl-subseq messages 1)))
609638
(setq max-tokens (or max-tokens 4096)))
610639

640+
;; o1 models currently does not support system prompt
641+
(when (and (or (eq service 'openai) (eq service 'azure-openai))
642+
(string-prefix-p "o1-" model))
643+
(setq messages (cl-remove-if (lambda (msg) (string-equal (plist-get msg :role) "system")) messages))
644+
;; o1 does not support max-tokens
645+
(when max-tokens
646+
(setq max-tokens nil)
647+
(setq max-completion-tokens (or max-tokens 128000))))
648+
611649
(let* ((input (if messages `(messages . ,messages) `(prompt . ,prompt)))
612650
;; TODO yet unsupported properties: n, stop, logit_bias, user
613651
(data (map-filter (lambda (x _) x)
614652
`(,input
615653
(model . ,model)
616-
,@(when stream `((stream . ,stream)))
617-
,@(when max-tokens `((max_tokens . ,max-tokens)))
618-
,@(when temperature `((temperature . ,temperature)))
619-
,@(when top-p `((top_p . ,top-p)))
620-
,@(when frequency-penalty `((frequency_penalty . ,frequency-penalty)))
621-
,@(when presence-penalty `((presence_penalty . ,presence-penalty)))))))
654+
,@(when stream `((stream . ,stream)))
655+
,@(when max-tokens `((max_tokens . ,max-tokens)))
656+
,@(when max-completion-tokens `((max-completion-tokens . ,max-completion-tokens)))
657+
,@(when temperature `((temperature . ,temperature)))
658+
,@(when top-p `((top_p . ,top-p)))
659+
,@(when frequency-penalty `((frequency_penalty . ,frequency-penalty)))
660+
,@(when presence-penalty `((presence_penalty . ,presence-penalty)))))))
622661

623662
(when extra-system-prompt
624663
(setq data (append data `((system . ,extra-system-prompt)))))
@@ -652,7 +691,28 @@ and the length in chars of the pre-change text replaced by that range."
652691
;; (list (buffer-substring-no-properties (point-min) (point-max))
653692
;; (point)))))
654693

655-
(while (and (not errored) (search-forward "data: " nil t))
694+
;; handle completion (non-streamed) response of a single json object
695+
(while (and (not org-ai--current-request-is-streamed)
696+
(not errored))
697+
(let ((json-object-type 'plist)
698+
(json-key-type 'symbol)
699+
(json-array-type 'vector))
700+
(condition-case _err
701+
(let ((data (json-read)))
702+
(when org-ai--current-request-callback
703+
(funcall org-ai--current-request-callback data)))
704+
(error
705+
(setq errored t))))
706+
(progn
707+
(when org-ai--current-request-callback
708+
(funcall org-ai--current-request-callback nil))
709+
(org-ai-reset-stream-state)
710+
(message "org-ai request done")))
711+
712+
;; handle stream completion response, multiple json objects prefixed with "data: "
713+
(while (and org-ai--current-request-is-streamed
714+
(not errored)
715+
(search-forward "data: " nil t))
656716
(let* ((line (buffer-substring-no-properties (point) (line-end-position))))
657717
;; (message "...found data: %s" line)
658718
(if (string= line "[DONE]")
@@ -676,6 +736,13 @@ and the length in chars of the pre-change text replaced by that range."
676736
(setq errored t)
677737
(goto-char org-ai--url-buffer-last-position-marker)))))))))))))
678738

739+
(defun org-ai--stream-supported (service model)
740+
"Check if the stream is supported by the service and model.
741+
`SERVICE' is the service to use. `MODEL' is the model to use."
742+
;; stream not supported by openai o1 models
743+
(not (and (or (eq service 'openai) (eq service 'azure-openai))
744+
(string-prefix-p "o1-" model))))
745+
679746
(defun org-ai-interrupt-current-request ()
680747
"Interrupt the current request."
681748
(interactive)
@@ -694,7 +761,45 @@ and the length in chars of the pre-change text replaced by that range."
694761
(setq org-ai--url-buffer-last-position-marker nil)))
695762
(setq org-ai--current-request-callback nil)
696763
(setq org-ai--url-buffer-last-position-marker nil)
697-
(setq org-ai--current-chat-role nil))
764+
(setq org-ai--current-chat-role nil)
765+
(setq org-ai--current-request-is-streamed nil)
766+
(when org-ai--current-progress-reporter
767+
(progress-reporter-done org-ai--current-progress-reporter)
768+
(setq org-ai--current-progress-reporter nil)))
769+
770+
(defcustom org-ai--witty-messages
771+
'("Pondering imponderables... Almost there!"
772+
"`grep`ing the neural net for answers..."
773+
"Fetching witty AI response... In the meantime, have you tried Vim? Just kidding!"
774+
"Teaching AI the ways of the Lisp."
775+
"Consulting the sacred parentheses."
776+
"Hold tight! The AI is garbage collecting its thoughts."
777+
"Fetching clever reply from `/dev/ai`."
778+
"The AI is busy counting parentheses. Almost there!"
779+
"Running in an infinite loop... Just kidding! Processing your request."
780+
"The AI is stuck in a `(cl-labels ((loop () (loop))) (loop))`... Wait, no it's not.")
781+
"Messages to entertain while waiting")
782+
783+
(defun org-ai--progress-reporter-until-request-done ()
784+
(when org-ai--current-progress-reporter
785+
(progress-reporter-done org-ai--current-progress-reporter))
786+
787+
(setq org-ai--current-progress-reporter
788+
(let ((msg (or
789+
(nth (random (length org-ai--witty-messages)) org-ai--witty-messages)
790+
"Waiting for a response")))
791+
(make-progress-reporter msg)))
792+
793+
(let ((counter 0))
794+
(run-with-idle-timer
795+
0 nil
796+
(lambda ()
797+
(while org-ai--current-progress-reporter
798+
(setq counter (1+ counter))
799+
(progress-reporter-update org-ai--current-progress-reporter)
800+
(sit-for 0.1))
801+
(progress-reporter-done reporter)
802+
(setq org-ai--current-progress-reporter nil)))))
698803

699804
(defun org-ai-open-request-buffer ()
700805
"A debug helper that opens the url request buffer."

0 commit comments

Comments
 (0)