diff --git a/.github/workflows/sf_cli_integration.yml b/.github/workflows/sf_cli_integration.yml index b9ff172..b5ff65b 100644 --- a/.github/workflows/sf_cli_integration.yml +++ b/.github/workflows/sf_cli_integration.yml @@ -200,6 +200,10 @@ jobs: echo "::error::testFunction/.datacustomcode_proj/sdk_config.json not found after function init." exit 1 } + test -f testFunction/payload/tests/test.json || { + echo "::error::testFunction/payload/tests/test.json not found after function init." + exit 1 + } # ── Function: scan ──────────────────────────────────────────────────────── @@ -251,14 +255,18 @@ jobs: # ── Function: run ───────────────────────────────────────────────────────── - - name: '[function] run — sf data-code-extension function run --entrypoint testFunction/payload/entrypoint.py -o dev1' + - name: Install testFunction/requirements.txt + run: | + pip install -r testFunction/requirements.txt + + - name: '[function] run — sf data-code-extension function run --entrypoint testFunction/payload/entrypoint.py --test-with testFunction/payload/tests/test.json ' run: | sf data-code-extension function run \ --entrypoint testFunction/payload/entrypoint.py \ - -o dev1 || { - echo "::error::sf data-code-extension function run FAILED. Check mock server output above; the --entrypoint flag or SF CLI org auth contract may have changed." - exit 1 - } + --test-with testFunction/payload/tests/test.json || { + echo "::error::sf data-code-extension function run FAILED. Check mock server output above; the --entrypoint flag or SF CLI org auth contract may have changed." + exit 1 + } # ── Function: deploy ───────────────────────────────────────────────────── @@ -270,7 +278,6 @@ jobs: --description "Test function deploy" \ --package-dir testFunction/payload \ --cpu-size CPU_2XL \ - --function-invoke-opt UnstructuredChunking \ -o dev1 || { echo "::error::sf data-code-extension function deploy FAILED. Check mock server output above for which endpoint failed. The deploy command flags or API contract may have changed." exit 1 diff --git a/src/datacustomcode/cli.py b/src/datacustomcode/cli.py index c6e9c5c..84da911 100644 --- a/src/datacustomcode/cli.py +++ b/src/datacustomcode/cli.py @@ -27,6 +27,13 @@ from datacustomcode import AuthType from datacustomcode.auth import configure_oauth_tokens +from datacustomcode.constants import ( + CONFIG_FILE, + ENTRYPOINT_FILE, + PAYLOAD_DIR, + TEST_FILE, + TESTS_DIR, +) from datacustomcode.scan import find_base_directory, get_package_type @@ -74,6 +81,30 @@ def _configure_client_credentials( ) +def _generate_function_test_file(entrypoint_path: str) -> Optional[str]: + """Generate test.json file for a function. + + Args: + entrypoint_path: Path to the function's entrypoint.py + + Returns: + Path to generated test.json, or None if generation failed + """ + from datacustomcode.function_utils import generate_test_json + + tests_dir = os.path.join(os.path.dirname(entrypoint_path), TESTS_DIR) + os.makedirs(tests_dir, exist_ok=True) + test_json_path = os.path.join(tests_dir, TEST_FILE) + + try: + generate_test_json(entrypoint_path, test_json_path) + logger.debug(f"Generated test JSON at {test_json_path}") + return test_json_path + except Exception as e: + logger.warning(f"Could not generate test.json: {e}") + return None + + @cli.command() @click.option("--profile", default="default", help="Credential profile name") @click.option( @@ -162,7 +193,6 @@ def zip(path: str, network: str): Choose based on your workload requirements.""", ) -@click.option("--function-invoke-opt") @click.option( "--sf-cli-org", default=None, @@ -176,13 +206,14 @@ def deploy( cpu_size: str, profile: str, network: str, - function_invoke_opt: str, sf_cli_org: Optional[str], ): + from datacustomcode.constants import USE_IN_FEATURE_MAPPING_FOR_CONNECT_API from datacustomcode.deploy import ( COMPUTE_TYPES, CodeExtensionMetadata, deploy_full, + infer_use_in_feature, ) from datacustomcode.token_provider import ( CredentialsTokenProvider, @@ -211,15 +242,24 @@ def deploy( ) if package_type == "function": - if not function_invoke_opt: + # Infer use_in_feature from function signature + entrypoint_path = os.path.join(path, ENTRYPOINT_FILE) + use_in_feature = infer_use_in_feature(entrypoint_path) + if use_in_feature: + logger.info(f"Inferred use_in_feature: {use_in_feature}") + else: click.secho( - "Error: Function invoke options are required for function package type", + "Error: Could not infer function invoke options. " + "Please provide --use-in-feature", fg="red", ) raise click.Abort() - else: - function_invoke_options = function_invoke_opt.split(",") - metadata.functionInvokeOptions = function_invoke_options + + # Map user-provided feature names to API names + mapped_feature = USE_IN_FEATURE_MAPPING_FOR_CONNECT_API.get( + use_in_feature, use_in_feature + ) + metadata.functionInvokeOptions = [mapped_feature] try: if sf_cli_org: @@ -238,7 +278,12 @@ def deploy( @click.option( "--code-type", default="script", type=click.Choice(["script", "function"]) ) -def init(directory: str, code_type: str): +@click.option( + "--use-in-feature", + default="SearchIndexChunking", + help="Feature where this function will be used (only applicable for function).", +) +def init(directory: str, code_type: str, use_in_feature: Optional[str]): from datacustomcode.scan import ( dc_config_json_from_file, update_config, @@ -250,9 +295,9 @@ def init(directory: str, code_type: str): if code_type == "script": copy_script_template(directory) elif code_type == "function": - copy_function_template(directory) - entrypoint_path = os.path.join(directory, "payload", "entrypoint.py") - config_location = os.path.join(os.path.dirname(entrypoint_path), "config.json") + copy_function_template(directory, use_in_feature) + entrypoint_path = os.path.join(directory, PAYLOAD_DIR, ENTRYPOINT_FILE) + config_location = os.path.join(os.path.dirname(entrypoint_path), CONFIG_FILE) # Write package type to SDK-specific config sdk_config = {"type": code_type} @@ -265,6 +310,7 @@ def init(directory: str, code_type: str): updated_config_json = update_config(entrypoint_path) with open(config_location, "w") as f: json.dump(updated_config_json, f, indent=2) + click.echo( "Start developing by updating the code in " + click.style(entrypoint_path, fg="blue", bold=True) @@ -275,6 +321,24 @@ def init(directory: str, code_type: str): + " to automatically update config.json when you make changes to your code" ) + # Generate test.json for functions + if code_type == "function": + test_json_path = _generate_function_test_file(entrypoint_path) + if test_json_path: + click.echo( + "Generated test file at " + + click.style(test_json_path, fg="blue", bold=True) + ) + click.echo( + "Test your function locally with " + + click.style( + f"datacustomcode run {entrypoint_path} " + f"--test-with {test_json_path}", + fg="blue", + bold=True, + ) + ) + @cli.command() @click.argument("filename") @@ -286,7 +350,7 @@ def init(directory: str, code_type: str): def scan(filename: str, config: str, dry_run: bool, no_requirements: bool): from datacustomcode.scan import update_config, write_requirements_file - config_location = config or os.path.join(os.path.dirname(filename), "config.json") + config_location = config or os.path.join(os.path.dirname(filename), CONFIG_FILE) click.echo( "Dumping scan results to config file: " + click.style(config_location, fg="blue", bold=True) @@ -312,6 +376,12 @@ def scan(filename: str, config: str, dry_run: bool, no_requirements: bool): @click.option("--config-file", default=None) @click.option("--dependencies", default=[], multiple=True) @click.option("--profile", default="default") +@click.option( + "--test-with", + default=None, + type=click.Path(exists=True), + help="Path to test JSON file for function testing", +) @click.option( "--sf-cli-org", default=None, @@ -322,10 +392,16 @@ def run( config_file: Union[str, None], dependencies: List[str], profile: str, + test_with: Optional[str], sf_cli_org: Optional[str], ): from datacustomcode.run import run_entrypoint run_entrypoint( - entrypoint, config_file, dependencies, profile, sf_cli_org=sf_cli_org + entrypoint, + config_file, + dependencies, + profile, + test_file=test_with, + sf_cli_org=sf_cli_org, ) diff --git a/src/datacustomcode/constants.py b/src/datacustomcode/constants.py new file mode 100644 index 0000000..76b6a7c --- /dev/null +++ b/src/datacustomcode/constants.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used throughout the datacustomcode package.""" + +# File and directory names +ENTRYPOINT_FILE = "entrypoint.py" +CONFIG_FILE = "config.json" +PAYLOAD_DIR = "payload" +TESTS_DIR = "tests" +TEST_FILE = "test.json" +REQUIREMENTS_FILE = "requirements.txt" + +# Default values +DEFAULT_PROFILE = "default" +DEFAULT_NETWORK = "default" +DEFAULT_CPU_SIZE = "CPU_2XL" + +# Feature to template folder mapping +FEATURE_TEMPLATE_MAPPING = { + "SearchIndexChunking": "chunking", +} + +# Feature name to Connect API name mapping +USE_IN_FEATURE_MAPPING_FOR_CONNECT_API = { + "SearchIndexChunking": "UnstructuredChunking", +} + +# Pydantic request/response type names to feature names +REQUEST_TYPE_TO_FEATURE = { + "SearchIndexChunkingV1Request": "SearchIndexChunking", + "SearchIndexChunkingV1Response": "SearchIndexChunking", +} diff --git a/src/datacustomcode/deploy.py b/src/datacustomcode/deploy.py index 114252a..7fcd3ea 100644 --- a/src/datacustomcode/deploy.py +++ b/src/datacustomcode/deploy.py @@ -35,6 +35,7 @@ import requests from datacustomcode.cmd import cmd_output +from datacustomcode.constants import REQUEST_TYPE_TO_FEATURE from datacustomcode.scan import find_base_directory, get_package_type DATA_CUSTOM_CODE_PATH = "services/data/v63.0/ssot/data-custom-code" @@ -65,6 +66,40 @@ def _sanitize_api_name(name: str) -> str: return sanitized +def infer_use_in_feature(entrypoint_path: str) -> Union[str, None]: + """Infer the use_in_feature from function signature. + + Checks both the request parameter type and return type annotation. + Both must map to the same feature for a valid inference. + + Uses static AST parsing to avoid importing dependencies. + + Args: + entrypoint_path: Path to the entrypoint.py file + + Returns: + The feature name if both request and response match, None otherwise + """ + from datacustomcode.function_utils import inspect_function_types_static + + request_type_name, response_type_name = inspect_function_types_static( + entrypoint_path + ) + + if not request_type_name or not response_type_name: + return None + + # Look up features for both types + request_feature = REQUEST_TYPE_TO_FEATURE.get(request_type_name) + response_feature = REQUEST_TYPE_TO_FEATURE.get(response_type_name) + + # Both must be present and must match + if request_feature and response_feature and request_feature == response_feature: + return request_feature + + return None + + class CodeExtensionMetadata(BaseModel): name: str version: str @@ -205,6 +240,11 @@ def create_deployment( def prepare_dependency_archive( directory: str, docker_network: str, package_type: str ) -> None: + # The parent directory of 'directory' contains Dockerfile.dependencies, + # requirements.txt, and build_native_dependencies.sh + # (same location checked by has_nonempty_requirements_file) + parent_dir = os.path.dirname(directory) + cmd = f"docker images -q {DOCKER_IMAGE_NAME}" image_exists = cmd_output(cmd) @@ -213,7 +253,8 @@ def prepare_dependency_archive( if not image_exists: logger.info(f"Building docker image with docker network: {docker_network}...") cmd = docker_build_cmd(docker_network) - cmd_output(cmd, env=docker_env) + # Run docker build from parent_dir where Dockerfile.dependencies exists + cmd_output(cmd, env=docker_env, cwd=parent_dir) # ignore_cleanup_errors=True: on Windows, Docker creates files inside the # mounted volume whose permissions prevent the host from deleting them. @@ -223,9 +264,11 @@ def prepare_dependency_archive( logger.info( f"Building dependencies archive with docker network: {docker_network}" ) - shutil.copy("requirements.txt", temp_dir) - shutil.copy("build_native_dependencies.sh", temp_dir) + # Copy from parent_dir where files actually exist + shutil.copy(os.path.join(parent_dir, "requirements.txt"), temp_dir) + shutil.copy(os.path.join(parent_dir, "build_native_dependencies.sh"), temp_dir) cmd = docker_run_cmd(docker_network, temp_dir) + # Docker run doesn't need cwd since temp_dir is absolute and mounted cmd_output(cmd, env=docker_env) if package_type == "function": source_py_files = os.path.join(temp_dir, "py-files") diff --git a/src/datacustomcode/function/feature_types/chunking.py b/src/datacustomcode/function/feature_types/chunking.py index bdf0d91..1a2f1d7 100644 --- a/src/datacustomcode/function/feature_types/chunking.py +++ b/src/datacustomcode/function/feature_types/chunking.py @@ -22,13 +22,12 @@ Any, Dict, List, - Literal, ) from pydantic import BaseModel, Field -class DocElement(BaseModel): +class SearchIndexDocElement(BaseModel): """Document element to be chunked""" text: str = Field(..., description="Text content to be chunked") @@ -37,7 +36,7 @@ class DocElement(BaseModel): ) -class ChunkOutput(BaseModel): +class SearchIndexChunkOutput(BaseModel): """Output chunk from the chunking process""" chunk_id: str = Field(..., description="UUID for this chunk") @@ -55,20 +54,17 @@ class ChunkOutput(BaseModel): ) -class StatusResponse(BaseModel): +class SearchIndexStatusResponse(BaseModel): """Status response for operation""" status_type: str = Field(..., description="'success' or 'error'") status_message: str = Field(..., description="Human-readable status") -class UdsChunkingV1BatchRequest(BaseModel): +class SearchIndexChunkingV1Request(BaseModel): """Batch request for UDS chunking""" - version: Literal["v1"] = Field( - default="v1", description="API version, must be 'v1'" - ) - input: List[DocElement] = Field( + input: List[SearchIndexDocElement] = Field( ..., min_length=1, description="List of documents (min 1)" ) max_characters: int = Field(..., description="Max chars per chunk (default: 100)") @@ -77,13 +73,12 @@ class UdsChunkingV1BatchRequest(BaseModel): ) -class UdsChunkingV1BatchResponse(BaseModel): +class SearchIndexChunkingV1Response(BaseModel): """Batch response for UDS chunking""" - version: Literal["v1"] = Field( - default="v1", description="API version, must be 'v1'" - ) - output: List[ChunkOutput] = Field( + output: List[SearchIndexChunkOutput] = Field( default_factory=list, description="Flat list of chunks from all docs" ) - status: StatusResponse = Field(..., description="Overall operation status") + status: SearchIndexStatusResponse = Field( + ..., description="Overall operation status" + ) diff --git a/src/datacustomcode/function_utils.py b/src/datacustomcode/function_utils.py new file mode 100644 index 0000000..8e6f12e --- /dev/null +++ b/src/datacustomcode/function_utils.py @@ -0,0 +1,364 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for inspecting and working with function entrypoints.""" + +import ast +import importlib.util +import inspect +import json +import typing +from typing import ( + Any, + Optional, + Tuple, +) + + +def load_function_module(entrypoint_path: str, module_name: str = "function_module"): + """Load a function entrypoint as a Python module. + + Args: + entrypoint_path: Path to the entrypoint.py file + module_name: Name to assign to the module + + Returns: + The loaded module + + Raises: + ImportError: If the module cannot be loaded + """ + spec = importlib.util.spec_from_file_location(module_name, entrypoint_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load module from {entrypoint_path}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def get_function_callable(module): + """Get the 'function' callable from a module. + + Args: + module: The module to extract the function from + + Returns: + The function callable + + Raises: + AttributeError: If module doesn't have a 'function' attribute + """ + if not hasattr(module, "function"): + raise AttributeError("Module does not have a 'function' callable") + return module.function + + +def get_type_name(type_annotation: Any) -> Optional[str]: + """Extract the type name from a type annotation. + + Args: + type_annotation: A type annotation object + + Returns: + The type name as a string, or None if it cannot be determined + """ + if type_annotation == inspect.Parameter.empty: + return None + + if hasattr(type_annotation, "__name__"): + return str(type_annotation.__name__) + + return str(type_annotation) + + +def get_function_signature_types( + function_callable, +) -> Tuple[Optional[Any], Optional[Any], Optional[str], Optional[str]]: + """Extract request and response types from a function signature. + + Args: + function_callable: The function to inspect + + Returns: + Tuple of (request_type, response_type, request_type_name, response_type_name) + Any of these can be None if not found + """ + sig = inspect.signature(function_callable) + params = list(sig.parameters.values()) + + request_type = None + request_type_name = None + if len(params) >= 1: + request_type = params[0].annotation + request_type_name = get_type_name(request_type) + + response_type = sig.return_annotation + response_type_name = get_type_name(response_type) + + return request_type, response_type, request_type_name, response_type_name + + +def inspect_function_types_static( + entrypoint_path: str, +) -> Tuple[Optional[str], Optional[str]]: + """Inspect function types using static AST parsing (no imports). + + This parses the Python file without executing it, so it doesn't + require dependencies to be installed. + + Args: + entrypoint_path: Path to the entrypoint.py file + + Returns: + Tuple of (request_type_name, response_type_name) + """ + try: + with open(entrypoint_path, "r") as f: + tree = ast.parse(f.read(), filename=entrypoint_path) + + # Find the 'function' definition + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == "function": + # Get request type (first parameter annotation) + request_type_name = None + if node.args.args and len(node.args.args) > 0: + first_param = node.args.args[0] + if first_param.annotation: + request_type_name = _get_type_name_from_ast( + first_param.annotation + ) + + # Get response type (return annotation) + response_type_name = None + if node.returns: + response_type_name = _get_type_name_from_ast(node.returns) + + return request_type_name, response_type_name + + return None, None + except Exception: + return None, None + + +def _get_type_name_from_ast(annotation) -> Optional[str]: + """Extract type name from an AST annotation node.""" + if isinstance(annotation, ast.Name): + # Simple type: MyType + return annotation.id + elif isinstance(annotation, ast.Attribute): + # Module.Type - just return the type name + return annotation.attr + elif isinstance(annotation, ast.Subscript): + # Generic type: List[MyType], Optional[MyType] + # Return the base type name + return _get_type_name_from_ast(annotation.value) + return None + + +def _import_pydantic_model(entrypoint_path: str, type_name: str) -> Optional[Any]: + """Import a Pydantic model by finding its import statement. + + Parses the entrypoint to find where the type is imported from, + then imports just that module (not the entrypoint itself). + + Args: + entrypoint_path: Path to entrypoint.py + type_name: Name of the type to import (e.g., "SearchIndexChunkingV1Request") + + Returns: + The Pydantic model class, or None if not found + """ + try: + with open(entrypoint_path, "r") as f: + tree = ast.parse(f.read(), filename=entrypoint_path) + + # Find where this type is imported from + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom): + # from module import Type1, Type2 + for alias in node.names: + if alias.name == type_name: + # Found it! Import from the module + module_name = node.module + if module_name: + module = importlib.import_module(module_name) + return getattr(module, type_name, None) + + return None + except Exception: + return None + + +def inspect_function_types( + entrypoint_path: str, +) -> Tuple[Optional[str], Optional[str]]: + """Inspect a function entrypoint and extract type names. + + Args: + entrypoint_path: Path to the entrypoint.py file + + Returns: + Tuple of (request_type_name, response_type_name) + Either can be None if not found or on error + + Example: + >>> request_name, response_name = inspect_function_types( + ... "payload/entrypoint.py" + ... ) + >>> print(request_name) # "SearchIndexChunkingV1Request" + >>> print(response_name) # "SearchIndexChunkingV1Response" + """ + try: + module = load_function_module(entrypoint_path, "temp_module") + function_callable = get_function_callable(module) + _, _, request_type_name, response_type_name = get_function_signature_types( + function_callable + ) + return request_type_name, response_type_name + except Exception: + return None, None + + +def get_request_type(entrypoint_path: str) -> Any: + """Get the request type annotation from a function entrypoint. + + Args: + entrypoint_path: Path to the entrypoint.py file + + Returns: + The request type (Pydantic model class) + + Raises: + ImportError: If the module cannot be loaded + AttributeError: If the function is not found + ValueError: If the function signature is invalid + """ + module = load_function_module(entrypoint_path) + function_callable = get_function_callable(module) + + sig = inspect.signature(function_callable) + params = list(sig.parameters.values()) + + if len(params) < 1: + raise ValueError("Function must accept at least one parameter (request)") + + request_type = params[0].annotation + if request_type == inspect.Parameter.empty: + raise ValueError("Function request parameter must have a type annotation") + + return request_type + + +def _generate_model_sample_data(model_type): + """Generate sample data for all fields in a Pydantic model. + + Args: + model_type: A Pydantic model class + + Returns: + Dictionary with sample data for all fields + """ + from pydantic_core import PydanticUndefined + + sample_data = {} + for field_name, field_info in model_type.model_fields.items(): + # Check if field has a real default value + if field_info.default is not PydanticUndefined: + sample_data[field_name] = field_info.default + else: + # Required field or field without default - generate sample + sample_data[field_name] = generate_sample_value( + field_info.annotation, field_name + ) + return sample_data + + +def generate_sample_value(field_type, field_name: str): + """Generate a sample value based on field type. + + Args: + field_type: The type annotation of the field + field_name: The name of the field (used for contextual sample generation) + + Returns: + A sample value appropriate for the field type + """ + origin = typing.get_origin(field_type) + + if origin is list or field_type is list: + args = typing.get_args(field_type) + if args: + return [generate_sample_value(args[0], field_name)] + return [] + elif origin is dict or field_type is dict: + return {} + elif field_type is str or origin is typing.Literal: + if "version" in field_name.lower(): + return "v1" + return f"sample_{field_name}" + elif field_type is int: + if "max" in field_name.lower() or "characters" in field_name.lower(): + return 100 + return 1 + elif field_type is float: + return 1.0 + elif field_type is bool: + return True + elif hasattr(field_type, "model_fields"): + # Nested Pydantic model - use shared helper + return _generate_model_sample_data(field_type) + else: + return None + + +def generate_test_json(entrypoint_path: str, output_path: str) -> None: + """Generate a sample test.json file for a function. + + First tries static AST parsing to get type names, then uses those + to import only the Pydantic model classes (not the entrypoint). + + Args: + entrypoint_path: Path to the function entrypoint.py + output_path: Output path for test.json + + Raises: + ImportError: If the Pydantic model cannot be loaded + ValueError: If the request type is not found or not a Pydantic model + """ + # First, get the type name using static parsing (no imports) + request_type_name, _ = inspect_function_types_static(entrypoint_path) + + if not request_type_name: + raise ValueError("Could not determine request type from function signature") + + # Now try to import the Pydantic model class + # Look for it in the entrypoint's imports + request_type = _import_pydantic_model(entrypoint_path, request_type_name) + + if not request_type: + raise ValueError(f"Could not import Pydantic model: {request_type_name}") + + # Check if it's a Pydantic model + if not hasattr(request_type, "model_fields"): + raise ValueError("Request parameter type must be a Pydantic model") + + # Generate sample data for ALL fields (use defaults where available) + sample_data = _generate_model_sample_data(request_type) + sample_instance = request_type(**sample_data) + + # Write to file + with open(output_path, "w") as f: + json.dump(sample_instance.model_dump(), f, indent=2) diff --git a/src/datacustomcode/run.py b/src/datacustomcode/run.py index 0e5052a..6322270 100644 --- a/src/datacustomcode/run.py +++ b/src/datacustomcode/run.py @@ -70,6 +70,7 @@ def run_entrypoint( config_file: Union[str, None], dependencies: List[str], profile: str, + test_file: Optional[str] = None, sf_cli_org: Optional[str] = None, ) -> None: """Run the entrypoint for script or function with the given config and dependencies. @@ -79,6 +80,7 @@ def run_entrypoint( config_file: The config file to use. dependencies: The dependencies to import. profile: The credentials profile to use. + test_file: Optional test JSON file for function testing. sf_cli_org: Optional SF CLI org alias or username. If provided, credentials are fetched via `sf org display` instead of from credentials.ini. """ @@ -138,7 +140,64 @@ def run_entrypoint( raise exc except (ModuleNotFoundError, AttributeError) as inner_exc: raise inner_exc from exc - runpy.run_path(entrypoint, init_globals=globals(), run_name="__main__") + + # Handle test file for functions + if test_file and package_type == "function": + run_function_with_test(entrypoint, test_file) + else: + runpy.run_path(entrypoint, init_globals=globals(), run_name="__main__") + + +def run_function_with_test(entrypoint: str, test_file: str) -> None: + """Run a function with test data from a JSON file. + + Dependencies are already loaded by this point, so we just import + the entrypoint module and call the function directly. + + Args: + entrypoint: Path to the function entrypoint.py + test_file: Path to test JSON file containing request data + """ + from datacustomcode.function_utils import ( + get_function_callable, + get_request_type, + load_function_module, + ) + + # Import the entrypoint module in the current environment + # (with all dependencies loaded) + module = load_function_module(entrypoint, "entrypoint_module") + function_callable = get_function_callable(module) + request_type = get_request_type(entrypoint) + + # Load and parse the test JSON + with open(test_file, "r") as f: + test_data = json.load(f) + + # Use Pydantic to parse and validate the request + try: + request = request_type(**test_data) + except Exception as e: + raise ValueError( + f"Failed to parse test data as {request_type.__name__}: {e}" + ) from e + + # Import Runtime + from datacustomcode.function import Runtime + + # Call the function with test data + print(f"Running function with test data from {test_file}...") + result = function_callable(request, Runtime()) + + # Pretty print the result + print("\n" + "=" * 80) + print("RESULT:") + print("=" * 80) + if hasattr(result, "model_dump"): + print(json.dumps(result.model_dump(), indent=2)) + else: + print(result) + print("=" * 80) def add_py_folder(entrypoint: str): diff --git a/src/datacustomcode/template.py b/src/datacustomcode/template.py index 195d4a2..6807510 100644 --- a/src/datacustomcode/template.py +++ b/src/datacustomcode/template.py @@ -14,9 +14,12 @@ # limitations under the License. import os import shutil +from typing import Optional from loguru import logger +from datacustomcode.constants import FEATURE_TEMPLATE_MAPPING + script_template_dir = os.path.join(os.path.dirname(__file__), "templates", "script") function_template_dir = os.path.join(os.path.dirname(__file__), "templates", "function") @@ -37,16 +40,42 @@ def copy_script_template(target_dir: str) -> None: shutil.copy2(source, destination) -def copy_function_template(target_dir: str) -> None: +def copy_function_template(target_dir: str, use_in_feature: Optional[str]) -> None: os.makedirs(target_dir, exist_ok=True) + # First, copy common files from base function template for item in os.listdir(function_template_dir): source = os.path.join(function_template_dir, item) destination = os.path.join(target_dir, item) + # Skip feature-specific subdirectories + if os.path.isdir(source) and item in FEATURE_TEMPLATE_MAPPING.values(): + continue + if os.path.isdir(source): logger.debug(f"Copying directory {source} to {destination}...") shutil.copytree(source, destination, dirs_exist_ok=True) else: logger.debug(f"Copying file {source} to {destination}...") shutil.copy2(source, destination) + + # Then, copy feature-specific files (overwriting if needed) + if use_in_feature and use_in_feature in FEATURE_TEMPLATE_MAPPING: + feature_function_template_dir = os.path.join( + function_template_dir, FEATURE_TEMPLATE_MAPPING[use_in_feature] + ) + + for item in os.listdir(feature_function_template_dir): + source = os.path.join(feature_function_template_dir, item) + destination = os.path.join(target_dir, item) + + if os.path.isdir(source): + logger.debug( + f"Copying feature-specific directory {source} to {destination}..." + ) + shutil.copytree(source, destination, dirs_exist_ok=True) + else: + logger.debug( + f"Copying feature-specific file {source} to {destination}..." + ) + shutil.copy2(source, destination) diff --git a/src/datacustomcode/templates/__init__.py b/src/datacustomcode/templates/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/datacustomcode/templates/function/chunking/payload/config.json b/src/datacustomcode/templates/function/chunking/payload/config.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/src/datacustomcode/templates/function/chunking/payload/config.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/src/datacustomcode/templates/function/chunking/payload/entrypoint.py b/src/datacustomcode/templates/function/chunking/payload/entrypoint.py new file mode 100644 index 0000000..311a51b --- /dev/null +++ b/src/datacustomcode/templates/function/chunking/payload/entrypoint.py @@ -0,0 +1,74 @@ +import logging + +from langchain_text_splitters import RecursiveCharacterTextSplitter + +from datacustomcode.function import Runtime +from datacustomcode.function.feature_types.chunking import ( + SearchIndexChunkingV1Request, + SearchIndexChunkingV1Response, + SearchIndexChunkOutput, + SearchIndexStatusResponse, +) + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def function( + request: SearchIndexChunkingV1Request, runtime: Runtime +) -> SearchIndexChunkingV1Response: + print(f"Received {len(request.input)} documents to chunk") + print(f"Max characters per chunk: {request.max_characters}") + + # Initialize RecursiveCharacterTextSplitter + # It tries to split on: "\n\n", "\n", " ", "" (in that order) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=request.max_characters, + chunk_overlap=20, # Small overlap to maintain context + length_function=len, + separators=["\n\n", "\n", " ", ""], + ) + + chunks = [] + chunk_id = 1 + + # Process each document + for doc_idx, doc in enumerate(request.input): + text = doc.text + metadata = doc.metadata if hasattr(doc.metadata, "__iter__") else {} + + print(f"📄 Processing document {doc_idx + 1}: {len(text)} characters") + + # Split the text using RecursiveCharacterTextSplitter + text_chunks = text_splitter.split_text(text) + + # Create chunk outputs + for chunk_text in text_chunks: + chunk_output = SearchIndexChunkOutput( + chunk_id=f"chunk_{chunk_id:04d}", + chunk_type="text", + text=chunk_text.strip(), + seq_no=chunk_id, + metadata={ + k: str(v) for k, v in (dict(metadata) if metadata else {}).items() + }, + tag_metadata={}, + citations={}, + ) + chunks.append(chunk_output) + + print(f"Chunk {chunk_id}: {len(chunk_text)} chars") + chunk_id += 1 + + print(f"Generated {len(chunks)} chunks total") + + return SearchIndexChunkingV1Response( + output=chunks, + status=SearchIndexStatusResponse( + status_type="success", + status_message=( + f"Successfully chunked {len(request.input)} documents " + f"into {len(chunks)} chunks" + ), + ), + ) diff --git a/src/datacustomcode/templates/function/chunking/requirements.txt b/src/datacustomcode/templates/function/chunking/requirements.txt new file mode 100644 index 0000000..7f5990c --- /dev/null +++ b/src/datacustomcode/templates/function/chunking/requirements.txt @@ -0,0 +1,2 @@ +# Packages required for the chunking function +langchain-text-splitters>=0.3.0 diff --git a/src/datacustomcode/templates/script/requirements.txt b/src/datacustomcode/templates/script/requirements.txt index b4fc884..377c7cc 100644 --- a/src/datacustomcode/templates/script/requirements.txt +++ b/src/datacustomcode/templates/script/requirements.txt @@ -1 +1,2 @@ # Packages required for the custom code +langchain-text-splitters>=0.3.0 diff --git a/tests/test_cli.py b/tests/test_cli.py index e26cbdc..7765560 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -103,16 +103,19 @@ def test_deploy_command_success(self, mock_token_provider, mock_deploy_full): assert call_args[0][2].access_token == "test_token" assert call_args[0][2].instance_url == "https://instance.example.com" + @patch("datacustomcode.deploy.infer_use_in_feature") @patch("datacustomcode.deploy.deploy_full") @patch("datacustomcode.token_provider.CredentialsTokenProvider") def test_deploy_command_function_invoke_options( - self, mock_token_provider, mock_deploy_full + self, mock_token_provider, mock_deploy_full, mock_infer_feature ): """Test deploy command with function invoke options.""" mock_provider_instance = mock_token_provider.return_value mock_provider_instance.get_token.return_value = AccessTokenResponse( access_token="test_token", instance_url="https://instance.example.com" ) + # Mock infer_use_in_feature to return a valid feature + mock_infer_feature.return_value = "SearchIndexChunking" runner = CliRunner() with runner.isolated_filesystem(): @@ -122,16 +125,12 @@ def test_deploy_command_function_invoke_options( write_sdk_config(".", sdk_config) result = runner.invoke( deploy, - ["--name", "test-job", "--function-invoke-opt", "option1,option2"], + ["--name", "test-job"], ) assert result.exit_code == 0 mock_deploy_full.assert_called_once() - # Check that deploy_full was called with function invoke options - call_args = mock_deploy_full.call_args - assert call_args[0][1].functionInvokeOptions == ["option1", "option2"] - @patch("datacustomcode.token_provider.CredentialsTokenProvider") def test_deploy_command_credentials_error(self, mock_token_provider): """Test deploy command when credentials are not available.""" diff --git a/tests/test_deploy.py b/tests/test_deploy.py index 2fd3ce2..6b33cc7 100644 --- a/tests/test_deploy.py +++ b/tests/test_deploy.py @@ -59,6 +59,7 @@ class TestPrepareDependencyArchive: @patch("datacustomcode.deploy.shutil.copy") @patch("datacustomcode.deploy.tempfile.TemporaryDirectory") @patch("datacustomcode.deploy.os.path.join") + @patch("datacustomcode.deploy.os.path.dirname") @patch("datacustomcode.deploy.os.makedirs") @patch("datacustomcode.deploy.docker_build_cmd") @patch("datacustomcode.deploy.docker_run_cmd") @@ -67,6 +68,7 @@ def test_prepare_dependency_archive_image_exists( mock_docker_run_cmd, mock_docker_build_cmd, mock_makedirs, + mock_dirname, mock_join, mock_temp_dir, mock_copy, @@ -82,8 +84,34 @@ def test_prepare_dependency_archive_image_exists( # Mock cmd_output to return image ID (indicating image exists) mock_cmd_output.return_value = "abc123" - # Mock os.path.join for archive path - mock_join.return_value = "/tmp/test_dir/native_dependencies.tar.gz" + # Mock os.path.dirname to handle different calls + def dirname_side_effect(path): + if path == "/test/dir": + return "/test" + elif path == "payload/py-files": + return "payload" + else: + # For other paths, do simple string manipulation + return path.rsplit("/", 1)[0] if "/" in path else "" + + mock_dirname.side_effect = dirname_side_effect + + # Mock os.path.join to handle different calls + def join_side_effect(*args): + if args == ("/test", "requirements.txt"): + return "/test/requirements.txt" + elif args == ("/test", "build_native_dependencies.sh"): + return "/test/build_native_dependencies.sh" + elif args == ("/tmp/test_dir", "native_dependencies.tar.gz"): + return "/tmp/test_dir/native_dependencies.tar.gz" + elif args == ("payload", "archives", "native_dependencies.tar.gz"): + return "payload/archives/native_dependencies.tar.gz" + elif args == ("payload", "archives"): + return "payload/archives" + else: + return "/".join(args) + + mock_join.side_effect = join_side_effect # Mock the docker command functions mock_docker_build_cmd.return_value = "mock build command" @@ -97,9 +125,9 @@ def test_prepare_dependency_archive_image_exists( # Verify docker build command was not called (since image already exists) mock_docker_build_cmd.assert_not_called() - # Verify files were copied to temp directory - mock_copy.assert_any_call("requirements.txt", "/tmp/test_dir") - mock_copy.assert_any_call("build_native_dependencies.sh", "/tmp/test_dir") + # Verify files were copied to temp directory from parent directory + mock_copy.assert_any_call("/test/requirements.txt", "/tmp/test_dir") + mock_copy.assert_any_call("/test/build_native_dependencies.sh", "/tmp/test_dir") # Verify docker run command was called mock_docker_run_cmd.assert_called_once_with("default", "/tmp/test_dir") @@ -118,6 +146,7 @@ def test_prepare_dependency_archive_image_exists( @patch("datacustomcode.deploy.shutil.copy") @patch("datacustomcode.deploy.tempfile.TemporaryDirectory") @patch("datacustomcode.deploy.os.path.join") + @patch("datacustomcode.deploy.os.path.dirname") @patch("datacustomcode.deploy.os.makedirs") @patch("datacustomcode.deploy.docker_build_cmd") @patch("datacustomcode.deploy.docker_run_cmd") @@ -126,6 +155,7 @@ def test_prepare_dependency_archive_build_image( mock_docker_run_cmd, mock_docker_build_cmd, mock_makedirs, + mock_dirname, mock_join, mock_temp_dir, mock_copy, @@ -142,8 +172,33 @@ def test_prepare_dependency_archive_build_image( # and then return some value for subsequent calls mock_cmd_output.side_effect = [None, None, None, None] - # Mock os.path.join for archive path - mock_join.return_value = "/tmp/test_dir/native_dependencies.tar.gz" + # Mock os.path.dirname to handle different calls + def dirname_side_effect(path): + if path == "/test/dir": + return "/test" + elif path == "payload/py-files": + return "payload" + else: + return path.rsplit("/", 1)[0] if "/" in path else "" + + mock_dirname.side_effect = dirname_side_effect + + # Mock os.path.join to handle different calls + def join_side_effect(*args): + if args == ("/test", "requirements.txt"): + return "/test/requirements.txt" + elif args == ("/test", "build_native_dependencies.sh"): + return "/test/build_native_dependencies.sh" + elif args == ("/tmp/test_dir", "native_dependencies.tar.gz"): + return "/tmp/test_dir/native_dependencies.tar.gz" + elif args == ("payload", "archives", "native_dependencies.tar.gz"): + return "payload/archives/native_dependencies.tar.gz" + elif args == ("payload", "archives"): + return "payload/archives" + else: + return "/".join(args) + + mock_join.side_effect = join_side_effect # Mock the docker command functions mock_docker_build_cmd.return_value = "mock build command" @@ -156,11 +211,11 @@ def test_prepare_dependency_archive_build_image( # Verify docker build command was called mock_docker_build_cmd.assert_called_once_with("default") - mock_cmd_output.assert_any_call("mock build command", env=ANY) + mock_cmd_output.assert_any_call("mock build command", env=ANY, cwd="/test") - # Verify files were copied to temp directory - mock_copy.assert_any_call("requirements.txt", "/tmp/test_dir") - mock_copy.assert_any_call("build_native_dependencies.sh", "/tmp/test_dir") + # Verify files were copied to temp directory from parent directory + mock_copy.assert_any_call("/test/requirements.txt", "/tmp/test_dir") + mock_copy.assert_any_call("/test/build_native_dependencies.sh", "/tmp/test_dir") # Verify docker run command was called mock_docker_run_cmd.assert_called_once_with("default", "/tmp/test_dir") @@ -222,6 +277,7 @@ def test_prepare_dependency_archive_docker_build_failure( @patch("datacustomcode.deploy.shutil.copy") @patch("datacustomcode.deploy.tempfile.TemporaryDirectory") @patch("datacustomcode.deploy.os.path.join") + @patch("datacustomcode.deploy.os.path.dirname") @patch("datacustomcode.deploy.os.makedirs") @patch("datacustomcode.deploy.docker_build_cmd") @patch("datacustomcode.deploy.docker_run_cmd") @@ -230,6 +286,7 @@ def test_prepare_dependency_archive_docker_run_failure( mock_docker_run_cmd, mock_docker_build_cmd, mock_makedirs, + mock_dirname, mock_join, mock_temp_dir, mock_copy, @@ -252,15 +309,35 @@ def test_prepare_dependency_archive_docker_run_failure( ), # Run fails ] + # Mock os.path.dirname to handle different calls + def dirname_side_effect(path): + if path == "/test/dir": + return "/test" + else: + return path.rsplit("/", 1)[0] if "/" in path else "" + + mock_dirname.side_effect = dirname_side_effect + + # Mock os.path.join to handle different calls + def join_side_effect(*args): + if args == ("/test", "requirements.txt"): + return "/test/requirements.txt" + elif args == ("/test", "build_native_dependencies.sh"): + return "/test/build_native_dependencies.sh" + else: + return "/".join(args) + + mock_join.side_effect = join_side_effect + with pytest.raises(CalledProcessError, match="Run failed"): prepare_dependency_archive("/test/dir", "default", "script") # Verify docker images command was called mock_cmd_output.assert_any_call(self.EXPECTED_DOCKER_IMAGES_CMD) - # Verify files were copied to temp directory - mock_copy.assert_any_call("requirements.txt", "/tmp/test_dir") - mock_copy.assert_any_call("build_native_dependencies.sh", "/tmp/test_dir") + # Verify files were copied to temp directory from parent directory + mock_copy.assert_any_call("/test/requirements.txt", "/tmp/test_dir") + mock_copy.assert_any_call("/test/build_native_dependencies.sh", "/tmp/test_dir") # Verify docker run command was called mock_docker_run_cmd.assert_called_once_with("default", "/tmp/test_dir") @@ -269,6 +346,7 @@ def test_prepare_dependency_archive_docker_run_failure( @patch("datacustomcode.deploy.shutil.copy") @patch("datacustomcode.deploy.tempfile.TemporaryDirectory") @patch("datacustomcode.deploy.os.path.join") + @patch("datacustomcode.deploy.os.path.dirname") @patch("datacustomcode.deploy.os.makedirs") @patch("datacustomcode.deploy.docker_build_cmd") @patch("datacustomcode.deploy.docker_run_cmd") @@ -277,6 +355,7 @@ def test_prepare_dependency_archive_file_copy_failure( mock_docker_run_cmd, mock_docker_build_cmd, mock_makedirs, + mock_dirname, mock_join, mock_temp_dir, mock_copy, @@ -292,6 +371,26 @@ def test_prepare_dependency_archive_file_copy_failure( # Mock cmd_output to return image ID mock_cmd_output.return_value = "abc123" + # Mock os.path.dirname to handle different calls + def dirname_side_effect(path): + if path == "/test/dir": + return "/test" + else: + return path.rsplit("/", 1)[0] if "/" in path else "" + + mock_dirname.side_effect = dirname_side_effect + + # Mock os.path.join to handle different calls + def join_side_effect(*args): + if args == ("/test", "requirements.txt"): + return "/test/requirements.txt" + elif args == ("/test", "build_native_dependencies.sh"): + return "/test/build_native_dependencies.sh" + else: + return "/".join(args) + + mock_join.side_effect = join_side_effect + # Mock shutil.copy to raise exception mock_copy.side_effect = FileNotFoundError("File not found") @@ -301,8 +400,8 @@ def test_prepare_dependency_archive_file_copy_failure( # Verify docker images command was called mock_cmd_output.assert_any_call(self.EXPECTED_DOCKER_IMAGES_CMD) - # Verify files were attempted to be copied - mock_copy.assert_any_call("requirements.txt", "/tmp/test_dir") + # Verify files were attempted to be copied from parent directory + mock_copy.assert_any_call("/test/requirements.txt", "/tmp/test_dir") @patch("datacustomcode.deploy.cmd_output") @patch("datacustomcode.deploy.shutil.copytree") @@ -311,6 +410,7 @@ def test_prepare_dependency_archive_file_copy_failure( @patch("datacustomcode.deploy.tempfile.TemporaryDirectory") @patch("datacustomcode.deploy.os.path.exists") @patch("datacustomcode.deploy.os.path.join") + @patch("datacustomcode.deploy.os.path.dirname") @patch("datacustomcode.deploy.os.makedirs") @patch("datacustomcode.deploy.docker_build_cmd") @patch("datacustomcode.deploy.docker_run_cmd") @@ -319,6 +419,7 @@ def test_prepare_dependency_archive_function_type( mock_docker_run_cmd, mock_docker_build_cmd, mock_makedirs, + mock_dirname, mock_join, mock_exists, mock_temp_dir, @@ -337,10 +438,27 @@ def test_prepare_dependency_archive_function_type( # Mock cmd_output to return image ID (indicating image exists) mock_cmd_output.return_value = "abc123" - # Mock os.path.join for py-files paths + # Mock os.path.dirname to handle different calls + def dirname_side_effect(path): + if path == "/test/dir": + return "/test" + elif path == "payload/py-files": + return "payload" + else: + return path.rsplit("/", 1)[0] if "/" in path else "" + + mock_dirname.side_effect = dirname_side_effect + + # Mock os.path.join for all paths def join_side_effect(*args): - if args == ("/tmp/test_dir", "py-files"): + if args == ("/test", "requirements.txt"): + return "/test/requirements.txt" + elif args == ("/test", "build_native_dependencies.sh"): + return "/test/build_native_dependencies.sh" + elif args == ("/tmp/test_dir", "py-files"): return "/tmp/test_dir/py-files" + elif args == ("payload", "py-files"): + return "payload/py-files" return "/".join(args) mock_join.side_effect = join_side_effect @@ -367,9 +485,9 @@ def exists_side_effect(path): # Verify docker build command was not called (since image already exists) mock_docker_build_cmd.assert_not_called() - # Verify files were copied to temp directory - mock_copy.assert_any_call("requirements.txt", "/tmp/test_dir") - mock_copy.assert_any_call("build_native_dependencies.sh", "/tmp/test_dir") + # Verify files were copied to temp directory from parent directory + mock_copy.assert_any_call("/test/requirements.txt", "/tmp/test_dir") + mock_copy.assert_any_call("/test/build_native_dependencies.sh", "/tmp/test_dir") # Verify docker run command was called mock_docker_run_cmd.assert_called_once_with("default", "/tmp/test_dir") @@ -391,6 +509,7 @@ def exists_side_effect(path): @patch("datacustomcode.deploy.tempfile.TemporaryDirectory") @patch("datacustomcode.deploy.os.path.exists") @patch("datacustomcode.deploy.os.path.join") + @patch("datacustomcode.deploy.os.path.dirname") @patch("datacustomcode.deploy.os.makedirs") @patch("datacustomcode.deploy.docker_build_cmd") @patch("datacustomcode.deploy.docker_run_cmd") @@ -399,6 +518,7 @@ def test_prepare_dependency_archive_function_type_missing_pyfiles( mock_docker_run_cmd, mock_docker_build_cmd, mock_makedirs, + mock_dirname, mock_join, mock_exists, mock_temp_dir, @@ -418,9 +538,22 @@ def test_prepare_dependency_archive_function_type_missing_pyfiles( # Mock cmd_output to return image ID (indicating image exists) mock_cmd_output.return_value = "abc123" - # Mock os.path.join for py-files path + # Mock os.path.dirname to handle different calls + def dirname_side_effect(path): + if path == "/test/dir": + return "/test" + else: + return path.rsplit("/", 1)[0] if "/" in path else "" + + mock_dirname.side_effect = dirname_side_effect + + # Mock os.path.join for all paths def join_side_effect(*args): - if args == ("/tmp/test_dir", "py-files"): + if args == ("/test", "requirements.txt"): + return "/test/requirements.txt" + elif args == ("/test", "build_native_dependencies.sh"): + return "/test/build_native_dependencies.sh" + elif args == ("/tmp/test_dir", "py-files"): return "/tmp/test_dir/py-files" return "/".join(args) diff --git a/tests/test_function_utils.py b/tests/test_function_utils.py new file mode 100644 index 0000000..cc0f51d --- /dev/null +++ b/tests/test_function_utils.py @@ -0,0 +1,235 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import sys +import tempfile +import textwrap + +import pytest + +from datacustomcode import function_utils + + +@pytest.fixture +def sample_entrypoint(): + """Create a temporary entrypoint file with a function.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as temp_file: + entrypoint_content = textwrap.dedent( + """ + from typing import List + from pydantic import BaseModel + + class SampleRequest(BaseModel): + message: str + count: int = 5 + tags: List[str] = [] + version: str = "v1" + + class SampleResponse(BaseModel): + result: str + success: bool = True + + def function(request: SampleRequest) -> SampleResponse: + return SampleResponse(result=f"Processed {request.message}") + """ + ) + temp_file.write(entrypoint_content) + temp_file_path = temp_file.name + + yield temp_file_path + + if os.path.exists(temp_file_path): + os.unlink(temp_file_path) + + +@pytest.fixture +def entrypoint_no_annotations(): + """Create an entrypoint with no type annotations.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as temp_file: + entrypoint_content = textwrap.dedent( + """ + def function(request): + return {"result": "no annotations"} + """ + ) + temp_file.write(entrypoint_content) + temp_file_path = temp_file.name + + yield temp_file_path + + if os.path.exists(temp_file_path): + os.unlink(temp_file_path) + + +def test_get_function_signature_types(sample_entrypoint, entrypoint_no_annotations): + """Test extracting request and response types from function signatures.""" + module = function_utils.load_function_module(sample_entrypoint) + func = function_utils.get_function_callable(module) + req_type, resp_type, req_name, resp_name = ( + function_utils.get_function_signature_types(func) + ) + + assert req_name == "SampleRequest" + assert resp_name == "SampleResponse" + assert req_type is not None + assert resp_type is not None + + module_no_annot = function_utils.load_function_module(entrypoint_no_annotations) + func_no_annot = function_utils.get_function_callable(module_no_annot) + req_type, resp_type, req_name, resp_name = ( + function_utils.get_function_signature_types(func_no_annot) + ) + + assert req_name is None + assert resp_name is None + + +def test_inspect_function_types_static(sample_entrypoint, entrypoint_no_annotations): + """Test static AST-based inspection of function types.""" + req_name, resp_name = function_utils.inspect_function_types_static( + sample_entrypoint + ) + assert req_name == "SampleRequest" + assert resp_name == "SampleResponse" + + req_name, resp_name = function_utils.inspect_function_types_static( + entrypoint_no_annotations + ) + assert req_name is None + assert resp_name is None + + +def test_inspect_function_types(sample_entrypoint): + """Test dynamic inspection of function types.""" + req_name, resp_name = function_utils.inspect_function_types(sample_entrypoint) + assert req_name == "SampleRequest" + assert resp_name == "SampleResponse" + + req_name, resp_name = function_utils.inspect_function_types("/nonexistent/file.py") + assert req_name is None + assert resp_name is None + + +def test_get_request_type(sample_entrypoint, entrypoint_no_annotations): + """Test getting request type from entrypoint.""" + req_type = function_utils.get_request_type(sample_entrypoint) + assert req_type is not None + assert hasattr(req_type, "model_fields") + + with pytest.raises(ValueError, match="must have a type annotation"): + function_utils.get_request_type(entrypoint_no_annotations) + + +def test_generate_test_json(): + """Test generating test.json with simple and complex nested types.""" + temp_dir = tempfile.mkdtemp() + models_file = os.path.join(temp_dir, "test_models.py") + + try: + # Test 1: Simple request type + entrypoint_simple = os.path.join(temp_dir, "entrypoint_simple.py") + output_simple = os.path.join(temp_dir, "test_simple.json") + + with open(models_file, "w") as f: + models_content = textwrap.dedent( + """ + from pydantic import BaseModel + from typing import List + + class SimpleRequest(BaseModel): + message: str + count: int = 5 + tags: List[str] = [] + version: str = "v1" + + class NestedConfig(BaseModel): + host: str + port: int = 8080 + enabled: bool = True + + class ComplexRequest(BaseModel): + name: str + max_items: int = 100 + config: NestedConfig + metadata: dict = {} + """ + ) + f.write(models_content) + + with open(entrypoint_simple, "w") as f: + entrypoint_content = textwrap.dedent( + """ + from test_models import SimpleRequest + + def function(request: SimpleRequest): + return {"result": "ok"} + """ + ) + f.write(entrypoint_content) + + sys.path.insert(0, temp_dir) + + function_utils.generate_test_json(entrypoint_simple, output_simple) + assert os.path.exists(output_simple) + + with open(output_simple, "r") as f: + data = json.load(f) + + assert "message" in data + assert data["count"] == 5 + assert data["version"] == "v1" + assert data["tags"] == [] + + # Test 2: Complex request type with nested models + entrypoint_complex = os.path.join(temp_dir, "entrypoint_complex.py") + output_complex = os.path.join(temp_dir, "test_complex.json") + + with open(entrypoint_complex, "w") as f: + entrypoint_content = textwrap.dedent( + """ + from test_models import ComplexRequest + + def function(request: ComplexRequest): + return {"result": "ok"} + """ + ) + f.write(entrypoint_content) + + function_utils.generate_test_json(entrypoint_complex, output_complex) + assert os.path.exists(output_complex) + + with open(output_complex, "r") as f: + complex_data = json.load(f) + + assert "name" in complex_data + assert "max_items" in complex_data + assert complex_data["max_items"] == 100 + assert "config" in complex_data + assert isinstance(complex_data["config"], dict) + assert "host" in complex_data["config"] + assert "port" in complex_data["config"] + assert complex_data["config"]["port"] == 8080 + assert complex_data["config"]["enabled"] is True + assert "metadata" in complex_data + assert complex_data["metadata"] == {} + + finally: + if temp_dir in sys.path: + sys.path.remove(temp_dir) + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) diff --git a/tests/test_sf_cli_contract.py b/tests/test_sf_cli_contract.py index b96ab35..f53123e 100644 --- a/tests/test_sf_cli_contract.py +++ b/tests/test_sf_cli_contract.py @@ -188,25 +188,6 @@ def test_accepts_network_flag( result = runner.invoke(deploy, [*self._BASE_ARGS, "--network", "custom"]) assert result.exit_code != 2, result.output - @patch("datacustomcode.token_provider.SFCLITokenProvider") - @patch("datacustomcode.deploy.deploy_full") - @patch("datacustomcode.cli.find_base_directory") - @patch("datacustomcode.cli.get_package_type") - def test_accepts_function_invoke_opt_flag( - self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_provider - ): - mock_find_base.return_value = "payload" - mock_pkg_type.return_value = "function" - mock_provider_instance = mock_sf_cli_provider.return_value - mock_provider_instance.get_token.return_value = AccessTokenResponse( - access_token="tok", instance_url="https://example.com" - ) - runner = CliRunner() - result = runner.invoke( - deploy, [*self._BASE_ARGS, "--function-invoke-opt", "ASYNC"] - ) - assert result.exit_code != 2, result.output - class TestRunArgContract: """