TransNetV2模型用于视频复杂转场检测
Published in:2025-06-09 |
Words: 7.4k | Reading time: 41min | reading:

TransNet V2 模型介绍

  1. 目标:
    TransNet V2 的目标是创建一个快速且准确的镜头边界检测(Shot Boundary Detection, SBD)模型。它能够识别视频中的两种主要类型的镜头切换:
    • 硬切 (Hard Cuts): 两个镜头之间的突变。
    • 渐变转场 (Gradual Transitions): 如溶解 (dissolves)、淡入淡出 (fades)、划像 (wipes) 等,这些转场会持续数帧。
  • 模型架构

img

  1. 核心思想:
    TransNet V2 是一个深度神经网络。它直接处理一系列连续的、低分辨率的 RGB 视频帧,来预测每一帧发生镜头转换的概率,并能区分转场类型。

  2. 关键架构特点:

    • 输入 (Input):
      • 模型接收 N 个连续的、低分辨率 RGB 视频帧作为输入(例如,论文中常使用 N=100 帧,每帧分辨率缩小到 48x27 像素)。
      • 直接使用原始的、降采样后的 RGB 帧(而不是预先计算的特征,如帧差或光流)使得网络可以端到端地学习最优的视觉特征。
    • 帧编码 (Frame Encoding):
      • 输入的每一帧首先会经过两个 2D 卷积层和最大池化层,用于提取空间特征。
      • 这些特征图随后被展平 (flatten) 并按时间顺序堆叠起来,形成一个特征序列。
    • 时间建模 (Temporal Modeling) - 核心部分:
      • TransNet V2 的核心是一堆一维扩张卷积 (1D Dilated Convolutions) 块,它们作用于帧特征序列。
      • 一维卷积: 非常适合处理序列数据,计算效率高。
      • 扩张卷积 (Dilated Convolutions): 允许网络在不显著增加参数数量或计算成本(通过避免过多的池化层)的情况下,拥有非常大的感受野(即能看到很长的时间上下文)。这对于检测跨越多帧的渐变转场至关重要。每个扩张卷积块通常包含一个扩张一维卷积层、批量归一化 (Batch Normalization)、ReLU 激活函数和 Dropout。
    • 双预测头 (Dual Prediction Heads):
      网络从最后一个扩张卷积层的输出特征中引出两个分支(预测头):
      1. 转场概率预测 (Transition Probability Prediction): 一个核大小为 1 的一维卷积层,后接 Sigmoid 激活函数。对输入序列中的每一帧输出一个概率值(0 到 1),表示该帧是镜头边界(硬切或渐变转场的中心)的可能性。使用二元交叉熵 (BCE) 损失进行训练。
      2. 转场类型预测 (Transition Type Prediction - One-Hot): 另一个核大小为 1 的一维卷积层,有 3 个输出通道,每个通道后接 Sigmoid。它们分别预测每一帧属于以下类别的概率:
        • 无镜头转场
        • 硬切
        • 渐变转场的中心
          这被视为三个独立的二分类问题,同样使用 BCE 损失训练。这个辅助任务有助于网络学习更丰富的特征。
    • 组合损失 (Combined Loss):
      总损失是两个预测头 BCE 损失的加权和。
  3. 模型优势:

    • 高准确性: 在多个标准 SBD 数据集(如 ClipShots, TRECVID IACC.3, RAI, BBC Planet Earth)上取得了业界领先 (SOTA) 的结果。
    • 高速度: 为效率而设计,在现代 GPU 上能够实现远超实时的处理速度(例如,在 NVIDIA V100 上网络推理部分约 150 FPS,不包括视频解码)。
    • 鲁棒性强: 能够有效处理硬切和各种复杂的渐变转场。
    • 端到端学习: 直接从低分辨率 RGB 帧中学习特征,无需复杂的手动特征工程。

TransNet V2 使用方法 (推理与后处理)

  1. 输入视频处理:

    • 将视频解码成单独的帧。
    • 将每一帧的大小调整到网络所需的输入分辨率(例如 48x27 像素)。
  2. 滑动窗口推理:

    • 视频以 N 帧(如 100 帧)的重叠窗口进行处理。
    • 对于每个窗口,网络从主要的转场概率预测头输出 N 个概率值(对应窗口中的每一帧)。
    • 由于窗口是重叠的,视频中的同一帧可能会获得多个预测。通常会对这些预测进行平均,以获得该帧更稳定的概率。
  3. 预测结果后处理:

    • 阈值化 (Thresholding):
      • 将(平均后的)逐帧概率与一个预定义的阈值 P_thresh 进行比较。概率超过此阈值的帧被认为是候选的镜头边界。
      • 论文建议 P_thresh = 0.5 是一个不错的默认值,但可以根据实际情况调整。
    • 最小镜头长度 (Minimum Shot Length - L_min):
      • 为了避免产生许多非常短的、可能是误报的镜头,会施加一个约束:任何会导致镜头长度小于 L_min 帧的检测边界都会被丢弃。
      • 这通常通过遍历候选边界并确保其与前一个已接受边界的距离至少为 L_min 来实现。
      • 论文中提到根据数据集或期望的粒度使用 L_min 值,如 10 帧或 25 帧。
  4. 输出:

    • 最终输出是一个列表,包含检测到的镜头边界所在的帧索引号。
    • 可选地,可以使用转场类型预测头的输出来对检测到的边界进行分类,但通常主要关注的是边界的检测本
      身。
  • 模型检测结果

img

使用

