1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
| import oss2
class oss(object): """对象存储类,将模型传至阿里云端"""
def __init__(self, access_key_id, access_key_secret, endpoint, bucket_name): self.auth = oss2.Auth(access_key_id, access_key_secret) self.bucket = oss2.Bucket(self.auth, endpoint, bucket_name) # 连接OSS
def put_file(self, file_path, oss_path): with open("{}".format(file_path), "rb") as f: put_result = self.bucket.put_object(oss_path, f) if put_result.status == 200: # 若此时的status状态为200,则说明上传成功; print("put success")
def get_file(self, file_path, oss_path): # param1:oss上bucket中的文件名 # param2:保存在当地的文件路径+文件名 get_result = self.bucket.get_object_to_file(oss_path, file_path) if get_result.status == 200: print("get success") else: print("get failed")
oss_server = oss( access_key_id="AccessKey"), access_key_secret="AccessKeySecret"), endpoint="EndPoint", bucket_name="Bucket", )
def download_longfor_bert(pretrain_file, oss_get_path): """获取OSS指定目录下的文件 """ for obj in oss2.ObjectIterator(oss_server.bucket, prefix = oss_get_path, delimiter = '/'): # 通过is_prefix方法判断obj是否为文件夹。 if obj.is_prefix(): # 判断obj为文件夹。 print('directory: ' + obj.key) else: # 判断obj为文件。 print('file: ' + obj.key) file_name = str(obj.key).split('/')[-1] if file_name: oss_server.get_file(pretrain_file+file_name, obj.key)
|