Browse Source

Merge pull request #13 from papayalove/master

更新io modules
drunkpig 1 year ago
parent
commit
c5424292d7
3 changed files with 128 additions and 20 deletions
  1. 10 8
      magic_pdf/io/AbsReaderWriter.py
  2. 49 0
      magic_pdf/io/DiskReaderWriter.py
  3. 69 12
      magic_pdf/io/S3ReaderWriter.py

+ 10 - 8
magic_pdf/io/AbsReaderWriter.py

@@ -1,20 +1,22 @@
-
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 
 
 
 
 class AbsReaderWriter(ABC):
 class AbsReaderWriter(ABC):
     """
     """
     同时支持二进制和文本读写的抽象类
     同时支持二进制和文本读写的抽象类
-    TODO
     """
     """
+
+    def __init__(self):
+        # 初始化代码可以在这里添加,如果需要的话
+        pass
+
     @abstractmethod
     @abstractmethod
-    def read(self, path: str):
+    def read(self, path: str, mode="text"):
         pass
         pass
 
 
     @abstractmethod
     @abstractmethod
-    def write(self, path: str, content: str):
+    def write(self, content: str, path: str, mode="text"):
         pass
         pass
-    
-    
-    
-    
+
+
+

+ 49 - 0
magic_pdf/io/DiskReaderWriter.py

@@ -0,0 +1,49 @@
+import os
+from magic_pdf.io.AbsReaderWriter import AbsReaderWriter
+from loguru import logger
+class DiskReaderWriter(AbsReaderWriter):
+    def __init__(self, parent_path, encoding='utf-8'):
+        self.path = parent_path
+        self.encoding = encoding
+
+    def read(self, mode="text"):
+        if not os.path.exists(self.path):
+            logger.error(f"文件 {self.path} 不存在")
+            raise Exception(f"文件 {self.path} 不存在")
+        if mode == "text":
+            with open(self.path, 'r', encoding = self.encoding) as f:
+                return f.read()
+        elif mode == "binary":
+            with open(self.path, 'rb') as f:
+                return f.read()
+        else:
+            raise ValueError("Invalid mode. Use 'text' or 'binary'.")
+
+    def write(self, data, mode="text"):
+        if mode == "text":
+            with open(self.path, 'w', encoding=self.encoding) as f:
+                f.write(data)
+                logger.info(f"内容已成功写入 {self.path}")
+
+        elif mode == "binary":
+            with open(self.path, 'wb') as f:
+                f.write(data)
+                logger.info(f"内容已成功写入 {self.path}")
+        else:
+            raise ValueError("Invalid mode. Use 'text' or 'binary'.")
+
+
+# 使用示例
+if __name__ == "__main__":
+    file_path = "example.txt"
+    drw = DiskReaderWriter(file_path)
+
+    # 写入内容到文件
+    drw.write(b"Hello, World!", mode="binary")
+
+    # 从文件读取内容
+    content = drw.read()
+    if content:
+        logger.info(f"从 {file_path} 读取的内容: {content}")
+
+

+ 69 - 12
magic_pdf/io/S3ReaderWriter.py

@@ -1,18 +1,75 @@
 
 
 
 
-from magic_pdf.io import AbsReaderWriter
+from magic_pdf.io.AbsReaderWriter import AbsReaderWriter
+from magic_pdf.libs.commons import parse_aws_param, parse_bucket_key
+import boto3
+from loguru import logger
+from boto3.s3.transfer import TransferConfig
+from botocore.config import Config
 
 
 
 
-class DiskReaderWriter(AbsReaderWriter):
-    def __init__(self, parent_path, encoding='utf-8'):
-        self.path = parent_path
-        self.encoding = encoding
+class S3ReaderWriter(AbsReaderWriter):
+    def __init__(self, s3_profile):
+        self.client = self._get_client(s3_profile)
 
 
-    def read(self):
-        with open(self.path, 'rb') as f:
-            return f.read()
+    def _get_client(self, s3_profile):
 
 
-    def write(self, data):
-        with open(self.path, 'wb') as f:
-            f.write(data)
-            
+        ak, sk, end_point, addressing_style = parse_aws_param(s3_profile)
+        s3_client = boto3.client(
+            service_name="s3",
+            aws_access_key_id=ak,
+            aws_secret_access_key=sk,
+            endpoint_url=end_point,
+            config=Config(s3={"addressing_style": addressing_style},
+                          retries={'max_attempts': 5, 'mode': 'standard'}),
+        )
+
+        return s3_client
+    def read(self, s3_path, mode="text", encoding="utf-8"):
+        bucket_name, bucket_key = parse_bucket_key(s3_path)
+        res = self.client.get_object(Bucket=bucket_name, Key=bucket_key)
+        body = res["Body"].read()
+        if mode == 'text':
+            data = body.decode(encoding)  # Decode bytes to text
+        elif mode == 'binary':
+            data = body
+        else:
+            raise ValueError("Invalid mode. Use 'text' or 'binary'.")
+        return data
+
+    def write(self, data, s3_path, mode="text", encoding="utf-8"):
+        if mode == 'text':
+            body = data.encode(encoding)  # Encode text data as bytes
+        elif mode == 'binary':
+            body = data
+        else:
+            raise ValueError("Invalid mode. Use 'text' or 'binary'.")
+        bucket_name, bucket_key = parse_bucket_key(s3_path)
+        self.client.put_object(Body=body, Bucket=bucket_name, Key=bucket_key)
+        logger.info(f"内容已写入 {s3_path} ")
+
+
+if __name__ == "__main__":
+    # Config the connection info
+    profile = {
+        'ak': '',
+        'sk': '',
+        'endpoint': ''
+    }
+    # Create an S3ReaderWriter object
+    s3_reader_writer = S3ReaderWriter(profile)
+
+    # Write text data to S3
+    text_data = "This is some text data"
+    s3_reader_writer.write(data=text_data, s3_path = "s3://bucket_name/ebook/test/test.json", mode='text')
+
+    # Read text data from S3
+    text_data_read = s3_reader_writer.read(s3_path = "s3://bucket_name/ebook/test/test.json", mode='text')
+    logger.info(f"Read text data from S3: {text_data_read}")
+    # Write binary data to S3
+    binary_data = b"This is some binary data"
+    s3_reader_writer.write(data=text_data, s3_path = "s3://bucket_name/ebook/test/test2.json", mode='binary')
+
+    # Read binary data from S3
+    binary_data_read = s3_reader_writer.read(s3_path = "s3://bucket_name/ebook/test/test2.json", mode='binary')
+    logger.info(f"Read binary data from S3: {binary_data_read}")