模型推理

  • 推理脚本如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
import os
import sys # 引入 sys 以便使用 sys.stderr
import numpy as np
import tensorflow as tf


class TransNetV2:

def __init__(self, model_dir=None):
if model_dir is None:
model_dir = os.path.join(os.path.dirname(__file__), "transnetv2-weights/")
if not os.path.isdir(model_dir):
raise FileNotFoundError(f"[TransNetV2] ERROR: {model_dir} is not a directory.")
else:
print(f"[TransNetV2] Using weights from {model_dir}.")

self._input_size = (27, 48, 3)
try:
self._model = tf.saved_model.load(model_dir)
except OSError as exc:
raise IOError(f"[TransNetV2] It seems that files in {model_dir} are corrupted or missing. "
f"Re-download them manually and retry. For more info, see: "
f"https://github.com/soCzech/TransNetV2/issues/1#issuecomment-647357796") from exc

def predict_raw(self, frames: np.ndarray):
assert len(frames.shape) == 5 and frames.shape[2:] == self._input_size, \
"[TransNetV2] Input shape must be [batch, frames, height, width, 3]."
frames = tf.cast(frames, tf.float32)

logits, dict_ = self._model(frames)
single_frame_pred = tf.sigmoid(logits)
all_frames_pred = tf.sigmoid(dict_["many_hot"])

return single_frame_pred, all_frames_pred

def predict_frames(self, frames: np.ndarray):
assert len(frames.shape) == 4 and frames.shape[1:] == self._input_size, \
"[TransNetV2] Input shape must be [frames, height, width, 3]."

def input_iterator():
# return windows of size 100 where the first/last 25 frames are from the previous/next batch
# the first and last window must be padded by copies of the first and last frame of the video
no_padded_frames_start = 25
no_padded_frames_end = 25 + 50 - (len(frames) % 50 if len(frames) % 50 != 0 else 50) # 25 - 74

start_frame = np.expand_dims(frames[0], 0)
end_frame = np.expand_dims(frames[-1], 0)
padded_inputs = np.concatenate(
([start_frame] * no_padded_frames_start) + [frames] + ([end_frame] * no_padded_frames_end), 0
)

ptr = 0
while ptr + 100 <= len(padded_inputs):
out = padded_inputs[ptr:ptr + 100]
ptr += 50
yield out[np.newaxis]

predictions = []

# 在循环外部初始化处理帧数的计数器
processed_frames_count = 0

for inp in input_iterator():
single_frame_pred, all_frames_pred = self.predict_raw(inp)
predictions.append((single_frame_pred.numpy()[0, 25:75, 0],
all_frames_pred.numpy()[0, 25:75, 0]))

processed_frames_count = min(len(predictions) * 50, len(frames))
print("\r[TransNetV2] Processing video frames {}/{}".format(
processed_frames_count, len(frames)
), end="")

# 确保即使视频帧数少于一个批次,也会打印最终的换行符
if len(frames) > 0: # 只有在有帧的情况下才打印换行,避免空视频也打印
print("")

single_frame_pred_list = [single_ for single_, all_ in predictions]
all_frames_pred_list = [all_ for single_, all_ in predictions]

if not single_frame_pred_list: # 如果 predictions 为空 (例如,视频帧数非常少)
# 根据 frames 的长度创建一个全零的预测数组
# 这确保了即使没有通过迭代器处理任何内容,也会返回正确形状的数组
# (尽管对于非常短的视频,这可能仍然不是理想的,但至少避免了 concatenate 空列表的错误)
print(
f"[TransNetV2 WARNING] No predictions generated from input_iterator for a video of length {len(frames)}. Returning zeros.",
file=sys.stderr)
empty_preds_shape = (len(frames),) # 或者 (0,) 如果 len(frames) 是 0
if len(frames) == 0: # 处理0帧视频的极端情况
return np.array([]), np.array([])
return np.zeros(empty_preds_shape, dtype=np.float32), np.zeros(empty_preds_shape, dtype=np.float32)

single_frame_pred_concatenated = np.concatenate(single_frame_pred_list)
all_frames_pred_concatenated = np.concatenate(all_frames_pred_list)

return single_frame_pred_concatenated[:len(frames)], all_frames_pred_concatenated[
:len(frames)] # remove extra padded frames

def predict_video(self, video_fn: str):
try:
import ffmpeg
except ModuleNotFoundError:
raise ModuleNotFoundError("For `predict_video` function `ffmpeg` needs to be installed in order to extract "
"individual frames from video file. Install `ffmpeg` command line tool and then "
"install python wrapper by `pip install ffmpeg-python`.")

print("[TransNetV2] Extracting frames from {}".format(video_fn))
video_stream, err = None, None # Initialize

try:
process = (
ffmpeg
.input(video_fn)
.output("pipe:", format="rawvideo", pix_fmt="rgb24", s="48x27")
.run_async(pipe_stdout=True, pipe_stderr=True)
)
video_stream, err = process.communicate() # Get output after process finishes

if process.returncode != 0:
print(
f"[TransNetV2 ERROR] ffmpeg process failed for {os.path.basename(video_fn)} with exit code {process.returncode}.",
file=sys.stderr)
if err:
print(f"[TransNetV2 ERROR] ffmpeg stderr:\n{err.decode(errors='ignore')}", file=sys.stderr)
return np.array([]), np.array([]), np.array([]) # Return empty arrays indicating failure

