1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
| import os import sys import numpy as np import tensorflow as tf
class TransNetV2:
def __init__(self, model_dir=None): if model_dir is None: model_dir = os.path.join(os.path.dirname(__file__), "transnetv2-weights/") if not os.path.isdir(model_dir): raise FileNotFoundError(f"[TransNetV2] ERROR: {model_dir} is not a directory.") else: print(f"[TransNetV2] Using weights from {model_dir}.")
self._input_size = (27, 48, 3) try: self._model = tf.saved_model.load(model_dir) except OSError as exc: raise IOError(f"[TransNetV2] It seems that files in {model_dir} are corrupted or missing. " f"Re-download them manually and retry. For more info, see: " f"https://github.com/soCzech/TransNetV2/issues/1#issuecomment-647357796") from exc
def predict_raw(self, frames: np.ndarray): assert len(frames.shape) == 5 and frames.shape[2:] == self._input_size, \ "[TransNetV2] Input shape must be [batch, frames, height, width, 3]." frames = tf.cast(frames, tf.float32)
logits, dict_ = self._model(frames) single_frame_pred = tf.sigmoid(logits) all_frames_pred = tf.sigmoid(dict_["many_hot"])
return single_frame_pred, all_frames_pred
def predict_frames(self, frames: np.ndarray): assert len(frames.shape) == 4 and frames.shape[1:] == self._input_size, \ "[TransNetV2] Input shape must be [frames, height, width, 3]."
def input_iterator(): no_padded_frames_start = 25 no_padded_frames_end = 25 + 50 - (len(frames) % 50 if len(frames) % 50 != 0 else 50)
start_frame = np.expand_dims(frames[0], 0) end_frame = np.expand_dims(frames[-1], 0) padded_inputs = np.concatenate( ([start_frame] * no_padded_frames_start) + [frames] + ([end_frame] * no_padded_frames_end), 0 )
ptr = 0 while ptr + 100 <= len(padded_inputs): out = padded_inputs[ptr:ptr + 100] ptr += 50 yield out[np.newaxis]
predictions = []
processed_frames_count = 0
for inp in input_iterator(): single_frame_pred, all_frames_pred = self.predict_raw(inp) predictions.append((single_frame_pred.numpy()[0, 25:75, 0], all_frames_pred.numpy()[0, 25:75, 0]))
processed_frames_count = min(len(predictions) * 50, len(frames)) print("\r[TransNetV2] Processing video frames {}/{}".format( processed_frames_count, len(frames) ), end="")
if len(frames) > 0: print("")
single_frame_pred_list = [single_ for single_, all_ in predictions] all_frames_pred_list = [all_ for single_, all_ in predictions]
if not single_frame_pred_list: print( f"[TransNetV2 WARNING] No predictions generated from input_iterator for a video of length {len(frames)}. Returning zeros.", file=sys.stderr) empty_preds_shape = (len(frames),) if len(frames) == 0: return np.array([]), np.array([]) return np.zeros(empty_preds_shape, dtype=np.float32), np.zeros(empty_preds_shape, dtype=np.float32)
single_frame_pred_concatenated = np.concatenate(single_frame_pred_list) all_frames_pred_concatenated = np.concatenate(all_frames_pred_list)
return single_frame_pred_concatenated[:len(frames)], all_frames_pred_concatenated[ :len(frames)]
def predict_video(self, video_fn: str): try: import ffmpeg except ModuleNotFoundError: raise ModuleNotFoundError("For `predict_video` function `ffmpeg` needs to be installed in order to extract " "individual frames from video file. Install `ffmpeg` command line tool and then " "install python wrapper by `pip install ffmpeg-python`.")
print("[TransNetV2] Extracting frames from {}".format(video_fn)) video_stream, err = None, None
try: process = ( ffmpeg .input(video_fn) .output("pipe:", format="rawvideo", pix_fmt="rgb24", s="48x27") .run_async(pipe_stdout=True, pipe_stderr=True) ) video_stream, err = process.communicate()
if process.returncode != 0: print( f"[TransNetV2 ERROR] ffmpeg process failed for {os.path.basename(video_fn)} with exit code {process.returncode}.", file=sys.stderr) if err: print(f"[TransNetV2 ERROR] ffmpeg stderr:\n{err.decode(errors='ignore')}", file=sys.stderr) return np.array([]), np.array([]), np.array([])
if err: decoded_err = err.decode(errors='ignore').strip() if decoded_err: print( f"[TransNetV2 WARNING] ffmpeg stderr for {os.path.basename(video_fn)} (exit code {process.returncode}):\n{decoded_err}", file=sys.stderr)
if not video_stream: print( f"[TransNetV2 ERROR] ffmpeg produced no video stream for {os.path.basename(video_fn)}. Cannot proceed.", file=sys.stderr) return np.array([]), np.array([]), np.array([])
except ffmpeg.Error as e_ffmpeg: print(f"[TransNetV2 ERROR] ffmpeg.Error during frame extraction for {os.path.basename(video_fn)}:", file=sys.stderr) if hasattr(e_ffmpeg, 'stderr') and e_ffmpeg.stderr: print(e_ffmpeg.stderr.decode(errors='ignore'), file=sys.stderr) else: print(str(e_ffmpeg), file=sys.stderr) return np.array([]), np.array([]), np.array([]) except Exception as e_generic_ffmpeg: print( f"[TransNetV2 ERROR] Generic exception during ffmpeg processing for {os.path.basename(video_fn)}: {e_generic_ffmpeg}", file=sys.stderr) return np.array([]), np.array([]), np.array([])
if video_stream is None: print(f"[TransNetV2 ERROR] video_stream is None after ffmpeg processing for {os.path.basename(video_fn)}.", file=sys.stderr) return np.array([]), np.array([]), np.array([])
try: video = np.frombuffer(video_stream, np.uint8).reshape([-1, 27, 48, 3]) except ValueError as e_reshape: print( f"[TransNetV2 ERROR] Failed to reshape video_stream for {os.path.basename(video_fn)}. Stream length: {len(video_stream)}. Error: {e_reshape}", file=sys.stderr) return np.array([]), np.array([]), np.array([])
if video.shape[0] == 0: print( f"[TransNetV2 ERROR] Extracted 0 frames from {os.path.basename(video_fn)} after ffmpeg. Cannot proceed.", file=sys.stderr) return np.array([]), np.array([]), np.array([])
return (video, *self.predict_frames(video))
@staticmethod def predictions_to_scenes(predictions: np.ndarray, threshold: float = 0.5): if predictions is None or predictions.size == 0: print( "[TransNetV2 DEBUG] predictions_to_scenes received empty or None predictions. Returning empty scenes.", file=sys.stderr) return np.array([], dtype=np.int32)
predictions = (predictions > threshold).astype(np.uint8)
scenes = [] t, t_prev, start = -1, 0, 0 for i, t_current_frame_pred in enumerate(predictions): if t_prev == 1 and t_current_frame_pred == 0: start = i if t_prev == 0 and t_current_frame_pred == 1 and i != 0: scenes.append([start, i]) t_prev = t_current_frame_pred
if t_prev == 0 and len(predictions) > 0: if start < len(predictions): scenes.append([start, len(predictions) - 1])
if not scenes and len(predictions) > 0: print( "[TransNetV2 DEBUG] No scenes detected by transition logic, assuming single scene for the entire video.", file=sys.stderr) return np.array([[0, len(predictions) - 1]], dtype=np.int32)
if not scenes and len(predictions) == 0: print("[TransNetV2 DEBUG] No predictions, returning empty scenes array.", file=sys.stderr) return np.array([], dtype=np.int32)
return np.array(scenes, dtype=np.int32)
@staticmethod def visualize_predictions(frames: np.ndarray, predictions): from PIL import Image, ImageDraw
if frames is None or frames.size == 0: print("[TransNetV2 WARNING] visualize_predictions received no frames. Skipping visualization.", file=sys.stderr) return None
if isinstance(predictions, np.ndarray): predictions = [predictions]
valid_predictions = [] for p_arr in predictions: if p_arr is not None and p_arr.size > 0: valid_predictions.append(p_arr)
if not valid_predictions: print( "[TransNetV2 WARNING] visualize_predictions received no valid prediction arrays. Skipping visualization.", file=sys.stderr) return None predictions = valid_predictions
ih, iw, ic = frames.shape[1:] width = 25
pad_with = width - len(frames) % width if len(frames) % width != 0 else 0 pad_with = max(0, pad_with)
if frames.size > 0: frames = np.pad(frames, [(0, pad_with), (0, 1), (0, len(predictions)), (0, 0)]) else: return None
predictions = [np.pad(x, (0, pad_with)) for x in predictions] height = len(frames) // width
if height == 0 or width == 0: print("[TransNetV2 WARNING] Cannot create visualization due to zero height or width after padding.", file=sys.stderr) return None
img = frames.reshape([height, width, ih + 1, iw + len(predictions), ic]) img = np.concatenate(np.split( np.concatenate(np.split(img, height, axis=0), axis=2)[0], width, axis=1 ), axis=2)[0, :-1]
img = Image.fromarray(img) draw = ImageDraw.Draw(img)
for i, pred_tuple in enumerate(zip(*predictions)): x_base, y_base = i % width, i // width x_offset, y_offset = x_base * (iw + len(predictions)) + iw, y_base * ( ih + 1) + ih - 1
for j, p_value in enumerate(pred_tuple): color = [0, 0, 0] color[j % 3] = 255
value_scaled = round(p_value * (ih - 1)) if value_scaled != 0: draw.line((x_offset + j, y_offset, x_offset + j, y_offset - value_scaled), fill=tuple(color), width=1) return img
def main(): import argparse
parser = argparse.ArgumentParser() parser.add_argument("files", type=str, nargs="+", help="path to video files to process") parser.add_argument("--weights", type=str, default=None, help="path to TransNet V2 weights, tries to infer the location if not specified") parser.add_argument('--visualize', action="store_true", help="save a png file with prediction visualization for each extracted video") args = parser.parse_args()
try: model = TransNetV2(args.weights) except Exception as e_model_load: print(f"[TransNetV2 CRITICAL] Failed to load TransNetV2 model: {e_model_load}", file=sys.stderr) sys.exit(1)
for file in args.files: print(f"[TransNetV2 INFO] Processing file: {file}", file=sys.stderr) if os.path.exists(file + ".predictions.txt") or os.path.exists(file + ".scenes.txt"): print(f"[TransNetV2] {file}.predictions.txt or {file}.scenes.txt already exists. " f"Skipping video {file}.", file=sys.stderr) continue
try: video_frames, single_frame_predictions, all_frame_predictions = \ model.predict_video(file)
if video_frames is None or video_frames.size == 0: print( f"[TransNetV2 ERROR] predict_video returned no frames for {os.path.basename(file)}. Cannot save outputs.", file=sys.stderr) continue if single_frame_predictions is None or single_frame_predictions.size == 0: print( f"[TransNetV2 ERROR] predict_video returned no single_frame_predictions for {os.path.basename(file)}. Cannot save outputs.", file=sys.stderr) continue
print(f"[TransNetV2 DEBUG] Returned from predict_video for {os.path.basename(file)}.", file=sys.stderr) print(f"[TransNetV2 DEBUG] video_frames shape: {video_frames.shape}", file=sys.stderr) print(f"[TransNetV2 DEBUG] single_frame_predictions shape: {single_frame_predictions.shape}", file=sys.stderr) print(f"[TransNetV2 DEBUG] all_frame_predictions shape: {all_frame_predictions.shape}", file=sys.stderr)
predictions = np.stack([single_frame_predictions, all_frame_predictions], 1) predictions_filepath = file + ".predictions.txt" try: np.savetxt(predictions_filepath, predictions, fmt="%.6f") print(f"[TransNetV2 DEBUG] Attempted to save predictions to {predictions_filepath}", file=sys.stderr) if not os.path.exists(predictions_filepath): print( f"[TransNetV2 CRITICAL DEBUG] Saved predictions but file {predictions_filepath} does NOT exist!", file=sys.stderr) else: print(f"[TransNetV2 DEBUG] File {predictions_filepath} successfully created.", file=sys.stderr) except Exception as e_pred_save: print(f"[TransNetV2 ERROR] Exception saving predictions file {predictions_filepath}: {e_pred_save}", file=sys.stderr) continue
scenes = model.predictions_to_scenes(single_frame_predictions) print( f"[TransNetV2 DEBUG] Scenes array for {os.path.basename(file)} (shape: {scenes.shape if scenes is not None else 'None'}):\n{scenes}", file=sys.stderr)
if scenes is None or scenes.size == 0: print(f"[TransNetV2 WARNING] Scenes array is empty or None for {os.path.basename(file)}. " f"No .scenes.txt will be saved, or it might be empty.", file=sys.stderr)
scenes_filepath = file + ".scenes.txt" try: np.savetxt(scenes_filepath, scenes, fmt="%d") print(f"[TransNetV2 DEBUG] Attempted to save scenes to {scenes_filepath}", file=sys.stderr) if os.path.exists(scenes_filepath): print(f"[TransNetV2 DEBUG] File {scenes_filepath} successfully created.", file=sys.stderr) if scenes is not None and scenes.size == 0 and os.path.getsize(scenes_filepath) == 0: print(f"[TransNetV2 DEBUG] {scenes_filepath} is empty as expected for empty scenes array.", file=sys.stderr) elif scenes is not None and scenes.size > 0 and os.path.getsize(scenes_filepath) == 0: print( f"[TransNetV2 WARNING] {scenes_filepath} is unexpectedly empty despite non-empty scenes array.", file=sys.stderr)
else: print( f"[TransNetV2 CRITICAL DEBUG] Saved scenes but file {scenes_filepath} does NOT exist afterwards!", file=sys.stderr) except Exception as e_save: print(f"[TransNetV2 ERROR] Exception during np.savetxt for .scenes.txt ({scenes_filepath}): {e_save}", file=sys.stderr) continue
if args.visualize: vis_filepath = file + ".vis.png" if os.path.exists(vis_filepath): print(f"[TransNetV2] {vis_filepath} already exists. " f"Skipping visualization of video {file}.", file=sys.stderr) else: print(f"[TransNetV2 DEBUG] Attempting to visualize predictions for {os.path.basename(file)}", file=sys.stderr) pil_image = model.visualize_predictions( video_frames, predictions=(single_frame_predictions, all_frame_predictions))
if pil_image: try: pil_image.save(vis_filepath) print(f"[TransNetV2 DEBUG] Saved visualization to {vis_filepath}", file=sys.stderr) except Exception as e_vis_save: print(f"[TransNetV2 ERROR] Exception saving visualization {vis_filepath}: {e_vis_save}", file=sys.stderr) else: print(f"[TransNetV2 WARNING] Visualization not generated for {os.path.basename(file)}.", file=sys.stderr)
except Exception as e_video_processing: print(f"[TransNetV2 ERROR] Unhandled exception during processing of video {file}: {e_video_processing}", file=sys.stderr)
print("[TransNetV2 INFO] Finished processing all files.", file=sys.stderr)
if __name__ == "__main__": main()
|