moly_kit/mcp/
mcp_manager.rs

1#[cfg(not(target_arch = "wasm32"))]
2use rmcp::{
3    model::{CallToolRequestParam, CallToolResult},
4    service::{RoleClient, RunningService, ServiceExt},
5    transport::{
6        SseClientTransport, TokioChildProcess,
7        streamable_http_client::{StreamableHttpClientTransport, StreamableHttpClientWorker},
8    },
9};
10use serde_json::{Map, Value};
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{Arc, Mutex};
14
15use crate::protocol::{Tool, ToolCall, ToolResult};
16
17/// Creates a namespaced tool name using double underscores as separator
18/// Preserves original naming including hyphens and casing
19fn namespaced_name(server_id: &str, tool_name: &str) -> String {
20    format!("{}__{}", server_id, tool_name)
21}
22
23/// Parses a namespaced tool name into server_id and tool_name components
24/// "filesystem__read_file" -> ("filesystem", "read_file")
25/// "mcp-internet-speed__test-speed" -> ("mcp-internet-speed", "test-speed")
26pub fn parse_namespaced_tool_name(
27    namespaced_name: &str,
28) -> Result<(String, String), Box<dyn std::error::Error>> {
29    let parts: Vec<&str> = namespaced_name.splitn(2, "__").collect();
30    if parts.len() != 2 {
31        return Err(format!(
32            "Invalid namespaced tool name: '{}'. Expected format 'server_id__tool_name'",
33            namespaced_name
34        )
35        .into());
36    }
37    Ok((parts[0].to_string(), parts[1].to_string()))
38}
39
40/// Converts a namespaced tool name to a display-friendly format for UI
41/// "filesystem__read_file" -> "filesystem: read_file"
42/// "mcp-internet-speed__test-speed" -> "mcp-internet-speed: test-speed"
43pub fn display_name_from_namespaced(namespaced_name: &str) -> String {
44    if let Ok((server_id, tool_name)) = parse_namespaced_tool_name(namespaced_name) {
45        format!("{}: {}", server_id, tool_name)
46    } else {
47        // Fallback to original name if parsing fails
48        namespaced_name.to_string()
49    }
50}
51
52/// Parse tool arguments from JSON string to Map
53pub fn parse_tool_arguments(arguments: &str) -> Result<Map<String, Value>, String> {
54    match serde_json::from_str::<Value>(arguments) {
55        Ok(Value::Object(args)) => Ok(args),
56        Ok(_) => Err("Arguments must be a JSON object".to_string()),
57        Err(e) => Err(format!("Failed to parse arguments: {}", e)),
58    }
59}
60
61#[derive(Clone, Debug)]
62pub struct ToolRegistryEntry {
63    pub server_id: String,
64    pub original_name: String,
65    pub namespaced_name: String,
66    pub schema: Tool,
67}
68
69pub struct ToolRegistry {
70    /// A map of all tools, keyed by their namespaced name.
71    tools: HashMap<String, ToolRegistryEntry>,
72    /// A map of all tools, keyed by their server_id.
73    server_tools: HashMap<String, Vec<String>>,
74}
75
76impl ToolRegistry {
77    fn new() -> Self {
78        Self {
79            tools: HashMap::new(),
80            server_tools: HashMap::new(),
81        }
82    }
83
84    fn add_server_tools(&mut self, server_id: &str, tools: Vec<Tool>) {
85        let mut tool_names = Vec::new();
86
87        for tool in tools {
88            let namespaced_name = namespaced_name(server_id, &tool.name);
89            let original_name = tool.name.clone();
90            let entry = ToolRegistryEntry {
91                server_id: server_id.to_string(),
92                original_name: original_name.clone(),
93                namespaced_name: namespaced_name.clone(),
94                schema: tool,
95            };
96
97            self.tools.insert(namespaced_name, entry);
98            tool_names.push(original_name);
99        }
100
101        self.server_tools.insert(server_id.to_string(), tool_names);
102    }
103
104    fn get_tool_entry(&self, namespaced_name: &str) -> Option<&ToolRegistryEntry> {
105        self.tools.get(namespaced_name)
106    }
107
108    fn get_all_tools(&self) -> Vec<Tool> {
109        self.tools
110            .values()
111            .map(|entry| {
112                let mut tool = entry.schema.clone();
113                tool.name = entry.namespaced_name.clone();
114                tool
115            })
116            .collect()
117    }
118
119    fn remove_server(&mut self, server_id: &str) {
120        if let Some(tool_names) = self.server_tools.remove(server_id) {
121            for tool_name in tool_names {
122                let namespaced_name = namespaced_name(server_id, &tool_name);
123                self.tools.remove(&namespaced_name);
124            }
125        }
126    }
127}
128
129// The transport to use for the MCP server
130pub enum McpTransport {
131    Http(String), // The URL for the HTTP endpoint (streamable)
132    Sse(String),  // The URL for the SSE endpoint
133    #[cfg(not(target_arch = "wasm32"))]
134    Stdio(tokio::process::Command), // The command to launch the child process
135}
136
137#[cfg(not(target_arch = "wasm32"))]
138type DynService = Box<dyn rmcp::service::DynService<RoleClient>>;
139
140#[cfg(not(target_arch = "wasm32"))]
141type McpService = RunningService<RoleClient, DynService>;
142
143#[cfg(not(target_arch = "wasm32"))]
144type McpServiceHandle = Arc<McpService>;
145
146#[cfg(not(target_arch = "wasm32"))]
147type McpServiceRegistry = HashMap<String, McpServiceHandle>;
148
149struct McpManagerInner {
150    #[cfg(not(target_arch = "wasm32"))]
151    services: Mutex<McpServiceRegistry>,
152    #[cfg(not(target_arch = "wasm32"))]
153    registry: Mutex<ToolRegistry>,
154    latest_tools: Mutex<Vec<Tool>>,
155    dangerous_mode_enabled: AtomicBool,
156}
157
158/// Manages MCP servers and provides a unified interface for tool discovery and invocation.
159pub struct McpManagerClient {
160    inner: Arc<McpManagerInner>,
161}
162
163impl Clone for McpManagerClient {
164    fn clone(&self) -> Self {
165        Self {
166            inner: Arc::clone(&self.inner),
167        }
168    }
169}
170
171impl McpManagerClient {
172    pub fn new() -> Self {
173        Self {
174            inner: Arc::new(McpManagerInner {
175                #[cfg(not(target_arch = "wasm32"))]
176                services: Mutex::new(HashMap::new()),
177                #[cfg(not(target_arch = "wasm32"))]
178                registry: Mutex::new(ToolRegistry::new()),
179                latest_tools: Mutex::new(Vec::new()),
180                dangerous_mode_enabled: AtomicBool::new(false),
181            }),
182        }
183    }
184
185    /// Registers a new MCP server in the registry, and discovers tools from the server.
186    #[cfg(not(target_arch = "wasm32"))]
187    pub async fn add_server(
188        &self,
189        id: &str,
190        transport: McpTransport,
191    ) -> Result<(), Box<dyn std::error::Error>> {
192        let running_service = match transport {
193            McpTransport::Http(url) => {
194                let worker = StreamableHttpClientWorker::<reqwest::Client>::new_simple(url);
195                let transport = StreamableHttpClientTransport::spawn(worker);
196                ().into_dyn().serve(transport).await?
197            }
198            McpTransport::Sse(url) => {
199                let transport = SseClientTransport::start(url).await?;
200                ().into_dyn().serve(transport).await?
201            }
202            McpTransport::Stdio(command) => {
203                let transport = TokioChildProcess::new(command)?;
204                ().into_dyn().serve(transport).await?
205            }
206        };
207
208        self.inner
209            .services
210            .lock()
211            .unwrap()
212            .insert(id.to_string(), Arc::new(running_service));
213
214        // Discover tools from the newly added server
215        match self.discover_tools_for_server(id).await {
216            Ok(tools) => {
217                self.inner
218                    .registry
219                    .lock()
220                    .unwrap()
221                    .add_server_tools(id, tools);
222                ::log::debug!("Successfully discovered tools for MCP server: {}", id);
223            }
224            Err(e) => {
225                ::log::warn!("Failed to discover tools for MCP server '{}': {}", id, e);
226                // Don't fail the entire server addition if tool discovery fails
227            }
228        }
229
230        Ok(())
231    }
232
233    #[cfg(target_arch = "wasm32")]
234    pub async fn add_server(
235        &self,
236        id: &str,
237        _transport: McpTransport,
238    ) -> Result<(), Box<dyn std::error::Error>> {
239        let _ = id;
240        Err("MCP servers are not supported in web builds".into())
241    }
242
243    pub fn set_dangerous_mode_enabled(&self, enabled: bool) {
244        self.inner
245            .dangerous_mode_enabled
246            .store(enabled, Ordering::Relaxed);
247    }
248
249    pub fn get_dangerous_mode_enabled(&self) -> bool {
250        self.inner.dangerous_mode_enabled.load(Ordering::Relaxed)
251    }
252
253    /// Discovers tools from an MCP server.
254    #[cfg(not(target_arch = "wasm32"))]
255    async fn discover_tools_for_server(
256        &self,
257        server_id: &str,
258    ) -> Result<Vec<Tool>, Box<dyn std::error::Error>> {
259        let service = {
260            let services_guard = self.inner.services.lock().unwrap();
261            services_guard.get(server_id).map(|s| Arc::clone(s))
262        };
263
264        let Some(service) = service else {
265            return Err(format!("Server '{}' not found", server_id).into());
266        };
267
268        let list_tools_result = service.list_tools(Default::default()).await?;
269
270        let tools: Vec<Tool> = list_tools_result
271            .tools
272            .into_iter()
273            .map(|rmcp_tool| rmcp_tool.into())
274            .collect();
275
276        Ok(tools)
277    }
278
279    /// Lists and caches tools from all connected MCP servers.
280    #[cfg(not(target_arch = "wasm32"))]
281    pub async fn list_tools(&self) -> Result<Vec<Tool>, Box<dyn std::error::Error>> {
282        let services: Vec<McpServiceHandle> = {
283            let services_guard = self.inner.services.lock().unwrap();
284            services_guard.values().map(|s| Arc::clone(s)).collect()
285        };
286
287        let mut all_tools = Vec::new();
288        for service in services {
289            match service.list_tools(Default::default()).await {
290                Ok(list_tools_result) => {
291                    // Convert rmcp tools to our unified Tool type
292                    let converted_tools: Vec<Tool> = list_tools_result
293                        .tools
294                        .into_iter()
295                        .map(|rmcp_tool| rmcp_tool.into())
296                        .collect();
297                    all_tools.extend(converted_tools);
298                }
299                Err(e) => {
300                    ::log::warn!("Failed to list tools from server: {}", e);
301                    // Continue with other servers
302                }
303            }
304        }
305
306        *self.inner.latest_tools.lock().unwrap() = all_tools.clone();
307        Ok(all_tools)
308    }
309
310    #[cfg(target_arch = "wasm32")]
311    pub async fn list_tools(&self) -> Result<Vec<Tool>, Box<dyn std::error::Error>> {
312        Ok(Vec::new())
313    }
314
315    pub fn get_latest_tools(&self) -> Vec<Tool> {
316        self.inner.latest_tools.lock().unwrap().clone()
317    }
318
319    #[cfg(not(target_arch = "wasm32"))]
320    pub fn get_all_namespaced_tools(&self) -> Vec<Tool> {
321        self.inner.registry.lock().unwrap().get_all_tools()
322    }
323
324    #[cfg(target_arch = "wasm32")]
325    pub fn get_all_namespaced_tools(&self) -> Vec<Tool> {
326        Vec::new()
327    }
328
329    /// Calls a tool on an MCP server.
330    #[cfg(not(target_arch = "wasm32"))]
331    async fn call_tool(
332        &self,
333        namespaced_tool_name: &str,
334        arguments: serde_json::Map<String, serde_json::Value>,
335    ) -> Result<CallToolResult, Box<dyn std::error::Error>> {
336        // Parse the namespaced tool name to get server_id and original tool name
337        let (server_id, original_tool_name) = parse_namespaced_tool_name(namespaced_tool_name)?;
338
339        // Get the tool entry from registry for validation
340        let tool_entry = {
341            let registry = self.inner.registry.lock().unwrap();
342            registry.get_tool_entry(namespaced_tool_name).cloned()
343        };
344
345        let Some(_tool_entry) = tool_entry else {
346            return Err(format!("Tool '{}' not found in registry. Available tools can be retrieved with get_all_namespaced_tools()", namespaced_tool_name).into());
347        };
348
349        // Get the specific server
350        let service = {
351            let services_guard = self.inner.services.lock().unwrap();
352            services_guard.get(&server_id).map(|s| Arc::clone(s))
353        };
354
355        let Some(service) = service else {
356            return Err(format!("MCP server '{}' not found or disconnected", server_id).into());
357        };
358
359        // TODO: Add argument validation against tool_entry.schema here
360
361        // Call the tool directly on the service
362        let request = CallToolRequestParam {
363            name: original_tool_name.clone().into(),
364            arguments: Some(arguments),
365        };
366
367        match service.call_tool(request).await {
368            Ok(result) => Ok(result),
369            Err(e) => {
370                let error_message = format!(
371                    "Tool '{}' failed on server '{}': {}",
372                    original_tool_name, server_id, e
373                );
374                Err(error_message.into())
375            }
376        }
377    }
378
379    #[cfg(target_arch = "wasm32")]
380    pub async fn call_tool(
381        &self,
382        tool_name: &str,
383        _arguments: serde_json::Map<String, serde_json::Value>,
384    ) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
385        Err(format!(
386            "MCP servers are not yet supported in WASM builds. Cannot call tool '{}'",
387            tool_name
388        )
389        .into())
390    }
391
392    #[cfg(not(target_arch = "wasm32"))]
393    pub async fn remove_server(&self, id: &str) -> Result<(), Box<dyn std::error::Error>> {
394        self.inner.services.lock().unwrap().remove(id);
395        self.inner.registry.lock().unwrap().remove_server(id);
396        Ok(())
397    }
398
399    #[cfg(target_arch = "wasm32")]
400    pub async fn remove_server(&self, _id: &str) -> Result<(), Box<dyn std::error::Error>> {
401        Ok(())
402    }
403
404    /// Executes a tool call and returns the result
405    #[cfg(not(target_arch = "wasm32"))]
406    pub async fn execute_tool_call(
407        &self,
408        tool_name: &str,
409        tool_call_id: &str,
410        arguments: Map<String, Value>,
411    ) -> ToolResult {
412        match self.call_tool(tool_name, arguments).await {
413            Ok(result) => {
414                // Convert result to content string
415                let content = result
416                    .content
417                    .iter()
418                    .filter_map(|item| {
419                        // Convert ContentPart to text - for now we just serialize it
420                        if let Ok(text) = serde_json::to_string(item) {
421                            Some(text)
422                        } else {
423                            None
424                        }
425                    })
426                    .collect::<Vec<_>>()
427                    .join("\n");
428
429                ToolResult {
430                    tool_call_id: tool_call_id.to_string(),
431                    content,
432                    is_error: false,
433                }
434            }
435            Err(e) => ToolResult {
436                tool_call_id: tool_call_id.to_string(),
437                content: e.to_string(),
438                is_error: true,
439            },
440        }
441    }
442
443    #[cfg(target_arch = "wasm32")]
444    pub async fn execute_tool_call(
445        &self,
446        tool_name: &str,
447        tool_call_id: &str,
448        _arguments: Map<String, Value>,
449    ) -> ToolResult {
450        ToolResult {
451            tool_call_id: tool_call_id.to_string(),
452            content: format!(
453                "MCP servers are not yet supported in WASM builds. Cannot call tool '{}'",
454                tool_name
455            ),
456            is_error: true,
457        }
458    }
459
460    /// Executes multiple tool calls sequentially and returns the results
461    pub async fn execute_tool_calls(&self, tool_calls: Vec<ToolCall>) -> Vec<ToolResult> {
462        let mut tool_results = Vec::new();
463
464        // Execute all tool calls sequentially
465        for tool_call in tool_calls {
466            let result = self
467                .execute_tool_call(&tool_call.name, &tool_call.id, tool_call.arguments.clone())
468                .await;
469            tool_results.push(result);
470        }
471
472        tool_results
473    }
474}