if err: # Even if return code is 0, check stderr
# Not all messages on stderr are errors, but log them as warnings
# Filter out common non-error messages if necessary, or log all for debugging
decoded_err = err.decode(errors='ignore').strip()
if decoded_err: # Only print if there's actual content
print(
f"[TransNetV2 WARNING] ffmpeg stderr for {os.path.basename(video_fn)} (exit code {process.returncode}):\n{decoded_err}",
file=sys.stderr)

if not video_stream:
print(
f"[TransNetV2 ERROR] ffmpeg produced no video stream for {os.path.basename(video_fn)}. Cannot proceed.",
file=sys.stderr)
return np.array([]), np.array([]), np.array([]) # Return empty

except ffmpeg.Error as e_ffmpeg: # Catch ffmpeg's own errors
print(f"[TransNetV2 ERROR] ffmpeg.Error during frame extraction for {os.path.basename(video_fn)}:",
file=sys.stderr)
if hasattr(e_ffmpeg, 'stderr') and e_ffmpeg.stderr:
print(e_ffmpeg.stderr.decode(errors='ignore'), file=sys.stderr)
else:
print(str(e_ffmpeg), file=sys.stderr)
return np.array([]), np.array([]), np.array([]) # Return empty
except Exception as e_generic_ffmpeg: # Catch other potential errors during ffmpeg processing
print(
f"[TransNetV2 ERROR] Generic exception during ffmpeg processing for {os.path.basename(video_fn)}: {e_generic_ffmpeg}",
file=sys.stderr)
return np.array([]), np.array([]), np.array([]) # Return empty

if video_stream is None: # Should have been caught above, but as a safeguard
print(f"[TransNetV2 ERROR] video_stream is None after ffmpeg processing for {os.path.basename(video_fn)}.",
file=sys.stderr)
return np.array([]), np.array([]), np.array([])

try:
video = np.frombuffer(video_stream, np.uint8).reshape([-1, 27, 48, 3])
except ValueError as e_reshape:
print(
f"[TransNetV2 ERROR] Failed to reshape video_stream for {os.path.basename(video_fn)}. Stream length: {len(video_stream)}. Error: {e_reshape}",
file=sys.stderr)
return np.array([]), np.array([]), np.array([])

if video.shape[0] == 0:
print(
f"[TransNetV2 ERROR] Extracted 0 frames from {os.path.basename(video_fn)} after ffmpeg. Cannot proceed.",
file=sys.stderr)
return np.array([]), np.array([]), np.array([]) # Return empty

return (video, *self.predict_frames(video))

@staticmethod
def predictions_to_scenes(predictions: np.ndarray, threshold: float = 0.5):
if predictions is None or predictions.size == 0: # Handle empty or None predictions
print(
"[TransNetV2 DEBUG] predictions_to_scenes received empty or None predictions. Returning empty scenes.",
file=sys.stderr)
return np.array([], dtype=np.int32) # Return empty array of correct type

predictions = (predictions > threshold).astype(np.uint8)

scenes = []
t, t_prev, start = -1, 0, 0
for i, t_current_frame_pred in enumerate(predictions): # Renamed 't' to avoid conflict
if t_prev == 1 and t_current_frame_pred == 0:
start = i
if t_prev == 0 and t_current_frame_pred == 1 and i != 0:
scenes.append([start, i])
t_prev = t_current_frame_pred

# After loop, check if the video ends in a shot (t_prev will be 0 if it ended with a transition, or 1 if it ended mid-shot)
# The original logic for 'if t == 0:' was based on the last prediction value.
# If the last prediction was 0 (meaning it's part of a shot that started earlier)
if t_prev == 0 and len(predictions) > 0: # Ensure there was at least one prediction
# If start is not the beginning of the video and a shot has started
if start < len(predictions): # Make sure start is a valid index
scenes.append([start, len(predictions) - 1]) # Shot goes to the end

# just fix if all predictions are 1 (no transitions found, so one scene from start to end)
# or if all predictions are 0 (also one scene from start to end, after fixing start to 0)
if not scenes and len(predictions) > 0: # If no scenes were appended and there are predictions
print(
"[TransNetV2 DEBUG] No scenes detected by transition logic, assuming single scene for the entire video.",
file=sys.stderr)
return np.array([[0, len(predictions) - 1]], dtype=np.int32)

if not scenes and len(predictions) == 0: # If no predictions at all
print("[TransNetV2 DEBUG] No predictions, returning empty scenes array.", file=sys.stderr)
return np.array([], dtype=np.int32)

return np.array(scenes, dtype=np.int32)

@staticmethod
def visualize_predictions(frames: np.ndarray, predictions):
from PIL import Image, ImageDraw

if frames is None or frames.size == 0:
print("[TransNetV2 WARNING] visualize_predictions received no frames. Skipping visualization.",
file=sys.stderr)
return None # Or a placeholder image

if isinstance(predictions, np.ndarray):
predictions = [predictions]

# Filter out None or empty prediction arrays
valid_predictions = []
for p_arr in predictions:
if p_arr is not None and p_arr.size > 0:
valid_predictions.append(p_arr)

if not valid_predictions:
print(
"[TransNetV2 WARNING] visualize_predictions received no valid prediction arrays. Skipping visualization.",
file=sys.stderr)
return None # Or a placeholder image
predictions = valid_predictions

ih, iw, ic = frames.shape[1:]
width = 25

# pad frames so that length of the video is divisible by width
# pad frames also by len(predictions) pixels in width in order to show predictions
pad_with = width - len(frames) % width if len(frames) % width != 0 else 0
# Ensure pad_with is not negative if len(frames) is 0
pad_with = max(0, pad_with)

