Coverage for class_generator/utils.py: 92%

135 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-07-29 12:31 +0300

1"""Utilities for class generator.""" 

2 

3import ast 

4from dataclasses import dataclass, field 

5from pathlib import Path 

6from typing import Union 

7 

8from simple_logger.logger import get_logger 

9 

10from class_generator.constants import PYTHON_KEYWORD_MAPPINGS, VERSION_PRIORITY 

11 

12LOGGER = get_logger(name=__name__) 

13 

14 

15def sanitize_python_name(name: str) -> tuple[str, str]: 

16 """Sanitize Python reserved keywords by appending underscore.""" 

17 if name in PYTHON_KEYWORD_MAPPINGS: 

18 return PYTHON_KEYWORD_MAPPINGS[name], name 

19 return name, name 

20 

21 

22def get_latest_version(versions: list[str]) -> str: 

23 """ 

24 Get the latest version from a list of Kubernetes API versions. 

25 

26 Version precedence (from newest to oldest): 

27 - v2 > v1 > v1beta2 > v1beta1 > v1alpha2 > v1alpha1 

28 """ 

29 if not versions: 

30 return "" 

31 

32 # Sort versions by priority using imported constant 

33 sorted_versions = sorted(versions, key=lambda v: VERSION_PRIORITY.get(v.split("/")[-1], 0), reverse=True) 

34 

35 return sorted_versions[0] if sorted_versions else versions[0] 

36 

37 

38@dataclass 

39class ResourceInfo: 

40 """Information about a discovered resource class""" 

41 

42 name: str # Class name (e.g., "Pod", "Namespace") 

43 file_path: str # Path to the resource file 

44 base_class: str # "Resource" or "NamespacedResource" 

45 api_version: Union[str, None] = None 

46 api_group: Union[str, None] = None 

47 required_params: list[str] = field(default_factory=list) 

48 optional_params: list[str] = field(default_factory=list) 

49 has_containers: bool = False 

50 is_ephemeral: bool = False # True if resource is ephemeral (e.g. ProjectRequest) 

51 actual_resource_type: Union[str, None] = None # The actual resource type created (e.g. "Project") 

52 

53 

54class ResourceScanner: 

55 """Scans ocp_resources directory to discover resource classes""" 

56 

57 def __init__(self, ocp_resources_path: str = "ocp_resources"): 

58 self.ocp_resources_path = Path(ocp_resources_path) 

59 self.exclude_files = {"__init__.py", "resource.py", "exceptions.py", "utils"} 

60 

61 def scan_resources(self) -> list[ResourceInfo]: 

62 """Scan ocp_resources directory and extract all resource classes""" 

63 resources = [] 

64 

65 for py_file in self.ocp_resources_path.glob("*.py"): 

66 if py_file.name in self.exclude_files: 

67 continue 

68 

69 try: 

70 resource_info = self._analyze_resource_file(py_file) 

71 if resource_info: 

72 resources.append(resource_info) 

73 except Exception as e: 

74 LOGGER.warning(f"Failed to analyze {py_file}: {e}") 

75 

76 return sorted(resources, key=lambda r: r.name) 

77 

78 def _analyze_resource_file(self, file_path: Path) -> Union[ResourceInfo, None]: 

79 """Analyze a single resource file to extract class information""" 

80 with open(file_path, "r", encoding="utf-8") as f: 

81 content = f.read() 

82 

83 # Only consider resources with the generated marker comment 

84 if ( 

85 "# Generated using https://github.com/RedHatQE/openshift-python-wrapper/blob/main/scripts/resource/README.md" 

86 not in content 

87 ): 

88 return None 

89 

90 try: 

91 tree = ast.parse(content) 

92 except SyntaxError as e: 

93 LOGGER.error(f"Syntax error in {file_path}: {e}") 

94 return None 

95 

96 # Find resource classes 

97 for node in ast.walk(tree): 

98 if isinstance(node, ast.ClassDef): 

99 # Check if it inherits from Resource or NamespacedResource 

100 base_classes = [] 

101 for base in node.bases: 

102 if isinstance(base, ast.Name): 

103 base_classes.append(base.id) 

104 elif isinstance(base, ast.Attribute): 

105 base_classes.append(base.attr) 

106 

107 if "Resource" in base_classes or "NamespacedResource" in base_classes: 

108 return self._extract_resource_info(node, file_path, content) 

109 

110 return None 

111 

112 def _extract_resource_info(self, class_node: ast.ClassDef, file_path: Path, content: str) -> ResourceInfo: 

113 """Extract detailed information from a resource class""" 

114 name = class_node.name 

115 # Determine base class type 

116 base_class = "Resource" 

117 for base in class_node.bases: 

118 if isinstance(base, ast.Name) and base.id == "NamespacedResource": 

119 base_class = "NamespacedResource" 

120 break 

