Initial Commit
This commit is contained in:
364
app.py
Normal file
364
app.py
Normal file
@@ -0,0 +1,364 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from flask import Flask, request, jsonify, send_from_directory, render_template
|
||||
from flask_cors import CORS
|
||||
from werkzeug.utils import secure_filename
|
||||
from dotenv import load_dotenv
|
||||
import torch
|
||||
from segment_anything import SamPredictor, sam_model_registry
|
||||
import tempfile
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
# Configuration
|
||||
app.config['UPLOAD_FOLDER'] = os.getenv('UPLOAD_FOLDER', 'uploads')
|
||||
app.config['SEGMENTED_FOLDER'] = os.getenv('SEGMENTED_FOLDER', 'segmented')
|
||||
app.config['ALLOWED_EXTENSIONS'] = set(os.getenv('ALLOWED_EXTENSIONS', '.mp4,.avi,.mov,.mkv').split(','))
|
||||
|
||||
# Ensure directories exist
|
||||
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
||||
os.makedirs(app.config['SEGMENTED_FOLDER'], exist_ok=True)
|
||||
|
||||
# Initialize SAM2 model
|
||||
def initialize_sam2(model_size="vit_b"):
|
||||
"""Initialize the SAM2 model"""
|
||||
print(f"Initializing SAM2 model ({model_size})...")
|
||||
|
||||
try:
|
||||
from segment_anything import SamPredictor, sam_model_registry
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"SAM is not installed. Please install it manually from GitHub:\n"
|
||||
"git clone https://github.com/facebookresearch/segment-anything.git\n"
|
||||
"cd segment-anything\n"
|
||||
"pip install -e .\n"
|
||||
"Then download the model checkpoint and place it in the root directory."
|
||||
)
|
||||
|
||||
# Map model sizes to checkpoint files
|
||||
model_configs = {
|
||||
"vit_h": {
|
||||
"checkpoint": "sam_vit_h_4b8939.pth",
|
||||
"model_type": "vit_h"
|
||||
},
|
||||
"vit_l": {
|
||||
"checkpoint": "sam_vit_l_0b3195.pth",
|
||||
"model_type": "vit_l"
|
||||
},
|
||||
"vit_b": {
|
||||
"checkpoint": "sam_vit_b_01ec64.pth",
|
||||
"model_type": "vit_b"
|
||||
}
|
||||
}
|
||||
|
||||
if model_size not in model_configs:
|
||||
raise ValueError(f"Unknown model size: {model_size}. Choose from: vit_h, vit_l, vit_b")
|
||||
|
||||
config = model_configs[model_size]
|
||||
sam_checkpoint = config["checkpoint"]
|
||||
model_type = config["model_type"]
|
||||
|
||||
# Check if checkpoint file exists
|
||||
if not os.path.exists(sam_checkpoint):
|
||||
raise FileNotFoundError(
|
||||
f"SAM2 checkpoint file '{sam_checkpoint}' not found. "
|
||||
f"Please download it from https://github.com/facebookresearch/segment-anything "
|
||||
f"and place it in the root directory."
|
||||
)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
||||
sam.to(device=device)
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
print(f"SAM2 model ({model_type}) initialized on {device}")
|
||||
return predictor
|
||||
|
||||
# Global predictor instance
|
||||
sam_predictor = None
|
||||
|
||||
# Configuration for SAM model
|
||||
SAM_MODEL_SIZE = os.getenv('SAM_MODEL_SIZE', 'vit_b') # Read from .env or default to ViT-B
|
||||
|
||||
def allowed_file(filename):
|
||||
"""Check if file has allowed extension"""
|
||||
if '.' not in filename:
|
||||
return False
|
||||
|
||||
# Get the file extension with dot (e.g., '.mp4')
|
||||
file_extension = '.' + filename.rsplit('.', 1)[1].lower()
|
||||
|
||||
# Debug
|
||||
print(f"🔍 Checking extension: {file_extension}")
|
||||
print(f"📋 Allowed extensions: {app.config['ALLOWED_EXTENSIONS']}")
|
||||
|
||||
return file_extension in app.config['ALLOWED_EXTENSIONS']
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
"""Main page"""
|
||||
return render_template('index.html')
|
||||
|
||||
@app.route('/test')
|
||||
def test():
|
||||
"""Test route"""
|
||||
return jsonify({'status': 'ok', 'message': 'Flask app is running'})
|
||||
|
||||
@app.route('/upload', methods=['POST'])
|
||||
def upload_video():
|
||||
"""Handle video upload"""
|
||||
print("📤 Upload request received")
|
||||
|
||||
if 'file' not in request.files:
|
||||
print("❌ No file part in request")
|
||||
return jsonify({'error': 'No file part'}), 400
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
print("❌ No selected file")
|
||||
return jsonify({'error': 'No selected file'}), 400
|
||||
|
||||
print(f"📁 File received: {file.filename}")
|
||||
print(f"📊 File size: {len(file.read())} bytes")
|
||||
file.seek(0) # Reset file pointer after reading
|
||||
|
||||
# Debug file extension
|
||||
filename = secure_filename(file.filename)
|
||||
file_extension = filename.rsplit('.', 1)[1].lower() if '.' in filename else ''
|
||||
print(f"🔍 File extension: .{file_extension}")
|
||||
print(f"📋 Allowed extensions: {app.config['ALLOWED_EXTENSIONS']}")
|
||||
|
||||
if file and allowed_file(file.filename):
|
||||
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
||||
file.save(filepath)
|
||||
|
||||
print(f"✅ File saved: {filepath}")
|
||||
|
||||
# Extract first frame for preview
|
||||
preview_frame = extract_first_frame(filepath)
|
||||
|
||||
if preview_frame is None:
|
||||
print("⚠️ Could not extract preview frame, using placeholder")
|
||||
return jsonify({
|
||||
'message': 'File uploaded successfully (no preview available)',
|
||||
'filename': filename,
|
||||
'preview': None
|
||||
})
|
||||
|
||||
print("🖼️ Preview frame extracted successfully")
|
||||
return jsonify({
|
||||
'message': 'File uploaded successfully',
|
||||
'filename': filename,
|
||||
'preview': preview_frame
|
||||
})
|
||||
else:
|
||||
print(f"❌ File type not allowed: {file_extension}")
|
||||
return jsonify({'error': f'File type .{file_extension} not allowed. Allowed types: {app.config["ALLOWED_EXTENSIONS"]}'}), 400
|
||||
|
||||
def extract_first_frame(video_path):
|
||||
"""Extract first frame from video"""
|
||||
try:
|
||||
# Check if file exists
|
||||
if not os.path.exists(video_path):
|
||||
print(f"❌ Video file not found: {video_path}")
|
||||
return None
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
# Check if video opened successfully
|
||||
if not cap.isOpened():
|
||||
print(f"❌ Could not open video file: {video_path}")
|
||||
return None
|
||||
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
|
||||
if ret and frame is not None:
|
||||
# Convert to base64 for easy transfer
|
||||
success, buffer = cv2.imencode('.jpg', frame)
|
||||
if success:
|
||||
frame_base64 = base64.b64encode(buffer).decode('utf-8')
|
||||
print(f"✅ Successfully extracted first frame from {video_path}")
|
||||
return frame_base64
|
||||
else:
|
||||
print(f"❌ Failed to encode frame as JPEG")
|
||||
return None
|
||||
else:
|
||||
print(f"❌ Could not read first frame from {video_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"❌ Error extracting first frame: {e}")
|
||||
return None
|
||||
|
||||
@app.route('/segment', methods=['POST'])
|
||||
def segment_object():
|
||||
"""Handle object segmentation"""
|
||||
print("🎯 Segment request received")
|
||||
|
||||
global sam_predictor
|
||||
|
||||
if sam_predictor is None:
|
||||
print("🔧 Initializing SAM model...")
|
||||
sam_predictor = initialize_sam2(SAM_MODEL_SIZE)
|
||||
|
||||
data = request.json
|
||||
print(f"📥 Received data: {data}")
|
||||
|
||||
if not data or 'filename' not in data or 'points' not in data:
|
||||
print("❌ Missing required parameters")
|
||||
return jsonify({'error': 'Missing required parameters'}), 400
|
||||
|
||||
filename = data['filename']
|
||||
points = data['points'] # Expecting [[x1, y1], [x2, y2], ...]
|
||||
|
||||
video_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
||||
|
||||
if not os.path.exists(video_path):
|
||||
return jsonify({'error': 'Video file not found'}), 404
|
||||
|
||||
try:
|
||||
# Process the video
|
||||
output_path = process_video_segmentation(video_path, points)
|
||||
|
||||
output_filename = os.path.basename(output_path)
|
||||
print(f"✅ Segmentation completed: {output_filename}")
|
||||
print(f"📁 Output file path: {output_path}")
|
||||
print(f"🔍 File exists: {os.path.exists(output_path)}")
|
||||
print(f"📊 File size: {os.path.getsize(output_path)} bytes")
|
||||
|
||||
return jsonify({
|
||||
'message': 'Segmentation completed',
|
||||
'output_filename': output_filename,
|
||||
'debug_file_path': output_path,
|
||||
'debug_file_exists': os.path.exists(output_path)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
def process_video_segmentation(video_path, points):
|
||||
"""Process video segmentation using SAM2"""
|
||||
global sam_predictor
|
||||
|
||||
# Create output filename
|
||||
base_name = os.path.splitext(os.path.basename(video_path))[0]
|
||||
output_filename = f"{base_name}_segmented.mp4"
|
||||
output_path = os.path.join(app.config['SEGMENTED_FOLDER'], output_filename)
|
||||
|
||||
# Open video
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
# Get video properties
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
# Create video writer
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
||||
|
||||
frame_count = 0
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
|
||||
if not ret:
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
print(f"Processing frame {frame_count}/{total_frames}")
|
||||
|
||||
# Convert frame to RGB (SAM expects RGB)
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Set image for SAM
|
||||
sam_predictor.set_image(frame_rgb)
|
||||
|
||||
# Convert points to numpy array
|
||||
input_points = np.array(points)
|
||||
input_labels = np.array([1] * len(points)) # 1 means foreground point
|
||||
|
||||
# Get masks
|
||||
masks, scores, logits = sam_predictor.predict(
|
||||
point_coords=input_points,
|
||||
point_labels=input_labels,
|
||||
multimask_output=False
|
||||
)
|
||||
|
||||
# Create mask from the best prediction
|
||||
mask = masks[0].astype(np.uint8) * 255
|
||||
|
||||
# Apply mask to frame (simple approach - you can customize this)
|
||||
masked_frame = apply_mask_to_frame(frame, mask)
|
||||
|
||||
# Write frame
|
||||
out.write(masked_frame)
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
return output_path
|
||||
|
||||
def apply_mask_to_frame(frame, mask):
|
||||
"""Apply mask to frame - simple implementation"""
|
||||
# Create a colored version of the mask (red overlay)
|
||||
colored_mask = np.zeros_like(frame)
|
||||
colored_mask[:, :, 2] = mask # Red channel
|
||||
|
||||
# Blend the mask with the original frame
|
||||
alpha = 0.5
|
||||
result = cv2.addWeighted(frame, 1 - alpha, colored_mask, alpha, 0)
|
||||
|
||||
return result
|
||||
|
||||
@app.route('/download/<filename>')
|
||||
def download_file(filename):
|
||||
"""Download segmented video"""
|
||||
return send_from_directory(
|
||||
app.config['SEGMENTED_FOLDER'],
|
||||
filename,
|
||||
as_attachment=True
|
||||
)
|
||||
|
||||
@app.route('/preview/<filename>')
|
||||
def preview_video(filename):
|
||||
"""Preview original video"""
|
||||
return send_from_directory(
|
||||
app.config['UPLOAD_FOLDER'],
|
||||
filename
|
||||
)
|
||||
|
||||
@app.route('/segmented/<filename>')
|
||||
def serve_segmented_video(filename):
|
||||
"""Serve segmented video with proper range request support"""
|
||||
file_path = os.path.join(app.config['SEGMENTED_FOLDER'], filename)
|
||||
print(f"🎬 Video request for: {filename}")
|
||||
print(f"📁 Looking for file at: {file_path}")
|
||||
print(f"🔍 File exists: {os.path.exists(file_path)}")
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
print(f"❌ File not found: {file_path}")
|
||||
return jsonify({'error': f'File {filename} not found'}), 404
|
||||
|
||||
print(f"✅ Serving file: {file_path}")
|
||||
|
||||
# Use send_from_directory with proper MIME type for video
|
||||
return send_from_directory(
|
||||
app.config['SEGMENTED_FOLDER'],
|
||||
filename,
|
||||
conditional=True,
|
||||
mimetype='video/mp4'
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True, port=5000)
|
||||
Reference in New Issue
Block a user