# Pad frames, ensuring frames is not empty
if frames.size > 0:
frames = np.pad(frames, [(0, pad_with), (0, 1), (0, len(predictions)), (0, 0)])
else: # Should not happen if caught earlier, but as a safeguard
return None

predictions = [np.pad(x, (0, pad_with)) for x in predictions]
height = len(frames) // width

if height == 0 or width == 0: # Avoid division by zero or empty reshape
print("[TransNetV2 WARNING] Cannot create visualization due to zero height or width after padding.",
file=sys.stderr)
return None

img = frames.reshape([height, width, ih + 1, iw + len(predictions), ic])
img = np.concatenate(np.split(
np.concatenate(np.split(img, height, axis=0), axis=2)[0], width, axis=1 # Corrected axis for split
), axis=2)[0, :-1]

img = Image.fromarray(img)
draw = ImageDraw.Draw(img)

# iterate over all frames
for i, pred_tuple in enumerate(zip(*predictions)): # pred_tuple contains predictions for frame i
x_base, y_base = i % width, i // width # Top-left corner of the frame in the grid
x_offset, y_offset = x_base * (iw + len(predictions)) + iw, y_base * (
ih + 1) + ih - 1 # Bottom-right of frame content, before prediction lines

# we can visualize multiple predictions per single frame
for j, p_value in enumerate(pred_tuple): # j is prediction type index, p_value is its value
color = [0, 0, 0]
# Cycle through R, G, B for different prediction types
color[j % 3] = 255 # Use j for color to distinguish prediction types

value_scaled = round(p_value * (ih - 1)) # Scale prediction to frame height
if value_scaled != 0:
# Draw line upwards from the bottom edge of the frame visualization area
draw.line((x_offset + j, y_offset, x_offset + j, y_offset - value_scaled), fill=tuple(color),
width=1)
return img


def main():
import argparse # Already imported sys

parser = argparse.ArgumentParser()
parser.add_argument("files", type=str, nargs="+", help="path to video files to process")
parser.add_argument("--weights", type=str, default=None,
help="path to TransNet V2 weights, tries to infer the location if not specified")
parser.add_argument('--visualize', action="store_true",
help="save a png file with prediction visualization for each extracted video")
args = parser.parse_args()

try:
model = TransNetV2(args.weights)
except Exception as e_model_load:
print(f"[TransNetV2 CRITICAL] Failed to load TransNetV2 model: {e_model_load}", file=sys.stderr)
sys.exit(1) # Exit if model fails to load

for file in args.files:
print(f"[TransNetV2 INFO] Processing file: {file}", file=sys.stderr)
# This pre-check is fine, but the calling script test_and_cut_video.py now also does pre-cleanup
if os.path.exists(file + ".predictions.txt") or os.path.exists(file + ".scenes.txt"):
print(f"[TransNetV2] {file}.predictions.txt or {file}.scenes.txt already exists. "
f"Skipping video {file}.", file=sys.stderr)
continue

try:
video_frames, single_frame_predictions, all_frame_predictions = \
model.predict_video(file)

# --- Debugging: Check outputs of predict_video ---
if video_frames is None or video_frames.size == 0:
print(
f"[TransNetV2 ERROR] predict_video returned no frames for {os.path.basename(file)}. Cannot save outputs.",
file=sys.stderr)
continue # Skip to next file
if single_frame_predictions is None or single_frame_predictions.size == 0:
print(
f"[TransNetV2 ERROR] predict_video returned no single_frame_predictions for {os.path.basename(file)}. Cannot save outputs.",
file=sys.stderr)
continue # Skip to next file
# all_frame_predictions can sometimes be legitimately empty if single_frame_predictions is also empty.

print(f"[TransNetV2 DEBUG] Returned from predict_video for {os.path.basename(file)}.", file=sys.stderr)
print(f"[TransNetV2 DEBUG] video_frames shape: {video_frames.shape}", file=sys.stderr)
print(f"[TransNetV2 DEBUG] single_frame_predictions shape: {single_frame_predictions.shape}",
file=sys.stderr)
print(f"[TransNetV2 DEBUG] all_frame_predictions shape: {all_frame_predictions.shape}", file=sys.stderr)
# --- End Debugging ---

predictions = np.stack([single_frame_predictions, all_frame_predictions], 1)
predictions_filepath = file + ".predictions.txt"
try:
np.savetxt(predictions_filepath, predictions, fmt="%.6f")
print(f"[TransNetV2 DEBUG] Attempted to save predictions to {predictions_filepath}", file=sys.stderr)
if not os.path.exists(predictions_filepath):
print(
f"[TransNetV2 CRITICAL DEBUG] Saved predictions but file {predictions_filepath} does NOT exist!",
file=sys.stderr)
else:
print(f"[TransNetV2 DEBUG] File {predictions_filepath} successfully created.", file=sys.stderr)
except Exception as e_pred_save:
print(f"[TransNetV2 ERROR] Exception saving predictions file {predictions_filepath}: {e_pred_save}",
file=sys.stderr)
continue # Skip to next file if saving predictions fails

scenes = model.predictions_to_scenes(single_frame_predictions)
print(
f"[TransNetV2 DEBUG] Scenes array for {os.path.basename(file)} (shape: {scenes.shape if scenes is not None else 'None'}):\n{scenes}",
file=sys.stderr)

