core.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. import logging
  2. import subprocess
  3. from threading import RLock
  4. from . import utils
  5. from . import exceptions as ex
  6. logger = logging.getLogger(__name__)
  7. class GitRepo:
  8. """
  9. Class to handle Git operations in a local Git repository.
  10. """
  11. __checkout_locks = dict()
  12. def __init__(self, local_path: str) -> None:
  13. """
  14. Initialize GitRepo with the path to the local Git repository.
  15. Parameters:
  16. local_path (str): Path to the Git repository
  17. """
  18. self.__set_local_path(local_path=local_path)
  19. self.__register_lock()
  20. logger.info(f"GitRepo initialised for {local_path}")
  21. def __eq__(self, other) -> bool:
  22. """
  23. Check if the instance is equal to the 'other' instance
  24. Parameters:
  25. other: the other Object to check
  26. """
  27. if type(other) is type(self):
  28. # return True if the paths of the repositories the objects
  29. # point to are equal, otherwise False
  30. return self.__local_path == other.__local_path
  31. return False
  32. def __hash__(self) -> int:
  33. """
  34. Return the hash value of the instance
  35. """
  36. return hash(self.__local_path)
  37. def __register_lock(self) -> None:
  38. """
  39. Initialize an RLock object for the instance in a shared dictionary
  40. """
  41. if not GitRepo.__checkout_locks.get(self):
  42. # create a Lock object for the instance, if is not already created
  43. GitRepo.__checkout_locks[self] = RLock()
  44. return
  45. def __set_local_path(self, local_path: str) -> None:
  46. """
  47. Set the path for the repository, ensuring it is a valid Git repo.
  48. Parameters:
  49. local_path (str): Path to the Git repository
  50. Raises:
  51. NonGitDirectoryError: If the directory is not a valid
  52. Git repository
  53. """
  54. if not utils.is_git_repo(local_path):
  55. raise ex.NonGitDirectoryError(directory=local_path)
  56. self.__local_path = local_path
  57. def get_local_path(self) -> str:
  58. """
  59. Return the local path of the repository.
  60. Returns:
  61. str: Path to the Git repository
  62. """
  63. return self.__local_path
  64. def get_checkout_lock(self) -> RLock:
  65. """
  66. Return the checkout lock object associated with the instance
  67. Returns:
  68. RLock: The lock object associated with the instance
  69. """
  70. lock = GitRepo.__checkout_locks.get(self)
  71. if lock is None:
  72. raise ex.LockNotInitializedError(
  73. lock_name="checkout_lock",
  74. path=self.__local_path
  75. )
  76. return lock
  77. def __checkout(self, commit_ref: str, force: bool = False) -> None:
  78. """
  79. Check out a specific commit.
  80. Parameters:
  81. commit_ref (str): Commit reference to check out
  82. force (bool): Force checkout (default is False)
  83. Raises:
  84. ValueError: If commit_ref is None
  85. """
  86. if commit_ref is None:
  87. raise ValueError("commit_ref is required, cannot be None.")
  88. cmd = ['git', 'checkout', commit_ref]
  89. if force:
  90. cmd.append('-f')
  91. logger.debug(f"Running {' '.join(cmd)}")
  92. logger.debug("Attempting to aquire checkout lock.")
  93. with self.get_checkout_lock():
  94. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  95. def __reset(self, commit_ref: str, hard: bool = False) -> None:
  96. """
  97. Reset to a specific commit.
  98. Parameters:
  99. commit_ref (str): Commit reference to reset to
  100. hard (bool): Use hard reset (default is False)
  101. Raises:
  102. ValueError: If commit_ref is None
  103. """
  104. if commit_ref is None:
  105. raise ValueError("commit_ref is required, cannot be None.")
  106. cmd = ['git', 'reset', commit_ref]
  107. if hard:
  108. cmd.append('--hard')
  109. logger.debug(f"Running {' '.join(cmd)}")
  110. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  111. def __force_recursive_clean(self) -> None:
  112. """
  113. Forcefully clean the working directory,
  114. removing untracked files and directories.
  115. """
  116. cmd = ['git', 'clean', '-xdff']
  117. logger.debug(f"Running {' '.join(cmd)}")
  118. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  119. def __remote_list(self) -> list[str]:
  120. """
  121. Retrieve a list of remotes added to the repository
  122. Returns:
  123. list[str]: List of remote names
  124. """
  125. cmd = ['git', 'remote']
  126. logger.debug(f"Running {' '.join(cmd)}")
  127. ret = subprocess.run(
  128. cmd, cwd=self.__local_path, shell=False, capture_output=True,
  129. encoding='utf-8', check=True
  130. )
  131. return ret.stdout.split('\n')[:-1]
  132. def __is_commit_present_locally(self, commit_ref: str) -> bool:
  133. """
  134. Check if a specific commit exists locally.
  135. Parameters:
  136. commit_ref (str): Commit hash to check
  137. Returns:
  138. bool: True if the commit exists locally, False otherwise
  139. """
  140. if commit_ref is None:
  141. raise ValueError("commit_ref is required, cannot be None.")
  142. cmd = ['git', 'diff-tree', commit_ref, '--no-commit-id', '--no-patch']
  143. logger.debug(f"Running {' '.join(cmd)}")
  144. ret = subprocess.run(cmd, cwd=self.__local_path, shell=False)
  145. return ret.returncode == 0
  146. def remote_set_url(self, remote: str, url: str) -> None:
  147. """
  148. Set the URL for a specific remote.
  149. Parameters:
  150. remote (str): Name of the remote
  151. url (str): URL to set for the remote
  152. Raises:
  153. ValueError: If remote or URL is None
  154. """
  155. if remote is None:
  156. raise ValueError("remote is required, cannot be None.")
  157. if url is None:
  158. raise ValueError("url is required, cannot be None.")
  159. cmd = ['git', 'remote', 'set-url', remote, url]
  160. logger.debug(f"Running {' '.join(cmd)}")
  161. subprocess.run(cmd, cwd=self.__local_path, check=True)
  162. def remote_get_url(self, remote: str) -> str:
  163. """
  164. Get the URL for a specific remote.
  165. Parameters:
  166. remote (str): Name of the remote
  167. Returns:
  168. str: The URL associated with the remote
  169. Raises:
  170. ValueError: If remote is None
  171. """
  172. if remote is None:
  173. raise ValueError("remote is required, cannot be None.")
  174. cmd = ['git', 'remote', 'get-url', remote]
  175. logger.debug(f"Running {' '.join(cmd)}")
  176. # Capture the output of the command
  177. result = subprocess.run(
  178. cmd,
  179. cwd=self.__local_path,
  180. check=True,
  181. capture_output=True,
  182. text=True
  183. )
  184. # Return the URL from the output
  185. return result.stdout.strip()
  186. def fetch_remote(self, remote: str, force: bool = False,
  187. tags: bool = False, recurse_submodules: bool = False,
  188. refetch: bool = False) -> None:
  189. """
  190. Fetch updates from a remote repository.
  191. Parameters:
  192. remote (str): Remote to fetch from; if None, fetches all
  193. force (bool): Force fetch (default is False)
  194. tags (bool): Fetch tags (default is False)
  195. recurse_submodules (bool): Recurse into submodules
  196. (default is False)
  197. refetch (bool): Re-fetch all objects (default is False)
  198. """
  199. cmd = ['git', 'fetch']
  200. if remote:
  201. cmd.append(remote)
  202. else:
  203. logger.info("fetch_remote: remote is None, fetching all remotes")
  204. cmd.append('--all')
  205. if force:
  206. cmd.append('--force')
  207. if tags:
  208. cmd.append('--tags')
  209. if refetch:
  210. cmd.append('--refetch')
  211. if recurse_submodules:
  212. cmd.append('--recurse-submodules')
  213. else:
  214. cmd.append('--no-recurse-submodules')
  215. logger.debug(f"Running {' '.join(cmd)}")
  216. subprocess.run(cmd, cwd=self.__local_path, shell=False)
  217. def __branch_create(self, branch_name: str,
  218. start_point: str = None) -> None:
  219. """
  220. Create a new branch starting from a given commit.
  221. Parameters:
  222. branch_name (str): Name of the branch to create
  223. start_point (str): Starting commit or branch (optional)
  224. """
  225. if branch_name is None:
  226. raise ValueError("branch_name is required, cannot be None.")
  227. cmd = ['git', 'branch', branch_name]
  228. if start_point:
  229. if not self.__is_commit_present_locally(commit_ref=start_point):
  230. raise ex.CommitNotFoundError(commit_ref=start_point)
  231. cmd.append(start_point)
  232. logger.debug(f"Running {' '.join(cmd)}")
  233. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  234. def __branch_delete(self, branch_name: str, force: bool = False) -> None:
  235. """
  236. Delete a local branch.
  237. Parameters:
  238. branch_name (str): Name of the branch to delete
  239. force (bool): Force delete (default is False)
  240. """
  241. if branch_name is None:
  242. raise ValueError("branch_name is required, cannot be None.")
  243. if not self.__is_commit_present_locally(commit_ref=branch_name):
  244. raise ex.CommitNotFoundError(commit_ref=branch_name)
  245. cmd = ['git', 'branch', '-d', branch_name]
  246. if force:
  247. cmd.append('--force')
  248. logger.debug(f"Running {' '.join(cmd)}")
  249. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  250. def commit_id_for_remote_ref(self, remote: str,
  251. commit_ref: str) -> str:
  252. """
  253. Get the commit ID for a specific commit reference from a remote.
  254. Parameters:
  255. remote (str): Name of the remote
  256. commit_ref (str): Reference to get the commit ID for
  257. Returns:
  258. str | None: Commit ID if found, None otherwise
  259. """
  260. if remote is None:
  261. raise ValueError("remote is required, cannot be None.")
  262. if remote not in self.__remote_list():
  263. raise ex.RemoteNotFoundError(remote=remote)
  264. if commit_ref is None:
  265. raise ValueError("commit_ref is required, cannot be None.")
  266. if utils.is_valid_hex_string(test_str=commit_ref):
  267. # skip conversion if commit_ref is already hex string
  268. return commit_ref
  269. # allow branches and tags only for now
  270. allowed_ref_types = ['tags', 'heads']
  271. split_ref = commit_ref.split('/', 2)
  272. if len(split_ref) != 3 or split_ref[0] != 'refs':
  273. raise ValueError(f"commit_ref '{commit_ref}' format is invalid.")
  274. _, ref_type, _ = split_ref
  275. if ref_type not in allowed_ref_types:
  276. raise ValueError(f"ref_type '{ref_type}' is not supported.")
  277. cmd = ['git', 'ls-remote', remote]
  278. logger.debug(f"Running {' '.join(cmd)}")
  279. ret = subprocess.run(
  280. cmd, cwd=self.__local_path, encoding='utf-8', capture_output=True,
  281. shell=False, check=True
  282. )
  283. for line in ret.stdout.split('\n')[:-1]:
  284. (commit_id, res_ref) = line.split('\t')
  285. if res_ref == commit_ref:
  286. return commit_id
  287. return None
  288. def __ensure_commit_fetched(self, remote: str, commit_id: str) -> None:
  289. """
  290. Ensure a specific commit is fetched from the remote repository.
  291. Parameters:
  292. remote (str): Remote name to fetch from
  293. commit_id (str): Commit ID to ensure it is available locally
  294. Raises:
  295. RemoteNotFoundError: If the specified remote does not exist
  296. CommitNotFoundError: If the commit cannot be fetched after
  297. multiple attempts
  298. """
  299. if remote is None:
  300. raise ValueError("remote is required, cannot be None.")
  301. if remote not in self.__remote_list():
  302. raise ex.RemoteNotFoundError(remote=remote)
  303. if commit_id is None:
  304. raise ValueError("commit_id is required, cannot be None.")
  305. if not utils.is_valid_hex_string(test_str=commit_id):
  306. raise ValueError(
  307. f"commit_id should be a hex string, got '{commit_id}'."
  308. )
  309. if self.__is_commit_present_locally(commit_ref=commit_id):
  310. # early return if commit is already fetched
  311. return
  312. self.fetch_remote(remote=remote, force=True, tags=True)
  313. # retry fetch with refetch option if the commit is still not found
  314. if not self.__is_commit_present_locally(commit_ref=commit_id):
  315. self.fetch_remote(
  316. remote=remote, force=True, tags=True, refetch=True
  317. )
  318. if not self.__is_commit_present_locally(commit_ref=commit_id):
  319. raise ex.CommitNotFoundError(commit_ref=commit_id)
  320. def checkout_remote_commit_ref(self, remote: str,
  321. commit_ref: str,
  322. force: bool = False,
  323. hard_reset: bool = False,
  324. clean_working_tree: bool = False) -> None:
  325. """
  326. Check out a specific commit from a remote repository.
  327. Parameters:
  328. remote (str): Remote name to check out from
  329. commit_ref (str): Commit reference to check out
  330. force (bool): Force the checkout (default is False)
  331. hard_reset (bool): Hard reset after checkout (default is False)
  332. clean_working_tree (bool): Clean untracked files after checkout
  333. (default is False)
  334. Raises:
  335. RemoteNotFoundError: If the specified remote does not exist
  336. CommitNotFoundError: If the specified commit cannot be found
  337. """
  338. if remote is None:
  339. logger.error("remote cannot be None for checkout to remote commit")
  340. raise ValueError("remote is required, cannot be None.")
  341. if remote not in self.__remote_list():
  342. raise ex.RemoteNotFoundError(remote=remote)
  343. if commit_ref is None:
  344. raise ValueError("commit_ref is required, cannot be None.")
  345. # retrieve the commit ID for the specified commit reference
  346. commit_id = self.commit_id_for_remote_ref(
  347. remote=remote, commit_ref=commit_ref
  348. )
  349. # ensure the commit is fetched from the remote repository
  350. self.__ensure_commit_fetched(remote=remote, commit_id=commit_id)
  351. # perform checkout on the specified commit using the commit ID
  352. # commit ID is used in place of branch name or tag name to make sure
  353. # do not check out the branch or tag from wrong remote
  354. self.__checkout(commit_ref=commit_id, force=force)
  355. # optional hard reset and clean of working tree after checkout
  356. if hard_reset:
  357. self.__reset(commit_ref=commit_id, hard=True)
  358. if clean_working_tree:
  359. self.__force_recursive_clean()
  360. def submodule_update(self, init: bool = False, recursive: bool = False,
  361. force: bool = False) -> None:
  362. """
  363. Update Git submodules for the repository.
  364. Parameters:
  365. init (bool): Initialize submodules if they are not initialized
  366. (default is False)
  367. recursive (bool): Update submodules recursively (default is False)
  368. force (bool): Force update even if there are changes
  369. (default is False)
  370. """
  371. cmd = ['git', 'submodule', 'update']
  372. if init:
  373. cmd.append('--init')
  374. if recursive:
  375. cmd.append('--recursive')
  376. if force:
  377. cmd.append('--force')
  378. logger.debug(f"Running {' '.join(cmd)}")
  379. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  380. def remote_add(self, remote: str, url: str) -> None:
  381. """
  382. Add a new remote to the Git repository.
  383. Parameters:
  384. remote (str): Name of the remote to add
  385. url (str): URL for the remote repository
  386. Raises:
  387. DuplicateRemoteError: If remote already exists and
  388. overwrite is not allowed
  389. """
  390. if remote is None:
  391. raise ValueError("remote is required, cannot be None.")
  392. if url is None:
  393. raise ValueError("url is required, cannot be None.")
  394. # Set the URL if the remote exists and overwrite is allowed
  395. if remote in self.__remote_list():
  396. raise ex.DuplicateRemoteError(remote)
  397. # Add the new remote
  398. cmd = ['git', 'remote', 'add', remote, url]
  399. logger.debug(f"Running {' '.join(cmd)}")
  400. subprocess.run(cmd, cwd=self.__local_path, shell=False, check=True)
  401. def remote_add_bulk(self, remotes: tuple, force: bool = False) -> None:
  402. """
  403. Add multiple remotes to the Git repository at once.
  404. Parameters:
  405. remotes (tuple): Tuple of tuples containing remote name
  406. and url.
  407. E.g. (
  408. ('remote1', 'https://remote1_url'),
  409. ('remote2', 'https://remote2_url'),
  410. )
  411. force (bool): Force update the url if remote already exists.
  412. Raises:
  413. DuplicateRemoteError: If remote already exists and
  414. overwrite is not allowed.
  415. """
  416. logger.info(f"Remotes to add: {remotes}.")
  417. for (remote, url) in remotes:
  418. try:
  419. self.remote_add(remote=remote, url=url)
  420. except ex.DuplicateRemoteError:
  421. if not force:
  422. raise
  423. logger.info(f"Remote {remote} already exists. Updating url.")
  424. self.remote_set_url(remote=remote, url=url)
  425. logger.info(f"Remote {remote} added to repo with url {url}.")
  426. @staticmethod
  427. def clone(source: str,
  428. dest: str,
  429. branch: str = None,
  430. single_branch: bool = False,
  431. recurse_submodules: bool = False,
  432. shallow_submodules: bool = False) -> "GitRepo":
  433. """
  434. Clone a Git repository.
  435. Parameters:
  436. source (str): Source path of the repository to clone
  437. Can be local or a url.
  438. dest (str): Destination path for the clone
  439. branch (str): Specific branch to clone (optional)
  440. single_branch (bool): Only clone a single branch (default is False)
  441. recurse_submodules (bool): Recurse into submodules
  442. (default is False)
  443. shallow_submodules (bool): any cloned submodules will be shallow
  444. Returns:
  445. GitRepo: the cloned git repository
  446. """
  447. cmd = ['git', 'clone', source, dest]
  448. if branch:
  449. cmd.append('--branch=' + branch)
  450. if single_branch:
  451. cmd.append('--single-branch')
  452. if recurse_submodules:
  453. cmd.append('--recurse-submodules')
  454. if shallow_submodules:
  455. cmd.append('--shallow-submodules')
  456. logger.debug(f"Running {' '.join(cmd)}")
  457. subprocess.run(cmd, shell=False, check=True)
  458. return GitRepo(local_path=dest)
  459. @staticmethod
  460. def shallow_clone_at_commit_from_local(source: str,
  461. remote: str,
  462. commit_ref: str,
  463. dest: str) -> "GitRepo":
  464. """
  465. Perform a shallow clone of a repository at a specific commit.
  466. Parameters:
  467. source (str): Source path of the local repository
  468. remote (str): Remote name containing the commit
  469. commit_ref (str): Commit reference to clone
  470. dest (str): Destination path for the clone
  471. Returns:
  472. GitRepo: the cloned git repository
  473. Raises:
  474. RemoteNotFoundError: If the specified remote does not exist
  475. CommitNotFoundError: If the specified commit cannot be found
  476. """
  477. if remote is None:
  478. raise ValueError("remote is required, cannot be None.")
  479. if commit_ref is None:
  480. raise ValueError("commit_ref is required, cannot be None.")
  481. source_repo = GitRepo(local_path=source)
  482. # get the commit ID for the specified remote reference
  483. commit_id = source_repo.commit_id_for_remote_ref(
  484. remote=remote, commit_ref=commit_ref
  485. )
  486. source_repo.__ensure_commit_fetched(remote=remote, commit_id=commit_id)
  487. # create a temporary branch to point to the specified commit
  488. # as shallow clone needs a branch
  489. temp_branch_name = "temp-b-" + commit_id
  490. source_repo.__branch_create(
  491. branch_name=temp_branch_name, start_point=commit_id
  492. )
  493. # perform the clone from the source repository
  494. # using the temporary branch
  495. cloned_repo = GitRepo.clone(
  496. source=source,
  497. dest=dest,
  498. branch=temp_branch_name,
  499. single_branch=True,
  500. recurse_submodules=True,
  501. shallow_submodules=True
  502. )
  503. # add the remote containing the commit in cloned repo for reference
  504. url = source_repo.remote_get_url(remote=remote)
  505. cloned_repo.remote_add(remote=remote, url=url)
  506. # delete the temporary branch in source repository
  507. # after the clone operation
  508. source_repo.__branch_delete(branch_name=temp_branch_name, force=True)
  509. return cloned_repo