moly_kit/clients/
openai_image.rs

1//! Client based on the OpenAI one, but hits the image generation API instead.
2
3use crate::protocol::Tool;
4use crate::protocol::*;
5use crate::utils::asynchronous::{BoxPlatformSendFuture, BoxPlatformSendStream};
6use reqwest::header::{HeaderMap, HeaderName};
7use std::{
8    str::FromStr,
9    sync::{Arc, RwLock},
10};
11
12#[derive(Debug, Clone)]
13struct OpenAIImageClientInner {
14    url: String,
15    client: reqwest::Client,
16    headers: HeaderMap,
17}
18
19/// Specific OpenAI client to hit image generation endpoints.
20///
21/// If used as part of a [`crate::clients::MultiClient`], it's recommended to add this
22/// before the standard OpenAI client to ensure it get's priority. This is not strictly
23/// necessary if the OpenAI client recognizes and filters the image models you use.
24#[derive(Debug)]
25pub struct OpenAIImageClient(Arc<RwLock<OpenAIImageClientInner>>);
26
27impl Clone for OpenAIImageClient {
28    fn clone(&self) -> Self {
29        OpenAIImageClient(Arc::clone(&self.0))
30    }
31}
32
33impl OpenAIImageClient {
34    pub fn new(url: String) -> Self {
35        let headers = HeaderMap::new();
36        let client = default_client();
37
38        let inner = OpenAIImageClientInner {
39            url,
40            client,
41            headers,
42        };
43
44        OpenAIImageClient(Arc::new(RwLock::new(inner)))
45    }
46
47    pub fn set_header(&mut self, key: &str, value: &str) -> Result<(), &'static str> {
48        let header_name = HeaderName::from_str(key).map_err(|_| "Invalid header name")?;
49
50        let header_value = value.parse().map_err(|_| "Invalid header value")?;
51
52        self.0
53            .write()
54            .unwrap()
55            .headers
56            .insert(header_name, header_value);
57
58        Ok(())
59    }
60
61    pub fn set_key(&mut self, key: &str) -> Result<(), &'static str> {
62        self.set_header("Authorization", &format!("Bearer {}", key))
63    }
64
65    pub fn get_url(&self) -> String {
66        self.0.read().unwrap().url.clone()
67    }
68
69    async fn generate_image(
70        &self,
71        bot_id: &BotId,
72        messages: &[Message],
73    ) -> Result<MessageContent, ClientError> {
74        let inner = self.0.read().unwrap().clone();
75
76        let prompt = messages
77            .last()
78            .map(|msg| msg.content.text.as_str())
79            .ok_or_else(|| {
80                ClientError::new(ClientErrorKind::Unknown, "No messages provided".to_string())
81            })?;
82
83        let url = format!("{}/images/generations", inner.url);
84
85        let request_json = serde_json::json!({
86            "model": bot_id.id(),
87            "prompt": prompt,
88            // "auto" is supported by `gpt-image` but not for `dall-e`.
89            "size": "1024x1024",
90            // `gpt-image` always returns base64, but `dall-e` supports
91            // and defaults to `url` response format.
92            "response_format": "b64_json"
93        });
94
95        let request = inner
96            .client
97            .post(&url)
98            .headers(inner.headers.clone())
99            .json(&request_json);
100
101        let response = request.send().await.map_err(|e| {
102            ClientError::new_with_source(
103                ClientErrorKind::Network,
104                format!(
105                    "Could not send request to {url}. Verify your connection and the server status."
106                ),
107                Some(e),
108            )
109        })?;
110
111        let status = response.status();
112        let text = response.text().await.unwrap_or_default();
113
114        if !status.is_success() {
115            return Err(ClientError::new(
116                ClientErrorKind::Response,
117                format!(
118                    "Request to {url} failed with status {} and content: {}",
119                    status, text
120                ),
121            ));
122        }
123
124        let response_json: serde_json::Value = serde_json::from_str(&text).map_err(|e| {
125            ClientError::new_with_source(
126                ClientErrorKind::Format,
127                format!(
128                    "Failed to parse response from {url}. It does not match the expected format."
129                ),
130                Some(e),
131            )
132        })?;
133
134        let image_data = response_json
135            .get("data")
136            .and_then(|data| data.get(0))
137            .and_then(|item| item.get("b64_json"))
138            .and_then(|b64| b64.as_str())
139            .ok_or_else(|| {
140                ClientError::new(
141                    ClientErrorKind::Format,
142                    "Response does not contain expected 'b64_json' field".to_string(),
143                )
144            })?;
145
146        let attachment =
147            Attachment::from_base64("image.png".into(), Some("image/png".into()), image_data)
148                .map_err(|e| {
149                    ClientError::new_with_source(
150                        ClientErrorKind::Format,
151                        "Failed to create attachment from base64 data".to_string(),
152                        Some(e),
153                    )
154                })?;
155
156        let content = MessageContent {
157            text: String::new(),
158            attachments: vec![attachment],
159            ..Default::default()
160        };
161
162        Ok(content)
163    }
164}
165
166impl BotClient for OpenAIImageClient {
167    fn bots(&self) -> BoxPlatformSendFuture<'static, ClientResult<Vec<Bot>>> {
168        let inner = self.0.read().unwrap().clone();
169
170        // Hardcoded list of OpenAI-only image generation models that are currently
171        // available and supported.
172        let supported: Vec<Bot> = ["dall-e-2", "dall-e-3", "gpt-image-1"]
173            .into_iter()
174            .map(|id| Bot {
175                id: BotId::new(id, &inner.url),
176                name: id.to_string(),
177                avatar: Picture::Grapheme("I".into()),
178                capabilities: BotCapabilities::new(),
179            })
180            .collect();
181
182        Box::pin(futures::future::ready(ClientResult::new_ok(supported)))
183    }
184
185    fn send(
186        &mut self,
187        bot_id: &BotId,
188        messages: &[Message],
189        _tools: &[Tool],
190    ) -> BoxPlatformSendStream<'static, ClientResult<MessageContent>> {
191        let self_clone = self.clone();
192        let bot_id = bot_id.clone();
193        let messages = messages.to_vec();
194
195        Box::pin(async_stream::stream! {
196            match self_clone.generate_image(&bot_id, &messages).await {
197                Ok(content) => yield ClientResult::new_ok(content),
198                Err(e) => yield ClientResult::new_err(e.into()),
199            }
200        })
201    }
202
203    fn clone_box(&self) -> Box<dyn BotClient> {
204        Box::new(self.clone())
205    }
206}
207
208// TODO: Dedup from other clients.
209#[cfg(not(target_arch = "wasm32"))]
210fn default_client() -> reqwest::Client {
211    use std::time::Duration;
212
213    // On native, there are no default timeouts. Connection may hang if we don't
214    // configure them.
215    reqwest::Client::builder()
216        // Only considered while establishing the connection.
217        .connect_timeout(Duration::from_secs(90))
218        // Considered while reading the response and reset on every chunk
219        // received.
220        //
221        // Warning: Do not use normal `timeout` method as it doesn't consider
222        // this.
223        .read_timeout(Duration::from_secs(90))
224        .build()
225        .unwrap()
226}
227
228#[cfg(target_arch = "wasm32")]
229fn default_client() -> reqwest::Client {
230    // On web, reqwest timeouts are not configurable, but it uses the browser's
231    // fetch API under the hood, which handles connection issues properly.
232    reqwest::Client::new()
233}