moly_kit/clients/
openai_image.rs1use crate::protocol::Tool;
4use crate::protocol::*;
5use crate::utils::asynchronous::{BoxPlatformSendFuture, BoxPlatformSendStream};
6use reqwest::header::{HeaderMap, HeaderName};
7use std::{
8 str::FromStr,
9 sync::{Arc, RwLock},
10};
11
12#[derive(Debug, Clone)]
13struct OpenAIImageClientInner {
14 url: String,
15 client: reqwest::Client,
16 headers: HeaderMap,
17}
18
19#[derive(Debug)]
25pub struct OpenAIImageClient(Arc<RwLock<OpenAIImageClientInner>>);
26
27impl Clone for OpenAIImageClient {
28 fn clone(&self) -> Self {
29 OpenAIImageClient(Arc::clone(&self.0))
30 }
31}
32
33impl OpenAIImageClient {
34 pub fn new(url: String) -> Self {
35 let headers = HeaderMap::new();
36 let client = default_client();
37
38 let inner = OpenAIImageClientInner {
39 url,
40 client,
41 headers,
42 };
43
44 OpenAIImageClient(Arc::new(RwLock::new(inner)))
45 }
46
47 pub fn set_header(&mut self, key: &str, value: &str) -> Result<(), &'static str> {
48 let header_name = HeaderName::from_str(key).map_err(|_| "Invalid header name")?;
49
50 let header_value = value.parse().map_err(|_| "Invalid header value")?;
51
52 self.0
53 .write()
54 .unwrap()
55 .headers
56 .insert(header_name, header_value);
57
58 Ok(())
59 }
60
61 pub fn set_key(&mut self, key: &str) -> Result<(), &'static str> {
62 self.set_header("Authorization", &format!("Bearer {}", key))
63 }
64
65 pub fn get_url(&self) -> String {
66 self.0.read().unwrap().url.clone()
67 }
68
69 async fn generate_image(
70 &self,
71 bot_id: &BotId,
72 messages: &[Message],
73 ) -> Result<MessageContent, ClientError> {
74 let inner = self.0.read().unwrap().clone();
75
76 let prompt = messages
77 .last()
78 .map(|msg| msg.content.text.as_str())
79 .ok_or_else(|| {
80 ClientError::new(ClientErrorKind::Unknown, "No messages provided".to_string())
81 })?;
82
83 let url = format!("{}/images/generations", inner.url);
84
85 let request_json = serde_json::json!({
86 "model": bot_id.id(),
87 "prompt": prompt,
88 "size": "1024x1024",
90 "response_format": "b64_json"
93 });
94
95 let request = inner
96 .client
97 .post(&url)
98 .headers(inner.headers.clone())
99 .json(&request_json);
100
101 let response = request.send().await.map_err(|e| {
102 ClientError::new_with_source(
103 ClientErrorKind::Network,
104 format!(
105 "Could not send request to {url}. Verify your connection and the server status."
106 ),
107 Some(e),
108 )
109 })?;
110
111 let status = response.status();
112 let text = response.text().await.unwrap_or_default();
113
114 if !status.is_success() {
115 return Err(ClientError::new(
116 ClientErrorKind::Response,
117 format!(
118 "Request to {url} failed with status {} and content: {}",
119 status, text
120 ),
121 ));
122 }
123
124 let response_json: serde_json::Value = serde_json::from_str(&text).map_err(|e| {
125 ClientError::new_with_source(
126 ClientErrorKind::Format,
127 format!(
128 "Failed to parse response from {url}. It does not match the expected format."
129 ),
130 Some(e),
131 )
132 })?;
133
134 let image_data = response_json
135 .get("data")
136 .and_then(|data| data.get(0))
137 .and_then(|item| item.get("b64_json"))
138 .and_then(|b64| b64.as_str())
139 .ok_or_else(|| {
140 ClientError::new(
141 ClientErrorKind::Format,
142 "Response does not contain expected 'b64_json' field".to_string(),
143 )
144 })?;
145
146 let attachment =
147 Attachment::from_base64("image.png".into(), Some("image/png".into()), image_data)
148 .map_err(|e| {
149 ClientError::new_with_source(
150 ClientErrorKind::Format,
151 "Failed to create attachment from base64 data".to_string(),
152 Some(e),
153 )
154 })?;
155
156 let content = MessageContent {
157 text: String::new(),
158 attachments: vec![attachment],
159 ..Default::default()
160 };
161
162 Ok(content)
163 }
164}
165
166impl BotClient for OpenAIImageClient {
167 fn bots(&self) -> BoxPlatformSendFuture<'static, ClientResult<Vec<Bot>>> {
168 let inner = self.0.read().unwrap().clone();
169
170 let supported: Vec<Bot> = ["dall-e-2", "dall-e-3", "gpt-image-1"]
173 .into_iter()
174 .map(|id| Bot {
175 id: BotId::new(id, &inner.url),
176 name: id.to_string(),
177 avatar: Picture::Grapheme("I".into()),
178 capabilities: BotCapabilities::new(),
179 })
180 .collect();
181
182 Box::pin(futures::future::ready(ClientResult::new_ok(supported)))
183 }
184
185 fn send(
186 &mut self,
187 bot_id: &BotId,
188 messages: &[Message],
189 _tools: &[Tool],
190 ) -> BoxPlatformSendStream<'static, ClientResult<MessageContent>> {
191 let self_clone = self.clone();
192 let bot_id = bot_id.clone();
193 let messages = messages.to_vec();
194
195 Box::pin(async_stream::stream! {
196 match self_clone.generate_image(&bot_id, &messages).await {
197 Ok(content) => yield ClientResult::new_ok(content),
198 Err(e) => yield ClientResult::new_err(e.into()),
199 }
200 })
201 }
202
203 fn clone_box(&self) -> Box<dyn BotClient> {
204 Box::new(self.clone())
205 }
206}
207
208#[cfg(not(target_arch = "wasm32"))]
210fn default_client() -> reqwest::Client {
211 use std::time::Duration;
212
213 reqwest::Client::builder()
216 .connect_timeout(Duration::from_secs(90))
218 .read_timeout(Duration::from_secs(90))
224 .build()
225 .unwrap()
226}
227
228#[cfg(target_arch = "wasm32")]
229fn default_client() -> reqwest::Client {
230 reqwest::Client::new()
233}