if scenes is None or scenes.size == 0: # Check if scenes array is empty
print(f"[TransNetV2 WARNING] Scenes array is empty or None for {os.path.basename(file)}. "
f"No .scenes.txt will be saved, or it might be empty.", file=sys.stderr)
# Depending on desired behavior, you might 'continue' here or let np.savetxt handle an empty array.
# np.savetxt with an empty array will create an empty file.
# If an empty scenes file is problematic for the parent script, handle it here.
# For now, let it try to save, which will result in an empty file if scenes is empty.

scenes_filepath = file + ".scenes.txt"
try:
np.savetxt(scenes_filepath, scenes, fmt="%d")
print(f"[TransNetV2 DEBUG] Attempted to save scenes to {scenes_filepath}", file=sys.stderr)
if os.path.exists(scenes_filepath):
print(f"[TransNetV2 DEBUG] File {scenes_filepath} successfully created.", file=sys.stderr)
# Optionally, check file size or content for empty scenes
if scenes is not None and scenes.size == 0 and os.path.getsize(scenes_filepath) == 0:
print(f"[TransNetV2 DEBUG] {scenes_filepath} is empty as expected for empty scenes array.",
file=sys.stderr)
elif scenes is not None and scenes.size > 0 and os.path.getsize(scenes_filepath) == 0:
print(
f"[TransNetV2 WARNING] {scenes_filepath} is unexpectedly empty despite non-empty scenes array.",
file=sys.stderr)

else:
print(
f"[TransNetV2 CRITICAL DEBUG] Saved scenes but file {scenes_filepath} does NOT exist afterwards!",
file=sys.stderr)
except Exception as e_save:
print(f"[TransNetV2 ERROR] Exception during np.savetxt for .scenes.txt ({scenes_filepath}): {e_save}",
file=sys.stderr)
continue # Skip to next file if saving scenes fails

if args.visualize:
vis_filepath = file + ".vis.png"
if os.path.exists(vis_filepath):
print(f"[TransNetV2] {vis_filepath} already exists. "
f"Skipping visualization of video {file}.", file=sys.stderr)
# continue # This continue was inside the loop for 'file in args.files'
else:
print(f"[TransNetV2 DEBUG] Attempting to visualize predictions for {os.path.basename(file)}",
file=sys.stderr)
pil_image = model.visualize_predictions(
video_frames, predictions=(single_frame_predictions, all_frame_predictions))

if pil_image:
try:
pil_image.save(vis_filepath)
print(f"[TransNetV2 DEBUG] Saved visualization to {vis_filepath}", file=sys.stderr)
except Exception as e_vis_save:
print(f"[TransNetV2 ERROR] Exception saving visualization {vis_filepath}: {e_vis_save}",
file=sys.stderr)
else:
print(f"[TransNetV2 WARNING] Visualization not generated for {os.path.basename(file)}.",
file=sys.stderr)

except Exception as e_video_processing:
print(f"[TransNetV2 ERROR] Unhandled exception during processing of video {file}: {e_video_processing}",
file=sys.stderr)
# Optionally, re-raise or sys.exit(1) if this should halt the script
# For now, it will just print the error and attempt the next file.

print("[TransNetV2 INFO] Finished processing all files.", file=sys.stderr)


if __name__ == "__main__":
main()

调用推理脚本处理视频

  • benchmark

img

  • 处理输入视频为单一场景 检测视频转场并分割

自动跳过切割后视频前后500ms 去除转场动画

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469

# -*- coding: utf-8 -*-
"""
Time: 2025/6/7 19:00 # 修改为实际时间
Author: ZhaoQi Cao(Clara)
Version: V 1.4 # 版本更新: 修正 parse_scenes_file 函数以正确解析场景文件格式
File: test_and_cut_video.py
date: 2025/6/7 # 修改为实际日期
Describe: Write during the python at Tianjin
GitHub link: https://github.com/caozhaoqi
Blog link: https://caozhaoqi.github.io
WeChat Official Account: 码间拾遗(Code Snippets)
Power by macOS on Mac mini m4(2024)
"""
import os
import subprocess
import argparse
import cv2 # For getting total frames
import logging
from pathlib import Path # 用于更方便地处理路径
from tqdm import tqdm # 用于进度条
import shutil # 用于移动文件

# --- 配置日志 ---
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)

# 支持的视频文件扩展名
SUPPORTED_VIDEO_EXTENSIONS = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv']


def find_video_files(input_dir, output_dir_to_skip=None): # 增加一个可选参数
"""
Recursively finds all video files in the given directory,
optionally skipping a specified output directory.
"""
video_files = []
logger.info(f"Searching for video files in: {input_dir}")
# 规范化 output_dir_to_skip 以进行可靠的路径比较
normalized_output_skip_path = None
if output_dir_to_skip:
normalized_output_skip_path = os.path.normpath(os.path.abspath(output_dir_to_skip))

for root, dirs, files in os.walk(input_dir):
current_root_abs = os.path.normpath(os.path.abspath(root))

# 如果当前遍历的目录是输出目录或其子目录,则跳过
if normalized_output_skip_path and current_root_abs.startswith(normalized_output_skip_path):
logger.info(f"Skipping scan of output directory: {root}")
dirs[:] = [] # 清空 dirs 列表,阻止 os.walk 进入此目录的子目录
continue

for file in files:
# 过滤掉 macOS 的 ._* 文件和其他以点开头的隐藏文件
if file.startswith('.'):
continue
if any(file.lower().endswith(ext) for ext in SUPPORTED_VIDEO_EXTENSIONS):
video_files.append(os.path.join(root, file))
logger.info(f"Found {len(video_files)} video file(s).")
return video_files


