365 lines
12 KiB
Python
365 lines
12 KiB
Python
import os
|
|
import cv2
|
|
import numpy as np
|
|
from flask import Flask, request, jsonify, send_from_directory, render_template
|
|
from flask_cors import CORS
|
|
from werkzeug.utils import secure_filename
|
|
from dotenv import load_dotenv
|
|
import torch
|
|
from segment_anything import SamPredictor, sam_model_registry
|
|
import tempfile
|
|
import base64
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
app = Flask(__name__)
|
|
CORS(app)
|
|
|
|
# Configuration
|
|
app.config['UPLOAD_FOLDER'] = os.getenv('UPLOAD_FOLDER', 'uploads')
|
|
app.config['SEGMENTED_FOLDER'] = os.getenv('SEGMENTED_FOLDER', 'segmented')
|
|
app.config['ALLOWED_EXTENSIONS'] = set(os.getenv('ALLOWED_EXTENSIONS', '.mp4,.avi,.mov,.mkv').split(','))
|
|
|
|
# Ensure directories exist
|
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
|
os.makedirs(app.config['SEGMENTED_FOLDER'], exist_ok=True)
|
|
|
|
# Initialize SAM2 model
|
|
def initialize_sam2(model_size="vit_b"):
|
|
"""Initialize the SAM2 model"""
|
|
print(f"Initializing SAM2 model ({model_size})...")
|
|
|
|
try:
|
|
from segment_anything import SamPredictor, sam_model_registry
|
|
except ImportError:
|
|
raise ImportError(
|
|
"SAM is not installed. Please install it manually from GitHub:\n"
|
|
"git clone https://github.com/facebookresearch/segment-anything.git\n"
|
|
"cd segment-anything\n"
|
|
"pip install -e .\n"
|
|
"Then download the model checkpoint and place it in the root directory."
|
|
)
|
|
|
|
# Map model sizes to checkpoint files
|
|
model_configs = {
|
|
"vit_h": {
|
|
"checkpoint": "sam_vit_h_4b8939.pth",
|
|
"model_type": "vit_h"
|
|
},
|
|
"vit_l": {
|
|
"checkpoint": "sam_vit_l_0b3195.pth",
|
|
"model_type": "vit_l"
|
|
},
|
|
"vit_b": {
|
|
"checkpoint": "sam_vit_b_01ec64.pth",
|
|
"model_type": "vit_b"
|
|
}
|
|
}
|
|
|
|
if model_size not in model_configs:
|
|
raise ValueError(f"Unknown model size: {model_size}. Choose from: vit_h, vit_l, vit_b")
|
|
|
|
config = model_configs[model_size]
|
|
sam_checkpoint = config["checkpoint"]
|
|
model_type = config["model_type"]
|
|
|
|
# Check if checkpoint file exists
|
|
if not os.path.exists(sam_checkpoint):
|
|
raise FileNotFoundError(
|
|
f"SAM2 checkpoint file '{sam_checkpoint}' not found. "
|
|
f"Please download it from https://github.com/facebookresearch/segment-anything "
|
|
f"and place it in the root directory."
|
|
)
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
|
sam.to(device=device)
|
|
predictor = SamPredictor(sam)
|
|
|
|
print(f"SAM2 model ({model_type}) initialized on {device}")
|
|
return predictor
|
|
|
|
# Global predictor instance
|
|
sam_predictor = None
|
|
|
|
# Configuration for SAM model
|
|
SAM_MODEL_SIZE = os.getenv('SAM_MODEL_SIZE', 'vit_b') # Read from .env or default to ViT-B
|
|
|
|
def allowed_file(filename):
|
|
"""Check if file has allowed extension"""
|
|
if '.' not in filename:
|
|
return False
|
|
|
|
# Get the file extension with dot (e.g., '.mp4')
|
|
file_extension = '.' + filename.rsplit('.', 1)[1].lower()
|
|
|
|
# Debug
|
|
print(f"🔍 Checking extension: {file_extension}")
|
|
print(f"📋 Allowed extensions: {app.config['ALLOWED_EXTENSIONS']}")
|
|
|
|
return file_extension in app.config['ALLOWED_EXTENSIONS']
|
|
|
|
@app.route('/')
|
|
def index():
|
|
"""Main page"""
|
|
return render_template('index.html')
|
|
|
|
@app.route('/test')
|
|
def test():
|
|
"""Test route"""
|
|
return jsonify({'status': 'ok', 'message': 'Flask app is running'})
|
|
|
|
@app.route('/upload', methods=['POST'])
|
|
def upload_video():
|
|
"""Handle video upload"""
|
|
print("📤 Upload request received")
|
|
|
|
if 'file' not in request.files:
|
|
print("❌ No file part in request")
|
|
return jsonify({'error': 'No file part'}), 400
|
|
|
|
file = request.files['file']
|
|
if file.filename == '':
|
|
print("❌ No selected file")
|
|
return jsonify({'error': 'No selected file'}), 400
|
|
|
|
print(f"📁 File received: {file.filename}")
|
|
print(f"📊 File size: {len(file.read())} bytes")
|
|
file.seek(0) # Reset file pointer after reading
|
|
|
|
# Debug file extension
|
|
filename = secure_filename(file.filename)
|
|
file_extension = filename.rsplit('.', 1)[1].lower() if '.' in filename else ''
|
|
print(f"🔍 File extension: .{file_extension}")
|
|
print(f"📋 Allowed extensions: {app.config['ALLOWED_EXTENSIONS']}")
|
|
|
|
if file and allowed_file(file.filename):
|
|
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
|
file.save(filepath)
|
|
|
|
print(f"✅ File saved: {filepath}")
|
|
|
|
# Extract first frame for preview
|
|
preview_frame = extract_first_frame(filepath)
|
|
|
|
if preview_frame is None:
|
|
print("⚠️ Could not extract preview frame, using placeholder")
|
|
return jsonify({
|
|
'message': 'File uploaded successfully (no preview available)',
|
|
'filename': filename,
|
|
'preview': None
|
|
})
|
|
|
|
print("🖼️ Preview frame extracted successfully")
|
|
return jsonify({
|
|
'message': 'File uploaded successfully',
|
|
'filename': filename,
|
|
'preview': preview_frame
|
|
})
|
|
else:
|
|
print(f"❌ File type not allowed: {file_extension}")
|
|
return jsonify({'error': f'File type .{file_extension} not allowed. Allowed types: {app.config["ALLOWED_EXTENSIONS"]}'}), 400
|
|
|
|
def extract_first_frame(video_path):
|
|
"""Extract first frame from video"""
|
|
try:
|
|
# Check if file exists
|
|
if not os.path.exists(video_path):
|
|
print(f"❌ Video file not found: {video_path}")
|
|
return None
|
|
|
|
cap = cv2.VideoCapture(video_path)
|
|
|
|
# Check if video opened successfully
|
|
if not cap.isOpened():
|
|
print(f"❌ Could not open video file: {video_path}")
|
|
return None
|
|
|
|
ret, frame = cap.read()
|
|
cap.release()
|
|
|
|
if ret and frame is not None:
|
|
# Convert to base64 for easy transfer
|
|
success, buffer = cv2.imencode('.jpg', frame)
|
|
if success:
|
|
frame_base64 = base64.b64encode(buffer).decode('utf-8')
|
|
print(f"✅ Successfully extracted first frame from {video_path}")
|
|
return frame_base64
|
|
else:
|
|
print(f"❌ Failed to encode frame as JPEG")
|
|
return None
|
|
else:
|
|
print(f"❌ Could not read first frame from {video_path}")
|
|
return None
|
|
except Exception as e:
|
|
print(f"❌ Error extracting first frame: {e}")
|
|
return None
|
|
|
|
@app.route('/segment', methods=['POST'])
|
|
def segment_object():
|
|
"""Handle object segmentation"""
|
|
print("🎯 Segment request received")
|
|
|
|
global sam_predictor
|
|
|
|
if sam_predictor is None:
|
|
print("🔧 Initializing SAM model...")
|
|
sam_predictor = initialize_sam2(SAM_MODEL_SIZE)
|
|
|
|
data = request.json
|
|
print(f"📥 Received data: {data}")
|
|
|
|
if not data or 'filename' not in data or 'points' not in data:
|
|
print("❌ Missing required parameters")
|
|
return jsonify({'error': 'Missing required parameters'}), 400
|
|
|
|
filename = data['filename']
|
|
points = data['points'] # Expecting [[x1, y1], [x2, y2], ...]
|
|
|
|
video_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
|
|
|
if not os.path.exists(video_path):
|
|
return jsonify({'error': 'Video file not found'}), 404
|
|
|
|
try:
|
|
# Process the video
|
|
output_path = process_video_segmentation(video_path, points)
|
|
|
|
output_filename = os.path.basename(output_path)
|
|
print(f"✅ Segmentation completed: {output_filename}")
|
|
print(f"📁 Output file path: {output_path}")
|
|
print(f"🔍 File exists: {os.path.exists(output_path)}")
|
|
print(f"📊 File size: {os.path.getsize(output_path)} bytes")
|
|
|
|
return jsonify({
|
|
'message': 'Segmentation completed',
|
|
'output_filename': output_filename,
|
|
'debug_file_path': output_path,
|
|
'debug_file_exists': os.path.exists(output_path)
|
|
})
|
|
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)}), 500
|
|
|
|
def process_video_segmentation(video_path, points):
|
|
"""Process video segmentation using SAM2"""
|
|
global sam_predictor
|
|
|
|
# Create output filename
|
|
base_name = os.path.splitext(os.path.basename(video_path))[0]
|
|
output_filename = f"{base_name}_segmented.mp4"
|
|
output_path = os.path.join(app.config['SEGMENTED_FOLDER'], output_filename)
|
|
|
|
# Open video
|
|
cap = cv2.VideoCapture(video_path)
|
|
|
|
# Get video properties
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
# Create video writer
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
|
|
|
frame_count = 0
|
|
|
|
while cap.isOpened():
|
|
ret, frame = cap.read()
|
|
|
|
if not ret:
|
|
break
|
|
|
|
frame_count += 1
|
|
print(f"Processing frame {frame_count}/{total_frames}")
|
|
|
|
# Convert frame to RGB (SAM expects RGB)
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
|
|
# Set image for SAM
|
|
sam_predictor.set_image(frame_rgb)
|
|
|
|
# Convert points to numpy array
|
|
input_points = np.array(points)
|
|
input_labels = np.array([1] * len(points)) # 1 means foreground point
|
|
|
|
# Get masks
|
|
masks, scores, logits = sam_predictor.predict(
|
|
point_coords=input_points,
|
|
point_labels=input_labels,
|
|
multimask_output=False
|
|
)
|
|
|
|
# Create mask from the best prediction
|
|
mask = masks[0].astype(np.uint8) * 255
|
|
|
|
# Apply mask to frame (simple approach - you can customize this)
|
|
masked_frame = apply_mask_to_frame(frame, mask)
|
|
|
|
# Write frame
|
|
out.write(masked_frame)
|
|
|
|
cap.release()
|
|
out.release()
|
|
|
|
return output_path
|
|
|
|
def apply_mask_to_frame(frame, mask):
|
|
"""Apply mask to frame - simple implementation"""
|
|
# Create a colored version of the mask (red overlay)
|
|
colored_mask = np.zeros_like(frame)
|
|
colored_mask[:, :, 2] = mask # Red channel
|
|
|
|
# Blend the mask with the original frame
|
|
alpha = 0.5
|
|
result = cv2.addWeighted(frame, 1 - alpha, colored_mask, alpha, 0)
|
|
|
|
return result
|
|
|
|
@app.route('/download/<filename>')
|
|
def download_file(filename):
|
|
"""Download segmented video"""
|
|
return send_from_directory(
|
|
app.config['SEGMENTED_FOLDER'],
|
|
filename,
|
|
as_attachment=True
|
|
)
|
|
|
|
@app.route('/preview/<filename>')
|
|
def preview_video(filename):
|
|
"""Preview original video"""
|
|
return send_from_directory(
|
|
app.config['UPLOAD_FOLDER'],
|
|
filename
|
|
)
|
|
|
|
@app.route('/segmented/<filename>')
|
|
def serve_segmented_video(filename):
|
|
"""Serve segmented video with proper range request support"""
|
|
file_path = os.path.join(app.config['SEGMENTED_FOLDER'], filename)
|
|
print(f"🎬 Video request for: {filename}")
|
|
print(f"📁 Looking for file at: {file_path}")
|
|
print(f"🔍 File exists: {os.path.exists(file_path)}")
|
|
|
|
if not os.path.exists(file_path):
|
|
print(f"❌ File not found: {file_path}")
|
|
return jsonify({'error': f'File {filename} not found'}), 404
|
|
|
|
print(f"✅ Serving file: {file_path}")
|
|
|
|
# Use send_from_directory with proper MIME type for video
|
|
return send_from_directory(
|
|
app.config['SEGMENTED_FOLDER'],
|
|
filename,
|
|
conditional=True,
|
|
mimetype='video/mp4'
|
|
)
|
|
|
|
if __name__ == '__main__':
|
|
app.run(debug=True, port=5000)
|