use std::path::PathBuf; use serde::Deserialize; use serde::Serialize; #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] pub enum ShellType { Zsh, Bash, PowerShell, Sh, Cmd, } impl ShellType { pub fn name(self) -> &'static str { match self { Self::Zsh => "zsh", Self::Bash => "bash", Self::PowerShell => "powershell", Self::Sh => "sh", Self::Cmd => "cmd", } } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct DetectedShell { pub shell_type: ShellType, pub shell_path: PathBuf, } impl DetectedShell { pub fn name(&self) -> &'static str { self.shell_type.name() } } pub fn detect_shell_type(shell_path: impl AsRef) -> Option { let shell_path = shell_path.as_ref(); match shell_path.as_os_str().to_str() { Some("zsh") => Some(ShellType::Zsh), Some("sh") => Some(ShellType::Sh), Some("cmd") => Some(ShellType::Cmd), Some("bash") => Some(ShellType::Bash), Some("pwsh") => Some(ShellType::PowerShell), Some("powershell") => Some(ShellType::PowerShell), _ => { let shell_name = shell_path.file_stem(); if let Some(shell_name) = shell_name { let shell_name_path = std::path::Path::new(shell_name); if shell_name_path != shell_path { return detect_shell_type(shell_name_path); } } None } } } #[cfg(unix)] fn get_user_shell_path() -> Option { let uid = unsafe { libc::getuid() }; use std::ffi::CStr; use std::mem::MaybeUninit; use std::ptr; let mut passwd = MaybeUninit::::uninit(); // We cannot use getpwuid here: it returns pointers into libc-managed // storage, which is not safe to read concurrently on all targets (the musl // static build used by the CLI can segfault when parallel callers race on // that buffer). getpwuid_r keeps the passwd data in caller-owned memory. let suggested_buffer_len = unsafe { libc::sysconf(libc::_SC_GETPW_R_SIZE_MAX) }; let buffer_len = usize::try_from(suggested_buffer_len) .ok() .filter(|len| *len > 0) .unwrap_or(1024); let mut buffer = vec![0; buffer_len]; loop { let mut result = ptr::null_mut(); let status = unsafe { libc::getpwuid_r( uid, passwd.as_mut_ptr(), buffer.as_mut_ptr().cast(), buffer.len(), &mut result, ) }; if status == 0 { if result.is_null() { return None; } let passwd = unsafe { passwd.assume_init_ref() }; if passwd.pw_shell.is_null() { return None; } let shell_path = unsafe { CStr::from_ptr(passwd.pw_shell) } .to_string_lossy() .into_owned(); return Some(PathBuf::from(shell_path)); } if status != libc::ERANGE { return None; } // Retry with a larger buffer until libc can materialize the passwd entry. let new_len = buffer.len().checked_mul(2)?; if new_len > 1024 * 1024 { return None; } buffer.resize(new_len, 0); } } #[cfg(not(unix))] fn get_user_shell_path() -> Option { None } fn file_exists(path: &std::path::Path) -> Option { if std::fs::metadata(path).is_ok_and(|metadata| metadata.is_file()) { Some(PathBuf::from(path)) } else { None } } fn get_shell_path( shell_type: ShellType, provided_path: Option<&PathBuf>, binary_name: &str, fallback_paths: &[&str], ) -> Option { if let Some(path) = provided_path.and_then(|path| file_exists(path)) { return Some(path); } let default_shell_path = get_user_shell_path(); if let Some(default_shell_path) = default_shell_path && detect_shell_type(&default_shell_path) == Some(shell_type) && file_exists(&default_shell_path).is_some() { return Some(default_shell_path); } if let Ok(path) = which::which(binary_name) { return Some(path); } for path in fallback_paths { if let Some(path) = file_exists(std::path::Path::new(path)) { return Some(path); } } None } const ZSH_FALLBACK_PATHS: &[&str] = &["/bin/zsh"]; fn get_zsh_shell(path: Option<&PathBuf>) -> Option { let shell_path = get_shell_path(ShellType::Zsh, path, "zsh", ZSH_FALLBACK_PATHS); shell_path.map(|shell_path| DetectedShell { shell_type: ShellType::Zsh, shell_path, }) } const BASH_FALLBACK_PATHS: &[&str] = &["/bin/bash", "/usr/bin/bash"]; fn get_bash_shell(path: Option<&PathBuf>) -> Option { let shell_path = get_shell_path(ShellType::Bash, path, "bash", BASH_FALLBACK_PATHS); shell_path.map(|shell_path| DetectedShell { shell_type: ShellType::Bash, shell_path, }) } const SH_FALLBACK_PATHS: &[&str] = &["/bin/sh"]; fn get_sh_shell(path: Option<&PathBuf>) -> Option { let shell_path = get_shell_path(ShellType::Sh, path, "sh", SH_FALLBACK_PATHS); shell_path.map(|shell_path| DetectedShell { shell_type: ShellType::Sh, shell_path, }) } // Note the `pwsh` and `powershell` fallback paths are where the respective // shells are commonly installed on GitHub Actions Windows runners, but may not // be present on all Windows machines: // https://docs.github.com/en/actions/tutorials/build-and-test-code/powershell #[cfg(windows)] const PWSH_FALLBACK_PATHS: &[&str] = &[r#"C:\Program Files\PowerShell\7\pwsh.exe"#]; #[cfg(not(windows))] const PWSH_FALLBACK_PATHS: &[&str] = &["/usr/local/bin/pwsh"]; #[cfg(windows)] const POWERSHELL_FALLBACK_PATHS: &[&str] = &[r#"C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe"#]; #[cfg(not(windows))] const POWERSHELL_FALLBACK_PATHS: &[&str] = &[]; fn get_powershell_shell(path: Option<&PathBuf>) -> Option { let shell_path = get_shell_path(ShellType::PowerShell, path, "pwsh", PWSH_FALLBACK_PATHS) .or_else(|| { get_shell_path( ShellType::PowerShell, path, "powershell", POWERSHELL_FALLBACK_PATHS, ) }); shell_path.map(|shell_path| DetectedShell { shell_type: ShellType::PowerShell, shell_path, }) } fn get_cmd_shell(path: Option<&PathBuf>) -> Option { let shell_path = get_shell_path(ShellType::Cmd, path, "cmd", &[]); shell_path.map(|shell_path| DetectedShell { shell_type: ShellType::Cmd, shell_path, }) } pub fn ultimate_fallback_shell() -> DetectedShell { if cfg!(windows) { DetectedShell { shell_type: ShellType::Cmd, shell_path: PathBuf::from("cmd.exe"), } } else { DetectedShell { shell_type: ShellType::Sh, shell_path: PathBuf::from("/bin/sh"), } } } pub fn get_shell_by_model_provided_path(shell_path: &PathBuf) -> DetectedShell { detect_shell_type(shell_path) .and_then(|shell_type| get_shell(shell_type, Some(shell_path))) .unwrap_or_else(ultimate_fallback_shell) } pub fn get_shell(shell_type: ShellType, path: Option<&PathBuf>) -> Option { match shell_type { ShellType::Zsh => get_zsh_shell(path), ShellType::Bash => get_bash_shell(path), ShellType::PowerShell => get_powershell_shell(path), ShellType::Sh => get_sh_shell(path), ShellType::Cmd => get_cmd_shell(path), } } pub fn default_user_shell() -> DetectedShell { default_user_shell_from_path(get_user_shell_path()) } pub fn default_user_shell_from_path(user_shell_path: Option) -> DetectedShell { if cfg!(windows) { get_shell(ShellType::PowerShell, /*path*/ None).unwrap_or_else(ultimate_fallback_shell) } else { let user_default_shell = user_shell_path .and_then(|shell| detect_shell_type(&shell)) .and_then(|shell_type| get_shell(shell_type, /*path*/ None)); let shell_with_fallback = if cfg!(target_os = "macos") { user_default_shell .or_else(|| get_shell(ShellType::Zsh, /*path*/ None)) .or_else(|| get_shell(ShellType::Bash, /*path*/ None)) } else { user_default_shell .or_else(|| get_shell(ShellType::Bash, /*path*/ None)) .or_else(|| get_shell(ShellType::Zsh, /*path*/ None)) }; shell_with_fallback.unwrap_or_else(ultimate_fallback_shell) } } #[cfg(test)] mod tests { use super::*; use pretty_assertions::assert_eq; #[test] fn test_detect_shell_type() { assert_eq!( detect_shell_type(PathBuf::from("zsh")), Some(ShellType::Zsh) ); assert_eq!( detect_shell_type(PathBuf::from("bash")), Some(ShellType::Bash) ); assert_eq!( detect_shell_type(PathBuf::from("pwsh")), Some(ShellType::PowerShell) ); assert_eq!( detect_shell_type(PathBuf::from("powershell")), Some(ShellType::PowerShell) ); assert_eq!(detect_shell_type(PathBuf::from("fish")), None); assert_eq!(detect_shell_type(PathBuf::from("other")), None); assert_eq!( detect_shell_type(PathBuf::from("/bin/zsh")), Some(ShellType::Zsh) ); assert_eq!( detect_shell_type(PathBuf::from("/bin/bash")), Some(ShellType::Bash) ); assert_eq!( detect_shell_type(PathBuf::from("/usr/bin/bash")), Some(ShellType::Bash) ); assert_eq!( detect_shell_type(PathBuf::from("powershell.exe")), Some(ShellType::PowerShell) ); assert_eq!( detect_shell_type(PathBuf::from(if cfg!(windows) { "C:\\windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe" } else { "/usr/local/bin/pwsh" })), Some(ShellType::PowerShell) ); assert_eq!( detect_shell_type(PathBuf::from("pwsh.exe")), Some(ShellType::PowerShell) ); assert_eq!( detect_shell_type(PathBuf::from("/usr/local/bin/pwsh")), Some(ShellType::PowerShell) ); assert_eq!( detect_shell_type(PathBuf::from("/bin/sh")), Some(ShellType::Sh) ); assert_eq!(detect_shell_type(PathBuf::from("sh")), Some(ShellType::Sh)); assert_eq!( detect_shell_type(PathBuf::from("cmd")), Some(ShellType::Cmd) ); assert_eq!( detect_shell_type(PathBuf::from("cmd.exe")), Some(ShellType::Cmd) ); } }