This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 19fd88589 Port Add stream upload (multi-part upload) (#2147)
19fd88589 is described below
commit 19fd8858991bcf8e654c221e6956ce6a8b5a86e1
Author: Andrew Lamb <[email protected]>
AuthorDate: Tue Jul 26 10:20:27 2022 -0400
Port Add stream upload (multi-part upload) (#2147)
* feat: Add stream upload (multi-part upload) (#20)
* feat: Implement multi-part upload
Co-authored-by: Raphael Taylor-Davies <[email protected]>
* chore: simplify local file implementation
* chore: Remove pin-project
* feat: make cleanup_upload() top-level
* docs: Add some docs for upload
* chore: fix linting issue
* fix: rename to put_multipart
* feat: Implement multi-part upload for GCP
* fix: Get GCS test to pass
* chore: remove more upload language
* fix: Add guard to test so we don't run with fake gcs server
* chore: small tweaks
* fix: apply suggestions from code review
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
* feat: switch to quick-xml
* feat: remove throttle implementation of multipart
* fix: rename from cleanup to abort
* feat: enforce upload not readable until shutdown
* fix: ensure we close files before moving them
* chore: fix lint issue
Co-authored-by: Raphael Taylor-Davies <[email protected]>
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
* fmt
* RAT multipart
* Fix build
* fix: merge issue
Co-authored-by: Will Jones <[email protected]>
Co-authored-by: Raphael Taylor-Davies <[email protected]>
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
---
object_store/Cargo.toml | 5 +-
object_store/src/aws.rs | 231 ++++++++++++++++++++++++++-
object_store/src/azure.rs | 125 ++++++++++++++-
object_store/src/gcp.rs | 340 ++++++++++++++++++++++++++++++++++++---
object_store/src/lib.rs | 105 +++++++++++-
object_store/src/local.rs | 361 +++++++++++++++++++++++++++++++++++++++---
object_store/src/memory.rs | 69 +++++++-
object_store/src/multipart.rs | 195 +++++++++++++++++++++++
object_store/src/throttle.rs | 17 ++
9 files changed, 1392 insertions(+), 56 deletions(-)
diff --git a/object_store/Cargo.toml b/object_store/Cargo.toml
index 613b6ab2e..741539891 100644
--- a/object_store/Cargo.toml
+++ b/object_store/Cargo.toml
@@ -44,6 +44,7 @@ chrono = { version = "0.4", default-features = false,
features = ["clock"] }
futures = "0.3"
serde = { version = "1.0", default-features = false, features = ["derive"],
optional = true }
serde_json = { version = "1.0", default-features = false, optional = true }
+quick-xml = { version = "0.23.0", features = ["serialize"], optional = true }
rustls-pemfile = { version = "1.0", default-features = false, optional = true }
ring = { version = "0.16", default-features = false, features = ["std"] }
base64 = { version = "0.13", default-features = false, optional = true }
@@ -59,7 +60,7 @@ rusoto_credential = { version = "0.48.0", optional = true,
default-features = fa
rusoto_s3 = { version = "0.48.0", optional = true, default-features = false,
features = ["rustls"] }
rusoto_sts = { version = "0.48.0", optional = true, default-features = false,
features = ["rustls"] }
snafu = "0.7"
-tokio = { version = "1.18", features = ["sync", "macros", "parking_lot",
"rt-multi-thread", "time"] }
+tokio = { version = "1.18", features = ["sync", "macros", "parking_lot",
"rt-multi-thread", "time", "io-util"] }
tracing = { version = "0.1" }
reqwest = { version = "0.11", optional = true, default-features = false,
features = ["rustls-tls"] }
parking_lot = { version = "0.12" }
@@ -70,7 +71,7 @@ walkdir = "2"
[features]
azure = ["azure_core", "azure_storage_blobs", "azure_storage", "reqwest"]
azure_test = ["azure", "azure_core/azurite_workaround",
"azure_storage/azurite_workaround", "azure_storage_blobs/azurite_workaround"]
-gcp = ["serde", "serde_json", "reqwest", "reqwest/json", "reqwest/stream",
"chrono/serde", "rustls-pemfile", "base64"]
+gcp = ["serde", "serde_json", "quick-xml", "reqwest", "reqwest/json",
"reqwest/stream", "chrono/serde", "rustls-pemfile", "base64"]
aws = ["rusoto_core", "rusoto_credential", "rusoto_s3", "rusoto_sts", "hyper",
"hyper-rustls"]
[dev-dependencies] # In alphabetical order
diff --git a/object_store/src/aws.rs b/object_store/src/aws.rs
index 7ebcc2a88..3606a3806 100644
--- a/object_store/src/aws.rs
+++ b/object_store/src/aws.rs
@@ -16,7 +16,23 @@
// under the License.
//! An object store implementation for S3
+//!
+//! ## Multi-part uploads
+//!
+//! Multi-part uploads can be initiated with the [ObjectStore::put_multipart]
method.
+//! Data passed to the writer is automatically buffered to meet the minimum
size
+//! requirements for a part. Multiple parts are uploaded concurrently.
+//!
+//! If the writer fails for any reason, you may have parts uploaded to AWS but
not
+//! used that you may be charged for. Use the [ObjectStore::abort_multipart]
method
+//! to abort the upload and drop those unneeded parts. In addition, you may
wish to
+//! consider implementing [automatic cleanup] of unused parts that are older
than one
+//! week.
+//!
+//! [automatic cleanup]:
https://aws.amazon.com/blogs/aws/s3-lifecycle-management-update-support-for-multipart-uploads-and-delete-markers/
+use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl,
UploadPart};
use crate::util::format_http_range;
+use crate::MultipartId;
use crate::{
collect_bytes,
path::{Path, DELIMITER},
@@ -26,6 +42,7 @@ use crate::{
use async_trait::async_trait;
use bytes::Bytes;
use chrono::{DateTime, Utc};
+use futures::future::BoxFuture;
use futures::{
stream::{self, BoxStream},
Future, Stream, StreamExt, TryStreamExt,
@@ -36,10 +53,12 @@ use rusoto_credential::{InstanceMetadataProvider,
StaticProvider};
use rusoto_s3::S3;
use rusoto_sts::WebIdentityProvider;
use snafu::{OptionExt, ResultExt, Snafu};
+use std::io;
use std::ops::Range;
use std::{
convert::TryFrom, fmt, num::NonZeroUsize, ops::Deref, sync::Arc,
time::Duration,
};
+use tokio::io::AsyncWrite;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tracing::{debug, warn};
@@ -129,6 +148,32 @@ enum Error {
path: String,
},
+ #[snafu(display(
+ "Unable to upload data. Bucket: {}, Location: {}, Error: {} ({:?})",
+ bucket,
+ path,
+ source,
+ source,
+ ))]
+ UnableToUploadData {
+ source:
rusoto_core::RusotoError<rusoto_s3::CreateMultipartUploadError>,
+ bucket: String,
+ path: String,
+ },
+
+ #[snafu(display(
+ "Unable to cleanup multipart data. Bucket: {}, Location: {}, Error: {}
({:?})",
+ bucket,
+ path,
+ source,
+ source,
+ ))]
+ UnableToCleanupMultipartData {
+ source: rusoto_core::RusotoError<rusoto_s3::AbortMultipartUploadError>,
+ bucket: String,
+ path: String,
+ },
+
#[snafu(display(
"Unable to list data. Bucket: {}, Error: {} ({:?})",
bucket,
@@ -272,6 +317,71 @@ impl ObjectStore for AmazonS3 {
Ok(())
}
+ async fn put_multipart(
+ &self,
+ location: &Path,
+ ) -> Result<(MultipartId, Box<dyn AsyncWrite + Unpin + Send>)> {
+ let bucket_name = self.bucket_name.clone();
+
+ let request_factory = move || rusoto_s3::CreateMultipartUploadRequest {
+ bucket: bucket_name.clone(),
+ key: location.to_string(),
+ ..Default::default()
+ };
+
+ let s3 = self.client().await;
+
+ let data = s3_request(move || {
+ let (s3, request_factory) = (s3.clone(), request_factory.clone());
+
+ async move { s3.create_multipart_upload(request_factory()).await }
+ })
+ .await
+ .context(UnableToUploadDataSnafu {
+ bucket: &self.bucket_name,
+ path: location.as_ref(),
+ })?;
+
+ let upload_id = data.upload_id.unwrap();
+
+ let inner = S3MultiPartUpload {
+ upload_id: upload_id.clone(),
+ bucket: self.bucket_name.clone(),
+ key: location.to_string(),
+ client_unrestricted: self.client_unrestricted.clone(),
+ connection_semaphore: Arc::clone(&self.connection_semaphore),
+ };
+
+ Ok((upload_id, Box::new(CloudMultiPartUpload::new(inner, 8))))
+ }
+
+ async fn abort_multipart(
+ &self,
+ location: &Path,
+ multipart_id: &MultipartId,
+ ) -> Result<()> {
+ let request_factory = move || rusoto_s3::AbortMultipartUploadRequest {
+ bucket: self.bucket_name.clone(),
+ key: location.to_string(),
+ upload_id: multipart_id.to_string(),
+ ..Default::default()
+ };
+
+ let s3 = self.client().await;
+ s3_request(move || {
+ let (s3, request_factory) = (s3.clone(), request_factory);
+
+ async move { s3.abort_multipart_upload(request_factory()).await }
+ })
+ .await
+ .context(UnableToCleanupMultipartDataSnafu {
+ bucket: &self.bucket_name,
+ path: location.as_ref(),
+ })?;
+
+ Ok(())
+ }
+
async fn get(&self, location: &Path) -> Result<GetResult> {
Ok(GetResult::Stream(
self.get_object(location, None).await?.boxed(),
@@ -821,13 +931,131 @@ impl Error {
}
}
+struct S3MultiPartUpload {
+ bucket: String,
+ key: String,
+ upload_id: String,
+ client_unrestricted: rusoto_s3::S3Client,
+ connection_semaphore: Arc<Semaphore>,
+}
+
+impl CloudMultiPartUploadImpl for S3MultiPartUpload {
+ fn put_multipart_part(
+ &self,
+ buf: Vec<u8>,
+ part_idx: usize,
+ ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>> {
+ // Get values to move into future; we don't want a reference to Self
+ let bucket = self.bucket.clone();
+ let key = self.key.clone();
+ let upload_id = self.upload_id.clone();
+ let content_length = buf.len();
+
+ let request_factory = move || rusoto_s3::UploadPartRequest {
+ bucket,
+ key,
+ upload_id,
+ // AWS part number is 1-indexed
+ part_number: (part_idx + 1).try_into().unwrap(),
+ content_length: Some(content_length.try_into().unwrap()),
+ body: Some(buf.into()),
+ ..Default::default()
+ };
+
+ let s3 = self.client_unrestricted.clone();
+ let connection_semaphore = Arc::clone(&self.connection_semaphore);
+
+ Box::pin(async move {
+ let _permit = connection_semaphore
+ .acquire_owned()
+ .await
+ .expect("semaphore shouldn't be closed yet");
+
+ let response = s3_request(move || {
+ let (s3, request_factory) = (s3.clone(),
request_factory.clone());
+ async move { s3.upload_part(request_factory()).await }
+ })
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+
+ Ok((
+ part_idx,
+ UploadPart {
+ content_id: response.e_tag.unwrap(),
+ },
+ ))
+ })
+ }
+
+ fn complete(
+ &self,
+ completed_parts: Vec<Option<UploadPart>>,
+ ) -> BoxFuture<'static, Result<(), io::Error>> {
+ let parts =
+ completed_parts
+ .into_iter()
+ .enumerate()
+ .map(|(part_number, maybe_part)| match maybe_part {
+ Some(part) => {
+ Ok(rusoto_s3::CompletedPart {
+ e_tag: Some(part.content_id),
+ part_number: Some((part_number +
1).try_into().map_err(
+ |err| io::Error::new(io::ErrorKind::Other,
err),
+ )?),
+ })
+ }
+ None => Err(io::Error::new(
+ io::ErrorKind::Other,
+ format!("Missing information for upload part {:?}",
part_number),
+ )),
+ });
+
+ // Get values to move into future; we don't want a reference to Self
+ let bucket = self.bucket.clone();
+ let key = self.key.clone();
+ let upload_id = self.upload_id.clone();
+
+ let request_factory = move || -> Result<_, io::Error> {
+ Ok(rusoto_s3::CompleteMultipartUploadRequest {
+ bucket,
+ key,
+ upload_id,
+ multipart_upload: Some(rusoto_s3::CompletedMultipartUpload {
+ parts: Some(parts.collect::<Result<_, io::Error>>()?),
+ }),
+ ..Default::default()
+ })
+ };
+
+ let s3 = self.client_unrestricted.clone();
+ let connection_semaphore = Arc::clone(&self.connection_semaphore);
+
+ Box::pin(async move {
+ let _permit = connection_semaphore
+ .acquire_owned()
+ .await
+ .expect("semaphore shouldn't be closed yet");
+
+ s3_request(move || {
+ let (s3, request_factory) = (s3.clone(),
request_factory.clone());
+
+ async move {
s3.complete_multipart_upload(request_factory()?).await }
+ })
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+
+ Ok(())
+ })
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
use crate::{
tests::{
get_nonexistent_object, list_uses_directories_correctly,
list_with_delimiter,
- put_get_delete_list, rename_and_copy,
+ put_get_delete_list, rename_and_copy, stream_get,
},
Error as ObjectStoreError, ObjectStore,
};
@@ -943,6 +1171,7 @@ mod tests {
check_credentials(list_uses_directories_correctly(&integration).await).unwrap();
check_credentials(list_with_delimiter(&integration).await).unwrap();
check_credentials(rename_and_copy(&integration).await).unwrap();
+ check_credentials(stream_get(&integration).await).unwrap();
}
#[tokio::test]
diff --git a/object_store/src/azure.rs b/object_store/src/azure.rs
index 75dafef86..25f311a9a 100644
--- a/object_store/src/azure.rs
+++ b/object_store/src/azure.rs
@@ -16,10 +16,21 @@
// under the License.
//! An object store implementation for Azure blob storage
+//!
+//! ## Streaming uploads
+//!
+//! [ObjectStore::put_multipart] will upload data in blocks and write a blob
from those
+//! blocks. Data is buffered internally to make blocks of at least 5MB and
blocks
+//! are uploaded concurrently.
+//!
+//! [ObjectStore::abort_multipart] is a no-op, since Azure Blob Store doesn't
provide
+//! a way to drop old blocks. Instead unused blocks are automatically cleaned
up
+//! after 7 days.
use crate::{
+ multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart},
path::{Path, DELIMITER},
util::format_prefix,
- GetResult, ListResult, ObjectMeta, ObjectStore, Result,
+ GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result,
};
use async_trait::async_trait;
use azure_core::{prelude::*, HttpClient};
@@ -32,12 +43,15 @@ use azure_storage_blobs::{
};
use bytes::Bytes;
use futures::{
+ future::BoxFuture,
stream::{self, BoxStream},
StreamExt, TryStreamExt,
};
use snafu::{ResultExt, Snafu};
use std::collections::BTreeSet;
+use std::io;
use std::{convert::TryInto, sync::Arc};
+use tokio::io::AsyncWrite;
use url::Url;
/// A specialized `Error` for Azure object store-related errors
@@ -232,6 +246,27 @@ impl ObjectStore for MicrosoftAzure {
Ok(())
}
+ async fn put_multipart(
+ &self,
+ location: &Path,
+ ) -> Result<(MultipartId, Box<dyn AsyncWrite + Unpin + Send>)> {
+ let inner = AzureMultiPartUpload {
+ container_client: Arc::clone(&self.container_client),
+ location: location.to_owned(),
+ };
+ Ok((String::new(), Box::new(CloudMultiPartUpload::new(inner, 8))))
+ }
+
+ async fn abort_multipart(
+ &self,
+ _location: &Path,
+ _multipart_id: &MultipartId,
+ ) -> Result<()> {
+ // There is no way to drop blocks that have been uploaded. Instead,
they simply
+ // expire in 7 days.
+ Ok(())
+ }
+
async fn get(&self, location: &Path) -> Result<GetResult> {
let blob = self
.container_client
@@ -604,6 +639,94 @@ pub fn new_azure(
})
}
+// Relevant docs:
https://azure.github.io/Storage/docs/application-and-user-data/basics/azure-blob-storage-upload-apis/
+// In Azure Blob Store, parts are "blocks"
+// put_multipart_part -> PUT block
+// complete -> PUT block list
+// abort -> No equivalent; blocks are simply dropped after 7 days
+#[derive(Debug, Clone)]
+struct AzureMultiPartUpload {
+ container_client: Arc<ContainerClient>,
+ location: Path,
+}
+
+impl AzureMultiPartUpload {
+ /// Gets the block id corresponding to the part index.
+ ///
+ /// In Azure, the user determines what id each block has. They must be
+ /// unique within an upload and of consistent length.
+ fn get_block_id(&self, part_idx: usize) -> String {
+ format!("{:20}", part_idx)
+ }
+}
+
+impl CloudMultiPartUploadImpl for AzureMultiPartUpload {
+ fn put_multipart_part(
+ &self,
+ buf: Vec<u8>,
+ part_idx: usize,
+ ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>> {
+ let client = Arc::clone(&self.container_client);
+ let location = self.location.clone();
+ let block_id = self.get_block_id(part_idx);
+
+ Box::pin(async move {
+ client
+ .as_blob_client(location.as_ref())
+ .put_block(block_id.clone(), buf)
+ .execute()
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+
+ Ok((
+ part_idx,
+ UploadPart {
+ content_id: block_id,
+ },
+ ))
+ })
+ }
+
+ fn complete(
+ &self,
+ completed_parts: Vec<Option<UploadPart>>,
+ ) -> BoxFuture<'static, Result<(), io::Error>> {
+ let parts =
+ completed_parts
+ .into_iter()
+ .enumerate()
+ .map(|(part_number, maybe_part)| match maybe_part {
+ Some(part) => {
+
Ok(azure_storage_blobs::blob::BlobBlockType::Uncommitted(
+ azure_storage_blobs::BlockId::new(part.content_id),
+ ))
+ }
+ None => Err(io::Error::new(
+ io::ErrorKind::Other,
+ format!("Missing information for upload part {:?}",
part_number),
+ )),
+ });
+
+ let client = Arc::clone(&self.container_client);
+ let location = self.location.clone();
+
+ Box::pin(async move {
+ let block_list = azure_storage_blobs::blob::BlockList {
+ blocks: parts.collect::<Result<_, io::Error>>()?,
+ };
+
+ client
+ .as_blob_client(location.as_ref())
+ .put_block_list(&block_list)
+ .execute()
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+
+ Ok(())
+ })
+ }
+}
+
#[cfg(test)]
mod tests {
use crate::azure::new_azure;
diff --git a/object_store/src/gcp.rs b/object_store/src/gcp.rs
index e836caba7..d740625bd 100644
--- a/object_store/src/gcp.rs
+++ b/object_store/src/gcp.rs
@@ -16,27 +16,44 @@
// under the License.
//! An object store implementation for Google Cloud Storage
+//!
+//! ## Multi-part uploads
+//!
+//! [Multi-part
uploads](https://cloud.google.com/storage/docs/multipart-uploads)
+//! can be initiated with the [ObjectStore::put_multipart] method.
+//! Data passed to the writer is automatically buffered to meet the minimum
size
+//! requirements for a part. Multiple parts are uploaded concurrently.
+//!
+//! If the writer fails for any reason, you may have parts uploaded to GCS but
not
+//! used that you may be charged for. Use the [ObjectStore::abort_multipart]
method
+//! to abort the upload and drop those unneeded parts. In addition, you may
wish to
+//! consider implementing automatic clean up of unused parts that are older
than one
+//! week.
use std::collections::BTreeSet;
use std::fs::File;
-use std::io::BufReader;
+use std::io::{self, BufReader};
use std::ops::Range;
+use std::sync::Arc;
use async_trait::async_trait;
-use bytes::Bytes;
+use bytes::{Buf, Bytes};
use chrono::{DateTime, Utc};
+use futures::future::BoxFuture;
use futures::{stream::BoxStream, StreamExt, TryStreamExt};
use percent_encoding::{percent_encode, NON_ALPHANUMERIC};
use reqwest::header::RANGE;
use reqwest::{header, Client, Method, Response, StatusCode};
use snafu::{ResultExt, Snafu};
+use tokio::io::AsyncWrite;
+use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl,
UploadPart};
use crate::util::format_http_range;
use crate::{
oauth::OAuthProvider,
path::{Path, DELIMITER},
token::TokenCache,
util::format_prefix,
- GetResult, ListResult, ObjectMeta, ObjectStore, Result,
+ GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result,
};
#[derive(Debug, Snafu)]
@@ -47,6 +64,14 @@ enum Error {
#[snafu(display("Unable to decode service account file: {}", source))]
DecodeCredentials { source: serde_json::Error },
+ #[snafu(display("Got invalid XML response for {} {}: {}", method, url,
source))]
+ InvalidXMLResponse {
+ source: quick_xml::de::DeError,
+ method: String,
+ url: String,
+ data: Bytes,
+ },
+
#[snafu(display("Error performing list request: {}", source))]
ListRequest { source: reqwest::Error },
@@ -139,9 +164,42 @@ struct Object {
updated: DateTime<Utc>,
}
+#[derive(serde::Deserialize, Debug)]
+#[serde(rename_all = "PascalCase")]
+struct InitiateMultipartUploadResult {
+ upload_id: String,
+}
+
+#[derive(serde::Serialize, Debug)]
+#[serde(rename_all = "PascalCase", rename(serialize = "Part"))]
+struct MultipartPart {
+ #[serde(rename = "$unflatten=PartNumber")]
+ part_number: usize,
+ #[serde(rename = "$unflatten=ETag")]
+ e_tag: String,
+}
+
+#[derive(serde::Serialize, Debug)]
+#[serde(rename_all = "PascalCase")]
+struct CompleteMultipartUpload {
+ #[serde(rename = "Part", default)]
+ parts: Vec<MultipartPart>,
+}
+
/// Configuration for connecting to [Google Cloud
Storage](https://cloud.google.com/storage/).
#[derive(Debug)]
pub struct GoogleCloudStorage {
+ client: Arc<GoogleCloudStorageClient>,
+}
+
+impl std::fmt::Display for GoogleCloudStorage {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "GoogleCloudStorage({})", self.client.bucket_name)
+ }
+}
+
+#[derive(Debug)]
+struct GoogleCloudStorageClient {
client: Client,
base_url: String,
@@ -155,13 +213,7 @@ pub struct GoogleCloudStorage {
max_list_results: Option<String>,
}
-impl std::fmt::Display for GoogleCloudStorage {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "GoogleCloudStorage({})", self.bucket_name)
- }
-}
-
-impl GoogleCloudStorage {
+impl GoogleCloudStorageClient {
async fn get_token(&self) -> Result<String> {
if let Some(oauth_provider) = &self.oauth_provider {
Ok(self
@@ -243,6 +295,61 @@ impl GoogleCloudStorage {
Ok(())
}
+ /// Initiate a multi-part upload
<https://cloud.google.com/storage/docs/xml-api/post-object-multipart>
+ async fn multipart_initiate(&self, path: &Path) -> Result<MultipartId> {
+ let token = self.get_token().await?;
+ let url = format!("{}/{}/{}", self.base_url, self.bucket_name_encoded,
path);
+
+ let response = self
+ .client
+ .request(Method::POST, &url)
+ .bearer_auth(token)
+ .header(header::CONTENT_TYPE, "application/octet-stream")
+ .header(header::CONTENT_LENGTH, "0")
+ .query(&[("uploads", "")])
+ .send()
+ .await
+ .context(PutRequestSnafu)?
+ .error_for_status()
+ .context(PutRequestSnafu)?;
+
+ let data = response.bytes().await.context(PutRequestSnafu)?;
+ let result: InitiateMultipartUploadResult = quick_xml::de::from_reader(
+ data.as_ref().reader(),
+ )
+ .context(InvalidXMLResponseSnafu {
+ method: "POST".to_string(),
+ url,
+ data,
+ })?;
+
+ Ok(result.upload_id)
+ }
+
+ /// Cleanup unused parts
<https://cloud.google.com/storage/docs/xml-api/delete-multipart>
+ async fn multipart_cleanup(
+ &self,
+ path: &str,
+ multipart_id: &MultipartId,
+ ) -> Result<()> {
+ let token = self.get_token().await?;
+ let url = format!("{}/{}/{}", self.base_url, self.bucket_name_encoded,
path);
+
+ self.client
+ .request(Method::DELETE, &url)
+ .bearer_auth(token)
+ .header(header::CONTENT_TYPE, "application/octet-stream")
+ .header(header::CONTENT_LENGTH, "0")
+ .query(&[("uploadId", multipart_id)])
+ .send()
+ .await
+ .context(PutRequestSnafu)?
+ .error_for_status()
+ .context(PutRequestSnafu)?;
+
+ Ok(())
+ }
+
/// Perform a delete request
<https://cloud.google.com/storage/docs/json_api/v1/objects/delete>
async fn delete_request(&self, path: &Path) -> Result<()> {
let token = self.get_token().await?;
@@ -401,14 +508,184 @@ impl GoogleCloudStorage {
}
}
+fn reqwest_error_as_io(err: reqwest::Error) -> io::Error {
+ if err.is_builder() || err.is_request() {
+ io::Error::new(io::ErrorKind::InvalidInput, err)
+ } else if err.is_status() {
+ match err.status() {
+ Some(StatusCode::NOT_FOUND) =>
io::Error::new(io::ErrorKind::NotFound, err),
+ Some(StatusCode::BAD_REQUEST) => {
+ io::Error::new(io::ErrorKind::InvalidInput, err)
+ }
+ Some(_) => io::Error::new(io::ErrorKind::Other, err),
+ None => io::Error::new(io::ErrorKind::Other, err),
+ }
+ } else if err.is_timeout() {
+ io::Error::new(io::ErrorKind::TimedOut, err)
+ } else if err.is_connect() {
+ io::Error::new(io::ErrorKind::NotConnected, err)
+ } else {
+ io::Error::new(io::ErrorKind::Other, err)
+ }
+}
+
+struct GCSMultipartUpload {
+ client: Arc<GoogleCloudStorageClient>,
+ encoded_path: String,
+ multipart_id: MultipartId,
+}
+
+impl CloudMultiPartUploadImpl for GCSMultipartUpload {
+ /// Upload an object part
<https://cloud.google.com/storage/docs/xml-api/put-object-multipart>
+ fn put_multipart_part(
+ &self,
+ buf: Vec<u8>,
+ part_idx: usize,
+ ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>> {
+ let upload_id = self.multipart_id.clone();
+ let url = format!(
+ "{}/{}/{}",
+ self.client.base_url, self.client.bucket_name_encoded,
self.encoded_path
+ );
+ let client = Arc::clone(&self.client);
+
+ Box::pin(async move {
+ let token = client
+ .get_token()
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+
+ let response = client
+ .client
+ .request(Method::PUT, &url)
+ .bearer_auth(token)
+ .query(&[
+ ("partNumber", format!("{}", part_idx + 1)),
+ ("uploadId", upload_id),
+ ])
+ .header(header::CONTENT_TYPE, "application/octet-stream")
+ .header(header::CONTENT_LENGTH, format!("{}", buf.len()))
+ .body(buf)
+ .send()
+ .await
+ .map_err(reqwest_error_as_io)?
+ .error_for_status()
+ .map_err(reqwest_error_as_io)?;
+
+ let content_id = response
+ .headers()
+ .get("ETag")
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "response headers missing ETag",
+ )
+ })?
+ .to_str()
+ .map_err(|err| io::Error::new(io::ErrorKind::InvalidData,
err))?
+ .to_string();
+
+ Ok((part_idx, UploadPart { content_id }))
+ })
+ }
+
+ /// Complete a multipart upload
<https://cloud.google.com/storage/docs/xml-api/post-object-complete>
+ fn complete(
+ &self,
+ completed_parts: Vec<Option<UploadPart>>,
+ ) -> BoxFuture<'static, Result<(), io::Error>> {
+ let client = Arc::clone(&self.client);
+ let upload_id = self.multipart_id.clone();
+ let url = format!(
+ "{}/{}/{}",
+ self.client.base_url, self.client.bucket_name_encoded,
self.encoded_path
+ );
+
+ Box::pin(async move {
+ let parts: Vec<MultipartPart> = completed_parts
+ .into_iter()
+ .enumerate()
+ .map(|(part_number, maybe_part)| match maybe_part {
+ Some(part) => Ok(MultipartPart {
+ e_tag: part.content_id,
+ part_number: part_number + 1,
+ }),
+ None => Err(io::Error::new(
+ io::ErrorKind::Other,
+ format!("Missing information for upload part {:?}",
part_number),
+ )),
+ })
+ .collect::<Result<Vec<MultipartPart>, io::Error>>()?;
+
+ let token = client
+ .get_token()
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+
+ let upload_info = CompleteMultipartUpload { parts };
+
+ let data = quick_xml::se::to_string(&upload_info)
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
+ // We cannot disable the escaping that transforms "/" to
""e;" :(
+ // https://github.com/tafia/quick-xml/issues/362
+ // https://github.com/tafia/quick-xml/issues/350
+ .replace(""", "\"");
+
+ client
+ .client
+ .request(Method::POST, &url)
+ .bearer_auth(token)
+ .query(&[("uploadId", upload_id)])
+ .body(data)
+ .send()
+ .await
+ .map_err(reqwest_error_as_io)?
+ .error_for_status()
+ .map_err(reqwest_error_as_io)?;
+
+ Ok(())
+ })
+ }
+}
+
#[async_trait]
impl ObjectStore for GoogleCloudStorage {
async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> {
- self.put_request(location, bytes).await
+ self.client.put_request(location, bytes).await
+ }
+
+ async fn put_multipart(
+ &self,
+ location: &Path,
+ ) -> Result<(MultipartId, Box<dyn AsyncWrite + Unpin + Send>)> {
+ let upload_id = self.client.multipart_initiate(location).await?;
+
+ let encoded_path =
+ percent_encode(location.to_string().as_bytes(),
NON_ALPHANUMERIC).to_string();
+
+ let inner = GCSMultipartUpload {
+ client: Arc::clone(&self.client),
+ encoded_path,
+ multipart_id: upload_id.clone(),
+ };
+
+ Ok((upload_id, Box::new(CloudMultiPartUpload::new(inner, 8))))
+ }
+
+ async fn abort_multipart(
+ &self,
+ location: &Path,
+ multipart_id: &MultipartId,
+ ) -> Result<()> {
+ self.client
+ .multipart_cleanup(location.as_ref(), multipart_id)
+ .await?;
+
+ Ok(())
}
async fn get(&self, location: &Path) -> Result<GetResult> {
- let response = self.get_request(location, None, false).await?;
+ let response = self.client.get_request(location, None, false).await?;
let stream = response
.bytes_stream()
.map_err(|source| crate::Error::Generic {
@@ -421,14 +698,17 @@ impl ObjectStore for GoogleCloudStorage {
}
async fn get_range(&self, location: &Path, range: Range<usize>) ->
Result<Bytes> {
- let response = self.get_request(location, Some(range), false).await?;
+ let response = self
+ .client
+ .get_request(location, Some(range), false)
+ .await?;
Ok(response.bytes().await.context(GetRequestSnafu {
path: location.as_ref(),
})?)
}
async fn head(&self, location: &Path) -> Result<ObjectMeta> {
- let response = self.get_request(location, None, true).await?;
+ let response = self.client.get_request(location, None, true).await?;
let object = response.json().await.context(GetRequestSnafu {
path: location.as_ref(),
})?;
@@ -436,7 +716,7 @@ impl ObjectStore for GoogleCloudStorage {
}
async fn delete(&self, location: &Path) -> Result<()> {
- self.delete_request(location).await
+ self.client.delete_request(location).await
}
async fn list(
@@ -444,6 +724,7 @@ impl ObjectStore for GoogleCloudStorage {
prefix: Option<&Path>,
) -> Result<BoxStream<'_, Result<ObjectMeta>>> {
let stream = self
+ .client
.list_paginated(prefix, false)?
.map_ok(|r| {
futures::stream::iter(
@@ -457,7 +738,7 @@ impl ObjectStore for GoogleCloudStorage {
}
async fn list_with_delimiter(&self, prefix: Option<&Path>) ->
Result<ListResult> {
- let mut stream = self.list_paginated(prefix, true)?;
+ let mut stream = self.client.list_paginated(prefix, true)?;
let mut common_prefixes = BTreeSet::new();
let mut objects = Vec::new();
@@ -482,11 +763,11 @@ impl ObjectStore for GoogleCloudStorage {
}
async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
- self.copy_request(from, to, false).await
+ self.client.copy_request(from, to, false).await
}
async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
- self.copy_request(from, to, true).await
+ self.client.copy_request(from, to, true).await
}
}
@@ -537,13 +818,15 @@ pub fn new_gcs_with_client(
// environment variables. Set the environment variable explicitly so
// that we can optionally accept command line arguments instead.
Ok(GoogleCloudStorage {
- client,
- base_url: credentials.gcs_base_url,
- oauth_provider,
- token_cache: Default::default(),
- bucket_name,
- bucket_name_encoded: encoded_bucket_name,
- max_list_results: None,
+ client: Arc::new(GoogleCloudStorageClient {
+ client,
+ base_url: credentials.gcs_base_url,
+ oauth_provider,
+ token_cache: Default::default(),
+ bucket_name,
+ bucket_name_encoded: encoded_bucket_name,
+ max_list_results: None,
+ }),
})
}
@@ -568,7 +851,7 @@ mod test {
use crate::{
tests::{
get_nonexistent_object, list_uses_directories_correctly,
list_with_delimiter,
- put_get_delete_list, rename_and_copy,
+ put_get_delete_list, rename_and_copy, stream_get,
},
Error as ObjectStoreError, ObjectStore,
};
@@ -648,6 +931,11 @@ mod test {
list_uses_directories_correctly(&integration).await.unwrap();
list_with_delimiter(&integration).await.unwrap();
rename_and_copy(&integration).await.unwrap();
+ if integration.client.base_url == default_gcs_base_url() {
+ // Fake GCS server does not yet implement XML Multipart uploads
+ // https://github.com/fsouza/fake-gcs-server/issues/852
+ stream_get(&integration).await.unwrap();
+ }
}
#[tokio::test]
diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs
index 2dc65069a..54d28273f 100644
--- a/object_store/src/lib.rs
+++ b/object_store/src/lib.rs
@@ -30,7 +30,7 @@
//!
//! This crate provides APIs for interacting with object storage services.
//!
-//! It currently supports PUT, GET, DELETE, HEAD and list for:
+//! It currently supports PUT (single or chunked/concurrent), GET, DELETE,
HEAD and list for:
//!
//! * [Google Cloud Storage](https://cloud.google.com/storage/)
//! * [Amazon S3](https://aws.amazon.com/s3/)
@@ -56,6 +56,8 @@ mod oauth;
#[cfg(feature = "gcp")]
mod token;
+#[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))]
+mod multipart;
mod util;
use crate::path::Path;
@@ -68,16 +70,45 @@ use snafu::Snafu;
use std::fmt::{Debug, Formatter};
use std::io::{Read, Seek, SeekFrom};
use std::ops::Range;
+use tokio::io::AsyncWrite;
/// An alias for a dynamically dispatched object store implementation.
pub type DynObjectStore = dyn ObjectStore;
+/// Id type for multi-part uploads.
+pub type MultipartId = String;
+
/// Universal API to multiple object store services.
#[async_trait]
pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static {
/// Save the provided bytes to the specified location.
async fn put(&self, location: &Path, bytes: Bytes) -> Result<()>;
+ /// Get a multi-part upload that allows writing data in chunks
+ ///
+ /// Most cloud-based uploads will buffer and upload parts in parallel.
+ ///
+ /// To complete the upload, [AsyncWrite::poll_shutdown] must be called
+ /// to completion.
+ ///
+ /// For some object stores (S3, GCS, and local in particular), if the
+ /// writer fails or panics, you must call [ObjectStore::abort_multipart]
+ /// to clean up partially written data.
+ async fn put_multipart(
+ &self,
+ location: &Path,
+ ) -> Result<(MultipartId, Box<dyn AsyncWrite + Unpin + Send>)>;
+
+ /// Cleanup an aborted upload.
+ ///
+ /// See documentation for individual stores for exact behavior, as
capabilities
+ /// vary by object store.
+ async fn abort_multipart(
+ &self,
+ location: &Path,
+ multipart_id: &MultipartId,
+ ) -> Result<()>;
+
/// Return the bytes that are stored at the specified location.
async fn get(&self, location: &Path) -> Result<GetResult>;
@@ -330,6 +361,7 @@ mod test_util {
mod tests {
use super::*;
use crate::test_util::flatten_list_stream;
+ use tokio::io::AsyncWriteExt;
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
type Result<T, E = Error> = std::result::Result<T, E>;
@@ -497,6 +529,77 @@ mod tests {
Ok(())
}
+ fn get_vec_of_bytes(chunk_length: usize, num_chunks: usize) -> Vec<Bytes> {
+
std::iter::repeat(Bytes::from_iter(std::iter::repeat(b'x').take(chunk_length)))
+ .take(num_chunks)
+ .collect()
+ }
+
+ pub(crate) async fn stream_get(storage: &DynObjectStore) -> Result<()> {
+ let location = Path::from("test_dir/test_upload_file.txt");
+
+ // Can write to storage
+ let data = get_vec_of_bytes(5_000_000, 10);
+ let bytes_expected = data.concat();
+ let (_, mut writer) = storage.put_multipart(&location).await?;
+ for chunk in &data {
+ writer.write_all(chunk).await?;
+ }
+
+ // Object should not yet exist in store
+ let meta_res = storage.head(&location).await;
+ assert!(meta_res.is_err());
+ assert!(matches!(
+ meta_res.unwrap_err(),
+ crate::Error::NotFound { .. }
+ ));
+
+ writer.shutdown().await?;
+ let bytes_written = storage.get(&location).await?.bytes().await?;
+ assert_eq!(bytes_expected, bytes_written);
+
+ // Can overwrite some storage
+ let data = get_vec_of_bytes(5_000, 5);
+ let bytes_expected = data.concat();
+ let (_, mut writer) = storage.put_multipart(&location).await?;
+ for chunk in &data {
+ writer.write_all(chunk).await?;
+ }
+ writer.shutdown().await?;
+ let bytes_written = storage.get(&location).await?.bytes().await?;
+ assert_eq!(bytes_expected, bytes_written);
+
+ // We can abort an empty write
+ let location = Path::from("test_dir/test_abort_upload.txt");
+ let (upload_id, writer) = storage.put_multipart(&location).await?;
+ drop(writer);
+ storage.abort_multipart(&location, &upload_id).await?;
+ let get_res = storage.get(&location).await;
+ assert!(get_res.is_err());
+ assert!(matches!(
+ get_res.unwrap_err(),
+ crate::Error::NotFound { .. }
+ ));
+
+ // We can abort an in-progress write
+ let (upload_id, mut writer) = storage.put_multipart(&location).await?;
+ if let Some(chunk) = data.get(0) {
+ writer.write_all(chunk).await?;
+ let _ = writer.write(chunk).await?;
+ }
+ drop(writer);
+
+ storage.abort_multipart(&location, &upload_id).await?;
+ let get_res = storage.get(&location).await;
+ assert!(get_res.is_err());
+ assert!(matches!(
+ get_res.unwrap_err(),
+ crate::Error::NotFound { .. }
+ ));
+
+ Ok(())
+ }
+
pub(crate) async fn list_uses_directories_correctly(
storage: &DynObjectStore,
) -> Result<()> {
diff --git a/object_store/src/local.rs b/object_store/src/local.rs
index 8a9462eba..798edef6f 100644
--- a/object_store/src/local.rs
+++ b/object_store/src/local.rs
@@ -19,18 +19,23 @@
use crate::{
maybe_spawn_blocking,
path::{filesystem_path_to_url, Path},
- GetResult, ListResult, ObjectMeta, ObjectStore, Result,
+ GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result,
};
use async_trait::async_trait;
use bytes::Bytes;
+use futures::future::BoxFuture;
+use futures::FutureExt;
use futures::{stream::BoxStream, StreamExt};
use snafu::{ensure, OptionExt, ResultExt, Snafu};
-use std::collections::VecDeque;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
use std::ops::Range;
+use std::pin::Pin;
use std::sync::Arc;
+use std::task::Poll;
use std::{collections::BTreeSet, convert::TryFrom, io};
+use std::{collections::VecDeque, path::PathBuf};
+use tokio::io::AsyncWrite;
use url::Url;
use walkdir::{DirEntry, WalkDir};
@@ -233,24 +238,7 @@ impl ObjectStore for LocalFileSystem {
let path = self.config.path_to_filesystem(location)?;
maybe_spawn_blocking(move || {
- let mut file = match File::create(&path) {
- Ok(f) => f,
- Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
- let parent = path
- .parent()
- .context(UnableToCreateFileSnafu { path: &path, err
})?;
- std::fs::create_dir_all(&parent)
- .context(UnableToCreateDirSnafu { path: parent })?;
-
- match File::create(&path) {
- Ok(f) => f,
- Err(err) => {
- return Err(Error::UnableToCreateFile { path, err
}.into())
- }
- }
- }
- Err(err) => return Err(Error::UnableToCreateFile { path, err
}.into()),
- };
+ let mut file = open_writable_file(&path)?;
file.write_all(&bytes)
.context(UnableToCopyDataToFileSnafu)?;
@@ -260,6 +248,53 @@ impl ObjectStore for LocalFileSystem {
.await
}
+ async fn put_multipart(
+ &self,
+ location: &Path,
+ ) -> Result<(MultipartId, Box<dyn AsyncWrite + Unpin + Send>)> {
+ let dest = self.config.path_to_filesystem(location)?;
+
+ // Generate an id in case of concurrent writes
+ let mut multipart_id = 1;
+
+ // Will write to a temporary path
+ let staging_path = loop {
+ let staging_path = get_upload_stage_path(&dest,
&multipart_id.to_string());
+
+ match std::fs::metadata(&staging_path) {
+ Err(err) if err.kind() == io::ErrorKind::NotFound => break
staging_path,
+ Err(err) => {
+ return Err(Error::UnableToCopyDataToFile { source: err
}.into())
+ }
+ Ok(_) => multipart_id += 1,
+ }
+ };
+ let multipart_id = multipart_id.to_string();
+
+ let file = open_writable_file(&staging_path)?;
+
+ Ok((
+ multipart_id.clone(),
+ Box::new(LocalUpload::new(dest, multipart_id, Arc::new(file))),
+ ))
+ }
+
+ async fn abort_multipart(
+ &self,
+ location: &Path,
+ multipart_id: &MultipartId,
+ ) -> Result<()> {
+ let dest = self.config.path_to_filesystem(location)?;
+ let staging_path: PathBuf = get_upload_stage_path(&dest, multipart_id);
+
+ maybe_spawn_blocking(move || {
+ std::fs::remove_file(&staging_path)
+ .context(UnableToDeleteFileSnafu { path: staging_path })?;
+ Ok(())
+ })
+ .await
+ }
+
async fn get(&self, location: &Path) -> Result<GetResult> {
let path = self.config.path_to_filesystem(location)?;
maybe_spawn_blocking(move || {
@@ -343,7 +378,12 @@ impl ObjectStore for LocalFileSystem {
Err(e) => Some(Err(e)),
Ok(None) => None,
Ok(entry @ Some(_)) => entry
- .filter(|dir_entry| dir_entry.file_type().is_file())
+ .filter(|dir_entry| {
+ dir_entry.file_type().is_file()
+ // Ignore file names with # in them, since they
might be in-progress uploads.
+ // They would be rejected anyways by
filesystem_to_path below.
+ &&
!dir_entry.file_name().to_string_lossy().contains('#')
+ })
.map(|entry| {
let location =
config.filesystem_to_path(entry.path())?;
convert_entry(entry, location)
@@ -400,6 +440,13 @@ impl ObjectStore for LocalFileSystem {
for entry_res in walkdir.into_iter().map(convert_walkdir_result) {
if let Some(entry) = entry_res? {
+ if entry.file_type().is_file()
+ // Ignore file names with # in them, since they might
be in-progress uploads.
+ // They would be rejected anyways by
filesystem_to_path below.
+ && entry.file_name().to_string_lossy().contains('#')
+ {
+ continue;
+ }
let is_directory = entry.file_type().is_dir();
let entry_location =
config.filesystem_to_path(entry.path())?;
@@ -475,6 +522,216 @@ impl ObjectStore for LocalFileSystem {
}
}
+fn get_upload_stage_path(dest: &std::path::Path, multipart_id: &MultipartId)
-> PathBuf {
+ let mut staging_path = dest.as_os_str().to_owned();
+ staging_path.push(format!("#{}", multipart_id));
+ staging_path.into()
+}
+
+enum LocalUploadState {
+ /// Upload is ready to send new data
+ Idle(Arc<std::fs::File>),
+ /// In the middle of a write
+ Writing(
+ Arc<std::fs::File>,
+ BoxFuture<'static, Result<usize, io::Error>>,
+ ),
+ /// In the middle of syncing data and closing file.
+ ///
+ /// Future will contain last reference to file, so it will call drop on
completion.
+ ShuttingDown(BoxFuture<'static, Result<(), io::Error>>),
+ /// File is being moved from it's temporary location to the final location
+ Committing(BoxFuture<'static, Result<(), io::Error>>),
+ /// Upload is complete
+ Complete,
+}
+
+struct LocalUpload {
+ inner_state: LocalUploadState,
+ dest: PathBuf,
+ multipart_id: MultipartId,
+}
+
+impl LocalUpload {
+ pub fn new(
+ dest: PathBuf,
+ multipart_id: MultipartId,
+ file: Arc<std::fs::File>,
+ ) -> Self {
+ Self {
+ inner_state: LocalUploadState::Idle(file),
+ dest,
+ multipart_id,
+ }
+ }
+}
+
+impl AsyncWrite for LocalUpload {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ buf: &[u8],
+ ) -> std::task::Poll<Result<usize, io::Error>> {
+ let invalid_state =
+ |condition: &str| -> std::task::Poll<Result<usize, io::Error>> {
+ Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ format!("Tried to write to file {}.", condition),
+ )))
+ };
+
+ if let Ok(runtime) = tokio::runtime::Handle::try_current() {
+ let mut data: Vec<u8> = buf.to_vec();
+ let data_len = data.len();
+
+ loop {
+ match &mut self.inner_state {
+ LocalUploadState::Idle(file) => {
+ let file = Arc::clone(file);
+ let file2 = Arc::clone(&file);
+ let data: Vec<u8> = std::mem::take(&mut data);
+ self.inner_state = LocalUploadState::Writing(
+ file,
+ Box::pin(
+ runtime
+ .spawn_blocking(move ||
(&*file2).write_all(&data))
+ .map(move |res| match res {
+ Err(err) => {
+
Err(io::Error::new(io::ErrorKind::Other, err))
+ }
+ Ok(res) => res.map(move |_| data_len),
+ }),
+ ),
+ );
+ }
+ LocalUploadState::Writing(file, inner_write) => {
+ match inner_write.poll_unpin(cx) {
+ Poll::Ready(res) => {
+ self.inner_state =
+ LocalUploadState::Idle(Arc::clone(file));
+ return Poll::Ready(res);
+ }
+ Poll::Pending => {
+ return Poll::Pending;
+ }
+ }
+ }
+ LocalUploadState::ShuttingDown(_) => {
+ return invalid_state("when writer is shutting down");
+ }
+ LocalUploadState::Committing(_) => {
+ return invalid_state("when writer is committing data");
+ }
+ LocalUploadState::Complete => {
+ return invalid_state("when writer is complete");
+ }
+ }
+ }
+ } else if let LocalUploadState::Idle(file) = &self.inner_state {
+ let file = Arc::clone(file);
+ (&*file).write_all(buf)?;
+ Poll::Ready(Ok(buf.len()))
+ } else {
+ // If we are running on this thread, then only possible states are
Idle and Complete.
+ invalid_state("when writer is already complete.")
+ }
+ }
+
+ fn poll_flush(
+ self: Pin<&mut Self>,
+ _cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), io::Error>> {
+ if let Ok(runtime) = tokio::runtime::Handle::try_current() {
+ loop {
+ match &mut self.inner_state {
+ LocalUploadState::Idle(file) => {
+ // We are moving file into the future, and it will be
dropped on it's completion, closing the file.
+ let file = Arc::clone(file);
+ self.inner_state =
LocalUploadState::ShuttingDown(Box::pin(
+ runtime.spawn_blocking(move ||
(*file).sync_all()).map(
+ move |res| match res {
+ Err(err) => {
+
Err(io::Error::new(io::ErrorKind::Other, err))
+ }
+ Ok(res) => res,
+ },
+ ),
+ ));
+ }
+ LocalUploadState::ShuttingDown(fut) => match
fut.poll_unpin(cx) {
+ Poll::Ready(res) => {
+ res?;
+ let staging_path =
+ get_upload_stage_path(&self.dest,
&self.multipart_id);
+ let dest = self.dest.clone();
+ self.inner_state =
LocalUploadState::Committing(Box::pin(
+ runtime
+ .spawn_blocking(move || {
+ std::fs::rename(&staging_path, &dest)
+ })
+ .map(move |res| match res {
+ Err(err) => {
+
Err(io::Error::new(io::ErrorKind::Other, err))
+ }
+ Ok(res) => res,
+ }),
+ ));
+ }
+ Poll::Pending => {
+ return Poll::Pending;
+ }
+ },
+ LocalUploadState::Writing(_, _) => {
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::InvalidInput,
+ "Tried to commit a file where a write is in
progress.",
+ )));
+ }
+ LocalUploadState::Committing(fut) => match
fut.poll_unpin(cx) {
+ Poll::Ready(res) => {
+ self.inner_state = LocalUploadState::Complete;
+ return Poll::Ready(res);
+ }
+ Poll::Pending => return Poll::Pending,
+ },
+ LocalUploadState::Complete => {
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::Other,
+ "Already complete",
+ )))
+ }
+ }
+ }
+ } else {
+ let staging_path = get_upload_stage_path(&self.dest,
&self.multipart_id);
+ match &mut self.inner_state {
+ LocalUploadState::Idle(file) => {
+ let file = Arc::clone(file);
+ self.inner_state = LocalUploadState::Complete;
+ file.sync_all()?;
+ std::mem::drop(file);
+ std::fs::rename(&staging_path, &self.dest)?;
+ Poll::Ready(Ok(()))
+ }
+ _ => {
+ // If we are running on this thread, then only possible
states are Idle and Complete.
+ Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::Other,
+ "Already complete",
+ )))
+ }
+ }
+ }
+ }
+}
+
fn open_file(path: &std::path::PathBuf) -> Result<File> {
let file = File::open(path).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
@@ -492,6 +749,33 @@ fn open_file(path: &std::path::PathBuf) -> Result<File> {
Ok(file)
}
+fn open_writable_file(path: &std::path::PathBuf) -> Result<File> {
+ match File::create(&path) {
+ Ok(f) => Ok(f),
+ Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
+ let parent = path
+ .parent()
+ .context(UnableToCreateFileSnafu { path: &path, err })?;
+ std::fs::create_dir_all(&parent)
+ .context(UnableToCreateDirSnafu { path: parent })?;
+
+ match File::create(&path) {
+ Ok(f) => Ok(f),
+ Err(err) => Err(Error::UnableToCreateFile {
+ path: path.to_path_buf(),
+ err,
+ }
+ .into()),
+ }
+ }
+ Err(err) => Err(Error::UnableToCreateFile {
+ path: path.to_path_buf(),
+ err,
+ }
+ .into()),
+ }
+}
+
fn convert_entry(entry: DirEntry, location: Path) -> Result<ObjectMeta> {
let metadata = entry
.metadata()
@@ -548,11 +832,12 @@ mod tests {
use crate::{
tests::{
copy_if_not_exists, get_nonexistent_object,
list_uses_directories_correctly,
- list_with_delimiter, put_get_delete_list, rename_and_copy,
+ list_with_delimiter, put_get_delete_list, rename_and_copy,
stream_get,
},
Error as ObjectStoreError, ObjectStore,
};
use tempfile::TempDir;
+ use tokio::io::AsyncWriteExt;
#[tokio::test]
async fn file_test() {
@@ -564,6 +849,7 @@ mod tests {
list_with_delimiter(&integration).await.unwrap();
rename_and_copy(&integration).await.unwrap();
copy_if_not_exists(&integration).await.unwrap();
+ stream_get(&integration).await.unwrap();
}
#[test]
@@ -574,6 +860,7 @@ mod tests {
put_get_delete_list(&integration).await.unwrap();
list_uses_directories_correctly(&integration).await.unwrap();
list_with_delimiter(&integration).await.unwrap();
+ stream_get(&integration).await.unwrap();
});
}
@@ -770,4 +1057,34 @@ mod tests {
err
);
}
+
+ #[tokio::test]
+ async fn list_hides_incomplete_uploads() {
+ let root = TempDir::new().unwrap();
+ let integration =
LocalFileSystem::new_with_prefix(root.path()).unwrap();
+ let location = Path::from("some_file");
+
+ let data = Bytes::from("arbitrary data");
+ let (multipart_id, mut writer) =
+ integration.put_multipart(&location).await.unwrap();
+ writer.write_all(&data).await.unwrap();
+
+ let (multipart_id_2, mut writer_2) =
+ integration.put_multipart(&location).await.unwrap();
+ assert_ne!(multipart_id, multipart_id_2);
+ writer_2.write_all(&data).await.unwrap();
+
+ let list = flatten_list_stream(&integration, None).await.unwrap();
+ assert_eq!(list.len(), 0);
+
+ assert_eq!(
+ integration
+ .list_with_delimiter(None)
+ .await
+ .unwrap()
+ .objects
+ .len(),
+ 0
+ );
+ }
}
diff --git a/object_store/src/memory.rs b/object_store/src/memory.rs
index ffd8e3a52..dc3967d99 100644
--- a/object_store/src/memory.rs
+++ b/object_store/src/memory.rs
@@ -16,6 +16,7 @@
// under the License.
//! An in-memory object store implementation
+use crate::MultipartId;
use crate::{path::Path, GetResult, ListResult, ObjectMeta, ObjectStore,
Result};
use async_trait::async_trait;
use bytes::Bytes;
@@ -25,7 +26,12 @@ use parking_lot::RwLock;
use snafu::{ensure, OptionExt, Snafu};
use std::collections::BTreeMap;
use std::collections::BTreeSet;
+use std::io;
use std::ops::Range;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::Poll;
+use tokio::io::AsyncWrite;
/// A specialized `Error` for in-memory object store-related errors
#[derive(Debug, Snafu)]
@@ -67,7 +73,7 @@ impl From<Error> for super::Error {
/// storage provider.
#[derive(Debug, Default)]
pub struct InMemory {
- storage: RwLock<BTreeMap<Path, Bytes>>,
+ storage: Arc<RwLock<BTreeMap<Path, Bytes>>>,
}
impl std::fmt::Display for InMemory {
@@ -83,6 +89,29 @@ impl ObjectStore for InMemory {
Ok(())
}
+ async fn put_multipart(
+ &self,
+ location: &Path,
+ ) -> Result<(MultipartId, Box<dyn AsyncWrite + Unpin + Send>)> {
+ Ok((
+ String::new(),
+ Box::new(InMemoryUpload {
+ location: location.clone(),
+ data: Vec::new(),
+ storage: Arc::clone(&self.storage),
+ }),
+ ))
+ }
+
+ async fn abort_multipart(
+ &self,
+ _location: &Path,
+ _multipart_id: &MultipartId,
+ ) -> Result<()> {
+ // Nothing to clean up
+ Ok(())
+ }
+
async fn get(&self, location: &Path) -> Result<GetResult> {
let data = self.get_bytes(location).await?;
@@ -211,7 +240,7 @@ impl InMemory {
let storage = storage.clone();
Self {
- storage: RwLock::new(storage),
+ storage: Arc::new(RwLock::new(storage)),
}
}
@@ -227,6 +256,39 @@ impl InMemory {
}
}
+struct InMemoryUpload {
+ location: Path,
+ data: Vec<u8>,
+ storage: Arc<RwLock<BTreeMap<Path, Bytes>>>,
+}
+
+impl AsyncWrite for InMemoryUpload {
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ _cx: &mut std::task::Context<'_>,
+ buf: &[u8],
+ ) -> std::task::Poll<Result<usize, io::Error>> {
+ self.data.extend_from_slice(buf);
+ Poll::Ready(Ok(buf.len()))
+ }
+
+ fn poll_flush(
+ self: Pin<&mut Self>,
+ _cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ _cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), io::Error>> {
+ let data = Bytes::from(std::mem::take(&mut self.data));
+ self.storage.write().insert(self.location.clone(), data);
+ Poll::Ready(Ok(()))
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -234,7 +296,7 @@ mod tests {
use crate::{
tests::{
copy_if_not_exists, get_nonexistent_object,
list_uses_directories_correctly,
- list_with_delimiter, put_get_delete_list, rename_and_copy,
+ list_with_delimiter, put_get_delete_list, rename_and_copy,
stream_get,
},
Error as ObjectStoreError, ObjectStore,
};
@@ -248,6 +310,7 @@ mod tests {
list_with_delimiter(&integration).await.unwrap();
rename_and_copy(&integration).await.unwrap();
copy_if_not_exists(&integration).await.unwrap();
+ stream_get(&integration).await.unwrap();
}
#[tokio::test]
diff --git a/object_store/src/multipart.rs b/object_store/src/multipart.rs
new file mode 100644
index 000000000..c16022d37
--- /dev/null
+++ b/object_store/src/multipart.rs
@@ -0,0 +1,195 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use futures::{future::BoxFuture, stream::FuturesUnordered, Future, StreamExt};
+use std::{io, pin::Pin, sync::Arc, task::Poll};
+use tokio::io::AsyncWrite;
+
+use crate::Result;
+
+type BoxedTryFuture<T> = Pin<Box<dyn Future<Output = Result<T, io::Error>> +
Send>>;
+
+/// A trait that can be implemented by cloud-based object stores
+/// and used in combination with [`CloudMultiPartUpload`] to provide
+/// multipart upload support
+///
+/// Note: this does not use AsyncTrait as the lifetimes are difficult to manage
+pub(crate) trait CloudMultiPartUploadImpl {
+ /// Upload a single part
+ fn put_multipart_part(
+ &self,
+ buf: Vec<u8>,
+ part_idx: usize,
+ ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>>;
+
+ /// Complete the upload with the provided parts
+ ///
+ /// `completed_parts` is in order of part number
+ fn complete(
+ &self,
+ completed_parts: Vec<Option<UploadPart>>,
+ ) -> BoxFuture<'static, Result<(), io::Error>>;
+}
+
+#[derive(Debug, Clone)]
+pub(crate) struct UploadPart {
+ pub content_id: String,
+}
+
+pub(crate) struct CloudMultiPartUpload<T>
+where
+ T: CloudMultiPartUploadImpl,
+{
+ inner: Arc<T>,
+ /// A list of completed parts, in sequential order.
+ completed_parts: Vec<Option<UploadPart>>,
+ /// Part upload tasks currently running
+ tasks: FuturesUnordered<BoxedTryFuture<(usize, UploadPart)>>,
+ /// Maximum number of upload tasks to run concurrently
+ max_concurrency: usize,
+ /// Buffer that will be sent in next upload.
+ current_buffer: Vec<u8>,
+ /// Minimum size of a part in bytes
+ min_part_size: usize,
+ /// Index of current part
+ current_part_idx: usize,
+ /// The completion task
+ completion_task: Option<BoxedTryFuture<()>>,
+}
+
+impl<T> CloudMultiPartUpload<T>
+where
+ T: CloudMultiPartUploadImpl,
+{
+ pub fn new(inner: T, max_concurrency: usize) -> Self {
+ Self {
+ inner: Arc::new(inner),
+ completed_parts: Vec::new(),
+ tasks: FuturesUnordered::new(),
+ max_concurrency,
+ current_buffer: Vec::new(),
+ // TODO: Should self vary by provider?
+ // TODO: Should we automatically increase then when part index
gets large?
+ min_part_size: 5_000_000,
+ current_part_idx: 0,
+ completion_task: None,
+ }
+ }
+
+ pub fn poll_tasks(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Result<(), io::Error> {
+ if self.tasks.is_empty() {
+ return Ok(());
+ }
+ let total_parts = self.completed_parts.len();
+ while let Poll::Ready(Some(res)) = self.tasks.poll_next_unpin(cx) {
+ let (part_idx, part) = res?;
+ self.completed_parts
+ .resize(std::cmp::max(part_idx + 1, total_parts), None);
+ self.completed_parts[part_idx] = Some(part);
+ }
+ Ok(())
+ }
+}
+
+impl<T> AsyncWrite for CloudMultiPartUpload<T>
+where
+ T: CloudMultiPartUploadImpl + Send + Sync,
+{
+ fn poll_write(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ buf: &[u8],
+ ) -> std::task::Poll<Result<usize, io::Error>> {
+ // Poll current tasks
+ self.as_mut().poll_tasks(cx)?;
+
+ // If adding buf to pending buffer would trigger send, check
+ // whether we have capacity for another task.
+ let enough_to_send = (buf.len() + self.current_buffer.len()) >
self.min_part_size;
+ if enough_to_send && self.tasks.len() < self.max_concurrency {
+ // If we do, copy into the buffer and submit the task, and return
ready.
+ self.current_buffer.extend_from_slice(buf);
+
+ let out_buffer = std::mem::take(&mut self.current_buffer);
+ let task = self
+ .inner
+ .put_multipart_part(out_buffer, self.current_part_idx);
+ self.tasks.push(task);
+ self.current_part_idx += 1;
+
+ // We need to poll immediately after adding to setup waker
+ self.as_mut().poll_tasks(cx)?;
+
+ Poll::Ready(Ok(buf.len()))
+ } else if !enough_to_send {
+ self.current_buffer.extend_from_slice(buf);
+ Poll::Ready(Ok(buf.len()))
+ } else {
+ // Waker registered by call to poll_tasks at beginning
+ Poll::Pending
+ }
+ }
+
+ fn poll_flush(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), io::Error>> {
+ // Poll current tasks
+ self.as_mut().poll_tasks(cx)?;
+
+ // If current_buffer is not empty, see if it can be submitted
+ if !self.current_buffer.is_empty() && self.tasks.len() <
self.max_concurrency {
+ let out_buffer: Vec<u8> = std::mem::take(&mut self.current_buffer);
+ let task = self
+ .inner
+ .put_multipart_part(out_buffer, self.current_part_idx);
+ self.tasks.push(task);
+ }
+
+ self.as_mut().poll_tasks(cx)?;
+
+ // If tasks and current_buffer are empty, return Ready
+ if self.tasks.is_empty() && self.current_buffer.is_empty() {
+ Poll::Ready(Ok(()))
+ } else {
+ Poll::Pending
+ }
+ }
+
+ fn poll_shutdown(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), io::Error>> {
+ // First, poll flush
+ match self.as_mut().poll_flush(cx) {
+ Poll::Pending => return Poll::Pending,
+ Poll::Ready(res) => res?,
+ };
+
+ // If shutdown task is not set, set it
+ let parts = std::mem::take(&mut self.completed_parts);
+ let inner = Arc::clone(&self.inner);
+ let completion_task = self
+ .completion_task
+ .get_or_insert_with(|| inner.complete(parts));
+
+ Pin::new(completion_task).poll(cx)
+ }
+}
diff --git a/object_store/src/throttle.rs b/object_store/src/throttle.rs
index 656029651..6789f0e68 100644
--- a/object_store/src/throttle.rs
+++ b/object_store/src/throttle.rs
@@ -20,11 +20,13 @@ use parking_lot::Mutex;
use std::ops::Range;
use std::{convert::TryInto, sync::Arc};
+use crate::MultipartId;
use crate::{path::Path, GetResult, ListResult, ObjectMeta, ObjectStore,
Result};
use async_trait::async_trait;
use bytes::Bytes;
use futures::{stream::BoxStream, StreamExt};
use std::time::Duration;
+use tokio::io::AsyncWrite;
/// Configuration settings for throttled store
#[derive(Debug, Default, Clone, Copy)]
@@ -149,6 +151,21 @@ impl<T: ObjectStore> ObjectStore for ThrottledStore<T> {
self.inner.put(location, bytes).await
}
+ async fn put_multipart(
+ &self,
+ _location: &Path,
+ ) -> Result<(MultipartId, Box<dyn AsyncWrite + Unpin + Send>)> {
+ Err(super::Error::NotImplemented)
+ }
+
+ async fn abort_multipart(
+ &self,
+ _location: &Path,
+ _multipart_id: &MultipartId,
+ ) -> Result<()> {
+ Err(super::Error::NotImplemented)
+ }
+
async fn get(&self, location: &Path) -> Result<GetResult> {
sleep(self.config().wait_get_per_call).await;