Skip to content

Commit a6c08a1

Browse files
committed
refactored Gemini example
1 parent c47a793 commit a6c08a1

File tree

1 file changed

+120
-74
lines changed

1 file changed

+120
-74
lines changed

gemini_commandline/Main.hs

Lines changed: 120 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,143 @@
1-
import Control.Monad.IO.Class (liftIO)
1+
{-# LANGUAGE DeriveGeneric #-}
2+
{-# LANGUAGE OverloadedStrings #-}
3+
{-# LANGUAGE ScopedTypeVariables #-}
4+
25
import System.Environment (getArgs, getEnv)
3-
import qualified Data.Aeson as Aeson
4-
import Data.Aeson (FromJSON, ToJSON)
5-
import GHC.Generics
6+
import qualified Data.Aeson as Aeson -- Used for Aeson.encode, Aeson.object etc.
7+
import Data.Aeson (FromJSON, ToJSON, eitherDecode) -- Specific functions needed
8+
import GHC.Generics (Generic) -- Needed for deriving ToJSON/FromJSON
69
import Network.HTTP.Client.TLS (tlsManagerSettings)
7-
import Network.HTTP.Client (newManager, httpLbs, parseRequest, Request(..), RequestBody(..), responseBody, responseStatus)
10+
import Network.HTTP.Client (newManager, httpLbs, parseRequest, Manager, Request(..), RequestBody(..), Response(..), responseStatus)
811
import Network.HTTP.Types.Status (statusCode)
912
import qualified Data.Text as T
1013
import Data.Text.Encoding (encodeUtf8)
11-
import qualified Data.Vector as V
14+
import Control.Exception (SomeException, handle)
15+
16+
-- --- Request Data Types ---
17+
18+
data RequestPart = RequestPart
19+
{ reqText :: T.Text -- Using reqText to avoid name clash with Response Part's text
20+
} deriving (Show, Generic)
21+
22+
instance ToJSON RequestPart where
23+
toJSON (RequestPart t) = Aeson.object ["text" Aeson..= t]
24+
25+
data RequestContent = RequestContent
26+
{ reqParts :: [RequestPart] -- Using reqParts to avoid name clash
27+
} deriving (Show, Generic)
28+
29+
instance ToJSON RequestContent where
30+
toJSON (RequestContent p) = Aeson.object ["parts" Aeson..= p]
31+
32+
data GenerationConfig = GenerationConfig
33+
{ temperature :: Double
34+
, maxOutputTokens :: Int
35+
-- Add other config fields as needed (e.g., topP, topK)
36+
} deriving (Show, Generic, ToJSON)
1237

13-
data GeminiRequest = GeminiRequest
14-
{ prompt :: String
38+
data GeminiApiRequest = GeminiApiRequest
39+
{ contents :: [RequestContent]
40+
, generationConfig :: GenerationConfig
1541
} deriving (Show, Generic, ToJSON)
1642

17-
data GeminiResponse = GeminiResponse
18-
{ candidates :: [Candidate] -- Changed from choices to candidates
43+
44+
-- --- Response Data Types (mostly unchanged, renamed for clarity) ---
45+
46+
data ResponsePart = ResponsePart
47+
{ text :: String
1948
} deriving (Show, Generic, FromJSON)
2049

21-
data Candidate = Candidate
22-
{ content :: Content
50+
data ResponseContent = ResponseContent
51+
{ parts :: [ResponsePart]
2352
} deriving (Show, Generic, FromJSON)
2453

25-
data Content = Content
26-
{ parts :: [Part]
54+
data Candidate = Candidate
55+
{ content :: ResponseContent
2756
} deriving (Show, Generic, FromJSON)
2857

29-
data Part = Part
30-
  { text :: String
31-
  } deriving (Show, Generic, FromJSON, ToJSON)
58+
-- Assuming promptFeedback might be present at the top level of the response
59+
-- alongside candidates, adjust if it's nested differently.
60+
data SafetyRating = SafetyRating
61+
{ category :: String
62+
, probability :: String
63+
} deriving (Show, Generic, FromJSON)
3264

3365
data PromptFeedback = PromptFeedback
34-
  { blockReason :: Maybe String
35-
  , safetyRatings :: Maybe [SafetyRating]
36-
  } deriving (Show, Generic, FromJSON, ToJSON)
66+
{ blockReason :: Maybe String
67+
, safetyRatings :: Maybe [SafetyRating]
68+
} deriving (Show, Generic, FromJSON)
3769

38-
data SafetyRating = SafetyRating
39-
  { category :: String
40-
  , probability :: String
41-
  } deriving (Show, Generic, FromJSON, ToJSON)
70+
data GeminiApiResponse = GeminiApiResponse
71+
{ candidates :: [Candidate]
72+
, promptFeedback :: Maybe PromptFeedback -- Added optional promptFeedback
73+
} deriving (Show, Generic, FromJSON)
74+
75+
-- --- Completion Function ---
76+
77+
-- | Sends a prompt to the Gemini API and returns the completion text or an error.
78+
completion :: String -- ^ Google API Key
79+
-> Manager -- ^ HTTP Manager
80+
-> String -- ^ The user's prompt text
81+
-> IO (Either String String) -- ^ Left error message or Right completion text
82+
completion apiKey manager promptText = do
83+
initialRequest <- parseRequest "https://generativelanguage.googleapis.com/v1/models/gemini-2.0-flash:generateContent"
84+
let reqContent = RequestContent { reqParts = [RequestPart { reqText = T.pack promptText }] }
85+
let genConfig = GenerationConfig { temperature = 0.1, maxOutputTokens = 800 }
86+
let apiRequest = GeminiApiRequest { contents = [reqContent], generationConfig = genConfig }
87+
88+
let request = initialRequest
89+
{ requestHeaders =
90+
[ ("Content-Type", "application/json")
91+
, ("x-goog-api-key", encodeUtf8 $ T.pack apiKey)
92+
]
93+
, method = "POST"
94+
, requestBody = RequestBodyLBS $ Aeson.encode apiRequest
95+
}
96+
97+
response <- httpLbs request manager
98+
let status = responseStatus response
99+
body = responseBody response
100+
101+
if statusCode status == 200
102+
then do
103+
case eitherDecode body :: Either String GeminiApiResponse of
104+
Left err -> return $ Left ("Error decoding JSON response: " ++ err)
105+
Right geminiResponse ->
106+
case candidates geminiResponse of
107+
(candidate:_) ->
108+
case parts (content candidate) of
109+
(part:_) -> return $ Right (text part)
110+
[] -> return $ Left "Error: Received candidate with no parts."
111+
[] ->
112+
-- Check for blocking information if no candidates are present
113+
case promptFeedback geminiResponse of
114+
Just pf -> case blockReason pf of
115+
Just reason -> return $ Left ("API Error: Blocked - " ++ reason)
116+
Nothing -> return $ Left "Error: No candidates found and no block reason provided."
117+
Nothing -> return $ Left "Error: No candidates found in response."
118+
else do
119+
let err = "Error: API request failed with status " ++ show (statusCode status) ++ "\nBody: " ++ show body
120+
return $ Left err
121+
122+
-- --- Main Function ---
42123

43124
main :: IO ()
44125
main = do
45126
args <- getArgs
46127
case args of
47128
[] -> putStrLn "Error: Please provide a prompt as a command line argument."
48-
(arg:_) -> do -- Extract the argument directly
49-
apiKey <- getEnv "GOOGLE_API_KEY"
50-
51-
manager <- newManager tlsManagerSettings
52-
53-
initialRequest <- parseRequest "https://generativelanguage.googleapis.com/v1/models/gemini-2.0-flash:generateContent"
54-
55-
let geminiRequestBody = Aeson.object [
56-
     ("contents", Aeson.Array $ V.singleton $ Aeson.object [
57-
         ("parts", Aeson.Array $ V.singleton $ Aeson.object [
58-
             ("text", Aeson.String $ T.pack arg)
59-
         ])
60-
     ]),
61-
     ("generationConfig", Aeson.object [
62-
         ("temperature", Aeson.Number 0.1),
63-
         ("maxOutputTokens", Aeson.Number 800)
64-
     ])
65-
]
66-
67-
let request = initialRequest
68-
{ requestHeaders =
69-
[ ("Content-Type", "application/json")
70-
, ("x-goog-api-key", encodeUtf8 $ T.pack apiKey)
71-
]
72-
, method = "POST"
73-
, requestBody = RequestBodyLBS $ Aeson.encode geminiRequestBody
74-
}
75-
76-
response <- httpLbs request manager
77-
78-
let responseStatus' = responseStatus response
79-
80-
if statusCode responseStatus' == 200
81-
   then do
82-
     let maybeGeminiResponse = Aeson.decode (responseBody response) :: Maybe GeminiResponse
83-
     case maybeGeminiResponse of
84-
       Just geminiResponse -> do
85-
         case candidates geminiResponse of
86-
           (candidate:_) -> do
87-
             case parts (content candidate) of
88-
               (part:_) -> do -- Changed text_ to _ since it's unused
89-
                 liftIO $ putStrLn $ "Response:\n\n" ++ text part
90-
                [] -> do
91-
liftIO $ putStrLn "Error: No parts in content"
92-
           [] -> do
93-
             liftIO $ putStrLn "Error: No candidates in response"
94-
Nothing -> do
95-
liftIO $ putStrLn "Error: Failed to parse response"
96-
   else do
97-
     putStrLn $ "Error: " ++ show responseStatus'
129+
(promptArg:_) -> do
130+
apiKeyResult <- lookupEnv "GOOGLE_API_KEY" -- Using lookupEnv for safer handling
131+
case apiKeyResult of
132+
Nothing -> putStrLn "Error: GOOGLE_API_KEY environment variable not set."
133+
Just apiKey -> do
134+
manager <- newManager tlsManagerSettings
135+
result <- completion apiKey manager promptArg
136+
137+
case result of
138+
Left errMsg -> putStrLn $ "API Call Failed:\n" ++ errMsg
139+
Right completionText -> putStrLn $ "Response:\n\n" ++ completionText
140+
141+
-- Helper function (optional but good practice)
142+
lookupEnv :: String -> IO (Maybe String)
143+
lookupEnv name = handle (\(_ :: SomeException) -> return Nothing) $ Just <$> getEnv name

0 commit comments

Comments
 (0)