Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@
# Get the directory the script is running from
APP_ROOT = os.getenv("CHAINLIT_APP_ROOT", os.getcwd())

# Create the directory to store the uploaded files
FILES_DIRECTORY = Path(APP_ROOT) / ".files"
FILES_DIRECTORY.mkdir(exist_ok=True)

config_dir = os.path.join(APP_ROOT, ".chainlit")
public_dir = os.path.join(APP_ROOT, "public")
config_file = os.path.join(config_dir, "config.toml")
Expand Down Expand Up @@ -137,6 +133,8 @@
accept = ["*/*"]
max_files = 20
max_size_mb = 500
# Directory to store uploaded files. Relative paths are resolved from the app root.
# files_dir = ".files"

[features.audio]
# Enable audio features
Expand Down Expand Up @@ -290,6 +288,7 @@ class SpontaneousFileUploadFeature(BaseModel):
accept: Optional[Union[List[str], Dict[str, List[str]]]] = None
max_files: Optional[int] = None
max_size_mb: Optional[int] = None
files_dir: str = ".files"


class AudioFeature(BaseModel):
Expand Down Expand Up @@ -691,3 +690,12 @@ def lint_translations():


config = load_config()


def get_files_directory() -> Path:
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
"""Get the files directory from config. Relative paths resolve from APP_ROOT."""
feature = config.features.spontaneous_file_upload
files_dir = Path(feature.files_dir if feature else ".files")
if not files_dir.is_absolute():
files_dir = Path(APP_ROOT) / files_dir
return files_dir
10 changes: 7 additions & 3 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@
APP_ROOT,
BACKEND_ROOT,
DEFAULT_HOST,
FILES_DIRECTORY,
PACKAGE_ROOT,
ChainlitConfig,
config,
get_files_directory,
load_module,
public_dir,
reload_config,
Expand Down Expand Up @@ -188,8 +188,12 @@ async def watch_files_for_changes():
except asyncio.exceptions.CancelledError:
pass

if FILES_DIRECTORY.is_dir():
shutil.rmtree(FILES_DIRECTORY)
files_dir = get_files_directory()
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
if files_dir.is_dir():
app_root = Path(APP_ROOT).resolve()
resolved = files_dir.resolve()
if resolved == app_root / ".files" or app_root in resolved.parents:
shutil.rmtree(files_dir)

# Force exit the process to avoid potential AnyIO threads still running
os._exit(0)
Expand Down
6 changes: 3 additions & 3 deletions backend/chainlit/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def __init__(

@property
def files_dir(self):
from chainlit.config import FILES_DIRECTORY
from chainlit.config import get_files_directory

return FILES_DIRECTORY / self.id
return get_files_directory() / self.id

async def persist_file(
self,
Expand All @@ -160,7 +160,7 @@ async def persist_file(
"Either path or content must be provided to persist a file"
)

self.files_dir.mkdir(exist_ok=True)
self.files_dir.mkdir(parents=True, exist_ok=True)

file_id = str(uuid.uuid4())

Expand Down
28 changes: 21 additions & 7 deletions backend/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ def test_base_session_with_chat_profile(self):

def test_base_session_files_dir(self):
"""Test BaseSession files_dir property."""
with patch("chainlit.config.FILES_DIRECTORY", Path("/tmp/files")):
with patch(
"chainlit.config.get_files_directory", return_value=Path("/tmp/files")
):
session = BaseSession(
id="test_id",
client_type="webapp",
Expand All @@ -204,7 +206,9 @@ def test_base_session_files_dir(self):
async def test_base_session_persist_file_with_content(self):
"""Test persisting a file with content."""
with tempfile.TemporaryDirectory() as tmpdir:
with patch("chainlit.config.FILES_DIRECTORY", Path(tmpdir)):
with patch(
"chainlit.config.get_files_directory", return_value=Path(tmpdir)
):
session = BaseSession(
id="test_id",
client_type="webapp",
Expand All @@ -230,7 +234,9 @@ async def test_base_session_persist_file_with_content(self):
async def test_base_session_persist_file_with_string_content(self):
"""Test persisting a file with string content."""
with tempfile.TemporaryDirectory() as tmpdir:
with patch("chainlit.config.FILES_DIRECTORY", Path(tmpdir)):
with patch(
"chainlit.config.get_files_directory", return_value=Path(tmpdir)
):
session = BaseSession(
id="test_id",
client_type="webapp",
Expand Down Expand Up @@ -347,7 +353,9 @@ def test_http_session_initialization(self):
async def test_http_session_delete(self):
"""Test HTTPSession delete method."""
with tempfile.TemporaryDirectory() as tmpdir:
with patch("chainlit.config.FILES_DIRECTORY", Path(tmpdir)):
with patch(
"chainlit.config.get_files_directory", return_value=Path(tmpdir)
):
session = HTTPSession(
id="http_id",
client_type="copilot",
Expand Down Expand Up @@ -445,7 +453,9 @@ async def test_websocket_session_delete(self):
from chainlit.session import ws_sessions_id, ws_sessions_sid

with tempfile.TemporaryDirectory() as tmpdir:
with patch("chainlit.config.FILES_DIRECTORY", Path(tmpdir)):
with patch(
"chainlit.config.get_files_directory", return_value=Path(tmpdir)
):
session = WebsocketSession(
id="ws_id",
socket_id="socket_123",
Expand Down Expand Up @@ -567,7 +577,9 @@ def test_base_session_with_all_client_types(self):
async def test_persist_file_with_mime_extension(self):
"""Test that persist_file adds correct file extension based on MIME type."""
with tempfile.TemporaryDirectory() as tmpdir:
with patch("chainlit.config.FILES_DIRECTORY", Path(tmpdir)):
with patch(
"chainlit.config.get_files_directory", return_value=Path(tmpdir)
):
session = BaseSession(
id="test_id",
client_type="webapp",
Expand Down Expand Up @@ -612,7 +624,9 @@ async def test_websocket_session_delete_with_mcp_sessions(self):
"""Test WebsocketSession delete with MCP sessions."""

with tempfile.TemporaryDirectory() as tmpdir:
with patch("chainlit.config.FILES_DIRECTORY", Path(tmpdir)):
with patch(
"chainlit.config.get_files_directory", return_value=Path(tmpdir)
):
session = WebsocketSession(
id="ws_id",
socket_id="socket_123",
Expand Down
Loading