diff --git a/packages/markitdown/src/markitdown/_markitdown.py b/packages/markitdown/src/markitdown/_markitdown.py index b7ac5bc..cc5759a 100644 --- a/packages/markitdown/src/markitdown/_markitdown.py +++ b/packages/markitdown/src/markitdown/_markitdown.py @@ -73,6 +73,19 @@ def _load_plugins() -> Union[None | List[Any]]: return _plugins +def isBase64(sb): + """ + checks if the input object is base64 + """ + try: + if isinstance(sb, str): + sb = re.sub(r"^data:.*base64,", "", sb) + sb_bytes = bytes(sb, "ascii") + elif isinstance(sb, bytes): + sb_bytes = sb + return base64.b64encode(base64.b64decode(sb_bytes)) == sb_bytes + except Exception: + return False class MarkItDown: """(In preview) An extremely simple text-based document reader, suitable for LLM use. @@ -175,29 +188,65 @@ class MarkItDown: warn("Plugins converters are already enabled.", RuntimeWarning) def convert( - self, source: Union[str, requests.Response, Path], **kwargs: Any + self, source: Union[str, requests.Response, Path], input_type: Literal["auto", "local_file", "url", "base64", "bytes", "request_response"] = "auto", **kwargs: Any ) -> DocumentConverterResult: # TODO: deal with kwargs """ Args: - source: can be a string representing a path either as string pathlib path object or url, or a requests.response object + - input_type: specifies the input type. If set to "auto", the function will try to automatically determine the type. - extension: specifies the file extension to use when interpreting the file. If None, infer from source (path, uri, content-type, etc.) """ - # Local path or url - if isinstance(source, str): - if ( - source.startswith("http://") - or source.startswith("https://") - or source.startswith("file://") - ): - return self.convert_url(source, **kwargs) + if input_type == "auto": + # Check if source is Local path or url + if isinstance(source, str): + if ( + source.startswith("http://") + or source.startswith("https://") + or source.startswith("file://") + ): + input_type = "url" + elif os.path.isfile(source): + input_type = "local_file" + elif isBase64(source): + input_type = "base64" + # Check if source is a Request response + elif isinstance(source, requests.Response): + input_type = "request_response" + # Check if source is a local file path + elif isinstance(source, Path): + input_type = "local_file" + # Check if source is a Base64 encoded string + elif isBase64(source): + input_type = "base64" + # Check if source is a bytes object + elif isinstance(source, bytes): + input_type = "bytes" else: - return self.convert_local(source, **kwargs) - # Request response - elif isinstance(source, requests.Response): - return self.convert_response(source, **kwargs) - elif isinstance(source, Path): + raise ValueError(f"Unable to determine input type: {type(source)}") + + elif input_type == "url": + return self.convert_url(source, **kwargs) + elif input_type == "local_file": return self.convert_local(source, **kwargs) + elif input_type == "bytes" or input_type == "base64": + if input_type == "base64": + source = re.sub(r"^data:.*base64,", "", source) + source = base64.b64decode(source) + try: + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + tmp_file.write(source) + tmp_file.flush() # Ensure data is written to file + return self.convert_local(tmp_file.name, **kwargs) + except Exception as e: + raise e + finally: + if os.path.exists(tmp_file.name): + os.remove(tmp_file.name) + elif input_type == "request_response": + return self.convert_response(source, **kwargs) + else: + raise ValueError(f"Invalid input type: {input_type}") def convert_local( self, path: Union[str, Path], **kwargs: Any