|
from flask import Flask, request, jsonify, send_file, render_template_string |
|
import os |
|
import cv2 |
|
from rembg import new_session, remove |
|
from rembg.sessions import sessions_class |
|
import base64 |
|
import uuid |
|
from flask_cors import CORS |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
for session in sessions_class: |
|
session.download_models() |
|
|
|
def process_image(file_path, mask, model, x, y): |
|
im = cv2.imread(file_path, cv2.IMREAD_COLOR) |
|
input_path = f"temp_input_{uuid.uuid4().hex}.png" |
|
output_path = f"temp_output_{uuid.uuid4().hex}.png" |
|
cv2.imwrite(input_path, im) |
|
|
|
with open(input_path, 'rb') as i: |
|
with open(output_path, 'wb') as o: |
|
input_data = i.read() |
|
session = new_session(model) |
|
|
|
output = remove( |
|
input_data, |
|
session=session, |
|
**{"sam_prompt": [{"type": "point", "data": [x, y], "label": 1}]}, |
|
only_mask=(mask == "Mask only") |
|
) |
|
o.write(output) |
|
|
|
|
|
if os.path.exists(input_path): |
|
os.remove(input_path) |
|
|
|
return output_path |
|
|
|
@app.route('/api/process', methods=['POST']) |
|
def api_process(): |
|
if 'file' not in request.files: |
|
return jsonify({'error': 'No file uploaded'}), 400 |
|
|
|
file = request.files['file'] |
|
mask = request.form.get('mask', 'Default') |
|
model = request.form.get('model', 'isnet-general-use') |
|
x = request.form.get('x', None) |
|
y = request.form.get('y', None) |
|
|
|
try: |
|
x = float(x) if x is not None else None |
|
y = float(y) if y is not None else None |
|
except (TypeError, ValueError): |
|
x = None |
|
y = None |
|
|
|
|
|
temp_input = f"temp_{uuid.uuid4().hex}.png" |
|
file.save(temp_input) |
|
|
|
try: |
|
output_path = process_image(temp_input, mask, model, x, y) |
|
return send_file(output_path, mimetype='image/png') |
|
except Exception as e: |
|
return jsonify({'error': str(e)}), 500 |
|
finally: |
|
|
|
if os.path.exists(temp_input): |
|
os.remove(temp_input) |
|
|
|
HTML_TEMPLATE = """ |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>RemBG API</title> |
|
<style> |
|
body { |
|
font-family: Arial, sans-serif; |
|
max-width: 800px; |
|
margin: 0 auto; |
|
padding: 20px; |
|
} |
|
.container { |
|
display: flex; |
|
flex-direction: column; |
|
gap: 20px; |
|
} |
|
.row { |
|
display: flex; |
|
gap: 20px; |
|
} |
|
.column { |
|
flex: 1; |
|
} |
|
img { |
|
max-width: 100%; |
|
height: auto; |
|
border: 1px solid #ddd; |
|
} |
|
.form-group { |
|
margin-bottom: 15px; |
|
} |
|
label { |
|
display: block; |
|
margin-bottom: 5px; |
|
font-weight: bold; |
|
} |
|
select, input, button { |
|
width: 100%; |
|
padding: 8px; |
|
box-sizing: border-box; |
|
} |
|
button { |
|
background-color: #4CAF50; |
|
color: white; |
|
border: none; |
|
cursor: pointer; |
|
padding: 10px; |
|
} |
|
button:hover { |
|
background-color: #45a049; |
|
} |
|
#fetch-code { |
|
width: 100%; |
|
height: 150px; |
|
font-family: monospace; |
|
padding: 10px; |
|
box-sizing: border-box; |
|
background-color: #f5f5f5; |
|
border: 1px solid #ddd; |
|
} |
|
.coords-input { |
|
display: none; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<h1>RemBG API</h1> |
|
<p>Upload an image to process with RemBG. Select options and click "Process Image".</p> |
|
|
|
<div class="container"> |
|
<div class="row"> |
|
<div class="column"> |
|
<div class="form-group"> |
|
<label for="file">Input Image:</label> |
|
<input type="file" id="file" accept="image/*"> |
|
</div> |
|
<img id="input-image" src="" alt="Input image will appear here"> |
|
</div> |
|
<div class="column"> |
|
<div class="form-group"> |
|
<label>Output Image:</label> |
|
<img id="output-image" src="" alt="Output image will appear here"> |
|
</div> |
|
</div> |
|
</div> |
|
|
|
<div class="row"> |
|
<div class="column"> |
|
<div class="form-group"> |
|
<label for="mask">Output Type:</label> |
|
<select id="mask"> |
|
<option value="Default">Default</option> |
|
<option value="Mask only">Mask only</option> |
|
</select> |
|
</div> |
|
</div> |
|
<div class="column"> |
|
<div class="form-group"> |
|
<label for="model">Model Selection:</label> |
|
<select id="model"> |
|
<option value="u2net">u2net</option> |
|
<option value="u2netp">u2netp</option> |
|
<option value="u2net_human_seg">u2net_human_seg</option> |
|
<option value="u2net_cloth_seg">u2net_cloth_seg</option> |
|
<option value="silueta">silueta</option> |
|
<option value="isnet-general-use" selected>isnet-general-use</option> |
|
<option value="isnet-anime">isnet-anime</option> |
|
<option value="sam">sam</option> |
|
<option value="birefnet-general">birefnet-general</option> |
|
<option value="birefnet-general-lite">birefnet-general-lite</option> |
|
<option value="birefnet-portrait">birefnet-portrait</option> |
|
<option value="birefnet-dis">birefnet-dis</option> |
|
<option value="birefnet-hrsod">birefnet-hrsod</option> |
|
<option value="birefnet-cod">birefnet-cod</option> |
|
<option value="birefnet-massive">birefnet-massive</option> |
|
</select> |
|
</div> |
|
</div> |
|
</div> |
|
|
|
<div id="coords-section" style="display: none;"> |
|
<h3>SAM Model Coordinates</h3> |
|
<p>Click on the image to set coordinates (for SAM model only)</p> |
|
<div class="row"> |
|
<div class="column"> |
|
<div class="form-group"> |
|
<label for="x">X Coordinate:</label> |
|
<input type="number" id="x" class="coords-input"> |
|
</div> |
|
</div> |
|
<div class="column"> |
|
<div class="form-group"> |
|
<label for="y">Y Coordinate:</label> |
|
<input type="number" id="y" class="coords-input"> |
|
</div> |
|
</div> |
|
</div> |
|
</div> |
|
|
|
<button id="process-btn">Process Image</button> |
|
|
|
<div class="form-group"> |
|
<label for="fetch-code">Fetch Code:</label> |
|
<textarea id="fetch-code" readonly></textarea> |
|
</div> |
|
</div> |
|
|
|
<script> |
|
const fileInput = document.getElementById('file'); |
|
const inputImage = document.getElementById('input-image'); |
|
const outputImage = document.getElementById('output-image'); |
|
const maskSelect = document.getElementById('mask'); |
|
const modelSelect = document.getElementById('model'); |
|
const xInput = document.getElementById('x'); |
|
const yInput = document.getElementById('y'); |
|
const coordsSection = document.getElementById('coords-section'); |
|
const processBtn = document.getElementById('process-btn'); |
|
const fetchCodeTextarea = document.getElementById('fetch-code'); |
|
|
|
// 画像プレビュー |
|
fileInput.addEventListener('change', function(e) { |
|
const file = e.target.files[0]; |
|
if (file) { |
|
const reader = new FileReader(); |
|
reader.onload = function(event) { |
|
inputImage.src = event.target.result; |
|
updateFetchCode(); |
|
}; |
|
reader.readAsDataURL(file); |
|
} |
|
}); |
|
|
|
// モデル選択でSAMの場合は座標入力表示 |
|
modelSelect.addEventListener('change', function() { |
|
const isSam = modelSelect.value === 'sam'; |
|
coordsSection.style.display = isSam ? 'block' : 'none'; |
|
document.querySelectorAll('.coords-input').forEach(el => { |
|
el.style.display = isSam ? 'block' : 'none'; |
|
}); |
|
updateFetchCode(); |
|
}); |
|
|
|
// 画像クリックで座標取得 (SAMモデルのみ) |
|
inputImage.addEventListener('click', function(e) { |
|
if (modelSelect.value === 'sam') { |
|
const rect = e.target.getBoundingClientRect(); |
|
const x = e.clientX - rect.left; |
|
const y = e.clientY - rect.top; |
|
|
|
xInput.value = Math.round(x); |
|
yInput.value = Math.round(y); |
|
updateFetchCode(); |
|
} |
|
}); |
|
|
|
// その他の入力変更時 |
|
[maskSelect, xInput, yInput].forEach(el => { |
|
el.addEventListener('change', updateFetchCode); |
|
}); |
|
|
|
// 画像処理 |
|
processBtn.addEventListener('click', async function() { |
|
if (!fileInput.files || fileInput.files.length === 0) { |
|
alert('Please select an image file'); |
|
return; |
|
} |
|
|
|
const formData = new FormData(); |
|
formData.append('file', fileInput.files[0]); |
|
formData.append('mask', maskSelect.value); |
|
formData.append('model', modelSelect.value); |
|
|
|
if (modelSelect.value === 'sam' && xInput.value && yInput.value) { |
|
formData.append('x', xInput.value); |
|
formData.append('y', yInput.value); |
|
} |
|
|
|
try { |
|
const response = await fetch('/api/process', { |
|
method: 'POST', |
|
body: formData |
|
}); |
|
|
|
if (!response.ok) { |
|
const error = await response.json(); |
|
throw new Error(error.error || 'Failed to process image'); |
|
} |
|
|
|
const blob = await response.blob(); |
|
outputImage.src = URL.createObjectURL(blob); |
|
} catch (error) { |
|
alert('Error: ' + error.message); |
|
console.error(error); |
|
} |
|
}); |
|
|
|
// Fetchコード生成 |
|
function updateFetchCode() { |
|
const file = fileInput.files && fileInput.files[0]; |
|
if (!file) { |
|
fetchCodeTextarea.value = '// Select an image first'; |
|
return; |
|
} |
|
|
|
const mask = maskSelect.value; |
|
const model = modelSelect.value; |
|
const x = xInput.value; |
|
const y = yInput.value; |
|
|
|
let code = `const formData = new FormData();\n`; |
|
code += `formData.append('file', fileInput.files[0]);\n`; |
|
code += `formData.append('mask', '${mask}');\n`; |
|
code += `formData.append('model', '${model}');\n`; |
|
|
|
if (model === 'sam' && x && y) { |
|
code += `formData.append('x', '${x}');\n`; |
|
code += `formData.append('y', '${y}');\n`; |
|
} |
|
|
|
code += `\n`; |
|
code += `fetch('http://${window.location.host}/api/process', {\n`; |
|
code += ` method: 'POST',\n`; |
|
code += ` body: formData\n`; |
|
code += `})\n`; |
|
code += `.then(response => {\n`; |
|
code += ` if (!response.ok) {\n`; |
|
code += ` return response.json().then(err => { throw new Error(err.error); });\n`; |
|
code += ` }\n`; |
|
code += ` return response.blob();\n`; |
|
code += `})\n`; |
|
code += `.then(blob => {\n`; |
|
code += ` // Handle the processed image blob\n`; |
|
code += ` const imgUrl = URL.createObjectURL(blob);\n`; |
|
code += ` document.getElementById('output-image').src = imgUrl;\n`; |
|
code += `})\n`; |
|
code += `.catch(error => {\n`; |
|
code += ` console.error('Error:', error);\n`; |
|
code += ` alert('Error: ' + error.message);\n`; |
|
code += `});`; |
|
|
|
fetchCodeTextarea.value = code; |
|
} |
|
|
|
// 初期化 |
|
updateFetchCode(); |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
@app.route('/') |
|
def index(): |
|
return render_template_string(HTML_TEMPLATE) |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860) |