Support Async LLM Client, reuse llm_caption helper

This commit is contained in:
hlohaus 2025-04-30 23:16:22 +02:00
parent 041be54471
commit 4a2f793869
5 changed files with 34 additions and 69 deletions

View file

@ -8,7 +8,7 @@ from ._markitdown import (
PRIORITY_SPECIFIC_FILE_FORMAT, PRIORITY_SPECIFIC_FILE_FORMAT,
PRIORITY_GENERIC_FILE_FORMAT, PRIORITY_GENERIC_FILE_FORMAT,
) )
from ._base_converter import DocumentConverterResult, DocumentConverter from ._base_converter import DocumentConverterResult, AsyncDocumentConverterResult, DocumentConverter
from ._stream_info import StreamInfo from ._stream_info import StreamInfo
from ._exceptions import ( from ._exceptions import (
MarkItDownException, MarkItDownException,
@ -23,6 +23,7 @@ __all__ = [
"MarkItDown", "MarkItDown",
"DocumentConverter", "DocumentConverter",
"DocumentConverterResult", "DocumentConverterResult",
"AsyncDocumentConverterResult",
"MarkItDownException", "MarkItDownException",
"MissingDependencyException", "MissingDependencyException",
"FailedConversionAttempt", "FailedConversionAttempt",

View file

@ -1,7 +1,4 @@
import os from typing import Any, BinaryIO, Optional, Awaitable
import tempfile
from warnings import warn
from typing import Any, Union, BinaryIO, Optional, List
from ._stream_info import StreamInfo from ._stream_info import StreamInfo
@ -41,6 +38,14 @@ class DocumentConverterResult:
"""Return the converted Markdown text.""" """Return the converted Markdown text."""
return self.markdown return self.markdown
class AsyncDocumentConverterResult:
"""The result of converting a document to Markdown."""
def __init__(
self,
text_content: Awaitable[str],
):
self.text_content = text_content
class DocumentConverter: class DocumentConverter:
"""Abstract superclass of all DocumentConverters.""" """Abstract superclass of all DocumentConverters."""

View file

@ -1,11 +1,9 @@
import copy
import mimetypes import mimetypes
import os import os
import re import re
import sys import sys
import shutil import shutil
import tempfile import asyncio
import warnings
import traceback import traceback
import io import io
from dataclasses import dataclass from dataclasses import dataclass
@ -600,6 +598,9 @@ class MarkItDown:
finally: finally:
file_stream.seek(cur_pos) file_stream.seek(cur_pos)
if asyncio.iscoroutine(res):
return res
if res is not None: if res is not None:
# Normalize the content # Normalize the content
res.text_content = "\n".join( res.text_content = "\n".join(

View file

@ -1,9 +1,9 @@
from typing import BinaryIO, Any, Union from typing import BinaryIO, Any
import base64 import asyncio
import mimetypes
from ._exiftool import exiftool_metadata from ._exiftool import exiftool_metadata
from .._base_converter import DocumentConverter, DocumentConverterResult from .._base_converter import DocumentConverter, DocumentConverterResult, AsyncDocumentConverterResult
from .._stream_info import StreamInfo from .._stream_info import StreamInfo
from ._llm_caption import llm_caption
ACCEPTED_MIME_TYPE_PREFIXES = [ ACCEPTED_MIME_TYPE_PREFIXES = [
"image/jpeg", "image/jpeg",
@ -69,7 +69,7 @@ class ImageConverter(DocumentConverter):
llm_client = kwargs.get("llm_client") llm_client = kwargs.get("llm_client")
llm_model = kwargs.get("llm_model") llm_model = kwargs.get("llm_model")
if llm_client is not None and llm_model is not None: if llm_client is not None and llm_model is not None:
llm_description = self._get_llm_description( llm_description = llm_caption(
file_stream, file_stream,
stream_info, stream_info,
client=llm_client, client=llm_client,
@ -77,62 +77,14 @@ class ImageConverter(DocumentConverter):
prompt=kwargs.get("llm_prompt"), prompt=kwargs.get("llm_prompt"),
) )
if asyncio.iscoroutine(llm_description):
return AsyncDocumentConverterResult(
llm_description,
)
if llm_description is not None: if llm_description is not None:
md_content += "\n# Description:\n" + llm_description.strip() + "\n" md_content += "\n# Description:\n" + llm_description.strip() + "\n"
return DocumentConverterResult( return DocumentConverterResult(
markdown=md_content, markdown=md_content,
) )
def _get_llm_description(
self,
file_stream: BinaryIO,
stream_info: StreamInfo,
*,
client,
model,
prompt=None,
) -> Union[None, str]:
if prompt is None or prompt.strip() == "":
prompt = "Write a detailed caption for this image."
# Get the content type
content_type = stream_info.mimetype
if not content_type:
content_type, _ = mimetypes.guess_type(
"_dummy" + (stream_info.extension or "")
)
if not content_type:
content_type = "application/octet-stream"
# Convert to base64
cur_pos = file_stream.tell()
try:
base64_image = base64.b64encode(file_stream.read()).decode("utf-8")
except Exception as e:
return None
finally:
file_stream.seek(cur_pos)
# Prepare the data-uri
data_uri = f"data:{content_type};base64,{base64_image}"
# Prepare the OpenAI API request
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": data_uri,
},
},
],
}
]
# Call the OpenAI API
response = client.chat.completions.create(model=model, messages=messages)
return response.choices[0].message.content

View file

@ -1,12 +1,13 @@
from typing import BinaryIO, Any, Union from typing import BinaryIO, Any, Union, Awaitable
import base64 import base64
import mimetypes import mimetypes
import asyncio
from .._stream_info import StreamInfo from .._stream_info import StreamInfo
def llm_caption( def llm_caption(
file_stream: BinaryIO, stream_info: StreamInfo, *, client, model, prompt=None file_stream: BinaryIO, stream_info: StreamInfo, *, client, model, prompt=None
) -> Union[None, str]: ) -> Union[None, str, Awaitable[str]]:
if prompt is None or prompt.strip() == "": if prompt is None or prompt.strip() == "":
prompt = "Write a detailed caption for this image." prompt = "Write a detailed caption for this image."
@ -47,4 +48,9 @@ def llm_caption(
# Call the OpenAI API # Call the OpenAI API
response = client.chat.completions.create(model=model, messages=messages) response = client.chat.completions.create(model=model, messages=messages)
if asyncio.iscoroutine(response):
async def read_content():
response = await response
return response.choices[0].message.content
return read_content()
return response.choices[0].message.content return response.choices[0].message.content