parse_prompt.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. """
  2. Copyright © 2022 Noah Vogt <noah@noahvogt.com>
  3. This program is free software: you can redistribute it and/or modify
  4. it under the terms of the GNU General Public License as published by
  5. the Free Software Foundation, either version 3 of the License, or
  6. (at your option) any later version.
  7. This program is distributed in the hope that it will be useful,
  8. but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. GNU General Public License for more details.
  11. You should have received a copy of the GNU General Public License
  12. along with this program. If not, see <http://www.gnu.org/licenses/>.
  13. """
  14. from re import match
  15. from utils import (
  16. log,
  17. structure_as_list,
  18. get_unique_structure_elements,
  19. )
  20. def parse_prompt_input(slidegen) -> list:
  21. calculated_prompt = generate_final_prompt(
  22. str(slidegen.chosen_structure), slidegen.metadata["structure"]
  23. )
  24. log(
  25. "chosen structure: {}".format(calculated_prompt),
  26. color="cyan",
  27. )
  28. return structure_as_list(calculated_prompt)
  29. def generate_final_prompt(structure_prompt_answer, full_song_structure) -> str:
  30. valid_prompt, calculated_prompt = is_and_give_prompt_input_valid(
  31. structure_prompt_answer, full_song_structure
  32. )
  33. if not valid_prompt:
  34. log(
  35. "warning: prompt input '{}' is invalid, defaulting to full".format(
  36. structure_prompt_answer
  37. )
  38. + " song structure...",
  39. color="cyan",
  40. )
  41. calculated_prompt = full_song_structure
  42. return calculated_prompt
  43. def is_and_give_prompt_input_valid(
  44. prompt: str, full_structure: list
  45. ) -> tuple[bool, str]:
  46. if not match(
  47. r"^(([0-9]+|R)|[0-9]+-[0-9]+)(,(([0-9]+|R)|[0-9]+-[0-9]+))*$", prompt
  48. ):
  49. return False, ""
  50. allowed_elements = get_unique_structure_elements(full_structure)
  51. test_elements = prompt.split(",")
  52. for index, element in enumerate(test_elements):
  53. if "-" in element:
  54. splitted_dashpart = element.split("-")
  55. if splitted_dashpart[0] >= splitted_dashpart[1]:
  56. return False, ""
  57. if splitted_dashpart[0] not in allowed_elements:
  58. return False, ""
  59. if splitted_dashpart[1] not in allowed_elements:
  60. return False, ""
  61. dotted_part = calculate_dashed_prompt_part(
  62. full_structure, splitted_dashpart[0], splitted_dashpart[1]
  63. )
  64. dotted_part.reverse()
  65. test_elements[index] = dotted_part[0]
  66. for left_over_dotted_part_element in dotted_part[1:]:
  67. test_elements.insert(index, left_over_dotted_part_element)
  68. else:
  69. if element not in allowed_elements:
  70. return False, ""
  71. return True, ",".join(test_elements)
  72. def calculate_dashed_prompt_part(
  73. content: list, start_verse: str, end_verse: str
  74. ) -> list:
  75. content = list(content)
  76. for i in content:
  77. if i == ",":
  78. content.remove(i)
  79. start_index = content.index(start_verse)
  80. if start_index != 0:
  81. if content[0] == "R" and content[start_index - 1] == "R":
  82. start_index -= 1
  83. end_index = content.index(end_verse)
  84. if end_index != len(content) - 1:
  85. if content[-1] == "R" and content[end_index + 1] == "R":
  86. end_index += 1
  87. return content[start_index : end_index + 1]