121 if isinstance(base, ast.Attribute) and base.attr == "NamespacedResource": 

122 base_class = "NamespacedResource" 

123 break 

124 

125 # Analyze __init__ method for parameters 

126 required_params, optional_params, has_containers = self._analyze_init_method(class_node) 

127 

128 # Analyze to_dict method for truly required parameters (those that raise MissingRequiredArgumentError) 

129 truly_required_params = self._analyze_to_dict_method(class_node) 

130 

131 # Override required_params with what's actually required in to_dict() 

132 if truly_required_params: 

133 required_params = truly_required_params 

134 

135 # Extract API version and group from class attributes or content 

136 api_version, api_group = self._extract_api_info(class_node, content) 

137 

138 # Detect ephemeral resources 

139 is_ephemeral, actual_resource_type = self._handle_ephemeral_resource(name) 

140 

141 return ResourceInfo( 

142 name=name, 

143 file_path=str(file_path), 

144 base_class=base_class, 

145 api_version=api_version, 

146 api_group=api_group, 

147 required_params=required_params, 

148 optional_params=optional_params, 

149 has_containers=has_containers, 

150 is_ephemeral=is_ephemeral, 

151 actual_resource_type=actual_resource_type, 

152 ) 

153 

154 def _analyze_init_method(self, class_node: ast.ClassDef) -> tuple[list[str], list[str], bool]: 

155 """Analyze __init__ method to find required and optional parameters""" 

156 required_params = [] 

157 optional_params = [] 

158 has_containers = False 

159 

160 for node in class_node.body: 

161 if isinstance(node, ast.FunctionDef) and node.name == "__init__": 

162 # Skip 'self' and '**kwargs' 

163 for arg in node.args.args[1:]: 

164 if arg.arg == "kwargs": 

165 continue 

166 param_name = arg.arg 

167 

168 # Check if parameter has default value by looking at defaults 

169 # In AST, defaults align with the end of args list 

170 defaults_start_idx = len(node.args.args) - len(node.args.defaults) 

171 arg_idx = node.args.args.index(arg) 

172 

173 if arg_idx >= defaults_start_idx: 

174 optional_params.append(param_name) 

175 else: 

176 required_params.append(param_name) 

177 

178 if param_name == "containers": 

179 has_containers = True 

180 

181 return required_params, optional_params, has_containers 

182 

183 def _analyze_to_dict_method(self, class_node: ast.ClassDef) -> list[str]: 

184 """Analyze to_dict method to find truly required parameters""" 

185 truly_required = [] 

186 

187 for node in class_node.body: 

188 if isinstance(node, ast.FunctionDef) and node.name == "to_dict": 

189 # Look for MissingRequiredArgumentError raises 

190 for stmt in ast.walk(node): 

191 if isinstance(stmt, ast.Raise): 

192 # Check if raising MissingRequiredArgumentError 

193 if isinstance(stmt.exc, ast.Call): 

194 if (isinstance(stmt.exc.func, ast.Name) and 

195 stmt.exc.func.id == "MissingRequiredArgumentError"): 

196 # Extract the parameter name from the argument 

197 for keyword in stmt.exc.keywords: 

198 if keyword.arg == "argument": 

199 # Handle string format like "self.param_name" 

200 if isinstance(keyword.value, ast.Constant): 

201 param = keyword.value.value 

202 if param.startswith("self."): 

203 param = param[5:] # Remove "self." 

204 truly_required.append(param) 

205 

206 return truly_required 

207 

208 def _extract_api_info(self, class_node: ast.ClassDef, content: str) -> tuple[Union[str, None], Union[str, None]]: 

209 """Extract API version and group from class attributes""" 

210 api_version = None 

211 api_group = None 

212 

213 # Look for api_version or api_group class attributes 

214 for node in class_node.body: 

215 if isinstance(node, ast.Assign): 

216 for target in node.targets: 

217 if isinstance(target, ast.Name): 

218 if target.id == "api_version" and isinstance(node.value, ast.Attribute): 

219 # Extract version like Resource.ApiVersion.V1 

220 if isinstance(node.value.attr, str): 

221 api_version = node.value.attr.lower() 

222 elif target.id == "api_group" and isinstance(node.value, ast.Attribute): 

223 # Extract group like NamespacedResource.ApiGroup.APPS 

224 if isinstance(node.value.attr, str): 

225 api_group = node.value.attr.lower().replace("_", ".") 

226 

227 return api_version, api_group 

228 

229 def _handle_ephemeral_resource(self, name: str) -> tuple[bool, Union[str, None]]: 

230 """Check if resource is ephemeral and get actual resource type""" 

231 # Simple mapping for known ephemeral resources 

232 ephemeral_resources = { 

233 "ProjectRequest": "Project", 

234 } 

235 

236 if name in ephemeral_resources: 

237 return True, ephemeral_resources[name] 

238 

239 return False, None