diff --git a/.github/workflows/sf_cli_integration.yml b/.github/workflows/sf_cli_integration.yml index b9ff172..5a515d1 100644 --- a/.github/workflows/sf_cli_integration.yml +++ b/.github/workflows/sf_cli_integration.yml @@ -200,6 +200,10 @@ jobs: echo "::error::testFunction/.datacustomcode_proj/sdk_config.json not found after function init." exit 1 } + test -f testFunction/payload/tests/test.json || { + echo "::error::testFunction/payload/tests/test.json not found after function init." + exit 1 + } # ── Function: scan ──────────────────────────────────────────────────────── @@ -251,14 +255,14 @@ jobs: # ── Function: run ───────────────────────────────────────────────────────── - - name: '[function] run — sf data-code-extension function run --entrypoint testFunction/payload/entrypoint.py -o dev1' + - name: '[function] run — sf data-code-extension function run --entrypoint testFunction/payload/entrypoint.py --test-with testFunction/payload/tests/test.json -o dev1' run: | sf data-code-extension function run \ --entrypoint testFunction/payload/entrypoint.py \ - -o dev1 || { - echo "::error::sf data-code-extension function run FAILED. Check mock server output above; the --entrypoint flag or SF CLI org auth contract may have changed." - exit 1 - } + --test-with testFunction/payload/tests/test.json || { + echo "::error::sf data-code-extension function run FAILED. Check mock server output above; the --entrypoint flag or SF CLI org auth contract may have changed." + exit 1 + } # ── Function: deploy ───────────────────────────────────────────────────── @@ -270,7 +274,6 @@ jobs: --description "Test function deploy" \ --package-dir testFunction/payload \ --cpu-size CPU_2XL \ - --function-invoke-opt UnstructuredChunking \ -o dev1 || { echo "::error::sf data-code-extension function deploy FAILED. Check mock server output above for which endpoint failed. The deploy command flags or API contract may have changed." exit 1 diff --git a/src/datacustomcode/cli.py b/src/datacustomcode/cli.py index c6e9c5c..84da911 100644 --- a/src/datacustomcode/cli.py +++ b/src/datacustomcode/cli.py @@ -27,6 +27,13 @@ from datacustomcode import AuthType from datacustomcode.auth import configure_oauth_tokens +from datacustomcode.constants import ( + CONFIG_FILE, + ENTRYPOINT_FILE, + PAYLOAD_DIR, + TEST_FILE, + TESTS_DIR, +) from datacustomcode.scan import find_base_directory, get_package_type @@ -74,6 +81,30 @@ def _configure_client_credentials( ) +def _generate_function_test_file(entrypoint_path: str) -> Optional[str]: + """Generate test.json file for a function. + + Args: + entrypoint_path: Path to the function's entrypoint.py + + Returns: + Path to generated test.json, or None if generation failed + """ + from datacustomcode.function_utils import generate_test_json + + tests_dir = os.path.join(os.path.dirname(entrypoint_path), TESTS_DIR) + os.makedirs(tests_dir, exist_ok=True) + test_json_path = os.path.join(tests_dir, TEST_FILE) + + try: + generate_test_json(entrypoint_path, test_json_path) + logger.debug(f"Generated test JSON at {test_json_path}") + return test_json_path + except Exception as e: + logger.warning(f"Could not generate test.json: {e}") + return None + + @cli.command() @click.option("--profile", default="default", help="Credential profile name") @click.option( @@ -162,7 +193,6 @@ def zip(path: str, network: str): Choose based on your workload requirements.""", ) -@click.option("--function-invoke-opt") @click.option( "--sf-cli-org", default=None, @@ -176,13 +206,14 @@ def deploy( cpu_size: str, profile: str, network: str, - function_invoke_opt: str, sf_cli_org: Optional[str], ): + from datacustomcode.constants import USE_IN_FEATURE_MAPPING_FOR_CONNECT_API from datacustomcode.deploy import ( COMPUTE_TYPES, CodeExtensionMetadata, deploy_full, + infer_use_in_feature, ) from datacustomcode.token_provider import ( CredentialsTokenProvider, @@ -211,15 +242,24 @@ def deploy( ) if package_type == "function": - if not function_invoke_opt: + # Infer use_in_feature from function signature + entrypoint_path = os.path.join(path, ENTRYPOINT_FILE) + use_in_feature = infer_use_in_feature(entrypoint_path) + if use_in_feature: + logger.info(f"Inferred use_in_feature: {use_in_feature}") + else: click.secho( - "Error: Function invoke options are required for function package type", + "Error: Could not infer function invoke options. " + "Please provide --use-in-feature", fg="red", ) raise click.Abort() - else: - function_invoke_options = function_invoke_opt.split(",") - metadata.functionInvokeOptions = function_invoke_options + + # Map user-provided feature names to API names + mapped_feature = USE_IN_FEATURE_MAPPING_FOR_CONNECT_API.get( + use_in_feature, use_in_feature + ) + metadata.functionInvokeOptions = [mapped_feature] try: if sf_cli_org: @@ -238,7 +278,12 @@ def deploy( @click.option( "--code-type", default="script", type=click.Choice(["script", "function"]) ) -def init(directory: str, code_type: str): +@click.option( + "--use-in-feature", + default="SearchIndexChunking", + help="Feature where this function will be used (only applicable for function).", +) +def init(directory: str, code_type: str, use_in_feature: Optional[str]): from datacustomcode.scan import ( dc_config_json_from_file, update_config, @@ -250,9 +295,9 @@ def init(directory: str, code_type: str): if code_type == "script": copy_script_template(directory) elif code_type == "function": - copy_function_template(directory) - entrypoint_path = os.path.join(directory, "payload", "entrypoint.py") - config_location = os.path.join(os.path.dirname(entrypoint_path), "config.json") + copy_function_template(directory, use_in_feature) + entrypoint_path = os.path.join(directory, PAYLOAD_DIR, ENTRYPOINT_FILE) + config_location = os.path.join(os.path.dirname(entrypoint_path), CONFIG_FILE) # Write package type to SDK-specific config sdk_config = {"type": code_type} @@ -265,6 +310,7 @@ def init(directory: str, code_type: str): updated_config_json = update_config(entrypoint_path) with open(config_location, "w") as f: json.dump(updated_config_json, f, indent=2) + click.echo( "Start developing by updating the code in " + click.style(entrypoint_path, fg="blue", bold=True) @@ -275,6 +321,24 @@ def init(directory: str, code_type: str): + " to automatically update config.json when you make changes to your code" ) + # Generate test.json for functions + if code_type == "function": + test_json_path = _generate_function_test_file(entrypoint_path) + if test_json_path: + click.echo( + "Generated test file at " + + click.style(test_json_path, fg="blue", bold=True) + ) + click.echo( + "Test your function locally with " + + click.style( + f"datacustomcode run {entrypoint_path} " + f"--test-with {test_json_path}", + fg="blue", + bold=True, + ) + ) + @cli.command() @click.argument("filename") @@ -286,7 +350,7 @@ def init(directory: str, code_type: str): def scan(filename: str, config: str, dry_run: bool, no_requirements: bool): from datacustomcode.scan import update_config, write_requirements_file - config_location = config or os.path.join(os.path.dirname(filename), "config.json") + config_location = config or os.path.join(os.path.dirname(filename), CONFIG_FILE) click.echo( "Dumping scan results to config file: " + click.style(config_location, fg="blue", bold=True) @@ -312,6 +376,12 @@ def scan(filename: str, config: str, dry_run: bool, no_requirements: bool): @click.option("--config-file", default=None) @click.option("--dependencies", default=[], multiple=True) @click.option("--profile", default="default") +@click.option( + "--test-with", + default=None, + type=click.Path(exists=True), + help="Path to test JSON file for function testing", +) @click.option( "--sf-cli-org", default=None, @@ -322,10 +392,16 @@ def run( config_file: Union[str, None], dependencies: List[str], profile: str, + test_with: Optional[str], sf_cli_org: Optional[str], ): from datacustomcode.run import run_entrypoint run_entrypoint( - entrypoint, config_file, dependencies, profile, sf_cli_org=sf_cli_org + entrypoint, + config_file, + dependencies, + profile, + test_file=test_with, + sf_cli_org=sf_cli_org, ) diff --git a/src/datacustomcode/constants.py b/src/datacustomcode/constants.py new file mode 100644 index 0000000..76b6a7c --- /dev/null +++ b/src/datacustomcode/constants.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used throughout the datacustomcode package.""" + +# File and directory names +ENTRYPOINT_FILE = "entrypoint.py" +CONFIG_FILE = "config.json" +PAYLOAD_DIR = "payload" +TESTS_DIR = "tests" +TEST_FILE = "test.json" +REQUIREMENTS_FILE = "requirements.txt" + +# Default values +DEFAULT_PROFILE = "default" +DEFAULT_NETWORK = "default" +DEFAULT_CPU_SIZE = "CPU_2XL" + +# Feature to template folder mapping +FEATURE_TEMPLATE_MAPPING = { + "SearchIndexChunking": "chunking", +} + +# Feature name to Connect API name mapping +USE_IN_FEATURE_MAPPING_FOR_CONNECT_API = { + "SearchIndexChunking": "UnstructuredChunking", +} + +# Pydantic request/response type names to feature names +REQUEST_TYPE_TO_FEATURE = { + "SearchIndexChunkingV1Request": "SearchIndexChunking", + "SearchIndexChunkingV1Response": "SearchIndexChunking", +} diff --git a/src/datacustomcode/deploy.py b/src/datacustomcode/deploy.py index 114252a..65495e6 100644 --- a/src/datacustomcode/deploy.py +++ b/src/datacustomcode/deploy.py @@ -35,6 +35,7 @@ import requests from datacustomcode.cmd import cmd_output +from datacustomcode.constants import REQUEST_TYPE_TO_FEATURE from datacustomcode.scan import find_base_directory, get_package_type DATA_CUSTOM_CODE_PATH = "services/data/v63.0/ssot/data-custom-code" @@ -65,6 +66,40 @@ def _sanitize_api_name(name: str) -> str: return sanitized +def infer_use_in_feature(entrypoint_path: str) -> Union[str, None]: + """Infer the use_in_feature from function signature. + + Checks both the request parameter type and return type annotation. + Both must map to the same feature for a valid inference. + + Uses static AST parsing to avoid importing dependencies. + + Args: + entrypoint_path: Path to the entrypoint.py file + + Returns: + The feature name if both request and response match, None otherwise + """ + from datacustomcode.function_utils import inspect_function_types_static + + request_type_name, response_type_name = inspect_function_types_static( + entrypoint_path + ) + + if not request_type_name or not response_type_name: + return None + + # Look up features for both types + request_feature = REQUEST_TYPE_TO_FEATURE.get(request_type_name) + response_feature = REQUEST_TYPE_TO_FEATURE.get(response_type_name) + + # Both must be present and must match + if request_feature and response_feature and request_feature == response_feature: + return request_feature + + return None + + class CodeExtensionMetadata(BaseModel): name: str version: str diff --git a/src/datacustomcode/function/feature_types/chunking.py b/src/datacustomcode/function/feature_types/chunking.py index bdf0d91..1425921 100644 --- a/src/datacustomcode/function/feature_types/chunking.py +++ b/src/datacustomcode/function/feature_types/chunking.py @@ -14,76 +14,184 @@ # limitations under the License. """ -Pydantic models for byoc-function-proto (uds_chunking.proto) -Auto-generated - validation rules from buf.validate +Pydantic models for Search Index Chunking V1 """ - +from enum import Enum from typing import ( - Any, Dict, List, - Literal, + Optional, + Union, +) + +from pydantic import ( + BaseModel, + ConfigDict, + Field, ) -from pydantic import BaseModel, Field +class DocumentType(str, Enum): + """Document type enumeration""" -class DocElement(BaseModel): - """Document element to be chunked""" + TEXT = "text" + TITLE = "title" + TABLE = "table" + IMAGE = "image" + LIST_ITEM = "list_item" + CODE_SNIPPET = "code_snippet" + PAGE_METADATA = "page_metadata" + + +class ChunkType(str, Enum): + TEXT = "text" + + +class SearchIndexChunkingV1PrependField(BaseModel): + """Field to prepend to chunk content""" - text: str = Field(..., description="Text content to be chunked") - metadata: Dict[str, Any] = Field( - default_factory=dict, description="Source document metadata" + dmo_name: str = Field( + default="", description="Data Model Object name", examples=["udmo_1__dlm"] ) + field_name: str = Field( + default="", + description="Field name to prepend", + examples=["ResolvedFilePath__c"], + ) + value: str = Field( + default="", + description="Field value to prepend", + examples=["udlo_1__dll:quarterly_report.pdf"], + ) + model_config = ConfigDict(extra="ignore") -class ChunkOutput(BaseModel): - """Output chunk from the chunking process""" +class SearchIndexChunkingV1TranscriptField(BaseModel): + """Field to prepend to chunk content""" - chunk_id: str = Field(..., description="UUID for this chunk") - chunk_type: str = Field(..., description="Type: 'text'") - text: str = Field(..., description="Chunk text content") - seq_no: int = Field(..., description="Sequential chunk number (1-based)") - metadata: Dict[str, str] = Field( - default_factory=dict, description="Metadata from source (DMO fields)" + speaker: str = Field( + default="", + description="Speaker name for audio/video transcripts", + examples=["Agent"], ) - tag_metadata: Dict[str, Any] = Field( - default_factory=dict, description="Additional tags" + start_timestamp: str = Field( + default="", + description="Start timestamp in ISO8601 format: YYYY-MM-DDTHH:MM:SS.ffffff", + examples=["2026-03-25T02:01:24.918000"], ) - citations: Dict[str, Any] = Field( - default_factory=dict, description="Citation information" + end_timestamp: str = Field( + default="", + description="End timestamp in ISO8601 format: YYYY-MM-DDTHH:MM:SS.ffffff", + examples=["2026-03-25T02:01:30.500000"], ) + model_config = ConfigDict(extra="ignore") -class StatusResponse(BaseModel): - """Status response for operation""" +class SearchIndexChunkingV1Metadata(BaseModel): + """Metadata for input documents""" - status_type: str = Field(..., description="'success' or 'error'") - status_message: str = Field(..., description="Human-readable status") + type: DocumentType = Field( + default=DocumentType.TEXT, description="Document type (text)", examples=["text"] + ) + transcript_fields: SearchIndexChunkingV1TranscriptField = Field( + default_factory=SearchIndexChunkingV1TranscriptField, + description=( + "Transcript information. Will only be there in case of audio-video files" + ), + ) + page_number: int = Field( + default=0, + description="Page number in the source document (0-based)", + examples=[1], + ) + text_as_html: Optional[str] = Field( + default=None, + description="HTML representation of the document text", + examples=["

