core.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. import logging
  2. import subprocess
  3. from threading import RLock
  4. from . import utils
  5. from . import exceptions as ex
  6. from pathlib import Path
  7. logger = logging.getLogger(__name__)
  8. class GitRepo:
  9. """
  10. Class to handle Git operations in a local Git repository.
  11. """
  12. __checkout_locks = dict()
  13. def __init__(self, local_path: str) -> None:
  14. """
  15. Initialize GitRepo with the path to the local Git repository.
  16. Parameters:
  17. local_path (str): Path to the Git repository
  18. """
  19. self.__set_local_path(local_path=local_path)
  20. self.__register_lock()
  21. logger.info(f"GitRepo initialised for {local_path}")
  22. def __eq__(self, other) -> bool:
  23. """
  24. Check if the instance is equal to the 'other' instance
  25. Parameters:
  26. other: the other Object to check
  27. """
  28. if type(other) is type(self):
  29. # return True if the paths of the repositories the objects
  30. # point to are equal, otherwise False
  31. return self.__local_path == other.__local_path
  32. return False
  33. def __hash__(self) -> int:
  34. """
  35. Return the hash value of the instance
  36. """
  37. return hash(self.__local_path)
  38. def __register_lock(self) -> None:
  39. """
  40. Initialize an RLock object for the instance in a shared dictionary
  41. """
  42. if not GitRepo.__checkout_locks.get(self):
  43. # create a Lock object for the instance, if is not already created
  44. GitRepo.__checkout_locks[self] = RLock()
  45. return
  46. def __set_local_path(self, local_path: str) -> None:
  47. """
  48. Set the path for the repository, ensuring it is a valid Git repo.
  49. Parameters:
  50. local_path (str): Path to the Git repository
  51. Raises:
  52. NonGitDirectoryError: If the directory is not a valid
  53. Git repository
  54. """
  55. if not utils.is_git_repo(local_path):
  56. raise ex.NonGitDirectoryError(directory=local_path)
  57. self.__local_path = local_path
  58. def get_local_path(self) -> str:
  59. """
  60. Return the local path of the repository.
  61. Returns:
  62. str: Path to the Git repository
  63. """
  64. return self.__local_path
  65. def get_checkout_lock(self) -> RLock:
  66. """
  67. Return the checkout lock object associated with the instance
  68. Returns:
  69. RLock: The lock object associated with the instance
  70. """
  71. lock = GitRepo.__checkout_locks.get(self)
  72. if lock is None:
  73. raise ex.LockNotInitializedError(
  74. lock_name="checkout_lock",
  75. path=self.__local_path
  76. )
  77. return lock
  78. def __checkout(self, commit_ref: str, force: bool = False) -> None:
  79. """
  80. Check out a specific commit.
  81. Parameters:
  82. commit_ref (str): Commit reference to check out
  83. force (bool): Force checkout (default is False)
  84. Raises:
  85. ValueError: If commit_ref is None
  86. """
  87. if commit_ref is None:
  88. raise ValueError("commit_ref is required, cannot be None.")
  89. cmd = ['git', 'checkout', commit_ref]
  90. if force:
  91. cmd.append('-f')
  92. logger.debug(f"Running {' '.join(cmd)}")
  93. logger.debug("Attempting to aquire checkout lock.")
  94. with self.get_checkout_lock():
  95. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  96. def __reset(self, commit_ref: str, hard: bool = False) -> None:
  97. """
  98. Reset to a specific commit.
  99. Parameters:
  100. commit_ref (str): Commit reference to reset to
  101. hard (bool): Use hard reset (default is False)
  102. Raises:
  103. ValueError: If commit_ref is None
  104. """
  105. if commit_ref is None:
  106. raise ValueError("commit_ref is required, cannot be None.")
  107. cmd = ['git', 'reset', commit_ref]
  108. if hard:
  109. cmd.append('--hard')
  110. logger.debug(f"Running {' '.join(cmd)}")
  111. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  112. def __force_recursive_clean(self) -> None:
  113. """
  114. Forcefully clean the working directory,
  115. removing untracked files and directories.
  116. """
  117. cmd = ['git', 'clean', '-xdff']
  118. logger.debug(f"Running {' '.join(cmd)}")
  119. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  120. def __remote_list(self) -> list[str]:
  121. """
  122. Retrieve a list of remotes added to the repository
  123. Returns:
  124. list[str]: List of remote names
  125. """
  126. cmd = ['git', 'remote']
  127. logger.debug(f"Running {' '.join(cmd)}")
  128. ret = subprocess.run(
  129. cmd, cwd=self.__local_path, shell=False, capture_output=True,
  130. encoding='utf-8', check=True
  131. )
  132. return ret.stdout.split('\n')[:-1]
  133. def __is_commit_present_locally(self, commit_ref: str) -> bool:
  134. """
  135. Check if a specific commit exists locally.
  136. Parameters:
  137. commit_ref (str): Commit hash to check
  138. Returns:
  139. bool: True if the commit exists locally, False otherwise
  140. """
  141. if commit_ref is None:
  142. raise ValueError("commit_ref is required, cannot be None.")
  143. cmd = ['git', 'diff-tree', commit_ref, '--no-commit-id', '--no-patch']
  144. logger.debug(f"Running {' '.join(cmd)}")
  145. ret = subprocess.run(cmd, cwd=self.__local_path, shell=False)
  146. return ret.returncode == 0
  147. def remote_set_url(self, remote: str, url: str) -> None:
  148. """
  149. Set the URL for a specific remote.
  150. Parameters:
  151. remote (str): Name of the remote
  152. url (str): URL to set for the remote
  153. Raises:
  154. ValueError: If remote or URL is None
  155. """
  156. if remote is None:
  157. raise ValueError("remote is required, cannot be None.")
  158. if url is None:
  159. raise ValueError("url is required, cannot be None.")
  160. cmd = ['git', 'remote', 'set-url', remote, url]
  161. logger.debug(f"Running {' '.join(cmd)}")
  162. subprocess.run(cmd, cwd=self.__local_path, check=True)
  163. def remote_get_url(self, remote: str) -> str:
  164. """
  165. Get the URL for a specific remote.
  166. Parameters:
  167. remote (str): Name of the remote
  168. Returns:
  169. str: The URL associated with the remote
  170. Raises:
  171. ValueError: If remote is None
  172. """
  173. if remote is None:
  174. raise ValueError("remote is required, cannot be None.")
  175. cmd = ['git', 'remote', 'get-url', remote]
  176. logger.debug(f"Running {' '.join(cmd)}")
  177. # Capture the output of the command
  178. result = subprocess.run(
  179. cmd,
  180. cwd=self.__local_path,
  181. check=True,
  182. capture_output=True,
  183. text=True
  184. )
  185. # Return the URL from the output
  186. return result.stdout.strip()
  187. def fetch_remote(self, remote: str, force: bool = False,
  188. tags: bool = False, recurse_submodules: bool = False,
  189. refetch: bool = False) -> None:
  190. """
  191. Fetch updates from a remote repository.
  192. Parameters:
  193. remote (str): Remote to fetch from; if None, fetches all
  194. force (bool): Force fetch (default is False)
  195. tags (bool): Fetch tags (default is False)
  196. recurse_submodules (bool): Recurse into submodules
  197. (default is False)
  198. refetch (bool): Re-fetch all objects (default is False)
  199. """
  200. cmd = ['git', 'fetch']
  201. if remote:
  202. cmd.append(remote)
  203. else:
  204. logger.info("fetch_remote: remote is None, fetching all remotes")
  205. cmd.append('--all')
  206. if force:
  207. cmd.append('--force')
  208. if tags:
  209. cmd.append('--tags')
  210. if refetch:
  211. cmd.append('--refetch')
  212. if recurse_submodules:
  213. cmd.append('--recurse-submodules')
  214. else:
  215. cmd.append('--no-recurse-submodules')
  216. logger.debug(f"Running {' '.join(cmd)}")
  217. subprocess.run(cmd, cwd=self.__local_path, shell=False)
  218. def __branch_create(self, branch_name: str,
  219. start_point: str = None) -> None:
  220. """
  221. Create a new branch starting from a given commit.
  222. Parameters:
  223. branch_name (str): Name of the branch to create
  224. start_point (str): Starting commit or branch (optional)
  225. """
  226. if branch_name is None:
  227. raise ValueError("branch_name is required, cannot be None.")
  228. cmd = ['git', 'branch', branch_name]
  229. if start_point:
  230. if not self.__is_commit_present_locally(commit_ref=start_point):
  231. raise ex.CommitNotFoundError(commit_ref=start_point)
  232. cmd.append(start_point)
  233. logger.debug(f"Running {' '.join(cmd)}")
  234. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  235. def __branch_delete(self, branch_name: str, force: bool = False) -> None:
  236. """
  237. Delete a local branch.
  238. Parameters:
  239. branch_name (str): Name of the branch to delete
  240. force (bool): Force delete (default is False)
  241. """
  242. if branch_name is None:
  243. raise ValueError("branch_name is required, cannot be None.")
  244. if not self.__is_commit_present_locally(commit_ref=branch_name):
  245. raise ex.CommitNotFoundError(commit_ref=branch_name)
  246. cmd = ['git', 'branch', '-d', branch_name]
  247. if force:
  248. cmd.append('--force')
  249. logger.debug(f"Running {' '.join(cmd)}")
  250. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  251. def commit_id_for_remote_ref(self, remote: str,
  252. commit_ref: str) -> str:
  253. """
  254. Get the commit ID for a specific commit reference from a remote.
  255. Parameters:
  256. remote (str): Name of the remote
  257. commit_ref (str): Reference to get the commit ID for
  258. Returns:
  259. str | None: Commit ID if found, None otherwise
  260. """
  261. if remote is None:
  262. raise ValueError("remote is required, cannot be None.")
  263. if remote not in self.__remote_list():
  264. raise ex.RemoteNotFoundError(remote=remote)
  265. if commit_ref is None:
  266. raise ValueError("commit_ref is required, cannot be None.")
  267. if utils.is_valid_hex_string(test_str=commit_ref):
  268. # skip conversion if commit_ref is already hex string
  269. return commit_ref
  270. # allow branches and tags only for now
  271. allowed_ref_types = ['tags', 'heads']
  272. split_ref = commit_ref.split('/', 2)
  273. if len(split_ref) != 3 or split_ref[0] != 'refs':
  274. raise ValueError(f"commit_ref '{commit_ref}' format is invalid.")
  275. _, ref_type, _ = split_ref
  276. if ref_type not in allowed_ref_types:
  277. raise ValueError(f"ref_type '{ref_type}' is not supported.")
  278. cmd = ['git', 'ls-remote', remote]
  279. logger.debug(f"Running {' '.join(cmd)}")
  280. ret = subprocess.run(
  281. cmd, cwd=self.__local_path, encoding='utf-8', capture_output=True,
  282. shell=False, check=True
  283. )
  284. for line in ret.stdout.split('\n')[:-1]:
  285. (commit_id, res_ref) = line.split('\t')
  286. if res_ref == commit_ref:
  287. return commit_id
  288. return None
  289. def __ensure_commit_fetched(self, remote: str, commit_id: str) -> None:
  290. """
  291. Ensure a specific commit is fetched from the remote repository.
  292. Parameters:
  293. remote (str): Remote name to fetch from
  294. commit_id (str): Commit ID to ensure it is available locally
  295. Raises:
  296. RemoteNotFoundError: If the specified remote does not exist
  297. CommitNotFoundError: If the commit cannot be fetched after
  298. multiple attempts
  299. """
  300. if remote is None:
  301. raise ValueError("remote is required, cannot be None.")
  302. if remote not in self.__remote_list():
  303. raise ex.RemoteNotFoundError(remote=remote)
  304. if commit_id is None:
  305. raise ValueError("commit_id is required, cannot be None.")
  306. if not utils.is_valid_hex_string(test_str=commit_id):
  307. raise ValueError(
  308. f"commit_id should be a hex string, got '{commit_id}'."
  309. )
  310. if self.__is_commit_present_locally(commit_ref=commit_id):
  311. # early return if commit is already fetched
  312. return
  313. self.fetch_remote(remote=remote, force=True, tags=True)
  314. # retry fetch with refetch option if the commit is still not found
  315. if not self.__is_commit_present_locally(commit_ref=commit_id):
  316. self.fetch_remote(
  317. remote=remote, force=True, tags=True, refetch=True
  318. )
  319. if not self.__is_commit_present_locally(commit_ref=commit_id):
  320. raise ex.CommitNotFoundError(commit_ref=commit_id)
  321. def checkout_remote_commit_ref(self, remote: str,
  322. commit_ref: str,
  323. force: bool = False,
  324. hard_reset: bool = False,
  325. clean_working_tree: bool = False) -> None:
  326. """
  327. Check out a specific commit from a remote repository.
  328. Parameters:
  329. remote (str): Remote name to check out from
  330. commit_ref (str): Commit reference to check out
  331. force (bool): Force the checkout (default is False)
  332. hard_reset (bool): Hard reset after checkout (default is False)
  333. clean_working_tree (bool): Clean untracked files after checkout
  334. (default is False)
  335. Raises:
  336. RemoteNotFoundError: If the specified remote does not exist
  337. CommitNotFoundError: If the specified commit cannot be found
  338. """
  339. if remote is None:
  340. logger.error("remote cannot be None for checkout to remote commit")
  341. raise ValueError("remote is required, cannot be None.")
  342. if remote not in self.__remote_list():
  343. raise ex.RemoteNotFoundError(remote=remote)
  344. if commit_ref is None:
  345. raise ValueError("commit_ref is required, cannot be None.")
  346. # retrieve the commit ID for the specified commit reference
  347. commit_id = self.commit_id_for_remote_ref(
  348. remote=remote, commit_ref=commit_ref
  349. )
  350. # ensure the commit is fetched from the remote repository
  351. self.__ensure_commit_fetched(remote=remote, commit_id=commit_id)
  352. # perform checkout on the specified commit using the commit ID
  353. # commit ID is used in place of branch name or tag name to make sure
  354. # do not check out the branch or tag from wrong remote
  355. self.__checkout(commit_ref=commit_id, force=force)
  356. # optional hard reset and clean of working tree after checkout
  357. if hard_reset:
  358. self.__reset(commit_ref=commit_id, hard=True)
  359. if clean_working_tree:
  360. self.__force_recursive_clean()
  361. def submodule_update(self, init: bool = False, recursive: bool = False,
  362. force: bool = False) -> None:
  363. """
  364. Update Git submodules for the repository.
  365. Parameters:
  366. init (bool): Initialize submodules if they are not initialized
  367. (default is False)
  368. recursive (bool): Update submodules recursively (default is False)
  369. force (bool): Force update even if there are changes
  370. (default is False)
  371. """
  372. cmd = ['git', 'submodule', 'update']
  373. if init:
  374. cmd.append('--init')
  375. if recursive:
  376. cmd.append('--recursive')
  377. if force:
  378. cmd.append('--force')
  379. logger.debug(f"Running {' '.join(cmd)}")
  380. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  381. def remote_add(self, remote: str, url: str) -> None:
  382. """
  383. Add a new remote to the Git repository.
  384. Parameters:
  385. remote (str): Name of the remote to add
  386. url (str): URL for the remote repository
  387. Raises:
  388. DuplicateRemoteError: If remote already exists and
  389. overwrite is not allowed
  390. """
  391. if remote is None:
  392. raise ValueError("remote is required, cannot be None.")
  393. if url is None:
  394. raise ValueError("url is required, cannot be None.")
  395. # Set the URL if the remote exists and overwrite is allowed
  396. if remote in self.__remote_list():
  397. raise ex.DuplicateRemoteError(remote)
  398. # Add the new remote
  399. cmd = ['git', 'remote', 'add', remote, url]
  400. logger.debug(f"Running {' '.join(cmd)}")
  401. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  402. def remote_add_bulk(self, remotes: tuple, force: bool = False) -> None:
  403. """
  404. Add multiple remotes to the Git repository at once.
  405. Parameters:
  406. remotes (tuple): Tuple of tuples containing remote name
  407. and url.
  408. E.g. (
  409. ('remote1', 'https://remote1_url'),
  410. ('remote2', 'https://remote2_url'),
  411. )
  412. force (bool): Force update the url if remote already exists.
  413. Raises:
  414. DuplicateRemoteError: If remote already exists and
  415. overwrite is not allowed.
  416. """
  417. logger.info(f"Remotes to add: {remotes}.")
  418. for (remote, url) in remotes:
  419. try:
  420. self.remote_add(remote=remote, url=url)
  421. except ex.DuplicateRemoteError:
  422. if not force:
  423. raise
  424. logger.info(f"Remote {remote} already exists. Updating url.")
  425. self.remote_set_url(remote=remote, url=url)
  426. logger.info(f"Remote {remote} added to repo with url {url}.")
  427. @staticmethod
  428. def clone(source: str,
  429. dest: str,
  430. branch: str = None,
  431. single_branch: bool = False,
  432. recurse_submodules: bool = False,
  433. shallow_submodules: bool = False) -> "GitRepo":
  434. """
  435. Clone a Git repository.
  436. Parameters:
  437. source (str): Source path of the repository to clone
  438. Can be local or a url.
  439. dest (str): Destination path for the clone
  440. branch (str): Specific branch to clone (optional)
  441. single_branch (bool): Only clone a single branch (default is False)
  442. recurse_submodules (bool): Recurse into submodules
  443. (default is False)
  444. shallow_submodules (bool): any cloned submodules will be shallow
  445. Returns:
  446. GitRepo: the cloned git repository
  447. """
  448. cmd = ['git', 'clone', source, dest]
  449. if branch:
  450. cmd.append('--branch=' + branch)
  451. if single_branch:
  452. cmd.append('--single-branch')
  453. if recurse_submodules:
  454. cmd.append('--recurse-submodules')
  455. if shallow_submodules:
  456. cmd.append('--shallow-submodules')
  457. logger.debug(f"Running {' '.join(cmd)}")
  458. subprocess.run(cmd, shell=False, check=True)
  459. return GitRepo(local_path=dest)
  460. @staticmethod
  461. def shallow_clone_at_commit_from_local(source: str,
  462. remote: str,
  463. commit_ref: str,
  464. dest: str) -> "GitRepo":
  465. """
  466. Perform a shallow clone of a repository at a specific commit.
  467. Parameters:
  468. source (str): Source path of the local repository
  469. remote (str): Remote name containing the commit
  470. commit_ref (str): Commit reference to clone
  471. dest (str): Destination path for the clone
  472. Returns:
  473. GitRepo: the cloned git repository
  474. Raises:
  475. RemoteNotFoundError: If the specified remote does not exist
  476. CommitNotFoundError: If the specified commit cannot be found
  477. """
  478. if remote is None:
  479. raise ValueError("remote is required, cannot be None.")
  480. if commit_ref is None:
  481. raise ValueError("commit_ref is required, cannot be None.")
  482. source_repo = GitRepo(local_path=source)
  483. # get the commit ID for the specified remote reference
  484. commit_id = source_repo.commit_id_for_remote_ref(
  485. remote=remote, commit_ref=commit_ref
  486. )
  487. source_repo.__ensure_commit_fetched(remote=remote, commit_id=commit_id)
  488. # create a temporary branch to point to the specified commit
  489. # as shallow clone needs a branch
  490. temp_branch_name = "temp-b-" + commit_id
  491. source_repo.__branch_create(
  492. branch_name=temp_branch_name, start_point=commit_id
  493. )
  494. # perform the clone from the source repository
  495. # using the temporary branch
  496. cloned_repo = GitRepo.clone(
  497. source=source,
  498. dest=dest,
  499. branch=temp_branch_name,
  500. single_branch=True,
  501. recurse_submodules=True,
  502. shallow_submodules=True
  503. )
  504. # add the remote containing the commit in cloned repo for reference
  505. url = source_repo.remote_get_url(remote=remote)
  506. cloned_repo.remote_add(remote=remote, url=url)
  507. # delete the temporary branch in source repository
  508. # after the clone operation
  509. source_repo.__branch_delete(branch_name=temp_branch_name, force=True)
  510. return cloned_repo
  511. @staticmethod
  512. def clone_if_needed(source: str,
  513. dest: str,
  514. branch: str = None,
  515. single_branch: bool = False,
  516. recurse_submodules: bool = False,
  517. shallow_submodules: bool = False) -> "GitRepo":
  518. """
  519. Clone from the given source if a git repository does not exist.
  520. Parameters:
  521. source (str): Source path of the repository to clone
  522. Can be local or a url.
  523. dest (str): Destination path for the clone
  524. branch (str): Specific branch to clone (optional)
  525. single_branch (bool): Only clone a single branch (default is False)
  526. recurse_submodules (bool): Recurse into submodules (default is
  527. False)
  528. shallow_submodules (bool): Any cloned submodules will be shallow
  529. Returns:
  530. GitRepo: the GitRepo object for the local git repository.
  531. Raises:
  532. NonGitDirectoryError: If a directory exists at `dest` but is not
  533. a git repository.
  534. """
  535. repo: GitRepo
  536. try:
  537. repo = GitRepo(
  538. local_path=dest,
  539. )
  540. except FileNotFoundError:
  541. logger.info(f"No git repo found at {dest}. Creating.")
  542. # create parent directory if not present
  543. p = Path(dest)
  544. Path.mkdir(p.parent, parents=True, exist_ok=True)
  545. # clone
  546. repo = GitRepo.clone(
  547. source=source,
  548. dest=dest,
  549. branch=branch,
  550. single_branch=single_branch,
  551. recurse_submodules=recurse_submodules,
  552. shallow_submodules=shallow_submodules,
  553. )
  554. return repo