Source code for pyxet.commit_transaction

import sys
import threading

import fsspec

from .file_interface import XetFile

TRANSACTION_FILE_LIMIT = 512


def _validate_repo_info_for_transaction(repo_info):
    if repo_info.remote == '':
        raise ValueError("No repository specified")
    if repo_info.branch == '':
        raise ValueError("No branch specified")

def repo_info_key(repo_info):
    return f"{repo_info.remote}/{repo_info.branch}"


[docs] class MultiCommitTransaction(fsspec.transaction.Transaction): """ Handles a commit using the transaction interface. This transaction handler supports transactions across multiple branches by tracking them separately. Simultaneous changes across branches will require multiple actual transactions to complete. """ def __init__(self, fs, commit_message=None): """ This class should not be used directly. It is preferred to use fs.transaction. """ self.commit_message = None self._transaction_pool = {} self.fs = fs self._set_commit_message(commit_message) self.lock = threading.Lock() super().__init__(fs)
[docs] def set_commit_message(self, commit_message): """ Sets the commit message to be used. This applies to every current uncommitted transaction and future transactions. If commit_message is None, a default message "Commit [current datetime]" is used. """ with self.lock: self._set_commit_message(commit_message)
def _set_commit_message(self, commit_message): if commit_message is None: import datetime commit_message = "Commit " + datetime.datetime.now().isoformat() self.commit_message = commit_message def __repr__(self): with self.lock: return f"MultiCommitTransaction for [{self._transaction_pool.keys()}]" def __str__(self): with self.lock: return f"MultiCommitTransaction for [{self._transaction_pool.keys()}]" def __enter__(self): self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): """End transaction and commit, if exit is not due to exception""" # only commit if there was no exception self.complete(commit=exc_type is None) def get_handler_for_repo_info(self, repo_info): with self.lock: key = repo_info_key(repo_info) try: tr = self._transaction_pool[key] if tr.transaction_size() >= TRANSACTION_FILE_LIMIT: tr.commit_and_restart() except KeyError: tr = self.fs._create_transaction_handler(repo_info, self.commit_message) self._transaction_pool[key] = tr return tr.create_access_token()
[docs] def open_for_write(self, repo_info): """ Opens a file for writing. `repo_info` is the result of `pyxet.parse_url(url)` """ handler = self.get_handler_for_repo_info(repo_info) return XetFile(handler.open_for_write(repo_info.path), handler)
def start(self): """ Starts the transaction """ if self.fs.intrans: raise RuntimeError("Transaction already in progress") self.fs.intrans = True
[docs] def complete(self, commit=True): """ Finalizes and commits or cancels this transaction. The transaction can be restarted with start() """ with self.lock: # Should not be called while other things are in progress, but better be safe. ret_except = None for k, v in self._transaction_pool.items(): try: v.complete(commit) except Exception as e: sys.stderr.write(f"Failed to commit {k}: {e}\n") sys.stderr.flush() if ret_except is None: ret_except = e # reset all the transaction state self._transaction_pool = {} self.fs.intrans = False self._set_commit_message(None) if ret_except is not None: raise ret_except
[docs] def copy(self, src_repo_info, dest_repo_info): """ Copies a file from src to dest. src_repo_info and dest_repo_info are the returned values from `pyxet.parse_url(url)` """ handler = self.get_handler_for_repo_info(dest_repo_info) handler.copy(src_repo_info.branch, src_repo_info.path, dest_repo_info.path)
[docs] def mv(self, src_repo_info, dest_repo_info): """ Moves a file from src to dest. src_repo_info and dest_repo_info are the returned values from `pyxet.parse_url(url)` """ handler = self.get_handler_for_repo_info(dest_repo_info) handler.mv(src_repo_info.path, dest_repo_info.path)
[docs] def rm(self, repo_info): """ Removes a file. repo_info is the return value of `pyxet.parse_url(url)` """ handler = self.get_handler_for_repo_info(repo_info) handler.delete(repo_info.path)
def _set_do_not_commit(self): """ Internal method for testing purposes. Flags all active transactions to not attempt to push the commit, but will silently succeed. """ with self.lock: for v in self._transaction_pool.values(): v.set_do_not_commit() def _set_error_on_commit(self): """ Internal method for testing purposes. Flags all active transactions to not attempt to push the commit, but will just raise an exception """ with self.lock: for v in self._transaction_pool.values(): v.set_error_on_commit() def get_change_list(self): deletes = [] new_files = [] copies = [] moves = [] for v in self._transaction_pool.values(): deletes.extend(v.deletes) new_files.extend(v.new_files) copies.extend(v.copies) moves.extend(v.moves) return {'deletes': deletes, 'new_files': new_files, 'copies': copies, 'moves': moves}