Coverage for src/workstack/cli/shell_integration/handler.py: 87%

87 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-19 09:31 -0400

1import os 

2import shlex 

3from dataclasses import dataclass 

4from pathlib import Path 

5from typing import Final 

6 

7from click.testing import CliRunner 

8 

9from workstack.cli.commands.create import create 

10from workstack.cli.commands.down import down_cmd 

11from workstack.cli.commands.jump import jump_cmd 

12from workstack.cli.commands.prepare_cwd_recovery import generate_recovery_script 

13from workstack.cli.commands.switch import switch_cmd 

14from workstack.cli.commands.up import up_cmd 

15from workstack.cli.debug import debug_log 

16from workstack.cli.shell_utils import ( 

17 STALE_SCRIPT_MAX_AGE_SECONDS, 

18 cleanup_stale_scripts, 

19 write_script_to_temp, 

20) 

21from workstack.core.context import create_context 

22 

23PASSTHROUGH_MARKER: Final[str] = "__WORKSTACK_PASSTHROUGH__" 

24PASSTHROUGH_COMMANDS: Final[set[str]] = {"sync"} 

25 

26 

27@dataclass(frozen=True) 

28class ShellIntegrationResult: 

29 """Result returned by shell integration handlers.""" 

30 

31 passthrough: bool 

32 script: str | None 

33 exit_code: int 

34 

35 

36def _invoke_hidden_command(command_name: str, args: tuple[str, ...]) -> ShellIntegrationResult: 

37 """Invoke a command with --script flag for shell integration. 

38 

39 If args contain help flags or explicit --script, passthrough to regular command. 

40 Otherwise, add --script flag and capture the activation script. 

41 """ 

42 # Check if help flags or --script are present - these should pass through 

43 if "-h" in args or "--help" in args or "--script" in args: 

44 return ShellIntegrationResult(passthrough=True, script=None, exit_code=0) 

45 

46 # Map command names to their Click commands 

47 command_map = { 

48 "switch": switch_cmd, 

49 "create": create, 

50 "jump": jump_cmd, 

51 "up": up_cmd, 

52 "down": down_cmd, 

53 } 

54 

55 command = command_map.get(command_name) 

56 if command is None: 

57 if command_name in PASSTHROUGH_COMMANDS: 

58 return _build_passthrough_script(command_name, args) 

59 return ShellIntegrationResult(passthrough=True, script=None, exit_code=0) 

60 

61 # Add --script flag to get activation script 

62 script_args = list(args) + ["--script"] 

63 

64 debug_log(f"Handler: Invoking {command_name} with args: {script_args}") 

65 

66 # Clean up stale scripts before running (opportunistic cleanup) 

67 cleanup_stale_scripts(max_age_seconds=STALE_SCRIPT_MAX_AGE_SECONDS) 

68 

69 runner = CliRunner() 

70 result = runner.invoke( 

71 command, 

72 script_args, 

73 obj=create_context(dry_run=False), 

74 standalone_mode=False, 

75 ) 

76 

77 exit_code = int(result.exit_code) 

78 

79 # If command failed, passthrough to show proper error 

80 if exit_code != 0: 

81 return ShellIntegrationResult(passthrough=True, script=None, exit_code=exit_code) 

82 

83 # Output is now a file path, not script content 

84 script_path = result.output.strip() if result.output else None 

85 

86 debug_log(f"Handler: Got script_path={script_path}, exit_code={exit_code}") 

87 if script_path: 

88 script_exists = Path(script_path).exists() 

89 debug_log(f"Handler: Script exists? {script_exists}") 

90 

91 return ShellIntegrationResult(passthrough=False, script=script_path, exit_code=exit_code) 

92 

93 

94def handle_shell_request(args: tuple[str, ...]) -> ShellIntegrationResult: 

95 """Dispatch shell integration handling based on the original CLI invocation.""" 

96 if not args: 

97 return ShellIntegrationResult(passthrough=True, script=None, exit_code=0) 

98 

99 command_name = args[0] 

100 command_args = tuple(args[1:]) 

101 

102 return _invoke_hidden_command(command_name, command_args) 

103 

104 

105def _build_passthrough_script(command_name: str, args: tuple[str, ...]) -> ShellIntegrationResult: 

