Merge c856a6c2c1 into 041be54471
This commit is contained in:
commit
69437fdc34
5 changed files with 34 additions and 69 deletions
|
|
@ -8,7 +8,7 @@ from ._markitdown import (
|
|||
PRIORITY_SPECIFIC_FILE_FORMAT,
|
||||
PRIORITY_GENERIC_FILE_FORMAT,
|
||||
)
|
||||
from ._base_converter import DocumentConverterResult, DocumentConverter
|
||||
from ._base_converter import DocumentConverterResult, AsyncDocumentConverterResult, DocumentConverter
|
||||
from ._stream_info import StreamInfo
|
||||
from ._exceptions import (
|
||||
MarkItDownException,
|
||||
|
|
@ -23,6 +23,7 @@ __all__ = [
|
|||
"MarkItDown",
|
||||
"DocumentConverter",
|
||||
"DocumentConverterResult",
|
||||
"AsyncDocumentConverterResult",
|
||||
"MarkItDownException",
|
||||
"MissingDependencyException",
|
||||
"FailedConversionAttempt",
|
||||
|
|
|
|||
|
|
@ -1,7 +1,4 @@
|
|||
import os
|
||||
import tempfile
|
||||
from warnings import warn
|
||||
from typing import Any, Union, BinaryIO, Optional, List
|
||||
from typing import Any, BinaryIO, Optional, Awaitable
|
||||
from ._stream_info import StreamInfo
|
||||
|
||||
|
||||
|
|
@ -41,6 +38,14 @@ class DocumentConverterResult:
|
|||
"""Return the converted Markdown text."""
|
||||
return self.markdown
|
||||
|
||||
class AsyncDocumentConverterResult:
|
||||
"""The result of converting a document to Markdown."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_content: Awaitable[str],
|
||||
):
|
||||
self.text_content = text_content
|
||||
|
||||
class DocumentConverter:
|
||||
"""Abstract superclass of all DocumentConverters."""
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
import copy
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
import asyncio
|
||||
import traceback
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -600,6 +598,9 @@ class MarkItDown:
|
|||
finally:
|
||||
file_stream.seek(cur_pos)
|
||||
|
||||
if asyncio.iscoroutine(res):
|
||||
return res
|
||||
|
||||
if res is not None:
|
||||
# Normalize the content
|
||||
res.text_content = "\n".join(
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from typing import BinaryIO, Any, Union
|
||||
import base64
|
||||
import mimetypes
|
||||
from typing import BinaryIO, Any
|
||||
import asyncio
|
||||
from ._exiftool import exiftool_metadata
|
||||
from .._base_converter import DocumentConverter, DocumentConverterResult
|
||||
from .._base_converter import DocumentConverter, DocumentConverterResult, AsyncDocumentConverterResult
|
||||
from .._stream_info import StreamInfo
|
||||
from ._llm_caption import llm_caption
|
||||
|
||||
ACCEPTED_MIME_TYPE_PREFIXES = [
|
||||
"image/jpeg",
|
||||
|
|
@ -69,7 +69,7 @@ class ImageConverter(DocumentConverter):
|
|||
llm_client = kwargs.get("llm_client")
|
||||
llm_model = kwargs.get("llm_model")
|
||||
if llm_client is not None and llm_model is not None:
|
||||
llm_description = self._get_llm_description(
|
||||
llm_description = llm_caption(
|
||||
file_stream,
|
||||
stream_info,
|
||||
client=llm_client,
|
||||
|
|
@ -77,62 +77,14 @@ class ImageConverter(DocumentConverter):
|
|||
prompt=kwargs.get("llm_prompt"),
|
||||
)
|
||||
|
||||
if asyncio.iscoroutine(llm_description):
|
||||
return AsyncDocumentConverterResult(
|
||||
llm_description,
|
||||
)
|
||||
|
||||
if llm_description is not None:
|
||||
md_content += "\n# Description:\n" + llm_description.strip() + "\n"
|
||||
|
||||
return DocumentConverterResult(
|
||||
markdown=md_content,
|
||||
)
|
||||
|
||||
def _get_llm_description(
|
||||
self,
|
||||
file_stream: BinaryIO,
|
||||
stream_info: StreamInfo,
|
||||
*,
|
||||
client,
|
||||
model,
|
||||
prompt=None,
|
||||
) -> Union[None, str]:
|
||||
if prompt is None or prompt.strip() == "":
|
||||
prompt = "Write a detailed caption for this image."
|
||||
|
||||
# Get the content type
|
||||
content_type = stream_info.mimetype
|
||||
if not content_type:
|
||||
content_type, _ = mimetypes.guess_type(
|
||||
"_dummy" + (stream_info.extension or "")
|
||||
)
|
||||
if not content_type:
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
# Convert to base64
|
||||
cur_pos = file_stream.tell()
|
||||
try:
|
||||
base64_image = base64.b64encode(file_stream.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
return None
|
||||
finally:
|
||||
file_stream.seek(cur_pos)
|
||||
|
||||
# Prepare the data-uri
|
||||
data_uri = f"data:{content_type};base64,{base64_image}"
|
||||
|
||||
# Prepare the OpenAI API request
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": data_uri,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Call the OpenAI API
|
||||
response = client.chat.completions.create(model=model, messages=messages)
|
||||
return response.choices[0].message.content
|
||||
)
|
||||
|
|
@ -1,12 +1,13 @@
|
|||
from typing import BinaryIO, Any, Union
|
||||
from typing import BinaryIO, Any, Union, Awaitable
|
||||
import base64
|
||||
import mimetypes
|
||||
import asyncio
|
||||
from .._stream_info import StreamInfo
|
||||
|
||||
|
||||
def llm_caption(
|
||||
file_stream: BinaryIO, stream_info: StreamInfo, *, client, model, prompt=None
|
||||
) -> Union[None, str]:
|
||||
) -> Union[None, str, Awaitable[str]]:
|
||||
if prompt is None or prompt.strip() == "":
|
||||
prompt = "Write a detailed caption for this image."
|
||||
|
||||
|
|
@ -47,4 +48,9 @@ def llm_caption(
|
|||
|
||||
# Call the OpenAI API
|
||||
response = client.chat.completions.create(model=model, messages=messages)
|
||||
if asyncio.iscoroutine(response):
|
||||
async def read_content(response):
|
||||
response = await response
|
||||
return response.choices[0].message.content
|
||||
return read_content(response)
|
||||
return response.choices[0].message.content
|
||||
|
|
|
|||
Loading…
Reference in a new issue