def run_transnet_inference(video_path, transnet_script_path, model_dir, working_directory="."):
"""
Runs TransNetV2 inference to get shot boundaries.
Returns the path to the '.scenes.txt' file located in the working_directory.
"""
logger.info(f"Running TransNetV2 inference on: {video_path}")
video_filename = os.path.basename(video_path) # e.g., "myvideo.mp4"
base_name = os.path.splitext(video_filename)[0] # e.g., "myvideo" (for target filename in working_dir)
original_video_dir = os.path.dirname(video_path)

# --- Pre-cleanup: 尝试删除原始视频目录中与该视频相关的旧输出文件 ---
# This should target the names transnetv2.py actually creates
files_to_pre_cleanup = [
video_path + ".scenes.txt",
video_path + ".predictions.txt",
video_path + ".vis.png"
]
for old_file_path in files_to_pre_cleanup:
if os.path.exists(old_file_path):
try:
os.remove(old_file_path)
logger.info(f"Pre-emptively removed old file: {old_file_path}")
except OSError as e:
logger.warning(f"Could not pre-emptively remove old file {old_file_path}: {e}")
# --- 预清理结束 ---

command = [
"python", transnet_script_path,
video_path,
"--weights", model_dir
]

try:
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=working_directory)
stdout_bytes, stderr_bytes = process.communicate(timeout=3600)

process_stdout = stdout_bytes.decode(errors='ignore')
process_stderr = stderr_bytes.decode(errors='ignore')

if process.returncode != 0:
logger.error(f"TransNetV2 inference failed for {video_filename}:")
logger.error(f"STDOUT: {process_stdout}")
logger.error(f"STDERR: {process_stderr}")
return None

logger.info(f"TransNetV2 inference successful for {video_filename} (exit code 0).")
if process_stdout.strip():
logger.info(f"TransNetV2 STDOUT for {video_filename}:\n{process_stdout.strip()}")
if process_stderr.strip():
logger.warning(
f"TransNetV2 STDERR for {video_filename} (though exit code was 0):\n{process_stderr.strip()}")

except subprocess.TimeoutExpired:
logger.error(f"TransNetV2 inference timed out for {video_filename}.")
if process:
process.kill()
try:
stdout_bytes, stderr_bytes = process.communicate(timeout=5)
process_stdout = stdout_bytes.decode(errors='ignore')
process_stderr = stderr_bytes.decode(errors='ignore')
if process_stdout.strip():
logger.error(f"STDOUT (on timeout): {process_stdout.strip()}")
if process_stderr.strip():
logger.error(f"STDERR (on timeout): {process_stderr.strip()}")
except Exception as e_comm:
logger.error(f"Error getting output after timeout kill: {e_comm}")
return None
except Exception as e:
logger.error(f"Exception during TransNetV2 inference for {video_filename}: {e}")
return None

# --- 文件查找和移动逻辑 ---
target_scenes_file_in_working_dir = os.path.join(working_directory, f"{base_name}.scenes.txt")
target_predictions_file_in_working_dir = os.path.join(working_directory, f"{base_name}.predictions.txt")

expected_scenes_file_in_original_dir = video_path + ".scenes.txt"
expected_predictions_file_in_original_dir = video_path + ".predictions.txt"

scenes_file_found_path = None

if os.path.exists(target_scenes_file_in_working_dir):
logger.info(f"Found scenes file directly in working directory: {target_scenes_file_in_working_dir}")
scenes_file_found_path = target_scenes_file_in_working_dir
elif os.path.exists(expected_scenes_file_in_original_dir):
logger.info(f"Found scenes file in original video directory: {expected_scenes_file_in_original_dir}")
try:
shutil.move(expected_scenes_file_in_original_dir, target_scenes_file_in_working_dir)
logger.info(f"Moved scenes file to: {target_scenes_file_in_working_dir}")
scenes_file_found_path = target_scenes_file_in_working_dir

if os.path.exists(expected_predictions_file_in_original_dir):
shutil.move(expected_predictions_file_in_original_dir, target_predictions_file_in_working_dir)
logger.info(f"Moved predictions file to: {target_predictions_file_in_working_dir}")
except Exception as e:
logger.error(
f"Failed to move scenes/predictions file from '{expected_scenes_file_in_original_dir}' to '{target_scenes_file_in_working_dir}': {e}")
return None
else:
logger.error(f"Scenes file not found in working directory ({target_scenes_file_in_working_dir}) "
f"nor in original video directory ({expected_scenes_file_in_original_dir}).")
return None

return scenes_file_found_path


def get_total_frames(video_path):
"""Gets the total number of frames in a video."""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
# Log OpenCV error if available (requires newer OpenCV versions for good messages)
# For older versions, this might not give much info.
# cv_error = cv2.getErrorMsg() if hasattr(cv2, 'getErrorMsg') else "OpenCV error"
# logger.error(f"Could not open video: {video_path}. OpenCV: {cv_error}")
logger.error(f"Could not open video: {video_path}")
return None
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
return total_frames


def parse_scenes_file(scenes_file_path, total_frames):
"""
Parses the .scenes.txt file which contains start and end frames for each shot,
one shot per line, space-separated.
Returns a list of (start_frame, end_frame) tuples for each shot.
"""
if not os.path.exists(scenes_file_path):
logger.error(f"Scenes file not found for parsing: {scenes_file_path}")
return []

