parse_prompt.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright © 2024 Noah Vogt <noah@noahvogt.com>
  2. # This program is free software: you can redistribute it and/or modify
  3. # it under the terms of the GNU General Public License as published by
  4. # the Free Software Foundation, either version 3 of the License, or
  5. # (at your option) any later version.
  6. # This program is distributed in the hope that it will be useful,
  7. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  8. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  9. # GNU General Public License for more details.
  10. # You should have received a copy of the GNU General Public License
  11. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  12. from re import match
  13. from utils import (
  14. log,
  15. structure_as_list,
  16. get_unique_structure_elements,
  17. )
  18. def parse_prompt_input(slidegen) -> list:
  19. calculated_prompt = generate_final_prompt(
  20. str(slidegen.chosen_structure), slidegen.metadata["structure"]
  21. )
  22. log(
  23. "chosen structure: {}".format(calculated_prompt),
  24. color="cyan",
  25. )
  26. return structure_as_list(calculated_prompt)
  27. def generate_final_prompt(structure_prompt_answer, full_song_structure) -> str:
  28. valid_prompt, calculated_prompt = is_and_give_prompt_input_valid(
  29. structure_prompt_answer, full_song_structure
  30. )
  31. if not valid_prompt:
  32. log(
  33. "warning: prompt input '{}' is invalid, defaulting to full".format(
  34. structure_prompt_answer
  35. )
  36. + " song structure...",
  37. color="cyan",
  38. )
  39. calculated_prompt = full_song_structure
  40. return calculated_prompt
  41. def is_and_give_prompt_input_valid(
  42. prompt: str, full_structure: list
  43. ) -> tuple[bool, str]:
  44. if not match(
  45. r"^(([0-9]+|R)|[0-9]+-[0-9]+)(,(([0-9]+|R)|[0-9]+-[0-9]+))*$", prompt
  46. ):
  47. return False, ""
  48. allowed_elements = get_unique_structure_elements(full_structure)
  49. test_elements = prompt.split(",")
  50. for index, element in enumerate(test_elements):
  51. if "-" in element:
  52. splitted_dashpart = element.split("-")
  53. if splitted_dashpart[0] >= splitted_dashpart[1]:
  54. return False, ""
  55. if splitted_dashpart[0] not in allowed_elements:
  56. return False, ""
  57. if splitted_dashpart[1] not in allowed_elements:
  58. return False, ""
  59. dotted_part = calculate_dashed_prompt_part(
  60. full_structure, splitted_dashpart[0], splitted_dashpart[1]
  61. )
  62. dotted_part.reverse()
  63. test_elements[index] = dotted_part[0]
  64. for left_over_dotted_part_element in dotted_part[1:]:
  65. test_elements.insert(index, left_over_dotted_part_element)
  66. else:
  67. if element not in allowed_elements:
  68. return False, ""
  69. return True, ",".join(test_elements)
  70. def calculate_dashed_prompt_part(
  71. content: list, start_verse: str, end_verse: str
  72. ) -> list:
  73. content = list(content)
  74. for i in content:
  75. if i == ",":
  76. content.remove(i)
  77. start_index = content.index(start_verse)
  78. if start_index != 0:
  79. if content[0] == "R" and content[start_index - 1] == "R":
  80. start_index -= 1
  81. end_index = content.index(end_verse)
  82. if end_index != len(content) - 1:
  83. if content[-1] == "R" and content[end_index + 1] == "R":
  84. end_index += 1
  85. return content[start_index : end_index + 1]