core.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. import logging
  2. import subprocess
  3. from threading import RLock
  4. from . import utils
  5. from . import exceptions as ex
  6. logger = logging.getLogger(__name__)
  7. class GitRepo:
  8. """
  9. Class to handle Git operations in a local Git repository.
  10. """
  11. __checkout_locks = dict()
  12. def __init__(self, local_path: str) -> None:
  13. """
  14. Initialize GitRepo with the path to the local Git repository.
  15. Parameters:
  16. local_path (str): Path to the Git repository
  17. """
  18. self.__set_local_path(local_path=local_path)
  19. self.__register_lock()
  20. logger.info(f"GitRepo initialised for {local_path}")
  21. def __eq__(self, other) -> bool:
  22. """
  23. Check if the instance is equal to the 'other' instance
  24. Parameters:
  25. other: the other Object to check
  26. """
  27. if type(other) is type(self):
  28. # return True if the paths of the repositories the objects
  29. # point to are equal, otherwise False
  30. return self.__local_path == other.__local_path
  31. return False
  32. def __hash__(self) -> int:
  33. """
  34. Return the hash value of the instance
  35. """
  36. return hash(self.__local_path)
  37. def __register_lock(self) -> None:
  38. """
  39. Initialize an RLock object for the instance in a shared dictionary
  40. """
  41. if not GitRepo.__checkout_locks.get(self):
  42. # create a Lock object for the instance, if is not already created
  43. GitRepo.__checkout_locks[self] = RLock()
  44. return
  45. def __set_local_path(self, local_path: str) -> None:
  46. """
  47. Set the path for the repository, ensuring it is a valid Git repo.
  48. Parameters:
  49. local_path (str): Path to the Git repository
  50. Raises:
  51. NonGitDirectoryError: If the directory is not a valid
  52. Git repository
  53. """
  54. if not utils.is_git_repo(local_path):
  55. raise ex.NonGitDirectoryError(directory=local_path)
  56. self.__local_path = local_path
  57. def get_local_path(self) -> str:
  58. """
  59. Return the local path of the repository.
  60. Returns:
  61. str: Path to the Git repository
  62. """
  63. return self.__local_path
  64. def get_checkout_lock(self) -> RLock:
  65. """
  66. Return the checkout lock object associated with the instance
  67. Returns:
  68. RLock: The lock object associated with the instance
  69. """
  70. lock = GitRepo.__checkout_locks.get(self)
  71. if lock is None:
  72. raise ex.LockNotInitializedError(
  73. lock_name="checkout_lock",
  74. path=self.__local_path
  75. )
  76. return lock
  77. def __checkout(self, commit_ref: str, force: bool = False) -> None:
  78. """
  79. Check out a specific commit.
  80. Parameters:
  81. commit_ref (str): Commit reference to check out
  82. force (bool): Force checkout (default is False)
  83. Raises:
  84. ValueError: If commit_ref is None
  85. """
  86. if commit_ref is None:
  87. raise ValueError("commit_ref is required, cannot be None.")
  88. cmd = ['git', 'checkout', commit_ref]
  89. if force:
  90. cmd.append('-f')
  91. logger.debug(f"Running {' '.join(cmd)}")
  92. logger.debug("Attempting to aquire checkout lock.")
  93. with self.get_checkout_lock():
  94. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  95. def __reset(self, commit_ref: str, hard: bool = False) -> None:
  96. """
  97. Reset to a specific commit.
  98. Parameters:
  99. commit_ref (str): Commit reference to reset to
  100. hard (bool): Use hard reset (default is False)
  101. Raises:
  102. ValueError: If commit_ref is None
  103. """
  104. if commit_ref is None:
  105. raise ValueError("commit_ref is required, cannot be None.")
  106. cmd = ['git', 'reset', commit_ref]
  107. if hard:
  108. cmd.append('--hard')
  109. logger.debug(f"Running {' '.join(cmd)}")
  110. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  111. def __force_recursive_clean(self) -> None:
  112. """
  113. Forcefully clean the working directory,
  114. removing untracked files and directories.
  115. """
  116. cmd = ['git', 'clean', '-xdff']
  117. logger.debug(f"Running {' '.join(cmd)}")
  118. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  119. def __remote_list(self) -> list[str]:
  120. """
  121. Retrieve a list of remotes added to the repository
  122. Returns:
  123. list[str]: List of remote names
  124. """
  125. cmd = ['git', 'remote']
  126. logger.debug(f"Running {' '.join(cmd)}")
  127. ret = subprocess.run(
  128. cmd, cwd=self.__local_path, shell=False, capture_output=True,
  129. encoding='utf-8', check=True
  130. )
  131. return ret.stdout.split('\n')[:-1]
  132. def __is_commit_present_locally(self, commit_ref: str) -> bool:
  133. """
  134. Check if a specific commit exists locally.
  135. Parameters:
  136. commit_ref (str): Commit hash to check
  137. Returns:
  138. bool: True if the commit exists locally, False otherwise
  139. """
  140. if commit_ref is None:
  141. raise ValueError("commit_ref is required, cannot be None.")
  142. cmd = ['git', 'diff-tree', commit_ref, '--no-commit-id', '--no-patch']
  143. logger.debug(f"Running {' '.join(cmd)}")
  144. ret = subprocess.run(cmd, cwd=self.__local_path, shell=False)
  145. return ret.returncode == 0
  146. def remote_set_url(self, remote: str, url: str) -> None:
  147. """
  148. Set the URL for a specific remote.
  149. Parameters:
  150. remote (str): Name of the remote
  151. url (str): URL to set for the remote
  152. Raises:
  153. ValueError: If remote or URL is None
  154. """
  155. if remote is None:
  156. raise ValueError("remote is required, cannot be None.")
  157. if url is None:
  158. raise ValueError("url is required, cannot be None.")
  159. cmd = ['git', 'remote', 'set-url', remote, url]
  160. logger.debug(f"Running {' '.join(cmd)}")
  161. subprocess.run(cmd, cwd=self.__local_path, check=True)
  162. def fetch_remote(self, remote: str, force: bool = False,
  163. tags: bool = False, recurse_submodules: bool = False,
  164. refetch: bool = False) -> None:
  165. """
  166. Fetch updates from a remote repository.
  167. Parameters:
  168. remote (str): Remote to fetch from; if None, fetches all
  169. force (bool): Force fetch (default is False)
  170. tags (bool): Fetch tags (default is False)
  171. recurse_submodules (bool): Recurse into submodules
  172. (default is False)
  173. refetch (bool): Re-fetch all objects (default is False)
  174. """
  175. cmd = ['git', 'fetch']
  176. if remote:
  177. cmd.append(remote)
  178. else:
  179. logger.info("fetch_remote: remote is None, fetching all remotes")
  180. cmd.append('--all')
  181. if force:
  182. cmd.append('--force')
  183. if tags:
  184. cmd.append('--tags')
  185. if refetch:
  186. cmd.append('--refetch')
  187. if recurse_submodules:
  188. cmd.append('--recurse-submodules')
  189. else:
  190. cmd.append('--no-recurse-submodules')
  191. logger.debug(f"Running {' '.join(cmd)}")
  192. subprocess.run(cmd, cwd=self.__local_path, shell=False)
  193. def __branch_create(self, branch_name: str,
  194. start_point: str = None) -> None:
  195. """
  196. Create a new branch starting from a given commit.
  197. Parameters:
  198. branch_name (str): Name of the branch to create
  199. start_point (str): Starting commit or branch (optional)
  200. """
  201. if branch_name is None:
  202. raise ValueError("branch_name is required, cannot be None.")
  203. cmd = ['git', 'branch', branch_name]
  204. if start_point:
  205. if not self.__is_commit_present_locally(commit_ref=start_point):
  206. raise ex.CommitNotFoundError(commit_ref=start_point)
  207. cmd.append(start_point)
  208. logger.debug(f"Running {' '.join(cmd)}")
  209. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  210. def __branch_delete(self, branch_name: str, force: bool = False) -> None:
  211. """
  212. Delete a local branch.
  213. Parameters:
  214. branch_name (str): Name of the branch to delete
  215. force (bool): Force delete (default is False)
  216. """
  217. if branch_name is None:
  218. raise ValueError("branch_name is required, cannot be None.")
  219. if not self.__is_commit_present_locally(commit_ref=branch_name):
  220. raise ex.CommitNotFoundError(commit_ref=branch_name)
  221. cmd = ['git', 'branch', '-d', branch_name]
  222. if force:
  223. cmd.append('--force')
  224. logger.debug(f"Running {' '.join(cmd)}")
  225. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  226. def commit_id_for_remote_ref(self, remote: str,
  227. commit_ref: str) -> str:
  228. """
  229. Get the commit ID for a specific commit reference from a remote.
  230. Parameters:
  231. remote (str): Name of the remote
  232. commit_ref (str): Reference to get the commit ID for
  233. Returns:
  234. str | None: Commit ID if found, None otherwise
  235. """
  236. if remote is None:
  237. raise ValueError("remote is required, cannot be None.")
  238. if remote not in self.__remote_list():
  239. raise ex.RemoteNotFoundError(remote=remote)
  240. if commit_ref is None:
  241. raise ValueError("commit_ref is required, cannot be None.")
  242. if utils.is_valid_hex_string(test_str=commit_ref):
  243. # skip conversion if commit_ref is already hex string
  244. return commit_ref
  245. # allow branches and tags only for now
  246. allowed_ref_types = ['tags', 'heads']
  247. split_ref = commit_ref.split('/', 2)
  248. if len(split_ref) != 3 or split_ref[0] != 'refs':
  249. raise ValueError(f"commit_ref '{commit_ref}' format is invalid.")
  250. _, ref_type, _ = split_ref
  251. if ref_type not in allowed_ref_types:
  252. raise ValueError(f"ref_type '{ref_type}' is not supported.")
  253. cmd = ['git', 'ls-remote', remote]
  254. logger.debug(f"Running {' '.join(cmd)}")
  255. ret = subprocess.run(
  256. cmd, cwd=self.__local_path, encoding='utf-8', capture_output=True,
  257. shell=False, check=True
  258. )
  259. for line in ret.stdout.split('\n')[:-1]:
  260. (commit_id, res_ref) = line.split('\t')
  261. if res_ref == commit_ref:
  262. return commit_id
  263. return None
  264. def __ensure_commit_fetched(self, remote: str, commit_id: str) -> None:
  265. """
  266. Ensure a specific commit is fetched from the remote repository.
  267. Parameters:
  268. remote (str): Remote name to fetch from
  269. commit_id (str): Commit ID to ensure it is available locally
  270. Raises:
  271. RemoteNotFoundError: If the specified remote does not exist
  272. CommitNotFoundError: If the commit cannot be fetched after
  273. multiple attempts
  274. """
  275. if remote is None:
  276. raise ValueError("remote is required, cannot be None.")
  277. if remote not in self.__remote_list():
  278. raise ex.RemoteNotFoundError(remote=remote)
  279. if commit_id is None:
  280. raise ValueError("commit_id is required, cannot be None.")
  281. if not utils.is_valid_hex_string(test_str=commit_id):
  282. raise ValueError(
  283. f"commit_id should be a hex string, got '{commit_id}'."
  284. )
  285. if self.__is_commit_present_locally(commit_ref=commit_id):
  286. # early return if commit is already fetched
  287. return
  288. self.fetch_remote(remote=remote, force=True, tags=True)
  289. # retry fetch with refetch option if the commit is still not found
  290. if not self.__is_commit_present_locally(commit_ref=commit_id):
  291. self.fetch_remote(
  292. remote=remote, force=True, tags=True, refetch=True
  293. )
  294. if not self.__is_commit_present_locally(commit_ref=commit_id):
  295. raise ex.CommitNotFoundError(commit_ref=commit_id)
  296. def checkout_remote_commit_ref(self, remote: str,
  297. commit_ref: str,
  298. force: bool = False,
  299. hard_reset: bool = False,
  300. clean_working_tree: bool = False) -> None:
  301. """
  302. Check out a specific commit from a remote repository.
  303. Parameters:
  304. remote (str): Remote name to check out from
  305. commit_ref (str): Commit reference to check out
  306. force (bool): Force the checkout (default is False)
  307. hard_reset (bool): Hard reset after checkout (default is False)
  308. clean_working_tree (bool): Clean untracked files after checkout
  309. (default is False)
  310. Raises:
  311. RemoteNotFoundError: If the specified remote does not exist
  312. CommitNotFoundError: If the specified commit cannot be found
  313. """
  314. if remote is None:
  315. logger.error("remote cannot be None for checkout to remote commit")
  316. raise ValueError("remote is required, cannot be None.")
  317. if remote not in self.__remote_list():
  318. raise ex.RemoteNotFoundError(remote=remote)
  319. if commit_ref is None:
  320. raise ValueError("commit_ref is required, cannot be None.")
  321. # retrieve the commit ID for the specified commit reference
  322. commit_id = self.commit_id_for_remote_ref(
  323. remote=remote, commit_ref=commit_ref
  324. )
  325. # ensure the commit is fetched from the remote repository
  326. self.__ensure_commit_fetched(remote=remote, commit_id=commit_id)
  327. # perform checkout on the specified commit using the commit ID
  328. # commit ID is used in place of branch name or tag name to make sure
  329. # do not check out the branch or tag from wrong remote
  330. self.__checkout(commit_ref=commit_id, force=force)
  331. # optional hard reset and clean of working tree after checkout
  332. if hard_reset:
  333. self.__reset(commit_ref=commit_id, hard=True)
  334. if clean_working_tree:
  335. self.__force_recursive_clean()
  336. def submodule_update(self, init: bool = False, recursive: bool = False,
  337. force: bool = False) -> None:
  338. """
  339. Update Git submodules for the repository.
  340. Parameters:
  341. init (bool): Initialize submodules if they are not initialized
  342. (default is False)
  343. recursive (bool): Update submodules recursively (default is False)
  344. force (bool): Force update even if there are changes
  345. (default is False)
  346. """
  347. cmd = ['git', 'submodule', 'update']
  348. if init:
  349. cmd.append('--init')
  350. if recursive:
  351. cmd.append('--recursive')
  352. if force:
  353. cmd.append('--force')
  354. logger.debug(f"Running {' '.join(cmd)}")
  355. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  356. def remote_add(self, remote: str, url: str) -> None:
  357. """
  358. Add a new remote to the Git repository.
  359. Parameters:
  360. remote (str): Name of the remote to add
  361. url (str): URL for the remote repository
  362. Raises:
  363. DuplicateRemoteError: If remote already exists and
  364. overwrite is not allowed
  365. """
  366. if remote is None:
  367. raise ValueError("remote is required, cannot be None.")
  368. if url is None:
  369. raise ValueError("url is required, cannot be None.")
  370. # Set the URL if the remote exists and overwrite is allowed
  371. if remote in self.__remote_list():
  372. raise ex.DuplicateRemoteError(remote)
  373. # Add the new remote
  374. cmd = ['git', 'remote', 'add', remote, url]
  375. logger.debug(f"Running {' '.join(cmd)}")
  376. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  377. @staticmethod
  378. def clone(source: str,
  379. dest: str,
  380. branch: str = None,
  381. single_branch: bool = False,
  382. recurse_submodules: bool = False,
  383. shallow_submodules: bool = False) -> "GitRepo":
  384. """
  385. Clone a Git repository.
  386. Parameters:
  387. source (str): Source path of the repository to clone
  388. Can be local or a url.
  389. dest (str): Destination path for the clone
  390. branch (str): Specific branch to clone (optional)
  391. single_branch (bool): Only clone a single branch (default is False)
  392. recurse_submodules (bool): Recurse into submodules
  393. (default is False)
  394. shallow_submodules (bool): any cloned submodules will be shallow
  395. Returns:
  396. GitRepo: the cloned git repository
  397. """
  398. cmd = ['git', 'clone', source, dest]
  399. if branch:
  400. cmd.append('--branch=' + branch)
  401. if single_branch:
  402. cmd.append('--single-branch')
  403. if recurse_submodules:
  404. cmd.append('--recurse-submodules')
  405. if shallow_submodules:
  406. cmd.append('--shallow-submodules')
  407. logger.debug(f"Running {' '.join(cmd)}")
  408. subprocess.run(cmd, shell=False, check=True)
  409. return GitRepo(local_path=dest)
  410. @staticmethod
  411. def shallow_clone_at_commit_from_local(source: str,
  412. remote: str,
  413. commit_ref: str,
  414. dest: str) -> "GitRepo":
  415. """
  416. Perform a shallow clone of a repository at a specific commit.
  417. Parameters:
  418. source (str): Source path of the local repository
  419. remote (str): Remote name containing the commit
  420. commit_ref (str): Commit reference to clone
  421. dest (str): Destination path for the clone
  422. Returns:
  423. GitRepo: the cloned git repository
  424. Raises:
  425. RemoteNotFoundError: If the specified remote does not exist
  426. CommitNotFoundError: If the specified commit cannot be found
  427. """
  428. if remote is None:
  429. raise ValueError("remote is required, cannot be None.")
  430. if commit_ref is None:
  431. raise ValueError("commit_ref is required, cannot be None.")
  432. source_repo = GitRepo(local_path=source)
  433. # get the commit ID for the specified remote reference
  434. commit_id = source_repo.commit_id_for_remote_ref(
  435. remote=remote, commit_ref=commit_ref
  436. )
  437. source_repo.__ensure_commit_fetched(remote=remote, commit_id=commit_id)
  438. # create a temporary branch to point to the specified commit
  439. # as shallow clone needs a branch
  440. temp_branch_name = "temp-b-" + commit_id
  441. source_repo.__branch_create(
  442. branch_name=temp_branch_name, start_point=commit_id
  443. )
  444. # perform the clone from the source repository
  445. # using the temporary branch
  446. cloned_repo = GitRepo.clone(
  447. source=source,
  448. dest=dest,
  449. branch=temp_branch_name,
  450. single_branch=True,
  451. recurse_submodules=True,
  452. shallow_submodules=True
  453. )
  454. # delete the temporary branch in source repository
  455. # after the clone operation
  456. source_repo.__branch_delete(branch_name=temp_branch_name, force=True)
  457. return cloned_repo