reels-app/scripts/find_chorus.py

308 lines
11 KiB
Python

#!/usr/bin/env python3
"""
find_chorus.py — Avto-detekcija refrena v glasbenem videu.
Hibridni pristop:
1. Whisper transkribira pesem z word-level timestamps
2. Najde ponavljajoče se vrstice (n-gram matching + Levenshtein)
3. Energy analiza prek FFmpeg (RMS dB) — refren je navadno glasnejši
4. Združi: ponovljen tekst + visoka energija = refren
Output: JSON z najboljšimi kandidati (ranked).
Primer:
python3 find_chorus.py pesem.mp4
python3 find_chorus.py pesem.mp4 --duration 30 --json
"""
import argparse
import json
import subprocess
import sys
import tempfile
from collections import Counter
from pathlib import Path
import re
def extract_audio(video, sample_rate=16000):
"""Ekstrahiraj mono WAV za Whisper in energy analizo."""
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp.close()
cmd = [
"ffmpeg", "-y", "-i", str(video),
"-vn", "-ac", "1", "-ar", str(sample_rate),
"-c:a", "pcm_s16le", tmp.name,
]
subprocess.run(cmd, check=True, stderr=subprocess.DEVNULL)
return tmp.name
def transcribe(audio_path, lang=None, model_size="small"):
"""Whisper transkripcija z word-level timestamps."""
from faster_whisper import WhisperModel
print(f"🧠 Whisper: {model_size}, lang={lang or 'auto'}", file=sys.stderr)
model = WhisperModel(model_size, device="cpu", compute_type="int8")
segments, info = model.transcribe(
audio_path,
language=lang,
word_timestamps=True,
vad_filter=True,
)
print(f" Detekcija: {info.language} (p={info.language_probability:.2f})", file=sys.stderr)
# Vrne seznam line-level segmentov s timestamp-i
lines = []
for seg in segments:
text = seg.text.strip()
if text:
lines.append({
"start": seg.start,
"end": seg.end,
"text": text,
"duration": seg.end - seg.start,
})
return lines, info.language
def normalize_text(s):
"""Normalize za primerjavo: lowercase, brez punktuacije."""
s = s.lower()
s = re.sub(r"[^\w\s]", "", s)
s = re.sub(r"\s+", " ", s).strip()
return s
def line_similarity(a, b):
"""Jaccard similarity na bigrams besedah."""
a_words = normalize_text(a).split()
b_words = normalize_text(b).split()
if not a_words or not b_words:
return 0.0
def bigrams(words):
return set(zip(words, words[1:])) if len(words) > 1 else {(words[0],)}
a_bg = bigrams(a_words)
b_bg = bigrams(b_words)
if not a_bg or not b_bg:
return 0.0
return len(a_bg & b_bg) / len(a_bg | b_bg)
def find_repeated_lines(lines, similarity_threshold=0.5):
"""
Najdi ponavljajoče se vrstice. Vrne seznam clustrov.
Vsak cluster = list[indices_v_lines] kjer so si vrstice podobne.
"""
n = len(lines)
visited = [False] * n
clusters = []
for i in range(n):
if visited[i]:
continue
cluster = [i]
visited[i] = True
for j in range(i + 1, n):
if visited[j]:
continue
sim = line_similarity(lines[i]["text"], lines[j]["text"])
if sim >= similarity_threshold:
cluster.append(j)
visited[j] = True
if len(cluster) >= 2: # samo če se ponovi vsaj 2x
clusters.append(cluster)
return clusters
def compute_energy(audio_path, window_sec=1.0):
"""
Vrni list (timestamp, rms_db) preko FFmpeg astats filter.
Vsako okno window_sec sekund vrne en RMS sample.
"""
cmd = [
"ffmpeg", "-i", audio_path,
"-af", f"asetnsamples=n={int(16000 * window_sec)}:p=0,astats=metadata=1:reset={window_sec},"
"ametadata=print:key=lavfi.astats.Overall.RMS_level:file=-",
"-f", "null", "-",
]
result = subprocess.run(cmd, capture_output=True, text=True)
# ametadata file=- pošilja na stdout
output = result.stdout + "\n" + result.stderr
energies = []
current_pts = None
for line in output.split("\n"):
line = line.strip()
# Format A: "frame:N pts:X pts_time:Y"
m = re.search(r"pts_time:(\S+)", line)
if m:
try:
current_pts = float(m.group(1))
except ValueError:
pass
continue
# Format B: lavfi.astats.Overall.RMS_level=-15.123
if "RMS_level=" in line:
val = line.split("RMS_level=")[-1].strip()
try:
rms = float(val)
# Če nimamo timestamp-a, sintetiziraj na podlagi vrstnega reda
if current_pts is None:
current_pts = len(energies) * window_sec
energies.append((current_pts, rms))
# Increment za naslednji vzorec, če FFmpeg ne pošilja pts
current_pts += window_sec
except ValueError:
pass
return energies
def avg_energy_in_range(energies, start, end):
"""Povprečna RMS v [start, end]."""
in_range = [e for t, e in energies if start <= t <= end]
if not in_range:
return -60.0 # default tih
return sum(in_range) / len(in_range)
def find_chorus(video, lang=None, model_size="small", target_duration=30.0):
"""
Glavni entry point. Vrne ranked kandidate refrenov.
"""
audio = extract_audio(video)
try:
lines, detected_lang = transcribe(audio, lang=lang, model_size=model_size)
if not lines:
return {"error": "Brez transkripcije", "candidates": []}
print(f"📝 {len(lines)} vrstic transkripta", file=sys.stderr)
clusters = find_repeated_lines(lines, similarity_threshold=0.5)
print(f"🔁 {len(clusters)} ponavljajočih se sklopov", file=sys.stderr)
if not clusters:
return {
"error": "Ni najdenih ponavljajočih se vrstic",
"language": detected_lang,
"candidates": [],
}
print("🔊 Analiza energije...", file=sys.stderr)
energies = compute_energy(audio)
avg_overall = sum(e for _, e in energies) / max(1, len(energies))
print(f" Povprečje RMS: {avg_overall:.1f} dB", file=sys.stderr)
# Za vsak cluster izračunaj score
candidates = []
for cluster_idx, cluster in enumerate(clusters):
# Predstavnik clusterja = najdaljša vrstica
rep = max(cluster, key=lambda i: len(lines[i]["text"]))
rep_text = lines[rep]["text"]
# Vsaka instanca = potencialen reel start
for inst_idx in cluster:
line = lines[inst_idx]
# Razširi okno na target_duration začenši pri tej vrstici
start = line["start"]
end = min(start + target_duration, line["start"] + target_duration)
# Najdi konec videa (zadnja vrstica)
video_end = max(l["end"] for l in lines)
if start + target_duration > video_end:
start = max(0, video_end - target_duration)
end = video_end
avg_e = avg_energy_in_range(energies, start, start + target_duration)
energy_score = max(0, avg_e - avg_overall) # koliko nad povprečjem
# Score: refren = ponavljajoča kratka fraza + glasnejši
# - Ponovitve močno štejejo
# - Energija močno šteje (refren je glasnejši)
# - Kratke vrstice z veliko ponovitvami iste besede = klasičen refren
rep_words = rep_text.lower().split()
# Penaliziraj dolge informativne vrstice (verzi imajo več različnih besed)
unique_word_ratio = len(set(rep_words)) / max(len(rep_words), 1)
# Refren ima nizek unique_ratio (ponavljajoče besede), verz visok
chorus_signal = max(0, (1.0 - unique_word_ratio) * 20)
score = (
len(cluster) * 8 # ponovitve
+ energy_score * 10 # energija (refren je glasnejši)
+ chorus_signal # ponavljajoča fraza signal
)
candidates.append({
"start": round(start, 2),
"end": round(start + target_duration, 2),
"duration": target_duration,
"score": round(score, 2),
"repetitions": len(cluster),
"avg_rms_db": round(avg_e, 1),
"energy_above_avg_db": round(energy_score, 1),
"text_sample": rep_text[:80],
"cluster_id": cluster_idx,
})
# Sort by score, dedupe close candidates (vsaj 20s narazen)
candidates.sort(key=lambda c: -c["score"])
deduped = []
for c in candidates:
if all(abs(c["start"] - d["start"]) > 20 for d in deduped):
deduped.append(c)
if len(deduped) >= 5:
break
return {
"language": detected_lang,
"total_lines": len(lines),
"clusters_found": len(clusters),
"candidates": deduped,
}
finally:
Path(audio).unlink(missing_ok=True)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("input")
ap.add_argument("--lang", default=None)
ap.add_argument("--model", default="small",
choices=["tiny", "base", "small", "medium", "large-v3"])
ap.add_argument("--duration", type=float, default=30.0,
help="Ciljna dolžina reel-a v s")
ap.add_argument("--json", action="store_true", help="JSON output")
args = ap.parse_args()
src = Path(args.input)
if not src.exists():
print(f"{src} ne obstaja", file=sys.stderr)
sys.exit(1)
result = find_chorus(src, lang=args.lang, model_size=args.model,
target_duration=args.duration)
if args.json:
print(json.dumps(result, ensure_ascii=False, indent=2))
else:
if "error" in result and not result.get("candidates"):
print(f"{result['error']}", file=sys.stderr)
sys.exit(2)
print(f"\n🎵 Jezik: {result.get('language', '?')}")
print(f"📋 {result['total_lines']} vrstic, {result['clusters_found']} ponavljanj\n")
print("🏆 Najboljši kandidati za refren:\n")
for i, c in enumerate(result["candidates"], 1):
mins = int(c["start"] // 60)
secs = c["start"] - mins * 60
print(f" {i}. {mins}:{secs:05.2f} → +{c['duration']:.0f}s "
f"(score={c['score']}, ponovitev={c['repetitions']}, "
f"energija={c['energy_above_avg_db']:+.1f} dB)")
print(f" '{c['text_sample']}'\n")
if __name__ == "__main__":
main()