parse_prompt.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. """
  2. Copyright © 2022 Noah Vogt <noah@noahvogt.com>
  3. This program is free software: you can redistribute it and/or modify
  4. it under the terms of the GNU General Public License as published by
  5. the Free Software Foundation, either version 3 of the License, or
  6. (at your option) any later version.
  7. This program is distributed in the hope that it will be useful,
  8. but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. GNU General Public License for more details.
  11. You should have received a copy of the GNU General Public License
  12. along with this program. If not, see <http://www.gnu.org/licenses/>.
  13. """
  14. from re import match
  15. from utils import (
  16. log,
  17. structure_as_list,
  18. get_unique_structure_elements,
  19. )
  20. def parse_prompt_input(slidegen) -> list:
  21. full_structure_list = structure_as_list(slidegen.metadata["structure"])
  22. calculated_prompt = generate_final_prompt(
  23. str(slidegen.chosen_structure), slidegen.metadata["structure"]
  24. )
  25. log(
  26. "chosen structure: {}".format(calculated_prompt),
  27. color="cyan",
  28. )
  29. return structure_as_list(calculated_prompt)
  30. def generate_final_prompt(structure_prompt_answer, full_song_structure) -> str:
  31. valid_prompt, calculated_prompt = is_and_give_prompt_input_valid(
  32. structure_prompt_answer, full_song_structure
  33. )
  34. if not valid_prompt:
  35. log(
  36. "warning: prompt input '{}' is invalid, defaulting to full song structure".format(
  37. structure_prompt_answer
  38. ),
  39. color="yellow",
  40. )
  41. calculated_prompt = full_song_structure
  42. return calculated_prompt
  43. def is_and_give_prompt_input_valid(
  44. prompt: str, full_structure: list
  45. ) -> tuple[bool, str]:
  46. if not match(
  47. r"^(([0-9]+|R)|[0-9]+-[0-9]+)(,(([0-9]+|R)|[0-9]+-[0-9]+))*$", prompt
  48. ):
  49. return False, ""
  50. allowed_elements = get_unique_structure_elements(full_structure)
  51. test_elements = prompt.split(",")
  52. print("test elemets before loops: {}".format(test_elements))
  53. for index, element in enumerate(test_elements):
  54. if "-" in element:
  55. splitted_dashpart = element.split("-")
  56. if splitted_dashpart[0] >= splitted_dashpart[1]:
  57. return False, ""
  58. if splitted_dashpart[0] not in allowed_elements:
  59. return False, ""
  60. if splitted_dashpart[1] not in allowed_elements:
  61. return False, ""
  62. dotted_part = calculate_dashed_prompt_part(
  63. full_structure, splitted_dashpart[0], splitted_dashpart[1]
  64. )
  65. dotted_part.reverse()
  66. test_elements[index] = dotted_part[0]
  67. for left_over_dotted_part_element in dotted_part[1:]:
  68. test_elements.insert(index, left_over_dotted_part_element)
  69. else:
  70. if element not in allowed_elements:
  71. return False, ""
  72. return True, ",".join(test_elements)
  73. def calculate_dashed_prompt_part(
  74. content: list, start_verse: str, end_verse: str
  75. ) -> list:
  76. content = list(content)
  77. for i in content:
  78. if i == ",":
  79. content.remove(i)
  80. start_index = content.index(start_verse)
  81. if start_index != 0:
  82. if content[0] == "R" and content[start_index - 1] == "R":
  83. start_index -= 1
  84. end_index = content.index(end_verse)
  85. if end_index != len(content) - 1:
  86. if content[-1] == "R" and content[end_index + 1] == "R":
  87. end_index += 1
  88. return content[start_index : end_index + 1]