106 """Create a passthrough script tailored for the caller's shell.""" 

107 shell_name = os.environ.get("WORKSTACK_SHELL", "bash").lower() 

108 ctx = create_context(dry_run=False) 

109 recovery_path = generate_recovery_script(ctx) 

110 

111 script_content = _render_passthrough_script(shell_name, command_name, args, recovery_path) 

112 script_path = write_script_to_temp( 

113 script_content, 

114 command_name=f"{command_name}-passthrough", 

115 comment="generated by __shell passthrough handler", 

116 ) 

117 return ShellIntegrationResult(passthrough=False, script=str(script_path), exit_code=0) 

118 

119 

120def _render_passthrough_script( 

121 shell_name: str, 

122 command_name: str, 

123 args: tuple[str, ...], 

124 recovery_path: Path | None, 

125) -> str: 

126 """Render shell-specific script that runs the command and performs recovery.""" 

127 if shell_name == "fish": 

128 return _render_fish_passthrough(command_name, args, recovery_path) 

129 return _render_posix_passthrough(command_name, args, recovery_path) 

130 

131 

132def _render_posix_passthrough( 

133 command_name: str, 

134 args: tuple[str, ...], 

135 recovery_path: Path | None, 

136) -> str: 

137 quoted_args = " ".join(shlex.quote(part) for part in (command_name, *args)) 

138 recovery_literal = shlex.quote(str(recovery_path)) if recovery_path is not None else "''" 

139 lines = [ 

140 f"command workstack {quoted_args}", 

141 "__workstack_exit=$?", 

142 f"__workstack_recovery={recovery_literal}", 

143 'if [ -n "$__workstack_recovery" ] && [ -f "$__workstack_recovery" ]; then', 

144 ' if [ ! -d "$PWD" ]; then', 

145 ' . "$__workstack_recovery"', 

146 " fi", 

147 ' if [ -z "$WORKSTACK_KEEP_SCRIPTS" ]; then', 

148 ' rm -f "$__workstack_recovery"', 

149 " fi", 

150 "fi", 

151 "return $__workstack_exit", 

152 ] 

153 return "\n".join(lines) + "\n" 

154 

155 

156def _quote_fish(arg: str) -> str: 

157 if not arg: 

158 return '""' 

159 

160 escape_map = { 

161 "\\": "\\\\", 

162 '"': '\\"', 

163 "$": "\\$", 

164 "`": "\\`", 

165 "~": "\\~", 

166 "*": "\\*", 

167 "?": "\\?", 

168 "{": "\\{", 

169 "}": "\\}", 

170 "[": "\\[", 

171 "]": "\\]", 

172 "(": "\\(", 

173 ")": "\\)", 

174 "<": "\\<", 

175 ">": "\\>", 

176 "|": "\\|", 

177 ";": "\\;", 

178 "&": "\\&", 

179 } 

180 escaped_parts: list[str] = [] 

181 for char in arg: 

182 if char == "\n": 

183 escaped_parts.append("\\n") 

184 continue 

185 if char == "\t": 

186 escaped_parts.append("\\t") 

187 continue 

188 escaped_parts.append(escape_map.get(char, char)) 

189 

190 escaped = "".join(escaped_parts) 

191 return f'"{escaped}"' 

192 

193 

194def _render_fish_passthrough( 

195 command_name: str, 

196 args: tuple[str, ...], 

197 recovery_path: Path | None, 

198) -> str: 

199 command_parts = " ".join(_quote_fish(part) for part in (command_name, *args)) 

200 recovery_literal = _quote_fish(str(recovery_path)) if recovery_path is not None else '""' 

201 lines = [ 

202 f"command workstack {command_parts}", 

203 "set __workstack_exit $status", 

204 f"set __workstack_recovery {recovery_literal}", 

205 'if test -n "$__workstack_recovery"', 

206 ' if test -f "$__workstack_recovery"', 

207 ' if not test -d "$PWD"', 

208 ' source "$__workstack_recovery"', 

209 " end", 

210 " if not set -q WORKSTACK_KEEP_SCRIPTS", 

211 ' rm -f "$__workstack_recovery"', 

212 " end", 

213 " end", 

214 "end", 

215 "return $__workstack_exit", 

216 ] 

217 return "\n".join(lines) + "\n"