shots = []
try:
with open(scenes_file_path, 'r', encoding='utf-8') as f:
for line_number, line in enumerate(f, 1):
stripped_line = line.strip()
if not stripped_line: # Skip empty lines
continue

parts = stripped_line.split()
if len(parts) == 2 and parts[0].isdigit() and parts[1].isdigit():
start_frame = int(parts[0])
end_frame = int(parts[1]) # This is the end frame of the shot (inclusive)

# Validate frames against total_frames
if start_frame < 0:
logger.warning(
f"Invalid start_frame {start_frame} < 0 in {scenes_file_path} line {line_number}. Clamping to 0.")
start_frame = 0

# end_frame from TransNetV2 is inclusive and 0-indexed.
# If total_frames is N, valid frames are 0 to N-1.
if end_frame >= total_frames and total_frames > 0:
logger.warning(
f"End_frame {end_frame} from scenes file is >= total_frames {total_frames} in {scenes_file_path} line {line_number}. Clamping to {total_frames - 1}.")
end_frame = total_frames - 1
elif total_frames == 0 and end_frame > 0:
logger.warning(
f"Invalid end_frame {end_frame} for video with 0 total_frames in {scenes_file_path} line {line_number}. Skipping shot.")
continue
elif total_frames == 0 and end_frame == 0 and start_frame == 0: # Special case for 0-frame video if it somehow yields a (0,0) shot
pass # Allow (0,0) for a 0-frame video if that's a possible output

if start_frame > end_frame:
logger.warning(
f"Invalid shot (start_frame {start_frame} > end_frame {end_frame}) in {scenes_file_path} line {line_number}. Skipping shot.")
continue

shots.append((start_frame, end_frame))
else:
# This warning will catch lines that are not two integers.
logger.warning(
f"Malformed line in scenes file '{scenes_file_path}' line {line_number}: '{stripped_line}'. Expected two integers separated by space.")

except Exception as e:
logger.error(f"Error reading or parsing scenes file {scenes_file_path}: {e}")
return []

if not shots:
logger.warning(
f"No valid shots derived from scenes file: {scenes_file_path}. Assuming single shot for the entire video.")
if total_frames > 0:
shots.append((0, total_frames - 1))
# If total_frames is 0, shots will remain empty, which is correct.
else:
# Sort shots by start frame, just in case they are not ordered in the file
shots.sort(key=lambda x: x[0])
logger.info(
f"Detected {len(shots)} shots for {os.path.basename(scenes_file_path)} (start_frame, end_frame_inclusive): {shots}")

return shots


def cut_video_into_shots(video_path, shots, video_output_dir, padding_ms=500):
"""
Cuts the video into shots using ffmpeg based on frame numbers.
The start and end of each shot are REDUCED by padding_ms.
(Note: The 'padding_ms' parameter is used here as a reduction amount).
Saves shots into the video_specific output directory.
"""
video_basename = Path(video_path).stem
logger.info(f"Starting to cut {len(shots)} shots for video: {video_basename}, REDUCING each end by {padding_ms}ms.")

# Get video properties (FPS and total frames) once
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
logger.error(f"Could not open video {video_path} for properties. Skipping cutting.")
return
fps = cap.get(cv2.CAP_PROP_FPS)
video_total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()

if not fps > 0:
logger.error(f"Could not determine FPS for {video_path}. Skipping reduction cut.")
return
if video_total_frames == 0:
logger.error(f"Video {video_path} has 0 frames. Skipping cutting.")
return

# Convert the reduction amount from milliseconds to frames
reduction_frames = round(padding_ms / 1000 * fps)

for i, (start_frame, end_frame) in enumerate(tqdm(shots, desc=f"Cutting shots for {video_basename}", unit="shot")):
start_frame = int(start_frame)
end_frame = int(end_frame) # Inclusive end frame from parse_scenes_file

# Calculate new target start and end frames after reduction
target_start_frame = start_frame + reduction_frames
target_end_frame = end_frame - reduction_frames # This will also be an inclusive frame index

# Check if the shot is too short for the reduction, making it invalid
if target_end_frame < target_start_frame:
logger.warning(
f"Shot {i + 1} (original: {start_frame}-{end_frame}) is too short to apply {padding_ms}ms reduction "
f"from both ends (would result in invalid segment: {target_start_frame}-{target_end_frame}). "
f"Original duration: {(end_frame - start_frame + 1) / fps:.2f}s. Skipping this shot."
)
continue

# Clamp the target frames to the video's actual boundaries
# (0 to video_total_frames - 1)
final_start_frame = max(0, target_start_frame)
final_start_frame = min(final_start_frame, video_total_frames - 1) # Ensure start isn't past video end

final_end_frame = min(video_total_frames - 1, target_end_frame)
final_end_frame = max(0, final_end_frame) # Ensure end isn't before video start

# Re-check validity after clamping.
# This handles cases where clamping itself might make the segment invalid.
if final_end_frame < final_start_frame:
logger.warning(
f"Shot {i + 1} (original: {start_frame}-{end_frame}) became invalid after reduction and clamping to video boundaries: "
f"intended reduced ({target_start_frame}-{target_end_frame}), "
f"clamped to ({final_start_frame}-{final_end_frame}). Skipping this shot."
)
continue

output_shot_filename = f"{video_basename}_shot_{i + 1:03d}.mp4"
output_shot_path = os.path.join(video_output_dir, output_shot_filename)

