Skip to content
Merged
82 changes: 39 additions & 43 deletions sentry_sdk/integrations/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ async def _sentry_patched_create_async(*args: "Any", **kwargs: "Any") -> "Any":
return _sentry_patched_create_async


def _new_embeddings_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
def _new_sync_embeddings_create(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
if integration is None:
return f(*args, **kwargs)
Expand All @@ -1157,7 +1157,13 @@ def _new_embeddings_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")
_set_embeddings_input_data(span, kwargs, integration)

response = yield f, args, kwargs
try:
response = f(*args, **kwargs)
except Exception as exc:
exc_info = sys.exc_info()
with capture_internal_exceptions():
_capture_exception(exc, manual_span_cleanup=False)
reraise(*exc_info)

_set_embeddings_output_data(
span, response, kwargs, integration, finish_span=False
Expand All @@ -1166,68 +1172,58 @@ def _new_embeddings_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "A
return response


def _wrap_embeddings_create(f: "Any") -> "Any":
def _execute_sync(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
gen = _new_embeddings_create_common(f, *args, **kwargs)
async def _new_async_embeddings_create(
f: "Any", *args: "Any", **kwargs: "Any"
) -> "Any":
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
if integration is None:
return await f(*args, **kwargs)

try:
f, args, kwargs = next(gen)
except StopIteration as e:
return e.value
model = kwargs.get("model")

with sentry_sdk.start_span(
op=consts.OP.GEN_AI_EMBEDDINGS,
name=f"embeddings {model}",
origin=OpenAIIntegration.origin,
) as span:
span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")
_set_embeddings_input_data(span, kwargs, integration)

try:
try:
result = f(*args, **kwargs)
except Exception as e:
exc_info = sys.exc_info()
with capture_internal_exceptions():
_capture_exception(e, manual_span_cleanup=False)
reraise(*exc_info)

return gen.send(result)
except StopIteration as e:
return e.value
response = await f(*args, **kwargs)
except Exception as exc:
exc_info = sys.exc_info()
with capture_internal_exceptions():
_capture_exception(exc, manual_span_cleanup=False)
reraise(*exc_info)

_set_embeddings_output_data(
span, response, kwargs, integration, finish_span=False
)

return response


def _wrap_embeddings_create(f: "Any") -> "Any":
@wraps(f)
def _sentry_patched_create_sync(*args: "Any", **kwargs: "Any") -> "Any":
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
if integration is None:
return f(*args, **kwargs)

return _execute_sync(f, *args, **kwargs)
return _new_sync_embeddings_create(f, *args, **kwargs)

return _sentry_patched_create_sync


def _wrap_async_embeddings_create(f: "Any") -> "Any":
async def _execute_async(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
gen = _new_embeddings_create_common(f, *args, **kwargs)

try:
f, args, kwargs = next(gen)
except StopIteration as e:
return await e.value

try:
try:
result = await f(*args, **kwargs)
except Exception as e:
exc_info = sys.exc_info()
with capture_internal_exceptions():
_capture_exception(e, manual_span_cleanup=False)
reraise(*exc_info)

return gen.send(result)
except StopIteration as e:
return e.value

@wraps(f)
async def _sentry_patched_create_async(*args: "Any", **kwargs: "Any") -> "Any":
integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
if integration is None:
return await f(*args, **kwargs)

return await _execute_async(f, *args, **kwargs)
return await _new_async_embeddings_create(f, *args, **kwargs)

return _sentry_patched_create_async

Expand Down
Loading