RemBG-api / app.py
soiz1's picture
Update app.py
3e25509 verified
from flask import Flask, request, jsonify, send_file, render_template_string
import os
import cv2
from rembg import new_session, remove
from rembg.sessions import sessions_class
import base64
import uuid
from flask_cors import CORS
app = Flask(__name__)
CORS(app)
# セッションの初期化
for session in sessions_class:
session.download_models()
def process_image(file_path, mask, model, x, y):
im = cv2.imread(file_path, cv2.IMREAD_COLOR)
input_path = f"temp_input_{uuid.uuid4().hex}.png"
output_path = f"temp_output_{uuid.uuid4().hex}.png"
cv2.imwrite(input_path, im)
with open(input_path, 'rb') as i:
with open(output_path, 'wb') as o:
input_data = i.read()
session = new_session(model)
output = remove(
input_data,
session=session,
**{"sam_prompt": [{"type": "point", "data": [x, y], "label": 1}]},
only_mask=(mask == "Mask only")
)
o.write(output)
# 一時ファイルを削除
if os.path.exists(input_path):
os.remove(input_path)
return output_path
@app.route('/api/process', methods=['POST'])
def api_process():
if 'file' not in request.files:
return jsonify({'error': 'No file uploaded'}), 400
file = request.files['file']
mask = request.form.get('mask', 'Default')
model = request.form.get('model', 'isnet-general-use')
x = request.form.get('x', None)
y = request.form.get('y', None)
try:
x = float(x) if x is not None else None
y = float(y) if y is not None else None
except (TypeError, ValueError):
x = None
y = None
# 一時ファイルに保存
temp_input = f"temp_{uuid.uuid4().hex}.png"
file.save(temp_input)
try:
output_path = process_image(temp_input, mask, model, x, y)
return send_file(output_path, mimetype='image/png')
except Exception as e:
return jsonify({'error': str(e)}), 500
finally:
# 一時ファイルを削除
if os.path.exists(temp_input):
os.remove(temp_input)
HTML_TEMPLATE = """
<!DOCTYPE html>
<html>
<head>
<title>RemBG API</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
}
.container {
display: flex;
flex-direction: column;
gap: 20px;
}
.row {
display: flex;
gap: 20px;
}
.column {
flex: 1;
}
img {
max-width: 100%;
height: auto;
border: 1px solid #ddd;
}
.form-group {
margin-bottom: 15px;
}
label {
display: block;
margin-bottom: 5px;
font-weight: bold;
}
select, input, button {
width: 100%;
padding: 8px;
box-sizing: border-box;
}
button {
background-color: #4CAF50;
color: white;
border: none;
cursor: pointer;
padding: 10px;
}
button:hover {
background-color: #45a049;
}
#fetch-code {
width: 100%;
height: 150px;
font-family: monospace;
padding: 10px;
box-sizing: border-box;
background-color: #f5f5f5;
border: 1px solid #ddd;
}
.coords-input {
display: none;
}
</style>
</head>
<body>
<h1>RemBG API</h1>
<p>Upload an image to process with RemBG. Select options and click "Process Image".</p>
<div class="container">
<div class="row">
<div class="column">
<div class="form-group">
<label for="file">Input Image:</label>
<input type="file" id="file" accept="image/*">
</div>
<img id="input-image" src="" alt="Input image will appear here">
</div>
<div class="column">
<div class="form-group">
<label>Output Image:</label>
<img id="output-image" src="" alt="Output image will appear here">
</div>
</div>
</div>
<div class="row">
<div class="column">
<div class="form-group">
<label for="mask">Output Type:</label>
<select id="mask">
<option value="Default">Default</option>
<option value="Mask only">Mask only</option>
</select>
</div>
</div>
<div class="column">
<div class="form-group">
<label for="model">Model Selection:</label>
<select id="model">
<option value="u2net">u2net</option>
<option value="u2netp">u2netp</option>
<option value="u2net_human_seg">u2net_human_seg</option>
<option value="u2net_cloth_seg">u2net_cloth_seg</option>
<option value="silueta">silueta</option>
<option value="isnet-general-use" selected>isnet-general-use</option>
<option value="isnet-anime">isnet-anime</option>
<option value="sam">sam</option>
<option value="birefnet-general">birefnet-general</option>
<option value="birefnet-general-lite">birefnet-general-lite</option>
<option value="birefnet-portrait">birefnet-portrait</option>
<option value="birefnet-dis">birefnet-dis</option>
<option value="birefnet-hrsod">birefnet-hrsod</option>
<option value="birefnet-cod">birefnet-cod</option>
<option value="birefnet-massive">birefnet-massive</option>
</select>
</div>
</div>
</div>
<div id="coords-section" style="display: none;">
<h3>SAM Model Coordinates</h3>
<p>Click on the image to set coordinates (for SAM model only)</p>
<div class="row">
<div class="column">
<div class="form-group">
<label for="x">X Coordinate:</label>
<input type="number" id="x" class="coords-input">
</div>
</div>
<div class="column">
<div class="form-group">
<label for="y">Y Coordinate:</label>
<input type="number" id="y" class="coords-input">
</div>
</div>
</div>
</div>
<button id="process-btn">Process Image</button>
<div class="form-group">
<label for="fetch-code">Fetch Code:</label>
<textarea id="fetch-code" readonly></textarea>
</div>
</div>
<script>
const fileInput = document.getElementById('file');
const inputImage = document.getElementById('input-image');
const outputImage = document.getElementById('output-image');
const maskSelect = document.getElementById('mask');
const modelSelect = document.getElementById('model');
const xInput = document.getElementById('x');
const yInput = document.getElementById('y');
const coordsSection = document.getElementById('coords-section');
const processBtn = document.getElementById('process-btn');
const fetchCodeTextarea = document.getElementById('fetch-code');
// 画像プレビュー
fileInput.addEventListener('change', function(e) {
const file = e.target.files[0];
if (file) {
const reader = new FileReader();
reader.onload = function(event) {
inputImage.src = event.target.result;
updateFetchCode();
};
reader.readAsDataURL(file);
}
});
// モデル選択でSAMの場合は座標入力表示
modelSelect.addEventListener('change', function() {
const isSam = modelSelect.value === 'sam';
coordsSection.style.display = isSam ? 'block' : 'none';
document.querySelectorAll('.coords-input').forEach(el => {
el.style.display = isSam ? 'block' : 'none';
});
updateFetchCode();
});
// 画像クリックで座標取得 (SAMモデルのみ)
inputImage.addEventListener('click', function(e) {
if (modelSelect.value === 'sam') {
const rect = e.target.getBoundingClientRect();
const x = e.clientX - rect.left;
const y = e.clientY - rect.top;
xInput.value = Math.round(x);
yInput.value = Math.round(y);
updateFetchCode();
}
});
// その他の入力変更時
[maskSelect, xInput, yInput].forEach(el => {
el.addEventListener('change', updateFetchCode);
});
// 画像処理
processBtn.addEventListener('click', async function() {
if (!fileInput.files || fileInput.files.length === 0) {
alert('Please select an image file');
return;
}
const formData = new FormData();
formData.append('file', fileInput.files[0]);
formData.append('mask', maskSelect.value);
formData.append('model', modelSelect.value);
if (modelSelect.value === 'sam' && xInput.value && yInput.value) {
formData.append('x', xInput.value);
formData.append('y', yInput.value);
}
try {
const response = await fetch('/api/process', {
method: 'POST',
body: formData
});
if (!response.ok) {
const error = await response.json();
throw new Error(error.error || 'Failed to process image');
}
const blob = await response.blob();
outputImage.src = URL.createObjectURL(blob);
} catch (error) {
alert('Error: ' + error.message);
console.error(error);
}
});
// Fetchコード生成
function updateFetchCode() {
const file = fileInput.files && fileInput.files[0];
if (!file) {
fetchCodeTextarea.value = '// Select an image first';
return;
}
const mask = maskSelect.value;
const model = modelSelect.value;
const x = xInput.value;
const y = yInput.value;
let code = `const formData = new FormData();\n`;
code += `formData.append('file', fileInput.files[0]);\n`;
code += `formData.append('mask', '${mask}');\n`;
code += `formData.append('model', '${model}');\n`;
if (model === 'sam' && x && y) {
code += `formData.append('x', '${x}');\n`;
code += `formData.append('y', '${y}');\n`;
}
code += `\n`;
code += `fetch('http://${window.location.host}/api/process', {\n`;
code += ` method: 'POST',\n`;
code += ` body: formData\n`;
code += `})\n`;
code += `.then(response => {\n`;
code += ` if (!response.ok) {\n`;
code += ` return response.json().then(err => { throw new Error(err.error); });\n`;
code += ` }\n`;
code += ` return response.blob();\n`;
code += `})\n`;
code += `.then(blob => {\n`;
code += ` // Handle the processed image blob\n`;
code += ` const imgUrl = URL.createObjectURL(blob);\n`;
code += ` document.getElementById('output-image').src = imgUrl;\n`;
code += `})\n`;
code += `.catch(error => {\n`;
code += ` console.error('Error:', error);\n`;
code += ` alert('Error: ' + error.message);\n`;
code += `});`;
fetchCodeTextarea.value = code;
}
// 初期化
updateFetchCode();
</script>
</body>
</html>
"""
@app.route('/')
def index():
return render_template_string(HTML_TEMPLATE)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)