logger.debug(
f"Preparing to cut shot {i + 1}: original ({start_frame}-{end_frame}), "
f"target reduced frames ({final_start_frame}-{final_end_frame}) -> {output_shot_path}")

ffmpeg_command = [
"ffmpeg",
"-loglevel", "error",
"-i", video_path,
"-vf", f"select='between(n,{final_start_frame},{final_end_frame})',setpts=PTS-STARTPTS",
"-af", f"aselect='between(n,{final_start_frame},{final_end_frame})',asetpts=PTS-STARTPTS",
"-c:v", "libx264",
"-preset", "ultrafast",
"-crf", "23",
"-c:a", "aac",
"-y", # Overwrite output files without asking
output_shot_path
]

try:
process = subprocess.Popen(ffmpeg_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout_bytes, stderr_bytes = process.communicate(timeout=300)

stdout = stdout_bytes.decode(errors='ignore')
stderr = stderr_bytes.decode(errors='ignore')

if process.returncode != 0:
logger.error(f"ffmpeg failed for shot {output_shot_filename}:")
if stderr.strip(): logger.error(f"FFMPEG STDERR: {stderr.strip()}")
if stdout.strip(): logger.error(f"FFMPEG STDOUT: {stdout.strip()}")
else:
logger.info(f"Successfully created {output_shot_filename}")

except subprocess.TimeoutExpired:
logger.error(f"ffmpeg timed out for shot {output_shot_filename}.")
if process:
process.kill()
try:
stdout_bytes, stderr_bytes = process.communicate(timeout=5)
stdout = stdout_bytes.decode(errors='ignore')
stderr = stderr_bytes.decode(errors='ignore')
if stderr.strip(): logger.error(f"FFMPEG STDERR (on timeout kill): {stderr.strip()}")
if stdout.strip(): logger.error(f"FFMPEG STDOUT (on timeout kill): {stdout.strip()}")
except Exception as e_comm_kill:
logger.error(f"Error getting output after ffmpeg timeout kill: {e_comm_kill}")
except Exception as e:
logger.error(f"Exception during ffmpeg processing for shot {output_shot_filename}: {e}")

logger.info(f"Finished cutting {len(shots)} shots for video: {video_basename}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Process videos using TransNetV2 to detect shot boundaries and cut the videos into shots.")
parser.add_argument("input_source", help="Path to a directory containing video files to process.")
parser.add_argument("--transnet_script", required=False, default="inference/transnetv2.py",
help="Path to the transnetv2.py script.")
parser.add_argument("--model_dir", required=False, default="inference/transnetv2-weights",
help="Path to the TransNetV2 model directory.")
parser.add_argument("--output_dir", required=False, default="./r",
help="Base directory to output processed videos. A subdirectory will be created for each video.")
parser.add_argument("--log_file", required=False, default="my_processing_log.txt",
help="Path to the log file.")

args = parser.parse_args()

# --- 日志文件设置 ---
log_file_path = args.log_file
log_dir = os.path.dirname(log_file_path) # 获取日志文件所在目录
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir) # 如果日志目录不存在,则创建它

# 创建一个file handler,将日志写入文件
file_handler = logging.FileHandler(log_file_path, encoding='utf-8')
file_handler.setLevel(logging.INFO) # 设置日志级别为INFO

# 创建一个formatter并将其添加到handler
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
file_handler.setFormatter(formatter)

# 将handler添加到logger
logger.addHandler(file_handler)

logger.info("Script started.")
logger.info(f"Arguments: {args}")

input_source = args.input_source
transnet_script = args.transnet_script
model_dir = args.model_dir
output_dir = args.output_dir

# 确保输出目录存在
if not os.path.exists(output_dir):
os.makedirs(output_dir)
logger.info(f"Created base output directory: {output_dir}")
else:
logger.info(f"Base output directory already exists: {output_dir}")

video_files = find_video_files(input_source, output_dir_to_skip=output_dir) # 排除 output_dir
if not video_files:
logger.warning("No video files found to process.")
logger.info("All processing complete.")
exit()

logger.info(f"Processing {len(video_files)} video(s)...")

for video_path in video_files:
try:
# --- 视频特定处理开始 ---
video_basename = Path(video_path).stem
logger.info(f"--- Starting processing for video: {video_path} ---")

# 为当前视频创建处理目录和输出目录
video_output_dir = os.path.join(output_dir, video_basename)
if not os.path.exists(video_output_dir):
os.makedirs(video_output_dir)
logger.info(f"Created processing/output directory for this video: {video_output_dir}")
else:
logger.info(f"Processing/output directory already exists: {video_output_dir}")

total_frames = get_total_frames(video_path)
if total_frames is None or total_frames == 0:
logger.error(
f"Cannot process video {video_path} as it has no frames or could not be read. Skipping.")
continue # Skip to the next video
logger.info(f"Total frames in {os.path.basename(video_path)}: {total_frames}")

# 运行 TransNetV2 推理
scenes_file = run_transnet_inference(video_path, transnet_script, model_dir,
working_directory=video_output_dir)

if scenes_file is None:
logger.error(f"Failed to get scenes file for {video_path}. Skipping this video.")
continue # Skip to the next video

# 解析场景文件
shots = parse_scenes_file(scenes_file, total_frames)

# 将视频剪切成镜头
cut_video_into_shots(video_path, shots, video_output_dir, padding_ms=500)

except Exception as e:
logger.error(f"An unexpected error occurred while processing {video_path}: {e}")

logger.info("All processing complete.")

  • 检测输出结果

img

See

Next:
场景检测性能优化方向与测试方法