Files
censorAll/app.py
2025-12-11 23:26:21 +01:00

365 lines
12 KiB
Python

import os
import cv2
import numpy as np
from flask import Flask, request, jsonify, send_from_directory, render_template
from flask_cors import CORS
from werkzeug.utils import secure_filename
from dotenv import load_dotenv
import torch
from segment_anything import SamPredictor, sam_model_registry
import tempfile
import base64
from io import BytesIO
from PIL import Image
# Load environment variables
load_dotenv()
app = Flask(__name__)
CORS(app)
# Configuration
app.config['UPLOAD_FOLDER'] = os.getenv('UPLOAD_FOLDER', 'uploads')
app.config['SEGMENTED_FOLDER'] = os.getenv('SEGMENTED_FOLDER', 'segmented')
app.config['ALLOWED_EXTENSIONS'] = set(os.getenv('ALLOWED_EXTENSIONS', '.mp4,.avi,.mov,.mkv').split(','))
# Ensure directories exist
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['SEGMENTED_FOLDER'], exist_ok=True)
# Initialize SAM2 model
def initialize_sam2(model_size="vit_b"):
"""Initialize the SAM2 model"""
print(f"Initializing SAM2 model ({model_size})...")
try:
from segment_anything import SamPredictor, sam_model_registry
except ImportError:
raise ImportError(
"SAM is not installed. Please install it manually from GitHub:\n"
"git clone https://github.com/facebookresearch/segment-anything.git\n"
"cd segment-anything\n"
"pip install -e .\n"
"Then download the model checkpoint and place it in the root directory."
)
# Map model sizes to checkpoint files
model_configs = {
"vit_h": {
"checkpoint": "sam_vit_h_4b8939.pth",
"model_type": "vit_h"
},
"vit_l": {
"checkpoint": "sam_vit_l_0b3195.pth",
"model_type": "vit_l"
},
"vit_b": {
"checkpoint": "sam_vit_b_01ec64.pth",
"model_type": "vit_b"
}
}
if model_size not in model_configs:
raise ValueError(f"Unknown model size: {model_size}. Choose from: vit_h, vit_l, vit_b")
config = model_configs[model_size]
sam_checkpoint = config["checkpoint"]
model_type = config["model_type"]
# Check if checkpoint file exists
if not os.path.exists(sam_checkpoint):
raise FileNotFoundError(
f"SAM2 checkpoint file '{sam_checkpoint}' not found. "
f"Please download it from https://github.com/facebookresearch/segment-anything "
f"and place it in the root directory."
)
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
print(f"SAM2 model ({model_type}) initialized on {device}")
return predictor
# Global predictor instance
sam_predictor = None
# Configuration for SAM model
SAM_MODEL_SIZE = os.getenv('SAM_MODEL_SIZE', 'vit_b') # Read from .env or default to ViT-B
def allowed_file(filename):
"""Check if file has allowed extension"""
if '.' not in filename:
return False
# Get the file extension with dot (e.g., '.mp4')
file_extension = '.' + filename.rsplit('.', 1)[1].lower()
# Debug
print(f"🔍 Checking extension: {file_extension}")
print(f"📋 Allowed extensions: {app.config['ALLOWED_EXTENSIONS']}")
return file_extension in app.config['ALLOWED_EXTENSIONS']
@app.route('/')
def index():
"""Main page"""
return render_template('index.html')
@app.route('/test')
def test():
"""Test route"""
return jsonify({'status': 'ok', 'message': 'Flask app is running'})
@app.route('/upload', methods=['POST'])
def upload_video():
"""Handle video upload"""
print("📤 Upload request received")
if 'file' not in request.files:
print("❌ No file part in request")
return jsonify({'error': 'No file part'}), 400
file = request.files['file']
if file.filename == '':
print("❌ No selected file")
return jsonify({'error': 'No selected file'}), 400
print(f"📁 File received: {file.filename}")
print(f"📊 File size: {len(file.read())} bytes")
file.seek(0) # Reset file pointer after reading
# Debug file extension
filename = secure_filename(file.filename)
file_extension = filename.rsplit('.', 1)[1].lower() if '.' in filename else ''
print(f"🔍 File extension: .{file_extension}")
print(f"📋 Allowed extensions: {app.config['ALLOWED_EXTENSIONS']}")
if file and allowed_file(file.filename):
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
print(f"✅ File saved: {filepath}")
# Extract first frame for preview
preview_frame = extract_first_frame(filepath)
if preview_frame is None:
print("⚠️ Could not extract preview frame, using placeholder")
return jsonify({
'message': 'File uploaded successfully (no preview available)',
'filename': filename,
'preview': None
})
print("🖼️ Preview frame extracted successfully")
return jsonify({
'message': 'File uploaded successfully',
'filename': filename,
'preview': preview_frame
})
else:
print(f"❌ File type not allowed: {file_extension}")
return jsonify({'error': f'File type .{file_extension} not allowed. Allowed types: {app.config["ALLOWED_EXTENSIONS"]}'}), 400
def extract_first_frame(video_path):
"""Extract first frame from video"""
try:
# Check if file exists
if not os.path.exists(video_path):
print(f"❌ Video file not found: {video_path}")
return None
cap = cv2.VideoCapture(video_path)
# Check if video opened successfully
if not cap.isOpened():
print(f"❌ Could not open video file: {video_path}")
return None
ret, frame = cap.read()
cap.release()
if ret and frame is not None:
# Convert to base64 for easy transfer
success, buffer = cv2.imencode('.jpg', frame)
if success:
frame_base64 = base64.b64encode(buffer).decode('utf-8')
print(f"✅ Successfully extracted first frame from {video_path}")
return frame_base64
else:
print(f"❌ Failed to encode frame as JPEG")
return None
else:
print(f"❌ Could not read first frame from {video_path}")
return None
except Exception as e:
print(f"❌ Error extracting first frame: {e}")
return None
@app.route('/segment', methods=['POST'])
def segment_object():
"""Handle object segmentation"""
print("🎯 Segment request received")
global sam_predictor
if sam_predictor is None:
print("🔧 Initializing SAM model...")
sam_predictor = initialize_sam2(SAM_MODEL_SIZE)
data = request.json
print(f"📥 Received data: {data}")
if not data or 'filename' not in data or 'points' not in data:
print("❌ Missing required parameters")
return jsonify({'error': 'Missing required parameters'}), 400
filename = data['filename']
points = data['points'] # Expecting [[x1, y1], [x2, y2], ...]
video_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
if not os.path.exists(video_path):
return jsonify({'error': 'Video file not found'}), 404
try:
# Process the video
output_path = process_video_segmentation(video_path, points)
output_filename = os.path.basename(output_path)
print(f"✅ Segmentation completed: {output_filename}")
print(f"📁 Output file path: {output_path}")
print(f"🔍 File exists: {os.path.exists(output_path)}")
print(f"📊 File size: {os.path.getsize(output_path)} bytes")
return jsonify({
'message': 'Segmentation completed',
'output_filename': output_filename,
'debug_file_path': output_path,
'debug_file_exists': os.path.exists(output_path)
})
except Exception as e:
return jsonify({'error': str(e)}), 500
def process_video_segmentation(video_path, points):
"""Process video segmentation using SAM2"""
global sam_predictor
# Create output filename
base_name = os.path.splitext(os.path.basename(video_path))[0]
output_filename = f"{base_name}_segmented.mp4"
output_path = os.path.join(app.config['SEGMENTED_FOLDER'], output_filename)
# Open video
cap = cv2.VideoCapture(video_path)
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Create video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
print(f"Processing frame {frame_count}/{total_frames}")
# Convert frame to RGB (SAM expects RGB)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Set image for SAM
sam_predictor.set_image(frame_rgb)
# Convert points to numpy array
input_points = np.array(points)
input_labels = np.array([1] * len(points)) # 1 means foreground point
# Get masks
masks, scores, logits = sam_predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False
)
# Create mask from the best prediction
mask = masks[0].astype(np.uint8) * 255
# Apply mask to frame (simple approach - you can customize this)
masked_frame = apply_mask_to_frame(frame, mask)
# Write frame
out.write(masked_frame)
cap.release()
out.release()
return output_path
def apply_mask_to_frame(frame, mask):
"""Apply mask to frame - simple implementation"""
# Create a colored version of the mask (red overlay)
colored_mask = np.zeros_like(frame)
colored_mask[:, :, 2] = mask # Red channel
# Blend the mask with the original frame
alpha = 0.5
result = cv2.addWeighted(frame, 1 - alpha, colored_mask, alpha, 0)
return result
@app.route('/download/<filename>')
def download_file(filename):
"""Download segmented video"""
return send_from_directory(
app.config['SEGMENTED_FOLDER'],
filename,
as_attachment=True
)
@app.route('/preview/<filename>')
def preview_video(filename):
"""Preview original video"""
return send_from_directory(
app.config['UPLOAD_FOLDER'],
filename
)
@app.route('/segmented/<filename>')
def serve_segmented_video(filename):
"""Serve segmented video with proper range request support"""
file_path = os.path.join(app.config['SEGMENTED_FOLDER'], filename)
print(f"🎬 Video request for: {filename}")
print(f"📁 Looking for file at: {file_path}")
print(f"🔍 File exists: {os.path.exists(file_path)}")
if not os.path.exists(file_path):
print(f"❌ File not found: {file_path}")
return jsonify({'error': f'File {filename} not found'}), 404
print(f"✅ Serving file: {file_path}")
# Use send_from_directory with proper MIME type for video
return send_from_directory(
app.config['SEGMENTED_FOLDER'],
filename,
conditional=True,
mimetype='video/mp4'
)
if __name__ == '__main__':
app.run(debug=True, port=5000)