Coverage for structured_tutorials / models / tutorial.py: 100%

81 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-26 12:41 +0100

1# Copyright (c) 2025 Mathias Ertl 

2# Licensed under the MIT License. See LICENSE file for details. 

3 

4"""Module containing main tutorial model and global configuration models.""" 

5 

6from pathlib import Path 

7from typing import Annotated, Any 

8 

9from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, field_validator, model_validator 

10from pydantic_core.core_schema import ValidationInfo 

11from yaml import safe_load 

12 

13from structured_tutorials.models.base import default_tutorial_root_factory 

14from structured_tutorials.models.parts import AlternativeModel, PartModels, PromptModel, part_discriminator 

15from structured_tutorials.typing import Self 

16 

17 

18class DocumentationConfigurationModel(BaseModel): 

19 """Initial configuration for rendering the tutorial as documentation.""" 

20 

21 model_config = ConfigDict(extra="forbid", title="Documentation Configuration") 

22 

23 context: dict[str, Any] = Field( 

24 default_factory=dict, description="Key/value pairs for the initial context when rendering templates." 

25 ) 

26 alternative_names: dict[str, str] = Field( 

27 default_factory=dict, 

28 description="Names for alternative keys, used in tab titles. By default, the key itself is used.", 

29 ) 

30 

31 @model_validator(mode="after") 

32 def set_default_context(self) -> Self: 

33 self.context["run"] = False 

34 self.context["doc"] = True 

35 self.context.setdefault("user", "user") 

36 self.context.setdefault("host", "host") 

37 self.context.setdefault("cwd", "~") 

38 self.context.setdefault( 

39 "prompt_template", 

40 "{{ user }}@{{ host }}:{{ cwd }}{% if user == 'root' %}#{% else %}${% endif %} ", 

41 ) 

42 return self 

43 

44 

45class RuntimeConfigurationModel(BaseModel): 

46 """Initial configuration for running the tutorial.""" 

47 

48 model_config = ConfigDict(extra="forbid", title="Runtime Configuration") 

49 

50 context: dict[str, Any] = Field( 

51 default_factory=dict, description="Key/value pairs for the initial context when rendering templates." 

52 ) 

53 temporary_directory: bool = Field( 

54 default=False, description="Switch to an empty temporary directory before running the tutorial." 

55 ) 

56 git_export: bool = Field( 

57 default=False, 

58 description="Export a git archive to a temporary directory before running the tutorial.", 

59 ) 

60 environment: dict[str, str | None] = Field( 

61 default_factory=dict, 

62 description="Additional environment variables for all commands." 

63 "Set a value to `None` to clear it from the global environment.", 

64 ) 

65 clear_environment: bool = Field(default=False, description="Clear the environment for all commands.") 

66 

67 @model_validator(mode="after") 

68 def set_default_context(self) -> Self: 

69 self.context["doc"] = False 

70 self.context["run"] = True 

71 self.context["cwd"] = Path.cwd() 

72 return self 

73 

74 

75class ConfigurationModel(BaseModel): 

76 """Initial configuration of a tutorial.""" 

77 

78 model_config = ConfigDict(extra="forbid", title="Tutorial Configuration") 

79 

80 run: RuntimeConfigurationModel = RuntimeConfigurationModel() 

81 doc: DocumentationConfigurationModel = DocumentationConfigurationModel() 

82 context: dict[str, Any] = Field( 

83 default_factory=dict, description="Initial context for both documentation and runtime." 

84 ) 

85 

86 

87class TutorialModel(BaseModel): 

88 """Root structure for the entire tutorial.""" 

89 

90 model_config = ConfigDict(extra="forbid", title="Tutorial") 

91 

92 # absolute path to YAML file 

93 path: Path = Field( 

94 description="Absolute path to the tutorial file. This field is populated automatically while loading the tutorial.", # noqa: E501 

95 ) 

96 tutorial_root: Path = Field( 

97 default_factory=default_tutorial_root_factory, 

98 description="Directory from which relative file paths are resolved. Defaults to the path of the " 

99 "tutorial file.", 

100 ) # absolute path (input: relative to path) 

101 parts: tuple[ 

102 Annotated[ 

103 PartModels 

104 | Annotated[PromptModel, Tag("prompt")] 

105 | Annotated[AlternativeModel, Tag("alternatives")], 

106 Discriminator(part_discriminator), 

107 ], 

108 ..., 

109 ] = Field(description="The individual parts of this tutorial.") 

110 configuration: ConfigurationModel = Field(default=ConfigurationModel()) 

111 

112 @field_validator("path", mode="after") 

113 @classmethod 

114 def validate_path(cls, value: Path, info: ValidationInfo) -> Path: 

115 if not value.is_absolute(): 

116 raise ValueError(f"{value}: Must be an absolute path.") 

117 return value 

118 

119 @field_validator("tutorial_root", mode="after") 

120 @classmethod 

121 def resolve_tutorial_root(cls, value: Path, info: ValidationInfo) -> Path: 

122 if value.is_absolute(): 

123 raise ValueError(f"{value}: Must be a relative path (relative to the tutorial file).") 

124 path: Path = info.data["path"] 

125 

126 return (path.parent / value).resolve() 

127 

128 @model_validator(mode="after") 

129 def update_context(self) -> Self: 

130 self.configuration.run.context["tutorial_path"] = self.path 

131 self.configuration.run.context["tutorial_dir"] = self.path.parent 

132 self.configuration.doc.context["tutorial_path"] = self.path 

133 self.configuration.doc.context["tutorial_dir"] = self.path.parent 

134 return self 

135 

136 @model_validator(mode="after") 

137 def update_part_data(self) -> Self: 

138 for part_no, part in enumerate(self.parts): 

139 part.index = part_no 

140 if not part.id: 

141 part.id = str(part_no) 

142 return self 

143 

144 @classmethod 

145 def from_file(cls, path: Path) -> "TutorialModel": 

146 """Load a tutorial from a YAML file.""" 

147 with open(path) as stream: 

148 tutorial_data = safe_load(stream) 

149 

150 # e.g. an empty YAML file will return None 

151 if not isinstance(tutorial_data, dict): 

152 raise ValueError("File does not contain a mapping at top level.") 

153 

154 tutorial_data["path"] = path.resolve() 

155 tutorial = TutorialModel.model_validate(tutorial_data, context={"path": path}) 

156 return tutorial