Online Remittance Instructions

"], + ) + source_dmo_fields: Dict[str, Union[str, int]] = Field( + default_factory=dict, + description=( + "Source Data Model Object fields as key-value pairs " + "(values can be string or int)" + ), + examples=[ + { + "FilePath__c": "quarterly_report.pdf", + "Size__c": 1377454, + "ContentType__c": "pdf", + "LastModified__c": "2026-03-25T02:01:24.918000", + } + ], + ) + prepend: List[SearchIndexChunkingV1PrependField] = Field( + default_factory=list, description="List of fields to prepend to each chunk" + ) + model_config = ConfigDict(extra="ignore") + + +class SearchIndexChunkingV1DocElement(BaseModel): + """Document element to be chunked""" + text: str = Field( + default="", + description="Text content to be chunked", + examples=[ + ( + "Online Remittance Instructions\n\n" + "Transfer proceeds from the sale of your ESOP/RSUs easily." + ) + ], + ) + metadata: SearchIndexChunkingV1Metadata = Field( + default_factory=SearchIndexChunkingV1Metadata, + description="Source document metadata", + ) + model_config = ConfigDict(extra="ignore") -class UdsChunkingV1BatchRequest(BaseModel): - """Batch request for UDS chunking""" - version: Literal["v1"] = Field( - default="v1", description="API version, must be 'v1'" +class SearchIndexChunkingV1Output(BaseModel): + """Output chunk from the chunking process""" + + text: str = Field( + default="", + description="Chunk text content", + examples=["Online Remittance Instructions"], ) - input: List[DocElement] = Field( - ..., min_length=1, description="List of documents (min 1)" + seq_no: int = Field( + default=0, description="Sequential chunk number (1-based)", ge=1, examples=[1] ) - max_characters: int = Field(..., description="Max chars per chunk (default: 100)") - additional_params: Dict[str, Any] = Field( - default_factory=dict, description="Future extension point" + chunk_id: str = Field( + default="", + description="Unique identifier for this chunk (UUID format)", + examples=["550e8400-e29b-41d4-a716-446655440000"], ) + chunk_type: ChunkType = Field( + default=ChunkType.TEXT, + description="Type of chunk (e.g., 'text')", + examples=["text"], + ) + citations: Dict[str, str] = Field( + default_factory=dict, + description="Citation information as key-value pairs", + examples=[{"source": "quarterly_report.pdf"}], + ) + model_config = ConfigDict(extra="ignore") -class UdsChunkingV1BatchResponse(BaseModel): - """Batch response for UDS chunking""" +class SearchIndexChunkingV1Request(BaseModel): + """Request for Search Index Chunking""" - version: Literal["v1"] = Field( - default="v1", description="API version, must be 'v1'" + input: List[SearchIndexChunkingV1DocElement] = Field( + default_factory=list, description="List of documents to be chunked" ) - output: List[ChunkOutput] = Field( + model_config = ConfigDict(extra="ignore") + + +class SearchIndexChunkingV1Response(BaseModel): + """Batch response for UDS chunking""" + + output: List[SearchIndexChunkingV1Output] = Field( default_factory=list, description="Flat list of chunks from all docs" ) - status: StatusResponse = Field(..., description="Overall operation status") + model_config = ConfigDict(extra="ignore") diff --git a/src/datacustomcode/function_utils.py b/src/datacustomcode/function_utils.py new file mode 100644 index 0000000..c499526 --- /dev/null +++ b/src/datacustomcode/function_utils.py @@ -0,0 +1,367 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for inspecting and working with function entrypoints.""" + +import ast +import importlib.util +import inspect +import json +import typing +from typing import ( + Any, + Optional, + Tuple, +) + + +def load_function_module(entrypoint_path: str, module_name: str = "function_module"): + """Load a function entrypoint as a Python module. + + Args: + entrypoint_path: Path to the entrypoint.py file + module_name: Name to assign to the module + + Returns: + The loaded module + + Raises: + ImportError: If the module cannot be loaded + """ + spec = importlib.util.spec_from_file_location(module_name, entrypoint_path) + if spec is None or spec.loader is None: + raise ImportError(f"Could not load module from {entrypoint_path}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def get_function_callable(module): + """Get the 'function' callable from a module. + + Args: + module: The module to extract the function from + + Returns: + The function callable + + Raises: + AttributeError: If module doesn't have a 'function' attribute + """ + if not hasattr(module, "function"): + raise AttributeError("Module does not have a 'function' callable") + return module.function + + +def get_type_name(type_annotation: Any) -> Optional[str]: + """Extract the type name from a type annotation. + + Args: + type_annotation: A type annotation object + + Returns: + The type name as a string, or None if it cannot be determined + """ + if type_annotation == inspect.Parameter.empty: + return None + + if hasattr(type_annotation, "__name__"): + return str(type_annotation.__name__) + + return str(type_annotation) + + +def get_function_signature_types( + function_callable, +) -> Tuple[Optional[Any], Optional[Any], Optional[str], Optional[str]]: + """Extract request and response types from a function signature. + + Args: + function_callable: The function to inspect + + Returns: + Tuple of (request_type, response_type, request_type_name, response_type_name) + Any of these can be None if not found + """ + sig = inspect.signature(function_callable) + params = list(sig.parameters.values()) + + request_type = None + request_type_name = None + if len(params) >= 1: + request_type = params[0].annotation + request_type_name = get_type_name(request_type) + + response_type = sig.return_annotation + response_type_name = get_type_name(response_type) + + return request_type, response_type, request_type_name, response_type_name + + +def inspect_function_types_static( + entrypoint_path: str, +) -> Tuple[Optional[str], Optional[str]]: + """Inspect function types using static AST parsing (no imports). + + This parses the Python file without executing it, so it doesn't + require dependencies to be installed. + + Args: + entrypoint_path: Path to the entrypoint.py file + + Returns: + Tuple of (request_type_name, response_type_name) + """ + try: + with open(entrypoint_path, "r") as f: + tree = ast.parse(f.read(), filename=entrypoint_path) + + # Find the 'function' definition + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == "function": + # Get request type (first parameter annotation) + request_type_name = None + if node.args.args and len(node.args.args) > 0: + first_param = node.args.args[0] + if first_param.annotation: + request_type_name = _get_type_name_from_ast( + first_param.annotation + ) + + # Get response type (return annotation) + response_type_name = None + if node.returns: + response_type_name = _get_type_name_from_ast(node.returns) + + return request_type_name, response_type_name + + return None, None + except Exception: + return None, None + + +def _get_type_name_from_ast(annotation) -> Optional[str]: + """Extract type name from an AST annotation node.""" + if isinstance(annotation, ast.Name): + # Simple type: MyType + return annotation.id + elif isinstance(annotation, ast.Attribute): + # Module.Type - just return the type name + return annotation.attr + elif isinstance(annotation, ast.Subscript): + # Generic type: List[MyType], Optional[MyType] + # Return the base type name + return _get_type_name_from_ast(annotation.value) + return None + + +def _import_pydantic_model(entrypoint_path: str, type_name: str) -> Optional[Any]: + """Import a Pydantic model by finding its import statement. + + Parses the entrypoint to find where the type is imported from, + then imports just that module (not the entrypoint itself). + + Args: + entrypoint_path: Path to entrypoint.py + type_name: Name of the type to import (e.g., "SearchIndexChunkingV1Request") + + Returns: + The Pydantic model class, or None if not found + """ + try: + with open(entrypoint_path, "r") as f: + tree = ast.parse(f.read(), filename=entrypoint_path) + + # Find where this type is imported from + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom): + # from module import Type1, Type2 + for alias in node.names: + if alias.name == type_name: + # Found it! Import from the module + module_name = node.module + if module_name: + module = importlib.import_module(module_name) + return getattr(module, type_name, None) + + return None + except Exception: + return None + + +def inspect_function_types( + entrypoint_path: str, +) -> Tuple[Optional[str], Optional[str]]: + """Inspect a function entrypoint and extract type names. + + Args: + entrypoint_path: Path to the entrypoint.py file + + Returns: + Tuple of (request_type_name, response_type_name) + Either can be None if not found or on error + + Example: + >>> request_name, response_name = inspect_function_types( + ... "payload/entrypoint.py" + ... ) + >>> print(request_name) # "SearchIndexChunkingV1Request" + >>> print(response_name) # "SearchIndexChunkingV1Response" + """ + try: + module = load_function_module(entrypoint_path, "temp_module") + function_callable = get_function_callable(module) + _, _, request_type_name, response_type_name = get_function_signature_types( + function_callable + ) + return request_type_name, response_type_name + except Exception: + return None, None + + +def get_request_type(entrypoint_path: str) -> Any: + """Get the request type annotation from a function entrypoint. + + Args: + entrypoint_path: Path to the entrypoint.py file + + Returns: + The request type (Pydantic model class) + + Raises: + ImportError: If the module cannot be loaded + AttributeError: If the function is not found + ValueError: If the function signature is invalid + """ + module = load_function_module(entrypoint_path) + function_callable = get_function_callable(module) + + sig = inspect.signature(function_callable) + params = list(sig.parameters.values()) + + if len(params) < 1: + raise ValueError("Function must accept at least one parameter (request)") + + request_type = params[0].annotation + if request_type == inspect.Parameter.empty: + raise ValueError("Function request parameter must have a type annotation") + + return request_type + + +def _generate_model_sample_data(model_type): + """Generate sample data for all fields in a Pydantic model. + + Args: + model_type: A Pydantic model class + + Returns: + Dictionary with sample data for all fields + """ + from pydantic_core import PydanticUndefined + + sample_data = {} + for field_name, field_info in model_type.model_fields.items(): + # Use examples if available + if field_info.examples and len(field_info.examples) > 0: + sample_data[field_name] = field_info.examples[0] + # Check if field has a real default value + elif field_info.default is not PydanticUndefined: + sample_data[field_name] = field_info.default + else: + # Required field or field without default - generate sample + sample_data[field_name] = generate_sample_value( + field_info.annotation, field_name + ) + return sample_data + + +def generate_sample_value(field_type, field_name: str): + """Generate a sample value based on field type. + + Args: + field_type: The type annotation of the field + field_name: The name of the field (used for contextual sample generation) + + Returns: + A sample value appropriate for the field type + """ + origin = typing.get_origin(field_type) + + if origin is list or field_type is list: + args = typing.get_args(field_type) + if args: + return [generate_sample_value(args[0], field_name)] + return [] + elif origin is dict or field_type is dict: + return {} + elif field_type is str or origin is typing.Literal: + if "version" in field_name.lower(): + return "v1" + return f"sample_{field_name}" + elif field_type is int: + if "max" in field_name.lower() or "characters" in field_name.lower(): + return 100 + return 1 + elif field_type is float: + return 1.0 + elif field_type is bool: + return True + elif hasattr(field_type, "model_fields"): + # Nested Pydantic model - use shared helper + return _generate_model_sample_data(field_type) + else: + return None + + +def generate_test_json(entrypoint_path: str, output_path: str) -> None: + """Generate a sample test.json file for a function. + + First tries static AST parsing to get type names, then uses those + to import only the Pydantic model classes (not the entrypoint). + + Args: + entrypoint_path: Path to the function entrypoint.py + output_path: Output path for test.json + + Raises: + ImportError: If the Pydantic model cannot be loaded + ValueError: If the request type is not found or not a Pydantic model + """ + # First, get the type name using static parsing (no imports) + request_type_name, _ = inspect_function_types_static(entrypoint_path) + + if not request_type_name: + raise ValueError("Could not determine request type from function signature") + + # Now try to import the Pydantic model class + # Look for it in the entrypoint's imports + request_type = _import_pydantic_model(entrypoint_path, request_type_name) + + if not request_type: + raise ValueError(f"Could not import Pydantic model: {request_type_name}") + + # Check if it's a Pydantic model + if not hasattr(request_type, "model_fields"): + raise ValueError("Request parameter type must be a Pydantic model") + + # Generate sample data for ALL fields (use defaults where available) + sample_data = _generate_model_sample_data(request_type) + sample_instance = request_type(**sample_data) + + # Write to file + with open(output_path, "w") as f: + json.dump(sample_instance.model_dump(), f, indent=2) diff --git a/src/datacustomcode/run.py b/src/datacustomcode/run.py index 0e5052a..6322270 100644 --- a/src/datacustomcode/run.py +++ b/src/datacustomcode/run.py @@ -70,6 +70,7 @@ def run_entrypoint( config_file: Union[str, None], dependencies: List[str], profile: str, + test_file: Optional[str] = None, sf_cli_org: Optional[str] = None, ) -> None: """Run the entrypoint for script or function with the given config and dependencies. @@ -79,6 +80,7 @@ def run_entrypoint( config_file: The config file to use. dependencies: The dependencies to import. profile: The credentials profile to use. + test_file: Optional test JSON file for function testing. sf_cli_org: Optional SF CLI org alias or username. If provided, credentials are fetched via `sf org display` instead of from credentials.ini. """ @@ -138,7 +140,64 @@ def run_entrypoint( raise exc except (ModuleNotFoundError, AttributeError) as inner_exc: raise inner_exc from exc - runpy.run_path(entrypoint, init_globals=globals(), run_name="__main__") + + # Handle test file for functions + if test_file and package_type == "function": + run_function_with_test(entrypoint, test_file) + else: + runpy.run_path(entrypoint, init_globals=globals(), run_name="__main__") + + +def run_function_with_test(entrypoint: str, test_file: str) -> None: + """Run a function with test data from a JSON file. + + Dependencies are already loaded by this point, so we just import + the entrypoint module and call the function directly. + + Args: + entrypoint: Path to the function entrypoint.py + test_file: Path to test JSON file containing request data + """ + from datacustomcode.function_utils import ( + get_function_callable, + get_request_type, + load_function_module, + ) + + # Import the entrypoint module in the current environment + # (with all dependencies loaded) + module = load_function_module(entrypoint, "entrypoint_module") + function_callable = get_function_callable(module) + request_type = get_request_type(entrypoint) + + # Load and parse the test JSON + with open(test_file, "r") as f: + test_data = json.load(f) + + # Use Pydantic to parse and validate the request + try: + request = request_type(**test_data) + except Exception as e: + raise ValueError( + f"Failed to parse test data as {request_type.__name__}: {e}" + ) from e + + # Import Runtime + from datacustomcode.function import Runtime + + # Call the function with test data + print(f"Running function with test data from {test_file}...") + result = function_callable(request, Runtime()) + + # Pretty print the result + print("\n" + "=" * 80) + print("RESULT:") + print("=" * 80) + if hasattr(result, "model_dump"): + print(json.dumps(result.model_dump(), indent=2)) + else: + print(result) + print("=" * 80) def add_py_folder(entrypoint: str): diff --git a/src/datacustomcode/template.py b/src/datacustomcode/template.py index 195d4a2..6807510 100644 --- a/src/datacustomcode/template.py +++ b/src/datacustomcode/template.py @@ -14,9 +14,12 @@ # limitations under the License. import os import shutil +from typing import Optional from loguru import logger +from datacustomcode.constants import FEATURE_TEMPLATE_MAPPING + script_template_dir = os.path.join(os.path.dirname(__file__), "templates", "script") function_template_dir = os.path.join(os.path.dirname(__file__), "templates", "function") @@ -37,16 +40,42 @@ def copy_script_template(target_dir: str) -> None: shutil.copy2(source, destination) -def copy_function_template(target_dir: str) -> None: +def copy_function_template(target_dir: str, use_in_feature: Optional[str]) -> None: os.makedirs(target_dir, exist_ok=True) + # First, copy common files from base function template for item in os.listdir(function_template_dir): source = os.path.join(function_template_dir, item) destination = os.path.join(target_dir, item) + # Skip feature-specific subdirectories + if os.path.isdir(source) and item in FEATURE_TEMPLATE_MAPPING.values(): + continue + if os.path.isdir(source): logger.debug(f"Copying directory {source} to {destination}...") shutil.copytree(source, destination, dirs_exist_ok=True) else: logger.debug(f"Copying file {source} to {destination}...") shutil.copy2(source, destination) + + # Then, copy feature-specific files (overwriting if needed) + if use_in_feature and use_in_feature in FEATURE_TEMPLATE_MAPPING: + feature_function_template_dir = os.path.join( + function_template_dir, FEATURE_TEMPLATE_MAPPING[use_in_feature] + ) + + for item in os.listdir(feature_function_template_dir): + source = os.path.join(feature_function_template_dir, item) + destination = os.path.join(target_dir, item) + + if os.path.isdir(source): + logger.debug( + f"Copying feature-specific directory {source} to {destination}..." + ) + shutil.copytree(source, destination, dirs_exist_ok=True) + else: + logger.debug( + f"Copying feature-specific file {source} to {destination}..." + ) + shutil.copy2(source, destination) diff --git a/src/datacustomcode/templates/__init__.py b/src/datacustomcode/templates/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/datacustomcode/templates/function/chunking/payload/config.json b/src/datacustomcode/templates/function/chunking/payload/config.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/src/datacustomcode/templates/function/chunking/payload/config.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/src/datacustomcode/templates/function/chunking/payload/entrypoint.py b/src/datacustomcode/templates/function/chunking/payload/entrypoint.py new file mode 100644 index 0000000..dd199a7 --- /dev/null +++ b/src/datacustomcode/templates/function/chunking/payload/entrypoint.py @@ -0,0 +1,145 @@ +import logging +import uuid + +from datacustomcode.function import Runtime +from datacustomcode.function.feature_types.chunking import ( + ChunkType, + SearchIndexChunkingV1Output, + SearchIndexChunkingV1Request, + SearchIndexChunkingV1Response, +) + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +# Default max chunk size (can be overridden if contract adds max_characters field) +DEFAULT_MAX_CHUNK_SIZE = 50 + + +def split_text_into_chunks(text: str, max_size: int, overlap: int = 20): + """Split text into chunks with overlap, trying to break at natural boundaries. + + Tries to break at natural boundaries in order of preference: + 1. Paragraph boundaries (\\n\\n) + 2. Line boundaries (\\n) + 3. Sentence boundaries (. ! ?) + 4. Word boundaries (space) + 5. Hard cut if no good boundary found + + Args: + text: Text to split + max_size: Maximum characters per chunk + overlap: Number of characters to overlap between chunks + + Returns: + List of text chunks + """ + if len(text) <= max_size: + return [text] + + chunks = [] + start = 0 + + while start < len(text): + # Determine end position for this chunk + end = start + max_size + + if end >= len(text): + # Last chunk + chunks.append(text[start:]) + break + + # Try to find a good breaking point (in order of preference) + chunk_text = text[start:end] + break_point = None + + # Try to break at paragraph boundary (\n\n) + last_paragraph = chunk_text.rfind("\n\n") + if last_paragraph > max_size * 0.5: # Only if it's past halfway + break_point = start + last_paragraph + 2 # +2 to skip the \n\n + + # Try to break at line boundary (\n) + if break_point is None: + last_newline = chunk_text.rfind("\n") + if last_newline > max_size * 0.5: + break_point = start + last_newline + 1 + + # Try to break at sentence boundary (. ! ?) + if break_point is None: + for punct in [". ", "! ", "? "]: + last_sentence = chunk_text.rfind(punct) + if last_sentence > max_size * 0.5: + break_point = start + last_sentence + len(punct) + break + + # Try to break at word boundary (space) + if break_point is None: + last_space = chunk_text.rfind(" ") + if last_space > max_size * 0.5: + break_point = start + last_space + 1 + + # If no good breaking point, just hard cut + if break_point is None: + break_point = end + + chunks.append(text[start:break_point].strip()) + + # Move start position with overlap + start = max(break_point - overlap, start + 1) + + return chunks + + +def function( + request: SearchIndexChunkingV1Request, runtime: Runtime +) -> SearchIndexChunkingV1Response: + """Chunk documents into smaller pieces for search indexing. + + Args: + request: SearchIndexChunkingV1Request with input documents + runtime: Runtime context (unused but required by contract) + + Returns: + SearchIndexChunkingV1Response with chunked output + """ + print(f"Received {len(request.input)} documents to chunk") + + chunks = [] + seq_no = 1 + + # Use default max chunk size + max_chunk_size = DEFAULT_MAX_CHUNK_SIZE + + # Process each document + for doc_idx, doc in enumerate(request.input): + text = doc.text + metadata = doc.metadata + + print(f"Processing document {doc_idx + 1}: {len(text)} characters") + + # Split the text using our simple chunking algorithm + text_chunks = split_text_into_chunks(text, max_chunk_size, overlap=20) + + # Create chunk outputs + for chunk_text in text_chunks: + # Create citations from source_dmo_fields if available + citations = {} + if metadata.source_dmo_fields: + for key, value in metadata.source_dmo_fields.items(): + citations[key] = str(value) + + chunk_output = SearchIndexChunkingV1Output( + chunk_id=str(uuid.uuid4()), + chunk_type=ChunkType.TEXT, + text=chunk_text.strip(), + seq_no=seq_no, + citations=citations, + ) + chunks.append(chunk_output) + + print(f"Chunk {seq_no}: {len(chunk_text)} chars") + seq_no += 1 + + print(f"Generated {len(chunks)} chunks total") + + return SearchIndexChunkingV1Response(output=chunks) diff --git a/src/datacustomcode/templates/function/chunking/requirements.txt b/src/datacustomcode/templates/function/chunking/requirements.txt new file mode 100644 index 0000000..219536a --- /dev/null +++ b/src/datacustomcode/templates/function/chunking/requirements.txt @@ -0,0 +1 @@ +# Packages required for the chunking function diff --git a/tests/test_cli.py b/tests/test_cli.py index e26cbdc..7765560 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -103,16 +103,19 @@ def test_deploy_command_success(self, mock_token_provider, mock_deploy_full): assert call_args[0][2].access_token == "test_token" assert call_args[0][2].instance_url == "https://instance.example.com" + @patch("datacustomcode.deploy.infer_use_in_feature") @patch("datacustomcode.deploy.deploy_full") @patch("datacustomcode.token_provider.CredentialsTokenProvider") def test_deploy_command_function_invoke_options( - self, mock_token_provider, mock_deploy_full + self, mock_token_provider, mock_deploy_full, mock_infer_feature ): """Test deploy command with function invoke options.""" mock_provider_instance = mock_token_provider.return_value mock_provider_instance.get_token.return_value = AccessTokenResponse( access_token="test_token", instance_url="https://instance.example.com" ) + # Mock infer_use_in_feature to return a valid feature + mock_infer_feature.return_value = "SearchIndexChunking" runner = CliRunner() with runner.isolated_filesystem(): @@ -122,16 +125,12 @@ def test_deploy_command_function_invoke_options( write_sdk_config(".", sdk_config) result = runner.invoke( deploy, - ["--name", "test-job", "--function-invoke-opt", "option1,option2"], + ["--name", "test-job"], ) assert result.exit_code == 0 mock_deploy_full.assert_called_once() - # Check that deploy_full was called with function invoke options - call_args = mock_deploy_full.call_args - assert call_args[0][1].functionInvokeOptions == ["option1", "option2"] - @patch("datacustomcode.token_provider.CredentialsTokenProvider") def test_deploy_command_credentials_error(self, mock_token_provider): """Test deploy command when credentials are not available.""" diff --git a/tests/test_function_utils.py b/tests/test_function_utils.py new file mode 100644 index 0000000..cc0f51d --- /dev/null +++ b/tests/test_function_utils.py @@ -0,0 +1,235 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import sys +import tempfile +import textwrap + +import pytest + +from datacustomcode import function_utils + + +@pytest.fixture +def sample_entrypoint(): + """Create a temporary entrypoint file with a function.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as temp_file: + entrypoint_content = textwrap.dedent( + """ + from typing import List + from pydantic import BaseModel + + class SampleRequest(BaseModel): + message: str + count: int = 5 + tags: List[str] = [] + version: str = "v1" + + class SampleResponse(BaseModel): + result: str + success: bool = True + + def function(request: SampleRequest) -> SampleResponse: + return SampleResponse(result=f"Processed {request.message}") + """ + ) + temp_file.write(entrypoint_content) + temp_file_path = temp_file.name + + yield temp_file_path + + if os.path.exists(temp_file_path): + os.unlink(temp_file_path) + + +@pytest.fixture +def entrypoint_no_annotations(): + """Create an entrypoint with no type annotations.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as temp_file: + entrypoint_content = textwrap.dedent( + """ + def function(request): + return {"result": "no annotations"} + """ + ) + temp_file.write(entrypoint_content) + temp_file_path = temp_file.name + + yield temp_file_path + + if os.path.exists(temp_file_path): + os.unlink(temp_file_path) + + +def test_get_function_signature_types(sample_entrypoint, entrypoint_no_annotations): + """Test extracting request and response types from function signatures.""" + module = function_utils.load_function_module(sample_entrypoint) + func = function_utils.get_function_callable(module) + req_type, resp_type, req_name, resp_name = ( + function_utils.get_function_signature_types(func) + ) + + assert req_name == "SampleRequest" + assert resp_name == "SampleResponse" + assert req_type is not None + assert resp_type is not None + + module_no_annot = function_utils.load_function_module(entrypoint_no_annotations) + func_no_annot = function_utils.get_function_callable(module_no_annot) + req_type, resp_type, req_name, resp_name = ( + function_utils.get_function_signature_types(func_no_annot) + ) + + assert req_name is None + assert resp_name is None + + +def test_inspect_function_types_static(sample_entrypoint, entrypoint_no_annotations): + """Test static AST-based inspection of function types.""" + req_name, resp_name = function_utils.inspect_function_types_static( + sample_entrypoint + ) + assert req_name == "SampleRequest" + assert resp_name == "SampleResponse" + + req_name, resp_name = function_utils.inspect_function_types_static( + entrypoint_no_annotations + ) + assert req_name is None + assert resp_name is None + + +def test_inspect_function_types(sample_entrypoint): + """Test dynamic inspection of function types.""" + req_name, resp_name = function_utils.inspect_function_types(sample_entrypoint) + assert req_name == "SampleRequest" + assert resp_name == "SampleResponse" + + req_name, resp_name = function_utils.inspect_function_types("/nonexistent/file.py") + assert req_name is None + assert resp_name is None + + +def test_get_request_type(sample_entrypoint, entrypoint_no_annotations): + """Test getting request type from entrypoint.""" + req_type = function_utils.get_request_type(sample_entrypoint) + assert req_type is not None + assert hasattr(req_type, "model_fields") + + with pytest.raises(ValueError, match="must have a type annotation"): + function_utils.get_request_type(entrypoint_no_annotations) + + +def test_generate_test_json(): + """Test generating test.json with simple and complex nested types.""" + temp_dir = tempfile.mkdtemp() + models_file = os.path.join(temp_dir, "test_models.py") + + try: + # Test 1: Simple request type + entrypoint_simple = os.path.join(temp_dir, "entrypoint_simple.py") + output_simple = os.path.join(temp_dir, "test_simple.json") + + with open(models_file, "w") as f: + models_content = textwrap.dedent( + """ + from pydantic import BaseModel + from typing import List + + class SimpleRequest(BaseModel): + message: str + count: int = 5 + tags: List[str] = [] + version: str = "v1" + + class NestedConfig(BaseModel): + host: str + port: int = 8080 + enabled: bool = True + + class ComplexRequest(BaseModel): + name: str + max_items: int = 100 + config: NestedConfig + metadata: dict = {} + """ + ) + f.write(models_content) + + with open(entrypoint_simple, "w") as f: + entrypoint_content = textwrap.dedent( + """ + from test_models import SimpleRequest + + def function(request: SimpleRequest): + return {"result": "ok"} + """ + ) + f.write(entrypoint_content) + + sys.path.insert(0, temp_dir) + + function_utils.generate_test_json(entrypoint_simple, output_simple) + assert os.path.exists(output_simple) + + with open(output_simple, "r") as f: + data = json.load(f) + + assert "message" in data + assert data["count"] == 5 + assert data["version"] == "v1" + assert data["tags"] == [] + + # Test 2: Complex request type with nested models + entrypoint_complex = os.path.join(temp_dir, "entrypoint_complex.py") + output_complex = os.path.join(temp_dir, "test_complex.json") + + with open(entrypoint_complex, "w") as f: + entrypoint_content = textwrap.dedent( + """ + from test_models import ComplexRequest + + def function(request: ComplexRequest): + return {"result": "ok"} + """ + ) + f.write(entrypoint_content) + + function_utils.generate_test_json(entrypoint_complex, output_complex) + assert os.path.exists(output_complex) + + with open(output_complex, "r") as f: + complex_data = json.load(f) + + assert "name" in complex_data + assert "max_items" in complex_data + assert complex_data["max_items"] == 100 + assert "config" in complex_data + assert isinstance(complex_data["config"], dict) + assert "host" in complex_data["config"] + assert "port" in complex_data["config"] + assert complex_data["config"]["port"] == 8080 + assert complex_data["config"]["enabled"] is True + assert "metadata" in complex_data + assert complex_data["metadata"] == {} + + finally: + if temp_dir in sys.path: + sys.path.remove(temp_dir) + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) diff --git a/tests/test_sf_cli_contract.py b/tests/test_sf_cli_contract.py index b96ab35..f53123e 100644 --- a/tests/test_sf_cli_contract.py +++ b/tests/test_sf_cli_contract.py @@ -188,25 +188,6 @@ def test_accepts_network_flag( result = runner.invoke(deploy, [*self._BASE_ARGS, "--network", "custom"]) assert result.exit_code != 2, result.output - @patch("datacustomcode.token_provider.SFCLITokenProvider") - @patch("datacustomcode.deploy.deploy_full") - @patch("datacustomcode.cli.find_base_directory") - @patch("datacustomcode.cli.get_package_type") - def test_accepts_function_invoke_opt_flag( - self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_provider - ): - mock_find_base.return_value = "payload" - mock_pkg_type.return_value = "function" - mock_provider_instance = mock_sf_cli_provider.return_value - mock_provider_instance.get_token.return_value = AccessTokenResponse( - access_token="tok", instance_url="https://example.com" - ) - runner = CliRunner() - result = runner.invoke( - deploy, [*self._BASE_ARGS, "--function-invoke-opt", "ASYNC"] - ) - assert result.exit_code != 2, result.output - class TestRunArgContract: """