import os import cv2 import numpy as np from flask import Flask, request, jsonify, send_from_directory, render_template from flask_cors import CORS from werkzeug.utils import secure_filename from dotenv import load_dotenv import torch from segment_anything import SamPredictor, sam_model_registry import tempfile import base64 from io import BytesIO from PIL import Image # Load environment variables load_dotenv() app = Flask(__name__) CORS(app) # Configuration app.config['UPLOAD_FOLDER'] = os.getenv('UPLOAD_FOLDER', 'uploads') app.config['SEGMENTED_FOLDER'] = os.getenv('SEGMENTED_FOLDER', 'segmented') app.config['ALLOWED_EXTENSIONS'] = set(os.getenv('ALLOWED_EXTENSIONS', '.mp4,.avi,.mov,.mkv').split(',')) # Ensure directories exist os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) os.makedirs(app.config['SEGMENTED_FOLDER'], exist_ok=True) # Initialize SAM2 model def initialize_sam2(model_size="vit_b"): """Initialize the SAM2 model""" print(f"Initializing SAM2 model ({model_size})...") try: from segment_anything import SamPredictor, sam_model_registry except ImportError: raise ImportError( "SAM is not installed. Please install it manually from GitHub:\n" "git clone https://github.com/facebookresearch/segment-anything.git\n" "cd segment-anything\n" "pip install -e .\n" "Then download the model checkpoint and place it in the root directory." ) # Map model sizes to checkpoint files model_configs = { "vit_h": { "checkpoint": "sam_vit_h_4b8939.pth", "model_type": "vit_h" }, "vit_l": { "checkpoint": "sam_vit_l_0b3195.pth", "model_type": "vit_l" }, "vit_b": { "checkpoint": "sam_vit_b_01ec64.pth", "model_type": "vit_b" } } if model_size not in model_configs: raise ValueError(f"Unknown model size: {model_size}. Choose from: vit_h, vit_l, vit_b") config = model_configs[model_size] sam_checkpoint = config["checkpoint"] model_type = config["model_type"] # Check if checkpoint file exists if not os.path.exists(sam_checkpoint): raise FileNotFoundError( f"SAM2 checkpoint file '{sam_checkpoint}' not found. " f"Please download it from https://github.com/facebookresearch/segment-anything " f"and place it in the root directory." ) device = "cuda" if torch.cuda.is_available() else "cpu" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) predictor = SamPredictor(sam) print(f"SAM2 model ({model_type}) initialized on {device}") return predictor # Global predictor instance sam_predictor = None # Configuration for SAM model SAM_MODEL_SIZE = os.getenv('SAM_MODEL_SIZE', 'vit_b') # Read from .env or default to ViT-B def allowed_file(filename): """Check if file has allowed extension""" if '.' not in filename: return False # Get the file extension with dot (e.g., '.mp4') file_extension = '.' + filename.rsplit('.', 1)[1].lower() # Debug print(f"🔍 Checking extension: {file_extension}") print(f"📋 Allowed extensions: {app.config['ALLOWED_EXTENSIONS']}") return file_extension in app.config['ALLOWED_EXTENSIONS'] @app.route('/') def index(): """Main page""" return render_template('index.html') @app.route('/test') def test(): """Test route""" return jsonify({'status': 'ok', 'message': 'Flask app is running'}) @app.route('/upload', methods=['POST']) def upload_video(): """Handle video upload""" print("📤 Upload request received") if 'file' not in request.files: print("❌ No file part in request") return jsonify({'error': 'No file part'}), 400 file = request.files['file'] if file.filename == '': print("❌ No selected file") return jsonify({'error': 'No selected file'}), 400 print(f"📁 File received: {file.filename}") print(f"📊 File size: {len(file.read())} bytes") file.seek(0) # Reset file pointer after reading # Debug file extension filename = secure_filename(file.filename) file_extension = filename.rsplit('.', 1)[1].lower() if '.' in filename else '' print(f"🔍 File extension: .{file_extension}") print(f"📋 Allowed extensions: {app.config['ALLOWED_EXTENSIONS']}") if file and allowed_file(file.filename): filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(filepath) print(f"✅ File saved: {filepath}") # Extract first frame for preview preview_frame = extract_first_frame(filepath) if preview_frame is None: print("⚠️ Could not extract preview frame, using placeholder") return jsonify({ 'message': 'File uploaded successfully (no preview available)', 'filename': filename, 'preview': None }) print("🖼️ Preview frame extracted successfully") return jsonify({ 'message': 'File uploaded successfully', 'filename': filename, 'preview': preview_frame }) else: print(f"❌ File type not allowed: {file_extension}") return jsonify({'error': f'File type .{file_extension} not allowed. Allowed types: {app.config["ALLOWED_EXTENSIONS"]}'}), 400 def extract_first_frame(video_path): """Extract first frame from video""" try: # Check if file exists if not os.path.exists(video_path): print(f"❌ Video file not found: {video_path}") return None cap = cv2.VideoCapture(video_path) # Check if video opened successfully if not cap.isOpened(): print(f"❌ Could not open video file: {video_path}") return None ret, frame = cap.read() cap.release() if ret and frame is not None: # Convert to base64 for easy transfer success, buffer = cv2.imencode('.jpg', frame) if success: frame_base64 = base64.b64encode(buffer).decode('utf-8') print(f"✅ Successfully extracted first frame from {video_path}") return frame_base64 else: print(f"❌ Failed to encode frame as JPEG") return None else: print(f"❌ Could not read first frame from {video_path}") return None except Exception as e: print(f"❌ Error extracting first frame: {e}") return None @app.route('/segment', methods=['POST']) def segment_object(): """Handle object segmentation""" print("🎯 Segment request received") global sam_predictor if sam_predictor is None: print("🔧 Initializing SAM model...") sam_predictor = initialize_sam2(SAM_MODEL_SIZE) data = request.json print(f"📥 Received data: {data}") if not data or 'filename' not in data or 'points' not in data: print("❌ Missing required parameters") return jsonify({'error': 'Missing required parameters'}), 400 filename = data['filename'] points = data['points'] # Expecting [[x1, y1], [x2, y2], ...] video_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) if not os.path.exists(video_path): return jsonify({'error': 'Video file not found'}), 404 try: # Process the video output_path = process_video_segmentation(video_path, points) output_filename = os.path.basename(output_path) print(f"✅ Segmentation completed: {output_filename}") print(f"📁 Output file path: {output_path}") print(f"🔍 File exists: {os.path.exists(output_path)}") print(f"📊 File size: {os.path.getsize(output_path)} bytes") return jsonify({ 'message': 'Segmentation completed', 'output_filename': output_filename, 'debug_file_path': output_path, 'debug_file_exists': os.path.exists(output_path) }) except Exception as e: return jsonify({'error': str(e)}), 500 def process_video_segmentation(video_path, points): """Process video segmentation using SAM2""" global sam_predictor # Create output filename base_name = os.path.splitext(os.path.basename(video_path))[0] output_filename = f"{base_name}_segmented.mp4" output_path = os.path.join(app.config['SEGMENTED_FOLDER'], output_filename) # Open video cap = cv2.VideoCapture(video_path) # Get video properties fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Create video writer fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) frame_count = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break frame_count += 1 print(f"Processing frame {frame_count}/{total_frames}") # Convert frame to RGB (SAM expects RGB) frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Set image for SAM sam_predictor.set_image(frame_rgb) # Convert points to numpy array input_points = np.array(points) input_labels = np.array([1] * len(points)) # 1 means foreground point # Get masks masks, scores, logits = sam_predictor.predict( point_coords=input_points, point_labels=input_labels, multimask_output=False ) # Create mask from the best prediction mask = masks[0].astype(np.uint8) * 255 # Apply mask to frame (simple approach - you can customize this) masked_frame = apply_mask_to_frame(frame, mask) # Write frame out.write(masked_frame) cap.release() out.release() return output_path def apply_mask_to_frame(frame, mask): """Apply mask to frame - simple implementation""" # Create a colored version of the mask (red overlay) colored_mask = np.zeros_like(frame) colored_mask[:, :, 2] = mask # Red channel # Blend the mask with the original frame alpha = 0.5 result = cv2.addWeighted(frame, 1 - alpha, colored_mask, alpha, 0) return result @app.route('/download/') def download_file(filename): """Download segmented video""" return send_from_directory( app.config['SEGMENTED_FOLDER'], filename, as_attachment=True ) @app.route('/preview/') def preview_video(filename): """Preview original video""" return send_from_directory( app.config['UPLOAD_FOLDER'], filename ) @app.route('/segmented/') def serve_segmented_video(filename): """Serve segmented video with proper range request support""" file_path = os.path.join(app.config['SEGMENTED_FOLDER'], filename) print(f"🎬 Video request for: {filename}") print(f"📁 Looking for file at: {file_path}") print(f"🔍 File exists: {os.path.exists(file_path)}") if not os.path.exists(file_path): print(f"❌ File not found: {file_path}") return jsonify({'error': f'File {filename} not found'}), 404 print(f"✅ Serving file: {file_path}") # Use send_from_directory with proper MIME type for video return send_from_directory( app.config['SEGMENTED_FOLDER'], filename, conditional=True, mimetype='video/mp4' ) if __name__ == '__main__': app.run(debug=True, port=5000)