1#[cfg(not(target_arch = "wasm32"))]
2use rmcp::{
3 model::{CallToolRequestParam, CallToolResult},
4 service::{RoleClient, RunningService, ServiceExt},
5 transport::{
6 SseClientTransport, TokioChildProcess,
7 streamable_http_client::{StreamableHttpClientTransport, StreamableHttpClientWorker},
8 },
9};
10use serde_json::{Map, Value};
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{Arc, Mutex};
14
15use crate::protocol::{Tool, ToolCall, ToolResult};
16
17fn namespaced_name(server_id: &str, tool_name: &str) -> String {
20 format!("{}__{}", server_id, tool_name)
21}
22
23pub fn parse_namespaced_tool_name(
27 namespaced_name: &str,
28) -> Result<(String, String), Box<dyn std::error::Error>> {
29 let parts: Vec<&str> = namespaced_name.splitn(2, "__").collect();
30 if parts.len() != 2 {
31 return Err(format!(
32 "Invalid namespaced tool name: '{}'. Expected format 'server_id__tool_name'",
33 namespaced_name
34 )
35 .into());
36 }
37 Ok((parts[0].to_string(), parts[1].to_string()))
38}
39
40pub fn display_name_from_namespaced(namespaced_name: &str) -> String {
44 if let Ok((server_id, tool_name)) = parse_namespaced_tool_name(namespaced_name) {
45 format!("{}: {}", server_id, tool_name)
46 } else {
47 namespaced_name.to_string()
49 }
50}
51
52pub fn parse_tool_arguments(arguments: &str) -> Result<Map<String, Value>, String> {
54 match serde_json::from_str::<Value>(arguments) {
55 Ok(Value::Object(args)) => Ok(args),
56 Ok(_) => Err("Arguments must be a JSON object".to_string()),
57 Err(e) => Err(format!("Failed to parse arguments: {}", e)),
58 }
59}
60
61#[derive(Clone, Debug)]
62pub struct ToolRegistryEntry {
63 pub server_id: String,
64 pub original_name: String,
65 pub namespaced_name: String,
66 pub schema: Tool,
67}
68
69pub struct ToolRegistry {
70 tools: HashMap<String, ToolRegistryEntry>,
72 server_tools: HashMap<String, Vec<String>>,
74}
75
76impl ToolRegistry {
77 fn new() -> Self {
78 Self {
79 tools: HashMap::new(),
80 server_tools: HashMap::new(),
81 }
82 }
83
84 fn add_server_tools(&mut self, server_id: &str, tools: Vec<Tool>) {
85 let mut tool_names = Vec::new();
86
87 for tool in tools {
88 let namespaced_name = namespaced_name(server_id, &tool.name);
89 let original_name = tool.name.clone();
90 let entry = ToolRegistryEntry {
91 server_id: server_id.to_string(),
92 original_name: original_name.clone(),
93 namespaced_name: namespaced_name.clone(),
94 schema: tool,
95 };
96
97 self.tools.insert(namespaced_name, entry);
98 tool_names.push(original_name);
99 }
100
101 self.server_tools.insert(server_id.to_string(), tool_names);
102 }
103
104 fn get_tool_entry(&self, namespaced_name: &str) -> Option<&ToolRegistryEntry> {
105 self.tools.get(namespaced_name)
106 }
107
108 fn get_all_tools(&self) -> Vec<Tool> {
109 self.tools
110 .values()
111 .map(|entry| {
112 let mut tool = entry.schema.clone();
113 tool.name = entry.namespaced_name.clone();
114 tool
115 })
116 .collect()
117 }
118
119 fn remove_server(&mut self, server_id: &str) {
120 if let Some(tool_names) = self.server_tools.remove(server_id) {
121 for tool_name in tool_names {
122 let namespaced_name = namespaced_name(server_id, &tool_name);
123 self.tools.remove(&namespaced_name);
124 }
125 }
126 }
127}
128
129pub enum McpTransport {
131 Http(String), Sse(String), #[cfg(not(target_arch = "wasm32"))]
134 Stdio(tokio::process::Command), }
136
137#[cfg(not(target_arch = "wasm32"))]
138type DynService = Box<dyn rmcp::service::DynService<RoleClient>>;
139
140#[cfg(not(target_arch = "wasm32"))]
141type McpService = RunningService<RoleClient, DynService>;
142
143#[cfg(not(target_arch = "wasm32"))]
144type McpServiceHandle = Arc<McpService>;
145
146#[cfg(not(target_arch = "wasm32"))]
147type McpServiceRegistry = HashMap<String, McpServiceHandle>;
148
149struct McpManagerInner {
150 #[cfg(not(target_arch = "wasm32"))]
151 services: Mutex<McpServiceRegistry>,
152 #[cfg(not(target_arch = "wasm32"))]
153 registry: Mutex<ToolRegistry>,
154 latest_tools: Mutex<Vec<Tool>>,
155 dangerous_mode_enabled: AtomicBool,
156}
157
158pub struct McpManagerClient {
160 inner: Arc<McpManagerInner>,
161}
162
163impl Clone for McpManagerClient {
164 fn clone(&self) -> Self {
165 Self {
166 inner: Arc::clone(&self.inner),
167 }
168 }
169}
170
171impl McpManagerClient {
172 pub fn new() -> Self {
173 Self {
174 inner: Arc::new(McpManagerInner {
175 #[cfg(not(target_arch = "wasm32"))]
176 services: Mutex::new(HashMap::new()),
177 #[cfg(not(target_arch = "wasm32"))]
178 registry: Mutex::new(ToolRegistry::new()),
179 latest_tools: Mutex::new(Vec::new()),
180 dangerous_mode_enabled: AtomicBool::new(false),
181 }),
182 }
183 }
184
185 #[cfg(not(target_arch = "wasm32"))]
187 pub async fn add_server(
188 &self,
189 id: &str,
190 transport: McpTransport,
191 ) -> Result<(), Box<dyn std::error::Error>> {
192 let running_service = match transport {
193 McpTransport::Http(url) => {
194 let worker = StreamableHttpClientWorker::<reqwest::Client>::new_simple(url);
195 let transport = StreamableHttpClientTransport::spawn(worker);
196 ().into_dyn().serve(transport).await?
197 }
198 McpTransport::Sse(url) => {
199 let transport = SseClientTransport::start(url).await?;
200 ().into_dyn().serve(transport).await?
201 }
202 McpTransport::Stdio(command) => {
203 let transport = TokioChildProcess::new(command)?;
204 ().into_dyn().serve(transport).await?
205 }
206 };
207
208 self.inner
209 .services
210 .lock()
211 .unwrap()
212 .insert(id.to_string(), Arc::new(running_service));
213
214 match self.discover_tools_for_server(id).await {
216 Ok(tools) => {
217 self.inner
218 .registry
219 .lock()
220 .unwrap()
221 .add_server_tools(id, tools);
222 ::log::debug!("Successfully discovered tools for MCP server: {}", id);
223 }
224 Err(e) => {
225 ::log::warn!("Failed to discover tools for MCP server '{}': {}", id, e);
226 }
228 }
229
230 Ok(())
231 }
232
233 #[cfg(target_arch = "wasm32")]
234 pub async fn add_server(
235 &self,
236 id: &str,
237 _transport: McpTransport,
238 ) -> Result<(), Box<dyn std::error::Error>> {
239 let _ = id;
240 Err("MCP servers are not supported in web builds".into())
241 }
242
243 pub fn set_dangerous_mode_enabled(&self, enabled: bool) {
244 self.inner
245 .dangerous_mode_enabled
246 .store(enabled, Ordering::Relaxed);
247 }
248
249 pub fn get_dangerous_mode_enabled(&self) -> bool {
250 self.inner.dangerous_mode_enabled.load(Ordering::Relaxed)
251 }
252
253 #[cfg(not(target_arch = "wasm32"))]
255 async fn discover_tools_for_server(
256 &self,
257 server_id: &str,
258 ) -> Result<Vec<Tool>, Box<dyn std::error::Error>> {
259 let service = {
260 let services_guard = self.inner.services.lock().unwrap();
261 services_guard.get(server_id).map(|s| Arc::clone(s))
262 };
263
264 let Some(service) = service else {
265 return Err(format!("Server '{}' not found", server_id).into());
266 };
267
268 let list_tools_result = service.list_tools(Default::default()).await?;
269
270 let tools: Vec<Tool> = list_tools_result
271 .tools
272 .into_iter()
273 .map(|rmcp_tool| rmcp_tool.into())
274 .collect();
275
276 Ok(tools)
277 }
278
279 #[cfg(not(target_arch = "wasm32"))]
281 pub async fn list_tools(&self) -> Result<Vec<Tool>, Box<dyn std::error::Error>> {
282 let services: Vec<McpServiceHandle> = {
283 let services_guard = self.inner.services.lock().unwrap();
284 services_guard.values().map(|s| Arc::clone(s)).collect()
285 };
286
287 let mut all_tools = Vec::new();
288 for service in services {
289 match service.list_tools(Default::default()).await {
290 Ok(list_tools_result) => {
291 let converted_tools: Vec<Tool> = list_tools_result
293 .tools
294 .into_iter()
295 .map(|rmcp_tool| rmcp_tool.into())
296 .collect();
297 all_tools.extend(converted_tools);
298 }
299 Err(e) => {
300 ::log::warn!("Failed to list tools from server: {}", e);
301 }
303 }
304 }
305
306 *self.inner.latest_tools.lock().unwrap() = all_tools.clone();
307 Ok(all_tools)
308 }
309
310 #[cfg(target_arch = "wasm32")]
311 pub async fn list_tools(&self) -> Result<Vec<Tool>, Box<dyn std::error::Error>> {
312 Ok(Vec::new())
313 }
314
315 pub fn get_latest_tools(&self) -> Vec<Tool> {
316 self.inner.latest_tools.lock().unwrap().clone()
317 }
318
319 #[cfg(not(target_arch = "wasm32"))]
320 pub fn get_all_namespaced_tools(&self) -> Vec<Tool> {
321 self.inner.registry.lock().unwrap().get_all_tools()
322 }
323
324 #[cfg(target_arch = "wasm32")]
325 pub fn get_all_namespaced_tools(&self) -> Vec<Tool> {
326 Vec::new()
327 }
328
329 #[cfg(not(target_arch = "wasm32"))]
331 async fn call_tool(
332 &self,
333 namespaced_tool_name: &str,
334 arguments: serde_json::Map<String, serde_json::Value>,
335 ) -> Result<CallToolResult, Box<dyn std::error::Error>> {
336 let (server_id, original_tool_name) = parse_namespaced_tool_name(namespaced_tool_name)?;
338
339 let tool_entry = {
341 let registry = self.inner.registry.lock().unwrap();
342 registry.get_tool_entry(namespaced_tool_name).cloned()
343 };
344
345 let Some(_tool_entry) = tool_entry else {
346 return Err(format!("Tool '{}' not found in registry. Available tools can be retrieved with get_all_namespaced_tools()", namespaced_tool_name).into());
347 };
348
349 let service = {
351 let services_guard = self.inner.services.lock().unwrap();
352 services_guard.get(&server_id).map(|s| Arc::clone(s))
353 };
354
355 let Some(service) = service else {
356 return Err(format!("MCP server '{}' not found or disconnected", server_id).into());
357 };
358
359 let request = CallToolRequestParam {
363 name: original_tool_name.clone().into(),
364 arguments: Some(arguments),
365 };
366
367 match service.call_tool(request).await {
368 Ok(result) => Ok(result),
369 Err(e) => {
370 let error_message = format!(
371 "Tool '{}' failed on server '{}': {}",
372 original_tool_name, server_id, e
373 );
374 Err(error_message.into())
375 }
376 }
377 }
378
379 #[cfg(target_arch = "wasm32")]
380 pub async fn call_tool(
381 &self,
382 tool_name: &str,
383 _arguments: serde_json::Map<String, serde_json::Value>,
384 ) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
385 Err(format!(
386 "MCP servers are not yet supported in WASM builds. Cannot call tool '{}'",
387 tool_name
388 )
389 .into())
390 }
391
392 #[cfg(not(target_arch = "wasm32"))]
393 pub async fn remove_server(&self, id: &str) -> Result<(), Box<dyn std::error::Error>> {
394 self.inner.services.lock().unwrap().remove(id);
395 self.inner.registry.lock().unwrap().remove_server(id);
396 Ok(())
397 }
398
399 #[cfg(target_arch = "wasm32")]
400 pub async fn remove_server(&self, _id: &str) -> Result<(), Box<dyn std::error::Error>> {
401 Ok(())
402 }
403
404 #[cfg(not(target_arch = "wasm32"))]
406 pub async fn execute_tool_call(
407 &self,
408 tool_name: &str,
409 tool_call_id: &str,
410 arguments: Map<String, Value>,
411 ) -> ToolResult {
412 match self.call_tool(tool_name, arguments).await {
413 Ok(result) => {
414 let content = result
416 .content
417 .iter()
418 .filter_map(|item| {
419 if let Ok(text) = serde_json::to_string(item) {
421 Some(text)
422 } else {
423 None
424 }
425 })
426 .collect::<Vec<_>>()
427 .join("\n");
428
429 ToolResult {
430 tool_call_id: tool_call_id.to_string(),
431 content,
432 is_error: false,
433 }
434 }
435 Err(e) => ToolResult {
436 tool_call_id: tool_call_id.to_string(),
437 content: e.to_string(),
438 is_error: true,
439 },
440 }
441 }
442
443 #[cfg(target_arch = "wasm32")]
444 pub async fn execute_tool_call(
445 &self,
446 tool_name: &str,
447 tool_call_id: &str,
448 _arguments: Map<String, Value>,
449 ) -> ToolResult {
450 ToolResult {
451 tool_call_id: tool_call_id.to_string(),
452 content: format!(
453 "MCP servers are not yet supported in WASM builds. Cannot call tool '{}'",
454 tool_name
455 ),
456 is_error: true,
457 }
458 }
459
460 pub async fn execute_tool_calls(&self, tool_calls: Vec<ToolCall>) -> Vec<ToolResult> {
462 let mut tool_results = Vec::new();
463
464 for tool_call in tool_calls {
466 let result = self
467 .execute_tool_call(&tool_call.name, &tool_call.id, tool_call.arguments.clone())
468 .await;
469 tool_results.push(result);
470 }
471
472 tool_results
473 }
474}